├── data ├── .gitkeep └── models │ └── .gitkeep ├── surfemb ├── __init__.py ├── scripts │ ├── __init__.py │ ├── misc │ │ ├── __init__.py │ │ ├── compact_model.py │ │ ├── surface_samples_recover_normals.py │ │ ├── surface_samples_sample_even.py │ │ ├── surface_samples_remesh_visible.py │ │ ├── format_results_for_eval.py │ │ ├── load_detection_results.py │ │ └── render_poses.py │ ├── train.py │ ├── infer.py │ ├── infer_refine_depth.py │ └── infer_debug.py ├── data │ ├── __init__.py │ ├── config.py │ ├── obj.py │ ├── tfms.py │ ├── pose_auxs.py │ ├── detector_crops.py │ ├── instance.py │ ├── renderer.py │ └── std_auxs.py ├── dep │ ├── siren.py │ └── unet.py ├── utils.py ├── pose_refine.py ├── surface_embedding.py └── pose_est.py ├── .gitignore ├── environment.yml ├── LICENSE └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /surfemb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /surfemb/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /surfemb/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import instance, obj, pose_auxs, renderer, std_auxs, tfms 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /.ipynb_checkpoints 3 | /data 4 | /sandbox 5 | __pycache__ 6 | *.zip 7 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: surfemb 2 | dependencies: 3 | - python=3.8 4 | - pip 5 | - pytorch::pytorch 6 | - pytorch::torchvision 7 | - cudatoolkit=10.2 8 | - pyg::pytorch-scatter 9 | - conda-forge::imgaug 10 | - conda-forge::albumentations 11 | - py-opencv 12 | - numpy 13 | - matplotlib 14 | - scipy 15 | - pip: 16 | - tqdm 17 | - moderngl 18 | - pytorch_lightning 19 | - wandb 20 | - trimesh 21 | - rtree 22 | - pymeshlab -------------------------------------------------------------------------------- /surfemb/scripts/misc/compact_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | By default, the optimizer parameters are also saved with the model. 3 | """ 4 | import argparse 5 | 6 | import torch 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('model_path') 10 | args = parser.parse_args() 11 | 12 | ckpt = torch.load(args.model_path) 13 | torch.save(dict( 14 | state_dict=ckpt['state_dict'], 15 | hyper_parameters=ckpt['hyper_parameters'], 16 | ), args.model_path.replace('.ckpt', '.compact.ckpt')) 17 | -------------------------------------------------------------------------------- /surfemb/data/config.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | class DatasetConfig: 5 | model_folder = 'models' 6 | train_folder = 'train' 7 | test_folder = 'test' 8 | img_folder = 'rgb' 9 | depth_folder = 'depth' 10 | img_ext = 'png' 11 | depth_ext = 'png' 12 | 13 | 14 | config = defaultdict(lambda *_: DatasetConfig()) 15 | 16 | config['tless'] = tless = DatasetConfig() 17 | tless.model_folder = 'models_cad' 18 | tless.test_folder = 'test_primesense' 19 | tless.train_folder = 'train_primesense' 20 | 21 | config['hb'] = hb = DatasetConfig() 22 | hb.test_folder = 'test_primesense' 23 | 24 | config['itodd'] = itodd = DatasetConfig() 25 | itodd.depth_ext = 'tif' 26 | itodd.img_folder = 'gray' 27 | itodd.img_ext = 'tif' 28 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/surface_samples_recover_normals.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import trimesh 5 | import trimesh.proximity 6 | from tqdm import tqdm 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('dataset') 10 | args = parser.parse_args() 11 | 12 | mesh_folder = Path('data/remesh_visible') / args.dataset 13 | even_samples_folder = Path('data/surface_samples') / args.dataset 14 | assert even_samples_folder.exists() 15 | normals_folder = Path('data/surface_samples_normals') / args.dataset 16 | normals_folder.mkdir(exist_ok=True, parents=True) 17 | 18 | for mesh_fp in tqdm(list(mesh_folder.glob('*.ply'))): 19 | samples_fp = even_samples_folder / mesh_fp.name 20 | normals_fp = normals_folder / mesh_fp.name 21 | 22 | mesh = trimesh.load_mesh(mesh_fp) # type: trimesh.Trimesh 23 | sample_pts = trimesh.load_mesh(samples_fp).vertices 24 | face_idx = trimesh.proximity.closest_point(mesh, sample_pts)[-1] 25 | normals = mesh.face_normals[face_idx] 26 | 27 | pc = trimesh.PointCloud(normals) 28 | pc.export(normals_fp) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rasmus Laurvig Haugaard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /surfemb/scripts/misc/surface_samples_sample_even.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import pymeshlab 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('dataset') 9 | parser.add_argument('--n-samples', type=int, default=int(50e3)) 10 | args = parser.parse_args() 11 | 12 | n_samples = args.n_samples 13 | 14 | remesh_folder = Path(f'data/remesh_visible/{args.dataset}') 15 | assert remesh_folder.exists() 16 | even_samples_folder = Path(f'data/surface_samples/{args.dataset}') 17 | even_samples_folder.mkdir(exist_ok=True) 18 | ms = pymeshlab.MeshSet() 19 | 20 | for mesh_fp in tqdm(list(remesh_folder.glob('*.ply'))): 21 | ms.clear() 22 | samples_fp = even_samples_folder / mesh_fp.name 23 | 24 | print() 25 | print(mesh_fp) 26 | print() 27 | 28 | ms.load_new_mesh(str(mesh_fp.absolute())) 29 | mesh_id = ms.current_mesh_id() 30 | n = n_samples 31 | while True: 32 | ms.set_current_mesh(mesh_id) 33 | ms.poisson_disk_sampling(samplenum=n) 34 | n_actual_samples = ms.current_mesh().vertex_number() 35 | print(n_actual_samples) 36 | if n_actual_samples >= n_samples: 37 | ms.save_current_mesh(str(samples_fp.absolute()), save_vertex_normal=False, save_textures=False, 38 | save_vertex_quality=False, save_vertex_color=False, save_vertex_coord=False, 39 | save_vertex_radius=False) 40 | break 41 | else: 42 | ms.delete_current_mesh() 43 | n = n + n // 2 44 | -------------------------------------------------------------------------------- /surfemb/data/obj.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | import trimesh 8 | 9 | 10 | class Obj: 11 | def __init__(self, obj_id, mesh: trimesh.Trimesh, diameter: float): 12 | self.obj_id = obj_id 13 | self.mesh = mesh 14 | self.diameter = diameter 15 | 16 | bounding_sphere = self.mesh.bounding_sphere.primitive 17 | self.offset, self.scale = bounding_sphere.center, bounding_sphere.radius 18 | 19 | self.mesh_norm = mesh.copy() 20 | self.mesh_norm.apply_translation(-self.offset) 21 | self.mesh_norm.apply_scale(1 / self.scale) 22 | 23 | def normalize(self, pts: np.ndarray): 24 | return (pts - self.offset) / self.scale 25 | 26 | def denormalize(self, pts_norm: np.ndarray): 27 | return pts_norm * self.scale + self.offset 28 | 29 | 30 | def load_obj(models_root: Path, obj_id: int): 31 | models_info = json.load((models_root / 'models_info.json').open()) 32 | mesh = trimesh.load_mesh(str(models_root / f'obj_{obj_id:06d}.ply')) 33 | diameter = models_info[str(obj_id)]['diameter'] 34 | return Obj(obj_id, mesh, diameter) 35 | 36 | 37 | def load_objs(models_root: Path, obj_ids: Iterable[int] = None, show_progressbar=True): 38 | objs = [] 39 | if obj_ids is None: 40 | obj_ids = sorted([int(p.name[4:10]) for p in models_root.glob('*.ply')]) 41 | for obj_id in tqdm(obj_ids, 'loading objects') if show_progressbar else obj_ids: 42 | objs.append(load_obj(models_root, obj_id)) 43 | return objs, obj_ids 44 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/surface_samples_remesh_visible.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import pymeshlab 5 | import trimesh 6 | from tqdm import tqdm 7 | 8 | from ...data.config import config 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('dataset') 12 | parser.add_argument('--face-quality-threshold', type=float, default=1e-3) 13 | parser.add_argument('--remesh-percentage', type=float, default=5.) 14 | args = parser.parse_args() 15 | 16 | mesh_folder = Path('data/bop') / args.dataset / config[args.dataset].model_folder 17 | remesh_folder = Path('data/remesh_visible') / args.dataset 18 | remesh_folder.mkdir(exist_ok=True, parents=True) 19 | 20 | for mesh_fp in tqdm(list(mesh_folder.glob('*.ply'))): 21 | remesh_fp = remesh_folder / mesh_fp.name 22 | 23 | print() 24 | print(mesh_fp) 25 | print() 26 | 27 | ms = pymeshlab.MeshSet() 28 | ms.load_new_mesh(str(mesh_fp.absolute())) 29 | ms.repair_non_manifold_edges_by_removing_faces() 30 | ms.subdivision_surfaces_midpoint(iterations=10, threshold=pymeshlab.Percentage(args.remesh_percentage)) 31 | ms.ambient_occlusion(occmode='per-Face (deprecated)', reqviews=256) 32 | face_quality_array = ms.current_mesh().face_quality_array() 33 | minq = face_quality_array.min() 34 | if minq < args.face_quality_threshold: 35 | assert face_quality_array.max() > args.face_quality_threshold 36 | ms.select_by_face_quality(minq=minq, maxq=args.face_quality_threshold) 37 | ms.delete_selected_faces() 38 | ms.remove_unreferenced_vertices() 39 | ms.save_current_mesh(str(remesh_fp.absolute()), save_textures=False) 40 | 41 | area_reduction = trimesh.load_mesh(remesh_fp).area / trimesh.load_mesh(mesh_fp).area 42 | print() 43 | print(mesh_fp) 44 | print(f'area reduction {area_reduction}') 45 | print() 46 | -------------------------------------------------------------------------------- /surfemb/dep/siren.py: -------------------------------------------------------------------------------- 1 | # From https://vsitzmann.github.io/siren/ (MIT License) 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class SineLayer(nn.Module): 8 | def __init__(self, in_features, out_features, bias=True, 9 | is_first=False, omega_0=30.): 10 | super().__init__() 11 | self.omega_0 = omega_0 12 | self.is_first = is_first 13 | self.in_features = in_features 14 | self.linear = nn.Linear(in_features, out_features, bias=bias) 15 | self.init_weights() 16 | 17 | def init_weights(self): 18 | with torch.no_grad(): 19 | if self.is_first: 20 | self.linear.weight.uniform_(-1 / self.in_features, 21 | 1 / self.in_features) 22 | else: 23 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 24 | np.sqrt(6 / self.in_features) / self.omega_0) 25 | 26 | def forward(self, input): 27 | return torch.sin(self.omega_0 * self.linear(input)) 28 | 29 | 30 | class Siren(nn.Module): 31 | def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=True, 32 | first_omega_0=30., hidden_omega_0=30.): 33 | super().__init__() 34 | self.net = [] 35 | self.net.append(SineLayer(in_features, hidden_features, 36 | is_first=True, omega_0=first_omega_0)) 37 | for i in range(hidden_layers): 38 | self.net.append(SineLayer(hidden_features, hidden_features, 39 | is_first=False, omega_0=hidden_omega_0)) 40 | if outermost_linear: 41 | final_linear = nn.Linear(hidden_features, out_features) 42 | with torch.no_grad(): 43 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 44 | np.sqrt(6 / hidden_features) / hidden_omega_0) 45 | self.net.append(final_linear) 46 | else: 47 | self.net.append(SineLayer(hidden_features, out_features, 48 | is_first=False, omega_0=hidden_omega_0)) 49 | self.net = nn.Sequential(*self.net) 50 | 51 | def forward(self, coords): 52 | return self.net(coords) 53 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/format_results_for_eval.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | from ...data.config import config 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('dataset') 11 | parser.add_argument('poses') 12 | parser.add_argument('--dont-use-refinement', dest='use_refinement', action='store_false') 13 | parser.add_argument('--dont-use-pose-score', dest='use_pose_score', action='store_false') 14 | args = parser.parse_args() 15 | 16 | detection_path = Path('data/detection_results') / args.dataset 17 | poses_fp = Path(args.poses) 18 | 19 | name = '-'.join(poses_fp.name.split('-')[:-1]) # dataset, run_id, [optionally "depth"] 20 | pose_scores_fp = poses_fp.parent / f'{name.replace("-depth", "")}-poses-scores.npy' 21 | pose_timings_fp = poses_fp.parent / f'{name}-poses-timings.npy' 22 | 23 | poses = np.load(str(poses_fp)) 24 | pose_scores = np.load(str(pose_scores_fp)) 25 | pose_timings = np.load(str(pose_timings_fp)) 26 | det_scene_ids = np.load(str(detection_path / 'scene_ids.npy')) 27 | det_view_ids = np.load(str(detection_path / 'view_ids.npy')) 28 | det_obj_ids = np.load(str(detection_path / 'obj_ids.npy')) 29 | det_scores = np.load(str(detection_path / 'scores.npy')) 30 | det_times = np.load(str(detection_path / 'times.npy')) 31 | assert len(det_scores) == len(pose_scores) 32 | 33 | scores = pose_scores if args.use_pose_score else det_scores 34 | poses = poses[1 if args.use_refinement else 0] 35 | pose_timings = pose_timings[1 if args.use_refinement else 0] 36 | 37 | Rs = poses[:, :3, :3] 38 | ts = poses[:, :3, 3] 39 | 40 | img_timings = defaultdict(lambda: 0) 41 | 42 | for t, scene_id, view_id in zip(pose_timings, det_scene_ids, det_view_ids): 43 | img_timings[(scene_id, view_id)] += t 44 | 45 | lines = [] 46 | for i in range(len(poses)): 47 | line = ','.join(( 48 | str(det_scene_ids[i]), 49 | str(det_view_ids[i]), 50 | str(det_obj_ids[i]), 51 | str(scores[i]), 52 | ' '.join((str(v) for v in Rs[i].reshape(-1))), 53 | ' '.join((str(v) for v in ts[i])), 54 | f'{det_times[i] + img_timings[(det_scene_ids[i], det_view_ids[i])]}\n', 55 | )) 56 | lines.append(line) 57 | 58 | if args.use_refinement: 59 | name += '-refine' 60 | if args.use_pose_score: 61 | name += '-pose-score' 62 | 63 | with open( 64 | f'data/results/{name}_{args.dataset}-{config[args.dataset].test_folder}.csv' 65 | , 'w' 66 | ) as f: 67 | f.writelines(lines) 68 | -------------------------------------------------------------------------------- /surfemb/data/tfms.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import albumentations as A 7 | 8 | imagenet_stats = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 9 | 10 | 11 | def normalize(img: np.ndarray): # (h, w, 3) -> (3, h, w) 12 | mu, std = imagenet_stats 13 | if img.dtype == np.uint8: 14 | img = img / 255 15 | img = (img - mu) / std 16 | return img.transpose(2, 0, 1).astype(np.float32) 17 | 18 | 19 | def denormalize(img: Union[np.ndarray, torch.Tensor]): 20 | mu, std = imagenet_stats 21 | if isinstance(img, torch.Tensor): 22 | mu, std = [torch.Tensor(v).type(img.dtype).to(img.device)[:, None, None] for v in (mu, std)] 23 | return img * std + mu 24 | 25 | 26 | class Unsharpen(A.ImageOnlyTransform): 27 | def __init__(self, k_limits=(3, 7), strength_limits=(0., 2.), p=0.5): 28 | super().__init__() 29 | self.k_limits = k_limits 30 | self.strength_limits = strength_limits 31 | self.p = p 32 | 33 | def apply(self, img, **params): 34 | if np.random.rand() > self.p: 35 | return img 36 | k = np.random.randint(self.k_limits[0] // 2, self.k_limits[1] // 2 + 1) * 2 + 1 37 | s = k / 3 38 | blur = cv2.GaussianBlur(img, (k, k), s) 39 | strength = np.random.uniform(*self.strength_limits) 40 | unsharpened = cv2.addWeighted(img, 1 + strength, blur, -strength, 0) 41 | return unsharpened 42 | 43 | 44 | class DebayerArtefacts(A.ImageOnlyTransform): 45 | def __init__(self, p=0.5): 46 | super().__init__() 47 | self.p = p 48 | 49 | def apply(self, img, **params): 50 | if np.random.rand() > self.p: 51 | return img 52 | assert img.dtype == np.uint8 53 | # permute channels before bayering/debayering to cover different bayer formats 54 | channel_idxs = np.random.permutation(3) 55 | channel_idxs_inv = np.empty(3, dtype=int) 56 | channel_idxs_inv[channel_idxs] = 0, 1, 2 57 | 58 | # assemble bayer image 59 | bayer = np.zeros(img.shape[:2], dtype=img.dtype) 60 | bayer[::2, ::2] = img[::2, ::2, channel_idxs[2]] 61 | bayer[1::2, ::2] = img[1::2, ::2, channel_idxs[1]] 62 | bayer[::2, 1::2] = img[::2, 1::2, channel_idxs[1]] 63 | bayer[1::2, 1::2] = img[1::2, 1::2, channel_idxs[0]] 64 | 65 | # debayer 66 | debayer_method = np.random.choice((cv2.COLOR_BAYER_BG2BGR, cv2.COLOR_BAYER_BG2BGR_EA)) 67 | debayered = cv2.cvtColor(bayer, debayer_method)[..., channel_idxs_inv] 68 | return debayered 69 | -------------------------------------------------------------------------------- /surfemb/data/pose_auxs.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | 5 | from .instance import BopInstanceAux 6 | from .obj import Obj 7 | from .renderer import ObjCoordRenderer 8 | 9 | 10 | class ObjCoordAux(BopInstanceAux): 11 | def __init__(self, objs: Sequence[Obj], res: int, mask_key='mask_visib_crop', replace_mask=False, sigma=0.): 12 | self.objs, self.res = objs, res 13 | self.mask_key = mask_key 14 | self.replace_mask = replace_mask 15 | self.renderer = None 16 | self.sigma = sigma 17 | 18 | def get_renderer(self): 19 | # lazy instantiation of renderer to create the context in the worker process 20 | if self.renderer is None: 21 | self.renderer = ObjCoordRenderer(self.objs, self.res) 22 | return self.renderer 23 | 24 | def __call__(self, inst: dict, _) -> dict: 25 | renderer = self.get_renderer() 26 | K = inst['K_crop'].copy() 27 | 28 | if self.sigma > 0: 29 | # offset principal axis slightly to encourage all object coordinates within the pixel to have 30 | # som probability mass. Smoother probs -> more robust score and better posed refinement opt. problem. 31 | while True: 32 | offset = np.random.randn(2) 33 | if np.linalg.norm(offset) < 3: 34 | K[:2, 2] += offset * self.sigma 35 | break 36 | 37 | obj_coord = renderer.render(inst['obj_idx'], K, inst['cam_R_obj'], inst['cam_t_obj']).copy() 38 | if self.mask_key is not None: 39 | if self.replace_mask: 40 | mask = obj_coord[..., 3] 41 | else: 42 | mask = obj_coord[..., 3] * inst[self.mask_key] / 255 43 | obj_coord[..., 3] = mask 44 | inst[self.mask_key] = (mask * 255).astype(np.uint8) 45 | inst['obj_coord'] = obj_coord 46 | return inst 47 | 48 | 49 | class SurfaceSampleAux(BopInstanceAux): 50 | def __init__(self, objs: Sequence[Obj], n_samples: int, norm=True): 51 | self.objs, self.n_samples = objs, n_samples 52 | self.norm = norm 53 | 54 | def __call__(self, inst: dict, _) -> dict: 55 | obj = self.objs[inst['obj_idx']] 56 | mesh = obj.mesh_norm if self.norm else obj.mesh 57 | inst['surface_samples'] = mesh.sample(self.n_samples).astype(np.float32) 58 | return inst 59 | 60 | 61 | class MaskSamplesAux(BopInstanceAux): 62 | def __init__(self, n_samples: int, mask_key='mask_visib_crop'): 63 | self.mask_key = mask_key 64 | self.n_samples = n_samples 65 | 66 | def __call__(self, inst: dict, _): 67 | mask_arg = np.argwhere(inst[self.mask_key]) # (N, 2) 68 | idxs = np.random.choice(np.arange(len(mask_arg)), self.n_samples, replace=self.n_samples > len(mask_arg)) 69 | inst['mask_samples'] = mask_arg[idxs] # (n_samples, 2) 70 | return inst 71 | -------------------------------------------------------------------------------- /surfemb/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | from pathlib import Path 4 | 5 | import cv2 6 | import torch 7 | import torch.utils.data 8 | import trimesh 9 | import trimesh.sample 10 | 11 | 12 | @contextmanager 13 | def timer(text='', do=True): 14 | if do: 15 | start = time.time() 16 | try: 17 | yield 18 | finally: 19 | print(f'{text}: {time.time() - start:.4}s') 20 | else: 21 | yield 22 | 23 | 24 | @contextmanager 25 | def add_timing_to_list(l): 26 | start = time.time() 27 | try: 28 | yield 29 | finally: 30 | l.append(time.time() - start) 31 | 32 | 33 | def balanced_dataset_concat(a, b): 34 | # makes an approximately 50/50 concat 35 | # by adding copies of the smallest dataset 36 | if len(a) < len(b): 37 | a, b = b, a 38 | assert len(a) >= len(b) 39 | data = a 40 | for i in range(round(len(a) / len(b))): 41 | data += b 42 | return data 43 | 44 | 45 | def load_surface_samples(dataset, obj_ids, root=Path('data')): 46 | surface_samples = [trimesh.load_mesh(root / f'surface_samples/{dataset}/obj_{i:06d}.ply').vertices for i in obj_ids] 47 | surface_sample_normals = [trimesh.load_mesh(root / f'surface_samples_normals/{dataset}/obj_{i:06d}.ply').vertices 48 | for i in obj_ids] 49 | return surface_samples, surface_sample_normals 50 | 51 | 52 | class Rodrigues(torch.autograd.Function): 53 | @staticmethod 54 | def forward(ctx, rvec): 55 | R, jac = cv2.Rodrigues(rvec.detach().cpu().numpy()) 56 | jac = torch.from_numpy(jac).to(rvec.device) 57 | ctx.save_for_backward(jac) 58 | return torch.from_numpy(R).to(rvec.device) 59 | 60 | @staticmethod 61 | def backward(ctx, grad_output): 62 | jac, = ctx.saved_tensors 63 | return jac @ grad_output.to(jac.device).reshape(-1) 64 | 65 | 66 | def rotate_batch(batch: torch.Tensor): # (..., H, H) -> (4, ..., H, H) 67 | assert batch.shape[-1] == batch.shape[-2] 68 | return torch.stack([ 69 | batch, # 0 deg 70 | torch.flip(batch, [-2]).transpose(-1, -2), # 90 deg 71 | torch.flip(batch, [-1, -2]), # 180 deg 72 | torch.flip(batch, [-1]).transpose(-1, -2), # 270 deg 73 | ]) # (4, ..., H, H) 74 | 75 | 76 | def rotate_batch_back(batch: torch.Tensor): # (4, ..., H, H) -> (4, ..., H, H) 77 | assert batch.shape[0] == 4 78 | assert batch.shape[-1] == batch.shape[-2] 79 | return torch.stack([ 80 | batch[0], # 0 deg 81 | torch.flip(batch[1], [-1]).transpose(-1, -2), # -90 deg 82 | torch.flip(batch[2], [-1, -2]), # -180 deg 83 | torch.flip(batch[3], [-2]).transpose(-1, -2), # -270 deg 84 | ]) # (4, ..., H, H) 85 | 86 | 87 | class EmptyDataset(torch.utils.data.Dataset): 88 | def __len__(self): 89 | return 0 90 | 91 | def __getitem__(self, item): 92 | return None 93 | -------------------------------------------------------------------------------- /surfemb/data/detector_crops.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Sequence 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch.utils.data 9 | 10 | from .config import DatasetConfig 11 | from .instance import BopInstanceAux 12 | 13 | 14 | class DetectorCropDataset(torch.utils.data.Dataset): 15 | def __init__( 16 | self, dataset_root: Path, obj_ids, detection_folder: Path, cfg: DatasetConfig, 17 | auxs: Sequence[BopInstanceAux], show_progressbar=True, 18 | ): 19 | self.data_folder = dataset_root / cfg.test_folder 20 | self.img_folder = cfg.img_folder 21 | self.depth_folder = cfg.depth_folder 22 | self.img_ext = cfg.img_ext 23 | self.depth_ext = cfg.depth_ext 24 | 25 | self.bboxes = np.load(str(detection_folder / 'bboxes.npy')) 26 | self.obj_ids = np.load(str(detection_folder / 'obj_ids.npy')) 27 | self.scene_ids = np.load(str(detection_folder / 'scene_ids.npy')) 28 | self.view_ids = np.load(str(detection_folder / 'view_ids.npy')) 29 | self.obj_idxs = {obj_id: idx for idx, obj_id in enumerate(obj_ids)} 30 | 31 | self.auxs = auxs 32 | self.instances = [] 33 | scene_ids = sorted([int(scene_dir.name) for scene_dir in self.data_folder.glob('*') if scene_dir.is_dir()]) 34 | 35 | self.scene_cameras = defaultdict(lambda *_: []) 36 | 37 | for scene_id in tqdm(scene_ids, 'loading crop info') if show_progressbar else scene_ids: 38 | scene_folder = self.data_folder / f'{scene_id:06d}' 39 | self.scene_cameras[scene_id] = json.load((scene_folder / 'scene_camera.json').open()) 40 | 41 | for aux in self.auxs: 42 | aux.init(self) 43 | 44 | def __len__(self): 45 | return len(self.bboxes) 46 | 47 | def __getitem__(self, i): 48 | scene_id, view_id, obj_id = self.scene_ids[i], self.view_ids[i], self.obj_ids[i] 49 | instance = dict( 50 | scene_id=scene_id, img_id=view_id, obj_id=obj_id, obj_idx=self.obj_idxs[obj_id], 51 | K=np.array(self.scene_cameras[scene_id][str(view_id)]['cam_K']).reshape((3, 3)), 52 | mask_visib=self.bboxes[i], bbox=self.bboxes[i].round().astype(int), 53 | ) 54 | for aux in self.auxs: 55 | instance = aux(instance, self) 56 | return instance 57 | 58 | 59 | def _main(): 60 | import argparse 61 | import cv2 62 | from .config import config 63 | from . import std_auxs 64 | from .obj import load_objs 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('dataset') 68 | args = parser.parse_args() 69 | 70 | dataset_root = Path(f'bop/{args.dataset}') 71 | cfg = config[args.dataset] 72 | objs, obj_ids = load_objs(dataset_root / cfg.model_folder, None) 73 | 74 | data = DetectorCropDataset( 75 | dataset_root=dataset_root, cfg=cfg, obj_ids=obj_ids, 76 | detection_folder=Path(f'detection_results/{args.dataset}'), 77 | auxs=( 78 | std_auxs.RgbLoader(), 79 | std_auxs.RandomRotatedMaskCrop(224, max_angle=0, offset_scale=0, use_bbox=True), 80 | ), 81 | ) 82 | while True: 83 | i = np.random.randint(len(data)) 84 | img = data[i]['rgb_crop'] 85 | cv2.imshow('', img[..., ::-1]) 86 | if cv2.waitKey() == ord('q'): 87 | quit() 88 | 89 | 90 | if __name__ == '__main__': 91 | _main() 92 | -------------------------------------------------------------------------------- /surfemb/pose_refine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | from scipy.optimize import minimize 6 | 7 | from .utils import Rodrigues 8 | 9 | 10 | def refine_pose(R: np.ndarray, t: np.ndarray, query_img, renderer, obj_idx, K_crop, obj_, model, keys_verts, 11 | interpolation='bilinear', n_samples_denom=4096, method='BFGS'): 12 | """ 13 | Refines the pose estimate (R, t) by local maximization of the log prob (according to the queries / keys) 14 | of the initially visible surface. 15 | Bilinear interpolation and PyTorch autograd to get the gradient, and BFGS for optimization. 16 | """ 17 | h, w, _ = query_img.shape 18 | assert h == w 19 | res_crop = h 20 | device = model.device 21 | 22 | # Get the object coordinates and keys of the initially visible surface 23 | coord_img = renderer.render(obj_idx, K_crop, R, t) 24 | mask = coord_img[..., 3] == 1. 25 | coord_norm_masked = torch.from_numpy(coord_img[..., :3][mask]).to(device) # (N, 3) 26 | keys_masked = model.infer_mlp(coord_norm_masked, obj_idx) # (N, emb_dim) 27 | coord_masked = coord_norm_masked * obj_.scale + torch.from_numpy(obj_.offset).to(device) 28 | coord_masked = torch.cat((coord_masked, torch.ones(len(coord_masked), 1, device=device)), dim=1) # (N, 4) 29 | K_crop = torch.from_numpy(K_crop).to(device) 30 | 31 | # precompute log denominator in softmax (log sum exp over keys) per query 32 | # needs to be batched or estimated with reduced amount of keys (as implemented here) because of memory requirements 33 | keys_sampled = keys_verts[torch.randperm(len(keys_verts), device=device)[:n_samples_denom]] 34 | denom_img = torch.logsumexp(query_img @ keys_sampled.T, dim=-1, keepdim=True) # (H, W, 1) 35 | coord_masked = coord_masked.float() 36 | K_crop = K_crop.float() 37 | 38 | def sample(img, p_img_norm): 39 | samples = F.grid_sample( 40 | img.permute(2, 0, 1)[None], # (1, d, H, W) 41 | p_img_norm[None, None], # (1, 1, N, 2) 42 | align_corners=False, 43 | padding_mode='border', 44 | mode=interpolation, 45 | ) # (1, d, 1, N) 46 | return samples[0, :, 0].T # (N, d) 47 | 48 | def objective(pose: np.ndarray, return_grad=False): 49 | pose = torch.from_numpy(pose).float() 50 | pose.requires_grad = return_grad 51 | Rt = torch.cat(( 52 | Rodrigues.apply(pose[:3]), 53 | pose[3:, None], 54 | ), dim=1).to(device) # (3, 4) 55 | 56 | P = K_crop @ Rt 57 | p_img = coord_masked @ P.T 58 | p_img = p_img[..., :2] / p_img[..., 2:] # (N, 2) 59 | # pytorch grid_sample coordinates 60 | p_img_norm = (p_img + 0.5) * (2 / res_crop) - 1 61 | 62 | query_sampled = sample(query_img, p_img_norm) # (N, emb_dim) 63 | log_nominator = (keys_masked * query_sampled).sum(dim=-1) # (N,) 64 | log_denominator = sample(denom_img, p_img_norm)[:, 0] # (N,) 65 | score = -(log_nominator.mean() - log_denominator.mean()) / 2 66 | 67 | if return_grad: 68 | score.backward() 69 | return pose.grad.detach().cpu().numpy() 70 | else: 71 | return score.item() 72 | 73 | rvec = cv2.Rodrigues(R)[0] 74 | pose = np.array((*rvec[:, 0], *t[:, 0])) 75 | result = minimize(fun=objective, x0=pose, jac=lambda pose: objective(pose, return_grad=True), method=method) 76 | 77 | pose = result.x 78 | R = cv2.Rodrigues(pose[:3])[0] 79 | t = pose[3:, None] 80 | return R, t, result.fun 81 | -------------------------------------------------------------------------------- /surfemb/dep/unet.py: -------------------------------------------------------------------------------- 1 | # Initially from https://github.com/usuyama/pytorch-unet (MIT License) 2 | # Architecture slightly changed (removed some expensive high-res convolutions) 3 | # and extended to allow multiple decoders 4 | import torch 5 | from torch import nn 6 | import torchvision 7 | 8 | 9 | def convrelu(in_channels, out_channels, kernel, padding): 10 | return nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 12 | nn.ReLU(inplace=True), 13 | ) 14 | 15 | 16 | class ResNetUNet(nn.Module): 17 | def __init__(self, n_class, feat_preultimate=64, n_decoders=1): 18 | super().__init__() 19 | 20 | # shared encoder 21 | self.base_model = torchvision.models.resnet18(pretrained=True) 22 | self.base_layers = list(self.base_model.children()) 23 | 24 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 25 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 26 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 27 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 28 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 29 | 30 | # n_decoders 31 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 32 | self.decoders = [dict( 33 | layer0_1x1=convrelu(64, 64, 1, 0), 34 | layer1_1x1=convrelu(64, 64, 1, 0), 35 | layer2_1x1=convrelu(128, 128, 1, 0), 36 | layer3_1x1=convrelu(256, 256, 1, 0), 37 | layer4_1x1=convrelu(512, 512, 1, 0), 38 | conv_up3=convrelu(256 + 512, 512, 3, 1), 39 | conv_up2=convrelu(128 + 512, 256, 3, 1), 40 | conv_up1=convrelu(64 + 256, 256, 3, 1), 41 | conv_up0=convrelu(64 + 256, 128, 3, 1), 42 | conv_original_size=convrelu(128, feat_preultimate, 3, 1), 43 | conv_last=nn.Conv2d(feat_preultimate, n_class, 1), 44 | ) for _ in range(n_decoders)] 45 | 46 | # register decoder modules 47 | for i, decoder in enumerate(self.decoders): 48 | for key, val in decoder.items(): 49 | setattr(self, f'decoder{i}_{key}', val) 50 | 51 | def forward(self, input, decoder_idx=None): 52 | if decoder_idx is None: 53 | assert len(self.decoders) == 1 54 | decoder_idx = [0] 55 | else: 56 | assert len(decoder_idx) == 1 or len(decoder_idx) == len(input) 57 | 58 | # encoder 59 | layer0 = self.layer0(input) 60 | layer1 = self.layer1(layer0) 61 | layer2 = self.layer2(layer1) 62 | layer3 = self.layer3(layer2) 63 | layer4 = self.layer4(layer3) 64 | layers = [layer0, layer1, layer2, layer3, layer4] 65 | 66 | # decoders 67 | out = [] 68 | for i, dec_idx in enumerate(decoder_idx): 69 | decoder = self.decoders[dec_idx] 70 | batch_slice = slice(None) if len(decoder_idx) == 1 else slice(i, i + 1) 71 | 72 | x = decoder['layer4_1x1'](layer4[batch_slice]) 73 | x = self.upsample(x) 74 | for layer_idx in 3, 2, 1, 0: 75 | layer_slice = layers[layer_idx][batch_slice] 76 | layer_projection = decoder[f'layer{layer_idx}_1x1'](layer_slice) 77 | x = torch.cat([x, layer_projection], dim=1) 78 | x = decoder[f'conv_up{layer_idx}'](x) 79 | x = self.upsample(x) 80 | 81 | x = decoder['conv_original_size'](x) 82 | out.append(decoder['conv_last'](x)) 83 | 84 | if len(decoder_idx) == 1: 85 | # out: 1 x (B, C, H, W) 86 | return out[0] 87 | else: 88 | # out: B x (1, C, H, W) 89 | return torch.stack(out)[:, 0] 90 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/load_detection_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | CosyPose must be installed to run this. 3 | Cached versions are made available to avoid this dependency. 4 | """ 5 | 6 | import argparse 7 | from pathlib import Path 8 | import json 9 | from collections import defaultdict 10 | 11 | import torch 12 | import numpy as np 13 | import cv2 14 | 15 | from ...data.config import config 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('dataset') 19 | parser.add_argument('--debug', action='store_true') 20 | args = parser.parse_args() 21 | 22 | cfg = config[args.dataset] 23 | 24 | 25 | def recursive_dict(*_): 26 | return defaultdict(recursive_dict) 27 | 28 | 29 | # only save detection results from the target objects 30 | targets_raw = json.load(open(f'data/bop/{args.dataset}/test_targets_bop19.json')) 31 | inst_count = recursive_dict() 32 | target_count = 0 33 | for t in targets_raw: 34 | inst_count[t['scene_id']][t['im_id']][t['obj_id']] = t['inst_count'] 35 | target_count += t['inst_count'] 36 | print('target count', target_count) 37 | 38 | results = torch.load(f'../cosypose/local_data/results/bop-pbr--223026/dataset={args.dataset}/results.pth.tar') 39 | preds = results['predictions'] 40 | detections = preds['maskrcnn_detections/detections'] 41 | infos = preds['maskrcnn_detections/coarse/iteration=1'].infos 42 | times = infos.time.to_numpy() # times for detection results only are not available 43 | scene_ids = infos.scene_id.to_numpy() 44 | view_ids = infos.view_id.to_numpy() 45 | scores = infos.score.to_numpy() 46 | obj_ids = np.array([int(label[-6:]) for label in infos.label]) 47 | bboxes = detections.bboxes.numpy() 48 | print('det count', len(scores)) 49 | 50 | debug = args.debug 51 | mask_all = np.zeros(len(scores), dtype=bool) 52 | for scene_id in sorted(inst_count.keys()): 53 | inst_count_scene = inst_count[scene_id] 54 | scene_mask = scene_ids == scene_id 55 | for view_id in sorted(inst_count_scene.keys()): 56 | inst_count_view = inst_count_scene[view_id] 57 | view_mask = view_ids == view_id 58 | if debug: 59 | img = cv2.imread(f'bop/{args.dataset}/{cfg.test_folder}/{scene_id:06d}/{cfg.img_folder}/' 60 | f'{view_id:06d}.{cfg.img_ext}') 61 | break_view = False 62 | for obj_id in sorted(inst_count_view.keys()): 63 | obj_mask = obj_ids == obj_id 64 | mask = scene_mask & view_mask & obj_mask 65 | arg_mask = np.argwhere(mask).reshape(-1) 66 | mask_all[arg_mask] = True 67 | if debug: 68 | print(f'obj_id: {obj_id}, n_targets: {inst_count_view[obj_id]}, n_est: {mask.sum()}') 69 | print('scores: ', scores[arg_mask]) 70 | img_ = img.copy() 71 | for j, i in enumerate(arg_mask): 72 | l, t, r, b = bboxes[i] 73 | c = (0, 255, 0) if j < inst_count_view[obj_id] else (0, 0, 255) 74 | cv2.rectangle(img_, (l, t), (r, b), c) 75 | cv2.putText(img_, f'{scores[i]:.4f}', (int(l) + 2, int(b) - 2), cv2.FONT_HERSHEY_PLAIN, 1, c) 76 | cv2.imshow('', img_) 77 | key = cv2.waitKey() 78 | if key == ord('q'): 79 | quit() 80 | elif key == ord('s'): 81 | break_view = True 82 | break 83 | if break_view: 84 | break 85 | print('det masked count', mask_all.sum()) 86 | 87 | folder = Path('data/detection_results') 88 | folder.mkdir(exist_ok=True) 89 | folder = folder / args.dataset 90 | folder.mkdir(exist_ok=True) 91 | 92 | np.save(f'{folder}/scene_ids.npy', scene_ids[mask_all]) 93 | np.save(f'{folder}/view_ids.npy', view_ids[mask_all]) 94 | np.save(f'{folder}/scores.npy', scores[mask_all]) 95 | np.save(f'{folder}/obj_ids.npy', obj_ids[mask_all]) 96 | np.save(f'{folder}/bboxes.npy', bboxes[mask_all]) 97 | np.save(f'{folder}/times.npy', times[mask_all]) 98 | -------------------------------------------------------------------------------- /surfemb/data/instance.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Sequence 4 | import warnings 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch.utils.data 9 | 10 | from .config import DatasetConfig 11 | 12 | 13 | # BopInstanceDataset should only be used with test=True for debugging reasons 14 | # use detector_crops.DetectorCropDataset for actual test inference 15 | 16 | 17 | class BopInstanceDataset(torch.utils.data.Dataset): 18 | def __init__( 19 | self, dataset_root: Path, pbr: bool, test: bool, cfg: DatasetConfig, 20 | obj_ids: Sequence[int], 21 | scene_ids=None, min_visib_fract=0.1, min_px_count_visib=1024, 22 | auxs: Sequence['BopInstanceAux'] = tuple(), show_progressbar=True, 23 | ): 24 | self.pbr, self.test, self.cfg = pbr, test, cfg 25 | if pbr: 26 | assert not test 27 | self.data_folder = dataset_root / 'train_pbr' 28 | self.img_folder = 'rgb' 29 | self.depth_folder = 'depth' 30 | self.img_ext = 'jpg' 31 | self.depth_ext = 'png' 32 | else: 33 | self.data_folder = dataset_root / (cfg.test_folder if test else cfg.train_folder) 34 | self.img_folder = cfg.img_folder 35 | self.depth_folder = cfg.depth_folder 36 | self.img_ext = cfg.img_ext 37 | self.depth_ext = cfg.depth_ext 38 | 39 | self.auxs = auxs 40 | obj_idxs = {obj_id: idx for idx, obj_id in enumerate(obj_ids)} 41 | self.instances = [] 42 | if scene_ids is None: 43 | scene_ids = sorted([int(p.name) for p in self.data_folder.glob('*')]) 44 | for scene_id in tqdm(scene_ids, 'loading crop info') if show_progressbar else scene_ids: 45 | scene_folder = self.data_folder / f'{scene_id:06d}' 46 | scene_gt = json.load((scene_folder / 'scene_gt.json').open()) 47 | scene_gt_info = json.load((scene_folder / 'scene_gt_info.json').open()) 48 | scene_camera = json.load((scene_folder / 'scene_camera.json').open()) 49 | 50 | for img_id, poses in scene_gt.items(): 51 | img_info = scene_gt_info[img_id] 52 | K = np.array(scene_camera[img_id]['cam_K']).reshape((3, 3)).copy() 53 | if pbr: 54 | warnings.warn('Altering camera matrix, since PBR camera matrix doesnt seem to be correct') 55 | K[:2, 2] -= 0.5 56 | 57 | for pose_idx, pose in enumerate(poses): 58 | obj_id = pose['obj_id'] 59 | if obj_ids is not None and obj_id not in obj_ids: 60 | continue 61 | pose_info = img_info[pose_idx] 62 | if pose_info['visib_fract'] < min_visib_fract: 63 | continue 64 | if pose_info['px_count_visib'] < min_px_count_visib: 65 | continue 66 | 67 | bbox_visib = pose_info['bbox_visib'] 68 | bbox_obj = pose_info['bbox_obj'] 69 | 70 | cam_R_obj = np.array(pose['cam_R_m2c']).reshape(3, 3) 71 | cam_t_obj = np.array(pose['cam_t_m2c']).reshape(3, 1) 72 | 73 | self.instances.append(dict( 74 | scene_id=scene_id, img_id=int(img_id), K=K, obj_id=obj_id, pose_idx=pose_idx, 75 | bbox_visib=bbox_visib, bbox_obj=bbox_obj, cam_R_obj=cam_R_obj, cam_t_obj=cam_t_obj, 76 | obj_idx=obj_idxs[obj_id], 77 | )) 78 | 79 | for aux in self.auxs: 80 | aux.init(self) 81 | 82 | def __len__(self): 83 | return len(self.instances) 84 | 85 | def __getitem__(self, i): 86 | instance = self.instances[i].copy() 87 | for aux in self.auxs: 88 | instance = aux(instance, self) 89 | return instance 90 | 91 | 92 | class BopInstanceAux: 93 | def init(self, dataset: BopInstanceDataset): 94 | pass 95 | 96 | def __call__(self, data: dict, dataset: BopInstanceDataset) -> dict: 97 | pass 98 | 99 | 100 | def _main(): 101 | from .config import tless 102 | for pbr, test in (True, False), (False, False), (False, True): 103 | print(f'pbr: {pbr}, test: {test}') 104 | data = BopInstanceDataset(dataset_root=Path('bop/tless'), pbr=pbr, test=test, cfg=tless, obj_ids=range(1, 31)) 105 | print(len(data)) 106 | 107 | 108 | if __name__ == '__main__': 109 | _main() 110 | -------------------------------------------------------------------------------- /surfemb/data/renderer.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import moderngl 5 | 6 | from .obj import Obj 7 | 8 | 9 | def orthographic_matrix(left, right, bottom, top, near, far): 10 | return np.array(( 11 | (2 / (right - left), 0, 0, -(right + left) / (right - left)), 12 | (0, 2 / (top - bottom), 0, -(top + bottom) / (top - bottom)), 13 | (0, 0, -2 / (far - near), -(far + near) / (far - near)), 14 | (0, 0, 0, 1), 15 | )) 16 | 17 | 18 | def projection_matrix(K, w, h, near=10., far=10000.): # 1 cm to 10 m 19 | # transform from cv2 camera coordinates to opengl (flipping sign of y and z) 20 | view = np.eye(4) 21 | view[1:3] *= -1 22 | 23 | # see http://ksimek.github.io/2013/06/03/calibrated_cameras_in_opengl/ 24 | persp = np.zeros((4, 4)) 25 | persp[:2, :3] = K[:2, :3] 26 | persp[2, 2:] = near + far, near * far 27 | persp[3, 2] = -1 28 | # transform the camera matrix from cv2 to opengl as well (flipping sign of y and z) 29 | persp[:2, 1:3] *= -1 30 | 31 | # The origin of the image is in the *center* of the top left pixel. 32 | # The orthographic matrix should map the whole image *area* into the opengl NDC, therefore the -.5 below: 33 | orth = orthographic_matrix(-.5, w - .5, -.5, h - .5, near, far) 34 | return orth @ persp @ view 35 | 36 | 37 | class ObjCoordRenderer: 38 | def __init__(self, objs: Sequence[Obj], w: int, h: int = None, device_idx=0): 39 | self.objs = objs 40 | if h is None: 41 | h = w 42 | self.h, self.w = h, w 43 | self.ctx = moderngl.create_context(standalone=True, backend='egl', device_index=device_idx) 44 | self.ctx.disable(moderngl.CULL_FACE) 45 | self.ctx.enable(moderngl.DEPTH_TEST) 46 | self.fbo = self.ctx.simple_framebuffer((w, h), components=4, dtype='f4') 47 | self.near, self.far = 10., 10000., 48 | 49 | self.prog = self.ctx.program( 50 | vertex_shader=""" 51 | #version 330 52 | uniform vec3 offset; 53 | uniform float scale; 54 | uniform mat4 mvp; 55 | in vec3 in_vert; 56 | out vec3 color; 57 | void main() { 58 | gl_Position = mvp * vec4(in_vert, 1.0); 59 | color = (in_vert - offset) / scale; 60 | } 61 | """, 62 | fragment_shader=""" 63 | #version 330 64 | out vec4 fragColor; 65 | in vec3 color; 66 | void main() { 67 | fragColor = vec4(color, 1.0); 68 | } 69 | """, 70 | ) 71 | 72 | self.vaos = [] 73 | for obj in self.objs: 74 | vertices = obj.mesh.vertices[obj.mesh.faces].astype('f4') # (n, 3) 75 | vao = self.ctx.simple_vertex_array(self.prog, self.ctx.buffer(vertices), 'in_vert') 76 | self.vaos.append(vao) 77 | 78 | def read(self): 79 | return np.frombuffer(self.fbo.read(components=4, dtype='f4'), 'f4').reshape((self.h, self.w, 4)) 80 | 81 | def read_depth(self): 82 | depth = np.frombuffer(self.fbo.read(attachment=-1, dtype='f4'), 'f4').reshape(self.h, self.w) 83 | neg_mask = depth == 1 84 | near, far = 10., 10000. # TODO: use projection matrix instead of the default values 85 | depth = 2 * depth - 1 86 | depth = 2 * near * far / (far + near - depth * (far - near)) 87 | depth[neg_mask] = 0 88 | return depth 89 | 90 | def render(self, obj_idx, K, R, t, clear=True, read=True, read_depth=False): 91 | obj = self.objs[obj_idx] 92 | mv = np.concatenate(( 93 | np.concatenate((R, t), axis=1), 94 | [[0, 0, 0, 1]], 95 | )) 96 | mvp = projection_matrix(K, self.w, self.h, self.near, self.far) @ mv 97 | self.prog['mvp'].value = tuple(mvp.T.astype('f4').reshape(-1)) 98 | self.prog['scale'].value = obj.scale 99 | self.prog['offset'].value = tuple(obj.offset.astype('f4')) 100 | 101 | self.fbo.use() 102 | if clear: 103 | self.ctx.clear() 104 | self.vaos[obj_idx].render(mode=moderngl.TRIANGLES) 105 | if read_depth: 106 | return self.read_depth() 107 | elif read: 108 | return self.read() 109 | else: 110 | return None 111 | 112 | @staticmethod 113 | def extract_mask(model_coords_img: np.ndarray): 114 | return model_coords_img[..., 3] == 255 115 | 116 | def denormalize(self, model_coords: np.ndarray, obj_idx: int): 117 | return model_coords * self.objs[obj_idx].scale + self.objs[obj_idx].offset 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SurfEmb 2 | 3 | **SurfEmb: Dense and Continuous Correspondence Distributions 4 | for Object Pose Estimation with Learnt Surface Embeddings** 5 | Rasmus Laurvig Haugard, Anders Glent Buch 6 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2022 7 | [pre-print](https://arxiv.org/abs/2111.13489) | 8 | [project-site](https://surfemb.github.io/) 9 | 10 | The easiest way to explore correspondence distributions is through the [project site](https://surfemb.github.io/). 11 | 12 | The following describes how to reproduce the results. 13 | 14 | ## Install 15 | 16 | Download surfemb: 17 | 18 | ```shell 19 | $ git clone https://github.com/rasmushaugaard/surfemb.git 20 | $ cd surfemb 21 | ``` 22 | 23 | All following commands are expected to be run in the project root directory. 24 | 25 | [Install conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) 26 | , create a new environment, *surfemb*, and activate it: 27 | 28 | ```shell 29 | $ conda env create -f environment.yml 30 | $ conda activate surfemb 31 | ``` 32 | 33 | ## Download BOP data 34 | 35 | Download and extract datasets from the [BOP site](https://bop.felk.cvut.cz/datasets/). 36 | *Base archive*, and *object models* are needed for both training and inference. For training, *PBR-BlenderProc4BOP 37 | training images* are needed as well, and for inference, the *BOP'19/20 test images* are needed. 38 | 39 | Extract the datasets under ```data/bop``` (or make a symbolic link). 40 | 41 | ## Model 42 | 43 | Download a trained model (see *releases*): 44 | 45 | ```shell 46 | $ wget https://github.com/rasmushaugaard/surfemb/releases/download/v0.0.1/tless-2rs64lwh.compact.ckpt -P data/models 47 | ``` 48 | 49 | **OR** 50 | 51 | Train a model: 52 | 53 | ```shell 54 | $ python -m surfemb.scripts.train [dataset] --gpus [gpu ids] 55 | ``` 56 | 57 | For example, to train a model on *T-LESS* on *cuda:0* 58 | 59 | ```shell 60 | $ python -m surfemb.scripts.train tless --gpus 0 61 | ``` 62 | 63 | ## Inference data 64 | 65 | We use the detections from [CosyPose's](https://github.com/ylabbe/cosypose) MaskRCNN models, and sample surface points 66 | evenly for inference. 67 | For ease of use, this data can be downloaded and extracted as follows: 68 | 69 | ```shell 70 | $ wget https://github.com/rasmushaugaard/surfemb/releases/download/v0.0.1/inference_data.zip 71 | $ unzip inference_data.zip 72 | ``` 73 | 74 | **OR** 75 | 76 |
77 | Extract detections and sample surface points 78 | 79 | ### Surface samples 80 | 81 | First, flip the normals of ITODD object 18, which is inside out. 82 | 83 | Then remove invisible parts of the objects 84 | 85 | ```shell 86 | $ python -m surfemb.scripts.misc.surface_samples_remesh_visible [dataset] 87 | ``` 88 | 89 | sample points evenly from the mesh surface 90 | 91 | ```shell 92 | $ python -m surfemb.scripts.misc.surface_samples_sample_even [dataset] 93 | ``` 94 | 95 | and recover the normals for the sampled points. 96 | 97 | ```shell 98 | $ python -m surfemb.scripts.misc.surface_samples_recover_normals [dataset] 99 | ``` 100 | 101 | ### Detection results 102 | 103 | Download CosyPose in the same directory as SurfEmb was downloaded in, install CosyPose and follow their guide to 104 | download their BOP-trained detection results. Then: 105 | 106 | ```shell 107 | $ python -m surfemb.scripts.misc.load_detection_results [dataset] 108 | ``` 109 | 110 |
111 | 112 | ## Inference inspection 113 | 114 | To see pose estimation examples on the training images run 115 | 116 | ```shell 117 | $ python -m surfemb.scripts.infer_debug [model_path] --device [device] 118 | ``` 119 | 120 | *[device]* could for example be *cuda:0* or *cpu*. 121 | 122 | Add ```--real``` to use the test images with simulated crops based on the ground truth poses, or further 123 | add ```--detections``` to use the CosyPose detections. 124 | 125 | ## Inference for BOP evaluation 126 | 127 | Inference is run on the (real) test images with CosyPose detections: 128 | 129 | ```shell 130 | $ python -m surfemb.scripts.infer [model_path] --device [device] 131 | ``` 132 | 133 | Pose estimation results are saved to ```data/results```. 134 | To obtain results with depth (requires running normal inference first), run 135 | 136 | ```shell 137 | $ python -m surfemb.scripts.infer_refine_depth [model_path] --device [device] 138 | ``` 139 | 140 | The results can be formatted for BOP evaluation using 141 | 142 | ```shell 143 | $ python -m surfemb.scripts.misc.format_results_for_eval [poses_path] 144 | ``` 145 | 146 | Either upload the formatted results to the BOP Challenge website or evaluate using 147 | the [BOP toolkit](https://github.com/thodan/bop_toolkit). 148 | 149 | ## Extra 150 | 151 | Custom dataset: 152 | Format the dataset as a BOP dataset and put it in *data/bop*. -------------------------------------------------------------------------------- /surfemb/scripts/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | 4 | import torch.utils.data 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.loggers import WandbLogger 7 | import wandb 8 | 9 | from .. import utils 10 | from ..data import obj, instance 11 | from ..data.config import config 12 | from ..surface_embedding import SurfaceEmbeddingModel 13 | 14 | 15 | def worker_init_fn(*_): 16 | # each worker should only use one os thread 17 | # numpy/cv2 takes advantage of multithreading by default 18 | import os 19 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 20 | os.environ['MKL_NUM_THREADS'] = '1' 21 | import cv2 22 | cv2.setNumThreads(0) 23 | 24 | # random seed 25 | import numpy as np 26 | np.random.seed(None) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('dataset') 32 | parser.add_argument('--n-valid', type=int, default=200) 33 | parser.add_argument('--res-data', type=int, default=256) 34 | parser.add_argument('--res-crop', type=int, default=224) 35 | parser.add_argument('--batch-size', type=int, default=16) 36 | parser.add_argument('--num-workers', type=int, default=None) 37 | parser.add_argument('--min-visib-fract', type=float, default=0.1) 38 | parser.add_argument('--max-steps', type=int, default=500_000) 39 | parser.add_argument('--gpus', type=int, nargs='+', default=[0]) 40 | parser.add_argument('--debug', action='store_true') 41 | parser.add_argument('--ckpt', default=None) 42 | parser.add_argument('--no-synth', dest='synth', action='store_false') 43 | parser.add_argument('--real', action='store_true') 44 | 45 | parser = SurfaceEmbeddingModel.model_specific_args(parser) 46 | args = parser.parse_args() 47 | debug = args.debug 48 | root = Path('data/bop') / args.dataset 49 | cfg = config[args.dataset] 50 | 51 | # load objs 52 | objs, obj_ids = obj.load_objs(root / cfg.model_folder) 53 | assert len(obj_ids) > 0 54 | 55 | # model 56 | if args.ckpt: 57 | assert args.dataset == Path(args.ckpt).name.split('-')[0] 58 | model = SurfaceEmbeddingModel.load_from_checkpoint(args.ckpt) 59 | else: 60 | model = SurfaceEmbeddingModel(n_objs=len(obj_ids), **vars(args)) 61 | 62 | # datasets 63 | auxs = model.get_auxs(objs, args.res_crop) 64 | data = utils.EmptyDataset() 65 | if args.synth: 66 | data += instance.BopInstanceDataset( 67 | dataset_root=root, pbr=True, test=False, cfg=cfg, obj_ids=obj_ids, auxs=auxs, 68 | min_visib_fract=args.min_visib_fract, scene_ids=[1] if debug else None, 69 | ) 70 | if args.real: 71 | assert args.dataset in {'tless', 'tudl', 'ycbv'} 72 | data_real = instance.BopInstanceDataset( 73 | dataset_root=root, pbr=False, test=False, cfg=cfg, obj_ids=obj_ids, auxs=auxs, 74 | min_visib_fract=args.min_visib_fract, scene_ids=[1] if debug else None, 75 | ) 76 | if args.synth: 77 | data = utils.balanced_dataset_concat(data, data_real) 78 | else: 79 | data = data_real 80 | 81 | n_valid = args.n_valid 82 | data_train, data_valid = torch.utils.data.random_split( 83 | data, (len(data) - n_valid, n_valid), 84 | generator=torch.Generator().manual_seed(0), 85 | ) 86 | 87 | loader_args = dict( 88 | batch_size=args.batch_size, 89 | num_workers=torch.get_num_threads() if args.num_workers is None else args.num_workers, 90 | persistent_workers=True, shuffle=True, 91 | worker_init_fn=worker_init_fn, pin_memory=True, 92 | ) 93 | loader_train = torch.utils.data.DataLoader(data_train, drop_last=True, **loader_args) 94 | loader_valid = torch.utils.data.DataLoader(data_valid, **loader_args) 95 | 96 | # train 97 | log_dir = Path('data/logs') 98 | log_dir.mkdir(parents=True, exist_ok=True) 99 | run = wandb.init(project='surfemb', dir=log_dir) 100 | run.name = run.id 101 | 102 | logger = pl.loggers.WandbLogger(experiment=run) 103 | logger.log_hyperparams(args) 104 | 105 | model_ckpt_cb = pl.callbacks.ModelCheckpoint(dirpath='data/models/', save_top_k=0, save_last=True) 106 | model_ckpt_cb.CHECKPOINT_NAME_LAST = f'{args.dataset}-{run.id}' 107 | trainer = pl.Trainer( 108 | resume_from_checkpoint=args.ckpt, 109 | logger=logger, gpus=args.gpus, max_steps=args.max_steps, 110 | callbacks=[ 111 | pl.callbacks.LearningRateMonitor(), 112 | model_ckpt_cb, 113 | ], 114 | val_check_interval=min(1., n_valid / len(data) * 50) # spend ~1/50th of the time on validation 115 | ) 116 | trainer.fit(model, loader_train, loader_valid) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /surfemb/scripts/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from .. import utils 9 | from ..data import detector_crops 10 | from ..data.config import config 11 | from ..data.obj import load_objs 12 | from ..data.renderer import ObjCoordRenderer 13 | from ..surface_embedding import SurfaceEmbeddingModel 14 | from ..pose_est import estimate_pose 15 | from ..pose_refine import refine_pose 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('model_path') 19 | parser.add_argument('--device', default='cuda:0') 20 | parser.add_argument('--res-data', type=int, default=256) 21 | parser.add_argument('--res-crop', type=int, default=224) 22 | parser.add_argument('--max-poses', type=int, default=10000) 23 | parser.add_argument('--max-pose-evaluations', type=int, default=1000) 24 | parser.add_argument('--no-rotation-ensemble', dest='rotation_ensemble', action='store_false') 25 | 26 | args = parser.parse_args() 27 | res_crop = args.res_crop 28 | device = torch.device(args.device) 29 | model_path = Path(args.model_path) 30 | assert model_path.is_file() 31 | model_name = model_path.name.split('.')[0] 32 | dataset = model_name.split('-')[0] 33 | 34 | results_dir = Path('data/results') 35 | results_dir.mkdir(exist_ok=True) 36 | poses_fp = results_dir / f'{model_name}-poses.npy' 37 | poses_scores_fp = results_dir / f'{model_name}-poses-scores.npy' 38 | poses_timings_fp = results_dir / f'{model_name}-poses-timings.npy' 39 | for fp in poses_fp, poses_scores_fp, poses_timings_fp: 40 | assert not fp.exists() 41 | 42 | # load model 43 | model = SurfaceEmbeddingModel.load_from_checkpoint(str(model_path)).eval().to(device) # type: SurfaceEmbeddingModel 44 | model.freeze() 45 | 46 | # load data 47 | root = Path('data/bop') / dataset 48 | cfg = config[dataset] 49 | objs, obj_ids = load_objs(root / cfg.model_folder) 50 | assert len(obj_ids) > 0 51 | surface_samples, surface_sample_normals = utils.load_surface_samples(dataset, obj_ids) 52 | data = detector_crops.DetectorCropDataset( 53 | dataset_root=root, cfg=cfg, obj_ids=obj_ids, 54 | detection_folder=Path(f'data/detection_results/{dataset}'), 55 | auxs=model.get_infer_auxs(objs=objs, crop_res=res_crop, from_detections=True) 56 | ) 57 | renderer = ObjCoordRenderer(objs, w=res_crop, h=res_crop) 58 | 59 | # infer 60 | all_poses = np.empty((2, len(data), 3, 4)) 61 | all_scores = np.ones(len(data)) * -np.inf 62 | time_forward, time_pnpransac, time_refine = [], [], [] 63 | 64 | 65 | def infer(i, d): 66 | obj_idx = d['obj_idx'] 67 | img = d['rgb_crop'] 68 | K_crop = d['K_crop'] 69 | 70 | with utils.add_timing_to_list(time_forward): 71 | mask_lgts, query_img = model.infer_cnn(img, obj_idx, rotation_ensemble=args.rotation_ensemble) 72 | mask_lgts[0, 0].item() # synchronize for timing 73 | 74 | # keys are independent of input (could be cached, but it's not the bottleneck) 75 | obj_ = objs[obj_idx] 76 | verts = surface_samples[obj_idx] 77 | verts_norm = (verts - obj_.offset) / obj_.scale 78 | obj_keys = model.infer_mlp(torch.from_numpy(verts_norm).float().to(device), obj_idx) 79 | verts = torch.from_numpy(verts).float().to(device) 80 | 81 | with utils.add_timing_to_list(time_pnpransac): 82 | R_est, t_est, scores, *_ = estimate_pose( 83 | mask_lgts=mask_lgts, query_img=query_img, 84 | obj_pts=verts, obj_normals=surface_sample_normals[obj_idx], obj_keys=obj_keys, 85 | obj_diameter=obj_.diameter, K=K_crop, 86 | ) 87 | success = len(scores) > 0 88 | if success: 89 | best_idx = torch.argmax(scores).item() 90 | all_scores[i] = scores[best_idx].item() 91 | R_est, t_est = R_est[best_idx].cpu().numpy(), t_est[best_idx].cpu().numpy()[:, None] 92 | else: 93 | R_est, t_est = np.eye(3), np.zeros((3, 1)) 94 | 95 | with utils.add_timing_to_list(time_refine): 96 | if success: 97 | R_est_r, t_est_r, score_r = refine_pose( 98 | R=R_est, t=t_est, query_img=query_img, K_crop=K_crop, 99 | renderer=renderer, obj_idx=obj_idx, obj_=obj_, model=model, keys_verts=obj_keys, 100 | ) 101 | else: 102 | R_est_r, t_est_r = R_est, t_est 103 | 104 | for j, (R, t) in enumerate([(R_est, t_est), (R_est_r, t_est_r)]): 105 | all_poses[j, i, :3, :3] = R 106 | all_poses[j, i, :3, 3:] = t 107 | 108 | 109 | for i, d in enumerate(tqdm(data, desc='running pose est.', smoothing=0)): 110 | infer(i, d) 111 | 112 | time_forward = np.array(time_forward) 113 | time_pnpransac = np.array(time_pnpransac) 114 | time_refine = np.array(time_refine) 115 | 116 | timings = np.stack(( 117 | time_forward + time_pnpransac, 118 | time_forward + time_pnpransac + time_refine 119 | )) 120 | 121 | np.save(str(poses_fp), all_poses) 122 | np.save(str(poses_scores_fp), all_scores) 123 | np.save(str(poses_timings_fp), timings) 124 | -------------------------------------------------------------------------------- /surfemb/scripts/misc/render_poses.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pathlib import Path 4 | from collections import defaultdict 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from ...data.renderer import ObjCoordRenderer 10 | from ...data.obj import load_objs 11 | from ...data.config import config 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('dataset') 15 | parser.add_argument('poses') 16 | parser.add_argument('--scene_id', type=int, default=None) 17 | parser.add_argument('--view_id', type=int, default=None) 18 | parser.add_argument('--no-refine', action='store_true') 19 | parser.add_argument('--no-bbox', action='store_true') 20 | parser.add_argument('--render-all', action='store_true') 21 | parser.add_argument('--alpha', type=float, default=0.7) 22 | 23 | args = parser.parse_args() 24 | 25 | dataset = args.dataset 26 | cfg = config[dataset] 27 | poses_fp = Path(args.poses) 28 | name = '-'.join(poses_fp.name.split('-')[:-1]) 29 | pose_scores_fp = poses_fp.parent / f'{name.replace("-depth", "")}-poses-scores.npy' 30 | poses = np.load(str(poses_fp))[0 if args.no_refine else 1] 31 | pose_scores = np.load(str(pose_scores_fp)) 32 | show_bbox = not args.no_bbox 33 | 34 | z = poses[:, 2, 3] 35 | 36 | root = Path('data/bop') / dataset 37 | objs, obj_ids = load_objs(root / cfg.model_folder) 38 | 39 | datafolder = root / cfg.test_folder 40 | 41 | 42 | def recursive_dict(*_): 43 | return defaultdict(recursive_dict) 44 | 45 | 46 | inst_count = recursive_dict() 47 | for t in json.load(open(f'data/bop/{args.dataset}/test_targets_bop19.json')): 48 | inst_count[t['scene_id']][t['im_id']][t['obj_id']] = t['inst_count'] 49 | 50 | detection_folder = Path('data/detection_results') / dataset 51 | bboxes, det_obj_ids, scene_ids, det_scores, view_ids = [ 52 | np.load(str(detection_folder / f'{name}.npy')) for name in 53 | ('bboxes', 'obj_ids', 'scene_ids', 'scores', 'view_ids') 54 | ] 55 | 56 | start_idx = list(inst_count.keys()).index(args.scene_id) if args.scene_id is not None else 0 57 | for scene_id, scene in list(inst_count.items())[start_idx:]: 58 | scene_folder = datafolder / f'{scene_id:06d}' 59 | scene_camera = json.load((scene_folder / 'scene_camera.json').open()) 60 | 61 | print(f'scene {scene_id} has {len(scene)} views') 62 | start_idx = list(scene.keys()).index(args.view_id) if args.view_id is not None and scene_id == args.scene_id else 0 63 | for view_id, view in list(scene.items())[start_idx:]: 64 | print(scene_id, view_id) 65 | K = np.array(scene_camera[str(view_id)]['cam_K']).reshape(3, 3) 66 | img = cv2.imread(str(scene_folder / cfg.img_folder / f'{view_id:06d}.{cfg.img_ext}')) 67 | assert img is not None 68 | img_ = img.copy() 69 | h, w = img.shape[:2] 70 | 71 | renderer = ObjCoordRenderer(objs, w=w, h=h) 72 | renderer.ctx.clear() 73 | for obj_id, count in view.items(): 74 | obj_idx = obj_ids.index(obj_id) 75 | mask = (scene_ids == scene_id) & (view_ids == view_id) & (det_obj_ids == obj_id) 76 | if mask.sum() > count: 77 | score_threshold = sorted(det_scores[mask])[-count] 78 | pose_score_threshold = sorted(pose_scores[mask])[-count] 79 | else: 80 | score_threshold = -np.inf 81 | pose_score_threshold = -np.inf 82 | for bbox, det_score, pose, pose_score in zip(bboxes[mask], det_scores[mask], 83 | poses[mask], pose_scores[mask]): 84 | c = (0, 0, 255) if det_score < score_threshold else (0, 255, 0) 85 | if show_bbox: 86 | l, t, r, b = bbox.round().astype(int) 87 | cv2.rectangle(img, (l, t), (r, b), c) 88 | cv2.putText(img, f'{obj_id}/{det_score:.3f}', (l + 2, b - 2), cv2.FONT_HERSHEY_PLAIN, 1, c) 89 | if pose_score >= pose_score_threshold or args.render_all: 90 | R, t = pose[:3, :3], pose[:3, 3:] 91 | if np.allclose(t[:, 0], (0, 0, 0)): 92 | print('no pose found') 93 | else: 94 | renderer.render(obj_idx=obj_idx, K=K, R=R, t=t, clear=False, read=False) 95 | 96 | render = renderer.read().copy() 97 | mask = render[..., 3] != 0 98 | render = render[..., :3] * 0.5 + 0.5 99 | render_vis = np.tile(cv2.cvtColor(img_, cv2.COLOR_BGR2GRAY)[..., None], (1, 1, 3)) / 255 100 | # print(render_vis.shape) 101 | render_vis[mask] = render[mask] * args.alpha + render_vis[mask] * (1 - args.alpha) 102 | if show_bbox: 103 | for obj_id, count in view.items(): 104 | mask = (scene_ids == scene_id) & (view_ids == view_id) & (det_obj_ids == obj_id) 105 | if mask.sum() > count: 106 | score_threshold = sorted(pose_scores[mask])[-count] 107 | else: 108 | score_threshold = -np.inf 109 | for bbox, pose_score in zip(bboxes[mask], pose_scores[mask]): 110 | c = (0, 0, 1.) if pose_score < score_threshold else (0, 1., 0) 111 | l, t, r, b = bbox.round().astype(int) 112 | cv2.rectangle(render_vis, (l, t), (r, b), c) 113 | cv2.putText(render_vis, f'{obj_id}/{pose_score:.3f}', (l + 2, b - 2), cv2.FONT_HERSHEY_PLAIN, 1, c) 114 | 115 | cv2.imshow('render', render_vis) 116 | # cv2.imshow('', img) 117 | next_scene = False 118 | while True: 119 | key = cv2.waitKey() 120 | if key == ord('q'): 121 | quit() 122 | elif key == ord('n'): 123 | next_scene = True 124 | elif key in {225, 233}: # shift / alt 125 | continue 126 | else: 127 | pass # print(key) 128 | break 129 | if next_scene: 130 | break 131 | -------------------------------------------------------------------------------- /surfemb/data/std_auxs.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from .instance import BopInstanceDataset, BopInstanceAux 7 | from .tfms import normalize 8 | 9 | 10 | class RgbLoader(BopInstanceAux): 11 | def __init__(self, copy=False): 12 | self.copy = copy 13 | 14 | def __call__(self, inst: dict, dataset: BopInstanceDataset) -> dict: 15 | scene_id, img_id = inst['scene_id'], inst['img_id'] 16 | fp = dataset.data_folder / f'{scene_id:06d}/{dataset.img_folder}/{img_id:06d}.{dataset.img_ext}' 17 | rgb = cv2.imread(str(fp), cv2.IMREAD_COLOR)[..., ::-1] 18 | assert rgb is not None 19 | inst['rgb'] = rgb.copy() if self.copy else rgb 20 | return inst 21 | 22 | 23 | class MaskLoader(BopInstanceAux): 24 | def __init__(self, mask_type='mask_visib'): 25 | self.mask_type = mask_type 26 | 27 | def __call__(self, inst: dict, dataset: BopInstanceDataset) -> dict: 28 | scene_id, img_id, pose_idx = inst['scene_id'], inst['img_id'], inst['pose_idx'] 29 | mask_folder = dataset.data_folder / f'{scene_id:06d}' / self.mask_type 30 | mask = cv2.imread(str(mask_folder / f'{img_id:06d}_{pose_idx:06d}.png'), cv2.IMREAD_GRAYSCALE) 31 | assert mask is not None 32 | inst[self.mask_type] = mask 33 | return inst 34 | 35 | 36 | class RandomRotatedMaskCrop(BopInstanceAux): 37 | def __init__(self, crop_res: int, crop_scale=1.2, max_angle=np.pi, mask_key='mask_visib', 38 | crop_keys=('rgb', 'mask_visib'), offset_scale=1., use_bbox=False, 39 | rgb_interpolation=(cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC)): 40 | self.crop_res, self.crop_scale = crop_res, crop_scale 41 | self.max_angle = max_angle 42 | self.mask_key = mask_key 43 | self.crop_keys = crop_keys 44 | self.rgb_interpolation = rgb_interpolation 45 | self.offset_scale = offset_scale 46 | self.use_bbox = use_bbox 47 | self.definition_aux = RandomRotatedMaskCropDefinition(self) 48 | self.apply_aux = RandomRotatedMaskCropApply(self) 49 | 50 | def __call__(self, inst: dict, _) -> dict: 51 | inst = self.definition_aux(inst, _) 52 | inst = self.apply_aux(inst, _) 53 | return inst 54 | 55 | 56 | class RandomRotatedMaskCropDefinition(BopInstanceAux): 57 | def __init__(self, parent: RandomRotatedMaskCrop): 58 | self.p = parent 59 | 60 | def __call__(self, inst: dict, _) -> dict: 61 | theta = np.random.uniform(-self.p.max_angle, self.p.max_angle) 62 | S, C = np.sin(theta), np.cos(theta) 63 | R = np.array(( 64 | (C, -S), 65 | (S, C), 66 | )) 67 | 68 | if self.p.use_bbox: 69 | left, top, right, bottom = inst['bbox'] 70 | else: 71 | mask_arg_rotated = np.argwhere(inst[self.p.mask_key])[:, ::-1] @ R.T 72 | left, top = mask_arg_rotated.min(axis=0) 73 | right, bottom = mask_arg_rotated.max(axis=0) 74 | cy, cx = (top + bottom) / 2, (left + right) / 2 75 | 76 | # detector crops can probably be simulated better than this 77 | size = self.p.crop_res / max(bottom - top, right - left) / self.p.crop_scale 78 | size = size * np.random.uniform(1 - 0.05 * self.p.offset_scale, 1 + 0.05 * self.p.offset_scale) 79 | r = self.p.crop_res 80 | M = np.concatenate((R, [[-cx], [-cy]]), axis=1) * size 81 | M[:, 2] += r / 2 82 | 83 | offset = (r - r / self.p.crop_scale) / 2 * self.p.offset_scale 84 | M[:, 2] += np.random.uniform(-offset, offset, 2) 85 | Ms = np.concatenate((M, [[0, 0, 1]])) 86 | 87 | # calculate axis aligned bounding box in the original image of the rotated crop 88 | crop_corners = np.array(((0, 0, 1), (0, r, 1), (r, 0, 1), (r, r, 1))) - (0.5, 0.5, 0) # (4, 3) 89 | crop_corners = np.linalg.inv(Ms) @ crop_corners.T # (3, 4) 90 | crop_corners = crop_corners[:2] / crop_corners[2:] # (2, 4) 91 | left, top = np.floor(crop_corners.min(axis=1)).astype(int) 92 | right, bottom = np.ceil(crop_corners.max(axis=1)).astype(int) + 1 93 | left, top = np.maximum((left, top), 0) 94 | right, bottom = np.maximum((right, bottom), (left + 1, top + 1)) 95 | inst['AABB_crop'] = left, top, right, bottom 96 | 97 | inst['M_crop'] = M 98 | inst['K_crop'] = Ms @ inst['K'] 99 | return inst 100 | 101 | 102 | class RandomRotatedMaskCropApply(BopInstanceAux): 103 | def __init__(self, parent: RandomRotatedMaskCrop): 104 | self.p = parent 105 | 106 | def __call__(self, inst: dict, _) -> dict: 107 | r = self.p.crop_res 108 | for crop_key in self.p.crop_keys: 109 | im = inst[crop_key] 110 | interp = cv2.INTER_LINEAR if im.ndim == 2 else np.random.choice(self.p.rgb_interpolation) 111 | inst[f'{crop_key}_crop'] = cv2.warpAffine(im, inst['M_crop'], (r, r), flags=interp) 112 | return inst 113 | 114 | 115 | class TransformsAux(BopInstanceAux): 116 | def __init__(self, tfms, key='rgb_crop', crop_key=None): 117 | self.key = key 118 | self.tfms = tfms 119 | self.crop_key = crop_key 120 | 121 | def __call__(self, inst: dict, _) -> dict: 122 | if self.crop_key is not None: 123 | left, top, right, bottom = inst[self.crop_key] 124 | img_slice = slice(top, bottom), slice(left, right) 125 | else: 126 | img_slice = slice(None) 127 | img = inst[self.key] 128 | img[img_slice] = self.tfms(image=img[img_slice])['image'] 129 | return inst 130 | 131 | 132 | class NormalizeAux(BopInstanceAux): 133 | def __init__(self, key='rgb_crop', suffix=''): 134 | self.key = key 135 | self.suffix = suffix 136 | 137 | def __call__(self, inst: dict, _) -> dict: 138 | inst[f'{self.key}{self.suffix}'] = normalize(inst[self.key]) 139 | return inst 140 | 141 | 142 | class KeyFilterAux(BopInstanceAux): 143 | def __init__(self, keys=Set[str]): 144 | self.keys = keys 145 | 146 | def __call__(self, inst: dict, _) -> dict: 147 | return {k: v for k, v in inst.items() if k in self.keys} 148 | -------------------------------------------------------------------------------- /surfemb/scripts/infer_refine_depth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Depth refinement: 3 | * run image crop through model 4 | * mask-multiply query with both estimated mask probs and the pose est mask 5 | * use query norm threshold to estimate a conservative 2d mask 6 | (estimating the visible mask in the encoder-decoder would make more sense.) 7 | * find median of depth difference within the 2d mask (ignoring invalid depth) 8 | * find com with respect to the 2d mask and use that as the ray to adjust the depth along 9 | """ 10 | 11 | import argparse 12 | import json 13 | from pathlib import Path 14 | 15 | import cv2 16 | import numpy as np 17 | import torch 18 | from tqdm import tqdm 19 | import matplotlib.pyplot as plt 20 | 21 | from ..utils import add_timing_to_list 22 | from ..data.config import config 23 | from ..data.obj import load_objs 24 | from ..data.renderer import ObjCoordRenderer 25 | from ..data.detector_crops import DetectorCropDataset 26 | from ..surface_embedding import SurfaceEmbeddingModel 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('model_path') 30 | parser.add_argument('--device', required=True) 31 | parser.add_argument('--debug', action='store_true') 32 | 33 | args = parser.parse_args() 34 | model_path = Path(args.model_path) 35 | name = model_path.name.split('.')[0] 36 | dataset = name.split('-')[0] 37 | device = torch.device(args.device) 38 | 39 | cfg = config[dataset] 40 | crop_res = 224 41 | root = Path('data/bop') / dataset 42 | test_folder = root / cfg.test_folder 43 | assert root.exists() 44 | 45 | poses = np.load(f'data/results/{name}-poses.npy') 46 | poses_timings = np.load(f'data/results/{name}-poses-timings.npy') 47 | poses_depth = poses.copy() 48 | poses_depth_fp = Path(f'data/results/{name}-depth-poses.npy') 49 | poses_depth_timings_fp = Path(f'data/results/{name}-depth-poses-timings.npy') 50 | for fp in poses_depth_fp, poses_depth_timings_fp: 51 | assert not fp.exists() 52 | 53 | model = SurfaceEmbeddingModel.load_from_checkpoint(args.model_path).to(device) 54 | model.eval() 55 | model.freeze() 56 | 57 | objs, obj_ids = load_objs(root / cfg.model_folder) 58 | dataset = DetectorCropDataset( 59 | dataset_root=root, obj_ids=obj_ids, cfg=cfg, detection_folder=Path(f'data/detection_results/{dataset}'), 60 | auxs=model.get_infer_auxs(objs=objs, crop_res=crop_res) 61 | ) 62 | assert poses.shape[1] == len(dataset) 63 | 64 | crop_renderer = ObjCoordRenderer(objs=objs, w=crop_res, h=crop_res) 65 | n_failed = 0 66 | all_depth_timings = [[], []] 67 | for j in range(2): 68 | depth_timings = all_depth_timings[j] 69 | for i, d in enumerate(tqdm(dataset)): 70 | pose = poses[j, i] # (3, 4) 71 | R = pose[:3, :3] 72 | t = pose[:3, 3:] 73 | 74 | obj_idx, K_crop, K = d['obj_idx'], d['K_crop'], d['K'] 75 | 76 | depth_sensor = cv2.imread( 77 | str(test_folder / f'{d["scene_id"]:06d}/depth/{d["img_id"]:06d}.{cfg.depth_ext}'), 78 | cv2.IMREAD_UNCHANGED 79 | ) 80 | scene_camera_fp = test_folder / f'{d["scene_id"]:06d}/scene_camera.json' 81 | depth_scale = json.load(scene_camera_fp.open())[str(d['img_id'])]['depth_scale'] 82 | depth_sensor = depth_sensor * depth_scale 83 | h, w = depth_sensor.shape 84 | 85 | mask_lgts, query_img = [v.cpu() for v in model.infer_cnn(d['rgb_crop'], obj_idx)] 86 | # the above either doesn't count towards the time (loading images), 87 | # or has already been done in the initial pose estimate (cnn forward pass) 88 | # so timing starts here: 89 | with add_timing_to_list(depth_timings): 90 | depth_sensor_mask = (depth_sensor > 0).astype(np.float32) 91 | M = (K_crop @ np.linalg.inv(K))[:2] 92 | depth_sensor_mask_crop = cv2.warpAffine(depth_sensor_mask, M, (crop_res, crop_res), 93 | flags=cv2.INTER_LINEAR) == 1. 94 | depth_sensor_crop = cv2.warpAffine(depth_sensor, M, (crop_res, crop_res), flags=cv2.INTER_LINEAR) 95 | depth_render = crop_renderer.render(obj_idx, K_crop, R, t, read_depth=True) 96 | render_mask = depth_render > 0 97 | 98 | query_img_norm = torch.norm(query_img, dim=-1) * torch.sigmoid(mask_lgts) 99 | query_img_norm = query_img_norm.numpy() * render_mask * depth_sensor_mask_crop 100 | norm_sum = query_img_norm.sum() 101 | if norm_sum == 0: 102 | n_failed += 1 103 | continue 104 | query_img_norm /= norm_sum 105 | norm_mask = query_img_norm > (query_img_norm.max() * 0.8) 106 | yy, xx = np.argwhere(norm_mask).T # 2 x (N,) 107 | depth_diff = depth_sensor_crop[yy, xx] - depth_render[yy, xx] 108 | depth_adjustment = np.median(depth_diff) 109 | 110 | yx_coords = np.meshgrid(np.arange(crop_res), np.arange(crop_res)) 111 | yx_coords = np.stack(yx_coords[::-1], axis=-1) # (crop_res, crop_res, 2yx) 112 | yx_ray_2d = (yx_coords * query_img_norm[..., None]).sum(axis=(0, 1)) # y, x 113 | ray_3d = np.linalg.inv(K_crop) @ (*yx_ray_2d[::-1], 1) 114 | ray_3d /= ray_3d[2] 115 | 116 | t_depth_refined = t + ray_3d[:, None] * depth_adjustment 117 | poses_depth[j, i, :3, 3:] = t_depth_refined 118 | 119 | if args.debug: 120 | axs = plt.subplots(3, 4, figsize=(12, 9))[1] 121 | axs[0, 0].imshow(d['rgb']) 122 | axs[0, 0].set_title('rgb') 123 | axs[0, 1].imshow(depth_sensor) 124 | axs[0, 1].set_title('depth') 125 | axs[0, 2].imshow(depth_render) 126 | axs[0, 2].set_title('depth render') 127 | 128 | axs[1, 0].imshow(d['rgb_crop']) 129 | axs[1, 0].set_title('rgb crop') 130 | axs[1, 0].scatter(xx, yy) 131 | axs[1, 1].imshow(depth_sensor_crop) 132 | axs[1, 1].set_title('depth crop') 133 | axs[1, 2].imshow(query_img_norm) 134 | axs[1, 2].scatter(*yx_ray_2d[None, ::-1].T, c='r') 135 | axs[1, 2].set_title('query norm') 136 | axs[1, 3].imshow(norm_mask) 137 | axs[1, 3].set_title('norm mask') 138 | 139 | axs[2, 0].imshow(d['rgb_crop']) 140 | axs[2, 0].imshow(render_mask, alpha=0.5) 141 | axs[2, 0].set_title('initial pose') 142 | render_mask_after = crop_renderer.render(obj_idx, K_crop, R, t_depth_refined, read_depth=True) > 0 143 | axs[2, 1].imshow(d['rgb_crop']) 144 | axs[2, 1].imshow(render_mask_after, alpha=0.5) 145 | axs[2, 1].set_title('pose after depth refine') 146 | axs[2, 3].hist(depth_diff) 147 | axs[2, 3].plot([depth_adjustment, depth_adjustment], [0, max(axs[2, 3].get_ylim())]) 148 | axs[2, 3].set_title('depth diff. hist.') 149 | for ax in axs.reshape(-1)[:-1]: 150 | ax.axis('off') 151 | plt.tight_layout() 152 | plt.show() 153 | 154 | poses_depth_timings = poses_timings + np.array(all_depth_timings) 155 | print(f'{n_failed / (len(dataset) * 2):.3f} failed') 156 | 157 | if not args.debug: 158 | np.save(str(poses_depth_fp), poses_depth) 159 | np.save(str(poses_depth_timings_fp), poses_depth_timings) 160 | -------------------------------------------------------------------------------- /surfemb/scripts/infer_debug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import cv2 5 | import torch.utils.data 6 | import numpy as np 7 | 8 | from .. import utils 9 | from ..data import obj 10 | from ..data.config import config 11 | from ..data import instance 12 | from ..data import detector_crops 13 | from ..data.renderer import ObjCoordRenderer 14 | from ..surface_embedding import SurfaceEmbeddingModel 15 | from .. import pose_est 16 | from .. import pose_refine 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('model_path') 20 | parser.add_argument('--real', action='store_true') 21 | parser.add_argument('--detection', action='store_true') 22 | parser.add_argument('--i', type=int, default=0) 23 | parser.add_argument('--device', default='cuda:0') 24 | 25 | args = parser.parse_args() 26 | data_i = args.i 27 | device = torch.device(args.device) 28 | model_path = Path(args.model_path) 29 | 30 | model = SurfaceEmbeddingModel.load_from_checkpoint(args.model_path) 31 | model.eval() 32 | model.freeze() 33 | model.to(device) 34 | 35 | dataset = model_path.name.split('-')[0] 36 | real = args.real 37 | detection = args.detection 38 | root = Path('data/bop') / dataset 39 | cfg = config[dataset] 40 | res_crop = 224 41 | 42 | objs, obj_ids = obj.load_objs(root / cfg.model_folder) 43 | renderer = ObjCoordRenderer(objs, res_crop) 44 | assert len(obj_ids) == model.n_objs 45 | surface_samples, surface_sample_normals = utils.load_surface_samples(dataset, obj_ids) 46 | auxs = model.get_infer_auxs(objs=objs, crop_res=res_crop, from_detections=detection) 47 | dataset_args = dict(dataset_root=root, obj_ids=obj_ids, auxs=auxs, cfg=cfg) 48 | if detection: 49 | assert args.real 50 | data = detector_crops.DetectorCropDataset( 51 | **dataset_args, detection_folder=Path(f'detection_results/{dataset}') 52 | ) 53 | else: 54 | data = instance.BopInstanceDataset(**dataset_args, pbr=not args.real, test=args.real) 55 | 56 | # initialize opencv windows 57 | cols = 4 58 | window_names = 'img', 'mask_est', 'queries', 'keys', \ 59 | 'dist', 'xy', 'xz', 'yz', \ 60 | 'pose', 'mask_score', 'coord_score', 'query_norm' 61 | for j, name in enumerate(window_names): 62 | row = j // cols 63 | col = j % cols 64 | cv2.imshow(name, np.zeros((res_crop, res_crop))) 65 | cv2.moveWindow(name, 100 + 300 * col, 100 + 300 * row) 66 | 67 | print() 68 | print('With an opencv window active:') 69 | print("press 'a', 'd' and 'x'(random) to get a new input image,") 70 | print("press 'e' to estimate pose, and 'r' to refine pose estimate,") 71 | print("press 'g' to see the ground truth pose,") 72 | print("press 'q' to quit.") 73 | while True: 74 | print() 75 | print('------------ new input -------------') 76 | inst = data[data_i] 77 | obj_idx = inst['obj_idx'] 78 | img = inst['rgb_crop'] 79 | K_crop = inst['K_crop'] 80 | obj_ = objs[obj_idx] 81 | print(f'i: {data_i}, obj_id: {obj_ids[obj_idx]}') 82 | 83 | with utils.timer('forward_cnn'): 84 | mask_lgts, query_img = model.infer_cnn(img, obj_idx) 85 | 86 | mask_prob = torch.sigmoid(mask_lgts) 87 | query_vis = model.get_emb_vis(query_img) 88 | query_norm_img = torch.norm(query_img, dim=-1) * mask_prob 89 | query_norm_img /= query_norm_img.max() 90 | cv2.imshow('query_norm', query_norm_img.cpu().numpy()) 91 | 92 | dist_img = torch.zeros(res_crop, res_crop, device=model.device) 93 | 94 | verts_np = surface_samples[obj_idx] 95 | verts = torch.from_numpy(verts_np).float().to(device) 96 | normals = surface_sample_normals[obj_idx] 97 | verts_norm = (verts_np - obj_.offset) / obj_.scale 98 | with utils.timer('forward_mlp'): 99 | keys_verts = model.infer_mlp(torch.from_numpy(verts_norm).float().to(model.device), obj_idx) # (N, emb_dim) 100 | keys_means = keys_verts.mean(dim=0) # (emb_dim,) 101 | 102 | if not detection: 103 | coord_img = torch.from_numpy(inst['obj_coord']).to(device) 104 | key_img = model.infer_mlp(coord_img[..., :3], obj_idx) 105 | key_mask = coord_img[..., 3] == 1 106 | keys = key_img[key_mask] # (N, emb_dim) 107 | key_vis = model.get_emb_vis(key_img, mask=key_mask, demean=keys_means) 108 | 109 | # corr vis 110 | uv_names = 'xy', 'xz', 'yz' 111 | uv_slices = slice(1, None, -1), slice(2, None, -2), slice(2, 0, -1) 112 | uv_uniques = [] 113 | uv_all = ((verts_norm + 1) * (res_crop / 2 - .5)).round().astype(int) 114 | for uv_name, uv_slice in zip(uv_names, uv_slices): 115 | view_uvs_unique, view_uvs_unique_inv = np.unique(uv_all[:, uv_slice], axis=0, return_inverse=True) 116 | uv_uniques.append((view_uvs_unique, view_uvs_unique_inv)) 117 | 118 | # visualize 119 | img_vis = img[..., ::-1].astype(np.float32) / 255 120 | grey = cv2.cvtColor(img_vis, cv2.COLOR_BGR2GRAY) 121 | 122 | for win_name in (*uv_names, 'dist', 'pose', 'mask_score', 'coord_score', 'keys'): 123 | cv2.imshow(win_name, np.zeros((res_crop, res_crop))) 124 | 125 | cv2.imshow('img', img_vis) 126 | cv2.imshow('mask_est', torch.sigmoid(mask_lgts).cpu().numpy()) 127 | cv2.imshow('queries', query_vis.cpu().numpy()) 128 | if not detection: 129 | cv2.imshow('keys', key_vis.cpu().numpy()) 130 | 131 | last_mouse_pos = 0, 0 132 | uv_pts_3d = [] 133 | current_pose = None 134 | down_sample_scale = 3 135 | 136 | 137 | def mouse_cb(event, x, y, flags=0, *_): 138 | global last_mouse_pos 139 | if detection: 140 | return 141 | if flags & cv2.EVENT_FLAG_CTRLKEY: 142 | return 143 | last_mouse_pos = x, y 144 | q = query_img[y, x] 145 | p_mask = mask_prob[y, x] 146 | 147 | key_probs = torch.softmax(keys @ q, dim=0) # (N,) 148 | dist_img[key_mask] = key_probs / key_probs.max() * p_mask 149 | dist_vis = np.stack((grey, grey, dist_img.cpu().numpy()), axis=-1) 150 | cv2.circle(dist_vis, (x, y), 10, (1., 1., 1.), 1, cv2.LINE_AA) 151 | cv2.imshow('dist', dist_vis) 152 | 153 | vert_probs = torch.softmax(keys_verts @ q, dim=0).cpu().numpy() # (N,) 154 | for uv_name, uv_slice, (view_uvs_unique, view_uvs_unique_inv) in zip(uv_names, uv_slices, uv_uniques): 155 | uvs_unique_probs = np.zeros(len(view_uvs_unique), dtype=vert_probs.dtype) 156 | np.add.at(uvs_unique_probs, view_uvs_unique_inv, vert_probs) 157 | prob_img = np.zeros((224, 224, 3)) 158 | yy, xx = view_uvs_unique.T 159 | prob_img[yy, xx, 2] = uvs_unique_probs / uvs_unique_probs.max() * p_mask.cpu().numpy() 160 | prob_img[yy, xx, :2] = 0.1 161 | for p, c in zip(uv_pts_3d, np.eye(3)[::-1]): 162 | p_norm = ((p - obj_.offset) / obj_.scale)[uv_slice] 163 | p_uv = ((p_norm + 1) * (res_crop / 2 - .5)).round().astype(int) 164 | cv2.drawMarker(prob_img, tuple(p_uv[::-1]), c, cv2.MARKER_CROSS, 10) 165 | cv2.imshow(uv_name, prob_img[::-1]) 166 | 167 | 168 | for name in window_names: 169 | cv2.setMouseCallback(name, mouse_cb) 170 | 171 | 172 | def debug_pose_hypothesis(R, t, obj_pts=None, img_pts=None): 173 | global uv_pts_3d, current_pose 174 | current_pose = R, t 175 | render = renderer.render(obj_idx, K_crop, R, t) 176 | render_mask = render[..., 3] == 1. 177 | pose_img = img_vis.copy() 178 | pose_img[render_mask] = pose_img[render_mask] * 0.5 + render[..., :3][render_mask] * 0.25 + 0.25 179 | 180 | if obj_pts is not None: 181 | colors = np.eye(3)[::-1] 182 | for (x, y), c in zip(img_pts.astype(int), colors): 183 | cv2.drawMarker(pose_img, (x, y), tuple(c), cv2.MARKER_CROSS, 10) 184 | uv_pts_3d = obj_pts 185 | mouse_cb(None, *last_mouse_pos) 186 | 187 | cv2.imshow('pose', pose_img) 188 | 189 | poses = np.eye(4) 190 | poses[:3, :3] = R 191 | poses[:3, 3:] = t 192 | pose_est.estimate_pose( 193 | mask_lgts=mask_lgts, query_img=query_img, 194 | obj_pts=verts, obj_normals=normals, obj_keys=keys_verts, 195 | obj_diameter=obj_.diameter, K=K_crop, down_sample_scale=down_sample_scale, 196 | visualize=True, poses=poses[None], 197 | ) 198 | 199 | 200 | def estimate_pose(): 201 | print() 202 | with utils.timer('pnp ransac'): 203 | R, t, scores, mask_scores, coord_scores, dist_2d, size_mask, normals_mask = pose_est.estimate_pose( 204 | mask_lgts=mask_lgts, query_img=query_img, down_sample_scale=down_sample_scale, 205 | obj_pts=verts, obj_normals=normals, obj_keys=keys_verts, 206 | obj_diameter=obj_.diameter, K=K_crop, 207 | ) 208 | if not len(scores): 209 | print('no pose') 210 | return None 211 | else: 212 | R, t, scores, mask_scores, coord_scores = [a.cpu().numpy() for a in 213 | (R, t, scores, mask_scores, coord_scores)] 214 | best_pose_idx = np.argmax(scores) 215 | R_, t_ = R[best_pose_idx], t[best_pose_idx, :, None] 216 | debug_pose_hypothesis(R_, t_) 217 | return R_, t_ 218 | 219 | 220 | while True: 221 | print() 222 | key = cv2.waitKey() 223 | if key == ord('q'): 224 | quit() 225 | elif key == ord('a'): 226 | data_i = (data_i - 1) % len(data) 227 | break 228 | elif key == ord('d'): 229 | data_i = (data_i + 1) % len(data) 230 | break 231 | elif key == ord('x'): 232 | data_i = np.random.randint(len(data)) 233 | break 234 | elif key == ord('e'): 235 | print('pose est:') 236 | estimate_pose() 237 | elif key == ord('g'): 238 | print('gt:') 239 | debug_pose_hypothesis(inst['cam_R_obj'], inst['cam_t_obj']) 240 | elif key == ord('r'): 241 | print('refine:') 242 | if current_pose is not None: 243 | with utils.timer('refinement'): 244 | R, t, score_r = pose_refine.refine_pose( 245 | R=current_pose[0], t=current_pose[1], query_img=query_img, keys_verts=keys_verts, 246 | obj_idx=obj_idx, obj_=obj_, K_crop=K_crop, model=model, renderer=renderer, 247 | ) 248 | trace = np.trace(R @ current_pose[0].T) 249 | angle = np.arccos(np.clip((trace - 1) / 2, -1, 1)) 250 | print(f'refinement angle diff: {np.rad2deg(angle):.1f} deg') 251 | debug_pose_hypothesis(R, t) 252 | 253 | mouse_cb(None, *last_mouse_pos) 254 | -------------------------------------------------------------------------------- /surfemb/surface_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Sequence, Union 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import wandb 8 | import albumentations as A 9 | import cv2 10 | import pytorch_lightning as pl 11 | 12 | from . import data 13 | from .dep.unet import ResNetUNet 14 | from .dep.siren import Siren 15 | from .data.obj import Obj 16 | from .data.tfms import denormalize 17 | from . import utils 18 | 19 | # could be extended to allow other mlp architectures 20 | mlp_class_dict = dict( 21 | siren=Siren 22 | ) 23 | 24 | 25 | class SurfaceEmbeddingModel(pl.LightningModule): 26 | def __init__(self, n_objs: int, emb_dim=12, n_pos=1024, n_neg=1024, lr_cnn=3e-4, lr_mlp=3e-5, 27 | mlp_name='siren', mlp_hidden_features=256, mlp_hidden_layers=2, 28 | key_noise=1e-3, warmup_steps=2000, separate_decoders=True, 29 | **kwargs): 30 | """ 31 | :param emb_dim: number of embedding dimensions 32 | :param n_pos: number of positive (q, k) pairs from the object mask 33 | :param n_neg: number of negative keys, k-, from the object surface 34 | """ 35 | super().__init__() 36 | self.save_hyperparameters() 37 | 38 | self.n_objs, self.emb_dim = n_objs, emb_dim 39 | self.n_pos, self.n_neg = n_pos, n_neg 40 | self.lr_cnn, self.lr_mlp = lr_cnn, lr_mlp 41 | self.warmup_steps = warmup_steps 42 | self.key_noise = key_noise 43 | self.separate_decoders = separate_decoders 44 | 45 | # query model 46 | self.cnn = ResNetUNet( 47 | n_class=(emb_dim + 1) if separate_decoders else n_objs * (emb_dim + 1), 48 | n_decoders=n_objs if separate_decoders else 1, 49 | ) 50 | # key models 51 | mlp_class = mlp_class_dict[mlp_name] 52 | mlp_args = dict(in_features=3, out_features=emb_dim, 53 | hidden_features=mlp_hidden_features, hidden_layers=mlp_hidden_layers) 54 | self.mlps = torch.nn.Sequential(*[mlp_class(**mlp_args) for _ in range(n_objs)]) 55 | 56 | @staticmethod 57 | def model_specific_args(parent_parser: argparse.ArgumentParser): 58 | parser = parent_parser.add_argument_group(SurfaceEmbeddingModel.__name__) 59 | parser.add_argument('--emb-dim', type=int, default=12) 60 | parser.add_argument('--single-decoder', dest='separate_decoders', action='store_false') 61 | return parent_parser 62 | 63 | def get_auxs(self, objs: Sequence[Obj], crop_res: int): 64 | random_crop_aux = data.std_auxs.RandomRotatedMaskCrop(crop_res) 65 | return ( 66 | data.std_auxs.RgbLoader(), 67 | data.std_auxs.MaskLoader(), 68 | random_crop_aux.definition_aux, 69 | # Some image augmentations probably make most sense in the original image, before rotation / rescaling 70 | # by cropping. 'definition_aux' registers 'AABB_crop' such that the "expensive" image augmentation is only 71 | # performed where the crop is going to be taken from. 72 | data.std_auxs.TransformsAux(key='rgb', crop_key='AABB_crop', tfms=A.Compose([ 73 | A.GaussianBlur(blur_limit=(1, 3)), 74 | A.ISONoise(), 75 | A.GaussNoise(), 76 | data.tfms.DebayerArtefacts(), 77 | data.tfms.Unsharpen(), 78 | A.CLAHE(), # could probably be moved to the post-crop augmentations 79 | A.GaussianBlur(blur_limit=(1, 3)), 80 | ])), 81 | random_crop_aux.apply_aux, 82 | data.pose_auxs.ObjCoordAux(objs, crop_res, replace_mask=True), 83 | data.pose_auxs.SurfaceSampleAux(objs, self.n_neg), 84 | data.pose_auxs.MaskSamplesAux(self.n_pos), 85 | data.std_auxs.TransformsAux(tfms=A.Compose([ 86 | A.CoarseDropout(max_height=16, max_width=16, min_width=8, min_height=8), 87 | A.ColorJitter(hue=0.1), 88 | ])), 89 | data.std_auxs.NormalizeAux(), 90 | data.std_auxs.KeyFilterAux({'rgb_crop', 'obj_coord', 'obj_idx', 'surface_samples', 'mask_samples'}) 91 | ) 92 | 93 | def get_infer_auxs(self, objs: Sequence[Obj], crop_res: int, from_detections=True): 94 | auxs = [data.std_auxs.RgbLoader()] 95 | if not from_detections: 96 | auxs.append(data.std_auxs.MaskLoader()) 97 | auxs.append(data.std_auxs.RandomRotatedMaskCrop( 98 | crop_res, max_angle=0, 99 | offset_scale=0 if from_detections else 1, 100 | use_bbox=from_detections, 101 | rgb_interpolation=(cv2.INTER_LINEAR,), 102 | )) 103 | if not from_detections: 104 | auxs += [ 105 | data.pose_auxs.ObjCoordAux(objs, crop_res, replace_mask=True), 106 | data.pose_auxs.SurfaceSampleAux(objs, self.n_neg), 107 | data.pose_auxs.MaskSamplesAux(self.n_pos), 108 | ] 109 | return auxs 110 | 111 | def configure_optimizers(self): 112 | opt = torch.optim.Adam([ 113 | dict(params=self.cnn.parameters(), lr=1e-4), 114 | dict(params=self.mlps.parameters(), lr=3e-5), 115 | ]) 116 | sched = dict( 117 | scheduler=torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(1., i / self.warmup_steps)), 118 | interval='step' 119 | ) 120 | return [opt], [sched] 121 | 122 | def step(self, batch, log_prefix): 123 | img = batch['rgb_crop'] # (B, 3, H, W) 124 | coord_img = batch['obj_coord'] # (B, H, W, 4) [-1, 1] 125 | obj_idx = batch['obj_idx'] # (B,) 126 | coords_neg = batch['surface_samples'] # (B, n_neg, 3) [-1, 1] 127 | mask_samples = batch['mask_samples'] # (B, n_pos, 2) 128 | 129 | device = img.device 130 | B, _, H, W = img.shape 131 | assert coords_neg.shape[1] == self.n_neg 132 | mask = coord_img[..., 3] == 1. # (B, H, W) 133 | y, x = mask_samples.permute(2, 0, 1) # 2 x (B, n_pos) 134 | 135 | if self.separate_decoders: 136 | cnn_out = self.cnn(img, obj_idx) # (B, 1 + emb_dim, H, W) 137 | mask_lgts = cnn_out[:, 0] # (B, H, W) 138 | queries = cnn_out[:, 1:] # (B, emb_dim, H, W) 139 | else: 140 | cnn_out = self.cnn(img) # (B, n_objs + n_objs * emb_dim, H, W) 141 | mask_lgts = cnn_out[torch.arange(B), obj_idx] # (B, H, W) 142 | queries = cnn_out[:, self.n_objs:].view(B, self.n_objs, self.emb_dim, H, W) 143 | queries = queries[torch.arange(B), obj_idx] # (B, emb_dim, H, W) 144 | 145 | mask_prob = torch.sigmoid(mask_lgts) # (B, H, W) 146 | mask_loss = F.binary_cross_entropy(mask_prob, mask.type_as(mask_prob)) 147 | 148 | queries = queries[torch.arange(B).view(B, 1), :, y, x] # (B, n_pos, emb_dim) 149 | 150 | # compute similarities for positive pairs 151 | coords_pos = coord_img[torch.arange(B).view(B, 1), y, x, :3] # (B, n_pos, 3) [-1, 1] 152 | coords_pos += torch.randn_like(coords_pos) * self.key_noise 153 | keys_pos = torch.stack([self.mlps[i](c) for i, c in zip(obj_idx, coords_pos)]) # (B, n_pos, emb_dim) 154 | sim_pos = (queries * keys_pos).sum(dim=-1, keepdim=True) # (B, n_pos, 1) 155 | 156 | # compute similarities for negative pairs 157 | coords_neg += torch.randn_like(coords_neg) * self.key_noise 158 | keys_neg = torch.stack([self.mlps[i](v) for i, v in zip(obj_idx, coords_neg)]) # (B, n_neg, n_dim) 159 | sim_neg = queries @ keys_neg.permute(0, 2, 1) # (B, n_pos, n_neg) 160 | 161 | # loss 162 | lgts = torch.cat((sim_pos, sim_neg), dim=-1).permute(0, 2, 1) # (B, 1 + n_neg, n_pos) 163 | target = torch.zeros(B, self.n_pos, device=device, dtype=torch.long) 164 | nce_loss = F.cross_entropy(lgts, target) 165 | 166 | loss = mask_loss + nce_loss 167 | self.log(f'{log_prefix}/loss', loss) 168 | self.log(f'{log_prefix}/mask_loss', mask_loss) 169 | self.log(f'{log_prefix}/nce_loss', nce_loss) 170 | return loss 171 | 172 | def training_step(self, batch, _): 173 | return self.step(batch, 'train') 174 | 175 | def validation_step(self, batch, _): 176 | self.log_image_sample(batch) 177 | return self.step(batch, 'valid') 178 | 179 | def get_emb_vis(self, emb_img: torch.Tensor, mask: torch.Tensor = None, demean: torch.tensor = False): 180 | if demean is True: 181 | demean = emb_img[mask].view(-1, self.emb_dim).mean(dim=0) 182 | if demean is not False: 183 | emb_img = emb_img - demean 184 | shape = emb_img.shape[:-1] 185 | emb_img = emb_img.view(*shape, 3, -1).mean(dim=-1) 186 | if mask is not None: 187 | emb_img[~mask] = 0. 188 | emb_img /= torch.abs(emb_img).max() + 1e-9 189 | emb_img.mul_(0.5).add_(0.5) 190 | return emb_img 191 | 192 | def log_image_sample(self, batch, i=0): 193 | img = batch['rgb_crop'][i] 194 | obj_idx = batch['obj_idx'][i] 195 | coord_img = batch['obj_coord'][i] 196 | coord_mask = coord_img[..., 3] != 0 197 | 198 | mask_lgts, query_img = self.infer_cnn(img, obj_idx) 199 | query_img = self.get_emb_vis(query_img) 200 | mask_est = torch.tile(torch.sigmoid(mask_lgts)[..., None], (1, 1, 3)) 201 | 202 | key_img = self.infer_mlp(coord_img[..., :3], obj_idx) 203 | key_img = self.get_emb_vis(key_img, mask=coord_mask, demean=True) 204 | 205 | log_img = torch.cat(( 206 | denormalize(img).permute(1, 2, 0), mask_est, query_img, key_img, 207 | ), dim=1).cpu().numpy() 208 | self.trainer.logger.experiment.log(dict( 209 | embeddings=wandb.Image(log_img), 210 | global_step=self.trainer.global_step 211 | )) 212 | 213 | @torch.no_grad() 214 | def infer_cnn(self, img: Union[np.ndarray, torch.Tensor], obj_idx, rotation_ensemble=True): 215 | assert not self.training 216 | if isinstance(img, np.ndarray): 217 | if img.dtype == np.uint8: 218 | img = data.tfms.normalize(img) 219 | img = torch.from_numpy(img).to(self.device) 220 | _, h, w = img.shape 221 | 222 | if rotation_ensemble: 223 | img = utils.rotate_batch(img) # (4, 3, h, h) 224 | else: 225 | img = img[None] # (1, 3, h, w) 226 | cnn_out = self.cnn(img, [obj_idx] * len(img) if self.separate_decoders else None) 227 | if not self.separate_decoders: 228 | channel_idxs = [obj_idx] + list(self.n_objs + obj_idx * self.emb_dim + np.arange(self.emb_dim)) 229 | cnn_out = cnn_out[:, channel_idxs] 230 | # cnn_out: (B, 1+emb_dim, h, w) 231 | if rotation_ensemble: 232 | cnn_out = utils.rotate_batch_back(cnn_out).mean(dim=0) 233 | else: 234 | cnn_out = cnn_out[0] 235 | mask_lgts, query_img = cnn_out[0], cnn_out[1:] 236 | query_img = query_img.permute(1, 2, 0) # (h, w, emb_dim) 237 | return mask_lgts, query_img 238 | 239 | @torch.no_grad() 240 | def infer_mlp(self, pts_norm: Union[np.ndarray, torch.Tensor], obj_idx): 241 | assert not self.training 242 | if isinstance(pts_norm, np.ndarray): 243 | pts_norm = torch.from_numpy(pts_norm).to(self.device).float() 244 | return self.mlps[obj_idx](pts_norm) # (..., emb_dim) 245 | -------------------------------------------------------------------------------- /surfemb/pose_est.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch_scatter 5 | import torch.nn.functional as F 6 | from scipy.spatial.transform import Rotation 7 | 8 | from .utils import timer 9 | 10 | 11 | def estimate_pose(mask_lgts: torch.tensor, query_img: torch.tensor, 12 | obj_pts: torch.tensor, obj_normals: torch.tensor, obj_keys: torch.tensor, obj_diameter: float, 13 | K: np.ndarray, max_poses=10000, max_pose_evaluations=1000, down_sample_scale=3, alpha=1.5, 14 | dist_2d_min=0.1, pnp_method=cv2.SOLVEPNP_AP3P, pose_batch_size=500, max_pool=True, 15 | avg_queries=True, do_prune=True, visualize=False, poses=None, debug=False): 16 | """ 17 | Builds correspondence distribution from queries and keys, 18 | samples correspondences with inversion sampling, 19 | samples poses from correspondences with P3P, 20 | prunes pose hypothesis, 21 | and scores pose hypotheses based on estimated mask and correspondence distribution. 22 | 23 | :param mask_lgts: (r, r) 24 | :param query_img: (r, r, e) 25 | :param obj_pts: (m, 3) 26 | :param obj_normals: (m, 3) 27 | :param obj_keys: (m, e) 28 | :param alpha: exponent factor for correspondence weighing 29 | :param K: (3, 3) camera intrinsics 30 | :param max_poses: number of poses to sample (before pruning) 31 | :param max_pose_evaluations: maximum number of poses to evaluate / score after pruning 32 | :param dist_2d_min: minimum 2d distance between at least one pair of correspondences for a hypothesis 33 | :param max_pool: max pool probs spatially to make score more robust (but less accurate), 34 | similar to a reprojection error threshold in common PnP RANSAC frameworks 35 | :param poses: evaluate these poses instead of sampling poses 36 | """ 37 | device = mask_lgts.device 38 | r = mask_lgts.shape[0] 39 | m, e = obj_keys.shape 40 | 41 | # down sample 42 | K = K.copy() 43 | K[:2, 2] += 0.5 # change origin to corner 44 | K[:2] /= down_sample_scale 45 | K[:2, 2] -= 0.5 # change origin back 46 | 47 | mask_log_prob, neg_mask_log_prob = [ 48 | F.max_pool2d(F.logsigmoid(lgts)[None], down_sample_scale)[0] 49 | for lgts in (mask_lgts, -mask_lgts) 50 | ] 51 | mask_lgts = F.avg_pool2d(mask_lgts[None], down_sample_scale)[0] 52 | res_sampled = len(mask_lgts) 53 | n = res_sampled ** 2 54 | mask_prob = torch.sigmoid(mask_lgts).view(n) 55 | 56 | yy = torch.arange(res_sampled, device=device) 57 | yy, xx = torch.meshgrid(yy, yy) 58 | yy, xx = (v.reshape(n) for v in (yy, xx)) 59 | img_pts = torch.stack((xx, yy), dim=1) # (n, 2) 60 | 61 | if max_pool: 62 | mask_log_prob = F.max_pool2d(mask_log_prob[None], 3, 1, 1)[0] 63 | neg_mask_log_prob = F.max_pool2d(neg_mask_log_prob[None], 3, 1, 1)[0] 64 | mask_log_prob = mask_log_prob.view(n) 65 | neg_mask_log_prob = neg_mask_log_prob.view(n) 66 | 67 | with timer('corr matrix', debug): 68 | if avg_queries: 69 | queries = F.avg_pool2d(query_img.permute(2, 0, 1), down_sample_scale).view(e, n).T # (n, e) 70 | corr_matrix_log = torch.log_softmax(queries @ obj_keys.T, dim=1) # (n, m) 71 | corr_matrix = corr_matrix_log.exp() 72 | else: 73 | # evaluate the whole corr_matrix followed by max pool for evaluation (corr_matrix_log), batched to avoid oom 74 | # use kernel centers for corr_matrix 75 | query_img = query_img[:res_sampled * down_sample_scale, :res_sampled * down_sample_scale].permute(2, 0, 1) 76 | corr_matrix = torch.empty(res_sampled, res_sampled, m, device=device, dtype=torch.float32) 77 | corr_matrix_log = torch.empty_like(corr_matrix) 78 | patch_out_len = int((res_sampled * 0.5) // down_sample_scale) 79 | patch_len = patch_out_len * down_sample_scale 80 | n_patches_len = int(np.ceil(res_sampled * down_sample_scale / patch_len)) 81 | offset = down_sample_scale // 2 82 | for i in range(n_patches_len): 83 | l, lo = patch_len * i, patch_out_len * i 84 | ro = lo + patch_out_len 85 | for j in range(n_patches_len): 86 | t, to = patch_len * j, patch_out_len * j 87 | bo = to + patch_out_len 88 | 89 | patch = query_img[:, t:t + patch_len, l:l + patch_len] 90 | shape = patch.shape[1:] 91 | patch_corr_log = torch.log_softmax(patch.reshape(e, -1).T @ obj_keys.T, dim=1).view(*shape, m) 92 | corr_matrix[to:bo, lo:ro] = patch_corr_log[offset::down_sample_scale, offset::down_sample_scale] 93 | corr_matrix_log[to:bo, lo:ro] = F.max_pool2d( 94 | patch_corr_log.permute(2, 0, 1), down_sample_scale).permute(1, 2, 0) 95 | corr_matrix_log = corr_matrix_log.view(n, m) 96 | corr_matrix = corr_matrix.view(n, m).exp_() 97 | corr_matrix *= mask_prob[:, None] 98 | 99 | if max_pool: 100 | # max pool spatially 101 | # batched over m to avoid oom 102 | corr_matrix_log = corr_matrix_log.view(res_sampled, res_sampled, m).permute(2, 0, 1) # (m, rs, rs) 103 | m_bs = 10000 104 | for i in range(int(np.ceil(m / m_bs))): 105 | l, r = i * m_bs, (i + 1) * m_bs 106 | corr_matrix_log[l:r] = F.max_pool2d(corr_matrix_log[l:r], kernel_size=3, stride=1, padding=1) 107 | corr_matrix_log = corr_matrix_log.permute(1, 2, 0).view(n, m) 108 | 109 | if poses is None: 110 | with timer('sample corr', debug): 111 | corr_matrix = corr_matrix.view(-1) 112 | corr_matrix.pow_(alpha) 113 | corr_matrix_cumsum = torch.cumsum(corr_matrix, dim=0, out=corr_matrix) 114 | corr_matrix_cumsum /= corr_matrix_cumsum[-1].item() 115 | corr_matrix = None # cumsum is overwritten. Reset variable to avoid accidental use 116 | corr_idx = torch.searchsorted(corr_matrix_cumsum, torch.rand(max_poses, 4, device=device)) # (max_poses, 4) 117 | del corr_matrix_cumsum # frees gpu memory 118 | 119 | p2d_idx, p3d_idx = corr_idx.div(m, rounding_mode='floor'), corr_idx % m 120 | p2d, p3d = img_pts[p2d_idx].float(), obj_pts[p3d_idx] # (max_poses, 4, 2 xy), (max_poses, 4, 3 xyz) 121 | n3d = obj_normals[p3d_idx[:, :3].cpu().numpy()] # (max_poses, 3, 3 nx ny nz) 122 | 123 | with timer('to cpu', debug): 124 | p2d, p3d = p2d.cpu().numpy(), p3d.cpu().numpy() 125 | 126 | if visualize: 127 | corr_2d_vis = np.zeros((r, r)) 128 | p2d_xx, p2d_yy = p2d.astype(int).reshape(-1, 2).T 129 | np.add.at(corr_2d_vis, (p2d_yy, p2d_xx), 1) 130 | corr_2d_vis /= corr_2d_vis.max() 131 | cv2.imshow('corr_2d_vis', corr_2d_vis) 132 | 133 | poses = np.zeros((max_poses, 3, 4)) 134 | poses_mask = np.zeros(max_poses, dtype=bool) 135 | with timer('pnp', debug): 136 | rotvecs = np.zeros((max_poses, 3)) 137 | for i in range(max_poses): 138 | ret, rvecs, tvecs = cv2.solveP3P(p3d[i], p2d[i], K, None, flags=pnp_method) 139 | if rvecs: 140 | j = np.random.randint(len(rvecs)) 141 | rotvecs[i] = rvecs[j][:, 0] 142 | poses[i, :3, 3:] = tvecs[j] 143 | poses_mask[i] = True 144 | poses[:, :3, :3] = Rotation.from_rotvec(rotvecs).as_matrix() 145 | poses, p2d, p3d, n3d = [a[poses_mask] for a in (poses, p2d, p3d, n3d)] 146 | 147 | with timer('pose pruning', debug): 148 | # Prune hypotheses where all correspondences come from the same small area in the image 149 | dist_2d = np.linalg.norm(p2d[:, :3, None] - p2d[:, None, :3], axis=-1).max(axis=(1, 2)) # (max_poses,) 150 | dist_2d_mask = dist_2d >= dist_2d_min * res_sampled 151 | 152 | # Prune hypotheses that are very close to or very far from the camera compared to the crop 153 | z = poses[:, 2, 3] 154 | z_min = K[0, 0] * obj_diameter / (res_sampled * 20) 155 | z_max = K[0, 0] * obj_diameter / (res_sampled * 0.5) 156 | size_mask = (z_min < z) & (z < z_max) 157 | 158 | # Prune hypotheses where correspondences are not visible, estimated by the face normal. 159 | Rt = poses[:, :3, :3].transpose(0, 2, 1) # (max_poses, 3, 3) 160 | n3d_cam = n3d @ Rt # (max_poses, 3 pts, 3 nxnynz) 161 | p3d_cam = p3d[:, :3] @ Rt + poses[:, None, :3, 3] # (max_poses, 3 pts, 3 xyz) 162 | normals_dot = (n3d_cam * p3d_cam).sum(axis=-1) # (max_poses, 3 pts) 163 | normals_mask = np.all(normals_dot < 0, axis=-1) # (max_poses,) 164 | 165 | # allow not pruning for debugging reasons 166 | if do_prune: 167 | poses = poses[dist_2d_mask & size_mask & normals_mask] # (n_poses, 3, 4) 168 | else: 169 | dist_2d, size_mask, normals_mask = None, None, None 170 | 171 | poses = poses[slice(None, max_pose_evaluations)] 172 | n_poses = len(poses) 173 | R = poses[:, :3, :3] # (n_poses, 3, 3) 174 | t = poses[:, :3, 3] # (n_poses, 3) 175 | 176 | if debug: 177 | print('n_poses', n_poses) 178 | 179 | def batch_score(R: torch.tensor, t: torch.tensor, visualize=False): 180 | n_poses = len(R) 181 | # project to image 182 | obj_pts_cam = obj_pts @ R.permute(0, 2, 1) + t[:, None] # (n_poses, m, 3) 183 | z = obj_pts_cam[..., 2] # (n_poses, m) 184 | obj_pts_img = obj_pts_cam @ K.T 185 | u = (obj_pts_img[..., :2] / obj_pts_img[..., 2:]).round_() # (n_poses, m, 2 xy) 186 | # ignore pts outside the image 187 | mask_neg = torch.any(torch.logical_or(u < 0, res_sampled <= u), dim=-1) # (n_poses, m) 188 | # convert 2D-coordinates to flat indexing 189 | u = u[..., 1].mul_(res_sampled).add_(u[..., 0]) # (n_poses, m) 190 | # use an ignore bin to allow batched scatter_min 191 | u[mask_neg] = n # index for the ignore bin 192 | # maybe u should be rounded before casting to long - or converted to long after rounding above 193 | # but a small test shows that there are no rounding errors 194 | u = u.long() 195 | 196 | # per pixel, find the vertex closest to the camera 197 | z, z_arg = torch_scatter.scatter_min(z, u, dim_size=n + 1) # 2x(n_poses, n + 1 ignore bin) 198 | z, z_arg = z[:, :-1], z_arg[:, :-1] # then discard the ignore bin: 2x(n_poses, n) 199 | # get mask of populated pixels 200 | mask = z > 0 # (n_poses, n) 201 | mask_pose_idx, mask_n_idx = torch.where(mask) # 2x (k,) 202 | z, z_arg = z[mask_pose_idx, mask_n_idx], z_arg[mask_pose_idx, mask_n_idx] # 2x (k,) 203 | u = u[mask_pose_idx, z_arg] # (k,) 204 | 205 | mask_score_2d = neg_mask_log_prob[None].expand(n_poses, n).clone() # (n_poses, n) 206 | mask_score_2d[mask_pose_idx, u] = mask_log_prob[u] 207 | mask_score = mask_score_2d.mean(dim=1) # (n_poses,) 208 | 209 | coord_score = corr_matrix_log[u, z_arg] # (k,) 210 | coord_score = torch_scatter.scatter_mean(coord_score, mask_pose_idx, dim_size=n_poses) # (n_poses,) 211 | # handle special case, where no mask pts are in the image 212 | coord_score_mask = torch.ones(n_poses, dtype=torch.bool, device=device) 213 | coord_score_mask[mask_pose_idx] = 0 214 | coord_score[coord_score_mask] = -np.inf 215 | 216 | # normalize by max entropy 217 | mask_score /= np.log(2) 218 | coord_score /= np.log(m) 219 | 220 | score = mask_score + coord_score # (n_poses,) 221 | 222 | if visualize: 223 | assert len(R) == 1 224 | mask_score_img = mask_score_2d[0].view(res_sampled, res_sampled).cpu().numpy() # [mi, 0] 225 | mask_score_img = 1 - mask_score_img / mask_score_img.min() 226 | cv2.imshow('mask_score', mask_score_img) 227 | 228 | coord_score_img = torch.zeros(res_sampled * res_sampled, 3, device=device) 229 | coord_score_img[:, 2] = 1. 230 | coord_scores = corr_matrix_log[u, z_arg] # (k_best) 231 | coord_score_img[u] = (1 - coord_scores / coord_scores.min())[:, None] 232 | cv2.imshow('coord_score', coord_score_img.view(res_sampled, res_sampled, 3).cpu().numpy()) 233 | 234 | return score, mask_score, coord_score 235 | 236 | R, t, K = (torch.from_numpy(v).float().to(device) for v in (R, t, K)) 237 | pose_scores = torch.empty(n_poses, device=device) 238 | mask_scores = torch.empty(n_poses, device=device) 239 | coord_scores = torch.empty(n_poses, device=device) 240 | scores = pose_scores, mask_scores, coord_scores 241 | for batch_idx in range(np.ceil(n_poses / pose_batch_size).astype(int)): 242 | l = pose_batch_size * batch_idx 243 | r = l + pose_batch_size 244 | with timer('batch', debug): 245 | batch_scores = batch_score(R[l:r], t[l:r]) 246 | for container, items in zip(scores, batch_scores): 247 | container[l:r] = items 248 | 249 | if visualize and len(pose_scores) > 0: 250 | best_pose_idx = torch.argmax(pose_scores) 251 | batch_score(R[best_pose_idx:best_pose_idx + 1], t[best_pose_idx:best_pose_idx + 1], visualize=True) 252 | print('pose_score', pose_scores[best_pose_idx].item()) 253 | print('mask_score', mask_scores[best_pose_idx].item()) 254 | print('coord_score', coord_scores[best_pose_idx].item()) 255 | 256 | return R, t, pose_scores, mask_scores, coord_scores, dist_2d, size_mask, normals_mask 257 | --------------------------------------------------------------------------------