├── lib ├── three │ ├── __init__.py │ ├── batchview.py │ ├── core.py │ └── rigid.py ├── network.py ├── rendering.py ├── preprocess.py └── geometry.py ├── assets └── introduction_figure.png ├── Dataspace └── README.md ├── .gitignore ├── checkpoints └── README.md ├── requirements.txt ├── dataset ├── LineMOD_Dataset.py └── TLESS_Dataset.py ├── training ├── config.py ├── shapenet.py ├── preprocess_shapenet.py ├── data_augment.py ├── pyrenderer.py └── train_utils.py ├── evaluation ├── config.py ├── pplane_ICP.py ├── TLESS_MPmask_OVE6D_sixd17.py ├── LMO_RCNN_OVE6D_pipeline.py └── LM_RCNN_OVE6D_pipeline.py ├── README.md ├── example └── misc.py └── utility ├── meshutils.py └── visualization.py /lib/three/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .rigid import * 3 | from .batchview import * 4 | -------------------------------------------------------------------------------- /assets/introduction_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingdingcai/OVE6D-pose/HEAD/assets/introduction_figure.png -------------------------------------------------------------------------------- /Dataspace/README.md: -------------------------------------------------------------------------------- 1 | # This directory contains the datasets (in [BOP format](https://bop.felk.cvut.cz/datasets)) for evaluation. 2 | 3 | the datasets should be organized as: 4 | ./lm 5 | ./lmo 6 | ./tless -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | evaluation/eval_results 2 | evaluation/bop_pred_results 3 | evaluation/mv_pred_results 4 | evaluation/object_codebooks 5 | evaluation/*_GTmask.py 6 | evaluation/viewpoint_codebook 7 | example/.ipynb_checkpoints 8 | Dataspace2 9 | 10 | notebook 11 | *__pycache__* 12 | *.pth 13 | training/checkpoints 14 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | ## This directory contains the pre-trained weights of OVE6D framework. 2 | 3 | - 1. ``OVE6D_pose_model.pth`` pre-trained weights for OVE6D model. 4 | - 2. ``lm_maskrcnn_model.pth`` pre-trained weights of [Mask-RCNN](https://github.com/facebookresearch/detectron2) for LINEMOD object segmentation. 5 | - 3. ``lmO_maskrcnn_model.pth`` pre-trained weights for Occluded LINEMOD object segmentation. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyrender==0.1.45 2 | PyOpenGL==3.1.0 3 | PyOpenGL-accelerate==3.1.5 4 | scikit-image==0.18.1 5 | trimesh==3.9.9 6 | scipy==1.5.1 7 | Pillow==7.2.0 8 | imageio==2.9.0 9 | numpy==1.19.5 10 | structlog==21.1.0 11 | matplotlib==3.3.4 12 | tqdm==4.59.0 13 | imgaug==0.4.0 14 | opencv-python==4.5.1.48 15 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 16 | PyYAML==5.4.1 17 | tensorboard==2.4.1 18 | Blender==2.80 19 | cudatoolkit==11.1 20 | python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.8/index.html 21 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" 22 | -------------------------------------------------------------------------------- /lib/three/batchview.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def bvmm(a, b): 6 | if a.shape[0] != b.shape[0]: 7 | raise ValueError("batch dimension must match") 8 | if a.shape[1] != b.shape[1]: 9 | raise ValueError("view dimension must match") 10 | 11 | nbatch, nview, nrow, ncol = a.shape 12 | a = a.view(-1, nrow, ncol) 13 | b = b.view(-1, nrow, ncol) 14 | out = torch.bmm(a, b) 15 | out = out.view(nbatch, nview, out.shape[1], out.shape[2]) 16 | return out 17 | 18 | 19 | def bv2b(x): 20 | if not x.is_contiguous(): 21 | return x.reshape(-1, *x.shape[2:]) 22 | return x.view(-1, *x.shape[2:]) 23 | 24 | 25 | def b2bv(x, num_view=-1, batch_size=-1): 26 | if num_view == -1 and batch_size == -1: 27 | raise ValueError('One of num_view or batch_size must be non-negative.') 28 | return x.view(batch_size, num_view, *x.shape[1:]) 29 | 30 | 31 | def vcat(tensors, batch_size): 32 | tensors = [b2bv(t, batch_size=batch_size) for t in tensors] 33 | return bv2b(torch.cat(tensors, dim=1)) 34 | 35 | 36 | def vsplit(tensor, sections): 37 | num_view = sum(sections) 38 | tensor = b2bv(tensor, num_view=num_view) 39 | return tuple(bv2b(t) for t in torch.split(tensor, sections, dim=1)) 40 | -------------------------------------------------------------------------------- /dataset/LineMOD_Dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from pathlib import Path 4 | 5 | class Dataset(): 6 | def __init__(self, data_dir): 7 | self.model_dir = Path(data_dir) / 'models_eval' 8 | self.cam_file = Path(data_dir) / 'camera.json' 9 | 10 | with open(self.cam_file, 'r') as cam_f: 11 | self.cam_info = json.load(cam_f) 12 | 13 | self.cam_K = torch.tensor([ 14 | [self.cam_info['fx'], 0, self.cam_info['cx']], 15 | [0.0, self.cam_info['fy'], self.cam_info['cy']], 16 | [0.0, 0.0, 1.0] 17 | ], dtype=torch.float32) 18 | 19 | self.cam_height = self.cam_info['height'] 20 | self.cam_width = self.cam_info['width'] 21 | 22 | self.model_info_file = self.model_dir / 'models_info.json' 23 | with open(self.model_info_file, 'r') as model_f: 24 | self.model_info = json.load(model_f) 25 | 26 | self.obj_model_file = dict() 27 | self.obj_diameter = dict() 28 | 29 | for model_file in sorted(self.model_dir.iterdir()): 30 | if str(model_file).endswith('.ply'): 31 | obj_id = int(model_file.name.split('_')[-1].split('.')[0]) 32 | self.obj_model_file[obj_id] = model_file 33 | self.obj_diameter[obj_id] = self.model_info[str(obj_id)]['diameter'] 34 | -------------------------------------------------------------------------------- /training/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | BASE_LR = 1e-3 # starting learning rate 4 | MAX_EPOCHS = 50 # maximum training epochs 5 | NUM_VIEWS = 16 # the sampling number of viewpoint for each object 6 | WARMUP_EPOCHS = 0 # warmup epochs during training 7 | RANKING_MARGIN = 0.1 # the triplet margin for ranking 8 | USE_DATA_AUG = True # whether apply data augmentation during training process 9 | HIST_BIN_NUMS = 100 # the number of histogram bins 10 | MIN_DEPTH_PIXELS = 200 # the minimum number of valid depth values for a valid training depth image 11 | VISIB_FRAC = 0.1 # the minimum visible surface ratio 12 | 13 | RENDER_WIDTH = 720 # the width of rendered images 14 | RENDER_HEIGHT = 540 # the height of rendered images 15 | MIN_HIST_STAT = 50 # the histogram threshold for filtering out ambiguous inter-viewpoint training pairs 16 | RENDER_DIST = 5 # the radius distance factor of uniform sampling relative to object diameter. 17 | ZOOM_MODE = 'bilinear' # the target zooming mode (bilinear or nearest) 18 | ZOOM_SIZE = 128 # the target zooming size 19 | ZOOM_DIST_FACTOR = 0.01 # the distance factor of zooming (relative to object diameter) 20 | 21 | 22 | INTRINSIC = torch.tensor([ 23 | [1.0757e+03, 0.0000e+00, 3.6607e+02], 24 | [0.0000e+00, 1.0739e+03, 2.8972e+02], 25 | [0.0000e+00, 0.0000e+00, 1.0000e+00] 26 | ], dtype=torch.float32) 27 | 28 | 29 | # RENDER_WIDTH = 640 # the width of rendered images 30 | # RENDER_HEIGHT = 480 # the height of rendered images 31 | # MIN_HIST_STAT = 30 # the histogram threshold for filtering out ambiguous inter-viewpoint training pairs 32 | # RENDER_DIST = 5 # the radius distance factor of uniform sampling relative to object diameter. 33 | # ZOOM_MODE = 'bilinear' # the target zooming mode (bilinear or nearest) 34 | # ZOOM_SIZE = 128 # the target zooming size 35 | # ZOOM_DIST_FACTOR = 8 # the distance factor of zooming (relative to object diameter) 36 | 37 | # INTRINSIC = torch.tensor([ 38 | # [615.1436, 0.000000, 315.3623], 39 | # [0.0000, 615.4991, 251.5415], 40 | # [0.0000, 0.000000, 1.000000], 41 | # ], dtype=torch.float32) 42 | 43 | 44 | -------------------------------------------------------------------------------- /dataset/TLESS_Dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from pathlib import Path 4 | 5 | 6 | class Dataset(): 7 | def __init__(self, data_dir, type='recon'): 8 | """ 9 | type[cad, recon]: using cad model or reconstructed model 10 | """ 11 | super().__init__() 12 | assert(type == 'cad' or type == 'recon'), "only support CAD model (cad) or reconstructed model (recon)" 13 | self.cam_file = Path(data_dir) / 'camera_primesense.json' 14 | with open(self.cam_file, 'r') as cam_f: 15 | self.cam_info = json.load(cam_f) 16 | 17 | # The below is the ground truth camera information of this dataset, which is supposed to be utilized to generate the codebook 18 | # self.cam_K = torch.tensor([ 19 | # [self.cam_info['fx'], 0, self.cam_info['cx']], 20 | # [0.0, self.cam_info['fy'], self.cam_info['cy']], 21 | # [0.0, 0.0, 1.0] 22 | # ], dtype=torch.float32) 23 | # self.cam_height = self.cam_info['height'] 24 | # self.cam_width = self.cam_info['width'] 25 | 26 | # But we use by chance the below information (of test_primesense/01/rgb/190.png) to generate object codebooks in our paper 27 | self.cam_K = torch.tensor([ 28 | [1.0757e+03, 0.0000e+00, 3.6607e+02], 29 | [0.0000e+00, 1.0739e+03, 2.8972e+02], 30 | [0.0000e+00, 0.0000e+00, 1.0000e+00], 31 | ], dtype=torch.float32) 32 | self.cam_height = 540 33 | self.cam_width = 720 34 | 35 | 36 | if type == "recon": 37 | self.model_dir = Path(data_dir) / 'models_reconst' 38 | else: 39 | self.model_dir = Path(data_dir) / 'models_cad' 40 | 41 | 42 | self.model_info_file = self.model_dir / 'models_info.json' 43 | 44 | # self.cam_height = 540 45 | # self.cam_width = 720 46 | # self.depth_scale = 0.1 47 | 48 | with open(self.model_info_file, 'r') as model_f: 49 | self.model_info = json.load(model_f) 50 | 51 | self.obj_model_file = dict() 52 | self.obj_diameter = dict() 53 | 54 | for model_file in sorted(self.model_dir.iterdir()): 55 | if str(model_file).endswith('.ply'): 56 | obj_id = int(model_file.name.split('_')[-1].split('.')[0]) 57 | self.obj_model_file[obj_id] = model_file 58 | self.obj_diameter[obj_id] = self.model_info[str(obj_id)]['diameter'] 59 | 60 | -------------------------------------------------------------------------------- /training/shapenet.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import sys 3 | # sys.path.append(os.path.abspath('.')) 4 | 5 | import structlog 6 | from pathlib import Path 7 | from training.pyrenderer import PyrenderDataset 8 | 9 | 10 | logger = structlog.get_logger(__name__) 11 | 12 | 13 | 14 | synsets_cat = { 15 | '02691156': 'airplane', '02773838': 'bag', '02808440': 'bathtub', '02818832': 'bed', '02843684': 'birdhouse', 16 | '02871439': 'bookshelf', '02924116': 'bus', '02933112': 'cabinet', '02942699': 'camera', '02958343': 'car', 17 | '03001627': 'chair', '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'display', '03325088': 'faucet', 18 | '03636649': 'lamp', '03642806': 'laptop', '03691459': 'loudspeaker', '03710193': 'mailbox', '03761084': 'microwaves', 19 | '03790512': 'motorbike', '03928116': 'piano', '03948459': 'pistol', '04004475': 'printer', '04090263': 'rifle', 20 | '04256520': 'sofa', '04379243': 'table', '04468005': 'train', '04530566': 'watercraft', '04554684': 'washer' 21 | } 22 | 23 | 24 | def get_shape_paths(dataset_dir, whitelist_synsets=None, blacklist_synsets=None): 25 | """ 26 | Returns shape paths for ShapeNet. 27 | 28 | Args: 29 | dataset_dir: the directory containing the dataset 30 | blacklist_synsets: a list of synsets to exclude 31 | 32 | Returns: 33 | 34 | """ 35 | shape_index_path = (dataset_dir / 'paths.txt') 36 | if shape_index_path.exists(): 37 | with open(shape_index_path, 'r') as f: 38 | paths = [Path(dataset_dir, p.strip()) for p in f.readlines()] 39 | else: 40 | paths = list(dataset_dir.glob('**/*.obj')) 41 | 42 | logger.info("total models", num_shape=len(paths)) 43 | 44 | if whitelist_synsets is not None: 45 | num_filtered = sum(1 for p in paths if p.parent.parent.parent.name in whitelist_synsets) 46 | paths = [p for p in paths if p.parent.parent.parent.name in whitelist_synsets] 47 | logger.info("selected shapes from whitelist", num_filtered=num_filtered) 48 | 49 | if blacklist_synsets is not None: 50 | num_filtered = sum(1 for p in paths if p.parent.parent.parent.name in blacklist_synsets) 51 | paths = [p for p in paths if p.parent.parent.parent.name not in blacklist_synsets] 52 | logger.info("selected shapes byond blacklist", num_filtered=num_filtered) 53 | 54 | return paths 55 | 56 | 57 | class ShapeNetV2(PyrenderDataset): 58 | def __init__(self, *args, data_dir, 59 | whitelist_synsets=None, 60 | blacklist_synsets=None, 61 | scale_jitter=(0.05, 0.5), 62 | **kwargs): 63 | shape_paths = get_shape_paths(data_dir, 64 | whitelist_synsets=whitelist_synsets, 65 | blacklist_synsets=blacklist_synsets, 66 | ) 67 | 68 | super().__init__(shape_paths, scale_jitter=scale_jitter, *args, **kwargs) 69 | 70 | -------------------------------------------------------------------------------- /evaluation/config.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | from pytorch3d.transforms import euler_angles_to_matrix 5 | 6 | RANDOM_SEED = 2021 # for reproduce the results of evaluation 7 | 8 | VIEWBOOK_BATCHSIZE = 200 # batch size for constructing viewpoint codebook, reduce this if out of GPU memory 9 | RENDER_WIDTH = 640 # the width of rendered images 10 | RENDER_HEIGHT = 480 # the height of rendered images 11 | RENDER_DIST = 5 # the radius distance factor of uniform sampling relative to object diameter. 12 | RENDER_NUM_VIEWS = 4000 # the number of uniform sampling views from a sphere 13 | MODEL_SCALING = 1.0/1000 # TLESS object model scale from millimeter to meter 14 | 15 | ZOOM_SIZE = 128 # the target zooming size 16 | ZOOM_MODE = 'bilinear' # the target zooming mode (bilinear or nearest) 17 | ZOOM_DIST_FACTOR = 0.01 # the distance factor of zooming (relative to object diameter) 18 | DATASET_NAME = '' 19 | SAVE_FTMAP = True # store the latent feature maps of viewpoint (for rotation regression) 20 | 21 | HEMI_ONLY = True 22 | USE_ICP = True 23 | ICP_neighbors = 10 24 | ICP_min_planarity = 0.2 25 | ICP_max_iterations = 20 # max iterations for ICP 26 | ICP_correspondences = 1000 # the number of points selected for iteration 27 | 28 | VP_NUM_TOPK = 50 # the number of viewpoint retrievals 29 | POSE_NUM_TOPK = 5 # the number of pose hypotheses 30 | 31 | 32 | DATA_PATH = 'Dataspace' 33 | 34 | 35 | def BOP_REF_POSE(ref_R): 36 | unsqueeze = False 37 | if not isinstance(ref_R, torch.Tensor): 38 | ref_R = torch.tensor(ref_R, dtype=torch.float32) 39 | if ref_R.dim() == 2: 40 | ref_R = ref_R.unsqueeze(0) 41 | unsqueeze = True 42 | assert ref_R.dim() == 3 and ref_R.shape[-1] == 3, "rotation R dim must be B x 3 x 3" 43 | CAM_REF_POSE = torch.tensor(( 44 | (-1, 0, 0), 45 | (0, 1, 0), 46 | (0, 0, 1), 47 | ), dtype=torch.float32) 48 | 49 | XR = euler_angles_to_matrix(torch.tensor([180/180*math.pi, 0, 0]), "XYZ") 50 | R = (XR[None, ...] @ ref_R.clone()) 51 | R = CAM_REF_POSE.T[None, ...] @ R @ CAM_REF_POSE[None, ...] 52 | if unsqueeze: 53 | R = R.squeeze(0) 54 | return R 55 | 56 | def POSE_TO_BOP(ref_R): 57 | unsqueeze = False 58 | if not isinstance(ref_R, torch.Tensor): 59 | ref_R = torch.tensor(ref_R, dtype=torch.float32) 60 | if ref_R.dim() == 2: 61 | ref_R = ref_R.unsqueeze(0) 62 | unsqueeze = True 63 | assert ref_R.dim() == 3 and ref_R.shape[-1] == 3, "rotation R dim must be B x 3 x 3" 64 | CAM_REF_POSE = torch.tensor(( 65 | (-1, 0, 0), 66 | (0, 1, 0), 67 | (0, 0, 1), 68 | ), dtype=torch.float32) 69 | XR = euler_angles_to_matrix(torch.tensor([180/180*math.pi, 0, 0]), "XYZ") 70 | R = XR[None, ...] @ ref_R 71 | 72 | R = CAM_REF_POSE[None, ...] @ R @ CAM_REF_POSE.T[None, ...] 73 | if unsqueeze: 74 | R = R.squeeze(0) 75 | return R 76 | 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OVE6D: Object Viewpoint Encoding for Depth-based 6D Object Pose Estimation (CVPR 2022) 2 | - [Paper](https://arxiv.org/abs/2203.01072) 3 | - [Project page](https://dingdingcai.github.io/ove6d-pose/) 4 | - Another good implementation of this project can be found [here](https://github.com/EternalGoldenBraid/PoseEstimation_pipeline) with real demos. 5 | 6 |

7 | 8 |

