├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── area_matchers ├── AGBasic.py ├── AGConfig.py ├── AGMatcherFree.py ├── AGUtils.py ├── AreaGrapher.py ├── AreaMatchDense.py ├── AreaPreprocessor.py ├── CoarseAreaMatcher.py ├── abstract_am.py ├── dmesa.py ├── mesa.py └── sem_am.py ├── assets ├── A2PM.png ├── Qua.png ├── mesa-ava.png ├── mesa-main.png └── run_MESA_on_win11.md ├── conf ├── area_matcher │ ├── dmesa.yaml │ ├── mesa-f.yaml │ └── sem_area_matcher.yaml ├── dataset │ ├── demo_pair.yaml │ ├── megadepth.yaml │ ├── scannet_sam.yaml │ └── scannet_seem.yaml ├── evaler │ └── instance_eval.yaml ├── experiment │ ├── a2pm_dmesa_egam_dkm_megadepth.yaml │ ├── a2pm_dmesa_egam_dkm_scannet.yaml │ ├── a2pm_dmesa_egam_loftr_megadepth.yaml │ ├── a2pm_dmesa_egam_loftr_scannet.yaml │ ├── a2pm_dmesa_egam_spsg_megadepth.yaml │ ├── a2pm_dmesa_egam_spsg_scannet.yaml │ ├── a2pm_mesa_egam_dkm_megadepth.yaml │ ├── a2pm_mesa_egam_dkm_scannet.yaml │ ├── a2pm_mesa_egam_loftr_megadepth.yaml │ ├── a2pm_mesa_egam_loftr_scannet.yaml │ ├── a2pm_mesa_egam_spsg_megadepth.yaml │ ├── a2pm_mesa_egam_spsg_scannet.yaml │ └── demo.yaml ├── geo_area_matcher │ ├── egam.yaml │ └── gam.yaml └── point_matcher │ ├── aspan_indoor.yaml │ ├── aspan_outdoor.yaml │ ├── dkm_indoor.yaml │ ├── dkm_outdoor.yaml │ ├── loftr_indoor.yaml │ ├── loftr_outdoor.yaml │ ├── spsg_indoor.yaml │ └── spsg_outdoor.yaml ├── dataloader ├── __init__.py ├── abstract_dataloader.py ├── demo_pair_loader.py ├── megadepth.py └── scannet.py ├── demo ├── color │ ├── 4119.965344.png │ └── 4120.813199.png ├── intrins │ ├── 4119.965344.txt │ └── 4120.813199.txt └── samres │ ├── 4119.965344.npy │ └── 4120.813199.npy ├── geo_area_matchers ├── MatchSampler.py ├── abstract_gam.py ├── egam.py └── gam.py ├── metric ├── Evaluation.py ├── eval_ratios.py └── instance_eval.py ├── point_matchers ├── __init__.py ├── abstract_point_matcher.py ├── aspanformer.py ├── dkm.py ├── loftr.py └── spsg.py ├── requirements.txt ├── scripts ├── demo.py ├── dmesa-dkm-md.sh ├── dmesa-dkm-sn.sh ├── megadepth_1500_pairs.txt ├── mesa-f-dkm-md.sh ├── mesa-f-dkm-sn.sh ├── outputs │ └── 2024-11-03 │ │ └── 15-43-14 │ │ ├── .hydra │ │ ├── config.yaml │ │ ├── hydra.yaml │ │ └── overrides.yaml │ │ └── demo.log ├── scannet_pairs.txt ├── test_a2pm.py └── test_in_dev.sh ├── segmentor ├── ImgSAMSeg.py ├── SAMSeger.py ├── __init__.py ├── sam_seg.sh ├── seg_res │ └── seg_res_dogs.jpg.npy └── seg_utils.py └── utils ├── common.py ├── geo.py ├── img_process.py ├── load.py ├── transformer.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./res 2 | ./scripts/outputs 3 | assets/.DS_Store 4 | .DS_Store 5 | 6 | *.pyc 7 | *.log 8 | 9 | **__pycache__/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "point_matchers/ASpanFormer"] 2 | path = point_matchers/ASpanFormer 3 | url = git@github.com:Easonyesheng/ml-aspanformer.git 4 | [submodule "point_matchers/DKM"] 5 | path = point_matchers/DKM 6 | url = git@github.com:Easonyesheng/DKM.git 7 | [submodule "point_matchers/LoFTR"] 8 | path = point_matchers/LoFTR 9 | url = git@github.com:Easonyesheng/LoFTR.git 10 | [submodule "point_matchers/SuperGluePretrainedNetwork"] 11 | path = point_matchers/SuperGluePretrainedNetwork 12 | url = git@github.com:Easonyesheng/SuperGluePretrainedNetwork.git 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Eason Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /area_matchers/AGConfig.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2023-06-29 10:32:53 4 | LastEditors: EasonZhang 5 | LastEditTime: 2023-12-27 11:28:36 6 | FilePath: /A2PM/configs/AGConfig.py 7 | Description: config for area graph 8 | 9 | Copyright (c) 2023 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | 13 | 14 | preprocess_configs = { 15 | "min_area_size": 6000, 16 | "W": 640, 17 | "H": 480, 18 | "save_path": "/data0/zys/A2PM/testAG/graphResComplete", 19 | "max_wh_ratio": 4.0, 20 | "max_size_ratio_thd": 0.5, 21 | "tiny_area_size": 900, 22 | 'topK': 3, 23 | 'min_dist_thd': 100, 24 | 'seg_source': "SAM", # "SAM" or "Sem" 25 | } 26 | 27 | areagraph_configs = { 28 | 'preprocesser_config': preprocess_configs, 29 | 'sam_res_path': "/data0/zys/A2PM/testAG/res/SAMRes.npy", 30 | 'sem_res_path': "/data0/zys/A2PM/testAG/res/semRes.png", 31 | 'W': 640, 32 | 'H': 480, 33 | 'save_path': "/data0/zys/A2PM/testAG/graphResComplete", 34 | 'ori_img_path': "/data0/zys/A2PM/data/ScanData/scene0000_00/color/12.jpg", 35 | 'fs_overlap_thd': 0.8, 36 | 'level_num': 4, 37 | 'level_step': [560, 480, 256, 130, 0], 38 | 'show_flag': 0, 39 | } 40 | -------------------------------------------------------------------------------- /area_matchers/AGUtils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-19 23:09:10 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-19 21:02:10 6 | FilePath: /SA2M/hydra-mesa/area_matchers/AGUtils.py 7 | Description: TBD 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import sys 12 | sys.path.append("..") 13 | 14 | import numpy as np 15 | import cv2 16 | import os.path as osp 17 | from sklearn.cluster import KMeans 18 | import networkx as nx 19 | import matplotlib.pyplot as plt 20 | import copy 21 | from collections import defaultdict 22 | from loguru import logger 23 | import maxflow 24 | 25 | from utils.common import test_dir_if_not_create 26 | 27 | class GraphCutSolver(object): 28 | """ 29 | Graph Cut Solver 30 | """ 31 | def __init__(self) -> None: 32 | pass 33 | 34 | def solve(self, E_graph): 35 | """ 36 | Args: 37 | E_graph: np.ndarray, shape=(N, N) 38 | Returns: 39 | labels: np.ndarray, shape=(N, ) 40 | """ 41 | node_num = E_graph.shape[0] - 1 42 | g = maxflow.Graph[float](node_num, node_num) 43 | node_list = g.add_nodes(node_num) 44 | for i in range(node_num): 45 | for j in range(node_num): 46 | if E_graph[i, j] != -1: 47 | g.add_edge(node_list[i], node_list[j], E_graph[i, j], E_graph[i, j]) 48 | # logger.info(f"for node {i} and {j}\n add edge: {E_graph[i, j]}") 49 | 50 | # add t links 51 | for i in range(node_num): 52 | g.add_tedge(node_list[i], E_graph[-1, i], E_graph[i, -1]) # source, sink 53 | # logger.info(f"for node {i}\n add t-link to sink: {E_graph[i, -1]}\n add t-link to source: {E_graph[-1, i]}") 54 | 55 | 56 | g.maxflow() 57 | labels = np.array(g.get_grid_segments(node_list)).astype(np.int32) 58 | # logger.success(f"labels: {labels}") 59 | 60 | # get the node belongs to source 61 | source_nodes = np.where(labels == 0)[0] 62 | 63 | # logger.success(f"source_nodes: {source_nodes}") 64 | 65 | return source_nodes.tolist() 66 | 67 | class KMCluster(object): 68 | """ 69 | KMeans cluster 70 | Funcs: 71 | 1. cluster the area centers 72 | """ 73 | def __init__(self, save_path) -> None: 74 | """ 75 | Args: 76 | center_list: list of cluster centers, 1~N 77 | """ 78 | self.center_list = [] 79 | self.save_path = save_path 80 | 81 | def load_center_list(self, center_list): 82 | """ load the pre-defined cluster center list 83 | """ 84 | self.center_list = center_list 85 | 86 | def cluste_2d_points(self, points, show=False, name=""): 87 | """ cluster the 2d points, use the elbow method to get the best cluster number 88 | Args: 89 | points: np.ndarray[[u, v]...] 90 | Returns: 91 | label_num 92 | labels: list of labels 93 | """ 94 | if len(self.center_list) == 0: 95 | logger.error("center_list is empty, please load the center_list first") 96 | return None, None 97 | 98 | inertia_list = [] 99 | labels_list = [] 100 | for cluster_num in self.center_list: 101 | kmeans = KMeans(n_clusters=cluster_num, random_state=0, n_init='auto').fit(points) 102 | labels = kmeans.labels_ 103 | inertia = kmeans.inertia_ 104 | inertia_list.append(inertia) 105 | labels_list.append(labels) 106 | # logger.debug("cluster_num: {}, inertia: {}".format(cluster_num, inertia)) 107 | if inertia < 1e-6: 108 | break 109 | 110 | # get the best cluster number use the elbow method 111 | inertia_diff_front_list = [] 112 | inertia_diff_back_list = [] 113 | for i in range(1, len(inertia_list)-1): 114 | inertia_diff_front_list.append(inertia_list[i-1] - inertia_list[i]) 115 | inertia_diff_back_list.append(inertia_list[i] - inertia_list[i-1]) 116 | 117 | 118 | front_back_diff = np.array(inertia_diff_front_list) - np.array(inertia_diff_back_list) 119 | if show: 120 | plt.plot(self.center_list[1:-1], front_back_diff, 'bx-') 121 | plt.savefig(osp.join(self.save_path, name+"_elbow_diff.png")) 122 | plt.close() 123 | 124 | # get the biggest diff idx 125 | 126 | if len(front_back_diff) == 0: 127 | best_cluster_num_idx = 0 128 | best_cluster_num = self.center_list[best_cluster_num_idx] 129 | else: 130 | best_cluster_num_idx = np.argmax(front_back_diff) + 1 131 | best_cluster_num = self.center_list[best_cluster_num_idx] 132 | logger.debug("best_cluster_num: {}".format(best_cluster_num)) 133 | 134 | # plot the elbow method 135 | if show: 136 | plt.plot(self.center_list, inertia_list, 'bx-') 137 | plt.savefig(osp.join(self.save_path, name+"_elbow.png")) 138 | plt.close() 139 | 140 | best_labels = labels_list[best_cluster_num_idx] 141 | 142 | # plot the cluster result 143 | if show: 144 | plt.scatter(points[:, 0], points[:, 1], c=best_labels, s=50, cmap='viridis') 145 | centers = kmeans.cluster_centers_ 146 | plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5) 147 | plt.savefig(osp.join(self.save_path, name+"_cluster.png")) 148 | plt.close() 149 | 150 | return best_cluster_num, best_labels 151 | 152 | class AGViewer(object): 153 | """ Area Graph Visualization 154 | """ 155 | def __init__(self, W, H, save_path) -> None: 156 | """ 157 | """ 158 | self.W = W 159 | self.H = H 160 | self.save_path = save_path 161 | 162 | def spring_layout_by_level(self, graph, AGNodes, level_num, node_dist=2): 163 | """ 164 | """ 165 | nodes_list = list(graph.nodes()) 166 | fig_size_h = level_num*3 167 | # get the level of each node 168 | level_list = [] 169 | for node in AGNodes: 170 | level_list.append(node.level) 171 | level_list = np.array(level_list) 172 | 173 | # get the nodes in each level 174 | nodes_in_level = defaultdict(list) 175 | for i in range(0,level_num+1): 176 | nodes_in_level[i] += np.where(level_list == i)[0].tolist() 177 | 178 | # @DEBUG show the nodes in each level and all nodes lists 179 | for i in range(0, level_num+1): 180 | # logger.debug("level {}: {}".format(i, nodes_in_level[i])) 181 | pass 182 | # logger.debug("all nodes: {}".format(nodes_list)) 183 | 184 | # get the max node number in all level 185 | max_node_num = 0 186 | for i in range(0, level_num+1): 187 | if len(nodes_in_level[i]) > max_node_num: 188 | max_node_num = len(nodes_in_level[i]) 189 | 190 | fig_size_w = max_node_num*(node_dist+2) 191 | 192 | plt.rcParams['figure.figsize']= (int(fig_size_w), int(fig_size_h)) 193 | 194 | 195 | 196 | # get the position of each node 197 | pos = {} 198 | for i in range(0, level_num+1): 199 | for j in range(len(nodes_in_level[i])): 200 | pos[int(nodes_in_level[i][j])] = np.array([j*(node_dist+2)+i%2, fig_size_h-i-1]) 201 | 202 | return pos 203 | 204 | 205 | def draw_from_adjMat(self, adjMat, AGNodes, highlighted_idx_list=[], level_num=4, node_dist=2, name="", save=False): 206 | """ draw area graph from adjMat 207 | """ 208 | G = nx.from_numpy_array(adjMat, create_using=nx.DiGraph) 209 | # print(G.nodes()) # is the node lable 210 | 211 | # draw graph whose edges with different weights are in different colors 212 | # pos = nx.spring_layout(G) 213 | # pos = nx.spectral_layout(G) 214 | pos = self.spring_layout_by_level(G, AGNodes, level_num=level_num, node_dist=node_dist) 215 | 216 | # get differnt edges 217 | e_neibour = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] == 1] 218 | e_father_son = [(u, v) for (u, v, d) in G.edges(data=True) if d['weight'] == 2] 219 | 220 | # get node labels 221 | node_labels = {} 222 | for node in G.nodes(): 223 | node_labels[node] = node 224 | 225 | highlight_nodes = [] 226 | for i, node in enumerate(G.nodes()): 227 | if i in highlighted_idx_list: 228 | highlight_nodes.append(node) 229 | 230 | # draw nodes 231 | nx.draw_networkx_nodes(G, pos, node_size=3000, node_color='r') 232 | nx.draw_networkx_nodes(G, pos, nodelist=highlight_nodes, node_size=3000, node_color='b') 233 | 234 | # draw node labels 235 | nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=50) 236 | # draw edges in e_neibour with directed=False 237 | nx.draw_networkx_edges(G, pos, edgelist=e_neibour, width=3, alpha=0.5, edge_color='b', style='dashed', arrows=False) 238 | # draw edges in e_father_son with directed=True 239 | nx.draw_networkx_edges(G, pos, edgelist=e_father_son, width=3, alpha=0.5, edge_color='g', arrows=True) 240 | 241 | if save: 242 | test_dir_if_not_create(self.save_path) 243 | plt.savefig(osp.join(self.save_path, f"{name}.png")) 244 | 245 | plt.close() 246 | 247 | def draw_single_node_area_in_img(self, img_, area_node, color=(0, 255, 0), name="", save=False): 248 | """ 249 | Args: 250 | img: np.ndarray, shape=(H, W, 3) 251 | area_node: AGNode 252 | area: [u_min, u_max, v_min, v_max] 253 | idx: int 254 | color: tuple, (r, g, b) 255 | """ 256 | img = copy.deepcopy(img_) 257 | area = area_node.area 258 | idx = area_node.idx 259 | 260 | # draw area rect in img 261 | cv2.rectangle(img, (area[0], area[2]), (area[1], area[3]), color, 2) 262 | 263 | # draw area idx in img 264 | cv2.putText(img, str(idx), ((area[0]+area[1])//2, (area[2]+area[3])//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, color, 3) 265 | 266 | if save: 267 | cv2.imwrite(osp.join(self.save_path, f"{name}.jpg"), img) 268 | 269 | return img 270 | 271 | def draw_multi_nodes_areas_in_img(self, img, area_nodes, name="", save=False): 272 | """ 273 | Args: 274 | img: np.ndarray, shape=(H, W, 3) 275 | area_nodes: list, [AGNode, AGNode, ...] 276 | name: str 277 | save: bool 278 | """ 279 | img_ = copy.deepcopy(img) 280 | for area_node in area_nodes: 281 | color = np.random.randint(0, 255, size=3) 282 | color = tuple(color.tolist()) 283 | img_ = self.draw_single_node_area_in_img(img_, area_node, color, name=name, save=save) 284 | 285 | if save: 286 | test_dir_if_not_create(self.save_path) 287 | cv2.imwrite(osp.join(self.save_path, f"{name}.jpg"), img_) 288 | 289 | return img_ 290 | 291 | 292 | class MaskViewer(object): 293 | """ Mask Visualization 294 | """ 295 | def __init__(self, save_path) -> None: 296 | """ 297 | """ 298 | self.save_path = save_path 299 | 300 | def draw_single_mask(self, mask, bbox, name): 301 | """ 302 | """ 303 | mask_show = mask.astype(np.uint8) * 255 304 | # to color img 305 | mask_show = cv2.cvtColor(mask_show, cv2.COLOR_GRAY2BGR) 306 | # draw bbox 307 | cv2.rectangle(mask_show, (bbox[0], bbox[2]), (bbox[1], bbox[3]), (0, 0, 255), 2) 308 | 309 | cv2.imwrite(osp.join(self.save_path, f"{name}.jpg"), mask_show) 310 | 311 | 312 | def draw_multi_masks_in_one(self, area_info_list, W, H, name="", key="mask"): 313 | """ 314 | """ 315 | masks_show = np.zeros((H, W, 3), dtype=np.uint8) 316 | exsit_colors = [] 317 | 318 | for area_info in area_info_list: 319 | mask = area_info[key].astype(np.uint8) 320 | mask = cv2.resize(mask, (W, H)) 321 | color = np.random.randint(0, 255, size=3) 322 | while tuple(color.tolist()) in exsit_colors: 323 | color = np.random.randint(0, 255, size=3) 324 | masks_show[mask > 0] = color 325 | color = tuple(color.tolist()) 326 | 327 | if key == "mask": 328 | bbox = area_info["area_bbox"] 329 | bbox = [int(b) for b in bbox] 330 | # turn color to scalar 331 | # draw bbox 332 | cv2.rectangle(masks_show, (bbox[0], bbox[2]), (bbox[1], bbox[3]), color, 2) 333 | exsit_colors.append(color) 334 | 335 | cv2.imwrite(osp.join(self.save_path, f"{name}.png"), masks_show) 336 | -------------------------------------------------------------------------------- /area_matchers/CoarseAreaMatcher.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2023-06-28 22:11:54 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-22 22:56:03 6 | FilePath: /SA2M/hydra-mesa/area_matchers/CoarseAreaMatcher.py 7 | Description: Input two sub-images, output inside coarse point matches using off-the-shelf point matcher. 8 | 9 | Copyright (c) 2023 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import sys 13 | sys.path.append("..") 14 | 15 | import os 16 | import os.path as osp 17 | import numpy as np 18 | import math 19 | import cv2 20 | from loguru import logger 21 | # plt 22 | import matplotlib.pyplot as plt 23 | 24 | 25 | from utils.geo import tune_corrs_size, tune_mkps_size 26 | from utils.common import test_dir_if_not_create 27 | 28 | 29 | class CoarseAreaMatcher(object): 30 | """ 31 | """ 32 | def __init__(self, configs={}) -> None: 33 | """ 34 | """ 35 | self.matcher_name = configs["matcher_name"] 36 | self.matcher = None 37 | self.datasetName = configs["datasetName"] 38 | self.out_path = configs["out_path"] 39 | self.area_w = configs["area_w"] 40 | self.area_h = configs["area_h"] 41 | self.patch_size = configs["patch_size"] 42 | self.conf_thd = configs["conf_thd"] 43 | self.out_path = configs["out_path"] 44 | self.pair_name = configs["pair_name"] 45 | pass 46 | 47 | def init_matcher(self): 48 | """ 49 | """ 50 | cur_path = osp.dirname(osp.abspath(__file__)) 51 | 52 | if self.matcher_name == "ASpan": 53 | from point_matchers.aspanformer import ASpanMatcher 54 | logger.info("Initialize ASpan Matcher") 55 | if self.datasetName == "ScanNet" or self.datasetName == "KITTI" or self.datasetName == "ETH3D": 56 | weight_path = f"{cur_path}/../point_matchers/ASpanFormer/weights/indoor.ckpt" 57 | elif self.datasetName == "MegaDepth" or self.datasetName == "YFCC": 58 | weight_path = f"{cur_path}/../point_matchers/ASpanFormer/weights/outdoor.ckpt" 59 | else: 60 | raise NotImplementedError(f"Dataset {self.datasetName} not implemented yet!") 61 | 62 | aspan_configs = { 63 | "config_path": f"{cur_path}/../point_matchers/ASpanFormer/configs", 64 | "weights": weight_path, 65 | "dataset_name": self.datasetName, 66 | } 67 | self.matcher = ASpanMatcher(**aspan_configs) 68 | 69 | elif self.matcher_name == "ASpanCD": 70 | from point_matchers.aspanformer import ASpanMatcher 71 | logger.info("Initialize ASpan Matcher") 72 | if self.datasetName == "ScanNet" or self.datasetName == "KITTI": 73 | weight_path = f"{cur_path}/../point_matchers/ASpanFormer/weights/outdoor.ckpt" # cross domain 74 | elif self.datasetName == "MegaDepth" or self.datasetName == "YFCC": 75 | weight_path = f"{cur_path}/../point_matchers/ASpanFormer/weights/indoor.ckpt" # cross domain 76 | else: 77 | raise NotImplementedError(f"Dataset {self.datasetName} not implemented yet!") 78 | 79 | aspan_configs = { 80 | "config_path": f"{cur_path}/../point_matchers/ASpanFormer/configs", 81 | "weights": weight_path, 82 | "dataset_name": self.datasetName, 83 | } 84 | self.matcher = ASpanMatcher(**aspan_configs) 85 | 86 | # FIXME: finish loftr warpper 87 | elif self.matcher_name == "LoFTR": 88 | from point_matchers.loftr import LoFTRMatcher 89 | logger.info("Initialize LoFTR Matcher") 90 | if self.datasetName == "ScanNet" or self.datasetName == "KITTI" or self.datasetName == "ETH3D": 91 | weight_path = f"{cur_path}/../point_matchers/LoFTR/weights/indoor_ds_new.ckpt" 92 | elif self.datasetName == "MegaDepth": 93 | weight_path = f"{cur_path}/../point_matchers/LoFTR/weights/outdoor_ds.ckpt" 94 | else: 95 | raise NotImplementedError(f"Dataset {self.datasetName} not implemented yet!") 96 | 97 | loftr_configs = { 98 | "weights": weight_path, 99 | "cuda_idx": 0, 100 | } 101 | self.matcher = LoFTRMatcher(loftr_configs, self.datasetName, mode="tool") 102 | elif self.matcher_name == "LoFTRCD": 103 | from point_matchers.loftr import LoFTRMatcher 104 | logger.info("Initialize LoFTR Matcher") 105 | if self.datasetName == "ScanNet" or self.datasetName == "KITTI": 106 | weight_path = f"{cur_path}/../point_matchers/LoFTR/weights/outdoor_ds.ckpt" # cross domain 107 | elif self.datasetName == "MegaDepth": 108 | weight_path = f"{cur_path}/../point_matchers/LoFTR/weights/indoor_ds_new.ckpt" # cross domain 109 | else: 110 | raise NotImplementedError(f"Dataset {self.datasetName} not implemented yet!") 111 | loftr_configs = { 112 | "weights": weight_path, 113 | "cuda_idx": 0, 114 | } 115 | else: 116 | raise NotImplementedError(f"Matcher {self.matcher_name} not implemented yet!") 117 | 118 | def match(self, area0, area1, resize_flag=True): 119 | """ 120 | Return: 121 | mkpts0_c: (N, 2), np.array 122 | mkpts1_c: (N, 2), np.array 123 | """ 124 | assert self.matcher is not None, "Matcher not initialized yet!" 125 | area0_h, area0_w, _ = area0.shape 126 | area1_h, area1_w, _ = area1.shape 127 | 128 | if resize_flag: 129 | # resize area0 and area1 to the same size 130 | area0 = cv2.resize(area0, (self.area_w, self.area_h)) 131 | area1 = cv2.resize(area1, (self.area_w, self.area_h)) 132 | 133 | logger.info(f"match areas with size: {area0.shape}, {area1.shape}") 134 | 135 | try: 136 | ret = self.matcher.get_coarse_mkpts_c(area0, area1) 137 | except Exception as e: 138 | logger.exception(e) 139 | return None 140 | 141 | mkpts0_c, mkpts1_c, mconf, conf_mat = ret 142 | # put in cpu 143 | mkpts0_c = mkpts0_c.cpu().numpy() 144 | mkpts1_c = mkpts1_c.cpu().numpy() 145 | mconf = mconf.cpu().numpy() 146 | conf_mat = conf_mat.cpu().numpy() 147 | 148 | # tune mkpts size 149 | mkpts0_c = tune_mkps_size(mkpts0_c, self.area_w, self.area_h, area0_w, area0_h) 150 | mkpts1_c = tune_mkps_size(mkpts1_c, self.area_w, self.area_h, area1_w, area1_h) 151 | mkpts0_c = np.array(mkpts0_c) 152 | mkpts1_c = np.array(mkpts1_c) 153 | 154 | return mkpts0_c, mkpts1_c, mconf, conf_mat 155 | 156 | def match_ret_activity(self, area0, area1, sigma_thd=0.1, draw_match=False, name=""): 157 | """ match two areas, return the activity 158 | Return: 159 | sigma0: float, the activity of area0, = len(mkp0_c)*64 / (area_w * area_h) 160 | sigma1: float, the activity of area1, = len(mkp1_c)*64 / (area_w * area_h) 161 | """ 162 | conf_thd = self.conf_thd 163 | assert self.matcher is not None, "Matcher not initialized yet!" 164 | 165 | # resize area0 and area1 to the same size 166 | area0 = cv2.resize(area0, (self.area_w, self.area_h)) 167 | area1 = cv2.resize(area1, (self.area_w, self.area_h)) 168 | 169 | # logger.info(f"match areas with size: {area0.shape}, {area1.shape}") 170 | if self.matcher_name == "ASpan" or self.matcher_name == "LoFTR" or self.matcher_name == "ASpanCD" or self.matcher_name == "LoFTRCD": 171 | try: 172 | # time_start = cv2.getTickCount() 173 | ret = self.matcher.get_coarse_mkpts_c(area0, area1) 174 | # time_end = cv2.getTickCount() 175 | # logger.info(f"single matching time: {(time_end - time_start) / cv2.getTickFrequency()}s") 176 | 177 | except Exception as e: 178 | logger.exception(e) 179 | return None 180 | 181 | mkpts0_c, mkpts1_c, mconf, conf_mat = ret 182 | # put in cpu 183 | mkpts0_c = mkpts0_c.cpu().numpy() 184 | mkpts1_c = mkpts1_c.cpu().numpy() 185 | mconf = mconf.cpu().numpy() 186 | conf_mat = conf_mat.cpu().numpy() 187 | 188 | # filter by conf_thd 189 | mkpts0_c = mkpts0_c[mconf > conf_thd] 190 | mkpts1_c = mkpts1_c[mconf > conf_thd] 191 | 192 | # calc activity 193 | sigma0 = self.calc_activity_by_occ(mkpts0_c, self.area_w, self.area_h) 194 | sigma1 = self.calc_activity_by_occ(mkpts1_c, self.area_w, self.area_h) 195 | # logger.info(f"real activity: {sigma0}, {sigma1}") 196 | sigma0 = 0 if sigma0 < sigma_thd else sigma0 197 | sigma1 = 0 if sigma1 < sigma_thd else sigma1 198 | 199 | if draw_match: 200 | # tune mkpts size 201 | # mkpts0_c = tune_mkps_size(mkpts0_c, self.area_w, self.area_h, area0_w, area0_h) 202 | # mkpts1_c = tune_mkps_size(mkpts1_c, self.area_w, self.area_h, area1_w, area1_h) 203 | # mkpts0_c = np.array(mkpts0_c) 204 | # mkpts1_c = np.array(mkpts1_c) 205 | self.visulization(area0, area1, mkpts0_c, mkpts1_c, mconf, conf_mat, name=name+f"_{sigma0 :.2f}_{sigma1 :.2f}") 206 | pass 207 | else: 208 | raise NotImplementedError(f"Matcher {self.matcher_name} not implemented yet!") 209 | 210 | # logger.info(f"sigma0: {sigma0}") 211 | # logger.info(f"sigma1: {sigma1}") 212 | 213 | return sigma0, sigma1 214 | 215 | def calc_activity_by_occ(self, mkpts, area_w, area_h): 216 | """ calc activity by occlusion 217 | each mkpt is a 2d point, representing the center of a patch with radius patch_r 218 | 219 | """ 220 | patch_r = self.patch_size 221 | occ_map = np.zeros((area_h, area_w), dtype=np.uint8) 222 | for i in range(mkpts.shape[0]): 223 | pt = mkpts[i] 224 | pt[0] = min(max(pt[0], patch_r), area_w-patch_r) 225 | pt[1] = min(max(pt[1], patch_r), area_h-patch_r) 226 | pt = (int(pt[0]), int(pt[1])) 227 | occ_map[pt[1]-patch_r:pt[1]+patch_r, pt[0]-patch_r:pt[0]+patch_r] = 1 228 | 229 | occ_num = np.sum(occ_map) 230 | occ_ratio = occ_num / (area_w * area_h) 231 | return occ_ratio 232 | 233 | def visulization(self, area0, area1, mkpts0_c, mkpts1_c, mconf, conf_mat, name=""): 234 | """ visulization the matching result in two areas 235 | Args: 236 | area0: np.array, shape: [area_h, area_w, 3] 237 | area1: np.array, shape: [area_h, area_w, 3] 238 | mkpts0_c: np.array, shape: [n, 2] 239 | mkpts1_c: np.array, shape: [n, 2] 240 | mconf: np.array, shape: [n, 1] 241 | conf_mat: np.array, shape: [area_h*area_w, area_h*area_w] 242 | """ 243 | # if area is gray, convert to rgb 244 | if len(area0.shape) == 2: 245 | area0 = cv2.cvtColor(area0, cv2.COLOR_GRAY2RGB) 246 | if len(area1.shape) == 2: 247 | area1 = cv2.cvtColor(area1, cv2.COLOR_GRAY2RGB) 248 | 249 | # tune mkpts size 250 | mkpts0_c = tune_mkps_size(mkpts0_c, area0.shape[1], area0.shape[0], self.area_w, self.area_h) 251 | mkpts1_c = tune_mkps_size(mkpts1_c, area1.shape[1], area1.shape[0], self.area_w, self.area_h) 252 | 253 | # resize area0 and area1 to the same size 254 | area0 = cv2.resize(area0, (self.area_w, self.area_h)) 255 | area1 = cv2.resize(area1, (self.area_w, self.area_h)) 256 | 257 | 258 | # draw mkpts with 8x8 red rectangle in the combined image 259 | area0_mkpts = area0.copy() 260 | area1_mkpts = area1.copy() 261 | 262 | patch_radius = self.patch_size // 2 263 | 264 | img_out = np.zeros((self.area_h, self.area_w*2, 3), dtype=np.uint8) 265 | img_out_rect = np.zeros((self.area_h, self.area_w*2, 3), dtype=np.uint8) 266 | img_out[:, :self.area_w, :] = area0_mkpts 267 | img_out[:, self.area_w:, :] = area1_mkpts 268 | 269 | 270 | for i in range(mkpts0_c.shape[0]): 271 | pt0 = mkpts0_c[i] 272 | pt1 = mkpts1_c[i] 273 | conf = mconf[i] 274 | if conf < self.conf_thd: 275 | continue 276 | 277 | # random color 278 | color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) 279 | # logger.info(f"color: {color}") 280 | 281 | # fix coordinate 282 | pt0[0] = min(max(pt0[0], patch_radius), self.area_w-patch_radius) 283 | pt0[1] = min(max(pt0[1], patch_radius), self.area_h-patch_radius) 284 | pt1[0] = min(max(pt1[0], patch_radius), self.area_w-patch_radius) 285 | pt1[1] = min(max(pt1[1], patch_radius), self.area_h-patch_radius) 286 | 287 | pt0 = (int(pt0[0]), int(pt0[1])) 288 | pt1 = (int(pt1[0]+self.area_w), int(pt1[1])) 289 | 290 | cv2.rectangle(img_out_rect, (pt0[0]-patch_radius, pt0[1]-patch_radius), (pt0[0]+patch_radius, pt0[1]+patch_radius), color, -1) 291 | cv2.rectangle(img_out_rect, (pt1[0]-patch_radius, pt1[1]-patch_radius), (pt1[0]+patch_radius, pt1[1]+patch_radius), color, -1) 292 | # cv2.line(img_out, pt0, pt1, (0, 255, 0), 1) 293 | 294 | # img add 295 | img_out = cv2.addWeighted(img_out, 1.0, img_out_rect, 0.5, 1) 296 | 297 | 298 | # save the image 299 | if name == "": 300 | name_img = f"{self.matcher_name}_patch_matches.png" 301 | else: 302 | name_img = f"{name}_patch_matches.png" 303 | 304 | out_folder = os.path.join(self.out_path, f"coarse_patch_matches_{self.pair_name}") 305 | test_dir_if_not_create(out_folder+"/"+"/".join(name_img.split("/")[:-1])) 306 | 307 | name_img = os.path.join(out_folder, name_img) 308 | 309 | logger.info(f"save patch matches image to {name_img}") 310 | cv2.imwrite(name_img, img_out) -------------------------------------------------------------------------------- /area_matchers/abstract_am.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-19 22:43:14 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-28 11:30:23 6 | FilePath: /SA2M/hydra-mesa/area_matchers/abstract_am.py 7 | Description: abstract area matcher for pre-processing 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import os 13 | import numpy as np 14 | import abc 15 | 16 | from utils.common import test_dir_if_not_create 17 | 18 | class AbstractAreaMatcher(abc.ABC): 19 | 20 | def __init__(self) -> None: 21 | pass 22 | 23 | @abc.abstractmethod 24 | def name(self) -> str: 25 | raise NotImplementedError 26 | 27 | @abc.abstractmethod 28 | def init_dataloader(self, dataloader): 29 | raise NotImplementedError 30 | 31 | @abc.abstractmethod 32 | def area_matching(self): 33 | """ Main Func 34 | Returns: 35 | area_match_src: list of [u_min, u_max, v_min, v_max] in src img 36 | area_match_dst: list of [u_min, u_max, v_min, v_max] in dst img 37 | """ 38 | raise NotImplementedError 39 | 40 | def set_outpath(self, outpath: str): 41 | """Run after init_dataloader 42 | """ 43 | self.out_path = os.path.join(outpath, f"{self.scene_name}_{self.name0}_{self.name1}", 'am') 44 | if self.draw_verbose == 1: 45 | test_dir_if_not_create(self.out_path) -------------------------------------------------------------------------------- /area_matchers/dmesa.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import copy 7 | import torch 8 | import cv2 9 | from loguru import logger 10 | from collections import defaultdict 11 | from copy import deepcopy 12 | from tqdm import tqdm 13 | import time 14 | 15 | from .AreaGrapher import AreaGraph 16 | from .AGConfig import areagraph_configs 17 | from .AGUtils import GraphCutSolver 18 | from .CoarseAreaMatcher import CoarseAreaMatcher 19 | from .AGBasic import AGNode 20 | from utils.vis import draw_matched_area, draw_matched_area_list, draw_matched_area_with_mkpts 21 | from utils.common import test_dir_if_not_create, validate_type 22 | from utils.geo import calc_areas_iou 23 | 24 | from dataloader.abstract_dataloader import AbstractDataloader 25 | from .abstract_am import AbstractAreaMatcher 26 | 27 | 28 | class DMesaAreaMatcher(AbstractAreaMatcher): 29 | """ DMESAAreaMatcher 30 | """ 31 | def __init__(self, 32 | W, 33 | H, 34 | coarse_matcher_name, 35 | level_num, 36 | level_step, 37 | stop_match_level, 38 | area_crop_mode, 39 | patch_size_ratio, 40 | valid_gaussian_width, 41 | source_area_selection_mode, 42 | iou_fusion_thd, 43 | patch_match_num_thd, 44 | match_mode, 45 | coarse_match_all_in_one, 46 | dual_match, 47 | datasetName, 48 | step_gmm=None, 49 | draw_verbose=0): 50 | """ 51 | """ 52 | self.coarse_matcher_name = coarse_matcher_name 53 | self.level_num = level_num 54 | self.level_step = level_step 55 | self.area_crop_mode = area_crop_mode 56 | self.patch_size_ratio = patch_size_ratio 57 | self.valid_gaussian_width = valid_gaussian_width 58 | self.source_area_selection_mode = source_area_selection_mode 59 | self.iou_fusion_thd = iou_fusion_thd 60 | self.patch_match_num_thd = patch_match_num_thd 61 | self.match_mode = match_mode 62 | if self.match_mode == "EM": 63 | self.step_gmm = step_gmm 64 | self.coarse_match_all_in_one = coarse_match_all_in_one 65 | self.dual_match = dual_match 66 | self.stop_match_level = stop_match_level 67 | self.W, self.H = W, H 68 | self.datasetName = datasetName 69 | self.draw_verbose = draw_verbose 70 | 71 | 72 | def init_dataloader(self, dataloader): 73 | """ 74 | """ 75 | validate_type(dataloader, AbstractDataloader) 76 | self.sem_path0, self.sem_path1 = dataloader.get_sem_paths() 77 | 78 | self.name0 = dataloader.image_name0 79 | self.name1 = dataloader.image_name1 80 | 81 | self.image0_path, self.image1_path = dataloader.img0_path, dataloader.img1_path 82 | 83 | self.scene_name = dataloader.scene_name 84 | 85 | self.img0, self.img1, self.scale0, self.scale1 = dataloader.load_images(self.W, self.H) 86 | 87 | def name(self): 88 | return "MesaAreaMatcher-TrainingFree" 89 | 90 | def init_area_matcher(self): 91 | """ 92 | """ 93 | AM_config = { 94 | "matcher_name": self.coarse_matcher_name, 95 | "datasetName": self.datasetName, 96 | "out_path": self.out_path, 97 | "level_num": self.level_num, 98 | "level_step": self.level_step, 99 | "stop_match_level": self.stop_match_level, 100 | "W": self.W, # original image size, NOTE also the match size 101 | "H": self.H, 102 | "area_crop_mode": self.area_crop_mode, # first expand to square, then padding 103 | "patch_size_ratio": self.patch_size_ratio, 104 | "valid_gaussian_width": self.valid_gaussian_width, 105 | "show_flag": self.draw_verbose, 106 | "source_area_selection_mode": self.source_area_selection_mode, 107 | "iou_fusion_thd": self.iou_fusion_thd, # iou threshold for repeat area identification 108 | "patch_match_num_thd": self.patch_match_num_thd, # threshold for patch match number 109 | "match_mode": self.match_mode, 110 | "coarse_match_all_in_one": self.coarse_match_all_in_one, 111 | "dual_match": self.dual_match, 112 | } 113 | 114 | if self.match_mode == "EM": 115 | AM_config.update({"step_gmm": self.step_gmm}) 116 | 117 | from .AreaMatchDense import AGMatcherDense 118 | self.area_matcher = AGMatcherDense(configs=AM_config) 119 | 120 | def area_matching(self, dataloader, out_path): 121 | """ 122 | """ 123 | logger.info(f"start area matching") 124 | 125 | self.init_dataloader(dataloader) 126 | self.set_outpath(out_path) 127 | 128 | self.init_area_matcher() 129 | self.area_matcher.path_loader(self.image0_path, self.sem_path0, self.image1_path, self.sem_path1, self.name0, self.name1) 130 | 131 | self.area_matcher.init_area_matcher() 132 | self.area_matcher.img_areagraph_construct(efficient=True) 133 | 134 | area_match_srcs, area_match_dsts = self.area_matcher.dense_area_matching_dual() 135 | 136 | 137 | self.area_match_srcs = area_match_srcs 138 | self.area_match_dsts = area_match_dsts 139 | 140 | if self.draw_verbose: 141 | flag = draw_matched_area_list(self.img0, self.img1, area_match_srcs, area_match_dsts, self.out_path, self.name0, self.name1, self.draw_verbose) 142 | if not flag: 143 | logger.critical(f"Something wrong with area matching, please check the code for {self.out_path.split('/')[-1]}") 144 | 145 | # draw each area's match 146 | for i, src_area in enumerate(area_match_srcs): 147 | dst_area = area_match_dsts[i] 148 | draw_matched_area(self.img0, self.img1, src_area, dst_area, (0,255,0), self.out_path, f"{i}_"+self.name0, self.name1, self.draw_verbose) 149 | 150 | logger.info(f"finish area matching") 151 | 152 | 153 | return area_match_srcs, area_match_dsts 154 | 155 | def area_matching_rt_time(self, dataloader, out_path): 156 | """ 157 | """ 158 | logger.info(f"start area matching") 159 | 160 | self.init_dataloader(dataloader) 161 | self.set_outpath(out_path) 162 | 163 | self.init_area_matcher() 164 | self.area_matcher.path_loader(self.image0_path, self.sem_path0, self.image1_path, self.sem_path1, self.name0, self.name1) 165 | 166 | self.area_matcher.init_area_matcher() 167 | times0 = cv2.getTickCount() 168 | self.area_matcher.img_areagraph_construct(efficient=True) 169 | 170 | area_match_srcs, area_match_dsts = self.area_matcher.dense_area_matching_dual() 171 | times1 = cv2.getTickCount() 172 | time_match = (times1 - times0) / cv2.getTickFrequency() 173 | 174 | 175 | self.area_match_srcs = area_match_srcs 176 | self.area_match_dsts = area_match_dsts 177 | 178 | if self.draw_verbose: 179 | flag = draw_matched_area_list(self.img0, self.img1, area_match_srcs, area_match_dsts, self.out_path, self.name0, self.name1, self.draw_verbose) 180 | if not flag: 181 | logger.critical(f"Something wrong with area matching, please check the code for {self.out_path.split('/')[-1]}") 182 | 183 | # draw each area's match 184 | for i, src_area in enumerate(area_match_srcs): 185 | dst_area = area_match_dsts[i] 186 | draw_matched_area(self.img0, self.img1, src_area, dst_area, (0,255,0), self.out_path, f"{i}_"+self.name0, self.name1, self.draw_verbose) 187 | 188 | logger.info(f"finish area matching") 189 | 190 | return area_match_srcs, area_match_dsts, time_match -------------------------------------------------------------------------------- /area_matchers/mesa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-20 11:42:48 4 | LastEditors: Easonyesheng preacher@sjtu.edu.cn 5 | LastEditTime: 2024-07-30 15:35:34 6 | FilePath: /SA2M/hydra-mesa/area_matchers/mesa.py 7 | Description: traning-free version mesa 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import os 13 | import os.path as osp 14 | import numpy as np 15 | import copy 16 | import torch 17 | import cv2 18 | from loguru import logger 19 | from collections import defaultdict 20 | from copy import deepcopy 21 | from tqdm import tqdm 22 | 23 | from .AreaGrapher import AreaGraph 24 | from .AGConfig import areagraph_configs 25 | from .AGUtils import GraphCutSolver 26 | from .CoarseAreaMatcher import CoarseAreaMatcher 27 | from .AGBasic import AGNode 28 | from utils.vis import draw_matched_area, draw_matched_area_list, draw_matched_area_with_mkpts 29 | from utils.common import test_dir_if_not_create, validate_type 30 | from utils.geo import calc_areas_iou 31 | 32 | from dataloader.abstract_dataloader import AbstractDataloader 33 | from .abstract_am import AbstractAreaMatcher 34 | 35 | 36 | class MesaAreaMatcher(AbstractAreaMatcher): 37 | """ MESAAreaMatcher, NOTE this is a training free version 38 | """ 39 | def __init__(self, 40 | W, 41 | H, 42 | coarse_matcher_name, 43 | level_num, 44 | level_step, 45 | adj_weight, 46 | stop_match_level, 47 | coarse_match_thd, 48 | patch_size, 49 | similar_area_dist_thd, 50 | area_w, 51 | area_h, 52 | sigma_thd, 53 | global_energy_weights, 54 | iou_fusion_thd, 55 | candi_energy_thd, 56 | global_refine, 57 | global_energy_candy_range, 58 | fast_version, 59 | energy_norm_way, 60 | datasetName, 61 | draw_verbose=0): 62 | """ 63 | """ 64 | self.coarse_matcher_name = coarse_matcher_name 65 | self.level_num = level_num 66 | self.level_step = level_step 67 | self.adj_weight = adj_weight 68 | self.stop_match_level = stop_match_level 69 | self.coarse_match_thd = coarse_match_thd 70 | self.patch_size = patch_size 71 | self.similar_area_dist_thd = similar_area_dist_thd 72 | self.area_w = area_w 73 | self.area_h = area_h 74 | self.sigma_thd = sigma_thd 75 | self.global_energy_weights = global_energy_weights 76 | self.iou_fusion_thd = iou_fusion_thd 77 | self.candi_energy_thd = candi_energy_thd 78 | self.global_refine = global_refine 79 | self.global_energy_candy_range = global_energy_candy_range 80 | self.fast_version = fast_version 81 | self.energy_norm_way = energy_norm_way 82 | 83 | self.W, self.H = W, H 84 | self.datasetName = datasetName 85 | self.draw_verbose = draw_verbose 86 | 87 | 88 | def init_dataloader(self, dataloader): 89 | """ 90 | """ 91 | validate_type(dataloader, AbstractDataloader) 92 | self.sem_path0, self.sem_path1 = dataloader.get_sem_paths() 93 | 94 | self.name0 = dataloader.image_name0 95 | self.name1 = dataloader.image_name1 96 | 97 | self.image0_path, self.image1_path = dataloader.img0_path, dataloader.img1_path 98 | 99 | self.scene_name = dataloader.scene_name 100 | 101 | self.img0, self.img1, self.scale0, self.scale1 = dataloader.load_images(self.W, self.H) 102 | 103 | def name(self): 104 | return "MesaAreaMatcher-TrainingFree" 105 | 106 | def init_area_matcher(self): 107 | """ 108 | """ 109 | AM_config = { 110 | "matcher_name": self.coarse_matcher_name, 111 | "datasetName": self.datasetName, 112 | "out_path": self.out_path, 113 | "level_num": self.level_num, 114 | "level_step": self.level_step, 115 | "adj_weight": self.adj_weight, 116 | "stop_match_level": self.stop_match_level, 117 | "W": self.W, 118 | "H": self.H, 119 | "coarse_match_thd": self.coarse_match_thd, 120 | "patch_size": self.patch_size, 121 | "similar_area_dist_thd": self.similar_area_dist_thd, 122 | "area_w": self.area_w, 123 | "area_h": self.area_h, 124 | "show_flag": self.draw_verbose, 125 | "sigma_thd": self.sigma_thd, 126 | "global_energy_weights": self.global_energy_weights, 127 | "iou_fusion_thd": self.iou_fusion_thd, 128 | "candi_energy_thd": self.candi_energy_thd, 129 | "global_refine": self.global_refine, 130 | "global_energy_candy_range": self.global_energy_candy_range, 131 | "fast_version": self.fast_version, 132 | "energy_norm_way": "minmax", 133 | } 134 | 135 | from .AGMatcherFree import AGMatcherF 136 | self.area_matcher = AGMatcherF(configs=AM_config) 137 | 138 | def area_matching(self, dataloader, out_path): 139 | """ 140 | """ 141 | logger.info(f"start area matching") 142 | 143 | self.init_dataloader(dataloader) 144 | self.set_outpath(out_path) 145 | 146 | self.init_area_matcher() 147 | self.area_matcher.path_loader(self.image0_path, self.sem_path0, self.image1_path, self.sem_path1, self.name0, self.name1) 148 | self.area_matcher.init_area_matcher() 149 | self.area_matcher.img_areagraph_construct(efficient=True) 150 | 151 | # debug - end the total process for only AG construction 152 | # import sys 153 | # sys.exit(0) 154 | 155 | area_match_srcs, area_match_dsts = self.area_matcher.dual_graphical_match(self.draw_verbose) 156 | 157 | 158 | self.area_match_srcs = area_match_srcs 159 | self.area_match_dsts = area_match_dsts 160 | 161 | if self.draw_verbose: 162 | flag = draw_matched_area_list(self.img0, self.img1, area_match_srcs, area_match_dsts, self.out_path, self.name0, self.name1, self.draw_verbose) 163 | if not flag: 164 | logger.critical(f"Something wrong with area matching, please check the code for {self.out_path.split('/')[-1]}") 165 | 166 | # draw each area's match 167 | for i, src_area in enumerate(area_match_srcs): 168 | dst_area = area_match_dsts[i] 169 | draw_matched_area(self.img0, self.img1, src_area, dst_area, (0,255,0), self.out_path, f"{i}_"+self.name0, self.name1, self.draw_verbose) 170 | 171 | logger.success(f"finish area matching") 172 | 173 | 174 | return area_match_srcs, area_match_dsts 175 | -------------------------------------------------------------------------------- /assets/A2PM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/assets/A2PM.png -------------------------------------------------------------------------------- /assets/Qua.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/assets/Qua.png -------------------------------------------------------------------------------- /assets/mesa-ava.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/assets/mesa-ava.png -------------------------------------------------------------------------------- /assets/mesa-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/assets/mesa-main.png -------------------------------------------------------------------------------- /assets/run_MESA_on_win11.md: -------------------------------------------------------------------------------- 1 | ## Windows下运行A2PM项目 2 | 3 | ## 一 进行Segmentation Preprocessing 4 | 5 | **1.** 6 | 下载[SAM](https://github.com/facebookresearch/segment-anything)粘贴到A2PM-MESA目录下; 7 | 8 | Download [SAM](https://github.com/facebookresearch/segment-anything) and paste it into the `A2PM-MESA` directory. 9 | 10 | **2.** 11 | 下载[SAM2](https://github.com/facebookresearch/sam2),将其中的sam2粘贴到A2PM-MESA\segmentor目录下(别问为什么,SAM2项目路径内部问题,源码内部有整段解释为何会有包错误); 12 | 13 | Download [SAM2](https://github.com/facebookresearch/sam2) and paste the `sam2` in it into the `A2PM-MESA\segmentor` directory (don't ask why, there is a problem inside the SAM2 project path, there is a whole paragraph inside the source code explaining why there is a package error); 14 | 15 | **3.** 16 | 修改`SAMSeger.py`文件内 17 | modify file `SAMSeger.py` 18 | 19 | ``` 20 | from SAM.segment_anything import sam_model_registry, SamAutomaticMaskGenerator 21 | from SAM2.sam2.build_sam import build_sam2 22 | from SAM2.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator 23 | ``` 24 | 25 | 改为 26 | to 27 | 28 | ``` 29 | from SAM.segment_anything import sam_model_registry, SamAutomaticMaskGenerator 30 | from sam2.build_sam import build_sam2 31 | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator 32 | ``` 33 | 34 | **4.** 35 | 下载`sam_vit_h_4b8939.pth`模型文件放在`A2PM-MESA\segmentor`目录下,同时修改该目录下的`ImgSAMSeg.py`文件,将`"sam_model_path": f"{current_path}/SAM/sam_vit_h_4b8939.pth",`替换为你下载的模型地址 36 | 37 | Download the `sam_vit_h_4b8939.pth` model file and place it in the `A2PM-MESA\segmentor` directory. Additionally, modify the `ImgSAMSeg.py` file in the same directory. 38 | 39 | ```python 40 | SAM_configs = { 41 | "SAM_name": "SAM", 42 | "W": 640, 43 | "H": 480, 44 | "sam_model_type": "vit_h", 45 | "sam_model_path": f"{current_path}/SAM/sam_vit_h_4b8939.pth", 46 | "save_folder": "", 47 | "points_per_side": 16, 48 | } 49 | ``` 50 | 51 | Replace with the address of the model you downloaded。 52 | 53 | **5.** 54 | 55 | ```bash 56 | cd F:\VSCodeProject\A2PM-MESA\segmentor 57 | ``` 58 | (modify to your address) 59 | 60 | ```python 61 | python ImgSAMSeg.py --img_path "F:\VSCodeProject\A2PM-MESA\dataset\scannet_test_1500\scene0720_00\color\180.jpg" --save_folder "F:\VSCodeProject\A2PM-MESA\result\private\SA2M\data\SAMRes\scene0720_00" --save_name "180" 62 | ``` 63 | 64 | ```python 65 | python ImgSAMSeg.py --img_path "F:\VSCodeProject\A2PM-MESA\dataset\scannet_test_1500\scene0720_00\color\300.jpg" --save_folder "F:\VSCodeProject\A2PM-MESA\result\private\SA2M\data\SAMRes\scene0720_00" --save_name "300" 66 | ``` 67 | 68 | 结果的地址要与 `A2PM-MESA\conf\dataset\scannet_sam.yaml` 文件下的路径一致 69 | 70 | correspond with the address in `A2PM-MESA\conf\dataset\scannet_sam.yaml` 71 | 72 | ## 二 根据第一步的结果进行MESA+DKM拼配 73 | 74 | ## MESA+DKM matching according to the results of the first step 75 | 76 | ```bash 77 | cd F:\VSCodeProject\A2PM-MESA\scripts 78 | 79 | python test_a2pm.py +experiment=a2pm_mesa_egam_dkm_scannet 80 | ``` 81 | 82 | ## 三 进行基准测试 Perform benchmark testing. 83 | 84 | **1.** 85 | 根据`A2PM-MESA\scripts\scannet_pairs.txt`中需要的图片依次调用第一步中的命令生成所需`.npy`文件(只进行部分测试,前十几张照片即可) 86 | 87 | According to the pictures required in `A2PM-MESA\scripts\scannet_pairs.txt`, call the command of the first step to generate the required `.npy` file (only part of the test, the first dozen photos can be used). 88 | 89 | **2.** 90 | 在`A2PM-MESA\scripts`下新建`mesa-f-dkm-sn.py`文件,文件第8,10,11行需要自行修改 91 | 92 | (Create a `mesa-f-dkm-sn.py` file under `A2PM-MESA\scripts`, and you need to modify lines 8, 10, and 11 of the file) 93 | 94 | ``` 95 | import os 96 | import subprocess 97 | 98 | # 设置参数 99 | dataset = "ScanNet" 100 | cuda_id = 0 101 | project_name = "mesa-f-egam-dkm-sn-eval-res" 102 | exp_root_path = "F:\\VSCodeProject\\A2PM-MESA" 103 | 104 | already_done_name_file_folder = os.path.join(exp_root_path, "result", "private","A2PM-MESA","res", f"{project_name}", "ratios") 105 | pair_txt = os.path.join(exp_root_path, "scripts", "scannet_pairs.txt") 106 | 107 | # 获取已经完成的文件 108 | already_done_name_file = None 109 | for file in os.listdir(already_done_name_file_folder): 110 | if file.endswith("pose_err_names.txt"): 111 | already_done_name_file = os.path.join(already_done_name_file_folder, file) 112 | break 113 | 114 | # 读取对文件 115 | with open(pair_txt, 'r') as f: 116 | for line in f: 117 | # 解析每一行 118 | arr = line.strip().split('_') 119 | scene_name = f"{arr[0]}_{arr[1]}" 120 | pair0 = arr[2] 121 | pair1 = arr[3] 122 | 123 | # 检查是否已经处理过 124 | if already_done_name_file and os.path.isfile(already_done_name_file): 125 | with open(already_done_name_file, 'r') as done_file: 126 | already_done_lines = done_file.readlines() 127 | if any(f"{scene_name}_{pair0}_{pair1}" in line for line in already_done_lines): 128 | print(f"{scene_name}_{pair0}_{pair1} already done") 129 | continue 130 | 131 | # 设置环境变量 132 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_id) 133 | 134 | # 执行 Python 脚本 135 | command = [ 136 | "python", "test_a2pm.py", 137 | "+experiment=a2pm_mesa_egam_dkm_scannet", 138 | "test_area_acc=False", 139 | "test_pm_acc=False", 140 | "verbose=0", 141 | f"name={project_name}", 142 | f"dataset_name={dataset}", 143 | f"dataset.scene_name={scene_name}", 144 | f"dataset.image_name0={pair0}", 145 | f"dataset.image_name1={pair1}" 146 | ] 147 | 148 | # 打印并执行命令 149 | print("Running command:", " ".join(command)) 150 | subprocess.run(command) 151 | ``` 152 | 153 | **3.** 154 | 运行测试程序 155 | run the benchmark test program 156 | 157 | ```bash 158 | python mesa-f-dkm-sn.py 159 | ``` 160 | 161 | **4.** 162 | 统计metric 163 | 164 | ```bash 165 | cd ../metric 166 | python eval_ratios.py 167 | ``` 168 | - NOTE: eval_ratios.py#L21~L26 需要修改到对应路径和文件夹名字. 169 | 170 | -------------------------------------------------------------------------------- /conf/area_matcher/dmesa.yaml: -------------------------------------------------------------------------------- 1 | _target_: area_matchers.dmesa.DMesaAreaMatcher 2 | datasetName: 'ScanNet' 3 | W: 640 4 | H: 480 5 | coarse_matcher_name: 'ASpan' 6 | level_num: 4 7 | level_step: [560, 390, 256, 130, 0] 8 | stop_match_level: 3 9 | area_crop_mode: 'expand_padding' 10 | patch_size_ratio: 0.125 11 | valid_gaussian_width: 'sqrt2' 12 | source_area_selection_mode: 'direct' 13 | iou_fusion_thd: 0.8 14 | patch_match_num_thd: 30 15 | match_mode: 'pms_GF' 16 | coarse_match_all_in_one: 1 17 | dual_match: 1 18 | draw_verbose: 0 -------------------------------------------------------------------------------- /conf/area_matcher/mesa-f.yaml: -------------------------------------------------------------------------------- 1 | _target_: area_matchers.mesa.MesaAreaMatcher 2 | 3 | W: 640 4 | H: 480 5 | 6 | coarse_matcher_name: 'ASpan' 7 | level_num: 4 8 | level_step: [560, 390, 256, 130, 0] 9 | adj_weight: 0.01 10 | stop_match_level: 3 11 | coarse_match_thd: 0.3 12 | patch_size: 64 13 | similar_area_dist_thd: 20 14 | area_w: 480 15 | area_h: 480 16 | sigma_thd: 0.1 17 | global_energy_weights: [10,1,1,1] 18 | iou_fusion_thd: 0.8 19 | candi_energy_thd: 0.7 20 | global_refine: 1 21 | fast_version: 0 22 | global_energy_candy_range: 0.1 23 | energy_norm_way: 'minmax' 24 | datasetName: ??? 25 | 26 | draw_verbose: 0 -------------------------------------------------------------------------------- /conf/area_matcher/sem_area_matcher.yaml: -------------------------------------------------------------------------------- 1 | _target_: area_matchers.sem_am.SemAreaMatcher 2 | 3 | semantic_mode: SEEM 4 | 5 | datasetName: ScanNet 6 | 7 | W: 640 8 | H: 480 9 | 10 | # params 11 | connected_thd: 3600 12 | radius_thd_up: 128 13 | radius_thd_down: 100 14 | desc_type: 2 15 | small_label_filted_thd_on_bound: 20 16 | small_label_filted_thd_inside_area: 900 17 | combined_obj_dist_thd: 200 18 | leave_multi_obj_match: 0 19 | obj_desc_match_thd: 0.5 20 | same_overlap_dist: 100 21 | label_list_area_thd: 400 22 | overlap_radius: 128 23 | overlap_desc_dist_thd: 0.25 24 | inv_overlap_pyramid_ratio: 8 25 | output_patch_size: 256 26 | 27 | draw_verbose: 0 -------------------------------------------------------------------------------- /conf/dataset/demo_pair.yaml: -------------------------------------------------------------------------------- 1 | _target_: dataloader.demo_pair_loader.DemoPairLoader 2 | 3 | root_path: /opt/data/private/SA2M/hydra-mesa/demo 4 | scene_name: "" 5 | image_name0: "4119.965344" 6 | image_name1: "4120.813199" 7 | color_folder: "color" 8 | color_post: "png" 9 | sem_folder: "samres" 10 | sem_post: "npy" 11 | intrin_folder: "intrins" 12 | intrin_post: "txt" 13 | -------------------------------------------------------------------------------- /conf/dataset/megadepth.yaml: -------------------------------------------------------------------------------- 1 | _target_: dataloader.megadepth.MegaDepthDataloader 2 | 3 | root_path: /opt/data/private/SA2M/data/megadepth_test_1500 4 | scene_name: ??? 5 | image_name0: ??? 6 | image_name1: ??? 7 | 8 | sem_mode: SAM 9 | sem_folder: /opt/data/private/SA2M/data/SAMRes 10 | sem_post: npy -------------------------------------------------------------------------------- /conf/dataset/scannet_sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: dataloader.scannet.ScanNetDataloader 2 | 3 | root_path: /opt/data/private/SA2M/data/scannet_test_1500 4 | color_folder: color 5 | color_post: jpg 6 | depth_folder: depth 7 | depth_post: png 8 | depth_factor: 1000.0 9 | K_folder: intrinsic 10 | pose_folder: pose 11 | pose_post: txt 12 | 13 | # SAM 14 | sem_mode: SAM 15 | sem_folder: /opt/data/private/SA2M/data/SAMRes 16 | sem_post: npy 17 | 18 | # # SEEM 19 | # sem_mode: SEEM 20 | # sem_folder: /opt/data/private/SA2M/data/SNSEEMRes/SAMViTB 21 | # sem_post: png 22 | 23 | # GT 24 | # sem_mode: GT 25 | # sem_folder: '' 26 | # sem_post: png 27 | 28 | 29 | scene_name: ??? 30 | image_name0: ??? 31 | image_name1: ??? -------------------------------------------------------------------------------- /conf/dataset/scannet_seem.yaml: -------------------------------------------------------------------------------- 1 | _target_: dataloader.scannet.ScanNetDataloader 2 | 3 | root_path: /opt/data/private/SA2M/data/scannet_test_1500 4 | color_folder: color 5 | color_post: jpg 6 | depth_folder: depth 7 | depth_post: png 8 | depth_factor: 1000.0 9 | K_folder: intrinsic 10 | pose_folder: pose 11 | pose_post: txt 12 | 13 | # SAM 14 | # sem_mode: SAM 15 | # sem_folder: /opt/data/private/SA2M/data/SAMRes 16 | # sem_post: npy 17 | 18 | # SEEM 19 | sem_mode: SEEM 20 | sem_folder: /opt/data/private/SA2M/data/SNSEEMRes/SAMViTB 21 | sem_post: png 22 | 23 | # GT 24 | # sem_mode: GT 25 | # sem_folder: '' 26 | # sem_post: png 27 | 28 | 29 | scene_name: ??? 30 | image_name0: ??? 31 | image_name1: ??? -------------------------------------------------------------------------------- /conf/evaler/instance_eval.yaml: -------------------------------------------------------------------------------- 1 | _target_: metric.instance_eval.InstanceEval 2 | 3 | sample_mode: random 4 | eval_corr_num: 1000 5 | sac_mode: MAGSAC 6 | out_path: '' 7 | draw_verbose: False -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_dkm_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: dkm_outdoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-egam-dkm-megadepth' 11 | dataset_name: 'MegaDepth' 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 1 15 | pm_acc_thds: [0.0001, 0.0003, 0.0005, 0.0007, 0.001, 0.003, 0.005] 16 | 17 | # size info 18 | area_from_size_W: 640 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 # Not work for MD 22 | eval_from_size_H: 480 # Not work for MD 23 | 24 | crop_from_size_W: 1296 # Not work for MD 25 | crop_from_size_H: 968 # Not work for MD 26 | 27 | crop_size_W: 832 28 | crop_size_H: 832 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 5000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: grid 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | alpha_list: [3] 52 | crop_size_W: ${crop_size_W} 53 | crop_size_H: ${crop_size_H} 54 | crop_from_size_W: ${crop_from_size_W} 55 | crop_from_size_H: ${crop_from_size_H} 56 | eval_from_size_W: ${eval_from_size_W} 57 | eval_from_size_H: ${eval_from_size_H} 58 | area_from_size_W: ${area_from_size_W} 59 | area_from_size_H: ${area_from_size_H} 60 | reject_out_area_flag: 1 61 | verbose: ${verbose} 62 | 63 | 64 | ## update the size info for area_matcher 65 | area_matcher: 66 | datasetName: ${dataset_name} 67 | W: ${area_from_size_W} 68 | H: ${area_from_size_H} 69 | draw_verbose: ${verbose} 70 | 71 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_dkm_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: dkm_indoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-egam-dkm-scannet' 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | pm_acc_thds: [1, 3, 5, 7, 9] 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 640 28 | crop_size_H: 480 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 5000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | alpha_list: [3] 52 | crop_size_W: ${crop_size_W} 53 | crop_size_H: ${crop_size_H} 54 | crop_from_size_W: ${crop_from_size_W} 55 | crop_from_size_H: ${crop_from_size_H} 56 | eval_from_size_W: ${eval_from_size_W} 57 | eval_from_size_H: ${eval_from_size_H} 58 | area_from_size_W: ${area_from_size_W} 59 | area_from_size_H: ${area_from_size_H} 60 | verbose: ${verbose} 61 | 62 | 63 | ## update the size info for area_matcher 64 | area_matcher: 65 | datasetName: ${dataset_name} 66 | W: ${area_from_size_W} 67 | H: ${area_from_size_H} 68 | draw_verbose: ${verbose} 69 | 70 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_loftr_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: loftr_outdoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-loftr-egam-megadepth' 11 | dataset_name: MegaDepth 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | pm_acc_thds: [0.0001, 0.0003, 0.0005, 0.0007, 0.001, 0.003, 0.005] 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 # Not work for MD 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 # Not work for MD 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 1200 28 | crop_size_H: 1200 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | crop_size_W: ${crop_size_W} 52 | crop_size_H: ${crop_size_H} 53 | crop_from_size_W: ${crop_from_size_W} 54 | crop_from_size_H: ${crop_from_size_H} 55 | eval_from_size_W: ${eval_from_size_W} 56 | eval_from_size_H: ${eval_from_size_H} 57 | area_from_size_W: ${area_from_size_W} 58 | area_from_size_H: ${area_from_size_H} 59 | reject_out_area_flag: 1 60 | verbose: ${verbose} 61 | 62 | 63 | ## update the size info for area_matcher 64 | area_matcher: 65 | datasetName: ${dataset_name} 66 | W: ${area_from_size_W} 67 | H: ${area_from_size_H} 68 | draw_verbose: ${verbose} 69 | 70 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_loftr_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: loftr_indoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-egam-loftr-scannet' 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | pm_acc_thds: [1, 3, 5, 7, 9] 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 480 28 | crop_size_H: 480 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | crop_size_W: ${crop_size_W} 52 | crop_size_H: ${crop_size_H} 53 | crop_from_size_W: ${crop_from_size_W} 54 | crop_from_size_H: ${crop_from_size_H} 55 | eval_from_size_W: ${eval_from_size_W} 56 | eval_from_size_H: ${eval_from_size_H} 57 | area_from_size_W: ${area_from_size_W} 58 | area_from_size_H: ${area_from_size_H} 59 | verbose: ${verbose} 60 | 61 | 62 | ## update the size info for area_matcher 63 | area_matcher: 64 | datasetName: ${dataset_name} 65 | W: ${area_from_size_W} 66 | H: ${area_from_size_H} 67 | draw_verbose: ${verbose} 68 | 69 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_spsg_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: spsg_outdoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-egam-spsg-megadepth' 11 | dataset_name: MegaDepth 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | pm_acc_thds: [0.1, 0.5, 1, 5, 10] 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 # not work 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 # not work 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 640 28 | crop_size_H: 640 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | crop_size_W: ${crop_size_W} 52 | crop_size_H: ${crop_size_H} 53 | crop_from_size_W: ${crop_from_size_W} 54 | crop_from_size_H: ${crop_from_size_H} 55 | eval_from_size_W: ${eval_from_size_W} 56 | eval_from_size_H: ${eval_from_size_H} 57 | area_from_size_W: ${area_from_size_W} 58 | area_from_size_H: ${area_from_size_H} 59 | reject_out_area_flag: 1 60 | verbose: ${verbose} 61 | # specific for spsg 62 | valid_inside_area_match_num: 10 63 | 64 | 65 | ## update the size info for area_matcher 66 | area_matcher: 67 | datasetName: ${dataset_name} 68 | W: ${area_from_size_W} 69 | H: ${area_from_size_H} 70 | draw_verbose: ${verbose} 71 | 72 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_dmesa_egam_spsg_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: spsg_indoor 5 | - /area_matcher: dmesa 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: 'dmesa-egam-spsg-scannet' 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | pm_acc_thds: [1, 3, 5] 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 480 28 | crop_size_H: 480 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | crop_size_W: ${crop_size_W} 52 | crop_size_H: ${crop_size_H} 53 | crop_from_size_W: ${crop_from_size_W} 54 | crop_from_size_H: ${crop_from_size_H} 55 | eval_from_size_W: ${eval_from_size_W} 56 | eval_from_size_H: ${eval_from_size_H} 57 | area_from_size_W: ${area_from_size_W} 58 | area_from_size_H: ${area_from_size_H} 59 | # specific for spsg 60 | valid_inside_area_match_num: 10 61 | verbose: ${verbose} 62 | 63 | 64 | ## update the size info for area_matcher 65 | area_matcher: 66 | datasetName: ${dataset_name} 67 | W: ${area_from_size_W} 68 | H: ${area_from_size_H} 69 | draw_verbose: ${verbose} 70 | 71 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_dkm_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: dkm_outdoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-dkm-megadepth 11 | dataset_name: MegaDepth 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | # size info 17 | area_from_size_W: 640 # semantic size 18 | area_from_size_H: 480 19 | 20 | eval_from_size_W: 640 # not used 21 | eval_from_size_H: 480 22 | 23 | crop_from_size_W: 1296 # not used 24 | crop_from_size_H: 968 25 | 26 | crop_size_W: 832 27 | crop_size_H: 832 28 | 29 | # others 30 | verbose: 1 31 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 32 | match_num: 5000 33 | 34 | ############################################# UPDATE ############################################# 35 | # match info for eval 36 | evaler: 37 | eval_corr_num: ${match_num} 38 | sample_mode: grid 39 | sac_mode: MAGSAC 40 | out_path: ${out_path} 41 | 42 | # update point_matcher 43 | point_matcher: 44 | dataset_name: ${dataset_name} 45 | 46 | ## update the size info for gam 47 | geo_area_matcher: 48 | datasetName: ${dataset_name} 49 | std_match_num: ${match_num} 50 | alpha_list: [3] 51 | crop_mode: 2 52 | crop_size_W: ${crop_size_W} 53 | crop_size_H: ${crop_size_H} 54 | crop_from_size_W: ${crop_from_size_W} 55 | crop_from_size_H: ${crop_from_size_H} 56 | eval_from_size_W: ${eval_from_size_W} 57 | eval_from_size_H: ${eval_from_size_H} 58 | area_from_size_W: ${area_from_size_W} 59 | area_from_size_H: ${area_from_size_H} 60 | reject_out_area_flag: 1 61 | verbose: ${verbose} 62 | 63 | 64 | ## update the size info for area_matcher 65 | area_matcher: 66 | datasetName: ${dataset_name} 67 | W: ${area_from_size_W} 68 | H: ${area_from_size_H} 69 | draw_verbose: ${verbose} 70 | 71 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_dkm_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: dkm_indoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-dkm-sn-eval 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 640 28 | crop_size_H: 480 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 5000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: grid 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | 43 | # update point_matcher 44 | point_matcher: 45 | dataset_name: ${dataset_name} 46 | 47 | ## update the size info for gam 48 | geo_area_matcher: 49 | datasetName: ${dataset_name} 50 | alpha_list: [3.5] 51 | std_match_num: ${match_num} 52 | crop_size_W: ${crop_size_W} 53 | crop_size_H: ${crop_size_H} 54 | crop_from_size_W: ${crop_from_size_W} 55 | crop_from_size_H: ${crop_from_size_H} 56 | eval_from_size_W: ${eval_from_size_W} 57 | eval_from_size_H: ${eval_from_size_H} 58 | area_from_size_W: ${area_from_size_W} 59 | area_from_size_H: ${area_from_size_H} 60 | verbose: ${verbose} 61 | 62 | 63 | ## update the size info for area_matcher 64 | area_matcher: 65 | datasetName: ${dataset_name} 66 | W: ${area_from_size_W} 67 | H: ${area_from_size_H} 68 | draw_verbose: ${verbose} 69 | 70 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_loftr_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: loftr_outdoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-loftr-megadepth 11 | dataset_name: MegaDepth 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | # size info 17 | area_from_size_W: 640 # semantic size 18 | area_from_size_H: 480 19 | 20 | eval_from_size_W: 640 21 | eval_from_size_H: 480 22 | 23 | crop_from_size_W: 1296 24 | crop_from_size_H: 968 25 | 26 | crop_size_W: 864 27 | crop_size_H: 864 28 | 29 | # others 30 | verbose: 1 31 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 32 | match_num: 1000 33 | 34 | ############################################# UPDATE ############################################# 35 | # match info for eval 36 | evaler: 37 | eval_corr_num: ${match_num} 38 | sample_mode: random 39 | sac_mode: MAGSAC 40 | out_path: ${out_path} 41 | 42 | # update point_matcher 43 | point_matcher: 44 | dataset_name: ${dataset_name} 45 | 46 | ## update the size info for gam 47 | geo_area_matcher: 48 | crop_size_W: ${crop_size_W} 49 | crop_size_H: ${crop_size_H} 50 | crop_from_size_W: ${crop_from_size_W} 51 | crop_from_size_H: ${crop_from_size_H} 52 | eval_from_size_W: ${eval_from_size_W} 53 | eval_from_size_H: ${eval_from_size_H} 54 | area_from_size_W: ${area_from_size_W} 55 | area_from_size_H: ${area_from_size_H} 56 | datasetName: ${dataset_name} 57 | reject_out_area_flag: 1 58 | sampler_name: 'GridFill' # specify the sampler name 59 | reject_out_area_flag: 1 60 | verbose: ${verbose} 61 | 62 | 63 | ## update the size info for area_matcher 64 | area_matcher: 65 | datasetName: ${dataset_name} 66 | W: ${area_from_size_W} 67 | H: ${area_from_size_H} 68 | draw_verbose: ${verbose} 69 | 70 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_loftr_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: loftr_indoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-loftr-scannet 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | # size info 17 | area_from_size_W: 640 # semantic size 18 | area_from_size_H: 480 19 | 20 | eval_from_size_W: 640 21 | eval_from_size_H: 480 22 | 23 | crop_from_size_W: 1296 24 | crop_from_size_H: 968 25 | 26 | crop_size_W: 480 27 | crop_size_H: 480 28 | 29 | # others 30 | verbose: 1 31 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 32 | match_num: 1000 33 | 34 | ############################################# UPDATE ############################################# 35 | # match info for eval 36 | evaler: 37 | eval_corr_num: ${match_num} 38 | sample_mode: grid 39 | sac_mode: MAGSAC 40 | out_path: ${out_path} 41 | 42 | # update point_matcher 43 | point_matcher: 44 | dataset_name: ${dataset_name} 45 | 46 | ## update the size info for gam 47 | geo_area_matcher: 48 | crop_size_W: ${crop_size_W} 49 | crop_size_H: ${crop_size_H} 50 | crop_from_size_W: ${crop_from_size_W} 51 | crop_from_size_H: ${crop_from_size_H} 52 | eval_from_size_W: ${eval_from_size_W} 53 | eval_from_size_H: ${eval_from_size_H} 54 | area_from_size_W: ${area_from_size_W} 55 | area_from_size_H: ${area_from_size_H} 56 | datasetName: ${dataset_name} 57 | verbose: ${verbose} 58 | 59 | 60 | ## update the size info for area_matcher 61 | area_matcher: 62 | datasetName: ${dataset_name} 63 | W: ${area_from_size_W} 64 | H: ${area_from_size_H} 65 | draw_verbose: ${verbose} 66 | 67 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_spsg_megadepth.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: megadepth 4 | - /point_matcher: spsg_outdoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-spsg-megadepth 11 | dataset_name: MegaDepth 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 832 28 | crop_size_H: 832 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: random 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | datasetName: ${dataset_name} 51 | crop_size_W: ${crop_size_W} 52 | crop_size_H: ${crop_size_H} 53 | crop_from_size_W: ${crop_from_size_W} 54 | crop_from_size_H: ${crop_from_size_H} 55 | eval_from_size_W: ${eval_from_size_W} 56 | eval_from_size_H: ${eval_from_size_H} 57 | area_from_size_W: ${area_from_size_W} 58 | area_from_size_H: ${area_from_size_H} 59 | reject_out_area_flag: 1 60 | # specific for spsg 61 | valid_inside_area_match_num: 10 62 | verbose: ${verbose} 63 | 64 | 65 | ## update the size info for area_matcher 66 | area_matcher: 67 | datasetName: ${dataset_name} 68 | W: ${area_from_size_W} 69 | H: ${area_from_size_H} 70 | draw_verbose: ${verbose} 71 | 72 | -------------------------------------------------------------------------------- /conf/experiment/a2pm_mesa_egam_spsg_scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: scannet_sam 4 | - /point_matcher: spsg_indoor 5 | - /area_matcher: mesa-f 6 | - /geo_area_matcher: egam 7 | - /evaler: instance_eval 8 | - _self_ 9 | 10 | name: mesa-f-egam-spsg-scannet 11 | dataset_name: ScanNet 12 | test_area_acc: 0 13 | test_pose_err: 1 14 | test_pm_acc: 0 15 | 16 | 17 | # size info 18 | area_from_size_W: 640 # semantic size 19 | area_from_size_H: 480 20 | 21 | eval_from_size_W: 640 22 | eval_from_size_H: 480 23 | 24 | crop_from_size_W: 1296 25 | crop_from_size_H: 968 26 | 27 | crop_size_W: 480 28 | crop_size_H: 480 29 | 30 | # others 31 | verbose: 1 32 | out_path: /opt/data/private/A2PM-git/A2PM-MESA/res/${name}-res 33 | match_num: 1000 34 | 35 | ############################################# UPDATE ############################################# 36 | # match info for eval 37 | evaler: 38 | eval_corr_num: ${match_num} 39 | sample_mode: grid 40 | sac_mode: MAGSAC 41 | out_path: ${out_path} 42 | draw_verbose: ${verbose} 43 | 44 | # update point_matcher 45 | point_matcher: 46 | dataset_name: ${dataset_name} 47 | 48 | ## update the size info for gam 49 | geo_area_matcher: 50 | crop_size_W: ${crop_size_W} 51 | crop_size_H: ${crop_size_H} 52 | crop_from_size_W: ${crop_from_size_W} 53 | crop_from_size_H: ${crop_from_size_H} 54 | eval_from_size_W: ${eval_from_size_W} 55 | eval_from_size_H: ${eval_from_size_H} 56 | area_from_size_W: ${area_from_size_W} 57 | area_from_size_H: ${area_from_size_H} 58 | # specific for spsg 59 | valid_inside_area_match_num: 10 60 | verbose: ${verbose} 61 | 62 | 63 | ## update the size info for area_matcher 64 | area_matcher: 65 | datasetName: ${dataset_name} 66 | W: ${area_from_size_W} 67 | H: ${area_from_size_H} 68 | draw_verbose: ${verbose} 69 | 70 | -------------------------------------------------------------------------------- /conf/experiment/demo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: demo_pair 4 | - /point_matcher: spsg_indoor 5 | # - /area_matcher: mesa 6 | - /area_matcher: dmesa 7 | # - /geo_area_matcher: gam 8 | - /geo_area_matcher: egam 9 | - /evaler: instance_eval 10 | - _self_ 11 | 12 | name: dmesa-f-egam-spsg-single-pair-demo 13 | dataset_name: ScanNet # use ScanNet for indoor demo, MegaDepth for outdoor demo 14 | 15 | # size info 16 | area_from_size_W: 640 # semantic size 17 | area_from_size_H: 480 18 | 19 | eval_from_size_W: 640 # NOT USED, we eval on the original size 20 | eval_from_size_H: 480 21 | 22 | crop_from_size_W: 1296 # NOT USED, we crop on the original size 23 | crop_from_size_H: 968 24 | 25 | crop_size_W: 480 26 | crop_size_H: 480 27 | 28 | # others 29 | verbose: 1 30 | out_path: /opt/data/private/SA2M/hydra-mesa/res/${name}-res 31 | match_num: 1000 32 | 33 | ############################################# UPDATE ############################################# 34 | # match info for eval 35 | evaler: 36 | eval_corr_num: ${match_num} 37 | sample_mode: grid 38 | sac_mode: MAGSAC 39 | out_path: ${out_path} 40 | draw_verbose: ${verbose} 41 | 42 | # update point_matcher 43 | point_matcher: 44 | dataset_name: ${dataset_name} 45 | 46 | ## update the size info for gam 47 | geo_area_matcher: 48 | datasetName: demo # SPECIFIC for gamer 49 | crop_size_W: ${crop_size_W} 50 | crop_size_H: ${crop_size_H} 51 | crop_from_size_W: ${crop_from_size_W} 52 | crop_from_size_H: ${crop_from_size_H} 53 | eval_from_size_W: ${eval_from_size_W} 54 | eval_from_size_H: ${eval_from_size_H} 55 | area_from_size_W: ${area_from_size_W} 56 | area_from_size_H: ${area_from_size_H} 57 | # specific for spsg 58 | valid_inside_area_match_num: 10 59 | verbose: ${verbose} 60 | 61 | 62 | ## update the size info for area_matcher 63 | area_matcher: 64 | datasetName: ${dataset_name} 65 | W: ${area_from_size_W} 66 | H: ${area_from_size_H} 67 | draw_verbose: ${verbose} -------------------------------------------------------------------------------- /conf/geo_area_matcher/egam.yaml: -------------------------------------------------------------------------------- 1 | _target_: geo_area_matchers.egam.EGeoAreaMatcher 2 | datasetName: 'ScanNet' 3 | 4 | area_from_size_W: 640 5 | area_from_size_H: 480 6 | 7 | crop_size_W: 640 8 | crop_size_H: 640 9 | 10 | crop_from_size_W: 1296 11 | crop_from_size_H: 968 12 | 13 | eval_from_size_W: 640 14 | eval_from_size_H: 480 15 | 16 | std_match_num: 1000 17 | alpha_list: #[0.5, 2.0, 3.0, 5.0, 10.0] 18 | - 0.5 19 | - 2.0 20 | - 3.0 21 | - 3.5 22 | - 5.0 23 | adaptive_size_thd: 1.0 24 | valid_inside_area_match_num: 100 25 | reject_out_area_flag: 0 26 | crop_mode: 0 27 | sac_mode: 'MAGSAC' 28 | 29 | # sampler use or not 30 | sampler_name: '' 31 | occ_size: 1 32 | common_occ_flag: 1 33 | 34 | verbose: 0 -------------------------------------------------------------------------------- /conf/geo_area_matcher/gam.yaml: -------------------------------------------------------------------------------- 1 | _target_: geo_area_matchers.gam.PRGeoAreaMatcher 2 | datasetName: ScanNet 3 | 4 | area_from_size_W: 640 5 | area_from_size_H: 480 6 | 7 | crop_size_W: 640 8 | crop_size_H: 640 9 | 10 | crop_from_size_W: 1296 11 | crop_from_size_H: 968 12 | 13 | eval_from_size_W: 640 14 | eval_from_size_H: 480 15 | 16 | std_match_num: 1000 17 | alpha_list: #[0.5, 2.0, 3.0, 5.0, 10.0] 18 | - 0.5 19 | - 2.0 20 | - 3.0 21 | - 5.0 22 | - 10.0 23 | filter_area_num: 1 24 | adaptive_size_thd: 0.6 25 | reject_out_area_flag: 0 26 | valid_inside_area_match_num: 200 27 | verbose: 0 -------------------------------------------------------------------------------- /conf/point_matcher/aspan_indoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.aspanformer.ASpanMatcher 2 | 3 | config_path: '/opt/data/private/A2PM-git/A2PM-MESA/point_matchers/ASpanFormer/configs' 4 | weights: '/opt/data/private/SA2M/Matchers/ASpanFormer/weights/indoor.ckpt' 5 | dataset_name: ??? # ScanNet or ETH3D or KITTI 6 | 7 | -------------------------------------------------------------------------------- /conf/point_matcher/aspan_outdoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.aspanformer.ASpanMatcher 2 | 3 | config_path: '/opt/data/private/A2PM-git/A2PM-MESA/point_matchers/ASpanFormer/configs' 4 | weights: '/opt/data/private/SA2M/Matchers/ASpanFormer/weights/outdoor.ckpt' 5 | dataset_name: ??? 6 | 7 | -------------------------------------------------------------------------------- /conf/point_matcher/dkm_indoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.dkm.DKMMatcher 2 | 3 | dataset_name: ScanNet 4 | weights: '/opt/data/private/SA2M/hydra-mesa/point_matchers/DKM/weights/DKMv3_indoor.pth' 5 | -------------------------------------------------------------------------------- /conf/point_matcher/dkm_outdoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.dkm.DKMMatcher 2 | 3 | dataset_name: MegaDepth 4 | weights: '/opt/data/private/SA2M/hydra-mesa/point_matchers/DKM/weights/DKMv3_outdoor.pth' 5 | -------------------------------------------------------------------------------- /conf/point_matcher/loftr_indoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.loftr.LoFTRMatcher 2 | 3 | # config_path: 'path/to/ASpanFormer/configs/' 4 | config_path: '/opt/data/private/A2PM-git/A2PM-MESA/point_matchers/LoFTR/configs' 5 | weights: '/opt/data/private/SA2M/hydra-mesa/point_matchers/LoFTR/weights/indoor_ds_new.ckpt' 6 | dataset_name: ??? # ScanNet or ETH3D or KITTI 7 | cross_domain: False -------------------------------------------------------------------------------- /conf/point_matcher/loftr_outdoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.loftr.LoFTRMatcher 2 | 3 | # config_path: 'path/to/ASpanFormer/configs/' 4 | config_path: '/opt/data/private/A2PM-git/A2PM-MESA/point_matchers/LoFTR/configs' 5 | weights: '/opt/data/private/SA2M/hydra-mesa/point_matchers/LoFTR/weights/outdoor_ds.ckpt' 6 | dataset_name: ??? 7 | cross_domain: False -------------------------------------------------------------------------------- /conf/point_matcher/spsg_indoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.spsg.SPSGMatcher 2 | 3 | dataset_name: ScanNet 4 | weights: 'indoor' -------------------------------------------------------------------------------- /conf/point_matcher/spsg_outdoor.yaml: -------------------------------------------------------------------------------- 1 | _target_: point_matchers.spsg.SPSGMatcher 2 | 3 | dataset_name: MegaDepth 4 | weights: 'outdoor' -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/abstract_dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 22:14:07 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-18 21:14:31 6 | FilePath: /SA2M/hydra-mesa/dataloader/abstract_dataloader.py 7 | Description: abstract dataloader class 8 | what dose the dataloader do? 9 | - Input 10 | - image pair name 11 | - data root path 12 | * specific dataset gets specific file structure 13 | should be handled by the specific child class 14 | - Output 15 | - image data 16 | - size 17 | - geo info 18 | - K 19 | - pose 20 | - depth data (Optional) 21 | - size 22 | - depth map 23 | - semantic info 24 | - semantic mask (*.png or *.npy) 25 | 26 | 27 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 28 | ''' 29 | 30 | import abc 31 | import numpy as np 32 | from typing import List, Optional, Any, Tuple 33 | import cv2 34 | 35 | from utils.load import load_cv_img_resize 36 | 37 | class AbstractDataloader(abc.ABC): 38 | """ 39 | Abstract class for dataloader 40 | """ 41 | 42 | def __init__(self, 43 | root_path, 44 | scene_name, 45 | image_name0, 46 | image_name1, 47 | ) -> None: 48 | super().__init__() 49 | self._name = 'AbstractDataloader' 50 | self.root_path = root_path 51 | self.scene_name = scene_name 52 | self.image_name0 = str(image_name0) 53 | self.image_name1 = str(image_name1) 54 | 55 | # self.paths init 56 | self.img0_path = None 57 | self.img1_path = None 58 | 59 | # depth info 60 | self.depth0_path = None 61 | self.depth1_path = None 62 | 63 | # semantic info 64 | self.sem0_path = None 65 | self.sem1_path = None 66 | 67 | # geo info 68 | self.K0_path = None 69 | self.K1_path = None 70 | self.pose0_path = None 71 | self.pose1_path = None 72 | 73 | def name(self) -> str: 74 | """ 75 | Return the name of the dataset 76 | """ 77 | return self._name 78 | 79 | @abc.abstractmethod 80 | def _path_assemble(self): 81 | """Assemble the paths 82 | Returns: 83 | assembled path to self 84 | """ 85 | pass 86 | 87 | def load_images(self, W=None, H=None, PMer=False): 88 | """ load images 89 | """ 90 | assert not PMer, 'Error: PMer is not supported in this dataloader' 91 | if W is None or H is None: 92 | # load as original size 93 | img0 = cv2.imread(self.img0_path, cv2.IMREAD_COLOR) 94 | img1 = cv2.imread(self.img1_path, cv2.IMREAD_COLOR) 95 | scale0 = 1 96 | scale1 = 1 97 | else: 98 | img0, scale0 = load_cv_img_resize(self.img0_path, W, H, 1) 99 | img1, scale1 = load_cv_img_resize(self.img1_path, W, H, 1) 100 | 101 | return img0, img1, scale0, scale1 102 | 103 | 104 | @abc.abstractmethod 105 | def load_Ks(self) -> Tuple[np.ndarray, np.ndarray]: 106 | """Load Ks from self path 107 | Returns: 108 | K0, K1: np.ndarray 3x3 109 | """ 110 | pass 111 | 112 | @abc.abstractmethod 113 | def load_depths(self) -> Tuple[np.ndarray, np.ndarray]: 114 | """ 115 | Load depth from self path 116 | """ 117 | pass 118 | 119 | @abc.abstractmethod 120 | def load_semantics(self) -> Tuple[np.ndarray, np.ndarray]: 121 | """ 122 | Load semantic from self path 123 | """ 124 | pass 125 | 126 | @abc.abstractmethod 127 | def load_poses(self) -> Tuple[np.ndarray, np.ndarray]: 128 | """ 129 | Load pose from self path 130 | """ 131 | pass 132 | 133 | @abc.abstractmethod 134 | def get_eval_info(self): 135 | """Return eval info as dict 136 | Returns: 137 | eval_info: dict 138 | - dataset_name: str 139 | - image0, image1: np.ndarray HxWx3 140 | - K0, K1: np.ndarray 3x3 141 | - P0, P1: np.ndarray 4x4 142 | - optional: depth_factor, depth0, depth1 143 | """ 144 | pass 145 | 146 | @abc.abstractmethod 147 | def get_sem_paths(self) -> Tuple[str, str]: 148 | """ 149 | Return semantic paths 150 | """ 151 | pass 152 | 153 | @abc.abstractmethod 154 | def tune_corrs_size_to_eval(self, corrs, match_W, match_H, eval_W, eval_H): 155 | """ 156 | Tune the corrs size to eval size 157 | """ 158 | pass -------------------------------------------------------------------------------- /dataloader/demo_pair_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-10-19 21:54:21 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-11-03 15:14:50 6 | FilePath: /SA2M/hydra-mesa/dataloader/demo_pair_loader.py 7 | Description: data loader for demo pair 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import os 12 | import cv2 13 | from loguru import logger 14 | from typing import List, Optional, Any, Tuple 15 | import numpy as np 16 | 17 | from .abstract_dataloader import AbstractDataloader 18 | 19 | 20 | class DemoPairLoader(AbstractDataloader): 21 | """ dataloader for demo pair 22 | """ 23 | 24 | def __init__(self, 25 | root_path, 26 | scene_name, # not used 27 | image_name0, 28 | image_name1, 29 | color_folder, 30 | color_post, 31 | sem_folder, 32 | sem_post, 33 | intrin_folder="", 34 | intrin_post="", 35 | ) -> None: 36 | super().__init__(root_path, scene_name, image_name0, image_name1) 37 | self.color_folder = color_folder 38 | self.color_post = color_post 39 | self.sem_folder = sem_folder 40 | self.sem_post = sem_post 41 | self.intrin_folder = intrin_folder 42 | self.intrin_post = intrin_post 43 | self._name = 'DemoPairLoader' 44 | 45 | self._path_assemble() 46 | 47 | def _path_assemble(self) -> None: 48 | """ assemble path 49 | """ 50 | self.img0_path = os.path.join(self.root_path, self.color_folder, self.image_name0 + f".{self.color_post}") 51 | self.img1_path = os.path.join(self.root_path, self.color_folder, self.image_name1 + f".{self.color_post}") 52 | 53 | self.sem0_path = os.path.join(self.root_path, self.sem_folder, self.scene_name, self.image_name0 + f".{self.sem_post}") 54 | self.sem1_path = os.path.join(self.root_path, self.sem_folder, self.scene_name, self.image_name1 + f".{self.sem_post}") 55 | 56 | self.K0_path = os.path.join(self.root_path, self.intrin_folder, self.scene_name, self.image_name0 + f".{self.intrin_post}") 57 | self.K1_path = os.path.join(self.root_path, self.intrin_folder, self.scene_name, self.image_name1 + f".{self.intrin_post}") 58 | 59 | def load_Ks(self, scale0, scale1): 60 | """ load Ks 61 | """ 62 | if self.intrin_folder == "": 63 | logger.warning(f"no intrinsic parameter provided, should avoid using egam and use gam.") 64 | return None, None 65 | 66 | K0 = open(self.K0_path, 'r').readlines() 67 | K0 = K0[0].strip().split(' ') 68 | fx0, fy0, cx0, cy0 = [float(x) for x in K0] 69 | if scale0 is List: 70 | scale0_x, scale0_y = scale0 71 | else: 72 | scale0_x = scale0 73 | scale0_y = scale0 74 | fx0 *= scale0_x 75 | fy0 *= scale0_y 76 | cx0 *= scale0_x 77 | cy0 *= scale0_y 78 | K0 = np.array([[fx0, 0, cx0], [0, fy0, cy0], [0, 0, 1]]) 79 | 80 | 81 | K1 = open(self.K1_path, 'r').readlines() 82 | K1 = K1[0].strip().split(' ') 83 | fx1, fy1, cx1, cy1 = [float(x) for x in K1] 84 | if scale1 is List: 85 | scale1_x, scale1_y = scale1 86 | else: 87 | scale1_x = scale1 88 | scale1_y = scale1 89 | fx1 *= scale1_x 90 | fy1 *= scale1_y 91 | cx1 *= scale1_x 92 | cy1 *= scale1_y 93 | K1 = np.array([[fx1, 0, cx1], [0, fy1, cy1], [0, 0, 1]]) 94 | 95 | return K0, K1 96 | 97 | 98 | 99 | def load_depths(self): 100 | """ load depth data 101 | """ 102 | raise NotImplementedError 103 | 104 | def load_poses(self): 105 | """ load pose 106 | Note: no GT pose provided in demo pair, can be implemented if you have 107 | """ 108 | logger.warning(f"load pose not implemented") 109 | return None, None 110 | 111 | def get_eval_info(self): 112 | """ get eval info 113 | """ 114 | raise NotImplementedError 115 | 116 | def get_sem_paths(self): 117 | """ get semantic paths 118 | """ 119 | return self.sem0_path, self.sem1_path 120 | 121 | def load_semantics(self) -> Tuple[np.ndarray, np.ndarray]: 122 | """ load semantic info 123 | """ 124 | sem0 = np.load(self.sem0_path, allow_pickle=True) 125 | sem1 = np.load(self.sem1_path, allow_pickle=True) 126 | return sem0, sem1 127 | 128 | def tune_corrs_size_to_eval(self): 129 | """ tune corrs size to eval 130 | """ 131 | raise NotImplementedError -------------------------------------------------------------------------------- /dataloader/megadepth.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-29 11:45:30 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-01 19:46:51 6 | FilePath: /SA2M/hydra-mesa/dataloader/megadepth.py 7 | Description: dataloader for MegaDepth 8 | Most data of MegaDepth is saved in npz file 9 | only npz_file_name is needed 10 | 11 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 12 | ''' 13 | 14 | 15 | import os 16 | import numpy as np 17 | import cv2 18 | from typing import List, Optional, Any, Tuple 19 | 20 | from .abstract_dataloader import AbstractDataloader 21 | from utils.load import ( # beyond the current directory, use absolute import, use sys.path.append('..') to add the parent directory in the main script 22 | load_cv_img_resize, 23 | load_cv_depth, 24 | load_K_txt, 25 | load_pose_txt, 26 | ) 27 | 28 | from utils.img_process import load_img_padding_rt_size 29 | from utils.geo import tune_corrs_size_diff 30 | 31 | class MegaDepthDataloader(AbstractDataloader): 32 | """ TODO: 33 | """ 34 | 35 | def __init__(self, 36 | root_path, 37 | scene_name, # not work 38 | image_name0, # specific structure 00xx_x_x_id 39 | image_name1, 40 | sem_mode, 41 | sem_folder, 42 | sem_post, 43 | ): 44 | """ 45 | """ 46 | super().__init__(root_path, scene_name, image_name0, image_name1) 47 | 48 | self._name = 'MegaDepthDataloader' 49 | 50 | self.sem_mode = sem_mode 51 | assert self.sem_mode in ['SAM'] 52 | self.sem_folder = sem_folder 53 | self.sem_post = sem_post 54 | 55 | self._path_assemble() 56 | 57 | def _path_assemble(self): 58 | """ 59 | """ 60 | npz_folder = os.path.join(self.root_path, "scene_info_val_1500") 61 | pair0 = self.image_name0 # npz_file_name_id, e.g. 0000_0.1_0.2_id 62 | pair1 = self.image_name1 # npz_file_name has "_id" in the end 63 | # get the npz file name 64 | pair0_npz_name = "_".join(pair0.split("_")[:-1]) 65 | pair1_npz_name = "_".join(pair1.split("_")[:-1]) 66 | assert pair0_npz_name == pair1_npz_name, "pair0 and pair1 should be in the same npz file" 67 | id0 = (pair0.split("_")[-1]) 68 | id1 = (pair1.split("_")[-1]) 69 | 70 | npz_path = os.path.join(npz_folder, pair0_npz_name+".npz") 71 | 72 | npz_data = np.load(npz_path, allow_pickle=True) 73 | 74 | # get the image path 75 | img_path0 = npz_data["image_paths"][int(id0)] 76 | img_path1 = npz_data["image_paths"][int(id1)] 77 | img_folder0 = img_path0.split("/")[1] 78 | img_folder1 = img_path1.split("/")[1] 79 | img_name0 = img_path0.split("/")[-1].split(".")[0] 80 | img_name1 = img_path1.split("/")[-1].split(".")[0] 81 | 82 | if self.sem_mode == "SAM": 83 | sem_path = self.sem_folder 84 | sem_post = self.sem_post 85 | sem_path0 = os.path.join(sem_path, "MegaDepth1500", img_folder0, f'{img_name0}.{sem_post}') 86 | sem_path1 = os.path.join(sem_path, "MegaDepth1500", img_folder1, f'{img_name1}.{sem_post}') 87 | self.sem0_path = sem_path0 88 | self.sem1_path = sem_path1 89 | else: 90 | raise NotImplementedError(f"semantic mode {semantic_mode} not implemented") 91 | 92 | self.img0_path = os.path.join(self.root_path, img_path0) 93 | self.img1_path = os.path.join(self.root_path, img_path1) 94 | 95 | img0 = cv2.imread(self.img0_path, cv2.IMREAD_COLOR) 96 | img1 = cv2.imread(self.img1_path, cv2.IMREAD_COLOR) 97 | self.eval_W0, self.eval_H0 = img0.shape[1], img0.shape[0] 98 | self.eval_W1, self.eval_H1 = img1.shape[1], img1.shape[0] 99 | 100 | 101 | self.K0 = npz_data["intrinsics"][int(id0)].astype(np.float32) 102 | self.K1 = npz_data["intrinsics"][int(id1)].astype(np.float32) 103 | 104 | self.pose0 = npz_data["poses"][int(id0)] 105 | self.pose1 = npz_data["poses"][int(id1)] 106 | self.pose0 = np.matrix(self.pose0).astype(np.float32) 107 | self.pose1 = np.matrix(self.pose1).astype(np.float32) 108 | 109 | # override 110 | def load_images(self, W=None, H=None, PMer=False): 111 | """ 112 | """ 113 | if not PMer: 114 | return super().load_images(W, H) 115 | else: 116 | # specific for PMer: need padding 117 | match_color0, mask0, size0_ = load_img_padding_rt_size(self.img0_path, [W, H]) 118 | crop_W0, crop_H0 = size0_ # NOTE: only used for PMer 119 | match_color1, mask1, size1_ = load_img_padding_rt_size(self.img1_path, [W, H]) 120 | crop_W1, crop_H1 = size1_ 121 | 122 | return match_color0, mask0, crop_W0, crop_H0, match_color1, mask1, crop_W1, crop_H1 123 | 124 | def load_Ks(self, scale0=None, scale1=None): 125 | """ 126 | """ 127 | return self.K0, self.K1 128 | 129 | def load_poses(self): 130 | """ 131 | """ 132 | return self.pose0, self.pose1 133 | 134 | def load_depths(self): 135 | """ 136 | """ 137 | raise NotImplementedError("MegaDepth does not provide depth info") 138 | 139 | def load_semantics(self): 140 | """ 141 | """ 142 | if self.sem_mode == "SAM": 143 | sem0 = np.load(self.sem0_path, allow_pickle=True) 144 | sem1 = np.load(self.sem1_path, allow_pickle=True) 145 | else: 146 | raise NotImplementedError(f"sem_mode {self.sem_mode} not implemented") 147 | 148 | return sem0, sem1 149 | 150 | def get_sem_paths(self): 151 | """ 152 | """ 153 | return self.sem0_path, self.sem1_path 154 | 155 | def get_eval_info(self, eval_W, eval_H): 156 | """ 157 | """ 158 | image0 = cv2.imread(self.img0_path, cv2.IMREAD_COLOR) 159 | image1 = cv2.imread(self.img1_path, cv2.IMREAD_COLOR) 160 | 161 | eval_W0, eval_H0 = image0.shape[1], image0.shape[0] 162 | eval_W1, eval_H1 = image1.shape[1], image1.shape[0] 163 | 164 | K0, K1 = self.load_Ks() 165 | P0, P1 = self.load_poses() 166 | sem0, sem1 = self.load_semantics() 167 | 168 | eval_info = { 169 | "dataset_name": "MegaDepth", 170 | "image0": image0, 171 | "image1": image1, 172 | 'eval_W0': eval_W0, 173 | 'eval_H0': eval_H0, 174 | 'eval_W1': eval_W1, 175 | 'eval_H1': eval_H1, 176 | "K0": K0, 177 | "K1": K1, 178 | "P0": P0, 179 | "P1": P1, 180 | "sem0": sem0, 181 | "sem1": sem1, 182 | } 183 | 184 | return eval_info 185 | 186 | def tune_corrs_size_to_eval(self, corrs, match_W0, match_H0, match_W1, match_H1): 187 | """ match at the same size, eval at the original size 188 | """ 189 | eval_corrs = tune_corrs_size_diff(corrs, match_W0, match_W1, match_H0, match_H1, self.eval_W0, self.eval_W1, self.eval_H0, self.eval_H1) 190 | return eval_corrs 191 | 192 | 193 | -------------------------------------------------------------------------------- /dataloader/scannet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 22:14:22 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-08 22:34:49 6 | FilePath: /SA2M/hydra-mesa/dataloader/scannet.py 7 | Description: dataloader for scannet dataset, including training and ScanNet1500 8 | The file structure of the ScanNet dataset is as follows: 9 | root_path 10 | ├── scene_name 11 | │ ├── color 12 | │ │ ├── image_name0.jpg 13 | │ │ └── image_name1.jpg 14 | │ │ └── ... 15 | │ ├── depth 16 | │ │ ├── image_name0.png 17 | │ │ └── image_name1.png 18 | │ │ └── ... 19 | │ ├── pose 20 | │ │ ├── image_name0.txt 21 | │ │ └── image_name1.txt 22 | │ │ └── ... 23 | │ ├── intrinsic 24 | │ │ ├── intrinsic_color.txt 25 | │ │ └── intrinsic_depth.txt 26 | 27 | sem_folder 28 | ├── scene_name 29 | │ ├── image_name0.$post 30 | │ └── ... 31 | 32 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 33 | ''' 34 | 35 | import os 36 | import numpy as np 37 | import cv2 38 | from typing import List, Optional, Any, Tuple 39 | 40 | from .abstract_dataloader import AbstractDataloader 41 | from utils.load import ( # beyond the current directory, use absolute import, use sys.path.append('..') to add the parent directory in the main script 42 | load_cv_img_resize, 43 | load_cv_depth, 44 | load_K_txt, 45 | load_pose_txt, 46 | ) 47 | 48 | from utils.geo import tune_corrs_size 49 | 50 | class ScanNetDataloader(AbstractDataloader): 51 | """ dataloader for ScanNet dataset 52 | """ 53 | 54 | def __init__(self, 55 | root_path, 56 | scene_name, 57 | image_name0, 58 | image_name1, 59 | color_folder, 60 | color_post, 61 | depth_folder, 62 | depth_post, 63 | depth_factor, 64 | K_folder, 65 | pose_folder, 66 | pose_post, 67 | sem_folder, 68 | sem_mode, 69 | sem_post, 70 | ) -> None: 71 | super().__init__(root_path, scene_name, image_name0, image_name1) 72 | 73 | self._name = 'ScanNetDataLoader' 74 | 75 | self.color_folder = color_folder 76 | self.color_post = color_post 77 | self.depth_folder = depth_folder 78 | self.depth_post = depth_post 79 | self.depth_factor = depth_factor 80 | self.K_folder = K_folder 81 | self.pose_folder = pose_folder 82 | self.pose_post = pose_post 83 | self.sem_folder = sem_folder 84 | self.sem_mode = sem_mode 85 | self.sem_post = sem_post 86 | 87 | self._path_assemble() 88 | 89 | def reset_imgs(self, scene_name, image_name0, image_name1): 90 | """ 91 | """ 92 | self.scene_name = scene_name 93 | self.image_name0 = image_name0 94 | self.image_name1 = image_name1 95 | self._path_assemble() 96 | 97 | def _path_assemble(self): 98 | """ assemble the paths 99 | """ 100 | # color image 101 | self.img0_path = os.path.join(self.root_path, self.scene_name, self.color_folder, f"{self.image_name0}.{self.color_post}") 102 | self.img1_path = os.path.join(self.root_path, self.scene_name, self.color_folder, f"{self.image_name1}.{self.color_post}") 103 | 104 | # depth image 105 | self.depth0_path = os.path.join(self.root_path, self.scene_name, self.depth_folder, f"{self.image_name0}.{self.depth_post}") 106 | self.depth1_path = os.path.join(self.root_path, self.scene_name, self.depth_folder, f"{self.image_name1}.{self.depth_post}") 107 | 108 | # intrinsic 109 | self.K0_path = os.path.join(self.root_path, self.scene_name, self.K_folder, f"intrinsic_color.txt") 110 | self.K1_path = os.path.join(self.root_path, self.scene_name, self.K_folder, f"intrinsic_color.txt") 111 | 112 | # pose 113 | self.pose0_path = os.path.join(self.root_path, self.scene_name, self.pose_folder, f"{self.image_name0}.{self.pose_post}") 114 | self.pose1_path = os.path.join(self.root_path, self.scene_name, self.pose_folder, f"{self.image_name1}.{self.pose_post}") 115 | 116 | # semantic 117 | assert self.sem_mode in ["GT", "SEEM", "SAM"], f"sem_mode {self.sem_mode} not implemented" 118 | if self.sem_mode == "GT": 119 | assert self.sem_folder == "label-filt", f"sem_folder {self.sem_folder} error" 120 | self.sem0_path = os.path.join(self.root_path, self.scene_name, self.sem_folder, f"{self.image_name0}.{self.sem_post}") 121 | self.sem1_path = os.path.join(self.root_path, self.scene_name, self.sem_folder, f"{self.image_name1}.{self.sem_post}") 122 | elif self.sem_mode == "SEEM" or self.sem_mode == "SAM": 123 | self.sem0_path = os.path.join(self.sem_folder, self.scene_name, f"{self.image_name0}.{self.sem_post}") 124 | self.sem1_path = os.path.join(self.sem_folder, self.scene_name, f"{self.image_name1}.{self.sem_post}") 125 | 126 | def load_Ks(self, scale0, scale1): 127 | """ load Ks 128 | Returns: 129 | K0, K1: np.mat, 3x3 130 | """ 131 | K0 = load_K_txt(self.K0_path, scale0) 132 | K1 = load_K_txt(self.K1_path, scale1) 133 | return K0, K1 134 | 135 | def load_poses(self): 136 | """ load poses 137 | Returns: 138 | P0, P1: np.mat, 4x4 139 | """ 140 | P0 = load_pose_txt(self.pose0_path) 141 | P1 = load_pose_txt(self.pose1_path) 142 | return P0, P1 143 | 144 | def get_depth_factor(self): 145 | return self.depth_factor 146 | 147 | def load_depths(self): 148 | """ 149 | """ 150 | depth0 = load_cv_depth(self.depth0_path) 151 | depth1 = load_cv_depth(self.depth1_path) 152 | return depth0, depth1 153 | 154 | def load_semantics(self, W=None, H=None): 155 | """ 156 | """ 157 | if self.sem_mode == "GT" or self.sem_mode == "SEEM": 158 | sem0, _ = load_cv_img_resize(self.sem0_path, W, H, -1) 159 | sem1, _ = load_cv_img_resize(self.sem1_path, W, H, -1) 160 | elif self.sem_mode == "SAM": 161 | sem0 = np.load(self.sem0_path, allow_pickle=True) 162 | sem1 = np.load(self.sem1_path, allow_pickle=True) 163 | else: 164 | raise NotImplementedError(f"sem_mode {self.sem_mode} not implemented") 165 | 166 | return sem0, sem1 167 | 168 | def get_sem_paths(self): 169 | """ 170 | """ 171 | return self.sem0_path, self.sem1_path 172 | 173 | def get_eval_info(self, eval_W, eval_H): 174 | """ for evaluation 175 | """ 176 | eval_info = {} 177 | image0, image1, scale0, scale1 = self.load_images(eval_W, eval_H) 178 | K0, K1 = self.load_Ks(scale0, scale1) 179 | P0, P1 = self.load_poses() 180 | depth0, depth1 = self.load_depths() 181 | sem0, sem1 = self.load_semantics(eval_W, eval_H) 182 | 183 | eval_info["dataset_name"] = 'ScanNet' 184 | eval_info["image0"] = image0 185 | eval_info["image1"] = image1 186 | eval_info["K0"] = K0 187 | eval_info["K1"] = K1 188 | eval_info["P0"] = P0 189 | eval_info["P1"] = P1 190 | eval_info["depth0"] = depth0 191 | eval_info["depth1"] = depth1 192 | eval_info["depth_factor"] = self.depth_factor 193 | eval_info["sem0"] = sem0 194 | eval_info["sem1"] = sem1 195 | 196 | return eval_info 197 | 198 | # specific for PMer 199 | def tune_corrs_size_to_eval(self, corrs, match_W, match_H, eval_W, eval_H): 200 | """ 201 | """ 202 | eval_corrs = tune_corrs_size(corrs, match_W, match_H, eval_W, eval_H) 203 | return eval_corrs 204 | -------------------------------------------------------------------------------- /demo/color/4119.965344.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/demo/color/4119.965344.png -------------------------------------------------------------------------------- /demo/color/4120.813199.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/demo/color/4120.813199.png -------------------------------------------------------------------------------- /demo/intrins/4119.965344.txt: -------------------------------------------------------------------------------- 1 | 726.28741455078 726.28741455078 354.6496887207 186.46566772461 2 | -------------------------------------------------------------------------------- /demo/intrins/4120.813199.txt: -------------------------------------------------------------------------------- 1 | 726.28741455078 726.28741455078 354.6496887207 186.46566772461 2 | -------------------------------------------------------------------------------- /demo/samres/4119.965344.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/demo/samres/4119.965344.npy -------------------------------------------------------------------------------- /demo/samres/4120.813199.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/demo/samres/4120.813199.npy -------------------------------------------------------------------------------- /geo_area_matchers/MatchSampler.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append('../') 4 | 5 | import os 6 | import numpy as np 7 | import cv2 8 | import random 9 | from loguru import logger 10 | 11 | from utils.geo import ( 12 | cal_corr_F_and_mean_sd_rt_sd, 13 | list_of_corrs2corr_list, 14 | ) 15 | 16 | from utils.vis import ( 17 | plot_matches_lists_lr, 18 | ) 19 | 20 | class BasicMatchSampler(object): 21 | """Basic Sampler for the MatchSampler 22 | Basic Flow: 23 | 24 | """ 25 | dft_configs = { 26 | "W0": 640, 27 | "H0": 480, 28 | "W1": 640, 29 | "H1": 480, 30 | "out_path": "", 31 | "sample_num": 1000, 32 | "draw_verbose": 0, 33 | } 34 | 35 | def __init__(self, configs) -> None: 36 | """ 37 | """ 38 | self.configs = {**self.dft_configs, **configs} 39 | self.W0 = self.configs["W0"] 40 | self.H0 = self.configs["H0"] 41 | self.W1 = self.configs["W1"] 42 | self.H1 = self.configs["H1"] 43 | self.sample_num = self.configs["sample_num"] 44 | 45 | self.total_corrs_list = None 46 | self.inside_area_corrs = None 47 | self.global_corrs = None 48 | self.sampled_corrs = None 49 | self.sampled_corrs_rand = None 50 | self.name = "" 51 | 52 | self.img0 = None 53 | self.img1 = None 54 | 55 | self.out_path = self.configs["out_path"] 56 | self.draw_verbose = self.configs["draw_verbose"] 57 | 58 | def load_corrs_from_GAM(self, corrs_list): 59 | """ Load the correspondences from the GAM 60 | Args: 61 | corrs_list (list of list): list of correspondences 62 | the last one is the global correspondences 63 | """ 64 | self.total_corrs_list = corrs_list 65 | self.inside_area_corrs = corrs_list[:-1] 66 | self.global_corrs = corrs_list[-1] 67 | 68 | def load_ori_imgs(self, img0, img1): 69 | """ Load the original images in eval from size 70 | Args: 71 | img0 (np.ndarray): original image 0 72 | img1 (np.ndarray): original image 1 73 | """ 74 | 75 | # if gray image, convert to color image 76 | if len(img0.shape) == 2: 77 | self.img0 = cv2.cvtColor(img0, cv2.COLOR_GRAY2BGR) 78 | else: 79 | self.img0 = img0 80 | 81 | if len(img1.shape) == 2: 82 | self.img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR) 83 | else: 84 | self.img1 = img1 85 | 86 | def draw_before_sample(self): 87 | """ 88 | """ 89 | assert self.img0 is not None, "Please load the original images first" 90 | assert self.img1 is not None, "Please load the original images first" 91 | assert self.total_corrs_list is not None, "Please load the correspondences first" 92 | 93 | if self.draw_verbose: 94 | temp_all_corrs = list_of_corrs2corr_list(self.total_corrs_list) 95 | plot_matches_lists_lr(self.img0, self.img1, temp_all_corrs, self.out_path, name=f"{self.name}_before_sample_all") 96 | 97 | # random sample 98 | if len(temp_all_corrs) <= self.sample_num: 99 | temp_all_corrs_meet_num = temp_all_corrs 100 | else: 101 | temp_all_corrs_meet_num = random.sample(temp_all_corrs, self.sample_num) 102 | plot_matches_lists_lr(self.img0, self.img1, temp_all_corrs_meet_num, self.out_path, name=f"{self.name}_before_sample_{self.sample_num}") 103 | 104 | def draw_after_sample(self, sampled_corrs): 105 | assert self.img0 is not None, "Please load the original images first" 106 | assert self.img1 is not None, "Please load the original images first" 107 | assert self.sampled_corrs is not None, "Please sample the correspondences first" 108 | 109 | if self.draw_verbose: 110 | plot_matches_lists_lr(self.img0, self.img1, self.sampled_corrs, self.out_path, name=f"{self.name}_after_sample_all") 111 | plot_matches_lists_lr(self.img0, self.img1, self.sampled_corrs_rand, self.out_path, name=f"{self.name}_after_sample_{self.sample_num}_rand") 112 | 113 | def sample(self): 114 | """ Sample the correspondences 115 | Args: 116 | num_samples (int): number of samples 117 | """ 118 | raise NotImplementedError 119 | 120 | 121 | 122 | class GridFillSampler(BasicMatchSampler): 123 | """ 124 | """ 125 | 126 | def __init__(self, configs) -> None: 127 | super().__init__(configs) 128 | 129 | # specific params 130 | self.occ_size = self.configs["occ_size"] 131 | self.common_occ_flag = self.configs["common_occ_flag"] 132 | 133 | self.occ_img0 = np.zeros((self.H0, self.W0), dtype=np.uint8) 134 | self.occ_img1 = np.zeros((self.H1, self.W1), dtype=np.uint8) 135 | 136 | def sample(self): 137 | """ 138 | Returns: 139 | sampled_corrs (list of corr): sampled correspondences 140 | """ 141 | assert self.inside_area_corrs is not None, "Please load the correspondences first" 142 | assert self.global_corrs is not None, "Please load the correspondences first" 143 | 144 | num_samples = self.sample_num 145 | 146 | self.draw_before_sample() 147 | 148 | # calc F and Sampson Distance for every correspondence inside the area 149 | temp_inside_corrs = list_of_corrs2corr_list(self.inside_area_corrs) 150 | if len(temp_inside_corrs) <= 10: 151 | logger.error(f"Too few correspondences inside the area, only {len(temp_inside_corrs)}") 152 | return None, None 153 | 154 | F, mean_sd, rt_sd_list = cal_corr_F_and_mean_sd_rt_sd(temp_inside_corrs) 155 | 156 | # sort the correspondences by Sampson Distance (smaller is at the front) 157 | rt_sd_list = np.array(rt_sd_list) 158 | sorted_idx = np.argsort(rt_sd_list) 159 | sorted_inside_area_corrs = [temp_inside_corrs[idx] for idx in sorted_idx] 160 | 161 | # fill the occ img with inside area correspondences 162 | sampled_corrs_inside = self.fill_occ_img(sorted_inside_area_corrs) 163 | logger.info(f"Sampled {len(sampled_corrs_inside)} correspondences inside the area") 164 | 165 | # fill the occ img with global correspondences 166 | sampled_corrs_global = self.fill_occ_img(self.global_corrs) 167 | logger.info(f"Sampled {len(sampled_corrs_global)} correspondences outside the area") 168 | 169 | # fuse 170 | sampled_corrs = sampled_corrs_inside + sampled_corrs_global 171 | logger.info(f"Sampled {len(sampled_corrs)} correspondences in total") 172 | 173 | self.sampled_corrs = sampled_corrs 174 | 175 | # random sample self.sample_num correspondences 176 | if len(sampled_corrs) <= num_samples: 177 | sampled_corrs_rand = sampled_corrs 178 | else: 179 | sampled_corrs_rand = random.sample(sampled_corrs, num_samples) 180 | 181 | self.sampled_corrs_rand = sampled_corrs_rand 182 | 183 | self.draw_after_sample(sampled_corrs) 184 | 185 | return sampled_corrs, sampled_corrs_rand 186 | 187 | def fill_occ_img(self, sorted_corrs): 188 | """ 189 | Args: 190 | sorted_corrs (list of corr): sorted correspondences 191 | Returns: 192 | sampled_corrs (list of corr): sampled correspondences 193 | """ 194 | sorted_corrs_np = np.array(sorted_corrs) 195 | sampled_corrs = [] 196 | 197 | for corr in sorted_corrs_np: 198 | corr_int = corr.astype(np.int64) 199 | u0, v0 = corr_int[:2] 200 | u1, v1 = corr_int[2:] 201 | 202 | if u0 < self.occ_size or u0 >= self.W0-self.occ_size or v0 < self.occ_size or v0 >= self.H0-self.occ_size: 203 | continue 204 | 205 | if u1 < self.occ_size or u1 >= self.W1-self.occ_size or v1 < self.occ_size or v1 >= self.H1-self.occ_size: 206 | continue 207 | 208 | if self.common_occ_flag: 209 | try: 210 | if self.occ_img0[v0, u0] == 0 and self.occ_img1[v1, u1] == 0: 211 | self.occ_img0[v0-self.occ_size:v0+self.occ_size, u0-self.occ_size:u0+self.occ_size] = 1 212 | self.occ_img1[v1-self.occ_size:v1+self.occ_size, u1-self.occ_size:u1+self.occ_size] = 1 213 | sampled_corrs.append(corr) 214 | else: 215 | continue 216 | except IndexError as e: 217 | logger.error(f"IndexError: {corr}, W0 {self.W0}, H0 {self.H0}, W1 {self.W1}, H1 {self.H1}") 218 | else: 219 | if self.occ_img0[v0, u0] == 0 or self.occ_img1[v1, u1] == 0: 220 | self.occ_img0[v0-self.occ_size:v0+self.occ_size, u0-self.occ_size:u0+self.occ_size] = 1 221 | self.occ_img1[v1-self.occ_size:v1+self.occ_size, u1-self.occ_size:u1+self.occ_size] = 1 222 | sampled_corrs.append(corr) 223 | else: 224 | continue 225 | 226 | return sampled_corrs 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /geo_area_matchers/abstract_gam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-19 22:34:22 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-28 11:30:07 6 | FilePath: /SA2M/hydra-mesa/geo_area_matchers/abstract_gam.py 7 | Description: abstract geo area matcher for post-processing 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import numpy as np 13 | import abc 14 | import os 15 | 16 | from utils.common import test_dir_if_not_create 17 | 18 | class AbstractGeoAreaMatcher(abc.ABC): 19 | def __init__(self) -> None: 20 | self.initialized = False 21 | pass 22 | 23 | @abc.abstractmethod 24 | def name(self) -> str: 25 | pass 26 | 27 | @abc.abstractmethod 28 | def init_dataloader(self, dataloader): 29 | raise NotImplementedError 30 | 31 | @abc.abstractmethod 32 | def load_point_matcher(self, point_matcher): 33 | raise NotImplementedError 34 | 35 | @abc.abstractmethod 36 | def load_ori_corrs(self, ori_corrs): 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def init_gam(self): 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def geo_area_matching_refine(self, matched_areas0, matched_areas1): 45 | """ Main Func 46 | Returns: 47 | alpha_corrs_dict: dict, inside-area corrs under each alpha 48 | alpha_inlier_idxs_dict: dict, inlier idxs of input areas under each alpha 49 | """ 50 | raise NotImplementedError 51 | 52 | @abc.abstractmethod 53 | def doubtful_area_match_predict(self, doubt_match_pairs): 54 | """ 55 | """ 56 | raise NotImplementedError 57 | 58 | 59 | def set_outpath(self, outpath: str): 60 | """Run after init_dataloader 61 | """ 62 | self.out_path = os.path.join(outpath, f"{self.scene_name}_{self.name0}_{self.name1}", 'gam') 63 | if self.draw_verbose == 1: 64 | test_dir_if_not_create(self.out_path) -------------------------------------------------------------------------------- /metric/eval_ratios.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-28 19:30:20 4 | LastEditors: Easonyesheng preacher@sjtu.edu.cn 5 | LastEditTime: 2024-07-31 20:29:17 6 | FilePath: /SA2M/hydra-mesa/metric/eval_ratios.py 7 | Description: scripts for evaluation of ratios 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import sys 13 | sys.path.append('..') 14 | 15 | import os 16 | import numpy as np 17 | from loguru import logger 18 | 19 | from metric.Evaluation import AMEval, PoseAUCEval, TimeEval, MMAEval 20 | 21 | root_path = '/opt/data/private/A2PM-git/A2PM-MESA/res' 22 | 23 | # specific 24 | root_folder = 'dmesa-dkm-md-eval-res' 25 | root_folder = 'dmesa-dkm-scannet-res' 26 | root_folder = 'mesa-f-dkm-md-eval-res' 27 | # root_folder = 'mesa-f-dkm-sn-eval-res' 28 | 29 | 30 | 31 | 32 | baseline_name = 'pm' 33 | challenger_name = 'a2pm' 34 | folder_name = 'ratios' 35 | phis = ['0.5', '2.0', '3.0', '3.5', '5.0'] 36 | 37 | 38 | 39 | log_folder = f"{root_path}/{root_folder}" 40 | logger.add(f"{log_folder}/AMEval/res.log", rotation="500 MB", level="INFO", retention="10 days") 41 | 42 | # Pose Error Eval 43 | # specific 44 | output_path = os.path.join(root_path, root_folder, folder_name) 45 | 46 | pose_eval_cfg = { 47 | 'root_path': os.path.join(root_path, root_folder), 48 | 'folder_name': folder_name, 49 | 'baseline_name': baseline_name, 50 | 'challenger_name': challenger_name, 51 | 'phi_list': phis, 52 | 'output_path': output_path, 53 | } 54 | 55 | pose_eval = PoseAUCEval(pose_eval_cfg) 56 | pose_eval.run() 57 | 58 | 59 | # AMEval 60 | try: 61 | am_name = 'am' 62 | gam_name = 'gam' 63 | AMP_Thd = [0.6, 0.7, 0.8] 64 | res_folder = 'ratios' 65 | 66 | am_eval_cfg = { 67 | 'root_path': os.path.join(root_path, root_folder), 68 | 'name': am_name, 69 | 'AMP_Thd': AMP_Thd, 70 | } 71 | 72 | am_eval = AMEval(am_eval_cfg) 73 | am_eval.run_AMEval() 74 | 75 | for phi in phis: 76 | am_eval_cfg['name'] = f'{am_name}+{gam_name}-{phi}' 77 | am_eval = AMEval(am_eval_cfg) 78 | am_eval.run_AMEval() 79 | except Exception as e: 80 | logger.error(f"AMEval failed: {e}") 81 | -------------------------------------------------------------------------------- /metric/instance_eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-19 21:05:36 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-18 17:06:50 6 | FilePath: /SA2M/hydra-mesa/metric/instance_eval.py 7 | Description: the evaluator for instance-level metrics, including 8 | - area matching metrics TODO: 9 | - area overlap ratio (AOR) 10 | - point matching metrics TODO: 11 | - MMA w/ depth 12 | - MMA w/o depth 13 | - pose estimation metrics 14 | - pose error 15 | 16 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 17 | ''' 18 | 19 | import os 20 | import numpy as np 21 | from loguru import logger 22 | import random 23 | from typing import List, Optional, Any, Tuple 24 | 25 | from utils.geo import ( 26 | tune_corrs_size_diff, 27 | nms_for_corrs, 28 | compute_pose_error_simp, 29 | calc_area_match_performence_eff_MC, 30 | assert_match_reproj, 31 | assert_match_qFp, 32 | ) 33 | 34 | from utils.vis import plot_matches_with_mask_ud 35 | 36 | from utils.common import test_dir_if_not_create 37 | 38 | class InstanceEval(object): 39 | """ params are loaded when using 40 | """ 41 | 42 | def __init__(self, 43 | sample_mode, 44 | eval_corr_num, 45 | sac_mode, 46 | out_path, 47 | draw_verbose=False, 48 | ) -> None: 49 | self.eval_info = None 50 | self.sample_mode = sample_mode 51 | self.eval_corr_num = eval_corr_num 52 | self.sac_mode = sac_mode 53 | self.out_path = os.path.join(out_path, 'ratios') 54 | self.draw_verbose = draw_verbose 55 | test_dir_if_not_create(self.out_path) 56 | 57 | def init_data_loader(self, dataloader, eval_W, eval_H): 58 | """ 59 | """ 60 | try: 61 | self.eval_info = dataloader.get_eval_info(eval_W, eval_H) 62 | self.instance_name = f'{dataloader.scene_name}_{dataloader.image_name0}_{dataloader.image_name1}' 63 | except Exception as e: 64 | logger.error(f'Error: {e}') 65 | 66 | def eval_area_overlap_ratio(self, 67 | areas0, 68 | areas1, 69 | pre_name, 70 | ): 71 | """ 72 | Args: 73 | areas0/1: list of [u_min, u_max, v_min, v_max] 74 | """ 75 | assert self.eval_info is not None, 'Error: eval_info is not loaded, init dataloader pls' 76 | assert self.eval_info['dataset_name'] in ['ScanNet'], 'Error: dataset not supported' 77 | 78 | if len(areas0) < 1 or len(areas1) < 1: 79 | logger.error(f'Error: areas0/1 is empty') 80 | return 81 | 82 | try: 83 | acr, aor = calc_area_match_performence_eff_MC( 84 | areas0, 85 | areas1, 86 | self.eval_info['image0'], 87 | self.eval_info['image1'], 88 | self.eval_info['K0'], 89 | self.eval_info['K1'], 90 | self.eval_info['P0'], 91 | self.eval_info['P1'], 92 | self.eval_info['depth0'], 93 | self.eval_info['depth1'], 94 | self.eval_info['depth_factor'], 95 | ) 96 | mean_aor = np.mean(aor) 97 | except Exception as e: 98 | logger.error(f'Error: {e}') 99 | raise e 100 | 101 | # write to file 102 | name_file = os.path.join(self.out_path, f'{pre_name}_ameval_names.txt') 103 | aor_file = os.path.join(self.out_path, f'{pre_name}_aor.txt') 104 | acr_file = os.path.join(self.out_path, f'{pre_name}_acr.txt') 105 | 106 | # read names 107 | if not os.path.exists(name_file): 108 | exist_names = [] 109 | else: 110 | with open(name_file, 'r') as f: 111 | exist_names = f.readlines() 112 | exist_names = [name.strip() for name in exist_names] 113 | 114 | if self.instance_name in exist_names: 115 | pass 116 | else: 117 | with open(name_file, 'a') as f: 118 | f.write(f'{self.instance_name}\n') 119 | with open(aor_file, 'a') as f: 120 | f.write(f'{mean_aor}\n') 121 | with open(acr_file, 'a') as f: 122 | f.write(f'{acr}\n') 123 | 124 | def eval_point_match(self, 125 | corrs, 126 | pre_name, 127 | thds=[1,3,5], 128 | ): 129 | """ 130 | Args: 131 | corrs: all corrs should be in the eval size 132 | """ 133 | assert self.eval_info is not None, 'Error: eval_info is not loaded, init dataloader pls' 134 | assert self.sample_mode in ['random', 'grid'], f'Error: sample_mode {sample_mode} not supported' 135 | 136 | eval_num = self.eval_corr_num 137 | sample_mode = self.sample_mode 138 | 139 | corr_num = len(corrs) 140 | if corr_num < 10: 141 | logger.error(f'Error: not enough corrs for pose estimation') 142 | return [] 143 | 144 | if corr_num > eval_num: 145 | if sample_mode == 'random': 146 | corrs = random.sample(corrs, eval_num) 147 | elif sample_mode == 'grid': 148 | corrs = np.array(corrs) 149 | corrs = nms_for_corrs(corrs, r=3) 150 | if len(corrs) > eval_num: 151 | corrs = random.sample(corrs, eval_num) 152 | 153 | good_ratios = [] 154 | masks = [] 155 | 156 | try: 157 | dataset_name = self.eval_info['dataset_name'] 158 | pose0 = self.eval_info['P0'] 159 | pose1 = self.eval_info['P1'] 160 | K0 = self.eval_info['K0'] 161 | K1 = self.eval_info['K1'] 162 | image0 = self.eval_info['image0'] # imgs are in eval size 163 | image1 = self.eval_info['image1'] 164 | if dataset_name == "ScanNet": 165 | depth0 = self.eval_info['depth0'] 166 | depth1 = self.eval_info['depth1'] 167 | depth_factor = self.eval_info['depth_factor'] 168 | for thd in thds: 169 | mask, bad_ratio, gt_pts = assert_match_reproj(corrs, 170 | depth0, depth1, depth_factor, 171 | K0, K1, 172 | pose0, pose1, 173 | thd, 0) 174 | good_ratio = (100 - bad_ratio)/100 175 | good_ratios.append(good_ratio) 176 | masks.append(mask) 177 | 178 | elif dataset_name in ['MegaDepth', 'YFCC', 'ETH3D']: 179 | for thd in thds: 180 | mask, bad_ratio = assert_match_qFp(corrs, K0, K1, pose0, pose1, thd) 181 | good_ratio = (1 - bad_ratio) 182 | good_ratios.append(good_ratio) 183 | masks.append(mask) 184 | else: 185 | raise NotImplementedError(f"dataset {dataset_name} not supported") 186 | 187 | # draw match images with masks for corrs 188 | if self.draw_verbose: 189 | # get the upper path of the out_path 190 | up_out_path = os.path.dirname(self.out_path) 191 | pm_out_path = os.path.join(up_out_path, self.instance_name, 'pm') 192 | test_dir_if_not_create(pm_out_path) 193 | for i, thd in enumerate(thds): 194 | plot_matches_with_mask_ud( 195 | image0, image1, 196 | masks[i], corrs, pm_out_path, pre_name+f"_mma_{thd}") 197 | 198 | except Exception as e: 199 | logger.error(f'Error: {e}') 200 | raise e 201 | 202 | 203 | logger.success(f'point matches good ratios: {good_ratios}') 204 | 205 | # write to file 206 | name_file = os.path.join(self.out_path, f'{pre_name}_mma_names.txt') 207 | mma_file = os.path.join(self.out_path, f'{pre_name}_mmas.txt') 208 | 209 | # read names 210 | if not os.path.exists(name_file): 211 | exist_names = [] 212 | else: 213 | with open(name_file, 'r') as f: 214 | exist_names = f.readlines() 215 | exist_names = [name.strip() for name in exist_names] 216 | 217 | if self.instance_name in exist_names: 218 | pass 219 | else: 220 | with open(name_file, 'a') as f: 221 | f.write(f'{self.instance_name}\n') 222 | with open(mma_file, 'a') as f: 223 | for i, _ in enumerate(thds): 224 | f.write(f'{good_ratios[i]} ') 225 | f.write('\n') 226 | 227 | return good_ratios 228 | 229 | 230 | def eval_pose_error(self, 231 | corrs, 232 | pre_name, 233 | ): 234 | """ 235 | Args: 236 | corrs: list in eval size 237 | Returns: 238 | pose error: [R_err, t_err] 239 | """ 240 | assert self.eval_info is not None, 'Error: eval_info is not loaded, init dataloader pls' 241 | assert self.sample_mode in ['random', 'grid'], f'Error: sample_mode {sample_mode} not supported' 242 | 243 | corr_num = len(corrs) 244 | if corr_num < 10: 245 | logger.error(f'Error: not enough corrs for pose estimation') 246 | errs = [180, 180] 247 | 248 | eval_num = self.eval_corr_num 249 | sample_mode = self.sample_mode 250 | if corr_num > eval_num: 251 | if sample_mode == 'random': 252 | corrs = random.sample(corrs, eval_num) 253 | elif sample_mode == 'grid': 254 | corrs = np.array(corrs) 255 | corrs = nms_for_corrs(corrs, r=3) 256 | if len(corrs) > eval_num: 257 | corrs = random.sample(corrs, eval_num) 258 | 259 | try: 260 | dataset_name = self.eval_info['dataset_name'] 261 | pose0 = self.eval_info['P0'] 262 | pose1 = self.eval_info['P1'] 263 | K0 = self.eval_info['K0'] 264 | K1 = self.eval_info['K1'] 265 | if dataset_name == "MegaDepth": 266 | gt_pose = np.matmul(pose1, np.linalg.inv(pose0)) 267 | else: 268 | gt_pose = pose1.I @ pose0 269 | 270 | errs = compute_pose_error_simp(corrs, K0, K1, gt_pose, pix_thd=0.5, conf=0.9999, sac_mode=self.sac_mode) 271 | 272 | except Exception as e: 273 | logger.error(f'Error: {e}') 274 | if e is KeyError: 275 | raise e 276 | else: 277 | errs = [180, 180] 278 | 279 | # write to file 280 | name_file = os.path.join(self.out_path, f'{pre_name}_pose_err_names.txt') 281 | pose_err_file = os.path.join(self.out_path, f'{pre_name}_pose_errs.txt') 282 | 283 | # read names 284 | if not os.path.exists(name_file): 285 | exist_names = [] 286 | else: 287 | with open(name_file, 'r') as f: 288 | exist_names = f.readlines() 289 | exist_names = [name.strip() for name in exist_names] 290 | 291 | if self.instance_name in exist_names: 292 | pass 293 | else: 294 | with open(name_file, 'a') as f: 295 | f.write(f'{self.instance_name}\n') 296 | with open(pose_err_file, 'a') as f: 297 | f.write(f'{errs[0]} {errs[1]}\n') 298 | 299 | return errs 300 | 301 | 302 | @staticmethod 303 | def tune_corrs_size(corrs, 304 | src_W0=None, 305 | src_W1=None, 306 | src_H0=None, 307 | src_H1=None, 308 | dst_W0=None, 309 | dst_W1=None, 310 | dst_H0=None, 311 | dst_H1=None): 312 | """ 313 | """ 314 | return tune_corrs_size_diff(corrs, src_W0, src_W1, src_H0, src_H1, dst_W0, dst_W1, dst_H0, dst_H1) 315 | -------------------------------------------------------------------------------- /point_matchers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer -------------------------------------------------------------------------------- /point_matchers/abstract_point_matcher.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 20:35:29 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-27 21:20:01 6 | FilePath: /SA2M/hydra-mesa/point_matchers/abstract_point_matcher.py 7 | Description: abstrat point matcher class 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import sys 13 | import numpy as np 14 | from typing import List, Optional, Any 15 | import abc 16 | 17 | 18 | class AbstractPointMatcher(abc.ABC): 19 | """ 20 | Abstract class for point matcher 21 | """ 22 | def name(self) -> str: 23 | """ 24 | Return the name of the matcher 25 | """ 26 | try: 27 | return self._name 28 | except AttributeError: 29 | raise NotImplementedError("Not use the abstract class directly") 30 | 31 | def set_corr_num_init(self, num: int): 32 | """ 33 | Set the number of correspondences to be matched 34 | """ 35 | self.match_num = num 36 | 37 | @abc.abstractmethod 38 | def match(self, img0: np.ndarray, img1: np.ndarray, mask0: Optional[Any]=None, mask1: Optional[Any]=None) -> List[List[float]]: 39 | """ 40 | Match two images and return the correspondences 41 | Returns: 42 | self.matched_corrs 43 | """ 44 | pass 45 | 46 | def return_matches(self): 47 | """""" 48 | return self.matched_corrs 49 | 50 | @staticmethod 51 | def convert_matches2list(mkpts0, mkpts1) -> List[List[float]]: 52 | """ 53 | Args: 54 | mkpts0/1: np.ndarray Nx2 55 | Returns: 56 | matches: list [[corr]s] 57 | """ 58 | matches = [] 59 | 60 | assert mkpts0.shape == mkpts1.shape, f"different shape: {mkpts0.shape} != {mkpts1.shape}" 61 | 62 | for i in range(mkpts0.shape[0]): 63 | u0, v0 = mkpts0[i,0], mkpts0[i,1] 64 | u1, v1 = mkpts1[i,0], mkpts1[i,1] 65 | 66 | matches.append([u0, v0, u1, v1]) 67 | 68 | return matches 69 | 70 | -------------------------------------------------------------------------------- /point_matchers/aspanformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 21:03:00 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-25 23:16:34 6 | FilePath: /SA2M/hydra-mesa/point_matchers/aspanformer.py 7 | Description: aspanformer point matcher 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import numpy as np 15 | import cv2 16 | from typing import Any, List, Optional 17 | from loguru import logger 18 | import random 19 | 20 | from .ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer 21 | from .ASpanFormer.src.config.default import get_cfg_defaults 22 | from .ASpanFormer.src.utils.misc import lower_config 23 | 24 | from .abstract_point_matcher import AbstractPointMatcher 25 | 26 | 27 | class ASpanMatcher(AbstractPointMatcher): 28 | """ASpanFormer Matcher Warper 29 | """ 30 | 31 | def __init__( 32 | self, 33 | config_path: str, 34 | weights: str, 35 | dataset_name: str, 36 | ) -> None: 37 | super().__init__() 38 | self._name = "ASpanMatcher" 39 | _default_cfg = get_cfg_defaults() 40 | # indoor 41 | if dataset_name == "ScanNet" or dataset_name == "Matterport3D" or dataset_name == "KITTI" or dataset_name == "ETH3D": 42 | main_cfg_path = f"{config_path}/aspan/indoor/aspan_test.py" 43 | data_config = f"{config_path}/data/scannet_test_1500.py" 44 | # outdoor 45 | elif dataset_name == "MegaDepth" or dataset_name == "YFCC": 46 | main_cfg_path = f"{config_path}/aspan/outdoor/aspan_test.py" 47 | data_config = f"{config_path}/data/megadepth_test_1500.py" 48 | else: 49 | raise NotImplementedError(f"dataset {dataset_name} not implemented") 50 | 51 | _default_cfg.merge_from_file(main_cfg_path) 52 | _default_cfg.merge_from_file(data_config) 53 | 54 | _default_cfg = lower_config(_default_cfg) 55 | matcher = ASpanFormer(config=_default_cfg['aspan']) 56 | matcher.load_state_dict(torch.load(weights)["state_dict"], strict=False) 57 | 58 | self.matcher = matcher.eval().cuda() 59 | 60 | 61 | def match(self, img0: np.ndarray, img1: np.ndarray, mask0: Optional[Any]=None, mask1: Optional[Any]=None) -> List[List[float]]: 62 | """ 63 | Returns: 64 | matched_corrs: list of [u0, v0, u1, v1] 65 | """ 66 | if len(img0.shape) == 3: 67 | img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY) 68 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 69 | 70 | 71 | logger.info(f"img shape is {img0.shape}") 72 | 73 | img_tensor0 = torch.from_numpy(img0 / 255.)[None][None].cuda().float() 74 | img_tensor1 = torch.from_numpy(img1 / 255.)[None][None].cuda().float() 75 | 76 | batch = {"image0": img_tensor0, "image1": img_tensor1} 77 | 78 | if mask0 is not None and mask1 is not None: 79 | mask0 = torch.from_numpy(mask0).cuda() # type: ignore 80 | mask1 = torch.from_numpy(mask1).cuda() # type: ignore 81 | [ts_mask_0, ts_mask_1] = F.interpolate( 82 | torch.stack([mask0, mask1], dim=0)[None].float(), 83 | scale_factor=0.125, 84 | mode='nearest', 85 | recompute_scale_factor=False 86 | )[0].bool().to("cuda") 87 | batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)}) 88 | 89 | 90 | with torch.no_grad(): 91 | self.matcher(batch) 92 | mkpts0 = batch['mkpts0_f'].cpu().numpy() # Nx2 93 | mkpts1 = batch['mkpts1_f'].cpu().numpy() 94 | m_bids = batch['m_bids'].cpu().numpy() 95 | 96 | 97 | self.matched_corrs = self.convert_matches2list(mkpts0, mkpts1) 98 | 99 | if len(self.matched_corrs) > self.match_num: 100 | logger.info(f"sample {self.match_num} corrs from {len(self.matched_corrs)} corrs") 101 | self.matched_corrs = random.sample(self.matched_corrs, self.match_num) 102 | 103 | self.corrs = self.matched_corrs # used in SGAM 104 | 105 | return self.matched_corrs 106 | 107 | def get_coarse_mkpts_c(self, area0, area1): 108 | """ match region and only use coarse level to get coarse mkpts 109 | """ 110 | assert area0.shape == area1.shape 111 | 112 | if len(area0.shape) == 3: 113 | area0 = cv2.cvtColor(area0, cv2.COLOR_BGR2GRAY) 114 | area1 = cv2.cvtColor(area1, cv2.COLOR_BGR2GRAY) 115 | 116 | area0_tensor = torch.from_numpy(area0 / 255.)[None][None].cuda().float() 117 | area1_tensor = torch.from_numpy(area1 / 255.)[None][None].cuda().float() 118 | 119 | batch = {"image0": area0_tensor, "image1": area1_tensor} 120 | 121 | with torch.no_grad(): 122 | self.matcher.coarse_match_mkpts_c(batch) 123 | 124 | conf_matrix = batch["conf_matrix"] 125 | mkpts0_c = batch["mkpts0_c"] 126 | mkpts1_c = batch["mkpts1_c"] 127 | mconf = batch["mconf"] 128 | 129 | return mkpts0_c, mkpts1_c, mconf, conf_matrix -------------------------------------------------------------------------------- /point_matchers/dkm.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | from copy import deepcopy 5 | 6 | import torch 7 | from PIL import Image 8 | import cv2 9 | import numpy as np 10 | import matplotlib.cm as cm 11 | from loguru import logger 12 | import random 13 | import torch.nn.functional as F 14 | 15 | from .DKM.dkm import DKMv3_outdoor, DKMv3_indoor 16 | from .abstract_point_matcher import AbstractPointMatcher 17 | 18 | 19 | class DKMMatcher(AbstractPointMatcher): 20 | """ 21 | """ 22 | def __init__(self, 23 | dataset_name, 24 | weights, 25 | ): 26 | 27 | super().__init__() 28 | 29 | self._name = "DKM" 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | 32 | self.datasetName = dataset_name 33 | if self.datasetName == "ScanNet" or self.datasetName == "KITTI" or self.datasetName == "ETH3D": 34 | self.matcher = DKMv3_indoor(device=device, path_to_weights=weights) 35 | elif self.datasetName == "MegaDepth" or self.datasetName == "YFCC": 36 | self.matcher = DKMv3_outdoor(device=device, path_to_weights=weights) 37 | else: 38 | raise NotImplementedError 39 | 40 | def match(self, img0, img1, mask0=None, mask1=None): 41 | """ 42 | Args: 43 | img0, img1: cv img 44 | NOTE the image size is not in consideration, as the DKM outputs normalized coordinates 45 | """ 46 | 47 | # turn to PIL image 48 | img0 = Image.fromarray(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)) 49 | img1 = Image.fromarray(cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)) 50 | # scale to orignal size 51 | w1, h1 = img0.size 52 | w2, h2 = img1.size 53 | 54 | assert w1 == w2 and h1 == h2, "DKM only support same size image" 55 | 56 | if self.datasetName == "MegaDepth" or self.datasetName == "YFCC" or self.datasetName == "ETH3D": 57 | self.matcher.h_resized = h1 58 | self.matcher.w_resized = w1 # the crop area size 59 | self.matcher.upsample_preds = True 60 | self.matcher.upsample_res = (1152, 1536) 61 | self.matcher.use_soft_mutual_nearest_neighbours = False 62 | elif self.datasetName == "ScanNet" or self.datasetName == "KITTI": 63 | self.matcher.h_resized = h1 64 | self.matcher.w_resized = w1 65 | self.matcher.upsample_preds = False 66 | else: 67 | raise NotImplementedError 68 | 69 | # match 70 | # import time 71 | # start = time.time() 72 | dense_matches, dense_certainty = self.matcher.match(img0, img1) 73 | # logger.info(f"DKM match time: {time.time() - start}") 74 | # sample 75 | sparse_matches,_ = self.matcher.sample( 76 | dense_matches, dense_certainty, 5000 77 | ) 78 | 79 | kpts1 = sparse_matches[:, :2] 80 | kpts1 = ( 81 | torch.stack( 82 | ( 83 | w1 * (kpts1[:, 0] + 1) / 2, 84 | h1 * (kpts1[:, 1] + 1) / 2, 85 | ), 86 | axis=-1, 87 | ) 88 | ) 89 | kpts2 = sparse_matches[:, 2:] 90 | kpts2 = ( 91 | torch.stack( 92 | ( 93 | w2 * (kpts2[:, 0] + 1) / 2, 94 | h2 * (kpts2[:, 1] + 1) / 2, 95 | ), 96 | axis=-1, 97 | ) 98 | ) 99 | 100 | # put kpts into numpy 101 | kpts1 = kpts1.cpu().numpy() 102 | kpts2 = kpts2.cpu().numpy() 103 | 104 | self.matched_corrs = self.convert_matches2list(kpts1, kpts2) 105 | 106 | if len(self.matched_corrs) > self.match_num: 107 | logger.info(f"sample {self.match_num} corrs from {len(self.matched_corrs)} corrs") 108 | self.matched_corrs = random.sample(self.matched_corrs, self.match_num) 109 | 110 | self.corrs = self.matched_corrs # used in SGAM 111 | 112 | return self.matched_corrs -------------------------------------------------------------------------------- /point_matchers/loftr.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import cv2 7 | from typing import Any, List, Optional 8 | from loguru import logger 9 | import random 10 | 11 | 12 | from .abstract_point_matcher import AbstractPointMatcher 13 | from .LoFTR.src.utils.plotting import make_matching_figure 14 | from .LoFTR.src.loftr import LoFTR, default_cfg 15 | from .LoFTR.src.config.default import get_cfg_defaults 16 | from .LoFTR.src.utils.misc import lower_config 17 | 18 | 19 | class LoFTRMatcher(AbstractPointMatcher): 20 | """ LoFTR Matcher Warpper 21 | """ 22 | 23 | def __init__( 24 | self, 25 | config_path: str, 26 | dataset_name: str, 27 | weights: str, 28 | cross_domain: bool = False, 29 | ): 30 | super().__init__() 31 | self._name = "LoFTRMatcher" 32 | _default_cfg = get_cfg_defaults() 33 | 34 | if dataset_name == "ScanNet" or dataset_name == "Matterport3D" or dataset_name == "KITTI" or dataset_name == "ETH3D": 35 | if cross_domain: 36 | main_cfg_path = f"{config_path}/loftr/outdoor/buggy_pos_enc/loftr_ds.py" 37 | else: 38 | main_cfg_path = f"{config_path}/loftr/indoor/scannet/loftr_ds_eval_new.py" 39 | 40 | data_config = f"{config_path}/data/scannet_test_1500.py" 41 | elif dataset_name == "MegaDepth" or dataset_name == "YFCC": 42 | if cross_domain: 43 | main_cfg_path = f"{config_path}/loftr/indoor/scannet/loftr_ds_eval_new.py" 44 | else: 45 | main_cfg_path = f"{config_path}/loftr/outdoor/buggy_pos_enc/loftr_ds.py" 46 | 47 | data_config = f"{config_path}/data/megadepth_test_1500.py" 48 | else: 49 | raise NotImplementedError(f"dataset_name {dataset_name} not implemented") 50 | 51 | 52 | _default_cfg.merge_from_file(main_cfg_path) 53 | _default_cfg.merge_from_file(data_config) 54 | _default_cfg = lower_config(_default_cfg) 55 | 56 | # _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt 57 | 58 | self._default_cfg = _default_cfg['loftr'] 59 | 60 | torch.cuda.set_device(0) 61 | matcher = LoFTR(config=_default_cfg['loftr']) 62 | matcher.load_state_dict(torch.load(weights)["state_dict"]) 63 | self.matcher = matcher.eval().cuda() 64 | 65 | def match(self, img0, img1, mask0=None, mask1=None): 66 | """for SGAMer""" 67 | if len(img0.shape) == 3: 68 | img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY) 69 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 70 | 71 | img_tensor0 = torch.from_numpy(img0)[None][None].cuda() / 255. 72 | img_tensor1 = torch.from_numpy(img1)[None][None].cuda() / 255. 73 | 74 | batch = {"image0": img_tensor0, "image1": img_tensor1} 75 | 76 | if mask0 is not None and mask1 is not None: 77 | mask0 = torch.from_numpy(mask0).cuda() 78 | mask1 = torch.from_numpy(mask1).cuda() 79 | [ts_mask_0, ts_mask_1] = F.interpolate( 80 | torch.stack([mask0, mask1], dim=0)[None].float(), 81 | scale_factor=0.125, 82 | mode='nearest', 83 | recompute_scale_factor=False 84 | )[0].bool().to("cuda") 85 | batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)}) 86 | 87 | with torch.no_grad(): 88 | self.matcher(batch) 89 | mkpts0 = batch['mkpts0_f'].cpu().numpy() # Nx2 90 | mkpts1 = batch['mkpts1_f'].cpu().numpy() 91 | m_bids = batch['m_bids'].cpu().numpy() 92 | mask = m_bids == 0 # only one batch 93 | mkpts0 = mkpts0[mask] 94 | mkpts1 = mkpts1[mask] 95 | 96 | self.matched_corrs = self.convert_matches2list(mkpts0, mkpts1) 97 | 98 | if len(self.matched_corrs) > self.match_num: 99 | logger.info(f"sample {self.match_num} corrs from {len(self.matched_corrs)} corrs") 100 | self.matched_corrs = random.sample(self.matched_corrs, self.match_num) 101 | 102 | self.corrs = self.matched_corrs # used in SGAM 103 | 104 | logger.info(f"matched corrs num is {len(self.matched_corrs)}") 105 | 106 | return self.matched_corrs 107 | 108 | def get_coarse_mkpts_c(self, area0, area1): 109 | """ match region and only use coarse level to get coarse mkpts 110 | """ 111 | assert area0.shape == area1.shape 112 | 113 | if len(area0.shape) == 3: 114 | area0 = cv2.cvtColor(area0, cv2.COLOR_BGR2GRAY) 115 | area1 = cv2.cvtColor(area1, cv2.COLOR_BGR2GRAY) 116 | 117 | area0_tensor = torch.from_numpy(area0 / 255.)[None][None].cuda().float() 118 | area1_tensor = torch.from_numpy(area1 / 255.)[None][None].cuda().float() 119 | 120 | batch = {"image0": area0_tensor, "image1": area1_tensor} 121 | 122 | with torch.no_grad(): 123 | self.matcher.coarse_match_mkpts_c(batch) 124 | 125 | conf_matrix = batch["conf_matrix"] 126 | mkpts0_c = batch["mkpts0_c"] 127 | mkpts1_c = batch["mkpts1_c"] 128 | mconf = batch["mconf"] 129 | 130 | return mkpts0_c, mkpts1_c, mconf, conf_matrix -------------------------------------------------------------------------------- /point_matchers/spsg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-13 22:05:13 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-29 11:03:39 6 | FilePath: /SA2M/hydra-mesa/point_matchers/spsg.py 7 | Description: SuperPoint + SuperGlue 8 | 9 | 10 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ''' 12 | 13 | import os 14 | from copy import deepcopy 15 | 16 | import torch 17 | import cv2 18 | import numpy as np 19 | from loguru import logger 20 | import random 21 | import torch.nn.functional as F 22 | 23 | from .abstract_point_matcher import AbstractPointMatcher 24 | from .SuperGluePretrainedNetwork.models.matching import Matching 25 | 26 | class SPSGMatcher(AbstractPointMatcher): 27 | """ 28 | Specific: 29 | configs["weights"]: indoor or outdoor 30 | """ 31 | 32 | def __init__(self, 33 | weights, 34 | dataset_name="ScanNet" 35 | ): 36 | super().__init__() 37 | 38 | """specific""" 39 | self._name = "SPSG" 40 | self.SG_weights = weights 41 | 42 | # use superglue recommended settings 43 | # indoor 44 | dataset = dataset_name 45 | if dataset == "ScanNet" or dataset == "Matterport3D" or dataset == "KITTI" or dataset == "ETH3D": 46 | config = { 47 | 'superpoint': { 48 | 'nms_radius': 4, 49 | 'keypoint_threshold': 0.005, 50 | 'max_keypoints': 1024, 51 | }, 52 | 'superglue': { 53 | 'weights': self.SG_weights, 54 | 'sinkhorn_iterations': 20, 55 | 'match_threshold': 0.2, 56 | } 57 | } 58 | 59 | elif dataset == "MegaDepth" or dataset == "YFCC": 60 | config = { 61 | 'superpoint': { 62 | 'nms_radius': 4, 63 | 'keypoint_threshold': 0.005, 64 | 'max_keypoints': 1024, 65 | }, 66 | 'superglue': { 67 | 'weights': self.SG_weights, 68 | 'sinkhorn_iterations': 20, 69 | 'match_threshold': 0.2, 70 | } 71 | } 72 | else: 73 | config = {} 74 | raise NotImplementedError(f"dataset {dataset} not implemented") 75 | 76 | # print(f"detector_config: {detector_config}") 77 | self.matcher = Matching(config).eval().to('cuda') 78 | 79 | def match(self, img0, img1, mask0=None, mask1=None): 80 | """ 81 | """ 82 | # detect keypoints 83 | # turn to gray scale 84 | if len(img0.shape) == 3: 85 | img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY) 86 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) 87 | 88 | img0 = torch.from_numpy(img0/255.).float()[None][None].to('cuda') 89 | img1 = torch.from_numpy(img1/255.).float()[None][None].to('cuda') 90 | pred = self.matcher({'image0': img0, 'image1': img1}) 91 | pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()} 92 | kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] 93 | matches, conf = pred['matches0'], pred['matching_scores0'] 94 | 95 | valid = matches > -1 96 | mkpts0 = kpts0[valid] 97 | mkpts1 = kpts1[matches[valid]] 98 | mconf = conf[valid] 99 | 100 | self.matched_corrs = self.convert_matches2list(mkpts0, mkpts1) 101 | 102 | if len(self.matched_corrs) > self.match_num: 103 | logger.info(f"sample {self.match_num} corrs from {len(self.matched_corrs)} corrs") 104 | self.matched_corrs = random.sample(self.matched_corrs, self.match_num) 105 | 106 | self.corrs = deepcopy(self.matched_corrs) 107 | 108 | return self.matched_corrs 109 | 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | opencv_python==4.8.0.76 3 | albumentations==0.5.1 --no-binary=imgaug,albumentations 4 | ray>=1.0.1 5 | einops==0.3.0 6 | kornia==0.4.1 7 | loguru==0.5.3 8 | yacs>=0.1.8 9 | tqdm 10 | autopep8 11 | pylint 12 | ipython 13 | jupyterlab 14 | matplotlib 15 | h5py==3.1.0 16 | pytorch-lightning==1.3.5 17 | torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461 18 | joblib>=1.0.1 19 | hydra-core 20 | openpyxl 21 | loguru 22 | pandas 23 | seaborn 24 | scikit-learn 25 | scipy==1.9.0 26 | PyMaxflow==1.3.0 27 | pyquaternion 28 | timm 29 | -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 20:31:50 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-11-03 15:05:05 6 | FilePath: /SA2M/hydra-mesa/scripts/demo.py 7 | Description: test hydra-powered a2pm 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import sys 12 | sys.path.append('..') 13 | 14 | import os 15 | import hydra 16 | from omegaconf import DictConfig, OmegaConf 17 | from loguru import logger 18 | 19 | from point_matchers.abstract_point_matcher import AbstractPointMatcher 20 | from dataloader.abstract_dataloader import AbstractDataloader 21 | from area_matchers.abstract_am import AbstractAreaMatcher 22 | from geo_area_matchers.abstract_gam import AbstractGeoAreaMatcher 23 | from utils.common import validate_type 24 | from utils.geo import list_of_corrs2corr_list 25 | 26 | import random 27 | # fix random seed 28 | random.seed(2) 29 | 30 | import torch 31 | # fix random seed 32 | torch.manual_seed(2) 33 | 34 | # @hydra.main(version_base=None, config_path="../conf", config_name="a2pm_mesa_egam_dkm_scannet") 35 | # @hydra.main(version_base=None, config_path="../conf", config_name="a2pm_mesa_egam_spsg_scannet") 36 | @hydra.main(version_base=None, config_path="../conf") 37 | def test(cfg: DictConfig) -> None: 38 | """ 39 | Test point matcher 40 | """ 41 | 42 | # set full error 43 | os.environ["HYDRA_FULL_ERROR"] = '1' 44 | 45 | OmegaConf.resolve(cfg) 46 | 47 | if cfg.verbose==0: 48 | logger.remove() 49 | logger.add(sys.stdout, level="SUCCESS") 50 | elif cfg.verbose==1: 51 | logger.remove() 52 | logger.add(sys.stdout, level="INFO") 53 | else: 54 | raise NotImplementedError(f"verbose {cfg.verbose} not supported") 55 | 56 | logger.info(f"\n{OmegaConf.to_yaml(cfg)}") 57 | 58 | # load point matcher 59 | pmer = hydra.utils.instantiate(cfg.point_matcher) 60 | validate_type(pmer, AbstractPointMatcher) 61 | pmer.set_corr_num_init(cfg.match_num) 62 | 63 | # load dataloader 64 | dataloader = hydra.utils.instantiate(cfg.dataset) 65 | validate_type(dataloader, AbstractDataloader) 66 | 67 | img0, img1, _, _ = dataloader.load_images(cfg.crop_size_W, cfg.crop_size_H) 68 | ori_corrs = pmer.match(img0, img1, None, None) 69 | 70 | # test amer 71 | amer = hydra.utils.instantiate(cfg.area_matcher) 72 | logger.info(f"amer: {amer.name()}") 73 | validate_type(amer, AbstractAreaMatcher) 74 | area_matches0, area_matches1 = amer.area_matching(dataloader, cfg.out_path) # here is the area matches from the area matcher, areas are represented by 4 coordinate list: [u_min, u_max, v_min, v_max] 75 | 76 | logger.success(f"area matching done, area_matches len: {len(area_matches0)}") 77 | 78 | # test gam 79 | gamer = hydra.utils.instantiate(cfg.geo_area_matcher) 80 | validate_type(gamer, AbstractGeoAreaMatcher) 81 | logger.info(f"gamer: {gamer.name()}") 82 | gamer.init_gam( 83 | dataloader=dataloader, 84 | point_matcher=pmer, 85 | ori_corrs=ori_corrs, 86 | out_path=cfg.out_path 87 | ) 88 | alpha_corrs_dict, alpha_inlier_idxs_dict, _ = gamer.geo_area_matching_refine(area_matches0, area_matches1) 89 | 90 | 91 | 92 | logger.success(f"geo area matching done") 93 | for alpha in alpha_corrs_dict.keys(): 94 | logger.success(f"for alpha: {alpha}, areas num: {len(alpha_inlier_idxs_dict[alpha])}") 95 | # get inlier area matches 96 | inlier_area_matches0 = [area_matches0[i] for i in alpha_inlier_idxs_dict[alpha]] 97 | inlier_area_matches1 = [area_matches1[i] for i in alpha_inlier_idxs_dict[alpha]] 98 | 99 | # NOTE: description about how to use the matched areas 100 | # areas are represented by 4 coordinate list: [u_min, u_max, v_min, v_max] 101 | # the matched areas are stored in inlier_area_matches0 and inlier_area_matches1 102 | # for example, inlier_area_matches0[0]=[u0_min, u0_max, v0_min, v0_max] is matched with inlier_area_matches1[0]=[u1_min, u1_max, v1_min, v1_max] 103 | # you can choose a proper alpha by modifying the `alpha_list` in the config file of the geo area matcher you used 104 | # such as `alpha_list: [0.1, 0.2, 0.3, 0.4, 0.5]` in L17 of `conf/geo_area_matcher/egam.yaml` 105 | 106 | 107 | for alpha in alpha_corrs_dict.keys(): 108 | # draw 109 | corrs = list_of_corrs2corr_list(alpha_corrs_dict[alpha]) 110 | logger.success(f"alpha: {alpha}, corrs num: {len(corrs)}") 111 | #TODO: draw corrs 112 | 113 | if __name__ == "__main__": 114 | test() 115 | pass -------------------------------------------------------------------------------- /scripts/dmesa-dkm-md.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: EasonZhang 4 | # @Date: 2024-07-18 15:36:50 5 | # @LastEditors: Easonyesheng preacher@sjtu.edu.cn 6 | # @LastEditTime: 2024-07-27 11:48:52 7 | # @FilePath: /SA2M/hydra-mesa/scripts/qua-res-generator-dmesa-dkm-md.sh 8 | # @Description: TBD 9 | # 10 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ### 12 | dataset=MegaDepth 13 | cuda_id=1 14 | project_name=dmesa-dkm-md-eval 15 | exp_root_path=/opt/data/private/A2PM-git/A2PM-MESA 16 | 17 | already_done_name_file_folder=${exp_root_path}/res/${project_name}-res/ratios 18 | already_done_name_file=$(ls ${already_done_name_file_folder}/*pose_err_names.txt | head -n 1) 19 | 20 | pair_txt=${exp_root_path}/scripts/megadepth_1500_pairs.txt 21 | scene_name=MegaDepth 22 | 23 | # get the scene name and pair from pair_txt 24 | # the format of pair_txt is: scene_name_pair0_pair1 25 | while read line 26 | do 27 | echo "line: ${line}" 28 | # parse line to pair0 and pair1, line is {pair0} {pair1} 29 | pair0=$(echo ${line} | awk '{print $1}') 30 | pair1=$(echo ${line} | awk '{print $2}') 31 | 32 | echo "pair0: ${pair0}" 33 | echo "pair1: ${pair1}" 34 | 35 | # # get the last part of the pair1, which is separated by _ 36 | # pair1_last=$(echo ${pair1} | awk -F_ '{print $NF}') 37 | 38 | complete_pair_name=MegaDepth_${pair0}_${pair1} 39 | echo "complete_pair_name: ${complete_pair_name}" 40 | 41 | # if $scene_$pair0_$pair1 in already_done_name_file, continue 42 | if [ -f "${already_done_name_file}" ];then 43 | if grep -q "${complete_pair_name}" ${already_done_name_file};then 44 | echo ${complete_pair_name} already done 45 | continue 46 | fi 47 | fi 48 | 49 | echo "performing test on ${complete_pair_name}" 50 | 51 | CUDA_VISIBLE_DEVICES=$cuda_id python test_a2pm.py \ 52 | +experiment=a2pm_dmesa_egam_dkm_megadepth \ 53 | test_area_acc=False \ 54 | test_pm_acc=False \ 55 | verbose=0 \ 56 | name=${project_name} \ 57 | dataset_name=$dataset \ 58 | dataset.scene_name=$scene_name \ 59 | dataset.image_name0=$pair0 \ 60 | dataset.image_name1=$pair1 61 | # break 62 | done < $pair_txt 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /scripts/dmesa-dkm-sn.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: EasonZhang 4 | # @Date: 2024-07-18 14:38:31 5 | # @LastEditors: Easonyesheng preacher@sjtu.edu.cn 6 | # @LastEditTime: 2024-07-29 12:12:16 7 | # @FilePath: /SA2M/hydra-mesa/scripts/qua-res-generator-dmesa-dkm-sn.sh 8 | # @Description: TBD 9 | # 10 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ### 12 | dataset=ScanNet 13 | cuda_id=0 14 | project_name=dmesa-dkm-sn-eval 15 | exp_root_path=/opt/data/private/A2PM-git/A2PM-MESA 16 | 17 | already_done_name_file_folder=${exp_root_path}/res/${project_name}-res/ratios 18 | already_done_name_file=$(ls ${already_done_name_file_folder}/*pose_err_names.txt | head -n 1) 19 | 20 | pair_txt=${exp_root_path}/scripts/scannet_pairs.txt 21 | 22 | # get the scene name and pair from pair_txt 23 | # the format of pair_txt is: scene_name_pair0_pair1 24 | while read line 25 | do 26 | # split line by _ 27 | arr=(${line//_/ }) 28 | echo scene_name = ${arr[0]}_${arr[1]} 29 | echo pair0 = ${arr[2]} 30 | echo pair1 = ${arr[3]} 31 | scene_name=${arr[0]}_${arr[1]} 32 | pair0=${arr[2]} 33 | pair1=${arr[3]} 34 | 35 | # if $scene_name_$pair0_$pair1 in already_done_name_file, continue 36 | if [ -f "${already_done_name_file}" ];then 37 | if grep -q "${scene_name}_${pair0}_${pair1}" ${already_done_name_file};then 38 | echo ${scene_name}_${pair0}_${pair1} already done 39 | continue 40 | fi 41 | fi 42 | 43 | CUDA_VISIBLE_DEVICES=$cuda_id python test_a2pm.py \ 44 | +experiment=a2pm_dmesa_egam_dkm_scannet \ 45 | test_area_acc=False \ 46 | test_pm_acc=False \ 47 | verbose=0 \ 48 | name=${project_name} \ 49 | dataset_name=$dataset \ 50 | dataset.scene_name=$scene_name \ 51 | dataset.image_name0=$pair0 \ 52 | dataset.image_name1=$pair1 53 | # break 54 | done < $pair_txt 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /scripts/mesa-f-dkm-md.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: EasonZhang 4 | # @Date: 2024-07-18 15:36:50 5 | # @LastEditors: Easonyesheng preacher@sjtu.edu.cn 6 | # @LastEditTime: 2024-07-29 10:43:38 7 | # @FilePath: /SA2M/hydra-mesa/scripts/qua-res-generator-dmesa-dkm-md.sh 8 | # @Description: TBD 9 | # 10 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ### 12 | dataset=MegaDepth 13 | cuda_id=1 14 | project_name=mesa-f-dkm-md-eval 15 | exp_root_path=/opt/data/private/A2PM-git/A2PM-MESA 16 | 17 | already_done_name_file_folder=${exp_root_path}/res/${project_name}-res/ratios 18 | already_done_name_file=$(ls ${already_done_name_file_folder}/*pose_err_names.txt | head -n 1) 19 | 20 | pair_txt=${exp_root_path}/scripts/megadepth_1500_pairs.txt 21 | scene_name=MegaDepth 22 | 23 | # get the scene name and pair from pair_txt 24 | # the format of pair_txt is: scene_name_pair0_pair1 25 | while read line 26 | do 27 | echo "line: ${line}" 28 | # parse line to pair0 and pair1, line is {pair0} {pair1} 29 | pair0=$(echo ${line} | awk '{print $1}') 30 | pair1=$(echo ${line} | awk '{print $2}') 31 | 32 | echo "pair0: ${pair0}" 33 | echo "pair1: ${pair1}" 34 | 35 | # # get the last part of the pair1, which is separated by _ 36 | # pair1_last=$(echo ${pair1} | awk -F_ '{print $NF}') 37 | 38 | complete_pair_name=MegaDepth_${pair0}_${pair1} 39 | echo "complete_pair_name: ${complete_pair_name}" 40 | 41 | # if $scene_$pair0_$pair1 in already_done_name_file, continue 42 | if [ -f "${already_done_name_file}" ];then 43 | if grep -q "${complete_pair_name}" ${already_done_name_file};then 44 | echo ${complete_pair_name} already done 45 | continue 46 | fi 47 | fi 48 | 49 | echo "performing test on ${complete_pair_name}" 50 | 51 | CUDA_VISIBLE_DEVICES=$cuda_id python test_a2pm.py \ 52 | +experiment=a2pm_mesa_egam_dkm_megadepth \ 53 | test_area_acc=False \ 54 | test_pm_acc=False \ 55 | verbose=0 \ 56 | name=${project_name} \ 57 | dataset_name=$dataset \ 58 | dataset.scene_name=$scene_name \ 59 | dataset.image_name0=$pair0 \ 60 | dataset.image_name1=$pair1 61 | # break 62 | done < $pair_txt 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /scripts/mesa-f-dkm-sn.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: EasonZhang 4 | # @Date: 2024-07-18 14:38:31 5 | # @LastEditors: Easonyesheng preacher@sjtu.edu.cn 6 | # @LastEditTime: 2024-07-29 10:49:26 7 | # @FilePath: /SA2M/hydra-mesa/scripts/qua-res-generator-dmesa-dkm-sn.sh 8 | # @Description: TBD 9 | # 10 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ### 12 | dataset=ScanNet 13 | cuda_id=0 14 | project_name=mesa-f-dkm-sn-eval 15 | exp_root_path=/opt/data/private/A2PM-git/A2PM-MESA 16 | 17 | already_done_name_file_folder=${exp_root_path}/res/${project_name}-res/ratios 18 | already_done_name_file=$(ls ${already_done_name_file_folder}/*pose_err_names.txt | head -n 1) 19 | 20 | pair_txt=${exp_root_path}/scripts/scannet_pairs.txt 21 | 22 | # get the scene name and pair from pair_txt 23 | # the format of pair_txt is: scene_name_pair0_pair1 24 | while read line 25 | do 26 | # split line by _ 27 | arr=(${line//_/ }) 28 | echo scene_name = ${arr[0]}_${arr[1]} 29 | echo pair0 = ${arr[2]} 30 | echo pair1 = ${arr[3]} 31 | scene_name=${arr[0]}_${arr[1]} 32 | pair0=${arr[2]} 33 | pair1=${arr[3]} 34 | 35 | # if $scene_name_$pair0_$pair1 in already_done_name_file, continue 36 | if [ -f "${already_done_name_file}" ];then 37 | if grep -q "${scene_name}_${pair0}_${pair1}" ${already_done_name_file};then 38 | echo ${scene_name}_${pair0}_${pair1} already done 39 | continue 40 | fi 41 | fi 42 | 43 | CUDA_VISIBLE_DEVICES=$cuda_id python test_a2pm.py \ 44 | +experiment=a2pm_mesa_egam_dkm_scannet \ 45 | test_area_acc=False \ 46 | test_pm_acc=False \ 47 | verbose=0 \ 48 | name=${project_name} \ 49 | dataset_name=$dataset \ 50 | dataset.scene_name=$scene_name \ 51 | dataset.image_name0=$pair0 \ 52 | dataset.image_name1=$pair1 53 | # break 54 | done < $pair_txt 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /scripts/outputs/2024-11-03/15-43-14/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | _target_: dataloader.demo_pair_loader.DemoPairLoader 3 | root_path: /opt/data/private/SA2M/hydra-mesa/demo 4 | scene_name: '' 5 | image_name0: '4119.965344' 6 | image_name1: '4120.813199' 7 | color_folder: color 8 | color_post: png 9 | sem_folder: samres 10 | sem_post: npy 11 | intrin_folder: intrins 12 | intrin_post: txt 13 | point_matcher: 14 | _target_: point_matchers.spsg.SPSGMatcher 15 | dataset_name: ${dataset_name} 16 | weights: indoor 17 | area_matcher: 18 | _target_: area_matchers.dmesa.DMesaAreaMatcher 19 | datasetName: ${dataset_name} 20 | W: ${area_from_size_W} 21 | H: ${area_from_size_H} 22 | coarse_matcher_name: ASpan 23 | level_num: 4 24 | level_step: 25 | - 560 26 | - 390 27 | - 256 28 | - 130 29 | - 0 30 | stop_match_level: 3 31 | area_crop_mode: expand_padding 32 | patch_size_ratio: 0.125 33 | valid_gaussian_width: sqrt2 34 | source_area_selection_mode: direct 35 | iou_fusion_thd: 0.8 36 | patch_match_num_thd: 30 37 | match_mode: pms_GF 38 | coarse_match_all_in_one: 1 39 | dual_match: 1 40 | draw_verbose: ${verbose} 41 | geo_area_matcher: 42 | _target_: geo_area_matchers.egam.EGeoAreaMatcher 43 | datasetName: demo 44 | area_from_size_W: ${area_from_size_W} 45 | area_from_size_H: ${area_from_size_H} 46 | crop_size_W: ${crop_size_W} 47 | crop_size_H: ${crop_size_H} 48 | crop_from_size_W: ${crop_from_size_W} 49 | crop_from_size_H: ${crop_from_size_H} 50 | eval_from_size_W: ${eval_from_size_W} 51 | eval_from_size_H: ${eval_from_size_H} 52 | std_match_num: 1000 53 | alpha_list: 54 | - 0.5 55 | - 2.0 56 | - 3.0 57 | - 3.5 58 | - 5.0 59 | adaptive_size_thd: 1.0 60 | valid_inside_area_match_num: 10 61 | reject_out_area_flag: 0 62 | crop_mode: 0 63 | sac_mode: MAGSAC 64 | sampler_name: '' 65 | occ_size: 1 66 | common_occ_flag: 1 67 | verbose: ${verbose} 68 | evaler: 69 | _target_: metric.instance_eval.InstanceEval 70 | sample_mode: grid 71 | eval_corr_num: ${match_num} 72 | sac_mode: MAGSAC 73 | out_path: ${out_path} 74 | draw_verbose: ${verbose} 75 | name: dmesa-f-egam-spsg-single-pair-demo 76 | dataset_name: ScanNet 77 | area_from_size_W: 640 78 | area_from_size_H: 480 79 | eval_from_size_W: 640 80 | eval_from_size_H: 480 81 | crop_from_size_W: 1296 82 | crop_from_size_H: 968 83 | crop_size_W: 480 84 | crop_size_H: 480 85 | verbose: 1 86 | out_path: /opt/data/private/SA2M/hydra-mesa/res/${name}-res 87 | match_num: 1000 88 | -------------------------------------------------------------------------------- /scripts/outputs/2024-11-03/15-43-14/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | sweep: 5 | dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 6 | subdir: ${hydra.job.num} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | params: null 13 | help: 14 | app_name: ${hydra.job.name} 15 | header: '${hydra.help.app_name} is powered by Hydra. 16 | 17 | ' 18 | footer: 'Powered by Hydra (https://hydra.cc) 19 | 20 | Use --hydra-help to view Hydra specific help 21 | 22 | ' 23 | template: '${hydra.help.header} 24 | 25 | == Configuration groups == 26 | 27 | Compose your configuration from those groups (group=option) 28 | 29 | 30 | $APP_CONFIG_GROUPS 31 | 32 | 33 | == Config == 34 | 35 | Override anything in the config (foo.bar=value) 36 | 37 | 38 | $CONFIG 39 | 40 | 41 | ${hydra.help.footer} 42 | 43 | ' 44 | hydra_help: 45 | template: 'Hydra (${hydra.runtime.version}) 46 | 47 | See https://hydra.cc for more info. 48 | 49 | 50 | == Flags == 51 | 52 | $FLAGS_HELP 53 | 54 | 55 | == Configuration groups == 56 | 57 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 58 | to command line) 59 | 60 | 61 | $HYDRA_CONFIG_GROUPS 62 | 63 | 64 | Use ''--cfg hydra'' to Show the Hydra config. 65 | 66 | ' 67 | hydra_help: ??? 68 | hydra_logging: 69 | version: 1 70 | formatters: 71 | simple: 72 | format: '[%(asctime)s][HYDRA] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: simple 77 | stream: ext://sys.stdout 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | loggers: 83 | logging_example: 84 | level: DEBUG 85 | disable_existing_loggers: false 86 | job_logging: 87 | version: 1 88 | formatters: 89 | simple: 90 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 91 | handlers: 92 | console: 93 | class: logging.StreamHandler 94 | formatter: simple 95 | stream: ext://sys.stdout 96 | file: 97 | class: logging.FileHandler 98 | formatter: simple 99 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 100 | root: 101 | level: INFO 102 | handlers: 103 | - console 104 | - file 105 | disable_existing_loggers: false 106 | env: {} 107 | mode: RUN 108 | searchpath: [] 109 | callbacks: {} 110 | output_subdir: .hydra 111 | overrides: 112 | hydra: 113 | - hydra.mode=RUN 114 | task: 115 | - +experiment=demo 116 | job: 117 | name: demo 118 | chdir: null 119 | override_dirname: +experiment=demo 120 | id: ??? 121 | num: ??? 122 | config_name: null 123 | env_set: {} 124 | env_copy: [] 125 | config: 126 | override_dirname: 127 | kv_sep: '=' 128 | item_sep: ',' 129 | exclude_keys: [] 130 | runtime: 131 | version: 1.3.2 132 | version_base: '1.3' 133 | cwd: /opt/data/private/A2PM-git/A2PM-MESA/scripts 134 | config_sources: 135 | - path: hydra.conf 136 | schema: pkg 137 | provider: hydra 138 | - path: /opt/data/private/A2PM-git/A2PM-MESA/conf 139 | schema: file 140 | provider: main 141 | - path: '' 142 | schema: structured 143 | provider: schema 144 | output_dir: /opt/data/private/A2PM-git/A2PM-MESA/scripts/outputs/2024-11-03/15-43-14 145 | choices: 146 | experiment: demo 147 | evaler: instance_eval 148 | geo_area_matcher: egam 149 | area_matcher: dmesa 150 | point_matcher: spsg_indoor 151 | dataset: demo_pair 152 | hydra/env: default 153 | hydra/callbacks: null 154 | hydra/job_logging: default 155 | hydra/hydra_logging: default 156 | hydra/hydra_help: default 157 | hydra/help: default 158 | hydra/sweeper: basic 159 | hydra/launcher: basic 160 | hydra/output: default 161 | verbose: false 162 | -------------------------------------------------------------------------------- /scripts/outputs/2024-11-03/15-43-14/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | - +experiment=demo 2 | -------------------------------------------------------------------------------- /scripts/outputs/2024-11-03/15-43-14/demo.log: -------------------------------------------------------------------------------- 1 | [2024-11-03 15:43:16,824][numexpr.utils][INFO] - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. 2 | [2024-11-03 15:43:16,824][numexpr.utils][INFO] - NumExpr defaulting to 8 threads. 3 | -------------------------------------------------------------------------------- /scripts/test_a2pm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 20:31:50 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-26 16:29:24 6 | FilePath: /A2PM-MESA/scripts/test_a2pm.py 7 | Description: test hydra-powered a2pm 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import sys 12 | sys.path.append('..') 13 | 14 | import os 15 | import hydra 16 | from omegaconf import DictConfig, OmegaConf 17 | from loguru import logger 18 | 19 | from point_matchers.abstract_point_matcher import AbstractPointMatcher 20 | from dataloader.abstract_dataloader import AbstractDataloader 21 | from area_matchers.abstract_am import AbstractAreaMatcher 22 | from geo_area_matchers.abstract_gam import AbstractGeoAreaMatcher 23 | from metric.instance_eval import InstanceEval 24 | from utils.common import validate_type 25 | from utils.geo import list_of_corrs2corr_list 26 | 27 | import random 28 | # fix random seed 29 | random.seed(2) 30 | 31 | import torch 32 | # fix random seed 33 | torch.manual_seed(2) 34 | 35 | @hydra.main(version_base=None, config_path="../conf") 36 | def test(cfg: DictConfig) -> None: 37 | """ 38 | Test A2PM 39 | """ 40 | 41 | # set full error 42 | os.environ["HYDRA_FULL_ERROR"] = '1' 43 | 44 | OmegaConf.resolve(cfg) 45 | 46 | if cfg.verbose==0: 47 | logger.remove() 48 | logger.add(sys.stdout, level="SUCCESS") 49 | elif cfg.verbose==1: 50 | logger.remove() 51 | logger.add(sys.stdout, level="INFO") 52 | else: 53 | raise NotImplementedError(f"verbose {cfg.verbose} not supported") 54 | 55 | logger.info(f"\n{OmegaConf.to_yaml(cfg)}") 56 | 57 | # load point matcher 58 | pmer = hydra.utils.instantiate(cfg.point_matcher) 59 | validate_type(pmer, AbstractPointMatcher) 60 | 61 | # load dataloader 62 | dataloader = hydra.utils.instantiate(cfg.dataset) 63 | validate_type(dataloader, AbstractDataloader) 64 | 65 | # load evaler 66 | evaler = hydra.utils.instantiate(cfg.evaler) 67 | evaler.init_data_loader(dataloader, cfg.eval_from_size_W, cfg.eval_from_size_H) 68 | 69 | # test pmer 70 | if cfg.dataset_name in ['ScanNet', 'KITTI']: # no need padding 71 | img0, img1, _, _ = dataloader.load_images(cfg.crop_size_W, cfg.crop_size_H) 72 | pmer.set_corr_num_init(cfg.match_num) 73 | ori_corrs = pmer.match(img0, img1, None, None) 74 | ori_corrs = dataloader.tune_corrs_size_to_eval(ori_corrs, cfg.crop_size_W, cfg.crop_size_H, cfg.eval_from_size_W, cfg.eval_from_size_H) 75 | 76 | elif cfg.dataset_name in ['MegaDepth', 'YFCC', 'ETH3D']: # need padding 77 | img0, mask0, match_in_W0, match_in_H0,\ 78 | img1, mask1, match_in_W1, match_in_H1 = dataloader.load_images(cfg.crop_size_W, cfg.crop_size_H, PMer=True) 79 | pmer.set_corr_num_init(cfg.match_num) 80 | ori_corrs = pmer.match(img0, img1, mask0, mask1) 81 | ori_corrs = dataloader.tune_corrs_size_to_eval(ori_corrs,\ 82 | match_in_W0, match_in_H0,\ 83 | match_in_W1, match_in_H1) # eval size is the same as the original image size 84 | else: 85 | raise NotImplementedError(f"dataset {cfg.dataset_name} not supported") 86 | 87 | # test amer 88 | amer = hydra.utils.instantiate(cfg.area_matcher) 89 | logger.info(f"amer: {amer.name()}") 90 | validate_type(amer, AbstractAreaMatcher) 91 | area_matches0, area_matches1 = amer.area_matching(dataloader, cfg.out_path) 92 | 93 | logger.success(f"area matching done, area_matches len: {len(area_matches0)}") 94 | 95 | if cfg.test_area_acc: 96 | # test area accuracy 97 | evaler.eval_area_overlap_ratio(area_matches0, area_matches1, 'am') 98 | 99 | # test gam 100 | gamer = hydra.utils.instantiate(cfg.geo_area_matcher) 101 | validate_type(gamer, AbstractGeoAreaMatcher) 102 | logger.info(f"gamer: {gamer.name()}") 103 | gamer.init_gam( 104 | dataloader=dataloader, 105 | point_matcher=pmer, 106 | ori_corrs=ori_corrs, 107 | out_path=cfg.out_path 108 | ) 109 | alpha_corrs_dict, alpha_inlier_idxs_dict, _ = gamer.geo_area_matching_refine(area_matches0, area_matches1) 110 | 111 | logger.success(f"geo area matching done") 112 | for alpha in alpha_corrs_dict.keys(): 113 | logger.success(f"for alpha: {alpha}, areas num: {len(alpha_inlier_idxs_dict[alpha])}") 114 | 115 | if cfg.test_area_acc: 116 | # test area accuracy 117 | for alpha in alpha_inlier_idxs_dict.keys(): 118 | # get inlier area for each alpha 119 | inlier_areas0 = [area_matches0[i] for i in alpha_inlier_idxs_dict[alpha]] 120 | inlier_areas1 = [area_matches1[i] for i in alpha_inlier_idxs_dict[alpha]] 121 | evaler.eval_area_overlap_ratio(inlier_areas0, inlier_areas1, f'am+gam-{alpha}') 122 | 123 | # test point matching accuracy 124 | if cfg.test_pm_acc: 125 | thds = cfg.pm_acc_thds 126 | 127 | # for pmer 128 | logger.success(f"ori corrs matching accuracy from pmer {pmer.name()} are: ") 129 | evaler.eval_point_match(ori_corrs, 'pm', thds) 130 | 131 | # for a2pmer 132 | for alpha in alpha_corrs_dict.keys(): 133 | logger.success(f"for alpha: {alpha} of {cfg.name}, matching accuracies are: ") 134 | corrs = list_of_corrs2corr_list(alpha_corrs_dict[alpha]) 135 | evaler.eval_point_match(corrs, f'a2pm-{alpha}', thds) 136 | 137 | 138 | # test pose error 139 | if cfg.test_pose_err: 140 | # for pmer 141 | logger.success(f"ori corrs pose error from pmer {pmer.name()} are: ") 142 | pose_err = evaler.eval_pose_error(ori_corrs, 'pm') 143 | 144 | # for a2pmer 145 | for alpha in alpha_corrs_dict.keys(): 146 | logger.success(f"for alpha: {alpha} of {cfg.name}, pose errors are: ") 147 | corrs = list_of_corrs2corr_list(alpha_corrs_dict[alpha]) 148 | pose_err = evaler.eval_pose_error(corrs, f'a2pm-{alpha}') 149 | 150 | if __name__ == "__main__": 151 | test() 152 | pass -------------------------------------------------------------------------------- /scripts/test_in_dev.sh: -------------------------------------------------------------------------------- 1 | ### 2 | # @Author: EasonZhang 3 | # @Date: 2024-06-17 22:40:17 4 | # @LastEditors: Easonyesheng preacher@sjtu.edu.cn 5 | # @LastEditTime: 2024-07-27 15:35:46 6 | # @FilePath: /SA2M/hydra-mesa/scripts/test_in_dev.sh 7 | # @Description: TBD 8 | # 9 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ### 11 | 12 | # ScanNet 13 | dataset=ScanNet 14 | scene=scene0720_00 15 | pair0=180 16 | pair1=2580 17 | 18 | # MESA+DKM+ScanNet - tested 19 | python test_a2pm.py \ 20 | +experiment=a2pm_mesa_egam_dkm_scannet \ 21 | name=test \ 22 | dataset_name=$dataset \ 23 | dataset.scene_name=$scene \ 24 | dataset.image_name0=$pair0 \ 25 | dataset.image_name1=$pair1 \ 26 | 27 | # # DMESA+DKM+ScanNet - tested 28 | # python test_a2pm.py \ 29 | # +experiment=a2pm_dmesa_egam_dkm_scannet \ 30 | # name=test \ 31 | # dataset_name=$dataset \ 32 | # dataset.scene_name=$scene \ 33 | # dataset.image_name0=$pair0 \ 34 | # dataset.image_name1=$pair1 \ 35 | 36 | # DMESA+SPSG+ScanNet - tested 37 | # python test_a2pm.py \ 38 | # +experiment=a2pm_dmesa_egam_spsg_scannet \ 39 | # name=test \ 40 | # dataset_name=$dataset \ 41 | # dataset.scene_name=$scene \ 42 | # dataset.image_name0=$pair0 \ 43 | # dataset.image_name1=$pair1 \ 44 | 45 | # MESA+SPSG+ScanNet - tested 46 | # python test_a2pm.py \ 47 | # +experiment=a2pm_mesa_egam_spsg_scannet \ 48 | # name=test \ 49 | # dataset_name=$dataset \ 50 | # dataset.scene_name=$scene \ 51 | # dataset.image_name0=$pair0 \ 52 | # dataset.image_name1=$pair1 \ 53 | 54 | # MESA+loftr+ScanNet - tested 55 | # python test_a2pm.py \ 56 | # +experiment=a2pm_mesa_egam_loftr_scannet \ 57 | # name=test \ 58 | # dataset_name=$dataset \ 59 | # dataset.scene_name=$scene \ 60 | # dataset.image_name0=$pair0 \ 61 | # dataset.image_name1=$pair1 \ 62 | 63 | # DMESA+LoFTR+ScanNet - tested 64 | # python test_a2pm.py \ 65 | # +experiment=a2pm_dmesa_egam_loftr_scannet \ 66 | # name=test \ 67 | # dataset_name=$dataset \ 68 | # dataset.scene_name=$scene \ 69 | # dataset.image_name0=$pair0 \ 70 | # dataset.image_name1=$pair1 \ 71 | 72 | ############################################ 73 | # MegaDepth 74 | dataset=MegaDepth 75 | scene=md # no use 76 | pair0='0022_0.1_0.3_1401' 77 | pair1='0022_0.1_0.3_810' 78 | 79 | # # MESA+DKM+MegaDepth - tested 80 | # python test_a2pm.py \ 81 | # +experiment=a2pm_mesa_egam_dkm_megadepth \ 82 | # name=test \ 83 | # dataset_name=$dataset \ 84 | # dataset.scene_name=$scene \ 85 | # dataset.image_name0=$pair0 \ 86 | # dataset.image_name1=$pair1 \ 87 | 88 | # # DMESA+DKM+MegaDepth - tested 89 | # python test_a2pm.py \ 90 | # +experiment=a2pm_dmesa_egam_dkm_megadepth \ 91 | # name=test \ 92 | # dataset_name=$dataset \ 93 | # dataset.scene_name=$scene \ 94 | # dataset.image_name0=$pair0 \ 95 | # dataset.image_name1=$pair1 \ 96 | 97 | 98 | # DMESA+SPSG+MegaDepth - tested 99 | # python test_a2pm.py \ 100 | # +experiment=a2pm_dmesa_egam_spsg_megadepth \ 101 | # name=test \ 102 | # dataset_name=$dataset \ 103 | # dataset.scene_name=$scene \ 104 | # dataset.image_name0=$pair0 \ 105 | # dataset.image_name1=$pair1 \ 106 | 107 | # MESA+SPSG+MegaDepth - tested 108 | # python test_a2pm.py \ 109 | # +experiment=a2pm_mesa_egam_spsg_megadepth \ 110 | # name=test \ 111 | # dataset_name=$dataset \ 112 | # dataset.scene_name=$scene \ 113 | # dataset.image_name0=$pair0 \ 114 | # dataset.image_name1=$pair1 \ 115 | 116 | # DMESA+LoFTR+MegaDepth - tested 117 | # python test_a2pm.py \ 118 | # +experiment=a2pm_dmesa_egam_loftr_megadepth \ 119 | # name=test \ 120 | # dataset_name=$dataset \ 121 | # dataset.scene_name=$scene \ 122 | # dataset.image_name0=$pair0 \ 123 | # dataset.image_name1=$pair1 \ 124 | 125 | # # MESA+LoFTR+MegaDepth - tested 126 | # python test_a2pm.py \ 127 | # +experiment=a2pm_mesa_egam_loftr_megadepth \ 128 | # name=test \ 129 | # dataset_name=$dataset \ 130 | # dataset.scene_name=$scene \ 131 | # dataset.image_name0=$pair0 \ 132 | # dataset.image_name1=$pair1 \ -------------------------------------------------------------------------------- /segmentor/ImgSAMSeg.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Author: EasonZhang 4 | Date: 2023-06-15 17:06:04 5 | LastEditors: EasonZhang 6 | LastEditTime: 2024-09-11 15:43:21 7 | FilePath: /SA2M/hydra-mesa/segmentor/ImgSAMSeg.py 8 | Description: A script to segment the image using SAM 9 | 10 | Copyright (c) 2023 by EasonZhang, All Rights Reserved. 11 | ''' 12 | 13 | import sys 14 | sys.path.append("..") 15 | 16 | import os 17 | 18 | import argparse 19 | import numpy as np 20 | from loguru import logger 21 | logger.remove()#删去import logger之后自动产生的handler,不删除的话会出现重复输出的现象 22 | handler_id = logger.add(sys.stderr, level="INFO")#添加一个可以修改控制的handler 23 | import cv2 24 | 25 | from segmentor.SAMSeger import SAMSeger 26 | 27 | 28 | # current file path 29 | current_path = os.path.dirname(os.path.abspath(__file__)) 30 | 31 | # SAM configs 32 | SAM_configs = { 33 | "SAM_name": "SAM", 34 | "W": 640, 35 | "H": 480, 36 | "sam_model_type": "vit_h", 37 | "sam_model_path": f"{current_path}/../SAM/sam_vit_h_4b8939.pth", 38 | "save_folder": "", 39 | "points_per_side": 16, 40 | } 41 | 42 | # SAM2 configs 43 | SAM2_configs = { 44 | "SAM_name": "SAM2", 45 | "W": 640, 46 | "H": 480, 47 | "sam_model_type": "sam2_hiera_l.yaml", 48 | "sam_model_path": "/opt/data/private/SAM2/segment-anything-2/checkpoints/sam2_hiera_large.pt", 49 | "save_folder": "", 50 | "points_per_side": 16, 51 | } 52 | 53 | def SMASeg(args): 54 | """ 55 | """ 56 | 57 | if args.sam_name == "SAM": 58 | seg_configs = SAM_configs 59 | elif args.sam_name == "SAM2": 60 | seg_configs = SAM2_configs 61 | else: 62 | raise ValueError("Invalid SAM name") 63 | 64 | seg_configs["save_folder"] = args.save_folder 65 | seg_configs["W"] = args.W 66 | seg_configs["H"] = args.H 67 | 68 | sam_seger = SAMSeger(configs=seg_configs) 69 | 70 | img_path = args.img_path 71 | sam_res = sam_seger.segment(img_path=img_path, save_name=args.save_name, save_img_flag=False) 72 | 73 | def args_achieve(): 74 | """ 75 | """ 76 | parser = argparse.ArgumentParser(description="A script to segment the image using SAM") 77 | parser.add_argument("--sam_name", type=str, default="SAM", help="The name of the SAM model") 78 | parser.add_argument("--img_path", type=str, default="/data0/zys/A2PM/data/ScanData/scene0000_00/color/12.jpg", help="The path of the image to be segmented") 79 | parser.add_argument("--save_folder", type=str, default="/data0/zys/A2PM/testAG/res", help="The folder to save the segmented image") 80 | parser.add_argument("--save_name", type=str, default="SAMRes", help="The name of the segmented image") 81 | parser.add_argument("--W", type=int, default=640, help="The width of the image") 82 | parser.add_argument("--H", type=int, default=480, help="The height of the image") 83 | args = parser.parse_args() 84 | return args 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | args = args_achieve() 90 | SMASeg(args) -------------------------------------------------------------------------------- /segmentor/SAMSeger.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2023-05-17 15:57:27 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-09-11 15:14:53 6 | FilePath: /SA2M/hydra-mesa/segmentor/SAMSeger.py 7 | Description: SAM-based Image Segmenter 8 | 9 | Copyright (c) 2023 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import sys 13 | sys.path.append("..") 14 | 15 | import os 16 | import os.path as osp 17 | import numpy as np 18 | import cv2 19 | from loguru import logger 20 | from .seg_utils import MaskViewer 21 | import torch 22 | 23 | # TODO: Modify to your SAM path 24 | from SAM.segment_anything import sam_model_registry, SamAutomaticMaskGenerator 25 | from SAM2.sam2.build_sam import build_sam2 26 | from SAM2.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator 27 | 28 | class SAMSeger(object): 29 | """ 30 | """ 31 | def __init__(self, configs={}) -> None: 32 | """ 33 | Args: 34 | W 35 | H 36 | sam_model_type 37 | sam_model_path 38 | save_folder 39 | points_per_side 40 | """ 41 | self.W = configs["W"] 42 | self.H = configs["H"] 43 | self.SAM_name = configs["SAM_name"] 44 | assert self.SAM_name in ["SAM", "SAM2"] 45 | self.sam_model_type = configs["sam_model_type"] 46 | self.sam_model_path = configs["sam_model_path"] 47 | self.save_folder = configs["save_folder"] 48 | self.points_per_side = configs["points_per_side"] 49 | 50 | if self.SAM_name == "SAM": 51 | self.sam_model = sam_model_registry[self.sam_model_type](checkpoint=self.sam_model_path) 52 | self.sam_mask_generator = SamAutomaticMaskGenerator( 53 | model=self.sam_model, 54 | points_per_side=self.points_per_side, 55 | ) 56 | elif self.SAM_name == "SAM2": 57 | checkpoint = self.sam_model_path 58 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 59 | model_cfg = self.sam_model_type 60 | sam2 = build_sam2(model_cfg, checkpoint, device='cuda', apply_postprocessing=False) 61 | mask_g = SAM2AutomaticMaskGenerator(sam2) 62 | self.sam_mask_generator = mask_g 63 | else: 64 | raise ValueError("SAM_name must be SAM or SAM2") 65 | 66 | self.viewer = MaskViewer(self.save_folder) 67 | 68 | def img_loader_SAM(self, path): 69 | """ 70 | """ 71 | image = cv2.imread(path, -1) 72 | image = cv2.resize(image, (self.W, self.H)) 73 | logger.info(f"load image as {image.shape}, {image.dtype}") 74 | return image 75 | 76 | def img_loader_SAM2(self, path): 77 | import PIL.Image as Image 78 | img = Image.open(path) 79 | img = img.resize((self.W, self.H)) 80 | img = np.array(img.convert('RGB')) 81 | 82 | return img 83 | 84 | def segment(self, img_path, sort_flag=True, save_flag=True, save_img_flag=False, save_name=""): 85 | """ 86 | Args: 87 | img_path 88 | save_name 89 | Returns: 90 | masks : a list of mask 91 | mask = { 92 | segmentation : the mask 93 | area : the area of the mask in pixels 94 | bbox : the boundary box of the mask in XYWH format 95 | predicted_iou : the model's own prediction for the quality of the mask 96 | point_coords : the sampled input point that generated this mask 97 | stability_score : an additional measure of mask quality 98 | crop_box : the crop of the image used to generate this mask in XYWH format 99 | } 100 | """ 101 | if self.SAM_name == "SAM": 102 | img = self.img_loader_SAM(img_path) 103 | elif self.SAM_name == "SAM2": 104 | img = self.img_loader_SAM2(img_path) 105 | 106 | masks = self.sam_mask_generator.generate(img) 107 | 108 | if sort_flag: 109 | masks.sort(key=lambda x: x["area"], reverse=True) 110 | 111 | if save_flag: 112 | if save_name == "": 113 | save_name = osp.splitext(osp.basename(img_path))[0] 114 | save_full_name = osp.join(self.save_folder, save_name) 115 | if not osp.exists(self.save_folder): 116 | logger.info(f"create folder {self.save_folder}") 117 | os.makedirs(self.save_folder) 118 | 119 | np.save(save_full_name, masks) 120 | logger.info(f"save masks to {save_full_name}.npy") 121 | 122 | logger.info(f"segment {img_path} as {len(masks)} masks") 123 | 124 | if save_img_flag: 125 | self.viewer.draw_multi_masks_in_one(masks, self.W, self.H, name=save_name, key="segmentation") 126 | 127 | return masks 128 | 129 | def draw_masks(self, masks): 130 | """ 131 | """ 132 | for i, mask in enumerate(masks): 133 | self._draw_single_mask(mask["segmentation"], f"{i}") 134 | 135 | def _draw_single_mask(self, mask, name, flag=False): 136 | """ 137 | """ 138 | if not flag: return 139 | 140 | img = np.zeros((self.H, self.W), dtype=np.uint8) 141 | img[mask] = 255 142 | save_name = osp.join(self.save_folder, f"filtered_{name}.jpg") 143 | cv2.imwrite(save_name, img) -------------------------------------------------------------------------------- /segmentor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/segmentor/__init__.py -------------------------------------------------------------------------------- /segmentor/sam_seg.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: EasonZhang 4 | # @Date: 2024-07-23 10:44:13 5 | # @LastEditors: EasonZhang 6 | # @LastEditTime: 2024-07-23 10:50:31 7 | # @FilePath: /SA2M/hydra-mesa/segmentor/sam_seg.sh 8 | # @Description: TBD 9 | # 10 | # Copyright (c) 2024 by EasonZhang, All Rights Reserved. 11 | ### 12 | 13 | cuda_id=0 14 | img_path=/opt/data/private/SA2M/hydra-mesa/SAM/demo/src/assets/data/dogs.jpg 15 | save_folder=/opt/data/private/SA2M/hydra-mesa/segmentor/seg_res 16 | save_name=seg_res_dogs.jpg 17 | 18 | CUDA_VISIBLE_DEVICES=$cuda_id python ImgSAMSeg.py --img_path $img_path --save_folder $save_folder --save_name $save_name --sam_name SAM 19 | -------------------------------------------------------------------------------- /segmentor/seg_res/seg_res_dogs.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Easonyesheng/A2PM-MESA/920393b317529b6f3feab2152de5dc61102a6e78/segmentor/seg_res/seg_res_dogs.jpg.npy -------------------------------------------------------------------------------- /segmentor/seg_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-07-19 20:42:13 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-09-11 15:16:03 6 | FilePath: /SA2M/hydra-mesa/segmentor/seg_utils.py 7 | Description: TBD 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | import cv2 13 | import numpy as np 14 | import os.path as osp 15 | 16 | class MaskViewer(object): 17 | """ Mask Visualization 18 | """ 19 | def __init__(self, save_path) -> None: 20 | """ 21 | """ 22 | self.save_path = save_path 23 | 24 | def draw_single_mask(self, mask, bbox, name): 25 | """ 26 | """ 27 | mask_show = mask.astype(np.uint8) * 255 28 | # to color img 29 | mask_show = cv2.cvtColor(mask_show, cv2.COLOR_GRAY2BGR) 30 | # draw bbox 31 | cv2.rectangle(mask_show, (bbox[0], bbox[2]), (bbox[1], bbox[3]), (0, 0, 255), 2) 32 | 33 | cv2.imwrite(osp.join(self.save_path, f"{name}.jpg"), mask_show) 34 | 35 | 36 | def draw_multi_masks_in_one(self, area_info_list, W, H, name="", key="mask"): 37 | """ 38 | """ 39 | masks_show = np.zeros((H, W, 3), dtype=np.uint8) 40 | exsit_colors = [] 41 | 42 | for area_info in area_info_list: 43 | mask = area_info[key].astype(np.uint8) 44 | mask = cv2.resize(mask, (W, H)) 45 | color = np.random.randint(0, 255, size=3) 46 | while tuple(color.tolist()) in exsit_colors: 47 | color = np.random.randint(0, 255, size=3) 48 | masks_show[mask > 0] = color 49 | color = tuple(color.tolist()) 50 | 51 | if key == "mask": 52 | bbox = area_info["area_bbox"] 53 | # turn color to scalar 54 | # draw bbox 55 | cv2.rectangle(masks_show, (bbox[0], bbox[2]), (bbox[1], bbox[3]), color, 2) 56 | exsit_colors.append(color) 57 | 58 | cv2.imwrite(osp.join(self.save_path, f"{name}.png"), masks_show) 59 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-12 22:42:41 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-20 23:33:40 6 | FilePath: /SA2M/hydra-mesa/utils/common.py 7 | Description: TBD 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | 12 | from typing import Any, Type 13 | from loguru import logger 14 | 15 | # from nuplan 16 | def validate_type(instantiated_class: Any, desired_type: Type[Any]) -> None: 17 | """ 18 | Validate that constructed type is indeed the desired one 19 | :param instantiated_class: class that was created 20 | :param desired_type: type that the created class should have 21 | """ 22 | assert isinstance( 23 | instantiated_class, desired_type 24 | ), f"Class to be of type {desired_type}, but is {type(instantiated_class)}!" 25 | 26 | def test_dir_if_not_create(path): 27 | """ create folder 28 | 29 | Args: 30 | 31 | Returns: 32 | """ 33 | import os 34 | if os.path.isdir(path): 35 | return True 36 | else: 37 | logger.info(f'Create New Folder: {path}') 38 | os.makedirs(path) 39 | return True 40 | 41 | def clean_mat_idx(mat, idx): 42 | """ delete the mat value in mat[idx, :] and mat[:, idx] 43 | shrink the mat shape by 1 44 | """ 45 | assert mat.shape[0] > idx, f"mat.shape: {mat.shape} < idx: {idx}" 46 | mat = np.delete(mat, idx, axis=0) 47 | mat = np.delete(mat, idx, axis=1) 48 | if mat.shape[0] == 0: 49 | mat = None 50 | return mat 51 | 52 | def expand_mat_by1(mat): 53 | """ expand the mat by 1 (add a row and a column) 54 | """ 55 | if mat is None: 56 | mat = np.zeros((1, 1)) 57 | return mat 58 | 59 | mat = np.pad(mat, ((0, 1), (0, 1)), 'constant', constant_values=0) 60 | return mat -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-13 23:07:56 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-06-28 19:42:21 6 | FilePath: /SA2M/hydra-mesa/utils/load.py 7 | Description: load utils 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import numpy as np 12 | import yaml 13 | import cv2 14 | from loguru import logger 15 | import os 16 | import glob 17 | from itertools import combinations 18 | 19 | def load_cv_img_resize(img_path, W, H, mode=0): 20 | """ 21 | """ 22 | assert os.path.exists(img_path), f"img path {img_path} not exists" 23 | img = cv2.imread(img_path, mode) 24 | logger.info(f"load img from {img_path} with size {img.shape} resized to {W} x {H}") 25 | if mode == 1: 26 | H_ori, W_ori, _ = img.shape 27 | else: 28 | H_ori, W_ori = img.shape 29 | scale_u = W / W_ori 30 | scale_v = H / H_ori 31 | # print(f"ori W, H: {W_ori} x {H_ori}, with scale: {scale_u} , {scale_v}") 32 | img = cv2.resize(img, (W, H), cv2.INTER_AREA) # type: ignore 33 | return img, [scale_u, scale_v] 34 | 35 | def load_cv_depth(depth_path): 36 | """ for ScanNet Dataset 37 | """ 38 | assert os.path.exists(depth_path), f"depth path {depth_path} not exists" 39 | logger.info(f"load depth from {depth_path}") 40 | return cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 41 | 42 | def load_K_txt(intri_path, scale=[1, 1]): 43 | """For ScanNet K 44 | Args: 45 | scale = [scale_u, scale_v] 46 | """ 47 | assert os.path.exists(intri_path), f"intri path {intri_path} not exists" 48 | K = np.loadtxt(intri_path) 49 | fu = K[0,0] * scale[0] 50 | fv = K[1,1] * scale[1] 51 | cu = K[0,2] * scale[0] 52 | cv = K[1,2] * scale[1] 53 | K_ = np.array([[fu, 0, cu], [0, fv, cv], [0, 0, 1]]) 54 | 55 | logger.info(f"load K from {intri_path} with scale {scale} is \n {K_}") 56 | return np.matrix(K_) 57 | 58 | def load_pose_txt(pose_path): 59 | """For ScanNet pose: cam2world 60 | txt file with 61 | P = 62 | |R t| 63 | |0 1| 64 | 65 | Returns: 66 | P : np.mat 67 | """ 68 | assert os.path.exists(pose_path), f"pose path {pose_path} not exists" 69 | P = np.loadtxt(pose_path) 70 | P = np.matrix(P) 71 | logger.info(f"load pose is \n{P}") 72 | return P 73 | 74 | -------------------------------------------------------------------------------- /utils/transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2023-09-22 17:04:24 4 | LastEditors: EasonZhang 5 | LastEditTime: 2023-12-27 14:56:40 6 | FilePath: /A2PM/utils/transformer.py 7 | Description: including some functions for transforming 8 | - transform SEEM segmentation results to SAM format 9 | - SAM format dict: 10 | - "segmentation" : segmentation binary mask 11 | - "bbox" : bounding box of the segmentation: [x1, y1, x2, y2] 12 | 13 | Copyright (c) 2023 by EasonZhang, All Rights Reserved. 14 | ''' 15 | 16 | import numpy as np 17 | import cv2 18 | import os 19 | from loguru import logger 20 | 21 | 22 | class SEEM2SAM(object): 23 | """ one folder a time 24 | """ 25 | cfg_dft = { 26 | "root_path": "", 27 | "floder_name": "", 28 | "out_path": "", 29 | } 30 | 31 | def __init__(self, cfg={}): 32 | """ 33 | """ 34 | self.cfg = {**self.cfg_dft, **cfg} 35 | self.root_path = self.cfg["root_path"] 36 | self.floder_name = self.cfg["floder_name"] 37 | self.out_path = self.cfg["out_path"] 38 | 39 | def load_seem_seg_img_name_list(self): 40 | """ load all png images in the folder 41 | 42 | """ 43 | img_name_list = [] 44 | for file_ in os.listdir(f"{self.root_path}/{self.floder_name}"): 45 | if file_.endswith(".png"): 46 | img_name_list.append(file_) 47 | return img_name_list 48 | 49 | def load_seem_seg_img(self, img_name): 50 | """ load png image 51 | 52 | """ 53 | img = cv2.imread(f"{self.root_path}/{self.floder_name}/{img_name}", -1) 54 | logger.info(f"load {img_name} with shape {img.shape}") 55 | return img 56 | 57 | def trans_png2npy(self, img, img_name="", save=True): 58 | """ transform segmentation png image to dict and save as npy file 59 | Args: 60 | img (np.ndarray): segmentation png image 61 | - each pixel value is the class index 62 | - get the same-values pixels as the segmentation 63 | - get the bounding box of the segmentation 64 | Returns: 65 | a list of dict: segmentation dict 66 | - "segmentation" : segmentation binary mask 67 | - "bbox" : bounding box of the segmentation: [x, y, w, h] 68 | """ 69 | # get all class index 70 | class_index_list = np.unique(img) 71 | 72 | # for each class index, get the segmentation and bounding box 73 | segmentation_dicts = [] 74 | for class_index in class_index_list: 75 | # if class_index == 0: 76 | # continue 77 | # get the segmentation 78 | segmentation = np.zeros_like(img) 79 | segmentation[img == class_index] = 1 80 | 81 | segmentation = self.get_connection_area(segmentation) 82 | 83 | # get the bounding box 84 | bbox = self.get_bbox(segmentation) # [x, y, w, h] 85 | 86 | # # draw the bounding box 87 | # save_path = f"/data2/zys/A2PM/testAGC/{class_index}.jpg" 88 | # color = np.random.randint(0, 255, (3)) 89 | # x1 = bbox[0] 90 | # y1 = bbox[1] 91 | # x2 = bbox[0] + bbox[2] 92 | # y2 = bbox[1] + bbox[3] 93 | # # covert segmentation to color image 94 | # segmentation_show = np.zeros_like(segmentation) 95 | # segmentation_show = cv2.cvtColor(segmentation_show, cv2.COLOR_GRAY2BGR) 96 | # segmentation_show[segmentation == 1] = color 97 | # cv2.rectangle(segmentation_show, (x1, y1), (x2, y2), (int(color[0]), int(color[1]), int(color[2])), 2) 98 | # cv2.imwrite(save_path, segmentation_show) 99 | 100 | # save as dict 101 | segmentation_dict = { 102 | "segmentation": segmentation, 103 | "bbox": bbox, 104 | } 105 | segmentation_dicts.append(segmentation_dict) 106 | 107 | if save: 108 | save_path = f"{self.out_path}/{self.floder_name}" 109 | logger.info(f"save segmentation dicts to {save_path}") 110 | if not os.path.exists(save_path): 111 | os.makedirs(save_path) 112 | 113 | np.save(f"{save_path}/{img_name[:-4]}.npy", segmentation_dicts) 114 | 115 | return segmentation_dicts 116 | 117 | def get_connection_area(self, segmentation): 118 | """ get the connection area of the segmentation, only save the biggest one 119 | Args: 120 | segmentation (np.ndarray): segmentation binary mask 121 | """ 122 | 123 | bin_img = self._get_bin_img(segmentation) 124 | 125 | # get the connection area 126 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_img, connectivity=8) 127 | 128 | # get the biggest area 129 | max_area = 0 130 | max_area_idx = 0 131 | for i in range(1, num_labels): 132 | area = stats[i, cv2.CC_STAT_AREA] 133 | if area > max_area: 134 | max_area = area 135 | max_area_idx = i 136 | 137 | # return the biggest area img 138 | max_area_img = np.zeros_like(bin_img) 139 | max_area_img[labels == max_area_idx] = 1 140 | 141 | return max_area_img 142 | 143 | def _get_bin_img(self, img): 144 | """ get binary image 145 | """ 146 | H, W = img.shape[0], img.shape[1] 147 | bin_img = np.zeros((H, W), dtype=np.uint8) 148 | 149 | bin_img[np.where(img==1)] = 255 150 | 151 | _, bin_img_rt = cv2.threshold(bin_img, 10, 255, cv2.THRESH_BINARY) 152 | 153 | kernel_close = np.ones((13,13), np.uint8) 154 | close = cv2.morphologyEx(bin_img_rt, cv2.MORPH_CLOSE, kernel_close) 155 | 156 | kernel_open = np.ones((5,5),np.uint8) 157 | opening = cv2.morphologyEx(close, cv2.MORPH_OPEN, kernel_open) 158 | 159 | # cv2.imwrite(os.path.join(self.out_path, "bin_img_" + str(pix_val)+"_"+str(name) + ".jpg"), opening) 160 | return opening 161 | 162 | def get_bbox(self, segmentation): 163 | """ get the bounding box of the segmentation 164 | Args: 165 | segmentation (np.ndarray): segmentation binary mask 166 | Returns: 167 | bbox (list): bounding box of the segmentation: [u, v, w, h] 168 | """ 169 | # get the index of the segmentation 170 | index = np.where(segmentation == 1) 171 | # u is the width-axis, v is the height-axis 172 | u_min, u_max = np.min(index[1]), np.max(index[1]) 173 | v_min, v_max = np.min(index[0]), np.max(index[0]) 174 | 175 | w = u_max - u_min 176 | h = v_max - v_min 177 | 178 | bbox = [u_min, v_min, w, h] 179 | 180 | return bbox 181 | 182 | 183 | def show_res(self, img_name, segmentation_dicts): 184 | """ show the segmentation results 185 | Args: 186 | img (np.ndarray): original image 187 | segmentation_dicts (list): segmentation dicts 188 | """ 189 | H, W = segmentation_dicts[0]["segmentation"].shape 190 | img = np.zeros((H, W, 3), dtype=np.uint8) 191 | 192 | for segmentation_dict in segmentation_dicts: 193 | segmentation = segmentation_dict["segmentation"] 194 | bbox = segmentation_dict["bbox"] 195 | 196 | # show the segmentation 197 | color = np.random.randint(0, 255, (3)) 198 | img[segmentation == 1] = color 199 | 200 | # show the bounding box 201 | x1, y1, x2, y2 = bbox 202 | cv2.rectangle(img, (x1, y1), (x2, y2), (int(color[0]), int(color[1]), int(color[2])), 2) 203 | 204 | cv2.imwrite(f"{self.out_path}/{self.floder_name}/{img_name}", img) 205 | logger.info(f"save {img_name} with shape {img.shape}") 206 | 207 | def run(self, show_idx_end=-1): 208 | """ 209 | """ 210 | img_name_list = self.load_seem_seg_img_name_list() 211 | for i, img_name in enumerate(img_name_list): 212 | img = self.load_seem_seg_img(img_name) 213 | seg_dicts = self.trans_png2npy(img, img_name) 214 | if i <= show_idx_end: 215 | self.show_res(img_name, seg_dicts) 216 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: EasonZhang 3 | Date: 2024-06-19 23:15:51 4 | LastEditors: EasonZhang 5 | LastEditTime: 2024-07-18 12:10:20 6 | FilePath: /SA2M/hydra-mesa/utils/vis.py 7 | Description: TBD 8 | 9 | Copyright (c) 2024 by EasonZhang, All Rights Reserved. 10 | ''' 11 | import numpy as np 12 | import cv2 13 | import matplotlib.cm as cm 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | import pandas as pd 17 | import seaborn as sns 18 | import random 19 | import copy 20 | 21 | import os 22 | from loguru import logger 23 | import collections 24 | 25 | from .img_process import img_to_color 26 | 27 | def draw_matched_area(img0, img1, area0, area1, color, out_path, name0, name1, save=True): 28 | """ 29 | """ 30 | img0 = copy.deepcopy(img0) 31 | img1 = copy.deepcopy(img1) 32 | 33 | if len(img0.shape) == 2: 34 | img0 = img_to_color(img0) 35 | if len(img1.shape) == 2: 36 | img1 = img_to_color(img1) 37 | 38 | W, H = img0.shape[1], img0.shape[0] 39 | 40 | out = stack_img(img0, img1) 41 | 42 | draw_matched_area_in_img(out, area0, area1, color) 43 | 44 | if save: 45 | cv2.imwrite(os.path.join(out_path, f"{name0}_{name1}_matched_area.png"), out) 46 | logger.info(f"save matched area to {os.path.join(out_path, f'{name0}_{name1}_matched_area.png')}") 47 | 48 | return out 49 | 50 | def draw_matched_area_list(img0, img1, area0_list, area1_list, out_path, name0, name1, save=True): 51 | """ 52 | """ 53 | n = len(area0_list) 54 | assert n == len(area1_list) 55 | 56 | color_map = get_n_colors(n) 57 | 58 | if len(img0.shape) == 2: 59 | img0 = img_to_color(img0) 60 | if len(img1.shape) == 2: 61 | img1 = img_to_color(img1) 62 | 63 | W, H = img0.shape[1], img0.shape[0] 64 | out = stack_img(img0, img1) 65 | 66 | flag = True 67 | for i in range(n): 68 | color = color_map[i] 69 | flag = draw_matched_area_in_img(out, area0_list[i], area1_list[i], color) 70 | 71 | if save: 72 | cv2.imwrite(os.path.join(out_path, f"{name0}_{name1}_matched_areas.png"), out) 73 | 74 | return flag 75 | 76 | def draw_matched_area_with_mkpts(img0, img1, area0, area1, mkpts0, mkpts1, color, out_path, name0, name1, save=True): 77 | """ 78 | """ 79 | if len(img0.shape) == 2: 80 | img0 = img_to_color(img0) 81 | if len(img1.shape) == 2: 82 | img1 = img_to_color(img1) 83 | 84 | W, H = img0.shape[1], img0.shape[0] 85 | 86 | out = stack_img(img0, img1) 87 | 88 | out = draw_matched_area_in_img(out, area0, area1, color) 89 | 90 | out = draw_mkpts_in_img(out, mkpts0, mkpts1, color) 91 | 92 | if save: 93 | cv2.imwrite(os.path.join(out_path, f"{name0}_{name1}_matched_area_kpts.png"), out) 94 | logger.info(f"save matched area to {os.path.join(out_path, f'{name0}_{name1}_matched_area_kpts.png')}") 95 | 96 | return out 97 | 98 | def paint_semantic(ins0, ins1, out_path="", name0="", name1="", save=True): 99 | """ fill color by sematic label 100 | """ 101 | assert len(ins0.shape) == len(ins1.shape) == 2 102 | 103 | H, W = ins0.shape 104 | 105 | label_list = [] 106 | label_color_dict = {} 107 | 108 | for i in range(H): 109 | for j in range(W): 110 | temp0 = ins0[i, j] 111 | temp1 = ins1[i, j] 112 | if temp0 not in label_list: 113 | label_list.append(temp0) 114 | if temp1 not in label_list: 115 | label_list.append(temp1) 116 | 117 | label_list = sorted(label_list) 118 | 119 | # print(label_list) 120 | 121 | N = len(label_list) 122 | cmaps_='gist_ncar' 123 | cmap = matplotlib.colors.ListedColormap(plt.get_cmap(cmaps_)(np.linspace(0, 1, N))) 124 | 125 | for i in range(N): 126 | # black is background 127 | if label_list[i] == 0: 128 | label_color_dict[label_list[i]] = [0, 0, 0] 129 | continue 130 | c = cmap(i) 131 | label_color_dict[label_list[i]] = [int(c[0]*255), int(c[1]*255), int(c[2]*255)] 132 | 133 | 134 | 135 | outImg0 = np.zeros((H,W,3)) 136 | # print(outImg0.shape) 137 | 138 | for i in range(H): 139 | for j in range(W): 140 | outImg0[i, j, :] = label_color_dict[ins0[i,j]] 141 | 142 | outImg1 = np.zeros((H,W,3)) 143 | # print(outImg0.shape) 144 | 145 | for i in range(H): 146 | for j in range(W): 147 | outImg1[i, j, :] = label_color_dict[ins1[i,j]] 148 | 149 | if save: 150 | cv2.imwrite(os.path.join(out_path, "{0}_color.jpg".format(name0)), outImg0) 151 | cv2.imwrite(os.path.join(out_path, "{0}_color.jpg".format(name1)), outImg1) 152 | 153 | return outImg0, outImg1 154 | 155 | def get_n_colors(n): 156 | """ 157 | """ 158 | label_color_dict = {} 159 | 160 | cmaps_='gist_ncar' 161 | cmap = matplotlib.colors.ListedColormap(plt.get_cmap(cmaps_)(np.linspace(0, 1, n))) 162 | 163 | for i in range(n): 164 | c = cmap(i) 165 | label_color_dict[i] = [int(c[0]*255), int(c[1]*255), int(c[2]*255)] 166 | 167 | return label_color_dict 168 | 169 | def stack_img(img0, img1): 170 | """ stack two image in horizontal 171 | Args: 172 | img0: numpy array 3 channel 173 | """ 174 | # assert img0.shape == img1.shape 175 | 176 | if len(img0.shape) == 2: 177 | img0 = img_to_color(img0) 178 | if len(img1.shape) == 2: 179 | img1 = img_to_color(img1) 180 | 181 | assert len(img0.shape) == 3 182 | 183 | W0, H0 = img0.shape[1], img0.shape[0] 184 | W1, H1 = img1.shape[1], img1.shape[0] 185 | 186 | H_s = max(H0, H1) 187 | W_s = W0 + W1 188 | 189 | out = 255 * np.ones((H_s, W_s, 3), np.uint8) 190 | 191 | try: 192 | out[:H0, :W0, :] = img0.copy() 193 | out[:H1, W0:, :] = img1.copy() 194 | except ValueError as e: 195 | logger.exception(e) 196 | logger.info(f"img0 shape is {img0.shape}, img1 shape is {img1.shape}") 197 | logger.info(f"out shape is {out.shape}") 198 | raise e 199 | 200 | return out 201 | 202 | def draw_matched_area_in_img(out, patch0, patch1, color): 203 | """ 204 | """ 205 | W = out.shape[1] // 2 206 | W = int(W) 207 | patch0 = [int(i) for i in patch0] 208 | patch1_s = [patch1[0]+W, patch1[1]+W, patch1[2], patch1[3]] 209 | try: 210 | patch1_s = [int(i) for i in patch1_s] 211 | except ValueError as e: 212 | logger.exception(e) 213 | return False 214 | 215 | 216 | # logger.info(f"patch0 are {patch0[0]}, {patch0[1]}, {patch0[2]}, {patch0[3]}") 217 | # logger.info(f"patch1 are {patch1_s[0]}, {patch1_s[1]}, {patch1_s[2]}, {patch1_s[3]}") 218 | 219 | cv2.rectangle(out, (patch0[0], patch0[2]), (patch0[1], patch0[3]), tuple(color), 3) 220 | try: 221 | cv2.rectangle(out, (patch1_s[0], patch1_s[2]), (patch1_s[1], patch1_s[3]), color, 3) 222 | except cv2.error: 223 | logger.exception("what?") 224 | return False 225 | 226 | line_s = [(patch0[0]+patch0[1])//2, (patch0[2]+patch0[3])//2] 227 | line_e = [(patch1_s[0]+patch1_s[1])//2, (patch1_s[2]+patch1_s[3])//2] 228 | 229 | cv2.line(out, (line_s[0], line_s[1]), (line_e[0], line_e[1]), color=color, thickness=3, lineType=cv2.LINE_AA) 230 | 231 | return True 232 | 233 | def plot_matches_lists_lr(image0, image1, matches, outPath, name): 234 | """ 235 | Args: 236 | matches: [u0, v0, u1,v1]s 237 | """ 238 | image0 = img_to_color(image0) 239 | image1 = img_to_color(image1) 240 | 241 | H0, W0 = image0.shape[0], image0.shape[1] 242 | H1, W1 = image1.shape[0], image1.shape[1] 243 | 244 | H, W = max(H0, H1), W0 + W1 245 | out = 255 * np.ones((H, W, 3), np.uint8) 246 | out[:H0, :W0, :] = image0 247 | out[:H1, W0:, :] = image1 248 | 249 | color = np.zeros((len(matches), 3), dtype=int) 250 | color[:, 1] = 255 251 | 252 | for match, c in zip(matches, color): 253 | c = c.tolist() 254 | u0, v0, u1, v1 = match 255 | # print(u0) 256 | u0 = int(u0) 257 | v0 = int(v0) 258 | u1 = int(u1) + W0 259 | v1 = int(v1) 260 | cv2.line(out, (u0, v0), (u1, v1), color=c, thickness=1, lineType=cv2.LINE_AA) 261 | cv2.circle(out, (u0, v0), 2, c, -1, lineType=cv2.LINE_AA) 262 | cv2.circle(out, (u1, v1), 2, c, -1, lineType=cv2.LINE_AA) 263 | 264 | path = os.path.join(outPath, name+".jpg") 265 | # logger.critical(f"save match list img to {path}") 266 | logger.info(f"save match list img to {path}") 267 | cv2.imwrite(path, out) 268 | 269 | def plot_matches_lists_ud(image0, image1, matches, outPath, name): 270 | """ 271 | Args: 272 | matches: [u0, v0, u1,v1]s 273 | """ 274 | image0 = img_to_color(image0) 275 | image1 = img_to_color(image1) 276 | 277 | H0, W0 = image0.shape[0], image0.shape[1] 278 | H1, W1 = image1.shape[0], image1.shape[1] 279 | 280 | H, W = H0 + H1, max(W0, W1) 281 | out = 255 * np.ones((H, W, 3), np.uint8) 282 | out[:H0, :W0, :] = image0 283 | out[H0:, :W1, :] = image1 284 | 285 | color = np.zeros((len(matches), 3), dtype=int) 286 | color[:, 1] = 255 287 | 288 | for match, c in zip(matches, color): 289 | c = c.tolist() 290 | u0, v0, u1, v1 = match 291 | # print(u0) 292 | u0 = int(u0) 293 | v0 = int(v0) 294 | u1 = int(u1) 295 | v1 = int(v1) + H0 296 | cv2.line(out, (u0, v0), (u1, v1), color=c, thickness=1, lineType=cv2.LINE_AA) 297 | cv2.circle(out, (u0, v0), 2, c, -1, lineType=cv2.LINE_AA) 298 | cv2.circle(out, (u1, v1), 2, c, -1, lineType=cv2.LINE_AA) 299 | 300 | path = os.path.join(outPath, name+".jpg") 301 | logger.info(f"save match list img to {path}") 302 | cv2.imwrite(path, out) 303 | 304 | def plot_matches_with_mask_ud(image0, image1, mask, matches, outPath, name, sample_num=500): 305 | """ 306 | Args: 307 | mask: 0 -> false match 308 | """ 309 | # random sample 310 | if len(matches) > sample_num: 311 | matches = random.sample(matches, sample_num) 312 | 313 | image0 = img_to_color(image0) 314 | image1 = img_to_color(image1) 315 | 316 | H0, W0 = image0.shape[0], image0.shape[1] 317 | H1, W1 = image1.shape[0], image1.shape[1] 318 | 319 | H, W = H0 + H1, max(W0, W1) 320 | out = 255 * np.ones((H, W, 3), np.uint8) 321 | out[:H0, :W0, :] = image0 322 | out[H0:, :W1, :] = image1 323 | 324 | for i, match in enumerate(matches): 325 | if mask[i] == 0: c = [0, 0, 255] 326 | if mask[i] == 1: c = [0, 255, 0] 327 | if mask[i] == -1: continue 328 | 329 | u0, v0, u1, v1 = match 330 | u0 = int(u0) 331 | v0 = int(v0) 332 | u1 = int(u1) 333 | v1 = int(v1) + H0 334 | cv2.line(out, (u0, v0), (u1, v1), color=c, thickness=1, lineType=cv2.LINE_AA) 335 | cv2.circle(out, (u0, v0), 2, c, -1, lineType=cv2.LINE_AA) 336 | cv2.circle(out, (u1, v1), 2, c, -1, lineType=cv2.LINE_AA) 337 | 338 | 339 | path = os.path.join(outPath, name+".jpg") 340 | logger.info(f"save match list img to {path}") 341 | cv2.imwrite(path, out) 342 | 343 | return out 344 | --------------------------------------------------------------------------------