├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
├── caps-megadepth.png
├── caps-scannet.png
├── fcgf-3dmatch.png
├── fpfh-3dmatch.png
└── overview.png
└── code
├── README.md
├── dataset
├── README.md
├── base.py
├── megadepth_sgp.py
├── megadepth_test.py
├── megadepth_train.py
├── threedmatch_sgp.py
├── threedmatch_test.py
└── threedmatch_train.py
├── geometry
├── common.py
├── image.py
└── pointcloud.py
├── perception2d
├── adaptor.py
├── config_sgp.yml
├── config_sgp_sample.yml
├── config_test.yml
├── config_train.yml
├── sgp.py
├── test.py
└── train.py
├── perception3d
├── adaptor.py
├── config_sgp.yml
├── config_sgp_sample.yml
├── config_test.yml
├── config_train.yml
├── sgp.py
├── test.py
└── train.py
└── sgp_base.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | pseudo-label/
141 | 3dmatch_train/
142 | logs/
143 | out/
144 | outputs/
145 | caps_logs/
146 | caps_outputs/
147 | fcgf_outputs/
148 | caps_pseudo_label/
149 | fcgf_pseudo_label/
150 | *.npz
151 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "code/ext/FCGF"]
2 | path = code/ext/FCGF
3 | url = https://github.com/chrischoy/FCGF.git
4 | [submodule "code/ext/caps"]
5 | path = code/ext/caps
6 | url = https://github.com/qianqianwang68/caps.git
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Wei Dong and Heng Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SGP: Self-supervised Geometric Perception
2 | [CVPR 2021 Oral] Self-supervised Geometric Perception
3 | https://arxiv.org/abs/2103.03114
4 |
5 | ## Introduction
6 | In short, SGP is, to the best of our knowledge, the first general framework for feature learning in geometric perception without any supervision from ground-truth geometric labels.
7 |
8 | SGP runs in an EM fashion. It iteratively performs robust estimation of the geometric models to generate pseudo-labels, and feature learning under the supervision of the noisy pseudo-labels.
9 |
10 |
11 |
12 |
13 |
14 | We applied SGP to camera pose estimation and point cloud registration, demonstrating performance that is on par or even superior to supervised oracles in large-scale real datasets.
15 |
16 | ### Camera pose estimation
17 |
18 | Deep image features like [CAPS](https://github.com/qianqianwang68/caps) can be trained with relative pose labels generated by 5pt-RANSAC, bootstraped with the handcrafted SIFT feature. They can be later used in robust relative camera pose estimation.
19 |
20 |
21 |

22 |

23 |
24 |
25 | ### Point cloud registration
26 |
27 | Deep 3D features like [FCGF](https://github.com/chrischoy/FCGF) can be trained with relative pose labels generated by 3pt-RANSAC, bootstraped by the handcrafted FPFH feature. They can be later used in robust point cloud registration.
28 |
29 |
30 |

31 |

32 |
33 |
34 |
35 |
36 | ## Code
37 |
38 | Please see `code/` for detailed intructions about how to use the code base.
39 |
40 |
41 |
42 | ## Citation
43 |
44 | ```
45 | @inproceedings{yang2021sgp,
46 | title={Self-supervised Geometric Perception},
47 | author={Yang, Heng and Dong, Wei and Carlone, Luca and Koltun, Vladlen},
48 | booktitle={CVPR},
49 | year={2021}
50 | }
51 | ```
52 |
53 |
--------------------------------------------------------------------------------
/assets/caps-megadepth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/caps-megadepth.png
--------------------------------------------------------------------------------
/assets/caps-scannet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/caps-scannet.png
--------------------------------------------------------------------------------
/assets/fcgf-3dmatch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/fcgf-3dmatch.png
--------------------------------------------------------------------------------
/assets/fpfh-3dmatch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/fpfh-3dmatch.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theNded/SGP/63d33cc8bffde53676d9c4800f4b11804b53b360/assets/overview.png
--------------------------------------------------------------------------------
/code/README.md:
--------------------------------------------------------------------------------
1 | # Self-supervised Geometric Perception
2 |
3 | ## Disclaimer
4 | In comparison to the code for the paper submission, this repository has been fully rewritten for a better readability and easier generalization. Please file a GitHub issue if there is anything buggy.
5 |
6 | Since the final benchmark results depend on RANSAC (in performing robust model estimation), we expect minor discrepancies comparing to the numbers published in the paper (due to randomness of RANSAC). Again, please submit an issue if a significant difference is observed.
7 |
8 | ### TODO
9 | - [ ] Release Pretrained weights.
10 |
11 | ## Setup
12 | Clone the project by
13 | ```
14 | git clone --recursive https://github.com/theNded/SGP.git
15 | ```
16 | This will by default clone the submodules [FCGF](https://github.com/chrischoy/FCGF) and [CAPS](https://github.com/qianqianwang68/caps) for 3D and 2D perception, respectively. Please follow the instructions in the corresponding repositories to configure the submodule(s) of interest.
17 |
18 | ## Datasets
19 | For the 3D perception task, please download the [3DMatch dataset](https://drive.google.com/file/d/1P5xS4ZGrmuoElZbKeC6bM5UoWz8H9SL1/view) reorganized by us that aggregates point clouds by scenes. The reorganized [test set](https://drive.google.com/file/d/1AmmADbhk5X62Q6CnsbJcwm1BK0Uov1yG/view?usp=sharing) is also available.
20 |
21 | For the 2D perception task, please download the [MegaDepth dataset](https://drive.google.com/file/d/1-o4TRLx6qm8ehQevV7nExmVJXfMxj657/view) provided by the author of CAPS. The test set has not been officially released, so please contact [CAPS authors](https://github.com/qianqianwang68/caps) for the data. We only provide the data loader.
22 |
23 | ## Vanilla training and testing
24 | Copy and/or modify the `config_[train|test].yml` files in `perception3d`. The configurable parameters can be found in `perception3d/adaptor.py`. Then run
25 | ```
26 | python perception3d/train.py --config /path/to/config.yml
27 | python perception3d/test.py --config /path/to/config.yml --weights /path/to/weights.pth
28 | ```
29 | You may also add `--debug` to visualize the registration/alignment results. The same applies to 2D.
30 |
31 | For a sanity check, you may first use pretrained weights of deep features (i.e., supervised oracle) that are available on the correspondent websites/GitHub repos. The system should be able to run seamlessly.
32 |
33 | Note our codebase is non-intrusive, i.e., the original repository are not modified, hence there are minor inconsistencies in configurations between 2D and 3D. For instance, pretrained weights are named as `weights` for FCGF and `ckpt_path` for CAPS. Please carefuly check correspondent config options located in `adaptor.py`.
34 |
35 |
36 | ## Self-supervised training
37 | The training runs in teacher-student meta loops, started with a bootstrap step (`bs`) supervised by hand-crafted features (SIFT/FPFH), followed by actual training loops (`00`, `01`) that trains a deep feature (CAPS/FCGF) with itself. After similarly configuring `config_sgp.yml`, run
38 | ```
39 | python perception3d/sgp.py --config /path/to/config.yml
40 | ```
41 | As the SGP process is time consuming, it is suggested to first perform a sanity check on a minimal set of data, configured in `config_sgp_sample.yml`.
42 |
43 | To test the results per meta-iteration, by default run
44 | ```shell
45 | # 2D
46 | python perception2d/test.py --config perception2d/config_test.yml --ckpt_path caps_outputs/bs/caps_sgp/040000.pth
47 | # 3D
48 | python perception3d/test.py --config perception3d/config_test.yml --weights fcgf_outputs/bs/checkpoint.pth
49 | ```
50 | for the trained feature from bootstrap (`bs`), and
51 | ```shell
52 | # 2D
53 | python perception2d/test.py --config perception2d/config_test.yml --ckpt_path caps_outputs/00/caps_sgp/040000.pth
54 | # 3D
55 | python perception3d/test.py --config perception3d/config_test.yml --weights fcgf_outputs/00/checkpoint.pth
56 | ```
57 | for the trained feature from 0-th meta-iteration (`00`) and following meta iterations.
58 |
59 | To restart or extend current meta iterations, change `restart_meta_iter` and `max_meta_iters` in the configuration.
60 |
61 | ## Extension
62 | To use your own dataset organized by scenes, checkout `dataset/`. A README details how the datasets are organized and how you may extend the base class and parse your scenes.
63 |
64 | To train your own deep feature, checkout `sgp_base.py` and the corresponding `perception2d/` or `perception3d/` files. They share a similar interface for the `bootstrap` teaching-learning and `iterative` self-supervised teaching-learning.
65 |
--------------------------------------------------------------------------------
/code/dataset/README.md:
--------------------------------------------------------------------------------
1 | # Dataloader
2 |
3 | The overall target of a dataloader of SGP is to provide the loader of tuples:
4 | ```python
5 | def __getitem__(self, idx):
6 | # some processing
7 | return data_src, data_dst, info_src, info_dst, info_pair
8 | ```
9 | where each tuple contains
10 | - `data_src`, `data_dst`: image for 2D perception, point cloud for 3D perception.
11 | - `info`: additional properties, e.g. (unary) intrinsics for one image, (mutual) overlaps between two point clouds. They do not directly provide the supervision, but may serve as very weak supervision signals in geometry perception tasks.
12 |
13 |
14 | As SGP works on pairs of data with overlaps in a scene, we assume a large dataset is consisting of various smaller scenes where overlaps exist:
15 | ```
16 | root/
17 | |_ scene_0/
18 | |_ data_0
19 | |_ data_1
20 | |_ ...
21 | |_ data_n
22 | |_ pairs.txt
23 | |_ metadata.txt
24 | |_ scene_1/
25 | |_ ...
26 | |_ scene_m/
27 | ```
28 | Here, the root folder contains `m` scenes. Each scene includes `n` data files.
29 |
30 | Assuming we have some prior knowledge of the rough overlaps between data, a scene can also provide a file storing pair associations in pair.txt:
31 | ```
32 | data_0 data_2
33 | data_0 data_8
34 | data_1 data_3
35 | ...
36 | ```
37 | Otherwise a random selection will be applied. It is strongly recommended to specify a `pair.txt` to ensure valid self supervision.
38 |
39 | Optionally, `metadata.txt` could be provided for more info. For instance, image-wise intrinsic matrix could be provided per image, where the perception task uses the geometry model to estimate extrinsic matrix between frames:
40 | ```
41 | data_0 fx_0 fy_0 cx_0 cy_0
42 | data_1 fx_1 fy_1 cx_1 cy_1
43 | ...
44 | ```
45 |
46 | So the intermediate interface will be based on scenes:
47 | ```python
48 | def parse_scene(self, scene):
49 | # some processing
50 | return {'folder': scene, # str
51 | 'fnames': fnames, # len == n, list of str
52 | 'pairs': pairs, # len == m, list of (i, j) tuple
53 | # Optionally metadata
54 | 'unary_metadata' : unary_metadata, # len == n, list of object
55 | 'binary_metadata': mutual_metadata # len == m, list of object
56 | }
57 | ```
58 | A list of such `scene`s construct the data field, where `collect_scenes` call `parse_scene`:
59 | ```python
60 | def __init__(self, root, scenes):
61 | self.root = root
62 | self.scenes = self.collect_scenes(root, scenes)
63 | ```
64 | Now data length is given by the sum of `len(scene['pairs'])`, and the get item function is separated to get the scene id then the pair id, with a map array (details ommitted).
65 | ```python
66 | def __getitem__(self, idx):
67 | # Use the LUT
68 | scene_idx = self.scene_idx_map[idx]
69 | pair_idx = self.pair_idx_map[idx]
70 |
71 | # Access actual data
72 | scene = self.scenes[scene_idx]
73 | folder = scene['folder']
74 |
75 | i, j = scene['pairs'][pair_idx]
76 | fname_src = scene['fnames'][i]
77 | fname_dst = scene['fnames'][j]
78 |
79 | print(i, j, fname_src, fname_dst)
80 |
81 | data_src = self.load_data(folder, fname_src)
82 | data_dst = self.load_data(folder, fname_dst)
83 |
84 | # Optional. Could be None
85 | metadata_src = scene['unary_metadata'][i]
86 | metadata_dst = scene['unary_metadata'][j]
87 | metadata_pair = scene['binary_metadata'][pair_idx]
88 |
89 | return data_src, data_dst, metadata_src, metadata_dst, metadata_pair
90 | ```
91 |
92 | In reality, there could be minor changes in the dataset structure. For instance, there could be subfolders in a scene, and the corresponding `pairs.txt` and `metadata.txt` are renamed and outside the data folder.
93 | ```
94 | root/
95 | |_ scene_0/
96 | |_ day/
97 | |_ images/
98 | |_ data_0.jpg
99 | |_ data_1.jpg
100 | |_ ...
101 | |_ pairs.txt
102 | |_ cameras.txt
103 | |_ night/
104 | |_ images/
105 | |_ data_0.jpg
106 | |_ data_1.jpg
107 | |_ ...
108 | |_ pairs.txt
109 | |_ cameras.txt
110 | ```
111 | In this case, we only need to override `parse_scene` to re-interpret the low level structure, and override `collect_scenes` to collate various subscenes from a scene.
--------------------------------------------------------------------------------
/code/dataset/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 |
5 | class DatasetBase:
6 | def __init__(self, root, scenes):
7 | self.root = root
8 | self.scenes = self.collect_scenes(root, scenes)
9 |
10 | scene_ids = []
11 | pair_ids = []
12 |
13 | for i, scene in enumerate(self.scenes):
14 | num_pairs = len(scene['pairs'])
15 | scene_ids.append(np.ones((num_pairs), dtype=np.int) * i)
16 | pair_ids.append(np.arange(0, num_pairs, dtype=np.int))
17 |
18 | self.scene_idx_map = np.concatenate(scene_ids)
19 | self.pair_idx_map = np.concatenate(pair_ids)
20 |
21 | def __len__(self):
22 | return len(self.scene_idx_map)
23 |
24 | def __getitem__(self, idx):
25 | # Use the LUT
26 | scene_idx = self.scene_idx_map[idx]
27 | pair_idx = self.pair_idx_map[idx]
28 |
29 | # Access actual data
30 | scene = self.scenes[scene_idx]
31 | folder = scene['folder']
32 |
33 | i, j = scene['pairs'][pair_idx]
34 | fname_src = scene['fnames'][i]
35 | fname_dst = scene['fnames'][j]
36 |
37 | data_src = self.load_data(folder, fname_src)
38 | data_dst = self.load_data(folder, fname_dst)
39 |
40 | # Optional. Could be None
41 | info_src = scene['unary_info'][i]
42 | info_dst = scene['unary_info'][j]
43 | info_pair = scene['binary_info'][pair_idx]
44 |
45 | return data_src, data_dst, info_src, info_dst, info_pair
46 |
47 | # NOTE: override in inheritance
48 | def parse_scene(self, root, scene):
49 | return {
50 | 'folder': scene,
51 | 'fnames': [],
52 | 'pairs': [],
53 | 'unary_info': [],
54 | 'binary_info': []
55 | }
56 |
57 | # NOTE: override in inheritance
58 | def load_data(self, folder, fname):
59 | return os.path.join(folder, fname)
60 |
61 | # NOTE: optionally override in inheritance, if a scene includes more than 1 subset
62 | def collect_scenes(self, root, scenes):
63 | return [self.parse_scene(root, scene) for scene in scenes]
64 |
--------------------------------------------------------------------------------
/code/dataset/megadepth_sgp.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import cv2
8 | import numpy as np
9 | import open3d as o3d
10 |
11 | from dataset.base import DatasetBase
12 | from geometry.image import skew, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches
13 |
14 | from tqdm import tqdm
15 |
16 | PSEUDO_LABEL_FNAME = 'pseudo-label.log'
17 |
18 |
19 | # Train and test sets are identical for CAPS
20 | class DatasetMegaDepthSGP(DatasetBase):
21 | def __init__(self,
22 | data_root,
23 | scenes,
24 | label_root,
25 | mode,
26 | inlier_ratio_thr=0.3,
27 | num_matches_thr=100,
28 | sample_rate=0.2):
29 | self.label_root = label_root
30 | self.inlier_ratio_thr = inlier_ratio_thr
31 | self.num_matches_thr = num_matches_thr
32 | self.sample_rate = sample_rate
33 |
34 | if not os.path.exists(label_root):
35 | print(
36 | 'label root {} does not exist, entering teaching mode.'.format(
37 | label_root))
38 | self.mode = 'teaching'
39 | os.makedirs(label_root, exist_ok=True)
40 | elif mode == 'teaching':
41 | print('label root {} will be overwritten to enter teaching mode'.
42 | format(label_root))
43 | self.mode = 'teaching'
44 | else:
45 | print('label root {} exists, entering learning mode.'.format(
46 | label_root))
47 | self.mode = 'learning'
48 |
49 | super(DatasetMegaDepthSGP, self).__init__(data_root, scenes)
50 |
51 | # override
52 | def parse_scene(self, root, scene):
53 | if self.mode == 'teaching':
54 | return self._parse_scene_teaching(root, scene)
55 | elif self.mode == 'learning':
56 | return self._parse_scene_learning(root, scene)
57 | else:
58 | print('Unsupported mode, abort')
59 | exit()
60 |
61 | def write_pseudo_label(self, idx, label, info):
62 | scene_idx = self.scene_idx_map[idx]
63 | pair_idx = self.pair_idx_map[idx]
64 |
65 | # Access actual data
66 | scene = self.scenes[scene_idx]
67 | i, j = scene['pairs'][pair_idx]
68 | folder = scene['folder']
69 |
70 | num_inliers, num_matches = info
71 | label_file = os.path.join(self.label_root, folder, PSEUDO_LABEL_FNAME)
72 | with open(label_file, 'a') as f:
73 | f.write('{} {} {} {} '.format(i, j, num_inliers, num_matches))
74 | label_str = ' '.join(map(str, label.flatten()))
75 | f.write(label_str)
76 | f.write('\n')
77 |
78 | def _deterministic_shuffle_(self, seq):
79 | import random
80 | random.Random(15213).shuffle(seq)
81 |
82 | def _parse_scene_teaching(self, root, scene):
83 | # Generate pseudo labels
84 | label_path = os.path.join(self.label_root, scene)
85 | os.makedirs(label_path, exist_ok=True)
86 | label_file = os.path.join(label_path, PSEUDO_LABEL_FNAME)
87 |
88 | if os.path.exists(label_file):
89 | os.remove(label_file)
90 | with open(label_file, 'w') as f:
91 | pass
92 |
93 | scene_path = os.path.join(root, scene)
94 |
95 | fnames = os.listdir(os.path.join(scene_path, 'images'))
96 | fnames_map = {fname: i for i, fname in enumerate(fnames)}
97 |
98 | cam_fname = os.path.join(scene_path, 'img_cam.txt')
99 | with open(cam_fname, 'r') as f:
100 | cam_content = f.readlines()
101 |
102 | cnt = 0
103 | intrinsics = np.zeros((len(fnames), 3, 3))
104 | extrinsics = np.zeros((len(fnames), 4, 4))
105 | for line in cam_content:
106 | line = line.strip()
107 | if len(line) > 0 and line[0] != "#":
108 | lst = line.split()
109 | fname = lst[0]
110 | idx = fnames_map[fname]
111 |
112 | fx, fy = float(lst[3]), float(lst[4])
113 | cx, cy = float(lst[5]), float(lst[6])
114 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0,
115 | 1]).reshape((3, 3))
116 | cnt += 1
117 |
118 | assert cnt == len(fnames)
119 |
120 | # Load pairs.txt
121 | pair_fname = os.path.join(scene_path, 'pairs.txt')
122 | with open(pair_fname, 'r') as f:
123 | pair_content = f.readlines()
124 |
125 | pairs = []
126 | for line in pair_content:
127 | lst = line.strip().split(' ')
128 | src_fname = lst[0]
129 | dst_fname = lst[1]
130 |
131 | src_idx = fnames_map[src_fname]
132 | dst_idx = fnames_map[dst_fname]
133 | pairs.append((src_idx, dst_idx))
134 |
135 | pairs_cnt = len(pairs)
136 | idx_selection = np.arange(pairs_cnt)
137 | self._deterministic_shuffle_(idx_selection)
138 | idx_selection = idx_selection[:int(self.sample_rate *
139 | pairs_cnt)].astype(int)
140 |
141 | return {
142 | 'folder': scene,
143 | 'fnames': fnames,
144 | 'pairs': np.asarray(pairs)[idx_selection],
145 | 'unary_info': intrinsics,
146 | 'binary_info': [None for i in range(len(pairs))]
147 | }
148 |
149 | def _parse_scene_learning(self, root, scene):
150 | # Load pseudo labels
151 | label_path = os.path.join(self.label_root, scene, PSEUDO_LABEL_FNAME)
152 | if not os.path.exists(label_path):
153 | raise Exception('{} not found', label_path)
154 |
155 | scene_path = os.path.join(root, scene)
156 |
157 | fnames = os.listdir(os.path.join(scene_path, 'images'))
158 | fnames_map = {fname: i for i, fname in enumerate(fnames)}
159 |
160 | cam_fname = os.path.join(scene_path, 'img_cam.txt')
161 | with open(cam_fname, 'r') as f:
162 | cam_content = f.readlines()
163 |
164 | cnt = 0
165 | intrinsics = np.zeros((len(fnames), 3, 3))
166 | for line in cam_content:
167 | line = line.strip()
168 | if len(line) > 0 and line[0] != "#":
169 | lst = line.split()
170 | fname = lst[0]
171 | idx = fnames_map[fname]
172 |
173 | fx, fy = float(lst[3]), float(lst[4])
174 | cx, cy = float(lst[5]), float(lst[6])
175 |
176 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0,
177 | 1]).reshape((3, 3))
178 | cnt += 1
179 |
180 | assert cnt == len(fnames)
181 |
182 | with open(label_path, 'r') as f:
183 | pair_content = f.readlines()
184 |
185 | pairs = []
186 | binary_info = []
187 |
188 | for line in pair_content:
189 | lst = line.strip().split(' ')
190 | src_idx = int(lst[0])
191 | dst_idx = int(lst[1])
192 |
193 | num_inliers = float(lst[2])
194 | num_matches = float(lst[3])
195 |
196 | F_data = list(map(float, lst[4:]))
197 | F = np.array(F_data).reshape((3, 3))
198 |
199 | if num_matches >= self.num_matches_thr \
200 | and (num_inliers / num_matches) >= self.inlier_ratio_thr:
201 | pairs.append((src_idx, dst_idx))
202 | binary_info.append(F)
203 |
204 | return {
205 | 'folder': scene,
206 | 'fnames': fnames,
207 | 'pairs': pairs,
208 | 'unary_info': intrinsics,
209 | 'binary_info': binary_info
210 | }
211 |
212 | # override
213 | def load_data(self, folder, fname):
214 | fname = os.path.join(self.root, folder, 'images', fname)
215 | return cv2.imread(fname)
216 |
217 | # override
218 | def collect_scenes(self, root, scenes):
219 | scene_collection = []
220 |
221 | for scene in scenes:
222 | scene_path = os.path.join(root, scene)
223 | subdirs = os.listdir(scene_path)
224 | for subdir in subdirs:
225 | if subdir.startswith('dense') and \
226 | os.path.isdir(
227 | os.path.join(scene_path, subdir)):
228 | scene_dict = self.parse_scene(
229 | root, os.path.join(scene, subdir, 'aligned'))
230 | scene_collection.append(scene_dict)
231 |
232 | return scene_collection
233 |
--------------------------------------------------------------------------------
/code/dataset/megadepth_test.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import cv2
8 | import numpy as np
9 | import open3d as o3d
10 |
11 | from dataset.base import DatasetBase
12 | from geometry.image import compute_fundamental_from_poses, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches
13 |
14 |
15 | # Train and test sets are identical for CAPS
16 | class DatasetMegaDepthTest(DatasetBase):
17 | def __init__(self, data_root, scenes, label_root):
18 | self.data_root = data_root
19 | super(DatasetMegaDepthTest, self).__init__(label_root, scenes)
20 |
21 | # override
22 | def parse_scene(self, label_root, scene):
23 | scene_path = os.path.join(label_root, scene)
24 |
25 | # Load cameras
26 | cam_fname = os.path.join(scene_path, 'img_cam.txt')
27 | with open(cam_fname, 'r') as f:
28 | cam_content = f.readlines()
29 |
30 | cnt = 0
31 | fnames = []
32 | fnames_map = {}
33 | intrinsics = []
34 | extrinsics = []
35 | for i, line in enumerate(cam_content):
36 | line = line.strip()
37 | if len(line) > 0 and line[0] != "#":
38 | lst = line.split()
39 | seq = lst[0]
40 | fname = lst[1]
41 |
42 | fx, fy = float(lst[4]), float(lst[5])
43 | cx, cy = float(lst[6]), float(lst[7])
44 |
45 | R = np.array(lst[8:17]).reshape((3, 3))
46 | t = np.array(lst[17:20])
47 | T = np.eye(4)
48 | T[:3, :3] = R
49 | T[:3, 3] = t
50 |
51 | fnames.append(
52 | os.path.join(self.data_root, seq, 'dense', 'aligned',
53 | 'images', fname))
54 | fnames_map[fname] = i
55 | intrinsics.append(
56 | np.array([fx, 0, cx, 0, fy, cy, 0, 0, 1]).reshape((3, 3)))
57 | extrinsics.append(T)
58 |
59 | # Load pairs.txt
60 | pair_fname = os.path.join(scene_path, 'pairs.txt')
61 | with open(pair_fname, 'r') as f:
62 | pair_content = f.readlines()
63 |
64 | pairs = []
65 | for line in pair_content:
66 | lst = line.strip().split(' ')
67 | seq = lst[0]
68 | src_fname = lst[1]
69 | dst_fname = lst[2]
70 |
71 | src_idx = fnames_map[src_fname]
72 | dst_idx = fnames_map[dst_fname]
73 | pairs.append((src_idx, dst_idx))
74 |
75 | return {
76 | 'folder': scene,
77 | 'fnames': fnames,
78 | 'pairs': pairs,
79 | 'unary_info': [(K, T) for K, T in zip(intrinsics, extrinsics)],
80 | 'binary_info': [None for i in range(len(pairs))]
81 | }
82 |
83 | # override
84 | def load_data(self, folder, fname):
85 | return cv2.imread(fname)
86 |
--------------------------------------------------------------------------------
/code/dataset/megadepth_train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import cv2
8 | import numpy as np
9 | import open3d as o3d
10 |
11 | from dataset.base import DatasetBase
12 | from geometry.image import compute_fundamental_from_poses, detect_keypoints, extract_feats, match_feats, estimate_essential, draw_matches
13 |
14 |
15 | class DatasetMegaDepthTrain(DatasetBase):
16 | def __init__(self, root, scenes):
17 | super(DatasetMegaDepthTrain, self).__init__(root, scenes)
18 |
19 | # override
20 | def parse_scene(self, root, scene):
21 | scene_path = os.path.join(root, scene)
22 |
23 | fnames = os.listdir(os.path.join(scene_path, 'images'))
24 | fnames_map = {fname: i for i, fname in enumerate(fnames)}
25 |
26 | # Load pairs.txt
27 | pair_fname = os.path.join(scene_path, 'pairs.txt')
28 | with open(pair_fname, 'r') as f:
29 | pair_content = f.readlines()
30 |
31 | pairs = []
32 | for line in pair_content:
33 | lst = line.strip().split(' ')
34 | src_fname = lst[0]
35 | dst_fname = lst[1]
36 |
37 | src_idx = fnames_map[src_fname]
38 | dst_idx = fnames_map[dst_fname]
39 | pairs.append((src_idx, dst_idx))
40 |
41 | cam_fname = os.path.join(scene_path, 'img_cam.txt')
42 | with open(cam_fname, 'r') as f:
43 | cam_content = f.readlines()
44 |
45 | cnt = 0
46 | intrinsics = np.zeros((len(fnames), 3, 3))
47 | extrinsics = np.zeros((len(fnames), 4, 4))
48 | for line in cam_content:
49 | line = line.strip()
50 | if len(line) > 0 and line[0] != "#":
51 | lst = line.split()
52 | fname = lst[0]
53 | idx = fnames_map[fname]
54 |
55 | fx, fy = float(lst[3]), float(lst[4])
56 | cx, cy = float(lst[5]), float(lst[6])
57 |
58 | R = np.array(lst[7:16]).reshape((3, 3))
59 | t = np.array(lst[16:19])
60 | T = np.eye(4)
61 | T[:3, :3] = R
62 | T[:3, 3] = t
63 |
64 | intrinsics[idx] = np.array([fx, 0, cx, 0, fy, cy, 0, 0,
65 | 1]).reshape((3, 3))
66 | extrinsics[idx] = T
67 | cnt += 1
68 |
69 | assert cnt == len(fnames)
70 |
71 | return {
72 | 'folder': scene,
73 | 'fnames': fnames,
74 | 'pairs': pairs,
75 | 'unary_info': [(K, T) for K, T in zip(intrinsics, extrinsics)],
76 | 'binary_info': [None for i in range(len(pairs))]
77 | }
78 |
79 | # override
80 | def load_data(self, folder, fname):
81 | fname = os.path.join(self.root, folder, 'images', fname)
82 | return cv2.imread(fname)
83 |
84 | # override
85 | def collect_scenes(self, root, scenes):
86 | scene_collection = []
87 |
88 | for scene in scenes:
89 | scene_path = os.path.join(root, scene)
90 | subdirs = os.listdir(scene_path)
91 | for subdir in subdirs:
92 | if subdir.startswith('dense') and \
93 | os.path.isdir(
94 | os.path.join(scene_path, subdir)):
95 | scene_dict = self.parse_scene(
96 | root, os.path.join(scene, subdir, 'aligned'))
97 | scene_collection.append(scene_dict)
98 |
99 | return scene_collection
100 |
--------------------------------------------------------------------------------
/code/dataset/threedmatch_sgp.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import glob
8 |
9 | import open3d as o3d
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | from dataset.base import DatasetBase
14 | from geometry.pointcloud import make_o3d_pointcloud, extract_feats, match_feats, solve, refine
15 | PSEUDO_LABEL_FNAME = 'pseudo-label.log'
16 |
17 |
18 | class Dataset3DMatchSGP(DatasetBase):
19 | '''
20 | During teaching: labels are written to a separate directory
21 | During learning: it acts like the train, with labels in a separate directory
22 | '''
23 | def __init__(self, data_root, scenes, label_root, mode, overlap_thr=0.3):
24 | self.label_root = label_root
25 | self.overlap_thr = overlap_thr
26 |
27 | if not os.path.exists(label_root):
28 | print(
29 | 'label root {} does not exist, entering teaching mode.'.format(
30 | label_root))
31 | self.mode = 'teaching'
32 | os.makedirs(label_root, exist_ok=True)
33 | elif mode == 'teaching':
34 | print('label root {} will be overwritten to enter teaching mode'.
35 | format(label_root))
36 | self.mode = 'teaching'
37 | else:
38 | print('label root {} exists, entering learning mode.'.format(
39 | label_root))
40 | self.mode = 'learning'
41 |
42 | super(Dataset3DMatchSGP, self).__init__(data_root, scenes)
43 |
44 | # override
45 | def parse_scene(self, root, scene):
46 | if self.mode == 'teaching':
47 | return self._parse_scene_teaching(root, scene)
48 | elif self.mode == 'learning':
49 | return self._parse_scene_learning(root, scene)
50 | else:
51 | print('Unsupported mode, abort')
52 | exit()
53 |
54 | # override
55 | def load_data(self, folder, fname):
56 | fname = os.path.join(self.root, folder, fname)
57 | return make_o3d_pointcloud(np.load(fname)['pcd'])
58 |
59 | def write_pseudo_label(self, idx, label, overlap):
60 | scene_idx = self.scene_idx_map[idx]
61 | pair_idx = self.pair_idx_map[idx]
62 |
63 | # Access actual data
64 | scene = self.scenes[scene_idx]
65 | i, j = scene['pairs'][pair_idx]
66 | folder = scene['folder']
67 |
68 | label_file = os.path.join(self.label_root, folder, PSEUDO_LABEL_FNAME)
69 | with open(label_file, 'a') as f:
70 | f.write('{} {} {} '.format(i, j, overlap))
71 | label_str = ' '.join(map(str, label.flatten()))
72 | f.write(label_str)
73 | f.write('\n')
74 |
75 | def _parse_scene_teaching(self, root, scene):
76 | # Generate pseudo labels
77 | label_path = os.path.join(self.label_root, scene)
78 | os.makedirs(label_path, exist_ok=True)
79 | label_file = os.path.join(label_path, PSEUDO_LABEL_FNAME)
80 |
81 | if os.path.exists(label_file):
82 | os.remove(label_file)
83 | with open(label_file, 'w') as f:
84 | pass
85 |
86 | # Load actual data
87 | scene_path = os.path.join(root, scene)
88 |
89 | # Load filenames
90 | l = len(scene_path)
91 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz')))
92 | fnames = [fname[l + 1:] for fname in fnames]
93 |
94 | # Load overlaps.txt
95 | pair_fname = os.path.join(scene_path, 'overlaps.txt')
96 | with open(pair_fname, 'r') as f:
97 | pair_content = f.readlines()
98 |
99 | pairs = []
100 | binary_info = []
101 |
102 | # For a 3DMatch dataset for teaching,
103 | # binary_info is (optional) for filtering: overlap
104 | for line in pair_content:
105 | lst = line.strip().split(' ')
106 | src_idx = int(lst[0].split('.')[0].split('_')[-1])
107 | dst_idx = int(lst[1].split('.')[0].split('_')[-1])
108 | overlap = float(lst[2])
109 |
110 | if overlap >= self.overlap_thr:
111 | pairs.append((src_idx, dst_idx))
112 | binary_info.append(overlap)
113 |
114 | return {
115 | 'folder': scene,
116 | 'fnames': fnames,
117 | 'pairs': pairs,
118 | 'unary_info': [None for i in range(len(fnames))],
119 | 'binary_info': binary_info
120 | }
121 |
122 | '''
123 | Pseudo-Labels not available. Generate paths for writing to them later.
124 | '''
125 |
126 | def _parse_scene_learning(self, root, scene):
127 | # Load pseudo labels
128 | label_path = os.path.join(self.label_root, scene, PSEUDO_LABEL_FNAME)
129 | if not os.path.exists(label_path):
130 | raise Exception('{} not found', label_path)
131 |
132 | # Load actual data
133 | scene_path = os.path.join(root, scene)
134 |
135 | # Load filenames
136 | l = len(scene_path)
137 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz')))
138 | fnames = [fname[l + 1:] for fname in fnames]
139 |
140 | # Load overlaps.txt
141 | with open(label_path, 'r') as f:
142 | pair_content = f.readlines()
143 |
144 | pairs = []
145 | binary_info = []
146 |
147 | # For a 3DMatch dataset for learning,
148 | # binary_info is the pseudo label: src to dst transformation.
149 | for line in pair_content:
150 | lst = line.strip().split(' ')
151 | src_idx = int(lst[0].split('.')[0].split('_')[-1])
152 | dst_idx = int(lst[1].split('.')[0].split('_')[-1])
153 | overlap = float(lst[2])
154 | T_data = list(map(float, lst[3:]))
155 | T = np.array(T_data).reshape((4, 4))
156 |
157 | if overlap >= self.overlap_thr:
158 | pairs.append((src_idx, dst_idx))
159 | binary_info.append(T)
160 |
161 | return {
162 | 'folder': scene,
163 | 'fnames': fnames,
164 | 'pairs': pairs,
165 | 'unary_info': [None for i in range(len(fnames))],
166 | 'binary_info': binary_info
167 | }
168 |
--------------------------------------------------------------------------------
/code/dataset/threedmatch_test.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import glob
8 |
9 | import open3d as o3d
10 | import numpy as np
11 |
12 | from dataset.base import DatasetBase
13 |
14 |
15 | class Dataset3DMatchTest(DatasetBase):
16 | def __init__(self, root, scenes):
17 | super(Dataset3DMatchTest, self).__init__(root, scenes)
18 |
19 | # override
20 | def parse_scene(self, root, scene):
21 | scene_path = os.path.join(root, scene)
22 |
23 | l = len(scene_path)
24 | fnames = sorted(
25 | glob.glob(os.path.join(scene_path, '*.ply')),
26 | key=lambda fname: int(fname.split('.')[0].split('_')[-1]))
27 | fnames = [fname[l + 1:] for fname in fnames]
28 |
29 | # Load gt
30 | scene_gt_path = os.path.join(root, scene + '-evaluation')
31 | gt_fname = os.path.join(scene_gt_path, 'gt.log')
32 | with open(gt_fname, 'r') as f:
33 | pair_content = f.readlines()
34 |
35 | pairs = []
36 | binary_info = []
37 |
38 | # For a 3DMatch test dataset,
39 | # binary_info is the gt label: src to dst transformation.
40 | for i in range(0, len(pair_content), 5):
41 | lst = pair_content[i].strip().split('\t')
42 | src_idx = int(lst[0])
43 | dst_idx = int(lst[1])
44 |
45 | res = map(lambda x: np.fromstring(x.strip(), sep='\t'),
46 | pair_content[i+1:i+5])
47 | T_src2dst = np.stack(list(res))
48 | pairs.append((src_idx, dst_idx))
49 | binary_info.append(np.linalg.inv(T_src2dst))
50 |
51 | return {
52 | 'folder': scene,
53 | 'fnames': fnames,
54 | 'pairs': pairs,
55 | 'unary_info': [None for i in range(len(fnames))],
56 | 'binary_info': binary_info
57 | }
58 |
59 | # override
60 | def load_data(self, folder, fname):
61 | fname = os.path.join(self.root, folder, fname)
62 | return o3d.io.read_point_cloud(fname)
63 |
--------------------------------------------------------------------------------
/code/dataset/threedmatch_train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | import glob
8 |
9 | import open3d as o3d
10 | import numpy as np
11 |
12 | from dataset.base import DatasetBase
13 | from geometry.pointcloud import make_o3d_pointcloud
14 |
15 | class Dataset3DMatchTrain(DatasetBase):
16 | def __init__(self, root, scenes, overlap_thr=0.3):
17 | self.overlap_thr = overlap_thr
18 | super(Dataset3DMatchTrain, self).__init__(root, scenes)
19 |
20 | # override
21 | def parse_scene(self, root, scene):
22 | scene_path = os.path.join(root, scene)
23 |
24 | l = len(scene_path)
25 | fnames = sorted(glob.glob(os.path.join(scene_path, '*.npz')))
26 | fnames = [fname[l + 1:] for fname in fnames]
27 |
28 | # Load overlaps.txt
29 | pair_fname = os.path.join(scene_path, 'overlaps.txt')
30 | with open(pair_fname, 'r') as f:
31 | pair_content = f.readlines()
32 |
33 | pairs = []
34 | binary_info = []
35 |
36 | # For a preprocessed 3DMatch training dataset,
37 | # binary_info is the gt label: pre-calibrated identity matrix.
38 | for line in pair_content:
39 | lst = line.strip().split(' ')
40 | src_idx = int(lst[0].split('.')[0].split('_')[-1])
41 | dst_idx = int(lst[1].split('.')[0].split('_')[-1])
42 | overlap = float(lst[2])
43 |
44 | if overlap >= self.overlap_thr:
45 | pairs.append((src_idx, dst_idx))
46 | binary_info.append(np.eye(4))
47 |
48 | return {
49 | 'folder': scene,
50 | 'fnames': fnames,
51 | 'pairs': pairs,
52 | 'unary_info': [None for i in range(len(fnames))],
53 | 'binary_info': binary_info
54 | }
55 |
56 | # override
57 | def load_data(self, folder, fname):
58 | fname = os.path.join(self.root, folder, fname)
59 | return make_o3d_pointcloud(np.load(fname)['pcd'])
60 |
--------------------------------------------------------------------------------
/code/geometry/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def rotation_error(R0, R1):
5 | return np.abs(
6 | np.arccos(np.clip((np.trace(R0.T @ R1) - 1) / 2.0, -0.999999,
7 | 0.999999))) / np.pi * 180
8 |
9 |
10 | def translation_error(t0, t1):
11 | return np.linalg.norm(t0 - t1)
12 |
13 |
14 | def angular_translation_error(t0, t1):
15 | t0 = t0 / np.linalg.norm(t0)
16 | t1 = t1 / np.linalg.norm(t1)
17 | err = np.arccos(np.clip(np.inner(t0, t1), -0.999999,
18 | 0.999999)) / np.pi * 180
19 | return err
20 |
--------------------------------------------------------------------------------
/code/geometry/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 | import torchvision.transforms as transforms
5 |
6 | def skew(x):
7 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
8 |
9 |
10 | def compute_fundamental_from_poses(K_src, K_dst, T_src, T_dst):
11 | T_src2dst = T_dst.dot(np.linalg.inv(T_src))
12 | R = T_src2dst[:3, :3]
13 | t = T_src2dst[:3, 3]
14 | tx = skew(t)
15 | E = np.dot(tx, R)
16 | return np.linalg.inv(K_dst).T.dot(E).dot(np.linalg.inv(K_src))
17 |
18 |
19 | def detect_keypoints(im, detector, num_kpts=10000):
20 | gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
21 |
22 | if detector == 'sift':
23 | sift = cv2.xfeatures2d.SIFT_create(nfeatures=num_kpts)
24 | kpts = sift.detect(gray)
25 | elif detector == 'orb':
26 | orb = cv2.ORB_create(nfeatures=num_kpts)
27 | kpts = orb.detect(gray)
28 | else:
29 | raise NotImplementedError('Unknown keypoint detector.')
30 |
31 | return kpts
32 |
33 |
34 | def extract_feats(im, kpts, feature_type, model=None):
35 | if feature_type == 'sift':
36 | sift = cv2.xfeatures2d.SIFT_create()
37 | kpts, feats = sift.compute(im, kpts)
38 |
39 | elif feature_type == 'orb':
40 | orb = cv2.ORB_create()
41 | kpts, feats = orb.compute(im, kpts)
42 |
43 | elif feature_type == 'caps':
44 | assert model is not None
45 | transform = transforms.Compose([
46 | transforms.ToTensor(),
47 | transforms.Normalize(mean=(0.485, 0.456, 0.406),
48 | std=(0.229, 0.224, 0.225)),
49 | ])
50 |
51 | kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts])
52 | kpts = torch.from_numpy(kpts).float()
53 |
54 | desc_c, desc_f = model.extract_features(
55 | transform(im).unsqueeze(0).to(model.device),
56 | kpts.unsqueeze(0).to(model.device))
57 |
58 | feats = torch.cat((desc_c, desc_f),
59 | -1).squeeze(0).detach().cpu().numpy()
60 | else:
61 | raise NotImplementedError('Unknown feature descriptor.')
62 |
63 | return feats
64 |
65 |
66 | def match_feats(feats_src,
67 | feats_dst,
68 | feature_type,
69 | ratio_test=True,
70 | ratio_thr=0.6):
71 | if feature_type == 'orb':
72 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
73 | good = bf.match(feats_src, feats_dst)
74 | else: # sift and caps descriptor
75 | if ratio_test:
76 | bf = cv2.BFMatcher()
77 | matches = bf.knnMatch(feats_src, feats_dst, k=2)
78 | good = []
79 | for m, n in matches:
80 | if m.distance < ratio_thr * n.distance:
81 | good.append(m)
82 | if len(good) < 50:
83 | matches = sorted(matches,
84 | key=lambda x: x[0].distance / x[1].distance)
85 | good = [m[0] for m in matches[:50]]
86 |
87 | else:
88 | bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
89 | good = bf.match(feats_src, feats_dst)
90 | if len(good) < 50:
91 | bf = cv2.BFMatcher()
92 | matches = bf.match(feats_src, feats_dst)
93 | matches = sorted(matches, key=lambda x: x.distance)
94 | good = [m for m in matches[:50]]
95 |
96 | return good
97 |
98 |
99 | def estimate_essential(kp1, kp2, matches, K1, K2, th=1e-4):
100 | src_pts = np.float32([kp1[m.queryIdx].pt
101 | for m in matches]).reshape(-1, 1, 2)
102 | dst_pts = np.float32([kp2[m.trainIdx].pt
103 | for m in matches]).reshape(-1, 1, 2)
104 | pts_l_norm = cv2.undistortPoints(src_pts, cameraMatrix=K1, distCoeffs=None)
105 | pts_r_norm = cv2.undistortPoints(dst_pts, cameraMatrix=K2, distCoeffs=None)
106 | E, mask = cv2.findEssentialMat(pts_l_norm,
107 | pts_r_norm,
108 | focal=1.0,
109 | pp=(0., 0.),
110 | method=cv2.RANSAC,
111 | prob=0.999,
112 | threshold=th)
113 | if E.shape != (3, 3):
114 | return np.eye(3), np.zeros((len(matches))), np.eye(3), np.zeros((3))
115 |
116 | mask = np.squeeze(mask).astype(bool)
117 | _, R, t, _ = cv2.recoverPose(E, pts_l_norm[mask], pts_r_norm[mask])
118 | t = np.squeeze(t)
119 | return E, mask, R, t
120 |
121 |
122 | def decolorize(img):
123 | return cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY),
124 | cv2.COLOR_GRAY2RGB)
125 |
126 |
127 | def draw_matches(kps1, kps2, tentatives, img1, img2, H, mask):
128 | if H is None:
129 | print("No homography found")
130 | return
131 | matchesMask = mask.ravel().tolist()
132 | h, w, ch = img1.shape
133 | pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1],
134 | [w - 1, 0]]).reshape(-1, 1, 2)
135 | dst = cv2.perspectiveTransform(pts, H)
136 | img2_tr = cv2.polylines(decolorize(img2), [np.int32(dst)], True,
137 | (0, 0, 255), 3, cv2.LINE_AA)
138 | draw_params = dict(
139 | matchColor=(255, 255, 0), # draw matches in yellow color
140 | singlePointColor=None,
141 | matchesMask=matchesMask, # draw only inliers
142 | flags=2)
143 | return cv2.drawMatches(decolorize(img1), kps1, img2_tr, kps2, tentatives,
144 | None, **draw_params)
145 |
--------------------------------------------------------------------------------
/code/geometry/pointcloud.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | import torch
4 |
5 | import MinkowskiEngine as ME
6 | from scipy.spatial import cKDTree
7 |
8 |
9 | def make_o3d_pointcloud(xyz):
10 | pcd = o3d.geometry.PointCloud()
11 | pcd.points = o3d.utility.Vector3dVector(xyz)
12 | return pcd
13 |
14 |
15 | def extract_feats(pcd, feature_type, voxel_size, model=None):
16 | xyz = np.asarray(pcd.points)
17 | _, sel = ME.utils.sparse_quantize(xyz,
18 | return_index=True,
19 | quantization_size=voxel_size)
20 | xyz = xyz[sel]
21 | pcd = make_o3d_pointcloud(xyz)
22 |
23 | if feature_type == 'FPFH':
24 | radius_normal = voxel_size * 2
25 | pcd.estimate_normals(
26 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal,
27 | max_nn=30))
28 | radius_feat = voxel_size * 5
29 | feat = o3d.pipelines.registration.compute_fpfh_feature(
30 | pcd,
31 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feat,
32 | max_nn=100))
33 | # (N, 33)
34 | return pcd, feat.data.T
35 |
36 | elif feature_type == 'FCGF':
37 | DEVICE = torch.device('cuda')
38 | coords = ME.utils.batched_coordinates(
39 | [torch.floor(torch.from_numpy(xyz) / voxel_size).int()]).to(DEVICE)
40 |
41 | feats = torch.ones(coords.size(0), 1).to(DEVICE)
42 | sinput = ME.SparseTensor(feats, coordinates=coords) # .to(DEVICE)
43 |
44 | # (N, 32)
45 | return pcd, model(sinput).F.detach().cpu().numpy()
46 |
47 | else:
48 | raise NotImplementedError(
49 | 'Unimplemented feature type {}'.format(feature_type))
50 |
51 |
52 | def find_knn_cpu(feat0, feat1, knn=1, return_distance=False):
53 | feat1tree = cKDTree(feat1)
54 | dists, nn_inds = feat1tree.query(feat0, k=knn, n_jobs=-1)
55 | if return_distance:
56 | return nn_inds, dists
57 | else:
58 | return nn_inds
59 |
60 |
61 | def match_feats(feat_src, feat_dst, mutual_filter=True, k=1):
62 | if not mutual_filter:
63 | nns01 = find_knn_cpu(feat_src, feat_dst, knn=1, return_distance=False)
64 | corres01_idx0 = np.arange(len(nns01)).squeeze()
65 | corres01_idx1 = nns01.squeeze()
66 | return np.stack((corres01_idx0, corres01_idx1)).T
67 | else:
68 | # for each feat in src, find its k=1 nearest neighbours
69 | nns01 = find_knn_cpu(feat_src, feat_dst, knn=1, return_distance=False)
70 | # for each feat in dst, find its k nearest neighbours
71 | nns10 = find_knn_cpu(feat_dst, feat_src, knn=k, return_distance=False)
72 | # find corrs
73 | num_feats = len(nns01)
74 | corres01 = []
75 | if k == 1:
76 | for i in range(num_feats):
77 | if i == nns10[nns01[i]]:
78 | corres01.append([i, nns01[i]])
79 | else:
80 | for i in range(num_feats):
81 | if i in nns10[nns01[i]]:
82 | corres01.append([i, nns01[i]])
83 | # print(
84 | # f'Before mutual filter: {num_feats}, after mutual_filter with k={k}: {len(corres01)}.'
85 | # )
86 |
87 | # Fallback if mutual filter is too aggressive
88 | if len(corres01) < 10:
89 | nns01 = find_knn_cpu(feat_src,
90 | feat_dst,
91 | knn=1,
92 | return_distance=False)
93 | corres01_idx0 = np.arange(len(nns01)).squeeze()
94 | corres01_idx1 = nns01.squeeze()
95 | return np.stack((corres01_idx0, corres01_idx1)).T
96 |
97 | return np.asarray(corres01)
98 |
99 |
100 | def weighted_procrustes(A, B, weights=None):
101 | num_pts = A.shape[1]
102 | if weights is None:
103 | weights = np.ones(num_pts)
104 |
105 | # compute weighted center
106 | A_center = A @ weights / np.sum(weights)
107 | B_center = B @ weights / np.sum(weights)
108 |
109 | # compute relative positions
110 | A_ref = A - A_center[:, np.newaxis]
111 | B_ref = B - B_center[:, np.newaxis]
112 |
113 | # compute rotation
114 | M = B_ref @ np.diag(weights) @ A_ref.T
115 | U, _, Vh = np.linalg.svd(M)
116 | S = np.identity(3)
117 | S[-1, -1] = np.linalg.det(U) * np.linalg.det(Vh)
118 | R = U @ S @ Vh
119 |
120 | # compute translation
121 | t = B_center - R @ A_center
122 |
123 | return R, t
124 |
125 |
126 | def solve(src, dst, corres, solver_type, distance_thr, ransac_iters,
127 | confidence):
128 | if solver_type.startswith('RANSAC'):
129 | corres = o3d.utility.Vector2iVector(corres)
130 |
131 | result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
132 | src, dst, corres, distance_thr,
133 | o3d.pipelines.registration.TransformationEstimationPointToPoint(
134 | False), 3, [],
135 | o3d.pipelines.registration.RANSACConvergenceCriteria(
136 | ransac_iters, confidence))
137 |
138 | return result.transformation, result.fitness
139 |
140 | else:
141 | raise NotImplementedError(
142 | 'Unimplemented solver type {}'.format(solver_type))
143 |
144 |
145 | def refine(src, dst, ransac_T, distance_thr):
146 | result = o3d.pipelines.registration.registration_icp(
147 | src, dst, distance_thr, ransac_T,
148 | o3d.pipelines.registration.TransformationEstimationPointToPoint())
149 | icp_T = result.transformation
150 | icp_fitness = result.fitness
151 |
152 | fitness = icp_fitness * np.minimum(
153 | 1.0,
154 | float(len(dst.points)) / float(len(src.points)))
155 |
156 | return icp_T, fitness
157 |
--------------------------------------------------------------------------------
/code/perception2d/adaptor.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | caps_path = os.path.join(project_path, 'ext', 'caps')
8 | sys.path.append(caps_path)
9 |
10 | from ext.caps.CAPS.caps_model import CAPSModel
11 | from ext.caps.utils import cycle
12 |
13 | from dataset.megadepth_train import DatasetMegaDepthTrain
14 | from dataset.megadepth_test import DatasetMegaDepthTest
15 | from dataset.megadepth_sgp import DatasetMegaDepthSGP
16 | from geometry.image import *
17 |
18 | from geometry.common import rotation_error, angular_translation_error
19 |
20 | from tensorboardX import SummaryWriter
21 | import configargparse
22 |
23 | import torch
24 | from torch.utils.data import Dataset
25 | import numpy as np
26 | import cv2
27 |
28 | import utils
29 | import collections
30 | from tqdm import tqdm
31 | import dataloader.data_utils as data_utils
32 |
33 | rand = np.random.RandomState(234)
34 |
35 |
36 | class CAPSConfigParser(configargparse.ArgParser):
37 | def __init__(self):
38 | super().__init__(default_config_files=[
39 | os.path.join(os.path.dirname(__file__), 'caps_train_config.yml')
40 | ],
41 | conflict_handler='resolve')
42 |
43 | ## path options
44 | self.add('--datadir', type=str, help='the dataset directory')
45 | self.add("--logdir",
46 | type=str,
47 | default='caps_logs',
48 | help='dir of tensorboard logs')
49 | self.add("--outdir",
50 | type=str,
51 | default='caps_outputs',
52 | help='dir of output e.g., ckpts')
53 | self.add(
54 | "--ckpt_path",
55 | type=str,
56 | default='',
57 | help='specific checkpoint path to load the model from, '
58 | 'if not specified, automatically reload from most recent checkpoints'
59 | )
60 | self.add('--pseudo_label_dir',
61 | type=str,
62 | default='caps_pseudo_label',
63 | help='the pseudo-gt directory storing pairs and F matrices')
64 | self.add(
65 | '--label_dir',
66 | type=str,
67 | default='',
68 | help=
69 | 'the gt directory storing pairs and F matrices. Reserved for pose test set.'
70 | )
71 |
72 | # SGP options
73 | self.add('--scenes',
74 | nargs='+',
75 | help='scenes used for training/testing')
76 | self.add('--inlier_ratio_thr', type=float, default=0.001)
77 | self.add('--num_matches_thr', type=int, default=100)
78 | self.add('--sample_rate',
79 | type=float,
80 | default=1,
81 | help='rate of samples from the huge megadepth dataset')
82 | self.add('--num_kpts',
83 | type=int,
84 | default=10000,
85 | help='number of key points detected during teaching')
86 | self.add('--match_ratio_test',
87 | type=bool,
88 | default=True,
89 | help='performs ratio test in feature matching')
90 | self.add('--match_ratio_thr',
91 | type=float,
92 | default=0.75,
93 | help='ratio between best and second best matchings')
94 | self.add('--ransac_thr',
95 | type=float,
96 | default=1e-3,
97 | help='RANSAC threshold in estimating essential matrices')
98 |
99 | self.add(
100 | '--restart_meta_iter',
101 | type=int,
102 | default=-1,
103 | help='start of teacher-student iterations. -1 indicates bootstrap')
104 | self.add('--max_meta_iters',
105 | type=int,
106 | default=2,
107 | help='number of teacher-student iterations')
108 | self.add('--finetune',
109 | action='store_true',
110 | help='train from previous checkpoint during SGP.')
111 |
112 | ## general options
113 | self.add("--exp_name", type=str, help='experiment name')
114 | self.add('--n_iters',
115 | type=int,
116 | default=100,
117 | help='max number of training iterations')
118 | self.add("--save_interval",
119 | type=int,
120 | default=100,
121 | help='frequency of weight ckpt saving')
122 | self.add('--phase',
123 | type=str,
124 | default='train',
125 | choices=['train', 'val', 'test'])
126 |
127 | # data options
128 | self.add('--workers',
129 | type=int,
130 | help='number of data loading workers',
131 | default=8)
132 | self.add('--num_pts',
133 | type=int,
134 | default=500,
135 | help='num of points trained in each pair')
136 | self.add('--train_kp',
137 | type=str,
138 | default='mixed',
139 | help='sift/random/mixed')
140 | self.add('--prune_kp',
141 | type=int,
142 | default=1,
143 | help='if prune non-matchable keypoints')
144 |
145 | # training options
146 | self.add('--batch_size', type=int, default=2, help='input batch size')
147 | self.add('--lr', type=float, default=1e-4, help='base learning rate')
148 | self.add(
149 | "--lrate_decay_steps",
150 | type=int,
151 | default=80000,
152 | help=
153 | 'decay learning rate by a factor every specified number of steps')
154 | self.add(
155 | "--lrate_decay_factor",
156 | type=float,
157 | default=0.5,
158 | help=
159 | 'decay learning rate by a factor every specified number of steps')
160 |
161 | ## model options
162 | self.add(
163 | '--backbone',
164 | type=str,
165 | default='resnet50',
166 | help=
167 | 'backbone for feature representation extraction. supported: resent'
168 | )
169 | self.add(
170 | '--pretrained',
171 | type=int,
172 | default=1,
173 | help='if use ImageNet pretrained weights to initialize the network'
174 | )
175 | self.add('--coarse_feat_dim',
176 | type=int,
177 | default=128,
178 | help='the feature dimension for coarse level features')
179 | self.add('--fine_feat_dim',
180 | type=int,
181 | default=128,
182 | help='the feature dimension for fine level features')
183 | self.add(
184 | '--prob_from',
185 | type=str,
186 | default='correlation',
187 | help=
188 | 'compute prob by softmax(correlation score), or softmax(-distance),'
189 | 'options: correlation|distance')
190 | self.add(
191 | '--window_size',
192 | type=float,
193 | default=0.125,
194 | help='the size of the window, w.r.t image width at the fine level')
195 | self.add('--use_nn',
196 | type=int,
197 | default=1,
198 | help='if use nearest neighbor in the coarse level')
199 |
200 | ## loss function options
201 | self.add('--std',
202 | type=int,
203 | default=1,
204 | help='reweight loss using the standard deviation')
205 | self.add('--w_epipolar_coarse',
206 | type=float,
207 | default=1,
208 | help='coarse level epipolar loss weight')
209 | self.add('--w_epipolar_fine',
210 | type=float,
211 | default=1,
212 | help='fine level epipolar loss weight')
213 | self.add('--w_cycle_coarse',
214 | type=float,
215 | default=0.1,
216 | help='coarse level cycle consistency loss weight')
217 | self.add('--w_cycle_fine',
218 | type=float,
219 | default=0.1,
220 | help='fine level cycle consistency loss weight')
221 | self.add('--w_std',
222 | type=float,
223 | default=0,
224 | help='the weight for the loss on std')
225 | self.add(
226 | '--th_cycle',
227 | type=float,
228 | default=0.025,
229 | help=
230 | 'if the distance (normalized scale) from the prediction to epipolar line > this th, '
231 | 'do not add the cycle consistency loss')
232 | self.add(
233 | '--th_epipolar',
234 | type=float,
235 | default=0.5,
236 | help=
237 | 'if the distance (normalized scale) from the prediction to epipolar line > this th, '
238 | 'do not add the epipolar loss')
239 |
240 | ## logging options
241 | self.add('--log_scalar_interval',
242 | type=int,
243 | default=20,
244 | help='print interval')
245 | self.add('--log_img_interval',
246 | type=int,
247 | default=500,
248 | help='log image interval')
249 |
250 | ## eval options
251 | self.add('--extract_img_dir',
252 | type=str,
253 | help='the directory of images to extract features')
254 | self.add('--extract_out_dir',
255 | type=str,
256 | help='the directory of images to extract features')
257 |
258 | def get_config(self):
259 | config = self.parse_args()
260 | return config
261 |
262 |
263 | def my_collate(batch):
264 | ''' Puts each data field into a tensor with outer dimension batch size '''
265 | batch = list(filter(lambda b: b is not None, batch))
266 | return torch.utils.data.dataloader.default_collate(batch)
267 |
268 |
269 | class DatasetMegaDepthAdaptor(Dataset):
270 | def __init__(self, dataset, config):
271 | self.dataset = dataset
272 | self.config = config
273 |
274 | if config.phase == 'train':
275 | # augment during training
276 | self.transform = transforms.Compose([
277 | transforms.ToPILImage(),
278 | transforms.ColorJitter(brightness=1,
279 | contrast=1,
280 | saturation=1,
281 | hue=0.4),
282 | transforms.ToTensor(),
283 | transforms.Normalize(mean=(0.485, 0.456, 0.406),
284 | std=(0.229, 0.224, 0.225)),
285 | ])
286 | else:
287 | self.transform = transforms.Compose([
288 | transforms.ToTensor(),
289 | transforms.Normalize(mean=(0.485, 0.456, 0.406),
290 | std=(0.229, 0.224, 0.225)),
291 | ])
292 | self.phase = config.phase
293 |
294 | def __getitem__(self, idx):
295 | pass
296 |
297 | def __len__(self):
298 | return len(self.dataset)
299 |
300 |
301 | # For vanilla train & test
302 | class DatasetMegaDepthTrainAdaptor(DatasetMegaDepthAdaptor):
303 | def __init__(self, dataset, config):
304 | super(DatasetMegaDepthTrainAdaptor, self).__init__(dataset, config)
305 |
306 | def __getitem__(self, idx):
307 | im_src, im_dst, cam_src, cam_dst, _ = self.dataset[idx]
308 | h, w = im_src.shape[:2]
309 |
310 | im1_ori = torch.from_numpy(im_src)
311 | im2_ori = torch.from_numpy(im_dst)
312 |
313 | im1_tensor = self.transform(im_src)
314 | im2_tensor = self.transform(im_dst)
315 |
316 | coord1 = data_utils.generate_query_kpts(im_src, self.config.train_kp,
317 | 10 * self.config.num_pts, h, w)
318 |
319 | # if no keypoints are detected
320 | if len(coord1) == 0:
321 | return None
322 |
323 | # prune query keypoints that are not likely to have correspondence in the other image
324 | coord1 = utils.random_choice(coord1, self.config.num_pts)
325 | coord1 = torch.from_numpy(coord1).float()
326 |
327 | K_src, T_src = cam_src
328 | K_dst, T_dst = cam_dst
329 |
330 | T_src2dst = torch.from_numpy(T_dst.dot(np.linalg.inv(T_src)))
331 | F = compute_fundamental_from_poses(K_src, K_dst, T_src, T_dst)
332 | F = torch.from_numpy(F).float() / (F[-1, -1] + 1e-16)
333 |
334 | out = {
335 | 'im1_ori': im1_ori,
336 | 'im2_ori': im2_ori,
337 | 'intrinsic1': K_src,
338 | 'intrinsic2': K_dst,
339 |
340 | # Additional, for training
341 | 'im1': im1_tensor,
342 | 'im2': im2_tensor,
343 | 'coord1': coord1,
344 | 'F': F,
345 | 'pose': T_src2dst
346 | }
347 |
348 | return out
349 |
350 |
351 | # For SGP train
352 | class DatasetMegaDepthSGPAdaptor(DatasetMegaDepthAdaptor):
353 | def __init__(self, dataset, config):
354 | super(DatasetMegaDepthSGPAdaptor, self).__init__(dataset, config)
355 |
356 | def __getitem__(self, idx):
357 | im1, im2, K_src, K_dst, F = self.dataset[idx]
358 | h, w = im1.shape[:2]
359 |
360 | im1_ori, im2_ori = torch.from_numpy(im1), torch.from_numpy(im2)
361 |
362 | im1_tensor = self.transform(im1)
363 | im2_tensor = self.transform(im2)
364 |
365 | coord1 = data_utils.generate_query_kpts(im1, self.config.train_kp,
366 | 10 * self.config.num_pts, h, w)
367 |
368 | # if no keypoints are detected
369 | if len(coord1) == 0:
370 | return None
371 |
372 | # prune query keypoints that are not likely to have correspondence in the other image
373 | coord1 = utils.random_choice(coord1, self.config.num_pts)
374 | coord1 = torch.from_numpy(coord1).float()
375 |
376 | F = torch.from_numpy(F).float() / (F[-1, -1] + 1e-16)
377 |
378 | out = {
379 | 'im1_ori': im1_ori,
380 | 'im2_ori': im2_ori,
381 | 'intrinsic1': K_src,
382 | 'intrinsic2': K_dst,
383 |
384 | # Additional, for training
385 | 'im1': im1_tensor,
386 | 'im2': im2_tensor,
387 | 'coord1': coord1,
388 | 'F': F,
389 |
390 | # Pose is required in the base but not used in CAPSModel
391 | 'pose': np.eye(4)
392 | }
393 |
394 | return out
395 |
396 |
397 | def align(im_src, im_dst, K_src, K_dst, detector, feature, model, config):
398 | kpts_src = detect_keypoints(im_src, detector, num_kpts=config.num_kpts)
399 | kpts_dst = detect_keypoints(im_dst, detector, num_kpts=config.num_kpts)
400 |
401 | # Too few keypoints
402 | if len(kpts_src) < 5 or len(kpts_dst) < 5:
403 | return np.eye(3), np.eye(3), np.ones((3)), [], [], [], np.zeros((0))
404 |
405 | feats_src = extract_feats(im_src, kpts_src, feature, model)
406 | feats_dst = extract_feats(im_dst, kpts_dst, feature, model)
407 | matches = match_feats(feats_src, feats_dst, feature,
408 | config.match_ratio_test, config.match_ratio_thr)
409 | num_matches = len(matches)
410 |
411 | # Too few matches
412 | if num_matches <= 5: # 5-pts method
413 | return np.eye(3), np.eye(3), np.ones(
414 | (3)), kpts_src, kpts_dst, [], np.zeros((len(matches)))
415 |
416 | E, mask, R, t = estimate_essential(kpts_src,
417 | kpts_dst,
418 | matches,
419 | K_src,
420 | K_dst,
421 | th=config.ransac_thr)
422 | F = np.linalg.inv(K_dst).T.dot(E).dot(np.linalg.inv(K_src))
423 | F = F / (F[-1, -1] + 1e-16)
424 |
425 | return F, R, t, kpts_src, kpts_dst, matches, mask
426 |
427 |
428 | def caps_train(dataset, config):
429 | # save a copy for the current config in out_folder
430 | out_folder = os.path.join(config.outdir, config.exp_name)
431 | os.makedirs(out_folder, exist_ok=True)
432 | f = os.path.join(out_folder, 'config.txt')
433 | with open(f, 'w') as file:
434 | for arg in vars(config):
435 | attr = getattr(config, arg)
436 | file.write('{} = {}\n'.format(arg, attr))
437 |
438 | # tensorboard writer
439 | tb_log_dir = os.path.join(config.logdir, config.exp_name)
440 | print('tensorboard log files are stored in {}'.format(tb_log_dir))
441 | writer = SummaryWriter(tb_log_dir)
442 |
443 | # megadepth data loader
444 | dataloader = torch.utils.data.DataLoader(dataset,
445 | batch_size=config.batch_size,
446 | shuffle=True,
447 | num_workers=config.workers,
448 | collate_fn=my_collate)
449 |
450 | model = CAPSModel(config)
451 |
452 | start_step = model.start_step
453 | dataloader_iter = iter(cycle(dataloader))
454 | for step in range(start_step + 1, start_step + config.n_iters + 1):
455 | data = next(dataloader_iter)
456 | if data is None:
457 | continue
458 |
459 | model.set_input(data)
460 | model.optimize_parameters()
461 | model.write_summary(writer, step)
462 | if step % config.save_interval == 0 and step > 0:
463 | model.save_model(step)
464 |
465 |
466 | def caps_test(dataset, config):
467 | model = CAPSModel(config)
468 |
469 | r_errs = []
470 | t_errs = []
471 |
472 | for data in tqdm(dataset):
473 | im_src, im_dst, cam_src, cam_dst, _ = data
474 |
475 | K_src, T_src = cam_src
476 | K_dst, T_dst = cam_dst
477 | T_src2dst_gt = T_dst.dot(np.linalg.inv(T_src))
478 |
479 | F, R, t, kpts_src, kpts_dst, matches, mask = align(
480 | im_src, im_dst, K_src, K_dst, 'sift', 'caps', model, config)
481 |
482 | r_err = rotation_error(R, T_src2dst_gt[:3, :3])
483 | t_err = angular_translation_error(t, T_src2dst_gt[:3, 3])
484 | r_errs.append(r_err)
485 | t_errs.append(t_err)
486 |
487 | if config.debug:
488 | im = draw_matches(kpts_src, kpts_dst, matches, im_src, im_dst, F,
489 | mask)
490 | cv2.imshow('matches', im)
491 | cv2.waitKey(-1)
492 |
493 | return np.array(r_errs), np.array(t_errs)
494 |
--------------------------------------------------------------------------------
/code/perception2d/config_sgp.yml:
--------------------------------------------------------------------------------
1 | # training from scratch using a single gpu, default configs in config.py file
2 | exp_name: caps_sgp
3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/train'
4 | pseudo_label_dir: 'caps_pseudo_label'
5 | scenes: [0000, 0005]
6 |
7 | sample_rate: 0.2
8 | inlier_ratio_thr: 0.3
9 | n_iters: 40000
10 | save_interval: 10000
11 |
--------------------------------------------------------------------------------
/code/perception2d/config_sgp_sample.yml:
--------------------------------------------------------------------------------
1 | # training from scratch using a single gpu, default configs in config.py file
2 | exp_name: caps_sgp
3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/train'
4 | pseudo_label_dir: 'caps_pseudo_label'
5 | scenes: [sample]
6 |
7 | sample_rate: 0.5
8 | inlier_ratio_thr: 0.3
9 | n_iters: 10
10 | save_interval: 10
11 |
--------------------------------------------------------------------------------
/code/perception2d/config_test.yml:
--------------------------------------------------------------------------------
1 | # training from scratch using a single gpu, default configs in config.py file
2 | exp_name: caps_test
3 | datadir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/test'
4 | label_dir: '/home/wei/Workspace/data/CAPS-MegaDepth-release-light/test-pose'
5 | ckpt_path: '/home/wei/Downloads/caps-pretrained.pth'
6 | scenes: [easy]
7 |
--------------------------------------------------------------------------------
/code/perception2d/config_train.yml:
--------------------------------------------------------------------------------
1 | # training from scratch using a single gpu, default configs in config.py file
2 | exp_name: caps_train
3 | datadir: /home/wei/Workspace/data/CAPS-MegaDepth-release-light/train
4 | scenes: [1001, 0020]
5 |
--------------------------------------------------------------------------------
/code/perception2d/sgp.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | caps_path = os.path.join(project_path, 'ext', 'caps')
8 | sys.path.append(caps_path)
9 |
10 | import cv2
11 | import torch
12 |
13 | from sgp_base import SGPBase
14 | from dataset.megadepth_sgp import DatasetMegaDepthSGP
15 | from perception2d.adaptor import CAPSConfigParser, DatasetMegaDepthSGPAdaptor, CAPSModel, caps_train, caps_test, align
16 | from geometry.image import *
17 |
18 |
19 | class SGP2DFundamental(SGPBase):
20 | def __init__(self):
21 | super(SGP2DFundamental, self).__init__()
22 |
23 | # override
24 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info,
25 | config):
26 | F, R, t, kpts_src, kpts_dst, matches, mask = align(
27 | src_data, dst_data, src_info, dst_info, 'sift', 'sift', None,
28 | config)
29 |
30 | if config.debug:
31 | im = draw_matches(kpts_src, kpts_dst, matches, src_data, dst_data,
32 | F, mask)
33 | cv2.imshow('matches', im)
34 | cv2.waitKey(-1)
35 |
36 | return F, (mask.sum(), len(matches))
37 |
38 | # override
39 | def perception(self, src_data, dst_data, src_info, dst_info, model,
40 | config):
41 | F, R, t, kpts_src, kpts_dst, matches, mask = align(
42 | src_data, dst_data, src_info, dst_info, 'sift', 'caps', model,
43 | config)
44 |
45 | if config.debug:
46 | im = draw_matches(kpts_src, kpts_dst, matches, src_data, dst_data,
47 | F, mask)
48 | cv2.imshow('matches', im)
49 | cv2.waitKey(-1)
50 |
51 | return F, (mask.sum(), len(matches))
52 |
53 | # override
54 | def train_adaptor(self, sgp_dataset, config):
55 | caps_train(sgp_dataset, config)
56 |
57 | def run(self, config):
58 | base_outdir = config.outdir
59 | base_logdir = config.logdir
60 | base_pseudo_label_dir = config.pseudo_label_dir
61 |
62 | pseudo_label_path_bs = os.path.join(base_pseudo_label_dir, 'bs')
63 |
64 | if config.restart_meta_iter < 0:
65 | # Only sample a subset for teaching.
66 | teach_dataset = DatasetMegaDepthSGP(config.datadir,
67 | config.scenes,
68 | pseudo_label_path_bs,
69 | 'teaching',
70 | inlier_ratio_thr=config.inlier_ratio_thr,
71 | num_matches_thr=config.num_matches_thr,
72 | sample_rate=config.sample_rate)
73 | print('Dataset size: {}'.format(len(teach_dataset)))
74 | sgp.teach_bootstrap(teach_dataset, config)
75 |
76 | learn_dataset = DatasetMegaDepthSGPAdaptor(
77 | DatasetMegaDepthSGP(config.datadir,
78 | config.scenes,
79 | pseudo_label_path_bs,
80 | 'learning',
81 | inlier_ratio_thr=config.inlier_ratio_thr,
82 | num_matches_thr=config.num_matches_thr,
83 | sample_rate=1), config)
84 | config.outdir = os.path.join(base_outdir, 'bs')
85 | config.logdir = os.path.join(base_logdir, 'bs')
86 | sgp.learn(learn_dataset, config)
87 |
88 | config.match_ratio_test = False
89 | start_meta_iter = max(config.restart_meta_iter, 0)
90 | for i in range(start_meta_iter, config.max_meta_iters):
91 | pseudo_label_path_i = os.path.join(base_pseudo_label_dir,
92 | '{:02d}'.format(i))
93 | teach_dataset = DatasetMegaDepthSGP(config.datadir,
94 | config.scenes,
95 | pseudo_label_path_i,
96 | 'teaching',
97 | inlier_ratio_thr=config.inlier_ratio_thr,
98 | num_matches_thr=config.num_matches_thr,
99 | sample_rate=config.sample_rate)
100 | model = CAPSModel(config)
101 | sgp.teach(teach_dataset, model, config)
102 |
103 | learn_dataset = DatasetMegaDepthSGPAdaptor(
104 | DatasetMegaDepthSGP(config.datadir,
105 | config.scenes,
106 | pseudo_label_path_i,
107 | 'learning',
108 | inlier_ratio_thr=config.inlier_ratio_thr,
109 | num_matches_thr=config.num_matches_thr,
110 | sample_rate=1), config)
111 |
112 | if not config.finetune:
113 | config.outdir = os.path.join(base_outdir, '{:02d}'.format(i))
114 | config.logdir = os.path.join(base_logdir, '{:02d}'.format(i))
115 | sgp.learn(learn_dataset, config)
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = CAPSConfigParser()
120 | parser.add(
121 | '--config',
122 | is_config_file=True,
123 | default=os.path.join(os.path.dirname(__file__),
124 | 'config_sgp_sample.yml'),
125 | help='YAML config file path. Please refer to caps_config.yml as a '
126 | 'reference. It overrides the default config file, but will be '
127 | 'overridden by other command line inputs.')
128 | parser.add('--debug', action='store_true')
129 | config = parser.get_config()
130 |
131 | sgp = SGP2DFundamental()
132 | sgp.run(config)
133 |
--------------------------------------------------------------------------------
/code/perception2d/test.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | caps_path = os.path.join(project_path, 'ext', 'caps')
8 | sys.path.append(caps_path)
9 |
10 | from dataset.megadepth_test import DatasetMegaDepthTest
11 | from perception2d.adaptor import CAPSConfigParser, caps_test
12 |
13 | import numpy as np
14 |
15 | if __name__ == '__main__':
16 | parser = CAPSConfigParser()
17 | parser.add(
18 | '--config',
19 | is_config_file=True,
20 | default=os.path.join(os.path.dirname(__file__), 'config_test.yml'),
21 | help='YAML config file path. Please refer to caps_config.yml as a '
22 | 'reference. It overrides the default config file, but will be '
23 | 'overridden by other command line inputs.')
24 | parser.add('--debug', action='store_true')
25 | parser.add('--output', type=str, default='caps_test_result.npz')
26 | config = parser.get_config()
27 |
28 | # Note: for testing, our own interface would suffices.
29 | config.match_ratio_test = False
30 | dataset = DatasetMegaDepthTest(config.datadir, config.scenes, config.label_dir)
31 | r_errs, t_errs = caps_test(dataset, config)
32 |
33 | rot_recall = (r_errs < 10.0)
34 | angular_trans_recall = (t_errs < 10.0)
35 | print('Rotation Recall: {}/{} = {}'.format(
36 | rot_recall.sum(), len(rot_recall),
37 | float(rot_recall.sum()) / len(rot_recall)))
38 | print('Translation Recall: {}/{} = {}'.format(
39 | angular_trans_recall.sum(), len(angular_trans_recall),
40 | float(angular_trans_recall.sum()) / len(angular_trans_recall)))
41 |
42 | np.savez(config.output, rotation_errs=r_errs, translation_errs=t_errs)
43 |
--------------------------------------------------------------------------------
/code/perception2d/train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | caps_path = os.path.join(project_path, 'ext', 'caps')
8 | sys.path.append(caps_path)
9 |
10 | from dataset.megadepth_train import DatasetMegaDepthTrain
11 | from perception2d.adaptor import CAPSConfigParser, DatasetMegaDepthTrainAdaptor, caps_train
12 |
13 | if __name__ == '__main__':
14 | parser = CAPSConfigParser()
15 | parser.add(
16 | '--config',
17 | is_config_file=True,
18 | default=os.path.join(os.path.dirname(__file__), 'config_train.yml'),
19 | help='YAML config file path. Please refer to caps_config.yml as a '
20 | 'reference. It overrides the default config file, but will be '
21 | 'overridden by other command line inputs.')
22 | config = parser.get_config()
23 |
24 | # Note: for training, we need to wrap up with an adaptor to provide a consistent interface.
25 | dataset = DatasetMegaDepthTrainAdaptor(
26 | DatasetMegaDepthTrain(config.datadir, config.scenes), config)
27 | caps_train(dataset, config)
28 |
--------------------------------------------------------------------------------
/code/perception3d/adaptor.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF')
8 | sys.path.append(fcgf_path)
9 |
10 | import torch
11 | from easydict import EasyDict as edict
12 |
13 | from ext.FCGF.lib.data_loaders import *
14 | from ext.FCGF.lib.trainer import *
15 | from ext.FCGF.model import load_model
16 |
17 | from dataset.threedmatch_train import Dataset3DMatchTrain
18 | from dataset.threedmatch_test import Dataset3DMatchTest
19 | from dataset.threedmatch_sgp import Dataset3DMatchSGP
20 |
21 | from geometry.pointcloud import *
22 | from geometry.common import rotation_error, translation_error
23 |
24 | from tqdm import tqdm
25 |
26 | import configargparse
27 |
28 |
29 | def reload_config(config):
30 | dconfig = vars(config)
31 |
32 | if config.resume_dir:
33 | resume_config = json.load(open(config.resume_dir + '/config.json',
34 | 'r'))
35 | for k in dconfig:
36 | if k not in ['resume_dir'] and k in resume_config:
37 | dconfig[k] = resume_config[k]
38 | dconfig['resume'] = resume_config['out_dir'] + '/checkpoint.pth'
39 |
40 | return edict(dconfig)
41 |
42 |
43 | class FCGFConfigParser(configargparse.ArgParser):
44 | def __init__(self):
45 | super().__init__(default_config_files=[
46 | os.path.join(os.path.dirname(__file__), 'fcgf_config.yml')
47 | ],
48 | conflict_handler='resolve')
49 |
50 | # Mainly used params
51 | self.add('--dataset_path',
52 | type=str,
53 | default="/home/wei/Workspace/data/threedmatch_reorg")
54 | self.add('--scenes',
55 | nargs='+',
56 | help='scenes used for training/testing')
57 | self.add('--out_dir',
58 | type=str,
59 | default='fcgf_outputs',
60 | help='outputs containing summary and checkpoints')
61 |
62 | self.add(
63 | '--restart_meta_iter',
64 | type=int,
65 | default=-1,
66 | help='Restart of teacher-student iterations. -1 indicates bootstrap'
67 | )
68 | self.add('--meta_iters',
69 | type=int,
70 | default=2,
71 | help='number of teacher-student iterations')
72 | self.add('--finetune',
73 | action='store_true',
74 | help='train from previous checkpoint during SGP.')
75 |
76 | self.add('--pseudo_label_dir',
77 | type=str,
78 | help='the pseudo-gt directory storing pairs and T matrices')
79 |
80 | self.add('--overlap_thr',
81 | type=float,
82 | default=0.3,
83 | help='overlap threshold to filter outlier pairs')
84 | self.add('--voxel_size', type=float, default=0.05)
85 | self.add('--mutual_filter', type=bool, default=False)
86 | self.add('--ransac_iters', type=int, default=10000)
87 | self.add('--confidence', type=float, default=0.9999)
88 |
89 | # Other core configs from the FCGF repo.
90 | # See https://github.com/chrischoy/FCGF/blob/master/config.py
91 | self.add('--trainer',
92 | type=str,
93 | default='HardestContrastiveLossTrainer')
94 | self.add('--save_freq_epoch', type=int, default=1)
95 | self.add('--batch_size', type=int, default=4)
96 | self.add('--val_batch_size', type=int, default=1)
97 |
98 | # Hard negative mining
99 | self.add('--use_hard_negative', type=bool, default=True)
100 | self.add('--hard_negative_sample_ratio', type=int, default=0.05)
101 | self.add('--hard_negative_max_num', type=int, default=3000)
102 | self.add('--num_pos_per_batch', type=int, default=1024)
103 | self.add('--num_hn_samples_per_batch', type=int, default=256)
104 |
105 | # Metric learning loss
106 | self.add('--neg_thresh', type=float, default=1.4)
107 | self.add('--pos_thresh', type=float, default=0.1)
108 | self.add('--neg_weight', type=float, default=1)
109 |
110 | # Data augmentation
111 | self.add('--use_random_scale', type=bool, default=False)
112 | self.add('--min_scale', type=float, default=0.8)
113 | self.add('--max_scale', type=float, default=1.2)
114 | self.add('--use_random_rotation', type=bool, default=True)
115 | self.add('--rotation_range', type=float, default=360)
116 |
117 | # Data loader configs
118 | self.add('--train_phase', type=str, default="train")
119 | self.add('--val_phase', type=str, default="val")
120 | self.add('--test_phase', type=str, default="test")
121 |
122 | self.add('--stat_freq', type=int, default=40)
123 | self.add('--test_valid', type=bool, default=True)
124 | self.add('--val_max_iter', type=int, default=400)
125 | self.add('--val_epoch_freq', type=int, default=1)
126 | self.add('--positive_pair_search_voxel_size_multiplier',
127 | type=float,
128 | default=1.5)
129 |
130 | self.add('--hit_ratio_thresh', type=float, default=0.1)
131 |
132 | # Triplets
133 | self.add('--triplet_num_pos', type=int, default=256)
134 | self.add('--triplet_num_hn', type=int, default=512)
135 | self.add('--triplet_num_rand', type=int, default=1024)
136 |
137 | # Network specific configurations
138 | self.add('--model', type=str, default='ResUNetBN2C')
139 | self.add('--model_n_out',
140 | type=int,
141 | default=32,
142 | help='Feature dimension')
143 | self.add('--conv1_kernel_size', type=int, default=5)
144 | self.add('--normalize_feature', type=bool, default=True)
145 | self.add('--dist_type', type=str, default='L2')
146 | self.add('--best_val_metric', type=str, default='feat_match_ratio')
147 |
148 | # Optimizer arguments
149 | self.add('--optimizer', type=str, default='SGD')
150 | self.add('--max_epoch', type=int, default=100)
151 | self.add('--lr', type=float, default=1e-1)
152 | self.add('--momentum', type=float, default=0.8)
153 | self.add('--sgd_momentum', type=float, default=0.9)
154 | self.add('--sgd_dampening', type=float, default=0.1)
155 | self.add('--adam_beta1', type=float, default=0.9)
156 | self.add('--adam_beta2', type=float, default=0.999)
157 | self.add('--weight_decay', type=float, default=1e-4)
158 | self.add('--iter_size',
159 | type=int,
160 | default=1,
161 | help='accumulate gradient')
162 | self.add('--bn_momentum', type=float, default=0.05)
163 | self.add('--exp_gamma', type=float, default=0.99)
164 | self.add('--scheduler', type=str, default='ExpLR')
165 |
166 | self.add('--use_gpu', type=bool, default=True)
167 | self.add('--weights', type=str, default=None)
168 | self.add('--weights_dir', type=str, default=None)
169 | self.add('--resume', type=str, default=None)
170 | self.add('--resume_dir', type=str, default=None)
171 | self.add('--train_num_thread', type=int, default=2)
172 | self.add('--val_num_thread', type=int, default=1)
173 | self.add('--test_num_thread', type=int, default=2)
174 | self.add('--fast_validation', type=bool, default=False)
175 | self.add(
176 | '--nn_max_n',
177 | type=int,
178 | default=500,
179 | help=
180 | 'The maximum number of features to find nearest neighbors in batch'
181 | )
182 |
183 | def get_config(self):
184 | config = self.parse_args()
185 | config.device = 'cuda' if config.use_gpu else 'cpu'
186 |
187 | return reload_config(config)
188 |
189 |
190 | def get_trainer(trainer):
191 | if trainer == 'ContrastiveLossTrainer':
192 | return ContrastiveLossTrainer
193 | elif trainer == 'HardestContrastiveLossTrainer':
194 | return HardestContrastiveLossTrainer
195 | elif trainer == 'TripletLossTrainer':
196 | return TripletLossTrainer
197 | elif trainer == 'HardestTripletLossTrainer':
198 | return HardestTripletLossTrainer
199 | else:
200 | raise ValueError(f'Trainer {trainer} not found')
201 |
202 |
203 | class DatasetFCGFAdaptor(torch.utils.data.Dataset):
204 | '''
205 | Wrapper dataset for our data format and FCGF's sample format
206 | '''
207 | def __init__(self, dataset, config):
208 | self.dataset = dataset
209 | self.randg = np.random.RandomState()
210 | self.config = config
211 |
212 | def reset_seed(self, seed):
213 | self.randg.seed(seed)
214 |
215 | def apply_transform(self, pts, trans):
216 | R = trans[:3, :3]
217 | T = trans[:3, 3]
218 | pts = pts @ R.T + T
219 | return pts
220 |
221 | def __len__(self):
222 | return len(self.dataset)
223 |
224 | def __getitem__(self, idx):
225 | pcd0, pcd1, _, _, trans = self.dataset[idx]
226 |
227 | xyz0 = np.asarray(pcd0.points)
228 | xyz1 = np.asarray(pcd1.points)
229 |
230 | # Data augmentation
231 | T0 = sample_random_trans(xyz0, self.randg, 360)
232 | T1 = sample_random_trans(xyz1, self.randg, 360)
233 | trans = T1 @ trans @ np.linalg.inv(T0)
234 |
235 | xyz0 = self.apply_transform(xyz0, T0)
236 | xyz1 = self.apply_transform(xyz1, T1)
237 |
238 | # Voxelization after random transformation
239 | voxel_size = 0.05
240 | _, sel0 = ME.utils.sparse_quantize(xyz0,
241 | return_index=True,
242 | quantization_size=voxel_size)
243 | _, sel1 = ME.utils.sparse_quantize(xyz1,
244 | return_index=True,
245 | quantization_size=voxel_size)
246 | xyz0 = xyz0[sel0]
247 | xyz1 = xyz1[sel1]
248 |
249 | # Make point clouds using voxelized points
250 | pcd0 = make_o3d_pointcloud(xyz0)
251 | pcd1 = make_o3d_pointcloud(xyz1)
252 | matches = get_matching_indices(pcd0, pcd1, trans, voxel_size * 2)
253 |
254 | # Dummy features
255 | feats0 = np.ones((xyz0.shape[0], 1))
256 | feats1 = np.ones((xyz1.shape[0], 1))
257 |
258 | # Coordinates
259 | coords0 = np.floor(xyz0 / voxel_size)
260 | coords1 = np.floor(xyz1 / voxel_size)
261 |
262 | return (xyz0, xyz1, coords0, coords1, feats0, feats1, matches, trans)
263 |
264 |
265 | def load_fcgf_model(config):
266 | resume_ckpt_path = config.resume
267 | input_ckpt_path = config.weights
268 | out_ckpt_path = os.path.join(config.out_dir, 'checkpoint.pth')
269 |
270 | if resume_ckpt_path is not None and os.path.isfile(resume_ckpt_path):
271 | ckpt_path = resume_ckpt_path
272 | elif input_ckpt_path is not None and os.path.isfile(input_ckpt_path):
273 | ckpt_path = input_ckpt_path
274 | elif out_ckpt_path is not None and os.path.isfile(out_ckpt_path):
275 | ckpt_path = out_ckpt_path
276 | else:
277 | raise NotImplementedError('checkpoint not found, abort')
278 |
279 | print(f'load FCGF from checkpoint {ckpt_path}.')
280 | checkpoint = torch.load(ckpt_path)
281 | ckpt_cfg = checkpoint['config']
282 |
283 | Model = load_model(ckpt_cfg['model'])
284 | model = Model(in_channels=1,
285 | out_channels=ckpt_cfg['model_n_out'],
286 | bn_momentum=ckpt_cfg['bn_momentum'],
287 | normalize_feature=ckpt_cfg['normalize_feature'],
288 | conv1_kernel_size=ckpt_cfg['conv1_kernel_size'],
289 | D=3)
290 | model.load_state_dict(checkpoint['state_dict'])
291 | return model.to(config.device)
292 |
293 |
294 | def register(pcd_src, pcd_dst, feature, solver, model, config):
295 | pcd_src, feat_src = extract_feats(pcd_src, feature, config.voxel_size,
296 | model)
297 | pcd_dst, feat_dst = extract_feats(pcd_dst, feature, config.voxel_size,
298 | model)
299 | corrs = match_feats(feat_src, feat_dst, mutual_filter=config.mutual_filter)
300 |
301 | if len(corrs) < 10:
302 | print('Too few corres ({}), abort'.format(len(corrs)))
303 | return np.eye(4), 0
304 |
305 | T, fitness = solve(pcd_src, pcd_dst, corrs, solver,
306 | config.voxel_size * 1.4, config.ransac_iters,
307 | config.confidence)
308 | if fitness > 1e-6:
309 | T, fitness = refine(pcd_src, pcd_dst, T, config.voxel_size * 1.4)
310 |
311 | return T, fitness
312 |
313 |
314 | def fcgf_train(dataset, config):
315 | ch = logging.StreamHandler(sys.stdout)
316 | logging.getLogger().setLevel(logging.INFO)
317 | logging.basicConfig(format='%(asctime)s %(message)s',
318 | datefmt='%m/%d %H:%M:%S',
319 | handlers=[ch])
320 |
321 | torch.manual_seed(0)
322 | torch.cuda.manual_seed(0)
323 |
324 | logging.basicConfig(level=logging.INFO, format="")
325 |
326 | dataloader = torch.utils.data.DataLoader(dataset,
327 | batch_size=8,
328 | shuffle=True,
329 | num_workers=8,
330 | collate_fn=collate_pair_fn,
331 | pin_memory=False,
332 | drop_last=True)
333 | Trainer = get_trainer(config.trainer)
334 | trainer = Trainer(config=config, data_loader=dataloader)
335 | trainer.train()
336 |
337 |
338 | def fcgf_test(dataset, config):
339 | model = load_fcgf_model(config)
340 |
341 | r_errs = []
342 | t_errs = []
343 |
344 | for data in tqdm(dataset):
345 | pcd_src, pcd_dst, _, _, T_gt = data
346 |
347 | T, fitness = register(pcd_src, pcd_dst, 'FCGF', 'RANSAC', model,
348 | config)
349 | r_err = rotation_error(T[:3, :3], T_gt[:3, :3])
350 | t_err = translation_error(T[:3, 3], T_gt[:3, 3])
351 | r_errs.append(r_err)
352 | t_errs.append(t_err)
353 |
354 | if config.debug:
355 | pcd_src.paint_uniform_color([1, 0, 0])
356 | pcd_dst.paint_uniform_color([0, 1, 0])
357 | o3d.visualization.draw_geometries([pcd_src.transform(T), pcd_dst])
358 |
359 | return np.array(r_errs), np.array(t_errs)
360 |
--------------------------------------------------------------------------------
/code/perception3d/config_sgp.yml:
--------------------------------------------------------------------------------
1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg'
2 | pseudo_label_dir: fcgf_pseudo_label
3 | scenes: [7-scenes-chess@seq-01, 7-scenes-chess@seq-02, 7-scenes-chess@seq-03, 7-scenes-chess@seq-04, 7-scenes-chess@seq-05, 7-scenes-chess@seq-06, 7-scenes-fire@seq-01, 7-scenes-fire@seq-02, 7-scenes-fire@seq-03, 7-scenes-fire@seq-04, 7-scenes-heads@seq-01, 7-scenes-heads@seq-02, 7-scenes-office@seq-01, 7-scenes-office@seq-02, 7-scenes-office@seq-03, 7-scenes-office@seq-04, 7-scenes-office@seq-05, 7-scenes-office@seq-06, 7-scenes-office@seq-07, 7-scenes-office@seq-08, 7-scenes-office@seq-09, 7-scenes-office@seq-10, 7-scenes-pumpkin@seq-01, 7-scenes-pumpkin@seq-02, 7-scenes-pumpkin@seq-03, 7-scenes-pumpkin@seq-06, 7-scenes-pumpkin@seq-07, 7-scenes-pumpkin@seq-08, 7-scenes-redkitchen@seq-01, 7-scenes-redkitchen@seq-02, 7-scenes-redkitchen@seq-03, 7-scenes-redkitchen@seq-04, 7-scenes-redkitchen@seq-05, 7-scenes-redkitchen@seq-06, 7-scenes-redkitchen@seq-07, 7-scenes-redkitchen@seq-08, 7-scenes-redkitchen@seq-11, 7-scenes-redkitchen@seq-12, 7-scenes-redkitchen@seq-13, 7-scenes-redkitchen@seq-14, 7-scenes-stairs@seq-01, 7-scenes-stairs@seq-02, 7-scenes-stairs@seq-03, 7-scenes-stairs@seq-04, 7-scenes-stairs@seq-05, 7-scenes-stairs@seq-06, analysis-by-synthesis-apt1-kitchen@seq-01, analysis-by-synthesis-apt1-living@seq-01, analysis-by-synthesis-apt2-bed@seq-01, analysis-by-synthesis-apt2-kitchen@seq-01, analysis-by-synthesis-apt2-living@seq-01, analysis-by-synthesis-apt2-luke@seq-01, analysis-by-synthesis-office2-5a@seq-01, analysis-by-synthesis-office2-5b@seq-01, bundlefusion-apt0@seq-01, bundlefusion-apt1@seq-01, bundlefusion-apt2@seq-01, bundlefusion-copyroom@seq-01, bundlefusion-office0@seq-01, bundlefusion-office1@seq-01, bundlefusion-office2@seq-01, bundlefusion-office3@seq-01, rgbd-scenes-v2-scene_01@seq-01, rgbd-scenes-v2-scene_02@seq-01, rgbd-scenes-v2-scene_03@seq-01, rgbd-scenes-v2-scene_04@seq-01, rgbd-scenes-v2-scene_05@seq-01, rgbd-scenes-v2-scene_06@seq-01, rgbd-scenes-v2-scene_07@seq-01, rgbd-scenes-v2-scene_08@seq-01, rgbd-scenes-v2-scene_09@seq-01, rgbd-scenes-v2-scene_10@seq-01, rgbd-scenes-v2-scene_11@seq-01, rgbd-scenes-v2-scene_12@seq-01, rgbd-scenes-v2-scene_13@seq-01, rgbd-scenes-v2-scene_14@seq-01, sun3d-brown_bm_1-brown_bm_1@seq-01, sun3d-brown_bm_4-brown_bm_4@seq-01, sun3d-brown_cogsci_1-brown_cogsci_1@seq-01, sun3d-brown_cs_2-brown_cs2@seq-01, sun3d-brown_cs_3-brown_cs3@seq-01, sun3d-harvard_c11-hv_c11_2@seq-01, sun3d-harvard_c3-hv_c3_1@seq-01, sun3d-harvard_c5-hv_c5_1@seq-01, sun3d-harvard_c6-hv_c6_1@seq-01, sun3d-harvard_c8-hv_c8_3@seq-01, sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika@seq-01, sun3d-home_md-home_md_scan9_2012_sep_30@seq-01, sun3d-hotel_nips2012-nips_4@seq-01, sun3d-hotel_sf-scan1@seq-01, sun3d-hotel_uc-scan3@seq-01, sun3d-hotel_umd-maryland_hotel1@seq-01, sun3d-hotel_umd-maryland_hotel3@seq-01, sun3d-mit_32_d507-d507_2@seq-01, sun3d-mit_46_ted_lab1-ted_lab_2@seq-01, sun3d-mit_76_417-76-417b@seq-01, sun3d-mit_76_studyroom-76-1studyroom2@seq-01, sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika@seq-01, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika@seq-01, sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika@seq-01]
4 |
--------------------------------------------------------------------------------
/code/perception3d/config_sgp_sample.yml:
--------------------------------------------------------------------------------
1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg'
2 | pseudo_label_dir: fcgf_pseudo_label
3 | scenes: [7-scenes-chess@seq-01]
4 |
--------------------------------------------------------------------------------
/code/perception3d/config_test.yml:
--------------------------------------------------------------------------------
1 | dataset_path: '/home/wei/Workspace/data/threedmatch_test'
2 | weights: '/home/wei/Downloads/fcgf-pretrained.pth'
3 | scenes: [7-scenes-redkitchen, sun3d-home_at-home_at_scan1_2013_jan_1, sun3d-home_md-home_md_scan9_2012_sep_30, sun3d-hotel_uc-scan3, sun3d-hotel_umd-maryland_hotel1, sun3d-hotel_umd-maryland_hotel3, sun3d-mit_76_studyroom-76-1studyroom2, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika]
4 |
--------------------------------------------------------------------------------
/code/perception3d/config_train.yml:
--------------------------------------------------------------------------------
1 | dataset_path: '/home/wei/Workspace/data/threedmatch_reorg'
2 | scenes: [7-scenes-chess@seq-01, 7-scenes-chess@seq-02, 7-scenes-chess@seq-03, 7-scenes-chess@seq-04, 7-scenes-chess@seq-05, 7-scenes-chess@seq-06, 7-scenes-fire@seq-01, 7-scenes-fire@seq-02, 7-scenes-fire@seq-03, 7-scenes-fire@seq-04, 7-scenes-heads@seq-01, 7-scenes-heads@seq-02, 7-scenes-office@seq-01, 7-scenes-office@seq-02, 7-scenes-office@seq-03, 7-scenes-office@seq-04, 7-scenes-office@seq-05, 7-scenes-office@seq-06, 7-scenes-office@seq-07, 7-scenes-office@seq-08, 7-scenes-office@seq-09, 7-scenes-office@seq-10, 7-scenes-pumpkin@seq-01, 7-scenes-pumpkin@seq-02, 7-scenes-pumpkin@seq-03, 7-scenes-pumpkin@seq-06, 7-scenes-pumpkin@seq-07, 7-scenes-pumpkin@seq-08, 7-scenes-redkitchen@seq-01, 7-scenes-redkitchen@seq-02, 7-scenes-redkitchen@seq-03, 7-scenes-redkitchen@seq-04, 7-scenes-redkitchen@seq-05, 7-scenes-redkitchen@seq-06, 7-scenes-redkitchen@seq-07, 7-scenes-redkitchen@seq-08, 7-scenes-redkitchen@seq-11, 7-scenes-redkitchen@seq-12, 7-scenes-redkitchen@seq-13, 7-scenes-redkitchen@seq-14, 7-scenes-stairs@seq-01, 7-scenes-stairs@seq-02, 7-scenes-stairs@seq-03, 7-scenes-stairs@seq-04, 7-scenes-stairs@seq-05, 7-scenes-stairs@seq-06, analysis-by-synthesis-apt1-kitchen@seq-01, analysis-by-synthesis-apt1-living@seq-01, analysis-by-synthesis-apt2-bed@seq-01, analysis-by-synthesis-apt2-kitchen@seq-01, analysis-by-synthesis-apt2-living@seq-01, analysis-by-synthesis-apt2-luke@seq-01, analysis-by-synthesis-office2-5a@seq-01, analysis-by-synthesis-office2-5b@seq-01, bundlefusion-apt0@seq-01, bundlefusion-apt1@seq-01, bundlefusion-apt2@seq-01, bundlefusion-copyroom@seq-01, bundlefusion-office0@seq-01, bundlefusion-office1@seq-01, bundlefusion-office2@seq-01, bundlefusion-office3@seq-01, rgbd-scenes-v2-scene_01@seq-01, rgbd-scenes-v2-scene_02@seq-01, rgbd-scenes-v2-scene_03@seq-01, rgbd-scenes-v2-scene_04@seq-01, rgbd-scenes-v2-scene_05@seq-01, rgbd-scenes-v2-scene_06@seq-01, rgbd-scenes-v2-scene_07@seq-01, rgbd-scenes-v2-scene_08@seq-01, rgbd-scenes-v2-scene_09@seq-01, rgbd-scenes-v2-scene_10@seq-01, rgbd-scenes-v2-scene_11@seq-01, rgbd-scenes-v2-scene_12@seq-01, rgbd-scenes-v2-scene_13@seq-01, rgbd-scenes-v2-scene_14@seq-01, sun3d-brown_bm_1-brown_bm_1@seq-01, sun3d-brown_bm_4-brown_bm_4@seq-01, sun3d-brown_cogsci_1-brown_cogsci_1@seq-01, sun3d-brown_cs_2-brown_cs2@seq-01, sun3d-brown_cs_3-brown_cs3@seq-01, sun3d-harvard_c11-hv_c11_2@seq-01, sun3d-harvard_c3-hv_c3_1@seq-01, sun3d-harvard_c5-hv_c5_1@seq-01, sun3d-harvard_c6-hv_c6_1@seq-01, sun3d-harvard_c8-hv_c8_3@seq-01, sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika@seq-01, sun3d-home_md-home_md_scan9_2012_sep_30@seq-01, sun3d-hotel_nips2012-nips_4@seq-01, sun3d-hotel_sf-scan1@seq-01, sun3d-hotel_uc-scan3@seq-01, sun3d-hotel_umd-maryland_hotel1@seq-01, sun3d-hotel_umd-maryland_hotel3@seq-01, sun3d-mit_32_d507-d507_2@seq-01, sun3d-mit_46_ted_lab1-ted_lab_2@seq-01, sun3d-mit_76_417-76-417b@seq-01, sun3d-mit_76_studyroom-76-1studyroom2@seq-01, sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika@seq-01, sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika@seq-01, sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika@seq-01]
3 |
--------------------------------------------------------------------------------
/code/perception3d/sgp.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF')
8 | sys.path.append(fcgf_path)
9 |
10 | from tqdm import tqdm
11 |
12 | import torch
13 | import numpy as np
14 | import open3d as o3d
15 |
16 | from sgp_base import SGPBase
17 | from dataset.threedmatch_sgp import Dataset3DMatchSGP
18 | from perception3d.adaptor import DatasetFCGFAdaptor, FCGFConfigParser, load_fcgf_model, fcgf_train, reload_config, register
19 |
20 |
21 | class SGP3DRegistration(SGPBase):
22 | def __init__(self):
23 | super(SGP3DRegistration, self).__init__()
24 |
25 | # override
26 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info,
27 | config):
28 | T, fitness = register(src_data, dst_data, 'FPFH', 'RANSAC', None,
29 | config)
30 | if config.debug:
31 | src_data.paint_uniform_color([1, 0, 0])
32 | dst_data.paint_uniform_color([0, 1, 0])
33 | o3d.visualization.draw([src_data.transform(T), dst_data])
34 | return T, fitness
35 |
36 | # override
37 | def perception(self, src_data, dst_data, src_info, dst_info, model,
38 | config):
39 | T, fitness = register(src_data, dst_data, 'FCGF', 'RANSAC', model,
40 | config)
41 | if config.debug:
42 | src_data.paint_uniform_color([1, 0, 0])
43 | dst_data.paint_uniform_color([0, 1, 0])
44 | o3d.visualization.draw([src_data.transform(T), dst_data])
45 | return T, fitness
46 |
47 | # override
48 | def train_adaptor(self, sgp_dataset, config):
49 | fcgf_train(sgp_dataset, config)
50 |
51 | def run(self, config):
52 | epochs = config.max_epoch
53 | base_pseudo_label_dir = config.pseudo_label_dir
54 | base_outdir = config.out_dir
55 |
56 | # Bootstrap
57 | if config.restart_meta_iter < 0:
58 | pseudo_label_path_bs = os.path.join(base_pseudo_label_dir, 'bs')
59 | teach_dataset = Dataset3DMatchSGP(config.dataset_path,
60 | config.scenes,
61 | pseudo_label_path_bs, 'teaching',
62 | config.overlap_thr)
63 | # We need mutual filter for less reliable FPFH
64 | config.mutual_filter = True
65 | sgp.teach_bootstrap(teach_dataset, config)
66 |
67 | learn_dataset = DatasetFCGFAdaptor(
68 | Dataset3DMatchSGP(config.dataset_path, config.scenes,
69 | pseudo_label_path_bs, 'learning',
70 | config.overlap_thr), config)
71 | config.out_dir = os.path.join(base_outdir, 'bs')
72 | sgp.learn(learn_dataset, config)
73 |
74 | # Loop
75 | start_meta_iter = max(config.restart_meta_iter, 0)
76 | for i in range(start_meta_iter, config.meta_iters):
77 | pseudo_label_path_i = os.path.join(config.pseudo_label_dir,
78 | '{:02d}'.format(i))
79 | teach_dataset = Dataset3DMatchSGP(config.dataset_path,
80 | config.scenes,
81 | pseudo_label_path_i, 'teaching',
82 | config.overlap_thr)
83 |
84 | # No mutual filter results in better FCGF teaching
85 | config.mutual_filter = False
86 | model = load_fcgf_model(config)
87 | sgp.teach(teach_dataset, model, config)
88 |
89 | learn_dataset = DatasetFCGFAdaptor(
90 | Dataset3DMatchSGP(config.dataset_path, config.scenes,
91 | pseudo_label_path_i, 'learning',
92 | config.overlap_thr), config)
93 |
94 | # There is a bug in FCGF finetuning.
95 | # Suppose previous epochs are [1, n],
96 | # then finetuning will be [n, n+n] (double counting n), instead of [n+1, n+n]
97 | # To address this without changing the original repo, we need to reduce max epochs by 1.
98 | # The actual finetuning iters will be correct, while FCGF's output will be slightly different.
99 | if config.finetune:
100 | config.resume_dir = config.out_dir
101 | config = reload_config(config)
102 | config.max_epoch += (epochs - 1)
103 | else:
104 | config.out_dir = os.path.join(base_outdir, '{:02d}'.format(i))
105 |
106 | sgp.learn(learn_dataset, config)
107 |
108 |
109 | if __name__ == '__main__':
110 | parser = FCGFConfigParser()
111 | parser.add(
112 | '--config',
113 | is_config_file=True,
114 | default=os.path.join(os.path.dirname(__file__),
115 | 'config_sgp_sample.yml'),
116 | help='YAML config file path. Please refer to caps_config.yml as a '
117 | 'reference. It overrides the default config file, but will be '
118 | 'overridden by other command line inputs.')
119 | parser.add('--debug', action='store_true')
120 | config = parser.get_config()
121 |
122 | sgp = SGP3DRegistration()
123 | sgp.run(config)
124 |
--------------------------------------------------------------------------------
/code/perception3d/test.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF')
8 | sys.path.append(fcgf_path)
9 |
10 | from dataset.threedmatch_test import Dataset3DMatchTest
11 | from perception3d.adaptor import FCGFConfigParser, fcgf_test
12 |
13 | import numpy as np
14 |
15 | if __name__ == '__main__':
16 | parser = FCGFConfigParser()
17 | parser.add(
18 | '--config',
19 | is_config_file=True,
20 | default=os.path.join(os.path.dirname(__file__), 'config_test.yml'),
21 | help='YAML config file path. Please refer to caps_config.yml as a '
22 | 'reference. It overrides the default config file, but will be '
23 | 'overridden by other command line inputs.')
24 | parser.add('--debug', action='store_true')
25 | parser.add('--output', type=str, default='fcgf_test_result.npz')
26 | config = parser.get_config()
27 |
28 | dataset = Dataset3DMatchTest(config.dataset_path, config.scenes)
29 | r_errs, t_errs = fcgf_test(dataset, config)
30 |
31 | recall = (r_errs < 15.0) * (t_errs < 0.3)
32 | print('Recall: {}/{} = {}'.format(recall.sum(), len(recall),
33 | float(recall.sum()) / len(recall)))
34 |
35 | np.savez(config.output, rotation_errs=r_errs, translation_errs=t_errs)
36 |
--------------------------------------------------------------------------------
/code/perception3d/train.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | file_path = os.path.abspath(__file__)
4 | project_path = os.path.dirname(os.path.dirname(file_path))
5 | sys.path.append(project_path)
6 |
7 | fcgf_path = os.path.join(project_path, 'ext', 'FCGF')
8 | sys.path.append(fcgf_path)
9 |
10 | from dataset.threedmatch_train import Dataset3DMatchTrain
11 | from perception3d.adaptor import DatasetFCGFAdaptor, FCGFConfigParser, fcgf_train
12 |
13 | if __name__ == '__main__':
14 | parser = FCGFConfigParser()
15 | parser.add(
16 | '--config',
17 | is_config_file=True,
18 | default=os.path.join(os.path.dirname(__file__), 'config_train.yml'),
19 | help='YAML config file path. Please refer to caps_config.yml as a '
20 | 'reference. It overrides the default config file, but will be '
21 | 'overridden by other command line inputs.')
22 | config = parser.get_config()
23 |
24 | dataset = DatasetFCGFAdaptor(
25 | Dataset3DMatchTrain(config.dataset_path, config.scenes, config.overlap_thr), config)
26 | fcgf_train(dataset, config)
27 |
--------------------------------------------------------------------------------
/code/sgp_base.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | class SGPBase:
4 | def __init__(self):
5 | pass
6 |
7 | def teach_bootstrap(self, sgp_dataset, config):
8 | '''
9 | Teach without deep models. Use classical FPFH/SIFT features.
10 | '''
11 | for i, data in tqdm(enumerate(sgp_dataset)):
12 | src_data, dst_data, src_info, dst_info, pair_info = data
13 |
14 | label, pair_info = self.perception_bootstrap(
15 | src_data, dst_data, src_info, dst_info, config)
16 | sgp_dataset.write_pseudo_label(i, label, pair_info)
17 |
18 | def teach(self, sgp_dataset, model, config):
19 | '''
20 | Teach with deep models. Use learned FCGF/CAPS features.
21 | '''
22 | for i, data in tqdm(enumerate(sgp_dataset)):
23 | src_data, dst_data, src_info, dst_info, pair_info = data
24 |
25 | # if self.is_valid(src_info, dst_info, pair_info):
26 | label, pair_info = self.perception(src_data, dst_data, src_info,
27 | dst_info, model, config)
28 | sgp_dataset.write_pseudo_label(i, label, pair_info)
29 |
30 | def learn(self, sgp_dataset, config):
31 | # Adapt and dispatch training script to external implementations
32 | self.train_adaptor(sgp_dataset, config)
33 |
34 | # override
35 | def train_adaptor(self, sgp_dataset, config):
36 | pass
37 |
38 | # override
39 | def perception_bootstrap(self, src_data, dst_data, src_info, dst_info):
40 | pass
41 |
42 | # override
43 | def perception(self, src_data, dst_data, src_info, dst_info, model):
44 | pass
45 |
46 |
--------------------------------------------------------------------------------