9 | 10 | ``` Bash 11 | @inproceedings{cai2022ove6d, 12 | title={OVE6D: Object Viewpoint Encoding for Depth-based 6D Object Pose Estimation}, 13 | author={Cai, Dingding and Heikkil{\"a}, Janne and Rahtu, Esa}, 14 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 15 | pages={6803--6813}, 16 | year={2022} 17 | } 18 | ``` 19 | 20 | 21 | ## Setup 22 | Please start by installing [Miniconda3](https://conda.io/projects/conda/en/latest/user-guide/install/linux.html) with Pyhton3.8 or above. 23 | 24 | ## Denpendencies 25 | This project requires the evaluation code from [bop_toolkit](https://github.com/thodan/bop_toolkit) and [sixd_toolkit](https://github.com/thodan/sixd_toolkit). 26 | 27 | 28 | ## Dataset 29 | Our evaluation is conducted on three datasets all downloaded from [BOP website](https://bop.felk.cvut.cz/datasets). All three datasets are stored in the same directory. e.g. ``Dataspace/lm, Dataspace/lmo, Dataspace/tless``. 30 | 31 | ## Quantitative Evaluation 32 | Evaluation on the LineMOD and Occluded LineMOD datasets with instance segmentation (Mask-RCNN) network (entire pipeline i.e. instance segmentation + pose estimation) 33 | 34 | ``python LM_RCNN_OVE6D_pipeline.py`` for LineMOD. 35 | 36 | ``python LMO_RCNN_OVE6D_pipeline.py`` for Occluded LineMOD. 37 | 38 | Evaluation on the T-LESS dataset with the provided object segmentation masks (downloaded from [Multi-Path Encoder](https://github.com/DLR-RM/AugmentedAutoencoder/tree/multipath)). 39 | 40 | ``python TLESS_eval_sixd17.py`` for TLESS. 41 | 42 | ## Training 43 | To train OVE6D, the ShapeNet dataset is required. You need to first pre-process ShapeNet with the provided script in ``training/preprocess_shapenet.py``, and [Blender](https://www.blender.org/) is required for this task. More details refer to [LatentFusion](https://github.com/NVlabs/latentfusion). 44 | 45 | ## pre-trained weight for OVE6D 46 | Our pre-trained OVE6D weights can be found [here](https://drive.google.com/drive/folders/16f2xOjQszVY4aC-oVboAD-Z40Aajoc1s?usp=sharing). Please download and save to the directory ``checkpoints/``. 47 | 48 | ## Segmentation Masks 49 | 50 | 51 | - 1. For T-LESS we use the [segmentation masks](https://drive.google.com/file/d/1UiJ6fo-2chlm4snNYzc7I_1MBLIzncWW/view?usp=sharing) provided by [Multi-Path Encoder](https://github.com/DLR-RM/AugmentedAutoencoder/tree/multipath). 52 | 53 | - 2. For LineMOD and Occluded LineMOD, we fine-tuned the Mask-RCNN initialized with the weights from [Detectron2](https://github.com/facebookresearch/detectron2). The training data can be downloaded from [BOP](https://bop.felk.cvut.cz/datasets). 54 | 55 | # Acknowledgement 56 | - 1. The code is partially based on [LatentFusion](https://github.com/NVlabs/latentfusion). 57 | - 2. The evaluation code is based on [bop_toolkit](https://github.com/thodan/bop_toolkit) and [sixd_toolkit](https://github.com/thodan/sixd_toolkit). 58 | 59 | 60 | -------------------------------------------------------------------------------- /lib/three/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def acos_safe(t, eps: float = 1e-7): 6 | return torch.acos(torch.clamp(t, min=-1.0 + eps, max=1.0 - eps)) 7 | 8 | 9 | @torch.jit.script 10 | def ensure_batch_dim(tensor, num_dims: int): 11 | unsqueezed = False 12 | if len(tensor.shape) == num_dims: 13 | tensor = tensor.unsqueeze(0) 14 | unsqueezed = True 15 | 16 | return tensor, unsqueezed 17 | 18 | 19 | @torch.jit.script 20 | def normalize(vector, dim: int = -1): 21 | """ 22 | Normalizes the vector to a unit vector using the p-norm. 23 | Args: 24 | vector (tensor): the vector to normalize of size (*, 3) 25 | p (int): the norm order to use 26 | 27 | Returns: 28 | (tensor): A unit normalized vector of size (*, 3) 29 | """ 30 | return vector / torch.norm(vector, p=2.0, dim=dim, keepdim=True) 31 | 32 | 33 | @torch.jit.script 34 | def uniform(n: int, min_val: float, max_val: float): 35 | return (max_val - min_val) * torch.rand(n) + min_val 36 | 37 | 38 | def uniform_unit_vector(n): 39 | return normalize(torch.randn(n, 3), dim=1) 40 | 41 | 42 | def inner_product(a, b): 43 | return (a * b).sum(dim=-1) 44 | 45 | 46 | @torch.jit.script 47 | def homogenize(coords): 48 | ones = torch.ones_like(coords[..., 0, None]) 49 | return torch.cat((coords, ones), dim=-1) 50 | 51 | 52 | @torch.jit.script 53 | def dehomogenize(coords): 54 | return coords[..., :coords.size(-1) - 1] / coords[..., -1, None] 55 | 56 | 57 | def transform_coord_grid(grid, transform): 58 | if transform.size(0) != grid.size(0): 59 | raise ValueError('Batch dimensions must match.') 60 | 61 | out_shape = (*grid.shape[:-1], transform.size(1)) 62 | 63 | grid = homogenize(grid) 64 | coords = grid.view(grid.size(0), -1, grid.size(-1)) 65 | coords = transform @ coords.transpose(1, 2) 66 | coords = coords.transpose(1, 2) 67 | return dehomogenize(coords.view(*out_shape)) 68 | 69 | 70 | @torch.jit.script 71 | def transform_coords(coords, transform): 72 | coords, unsqueezed = ensure_batch_dim(coords, 2) 73 | 74 | coords = homogenize(coords) 75 | coords = transform @ coords.transpose(1, 2) 76 | coords = coords.transpose(1, 2) 77 | coords = dehomogenize(coords) 78 | if unsqueezed: 79 | coords = coords.squeeze(0) 80 | 81 | return coords 82 | 83 | 84 | @torch.jit.script 85 | def grid_to_coords(grid): 86 | return grid.view(grid.size(0), -1, grid.size(-1)) 87 | 88 | 89 | def spherical_to_cartesian(theta, phi, r=1.0): 90 | x = r * torch.cos(theta) * torch.sin(phi) 91 | y = r * torch.sin(theta) * torch.sin(phi) 92 | z = r * torch.cos(theta) 93 | return torch.stack((x, y, z), dim=-1) 94 | 95 | 96 | def points_bound(points): 97 | min_dim = torch.min(points, dim=0)[0] 98 | max_dim = torch.max(points, dim=0)[0] 99 | return torch.stack((min_dim, max_dim), dim=1) 100 | 101 | 102 | def points_radius(points): 103 | bounds = points_bound(points) 104 | centroid = bounds.mean(dim=1).unsqueeze(0) 105 | max_radius = torch.norm(points - centroid, dim=1).max() 106 | return max_radius 107 | 108 | 109 | def points_diameter(points): 110 | return 2* points_radius(points) 111 | 112 | 113 | def points_centroid(points): 114 | return points_bound(points).mean(dim=1) 115 | 116 | 117 | def points_bounding_size(points): 118 | bounds = points_bound(points) 119 | return torch.norm(bounds[:, 1] - bounds[:, 0]) 120 | -------------------------------------------------------------------------------- /training/preprocess_shapenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from pathlib import Path 4 | 5 | import bpy 6 | 7 | import os 8 | 9 | MAX_SIZE = 5e7 10 | 11 | 12 | _package_dir = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | 15 | def main(): 16 | # Drop blender arguments. 17 | argv = sys.argv[5:] 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument(dest='shapenet_dir', type=Path) 20 | parser.add_argument(dest='out_dir', type=Path) 21 | parser.add_argument('--strip-materials', action='store_true') 22 | parser.add_argument('--out-name', type=str, required=True) 23 | args = parser.parse_args(args=argv) 24 | 25 | paths = sorted(args.shapenet_dir.glob('**/model_normalized.obj')) 26 | 27 | for i, path in enumerate(paths): 28 | print(f"*** [{i+1}/{len(paths)}]") 29 | 30 | model_size = os.path.getsize(path) 31 | if model_size > MAX_SIZE: 32 | print("Model too big ({} > {})".format(model_size, MAX_SIZE)) 33 | continue 34 | 35 | synset_id = path.parent.parent.parent.name 36 | model_id = path.parent.parent.name 37 | # if model_id != '831918158307c1eef4757ae525403621': 38 | # continue 39 | print(f"Processing {path!s}") 40 | bpy.ops.wm.read_factory_settings(use_empty=True) 41 | bpy.ops.import_scene.obj(filepath=str(path), 42 | use_edges=True, 43 | use_smooth_groups=True, 44 | use_split_objects=True, 45 | use_split_groups=True, 46 | use_groups_as_vgroups=False, 47 | use_image_search=True) 48 | 49 | if len(bpy.data.objects) > 10: 50 | print("Too many objects. Skipping for now..") 51 | continue 52 | 53 | if args.strip_materials: 54 | print("Deleting materials.") 55 | for material in bpy.data.materials: 56 | # material.user_clear() 57 | bpy.data.materials.remove(material) 58 | 59 | for obj_idx, obj in enumerate(bpy.data.objects): 60 | bpy.context.view_layer.objects.active = obj 61 | bpy.ops.object.mode_set(mode='EDIT') 62 | bpy.ops.mesh.select_all(action='SELECT') 63 | print("Clearing split normals and removing doubles.") 64 | bpy.ops.mesh.customdata_custom_splitnormals_clear() 65 | bpy.ops.mesh.remove_doubles() 66 | bpy.ops.mesh.normals_make_consistent(inside=False) 67 | 68 | print("Unchecking auto_smooth") 69 | obj.data.use_auto_smooth = False 70 | 71 | bpy.ops.object.modifier_add(type='EDGE_SPLIT') 72 | print("Adding edge split modifier.") 73 | mod = obj.modifiers['EdgeSplit'] 74 | mod.split_angle = 20 75 | 76 | bpy.ops.object.mode_set(mode='OBJECT') 77 | 78 | print("Applying smooth shading.") 79 | bpy.ops.object.shade_smooth() 80 | 81 | print("Running smart UV project.") 82 | bpy.ops.uv.smart_project() 83 | 84 | bpy.context.active_object.select_set(state=False) 85 | 86 | out_path = args.out_dir / synset_id / model_id / 'models' / args.out_name 87 | print(out_path) 88 | out_path.parent.mkdir(exist_ok=True, parents=True) 89 | bpy.ops.export_scene.obj(filepath=str(out_path), 90 | group_by_material=True, 91 | keep_vertex_order=True, 92 | use_normals=True, use_uvs=True, 93 | use_materials=True, 94 | check_existing=False) 95 | print("Saved to {}".format(out_path)) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | 101 | # for headless processing, without display 102 | # blender -b -P preprocess_shapenet.py -- "$SHAPENET_PATH" "$OUT_PATH" --strip-materials --out-name uv_unwrapped.obj 103 | 104 | # with display 105 | # blender -P preprocess_shapenet.py -- "$SHAPENET_PATH" "$OUT_PATH" --strip-materials --out-name uv_unwrapped.obj 106 | -------------------------------------------------------------------------------- /lib/three/rigid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from lib.three import core 7 | 8 | 9 | def intrinsic_to_3x4(matrix): 10 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 11 | 12 | zeros = torch.zeros(1, 3, 1, dtype=matrix.dtype).expand(matrix.shape[0], -1, -1).to(matrix.device) 13 | mat = torch.cat((matrix, zeros), dim=-1) 14 | 15 | if unsqueezed: 16 | mat = mat.squeeze(0) 17 | 18 | return mat 19 | 20 | 21 | @torch.jit.script 22 | def RT_to_matrix(R, T): 23 | R, unsqueezed = core.ensure_batch_dim(R, num_dims=2) 24 | if R.shape[-1] == 3: 25 | R = F.pad(R, (0, 1, 0, 1)) # 4 x 4 26 | if R.dim() == 2: 27 | R = R[None, ...] 28 | if T.dim() == 1: 29 | T = T[None, ...] 30 | R[:, :3, 3] = T 31 | R[:, -1, -1] = 1.0 32 | if unsqueezed: 33 | R = R.squeeze(0) 34 | return R 35 | 36 | 37 | @torch.jit.script 38 | def matrix_3x3_to_4x4(matrix): 39 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 40 | 41 | mat = F.pad(matrix, [0, 1, 0, 1]) 42 | mat[:, -1, -1] = 1.0 43 | 44 | if unsqueezed: 45 | mat = mat.squeeze(0) 46 | 47 | return mat 48 | 49 | 50 | @torch.jit.script 51 | def rotation_to_4x4(matrix): 52 | return matrix_3x3_to_4x4(matrix) 53 | 54 | 55 | @torch.jit.script 56 | def translation_to_4x4(translation): 57 | translation, unsqueezed = core.ensure_batch_dim(translation, num_dims=1) 58 | 59 | eye = torch.eye(4, device=translation.device) 60 | mat = F.pad(translation.unsqueeze(2), [3, 0, 0, 1]) + eye 61 | 62 | if unsqueezed: 63 | mat = mat.squeeze(0) 64 | 65 | return mat 66 | 67 | 68 | def translate_matrix(matrix, offset): 69 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 70 | 71 | out = inverse_transform(matrix) 72 | out[:, :3, 3] += offset 73 | out = inverse_transform(out) 74 | 75 | if unsqueezed: 76 | out = out.squeeze(0) 77 | 78 | return out 79 | 80 | 81 | def scale_matrix(matrix, scale): 82 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 83 | 84 | out = inverse_transform(matrix) 85 | out[:, :3, 3] *= scale 86 | out = inverse_transform(out) 87 | 88 | if unsqueezed: 89 | out = out.squeeze(0) 90 | 91 | return out 92 | 93 | 94 | def decompose(matrix): 95 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 96 | 97 | # Extract rotation matrix. 98 | origin = (torch.tensor([0.0, 0.0, 0.0, 1.0], device=matrix.device, dtype=matrix.dtype) 99 | .unsqueeze(1) 100 | .unsqueeze(0)) 101 | origin = origin.expand(matrix.size(0), -1, -1) 102 | R = torch.cat((matrix[:, :, :3], origin), dim=-1) 103 | 104 | # Extract translation matrix. 105 | eye = torch.eye(4, 3, device=matrix.device).unsqueeze(0).expand(matrix.size(0), -1, -1) 106 | T = torch.cat((eye, matrix[:, :, 3].unsqueeze(-1)), dim=-1) 107 | 108 | if unsqueezed: 109 | R = R.squeeze(0) 110 | T = T.squeeze(0) 111 | 112 | return R, T 113 | 114 | 115 | def inverse_transform(matrix): 116 | matrix, unsqueezed = core.ensure_batch_dim(matrix, num_dims=2) 117 | 118 | R, T = decompose(matrix) 119 | R_inv = R.transpose(1, 2) 120 | t = T[:, :4, 3].unsqueeze(2) 121 | t_inv = (R_inv @ t)[:, :3].squeeze(2) 122 | 123 | out = torch.zeros_like(matrix) 124 | out[:, :3, :3] = R_inv[:, :3, :3] 125 | out[:, :3, 3] = -t_inv 126 | out[:, 3, 3] = 1 127 | 128 | if unsqueezed: 129 | out = out.squeeze(0) 130 | 131 | return out 132 | 133 | 134 | def extrinsic_to_position(extrinsic): 135 | extrinsic, unsqueezed = core.ensure_batch_dim(extrinsic, num_dims=2) 136 | 137 | rot_mat, trans_mat = decompose(extrinsic) 138 | position = rot_mat.transpose(2, 1) @ trans_mat[:, :, 3, None] 139 | position = core.dehomogenize(position.squeeze(-1)) 140 | 141 | if unsqueezed: 142 | position = position.squeeze(0) 143 | return position 144 | 145 | 146 | @torch.jit.script 147 | def random_translation(n: int, 148 | x_bound: Tuple[float, float], 149 | y_bound: Tuple[float, float], 150 | z_bound: Tuple[float, float]): 151 | trans_x = core.uniform(n, *x_bound) 152 | trans_y = core.uniform(n, *y_bound) 153 | trans_z = core.uniform(n, *z_bound) 154 | translation = torch.stack((trans_x, trans_y, trans_z), dim=-1) 155 | return translation 156 | -------------------------------------------------------------------------------- /example/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy import spatial 4 | 5 | def str2dict(ss): 6 | obj_score = dict() 7 | for obj_str in ss.split(','): 8 | obj_s = obj_str.strip() 9 | if len(obj_s) > 0: 10 | obj_id = obj_s.split(':')[0].strip() 11 | obj_s = obj_s.split(':')[1].strip() 12 | if len(obj_s) > 0: 13 | obj_score[int(obj_id)] = float(obj_s) 14 | return obj_score 15 | 16 | def cal_score(adi_str, add_str): 17 | adi_score = str2dict(adi_str) 18 | add_score = str2dict(add_str) 19 | add_score[10] = adi_score[10] 20 | add_score[11] = adi_score[11] 21 | if 3 in add_score: 22 | add_score.pop(3) 23 | if 7 in add_score: 24 | add_score.pop(7) 25 | return np.mean(list(add_score.values())) 26 | 27 | def printAD(add, adi, name='RAW'): 28 | print("{}: ADD:{:.5f}, ADI:{:.5f}, ADD(-S):{:.5f}".format( 29 | name, 30 | np.sum(list(str2dict(add).values()))/len(str2dict(add)), 31 | np.sum(list(str2dict(adi).values()))/len(str2dict(adi)), 32 | cal_score(adi_str=adi, add_str=add))) 33 | 34 | 35 | 36 | def box_2D_shape(points, pose, K): 37 | canonical_homo_pts = torch.tensor(vert2_to_bbox8(points).T, dtype=torch.float32) 38 | trans_homo = pose @ canonical_homo_pts 39 | homo_K = torch.zeros((3, 4), dtype=torch.float32) 40 | homo_K[:3, :3] = torch.tensor(K, dtype=torch.float32) 41 | bbox_2D = (homo_K @ trans_homo) 42 | bbox_2D = (bbox_2D[:2] / bbox_2D[2]).T.type(torch.int32)#.tolist() 43 | return bbox_2D 44 | 45 | 46 | def vert2_to_bbox8(corner_pts, homo=True): 47 | pts = list() 48 | for i in range(2): 49 | for j in range(2): 50 | for k in range(2): 51 | if homo: 52 | pt = [corner_pts[i, 0], corner_pts[j, 1], corner_pts[k, 2], 1.0] 53 | else: 54 | pt = [corner_pts[i, 0], corner_pts[j, 1], corner_pts[k, 2]] 55 | pts.append(pt) 56 | return np.asarray(pts) 57 | 58 | def bbox_to_shape(bbox_2D): 59 | connect_points = [[0, 2, 3, 1, 0], [0, 4, 6, 2], [2, 3, 7, 6], [6, 4, 5, 7], [7, 3, 1, 5]] 60 | shape = list() 61 | for plane in connect_points: 62 | for idx in plane: 63 | point = (bbox_2D[idx][0], bbox_2D[idx][1]) 64 | shape.append(point) 65 | return shape 66 | 67 | # def calc_ADDS(gt_pose, pd_pose, obj_model): 68 | 69 | def transform_pts_Rt(pts, R, t): 70 | """Applies a rigid transformation to 3D points. 71 | 72 | :param pts: nx3 ndarray with 3D points. 73 | :param R: 3x3 ndarray with a rotation matrix. 74 | :param t: 3x1 ndarray with a translation vector. 75 | :return: nx3 ndarray with transformed 3D points. 76 | """ 77 | assert (pts.shape[1] == 3) 78 | pts_t = R.dot(pts.T) + t.reshape((3, 1)) 79 | return pts_t.T 80 | 81 | 82 | def add(R_est, t_est, R_gt, t_gt, pts): 83 | """Average Distance of Model Points for objects with no indistinguishable 84 | views - by Hinterstoisser et al. (ACCV'12). 85 | 86 | :param R_est: 3x3 ndarray with the estimated rotation matrix. 87 | :param t_est: 3x1 ndarray with the estimated translation vector. 88 | :param R_gt: 3x3 ndarray with the ground-truth rotation matrix. 89 | :param t_gt: 3x1 ndarray with the ground-truth translation vector. 90 | :param pts: nx3 ndarray with 3D model points. 91 | :return: The calculated error. 92 | """ 93 | pts_est = transform_pts_Rt(pts, R_est, t_est) 94 | pts_gt = transform_pts_Rt(pts, R_gt, t_gt) 95 | e = np.linalg.norm(pts_est - pts_gt, axis=1).mean() 96 | return e 97 | 98 | def adi(R_est, t_est, R_gt, t_gt, pts): 99 | """Average Distance of Model Points for objects with indistinguishable views 100 | - by Hinterstoisser et al. (ACCV'12). 101 | 102 | :param R_est: 3x3 ndarray with the estimated rotation matrix. 103 | :param t_est: 3x1 ndarray with the estimated translation vector. 104 | :param R_gt: 3x3 ndarray with the ground-truth rotation matrix. 105 | :param t_gt: 3x1 ndarray with the ground-truth translation vector. 106 | :param pts: nx3 ndarray with 3D model points. 107 | :return: The calculated error. 108 | """ 109 | pts_est = transform_pts_Rt(pts, R_est, t_est) 110 | pts_gt = transform_pts_Rt(pts, R_gt, t_gt) 111 | 112 | # Calculate distances to the nearest neighbors from vertices in the 113 | # ground-truth pose to vertices in the estimated pose. 114 | nn_index = spatial.cKDTree(pts_est) 115 | nn_dists, _ = nn_index.query(pts_gt, k=1) 116 | 117 | e = nn_dists.mean() 118 | return e -------------------------------------------------------------------------------- /utility/meshutils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import trimesh 3 | import numpy as np 4 | 5 | import trimesh.remesh 6 | # from trimesh.visual.material import SimpleMaterial 7 | from scipy import linalg 8 | EPS = 10e-10 9 | 10 | def compute_vertex_normals(vertices, faces): 11 | normals = np.ones_like(vertices) 12 | triangles = vertices[faces] 13 | triangle_normals = np.cross(triangles[:, 1] - triangles[:, 0], 14 | triangles[:, 2] - triangles[:, 0]) 15 | triangle_normals /= (linalg.norm(triangle_normals, axis=1)[:, None] + EPS) 16 | normals[faces[:, 0]] += triangle_normals 17 | normals[faces[:, 1]] += triangle_normals 18 | normals[faces[:, 2]] += triangle_normals 19 | normals /= (linalg.norm(normals, axis=1)[:, None] + 0) 20 | 21 | return normals 22 | 23 | def are_trimesh_normals_corrupt(trimesh): 24 | corrupt_normals = linalg.norm(trimesh.vertex_normals, axis=1) == 0.0 25 | return corrupt_normals.sum() > 0 26 | 27 | def subdivide_mesh(mesh): 28 | attributes = {} 29 | if hasattr(mesh.visual, 'uv'): 30 | attributes = {'uv': mesh.visual.uv} 31 | vertices, faces, attributes = trimesh.remesh.subdivide( 32 | mesh.vertices, mesh.faces, attributes=attributes) 33 | mesh.vertices = vertices 34 | mesh.faces = faces 35 | if 'uv' in attributes: 36 | mesh.visual.uv = attributes['uv'] 37 | 38 | return mesh 39 | 40 | class Object3D(object): 41 | """Represents a graspable object.""" 42 | 43 | def __init__(self, path, load_materials=False): 44 | scene = trimesh.load(str(path)) 45 | if isinstance(scene, trimesh.Trimesh): 46 | scene = trimesh.Scene(scene) 47 | 48 | self.meshes: typing.List[trimesh.Trimesh] = list(scene.dump()) 49 | 50 | self.path = path 51 | self.scale = 1.0 52 | 53 | def to_scene(self): 54 | return trimesh.Scene(self.meshes) 55 | 56 | def are_normals_corrupt(self): 57 | for mesh in self.meshes: 58 | if are_trimesh_normals_corrupt(mesh): 59 | return True 60 | 61 | return False 62 | 63 | def recompute_normals(self): 64 | for mesh in self.meshes: 65 | mesh.vertex_normals = compute_vertex_normals(mesh.vertices, mesh.faces) 66 | 67 | return self 68 | 69 | def rescale(self, scale=1.0): 70 | """Set scale of object mesh. 71 | 72 | :param scale 73 | """ 74 | self.scale = scale 75 | for mesh in self.meshes: 76 | mesh.apply_scale(self.scale) 77 | 78 | return self 79 | 80 | def resize(self, size, ref='diameter'): 81 | """Set longest of all three lengths in Cartesian space. 82 | 83 | :param size 84 | """ 85 | if ref == 'diameter': 86 | ref_scale = self.bounding_diameter 87 | else: 88 | ref_scale = self.bounding_size 89 | 90 | self.scale = size / ref_scale 91 | for mesh in self.meshes: 92 | mesh.apply_scale(self.scale) 93 | 94 | return self 95 | 96 | @property 97 | def centroid(self): 98 | return self.bounds.mean(axis=0) 99 | 100 | @property 101 | def bounding_size(self): 102 | return max(self.extents) 103 | 104 | @property 105 | def bounding_diameter(self): 106 | centroid = self.bounds.mean(axis=0) 107 | max_radius = linalg.norm(self.vertices - centroid, axis=1).max() 108 | return max_radius * 2 109 | 110 | @property 111 | def bounding_radius(self): 112 | return self.bounding_diameter / 2.0 113 | 114 | @property 115 | def extents(self): 116 | min_dim = np.min(self.vertices, axis=0) 117 | max_dim = np.max(self.vertices, axis=0) 118 | return max_dim - min_dim 119 | 120 | @property 121 | def bounds(self): 122 | min_dim = np.min(self.vertices, axis=0) 123 | max_dim = np.max(self.vertices, axis=0) 124 | return np.stack((min_dim, max_dim), axis=0) 125 | 126 | def recenter(self, method='bounds'): 127 | if method == 'mean': 128 | # Center the mesh. 129 | vertex_mean = np.mean(self.vertices, 0) 130 | translation = -vertex_mean 131 | elif method == 'bounds': 132 | center = self.bounds.mean(axis=0) 133 | translation = -center 134 | else: 135 | raise ValueError(f"Unknown method {method!r}") 136 | 137 | for mesh in self.meshes: 138 | mesh.apply_translation(translation) 139 | 140 | return self 141 | 142 | @property 143 | def vertices(self): 144 | return np.concatenate([mesh.vertices for mesh in self.meshes]) 145 | -------------------------------------------------------------------------------- /training/data_augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import imgaug.augmenters as iaa 5 | import torchvision.transforms.functional as tf 6 | 7 | from lib import geometry 8 | 9 | 10 | def divergence_depth(anchor_depth, query_depth, min_dep_pixels=100, bins_num=100): 11 | hist_diff = 0 12 | anc_val_idx = anchor_depth>0 13 | que_val_idx = query_depth>0 14 | if anc_val_idx.sum() > min_dep_pixels and que_val_idx.sum() > min_dep_pixels: 15 | anc_vals = anchor_depth[anc_val_idx] 16 | que_vals = query_depth[que_val_idx] 17 | min_val = torch.minimum(anc_vals.min(), que_vals.min()) 18 | max_val = torch.maximum(anc_vals.max(), que_vals.max()) 19 | anc_hist = torch.histc(anc_vals, bins=bins_num, min=min_val, max=max_val) 20 | que_hist = torch.histc(que_vals, bins=bins_num, min=min_val, max=max_val) 21 | hist_diff = (que_hist - anc_hist).abs().mean() 22 | return hist_diff 23 | 24 | 25 | def batch_data_morph(depths, min_dep_pixels=None, hole_size=5, edge_size=5): 26 | new_depths = list() 27 | unsqueeze = False 28 | use_filter = False 29 | if min_dep_pixels is not None and isinstance(min_dep_pixels, int): 30 | use_filter = True 31 | 32 | if depths.dim() == 2: 33 | depths = depths[None, ...] 34 | if depths.dim() > 3: 35 | depths = depths.view(-1, depths.shape[-2], depths.shape[-1]) 36 | unsqueeze = True 37 | 38 | valid_idxes = torch.zeros(len(depths), dtype=torch.uint8) 39 | for ix, dep in enumerate(depths): 40 | dep = torch.tensor( 41 | cv2.morphologyEx( 42 | cv2.morphologyEx(dep.detach().cpu().numpy(), 43 | cv2.MORPH_CLOSE, 44 | np.ones((hole_size, hole_size), np.uint8) 45 | ), 46 | cv2.MORPH_OPEN, np.ones((edge_size, edge_size), np.uint8) 47 | ) 48 | ) 49 | new_depths.append(dep) 50 | if use_filter and (dep>0).sum() > min_dep_pixels: 51 | valid_idxes[ix] = 1 52 | new_depths = torch.stack(new_depths, dim=0).to(depths.device) 53 | if unsqueeze: 54 | new_depths = new_depths.unsqueeze(1) 55 | if use_filter: 56 | return new_depths, valid_idxes 57 | return new_depths 58 | 59 | 60 | def random_block_patches(tensor, max_area_cov=0.2, max_patch_nb=5): 61 | assert tensor.dim() == 4, "input must be BxCxHxW {}".format(tensor.shape) 62 | def square_patch(tensor, max_coverage=0.05): 63 | data_tensor = tensor.clone() 64 | batchsize, channel, height, width = data_tensor.shape 65 | coverage = torch.rand(len(data_tensor)) * (max_coverage - 0.01) + 0.01 66 | patches_size = (coverage.sqrt() * np.minimum(height, width)).type(torch.int64) 67 | square_mask = torch.zeros_like(data_tensor, dtype=torch.float32) 68 | x_offset = ((width - patches_size) * torch.rand(len(patches_size))).type(torch.int64) 69 | y_offset = ((height - patches_size) * torch.rand(len(patches_size))).type(torch.int64) 70 | for ix in range(batchsize): 71 | square_mask[ix, :, :patches_size[ix], :patches_size[ix]] = 1 72 | t_mask = tf.affine(img=square_mask[ix], angle=0, translate=(x_offset[ix], y_offset[ix]), scale=1.0, shear=0) 73 | data_tensor[ix] *= (1 - t_mask.to(data_tensor.device)) 74 | return data_tensor 75 | def circle_patch(tensor, max_coverage=0.05): 76 | data_tensor = tensor.clone() 77 | batchsize, channel, height, width = data_tensor.shape 78 | coverage = torch.rand(len(data_tensor)) * (max_coverage - 0.01) + 0.01 79 | patches_size = (coverage.sqrt() * np.minimum(height, width)).type(torch.int64) 80 | circle_mask = torch.zeros_like(data_tensor, dtype=torch.float32) 81 | radius = (patches_size / 2.0 - 0.5)[..., None, None, None] 82 | grid_map = torch.stack( 83 | torch.meshgrid(torch.linspace(0, height, height+1)[:-1], 84 | torch.linspace(0, width, width+1)[:-1]), dim=0 85 | ).expand(batchsize, -1, -1, -1) 86 | distance = ((grid_map[:, 0:1, :, :] - radius)**2 + (grid_map[:, 1:2, :, :] - radius)**2).sqrt() 87 | circle_mask[distance=0 111 | scaler = list(np.random.random(len(data))*(scale_jitter[1] - scale_jitter[0]) + scale_jitter[0]) 112 | 113 | aug = iaa.KeepSizeByResize( 114 | [ 115 | iaa.Resize(scaler), 116 | iaa.AdditiveLaplaceNoise(loc=0, scale=(0, 0.01), per_channel=True), 117 | # iaa.CoarseDropout(p=(0.01, 0.05), 118 | # size_percent=(0.1, 0.2), 119 | # ), 120 | iaa.Cutout(nb_iterations=(1, 5), 121 | position='normal', 122 | size=(0.01, 0.1), 123 | cval=0.0, 124 | fill_mode='constant', 125 | squared=0.1), 126 | iaa.GaussianBlur(sigma=(0.0, 1.5),), 127 | # iaa.AverageBlur(k=(2, 5)), 128 | ], 129 | interpolation=["nearest", "linear"], 130 | ) 131 | aug_depths = aug(images=data.detach().cpu().permute(0, 2, 3, 1).numpy()) 132 | aug_depths = torch.tensor(aug_depths).permute(0, 3, 1, 2).to(data.device) # B x C x H x W 133 | aug_depths[data<=0] = 0 134 | 135 | if nb_patch > 0: 136 | aug_depths = random_block_patches(aug_depths.clone().to(data.device), max_area_cov=area_patch, max_patch_nb=nb_patch) 137 | return aug_depths 138 | 139 | 140 | def zoom_and_crop(images, extrinsic, obj_diameter, cam_config, normalize=True, nan_check=False): 141 | device = images.device 142 | extrinsic = extrinsic.to(device) 143 | obj_diameter = obj_diameter.to(device) 144 | 145 | target_zoom_dist = cam_config.ZOOM_DIST_FACTOR * obj_diameter 146 | 147 | height, width = images.shape[-2:] 148 | cameras = geometry.Camera(intrinsic=cam_config.INTRINSIC.to(device), extrinsic=extrinsic.to(device), width=width, height=height) 149 | images_mask = torch.zeros_like(images) 150 | images_mask[images>0] = 1.0 151 | 152 | # substract mean depth value 153 | obj_dist = extrinsic[:, 2, 3] 154 | images -= images_mask * obj_dist[..., None, None, None].to(device) # substract the mean value 155 | 156 | # add noise based on object diameter 157 | random_noise = obj_diameter * (torch.rand_like(obj_diameter) - 0.5) # add noise to the depth image 158 | images += images_mask * random_noise[..., None, None, None] 159 | 160 | zoom_images, _ = cameras.zoom(images, target_size=cam_config.ZOOM_CROP_SIZE, target_dist=target_zoom_dist, scale_mode=cam_config.ZOOM_MODE) 161 | 162 | if nan_check: 163 | nan_cnt = torch.isnan(zoom_images.view(len(zoom_images), -1)).sum(dim=1) # calculate the amount of images containing NaN values 164 | val_idx = nan_cnt < 1 # return batch indexes of non-NaN images 165 | return zoom_images, val_idx 166 | return zoom_images 167 | -------------------------------------------------------------------------------- /evaluation/pplane_ICP.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code for point-to-plane ICP is modified from the respository https://github.com/pglira/simpleICP/tree/master/python 3 | """ 4 | import time 5 | import torch 6 | import numpy as np 7 | from datetime import datetime 8 | from scipy import spatial, stats 9 | 10 | def depth_to_pointcloud(depth, K): 11 | if not isinstance(depth, torch.Tensor): 12 | depth = torch.tensor(depth, dtype=torch.float32) 13 | K = K.squeeze().to(depth.device) 14 | depth = depth.squeeze() 15 | 16 | vs, us = depth.nonzero(as_tuple=True) 17 | zs = depth[vs, us] 18 | xs = (us - K[0, 2]) * zs / K[0, 0] 19 | ys = (vs - K[1, 2]) * zs / K[1, 1] 20 | pts = torch.stack([xs, ys, zs], dim=1) 21 | return pts 22 | 23 | 24 | def torch_batch_cov(X): 25 | """ 26 | calculate covariance 27 | """ 28 | mean = torch.mean(X, dim=-1).unsqueeze(-1) 29 | X = X - mean 30 | cov = X @ X.transpose(-1, -2) / (X.shape[-1] - 1) 31 | return cov 32 | 33 | 34 | class PointCloud: 35 | def __init__(self, pts): 36 | self.xyz_pts = pts 37 | self.normals = None 38 | self.planarity = None 39 | self.no_points = len(pts) 40 | self.sel = None 41 | self.device=pts.device 42 | self.dtype = pts.dtype 43 | 44 | def select_n_points(self, n): 45 | if self.no_points > n: 46 | self.sel = torch.linspace(0, self.no_points-1, n).round().type(torch.int64).to(self.device) 47 | else: 48 | self.sel = torch.arange(self.no_points).to(self.device) 49 | 50 | def estimate_normals(self, neighbors): 51 | self.normals = torch.full((self.no_points, 3), float('nan'), dtype=self.dtype, device=self.device) 52 | self.planarity = torch.full((self.no_points, ), float('nan'), dtype=self.dtype, device=self.device) 53 | 54 | knn_dists = -(self.xyz_pts[self.sel].unsqueeze(1) - self.xyz_pts.unsqueeze(0)).norm(dim=2, p=2) # QxN 55 | _, idxNN_all_qp = torch.topk(knn_dists, k=neighbors, dim=1) 56 | 57 | selected_points = self.xyz_pts[idxNN_all_qp] 58 | batch_C = torch_batch_cov(selected_points.transpose(-2, -1)) 59 | 60 | eig_vals, eig_vecs = np.linalg.eig(batch_C.detach().cpu().numpy()) 61 | eig_vals = torch.tensor(eig_vals).to(self.device) 62 | eig_vecs = torch.tensor(eig_vecs).to(self.device) 63 | 64 | _, idx_sort_vals = eig_vals.topk(k=eig_vals.shape[-1], dim=-1) # descending orders, Qx3 65 | idx_sort_vecs = idx_sort_vals[:, 2:3][..., None].repeat(1, 3, 1) # Qx3x3 66 | new_eig_vals = torch.gather(eig_vals, dim=1, index=idx_sort_vals).squeeze() # sorted eigen values by descending order 67 | new_eig_vecs = torch.gather(eig_vecs, dim=2, index=idx_sort_vecs).squeeze() # the vector whose corresponds to the smallest eigen value 68 | 69 | self.normals[self.sel] = new_eig_vecs 70 | self.planarity[self.sel] = (new_eig_vals[:, 1] - new_eig_vals[:, 2]) / new_eig_vals[:, 0] 71 | 72 | def transform(self, H): 73 | XInH = PointCloud.euler_coord_to_homogeneous_coord(self.xyz_pts) 74 | XOutH = (H @ XInH.T).T 75 | self.xyz_pts = PointCloud.homogeneous_coord_to_euler_coord(XOutH) 76 | 77 | 78 | @staticmethod 79 | def euler_coord_to_homogeneous_coord(XE): 80 | no_points = XE.shape[0] 81 | XH = torch.cat([XE, torch.ones(no_points, 1, device=XE.device)], dim=-1) 82 | return XH 83 | 84 | @staticmethod 85 | def homogeneous_coord_to_euler_coord(XH): 86 | XE = torch.stack([XH[:,0]/XH[:,3], XH[:,1]/XH[:,3], XH[:,2]/XH[:,3]], dim=-1) 87 | 88 | return XE 89 | 90 | def matching(pcfix, pcmov): 91 | knn_dists = -(pcfix.xyz_pts[pcfix.sel].unsqueeze(1) - pcmov.xyz_pts.unsqueeze(0)).norm(dim=2, p=2) # QxN 92 | pcmov.sel = torch.topk(knn_dists, k=1, dim=1)[1].squeeze() 93 | dxdyxdz = pcmov.xyz_pts[pcmov.sel] - pcfix.xyz_pts[pcfix.sel] 94 | nxnynz = pcfix.normals[pcfix.sel] # Qx3 95 | distances = (dxdyxdz * nxnynz).sum(dim=1) 96 | 97 | return distances 98 | 99 | 100 | def reject(pcfix, pcmov, min_planarity, distances): 101 | planarity = pcfix.planarity[pcfix.sel] 102 | med = distances.median() 103 | sigmad = (distances - torch.median(distances)).abs().median() * 1.4826 # normal 104 | 105 | keep_distance = abs(distances-med) <= 3 * sigmad 106 | keep_planarity = planarity > min_planarity 107 | keep = keep_distance & keep_planarity 108 | 109 | pcfix.sel = pcfix.sel[keep] 110 | pcmov.sel = pcmov.sel[keep] 111 | distances = distances[keep] 112 | 113 | return distances 114 | 115 | 116 | def estimate_rigid_body_transformation(pcfix, pcmov): 117 | fix_pts = pcfix.xyz_pts[pcfix.sel] 118 | dst_normals = pcfix.normals[pcfix.sel] 119 | 120 | mov_pts = pcmov.xyz_pts[pcmov.sel] 121 | x_mov = mov_pts[:, 0] 122 | y_mov = mov_pts[:, 1] 123 | z_mov = mov_pts[:, 2] 124 | 125 | nx_fix = dst_normals[:, 0] 126 | ny_fix = dst_normals[:, 1] 127 | nz_fix = dst_normals[:, 2] 128 | 129 | A = torch.stack([-z_mov*ny_fix + y_mov*nz_fix, 130 | z_mov*nx_fix - x_mov*nz_fix, 131 | -y_mov*nx_fix + x_mov*ny_fix, 132 | nx_fix, ny_fix, nz_fix], dim=-1).detach().cpu().numpy() 133 | 134 | b = (dst_normals * (fix_pts - mov_pts)).sum(dim=1).detach().cpu().numpy() # Sx3 -> S 135 | 136 | x, _, _, _ = np.linalg.lstsq(A, b) 137 | 138 | A = torch.tensor(A).to(pcfix.device) 139 | b = torch.tensor(b).to(pcfix.device) 140 | x = torch.tensor(x).to(pcfix.device) 141 | 142 | x = torch.clamp(x, torch.tensor(-0.5, device=pcfix.device), torch.tensor(0.5, device=pcfix.device)) 143 | 144 | residuals = A @ x - b 145 | 146 | R = euler_angles_to_linearized_rotation_matrix(x[0], x[1], x[2]) 147 | t = x[3:6] 148 | H = create_homogeneous_transformation_matrix(R, t) 149 | 150 | return H, residuals 151 | 152 | 153 | def euler_angles_to_linearized_rotation_matrix(alpha1, alpha2, alpha3): 154 | dR = torch.tensor([[ 1, -alpha3, alpha2], 155 | [ alpha3, 1, -alpha1], 156 | [-alpha2, alpha1, 1]]).to(alpha1.device) 157 | 158 | return dR 159 | 160 | 161 | def create_homogeneous_transformation_matrix(R, t): 162 | H = torch.tensor([[R[0,0], R[0,1], R[0,2], t[0]], 163 | [R[1,0], R[1,1], R[1,2], t[1]], 164 | [R[2,0], R[2,1], R[2,2], t[2]], 165 | [ 0, 0, 0, 1]]).to(R.device) 166 | 167 | return H 168 | 169 | def check_convergence_criteria(distances_new, distances_old, min_change): 170 | def change(new, old): 171 | return torch.abs((new - old) / old * 100) 172 | 173 | change_of_mean = change(torch.mean(distances_new), torch.mean(distances_old)) 174 | change_of_std = change(torch.std(distances_new), torch.std(distances_old)) 175 | 176 | return True if change_of_mean < min_change and change_of_std < min_change else False 177 | 178 | 179 | def sim_icp(X_fix, X_mov, correspondences=1000, neighbors=10, min_planarity=0.3, min_change=1, max_iterations=100, verbose=False): 180 | if len(X_fix) < neighbors: 181 | return torch.eye(4, dtype=X_fix.dtype).to(X_fix.device) 182 | pcfix = PointCloud(X_fix) 183 | pcmov = PointCloud(X_mov) 184 | 185 | pcfix.select_n_points(correspondences) 186 | sel_orig = pcfix.sel 187 | 188 | pcfix.estimate_normals(neighbors) # 500ms 189 | 190 | H = torch.eye(4, dtype=X_fix.dtype).to(X_fix.device) 191 | residual_distances = [] 192 | 193 | for i in range(0, max_iterations): 194 | initial_distances = matching(pcfix, pcmov) # 146ms 195 | # Todo Change initial_distances without return argument 196 | initial_distances = reject(pcfix, pcmov, min_planarity, initial_distances) # 3.3ms 197 | dH, residuals = estimate_rigid_body_transformation(pcfix, pcmov) 198 | residual_distances.append(residuals) 199 | pcmov.transform(dH) 200 | 201 | H = dH @ H 202 | pcfix.sel = sel_orig 203 | 204 | if i > 0: 205 | if check_convergence_criteria(residual_distances[i], residual_distances[i-1], min_change): 206 | break 207 | return H -------------------------------------------------------------------------------- /evaluation/TLESS_MPmask_OVE6D_sixd17.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | # import glob 5 | import json 6 | import yaml 7 | import time 8 | import torch 9 | import warnings 10 | import numpy as np 11 | from PIL import Image 12 | from pathlib import Path 13 | from os.path import join as pjoin 14 | 15 | warnings.filterwarnings("ignore") 16 | 17 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | sys.path.append(base_path) 19 | 20 | 21 | from dataset import TLESS_Dataset 22 | from lib import network, rendering 23 | from evaluation import utils 24 | from evaluation import config as cfg 25 | 26 | # this function is borrowed from https://github.com/thodan/sixd_toolkit/blob/master/pysixd/inout.py 27 | def save_results_sixd17(path, res, run_time=-1): 28 | 29 | txt = 'run_time: ' + str(run_time) + '\n' # The first line contains run time 30 | txt += 'ests:\n' 31 | line_tpl = '- {{score: {:.8f}, ' \ 32 | 'R: [' + ', '.join(['{:.8f}'] * 9) + '], ' \ 33 | 't: [' + ', '.join(['{:.8f}'] * 3) + ']}}\n' 34 | for e in res['ests']: 35 | Rt = e['R'].flatten().tolist() + e['t'].flatten().tolist() 36 | txt += line_tpl.format(e['score'], *Rt) 37 | with open(path, 'w') as f: 38 | f.write(txt) 39 | 40 | gpu_id = 0 41 | # gpu_id = 1 42 | 43 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 44 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 45 | os.environ['EGL_DEVICE_ID'] = str(gpu_id) 46 | DEVICE = torch.device('cuda') 47 | 48 | datapath = Path(cfg.DATA_PATH) 49 | 50 | cfg.DATASET_NAME = 'tless' # dataset name 51 | 52 | eval_dataset = TLESS_Dataset.Dataset(datapath / cfg.DATASET_NAME) 53 | cfg.RENDER_WIDTH = eval_dataset.cam_width # the width of rendered images 54 | cfg.RENDER_HEIGHT = eval_dataset.cam_height # the height of rendered imagescd 55 | 56 | 57 | ckpt_file = pjoin(base_path, 58 | 'checkpoints', 59 | "OVE6D_pose_model.pth" 60 | ) 61 | 62 | model_net = network.OVE6D().to(DEVICE) 63 | 64 | model_net.load_state_dict(torch.load(ckpt_file), strict=True) 65 | model_net.eval() 66 | 67 | codebook_saving_dir = pjoin(base_path,'evaluation/object_codebooks', 68 | cfg.DATASET_NAME, 69 | 'zoom_{}'.format(cfg.ZOOM_DIST_FACTOR), 70 | 'views_{}'.format(str(cfg.RENDER_NUM_VIEWS))) 71 | 72 | 73 | object_codebooks = utils.OVE6D_codebook_generation(codebook_dir=codebook_saving_dir, 74 | model_func=model_net, 75 | dataset=eval_dataset, 76 | config=cfg, 77 | device=DEVICE) 78 | 79 | raw_pred_results = list() 80 | icp1_pred_results = list() 81 | icpk_pred_results = list() 82 | raw_pred_runtime = list() 83 | icp1_pred_runtime = list() 84 | icpk_pred_runtime = list() 85 | 86 | test_data_dir = datapath / 'tless' / 'test_primesense' 87 | rcnn_mask_dir = datapath / 'tless' / 'mask_RCNN_50' 88 | 89 | 90 | eval_dir = pjoin(base_path, 'evaluation/pred_results/TLESS') 91 | 92 | raw_file_mode = "raw-sampleN{}-viewpointK{}-poseP{}-mpmask_tless_primesense" 93 | if cfg.USE_ICP: 94 | icp1_file_mode = "icp1-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-mpmask_tless_primesense" 95 | icpk_file_mode = "icpk-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-mpmask_tless_primesense" 96 | 97 | obj_renderer = rendering.Renderer(width=cfg.RENDER_WIDTH, height=cfg.RENDER_HEIGHT) 98 | 99 | for scene_id in sorted(os.listdir(test_data_dir)): 100 | raw_eval_dir = pjoin(eval_dir, raw_file_mode.format( 101 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK)) 102 | scene_raw_eval_dir = pjoin(raw_eval_dir, scene_id) 103 | if not os.path.exists(scene_raw_eval_dir): 104 | os.makedirs(scene_raw_eval_dir) 105 | 106 | if cfg.USE_ICP: 107 | icp1_eval_dir = pjoin(eval_dir, icp1_file_mode.format( 108 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 109 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 110 | )) 111 | icpk_eval_dir = pjoin(eval_dir, icpk_file_mode.format( 112 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 113 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 114 | )) 115 | scene_icp1_eval_dir = pjoin(icp1_eval_dir, scene_id) 116 | if not os.path.exists(scene_icp1_eval_dir): 117 | os.makedirs(scene_icp1_eval_dir) 118 | scene_icpk_eval_dir = pjoin(icpk_eval_dir, scene_id) 119 | if not os.path.exists(scene_icpk_eval_dir): 120 | os.makedirs(scene_icpk_eval_dir) 121 | 122 | scene_dir = pjoin(test_data_dir, scene_id) 123 | if not os.path.isdir(scene_dir): 124 | continue 125 | 126 | cam_info_file = pjoin(scene_dir, 'scene_camera.json') 127 | with open(cam_info_file, 'r') as cam_f: 128 | scene_camera_info = json.load(cam_f) 129 | 130 | scene_mask_dir = pjoin(rcnn_mask_dir, "{:02d}".format(int(scene_id))) 131 | scene_rcnn_file = pjoin(scene_mask_dir, 'mask_rcnn_predict.yml') 132 | with open(scene_rcnn_file, 'r') as rcnn_f: 133 | scene_detect_info = yaml.load(rcnn_f, Loader=yaml.FullLoader) 134 | 135 | depth_dir = pjoin(scene_dir, 'depth') 136 | view_runtime = list() 137 | for depth_png in sorted(os.listdir(depth_dir)): 138 | if not depth_png.endswith('.png'): 139 | continue 140 | view_id = int(depth_png.split('.')[0]) # 000000.png 141 | view_rcnn_ret = scene_detect_info[view_id] # scene detection results 142 | view_cam_info = scene_camera_info[str(view_id)] # scene camera information 143 | 144 | depth_file = pjoin(depth_dir, depth_png) 145 | mask_file = pjoin(scene_mask_dir, 'masks', '{}.npy'.format(view_id)) # 0000001.npy 146 | view_masks = torch.tensor(np.load(mask_file), dtype=torch.float32) 147 | view_depth = torch.from_numpy(np.array(Image.open(depth_file), dtype=np.float32)) 148 | 149 | view_depth *= view_cam_info['depth_scale'] 150 | view_camK = torch.tensor(view_cam_info['cam_K'], dtype=torch.float32).view(3, 3)[None, ...] # 1x3x3 151 | view_timer = time.time() 152 | for obj_rcnn in view_rcnn_ret: # estimate the detected objects 153 | obj_timer = time.time() 154 | chan = obj_rcnn['np_channel_id'] 155 | obj_id = obj_rcnn['obj_id'] 156 | obj_conf = obj_rcnn['score'] 157 | if obj_conf < 0: # only consider the valid detected objects 158 | continue 159 | if len(view_masks.shape) == 2: 160 | obj_mask = view_masks 161 | else: 162 | obj_mask = view_masks[:, :, chan] # 1xHxW 163 | 164 | obj_depth = view_depth * obj_mask 165 | obj_depth = obj_depth * cfg.MODEL_SCALING # from mm to meter 166 | obj_codebook = object_codebooks[obj_id] 167 | obj_depth = obj_depth.unsqueeze(0) 168 | obj_mask = obj_mask.unsqueeze(0) 169 | pose_ret = utils.OVE6D_mask_full_pose(model_func=model_net, 170 | obj_depth=obj_depth, 171 | obj_mask=obj_mask, 172 | obj_codebook=obj_codebook, 173 | cam_K=view_camK, 174 | config=cfg, 175 | device=DEVICE, 176 | obj_renderer=obj_renderer) 177 | 178 | raw_preds = dict() 179 | raw_preds.setdefault('ests',[]).append({'score': pose_ret['raw_score'].squeeze().numpy(), 180 | 'R': cfg.POSE_TO_BOP(pose_ret['raw_R']).numpy().squeeze(), 181 | 't': pose_ret['raw_t'].squeeze().numpy() * 1000.0}) 182 | 183 | raw_ret_path = os.path.join(scene_raw_eval_dir, '%04d_%02d.yml' % (view_id, obj_id)) 184 | save_results_sixd17(raw_ret_path, raw_preds, run_time=pose_ret['raw_time']) 185 | raw_pred_runtime.append(pose_ret['raw_time']) 186 | 187 | if cfg.USE_ICP: 188 | icp1_preds = dict() 189 | icp1_preds.setdefault('ests',[]).append({'score': pose_ret['icp1_score'].squeeze().numpy(), 190 | 'R': cfg.POSE_TO_BOP(pose_ret['icp1_R']).numpy().squeeze(), 191 | 't': pose_ret['icp1_t'].squeeze().numpy() * 1000.0}) 192 | 193 | icp1_ret_path = os.path.join(scene_icp1_eval_dir, '%04d_%02d.yml' % (view_id, obj_id)) 194 | save_results_sixd17(icp1_ret_path, icp1_preds, run_time=pose_ret['icp1_time']) 195 | icp1_pred_runtime.append(pose_ret['icp1_time']) 196 | 197 | icpk_preds = dict() 198 | icpk_preds.setdefault('ests',[]).append({'score': pose_ret['icpk_score'].squeeze().numpy(), 199 | 'R': cfg.POSE_TO_BOP(pose_ret['icpk_R']).numpy().squeeze(), 200 | 't': pose_ret['icpk_t'].squeeze().numpy() * 1000.0}) 201 | 202 | icpk_ret_path = os.path.join(scene_icpk_eval_dir, '%04d_%02d.yml' % (view_id, obj_id)) 203 | save_results_sixd17(icpk_ret_path, icpk_preds, run_time=pose_ret['icpk_time']) 204 | icpk_pred_runtime.append(pose_ret['icpk_time']) 205 | 206 | view_runtime.append(time.time() - view_timer) 207 | if (view_id+1) % 100 == 0: 208 | print('scene:{}, image: {}, image_cost:{:.3f}, raw_t:{:.3f}, icp1_t:{:.3f}, icpk_t:{:.3f}'.format( 209 | int(scene_id), view_id+1, np.mean(view_runtime), 210 | np.mean(raw_pred_runtime), np.mean(icp1_pred_runtime), np.mean(icpk_pred_runtime))) 211 | 212 | 213 | print('{}, {}'.format(scene_id, time.strftime('%m_%d-%H:%M:%S', time.localtime()))) 214 | 215 | mean_raw_time = np.mean(raw_pred_runtime) 216 | print('raw_mean_runtime: {:.4f}'.format(mean_raw_time)) 217 | 218 | if cfg.USE_ICP: 219 | mean_icp1_time = np.mean(icp1_pred_runtime) 220 | mean_icpk_time = np.mean(icpk_pred_runtime) 221 | print('icp1_mean_runtime: {:.4f}'.format(mean_icp1_time)) 222 | print('icpk_mean_runtime: {:.4f}'.format(mean_icpk_time)) 223 | 224 | del obj_renderer 225 | 226 | 227 | -------------------------------------------------------------------------------- /training/pyrenderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import structlog 5 | from torch.utils.data import Dataset 6 | 7 | from lib import rendering 8 | from lib.three import rigid 9 | 10 | from training import data_augment 11 | from training import train_utils 12 | 13 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 14 | logger = structlog.get_logger(__name__) 15 | 16 | 17 | class PyrenderDataset(Dataset): 18 | def __init__(self, shape_paths, config, 19 | x_bound=(-0.04,0.04), 20 | y_bound=(-0.02,0.02), 21 | scale_jitter=(0.5, 1.0), 22 | dist_jitter=(0.5, 1.5), 23 | aug_guassian_std=0.01, 24 | aug_rescale_jitter=(0.2, 0.8), 25 | aug_patch_area_ratio=0.2, 26 | aug_patch_max_num=1 27 | ): 28 | super().__init__() 29 | self.shape_paths = shape_paths 30 | self.width = config.RENDER_WIDTH 31 | self.height = config.RENDER_HEIGHT 32 | self.num_inputs = config.NUM_VIEWS 33 | self.intrinsic = config.INTRINSIC 34 | self.dist_base = config.RENDER_DIST 35 | self.data_augment = config.USE_DATA_AUG 36 | self.hist_bin_num = config.HIST_BIN_NUMS 37 | self.min_hist_filter_threshold = config.MIN_HIST_STAT 38 | self.min_dep_pixel_threshold = config.MIN_DEPTH_PIXELS 39 | 40 | self.x_bound = x_bound 41 | self.y_bound = y_bound 42 | self.scale_jitter = scale_jitter 43 | self.dist_jitter = torch.tensor(dist_jitter) 44 | 45 | self.aug_guassian_std = aug_guassian_std 46 | self.aug_rescale_jitter = aug_rescale_jitter 47 | self.aug_patch_area_ratio = aug_patch_area_ratio 48 | self.aug_patch_max_num = aug_patch_max_num 49 | 50 | self._renderer = None 51 | self._worker_id = None 52 | self._log = None 53 | 54 | def __len__(self): 55 | return len(self.shape_paths) 56 | 57 | def worker_init_fn(self, worker_id): 58 | self._worker_id = worker_id 59 | self._log = logger.bind(worker_id=worker_id) 60 | self._renderer = rendering.Renderer(width=self.width, height=self.height) 61 | # self._log.info('renderer initialized') 62 | 63 | def random_rotation(self, n): 64 | random_R = rendering.random_xyz_rotation(n) 65 | anchor_R = random_R @ rendering.evenly_distributed_rotation(n) 66 | outplane_R = rendering.random_xy_rotation(n) 67 | inplane_R = rendering.random_z_rotation(n) 68 | jitter_R = rendering.random_xy_rotation(n, rang_degree=3) 69 | return anchor_R, inplane_R, outplane_R, jitter_R 70 | 71 | def __getitem__(self, idx): 72 | 73 | anchor_R, inplane_R, outplane_R, jitter_R = self.random_rotation(self.num_inputs) 74 | 75 | scale_jitter = random.uniform(*self.scale_jitter) 76 | 77 | while True: 78 | model_path = random.choice(self.shape_paths) 79 | file_size = model_path.stat().st_size 80 | max_size = 2e7 81 | if file_size > max_size: 82 | # self._log.warning('skipping large model', path=model_path, max_size=max_size, file_size=file_size) 83 | continue 84 | try: 85 | obj, obj_diameter = rendering.load_object(model_path, scale=scale_jitter) 86 | 87 | obj_dist = self.dist_base * obj_diameter 88 | z_bound = (obj_dist * min(self.dist_jitter), obj_dist * max(self.dist_jitter)) # camera distance is set to be relative to object diameter 89 | 90 | anchor_T = rigid.random_translation(self.num_inputs, self.x_bound, self.y_bound, z_bound) 91 | inplane_T = rigid.random_translation(self.num_inputs, self.x_bound, self.y_bound, z_bound) 92 | outplane_T = rigid.random_translation(self.num_inputs, self.x_bound, self.y_bound, z_bound) 93 | 94 | context = rendering.SceneContext(obj, self.intrinsic) 95 | break 96 | except ValueError as e: 97 | continue 98 | # self._log.error('exception while loading mesh', exc_info=e) 99 | obj_diameters = obj_diameter.repeat(self.num_inputs) 100 | 101 | anchor_masks = list() 102 | anchor_depths = list() 103 | 104 | inplane_masks = list() 105 | inplane_depths = list() 106 | 107 | outplane_masks = list() 108 | outplane_depths = list() 109 | 110 | jitter_inplane_depths = list() 111 | 112 | valid_rot_idexes = list() # the discrepancy error count between anchor camera and its out-of-plane rotation 113 | 114 | # for R, T in zip(anchor_R, anchor_T): 115 | # context.set_pose(rotation=R, translation=T) 116 | # depth, mask = self._renderer.render(context)[1:] 117 | # anchor_masks.append(mask) 118 | # anchor_depths.append(depth) 119 | 120 | in_Rxyz = inplane_R @ anchor_R # object-space rotation 121 | for R, T in zip(in_Rxyz, inplane_T): 122 | context.set_pose(rotation=R, translation=T) 123 | depth, mask = self._renderer.render(context)[1:] 124 | inplane_masks.append(mask) 125 | inplane_depths.append(depth) 126 | 127 | 128 | jitter_in_Rxyz = jitter_R @ in_Rxyz # jittering the object-space rotation 129 | for R, T in zip(jitter_in_Rxyz, inplane_T): 130 | context.set_pose(rotation=R, translation=T) 131 | depth, mask = self._renderer.render(context)[1:] 132 | jitter_inplane_depths.append(depth) 133 | 134 | 135 | out_Rxy = outplane_R @ anchor_R # object-space rotation 136 | for R, T in zip(out_Rxy, outplane_T): 137 | context.set_pose(rotation=R, translation=T) 138 | depth, mask = self._renderer.render(context)[1:] 139 | outplane_masks.append(mask) 140 | outplane_depths.append(depth) 141 | 142 | 143 | # constant_T = torch.zeros_like(anchor_T) 144 | # constant_T[:, -1] = obj_dist # centerizing object with constant distance 145 | for anc_R, oup_R, inp_R, const_T in zip(anchor_R, out_Rxy, jitter_in_Rxyz, anchor_T): 146 | context.set_pose(rotation=anc_R, translation=const_T) 147 | anc_depth, anc_mask = self._renderer.render(context)[1:] 148 | context.set_pose(rotation=oup_R, translation=const_T) 149 | oup_depth = self._renderer.render(context)[1] 150 | 151 | anchor_masks.append(anc_mask) 152 | anchor_depths.append(anc_depth) 153 | 154 | 155 | # #calculate the viewpoint angles for inplane and outplane relative to anchor 156 | oup_vp_sim = (anc_R[2] * oup_R[2]).sum() # oup_vp_angle = arccos(oup_vp_sim) 157 | inp_vp_sim = (anc_R[2] * inp_R[2]).sum() # inp_vp_angle = arccos(inp_vp_sim) 158 | # #inp_vp_sim > oup_vp_sim is favored, inp_R is supposed to be closer to anc_R compared with oup_R 159 | 160 | # #the out-of-plane depth pairs (anc, out) are supposed to be having different depth distribution 161 | hist_diff = data_augment.divergence_depth(anc_depth, oup_depth, 162 | min_dep_pixels=self.min_dep_pixel_threshold, bins_num=self.hist_bin_num) 163 | if (inp_vp_sim <= oup_vp_sim) or (hist_diff < self.min_hist_filter_threshold): 164 | valid_rot_idexes.append(0) # invalid negative depth pair due to equivalent depth distribution 165 | else: 166 | valid_rot_idexes.append(1) 167 | 168 | del context 169 | valid_rot_indexes = torch.tensor(valid_rot_idexes, dtype=torch.uint8) 170 | 171 | anchor_masks = torch.stack(anchor_masks, dim=0).unsqueeze(1) 172 | anchor_depths = torch.stack(anchor_depths, dim=0).unsqueeze(1) 173 | 174 | inplane_masks = torch.stack(inplane_masks, dim=0).unsqueeze(1) 175 | inplane_depths = torch.stack(inplane_depths, dim=0).unsqueeze(1) 176 | 177 | jitter_inplane_depths = torch.stack(jitter_inplane_depths, dim=0).unsqueeze(1) 178 | 179 | outplane_masks = torch.stack(outplane_masks, dim=0).unsqueeze(1) 180 | outplane_depths = torch.stack(outplane_depths, dim=0).unsqueeze(1) 181 | 182 | anchor_extrinsic = rigid.RT_to_matrix(anchor_R, anchor_T) 183 | inplane_extrinsic = rigid.RT_to_matrix(in_Rxyz, inplane_T) 184 | outplane_extrinsic = rigid.RT_to_matrix(out_Rxy, outplane_T) 185 | 186 | outplane_depths_aug = outplane_depths 187 | inplane_depths_aug = jitter_inplane_depths 188 | 189 | valid_anc_idxes = torch.ones_like(valid_rot_indexes) 190 | valid_inp_idxes = torch.ones_like(valid_rot_indexes) 191 | valid_out_idxes = torch.ones_like(valid_rot_indexes) 192 | 193 | if self.data_augment: 194 | if random.random() > 0.5: 195 | inplane_depths_aug = data_augment.custom_aug(inplane_depths_aug, 196 | noise_level=self.aug_guassian_std, 197 | scale_jitter=self.aug_rescale_jitter, 198 | area_patch=self.aug_patch_area_ratio, 199 | nb_patch=self.aug_patch_max_num) 200 | if random.random() > 0.5: 201 | outplane_depths_aug = data_augment.custom_aug(outplane_depths_aug, 202 | noise_level=self.aug_guassian_std, 203 | scale_jitter=self.aug_rescale_jitter, 204 | area_patch=self.aug_patch_area_ratio, 205 | nb_patch=self.aug_patch_max_num) 206 | if random.random() > 0.5: 207 | inplane_depths_aug, valid_inp_idxes = data_augment.batch_data_morph(inplane_depths_aug, 208 | min_dep_pixels=self.min_dep_pixel_threshold, 209 | hole_size=5, 210 | edge_size=5) 211 | if random.random() > 0.5: 212 | outplane_depths_aug, valid_out_idxes = data_augment.batch_data_morph(outplane_depths_aug, 213 | min_dep_pixels=self.min_dep_pixel_threshold, 214 | hole_size=5, 215 | edge_size=5) 216 | 217 | return { 218 | 'anchor': { 219 | 'mask': anchor_masks, 220 | 'depth': anchor_depths, 221 | 'extrinsic': anchor_extrinsic, 222 | 'rotation_to_anchor': torch.eye(3).expand(self.num_inputs, -1, -1), 223 | 'valid_idx': valid_rot_indexes * valid_anc_idxes, 224 | 'obj_diameter': obj_diameters, 225 | }, 226 | 'inplane': { 227 | 'mask': inplane_masks, 228 | 'depth': inplane_depths, 229 | 'aug_depth': train_utils.background_filter(inplane_depths_aug, obj_diameters), 230 | 'extrinsic': inplane_extrinsic, 231 | 'rotation_to_anchor': inplane_R, 232 | 'valid_idx': valid_rot_indexes * valid_inp_idxes, 233 | 'obj_diameter': obj_diameters, 234 | }, 235 | 'outplane': { 236 | 'mask': outplane_masks, 237 | 'depth': outplane_depths, 238 | 'aug_depth': train_utils.background_filter(outplane_depths_aug, obj_diameters), 239 | 'extrinsic': outplane_extrinsic, 240 | 'rotation_to_anchor': outplane_R, 241 | 'valid_idx': valid_rot_indexes * valid_out_idxes, 242 | 'obj_diameter': obj_diameters, 243 | }, 244 | } -------------------------------------------------------------------------------- /lib/network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from lib.geometry import inplane_2D_spatial_transform 7 | from lib import preprocess 8 | 9 | 10 | class OVE6D(nn.Module): 11 | def __init__(self): 12 | super(OVE6D, self).__init__() 13 | ###################################### backbone ############################################ 14 | self.stem_layer1 = nn.Sequential( 15 | nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1), 16 | nn.BatchNorm2d(16), 17 | nn.ReLU()) 18 | self.stem_layer2 = nn.Sequential( 19 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), 20 | nn.BatchNorm2d(64), 21 | nn.ReLU()) 22 | self.stem_layer3 = nn.Sequential( 23 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 24 | nn.BatchNorm2d(64), 25 | nn.ReLU()) 26 | self.stem_layer4 = nn.Sequential( 27 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 28 | nn.BatchNorm2d(128), 29 | nn.ReLU()) 30 | self.stem_layer5 = nn.Sequential( 31 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 32 | nn.BatchNorm2d(128), 33 | nn.ReLU()) 34 | self.stem_layer6 = nn.Sequential( 35 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 36 | nn.BatchNorm2d(256), 37 | nn.ReLU()) 38 | self.stem_layer7 = nn.Sequential( 39 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 40 | nn.BatchNorm2d(256), 41 | nn.ReLU()) 42 | self.stem_layer8 = nn.Sequential( 43 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), 44 | nn.BatchNorm2d(512), 45 | nn.ReLU()) 46 | self.backbone_layers = list() 47 | self.backbone_layers.append(self.stem_layer1) 48 | self.backbone_layers.append(self.stem_layer2) 49 | self.backbone_layers.append(self.stem_layer3) 50 | self.backbone_layers.append(self.stem_layer4) 51 | self.backbone_layers.append(self.stem_layer5) 52 | self.backbone_layers.append(self.stem_layer6) 53 | self.backbone_layers.append(self.stem_layer7) 54 | self.backbone_layers.append(self.stem_layer8) 55 | ###################################### backbone ############################################ 56 | 57 | ################################# viewpoint encoder head ######################################## 58 | self.vp_enc_transition = nn.Sequential( 59 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1), 60 | nn.BatchNorm2d(256), 61 | nn.ReLU()) 62 | self.vp_enc_pool = nn.AdaptiveMaxPool2d((1, 1)) 63 | self.vp_enc_fc = nn.Linear(in_features=256, out_features=64) 64 | ################################# viewpoint encoder head ######################################## 65 | 66 | 67 | ################################ in-plane transformation regression ####################################### 68 | self.vp_inp_transition = nn.Sequential( 69 | nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, stride=1), 70 | nn.BatchNorm2d(128), 71 | nn.ReLU()) 72 | 73 | self.vp_rot_fc1 = nn.Sequential( 74 | nn.Linear(in_features=4096, out_features=128), 75 | nn.ReLU()) 76 | self.vp_rot_fc2 = nn.Linear(in_features=128, out_features=2) 77 | 78 | self.vp_tls_fc1 = nn.Sequential( 79 | nn.Linear(in_features=4096, out_features=128), 80 | nn.ReLU()) 81 | self.vp_tls_fc2 = nn.Linear(in_features=128, out_features=2) 82 | 83 | ################################ in-plane transformation regression ####################################### 84 | 85 | 86 | ############################# orientation confidence ##################################### 87 | self.vp_conf_layer1 = nn.Sequential( 88 | nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0), 89 | nn.BatchNorm2d(128), 90 | nn.ReLU()) 91 | self.vp_conf_layer2 = nn.Sequential( 92 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 93 | nn.BatchNorm2d(128), 94 | nn.ReLU()) 95 | self.vp_conf_pool = nn.AdaptiveAvgPool2d((1, 1)) 96 | self.vp_conf_fc = nn.Linear(128, 1) 97 | ############################# orientation confidence ##################################### 98 | 99 | for m in self.modules(): 100 | if isinstance(m, (nn.Conv2d, nn.Linear)): 101 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 102 | 103 | def backbone(self, x): 104 | """ 105 | The backbone network for extracting features 106 | """ 107 | H, W = x.shape[-2:] 108 | x = x.view(-1, 1, H, W) 109 | for layer in self.backbone_layers: 110 | shortcut = x 111 | x = layer(x) 112 | if x.shape == shortcut.shape: 113 | x += shortcut 114 | return x 115 | 116 | def viewpoint_encoder_head(self, x): 117 | """ 118 | encoder head for extracting viewpoint representation 119 | """ 120 | x = self.vp_enc_transition(x) 121 | x = self.vp_enc_pool(x) # B x c x 1 x 1 122 | x = x.view(x.shape[0], -1) # BV x CHW 123 | x = self.vp_enc_fc(x) 124 | x = F.normalize(x, dim=1) 125 | return x 126 | 127 | def vipri_encoder(self, x, return_maps=False): 128 | """ 129 | viewpoint in-plane rotation invariant encoding 130 | """ 131 | ft_map = self.backbone(x) # B x 1 x H x W => B x C x h x w 132 | vp_enc = self.viewpoint_encoder_head(ft_map) 133 | if return_maps: 134 | vp_map = self.vp_inp_transition(ft_map) 135 | return vp_map, vp_enc 136 | return vp_enc 137 | 138 | def regression_head(self, x, y): 139 | """ 140 | regression head for 2D in-plane transformation from x to y 141 | """ 142 | bs, ch = x.shape[:2] 143 | x = F.normalize(x.view(bs, -1), dim=1).view(bs, ch, -1) 144 | y = F.normalize(y.view(bs, -1), dim=1).view(bs, ch, -1) 145 | # z = (x.unsqueeze(3) * y.unsqueeze(2)).sum(dim=1) # BV x C x 64 x 64 146 | z = torch.bmm(x.permute(0, 2, 1), y) # BV x 64 x 64, feature map correlation 147 | z = z.view(bs, -1) # BV x 4096 148 | 149 | Rz = self.vp_rot_fc1(z) # BV x 4096 -> BV x 128 150 | Rz = self.vp_rot_fc2(Rz) # BV x 128 -> BV x 2 151 | Rz = F.normalize(Rz, dim=1) 152 | 153 | TxTy = self.vp_tls_fc1(z) # Bx4096 -> Bx128 154 | TxTy = self.vp_tls_fc2(TxTy) # Bx128 -> Bx2 -> Bx1x2 155 | TxTy = torch.tanh(TxTy) # range[-1.0, 1.0] 156 | 157 | row1 = torch.stack([Rz[:, 0], -Rz[:, 1], TxTy[:, 0]], dim=1) # cos(theta), -sin(theta), BV x 3 158 | row2 = torch.stack([Rz[:, 1], Rz[:, 0], TxTy[:, 1]], dim=1) # sin(theta), cos(theta), BV x 3 159 | 160 | theta = torch.stack([row1, row2], dim=1) # BV x 2 x 3, 2D in-plane transformation matrix 161 | 162 | return theta 163 | 164 | def spatial_transformation(self, x, theta): 165 | """ 166 | transform feature maps with the given transformation matrix 167 | x: BxCxHxW 168 | theta: Bx2x3 169 | """ 170 | stn_theta = theta.clone() # Bx2x3 171 | y = preprocess.spatial_transform_2D(x=x, theta=stn_theta, 172 | mode='bilinear', 173 | padding_mode='border', 174 | align_corners=False) 175 | return y 176 | 177 | def viewpoint_confidence(self, x, y): 178 | """ 179 | calcuate the consistency 180 | """ 181 | z = torch.cat([x, y], dim=1) # Bx2Cx8x8 182 | z = self.vp_conf_layer1(z) 183 | z = self.vp_conf_layer2(z) 184 | z = self.vp_conf_pool(z).view(z.size(0), -1) # BxCx1x1 -> BxC 185 | z = self.vp_conf_fc(z) # Bx256 -> Bx1 186 | z = torch.sigmoid(z) 187 | return z 188 | 189 | def inference(self, anc_map, inp_map): 190 | pd_theta = self.regression_head(x=anc_map, y=inp_map) 191 | stn_inp_map = self.spatial_transformation(x=anc_map, theta=pd_theta) # transform anchor viewpoint with in-plane rotation 192 | pd_conf = self.viewpoint_confidence(x=inp_map, y=stn_inp_map) 193 | return pd_theta, pd_conf 194 | 195 | def forward(self, x_anc_gt, x_oup_gt, x_inp_aug, x_oup_aug, inp_gt_theta): 196 | """ 197 | input: 198 | x_anc_gt: rendered clean anchor viewpoint depth: B x V x H x W 199 | x_inp_gt: rendered clean inplane rotated depth (z-axis) 200 | x_inp_aug: augmented x_inp_gt 201 | x_oup_aug: augmented out-of-inplane rotated depth (xy-axis) 202 | return: 203 | viewpoint embeddings for the viewpoint triplets (Nx64) 204 | in-plane rotation of intra-viewpoint pair (Nx2x3) 205 | feature maps of the rendered and transformed of intra-viewpoint (Nx128x8x8) 206 | """ 207 | # feature extractions and viewpoint embeddings 208 | z_anc_gt_map, z_anc_gt_vec = self.vipri_encoder(x_anc_gt, return_maps=True) # BV x 1 x 128 x 128 --> BV x 512 x 8 x 8 209 | z_oup_gt_map, _ = self.vipri_encoder(x_oup_gt, return_maps=True) # BV x 1 x 128 x 128 --> BV x 512 x 8 x 8 210 | 211 | z_oup_aug_vec = self.vipri_encoder(x_oup_aug, return_maps=False) 212 | z_inp_aug_map, z_inp_aug_vec = self.vipri_encoder(x_inp_aug, return_maps=True) 213 | 214 | # the regression branch is only trained with in-plane views 215 | inp_pd_theta = self.regression_head(x=z_anc_gt_map, y=z_inp_aug_map) 216 | oup_pd_theta = self.regression_head(x=z_oup_gt_map, y=z_inp_aug_map) 217 | oup_pd_theta = oup_pd_theta.detach() # No grad for training the regression branc 218 | 219 | # the transformed anchor feature map is supposed to be equal to the gt inplane feature maps 220 | gt_stn_inp_map = self.spatial_transformation(x=z_anc_gt_map, theta=inp_gt_theta) # transform the feature map of anchor view 221 | # the transformation branch is trained with GT transformation only 222 | pd_stn_inp_map = self.spatial_transformation(x=z_anc_gt_map, theta=inp_pd_theta) 223 | pd_stn_oup_map = self.spatial_transformation(x=z_oup_gt_map, theta=oup_pd_theta) 224 | 225 | z_inp_aug_map = z_inp_aug_map.detach() 226 | gt_stn_inp_map = gt_stn_inp_map.detach() 227 | pd_stn_inp_map = pd_stn_inp_map.detach() 228 | pd_stn_oup_map = pd_stn_oup_map.detach() 229 | 230 | # the confidence branch is only trained with the predicted feature maps 231 | alpha = 0.2 232 | pd_stn_mix_map = alpha * pd_stn_inp_map + (1 - alpha) * pd_stn_oup_map 233 | pd_mix_cls = self.viewpoint_confidence(z_inp_aug_map, pd_stn_mix_map) 234 | 235 | gt_inp_cls = self.viewpoint_confidence(x=z_inp_aug_map, y=gt_stn_inp_map) 236 | pd_inp_cls = self.viewpoint_confidence(x=z_inp_aug_map, y=pd_stn_inp_map) 237 | pd_oup_cls = self.viewpoint_confidence(x=z_inp_aug_map, y=pd_stn_oup_map) 238 | 239 | return (inp_pd_theta, 240 | gt_inp_cls, pd_inp_cls, pd_oup_cls, pd_mix_cls, 241 | z_anc_gt_vec, z_inp_aug_vec, z_oup_aug_vec) # viewpoint embedding triplets 242 | 243 | 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /utility/visualization.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import math 4 | from contextlib import contextmanager 5 | from pathlib import Path 6 | 7 | import imageio 8 | import numpy as np 9 | import structlog 10 | import tempfile 11 | import torch 12 | import torchvision 13 | from matplotlib import cm 14 | from matplotlib import pyplot as plt 15 | from matplotlib.colors import LinearSegmentedColormap 16 | from torch.nn import functional as F 17 | from tqdm.auto import tqdm 18 | 19 | logger = structlog.get_logger(__name__) 20 | 21 | 22 | _colormap_cache = {} 23 | 24 | 25 | def _build_colormap(name, num_bins=256): 26 | base = cm.get_cmap(name) 27 | color_list = base(np.linspace(0, 1, num_bins)) 28 | cmap_name = base.name + str(num_bins) 29 | colormap = LinearSegmentedColormap.from_list(cmap_name, color_list, num_bins) 30 | colormap = torch.tensor(colormap(np.linspace(0, 1, num_bins)), dtype=torch.float32)[:, :3] 31 | return colormap 32 | 33 | 34 | def get_colormap(name): 35 | if name not in _colormap_cache: 36 | _colormap_cache[name] = _build_colormap(name) 37 | return _colormap_cache[name] 38 | 39 | def colorize_tanh_depth(tensor, cmap='magma'): 40 | if len(tensor.shape) > 4: 41 | tensor = tensor.view(-1, *tensor.shape[-3:]) 42 | if len(tensor.shape) == 2: 43 | tensor = tensor.unsqueeze(0) 44 | if len(tensor.shape) == 4: 45 | tensor = tensor.squeeze(1) 46 | tensor = tensor.detach().cpu() # N x H x W 47 | 48 | tensor = torch.tanh(tensor.type(torch.float32)) # [-1, 1] 49 | cmin = tensor.min(dim=-1)[0].min(dim=-1)[0].unsqueeze(1).unsqueeze(1) # N x 1 x 1 50 | cmax = tensor.max(dim=-1)[0].max(dim=-1)[0].unsqueeze(1).unsqueeze(1) # N x 1 x 1 51 | tensor = (tensor - cmin) / (cmax - cmin + 1e-6) 52 | tensor = (tensor * 255).clamp(0.0, 255.0).long() 53 | colormap = get_colormap(cmap) 54 | colorized = colormap[tensor].permute(0, 3, 1, 2) 55 | return colorized 56 | 57 | def colorize_tensor(tensor, cmap='magma', cmin=0, cmax=1): 58 | if len(tensor.shape) > 4: 59 | tensor = tensor.view(-1, *tensor.shape[-3:]) 60 | if len(tensor.shape) == 2: 61 | tensor = tensor.unsqueeze(0) 62 | if len(tensor.shape) == 4: 63 | tensor = tensor.squeeze(1) 64 | tensor = tensor.detach().cpu() 65 | tensor = (tensor - cmin) / (cmax - cmin) 66 | tensor = (tensor * 255).clamp(0.0, 255.0).long() 67 | colormap = get_colormap(cmap) 68 | colorized = colormap[tensor].permute(0, 3, 1, 2) 69 | return colorized 70 | 71 | 72 | def colorize_depth(depth): 73 | if depth.min().item() < -0.1: 74 | return colorize_tensor(depth.squeeze(1) / 2.0 + 0.5) 75 | else: 76 | return colorize_tensor(depth.squeeze(1), cmin=depth.max() - 1.0, cmax=depth.max()) 77 | 78 | 79 | def colorize_numpy(array, to_byte=True): 80 | array = torch.tensor(array) 81 | colorized = colorize_tensor(array) 82 | colorized = colorized.squeeze().permute(1, 2, 0).numpy() 83 | if to_byte: 84 | colorized = (colorized * 255).astype(np.uint8) 85 | return colorized 86 | 87 | 88 | def make_grid(images, d_real=None, d_fake=None, output_size=128, count=None, row_size=1, 89 | shuffle=False, stride=1): 90 | # Ensure that the view dimension is collapsed. 91 | images = [im.view(-1, *im.shape[-3:]) for im in images if im is not None] 92 | 93 | if count is None: 94 | count = images[0].size(0) 95 | # Select `count` random examples. 96 | if shuffle: 97 | inds = torch.randperm(images[0].size(0))[::stride][:count] 98 | else: 99 | inds = torch.arange(0, images[0].size(0))[::stride][:count] 100 | images = [im.detach().cpu()[inds] for im in images] 101 | 102 | # Expand 1 channel images to 3 channels. 103 | images = [im.expand(-1, 3, -1, -1) for im in images] 104 | 105 | # Resize images to output size. 106 | images = [F.interpolate(im, output_size) for im in images] 107 | 108 | if d_real and d_fake: 109 | d_real = [t[inds] for t in d_real] 110 | d_fake = [t[inds] for t in d_fake] 111 | 112 | # Create discriminator score grid. 113 | d_real = colorize_tensor( 114 | torch.cat([F.interpolate(h.detach().cpu().clamp(0, 1), output_size // 2) 115 | for h in d_real], dim=3).squeeze(1)) 116 | d_fake = colorize_tensor( 117 | torch.cat([F.interpolate(h.detach().cpu().clamp(0, 1), output_size // 2) 118 | for h in d_fake], dim=3).squeeze(1)) 119 | d_grid = torch.cat((d_real, d_fake), dim=2) 120 | 121 | # Create final grid. 122 | grid = torch.cat((*images, d_grid), dim=3) 123 | else: 124 | grid = torch.cat(images, dim=3) 125 | 126 | return torchvision.utils.make_grid(grid, nrow=row_size, padding=2) 127 | 128 | 129 | def save_video(frames, path, fps=15): 130 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 131 | 132 | temp_dir = tempfile.TemporaryDirectory() 133 | logger.info("saving video", num_frames=len(frames), fps=fps, 134 | path=path, temp_dir=temp_dir.name) 135 | try: 136 | for i, frame in enumerate(tqdm(frames)): 137 | if torch.is_tensor(frame): 138 | frame = frame.permute(1, 2, 0).detach().cpu().numpy() 139 | frame_path = Path(temp_dir.name, f'{i:08d}.jpg') 140 | imageio.imsave(frame_path, (frame * 255).astype(np.uint8)) 141 | 142 | video = ImageSequenceClip(temp_dir.name, fps=fps) 143 | video.write_videofile(str(path), preset='ultrafast', fps=fps) 144 | finally: 145 | temp_dir.cleanup() 146 | 147 | 148 | def save_frames(frames, save_dir): 149 | save_dir = Path(save_dir) 150 | save_dir.mkdir(exist_ok=True, parents=True) 151 | 152 | for i, frame in enumerate(tqdm(frames)): 153 | imageio.imsave(save_dir / f'{i:04d}.jpg', (frame * 255).astype(np.uint8)) 154 | 155 | 156 | def batch_grid(batch, nrow=4): 157 | batch = batch.view(-1, *batch.shape[-3:]) 158 | grid = torchvision.utils.make_grid(batch.detach().cpu(), nrow=nrow) 159 | return grid 160 | 161 | 162 | @contextmanager 163 | def plot_to_tensor(out_tensor, dpi=100): 164 | """ 165 | A context manager that yields an axis object. Plots will be copied to `out_tensor`. 166 | The output tensor should be a float32 tensor. 167 | 168 | Usage: 169 | ``` 170 | tensor = torch.tensor(3, 480, 640) 171 | with plot_to_tensor(tensor) as ax: 172 | ax.plot(...) 173 | ``` 174 | 175 | Args: 176 | out_tensor: tensor to write to 177 | dpi: the DPI to render at 178 | """ 179 | height, width = out_tensor.shape[-2:] 180 | fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) 181 | ax = fig.add_subplot(111) 182 | ax.axis('off') 183 | fig.tight_layout(pad=0) 184 | 185 | yield ax 186 | 187 | # If we haven't already shown or saved the plot, then we need to 188 | # draw the figure first... 189 | fig.canvas.draw() 190 | 191 | # Now we can save it to a numpy array. 192 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 193 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 194 | plt.close() 195 | 196 | out_tensor.copy_((torch.tensor(data).float() / 255.0).permute(2, 0, 1)) 197 | 198 | 199 | @contextmanager 200 | def plot_to_array(height, width, rows=1, cols=1, dpi=100): 201 | """ 202 | A context manager that yields an axis object. Plots will be copied to `out_tensor`. 203 | The output tensor should be a float32 tensor. 204 | 205 | Usage: 206 | ``` 207 | with plot_to_array(480, 640, 2, 2) as (fig, axes, out_image): 208 | axes[0][0].plot(...) 209 | ``` 210 | 211 | Args: 212 | height: the height of the canvas 213 | width: the width of the canvas 214 | rows: the number of axis rows 215 | cols: the number of axis columns 216 | dpi: the DPI to render at 217 | """ 218 | out_array = np.empty((height, width, 3), dtype=np.uint8) 219 | fig, axes = plt.subplots(rows, cols, figsize=(width / dpi, height / dpi), dpi=dpi) 220 | 221 | yield fig, axes, out_array 222 | 223 | # If we haven't already shown or saved the plot, then we need to 224 | # draw the figure first... 225 | fig.tight_layout(pad=0) 226 | fig.canvas.draw() 227 | 228 | # Now we can save it to a numpy array. 229 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 230 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 231 | plt.close() 232 | 233 | np.copyto(out_array, data) 234 | 235 | 236 | def apply_mask_gray(image, mask): 237 | image = (image - 0.5) * 2.0 238 | image = image * mask 239 | return (image + 1.0) / 2.0 240 | 241 | 242 | def show_batch(batch, nrow=16, title=None, padding=2, pad_value=1): 243 | batch = batch.view(-1, *batch.shape[-3:]) 244 | grid = torchvision.utils.make_grid(batch.detach().cpu(), 245 | nrow=nrow, 246 | padding=padding, 247 | pad_value=pad_value).permute(1, 2, 0) 248 | if title: 249 | plt.title(title) 250 | plt.axis('off') 251 | plt.imshow(grid) 252 | 253 | 254 | def plot_image_batches(path, images, num_cols=None, size=5): 255 | titles, images = list(zip(*images)) 256 | 257 | num_images = len(images) 258 | num_batch = max(len(x) for x in images if x is not None) 259 | grid_row_size = int(math.ceil(math.sqrt(num_batch))) 260 | 261 | if num_cols is None: 262 | num_cols = num_images 263 | num_rows = int(math.ceil(len(images) / num_cols)) 264 | 265 | aspect_ratio = images[0].shape[-1] / images[0].shape[-2] 266 | width = num_cols * size * aspect_ratio 267 | height = num_rows * (size + 1) # Room for titles. 268 | 269 | fig = plt.figure(figsize=(width, height)) 270 | for i in range(num_images): 271 | if images[i] is None: 272 | continue 273 | plt.subplot(num_rows, num_cols, i+1) 274 | show_batch(images[i], 275 | nrow=min(len(images[i]), grid_row_size), 276 | title=titles[i]) 277 | 278 | fig.tight_layout() 279 | fig.savefig(path) 280 | plt.close('all') 281 | 282 | 283 | def plot_grid(num_cols, figsize, plots): 284 | if num_cols is None: 285 | num_cols = len(plots) 286 | num_rows = int(math.ceil(len(plots) / num_cols)) 287 | 288 | fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize) 289 | for i, ax in enumerate(axes.flatten()): 290 | if i >= len(plots) or plots[i] is None: 291 | ax.axis('off') 292 | continue 293 | plot = plots[i] 294 | args = plot.args if plot.args else [] 295 | kwargs = plot.kwargs if plot.kwargs else {} 296 | if isinstance(plot.func, str): 297 | getattr(ax, plot.func)(*args, **kwargs) 298 | else: 299 | plot.func(*args, **kwargs, ax=ax) 300 | ax.set_title(plot.title) 301 | if plot.params: 302 | for param_key, param_value in plot.params.items(): 303 | getattr(ax, f'set_{param_key}')(param_value) 304 | # fig.set_facecolor('white') 305 | fig.tight_layout() 306 | 307 | return fig 308 | 309 | 310 | def depth_to_disparity(depth): 311 | depth[depth > 0] = 1/depth[depth > 0] 312 | valid = depth[depth > 0] 313 | cmin = valid.min() 314 | cmax = valid.max() 315 | return (depth - cmin) / (cmax - cmin) 316 | 317 | 318 | def depth_to_disparity(depth): 319 | depth[depth > 0] = 1/depth[depth > 0] 320 | valid = depth[depth > 0] 321 | cmin = valid.min() 322 | cmax = valid.max() 323 | return (depth - cmin) / (cmax - cmin) 324 | 325 | 326 | def normalize_visulization(depth): 327 | 328 | if isinstance(depth, torch.Tensor): 329 | depth = depth.squeeze().clone() 330 | else: 331 | depth = torch.tensor(depth).squeeze().clone() 332 | 333 | mask = torch.zeros_like(depth) 334 | mask[depth>0] = 1 335 | min_dep = depth[mask.bool()].min() 336 | max_dep = depth[mask.bool()].max() 337 | mean_depth = 0.5*(min_dep + max_dep)* mask 338 | depth = depth - mean_depth 339 | return depth 340 | 341 | 342 | # Plot = namedtuple('Plot', ['title', 'args', 'kwargs', 'params', 'func'], 343 | # defaults=[None, None, None, 'plot']) 344 | 345 | -------------------------------------------------------------------------------- /evaluation/LMO_RCNN_OVE6D_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import sys 5 | import json 6 | # import yaml 7 | import time 8 | import torch 9 | import warnings 10 | import numpy as np 11 | from PIL import Image 12 | from pathlib import Path 13 | 14 | from detectron2 import model_zoo 15 | from detectron2.config import get_cfg 16 | from detectron2.engine import DefaultPredictor 17 | 18 | 19 | 20 | from os.path import join as pjoin 21 | from bop_toolkit_lib import inout 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 26 | sys.path.append(base_path) 27 | 28 | from lib import rendering, network 29 | 30 | from dataset import LineMOD_Dataset 31 | from evaluation import utils 32 | from evaluation import config as cfg 33 | 34 | gpu_id = 0 35 | # gpu_id = 1 36 | 37 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 38 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 39 | os.environ['EGL_DEVICE_ID'] = str(gpu_id) 40 | DEVICE = torch.device('cuda') 41 | 42 | 43 | datapath = Path(cfg.DATA_PATH) 44 | 45 | eval_dataset = LineMOD_Dataset.Dataset(datapath / 'lm') 46 | 47 | ################################################# MASK-RCNN Segmentation ################################################################## 48 | rcnnIdx_to_lmoIds_dict = {0:1, 1:5, 2:6, 3:8, 4:9, 5:10, 6:11, 7:12} 49 | rcnnIdx_to_lmoCats_dict = {0:'Ape', 1:'Can', 2:'Cat', 3:'Driller', 4:'Duck', 5:'Eggbox', 6:'Glue', 7:'Holepunch'} 50 | catId_to_catName_dict = {1:'Ape', 5:'Can', 6:'Cat', 8:'Driller', 9:'Duck', 10:'Eggbox', 11:'Glue', 12:'Holepunch'} 51 | rcnn_cfg = get_cfg() 52 | rcnn_cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 53 | rcnn_cfg.MODEL.WEIGHTS = pjoin(base_path, 'checkpoints', 'lmo_maskrcnn_model.pth') 54 | 55 | rcnn_cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(rcnnIdx_to_lmoCats_dict) 56 | rcnn_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.001 # the predicted category scores 57 | predictor = DefaultPredictor(rcnn_cfg) 58 | ################################################# MASK-RCNN Segmentation ################################################################## 59 | 60 | cfg.DATASET_NAME = 'lm' # dataset name 61 | cfg.RENDER_WIDTH = eval_dataset.cam_width # the width of rendered images 62 | cfg.RENDER_HEIGHT = eval_dataset.cam_height # the height of rendered images 63 | cfg.HEMI_ONLY = True 64 | 65 | ckpt_file = pjoin(base_path, 66 | 'checkpoints', 67 | "OVE6D_pose_model.pth" 68 | ) 69 | model_net = network.OVE6D().to(DEVICE) 70 | 71 | model_net.load_state_dict(torch.load(ckpt_file), strict=True) 72 | model_net.eval() 73 | 74 | codebook_saving_dir = pjoin(base_path,'evaluation/object_codebooks', 75 | cfg.DATASET_NAME, 76 | 'zoom_{}'.format(cfg.ZOOM_DIST_FACTOR), 77 | 'views_{}'.format(str(cfg.RENDER_NUM_VIEWS))) 78 | 79 | object_codebooks = utils.OVE6D_codebook_generation(codebook_dir=codebook_saving_dir, 80 | model_func=model_net, 81 | dataset=eval_dataset, 82 | config=cfg, 83 | device=DEVICE) 84 | raw_pred_results = list() 85 | icp1_pred_results = list() 86 | icpk_pred_results = list() 87 | raw_pred_runtime = list() 88 | icp1_pred_runtime = list() 89 | icpk_pred_runtime = list() 90 | 91 | rcnn_gt_results = dict() 92 | rcnn_pd_results = dict() 93 | 94 | test_data_dir = datapath / 'lmo' / 'test' # path to the test dataset of BOP 95 | eval_dir = pjoin(base_path, 'evaluation/pred_results/LMO') 96 | 97 | 98 | raw_file_mode = "raw-sampleN{}-viewpointK{}-poseP{}-rcnn_lmo-test.csv" 99 | if cfg.USE_ICP: 100 | icp1_file_mode = "icp1-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-rcnn_lmo-test.csv" 101 | icpk_file_mode = "icpk-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-rcnn_lmo-test.csv" 102 | 103 | obj_renderer = rendering.Renderer(width=cfg.RENDER_WIDTH, height=cfg.RENDER_HEIGHT) 104 | 105 | if not os.path.exists(eval_dir): 106 | os.makedirs(eval_dir) 107 | 108 | for scene_id in sorted(os.listdir(test_data_dir)): 109 | scene_dir = pjoin(test_data_dir, scene_id) 110 | if not os.path.isdir(scene_dir): 111 | continue 112 | cam_info_file = pjoin(scene_dir, 'scene_camera.json') 113 | with open(cam_info_file, 'r') as cam_f: 114 | scene_camera_info = json.load(cam_f) 115 | 116 | gt_pose_file = os.path.join(scene_dir, 'scene_gt.json') 117 | with open(gt_pose_file, 'r') as pose_f: 118 | pose_anno = json.load(pose_f) 119 | 120 | rgb_dir = pjoin(scene_dir, 'rgb') 121 | depth_dir = pjoin(scene_dir, 'depth') 122 | mask_dir = os.path.join(scene_dir, 'mask_visib') 123 | rcnn_runtime = list() 124 | view_runtime = list() 125 | for rgb_png in sorted(os.listdir(rgb_dir)): 126 | if not rgb_png.endswith('.png'): 127 | continue 128 | view_id_str = rgb_png.split('.')[0] 129 | view_id = int(view_id_str) 130 | view_timer = time.time() 131 | 132 | 133 | ###################### read gt mask ########################## 134 | target_gt_masks = dict() 135 | view_gt_poses = pose_anno[str(view_id)] 136 | for ix, gt_obj in enumerate(view_gt_poses): 137 | gt_obj_id = gt_obj['obj_id'] 138 | mask_file = os.path.join(mask_dir, "{:06d}_{:06d}.png".format(view_id, ix)) 139 | gt_msk = torch.tensor(cv2.imread(mask_file, 0)).type(torch.bool) 140 | target_gt_masks[gt_obj_id] = gt_msk 141 | if gt_obj_id not in rcnn_gt_results: 142 | rcnn_gt_results[gt_obj_id] = 0 143 | rcnn_gt_results[gt_obj_id] += 1 144 | ###################### read gt mask ########################## 145 | 146 | ###################### object segmentation ###################### 147 | img_name = "{:06d}.png".format(view_id) 148 | rgb_file = os.path.join(rgb_dir, img_name) 149 | rgb_img = cv2.imread(rgb_file) 150 | output = predictor(rgb_img) 151 | rcnn_pred_ids = output["instances"].pred_classes # cat_idx: 0 - 7 152 | rcnn_pred_masks = output["instances"].pred_masks 153 | # rcnn_pred_bboxes = output["instances"].pred_boxes 154 | rcnn_pred_scores = output["instances"].scores 155 | rcnn_cost = time.time() - view_timer 156 | rcnn_runtime.append(rcnn_cost) 157 | ###################### object segmentation ###################### 158 | 159 | obj_masks = rcnn_pred_masks # NxHxW 160 | 161 | view_cam_info = scene_camera_info[str(view_id)] # scene camera information 162 | depth_file = pjoin(depth_dir, "{:06d}.png".format(view_id)) 163 | view_depth = torch.tensor(np.array(Image.open(depth_file)), dtype=torch.float32) # HxW 164 | view_depth *= view_cam_info['depth_scale'] 165 | view_depth *= cfg.MODEL_SCALING # convert to meter scale from millimeter scale 166 | view_camK = torch.tensor(view_cam_info['cam_K'], dtype=torch.float32).view(3, 3)[None, ...] # 1x3x3 167 | 168 | cam_K = view_camK.to(DEVICE) 169 | view_depth = view_depth.to(DEVICE) 170 | obj_depths = view_depth[None, ...] * obj_masks 171 | 172 | unique_rcnn_obj_ids = torch.unique(rcnn_pred_ids) 173 | for uniq_rcnn_id in unique_rcnn_obj_ids: 174 | uniq_lmo_id = rcnnIdx_to_lmoIds_dict[uniq_rcnn_id.item()] 175 | uniq_obj_codebook = object_codebooks[uniq_lmo_id] 176 | 177 | uniq_obj_mask = obj_masks[rcnn_pred_ids==uniq_rcnn_id] 178 | uniq_obj_depth = obj_depths[rcnn_pred_ids==uniq_rcnn_id] 179 | uniq_obj_score = rcnn_pred_scores[rcnn_pred_ids==uniq_rcnn_id] 180 | 181 | mask_pixel_count = uniq_obj_mask.view(uniq_obj_mask.size(0), -1).sum(dim=1) 182 | 183 | valid_idx = (mask_pixel_count >= 100) 184 | if valid_idx.sum() == 0: 185 | mask_visib_ratio = mask_pixel_count / mask_pixel_count.max() 186 | valid_idx = mask_visib_ratio >= 0.05 187 | 188 | uniq_obj_mask = uniq_obj_mask[valid_idx] 189 | uniq_obj_depth = uniq_obj_depth[valid_idx] 190 | uniq_obj_score = uniq_obj_score[valid_idx] 191 | 192 | pose_ret = utils.OVE6D_rcnn_full_pose(model_func=model_net, 193 | obj_depths=uniq_obj_depth, 194 | obj_masks=uniq_obj_mask, 195 | obj_rcnn_scores=uniq_obj_score, 196 | obj_codebook=uniq_obj_codebook, 197 | cam_K=cam_K, 198 | config=cfg, 199 | device=DEVICE, 200 | obj_renderer=obj_renderer) 201 | select_rcnn_idx = pose_ret['rcnn_idx'] 202 | rcnn_pd_mask = uniq_obj_mask[select_rcnn_idx].cpu() 203 | rcnn_pd_score = uniq_obj_score[select_rcnn_idx].cpu() 204 | 205 | if uniq_lmo_id not in rcnn_pd_results: 206 | rcnn_pd_results[uniq_lmo_id] = list() 207 | 208 | if uniq_lmo_id in target_gt_masks: 209 | obj_gt_mask = target_gt_masks[uniq_lmo_id] 210 | inter_area = obj_gt_mask & rcnn_pd_mask 211 | outer_area = obj_gt_mask | rcnn_pd_mask 212 | iou = inter_area.sum() / outer_area.sum() 213 | rcnn_pd_results[uniq_lmo_id].append(iou.item()) 214 | else: 215 | rcnn_pd_results[uniq_lmo_id].append(0.0) 216 | 217 | raw_pred_results.append({'time': pose_ret['raw_time'], 218 | 'scene_id': int(scene_id), 219 | 'im_id': int(view_id), 220 | 'obj_id': int(uniq_lmo_id), 221 | 'score': pose_ret['raw_score'].squeeze().numpy(), 222 | 'R': cfg.POSE_TO_BOP(pose_ret['raw_R']).squeeze().numpy(), 223 | 't': pose_ret['raw_t'].squeeze().numpy() * 1000.0}) # convert estimated pose to BOP format 224 | raw_pred_runtime.append(pose_ret['raw_time']) 225 | if cfg.USE_ICP: 226 | icp1_pred_results.append({'time': pose_ret['icp1_rawicp_time'], 227 | 'scene_id': int(scene_id), 228 | 'im_id': int(view_id), 229 | 'obj_id': int(uniq_lmo_id), 230 | 'score': pose_ret['icp1_score'].squeeze().numpy(), 231 | 'R': cfg.POSE_TO_BOP(pose_ret['icp1_R']).squeeze().numpy(), 232 | 't': pose_ret['icp1_t'].squeeze().numpy() * 1000.0}) 233 | icp1_pred_runtime.append(pose_ret['icp1_rawicp_time']) 234 | 235 | icpk_pred_results.append({'time': pose_ret['icpk_rawicp_time'], 236 | 'scene_id': int(scene_id), 237 | 'im_id': int(view_id), 238 | 'obj_id': int(uniq_lmo_id), 239 | 'score': pose_ret['icpk_score'].squeeze().numpy(), 240 | 'R': cfg.POSE_TO_BOP(pose_ret['icpk_R']).squeeze().numpy(), 241 | 't': pose_ret['icpk_t'].squeeze().numpy() * 1000.0}) 242 | icpk_pred_runtime.append(pose_ret['icpk_rawicp_time']) 243 | 244 | view_runtime.append(time.time() - view_timer) 245 | if (view_id) % 100 == 0: 246 | print('scene:{}, image: {}, rcnn:{:.3f}, image_cost:{:.3f}, raw_t:{:.3f}, icp1_t:{:.3f}, icpk_t:{:.3f}'.format( 247 | int(scene_id), view_id+1, np.mean(rcnn_runtime), np.mean(view_runtime), 248 | np.mean(raw_pred_runtime), np.mean(icp1_pred_runtime), np.mean(icpk_pred_runtime))) 249 | 250 | print('{}, {}'.format(scene_id, time.strftime('%m_%d-%H:%M:%S', time.localtime()))) 251 | 252 | rawk_eval_file = pjoin(eval_dir, raw_file_mode.format( 253 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK 254 | )) 255 | inout.save_bop_results(rawk_eval_file, raw_pred_results) 256 | 257 | mean_raw_time = np.mean(raw_pred_runtime) 258 | print('raw_mean_runtime: {:.4f}, saving to {}'.format(mean_raw_time, rawk_eval_file)) 259 | 260 | if cfg.USE_ICP: 261 | icp1_eval_file = pjoin(eval_dir, icp1_file_mode.format( 262 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 263 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 264 | )) 265 | icpk_eval_file = pjoin(eval_dir, icpk_file_mode.format( 266 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 267 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 268 | )) 269 | inout.save_bop_results(icp1_eval_file, icp1_pred_results) 270 | inout.save_bop_results(icpk_eval_file, icpk_pred_results) 271 | 272 | mean_icp1_time = np.mean(icp1_pred_runtime) 273 | mean_icpk_time = np.mean(icpk_pred_runtime) 274 | print('icp1_mean_runtime: {:.4f}, saving to {}'.format(mean_icp1_time, icp1_eval_file)) 275 | print('icpk_mean_runtime: {:.4f}, saving to {}'.format(mean_icpk_time, icpk_eval_file)) 276 | 277 | del obj_renderer 278 | 279 | 280 | ##################### evaluate rcnn detection and segmentation performance #################### 281 | iou_T = 0.5 282 | rcnn_obj_ARs = list() 283 | rcnn_obj_APs = list() 284 | print(' #################################### IOU_Threshold = {:.2f} #################################### '.format(iou_T)) 285 | for obj_abs_id, obj_iou in rcnn_pd_results.items(): 286 | obj_name = catId_to_catName_dict[obj_abs_id] 287 | obj_rcnn_iou = np.array(obj_iou) 288 | 289 | all_pd_count = len(obj_rcnn_iou) 290 | all_gt_count = rcnn_gt_results[obj_abs_id] 291 | true_pd_count = sum(obj_rcnn_iou >= iou_T) 292 | 293 | obj_AP = true_pd_count / all_pd_count # True_PD / ALL_PD 294 | obj_AR = true_pd_count / all_gt_count # True_PD / ALL_GT 295 | 296 | rcnn_obj_APs.append(obj_AP) 297 | rcnn_obj_ARs.append(obj_AR) 298 | 299 | print('obj_id: {:02d}, obj_AR: {:.5f}, obj_AP: {:.5f}, All_GT:{}, All_PD:{}, True_PD:{}, obj_name: {}'.format( 300 | obj_abs_id, obj_AR, obj_AP, all_gt_count, all_pd_count, true_pd_count, obj_name)) 301 | 302 | mAR = np.mean(rcnn_obj_ARs) 303 | mAP = np.mean(rcnn_obj_APs) 304 | print('IOU_T:{:.5f}, mean_recall:{:.5f}, mean_precision: {:.5f}'.format(iou_T, mAR, mAP)) -------------------------------------------------------------------------------- /evaluation/LM_RCNN_OVE6D_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import sys 5 | import json 6 | # import yaml 7 | import time 8 | import torch 9 | import warnings 10 | import numpy as np 11 | from PIL import Image 12 | from pathlib import Path 13 | 14 | from detectron2 import model_zoo 15 | from detectron2.config import get_cfg 16 | from detectron2.engine import DefaultPredictor 17 | 18 | 19 | 20 | from os.path import join as pjoin 21 | from bop_toolkit_lib import inout 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 26 | sys.path.append(base_path) 27 | 28 | from lib import rendering, network 29 | 30 | from dataset import LineMOD_Dataset 31 | from evaluation import utils 32 | from evaluation import config as cfg 33 | 34 | gpu_id = 0 35 | # gpu_id = 1 36 | 37 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 38 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 39 | os.environ['EGL_DEVICE_ID'] = str(gpu_id) 40 | DEVICE = torch.device('cuda') 41 | 42 | 43 | datapath = Path(cfg.DATA_PATH) 44 | 45 | eval_dataset = LineMOD_Dataset.Dataset(datapath / 'lm') 46 | 47 | ################################################# MASK-RCNN Segmentation ################################################################## 48 | rcnnIdx_to_lmIds_dict = {0:1, 1:2, 2:3, 3:4, 4:5, 5:6, 6:7, 7:8, 8:9, 9:10, 10:11, 11:12, 12:13, 13:14, 14:15} 49 | rcnnIdx_to_lmCats_dict ={0:'Ape', 1:'Benchvice', 2:'Bowl', 3:'Camera', 4:'Can', 5:'Cat', 6:'Cup', 7:'Driller', 50 | 8:'Duck', 9:'Eggbox', 10:'Glue', 11:'Holepunch', 12:'Iron', 13:'Lamp', 14:'Phone'} 51 | rcnn_cfg = get_cfg() 52 | # rcnn_cfg.INPUT.MASK_FORMAT = 'bitmask' 53 | rcnn_cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 54 | rcnn_cfg.MODEL.WEIGHTS = pjoin(base_path, 55 | 'checkpoints', 56 | 'lm_maskrcnn_model.pth') 57 | 58 | rcnn_cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(rcnnIdx_to_lmCats_dict) 59 | rcnn_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.001 # the predicted category scores 60 | predictor = DefaultPredictor(rcnn_cfg) 61 | ################################################# MASK-RCNN Segmentation ################################################################## 62 | 63 | 64 | cfg.DATASET_NAME = 'lm' # dataset name 65 | cfg.RENDER_WIDTH = eval_dataset.cam_width # the width of rendered images 66 | cfg.RENDER_HEIGHT = eval_dataset.cam_height # the height of rendered images 67 | 68 | cfg.HEMI_ONLY = True 69 | 70 | ckpt_file = pjoin(base_path, 71 | 'checkpoints', 72 | "OVE6D_pose_model.pth" 73 | ) 74 | model_net = network.OVE6D().to(DEVICE) 75 | 76 | model_net.load_state_dict(torch.load(ckpt_file), strict=True) 77 | model_net.eval() 78 | 79 | codebook_saving_dir = pjoin(base_path,'evaluation/object_codebooks', 80 | cfg.DATASET_NAME, 81 | 'zoom_{}'.format(cfg.ZOOM_DIST_FACTOR), 82 | 'views_{}'.format(str(cfg.RENDER_NUM_VIEWS))) 83 | 84 | 85 | object_codebooks = utils.OVE6D_codebook_generation(codebook_dir=codebook_saving_dir, 86 | model_func=model_net, 87 | dataset=eval_dataset, 88 | config=cfg, 89 | device=DEVICE) 90 | raw_pred_results = list() 91 | icp1_pred_results = list() 92 | icpk_pred_results = list() 93 | raw_pred_runtime = list() 94 | icp1_pred_runtime = list() 95 | icpk_pred_runtime = list() 96 | 97 | rcnn_gt_results = dict() 98 | rcnn_pd_results = dict() 99 | 100 | test_data_dir = datapath / 'lm' / 'test' # path to the test dataset of BOP 101 | eval_dir = pjoin(base_path, 'evaluation/pred_results/LM') 102 | 103 | raw_file_mode = "raw-sampleN{}-viewpointK{}-poseP{}-rcnn_lm-test.csv" 104 | if cfg.USE_ICP: 105 | icp1_file_mode = "icp1-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-rcnn_lm-test.csv" 106 | icpk_file_mode = "icpk-sampleN{}-viewpointK{}-poseP{}-nbr{}-itr{}-pts{}-pla{}-rcnn_lm-test.csv" 107 | 108 | 109 | obj_renderer = rendering.Renderer(width=cfg.RENDER_WIDTH, height=cfg.RENDER_HEIGHT) 110 | 111 | if not os.path.exists(eval_dir): 112 | os.makedirs(eval_dir) 113 | 114 | # single_proposal_icp_cost = list() 115 | # single_proposal_raw_cost = list() 116 | 117 | img_read_cost = list() 118 | bg_cost = list() 119 | zoom_cost = list() 120 | rot_cost = list() 121 | tsl_cost = list() 122 | 123 | raw_syn_render_cost = list() 124 | raw_selection_cost = list() 125 | raw_postprocess_cost = list() 126 | 127 | icp1_refinement_cost = list() 128 | icpk_refinement_cost = list() 129 | 130 | icpk_syn_render_cost = list() 131 | icpk_selection_cost = list() 132 | icpk_postprocess_cost = list() 133 | 134 | for scene_id in sorted(os.listdir(test_data_dir)): 135 | tar_obj_id = int(scene_id) 136 | # if tar_obj_id not in [3, 7]: # skip these two objects 137 | # continue 138 | 139 | scene_dir = pjoin(test_data_dir, scene_id) 140 | if not os.path.isdir(scene_dir): 141 | continue 142 | cam_info_file = pjoin(scene_dir, 'scene_camera.json') 143 | with open(cam_info_file, 'r') as cam_f: 144 | scene_camera_info = json.load(cam_f) 145 | 146 | gt_pose_file = os.path.join(scene_dir, 'scene_gt.json') 147 | with open(gt_pose_file, 'r') as pose_f: 148 | pose_anno = json.load(pose_f) 149 | 150 | rgb_dir = pjoin(scene_dir, 'rgb') 151 | depth_dir = pjoin(scene_dir, 'depth') 152 | mask_dir = os.path.join(scene_dir, 'mask_visib') 153 | rcnn_runtime = list() 154 | view_runtime = list() 155 | for rgb_png in sorted(os.listdir(rgb_dir)): 156 | if not rgb_png.endswith('.png'): 157 | continue 158 | view_id_str = rgb_png.split('.')[0] 159 | view_id = int(view_id_str) 160 | view_timer = time.time() 161 | 162 | ###################### read gt mask ########################## 163 | # target_gt_masks = dict() 164 | # view_gt_poses = pose_anno[str(view_id)] 165 | # for ix, gt_obj in enumerate(view_gt_poses): 166 | # gt_obj_id = gt_obj['obj_id'] 167 | # mask_file = os.path.join(mask_dir, "{:06d}_{:06d}.png".format(view_id, ix)) 168 | # gt_msk = torch.tensor(cv2.imread(mask_file, 0)).type(torch.bool) 169 | # target_gt_masks[gt_obj_id] = gt_msk 170 | # if gt_obj_id not in rcnn_gt_results: 171 | # rcnn_gt_results[gt_obj_id] = 0 172 | # rcnn_gt_results[gt_obj_id] += 1 173 | ###################### read gt mask ########################## 174 | 175 | ###################### object segmentation ###################### 176 | img_name = "{:06d}.png".format(view_id) 177 | rgb_file = os.path.join(rgb_dir, img_name) 178 | rgb_img = cv2.imread(rgb_file) 179 | imread_cost = time.time() - view_timer 180 | img_read_cost.append(imread_cost) 181 | 182 | rcnn_timer = time.time() 183 | output = predictor(rgb_img) 184 | rcnn_pred_ids = output["instances"].pred_classes 185 | rcnn_pred_masks = output["instances"].pred_masks 186 | rcnn_pred_scores = output["instances"].scores 187 | # rcnn_pred_bboxes = output["instances"].pred_boxes 188 | rcnn_cost = time.time() - rcnn_timer 189 | rcnn_runtime.append(rcnn_cost) 190 | ###################### object segmentation ###################### 191 | 192 | obj_masks = rcnn_pred_masks # NxHxW 193 | 194 | view_cam_info = scene_camera_info[str(view_id)] # scene camera information 195 | depth_file = pjoin(depth_dir, "{:06d}.png".format(view_id)) 196 | view_depth = torch.tensor(np.array(Image.open(depth_file)), dtype=torch.float32) # HxW 197 | view_depth *= view_cam_info['depth_scale'] 198 | view_depth *= cfg.MODEL_SCALING # convert to meter scale from millimeter scale 199 | view_camK = torch.tensor(view_cam_info['cam_K'], dtype=torch.float32).view(3, 3)[None, ...] # 1x3x3 200 | 201 | cam_K = view_camK.to(DEVICE) 202 | view_depth = view_depth.to(DEVICE) 203 | obj_depths = view_depth[None, ...] * obj_masks 204 | 205 | tar_obj_codebook = object_codebooks[tar_obj_id] 206 | tar_rcnn_d = tar_obj_id - 1 207 | tar_obj_depths = obj_depths[tar_rcnn_d==rcnn_pred_ids] 208 | tar_obj_masks = rcnn_pred_masks[tar_rcnn_d==rcnn_pred_ids] 209 | tar_obj_scores = rcnn_pred_scores[tar_rcnn_d==rcnn_pred_ids] 210 | 211 | if len(tar_obj_scores) > 0: 212 | mask_pixel_count = tar_obj_masks.view(tar_obj_masks.size(0), -1).sum(dim=1) 213 | valid_idx = (mask_pixel_count >= 100) 214 | if valid_idx.sum() == 0: 215 | mask_visib_ratio = mask_pixel_count / mask_pixel_count.max() 216 | valid_idx = mask_visib_ratio >= 0.05 217 | 218 | tar_obj_masks = tar_obj_masks[valid_idx] 219 | tar_obj_depths = tar_obj_depths[valid_idx] 220 | tar_obj_scores = tar_obj_scores[valid_idx] 221 | 222 | pose_ret = utils.OVE6D_rcnn_full_pose(model_func=model_net, 223 | obj_depths=tar_obj_depths, 224 | obj_masks=tar_obj_masks, 225 | obj_rcnn_scores=tar_obj_scores, 226 | obj_codebook=tar_obj_codebook, 227 | cam_K=cam_K, 228 | config=cfg, 229 | device=DEVICE, 230 | obj_renderer=obj_renderer) 231 | select_rcnn_idx = pose_ret['rcnn_idx'] 232 | rcnn_pd_mask = tar_obj_masks[select_rcnn_idx].cpu() 233 | rcnn_pd_score = tar_obj_scores[select_rcnn_idx].cpu() 234 | raw_pred_results.append({'time': pose_ret['raw_time'], 235 | 'scene_id': int(scene_id), 236 | 'im_id': int(view_id), 237 | 'obj_id': int(tar_obj_id), 238 | 'score': pose_ret['raw_score'].squeeze().numpy(), 239 | 'R': cfg.POSE_TO_BOP(pose_ret['raw_R']).squeeze().numpy(), 240 | 't': pose_ret['raw_t'].squeeze().numpy() * 1000.0}) # convert estimated pose to BOP format 241 | 242 | bg_cost.append(pose_ret['bg_time']) 243 | zoom_cost.append(pose_ret['zoom_time']) 244 | rot_cost.append(pose_ret['rot_time']) 245 | tsl_cost.append(pose_ret['tsl_time']) 246 | 247 | raw_pred_runtime.append(pose_ret['raw_time']) 248 | raw_syn_render_cost.append(pose_ret['raw_syn_time']) 249 | raw_selection_cost.append(pose_ret['raw_select_time']) 250 | raw_postprocess_cost.append(pose_ret['raw_postp_time']) 251 | 252 | # single_proposal_raw_cost.append(pose_ret['top1_raw_time']) 253 | if cfg.USE_ICP: 254 | icp1_refinement_cost.append(pose_ret['icp1_ref_time']) 255 | icp1_pred_runtime.append(pose_ret['icp1_rawicp_time']) 256 | 257 | icpk_syn_render_cost.append(pose_ret['icpk_syn_time']) 258 | icpk_selection_cost.append(pose_ret['icpk_select_time']) 259 | icpk_postprocess_cost.append(pose_ret['icpk_postp_time']) 260 | 261 | icpk_refinement_cost.append(pose_ret['icpk_ref_time']) 262 | icpk_pred_runtime.append(pose_ret['icpk_rawicp_time']) 263 | 264 | icp1_pred_results.append({'time': pose_ret['icp1_rawicp_time'], 265 | 'scene_id': int(scene_id), 266 | 'im_id': int(view_id), 267 | 'obj_id': int(tar_obj_id), 268 | 'score': pose_ret['icp1_score'].squeeze().numpy(), 269 | 'R': cfg.POSE_TO_BOP(pose_ret['icp1_R']).squeeze().numpy(), 270 | 't': pose_ret['icp1_t'].squeeze().numpy() * 1000.0}) 271 | 272 | icpk_pred_results.append({'time': pose_ret['icpk_rawicp_time'], 273 | 'scene_id': int(scene_id), 274 | 'im_id': int(view_id), 275 | 'obj_id': int(tar_obj_id), 276 | 'score': pose_ret['icpk_score'].squeeze().numpy(), 277 | 'R': cfg.POSE_TO_BOP(pose_ret['icpk_R']).squeeze().numpy(), 278 | 't': pose_ret['icpk_t'].squeeze().numpy() * 1000.0}) 279 | 280 | 281 | 282 | view_runtime.append(time.time() - view_timer) 283 | if (view_id+1) % 100 == 0: 284 | print('scene:{}, image: {}, rcnn:{:.3f}, image_cost:{:.3f}, raw_t:{:.3f}, icp1_t:{:.3f}, icpk_t:{:.3f}'.format( 285 | int(scene_id), view_id+1, np.mean(rcnn_runtime), np.mean(view_runtime), 286 | np.mean(raw_pred_runtime), np.mean(icp1_pred_runtime), np.mean(icpk_pred_runtime))) 287 | 288 | print('{}, {}'.format(scene_id, time.strftime('%m_%d-%H:%M:%S', time.localtime()))) 289 | 290 | rawk_eval_file = pjoin(eval_dir, raw_file_mode.format( 291 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK)) 292 | inout.save_bop_results(rawk_eval_file, raw_pred_results) 293 | 294 | mean_raw_time = np.mean(raw_pred_runtime) 295 | print('raw_mean_runtime: {:.4f}, saving to {}'.format(mean_raw_time, rawk_eval_file)) 296 | 297 | if cfg.USE_ICP: 298 | icp1_eval_file = pjoin(eval_dir, icp1_file_mode.format( 299 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 300 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 301 | )) 302 | icpk_eval_file = pjoin(eval_dir, icpk_file_mode.format( 303 | cfg.RENDER_NUM_VIEWS, cfg.VP_NUM_TOPK, cfg.POSE_NUM_TOPK, 304 | cfg.ICP_neighbors, cfg.ICP_max_iterations, cfg.ICP_correspondences, cfg.ICP_min_planarity, 305 | )) 306 | inout.save_bop_results(icp1_eval_file, icp1_pred_results) 307 | inout.save_bop_results(icpk_eval_file, icpk_pred_results) 308 | 309 | mean_icp1_time = np.mean(icp1_pred_runtime) 310 | mean_icpk_time = np.mean(icpk_pred_runtime) 311 | print('icp1_mean_runtime: {:.4f}, saving to {}'.format(mean_icp1_time, icp1_eval_file)) 312 | print('icpk_mean_runtime: {:.4f}, saving to {}'.format(mean_icpk_time, icpk_eval_file)) 313 | 314 | del obj_renderer 315 | -------------------------------------------------------------------------------- /lib/rendering.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is partially borrowed from LatentFusion 3 | """ 4 | 5 | import os 6 | import math 7 | import torch 8 | import trimesh 9 | import pyrender 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from pyrender import RenderFlags 13 | from pytorch3d.transforms import matrix_to_euler_angles, euler_angles_to_matrix 14 | 15 | from utility import meshutils 16 | 17 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 18 | 19 | def uniform_z_rotation(n, eps_degree=0): 20 | """ 21 | uniformly sample N examples range from 0 to 360 22 | """ 23 | assert n > 0, "sample number must be nonzero" 24 | eps_rad = eps_degree / 180.0 * math.pi 25 | x_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 26 | y_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 27 | z_radians = (torch.arange(n) + 1)/(n + 1) * math.pi * 2 28 | target_euler_radians = torch.stack([x_radians, y_radians, z_radians], dim=-1) 29 | target_rotation_matrix = euler_angles_to_matrix(target_euler_radians, "XYZ") 30 | return target_rotation_matrix 31 | 32 | def uniform_xy_rotation(n, eps_degree=0): 33 | """ 34 | uniformly sample N examples range from 0 to 360 35 | """ 36 | assert n > 0, "sample number must be nonzero" 37 | target_rotation_matrix = random_xyz_rotation(1) @ evenly_distributed_rotation(n) 38 | return target_rotation_matrix 39 | 40 | def random_z_rotation(n, eps_degree=0): 41 | """ 42 | randomly sample N examples range from 0 to 360 43 | """ 44 | eps_rad = eps_degree / 180. * math.pi 45 | x_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 46 | y_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 47 | z_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * math.pi # -pi, pi 48 | target_euler_radians = torch.stack([x_radians, y_radians, z_radians], dim=-1) 49 | target_euler_matrix = euler_angles_to_matrix(target_euler_radians, "XYZ") 50 | return target_euler_matrix 51 | 52 | def random_xy_rotation(n, eps_degree=0, rang_degree=180): 53 | """ 54 | randomly sample N examples range from 0 to 360 55 | """ 56 | eps_rad = eps_degree / 180. * math.pi 57 | rang_rad = rang_degree / 180 * math.pi 58 | x_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * rang_rad # -pi, pi 59 | y_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * rang_rad # -pi, pi 60 | 61 | z_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 62 | 63 | target_euler_radians = torch.stack([x_radians, y_radians, z_radians], dim=-1) 64 | target_euler_matrix = euler_angles_to_matrix(target_euler_radians, "XYZ") 65 | return target_euler_matrix 66 | 67 | def random_xyz_rotation(n, eps_degree=180): 68 | """ 69 | randomly sample N examples range from 0 to 360 70 | """ 71 | eps_rad = eps_degree / 180. * math.pi 72 | x_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -pi, pi 73 | y_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -pi, pi 74 | z_radians = (torch.rand(n, dtype=torch.float32) * 2.0 - 1.0) * eps_rad # -eps, eps 75 | 76 | target_euler_radians = torch.stack([x_radians, y_radians, z_radians], dim=-1) 77 | target_euler_matrix = euler_angles_to_matrix(target_euler_radians, "XYZ") 78 | return target_euler_matrix 79 | 80 | def evenly_distributed_rotation(n, random_seed=None): 81 | """ 82 | uniformly sample N examples on a sphere 83 | """ 84 | def normalize(vector, dim: int = -1): 85 | return vector / torch.norm(vector, p=2.0, dim=dim, keepdim=True) 86 | 87 | if random_seed is not None: 88 | torch.manual_seed(random_seed) # fix the sampling of viewpoints for reproducing evaluation 89 | 90 | indices = torch.arange(0, n, dtype=torch.float32) + 0.5 91 | 92 | phi = torch.acos(1 - 2 * indices / n) 93 | theta = math.pi * (1 + 5 ** 0.5) * indices 94 | points = torch.stack([ 95 | torch.cos(theta) * torch.sin(phi), 96 | torch.sin(theta) * torch.sin(phi), 97 | torch.cos(phi),], dim=1) 98 | forward = -points 99 | 100 | down = normalize(torch.randn(n, 3), dim=1) 101 | right = normalize(torch.cross(down, forward)) 102 | down = normalize(torch.cross(forward, right)) 103 | R_mat = torch.stack([right, down, forward], dim=1) 104 | return R_mat 105 | 106 | def load_object(path, scale=1.0, size=1.0, recenter=True, resize=True, 107 | bound_type='diameter', load_materials=False) -> meshutils.Object3D: 108 | """ 109 | Loads an object model as an Object3D instance. 110 | 111 | Args: 112 | path: the path to the 3D model 113 | scale: a scaling factor to apply after all transformations 114 | size: the reference 'size' of the object if `resize` is True 115 | recenter: if True the object will be recentered at the centroid 116 | resize: if True the object will be resized to fit insize a cube of size `size` 117 | bound_type: how to compute size for resizing. Either 'diameter' or 'extents' 118 | 119 | Returns: 120 | (meshutils.Object3D): the loaded object model 121 | """ 122 | obj = meshutils.Object3D(path, load_materials=load_materials) 123 | 124 | if recenter: 125 | obj.recenter('bounds') 126 | 127 | if resize: 128 | if bound_type == 'diameter': 129 | object_scale = size / obj.bounding_diameter 130 | elif bound_type == 'extents': 131 | object_scale = size / obj.bounding_size 132 | else: 133 | raise ValueError(f"Unkown size_type {bound_type!r}") 134 | 135 | obj.rescale(object_scale) 136 | else: 137 | object_scale = 1.0 138 | 139 | if scale != 1.0: 140 | obj.rescale(scale) 141 | 142 | return obj, obj.bounding_diameter 143 | 144 | def _create_object_node(obj: meshutils.Object3D): 145 | smooth = True 146 | # Turn smooth shading off if vertex normals are unreliable. 147 | if obj.are_normals_corrupt(): 148 | smooth = False 149 | 150 | mesh = pyrender.Mesh.from_trimesh(obj.meshes, smooth=smooth) 151 | node = pyrender.Node(mesh=mesh) 152 | 153 | return node 154 | 155 | 156 | class SceneContext(object): 157 | """ 158 | A wrapper class containing all contextual information needed for rendering. 159 | """ 160 | 161 | def __init__(self, obj, intrinsic: torch.Tensor): 162 | self.obj = obj 163 | self.intrinsic = intrinsic.squeeze() 164 | self.extrinsic = None 165 | self.scene = pyrender.Scene(bg_color=(0, 0, 0, 0), ambient_light=(0.1, 0.1, 0.1)) 166 | 167 | fx = self.intrinsic[0, 0].item() 168 | fy = self.intrinsic[1, 1].item() 169 | cx = self.intrinsic[0, 2].item() 170 | cy = self.intrinsic[1, 2].item() 171 | 172 | self.camera = pyrender.IntrinsicsCamera(fx, fy, cx, cy) 173 | self.camera_node = self.scene.add(self.camera, name='camera') 174 | self.object_node = _create_object_node(self.obj) 175 | 176 | self.scene.add_node(self.object_node) 177 | 178 | def object_to_camera_pose(self, object_pose): 179 | """ 180 | Take an object pose and converts it to a camera pose. 181 | 182 | Takes a matrix that transforms object-space points to camera-space points and converts it 183 | to a matrix that takes OpenGL camera-space points and converts it into object-space points. 184 | """ 185 | CAM_REF_POSE = torch.tensor(( 186 | (1, 0, 0, 0), 187 | (0, -1, 0, 0), 188 | (0, 0, -1, 0), 189 | (0, 0, 0, 1), 190 | ), dtype=torch.float32) 191 | 192 | camera_transform = self.inverse_transform(object_pose) 193 | 194 | # We must flip the z-axis before performing our transformation so that the z-direction is 195 | # pointing in the correct direction when we feed this as OpenGL coordinates. 196 | return CAM_REF_POSE.t()[None, ...] @ camera_transform @ CAM_REF_POSE[None, ...] 197 | 198 | def set_pose(self, translation, rotation): 199 | extrinsic = self.RT_to_matrix(R=rotation, T=translation) 200 | self.extrinsic = extrinsic 201 | camera_pose = self.object_to_camera_pose(extrinsic).squeeze().numpy() 202 | assert len(camera_pose.shape) == 2, 'camera pose for pyrender must be 4 x 4' 203 | self.scene.set_pose(self.camera_node, camera_pose) 204 | 205 | def inverse_transform(self, matrix): 206 | if matrix.dim() == 2: 207 | matrix = matrix[None, ...] 208 | R = matrix[:, :3, :3] # B x 3 x 3 209 | T = matrix[:, :3, 3:4] # B x 3 x 1 210 | R_inv = R.transpose(-2, -1) # B x 3 x 3 211 | t_inv = (R_inv @ T).squeeze(2)# B x 3 212 | 213 | out = torch.zeros_like(matrix) 214 | out[:, :3, :3] = R_inv[:, :3, :3] 215 | out[:, :3, 3] = -t_inv 216 | out[:, 3, 3] = 1 217 | return out 218 | 219 | def RT_to_matrix(self, R, T): 220 | if R.shape[-1] == 3: 221 | R = F.pad(R, (0, 1, 0, 1)) # 4 x 4 222 | if R.dim() == 2: 223 | R = R[None, ...] 224 | if T.dim() == 1: 225 | T = T[None, ...] 226 | R[:, :3, 3] = T 227 | R[:, -1, -1] = 1.0 228 | return R 229 | 230 | 231 | class Renderer(object): 232 | """ 233 | A thin wrapper around the PyRender renderer. 234 | """ 235 | def __init__(self, width, height): 236 | self._renderer = pyrender.OffscreenRenderer(width, height) 237 | self._render_flags = RenderFlags.SKIP_CULL_FACES | RenderFlags.RGBA 238 | 239 | @property 240 | def width(self): 241 | return self._renderer.viewport_width 242 | 243 | @property 244 | def height(self): 245 | return self._renderer.viewport_height 246 | 247 | def __del__(self): 248 | self._renderer.delete() 249 | 250 | def render(self, context): 251 | color, depth = self._renderer.render(context.scene, flags=self._render_flags) 252 | color = color.copy().astype(np.float32) / 255.0 253 | color = torch.tensor(color) 254 | depth = torch.tensor(depth) 255 | # mask = color[..., 3] 256 | mask = (depth > 0).float() 257 | color = color[..., :3] 258 | return color, depth, mask 259 | 260 | 261 | def rendering_views(obj_mesh, intrinsic, R, T, height=540, width=720): 262 | obj_scene = SceneContext(obj=obj_mesh, intrinsic=intrinsic) # define a scene 263 | obj_renderer = Renderer(width=width, height=height) # define a renderer 264 | obj_depths = list() 265 | obj_masks = list() 266 | if R.dim() == 2: 267 | R = R[None, ...] 268 | if T.dim() == 1: 269 | T = T[None, ...] 270 | for anc_R, anc_T in zip(R, T): 271 | obj_scene.set_pose(rotation=anc_R, translation=anc_T) 272 | color, depth, mask = obj_renderer.render(obj_scene) 273 | obj_depths.append(depth) 274 | obj_masks.append(mask) 275 | del obj_scene 276 | obj_depths = torch.stack(obj_depths, dim=0).unsqueeze(1) 277 | obj_masks = torch.stack(obj_masks, dim=0).unsqueeze(1) 278 | return obj_depths, obj_masks 279 | 280 | def render_uniform_sampling_views(model_path, intrinsic, scale=1.0, num_views=1000, dist=0.8, height=540, width=720): 281 | obj, obj_scale = load_object(model_path, resize=False, recenter=False) 282 | obj.rescale(scale=scale) # from millimeter normalize to meter 283 | obj_scene = SceneContext(obj=obj, intrinsic=intrinsic) # define a scene 284 | obj_renderer = Renderer(width=width, height=height) # define a renderer 285 | 286 | obj_R = evenly_distributed_rotation(n=num_views) # uniform rotational views sampling from a shpere, N x 3 x 3 287 | obj_T = torch.zeros_like(obj_R[:, :, 0]) # constant distance, N x 3 288 | obj_T[:, -1] = dist 289 | 290 | obj_diameter = (((obj.vertices.max(0) - obj.vertices.min(0))**2).sum())**0.5 291 | obj_T = obj_T * obj_diameter # scaling according to specific object size 292 | 293 | obj_depths = list() 294 | obj_masks = list() 295 | 296 | for anc_R, anc_T in zip(obj_R, obj_T): 297 | obj_scene.set_pose(rotation=anc_R, translation=anc_T) 298 | color, depth, mask = obj_renderer.render(obj_scene) 299 | obj_depths.append(depth) 300 | obj_masks.append(mask) 301 | obj_depths = torch.stack(obj_depths, dim=0).unsqueeze(1) 302 | obj_masks = torch.stack(obj_masks, dim=0).unsqueeze(1) 303 | del obj_scene 304 | # del obj_renderer 305 | return obj_depths, obj_masks, obj_R, obj_T 306 | 307 | def render_RT_views(model_path, intrinsic, R, T, scale=1.0, height=540, width=720): 308 | obj_mesh, obj_scale = load_object(model_path, resize=False, recenter=False) 309 | obj_mesh.rescale(scale=scale) # from millimeter normalize to meter 310 | obj_scene = SceneContext(obj=obj_mesh, intrinsic=intrinsic) # define a scene 311 | obj_renderer = Renderer(width=width, height=height) # define a renderer 312 | obj_depths = list() 313 | obj_masks = list() 314 | if R.dim() == 2: 315 | R = R[None, ...] 316 | T = T[None, ...] 317 | for anc_R, anc_T in zip(R, T): 318 | obj_scene.set_pose(rotation=anc_R, translation=anc_T) 319 | color, depth, mask = obj_renderer.render(obj_scene) 320 | obj_depths.append(depth) 321 | obj_masks.append(mask) 322 | del obj_scene 323 | # del obj_renderer 324 | obj_depths = torch.stack(obj_depths, dim=0).unsqueeze(1) 325 | obj_masks = torch.stack(obj_masks, dim=0).unsqueeze(1) 326 | return obj_depths, obj_masks 327 | 328 | def render_single_view(model_path, intrinsic, R, T, scale=1.0, height=540, width=720): 329 | assert R.dim() == 2 and T.dim() == 1, "pyrender R and T shape " + R.shape 330 | obj, obj_scale = load_object(model_path, resize=False, recenter=False) 331 | obj.rescale(scale=scale) # from millimeter normalize to meter 332 | obj_scene = SceneContext(obj=obj, intrinsic=intrinsic) # define a scene 333 | obj_renderer = Renderer(width=width, height=height) # define a renderer 334 | obj_scene.set_pose(rotation=R, translation=T) 335 | color, depth, mask = obj_renderer.render(obj_scene) 336 | del obj_scene 337 | # del obj_renderer 338 | return depth, mask 339 | 340 | def render_sampling_pair_views(mesh_file, intrinsic, num_views=1000, dist=0.8, height=540, width=720, dist_jitter=0.2): 341 | obj_trimesh = trimesh.load(mesh_file) 342 | obj_trimesh.vertices = obj_trimesh.vertices / 1000.0 343 | # obj_trimesh.vertices = obj_trimesh.vertices - obj_trimesh.vertices.mean(0) 344 | 345 | obj_mesh = pyrender.Mesh.from_trimesh(obj_trimesh) 346 | 347 | obj_scene = SceneContext(mesh=obj_mesh, intrinsic=intrinsic) 348 | obj_renderer = Renderer(width=width, height=height) 349 | 350 | Rxy = random_xy_rotation(num_views, eps_degree=2) 351 | Rz = random_z_rotation(num_views, eps_degree=2) 352 | camera_T = torch.tensor([0.0, 0.0, dist], dtype=torch.float32).repreat(num_views, 1) 353 | camera_T = camera_T + (torch.rand_like(camera_T) - 0.5) * dist_jitter 354 | 355 | diameter = (((obj_trimesh.vertices.max(0)[0] - obj_trimesh.vertices.min(0)[0])**2).sum())**0.5 356 | camera_T = camera_T.clone() * diameter 357 | 358 | obj_Rxyz_depths = list() 359 | obj_Rxyz_masks = list() 360 | obj_Rxy_depths = list() 361 | obj_Rxy_masks = list() 362 | 363 | for anc_R, anc_T in zip(Rxy, camera_T): 364 | obj_scene.set_pose(rotation=anc_R, translation=anc_T) 365 | color, depth, mask = obj_renderer.render(obj_scene) 366 | obj_Rxy_depths.append(depth) 367 | obj_Rxy_masks.append(mask) 368 | 369 | obj_Rxy_depths = torch.stack(obj_Rxy_depths, dim=0) 370 | obj_Rxy_masks = torch.stack(obj_Rxy_masks, dim=0) 371 | 372 | Rxyz = Rz @ Rxy 373 | for anc_R, anc_T in zip(Rxyz, camera_T): 374 | obj_scene.set_pose(rotation=anc_R, translation=anc_T) 375 | color, depth, mask = obj_renderer.render(obj_scene) 376 | obj_Rxyz_depths.append(depth) 377 | obj_Rxyz_masks.append(mask) 378 | 379 | obj_Rxyz_depths = torch.stack(obj_Rxyz_depths, dim=0) 380 | obj_Rxyz_masks = torch.stack(obj_Rxyz_masks, dim=0) 381 | 382 | 383 | # del obj_renderer 384 | return obj_Rxy_depths, obj_Rxyz_depths, obj_Rxy_masks, obj_Rxyz_masks, Rxy, Rxyz, camera_T, Rz -------------------------------------------------------------------------------- /lib/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib import geometry 4 | 5 | def background_filter(depths, diameters, dist_factor=0.5): 6 | """ 7 | filter out the outilers beyond the object diameter 8 | """ 9 | new_depths = list() 10 | unsqueeze = False 11 | if not isinstance(diameters, torch.Tensor): 12 | diameters = torch.tensor(diameters) 13 | if diameters.dim() == 0: 14 | diameters = diameters[None, ...] 15 | if depths.dim() == 2: 16 | depths = depths[None, ...] 17 | if depths.dim() > 3: 18 | depths = depths.view(-1, depths.shape[-2], depths.shape[-1]) 19 | diameters = diameters.view(-1) 20 | unsqueeze = True 21 | assert len(depths) == len(diameters) 22 | for ix, dep in enumerate(depths): 23 | hei, wid = dep.shape 24 | diameter = diameters[ix] 25 | if (dep>0).sum() < 10: 26 | new_depths.append(dep) 27 | continue 28 | 29 | dep_vec = dep.view(-1) 30 | dep_val = dep_vec[dep_vec>0].clone() 31 | med_val = dep_val.median() 32 | 33 | dep_dist = (dep_val - med_val).abs() 34 | dist, indx = torch.topk(dep_dist, k=len(dep_dist)) 35 | invalid_idx = indx[dist > dist_factor * diameter] 36 | dep_val[invalid_idx] = 0 37 | dep_vec[dep_vec>0] = dep_val 38 | new_dep = dep_vec.view(hei, wid) 39 | if (new_dep>0).sum() < 100: # the number of valid depth values is too small, then return old one 40 | new_depths.append(dep) 41 | else: 42 | new_depths.append(new_dep) 43 | 44 | new_depths = torch.stack(new_depths, dim=0).to(depths.device) 45 | if unsqueeze: 46 | new_depths = new_depths.unsqueeze(1) 47 | return new_depths 48 | 49 | def convert_3Dcoord_to_2Dpixel(obj_t, intrinsic): 50 | """ 51 | convert the 3D space coordinates (dx, dy, dz) to 2D pixel coordinates (px, py, dz) 52 | """ 53 | obj_t = obj_t.squeeze() 54 | K = intrinsic.squeeze().to(obj_t.device) 55 | 56 | assert(obj_t.dim() <= 2), 'the input dimension must be 3 or Nx3' 57 | assert(K.dim() <= 3), 'the input dimension must be 3x3 or Nx3x3' 58 | 59 | if obj_t.dim() == 1: 60 | obj_t = obj_t[None, ...] 61 | if K.dim() == 2: 62 | K = K.unsqueeze(0).expand(obj_t.size(0), 1, 1) 63 | 64 | assert obj_t.size(0) == K.size(0), 'batch size must be equal' 65 | dz = obj_t[:, 2] 66 | px = obj_t[:, 0] / dz * K[:, 0, 0] + K[:, 0, 2] 67 | py = obj_t[:, 1] / dz * K[:, 1, 1] + K[:, 1, 2] 68 | new_t = torch.stack([px, py, dz], dim=1) 69 | return new_t 70 | 71 | def input_zoom_preprocess(images, target_dist, intrinsic, extrinsic=None, 72 | images_mask=None, normalize=True, dz=None, 73 | target_size=128, scale_mode='nearest'): 74 | device = images.device 75 | intrinsic = intrinsic.to(device) 76 | height, width = images.shape[-2:] 77 | 78 | assert(images.dim()==3 or images.dim()==4) 79 | if images.dim() == 3: 80 | images = images[None, ...] 81 | 82 | if images_mask is None: 83 | images_mask = torch.zeros_like(images) 84 | images_mask[images>0] = 1.0 85 | 86 | images_mask = images_mask.to(device) 87 | 88 | assert(images_mask.dim()==3 or images_mask.dim()==4) 89 | if images_mask.dim() == 3: 90 | images_mask = images_mask[None, ...] 91 | 92 | if not isinstance(target_dist, torch.Tensor): 93 | target_dist = torch.tensor(target_dist) 94 | 95 | target_dist = target_dist.to(device) 96 | 97 | if extrinsic is None: 98 | obj_translations = torch.stack(geometry.estimate_translation(depth=images, 99 | mask=images_mask, 100 | intrinsic=intrinsic), dim=1).to(device) 101 | if dz is not None: 102 | obj_translations[:, 2] = dz.to(device) 103 | else: 104 | extrinsic = extrinsic.to(device) 105 | obj_translations = extrinsic[:, :3, 3] # N x 3 106 | 107 | obj_zs = obj_translations[:, 2] 108 | 109 | if normalize: 110 | images -= images_mask * obj_zs[..., None, None, None].to(device) 111 | 112 | if extrinsic is None: 113 | cameras = geometry.Camera(intrinsic=intrinsic, height=height, width=width) 114 | obj_centroids = geometry.masks_to_centroids(images_mask) 115 | zoom_images, zoom_camera = cameras.zoom(image=images, 116 | target_dist=target_dist, 117 | target_size=target_size, 118 | zs=obj_zs, 119 | centroid_uvs=obj_centroids, 120 | scale_mode=scale_mode) 121 | # zoom_masks, _ = cameras.zoom(image=images_mask, 122 | # target_dist=target_dist, 123 | # target_size=target_size, 124 | # zs=obj_zs, 125 | # centroid_uvs=obj_centroids, 126 | # scale_mode=scale_mode) 127 | else: 128 | cameras = geometry.Camera(intrinsic=intrinsic, extrinsic=extrinsic, width=width, height=height) 129 | zoom_images, zoom_camera = cameras.zoom(images, 130 | target_dist=target_dist, 131 | target_size=target_size, 132 | scale_mode=scale_mode) 133 | # zoom_masks, _ = cameras.zoom(images_mask, 134 | # target_dist=target_dist, 135 | # target_size=target_size, 136 | # scale_mode=scale_mode) 137 | return zoom_images, zoom_camera, obj_translations 138 | 139 | 140 | def inplane_residual_theta(gt_t, init_t, gt_Rz, config, target_dist, device): 141 | """ 142 | gt_t(Nx3): the ground truth translation 143 | est_t(Nx3: the initial translation (directly estimated from depth) 144 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 145 | 146 | return: the relative transformation between the anchor image and the query image 147 | 148 | """ 149 | W = config.RENDER_WIDTH 150 | H = config.RENDER_HEIGHT 151 | fx = config.INTRINSIC[0, 0] 152 | fy = config.INTRINSIC[1, 1] 153 | cx = config.INTRINSIC[0, 2] 154 | cy = config.INTRINSIC[1, 2] 155 | 156 | gt_t = gt_t.clone().to(device) # Nx3 157 | init_t = init_t.clone().to(device) # Nx3 158 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 159 | 160 | gt_tx = gt_t[:, 0:1] 161 | gt_ty = gt_t[:, 1:2] 162 | gt_tz = gt_t[:, 2:3] 163 | 164 | init_tx = init_t[:, 0:1] 165 | init_ty = init_t[:, 1:2] 166 | init_tz = init_t[:, 2:3] 167 | 168 | if not isinstance(target_dist, torch.Tensor): 169 | target_dist = torch.tensor(target_dist) 170 | if target_dist.dim() == 1: 171 | target_dist = target_dist[..., None] # Nx1 172 | if target_dist.dim() != 0: 173 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 174 | 175 | init_scale = target_dist.to(device) / init_tz # Nx1 / config.ZOOM_CROP_SIZE 176 | 177 | gt_t[:, 0:1] = (gt_tx / gt_tz * fx + cx) / W # Nx1 * gt_scale # projection to 2D image plane 178 | gt_t[:, 1:2] = (gt_ty / gt_tz * fy + cy) / H # Nx1 * gt_scale 179 | 180 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) / W # Nx1 * init_scale 181 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) / H # Nx1 * init_scale 182 | 183 | offset_t = gt_t - init_t # N x 3 [dx, dy, dz] unit with (pixel, pixel, meter) 184 | offset_t[:, :2] = offset_t[:, :2] * init_scale 185 | 186 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 187 | res_T[:, :2, :2] = Rz_rot 188 | res_T[:, :3, 2] = offset_t 189 | 190 | return res_T 191 | 192 | 193 | def spatial_transform_2D(x, theta, mode='nearest', padding_mode='border', align_corners=False): 194 | assert(x.dim()==3 or x.dim()==4) 195 | assert(theta.dim()==2 or theta.dim()==3) 196 | assert(theta.shape[-2]==2 and theta.shape[-1]==3), "theta must be Nx2x3" 197 | if x.dim() == 3: 198 | x = x[None, ...] 199 | if theta.dim() == 2: 200 | theta = theta[None, ...].repeat(x.size(0), 1, 1) 201 | 202 | stn_theta = theta.clone() 203 | stn_theta[:, :2, :2] = theta[:, :2, :2].transpose(-1, -2) 204 | stn_theta[:, :2, 2:3] = -(stn_theta[:, :2, :2] @ stn_theta[:, :2, 2:3]) 205 | 206 | grid = F.affine_grid(stn_theta.to(x.device), x.shape, align_corners=align_corners) 207 | new_x = F.grid_sample(x.type(grid.dtype), grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 208 | return new_x 209 | 210 | def recover_full_translation(init_t, offset_t, config, target_dist, device): 211 | W = config.RENDER_WIDTH 212 | H = config.RENDER_HEIGHT 213 | fx = config.INTRINSIC[0, 0] 214 | fy = config.INTRINSIC[1, 1] 215 | 216 | dx = offset_t[:, 0:1].to(device) # Bx1 217 | dy = offset_t[:, 1:2].to(device) # Bx1 218 | dz = offset_t[:, 2:3].to(device) # Bx1 219 | 220 | init_tx = init_t[:, 0:1].to(device) # Bx1 221 | init_ty = init_t[:, 1:2].to(device) # Bx1 222 | init_tz = init_t[:, 2:3].to(device) # Bx1 223 | 224 | if not isinstance(target_dist, torch.Tensor): 225 | target_dist = torch.tensor(target_dist) 226 | if target_dist.dim() == 1: 227 | target_dist = target_dist[..., None] # Nx1 228 | if target_dist.dim() != 0: 229 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 230 | 231 | init_scale = target_dist.to(device) / init_tz #/ config.ZOOM_CROP_SIZE 232 | 233 | est_tz = init_tz + dz.to(device) 234 | est_tx = est_tz * (W / init_scale / fx * dx + init_tx/init_tz) # Nx1 235 | est_ty = est_tz * (H / init_scale / fy * dy + init_ty/init_tz) 236 | 237 | # print(est_tx.shape, est_ty.shape, est_tz.shape) 238 | 239 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 240 | 241 | return est_full_t 242 | 243 | 244 | def residual_inplane_transform(gt_t, init_t, gt_Rz, config, target_dist, device): 245 | """ 246 | gt_t(Nx3): the ground truth translation 247 | est_t(Nx3: the initial translation (directly estimated from depth) 248 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 249 | return: the relative transformation between the anchor image and the query image 250 | """ 251 | # W = config.RENDER_WIDTH 252 | # H = config.RENDER_HEIGHT 253 | fx = config.INTRINSIC[0, 0] 254 | fy = config.INTRINSIC[1, 1] 255 | cx = config.INTRINSIC[0, 2] 256 | cy = config.INTRINSIC[1, 2] 257 | 258 | gt_t = gt_t.clone().to(device) # Nx3 259 | init_t = init_t.clone().to(device) # Nx3 260 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 261 | 262 | gt_tx = gt_t[:, 0:1] 263 | gt_ty = gt_t[:, 1:2] 264 | gt_tz = gt_t[:, 2:3] 265 | 266 | init_tx = init_t[:, 0:1] 267 | init_ty = init_t[:, 1:2] 268 | init_tz = init_t[:, 2:3] 269 | 270 | if not isinstance(target_dist, torch.Tensor): 271 | target_dist = torch.tensor(target_dist) 272 | if target_dist.dim() == 1: 273 | target_dist = target_dist[..., None] # Nx1 274 | if target_dist.dim() != 0: 275 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 276 | target_dist = target_dist.to(device) 277 | 278 | tz_offset_frac = gt_tz / init_tz # gt_tz = tz_factor * init_tz, the ratio bwteen the ground truth distance and initial distance 279 | 280 | gt_t[:, 0:1] = (gt_tx / gt_tz * fx + cx) # Nx1, pixel coordinate projected on 2D image plane 281 | gt_t[:, 1:2] = (gt_ty / gt_tz * fy + cy) # Nx1 282 | 283 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) # Nx1 284 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) # Nx1 285 | 286 | gt_crop_scaling = target_dist / gt_tz # the scaling factor for the cropped object patch 287 | # init_crop_scaling = target_dist / init_tz 288 | 289 | gt_bbox_size = gt_crop_scaling * config.ZOOM_SIZE # the bbox size of the cropped object with gt distance 290 | # init_bbox_size = gt_bbox_size * tz_offset_frac 291 | 292 | delta_px = gt_tx - init_tx # from source image center to target image center 293 | delta_py = gt_ty - init_ty # from source image center to target image center 294 | 295 | px_offset_frac = delta_px / gt_bbox_size # convert the offset relative to the target image size 296 | py_offset_frac = delta_py / gt_bbox_size # convert the offset relative to the target image size 297 | 298 | offset_t = torch.cat([px_offset_frac, py_offset_frac, tz_offset_frac], dim=1) 299 | 300 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 301 | res_T[:, :2, :2] = Rz_rot 302 | res_T[:, :3, 2] = offset_t 303 | 304 | return res_T 305 | 306 | 307 | def recover_residual_translation(init_t, offset_t, config, target_dist, device): 308 | # W = config.RENDER_WIDTH 309 | # H = config.RENDER_HEIGHT 310 | fx = config.INTRINSIC[0, 0] 311 | fy = config.INTRINSIC[1, 1] 312 | cx = config.INTRINSIC[0, 2] 313 | cy = config.INTRINSIC[1, 2] 314 | 315 | init_t = init_t.clone().to(device) # Nx3 316 | offset_t = offset_t.clone().to(device) # Nx3 317 | 318 | init_tx = init_t[:, 0:1] # Bx1 319 | init_ty = init_t[:, 1:2] # Bx1 320 | init_tz = init_t[:, 2:3] # Bx1 321 | 322 | px_offset_frac = offset_t[:, 0:1] # Bx1 323 | py_offset_frac = offset_t[:, 1:2] # Bx1 324 | tz_offset_frac = offset_t[:, 2:3] # Bx1 325 | 326 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) # Nx1 * init_scale 327 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) # Nx1 * init_scale 328 | 329 | if not isinstance(target_dist, torch.Tensor): 330 | target_dist = torch.tensor(target_dist) 331 | if target_dist.dim() == 1: 332 | target_dist = target_dist[..., None] # Nx1 333 | if target_dist.dim() != 0: 334 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 335 | target_dist = target_dist.to(device) 336 | 337 | init_crop_scaling = target_dist / init_tz 338 | init_bbox_size = init_crop_scaling * config.ZOOM_SIZE 339 | pd_bbox_size = init_bbox_size / tz_offset_frac 340 | 341 | pd_delta_px = px_offset_frac * pd_bbox_size 342 | pd_delta_py = py_offset_frac * pd_bbox_size 343 | 344 | pd_px = init_t[:, 0:1] + pd_delta_px 345 | pd_py = init_t[:, 1:2] + pd_delta_py 346 | 347 | est_tz = tz_offset_frac * init_tz 348 | 349 | # est_tz = init_tz + tz_offset_frac * init_tz 350 | 351 | 352 | est_tx = (pd_px - cx) / fx * est_tz 353 | est_ty = (pd_py - cy) / fy * est_tz 354 | 355 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 356 | 357 | return est_full_t 358 | 359 | 360 | 361 | def residual_inplane_transform3(gt_t, init_t, gt_Rz, config, target_dist, device): 362 | """ 363 | gt_t(Nx3): the ground truth translation 364 | est_t(Nx3: the initial translation (directly estimated from depth) 365 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 366 | return: the relative transformation between the anchor image and the query image 367 | """ 368 | # W = config.RENDER_WIDTH 369 | # H = config.RENDER_HEIGHT 370 | fx = config.INTRINSIC[0, 0] 371 | fy = config.INTRINSIC[1, 1] 372 | cx = config.INTRINSIC[0, 2] 373 | cy = config.INTRINSIC[1, 2] 374 | 375 | gt_t = gt_t.clone().to(device) # Nx3 376 | init_t = init_t.clone().to(device) # Nx3 377 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 378 | 379 | gt_tx = gt_t[:, 0:1] 380 | gt_ty = gt_t[:, 1:2] 381 | gt_tz = gt_t[:, 2:3] 382 | 383 | init_tx = init_t[:, 0:1] 384 | init_ty = init_t[:, 1:2] 385 | init_tz = init_t[:, 2:3] 386 | 387 | # tz_offset_frac = (gt_tz - init_tz) / init_tz # gt_tz = init_tz + tz_offset_frac * init_tz 388 | 389 | tz_offset_frac = (gt_tz - init_tz)# / init_tz # gt_tz = init_tz + tz_offset_frac * init_tz 390 | 391 | if not isinstance(target_dist, torch.Tensor): 392 | target_dist = torch.tensor(target_dist) 393 | if target_dist.dim() == 1: 394 | target_dist = target_dist[..., None] # Nx1 395 | if target_dist.dim() != 0: 396 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 397 | target_dist = target_dist.to(device) 398 | 399 | # object GT 2D center in image 400 | gt_px = (gt_tx / gt_tz * fx + cx) # Nx1, pixel x-coordinate of the object gt_center 401 | gt_py = (gt_ty / gt_tz * fy + cy) # Nx1, pixel y-coordinate of the object gt_center 402 | 403 | # object initial 2D center in image 404 | init_px = (init_tx / init_tz * fx + cx) # Nx1 405 | init_py = (init_ty / init_tz * fy + cy) # Nx1 406 | 407 | offset_px = gt_px - init_px # from source image center to target image center 408 | offset_py = gt_py - init_py # from source image center to target image center 409 | 410 | # gt_box_size = 1.0 * target_dist / gt_tz * config.ZOOM_SIZE # cropped patch size with the gt depth 411 | init_box_size = 1.0 * target_dist / init_tz * config.ZOOM_SIZE # cropped patch size with the estimated depth 412 | 413 | px_offset_frac = offset_px / (init_box_size / 2.0) 414 | py_offset_frac = offset_py / (init_box_size / 2.0) 415 | 416 | offset_t = torch.cat([px_offset_frac, py_offset_frac, tz_offset_frac], dim=1) 417 | 418 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 419 | res_T[:, :2, :2] = Rz_rot 420 | res_T[:, :3, 2] = offset_t 421 | 422 | return res_T 423 | 424 | 425 | def recover_residual_translation3(init_t, offset_t, config, target_dist, device): 426 | # W = config.RENDER_WIDTH 427 | # H = config.RENDER_HEIGHT 428 | fx = config.INTRINSIC[0, 0] 429 | fy = config.INTRINSIC[1, 1] 430 | cx = config.INTRINSIC[0, 2] 431 | cy = config.INTRINSIC[1, 2] 432 | 433 | init_t = init_t.clone().to(device) # Nx3 434 | offset_t = offset_t.clone().to(device) # Nx3 435 | 436 | init_tx = init_t[:, 0:1] # Bx1 437 | init_ty = init_t[:, 1:2] # Bx1 438 | init_tz = init_t[:, 2:3] # Bx1 439 | 440 | px_offset_frac = offset_t[:, 0:1] # Bx1 441 | py_offset_frac = offset_t[:, 1:2] # Bx1 442 | tz_offset_frac = offset_t[:, 2:3] # Bx1 443 | 444 | init_px = (init_tx / init_tz * fx + cx) # Nx1 * init_scale 445 | init_py = (init_ty / init_tz * fy + cy) # Nx1 * init_scale 446 | 447 | if not isinstance(target_dist, torch.Tensor): 448 | target_dist = torch.tensor(target_dist) 449 | if target_dist.dim() == 1: 450 | target_dist = target_dist[..., None] # Nx1 451 | if target_dist.dim() != 0: 452 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 453 | target_dist = target_dist.to(device) 454 | 455 | init_box_size = 1.0 * target_dist / init_tz * config.ZOOM_SIZE # cropped patch size with the estimated depth 456 | 457 | est_px = init_px + px_offset_frac / 2.0 * init_box_size 458 | est_py = init_py + py_offset_frac / 2.0 * init_box_size 459 | 460 | est_tz = init_tz + tz_offset_frac # * init_tz 461 | est_tx = (est_px - cx) / fx * est_tz 462 | est_ty = (est_py - cy) / fy * est_tz 463 | 464 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 465 | 466 | return est_full_t 467 | 468 | 469 | 470 | 471 | -------------------------------------------------------------------------------- /training/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from lib import geometry 5 | 6 | def background_filter(depths, diameters, dist_factor=0.5): 7 | """ 8 | filter out the outilers beyond the object diameter 9 | """ 10 | new_depths = list() 11 | unsqueeze = False 12 | if not isinstance(diameters, torch.Tensor): 13 | diameters = torch.tensor(diameters) 14 | if diameters.dim() == 0: 15 | diameters = diameters[None, ...] 16 | if depths.dim() == 2: 17 | depths = depths[None, ...] 18 | if depths.dim() > 3: 19 | depths = depths.view(-1, depths.shape[-2], depths.shape[-1]) 20 | diameters = diameters.view(-1) 21 | unsqueeze = True 22 | assert len(depths) == len(diameters) 23 | for ix, dep in enumerate(depths): 24 | hei, wid = dep.shape 25 | diameter = diameters[ix] 26 | if (dep>0).sum() < 10: 27 | new_depths.append(dep) 28 | continue 29 | 30 | dep_vec = dep.view(-1) 31 | dep_val = dep_vec[dep_vec>0].clone() 32 | med_val = dep_val.median() 33 | 34 | dep_dist = (dep_val - med_val).abs() 35 | dist, indx = torch.topk(dep_dist, k=len(dep_dist)) 36 | invalid_idx = indx[dist > dist_factor * diameter] 37 | dep_val[invalid_idx] = 0 38 | dep_vec[dep_vec>0] = dep_val 39 | new_dep = dep_vec.view(hei, wid) 40 | if (new_dep>0).sum() < 100: # the number of valid depth values is too small, then return old one 41 | new_depths.append(dep) 42 | else: 43 | new_depths.append(new_dep) 44 | 45 | new_depths = torch.stack(new_depths, dim=0).to(depths.device) 46 | if unsqueeze: 47 | new_depths = new_depths.unsqueeze(1) 48 | return new_depths 49 | 50 | 51 | def convert_3Dcoord_to_2Dpixel(obj_t, intrinsic): 52 | """ 53 | convert the 3D space coordinates (dx, dy, dz) to 2D pixel coordinates (px, py, dz) 54 | """ 55 | obj_t = obj_t.squeeze() 56 | K = intrinsic.squeeze().to(obj_t.device) 57 | 58 | assert(obj_t.dim() <= 2), 'the input dimension must be 3 or Nx3' 59 | assert(K.dim() <= 3), 'the input dimension must be 3x3 or Nx3x3' 60 | 61 | if obj_t.dim() == 1: 62 | obj_t = obj_t[None, ...] 63 | if K.dim() == 2: 64 | K = K.unsqueeze(0).expand(obj_t.size(0), 1, 1) 65 | 66 | assert obj_t.size(0) == K.size(0), 'batch size must be equal' 67 | dz = obj_t[:, 2] 68 | px = obj_t[:, 0] / dz * K[:, 0, 0] + K[:, 0, 2] 69 | py = obj_t[:, 1] / dz * K[:, 1, 1] + K[:, 1, 2] 70 | new_t = torch.stack([px, py, dz], dim=1) 71 | return new_t 72 | 73 | 74 | def input_zoom_preprocess(images, target_dist, intrinsic, extrinsic=None, 75 | images_mask=None, normalize=True, dz=None, 76 | target_size=128, scale_mode='nearest'): 77 | device = images.device 78 | intrinsic = intrinsic.to(device) 79 | height, width = images.shape[-2:] 80 | 81 | assert(images.dim()==3 or images.dim()==4) 82 | if images.dim() == 3: 83 | images = images[None, ...] 84 | 85 | if images_mask is None: 86 | images_mask = torch.zeros_like(images) 87 | images_mask[images>0] = 1.0 88 | 89 | images_mask = images_mask.to(device) 90 | 91 | assert(images_mask.dim()==3 or images_mask.dim()==4) 92 | if images_mask.dim() == 3: 93 | images_mask = images_mask[None, ...] 94 | 95 | if not isinstance(target_dist, torch.Tensor): 96 | target_dist = torch.tensor(target_dist) 97 | 98 | target_dist = target_dist.to(device) 99 | 100 | if extrinsic is None: 101 | obj_translations = torch.stack(geometry.estimate_translation(depth=images, 102 | mask=images_mask, 103 | intrinsic=intrinsic), dim=1).to(device) 104 | if dz is not None: 105 | obj_translations[:, 2] = dz.to(device) 106 | else: 107 | extrinsic = extrinsic.to(device) 108 | obj_translations = extrinsic[:, :3, 3] # N x 3 109 | 110 | obj_zs = obj_translations[:, 2] 111 | 112 | if normalize: 113 | images -= images_mask * obj_zs[..., None, None, None].to(device) 114 | 115 | if extrinsic is None: 116 | cameras = geometry.Camera(intrinsic=intrinsic, height=height, width=width) 117 | obj_centroids = geometry.masks_to_centroids(images_mask) 118 | zoom_images, zoom_camera = cameras.zoom(image=images, 119 | target_dist=target_dist, 120 | target_size=target_size, 121 | zs=obj_zs, 122 | centroid_uvs=obj_centroids, 123 | scale_mode=scale_mode) 124 | # zoom_masks, _ = cameras.zoom(image=images_mask, 125 | # target_dist=target_dist, 126 | # target_size=target_size, 127 | # zs=obj_zs, 128 | # centroid_uvs=obj_centroids, 129 | # scale_mode=scale_mode) 130 | else: 131 | cameras = geometry.Camera(intrinsic=intrinsic, extrinsic=extrinsic, width=width, height=height) 132 | zoom_images, zoom_camera = cameras.zoom(images, 133 | target_dist=target_dist, 134 | target_size=target_size, 135 | scale_mode=scale_mode) 136 | # zoom_masks, _ = cameras.zoom(images_mask, 137 | # target_dist=target_dist, 138 | # target_size=target_size, 139 | # scale_mode=scale_mode) 140 | return zoom_images, zoom_camera, obj_translations 141 | 142 | 143 | def inplane_residual_theta(gt_t, init_t, gt_Rz, config, target_dist, device): 144 | """ 145 | gt_t(Nx3): the ground truth translation 146 | est_t(Nx3: the initial translation (directly estimated from depth) 147 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 148 | 149 | return: the relative transformation between the anchor image and the query image 150 | 151 | """ 152 | W = config.RENDER_WIDTH 153 | H = config.RENDER_HEIGHT 154 | fx = config.INTRINSIC[0, 0] 155 | fy = config.INTRINSIC[1, 1] 156 | cx = config.INTRINSIC[0, 2] 157 | cy = config.INTRINSIC[1, 2] 158 | 159 | gt_t = gt_t.clone().to(device) # Nx3 160 | init_t = init_t.clone().to(device) # Nx3 161 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 162 | 163 | gt_tx = gt_t[:, 0:1] 164 | gt_ty = gt_t[:, 1:2] 165 | gt_tz = gt_t[:, 2:3] 166 | 167 | init_tx = init_t[:, 0:1] 168 | init_ty = init_t[:, 1:2] 169 | init_tz = init_t[:, 2:3] 170 | 171 | if not isinstance(target_dist, torch.Tensor): 172 | target_dist = torch.tensor(target_dist) 173 | if target_dist.dim() == 1: 174 | target_dist = target_dist[..., None] # Nx1 175 | if target_dist.dim() != 0: 176 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 177 | 178 | init_scale = target_dist.to(device) / init_tz # Nx1 / config.ZOOM_CROP_SIZE 179 | 180 | gt_t[:, 0:1] = (gt_tx / gt_tz * fx + cx) / W # Nx1 * gt_scale # projection to 2D image plane 181 | gt_t[:, 1:2] = (gt_ty / gt_tz * fy + cy) / H # Nx1 * gt_scale 182 | 183 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) / W # Nx1 * init_scale 184 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) / H # Nx1 * init_scale 185 | 186 | offset_t = gt_t - init_t # N x 3 [dx, dy, dz] unit with (pixel, pixel, meter) 187 | offset_t[:, :2] = offset_t[:, :2] * init_scale 188 | 189 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 190 | res_T[:, :2, :2] = Rz_rot 191 | res_T[:, :3, 2] = offset_t 192 | 193 | return res_T 194 | 195 | 196 | def spatial_transform_2D(x, theta, mode='nearest', padding_mode='border', align_corners=False): 197 | assert(x.dim()==3 or x.dim()==4) 198 | assert(theta.dim()==2 or theta.dim()==3) 199 | assert(theta.shape[-2]==2 and theta.shape[-1]==3), "theta must be Nx2x3" 200 | if x.dim() == 3: 201 | x = x[None, ...] 202 | if theta.dim() == 2: 203 | theta = theta[None, ...].repeat(x.size(0), 1, 1) 204 | 205 | stn_theta = theta.clone() 206 | stn_theta[:, :2, :2] = theta[:, :2, :2].transpose(-1, -2) 207 | stn_theta[:, :2, 2:3] = -(stn_theta[:, :2, :2] @ stn_theta[:, :2, 2:3]) 208 | 209 | grid = F.affine_grid(stn_theta.to(x.device), x.shape, align_corners=align_corners) 210 | new_x = F.grid_sample(x.type(grid.dtype), grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 211 | return new_x 212 | 213 | 214 | def recover_full_translation(init_t, offset_t, config, target_dist, device): 215 | W = config.RENDER_WIDTH 216 | H = config.RENDER_HEIGHT 217 | fx = config.INTRINSIC[0, 0] 218 | fy = config.INTRINSIC[1, 1] 219 | 220 | dx = offset_t[:, 0:1].to(device) # Bx1 221 | dy = offset_t[:, 1:2].to(device) # Bx1 222 | dz = offset_t[:, 2:3].to(device) # Bx1 223 | 224 | init_tx = init_t[:, 0:1].to(device) # Bx1 225 | init_ty = init_t[:, 1:2].to(device) # Bx1 226 | init_tz = init_t[:, 2:3].to(device) # Bx1 227 | 228 | if not isinstance(target_dist, torch.Tensor): 229 | target_dist = torch.tensor(target_dist) 230 | if target_dist.dim() == 1: 231 | target_dist = target_dist[..., None] # Nx1 232 | if target_dist.dim() != 0: 233 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 234 | 235 | init_scale = target_dist.to(device) / init_tz #/ config.ZOOM_CROP_SIZE 236 | 237 | est_tz = init_tz + dz.to(device) 238 | est_tx = est_tz * (W / init_scale / fx * dx + init_tx/init_tz) # Nx1 239 | est_ty = est_tz * (H / init_scale / fy * dy + init_ty/init_tz) 240 | 241 | # print(est_tx.shape, est_ty.shape, est_tz.shape) 242 | 243 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 244 | 245 | return est_full_t 246 | 247 | 248 | def residual_inplane_transform(gt_t, init_t, gt_Rz, config, target_dist, device): 249 | """ 250 | gt_t(Nx3): the ground truth translation 251 | est_t(Nx3: the initial translation (directly estimated from depth) 252 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 253 | return: the relative transformation between the anchor image and the query image 254 | """ 255 | # W = config.RENDER_WIDTH 256 | # H = config.RENDER_HEIGHT 257 | fx = config.INTRINSIC[0, 0] 258 | fy = config.INTRINSIC[1, 1] 259 | cx = config.INTRINSIC[0, 2] 260 | cy = config.INTRINSIC[1, 2] 261 | 262 | gt_t = gt_t.clone().to(device) # Nx3 263 | init_t = init_t.clone().to(device) # Nx3 264 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 265 | 266 | gt_tx = gt_t[:, 0:1] 267 | gt_ty = gt_t[:, 1:2] 268 | gt_tz = gt_t[:, 2:3] 269 | 270 | init_tx = init_t[:, 0:1] 271 | init_ty = init_t[:, 1:2] 272 | init_tz = init_t[:, 2:3] 273 | 274 | if not isinstance(target_dist, torch.Tensor): 275 | target_dist = torch.tensor(target_dist) 276 | if target_dist.dim() == 1: 277 | target_dist = target_dist[..., None] # Nx1 278 | if target_dist.dim() != 0: 279 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 280 | target_dist = target_dist.to(device) 281 | 282 | tz_offset_frac = gt_tz / init_tz # gt_tz = tz_factor * init_tz, the ratio bwteen the ground truth distance and initial distance 283 | 284 | gt_t[:, 0:1] = (gt_tx / gt_tz * fx + cx) # Nx1, pixel coordinate projected on 2D image plane 285 | gt_t[:, 1:2] = (gt_ty / gt_tz * fy + cy) # Nx1 286 | 287 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) # Nx1 288 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) # Nx1 289 | 290 | gt_crop_scaling = target_dist / gt_tz # the scaling factor for the cropped object patch 291 | # init_crop_scaling = target_dist / init_tz 292 | 293 | gt_bbox_size = gt_crop_scaling * config.ZOOM_SIZE # the bbox size of the cropped object with gt distance 294 | # init_bbox_size = gt_bbox_size * tz_offset_frac 295 | 296 | delta_px = gt_tx - init_tx # from source image center to target image center 297 | delta_py = gt_ty - init_ty # from source image center to target image center 298 | 299 | px_offset_frac = delta_px / gt_bbox_size # convert the offset relative to the target image size 300 | py_offset_frac = delta_py / gt_bbox_size # convert the offset relative to the target image size 301 | 302 | offset_t = torch.cat([px_offset_frac, py_offset_frac, tz_offset_frac], dim=1) 303 | 304 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 305 | res_T[:, :2, :2] = Rz_rot 306 | res_T[:, :3, 2] = offset_t 307 | 308 | return res_T 309 | 310 | 311 | def recover_residual_translation(init_t, offset_t, config, target_dist, device): 312 | # W = config.RENDER_WIDTH 313 | # H = config.RENDER_HEIGHT 314 | fx = config.INTRINSIC[0, 0] 315 | fy = config.INTRINSIC[1, 1] 316 | cx = config.INTRINSIC[0, 2] 317 | cy = config.INTRINSIC[1, 2] 318 | 319 | init_t = init_t.clone().to(device) # Nx3 320 | offset_t = offset_t.clone().to(device) # Nx3 321 | 322 | init_tx = init_t[:, 0:1] # Bx1 323 | init_ty = init_t[:, 1:2] # Bx1 324 | init_tz = init_t[:, 2:3] # Bx1 325 | 326 | px_offset_frac = offset_t[:, 0:1] # Bx1 327 | py_offset_frac = offset_t[:, 1:2] # Bx1 328 | tz_offset_frac = offset_t[:, 2:3] # Bx1 329 | 330 | init_t[:, 0:1] = (init_tx / init_tz * fx + cx) # Nx1 * init_scale 331 | init_t[:, 1:2] = (init_ty / init_tz * fy + cy) # Nx1 * init_scale 332 | 333 | if not isinstance(target_dist, torch.Tensor): 334 | target_dist = torch.tensor(target_dist) 335 | if target_dist.dim() == 1: 336 | target_dist = target_dist[..., None] # Nx1 337 | if target_dist.dim() != 0: 338 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 339 | target_dist = target_dist.to(device) 340 | 341 | init_crop_scaling = target_dist / init_tz 342 | init_bbox_size = init_crop_scaling * config.ZOOM_SIZE 343 | pd_bbox_size = init_bbox_size / tz_offset_frac 344 | 345 | pd_delta_px = px_offset_frac * pd_bbox_size 346 | pd_delta_py = py_offset_frac * pd_bbox_size 347 | 348 | pd_px = init_t[:, 0:1] + pd_delta_px 349 | pd_py = init_t[:, 1:2] + pd_delta_py 350 | 351 | est_tz = tz_offset_frac * init_tz 352 | 353 | # est_tz = init_tz + tz_offset_frac * init_tz 354 | 355 | 356 | est_tx = (pd_px - cx) / fx * est_tz 357 | est_ty = (pd_py - cy) / fy * est_tz 358 | 359 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 360 | 361 | return est_full_t 362 | 363 | 364 | def residual_inplane_transform3(gt_t, init_t, gt_Rz, config, target_dist, device): 365 | """ 366 | gt_t(Nx3): the ground truth translation 367 | est_t(Nx3: the initial translation (directly estimated from depth) 368 | gt_Rz(Nx3x3): the ground truth relative in-plane rotation along camera optical axis 369 | return: the relative transformation between the anchor image and the query image 370 | """ 371 | # W = config.RENDER_WIDTH 372 | # H = config.RENDER_HEIGHT 373 | fx = config.INTRINSIC[0, 0] 374 | fy = config.INTRINSIC[1, 1] 375 | cx = config.INTRINSIC[0, 2] 376 | cy = config.INTRINSIC[1, 2] 377 | 378 | gt_t = gt_t.clone().to(device) # Nx3 379 | init_t = init_t.clone().to(device) # Nx3 380 | Rz_rot = gt_Rz[:, :2, :2].clone().to(device) # Nx2x2 381 | 382 | gt_tx = gt_t[:, 0:1] 383 | gt_ty = gt_t[:, 1:2] 384 | gt_tz = gt_t[:, 2:3] 385 | 386 | init_tx = init_t[:, 0:1] 387 | init_ty = init_t[:, 1:2] 388 | init_tz = init_t[:, 2:3] 389 | 390 | # tz_offset_frac = (gt_tz - init_tz) / init_tz # gt_tz = init_tz + tz_offset_frac * init_tz 391 | 392 | tz_offset_frac = (gt_tz - init_tz)# / init_tz # gt_tz = init_tz + tz_offset_frac * init_tz 393 | 394 | if not isinstance(target_dist, torch.Tensor): 395 | target_dist = torch.tensor(target_dist) 396 | if target_dist.dim() == 1: 397 | target_dist = target_dist[..., None] # Nx1 398 | if target_dist.dim() != 0: 399 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 400 | target_dist = target_dist.to(device) 401 | 402 | # object GT 2D center in image 403 | gt_px = (gt_tx / gt_tz * fx + cx) # Nx1, pixel x-coordinate of the object gt_center 404 | gt_py = (gt_ty / gt_tz * fy + cy) # Nx1, pixel y-coordinate of the object gt_center 405 | 406 | # object initial 2D center in image 407 | init_px = (init_tx / init_tz * fx + cx) # Nx1 408 | init_py = (init_ty / init_tz * fy + cy) # Nx1 409 | 410 | offset_px = gt_px - init_px # from source image center to target image center 411 | offset_py = gt_py - init_py # from source image center to target image center 412 | 413 | # gt_box_size = 1.0 * target_dist / gt_tz * config.ZOOM_SIZE # cropped patch size with the gt depth 414 | init_box_size = 1.0 * target_dist / init_tz * config.ZOOM_SIZE # cropped patch size with the estimated depth 415 | 416 | px_offset_frac = offset_px / (init_box_size / 2.0) 417 | py_offset_frac = offset_py / (init_box_size / 2.0) 418 | 419 | offset_t = torch.cat([px_offset_frac, py_offset_frac, tz_offset_frac], dim=1) 420 | 421 | res_T = torch.zeros((gt_t.size(0), 3, 3), device=device) # Nx3x3 422 | res_T[:, :2, :2] = Rz_rot 423 | res_T[:, :3, 2] = offset_t 424 | 425 | return res_T 426 | 427 | 428 | def recover_residual_translation3(init_t, offset_t, config, target_dist, device): 429 | # W = config.RENDER_WIDTH 430 | # H = config.RENDER_HEIGHT 431 | fx = config.INTRINSIC[0, 0] 432 | fy = config.INTRINSIC[1, 1] 433 | cx = config.INTRINSIC[0, 2] 434 | cy = config.INTRINSIC[1, 2] 435 | 436 | init_t = init_t.clone().to(device) # Nx3 437 | offset_t = offset_t.clone().to(device) # Nx3 438 | 439 | init_tx = init_t[:, 0:1] # Bx1 440 | init_ty = init_t[:, 1:2] # Bx1 441 | init_tz = init_t[:, 2:3] # Bx1 442 | 443 | px_offset_frac = offset_t[:, 0:1] # Bx1 444 | py_offset_frac = offset_t[:, 1:2] # Bx1 445 | tz_offset_frac = offset_t[:, 2:3] # Bx1 446 | 447 | init_px = (init_tx / init_tz * fx + cx) # Nx1 * init_scale 448 | init_py = (init_ty / init_tz * fy + cy) # Nx1 * init_scale 449 | 450 | if not isinstance(target_dist, torch.Tensor): 451 | target_dist = torch.tensor(target_dist) 452 | if target_dist.dim() == 1: 453 | target_dist = target_dist[..., None] # Nx1 454 | if target_dist.dim() != 0: 455 | assert(target_dist.dim() == init_tz.dim()), "shape must be same, however, {}, {}".format(target_dist.shape, init_tz.shape) 456 | target_dist = target_dist.to(device) 457 | 458 | init_box_size = 1.0 * target_dist / init_tz * config.ZOOM_SIZE # cropped patch size with the estimated depth 459 | 460 | est_px = init_px + px_offset_frac / 2.0 * init_box_size 461 | est_py = init_py + py_offset_frac / 2.0 * init_box_size 462 | 463 | est_tz = init_tz + tz_offset_frac # * init_tz 464 | est_tx = (est_px - cx) / fx * est_tz 465 | est_ty = (est_py - cy) / fy * est_tz 466 | 467 | est_full_t = torch.cat([est_tx, est_ty, est_tz], dim=1) # Nx3 468 | 469 | return est_full_t 470 | 471 | 472 | def dynamic_margin(x_vp, y_vp, max_margin=0.5, threshold_angle=math.pi/2): 473 | """ 474 | given two viewpoint vector (Nx3), calcuate the dynamic margin for the triplet loss 475 | """ 476 | assert(max_margin>=0 and max_margin<=1), "maximum margin must be between (0, 1)" 477 | vp_cosim = (x_vp * y_vp).sum(dim=1, keepdim=True) # Nx1 478 | vp_angle = torch.arccos(vp_cosim) 479 | threshold = torch.ones_like(vp_cosim) * threshold_angle 480 | vp_cosim[vp_angle>threshold] = 0.0 481 | dynamic_margin = max_margin * (1 - vp_cosim) # smaller margin for more similar viewpoint pairs 482 | return dynamic_margin 483 | 484 | -------------------------------------------------------------------------------- /lib/geometry.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is borrowed from LatentFusion https://github.com/NVlabs/latentfusion/blob/master/latentfusion/modules/geometry.py 3 | """ 4 | import torch 5 | from skimage import morphology 6 | from torch.nn import functional as F 7 | from lib import three 8 | 9 | 10 | def inplane_2D_spatial_transform(R, img, mode='nearest', padding_mode='border', align_corners=False): 11 | if R.dim() == 2: 12 | R = R[None, ...] 13 | Rz = R[:, :2, :2].transpose(-1, -2).clone() 14 | 15 | if img.dim() == 2: 16 | img = img[None, None, ...] 17 | if img.dim() == 3: 18 | img = img[None, ...] 19 | theta = F.pad(Rz, (0, 1)) 20 | grid = F.affine_grid(theta.to(img.device), img.shape, align_corners=align_corners) 21 | new_img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 22 | return new_img 23 | 24 | 25 | # @torch.jit.script 26 | def masks_to_viewports(masks, pad: float = 10): 27 | viewports = [] 28 | padding = torch.tensor([-pad, -pad, pad, pad], dtype=torch.float32, device=masks.device) 29 | 30 | for mask in masks: 31 | if mask.sum() == 0: 32 | height, width = mask.shape[-2:] 33 | viewport = torch.tensor([0, 0, width, height], dtype=torch.float32, device=masks.device) 34 | else: 35 | coords = torch.nonzero(mask.squeeze()).float() 36 | xmin = coords[:, 1].min() 37 | ymin = coords[:, 0].min() 38 | xmax = coords[:, 1].max() 39 | ymax = coords[:, 0].max() 40 | viewport = torch.stack([xmin, ymin, xmax, ymax]) 41 | viewport = viewport + padding 42 | viewports.append(viewport) 43 | 44 | return torch.stack(viewports, dim=0) 45 | 46 | # @torch.jit.script 47 | def masks_to_centroids(masks): 48 | viewports = masks_to_viewports(masks, 0.0) 49 | cu = (viewports[:, 2] + viewports[:, 0]) / 2.0 50 | cv = (viewports[:, 3] + viewports[:, 1]) / 2.0 51 | 52 | return torch.stack((cu, cv), dim=-1) 53 | 54 | 55 | def _erode_mask(mask, size=5): 56 | device = mask.device 57 | eroded = mask.cpu().squeeze(0).numpy() 58 | eroded = morphology.binary_erosion(eroded, selem=morphology.disk(size)) 59 | eroded = torch.tensor(eroded, device=device, dtype=torch.bool).unsqueeze(0) 60 | if len(eroded) < 10: 61 | return mask 62 | return eroded 63 | 64 | 65 | def _reject_outliers(data, m=1.5): 66 | mask = torch.abs(data - torch.median(data)) < m * torch.std(data) 67 | num_rejected = (~mask).sum().item() 68 | return data[mask], num_rejected 69 | 70 | 71 | def _reject_outliers_med(data, m=2.0): 72 | median = data.median() 73 | med = torch.median(torch.abs(data - median)) 74 | mask = torch.abs(data - median) / med < m 75 | num_rejected = (~mask).sum().item() 76 | return data[mask], num_rejected 77 | 78 | 79 | def estimate_camera_dist(depth, mask): 80 | num_batch = depth.shape[0] 81 | zs = torch.zeros(num_batch, device=depth.device) 82 | mask = mask.bool() 83 | for i in range(num_batch): 84 | _mask = _erode_mask(mask[i], size=3) # smooth mask, e.g. hole filling 85 | depth_vals = depth[i][_mask & (depth[i] > 0.0)] 86 | if len(depth_vals) > 0: 87 | depth_vals, num_rejected = _reject_outliers_med(depth_vals, m=3.0) 88 | if len(depth_vals) > 0: 89 | _min = depth_vals.min() 90 | _max = depth_vals.max() 91 | else: 92 | depth_vals = depth[i][_mask & (depth[i] > 0.0)] 93 | _min = depth_vals.min() 94 | _max = depth_vals.max() 95 | else: 96 | depth_vals = depth[i][depth[i] > 0.0] 97 | if len(depth_vals) > 0: 98 | _min = depth_vals.min() 99 | _max = depth_vals.max() 100 | else: 101 | _min = 1.0 102 | _max = 1.0 103 | zs[i] = (_min + _max) / 2.0 104 | return zs 105 | 106 | 107 | def estimate_translation(depth, mask, intrinsic): 108 | 109 | depth, _ = three.ensure_batch_dim(depth, num_dims=3) 110 | mask, _ = three.ensure_batch_dim(mask, num_dims=3) 111 | z_cam = estimate_camera_dist(depth, mask) 112 | centroid_uv = masks_to_centroids(mask) 113 | 114 | u0 = intrinsic[..., 0, 2] 115 | v0 = intrinsic[..., 1, 2] 116 | fu = intrinsic[..., 0, 0] 117 | fv = intrinsic[..., 1, 1] 118 | x_cam = (centroid_uv[:, 0] - u0) / fu * z_cam 119 | y_cam = (centroid_uv[:, 1] - v0) / fv * z_cam 120 | 121 | return x_cam, y_cam, z_cam 122 | 123 | 124 | def _grid_sample(tensor, grid, **kwargs): 125 | return F.grid_sample(tensor.float(), grid.float(),align_corners=False, **kwargs) 126 | 127 | 128 | # @torch.jit.script 129 | def bbox_to_grid(bbox, in_size, out_size): 130 | h = in_size[0] 131 | w = in_size[1] 132 | xmin = bbox[0].item() 133 | ymin = bbox[1].item() 134 | xmax = bbox[2].item() 135 | ymax = bbox[3].item() 136 | grid_y, grid_x = torch.meshgrid([ 137 | torch.linspace(ymin / h, ymax / h, out_size[0], device=bbox.device) * 2 - 1, 138 | torch.linspace(xmin / w, xmax / w, out_size[1], device=bbox.device) * 2 - 1, 139 | ]) 140 | return torch.stack((grid_x, grid_y), dim=-1) 141 | 142 | 143 | # @torch.jit.script 144 | def bboxes_to_grid(boxes, in_size, out_size): 145 | grids = torch.zeros(boxes.size(0), out_size[1], out_size[0], 2, device=boxes.device) 146 | for i in range(boxes.size(0)): 147 | box = boxes[i] 148 | grids[i, :, :, :] = bbox_to_grid(box, in_size, out_size) 149 | return grids 150 | 151 | 152 | class Camera(torch.nn.Module): 153 | def __init__(self, intrinsic, extrinsic=None, viewport=None, width=640, height=480, rotation=None, translation=None): 154 | super().__init__() 155 | if intrinsic.dim() == 2: 156 | intrinsic = intrinsic.unsqueeze(0) 157 | if intrinsic.shape[1] == 3 and intrinsic.shape[2] == 3: 158 | intrinsic = three.rigid.intrinsic_to_3x4(intrinsic) 159 | 160 | if viewport is None: 161 | viewport = (torch.tensor((0, 0, width, height), dtype=torch.float32).view(1, 4).expand(intrinsic.shape[0], -1)) 162 | if viewport.dim() == 1: 163 | viewport = viewport.unsqueeze(0) 164 | 165 | self.width = width 166 | self.height = height 167 | self.register_buffer('viewport', viewport.to(intrinsic.device)) # Nx4 168 | self.register_buffer('intrinsic', intrinsic) # Nx3x4 matrix 169 | 170 | if extrinsic is not None: 171 | if extrinsic.dim() == 2: 172 | extrinsic = extrinsic.unsqueeze(0) # Nx4x4 173 | homo_rotation_mat, homo_translation_mat = three.rigid.decompose(extrinsic) 174 | rotation = homo_rotation_mat[:, :3, :3].contiguous() # Nx3x3 175 | translation = homo_translation_mat[:, :3, -1].contiguous() # Nx3 176 | 177 | # if translation is None: 178 | # raise ValueError("translation must be given through extrinsic or explicitly.") 179 | # elif translation.dim() == 1: 180 | # translation = translation.unsqueeze(0) 181 | 182 | if translation is not None and translation.dim() == 1: 183 | translation = translation.unsqueeze(0) 184 | 185 | 186 | # if rotation is None: 187 | # raise ValueError("rotation must be given through extrinsic or explicitly.") 188 | # elif rotation.dim() == 2: 189 | # rotation = rotation.unsqueeze(0) # Nx3x3 190 | 191 | if rotation is not None and rotation.dim() == 2: 192 | rotation = rotation.unsqueeze(0) # Nx3x3 193 | if translation is not None: 194 | self.register_buffer('translation', translation.to(intrinsic.device)) 195 | else: 196 | self.register_buffer('translation', None) 197 | if rotation is not None: 198 | self.register_buffer('rotation', rotation.to(intrinsic.device)) 199 | else: 200 | self.register_buffer('rotation', None) 201 | 202 | 203 | 204 | def to_kwargs(self): 205 | return { 206 | 'intrinsic': self.intrinsic, 207 | 'extrinsic': self.extrinsic, 208 | 'viewport': self.viewport, 209 | 'height': self.height, 210 | 'width': self.width, 211 | } 212 | 213 | @classmethod 214 | def from_kwargs(self, kwargs): 215 | _kwargs = {} 216 | for k, v in kwargs.items(): 217 | if isinstance(v, list): 218 | _kwargs[k] = torch.tensor(v, dtype=torch.float32) 219 | else: 220 | _kwargs[k] = v 221 | return cls(**_kwargs) 222 | 223 | @property 224 | def device(self): 225 | return self.intrinsic.device 226 | 227 | @property 228 | def translation_matrix(self): 229 | eye = torch.eye(4, device=self.translation.device) 230 | homo_translation_mat = F.pad(self.translation.unsqueeze(2), (3, 0, 0, 1)) # Nx3 ==> Nx4x4 231 | homo_translation_mat += eye 232 | return homo_translation_mat 233 | 234 | @property 235 | def rotation_matrix(self): 236 | homo_rotation_mat = F.pad(self.rotation, (0, 1, 0, 1)) # Nx3x3==> Nx4x4 237 | homo_rotation_mat[:, -1, -1] = 1.0 238 | return homo_rotation_mat 239 | 240 | @property 241 | def extrinsic(self): 242 | homo_extrinsic_mat = self.translation_matrix @ self.rotation_matrix 243 | return homo_extrinsic_mat 244 | 245 | @extrinsic.setter 246 | def extrinsic(self, extrinsic): 247 | homo_rotation_mat, homo_translation_mat = three.rigid.decompose(extrinsic) 248 | rotation = homo_rotation_mat[:, :3, :3].contiguous() # Nx3x3 249 | translation = homo_translation_mat[:, :3, -1].contiguous() # Nx3 250 | self.rotation.copy_(rotation) 251 | self.translation.copy_(translation) 252 | 253 | @property 254 | def inv_translation_matrix(self): 255 | eye = torch.eye(4, device=self.translation.device) 256 | homo_inv_translation_mat = F.pad(-self.translation.unsqueeze(2), (3, 0, 0, 1)) 257 | homo_inv_translation_mat += eye 258 | return homo_inv_translation_mat 259 | 260 | @property 261 | def inv_intrinsic(self): 262 | return torch.inverse(self.intrinsic[:, :3, :3]) 263 | 264 | @property 265 | def viewport_height(self): 266 | return self.viewport[:, 3] - self.viewport[:, 1] 267 | 268 | @property 269 | def viewport_width(self): 270 | return self.viewport[:, 2] - self.viewport[:, 0] 271 | 272 | @property 273 | def viewport_centroid(self): 274 | cx = (self.viewport[:, 2] + self.viewport[:, 0]) / 2.0 275 | cy = (self.viewport[:, 3] + self.viewport[:, 1]) / 2.0 276 | return torch.stack((cx, cy), dim=-1) # N x 2 277 | 278 | @property 279 | def u0(self): 280 | return self.intrinsic[:, 0, 2] 281 | 282 | @property 283 | def v0(self): 284 | return self.intrinsic[:, 1, 2] 285 | 286 | @property 287 | def fu(self): 288 | return self.intrinsic[:, 0, 0] 289 | 290 | @property 291 | def fv(self): 292 | return self.intrinsic[:, 1, 1] 293 | 294 | @property 295 | def fov_u(self): 296 | return torch.atan2(self.fu, self.viewport_width / 2.0) 297 | 298 | @property 299 | def fov_v(self): 300 | return torch.atan2(self.fv, self.viewport_height / 2.0) 301 | 302 | @property 303 | def obj_to_cam(self): 304 | return self.translation_matrix @ self.rotation_matrix # Nx4x4, i.e. camera extrinsic or object pose 305 | 306 | @property 307 | def cam_to_obj(self): 308 | return self.rotation_matrix.transpose(2, 1) @ self.inv_translation_matrix # Nx4x4 309 | 310 | @property 311 | def obj_to_image(self): 312 | """ 313 | projection onto image plane based on camera intrinsic 314 | """ 315 | return self.intrinsic @ self.obj_to_cam # Nx3x4, projection 316 | 317 | @property 318 | def position(self): 319 | """ 320 | obtain camera position based on camera extrinsic 321 | """ 322 | # C = (-R^T)*t 323 | cam_position = -self.rotation_matrix[:, :3, :3].transpose(2, 1) @ self.translation_matrix[:, :3, 3, None] 324 | cam_position = cam_position.squeeze(-1) # Nx3x1 ==> Nx3 325 | return cam_position 326 | @property 327 | def direction(self): 328 | """ 329 | the direction of the vector from object center to camera center, i.e. normalize camera postion 330 | """ 331 | vector_direction = self.position / torch.norm(self.position, dim=1, p=2, keepdim=True) # Nx3 332 | return vector_direction 333 | 334 | @property 335 | def length(self): 336 | return self.intrinsic.shape[0] 337 | 338 | def rotate(self, rotation): 339 | rotation, unsqueezed = three.core.ensure_batch_dim(rotation, 2) 340 | if rotation.shape[0] == 1: 341 | rotation = rotation.expand_as(self.rotation) 342 | self.rotation = rotation @ self.rotation 343 | return self 344 | 345 | def translate(self, offset): 346 | """ 347 | move postion of the camera based on given offset 348 | """ 349 | assert offset.shape[-1] == 3 or offset.shape[-1] ==1, "offset must be an single number or tuple(x, y, z)" 350 | offset, unsqueezed = three.core.ensure_batch_dim(offset, 1) # 3==>1x3 351 | if offset.shape[0] == 1: 352 | offset = offset.expand_as(self.position) # N x 3 353 | homo_position = three.core.homogenize(self.position + offset).unsqueeze(-1) # cam new position, Nx4x1 354 | self.translation = -self.rotation_matrix @ homo_position.squeeze(2) # the relative translation of object 355 | return self 356 | 357 | def zoom(self, image, target_size, target_dist, 358 | zs=None, centroid_uvs=None, target_fu=None, target_fv=None, scale_mode='bilinear'): 359 | """ 360 | zoom the image and crop the image based on the given square size 361 | Args: 362 | image: the target image for zooming transformation 363 | target_size: the target crop image size 364 | target_dist: the target zoom distance from the origin 365 | target_fu: the target horizontal focal length 366 | target_fv: the target vertical focal length 367 | zs: the oringal distance from image to camera 368 | centroid_uvs: the target center for zooming 369 | """ 370 | K = self.intrinsic 371 | fu = K[:, 0, 0] 372 | fv = K[:, 1, 1] 373 | if zs is None: 374 | zs = self.translation_matrix[:, 2, 3] # if not given, set it from camera extrinsic 375 | 376 | if target_fu is None: 377 | target_fu = fu # if not given, set it from camera intrinsic, fx 378 | if target_fv is None: 379 | target_fv = fv # if not given, set it from camera intrinsic, fy 380 | 381 | if centroid_uvs is None: 382 | origin = (torch.tensor((0, 0, 0, 1.0), device=self.device).view(1, -1, 1).expand(self.length, -1, -1)) 383 | uvs = K @ self.obj_to_cam @ origin # center of interest (centered with object) 384 | uvs = (uvs[:, :2] / uvs[:, 2, None]).transpose(2, 1).squeeze(1) 385 | centroid_uvs = uvs.clone().float() 386 | 387 | if isinstance(target_size, torch.Tensor): 388 | target_size = target_size.to(self.device) 389 | 390 | if isinstance(target_dist, torch.Tensor): 391 | target_dist = target_dist.to(self.device) 392 | 393 | bbox_u = 1.0 * target_dist / zs / fu * target_fu * target_size / self.width 394 | bbox_v = 1.0 * target_dist / zs / fv * target_fv * target_size / self.height 395 | 396 | center_u = centroid_uvs[:, 0] / self.width # object center from pixel coordinate to scale ratio 397 | center_v = centroid_uvs[:, 1] / self.height 398 | 399 | boxes = torch.zeros(centroid_uvs.size(0), 4, device=self.device) 400 | boxes[:, 0] = (center_u - bbox_u / 2) * float(self.width) 401 | boxes[:, 1] = (center_v - bbox_v / 2) * float(self.height) 402 | boxes[:, 2] = (center_u + bbox_u / 2) * float(self.width) 403 | boxes[:, 3] = (center_v + bbox_v / 2) * float(self.height) 404 | camera_new = Camera(intrinsic=self.intrinsic, 405 | extrinsic=None, 406 | viewport=boxes, 407 | width=self.width, 408 | height=self.height, 409 | rotation=self.rotation, 410 | translation=self.translation) 411 | if image is None: 412 | return camera_new 413 | 414 | in_size = torch.tensor((self.height, self.width), device=self.device) 415 | out_size = torch.tensor((target_size, target_size), device=self.device) 416 | grids = bboxes_to_grid(boxes, in_size, out_size) 417 | zoomed_image = _grid_sample(image, grids, mode=scale_mode, padding_mode='zeros') 418 | 419 | return zoomed_image, camera_new 420 | 421 | def crop_to_viewport(self, image, target_size, scale_mode='nearest'): 422 | in_size = torch.tensor((self.height, self.width), device=self.device) 423 | out_size = torch.tensor((target_size, target_size), device=self.device) 424 | grid = bboxes_to_grid(self.viewport, in_size, out_size) 425 | return _grid_sample(image, grid, mode=scale_mode) 426 | 427 | def uncrop(self, image, scale_mode='nearest'): 428 | camera_new = Camera(intrinsic=self.intrinsic, 429 | extrinsic=None, 430 | width=self.width, 431 | height=self.height, 432 | rotation=self.rotation, 433 | translation=self.translation) 434 | if image is None: 435 | return camera_new 436 | 437 | yy, xx = torch.meshgrid([torch.arange(0, self.height, device=self.device, dtype=torch.float32), 438 | torch.arange(0, self.width, device=self.device, dtype=torch.float32)]) 439 | yy = yy.unsqueeze(0).expand(image.shape[0], -1, -1) 440 | xx = xx.unsqueeze(0).expand(image.shape[0], -1, -1) 441 | yy = (yy - self.viewport[:, 1, None, None]) / self.viewport_height[:, None, None] * 2 - 1 442 | xx = (xx - self.viewport[:, 0, None, None]) / self.viewport_width[:, None, None] * 2 - 1 443 | grid = torch.stack((xx, yy), dim=-1) 444 | uncroped_image = _grid_sample(image, grid, mode=scale_mode, padding_mode='zeros') 445 | 446 | return uncroped_image, camera_new 447 | 448 | def pixel_coords_uv(self, out_size): 449 | if isinstance(out_size, int): 450 | out_size = (out_size, out_size) 451 | 452 | v_pixel, u_pixel = torch.meshgrid([ 453 | torch.linspace(0.0, 1.0, out_size[0], device=self.device), 454 | torch.linspace(0.0, 1.0, out_size[1], device=self.device), 455 | ]) 456 | 457 | u_pixel = u_pixel.expand(self.length, -1, -1) 458 | u_pixel = (u_pixel * self.viewport_width.view(-1, 1, 1) + self.viewport[:, 0].view(-1, 1, 1)) 459 | v_pixel = v_pixel.expand(self.length, -1, -1) 460 | v_pixel = (v_pixel * self.viewport_height.view(-1, 1, 1) + self.viewport[:, 1].view(-1, 1, 1)) 461 | 462 | return u_pixel, v_pixel 463 | 464 | def depth_camera_coords(self, depth): 465 | u_pixel, v_pixel = self.pixel_coords_uv((depth.shape[-2], depth.shape[-1])) 466 | z_cam = depth.view_as(u_pixel) 467 | 468 | u0 = self.u0.view(-1, 1, 1) 469 | v0 = self.v0.view(-1, 1, 1) 470 | fu = self.fu.view(-1, 1, 1) 471 | fv = self.fv.view(-1, 1, 1) 472 | x_cam = (u_pixel - u0) / fu * z_cam 473 | y_cam = (v_pixel - v0) / fv * z_cam 474 | 475 | return x_cam, y_cam, z_cam 476 | 477 | def __getitem__(self, idx): 478 | return Camera(intrinsic=self.intrinsic[idx], 479 | extrinsic=None, 480 | viewport=self.viewport[idx], 481 | width=self.width, 482 | height=self.height, 483 | rotation=self.rotation[idx], 484 | translation=self.translation[idx]) 485 | 486 | def __setitem__(self, idx, camera): 487 | self.intrinsic[idx] = camera.intrinsic 488 | self.viewport[idx] = camera.viewport 489 | self.rotation[idx] = camera.rotation 490 | self.translation[idx] = camera.translation 491 | 492 | def __len__(self): 493 | return self.length 494 | 495 | def __iter__(self): 496 | cameras = [self[i] for i in range(len(self))] 497 | return iter(cameras) 498 | 499 | def clone(self): 500 | return Camera(self.intrinsic.clone(), 501 | extrinsic=None, 502 | viewport=self.viewport.clone(), 503 | rotation=self.rotation.clone(), 504 | translation=self.translation.clone(), 505 | width=self.width, 506 | height=self.height) 507 | 508 | def detach(self): 509 | return Camera(self.intrinsic.detach(), 510 | extrinsic=None, 511 | viewport=self.viewport.detach(), 512 | rotation=self.rotation.detach(), 513 | translation=self.translation.detach(), 514 | width=self.width, 515 | height=self.height) 516 | def __repr__(self): 517 | return ( 518 | f"Camera(count={self.intrinsic.size(0)})" 519 | ) 520 | 521 | 522 | --------------------------------------------------------------------------------