├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── caps-megadepth.png ├── caps-scannet.png ├── fcgf-3dmatch.png ├── fpfh-3dmatch.png └── overview.png └── code ├── README.md ├── dataset ├── README.md ├── base.py ├── megadepth_sgp.py ├── megadepth_test.py ├── megadepth_train.py ├── threedmatch_sgp.py ├── threedmatch_test.py └── threedmatch_train.py ├── geometry ├── common.py ├── image.py └── pointcloud.py ├── perception2d ├── adaptor.py ├── config_sgp.yml ├── config_sgp_sample.yml ├── config_test.yml ├── config_train.yml ├── sgp.py ├── test.py └── train.py ├── perception3d ├── adaptor.py ├── config_sgp.yml ├── config_sgp_sample.yml ├── config_test.yml ├── config_train.yml ├── sgp.py ├── test.py └── train.py └── sgp_base.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | pseudo-label/ 141 | 3dmatch_train/ 142 | logs/ 143 | out/ 144 | outputs/ 145 | caps_logs/ 146 | caps_outputs/ 147 | fcgf_outputs/ 148 | caps_pseudo_label/ 149 | fcgf_pseudo_label/ 150 | *.npz 151 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "code/ext/FCGF"] 2 | path = code/ext/FCGF 3 | url = https://github.com/chrischoy/FCGF.git 4 | [submodule "code/ext/caps"] 5 | path = code/ext/caps 6 | url = https://github.com/qianqianwang68/caps.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Wei Dong and Heng Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGP: Self-supervised Geometric Perception 2 | [CVPR 2021 Oral] Self-supervised Geometric Perception 3 | https://arxiv.org/abs/2103.03114 4 | 5 | ## Introduction 6 | In short, SGP is, to the best of our knowledge, the first general framework for feature learning in geometric perception without any supervision from ground-truth geometric labels. 7 | 8 | SGP runs in an EM fashion. It iteratively performs robust estimation of the geometric models to generate pseudo-labels, and feature learning under the supervision of the noisy pseudo-labels. 9 | 10 | overview 11 | 12 | 13 | 14 | We applied SGP to camera pose estimation and point cloud registration, demonstrating performance that is on par or even superior to supervised oracles in large-scale real datasets. 15 | 16 | ### Camera pose estimation 17 | 18 | Deep image features like [CAPS](https://github.com/qianqianwang68/caps) can be trained with relative pose labels generated by 5pt-RANSAC, bootstraped with the handcrafted SIFT feature. They can be later used in robust relative camera pose estimation. 19 | 20 |
21 | 22 | 23 |
24 | 25 | ### Point cloud registration 26 | 27 | Deep 3D features like [FCGF](https://github.com/chrischoy/FCGF) can be trained with relative pose labels generated by 3pt-RANSAC, bootstraped by the handcrafted FPFH feature. They can be later used in robust point cloud registration. 28 | 29 |
30 | 31 | 32 |
33 | 34 | 35 | 36 | ## Code 37 | 38 | Please see `code/` for detailed intructions about how to use the code base. 39 | 40 | 41 | 42 | ## Citation 43 | 44 | ``` 45 | @inproceedings{yang2021sgp, 46 | title={Self-supervised Geometric Perception}, 47 | author={Yang, Heng and Dong, Wei and Carlone, Luca and Koltun, Vladlen}, 48 | booktitle={CVPR}, 49 | year={2021} 50 | } 51 | ``` 52 | 53 | -------------------------------------------------------------------------------- /assets/caps-megadepth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/caps-megadepth.png -------------------------------------------------------------------------------- /assets/caps-scannet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/caps-scannet.png -------------------------------------------------------------------------------- /assets/fcgf-3dmatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/fcgf-3dmatch.png -------------------------------------------------------------------------------- /assets/fpfh-3dmatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/fpfh-3dmatch.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/overview.png -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Geometric Perception 2 | 3 | ## Disclaimer 4 | In comparison to the code for the paper submission, this repository has been fully rewritten for a better readability and easier generalization. Please file a GitHub issue if there is anything buggy. 5 | 6 | Since the final benchmark results depend on RANSAC (in performing robust model estimation), we expect minor discrepancies comparing to the numbers published in the paper (due to randomness of RANSAC). Again, please submit an issue if a significant difference is observed. 7 | 8 | ### TODO 9 | - [ ] Release Pretrained weights. 10 | 11 | ## Setup 12 | Clone the project by 13 | ``` 14 | git clone --recursive https://github.com/theNded/SGP.git 15 | ``` 16 | This will by default clone the submodules [FCGF](https://github.com/chrischoy/FCGF) and [CAPS](https://github.com/qianqianwang68/caps) for 3D and 2D perception, respectively. Please follow the instructions in the corresponding repositories to configure the submodule(s) of interest. 17 | 18 | ## Datasets 19 | For the 3D perception task, please download the [3DMatch dataset](https://drive.google.com/file/d/1P5xS4ZGrmuoElZbKeC6bM5UoWz8H9SL1/view) reorganized by us that aggregates point clouds by scenes. The reorganized [test set](https://drive.google.com/file/d/1AmmADbhk5X62Q6CnsbJcwm1BK0Uov1yG/view?usp=sharing) is also available. 20 | 21 | For the 2D perception task, please download the [MegaDepth dataset](https://drive.google.com/file/d/1-o4TRLx6qm8ehQevV7nExmVJXfMxj657/view) provided by the author of CAPS. The test set has not been officially released, so please contact [CAPS authors](https://github.com/qianqianwang68/caps) for the data. We only provide the data loader. 22 | 23 | ## Vanilla training and testing 24 | Copy and/or modify the `config_[train|test].yml` files in `perception3d`. The configurable parameters can be found in `perception3d/adaptor.py`. Then run 25 | ``` 26 | python perception3d/train.py --config /path/to/config.yml 27 | python perception3d/test.py --config /path/to/config.yml --weights /path/to/weights.pth 28 | ``` 29 | You may also add `--debug` to visualize the registration/alignment results. The same applies to 2D. 30 | 31 | For a sanity check, you may first use pretrained weights of deep features (i.e., supervised oracle) that are available on the correspondent websites/GitHub repos. The system should be able to run seamlessly. 32 | 33 | Note our codebase is non-intrusive, i.e., the original repository are not modified, hence there are minor inconsistencies in configurations between 2D and 3D. For instance, pretrained weights are named as `weights` for FCGF and `ckpt_path` for CAPS. Please carefuly check correspondent config options located in `adaptor.py`. 34 | 35 | 36 | ## Self-supervised training 37 | The training runs in teacher-student meta loops, started with a bootstrap step (`bs`) supervised by hand-crafted features (SIFT/FPFH), followed by actual training loops (`00`, `01`) that trains a deep feature (CAPS/FCGF) with itself. After similarly configuring `config_sgp.yml`, run 38 | ``` 39 | python perception3d/sgp.py --config /path/to/config.yml 40 | ``` 41 | As the SGP process is time consuming, it is suggested to first perform a sanity check on a minimal set of data, configured in `config_sgp_sample.yml`. 42 | 43 | To test the results per meta-iteration, by default run 44 | ```shell 45 | # 2D 46 | python perception2d/test.py --config perception2d/config_test.yml --ckpt_path caps_outputs/bs/caps_sgp/040000.pth 47 | # 3D 48 | python perception3d/test.py --config perception3d/config_test.yml --weights fcgf_outputs/bs/checkpoint.pth 49 | ``` 50 | for the trained feature from bootstrap (`bs`), and 51 | ```shell 52 | # 2D 53 | python perception2d/test.py --config perception2d/config_test.yml --ckpt_path caps_outputs/00/caps_sgp/040000.pth 54 | # 3D 55 | python perception3d/test.py --config perception3d/config_test.yml --weights fcgf_outputs/00/checkpoint.pth 56 | ``` 57 | for the trained feature from 0-th meta-iteration (`00`) and following meta iterations. 58 | 59 | To restart or extend current meta iterations, change `restart_meta_iter` and `max_meta_iters` in the configuration. 60 | 61 | ## Extension 62 | To use your own dataset organized by scenes, checkout `dataset/`. A README details how the datasets are organized and how you may extend the base class and parse your scenes. 63 | 64 | To train your own deep feature, checkout `sgp_base.py` and the corresponding `perception2d/` or `perception3d/` files. They share a similar interface for the `bootstrap` teaching-learning and `iterative` self-supervised teaching-learning. 65 | -------------------------------------------------------------------------------- /code/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Dataloader 2 | 3 | The overall target of a dataloader of SGP is to provide the loader of tuples: 4 | ```python 5 | def __getitem__(self, idx): 6 | # some processing 7 | return data_src, data_dst, info_src, info_dst, info_pair 8 | ``` 9 | where each tuple contains 10 | - `data_src`, `data_dst`: image for 2D perception, point cloud for 3D perception. 11 | - `info`: additional properties, e.g. (unary) intrinsics for one image, (mutual) overlaps between two point clouds. They do not directly provide the supervision, but may serve as very weak supervision signals in geometry perception tasks. 12 | 13 | 14 | As SGP works on pairs of data with overlaps in a scene, we assume a large dataset is consisting of various smaller scenes where overlaps exist: 15 | ``` 16 | root/ 17 | |_ scene_0/ 18 | |_ data_0 19 | |_ data_1 20 | |_ ... 21 | |_ data_n 22 | |_ pairs.txt 23 | |_ metadata.txt 24 | |_ scene_1/ 25 | |_ ... 26 | |_ scene_m/ 27 | ``` 28 | Here, the root folder contains `m` scenes. Each scene includes `n` data files. 29 | 30 | Assuming we have some prior knowledge of the rough overlaps between data, a scene can also provide a file storing pair associations in pair.txt: 31 | ``` 32 | data_0 data_2 33 | data_0 data_8 34 | data_1 data_3 35 | ... 36 | ``` 37 | Otherwise a random selection will be applied. It is strongly recommended to specify a `pair.txt` to ensure valid self supervision. 38 | 39 | Optionally, `metadata.txt` could be provided for more info. For instance, image-wise intrinsic matrix could be provided per image, where the perception task uses the geometry model to estimate extrinsic matrix between frames: 40 | ``` 41 | data_0 fx_0 fy_0 cx_0 cy_0 42 | data_1 fx_1 fy_1 cx_1 cy_1 43 | ... 44 | ``` 45 | 46 | So the intermediate interface will be based on scenes: 47 | ```python 48 | def parse_scene(self, scene): 49 | # some processing 50 | return {'folder': scene, # str 51 | 'fnames': fnames, # len == n, list of str 52 | 'pairs': pairs, # len == m, list of (i, j) tuple 53 | # Optionally metadata 54 | 'unary_metadata' : unary_metadata, # len == n, list of object 55 | 'binary_metadata': mutual_metadata # len == m, list of object 56 | } 57 | ``` 58 | A list of such `scene`s construct the data field, where `collect_scenes` call `parse_scene`: 59 | ```python 60 | def __init__(self, root, scenes): 61 | self.root = root 62 | self.scenes = self.collect_scenes(root, scenes) 63 | ``` 64 | Now data length is given by the sum of `len(scene['pairs'])`, and the get item function is separated to get the scene id then the pair id, with a map array (details ommitted). 65 | ```python 66 | def __getitem__(self, idx): 67 | # Use the LUT 68 | scene_idx = self.scene_idx_map[idx] 69 | pair_idx = self.pair_idx_map[idx] 70 | 71 | # Access actual data 72 | scene = self.scenes[scene_idx] 73 | folder = scene['folder'] 74 | 75 | i, j = scene['pairs'][pair_idx] 76 | fname_src = scene['fnames'][i] 77 | fname_dst = scene['fnames'][j] 78 | 79 | print(i, j, fname_src, fname_dst) 80 | 81 | data_src = self.load_data(folder, fname_src) 82 | data_dst = self.load_data(folder, fname_dst) 83 | 84 | # Optional. Could be None 85 | metadata_src = scene['unary_metadata'][i] 86 | metadata_dst = scene['unary_metadata'][j] 87 | metadata_pair = scene['binary_metadata'][pair_idx] 88 | 89 | return data_src, data_dst, metadata_src, metadata_dst, metadata_pair 90 | ``` 91 | 92 | In reality, there could be minor changes in the dataset structure. For instance, there could be subfolders in a scene, and the corresponding `pairs.txt` and `metadata.txt` are renamed and outside the data folder. 93 | ``` 94 | root/ 95 | |_ scene_0/ 96 | |_ day/ 97 | |_ images/ 98 | |_ data_0.jpg 99 | |_ data_1.jpg 100 | |_ ... 101 | |_ pairs.txt 102 | |_ cameras.txt 103 | |_ night/ 104 | |_ images/ 105 | |_ data_0.jpg 106 | |_ data_1.jpg 107 | |_ ... 108 | |_ pairs.txt 109 | |_ cameras.txt 110 | ``` 111 | In this case, we only need to override `parse_scene` to re-interpret the low level structure, and override `collect_scenes` to collate various subscenes from a scene. -------------------------------------------------------------------------------- /code/dataset/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class DatasetBase: 6 | def __init__(self, root, scenes): 7 | self.root = root 8 | self.scenes = self.collect_scenes(root, scenes) 9 | 10 | scene_ids = [] 11 | pair_ids = [] 12 | 13 | for i, scene in enumerate(self.scenes): 14 | num_pairs = len(scene['pairs']) 15 | scene_ids.append(np.ones((num_pairs), dtype=np.int) * i) 16 | pair_ids.append(np.arange(0, num_pairs, dtype=np.int)) 17 | 18 | self.scene_idx_map = np.concatenate(scene_ids) 19 | self.pair_idx_map = np.concatenate(pair_ids) 20 | 21 | def __len__(self): 22 | return len(self.scene_idx_map) 23 | 24 | def __getitem__(self, idx): 25 | # Use the LUT 26 | scene_idx = self.scene_idx_map[idx] 27 | pair_idx = self.pair_idx_map[idx] 28 | 29 | # Access actual data 30 | scene = self.scenes[scene_idx] 31 | folder = scene['folder'] 32 | 33 | i, j = scene['pairs'][pair_idx] 34 | fname_src = scene['fnames'][i] 35 | fname_dst = scene['fnames'][j] 36 | 37 | data_src = self.load_data(folder, fname_src) 38 | data_dst = self.load_data(folder, fname_dst) 39 | 40 | # Optional. Could be None 41 | info_src = scene['unary_info'][i] 42 | info_dst = scene['unary_info'][j] 43 | info_pair = scene['binary_info'][pair_idx] 44 | 45 | return data_src, data_dst, info_src, info_dst, info_pair 46 | 47 | # NOTE: override in inheritance 48 | def parse_scene(self, root, scene): 49 | return { 50 | 'folder': scene, 51 | 'fnames': [], 52 | 'pairs': [], 53 | 'unary_info': [], 54 | 'binary_info': [] 55 | } 56 | 57 | # NOTE: override in inheritance 58 | def load_data(self, folder, fname): 59 | return os.path.join(folder, fname) 60 | 61 | # NOTE: optionally override in inheritance, if a scene includes more than 1 subset 62 | def collect_scenes(self, root, scenes): 63 | return [self.parse_scene(root, scene) for scene in scenes] 64 | -------------------------------------------------------------------------------- /code/dataset/megadepth_sgp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import cv2 8 | import numpy as np 9 | import open3d as o3d 10 | 11 | from dataset.base import DatasetBase 12 | from geometry.image import skew, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches 13 | 14 | from tqdm import tqdm 15 | 16 | PSEUDO_LABEL_FNAME = 'pseudo-label.log' 17 | 18 | 19 | # Train and test sets are identical for CAPS 20 | class DatasetMegaDepthSGP(DatasetBase): 21 | def __init__(self, 22 | data_root, 23 | scenes, 24 | label_root, 25 | mode, 26 | inlier_ratio_thr=0.3, 27 | num_matches_thr=100, 28 | sample_rate=0.2): 29 | self.label_root = label_root 30 | self.inlier_ratio_thr = inlier_ratio_thr 31 | self.num_matches_thr = num_matches_thr 32 | self.sample_rate = sample_rate 33 | 34 | if not os.path.exists(label_root): 35 | print( 36 | 'label root {} does not exist, entering teaching mode.'.format( 37 | label_root)) 38 | self.mode = 'teaching' 39 | os.makedirs(label_root, exist_ok=True) 40 | elif mode == 'teaching': 41 | print('label root {} will be overwritten to enter teaching mode'. 42 | format(label_root)) 43 | self.mode = 'teaching' 44 | else: 45 | print('label root {} exists, entering learning mode.'.format( 46 | label_root)) 47 | self.mode = 'learning' 48 | 49 | super(DatasetMegaDepthSGP, self).__init__(data_root, scenes) 50 | 51 | # override 52 | def parse_scene(self, root, scene): 53 | if self.mode == 'teaching': 54 | return self._parse_scene_teaching(root, scene) 55 | elif self.mode == 'learning': 56 | return self._parse_scene_learning(root, scene) 57 | else: 58 | print('Unsupported mode, abort') 59 | exit() 60 | 61 | def write_pseudo_label(self, idx, label, info): 62 | scene_idx = self.scene_idx_map[idx] 63 | pair_idx = self.pair_idx_map[idx] 64 | 65 | # Access actual data 66 | scene = self.scenes[scene_idx] 67 | i, j = scene['pairs'][pair_idx] 68 | folder = scene['folder'] 69 | 70 | num_inliers, num_matches = info 71 | label_file = os.path.join(self.label_root, folder, PSEUDO_LABEL_FNAME) 72 | with open(label_file, 'a') as f: 73 | f.write('{} {} {} {} '.format(i, j, num_inliers, num_matches)) 74 | label_str = ' '.join(map(str, label.flatten())) 75 | f.write(label_str) 76 | f.write('\n') 77 | 78 | def _deterministic_shuffle_(self, seq): 79 | import random 80 | random.Random(15213).shuffle(seq) 81 | 82 | def _parse_scene_teaching(self, root, scene): 83 | # Generate pseudo labels 84 | label_path = os.path.join(self.label_root, scene) 85 | os.makedirs(label_path, exist_ok=True) 86 | label_file = os.path.join(label_path, PSEUDO_LABEL_FNAME) 87 | 88 | if os.path.exists(label_file): 89 | os.remove(label_file) 90 | with open(label_file, 'w') as f: 91 | pass 92 | 93 | scene_path = os.path.join(root, scene) 94 | 95 | fnames = os.listdir(os.path.join(scene_path, 'images')) 96 | fnames_map = {fname: i for i, fname in enumerate(fnames)} 97 | 98 | cam_fname = os.path.join(scene_path, 'img_cam.txt') 99 | with open(cam_fname, 'r') as f: 100 | cam_content = f.readlines() 101 | 102 | cnt = 0 103 | intrinsics = np.zeros((len(fnames), 3, 3)) 104 | extrinsics = np.zeros((len(fnames), 4, 4)) 105 | for line in cam_content: 106 | line = line.strip() 107 | if len(line) > 0 and line[0] != "#": 108 | lst = line.split() 109 | fname = lst[0] 110 | idx = fnames_map[fname] 111 | 112 | fx, fy = float(lst[3]), float(lst[4]) 113 | cx, cy = float(lst[5]), float(lst[6]) 114 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0, 115 | 1]).reshape((3, 3)) 116 | cnt += 1 117 | 118 | assert cnt == len(fnames) 119 | 120 | # Load pairs.txt 121 | pair_fname = os.path.join(scene_path, 'pairs.txt') 122 | with open(pair_fname, 'r') as f: 123 | pair_content = f.readlines() 124 | 125 | pairs = [] 126 | for line in pair_content: 127 | lst = line.strip().split(' ') 128 | src_fname = lst[0] 129 | dst_fname = lst[1] 130 | 131 | src_idx = fnames_map[src_fname] 132 | dst_idx = fnames_map[dst_fname] 133 | pairs.append((src_idx, dst_idx)) 134 | 135 | pairs_cnt = len(pairs) 136 | idx_selection = np.arange(pairs_cnt) 137 | self._deterministic_shuffle_(idx_selection) 138 | idx_selection = idx_selection[:int(self.sample_rate * 139 | pairs_cnt)].astype(int) 140 | 141 | return { 142 | 'folder': scene, 143 | 'fnames': fnames, 144 | 'pairs': np.asarray(pairs)[idx_selection], 145 | 'unary_info': intrinsics, 146 | 'binary_info': [None for i in range(len(pairs))] 147 | } 148 | 149 | def _parse_scene_learning(self, root, scene): 150 | # Load pseudo labels 151 | label_path = os.path.join(self.label_root, scene, PSEUDO_LABEL_FNAME) 152 | if not os.path.exists(label_path): 153 | raise Exception('{} not found', label_path) 154 | 155 | scene_path = os.path.join(root, scene) 156 | 157 | fnames = os.listdir(os.path.join(scene_path, 'images')) 158 | fnames_map = {fname: i for i, fname in enumerate(fnames)} 159 | 160 | cam_fname = os.path.join(scene_path, 'img_cam.txt') 161 | with open(cam_fname, 'r') as f: 162 | cam_content = f.readlines() 163 | 164 | cnt = 0 165 | intrinsics = np.zeros((len(fnames), 3, 3)) 166 | for line in cam_content: 167 | line = line.strip() 168 | if len(line) > 0 and line[0] != "#": 169 | lst = line.split() 170 | fname = lst[0] 171 | idx = fnames_map[fname] 172 | 173 | fx, fy = float(lst[3]), float(lst[4]) 174 | cx, cy = float(lst[5]), float(lst[6]) 175 | 176 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0, 177 | 1]).reshape((3, 3)) 178 | cnt += 1 179 | 180 | assert cnt == len(fnames) 181 | 182 | with open(label_path, 'r') as f: 183 | pair_content = f.readlines() 184 | 185 | pairs = [] 186 | binary_info = [] 187 | 188 | for line in pair_content: 189 | lst = line.strip().split(' ') 190 | src_idx = int(lst[0]) 191 | dst_idx = int(lst[1]) 192 | 193 | num_inliers = float(lst[2]) 194 | num_matches = float(lst[3]) 195 | 196 | F_data = list(map(float, lst[4:])) 197 | F = np.array(F_data).reshape((3, 3)) 198 | 199 | if num_matches >= self.num_matches_thr \ 200 | and (num_inliers / num_matches) >= self.inlier_ratio_thr: 201 | pairs.append((src_idx, dst_idx)) 202 | binary_info.append(F) 203 | 204 | return { 205 | 'folder': scene, 206 | 'fnames': fnames, 207 | 'pairs': pairs, 208 | 'unary_info': intrinsics, 209 | 'binary_info': binary_info 210 | } 211 | 212 | # override 213 | def load_data(self, folder, fname): 214 | fname = os.path.join(self.root, folder, 'images', fname) 215 | return cv2.imread(fname) 216 | 217 | # override 218 | def collect_scenes(self, root, scenes): 219 | scene_collection = [] 220 | 221 | for scene in scenes: 222 | scene_path = os.path.join(root, scene) 223 | subdirs = os.listdir(scene_path) 224 | for subdir in subdirs: 225 | if subdir.startswith('dense') and \ 226 | os.path.isdir( 227 | os.path.join(scene_path, subdir)): 228 | scene_dict = self.parse_scene( 229 | root, os.path.join(scene, subdir, 'aligned')) 230 | scene_collection.append(scene_dict) 231 | 232 | return scene_collection 233 | -------------------------------------------------------------------------------- /code/dataset/megadepth_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import cv2 8 | import numpy as np 9 | import open3d as o3d 10 | 11 | from dataset.base import DatasetBase 12 | from geometry.image import compute_fundamental_from_poses, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches 13 | 14 | 15 | # Train and test sets are identical for CAPS 16 | class DatasetMegaDepthTest(DatasetBase): 17 | def __init__(self, data_root, scenes, label_root): 18 | self.data_root = data_root 19 | super(DatasetMegaDepthTest, self).__init__(label_root, scenes) 20 | 21 | # override 22 | def parse_scene(self, label_root, scene): 23 | scene_path = os.path.join(label_root, scene) 24 | 25 | # Load cameras 26 | cam_fname = os.path.join(scene_path, 'img_cam.txt') 27 | with open(cam_fname, 'r') as f: 28 | cam_content = f.readlines() 29 | 30 | cnt = 0 31 | fnames = [] 32 | fnames_map = {} 33 | intrinsics = [] 34 | extrinsics = [] 35 | for i, line in enumerate(cam_content): 36 | line = line.strip() 37 | if len(line) > 0 and line[0] != "#": 38 | lst = line.split() 39 | seq = lst[0] 40 | fname = lst[1] 41 | 42 | fx, fy = float(lst[4]), float(lst[5]) 43 | cx, cy = float(lst[6]), float(lst[7]) 44 | 45 | R = np.array(lst[8:17]).reshape((3, 3)) 46 | t = np.array(lst[17:20]) 47 | T = np.eye(4) 48 | T[:3, :3] = R 49 | T[:3, 3] = t 50 | 51 | fnames.append( 52 | os.path.join(self.data_root, seq, 'dense', 'aligned', 53 | 'images', fname)) 54 | fnames_map[fname] = i 55 | intrinsics.append( 56 | np.array([fx, 0, cx, 0, fy, cy, 0, 0, 1]).reshape((3, 3))) 57 | extrinsics.append(T) 58 | 59 | # Load pairs.txt 60 | pair_fname = os.path.join(scene_path, 'pairs.txt') 61 | with open(pair_fname, 'r') as f: 62 | pair_content = f.readlines() 63 | 64 | pairs = [] 65 | for line in pair_content: 66 | lst = line.strip().split(' ') 67 | seq = lst[0] 68 | src_fname = lst[1] 69 | dst_fname = lst[2] 70 | 71 | src_idx = fnames_map[src_fname] 72 | dst_idx = fnames_map[dst_fname] 73 | pairs.append((src_idx, dst_idx)) 74 | 75 | return { 76 | 'folder': scene, 77 | 'fnames': fnames, 78 | 'pairs': pairs, 79 | 'unary_info': [(K, T) for K, T in zip(intrinsics, extrinsics)], 80 | 'binary_info': [None for i in range(len(pairs))] 81 | } 82 | 83 | # override 84 | def load_data(self, folder, fname): 85 | return cv2.imread(fname) 86 | -------------------------------------------------------------------------------- /code/dataset/megadepth_train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import cv2 8 | import numpy as np 9 | import open3d as o3d 10 | 11 | from dataset.base import DatasetBase 12 | from geometry.image import compute_fundamental_from_poses, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches 13 | 14 | 15 | class DatasetMegaDepthTrain(DatasetBase): 16 | def __init__(self, root, scenes): 17 | super(DatasetMegaDepthTrain, self).__init__(root, scenes) 18 | 19 | # override 20 | def parse_scene(self, root, scene): 21 | scene_path = os.path.join(root, scene) 22 | 23 | fnames = os.listdir(os.path.join(scene_path, 'images')) 24 | fnames_map = {fname: i for i, fname in enumerate(fnames)} 25 | 26 | # Load pairs.txt 27 | pair_fname = os.path.join(scene_path, 'pairs.txt') 28 | with open(pair_fname, 'r') as f: 29 | pair_content = f.readlines() 30 | 31 | pairs = [] 32 | for line in pair_content: 33 | lst = line.strip().split(' ') 34 | src_fname = lst[0] 35 | dst_fname = lst[1] 36 | 37 | src_idx = fnames_map[src_fname] 38 | dst_idx = fnames_map[dst_fname] 39 | pairs.append((src_idx, dst_idx)) 40 | 41 | cam_fname = os.path.join(scene_path, 'img_cam.txt') 42 | with open(cam_fname, 'r') as f: 43 | cam_content = f.readlines() 44 | 45 | cnt = 0 46 | intrinsics = np.zeros((len(fnames), 3, 3)) 47 | extrinsics = np.zeros((len(fnames), 4, 4)) 48 | for line in cam_content: 49 | line = line.strip() 50 | if len(line) > 0 and line[0] != "#": 51 | lst = line.split() 52 | fname = lst[0] 53 | idx = fnames_map[fname] 54 | 55 | fx, fy = float(lst[3]), float(lst[4]) 56 | cx, cy = float(lst[5]), float(lst[6]) 57 | 58 | R = np.array(lst[7:16]).reshape((3, 3)) 59 | t = np.array(lst[16:19]) 60 | T = np.eye(4) 61 | T[:3, :3] = R 62 | T[:3, 3] = t 63 | 64 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0, 65 | 1]).reshape((3, 3)) 66 | extrinsics[idx] = T 67 | cnt += 1 68 | 69 | assert cnt == len(fnames) 70 | 71 | return { 72 | 'folder': scene, 73 | 'fnames': fnames, 74 | 'pairs': pairs, 75 | 'unary_info': [(K, T) for K, T in zip(intrinsics, extrinsics)], 76 | 'binary_info': [None for i in range(len(pairs))] 77 | } 78 | 79 | # override 80 | def load_data(self, folder, fname): 81 | fname = os.path.join(self.root, folder, 'images', fname) 82 | return cv2.imread(fname) 83 | 84 | # override 85 | def collect_scenes(self, root, scenes): 86 | scene_collection = [] 87 | 88 | for scene in scenes: 89 | scene_path = os.path.join(root, scene) 90 | subdirs = os.listdir(scene_path) 91 | for subdir in subdirs: 92 | if subdir.startswith('dense') and \ 93 | os.path.isdir( 94 | os.path.join(scene_path, subdir)): 95 | scene_dict = self.parse_scene( 96 | root, os.path.join(scene, subdir, 'aligned')) 97 | scene_collection.append(scene_dict) 98 | 99 | return scene_collection 100 | -------------------------------------------------------------------------------- /code/dataset/threedmatch_sgp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import glob 8 | 9 | import open3d as o3d 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from dataset.base import DatasetBase 14 | from geometry.pointcloud import make_o3d_pointcloud, extract_feats, match_feats, solve, refine 15 | PSEUDO_LABEL_FNAME = 'pseudo-label.log' 16 | 17 | 18 | class Dataset3DMatchSGP(DatasetBase): 19 | ''' 20 | During teaching: labels are written to a separate directory 21 | During learning: it acts like the train, with labels in a separate directory 22 | ''' 23 | def __init__(self, data_root, scenes, label_root, mode, overlap_thr=0.3): 24 | self.label_root = label_root 25 | self.overlap_thr = overlap_thr 26 | 27 | if not os.path.exists(label_root): 28 | print( 29 | 'label root {} does not exist, entering teaching mode.'.format( 30 | label_root)) 31 | self.mode = 'teaching' 32 | os.makedirs(label_root, exist_ok=True) 33 | elif mode == 'teaching': 34 | print('label root {} will be overwritten to enter teaching mode'. 35 | format(label_root)) 36 | self.mode = 'teaching' 37 | else: 38 | print('label root {} exists, entering learning mode.'.format( 39 | label_root)) 40 | self.mode = 'learning' 41 | 42 | super(Dataset3DMatchSGP, self).__init__(data_root, scenes) 43 | 44 | # override 45 | def parse_scene(self, root, scene): 46 | if self.mode == 'teaching': 47 | return self._parse_scene_teaching(root, scene) 48 | elif self.mode == 'learning': 49 | return self._parse_scene_learning(root, scene) 50 | else: 51 | print('Unsupported mode, abort') 52 | exit() 53 | 54 | # override 55 | def load_data(self, folder, fname): 56 | fname = os.path.join(self.root, folder, fname) 57 | return make_o3d_pointcloud(np.load(fname)['pcd']) 58 | 59 | def write_pseudo_label(self, idx, label, overlap): 60 | scene_idx = self.scene_idx_map[idx] 61 | pair_idx = self.pair_idx_map[idx] 62 | 63 | # Access actual data 64 | scene = self.scenes[scene_idx] 65 | i, j = scene['pairs'][pair_idx] 66 | folder = scene['folder'] 67 | 68 | label_file = os.path.join(self.label_root, folder, PSEUDO_LABEL_FNAME) 69 | with open(label_file, 'a') as f: 70 | f.write('{} {} {} '.format(i, j, overlap)) 71 | label_str = ' '.join(map(str, label.flatten())) 72 | f.write(label_str) 73 | f.write('\n') 74 | 75 | def _parse_scene_teaching(self, root, scene): 76 | # Generate pseudo labels 77 | label_path = os.path.join(self.label_root, scene) 78 | os.makedirs(label_path, exist_ok=True) 79 | label_file = os.path.join(label_path, PSEUDO_LABEL_FNAME) 80 | 81 | if os.path.exists(label_file): 82 | os.remove(label_file) 83 | with open(label_file, 'w') as f: 84 | pass 85 | 86 | # Load actual data 87 | scene_path = os.path.join(root, scene) 88 | 89 | # Load filenames 90 | l = len(scene_path) 91 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz'))) 92 | fnames = [fname[l + 1:] for fname in fnames] 93 | 94 | # Load overlaps.txt 95 | pair_fname = os.path.join(scene_path, 'overlaps.txt') 96 | with open(pair_fname, 'r') as f: 97 | pair_content = f.readlines() 98 | 99 | pairs = [] 100 | binary_info = [] 101 | 102 | # For a 3DMatch dataset for teaching, 103 | # binary_info is (optional) for filtering: overlap 104 | for line in pair_content: 105 | lst = line.strip().split(' ') 106 | src_idx = int(lst[0].split('.')[0].split('_')[-1]) 107 | dst_idx = int(lst[1].split('.')[0].split('_')[-1]) 108 | overlap = float(lst[2]) 109 | 110 | if overlap >= self.overlap_thr: 111 | pairs.append((src_idx, dst_idx)) 112 | binary_info.append(overlap) 113 | 114 | return { 115 | 'folder': scene, 116 | 'fnames': fnames, 117 | 'pairs': pairs, 118 | 'unary_info': [None for i in range(len(fnames))], 119 | 'binary_info': binary_info 120 | } 121 | 122 | ''' 123 | Pseudo-Labels not available. Generate paths for writing to them later. 124 | ''' 125 | 126 | def _parse_scene_learning(self, root, scene): 127 | # Load pseudo labels 128 | label_path = os.path.join(self.label_root, scene, PSEUDO_LABEL_FNAME) 129 | if not os.path.exists(label_path): 130 | raise Exception('{} not found', label_path) 131 | 132 | # Load actual data 133 | scene_path = os.path.join(root, scene) 134 | 135 | # Load filenames 136 | l = len(scene_path) 137 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz'))) 138 | fnames = [fname[l + 1:] for fname in fnames] 139 | 140 | # Load overlaps.txt 141 | with open(label_path, 'r') as f: 142 | pair_content = f.readlines() 143 | 144 | pairs = [] 145 | binary_info = [] 146 | 147 | # For a 3DMatch dataset for learning, 148 | # binary_info is the pseudo label: src to dst transformation. 149 | for line in pair_content: 150 | lst = line.strip().split(' ') 151 | src_idx = int(lst[0].split('.')[0].split('_')[-1]) 152 | dst_idx = int(lst[1].split('.')[0].split('_')[-1]) 153 | overlap = float(lst[2]) 154 | T_data = list(map(float, lst[3:])) 155 | T = np.array(T_data).reshape((4, 4)) 156 | 157 | if overlap >= self.overlap_thr: 158 | pairs.append((src_idx, dst_idx)) 159 | binary_info.append(T) 160 | 161 | return { 162 | 'folder': scene, 163 | 'fnames': fnames, 164 | 'pairs': pairs, 165 | 'unary_info': [None for i in range(len(fnames))], 166 | 'binary_info': binary_info 167 | } 168 | -------------------------------------------------------------------------------- /code/dataset/threedmatch_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import glob 8 | 9 | import open3d as o3d 10 | import numpy as np 11 | 12 | from dataset.base import DatasetBase 13 | 14 | 15 | class Dataset3DMatchTest(DatasetBase): 16 | def __init__(self, root, scenes): 17 | super(Dataset3DMatchTest, self).__init__(root, scenes) 18 | 19 | # override 20 | def parse_scene(self, root, scene): 21 | scene_path = os.path.join(root, scene) 22 | 23 | l = len(scene_path) 24 | fnames = sorted( 25 | glob.glob(os.path.join(scene_path, '*.ply')), 26 | key=lambda fname: int(fname.split('.')[0].split('_')[-1])) 27 | fnames = [fname[l + 1:] for fname in fnames] 28 | 29 | # Load gt 30 | scene_gt_path = os.path.join(root, scene + '-evaluation') 31 | gt_fname = os.path.join(scene_gt_path, 'gt.log') 32 | with open(gt_fname, 'r') as f: 33 | pair_content = f.readlines() 34 | 35 | pairs = [] 36 | binary_info = [] 37 | 38 | # For a 3DMatch test dataset, 39 | # binary_info is the gt label: src to dst transformation. 40 | for i in range(0, len(pair_content), 5): 41 | lst = pair_content[i].strip().split('\t') 42 | src_idx = int(lst[0]) 43 | dst_idx = int(lst[1]) 44 | 45 | res = map(lambda x: np.fromstring(x.strip(), sep='\t'), 46 | pair_content[i+1:i+5]) 47 | T_src2dst = np.stack(list(res)) 48 | pairs.append((src_idx, dst_idx)) 49 | binary_info.append(np.linalg.inv(T_src2dst)) 50 | 51 | return { 52 | 'folder': scene, 53 | 'fnames': fnames, 54 | 'pairs': pairs, 55 | 'unary_info': [None for i in range(len(fnames))], 56 | 'binary_info': binary_info 57 | } 58 | 59 | # override 60 | def load_data(self, folder, fname): 61 | fname = os.path.join(self.root, folder, fname) 62 | return o3d.io.read_point_cloud(fname) 63 | -------------------------------------------------------------------------------- /code/dataset/threedmatch_train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | import glob 8 | 9 | import open3d as o3d 10 | import numpy as np 11 | 12 | from dataset.base import DatasetBase 13 | from geometry.pointcloud import make_o3d_pointcloud 14 | 15 | class Dataset3DMatchTrain(DatasetBase): 16 | def __init__(self, root, scenes, overlap_thr=0.3): 17 | self.overlap_thr = overlap_thr 18 | super(Dataset3DMatchTrain, self).__init__(root, scenes) 19 | 20 | # override 21 | def parse_scene(self, root, scene): 22 | scene_path = os.path.join(root, scene) 23 | 24 | l = len(scene_path) 25 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz'))) 26 | fnames = [fname[l + 1:] for fname in fnames] 27 | 28 | # Load overlaps.txt 29 | pair_fname = os.path.join(scene_path, 'overlaps.txt') 30 | with open(pair_fname, 'r') as f: 31 | pair_content = f.readlines() 32 | 33 | pairs = [] 34 | binary_info = [] 35 | 36 | # For a preprocessed 3DMatch training dataset, 37 | # binary_info is the gt label: pre-calibrated identity matrix. 38 | for line in pair_content: 39 | lst = line.strip().split(' ') 40 | src_idx = int(lst[0].split('.')[0].split('_')[-1]) 41 | dst_idx = int(lst[1].split('.')[0].split('_')[-1]) 42 | overlap = float(lst[2]) 43 | 44 | if overlap >= self.overlap_thr: 45 | pairs.append((src_idx, dst_idx)) 46 | binary_info.append(np.eye(4)) 47 | 48 | return { 49 | 'folder': scene, 50 | 'fnames': fnames, 51 | 'pairs': pairs, 52 | 'unary_info': [None for i in range(len(fnames))], 53 | 'binary_info': binary_info 54 | } 55 | 56 | # override 57 | def load_data(self, folder, fname): 58 | fname = os.path.join(self.root, folder, fname) 59 | return make_o3d_pointcloud(np.load(fname)['pcd']) 60 | -------------------------------------------------------------------------------- /code/geometry/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rotation_error(R0, R1): 5 | return np.abs( 6 | np.arccos(np.clip((np.trace(R0.T @ R1) - 1) / 2.0, -0.999999, 7 | 0.999999))) / np.pi * 180 8 | 9 | 10 | def translation_error(t0, t1): 11 | return np.linalg.norm(t0 - t1) 12 | 13 | 14 | def angular_translation_error(t0, t1): 15 | t0 = t0 / np.linalg.norm(t0) 16 | t1 = t1 / np.linalg.norm(t1) 17 | err = np.arccos(np.clip(np.inner(t0, t1), -0.999999, 18 | 0.999999)) / np.pi * 180 19 | return err 20 | -------------------------------------------------------------------------------- /code/geometry/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | def skew(x): 7 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) 8 | 9 | 10 | def compute_fundamental_from_poses(K_src, K_dst, T_src, T_dst): 11 | T_src2dst = T_dst.dot(np.linalg.inv(T_src)) 12 | R = T_src2dst[:3, :3] 13 | t = T_src2dst[:3, 3] 14 | tx = skew(t) 15 | E = np.dot(tx, R) 16 | return np.linalg.inv(K_dst).T.dot(E).dot(np.linalg.inv(K_src)) 17 | 18 | 19 | def detect_keypoints(im, detector, num_kpts=10000): 20 | gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) 21 | 22 | if detector == 'sift': 23 | sift = cv2.xfeatures2d.SIFT_create(nfeatures=num_kpts) 24 | kpts = sift.detect(gray) 25 | elif detector == 'orb': 26 | orb = cv2.ORB_create(nfeatures=num_kpts) 27 | kpts = orb.detect(gray) 28 | else: 29 | raise NotImplementedError('Unknown keypoint detector.') 30 | 31 | return kpts 32 | 33 | 34 | def extract_feats(im, kpts, feature_type, model=None): 35 | if feature_type == 'sift': 36 | sift = cv2.xfeatures2d.SIFT_create() 37 | kpts, feats = sift.compute(im, kpts) 38 | 39 | elif feature_type == 'orb': 40 | orb = cv2.ORB_create() 41 | kpts, feats = orb.compute(im, kpts) 42 | 43 | elif feature_type == 'caps': 44 | assert model is not None 45 | transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 48 | std=(0.229, 0.224, 0.225)), 49 | ]) 50 | 51 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) 52 | kpts = torch.from_numpy(kpts).float() 53 | 54 | desc_c, desc_f = model.extract_features( 55 | transform(im).unsqueeze(0).to(model.device), 56 | kpts.unsqueeze(0).to(model.device)) 57 | 58 | feats = torch.cat((desc_c, desc_f), 59 | -1).squeeze(0).detach().cpu().numpy() 60 | else: 61 | raise NotImplementedError('Unknown feature descriptor.') 62 | 63 | return feats 64 | 65 | 66 | def match_feats(feats_src, 67 | feats_dst, 68 | feature_type, 69 | ratio_test=True, 70 | ratio_thr=0.6): 71 | if feature_type == 'orb': 72 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) 73 | good = bf.match(feats_src, feats_dst) 74 | else: # sift and caps descriptor 75 | if ratio_test: 76 | bf = cv2.BFMatcher() 77 | matches = bf.knnMatch(feats_src, feats_dst, k=2) 78 | good = [] 79 | for m, n in matches: 80 | if m.distance < ratio_thr * n.distance: 81 | good.append(m) 82 | if len(good) < 50: 83 | matches = sorted(matches, 84 | key=lambda x: x[0].distance / x[1].distance) 85 | good = [m[0] for m in matches[:50]] 86 | 87 | else: 88 | bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) 89 | good = bf.match(feats_src, feats_dst) 90 | if len(good) < 50: 91 | bf = cv2.BFMatcher() 92 | matches = bf.match(feats_src, feats_dst) 93 | matches = sorted(matches, key=lambda x: x.distance) 94 | good = [m for m in matches[:50]] 95 | 96 | return good 97 | 98 | 99 | def estimate_essential(kp1, kp2, matches, K1, K2, th=1e-4): 100 | src_pts = np.float32([kp1[m.queryIdx].pt 101 | for m in matches]).reshape(-1, 1, 2) 102 | dst_pts = np.float32([kp2[m.trainIdx].pt 103 | for m in matches]).reshape(-1, 1, 2) 104 | pts_l_norm = cv2.undistortPoints(src_pts, cameraMatrix=K1, distCoeffs=None) 105 | pts_r_norm = cv2.undistortPoints(dst_pts, cameraMatrix=K2, distCoeffs=None) 106 | E, mask = cv2.findEssentialMat(pts_l_norm, 107 | pts_r_norm, 108 | focal=1.0, 109 | pp=(0., 0.), 110 | method=cv2.RANSAC, 111 | prob=0.999, 112 | threshold=th) 113 | if E.shape != (3, 3): 114 | return np.eye(3), np.zeros((len(matches))), np.eye(3), np.zeros((3)) 115 | 116 | mask = np.squeeze(mask).astype(bool) 117 | _, R, t, _ = cv2.recoverPose(E, pts_l_norm[mask], pts_r_norm[mask]) 118 | t = np.squeeze(t) 119 | return E, mask, R, t 120 | 121 | 122 | def decolorize(img): 123 | return cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), 124 | cv2.COLOR_GRAY2RGB) 125 | 126 | 127 | def draw_matches(kps1, kps2, tentatives, img1, img2, H, mask): 128 | if H is None: 129 | print("No homography found") 130 | return 131 | matchesMask = mask.ravel().tolist() 132 | h, w, ch = img1.shape 133 | pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], 134 | [w - 1, 0]]).reshape(-1, 1, 2) 135 | dst = cv2.perspectiveTransform(pts, H) 136 | img2_tr = cv2.polylines(decolorize(img2), [np.int32(dst)], True, 137 | (0, 0, 255), 3, cv2.LINE_AA) 138 | draw_params = dict( 139 | matchColor=(255, 255, 0), # draw matches in yellow color 140 | singlePointColor=None, 141 | matchesMask=matchesMask, # draw only inliers 142 | flags=2) 143 | return cv2.drawMatches(decolorize(img1), kps1, img2_tr, kps2, tentatives, 144 | None, **draw_params) 145 | -------------------------------------------------------------------------------- /code/geometry/pointcloud.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import torch 4 | 5 | import MinkowskiEngine as ME 6 | from scipy.spatial import cKDTree 7 | 8 | 9 | def make_o3d_pointcloud(xyz): 10 | pcd = o3d.geometry.PointCloud() 11 | pcd.points = o3d.utility.Vector3dVector(xyz) 12 | return pcd 13 | 14 | 15 | def extract_feats(pcd, feature_type, voxel_size, model=None): 16 | xyz = np.asarray(pcd.points) 17 | _, sel = ME.utils.sparse_quantize(xyz, 18 | return_index=True, 19 | quantization_size=voxel_size) 20 | xyz = xyz[sel] 21 | pcd = make_o3d_pointcloud(xyz) 22 | 23 | if feature_type == 'FPFH': 24 | radius_normal = voxel_size * 2 25 | pcd.estimate_normals( 26 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, 27 | max_nn=30)) 28 | radius_feat = voxel_size * 5 29 | feat = o3d.pipelines.registration.compute_fpfh_feature( 30 | pcd, 31 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feat, 32 | max_nn=100)) 33 | # (N, 33) 34 | return pcd, feat.data.T 35 | 36 | elif feature_type == 'FCGF': 37 | DEVICE = torch.device('cuda') 38 | coords = ME.utils.batched_coordinates( 39 | [torch.floor(torch.from_numpy(xyz) / voxel_size).int()]).to(DEVICE) 40 | 41 | feats = torch.ones(coords.size(0), 1).to(DEVICE) 42 | sinput = ME.SparseTensor(feats, coordinates=coords) # .to(DEVICE) 43 | 44 | # (N, 32) 45 | return pcd, model(sinput).F.detach().cpu().numpy() 46 | 47 | else: 48 | raise NotImplementedError( 49 | 'Unimplemented feature type {}'.format(feature_type)) 50 | 51 | 52 | def find_knn_cpu(feat0, feat1, knn=1, return_distance=False): 53 | feat1tree = cKDTree(feat1) 54 | dists, nn_inds = feat1tree.query(feat0, k=knn, n_jobs=-1) 55 | if return_distance: 56 | return nn_inds, dists 57 | else: 58 | return nn_inds 59 | 60 | 61 | def match_feats(feat_src, feat_dst, mutual_filter=True, k=1): 62 | if not mutual_filter: 63 | nns01 = find_knn_cpu(feat_src, feat_dst, knn=1, return_distance=False) 64 | corres01_idx0 = np.arange(len(nns01)).squeeze() 65 | corres01_idx1 = nns01.squeeze() 66 | return np.stack((corres01_idx0, corres01_idx1)).T 67 | else: 68 | # for each feat in src, find its k=1 nearest neighbours 69 | nns01 = find_knn_cpu(feat_src, feat_dst, knn=1, return_distance=False) 70 | # for each feat in dst, find its k nearest neighbours 71 | nns10 = find_knn_cpu(feat_dst, feat_src, knn=k, return_distance=False) 72 | # find corrs 73 | num_feats = len(nns01) 74 | corres01 = [] 75 | if k == 1: 76 | for i in range(num_feats): 77 | if i == nns10[nns01[i]]: 78 | corres01.append([i, nns01[i]]) 79 | else: 80 | for i in range(num_feats): 81 | if i in nns10[nns01[i]]: 82 | corres01.append([i, nns01[i]]) 83 | # print( 84 | # f'Before mutual filter: {num_feats}, after mutual_filter with k={k}: {len(corres01)}.' 85 | # ) 86 | 87 | # Fallback if mutual filter is too aggressive 88 | if len(corres01) < 10: 89 | nns01 = find_knn_cpu(feat_src, 90 | feat_dst, 91 | knn=1, 92 | return_distance=False) 93 | corres01_idx0 = np.arange(len(nns01)).squeeze() 94 | corres01_idx1 = nns01.squeeze() 95 | return np.stack((corres01_idx0, corres01_idx1)).T 96 | 97 | return np.asarray(corres01) 98 | 99 | 100 | def weighted_procrustes(A, B, weights=None): 101 | num_pts = A.shape[1] 102 | if weights is None: 103 | weights = np.ones(num_pts) 104 | 105 | # compute weighted center 106 | A_center = A @ weights / np.sum(weights) 107 | B_center = B @ weights / np.sum(weights) 108 | 109 | # compute relative positions 110 | A_ref = A - A_center[:, np.newaxis] 111 | B_ref = B - B_center[:, np.newaxis] 112 | 113 | # compute rotation 114 | M = B_ref @ np.diag(weights) @ A_ref.T 115 | U, _, Vh = np.linalg.svd(M) 116 | S = np.identity(3) 117 | S[-1, -1] = np.linalg.det(U) * np.linalg.det(Vh) 118 | R = U @ S @ Vh 119 | 120 | # compute translation 121 | t = B_center - R @ A_center 122 | 123 | return R, t 124 | 125 | 126 | def solve(src, dst, corres, solver_type, distance_thr, ransac_iters, 127 | confidence): 128 | if solver_type.startswith('RANSAC'): 129 | corres = o3d.utility.Vector2iVector(corres) 130 | 131 | result = o3d.pipelines.registration.registration_ransac_based_on_correspondence( 132 | src, dst, corres, distance_thr, 133 | o3d.pipelines.registration.TransformationEstimationPointToPoint( 134 | False), 3, [], 135 | o3d.pipelines.registration.RANSACConvergenceCriteria( 136 | ransac_iters, confidence)) 137 | 138 | return result.transformation, result.fitness 139 | 140 | else: 141 | raise NotImplementedError( 142 | 'Unimplemented solver type {}'.format(solver_type)) 143 | 144 | 145 | def refine(src, dst, ransac_T, distance_thr): 146 | result = o3d.pipelines.registration.registration_icp( 147 | src, dst, distance_thr, ransac_T, 148 | o3d.pipelines.registration.TransformationEstimationPointToPoint()) 149 | icp_T = result.transformation 150 | icp_fitness = result.fitness 151 | 152 | fitness = icp_fitness * np.minimum( 153 | 1.0, 154 | float(len(dst.points)) / float(len(src.points))) 155 | 156 | return icp_T, fitness 157 | -------------------------------------------------------------------------------- /code/perception2d/adaptor.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | caps_path = os.path.join(project_path, 'ext', 'caps') 8 | sys.path.append(caps_path) 9 | 10 | from ext.caps.CAPS.caps_model import CAPSModel 11 | from ext.caps.utils import cycle 12 | 13 | from dataset.megadepth_train import DatasetMegaDepthTrain 14 | from dataset.megadepth_test import DatasetMegaDepthTest 15 | from dataset.megadepth_sgp import DatasetMegaDepthSGP 16 | from geometry.image import * 17 | 18 | from geometry.common import rotation_error, angular_translation_error 19 | 20 | from tensorboardX import SummaryWriter 21 | import configargparse 22 | 23 | import torch 24 | from torch.utils.data import Dataset 25 | import numpy as np 26 | import cv2 27 | 28 | import utils 29 | import collections 30 | from tqdm import tqdm 31 | import dataloader.data_utils as data_utils 32 | 33 | rand = np.random.RandomState(234) 34 | 35 | 36 | class CAPSConfigParser(configargparse.ArgParser): 37 | def __init__(self): 38 | super().__init__(default_config_files=[ 39 | os.path.join(os.path.dirname(__file__), 'caps_train_config.yml') 40 | ], 41 | conflict_handler='resolve') 42 | 43 | ## path options 44 | self.add('--datadir', type=str, help='the dataset directory') 45 | self.add("--logdir", 46 | type=str, 47 | default='caps_logs', 48 | help='dir of tensorboard logs') 49 | self.add("--outdir", 50 | type=str, 51 | default='caps_outputs', 52 | help='dir of output e.g., ckpts') 53 | self.add( 54 | "--ckpt_path", 55 | type=str, 56 | default='', 57 | help='specific checkpoint path to load the model from, ' 58 | 'if not specified, automatically reload from most recent checkpoints' 59 | ) 60 | self.add('--pseudo_label_dir', 61 | type=str, 62 | default='caps_pseudo_label', 63 | help='the pseudo-gt directory storing pairs and F matrices') 64 | self.add( 65 | '--label_dir', 66 | type=str, 67 | default='', 68 | help= 69 | 'the gt directory storing pairs and F matrices. Reserved for pose test set.' 70 | ) 71 | 72 | # SGP options 73 | self.add('--scenes', 74 | nargs='+', 75 | help='scenes used for training/testing') 76 | self.add('--inlier_ratio_thr', type=float, default=0.001) 77 | self.add('--num_matches_thr', type=int, default=100) 78 | self.add('--sample_rate', 79 | type=float, 80 | default=1, 81 | help='rate of samples from the huge megadepth dataset') 82 | self.add('--num_kpts', 83 | type=int, 84 | default=10000, 85 | help='number of key points detected during teaching') 86 | self.add('--match_ratio_test', 87 | type=bool, 88 | default=True, 89 | help='performs ratio test in feature matching') 90 | self.add('--match_ratio_thr', 91 | type=float, 92 | default=0.75, 93 | help='ratio between best and second best matchings') 94 | self.add('--ransac_thr', 95 | type=float, 96 | default=1e-3, 97 | help='RANSAC threshold in estimating essential matrices') 98 | 99 | self.add( 100 | '--restart_meta_iter', 101 | type=int, 102 | default=-1, 103 | help='start of teacher-student iterations. -1 indicates bootstrap') 104 | self.add('--max_meta_iters', 105 | type=int, 106 | default=2, 107 | help='number of teacher-student iterations') 108 | self.add('--finetune', 109 | action='store_true', 110 | help='train from previous checkpoint during SGP.') 111 | 112 | ## general options 113 | self.add("--exp_name", type=str, help='experiment name') 114 | self.add('--n_iters', 115 | type=int, 116 | default=100, 117 | help='max number of training iterations') 118 | self.add("--save_interval", 119 | type=int, 120 | default=100, 121 | help='frequency of weight ckpt saving') 122 | self.add('--phase', 123 | type=str, 124 | default='train', 125 | choices=['train', 'val', 'test']) 126 | 127 | # data options 128 | self.add('--workers', 129 | type=int, 130 | help='number of data loading workers', 131 | default=8) 132 | self.add('--num_pts', 133 | type=int, 134 | default=500, 135 | help='num of points trained in each pair') 136 | self.add('--train_kp', 137 | type=str, 138 | default='mixed', 139 | help='sift/random/mixed') 140 | self.add('--prune_kp', 141 | type=int, 142 | default=1, 143 | help='if prune non-matchable keypoints') 144 | 145 | # training options 146 | self.add('--batch_size', type=int, default=2, help='input batch size') 147 | self.add('--lr', type=float, default=1e-4, help='base learning rate') 148 | self.add( 149 | "--lrate_decay_steps", 150 | type=int, 151 | default=80000, 152 | help= 153 | 'decay learning rate by a factor every specified number of steps') 154 | self.add( 155 | "--lrate_decay_factor", 156 | type=float, 157 | default=0.5, 158 | help= 159 | 'decay learning rate by a factor every specified number of steps') 160 | 161 | ## model options 162 | self.add( 163 | '--backbone', 164 | type=str, 165 | default='resnet50', 166 | help= 167 | 'backbone for feature representation extraction. supported: resent' 168 | ) 169 | self.add( 170 | '--pretrained', 171 | type=int, 172 | default=1, 173 | help='if use ImageNet pretrained weights to initialize the network' 174 | ) 175 | self.add('--coarse_feat_dim', 176 | type=int, 177 | default=128, 178 | help='the feature dimension for coarse level features') 179 | self.add('--fine_feat_dim', 180 | type=int, 181 | default=128, 182 | help='the feature dimension for fine level features') 183 | self.add( 184 | '--prob_from', 185 | type=str, 186 | default='correlation', 187 | help= 188 | 'compute prob by softmax(correlation score), or softmax(-distance),' 189 | 'options: correlation|distance') 190 | self.add( 191 | '--window_size', 192 | type=float, 193 | default=0.125, 194 | help='the size of the window, w.r.t image width at the fine level') 195 | self.add('--use_nn', 196 | type=int, 197 | default=1, 198 | help='if use nearest neighbor in the coarse level') 199 | 200 | ## loss function options 201 | self.add('--std', 202 | type=int, 203 | default=1, 204 | help='reweight loss using the standard deviation') 205 | self.add('--w_epipolar_coarse', 206 | type=float, 207 | default=1, 208 | help='coarse level epipolar loss weight') 209 | self.add('--w_epipolar_fine', 210 | type=float, 211 | default=1, 212 | help='fine level epipolar loss weight') 213 | self.add('--w_cycle_coarse', 214 | type=float, 215 | default=0.1, 216 | help='coarse level cycle consistency loss weight') 217 | self.add('--w_cycle_fine', 218 | type=float, 219 | default=0.1, 220 | help='fine level cycle consistency loss weight') 221 | self.add('--w_std', 222 | type=float, 223 | default=0, 224 | help='the weight for the loss on std') 225 | self.add( 226 | '--th_cycle', 227 | type=float, 228 | default=0.025, 229 | help= 230 | 'if the distance (normalized scale) from the prediction to epipolar line > this th, ' 231 | 'do not add the cycle consistency loss') 232 | self.add( 233 | '--th_epipolar', 234 | type=float, 235 | default=0.5, 236 | help= 237 | 'if the distance (normalized scale) from the prediction to epipolar line > this th, ' 238 | 'do not add the epipolar loss') 239 | 240 | ## logging options 241 | self.add('--log_scalar_interval', 242 | type=int, 243 | default=20, 244 | help='print interval') 245 | self.add('--log_img_interval', 246 | type=int, 247 | default=500, 248 | help='log image interval') 249 | 250 | ## eval options 251 | self.add('--extract_img_dir', 252 | type=str, 253 | help='the directory of images to extract features') 254 | self.add('--extract_out_dir', 255 | type=str, 256 | help='the directory of images to extract features') 257 | 258 | def get_config(self): 259 | config = self.parse_args() 260 | return config 261 | 262 | 263 | def my_collate(batch): 264 | ''' Puts each data field into a tensor with outer dimension batch size ''' 265 | batch = list(filter(lambda b: b is not None, batch)) 266 | return torch.utils.data.dataloader.default_collate(batch) 267 | 268 | 269 | class DatasetMegaDepthAdaptor(Dataset): 270 | def __init__(self, dataset, config): 271 | self.dataset = dataset 272 | self.config = config 273 | 274 | if config.phase == 'train': 275 | # augment during training 276 | self.transform = transforms.Compose([ 277 | transforms.ToPILImage(), 278 | transforms.ColorJitter(brightness=1, 279 | contrast=1, 280 | saturation=1, 281 | hue=0.4), 282 | transforms.ToTensor(), 283 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 284 | std=(0.229, 0.224, 0.225)), 285 | ]) 286 | else: 287 | self.transform = transforms.Compose([ 288 | transforms.ToTensor(), 289 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 290 | std=(0.229, 0.224, 0.225)), 291 | ]) 292 | self.phase = config.phase 293 | 294 | def __getitem__(self, idx): 295 | pass 296 | 297 | def __len__(self): 298 | return len(self.dataset) 299 | 300 | 301 | # For vanilla train & test 302 | class DatasetMegaDepthTrainAdaptor(DatasetMegaDepthAdaptor): 303 | def __init__(self, dataset, config): 304 | super(DatasetMegaDepthTrainAdaptor, self).__init__(dataset, config) 305 | 306 | def __getitem__(self, idx): 307 | im_src, im_dst, cam_src, cam_dst, _ = self.dataset[idx] 308 | h, w = im_src.shape[:2] 309 | 310 | im1_ori = torch.from_numpy(im_src) 311 | im2_ori = torch.from_numpy(im_dst) 312 | 313 | im1_tensor = self.transform(im_src) 314 | im2_tensor = self.transform(im_dst) 315 | 316 | coord1 = data_utils.generate_query_kpts(im_src, self.config.train_kp, 317 | 10 * self.config.num_pts, h, w) 318 | 319 | # if no keypoints are detected 320 | if len(coord1) == 0: 321 | return None 322 | 323 | # prune query keypoints that are not likely to have correspondence in the other image 324 | coord1 = utils.random_choice(coord1, self.config.num_pts) 325 | coord1 = torch.from_numpy(coord1).float() 326 | 327 | K_src, T_src = cam_src 328 | K_dst, T_dst = cam_dst 329 | 330 | T_src2dst = torch.from_numpy(T_dst.dot(np.linalg.inv(T_src))) 331 | F = compute_fundamental_from_poses(K_src, K_dst, T_src, T_dst) 332 | F = torch.from_numpy(F).float() / (F[-1, -1] + 1e-16) 333 | 334 | out = { 335 | 'im1_ori': im1_ori, 336 | 'im2_ori': im2_ori, 337 | 'intrinsic1': K_src, 338 | 'intrinsic2': K_dst, 339 | 340 | # Additional, for training 341 | 'im1': im1_tensor, 342 | 'im2': im2_tensor, 343 | 'coord1': coord1, 344 | 'F': F, 345 | 'pose': T_src2dst 346 | } 347 | 348 | return out 349 | 350 | 351 | # For SGP train 352 | class DatasetMegaDepthSGPAdaptor(DatasetMegaDepthAdaptor): 353 | def __init__(self, dataset, config): 354 | super(DatasetMegaDepthSGPAdaptor, self).__init__(dataset, config) 355 | 356 | def __getitem__(self, idx): 357 | im1, im2, K_src, K_dst, F = self.dataset[idx] 358 | h, w = im1.shape[:2] 359 | 360 | im1_ori, im2_ori = torch.from_numpy(im1), torch.from_numpy(im2) 361 | 362 | im1_tensor = self.transform(im1) 363 | im2_tensor = self.transform(im2) 364 | 365 | coord1 = data_utils.generate_query_kpts(im1, self.config.train_kp, 366 | 10 * self.config.num_pts, h, w) 367 | 368 | # if no keypoints are detected 369 | if len(coord1) == 0: 370 | return None 371 | 372 | # prune query keypoints that are not likely to have correspondence in the other image 373 | coord1 = utils.random_choice(coord1, self.config.num_pts) 374 | coord1 = torch.from_numpy(coord1).float() 375 | 376 | F = torch.from_numpy(F).float() / (F[-1, -1] + 1e-16) 377 | 378 | out = { 379 | 'im1_ori': im1_ori, 380 | 'im2_ori': im2_ori, 381 | 'intrinsic1': K_src, 382 | 'intrinsic2': K_dst, 383 | 384 | # Additional, for training 385 | 'im1': im1_tensor, 386 | 'im2': im2_tensor, 387 | 'coord1': coord1, 388 | 'F': F, 389 | 390 | # Pose is required in the base but not used in CAPSModel 391 | 'pose': np.eye(4) 392 | } 393 | 394 | return out 395 | 396 | 397 | def align(im_src, im_dst, K_src, K_dst, detector, feature, model, config): 398 | kpts_src = detect_keypoints(im_src, detector, num_kpts=config.num_kpts) 399 | kpts_dst = detect_keypoints(im_dst, detector, num_kpts=config.num_kpts) 400 | 401 | # Too few keypoints 402 | if len(kpts_src) < 5 or len(kpts_dst) < 5: 403 | return np.eye(3), np.eye(3), np.ones((3)), [], [], [], np.zeros((0)) 404 | 405 | feats_src = extract_feats(im_src, kpts_src, feature, model) 406 | feats_dst = extract_feats(im_dst, kpts_dst, feature, model) 407 | matches = match_feats(feats_src, feats_dst, feature, 408 | config.match_ratio_test, config.match_ratio_thr) 409 | num_matches = len(matches) 410 | 411 | # Too few matches 412 | if num_matches <= 5: # 5-pts method 413 | return np.eye(3), np.eye(3), np.ones( 414 | (3)), kpts_src, kpts_dst, [], np.zeros((len(matches))) 415 | 416 | E, mask, R, t = estimate_essential(kpts_src, 417 | kpts_dst, 418 | matches, 419 | K_src, 420 | K_dst, 421 | th=config.ransac_thr) 422 | F = np.linalg.inv(K_dst).T.dot(E).dot(np.linalg.inv(K_src)) 423 | F = F / (F[-1, -1] + 1e-16) 424 | 425 | return F, R, t, kpts_src, kpts_dst, matches, mask 426 | 427 | 428 | def caps_train(dataset, config): 429 | # save a copy for the current config in out_folder 430 | out_folder = os.path.join(config.outdir, config.exp_name) 431 | os.makedirs(out_folder, exist_ok=True) 432 | f = os.path.join(out_folder, 'config.txt') 433 | with open(f, 'w') as file: 434 | for arg in vars(config): 435 | attr = getattr(config, arg) 436 | file.write('{} = {}\n'.format(arg, attr)) 437 | 438 | # tensorboard writer 439 | tb_log_dir = os.path.join(config.logdir, config.exp_name) 440 | print('tensorboard log files are stored in {}'.format(tb_log_dir)) 441 | writer = SummaryWriter(tb_log_dir) 442 | 443 | # megadepth data loader 444 | dataloader = torch.utils.data.DataLoader(dataset, 445 | batch_size=config.batch_size, 446 | shuffle=True, 447 | num_workers=config.workers, 448 | collate_fn=my_collate) 449 | 450 | model = CAPSModel(config) 451 | 452 | start_step = model.start_step 453 | dataloader_iter = iter(cycle(dataloader)) 454 | for step in range(start_step + 1, start_step + config.n_iters + 1): 455 | data = next(dataloader_iter) 456 | if data is None: 457 | continue 458 | 459 | model.set_input(data) 460 | model.optimize_parameters() 461 | model.write_summary(writer, step) 462 | if step % config.save_interval == 0 and step > 0: 463 | model.save_model(step) 464 | 465 | 466 | def caps_test(dataset, config): 467 | model = CAPSModel(config) 468 | 469 | r_errs = [] 470 | t_errs = [] 471 | 472 | for data in tqdm(dataset): 473 | im_src, im_dst, cam_src, cam_dst, _ = data 474 | 475 | K_src, T_src = cam_src 476 | K_dst, T_dst = cam_dst 477 | T_src2dst_gt = T_dst.dot(np.linalg.inv(T_src)) 478 | 479 | F, R, t, kpts_src, kpts_dst, matches, mask = align( 480 | im_src, im_dst, K_src, K_dst, 'sift', 'caps', model, config) 481 | 482 | r_err = rotation_error(R, T_src2dst_gt[:3, :3]) 483 | t_err = angular_translation_error(t, T_src2dst_gt[:3, 3]) 484 | r_errs.append(r_err) 485 | t_errs.append(t_err) 486 | 487 | if config.debug: 488 | im = draw_matches(kpts_src, kpts_dst, matches, im_src, im_dst, F, 489 | mask) 490 | cv2.imshow('matches', im) 491 | cv2.waitKey(-1) 492 | 493 | return np.array(r_errs), np.array(t_errs) 494 | -------------------------------------------------------------------------------- /code/perception2d/config_sgp.yml: -------------------------------------------------------------------------------- 1 | # training from scratch using a single gpu, default configs in config.py file 2 | exp_name: caps_sgp 3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/train' 4 | pseudo_label_dir: 'caps_pseudo_label' 5 | scenes: [0000, 0005] 6 | 7 | sample_rate: 0.2 8 | inlier_ratio_thr: 0.3 9 | n_iters: 40000 10 | save_interval: 10000 11 | -------------------------------------------------------------------------------- /code/perception2d/config_sgp_sample.yml: -------------------------------------------------------------------------------- 1 | # training from scratch using a single gpu, default configs in config.py file 2 | exp_name: caps_sgp 3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/train' 4 | pseudo_label_dir: 'caps_pseudo_label' 5 | scenes: [sample] 6 | 7 | sample_rate: 0.5 8 | inlier_ratio_thr: 0.3 9 | n_iters: 10 10 | save_interval: 10 11 | -------------------------------------------------------------------------------- /code/perception2d/config_test.yml: -------------------------------------------------------------------------------- 1 | # training from scratch using a single gpu, default configs in config.py file 2 | exp_name: caps_test 3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/test' 4 | label_dir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/test-pose' 5 | ckpt_path: '/home/wei/Downloads/caps-pretrained.pth' 6 | scenes: [easy] 7 | -------------------------------------------------------------------------------- /code/perception2d/config_train.yml: -------------------------------------------------------------------------------- 1 | # training from scratch using a single gpu, default configs in config.py file 2 | exp_name: caps_train 3 | datadir: /home/wei/Workspace/data/CAPS-MegaDepth-release-light/train 4 | scenes: [1001, 0020] 5 | -------------------------------------------------------------------------------- /code/perception2d/sgp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | caps_path = os.path.join(project_path, 'ext', 'caps') 8 | sys.path.append(caps_path) 9 | 10 | import cv2 11 | import torch 12 | 13 | from sgp_base import SGPBase 14 | from dataset.megadepth_sgp import DatasetMegaDepthSGP 15 | from perception2d.adaptor import CAPSConfigParser, DatasetMegaDepthSGPAdaptor, CAPSModel, caps_train, caps_test, align 16 | from geometry.image import * 17 | 18 | 19 | class SGP2DFundamental(SGPBase): 20 | def __init__(self): 21 | super(SGP2DFundamental, self).__init__() 22 | 23 | # override 24 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info, 25 | config): 26 | F, R, t, kpts_src, kpts_dst, matches, mask = align( 27 | src_data, dst_data, src_info, dst_info, 'sift', 'sift', None, 28 | config) 29 | 30 | if config.debug: 31 | im = draw_matches(kpts_src, kpts_dst, matches, src_data, dst_data, 32 | F, mask) 33 | cv2.imshow('matches', im) 34 | cv2.waitKey(-1) 35 | 36 | return F, (mask.sum(), len(matches)) 37 | 38 | # override 39 | def perception(self, src_data, dst_data, src_info, dst_info, model, 40 | config): 41 | F, R, t, kpts_src, kpts_dst, matches, mask = align( 42 | src_data, dst_data, src_info, dst_info, 'sift', 'caps', model, 43 | config) 44 | 45 | if config.debug: 46 | im = draw_matches(kpts_src, kpts_dst, matches, src_data, dst_data, 47 | F, mask) 48 | cv2.imshow('matches', im) 49 | cv2.waitKey(-1) 50 | 51 | return F, (mask.sum(), len(matches)) 52 | 53 | # override 54 | def train_adaptor(self, sgp_dataset, config): 55 | caps_train(sgp_dataset, config) 56 | 57 | def run(self, config): 58 | base_outdir = config.outdir 59 | base_logdir = config.logdir 60 | base_pseudo_label_dir = config.pseudo_label_dir 61 | 62 | pseudo_label_path_bs = os.path.join(base_pseudo_label_dir, 'bs') 63 | 64 | if config.restart_meta_iter < 0: 65 | # Only sample a subset for teaching. 66 | teach_dataset = DatasetMegaDepthSGP(config.datadir, 67 | config.scenes, 68 | pseudo_label_path_bs, 69 | 'teaching', 70 | inlier_ratio_thr=config.inlier_ratio_thr, 71 | num_matches_thr=config.num_matches_thr, 72 | sample_rate=config.sample_rate) 73 | print('Dataset size: {}'.format(len(teach_dataset))) 74 | sgp.teach_bootstrap(teach_dataset, config) 75 | 76 | learn_dataset = DatasetMegaDepthSGPAdaptor( 77 | DatasetMegaDepthSGP(config.datadir, 78 | config.scenes, 79 | pseudo_label_path_bs, 80 | 'learning', 81 | inlier_ratio_thr=config.inlier_ratio_thr, 82 | num_matches_thr=config.num_matches_thr, 83 | sample_rate=1), config) 84 | config.outdir = os.path.join(base_outdir, 'bs') 85 | config.logdir = os.path.join(base_logdir, 'bs') 86 | sgp.learn(learn_dataset, config) 87 | 88 | config.match_ratio_test = False 89 | start_meta_iter = max(config.restart_meta_iter, 0) 90 | for i in range(start_meta_iter, config.max_meta_iters): 91 | pseudo_label_path_i = os.path.join(base_pseudo_label_dir, 92 | '{:02d}'.format(i)) 93 | teach_dataset = DatasetMegaDepthSGP(config.datadir, 94 | config.scenes, 95 | pseudo_label_path_i, 96 | 'teaching', 97 | inlier_ratio_thr=config.inlier_ratio_thr, 98 | num_matches_thr=config.num_matches_thr, 99 | sample_rate=config.sample_rate) 100 | model = CAPSModel(config) 101 | sgp.teach(teach_dataset, model, config) 102 | 103 | learn_dataset = DatasetMegaDepthSGPAdaptor( 104 | DatasetMegaDepthSGP(config.datadir, 105 | config.scenes, 106 | pseudo_label_path_i, 107 | 'learning', 108 | inlier_ratio_thr=config.inlier_ratio_thr, 109 | num_matches_thr=config.num_matches_thr, 110 | sample_rate=1), config) 111 | 112 | if not config.finetune: 113 | config.outdir = os.path.join(base_outdir, '{:02d}'.format(i)) 114 | config.logdir = os.path.join(base_logdir, '{:02d}'.format(i)) 115 | sgp.learn(learn_dataset, config) 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = CAPSConfigParser() 120 | parser.add( 121 | '--config', 122 | is_config_file=True, 123 | default=os.path.join(os.path.dirname(__file__), 124 | 'config_sgp_sample.yml'), 125 | help='YAML config file path. Please refer to caps_config.yml as a ' 126 | 'reference. It overrides the default config file, but will be ' 127 | 'overridden by other command line inputs.') 128 | parser.add('--debug', action='store_true') 129 | config = parser.get_config() 130 | 131 | sgp = SGP2DFundamental() 132 | sgp.run(config) 133 | -------------------------------------------------------------------------------- /code/perception2d/test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | caps_path = os.path.join(project_path, 'ext', 'caps') 8 | sys.path.append(caps_path) 9 | 10 | from dataset.megadepth_test import DatasetMegaDepthTest 11 | from perception2d.adaptor import CAPSConfigParser, caps_test 12 | 13 | import numpy as np 14 | 15 | if __name__ == '__main__': 16 | parser = CAPSConfigParser() 17 | parser.add( 18 | '--config', 19 | is_config_file=True, 20 | default=os.path.join(os.path.dirname(__file__), 'config_test.yml'), 21 | help='YAML config file path. Please refer to caps_config.yml as a ' 22 | 'reference. It overrides the default config file, but will be ' 23 | 'overridden by other command line inputs.') 24 | parser.add('--debug', action='store_true') 25 | parser.add('--output', type=str, default='caps_test_result.npz') 26 | config = parser.get_config() 27 | 28 | # Note: for testing, our own interface would suffices. 29 | config.match_ratio_test = False 30 | dataset = DatasetMegaDepthTest(config.datadir, config.scenes, config.label_dir) 31 | r_errs, t_errs = caps_test(dataset, config) 32 | 33 | rot_recall = (r_errs < 10.0) 34 | angular_trans_recall = (t_errs < 10.0) 35 | print('Rotation Recall: {}/{} = {}'.format( 36 | rot_recall.sum(), len(rot_recall), 37 | float(rot_recall.sum()) / len(rot_recall))) 38 | print('Translation Recall: {}/{} = {}'.format( 39 | angular_trans_recall.sum(), len(angular_trans_recall), 40 | float(angular_trans_recall.sum()) / len(angular_trans_recall))) 41 | 42 | np.savez(config.output, rotation_errs=r_errs, translation_errs=t_errs) 43 | -------------------------------------------------------------------------------- /code/perception2d/train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | caps_path = os.path.join(project_path, 'ext', 'caps') 8 | sys.path.append(caps_path) 9 | 10 | from dataset.megadepth_train import DatasetMegaDepthTrain 11 | from perception2d.adaptor import CAPSConfigParser, DatasetMegaDepthTrainAdaptor, caps_train 12 | 13 | if __name__ == '__main__': 14 | parser = CAPSConfigParser() 15 | parser.add( 16 | '--config', 17 | is_config_file=True, 18 | default=os.path.join(os.path.dirname(__file__), 'config_train.yml'), 19 | help='YAML config file path. Please refer to caps_config.yml as a ' 20 | 'reference. It overrides the default config file, but will be ' 21 | 'overridden by other command line inputs.') 22 | config = parser.get_config() 23 | 24 | # Note: for training, we need to wrap up with an adaptor to provide a consistent interface. 25 | dataset = DatasetMegaDepthTrainAdaptor( 26 | DatasetMegaDepthTrain(config.datadir, config.scenes), config) 27 | caps_train(dataset, config) 28 | -------------------------------------------------------------------------------- /code/perception3d/adaptor.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF') 8 | sys.path.append(fcgf_path) 9 | 10 | import torch 11 | from easydict import EasyDict as edict 12 | 13 | from ext.FCGF.lib.data_loaders import * 14 | from ext.FCGF.lib.trainer import * 15 | from ext.FCGF.model import load_model 16 | 17 | from dataset.threedmatch_train import Dataset3DMatchTrain 18 | from dataset.threedmatch_test import Dataset3DMatchTest 19 | from dataset.threedmatch_sgp import Dataset3DMatchSGP 20 | 21 | from geometry.pointcloud import * 22 | from geometry.common import rotation_error, translation_error 23 | 24 | from tqdm import tqdm 25 | 26 | import configargparse 27 | 28 | 29 | def reload_config(config): 30 | dconfig = vars(config) 31 | 32 | if config.resume_dir: 33 | resume_config = json.load(open(config.resume_dir + '/config.json', 34 | 'r')) 35 | for k in dconfig: 36 | if k not in ['resume_dir'] and k in resume_config: 37 | dconfig[k] = resume_config[k] 38 | dconfig['resume'] = resume_config['out_dir'] + '/checkpoint.pth' 39 | 40 | return edict(dconfig) 41 | 42 | 43 | class FCGFConfigParser(configargparse.ArgParser): 44 | def __init__(self): 45 | super().__init__(default_config_files=[ 46 | os.path.join(os.path.dirname(__file__), 'fcgf_config.yml') 47 | ], 48 | conflict_handler='resolve') 49 | 50 | # Mainly used params 51 | self.add('--dataset_path', 52 | type=str, 53 | default="/home/wei/Workspace/data/threedmatch_reorg") 54 | self.add('--scenes', 55 | nargs='+', 56 | help='scenes used for training/testing') 57 | self.add('--out_dir', 58 | type=str, 59 | default='fcgf_outputs', 60 | help='outputs containing summary and checkpoints') 61 | 62 | self.add( 63 | '--restart_meta_iter', 64 | type=int, 65 | default=-1, 66 | help='Restart of teacher-student iterations. -1 indicates bootstrap' 67 | ) 68 | self.add('--meta_iters', 69 | type=int, 70 | default=2, 71 | help='number of teacher-student iterations') 72 | self.add('--finetune', 73 | action='store_true', 74 | help='train from previous checkpoint during SGP.') 75 | 76 | self.add('--pseudo_label_dir', 77 | type=str, 78 | help='the pseudo-gt directory storing pairs and T matrices') 79 | 80 | self.add('--overlap_thr', 81 | type=float, 82 | default=0.3, 83 | help='overlap threshold to filter outlier pairs') 84 | self.add('--voxel_size', type=float, default=0.05) 85 | self.add('--mutual_filter', type=bool, default=False) 86 | self.add('--ransac_iters', type=int, default=10000) 87 | self.add('--confidence', type=float, default=0.9999) 88 | 89 | # Other core configs from the FCGF repo. 90 | # See https://github.com/chrischoy/FCGF/blob/master/config.py 91 | self.add('--trainer', 92 | type=str, 93 | default='HardestContrastiveLossTrainer') 94 | self.add('--save_freq_epoch', type=int, default=1) 95 | self.add('--batch_size', type=int, default=4) 96 | self.add('--val_batch_size', type=int, default=1) 97 | 98 | # Hard negative mining 99 | self.add('--use_hard_negative', type=bool, default=True) 100 | self.add('--hard_negative_sample_ratio', type=int, default=0.05) 101 | self.add('--hard_negative_max_num', type=int, default=3000) 102 | self.add('--num_pos_per_batch', type=int, default=1024) 103 | self.add('--num_hn_samples_per_batch', type=int, default=256) 104 | 105 | # Metric learning loss 106 | self.add('--neg_thresh', type=float, default=1.4) 107 | self.add('--pos_thresh', type=float, default=0.1) 108 | self.add('--neg_weight', type=float, default=1) 109 | 110 | # Data augmentation 111 | self.add('--use_random_scale', type=bool, default=False) 112 | self.add('--min_scale', type=float, default=0.8) 113 | self.add('--max_scale', type=float, default=1.2) 114 | self.add('--use_random_rotation', type=bool, default=True) 115 | self.add('--rotation_range', type=float, default=360) 116 | 117 | # Data loader configs 118 | self.add('--train_phase', type=str, default="train") 119 | self.add('--val_phase', type=str, default="val") 120 | self.add('--test_phase', type=str, default="test") 121 | 122 | self.add('--stat_freq', type=int, default=40) 123 | self.add('--test_valid', type=bool, default=True) 124 | self.add('--val_max_iter', type=int, default=400) 125 | self.add('--val_epoch_freq', type=int, default=1) 126 | self.add('--positive_pair_search_voxel_size_multiplier', 127 | type=float, 128 | default=1.5) 129 | 130 | self.add('--hit_ratio_thresh', type=float, default=0.1) 131 | 132 | # Triplets 133 | self.add('--triplet_num_pos', type=int, default=256) 134 | self.add('--triplet_num_hn', type=int, default=512) 135 | self.add('--triplet_num_rand', type=int, default=1024) 136 | 137 | # Network specific configurations 138 | self.add('--model', type=str, default='ResUNetBN2C') 139 | self.add('--model_n_out', 140 | type=int, 141 | default=32, 142 | help='Feature dimension') 143 | self.add('--conv1_kernel_size', type=int, default=5) 144 | self.add('--normalize_feature', type=bool, default=True) 145 | self.add('--dist_type', type=str, default='L2') 146 | self.add('--best_val_metric', type=str, default='feat_match_ratio') 147 | 148 | # Optimizer arguments 149 | self.add('--optimizer', type=str, default='SGD') 150 | self.add('--max_epoch', type=int, default=100) 151 | self.add('--lr', type=float, default=1e-1) 152 | self.add('--momentum', type=float, default=0.8) 153 | self.add('--sgd_momentum', type=float, default=0.9) 154 | self.add('--sgd_dampening', type=float, default=0.1) 155 | self.add('--adam_beta1', type=float, default=0.9) 156 | self.add('--adam_beta2', type=float, default=0.999) 157 | self.add('--weight_decay', type=float, default=1e-4) 158 | self.add('--iter_size', 159 | type=int, 160 | default=1, 161 | help='accumulate gradient') 162 | self.add('--bn_momentum', type=float, default=0.05) 163 | self.add('--exp_gamma', type=float, default=0.99) 164 | self.add('--scheduler', type=str, default='ExpLR') 165 | 166 | self.add('--use_gpu', type=bool, default=True) 167 | self.add('--weights', type=str, default=None) 168 | self.add('--weights_dir', type=str, default=None) 169 | self.add('--resume', type=str, default=None) 170 | self.add('--resume_dir', type=str, default=None) 171 | self.add('--train_num_thread', type=int, default=2) 172 | self.add('--val_num_thread', type=int, default=1) 173 | self.add('--test_num_thread', type=int, default=2) 174 | self.add('--fast_validation', type=bool, default=False) 175 | self.add( 176 | '--nn_max_n', 177 | type=int, 178 | default=500, 179 | help= 180 | 'The maximum number of features to find nearest neighbors in batch' 181 | ) 182 | 183 | def get_config(self): 184 | config = self.parse_args() 185 | config.device = 'cuda' if config.use_gpu else 'cpu' 186 | 187 | return reload_config(config) 188 | 189 | 190 | def get_trainer(trainer): 191 | if trainer == 'ContrastiveLossTrainer': 192 | return ContrastiveLossTrainer 193 | elif trainer == 'HardestContrastiveLossTrainer': 194 | return HardestContrastiveLossTrainer 195 | elif trainer == 'TripletLossTrainer': 196 | return TripletLossTrainer 197 | elif trainer == 'HardestTripletLossTrainer': 198 | return HardestTripletLossTrainer 199 | else: 200 | raise ValueError(f'Trainer {trainer} not found') 201 | 202 | 203 | class DatasetFCGFAdaptor(torch.utils.data.Dataset): 204 | ''' 205 | Wrapper dataset for our data format and FCGF's sample format 206 | ''' 207 | def __init__(self, dataset, config): 208 | self.dataset = dataset 209 | self.randg = np.random.RandomState() 210 | self.config = config 211 | 212 | def reset_seed(self, seed): 213 | self.randg.seed(seed) 214 | 215 | def apply_transform(self, pts, trans): 216 | R = trans[:3, :3] 217 | T = trans[:3, 3] 218 | pts = pts @ R.T + T 219 | return pts 220 | 221 | def __len__(self): 222 | return len(self.dataset) 223 | 224 | def __getitem__(self, idx): 225 | pcd0, pcd1, _, _, trans = self.dataset[idx] 226 | 227 | xyz0 = np.asarray(pcd0.points) 228 | xyz1 = np.asarray(pcd1.points) 229 | 230 | # Data augmentation 231 | T0 = sample_random_trans(xyz0, self.randg, 360) 232 | T1 = sample_random_trans(xyz1, self.randg, 360) 233 | trans = T1 @ trans @ np.linalg.inv(T0) 234 | 235 | xyz0 = self.apply_transform(xyz0, T0) 236 | xyz1 = self.apply_transform(xyz1, T1) 237 | 238 | # Voxelization after random transformation 239 | voxel_size = 0.05 240 | _, sel0 = ME.utils.sparse_quantize(xyz0, 241 | return_index=True, 242 | quantization_size=voxel_size) 243 | _, sel1 = ME.utils.sparse_quantize(xyz1, 244 | return_index=True, 245 | quantization_size=voxel_size) 246 | xyz0 = xyz0[sel0] 247 | xyz1 = xyz1[sel1] 248 | 249 | # Make point clouds using voxelized points 250 | pcd0 = make_o3d_pointcloud(xyz0) 251 | pcd1 = make_o3d_pointcloud(xyz1) 252 | matches = get_matching_indices(pcd0, pcd1, trans, voxel_size * 2) 253 | 254 | # Dummy features 255 | feats0 = np.ones((xyz0.shape[0], 1)) 256 | feats1 = np.ones((xyz1.shape[0], 1)) 257 | 258 | # Coordinates 259 | coords0 = np.floor(xyz0 / voxel_size) 260 | coords1 = np.floor(xyz1 / voxel_size) 261 | 262 | return (xyz0, xyz1, coords0, coords1, feats0, feats1, matches, trans) 263 | 264 | 265 | def load_fcgf_model(config): 266 | resume_ckpt_path = config.resume 267 | input_ckpt_path = config.weights 268 | out_ckpt_path = os.path.join(config.out_dir, 'checkpoint.pth') 269 | 270 | if resume_ckpt_path is not None and os.path.isfile(resume_ckpt_path): 271 | ckpt_path = resume_ckpt_path 272 | elif input_ckpt_path is not None and os.path.isfile(input_ckpt_path): 273 | ckpt_path = input_ckpt_path 274 | elif out_ckpt_path is not None and os.path.isfile(out_ckpt_path): 275 | ckpt_path = out_ckpt_path 276 | else: 277 | raise NotImplementedError('checkpoint not found, abort') 278 | 279 | print(f'load FCGF from checkpoint {ckpt_path}.') 280 | checkpoint = torch.load(ckpt_path) 281 | ckpt_cfg = checkpoint['config'] 282 | 283 | Model = load_model(ckpt_cfg['model']) 284 | model = Model(in_channels=1, 285 | out_channels=ckpt_cfg['model_n_out'], 286 | bn_momentum=ckpt_cfg['bn_momentum'], 287 | normalize_feature=ckpt_cfg['normalize_feature'], 288 | conv1_kernel_size=ckpt_cfg['conv1_kernel_size'], 289 | D=3) 290 | model.load_state_dict(checkpoint['state_dict']) 291 | return model.to(config.device) 292 | 293 | 294 | def register(pcd_src, pcd_dst, feature, solver, model, config): 295 | pcd_src, feat_src = extract_feats(pcd_src, feature, config.voxel_size, 296 | model) 297 | pcd_dst, feat_dst = extract_feats(pcd_dst, feature, config.voxel_size, 298 | model) 299 | corrs = match_feats(feat_src, feat_dst, mutual_filter=config.mutual_filter) 300 | 301 | if len(corrs) < 10: 302 | print('Too few corres ({}), abort'.format(len(corrs))) 303 | return np.eye(4), 0 304 | 305 | T, fitness = solve(pcd_src, pcd_dst, corrs, solver, 306 | config.voxel_size * 1.4, config.ransac_iters, 307 | config.confidence) 308 | if fitness > 1e-6: 309 | T, fitness = refine(pcd_src, pcd_dst, T, config.voxel_size * 1.4) 310 | 311 | return T, fitness 312 | 313 | 314 | def fcgf_train(dataset, config): 315 | ch = logging.StreamHandler(sys.stdout) 316 | logging.getLogger().setLevel(logging.INFO) 317 | logging.basicConfig(format='%(asctime)s %(message)s', 318 | datefmt='%m/%d %H:%M:%S', 319 | handlers=[ch]) 320 | 321 | torch.manual_seed(0) 322 | torch.cuda.manual_seed(0) 323 | 324 | logging.basicConfig(level=logging.INFO, format="") 325 | 326 | dataloader = torch.utils.data.DataLoader(dataset, 327 | batch_size=8, 328 | shuffle=True, 329 | num_workers=8, 330 | collate_fn=collate_pair_fn, 331 | pin_memory=False, 332 | drop_last=True) 333 | Trainer = get_trainer(config.trainer) 334 | trainer = Trainer(config=config, data_loader=dataloader) 335 | trainer.train() 336 | 337 | 338 | def fcgf_test(dataset, config): 339 | model = load_fcgf_model(config) 340 | 341 | r_errs = [] 342 | t_errs = [] 343 | 344 | for data in tqdm(dataset): 345 | pcd_src, pcd_dst, _, _, T_gt = data 346 | 347 | T, fitness = register(pcd_src, pcd_dst, 'FCGF', 'RANSAC', model, 348 | config) 349 | r_err = rotation_error(T[:3, :3], T_gt[:3, :3]) 350 | t_err = translation_error(T[:3, 3], T_gt[:3, 3]) 351 | r_errs.append(r_err) 352 | t_errs.append(t_err) 353 | 354 | if config.debug: 355 | pcd_src.paint_uniform_color([1, 0, 0]) 356 | pcd_dst.paint_uniform_color([0, 1, 0]) 357 | o3d.visualization.draw_geometries([pcd_src.transform(T), pcd_dst]) 358 | 359 | return np.array(r_errs), np.array(t_errs) 360 | -------------------------------------------------------------------------------- /code/perception3d/config_sgp.yml: -------------------------------------------------------------------------------- 1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg' 2 | pseudo_label_dir: fcgf_pseudo_label 3 | scenes: [7-scenes-chess@seq-01, 7-scenes-chess@seq-02, 7-scenes-chess@seq-03, 7-scenes-chess@seq-04, 7-scenes-chess@seq-05, 7-scenes-chess@seq-06, 7-scenes-fire@seq-01, 7-scenes-fire@seq-02, 7-scenes-fire@seq-03, 7-scenes-fire@seq-04, 7-scenes-heads@seq-01, 7-scenes-heads@seq-02, 7-scenes-office@seq-01, 7-scenes-office@seq-02, 7-scenes-office@seq-03, 7-scenes-office@seq-04, 7-scenes-office@seq-05, 7-scenes-office@seq-06, 7-scenes-office@seq-07, 7-scenes-office@seq-08, 7-scenes-office@seq-09, 7-scenes-office@seq-10, 7-scenes-pumpkin@seq-01, 7-scenes-pumpkin@seq-02, 7-scenes-pumpkin@seq-03, 7-scenes-pumpkin@seq-06, 7-scenes-pumpkin@seq-07, 7-scenes-pumpkin@seq-08, 7-scenes-redkitchen@seq-01, 7-scenes-redkitchen@seq-02, 7-scenes-redkitchen@seq-03, 7-scenes-redkitchen@seq-04, 7-scenes-redkitchen@seq-05, 7-scenes-redkitchen@seq-06, 7-scenes-redkitchen@seq-07, 7-scenes-redkitchen@seq-08, 7-scenes-redkitchen@seq-11, 7-scenes-redkitchen@seq-12, 7-scenes-redkitchen@seq-13, 7-scenes-redkitchen@seq-14, 7-scenes-stairs@seq-01, 7-scenes-stairs@seq-02, 7-scenes-stairs@seq-03, 7-scenes-stairs@seq-04, 7-scenes-stairs@seq-05, 7-scenes-stairs@seq-06, analysis-by-synthesis-apt1-kitchen@seq-01, analysis-by-synthesis-apt1-living@seq-01, analysis-by-synthesis-apt2-bed@seq-01, analysis-by-synthesis-apt2-kitchen@seq-01, analysis-by-synthesis-apt2-living@seq-01, analysis-by-synthesis-apt2-luke@seq-01, analysis-by-synthesis-office2-5a@seq-01, analysis-by-synthesis-office2-5b@seq-01, bundlefusion-apt0@seq-01, bundlefusion-apt1@seq-01, bundlefusion-apt2@seq-01, bundlefusion-copyroom@seq-01, bundlefusion-office0@seq-01, bundlefusion-office1@seq-01, bundlefusion-office2@seq-01, bundlefusion-office3@seq-01, rgbd-scenes-v2-scene_01@seq-01, rgbd-scenes-v2-scene_02@seq-01, rgbd-scenes-v2-scene_03@seq-01, rgbd-scenes-v2-scene_04@seq-01, rgbd-scenes-v2-scene_05@seq-01, rgbd-scenes-v2-scene_06@seq-01, rgbd-scenes-v2-scene_07@seq-01, rgbd-scenes-v2-scene_08@seq-01, rgbd-scenes-v2-scene_09@seq-01, rgbd-scenes-v2-scene_10@seq-01, rgbd-scenes-v2-scene_11@seq-01, rgbd-scenes-v2-scene_12@seq-01, rgbd-scenes-v2-scene_13@seq-01, rgbd-scenes-v2-scene_14@seq-01, sun3d-brown_bm_1-brown_bm_1@seq-01, sun3d-brown_bm_4-brown_bm_4@seq-01, sun3d-brown_cogsci_1-brown_cogsci_1@seq-01, sun3d-brown_cs_2-brown_cs2@seq-01, sun3d-brown_cs_3-brown_cs3@seq-01, sun3d-harvard_c11-hv_c11_2@seq-01, sun3d-harvard_c3-hv_c3_1@seq-01, sun3d-harvard_c5-hv_c5_1@seq-01, sun3d-harvard_c6-hv_c6_1@seq-01, sun3d-harvard_c8-hv_c8_3@seq-01, sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika@seq-01, sun3d-home_md-home_md_scan9_2012_sep_30@seq-01, sun3d-hotel_nips2012-nips_4@seq-01, sun3d-hotel_sf-scan1@seq-01, sun3d-hotel_uc-scan3@seq-01, sun3d-hotel_umd-maryland_hotel1@seq-01, sun3d-hotel_umd-maryland_hotel3@seq-01, sun3d-mit_32_d507-d507_2@seq-01, sun3d-mit_46_ted_lab1-ted_lab_2@seq-01, sun3d-mit_76_417-76-417b@seq-01, sun3d-mit_76_studyroom-76-1studyroom2@seq-01, sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika@seq-01, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika@seq-01, sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika@seq-01] 4 | -------------------------------------------------------------------------------- /code/perception3d/config_sgp_sample.yml: -------------------------------------------------------------------------------- 1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg' 2 | pseudo_label_dir: fcgf_pseudo_label 3 | scenes: [7-scenes-chess@seq-01] 4 | -------------------------------------------------------------------------------- /code/perception3d/config_test.yml: -------------------------------------------------------------------------------- 1 | dataset_path: '/home/wei/Workspace/data/threedmatch_test' 2 | weights: '/home/wei/Downloads/fcgf-pretrained.pth' 3 | scenes: [7-scenes-redkitchen, sun3d-home_at-home_at_scan1_2013_jan_1, sun3d-home_md-home_md_scan9_2012_sep_30, sun3d-hotel_uc-scan3, sun3d-hotel_umd-maryland_hotel1, sun3d-hotel_umd-maryland_hotel3, sun3d-mit_76_studyroom-76-1studyroom2, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika] 4 | -------------------------------------------------------------------------------- /code/perception3d/config_train.yml: -------------------------------------------------------------------------------- 1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg' 2 | scenes: [7-scenes-chess@seq-01, 7-scenes-chess@seq-02, 7-scenes-chess@seq-03, 7-scenes-chess@seq-04, 7-scenes-chess@seq-05, 7-scenes-chess@seq-06, 7-scenes-fire@seq-01, 7-scenes-fire@seq-02, 7-scenes-fire@seq-03, 7-scenes-fire@seq-04, 7-scenes-heads@seq-01, 7-scenes-heads@seq-02, 7-scenes-office@seq-01, 7-scenes-office@seq-02, 7-scenes-office@seq-03, 7-scenes-office@seq-04, 7-scenes-office@seq-05, 7-scenes-office@seq-06, 7-scenes-office@seq-07, 7-scenes-office@seq-08, 7-scenes-office@seq-09, 7-scenes-office@seq-10, 7-scenes-pumpkin@seq-01, 7-scenes-pumpkin@seq-02, 7-scenes-pumpkin@seq-03, 7-scenes-pumpkin@seq-06, 7-scenes-pumpkin@seq-07, 7-scenes-pumpkin@seq-08, 7-scenes-redkitchen@seq-01, 7-scenes-redkitchen@seq-02, 7-scenes-redkitchen@seq-03, 7-scenes-redkitchen@seq-04, 7-scenes-redkitchen@seq-05, 7-scenes-redkitchen@seq-06, 7-scenes-redkitchen@seq-07, 7-scenes-redkitchen@seq-08, 7-scenes-redkitchen@seq-11, 7-scenes-redkitchen@seq-12, 7-scenes-redkitchen@seq-13, 7-scenes-redkitchen@seq-14, 7-scenes-stairs@seq-01, 7-scenes-stairs@seq-02, 7-scenes-stairs@seq-03, 7-scenes-stairs@seq-04, 7-scenes-stairs@seq-05, 7-scenes-stairs@seq-06, analysis-by-synthesis-apt1-kitchen@seq-01, analysis-by-synthesis-apt1-living@seq-01, analysis-by-synthesis-apt2-bed@seq-01, analysis-by-synthesis-apt2-kitchen@seq-01, analysis-by-synthesis-apt2-living@seq-01, analysis-by-synthesis-apt2-luke@seq-01, analysis-by-synthesis-office2-5a@seq-01, analysis-by-synthesis-office2-5b@seq-01, bundlefusion-apt0@seq-01, bundlefusion-apt1@seq-01, bundlefusion-apt2@seq-01, bundlefusion-copyroom@seq-01, bundlefusion-office0@seq-01, bundlefusion-office1@seq-01, bundlefusion-office2@seq-01, bundlefusion-office3@seq-01, rgbd-scenes-v2-scene_01@seq-01, rgbd-scenes-v2-scene_02@seq-01, rgbd-scenes-v2-scene_03@seq-01, rgbd-scenes-v2-scene_04@seq-01, rgbd-scenes-v2-scene_05@seq-01, rgbd-scenes-v2-scene_06@seq-01, rgbd-scenes-v2-scene_07@seq-01, rgbd-scenes-v2-scene_08@seq-01, rgbd-scenes-v2-scene_09@seq-01, rgbd-scenes-v2-scene_10@seq-01, rgbd-scenes-v2-scene_11@seq-01, rgbd-scenes-v2-scene_12@seq-01, rgbd-scenes-v2-scene_13@seq-01, rgbd-scenes-v2-scene_14@seq-01, sun3d-brown_bm_1-brown_bm_1@seq-01, sun3d-brown_bm_4-brown_bm_4@seq-01, sun3d-brown_cogsci_1-brown_cogsci_1@seq-01, sun3d-brown_cs_2-brown_cs2@seq-01, sun3d-brown_cs_3-brown_cs3@seq-01, sun3d-harvard_c11-hv_c11_2@seq-01, sun3d-harvard_c3-hv_c3_1@seq-01, sun3d-harvard_c5-hv_c5_1@seq-01, sun3d-harvard_c6-hv_c6_1@seq-01, sun3d-harvard_c8-hv_c8_3@seq-01, sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika@seq-01, sun3d-home_md-home_md_scan9_2012_sep_30@seq-01, sun3d-hotel_nips2012-nips_4@seq-01, sun3d-hotel_sf-scan1@seq-01, sun3d-hotel_uc-scan3@seq-01, sun3d-hotel_umd-maryland_hotel1@seq-01, sun3d-hotel_umd-maryland_hotel3@seq-01, sun3d-mit_32_d507-d507_2@seq-01, sun3d-mit_46_ted_lab1-ted_lab_2@seq-01, sun3d-mit_76_417-76-417b@seq-01, sun3d-mit_76_studyroom-76-1studyroom2@seq-01, sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika@seq-01, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika@seq-01, sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika@seq-01] 3 | -------------------------------------------------------------------------------- /code/perception3d/sgp.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF') 8 | sys.path.append(fcgf_path) 9 | 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import numpy as np 14 | import open3d as o3d 15 | 16 | from sgp_base import SGPBase 17 | from dataset.threedmatch_sgp import Dataset3DMatchSGP 18 | from perception3d.adaptor import DatasetFCGFAdaptor, FCGFConfigParser, load_fcgf_model, fcgf_train, reload_config, register 19 | 20 | 21 | class SGP3DRegistration(SGPBase): 22 | def __init__(self): 23 | super(SGP3DRegistration, self).__init__() 24 | 25 | # override 26 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info, 27 | config): 28 | T, fitness = register(src_data, dst_data, 'FPFH', 'RANSAC', None, 29 | config) 30 | if config.debug: 31 | src_data.paint_uniform_color([1, 0, 0]) 32 | dst_data.paint_uniform_color([0, 1, 0]) 33 | o3d.visualization.draw([src_data.transform(T), dst_data]) 34 | return T, fitness 35 | 36 | # override 37 | def perception(self, src_data, dst_data, src_info, dst_info, model, 38 | config): 39 | T, fitness = register(src_data, dst_data, 'FCGF', 'RANSAC', model, 40 | config) 41 | if config.debug: 42 | src_data.paint_uniform_color([1, 0, 0]) 43 | dst_data.paint_uniform_color([0, 1, 0]) 44 | o3d.visualization.draw([src_data.transform(T), dst_data]) 45 | return T, fitness 46 | 47 | # override 48 | def train_adaptor(self, sgp_dataset, config): 49 | fcgf_train(sgp_dataset, config) 50 | 51 | def run(self, config): 52 | epochs = config.max_epoch 53 | base_pseudo_label_dir = config.pseudo_label_dir 54 | base_outdir = config.out_dir 55 | 56 | # Bootstrap 57 | if config.restart_meta_iter < 0: 58 | pseudo_label_path_bs = os.path.join(base_pseudo_label_dir, 'bs') 59 | teach_dataset = Dataset3DMatchSGP(config.dataset_path, 60 | config.scenes, 61 | pseudo_label_path_bs, 'teaching', 62 | config.overlap_thr) 63 | # We need mutual filter for less reliable FPFH 64 | config.mutual_filter = True 65 | sgp.teach_bootstrap(teach_dataset, config) 66 | 67 | learn_dataset = DatasetFCGFAdaptor( 68 | Dataset3DMatchSGP(config.dataset_path, config.scenes, 69 | pseudo_label_path_bs, 'learning', 70 | config.overlap_thr), config) 71 | config.out_dir = os.path.join(base_outdir, 'bs') 72 | sgp.learn(learn_dataset, config) 73 | 74 | # Loop 75 | start_meta_iter = max(config.restart_meta_iter, 0) 76 | for i in range(start_meta_iter, config.meta_iters): 77 | pseudo_label_path_i = os.path.join(config.pseudo_label_dir, 78 | '{:02d}'.format(i)) 79 | teach_dataset = Dataset3DMatchSGP(config.dataset_path, 80 | config.scenes, 81 | pseudo_label_path_i, 'teaching', 82 | config.overlap_thr) 83 | 84 | # No mutual filter results in better FCGF teaching 85 | config.mutual_filter = False 86 | model = load_fcgf_model(config) 87 | sgp.teach(teach_dataset, model, config) 88 | 89 | learn_dataset = DatasetFCGFAdaptor( 90 | Dataset3DMatchSGP(config.dataset_path, config.scenes, 91 | pseudo_label_path_i, 'learning', 92 | config.overlap_thr), config) 93 | 94 | # There is a bug in FCGF finetuning. 95 | # Suppose previous epochs are [1, n], 96 | # then finetuning will be [n, n+n] (double counting n), instead of [n+1, n+n] 97 | # To address this without changing the original repo, we need to reduce max epochs by 1. 98 | # The actual finetuning iters will be correct, while FCGF's output will be slightly different. 99 | if config.finetune: 100 | config.resume_dir = config.out_dir 101 | config = reload_config(config) 102 | config.max_epoch += (epochs - 1) 103 | else: 104 | config.out_dir = os.path.join(base_outdir, '{:02d}'.format(i)) 105 | 106 | sgp.learn(learn_dataset, config) 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = FCGFConfigParser() 111 | parser.add( 112 | '--config', 113 | is_config_file=True, 114 | default=os.path.join(os.path.dirname(__file__), 115 | 'config_sgp_sample.yml'), 116 | help='YAML config file path. Please refer to caps_config.yml as a ' 117 | 'reference. It overrides the default config file, but will be ' 118 | 'overridden by other command line inputs.') 119 | parser.add('--debug', action='store_true') 120 | config = parser.get_config() 121 | 122 | sgp = SGP3DRegistration() 123 | sgp.run(config) 124 | -------------------------------------------------------------------------------- /code/perception3d/test.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF') 8 | sys.path.append(fcgf_path) 9 | 10 | from dataset.threedmatch_test import Dataset3DMatchTest 11 | from perception3d.adaptor import FCGFConfigParser, fcgf_test 12 | 13 | import numpy as np 14 | 15 | if __name__ == '__main__': 16 | parser = FCGFConfigParser() 17 | parser.add( 18 | '--config', 19 | is_config_file=True, 20 | default=os.path.join(os.path.dirname(__file__), 'config_test.yml'), 21 | help='YAML config file path. Please refer to caps_config.yml as a ' 22 | 'reference. It overrides the default config file, but will be ' 23 | 'overridden by other command line inputs.') 24 | parser.add('--debug', action='store_true') 25 | parser.add('--output', type=str, default='fcgf_test_result.npz') 26 | config = parser.get_config() 27 | 28 | dataset = Dataset3DMatchTest(config.dataset_path, config.scenes) 29 | r_errs, t_errs = fcgf_test(dataset, config) 30 | 31 | recall = (r_errs < 15.0) * (t_errs < 0.3) 32 | print('Recall: {}/{} = {}'.format(recall.sum(), len(recall), 33 | float(recall.sum()) / len(recall))) 34 | 35 | np.savez(config.output, rotation_errs=r_errs, translation_errs=t_errs) 36 | -------------------------------------------------------------------------------- /code/perception3d/train.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | file_path = os.path.abspath(__file__) 4 | project_path = os.path.dirname(os.path.dirname(file_path)) 5 | sys.path.append(project_path) 6 | 7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF') 8 | sys.path.append(fcgf_path) 9 | 10 | from dataset.threedmatch_train import Dataset3DMatchTrain 11 | from perception3d.adaptor import DatasetFCGFAdaptor, FCGFConfigParser, fcgf_train 12 | 13 | if __name__ == '__main__': 14 | parser = FCGFConfigParser() 15 | parser.add( 16 | '--config', 17 | is_config_file=True, 18 | default=os.path.join(os.path.dirname(__file__), 'config_train.yml'), 19 | help='YAML config file path. Please refer to caps_config.yml as a ' 20 | 'reference. It overrides the default config file, but will be ' 21 | 'overridden by other command line inputs.') 22 | config = parser.get_config() 23 | 24 | dataset = DatasetFCGFAdaptor( 25 | Dataset3DMatchTrain(config.dataset_path, config.scenes, config.overlap_thr), config) 26 | fcgf_train(dataset, config) 27 | -------------------------------------------------------------------------------- /code/sgp_base.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | class SGPBase: 4 | def __init__(self): 5 | pass 6 | 7 | def teach_bootstrap(self, sgp_dataset, config): 8 | ''' 9 | Teach without deep models. Use classical FPFH/SIFT features. 10 | ''' 11 | for i, data in tqdm(enumerate(sgp_dataset)): 12 | src_data, dst_data, src_info, dst_info, pair_info = data 13 | 14 | label, pair_info = self.perception_bootstrap( 15 | src_data, dst_data, src_info, dst_info, config) 16 | sgp_dataset.write_pseudo_label(i, label, pair_info) 17 | 18 | def teach(self, sgp_dataset, model, config): 19 | ''' 20 | Teach with deep models. Use learned FCGF/CAPS features. 21 | ''' 22 | for i, data in tqdm(enumerate(sgp_dataset)): 23 | src_data, dst_data, src_info, dst_info, pair_info = data 24 | 25 | # if self.is_valid(src_info, dst_info, pair_info): 26 | label, pair_info = self.perception(src_data, dst_data, src_info, 27 | dst_info, model, config) 28 | sgp_dataset.write_pseudo_label(i, label, pair_info) 29 | 30 | def learn(self, sgp_dataset, config): 31 | # Adapt and dispatch training script to external implementations 32 | self.train_adaptor(sgp_dataset, config) 33 | 34 | # override 35 | def train_adaptor(self, sgp_dataset, config): 36 | pass 37 | 38 | # override 39 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info): 40 | pass 41 | 42 | # override 43 | def perception(self, src_data, dst_data, src_info, dst_info, model): 44 | pass 45 | 46 | --------------------------------------------------------------------------------