├── LICENSE ├── README.md ├── datasets ├── get_dataset.py ├── tudl │ ├── models │ │ ├── models_info.json │ │ ├── obj_000001.ply │ │ ├── obj_000002.ply │ │ └── obj_000003.ply │ ├── models_info.pth │ └── test_info.pth └── tudl_db.py ├── modules └── DCP │ ├── dcp.py │ └── rpmnet_emb │ ├── feature_net.py │ └── pointnet_util.py ├── results └── DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine │ ├── eval_results │ ├── model_epoch19_T5_cosine_tudl_000001_noiseTrue_v1.pth │ ├── model_epoch19_T5_cosine_tudl_000002_noiseTrue_v1.pth │ └── model_epoch19_T5_cosine_tudl_000003_noiseTrue_v1.pth │ └── model_epoch19.pth ├── test.py ├── train.py └── utils ├── attr_dict.py ├── commons.py ├── criterion.py ├── data_classes.py ├── diffusion_scheduler.py ├── get_lr_scheduler.py ├── losses.py ├── options.py └── se_math ├── __init__.py ├── invmat.py ├── mesh.py ├── se3.py ├── sinc.py ├── so3.py └── transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jiang-HB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SE(3) Diffusion Model-based Point Cloud Registration for Robust 6D Object Pose Estimation (NeurIPS2023) 2 | 3 | PyTorch implementation of the paper ["SE(3) Diffusion Model-based Point Cloud Registration for Robust 6D Object Pose Estimation"](https://openreview.net/pdf?id=Znpz1sv4IP). 4 | 5 | Haobo Jiang, Mathieu Salzmann, Zheng Dang, Jin Xie, and Jian Yang. 6 | 7 | Here is the [supplementary material](https://openreview.net/attachment?id=Znpz1sv4IP&name=supplementary_material). 8 | 9 | 10 | ## Introduction 11 | 12 | In this paper, we introduce an SE(3) diffusion model-based point cloud registration framework for 6D object pose estimation in real-world scenarios. Our approach formulates the 3D registration task as a denoising diffusion process, which progressively refines the pose of the source point cloud to obtain a precise alignment with the model point cloud. 13 | Training our framework involves two operations: An SE(3) diffusion process and an SE(3) reverse process. The SE(3) diffusion process gradually perturbs the optimal rigid transformation of a pair of point clouds by continuously injecting noise (perturbation transformation). 14 | By contrast, the SE(3) reverse process focuses on learning a denoising network that refines the noisy transformation step-by-step, bringing it closer to the optimal transformation for accurate pose estimation. Unlike standard diffusion models used in linear Euclidean spaces, our diffusion model operates on the SE(3) manifold. 15 | This requires exploiting the linear Lie algebra se(3) associated with SE(3) to constrain the transformation transitions during the diffusion and reverse processes. Additionally, to effectively train our denoising network, we derive a registration-specific variational lower bound as the optimization objective for model learning. 16 | Furthermore, we show that our denoising network can be constructed with a surrogate registration model, making our approach applicable to different deep registration networks. Extensive experiments demonstrate that our diffusion registration framework presents outstanding pose estimation performance on the real-world TUD-L, LINEMOD, and Occluded-LINEMOD datasets. 17 | 18 | 19 | ## Dataset Preprocessing 20 | 21 | ### TUD-L 22 | 23 | The raw data of TUD-L can be downloaded from BOP datasets: [training data](https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/tudl/tudl_train_real.zip), [testing data](https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/tudl/tudl_test_bop19.zip) and [object models](https://huggingface.co/datasets/bop-benchmark/datasets/resolve/main/tudl/tudl_models.zip). 24 | Also, please download pre-processed files: [train_info.pth](https://drive.google.com/file/d/1p07nibykEeVPrXzQf69pWPIAE8GnjDXC/view?usp=sharing), [test_info.pth](https://drive.google.com/file/d/16CeFZ9hfUnh1eoisx7cPzWftEfZCfO9w/view?usp=sharing), and [model_info.pth](https://drive.google.com/file/d/1yFu56Wmr-DFiWmWfYT66SaThmRnaFcAi/view?usp=sharing) 25 | Please put them into the directory: `./datasets/tudl/` as below: 26 | ``` 27 | . 28 | ├── train 29 | │ ├── 000001 30 | │ ├── 000002 31 | │ └── 000003 32 | ├── test 33 | │ ├── 000001 34 | │ ├── 000002 35 | │ └── 000003 36 | ├── models 37 | │ ├── models_info.json 38 | │ ├── obj_000001.ply 39 | │ └── obj_000002.ply 40 | │ └── obj_000003.ply 41 | ├── train_info.pth 42 | ├── test_info.pth 43 | ├── models_info.pth 44 | ``` 45 | 46 | ## Pretrained Model 47 | 48 | We provide the pre-trained model of Diff-DCP on TUD-L dataset in `./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/model_epoch19.pth`. 49 | 50 | ## Instructions to training and testing 51 | 52 | The training and testing can be done by running 53 | ```bash 54 | CUDA_VISIBLE_DEVICES=0 python3 train.py --net_type DiffusionDCP --db_nm tudl 55 | 56 | CUDA_VISIBLE_DEVICES=0 python3 test.py 57 | ``` 58 | 59 | ## Citation 60 | 61 | If you find this project useful, please cite: 62 | 63 | ```bash 64 | @inproceedings{jiang2023se, 65 | title={SE (3) Diffusion Model-based Point Cloud Registration for Robust 6D Object Pose Estimation}, 66 | author={Jiang, Haobo and Salzmann, Mathieu and Dang, Zheng and Xie, Jin and Yang, Jian}, 67 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 68 | year={2023} 69 | } 70 | ``` 71 | 72 | ## Acknowledgments 73 | We thank the authors of 74 | - [DCP](https://github.com/WangYueFt/dcp) 75 | - [RPMNet](https://github.com/yewzijian/RPMNet) 76 | 77 | for open sourcing their methods. 78 | -------------------------------------------------------------------------------- /datasets/get_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets.tudl_db import TUDL_DB_Train, TUDL_DB_Test 2 | from torch.utils.data import DataLoader 3 | 4 | def get_dataset(opts, db_nm, partition, batch_size, shuffle, drop_last, cls_nm=None, n_cores=1): 5 | loader, db = None, None 6 | if db_nm == "tudl": 7 | if partition in ["train"]: 8 | db = TUDL_DB_Train(opts, partition, cls_nm) 9 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores, pin_memory=True) 10 | else: 11 | db = TUDL_DB_Test(opts, partition, cls_nm) 12 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores, pin_memory=True) 13 | return loader, db -------------------------------------------------------------------------------- /datasets/tudl/models/models_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": {"diameter": 430.31, "min_x": -205.087, "min_y": -61.008, "min_z": -121.315, "size_x": 410.174, "size_y": 122.016, "size_z": 242.63}, 3 | "2": {"diameter": 175.704, "min_x": -61.26, "min_y": -68.617, "min_z": -79.646, "size_x": 122.52, "size_y": 137.234, "size_z": 159.292}, 4 | "3": {"diameter": 352.356, "min_x": -51.918, "min_y": -173.763, "min_z": -112.106, "size_x": 103.836, "size_y": 347.526, "size_z": 224.212} 5 | } -------------------------------------------------------------------------------- /datasets/tudl/models_info.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/datasets/tudl/models_info.pth -------------------------------------------------------------------------------- /datasets/tudl/test_info.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/datasets/tudl/test_info.pth -------------------------------------------------------------------------------- /datasets/tudl_db.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, open3d as o3d, copy, pdb 2 | from torch.utils.data import Dataset 3 | from utils.commons import load_data, save_data, regularize_pcd, crop_pcd, depth2pcd, cal_normal 4 | from utils.data_classes import PointCloud 5 | from pyquaternion import Quaternion 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | class BaseTUDLDataset(Dataset): 10 | def __init__(self, opts, partition, cls_nms): 11 | self.opts = opts 12 | self.db_nm = opts.db_nm 13 | self.is_test = opts.is_test 14 | self.partition = partition 15 | self.is_debug = opts.is_debug 16 | self.cls_nms = cls_nms 17 | self.n_points = 1024 18 | self.scale = 256 19 | self.depth_expand_ratio = 0.8 20 | 21 | self.base_dir = "./datasets/tudl/" 22 | self.model_path = "./datasets/tudl/models_info.pth" 23 | self.gen_db_dir = "./datasets/tudl/" 24 | if partition == "train": 25 | self.db_dir = os.path.join(self.base_dir, "train") 26 | self.db_path = "./datasets/tudl/train_info.pth" 27 | else: 28 | self.db_dir = os.path.join(self.base_dir, "test") 29 | self.db_path = "./datasets/tudl/test_info.pth" 30 | 31 | # load info 32 | self.models_info = self.get_models_info() 33 | self.annos_info, self.cropped_depths, self.bboxs, self.annos_list = self.get_annos_info() 34 | 35 | def __len__(self): 36 | return len(self.cropped_depths) 37 | 38 | def getitem(self, idx): 39 | raise NotImplementedError 40 | 41 | def __getitem__(self, idx): 42 | return self.getitem(idx) 43 | 44 | def get_models_info(self): 45 | models_info = load_data(self.model_path) 46 | for cls_nm in self.cls_nms: 47 | model_info = models_info[cls_nm] 48 | model_pcd = PointCloud(np.asarray(o3d.io.read_point_cloud(os.path.join(self.base_dir, model_info["model_pcd_path"])).points).T).points 49 | model_bbox = model_info["bbox_3d"] 50 | 51 | model_bbox.center /= self.scale 52 | model_bbox.wlh /= self.scale 53 | model_pcd /= self.scale 54 | 55 | correct_matrix = np.eye(3) @ np.linalg.inv(model_bbox.rotation_matrix) 56 | model_bbox.rotate(quaternion=Quaternion(matrix=correct_matrix)) 57 | model_pcd = correct_matrix @ model_pcd 58 | 59 | models_info[cls_nm]["model_pcd"] = model_pcd 60 | models_info[cls_nm]["model_bbox"] = model_bbox 61 | models_info[cls_nm]["correct_matrix"] = correct_matrix 62 | 63 | return models_info 64 | 65 | def get_annos_info(self): 66 | db_info = load_data(self.db_path) 67 | annos_info = {} 68 | cropped_depths, bboxs, anno_list = [], [], [] 69 | for cls_nm in self.cls_nms: 70 | annos = [] 71 | vid_db = db_info[cls_nm] 72 | for frame_idx, _ in enumerate(vid_db["frame_idxs"]): 73 | annos.append({ 74 | "db_dir": self.db_dir, 75 | "cls_nm": cls_nm, 76 | "camera_info": vid_db["cameras_info"][frame_idx], 77 | "rgb_path": vid_db["rgb_paths"][frame_idx], 78 | "depth_path": vid_db["depth_paths"][frame_idx], 79 | "mask_path": vid_db["mask_paths"][frame_idx], 80 | "mask_visib_path": vid_db["mask_visib_paths"][frame_idx], 81 | "frame_idx": int(vid_db["frame_idxs"][frame_idx]), 82 | "bbox_obj_2d": vid_db["bboxes_obj_2d"][frame_idx], 83 | "bbox_visib_2d": vid_db["bboxes_visib_2d"][frame_idx], 84 | "bbox_obj_3d": vid_db["bboxes_obj_3d"][frame_idx], 85 | "cam_R_m2c": vid_db["cam_Rs_m2c"][frame_idx], 86 | "cam_t_m2c": vid_db["cam_ts_m2c"][frame_idx], 87 | }) 88 | if self.opts.is_debug: 89 | annos = annos[:100] 90 | annos_info[cls_nm] = annos 91 | anno_list.extend(annos) 92 | 93 | gen_depths_bboxes_path = os.path.join(self.gen_db_dir, "%s_%s_%s_depths_bboxes_tiny%d.pth" % ( 94 | self.db_nm, self.partition, cls_nm, int(self.is_debug))) 95 | if os.path.exists(gen_depths_bboxes_path): 96 | print("Data exists. Loading...") 97 | data_info = load_data(gen_depths_bboxes_path)[cls_nm] 98 | for cropped_depth, bbox in data_info: 99 | cropped_depths.append(cropped_depth) 100 | bbox.center /= self.scale 101 | bbox.wlh /= self.scale 102 | bboxs.append(bbox) 103 | else: 104 | print("Data not exists. Generating...") 105 | data_info = defaultdict(list) 106 | for anno_idx, anno in enumerate(tqdm(annos)): 107 | depth = np.asarray(o3d.io.read_image(os.path.join(anno["db_dir"], anno["depth_path"]))) # [H, W] 108 | mask = np.asarray(o3d.io.read_image(os.path.join(anno["db_dir"], anno["mask_visib_path"]))) # [H, W] 109 | bbox_2d = anno["bbox_obj_2d"] # [x, y, w, h] 110 | min_xy = np.maximum(np.floor(bbox_2d[:2] + bbox_2d[2:] / 2. - bbox_2d[2:] * self.depth_expand_ratio), 0.).astype(np.int32) 111 | max_xy = np.minimum(np.ceil(bbox_2d[:2] + bbox_2d[2:] / 2. + bbox_2d[2:] * self.depth_expand_ratio), 112 | np.asarray([depth.shape[1], depth.shape[0]])).astype(np.int32) 113 | cropped_depth = depth[min_xy[1]: max_xy[1], min_xy[0]: max_xy[0]] 114 | cropped_mask = mask[min_xy[1]: max_xy[1], min_xy[0]: max_xy[0]] 115 | data_info[cls_nm].append([[cropped_depth, min_xy, max_xy, *depth.shape, cropped_mask], anno["bbox_obj_3d"]]) 116 | 117 | save_data(gen_depths_bboxes_path, data_info) 118 | 119 | for idx, (cropped_depth, bbox) in enumerate(data_info[cls_nm]): 120 | cropped_depths.append(cropped_depth) 121 | bbox.center /= self.scale 122 | bbox.wlh /= self.scale 123 | bboxs.append(bbox) 124 | 125 | return annos_info, cropped_depths, bboxs, anno_list 126 | 127 | def depth2pcd_single(self, idx): 128 | depth, min_xy, max_xy, H, W, mask = self.cropped_depths[idx] 129 | pcd = depth2pcd(depth, min_xy, max_xy, H, W, self.annos_list[idx]["camera_info"], scale=self.scale).T # [3, N] 130 | pcd = pcd[:, (mask == 255).astype(np.bool_).reshape([-1])] 131 | pcd = crop_pcd(pcd, self.bboxs[idx], offset=0., scale=2) 132 | return pcd 133 | 134 | def get_trans_gt(self, anno, cls_nm): 135 | R_gt = anno["cam_R_m2c"] 136 | t_gt = anno["cam_t_m2c"] / self.scale 137 | correct_matrix = self.models_info[cls_nm]["correct_matrix"] 138 | R_gt_correct = R_gt @ np.linalg.inv(correct_matrix) 139 | return R_gt_correct, t_gt 140 | 141 | def get_src_pcd(self, src_idx): 142 | src_pcd = self.depth2pcd_single(src_idx) # [3, N] 143 | src_gt_bbox = self.bboxs[src_idx] 144 | if src_pcd.shape[1] <= 50: 145 | return self.getitem(np.random.randint(0, self.__len__())) 146 | src_pcd = regularize_pcd(src_pcd, self.n_points // 2, is_test=self.is_test) # [3, N] 147 | return src_pcd, src_gt_bbox 148 | 149 | def get_model_pcd(self, cls_nm): 150 | model_pcd = regularize_pcd(self.models_info[cls_nm]["model_pcd"], self.n_points, is_test=self.is_test) # [3, M] 151 | model_gt_bbox = self.models_info[cls_nm]["model_bbox"] 152 | return model_pcd, model_gt_bbox 153 | 154 | class TUDL_DB_Train(BaseTUDLDataset): 155 | def __init__(self, opts, partition, cls_nms): 156 | super(TUDL_DB_Train, self).__init__(opts, partition, cls_nms) 157 | self.cls_nms = cls_nms 158 | 159 | def gen_reg_sample(self, res, src_pcd, model_pcd, src_gt_bbox, model_gt_bbox, sample_idx): 160 | 161 | t_center = - np.mean(src_pcd, axis=1) 162 | trans_src_pcd = src_pcd + t_center[:, None] 163 | trans_src_gt_bbox = copy.deepcopy(src_gt_bbox) 164 | trans_src_gt_bbox = trans_src_gt_bbox.translate(t_center) 165 | X, X_BBox = trans_src_pcd, trans_src_gt_bbox 166 | Y = model_pcd 167 | 168 | res["src_pcd"] = X.T.astype(np.float32) 169 | res["model_pcd"] = Y.T.astype(np.float32) 170 | if self.opts.is_normal: 171 | res["src_pcd_normal"] = cal_normal(X.T, radius=self.opts.radius, max_nn=30).astype(np.float32) 172 | res["model_pcd_normal"] = cal_normal(Y.T, radius=self.opts.radius, max_nn=30).astype(np.float32) 173 | 174 | # rotation label 175 | R_ms = X_BBox.rotation_matrix 176 | t_ms = X_BBox.center 177 | R_sm = np.linalg.inv(R_ms) 178 | t_sm = (- R_sm @ t_ms[:, None])[:, 0] 179 | res["transform_gt"] = np.concatenate([R_sm, t_sm[:, None]], axis=1).astype(np.float32) 180 | return res 181 | 182 | def getitem(self, idx): 183 | 184 | curr_anno = self.annos_list[idx] 185 | cls_nm = curr_anno["cls_nm"] 186 | res = {} 187 | 188 | src_pcd, src_gt_bbox = self.get_src_pcd(idx) # [3, N] 189 | model_pcd, model_gt_bbox = self.get_model_pcd(cls_nm) # [3, M] 190 | res = self.gen_reg_sample(res, src_pcd, model_pcd, src_gt_bbox, model_gt_bbox, idx) 191 | return res 192 | 193 | class TUDL_DB_Test(BaseTUDLDataset): 194 | def __init__(self, opts, partition, cls_nms): 195 | super(TUDL_DB_Test, self).__init__(opts, partition, cls_nms) 196 | self.cls_nms = cls_nms 197 | assert len(self.cls_nms) == 1 198 | 199 | def getitem(self, idx): 200 | 201 | anno = self.annos_list[idx] 202 | cls_nm = anno["cls_nm"] 203 | 204 | src_pcd, src_gt_bbox = self.get_src_pcd(idx) # [3, N] 205 | model_pcd, model_gt_bbox = self.get_model_pcd(cls_nm) # [3, M] 206 | 207 | t_ = - np.mean(src_pcd, axis=1) 208 | trans_src_pcd = src_pcd + t_[:, None] 209 | trans_src_gt_bbox = copy.deepcopy(src_gt_bbox) 210 | trans_src_gt_bbox = trans_src_gt_bbox.translate(t_) 211 | R_ms = trans_src_gt_bbox.rotation_matrix 212 | t_ms = trans_src_gt_bbox.center 213 | R_gt, t_gt = self.get_trans_gt(anno, cls_nm) 214 | R_sm = np.linalg.inv(R_ms) 215 | t_sm = (- R_sm @ t_ms[:, None])[:, 0] 216 | 217 | res = { 218 | "src_pcd0": src_pcd.T.astype(np.float32), # [N, 3] 219 | "src_pcd": trans_src_pcd.T.astype(np.float32), # [N, 3] 220 | "model_pcd": model_pcd.T.astype(np.float32), # [M, 3] 221 | "src_pcd_original": trans_src_pcd.T.astype(np.float32), # [N, 3] 222 | "model_pcd_original": model_pcd.T.astype(np.float32), # [N, 3] 223 | "src_center": trans_src_gt_bbox.center, 224 | "model_center": model_gt_bbox.center, 225 | "R_gt_ms": R_ms, 226 | "t_gt_ms": t_ms, 227 | "R_gt_sm": R_sm, 228 | "t_gt_sm": t_sm, 229 | "R_gt": R_gt, # [3, 3] 230 | "t_gt": t_gt, # [3] 231 | "transform_gt": np.concatenate([R_sm, t_sm[:, None]], axis=1).astype(np.float32) 232 | } 233 | 234 | if self.opts.is_normal: 235 | res["src_pcd_normal"] = cal_normal(trans_src_pcd.T, radius=self.opts.radius, max_nn=30).astype(np.float32) 236 | res["model_pcd_normal"] = cal_normal(model_pcd.T, radius=self.opts.radius, max_nn=30).astype(np.float32) 237 | 238 | return res 239 | 240 | def __len__(self): 241 | return len(self.annos_list) -------------------------------------------------------------------------------- /modules/DCP/dcp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import copy 6 | import math 7 | import pdb, numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | def quat2mat(quat): 14 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 15 | 16 | B = quat.size(0) 17 | 18 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 19 | wx, wy, wz = w*x, w*y, w*z 20 | xy, xz, yz = x*y, x*z, y*z 21 | 22 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 23 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 24 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 25 | return rotMat 26 | 27 | 28 | def transform(pts, trans): 29 | """ 30 | Applies the SE3 transformations, support torch.Tensor and np.ndarry. Equation: trans_pts = R @ pts + t 31 | Input 32 | - pts: [num_pts, 3] or [bs, num_pts, 3], pts to be transformed 33 | - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix 34 | Output 35 | - pts: [num_pts, 3] or [bs, num_pts, 3] transformed pts 36 | """ 37 | if len(pts.shape) == 3: 38 | trans_pts = trans[:, :3, :3] @ pts.permute(0,2,1) + trans[:, :3, 3:4] 39 | return trans_pts.permute(0,2,1) 40 | else: 41 | trans_pts = trans[:3, :3] @ pts.T + trans[:3, 3:4] 42 | return trans_pts.T 43 | 44 | def integrate_trans(R, t): 45 | """ 46 | Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry. 47 | Input 48 | - R: [3, 3] or [bs, 3, 3], rotation matrix 49 | - t: [3, 1] or [bs, 3, 1], translation matrix 50 | Output 51 | - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix 52 | """ 53 | if len(R.shape) == 3: 54 | if isinstance(R, torch.Tensor): 55 | trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device) 56 | else: 57 | trans = np.eye(4)[None] 58 | trans[:, :3, :3] = R 59 | trans[:, :3, 3:4] = t.view([-1, 3, 1]) 60 | else: 61 | if isinstance(R, torch.Tensor): 62 | trans = torch.eye(4).to(R.device) 63 | else: 64 | trans = np.eye(4) 65 | trans[:3, :3] = R 66 | trans[:3, 3:4] = t 67 | return trans 68 | 69 | def rigid_transform_3d(A, B, weights=None, weight_threshold=0): 70 | """ 71 | Input: 72 | - A: [bs, num_corr, 3], source point cloud 73 | - B: [bs, num_corr, 3], target point cloud 74 | - weights: [bs, num_corr] weight for each correspondence 75 | - weight_threshold: float, clips points with weight below threshold 76 | Output: 77 | - R, t 78 | """ 79 | bs = A.shape[0] 80 | if weights is None: 81 | weights = torch.ones_like(A[:, :, 0]) 82 | weights[weights < weight_threshold] = 0 83 | # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6) 84 | 85 | # find mean of point cloud 86 | centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / (torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6) 87 | centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / (torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6) 88 | 89 | # subtract mean 90 | Am = A - centroid_A 91 | Bm = B - centroid_B 92 | 93 | # construct weight covariance matrix 94 | Weight = torch.diag_embed(weights) 95 | H = Am.permute(0, 2, 1) @ Weight @ Bm 96 | 97 | # find rotation 98 | U, S, Vt = torch.svd(H.cpu()) 99 | U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device) 100 | delta_UV = torch.det(Vt @ U.permute(0, 2, 1)) 101 | eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device) 102 | eye[:, -1, -1] = delta_UV 103 | R = Vt @ eye @ U.permute(0, 2, 1) 104 | t = centroid_B.permute(0,2,1) - R @ centroid_A.permute(0,2,1) 105 | # warp_A = transform(A, integrate_trans(R,t)) 106 | # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean() 107 | return integrate_trans(R, t) 108 | 109 | def clones(module, N): 110 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 111 | 112 | 113 | def attention(query, key, value, mask=None, dropout=None): 114 | d_k = query.size(-1) 115 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k) 116 | if mask is not None: 117 | scores = scores.masked_fill(mask == 0, -1e9) 118 | p_attn = F.softmax(scores, dim=-1) 119 | return torch.matmul(p_attn, value), p_attn 120 | 121 | 122 | def nearest_neighbor(src, dst): 123 | inner = -2 * torch.matmul(src.transpose(1, 0).contiguous(), dst) # src, dst (num_dims, num_points) 124 | distances = -torch.sum(src ** 2, dim=0, keepdim=True).transpose(1, 0).contiguous() - inner - torch.sum(dst ** 2, 125 | dim=0, 126 | keepdim=True) 127 | distances, indices = distances.topk(k=1, dim=-1) 128 | return distances, indices 129 | 130 | 131 | def knn(x, k): 132 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 133 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 134 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 135 | 136 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 137 | return idx 138 | 139 | 140 | def get_graph_feature(x, k=20): 141 | # x = x.squeeze() 142 | idx = knn(x, k=k) # (batch_size, num_points, k) 143 | batch_size, num_points, _ = idx.size() 144 | device = torch.device('cuda') 145 | 146 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 147 | 148 | idx = idx + idx_base 149 | 150 | idx = idx.view(-1) 151 | 152 | _, num_dims, _ = x.size() 153 | 154 | x = x.transpose(2, 155 | 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 156 | feature = x.view(batch_size * num_points, -1)[idx, :] 157 | feature = feature.view(batch_size, num_points, k, num_dims) 158 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 159 | 160 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2) 161 | 162 | return feature 163 | 164 | 165 | class EncoderDecoder(nn.Module): 166 | """ 167 | A standard Encoder-Decoder architecture. Base for this and many 168 | other models. 169 | """ 170 | 171 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 172 | super(EncoderDecoder, self).__init__() 173 | self.encoder = encoder 174 | self.decoder = decoder 175 | self.src_embed = src_embed 176 | self.tgt_embed = tgt_embed 177 | self.generator = generator 178 | 179 | def forward(self, src, tgt, src_mask, tgt_mask): 180 | "Take in and process masked src and target sequences." 181 | return self.decode(self.encode(src, src_mask), src_mask, 182 | tgt, tgt_mask) 183 | 184 | def encode(self, src, src_mask): 185 | return self.encoder(self.src_embed(src), src_mask) 186 | 187 | def decode(self, memory, src_mask, tgt, tgt_mask): 188 | return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)) 189 | 190 | 191 | class Generator(nn.Module): 192 | def __init__(self, emb_dims): 193 | super(Generator, self).__init__() 194 | self.nn = nn.Sequential(nn.Linear(emb_dims, emb_dims // 2), 195 | nn.BatchNorm1d(emb_dims // 2), 196 | nn.ReLU(), 197 | nn.Linear(emb_dims // 2, emb_dims // 4), 198 | nn.BatchNorm1d(emb_dims // 4), 199 | nn.ReLU(), 200 | nn.Linear(emb_dims // 4, emb_dims // 8), 201 | nn.BatchNorm1d(emb_dims // 8), 202 | nn.ReLU()) 203 | self.proj_rot = nn.Linear(emb_dims // 8, 4) 204 | self.proj_trans = nn.Linear(emb_dims // 8, 3) 205 | 206 | def forward(self, x): 207 | x = self.nn(x.max(dim=1)[0]) 208 | rotation = self.proj_rot(x) 209 | translation = self.proj_trans(x) 210 | rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True) 211 | return rotation, translation 212 | 213 | 214 | class Encoder(nn.Module): 215 | def __init__(self, layer, N): 216 | super(Encoder, self).__init__() 217 | self.layers = clones(layer, N) 218 | self.norm = LayerNorm(layer.size) 219 | 220 | def forward(self, x, mask): 221 | for layer in self.layers: 222 | x = layer(x, mask) 223 | return self.norm(x) 224 | 225 | 226 | class Decoder(nn.Module): 227 | "Generic N layer decoder with masking." 228 | 229 | def __init__(self, layer, N): 230 | super(Decoder, self).__init__() 231 | self.layers = clones(layer, N) 232 | self.norm = LayerNorm(layer.size) 233 | 234 | def forward(self, x, memory, src_mask, tgt_mask): 235 | for layer in self.layers: 236 | x = layer(x, memory, src_mask, tgt_mask) 237 | return self.norm(x) 238 | 239 | 240 | class LayerNorm(nn.Module): 241 | def __init__(self, features, eps=1e-6): 242 | super(LayerNorm, self).__init__() 243 | self.a_2 = nn.Parameter(torch.ones(features)) 244 | self.b_2 = nn.Parameter(torch.zeros(features)) 245 | self.eps = eps 246 | 247 | def forward(self, x): 248 | mean = x.mean(-1, keepdim=True) 249 | std = x.std(-1, keepdim=True) 250 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 251 | 252 | 253 | class SublayerConnection(nn.Module): 254 | def __init__(self, size, dropout=None): 255 | super(SublayerConnection, self).__init__() 256 | self.norm = LayerNorm(size) 257 | 258 | def forward(self, x, sublayer): 259 | return x + sublayer(self.norm(x)) 260 | 261 | 262 | class EncoderLayer(nn.Module): 263 | def __init__(self, size, self_attn, feed_forward, dropout): 264 | super(EncoderLayer, self).__init__() 265 | self.self_attn = self_attn 266 | self.feed_forward = feed_forward 267 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 268 | self.size = size 269 | 270 | def forward(self, x, mask): 271 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 272 | return self.sublayer[1](x, self.feed_forward) 273 | 274 | 275 | class DecoderLayer(nn.Module): 276 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 277 | 278 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 279 | super(DecoderLayer, self).__init__() 280 | self.size = size 281 | self.self_attn = self_attn 282 | self.src_attn = src_attn 283 | self.feed_forward = feed_forward 284 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 285 | 286 | def forward(self, x, memory, src_mask, tgt_mask): 287 | "Follow Figure 1 (right) for connections." 288 | m = memory 289 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 290 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 291 | return self.sublayer[2](x, self.feed_forward) 292 | 293 | 294 | class MultiHeadedAttention(nn.Module): 295 | def __init__(self, h, d_model, dropout=0.1): 296 | "Take in model size and number of heads." 297 | super(MultiHeadedAttention, self).__init__() 298 | assert d_model % h == 0 299 | # We assume d_v always equals d_k 300 | self.d_k = d_model // h 301 | self.h = h 302 | self.linears = clones(nn.Linear(d_model, d_model), 4) 303 | self.attn = None 304 | self.dropout = None 305 | 306 | def forward(self, query, key, value, mask=None): 307 | "Implements Figure 2" 308 | if mask is not None: 309 | # Same mask applied to all h heads. 310 | mask = mask.unsqueeze(1) 311 | nbatches = query.size(0) 312 | 313 | # 1) Do all the linear projections in batch from d_model => h x d_k 314 | query, key, value = \ 315 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 316 | for l, x in zip(self.linears, (query, key, value))] 317 | 318 | # 2) Apply attention on all the projected vectors in batch. 319 | x, self.attn = attention(query, key, value, mask=mask, 320 | dropout=self.dropout) 321 | 322 | # 3) "Concat" using a view and apply a final linear. 323 | x = x.transpose(1, 2).contiguous() \ 324 | .view(nbatches, -1, self.h * self.d_k) 325 | return self.linears[-1](x) 326 | 327 | 328 | class PositionwiseFeedForward(nn.Module): 329 | "Implements FFN equation." 330 | 331 | def __init__(self, d_model, d_ff, dropout=0.1): 332 | super(PositionwiseFeedForward, self).__init__() 333 | self.w_1 = nn.Linear(d_model, d_ff) 334 | self.norm = nn.Sequential() # nn.BatchNorm1d(d_ff) 335 | self.w_2 = nn.Linear(d_ff, d_model) 336 | self.dropout = None 337 | 338 | def forward(self, x): 339 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) 340 | 341 | 342 | class PointNet(nn.Module): 343 | def __init__(self, emb_dims=512): 344 | super(PointNet, self).__init__() 345 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 346 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 347 | self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 348 | self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False) 349 | self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False) 350 | self.bn1 = nn.BatchNorm1d(64) 351 | self.bn2 = nn.BatchNorm1d(64) 352 | self.bn3 = nn.BatchNorm1d(64) 353 | self.bn4 = nn.BatchNorm1d(128) 354 | self.bn5 = nn.BatchNorm1d(emb_dims) 355 | 356 | def forward(self, x): 357 | x = F.relu(self.bn1(self.conv1(x))) 358 | x = F.relu(self.bn2(self.conv2(x))) 359 | x = F.relu(self.bn3(self.conv3(x))) 360 | x = F.relu(self.bn4(self.conv4(x))) 361 | x = F.relu(self.bn5(self.conv5(x))) 362 | return x 363 | 364 | 365 | class DGCNN(nn.Module): 366 | def __init__(self, emb_dims=512): 367 | super(DGCNN, self).__init__() 368 | self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) 369 | self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) 370 | self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) 371 | self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) 372 | self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) 373 | self.bn1 = nn.BatchNorm2d(64) 374 | self.bn2 = nn.BatchNorm2d(64) 375 | self.bn3 = nn.BatchNorm2d(128) 376 | self.bn4 = nn.BatchNorm2d(256) 377 | self.bn5 = nn.BatchNorm2d(emb_dims) 378 | 379 | def forward(self, x): 380 | batch_size, num_dims, num_points = x.size() 381 | x = get_graph_feature(x) 382 | x = F.relu(self.bn1(self.conv1(x))) 383 | x1 = x.max(dim=-1, keepdim=True)[0] 384 | 385 | x = F.relu(self.bn2(self.conv2(x))) 386 | x2 = x.max(dim=-1, keepdim=True)[0] 387 | 388 | x = F.relu(self.bn3(self.conv3(x))) 389 | x3 = x.max(dim=-1, keepdim=True)[0] 390 | 391 | x = F.relu(self.bn4(self.conv4(x))) 392 | x4 = x.max(dim=-1, keepdim=True)[0] 393 | 394 | x = torch.cat((x1, x2, x3, x4), dim=1) 395 | 396 | x = F.relu(self.bn5(self.conv5(x))).view(batch_size, -1, num_points) 397 | return x 398 | 399 | 400 | class MLPHead(nn.Module): 401 | def __init__(self, emb_dims): 402 | super(MLPHead, self).__init__() 403 | self.emb_dims = emb_dims 404 | self.nn = nn.Sequential(nn.Linear(emb_dims * 2, emb_dims // 2), 405 | nn.BatchNorm1d(emb_dims // 2), 406 | nn.ReLU(), 407 | nn.Linear(emb_dims // 2, emb_dims // 4), 408 | nn.BatchNorm1d(emb_dims // 4), 409 | nn.ReLU(), 410 | nn.Linear(emb_dims // 4, emb_dims // 8), 411 | nn.BatchNorm1d(emb_dims // 8), 412 | nn.ReLU()) 413 | self.proj_rot = nn.Linear(emb_dims // 8, 4) 414 | self.proj_trans = nn.Linear(emb_dims // 8, 3) 415 | 416 | def forward(self, *input): 417 | src_embedding = input[0] 418 | tgt_embedding = input[1] 419 | pdb.set_trace() 420 | embedding = torch.cat((src_embedding, tgt_embedding), dim=1) 421 | embedding = self.nn(embedding.max(dim=-1)[0]) 422 | rotation = self.proj_rot(embedding) 423 | rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True) 424 | translation = self.proj_trans(embedding) 425 | return quat2mat(rotation), translation 426 | 427 | class Identity(nn.Module): 428 | def __init__(self): 429 | super(Identity, self).__init__() 430 | 431 | def forward(self, *input): 432 | return input 433 | 434 | class Transformer(nn.Module): 435 | def __init__(self, args): 436 | super(Transformer, self).__init__() 437 | self.emb_dims = args.emb_dims 438 | self.N = args.n_blocks 439 | self.dropout = args.dropout 440 | self.ff_dims = args.ff_dims 441 | self.n_heads = args.n_heads 442 | c = copy.deepcopy 443 | attn = MultiHeadedAttention(self.n_heads, self.emb_dims) 444 | ff = PositionwiseFeedForward(self.emb_dims, self.ff_dims, self.dropout) 445 | self.model = EncoderDecoder(Encoder(EncoderLayer(self.emb_dims, c(attn), c(ff), self.dropout), self.N), 446 | Decoder(DecoderLayer(self.emb_dims, c(attn), c(attn), c(ff), self.dropout), self.N), 447 | nn.Sequential(), 448 | nn.Sequential(), 449 | nn.Sequential()) 450 | 451 | def forward(self, *input): 452 | src = input[0] 453 | tgt = input[1] 454 | src = src.transpose(2, 1).contiguous() 455 | tgt = tgt.transpose(2, 1).contiguous() 456 | tgt_embedding = self.model(src, tgt, None, None).transpose(2, 1).contiguous() 457 | src_embedding = self.model(tgt, src, None, None).transpose(2, 1).contiguous() 458 | return src_embedding, tgt_embedding 459 | 460 | 461 | class SVDHead(nn.Module): 462 | def __init__(self, args): 463 | super(SVDHead, self).__init__() 464 | self.emb_dims = args.emb_dims 465 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) 466 | self.reflect[2, 2] = -1 467 | 468 | def forward(self, *input): 469 | src_embedding = input[0] 470 | tgt_embedding = input[1] 471 | src = input[2] 472 | tgt = input[3] 473 | batch_size = src.size(0) 474 | 475 | d_k = src_embedding.size(1) 476 | scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) 477 | scores = torch.softmax(scores, dim=2) 478 | 479 | src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous()) 480 | 481 | # from utils.commons import save_data 482 | # save_data("test.pth", [src[1].cpu().numpy(), src_corr[1].cpu().numpy()]) 483 | 484 | src_centered = src - src.mean(dim=2, keepdim=True) # [B, 3, N] 485 | 486 | src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True) # [B, 3, M] 487 | 488 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()) 489 | 490 | U, S, V = [], [], [] 491 | R = [] 492 | 493 | for i in range(src.size(0)): 494 | u, s, v = torch.svd(H[i]) 495 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) 496 | r_det = torch.det(r) 497 | if r_det < 0: 498 | u, s, v = torch.svd(H[i]) 499 | v = torch.matmul(v, self.reflect) 500 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) 501 | # r = r * self.reflect 502 | R.append(r) 503 | 504 | U.append(u) 505 | S.append(s) 506 | V.append(v) 507 | 508 | U = torch.stack(U, dim=0) 509 | V = torch.stack(V, dim=0) 510 | S = torch.stack(S, dim=0) 511 | R = torch.stack(R, dim=0) 512 | 513 | t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True) 514 | return R, t.view(batch_size, 3) 515 | 516 | 517 | class DCP(nn.Module): 518 | def __init__(self, args): 519 | super(DCP, self).__init__() 520 | self.emb_dims = args.emb_dims 521 | self.args = args 522 | if args.emb_nn == 'pointnet': 523 | self.emb_nn = PointNet(emb_dims=self.emb_dims) 524 | elif args.emb_nn == 'dgcnn': 525 | self.emb_nn = DGCNN(emb_dims=self.emb_dims) 526 | elif args.emb_nn == "rpmnet_emb": 527 | from modules.DCP.rpmnet_emb.feature_net import FeatExtractionEarlyFusion 528 | self.emb_nn = FeatExtractionEarlyFusion(features=args.features, feature_dim=args.feat_dim, 529 | radius=args.radius, num_neighbors=args.num_neighbors) 530 | else: 531 | raise Exception('Not implemented') 532 | 533 | if args.pointer == 'identity': 534 | self.pointer = Identity() 535 | elif args.pointer == 'transformer': 536 | self.pointer = Transformer(args=args) 537 | else: 538 | raise Exception("Not implemented") 539 | 540 | if args.head == 'mlp': 541 | self.head = MLPHead(args=args) 542 | elif args.head == 'svd': 543 | self.head = SVDHead(args=args) 544 | else: 545 | raise Exception('Not implemented') 546 | 547 | def forward(self, data): 548 | 549 | src_pcd = data["src_pcd"].transpose(2, 1).contiguous() # [B, 3, N] 550 | tgt_pcd = data["model_pcd"].transpose(2, 1).contiguous() # [B, 3, M] 551 | if self.args.emb_nn in ["rpmnet_emb"]: 552 | src_pcd_norm = data["src_pcd_normal"] 553 | tgt_pcd_norm = data["model_pcd_normal"] 554 | src_embedding = self.emb_nn(src_pcd.transpose(2, 1), src_pcd_norm).transpose(2, 1).contiguous() 555 | tgt_embedding = self.emb_nn(tgt_pcd.transpose(2, 1), tgt_pcd_norm).transpose(2, 1).contiguous() 556 | elif self.args.emb_nn in ["pointnet", "dgcnn"]: 557 | src_embedding = self.emb_nn(src_pcd) 558 | tgt_embedding = self.emb_nn(tgt_pcd) 559 | else: 560 | raise NotImplementedError 561 | 562 | src_embedding_p, tgt_embedding_p = self.pointer(src_embedding, tgt_embedding) 563 | src_embedding = src_embedding + src_embedding_p 564 | tgt_embedding = tgt_embedding + tgt_embedding_p 565 | 566 | Rs_pred, ts_pred = self.head(src_embedding, tgt_embedding, src_pcd, tgt_pcd) 567 | return Rs_pred, ts_pred -------------------------------------------------------------------------------- /modules/DCP/rpmnet_emb/feature_net.py: -------------------------------------------------------------------------------- 1 | """Feature Extraction and Parameter Prediction networks 2 | """ 3 | import logging 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from modules.DCP.rpmnet_emb.pointnet_util import sample_and_group_multi 9 | 10 | _raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4, 'centerfeat': 5} 11 | _raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2, 'centerfeat': 3} 12 | 13 | class ParameterPredictionNet(nn.Module): 14 | def __init__(self, weights_dim): 15 | """PointNet based Parameter prediction network 16 | 17 | Args: 18 | weights_dim: Number of weights to predict (excluding beta), should be something like 19 | [3], or [64, 3], for 3 types of features 20 | """ 21 | 22 | super().__init__() 23 | 24 | self._logger = logging.getLogger(self.__class__.__name__) 25 | 26 | self.weights_dim = weights_dim 27 | 28 | # Pointnet 29 | self.prepool = nn.Sequential( 30 | nn.Conv1d(4, 64, 1), 31 | nn.GroupNorm(8, 64), 32 | nn.ReLU(), 33 | 34 | nn.Conv1d(64, 64, 1), 35 | nn.GroupNorm(8, 64), 36 | nn.ReLU(), 37 | 38 | nn.Conv1d(64, 64, 1), 39 | nn.GroupNorm(8, 64), 40 | nn.ReLU(), 41 | 42 | nn.Conv1d(64, 128, 1), 43 | nn.GroupNorm(8, 128), 44 | nn.ReLU(), 45 | 46 | nn.Conv1d(128, 1024, 1), 47 | nn.GroupNorm(16, 1024), 48 | nn.ReLU(), 49 | ) 50 | self.pooling = nn.AdaptiveMaxPool1d(1) 51 | self.postpool = nn.Sequential( 52 | nn.Linear(1024, 512), 53 | nn.GroupNorm(16, 512), 54 | nn.ReLU(), 55 | 56 | nn.Linear(512, 256), 57 | nn.GroupNorm(16, 256), 58 | nn.ReLU(), 59 | 60 | nn.Linear(256, 2 + np.prod(weights_dim)), 61 | ) 62 | 63 | self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim)) 64 | 65 | def forward(self, x, y): 66 | """ Returns alpha, beta, and gating_weights (if needed) 67 | 68 | Args: 69 | x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3) 70 | 71 | Returns: 72 | beta, alpha, weightings 73 | """ 74 | 75 | src_padded = F.pad(x, (0, 1), mode='constant', value=0) 76 | ref_padded = F.pad(y, (0, 1), mode='constant', value=1) 77 | concatenated = torch.cat([src_padded, ref_padded], dim=1) 78 | 79 | prepool_feat = self.prepool(concatenated.permute(0, 2, 1)) 80 | pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2) 81 | raw_weights = self.postpool(pooled) 82 | 83 | beta = F.softplus(raw_weights[:, 0]) 84 | alpha = F.softplus(raw_weights[:, 1]) 85 | 86 | return beta, alpha 87 | 88 | 89 | class ParameterPredictionNetConstant(nn.Module): 90 | def __init__(self, weights_dim): 91 | """Parameter Prediction Network with single alpha/beta as parameter. 92 | 93 | See: Ablation study (Table 4) in paper 94 | """ 95 | 96 | super().__init__() 97 | 98 | self._logger = logging.getLogger(self.__class__.__name__) 99 | 100 | self.anneal_weights = nn.Parameter(torch.zeros(2 + np.prod(weights_dim))) 101 | self.weights_dim = weights_dim 102 | 103 | self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim)) 104 | 105 | def forward(self, x): 106 | """Returns beta, gating_weights""" 107 | 108 | batch_size = x[0].shape[0] 109 | raw_weights = self.anneal_weights 110 | beta = F.softplus(raw_weights[0].expand(batch_size)) 111 | alpha = F.softplus(raw_weights[1].expand(batch_size)) 112 | 113 | return beta, alpha 114 | 115 | 116 | def get_prepool(in_dim, out_dim): 117 | """Shared FC part in PointNet before max pooling""" 118 | net = nn.Sequential( 119 | nn.Conv2d(in_dim, out_dim // 2, 1), 120 | nn.GroupNorm(8, out_dim // 2), 121 | nn.ReLU(), 122 | nn.Conv2d(out_dim // 2, out_dim // 2, 1), 123 | nn.GroupNorm(8, out_dim // 2), 124 | nn.ReLU(), 125 | nn.Conv2d(out_dim // 2, out_dim, 1), 126 | nn.GroupNorm(8, out_dim), 127 | nn.ReLU(), 128 | ) 129 | return net 130 | 131 | 132 | def get_postpool(in_dim, out_dim): 133 | """Linear layers in PointNet after max pooling 134 | 135 | Args: 136 | in_dim: Number of input channels 137 | out_dim: Number of output channels. Typically smaller than in_dim 138 | 139 | """ 140 | net = nn.Sequential( 141 | nn.Conv1d(in_dim, in_dim, 1), 142 | nn.GroupNorm(8, in_dim), 143 | nn.ReLU(), 144 | nn.Conv1d(in_dim, out_dim, 1), 145 | nn.GroupNorm(8, out_dim), 146 | nn.ReLU(), 147 | nn.Conv1d(out_dim, out_dim, 1), 148 | ) 149 | 150 | return net 151 | 152 | 153 | class FeatExtractionEarlyFusion(nn.Module): 154 | """Feature extraction Module that extracts hybrid features""" 155 | def __init__(self, features, feature_dim, radius, num_neighbors): 156 | super().__init__() 157 | 158 | self._logger = logging.getLogger(self.__class__.__name__) 159 | self._logger.info('Using early fusion, feature dim = {}'.format(feature_dim)) 160 | self.radius = radius 161 | self.n_sample = num_neighbors 162 | 163 | self.features = sorted(features, key=lambda f: _raw_features_order[f]) 164 | self._logger.info('Feature extraction using features {}'.format(', '.join(self.features))) 165 | 166 | # Layers 167 | raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) # number of channels after concat 168 | self.prepool = get_prepool(raw_dim, feature_dim * 2) 169 | self.postpool = get_postpool(feature_dim * 2, feature_dim) 170 | 171 | def forward(self, xyz, normals): 172 | """Forward pass of the feature extraction network 173 | 174 | Args: 175 | xyz: (B, N, 3) 176 | normals: (B, N, 3) 177 | 178 | Returns: 179 | cluster features (B, N, C) 180 | 181 | """ 182 | features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals) 183 | features['xyz'] = features['xyz'][:, :, None, :] 184 | 185 | # Gate and concat 186 | concat = [] 187 | for i in range(len(self.features)): 188 | f = self.features[i] 189 | expanded = (features[f]).expand(-1, -1, self.n_sample, -1) 190 | concat.append(expanded) 191 | fused_input_feat = torch.cat(concat, -1) 192 | 193 | # Prepool_FC, pool, postpool-FC 194 | new_feat = fused_input_feat.permute(0, 3, 2, 1) # [B, 10, n_sample, N] 195 | new_feat = self.prepool(new_feat) 196 | 197 | pooled_feat = torch.max(new_feat, 2)[0] # Max pooling (B, C, N) 198 | 199 | post_feat = self.postpool(pooled_feat) # Post pooling dense layers 200 | cluster_feat = post_feat.permute(0, 2, 1) 201 | cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True) 202 | 203 | return cluster_feat # (B, N, C) 204 | 205 | -------------------------------------------------------------------------------- /modules/DCP/rpmnet_emb/pointnet_util.py: -------------------------------------------------------------------------------- 1 | """Utilities for PointNet related functions 2 | 3 | Modified from: 4 | Pytorch Implementation of PointNet and PointNet++ 5 | https://github.com/yanx27/Pointnet_Pointnet2_pytorch 6 | """ 7 | 8 | import torch 9 | 10 | 11 | def angle_difference(src, dst): 12 | """Calculate angle between each pair of vectors. 13 | Assumes points are l2-normalized to unit length. 14 | 15 | Input: 16 | src: source points, [B, N, C] 17 | dst: target points, [B, M, C] 18 | Output: 19 | dist: per-point square distance, [B, N, M] 20 | """ 21 | B, N, _ = src.shape 22 | _, M, _ = dst.shape 23 | dist = torch.matmul(src, dst.permute(0, 2, 1)) 24 | dist = torch.acos(dist) 25 | 26 | return dist 27 | 28 | 29 | def square_distance(src, dst): 30 | """Calculate Euclid distance between each two points. 31 | src^T * dst = xn * xm + yn * ym + zn * zm; 32 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 33 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 34 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 35 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 36 | 37 | Args: 38 | src: source points, [B, N, C] 39 | dst: target points, [B, M, C] 40 | Returns: 41 | dist: per-point square distance, [B, N, M] 42 | """ 43 | B, N, _ = src.shape 44 | _, M, _ = dst.shape 45 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 46 | dist += torch.sum(src ** 2, dim=-1)[:, :, None] 47 | dist += torch.sum(dst ** 2, dim=-1)[:, None, :] 48 | return dist 49 | 50 | 51 | def index_points(points, idx): 52 | """Array indexing, i.e. retrieves relevant points based on indices 53 | 54 | Args: 55 | points: input points data_loader, [B, N, C] 56 | idx: sample index data_loader, [B, S]. S can be 2 dimensional 57 | Returns: 58 | new_points:, indexed points data_loader, [B, S, C] 59 | """ 60 | device = points.device 61 | B = points.shape[0] 62 | view_shape = list(idx.shape) 63 | view_shape[1:] = [1] * (len(view_shape) - 1) 64 | repeat_shape = list(idx.shape) 65 | repeat_shape[0] = 1 66 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 67 | new_points = points[batch_indices, idx, :] 68 | return new_points 69 | 70 | 71 | def farthest_point_sample(xyz, npoint): 72 | """Iterative farthest point sampling 73 | 74 | Args: 75 | xyz: pointcloud data_loader, [B, N, C] 76 | npoint: number of samples 77 | Returns: 78 | centroids: sampled pointcloud index, [B, npoint] 79 | """ 80 | device = xyz.device 81 | B, N, C = xyz.shape 82 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 83 | distance = torch.ones(B, N).to(device) * 1e10 84 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 85 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 86 | for i in range(npoint): 87 | centroids[:, i] = farthest 88 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 89 | dist = torch.sum((xyz - centroid) ** 2, -1) 90 | mask = dist < distance 91 | distance[mask] = dist[mask] 92 | farthest = torch.max(distance, -1)[1] 93 | return centroids 94 | 95 | 96 | def query_ball_point(radius, nsample, xyz, new_xyz, itself_indices=None): 97 | """ Grouping layer in PointNet++. 98 | 99 | Inputs: 100 | radius: local region radius 101 | nsample: max sample number in local region 102 | xyz: all points, (B, N, C) 103 | new_xyz: query points, (B, S, C) 104 | itself_indices (Optional): Indices of new_xyz into xyz (B, S). 105 | Used to try and prevent grouping the point itself into the neighborhood. 106 | If there is insufficient points in the neighborhood, or if left is none, the resulting cluster will 107 | still contain the center point. 108 | Returns: 109 | group_idx: grouped points index, [B, S, nsample] 110 | """ 111 | device = xyz.device 112 | B, N, C = xyz.shape 113 | _, S, _ = new_xyz.shape 114 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # (B, S, N) 115 | sqrdists = square_distance(new_xyz, xyz) 116 | 117 | if itself_indices is not None: 118 | # Remove indices of the center points so that it will not be chosen 119 | batch_indices = torch.arange(B, dtype=torch.long).to(device)[:, None].repeat(1, S) # (B, S) 120 | row_indices = torch.arange(S, dtype=torch.long).to(device)[None, :].repeat(B, 1) # (B, S) 121 | group_idx[batch_indices, row_indices, itself_indices] = N 122 | 123 | group_idx[sqrdists > radius ** 2] = N 124 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 125 | if itself_indices is not None: 126 | group_first = itself_indices[:, :, None].repeat([1, 1, nsample]) 127 | else: 128 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 129 | mask = group_idx == N 130 | group_idx[mask] = group_first[mask] 131 | return group_idx 132 | 133 | 134 | def sample_and_group(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, points: torch.Tensor, 135 | returnfps: bool=False): 136 | """ 137 | Args: 138 | npoint (int): Set to negative to compute for all points 139 | radius: 140 | nsample: 141 | xyz: input points position data_loader, [B, N, C] 142 | points: input points data_loader, [B, N, D] 143 | returnfps (bool) Whether to return furthest point indices 144 | Returns: 145 | new_xyz: sampled points position data_loader, [B, 1, C] 146 | new_points: sampled points data_loader, [B, 1, N, C+D] 147 | """ 148 | B, N, C = xyz.shape 149 | 150 | if npoint > 0: 151 | S = npoint 152 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 153 | new_xyz = index_points(xyz, fps_idx) 154 | else: 155 | S = xyz.shape[1] 156 | fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1) 157 | new_xyz = xyz 158 | 159 | idx = query_ball_point(radius, nsample, xyz, new_xyz) # (B, N, nsample) 160 | grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C) 161 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 162 | if points is not None: 163 | grouped_points = index_points(points, idx) 164 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 165 | else: 166 | new_points = grouped_xyz_norm 167 | if returnfps: 168 | return new_xyz, new_points, grouped_xyz, fps_idx 169 | else: 170 | return new_xyz, new_points 171 | 172 | 173 | def angle(v1: torch.Tensor, v2: torch.Tensor): 174 | """Compute angle between 2 vectors 175 | 176 | For robustness, we use the same formulation as in PPFNet, i.e. 177 | angle(v1, v2) = atan2(cross(v1, v2), dot(v1, v2)). 178 | This handles the case where one of the vectors is 0.0, since torch.atan2(0.0, 0.0)=0.0 179 | 180 | Args: 181 | v1: (B, *, 3) 182 | v2: (B, *, 3) 183 | 184 | Returns: 185 | 186 | """ 187 | 188 | cross_prod = torch.stack([v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1], 189 | v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2], 190 | v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0]], dim=-1) 191 | cross_prod_norm = torch.norm(cross_prod, dim=-1) 192 | dot_prod = torch.sum(v1 * v2, dim=-1) 193 | 194 | return torch.atan2(cross_prod_norm, dot_prod) 195 | 196 | 197 | def sample_and_group_multi(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, normals: torch.Tensor, 198 | returnfps: bool = False, is_centerfeat: bool = False): 199 | """Sample and group for xyz, dxyz and ppf features 200 | 201 | Args: 202 | npoint(int): Number of clusters (equivalently, keypoints) to sample. 203 | Set to negative to compute for all points 204 | radius(int): Radius of cluster for computing local features 205 | nsample: Maximum number of points to consider per cluster 206 | xyz: XYZ coordinates of the points 207 | normals: Corresponding normals for the points (required for ppf computation) 208 | returnfps: Whether to return indices of FPS points and their neighborhood 209 | 210 | Returns: 211 | Dictionary containing the following fields ['xyz', 'dxyz', 'ppf']. 212 | If returnfps is True, also returns: grouped_xyz, fps_idx 213 | """ 214 | 215 | B, N, C = xyz.shape 216 | 217 | if npoint > 0: 218 | S = npoint 219 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 220 | new_xyz = index_points(xyz, fps_idx) 221 | nr = index_points(normals, fps_idx)[:, :, None, :] 222 | else: 223 | S = xyz.shape[1] 224 | fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1).to(xyz.device) 225 | new_xyz = xyz 226 | nr = normals[:, :, None, :] 227 | 228 | idx = query_ball_point(radius, nsample, xyz, new_xyz, fps_idx) # (B, npoint, nsample) 229 | grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C) 230 | d = grouped_xyz - new_xyz.view(B, S, 1, C) # d = p_r - p_i (B, npoint, nsample, C) 231 | ni = index_points(normals, idx) 232 | 233 | nr_d = angle(nr, d) 234 | ni_d = angle(ni, d) 235 | nr_ni = angle(nr, ni) 236 | d_norm = torch.norm(d, dim=-1) 237 | 238 | xyz_feat = d # (B, npoint, n_sample, 3) 239 | ppf_feat = torch.stack([nr_d, ni_d, nr_ni, d_norm], dim=-1) # (B, npoint, n_sample, 4) 240 | 241 | if returnfps: 242 | return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}, grouped_xyz, fps_idx 243 | else: 244 | return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat} 245 | 246 | def sample_and_group_multi_center(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, normals: torch.Tensor, 247 | returnfps: bool = False, is_fake_center=False): 248 | """Sample and group for xyz, dxyz and ppf features 249 | 250 | Args: 251 | npoint(int): Number of clusters (equivalently, keypoints) to sample. 252 | Set to negative to compute for all points 253 | radius(int): Radius of cluster for computing local features 254 | nsample: Maximum number of points to consider per cluster 255 | xyz: XYZ coordinates of the points 256 | normals: Corresponding normals for the points (required for ppf computation) 257 | returnfps: Whether to return indices of FPS points and their neighborhood 258 | 259 | Returns: 260 | Dictionary containing the following fields ['xyz', 'dxyz', 'ppf']. 261 | If returnfps is True, also returns: grouped_xyz, fps_idx 262 | """ 263 | 264 | B, N, C = xyz.shape 265 | 266 | if npoint > 0: 267 | S = npoint 268 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 269 | new_xyz = index_points(xyz, fps_idx) 270 | nr = index_points(normals, fps_idx)[:, :, None, :] 271 | else: 272 | S = xyz.shape[1] 273 | fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1).to(xyz.device) 274 | new_xyz = xyz 275 | nr = normals[:, :, None, :] 276 | 277 | idx = query_ball_point(radius, nsample, xyz, new_xyz, fps_idx) # (B, npoint, nsample) 278 | grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C) 279 | d = grouped_xyz - new_xyz.view(B, S, 1, C) # d = p_r - p_i (B, npoint, nsample, C) 280 | ni = index_points(normals, idx) 281 | 282 | nr_d = angle(nr, d) 283 | ni_d = angle(ni, d) 284 | nr_ni = angle(nr, ni) 285 | d_norm = torch.norm(d, dim=-1) 286 | 287 | xyz_feat = d # (B, npoint, n_sample, 3) 288 | ppf_feat = torch.stack([nr_d, ni_d, nr_ni, d_norm], dim=-1) # (B, npoint, n_sample, 4) 289 | 290 | if not is_fake_center: 291 | nr_dd0 = angle(nr, new_xyz[:, :, None, :]).expand(-1, -1, grouped_xyz.shape[2]) # (B, npoint, 1) 292 | nr_dd1 = angle(nr, grouped_xyz) # (B, npoint, n_sample) 293 | ni_dd = angle(ni, grouped_xyz) # (B, npoint, n_sample) 294 | d_norm1 = torch.norm(new_xyz, dim=-1, keepdim=True).expand(-1, -1, grouped_xyz.shape[2]) # (B, npoint, 1) 295 | d_norm2 = torch.norm(grouped_xyz, dim=-1) # (B, npoint, nsample) 296 | centerfeat = torch.stack([nr_dd0, nr_dd1, ni_dd, d_norm1, d_norm2], dim=-1) # (B, npoint, n_sample, 5) 297 | else: 298 | fake_center = xyz.mean(1, keepdim=True)[:, :, None, :] # [B, 1, 1, 3] 299 | nr_dd0 = angle(nr, new_xyz[:, :, None, :] - fake_center).expand(-1, -1, grouped_xyz.shape[2]) # (B, npoint, 1) 300 | nr_dd1 = angle(nr, grouped_xyz - fake_center) # (B, npoint, n_sample) 301 | ni_dd = angle(ni, grouped_xyz - fake_center) # (B, npoint, n_sample) 302 | d_norm1 = torch.norm(new_xyz - fake_center.squeeze(2), dim=-1, keepdim=True).expand(-1, -1, grouped_xyz.shape[2]) # (B, npoint, 1) 303 | d_norm2 = torch.norm(grouped_xyz - fake_center, dim=-1) # (B, npoint, nsample) 304 | centerfeat = torch.stack([nr_dd0, nr_dd1, ni_dd, d_norm1, d_norm2], dim=-1) # (B, npoint, n_sample, 5) 305 | 306 | 307 | if returnfps: 308 | return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat, "centerfeat": centerfeat}, grouped_xyz, fps_idx 309 | else: 310 | return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat, "centerfeat": centerfeat} 311 | 312 | # def sample_and_group_multi_center(npoint: int, radius: float, nsample: int, xyz: torch.Tensor, normals: torch.Tensor, returnfps: bool = False): 313 | # """Sample and group for xyz, dxyz and ppf features 314 | # 315 | # Args: 316 | # npoint(int): Number of clusters (equivalently, keypoints) to sample. 317 | # Set to negative to compute for all points 318 | # radius(int): Radius of cluster for computing local features 319 | # nsample: Maximum number of points to consider per cluster 320 | # xyz: XYZ coordinates of the points 321 | # normals: Corresponding normals for the points (required for ppf computation) 322 | # returnfps: Whether to return indices of FPS points and their neighborhood 323 | # 324 | # Returns: 325 | # Dictionary containing the following fields ['xyz', 'dxyz', 'ppf']. 326 | # If returnfps is True, also returns: grouped_xyz, fps_idx 327 | # """ 328 | # 329 | # B, N, C = xyz.shape 330 | # 331 | # if npoint > 0: 332 | # S = npoint 333 | # fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 334 | # new_xyz = index_points(xyz, fps_idx) 335 | # nr = index_points(normals, fps_idx)[:, :, None, :] 336 | # else: 337 | # S = xyz.shape[1] 338 | # fps_idx = torch.arange(0, xyz.shape[1])[None, ...].repeat(xyz.shape[0], 1).to(xyz.device) 339 | # new_xyz = xyz 340 | # nr = normals[:, :, None, :] 341 | # 342 | # idx = query_ball_point(radius, nsample, xyz, new_xyz, fps_idx) # (B, npoint, nsample) 343 | # grouped_xyz = index_points(xyz, idx) # (B, npoint, nsample, C) 344 | # d = grouped_xyz - new_xyz.view(B, S, 1, C) # d = p_r - p_i (B, npoint, nsample, 3) 345 | # ni = index_points(normals, idx) 346 | # 347 | # nr_d = angle(nr, d) 348 | # ni_d = angle(ni, d) 349 | # nr_ni = angle(nr, ni) 350 | # d_norm = torch.norm(d, dim=-1) 351 | # 352 | # xyz_feat = d # (B, npoint, n_sample, 3) 353 | # ppf_feat = torch.stack([nr_d, ni_d, nr_ni, d_norm], dim=-1) # (B, npoint, n_sample, 4) 354 | # 355 | # if returnfps: 356 | # return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat}, grouped_xyz, fps_idx 357 | # else: 358 | # return {'xyz': new_xyz, 'dxyz': xyz_feat, 'ppf': ppf_feat} -------------------------------------------------------------------------------- /results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000001_noiseTrue_v1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000001_noiseTrue_v1.pth -------------------------------------------------------------------------------- /results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000002_noiseTrue_v1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000002_noiseTrue_v1.pth -------------------------------------------------------------------------------- /results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000003_noiseTrue_v1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000003_noiseTrue_v1.pth -------------------------------------------------------------------------------- /results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/model_epoch19.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/DiffusionReg/718122ad6e8c3e2ee4a7376d452f9872e873194f/results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/model_epoch19.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np, random, torch, pdb 2 | from utils.options import opts 3 | from datasets.get_dataset import get_dataset 4 | from collections import OrderedDict 5 | from utils.commons import save_data, load_data 6 | from collections import defaultdict 7 | from utils.criterion import mAP 8 | from tqdm import tqdm 9 | from utils.se_math import se3 10 | from utils.diffusion_scheduler import DiffusionScheduler 11 | 12 | opts.seed = 1234 13 | np.random.seed(opts.seed) 14 | random.seed(opts.seed) 15 | torch.manual_seed(opts.seed) 16 | torch.cuda.manual_seed(opts.seed) 17 | torch.backends.cudnn.enabled = True 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def init_opts(opts): 22 | opts.is_debug = False 23 | opts.is_test = True 24 | opts.schedule_type = ["linear", "cosine"][1] 25 | 26 | opts.is_save_res = True 27 | opts.n_diff_steps = 5 28 | opts.beta_1 = 0.2 29 | opts.beta_T = 0.8 30 | opts.sigma_r = 0.1 31 | opts.sigma_t = 0.01 32 | opts.is_add_noise = True 33 | 34 | return opts 35 | 36 | 37 | def get_model(opts): 38 | opts.model_type = "DiffusionDCP" 39 | opts.model_path = "./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/model_epoch19.pth" 40 | opts.save_path = f"./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T{opts.n_diff_steps}_{opts.schedule_type}_{opts.db_nm}_{opts.vid_infos[0]}_noise{opts.is_add_noise}_v1.pth" 41 | 42 | # model config 43 | from modules.DCP.dcp import DCP 44 | surrogate_model = DCP(opts) 45 | opts.vs = DiffusionScheduler(opts) 46 | 47 | try: 48 | surrogate_model.load_state_dict(OrderedDict({k[7:]: v for k, v in torch.load(opts.model_path, map_location=opts.device).items()})) 49 | except: 50 | surrogate_model.load_state_dict(OrderedDict({k: v for k, v in torch.load(opts.model_path, map_location=opts.device).items()})) 51 | surrogate_model = surrogate_model.to(opts.device) 52 | surrogate_model.eval() 53 | 54 | print(opts.save_path) 55 | return surrogate_model 56 | 57 | def main(opts): 58 | ## initial setting 59 | opts = init_opts(opts) 60 | surrogate_model = get_model(opts) 61 | test_loader, test_db = get_dataset(opts, db_nm=opts.db_nm, cls_nm=opts.vid_infos, partition="test", batch_size=1, shuffle=False, drop_last=False, n_cores=4) 62 | rcd = defaultdict(list) 63 | with torch.no_grad(): 64 | for i, data in enumerate(tqdm(test_loader)): 65 | data = {k: v.to(opts.device).float() for k, v in data.items()} 66 | X, X_normal = data["src_pcd"].clone(), data["src_pcd_normal"].clone() 67 | Y, Y_normal = data["model_pcd"].clone(), data["model_pcd_normal"].clone() 68 | B = len(X) 69 | H_t = torch.eye(4)[None].expand(B, -1, -1).to(opts.device) # [B, 4, 4] 70 | 71 | X_list = [] 72 | for t in range(opts.n_diff_steps, 1, -1): # [T, T-1, ..., 1] 73 | X_t = (H_t[:, :3, :3] @ X.transpose(2, 1) + H_t[:, :3, [3]]).transpose(2, 1) # [B, N, 3] 74 | X_list.append(X_t[0].cpu().numpy()) 75 | X_normal_t = (H_t[:, :3, :3] @ X_normal.transpose(2, 1)).transpose(2, 1) # [B, N, 3] 76 | Rs_pred, ts_pred = surrogate_model.forward({ 77 | "src_pcd": X_t, 78 | "src_pcd_normal": X_normal_t, 79 | "model_pcd": Y, 80 | "model_pcd_normal": Y_normal 81 | }) 82 | _delta_H_t = torch.cat([Rs_pred, ts_pred.unsqueeze(-1)], dim=2) # [B, 3, 4] 83 | delta_H_t = torch.eye(4)[None].expand(B, -1, -1).to(opts.device) # [B, 4, 4] 84 | delta_H_t[:, :3, :] = _delta_H_t 85 | H_0 = delta_H_t @ H_t 86 | 87 | gamma0 = opts.vs.gamma0[t] 88 | gamma1 = opts.vs.gamma1[t] 89 | H_t = se3.exp(gamma0 * se3.log(H_0) + gamma1 * se3.log(H_t)) 90 | 91 | ### noise 92 | if opts.is_add_noise: 93 | alpha_bar = opts.vs.alpha_bars[t] 94 | alpha_bar_ = opts.vs.alpha_bars[t-1] 95 | beta = opts.vs.betas[t] 96 | cc = ((1 - alpha_bar_) / (1.- alpha_bar)) * beta 97 | scale = torch.cat([torch.ones(3) * opts.sigma_r, torch.ones(3) * opts.sigma_t])[None].to(opts.device) # [1, 6] 98 | noise = torch.sqrt(cc) * scale * torch.randn(B, 6).to(opts.device) # [B, 6] 99 | H_noise = se3.exp(noise) 100 | H_t = H_noise @ H_t # [B, 4, 4] 101 | 102 | Rs_pred = H_0[:, :3, :3] 103 | ts_pred = H_0[:, :3, 3] 104 | Rs_pred = torch.inverse(Rs_pred) 105 | ts_pred = (- Rs_pred @ ts_pred[:, :, None])[:, :, 0] 106 | 107 | rcd["Rs_pred"].extend(list(Rs_pred.cpu().numpy())) 108 | rcd["ts_pred"].extend(list(ts_pred.cpu().numpy())) 109 | rcd["Rs_gt"].extend(list(data["R_gt_ms"].cpu().numpy())) 110 | rcd["ts_gt"].extend(list(data["t_gt_ms"].cpu().numpy())) 111 | 112 | if opts.is_save_res: 113 | save_data(opts.save_path, rcd) 114 | 115 | print(opts.save_path) 116 | 117 | def cal_score(): 118 | res_paths = [ 119 | "./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000001_noiseTrue_v1.pth", 120 | "./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000002_noiseTrue_v1.pth", 121 | "./results/DiffusionReg-DiffusionDCP-tudl-diffusion_200_0.00010_0.05_0.05_0.03-nvids3_cosine/eval_results/model_epoch19_T5_cosine_tudl_000003_noiseTrue_v1.pth", 122 | ] 123 | score = defaultdict(list) 124 | for res_path in res_paths: 125 | print(res_path) 126 | res = load_data(res_path) 127 | scale = 256 128 | score["Rs_gt"].extend(res["Rs_gt"]) 129 | score["Rs_pred"].extend(res["Rs_pred"]) 130 | score["ts_gt"].extend([x * scale / 10. for x in res["ts_gt"]]) 131 | score["ts_pred"].extend([x * scale / 10. for x in res["ts_pred"]]) 132 | auc_R, auc_t = mAP(score["Rs_pred"], score["ts_pred"], score["Rs_gt"], score["ts_gt"]) 133 | print("mAP_R (5/10/20 degree): %.3f, %.3f, %.3f | mAP_t (1/2/5 cm): %.3f, %.3f, %.3f |" % (*auc_R[:3], *auc_t[:3])) 134 | print(np.mean(score["times"])) 135 | 136 | if __name__ == '__main__': 137 | opts.db_nm = "tudl" 138 | for cls_nm in ["000001", "000002", "000003"]: 139 | opts.vid_infos = [cls_nm] 140 | main(opts) 141 | 142 | # cal_score() 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, random, torch, argparse, pdb 2 | from utils.options import opts 3 | from utils.diffusion_scheduler import DiffusionScheduler 4 | from utils.get_lr_scheduler import get_lr_scheduler 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | from utils.se_math import se3 8 | from datasets.get_dataset import get_dataset 9 | from utils.losses import compute_losses, compute_losses_diff 10 | 11 | np.random.seed(opts.seed) 12 | random.seed(opts.seed) 13 | torch.manual_seed(opts.seed) 14 | torch.backends.cudnn.benchmark=True 15 | torch.backends.cudnn.enabled = False 16 | 17 | def init_opts(opts): 18 | opts.is_debug = False 19 | opts.model_nm = "DiffusionReg" 20 | opts.is_test = False 21 | opts.is_normal = True 22 | opts.n_cores = 6 23 | opts.schedule_type = ["linear", "cosine"][1] 24 | 25 | # dataset config 26 | if opts.db_nm == "tudl": 27 | opts.n_vids = 3 28 | opts.n_epoches, opts.n_start_epoches, opts.batch_size = 20, 0, 32 29 | opts.vid_infos = ["000001", "000002", "000003"] 30 | else: 31 | raise NotImplementedError 32 | 33 | # diffusion configuration 34 | opts.n_diff_steps = 200 35 | opts.beta_1 = 1e-4 36 | opts.beta_T = 0.05 37 | opts.sigma_r = 0.05 38 | opts.sigma_t = 0.03 39 | diffusion_str = f"diffusion_{opts.n_diff_steps}_{opts.beta_1:.5f}_{opts.beta_T:.2f}_{opts.sigma_r:.2f}_{opts.sigma_t:.2f}" 40 | 41 | opts.results_dir = f"./results/{opts.model_nm}-{opts.net_type}-{opts.db_nm}-{diffusion_str}-nvids{opts.n_vids}_{opts.schedule_type}" 42 | os.makedirs(opts.results_dir, exist_ok=True) 43 | print(opts.results_dir) 44 | return opts 45 | 46 | def main(opts): 47 | 48 | opts = init_opts(opts) 49 | 50 | ## model setting 51 | from modules.DCP.dcp import DCP 52 | opts.vs = DiffusionScheduler(opts) 53 | surrogate_model = DCP(opts) 54 | 55 | if torch.cuda.device_count() > 1: 56 | surrogate_model = torch.nn.DataParallel(surrogate_model, range(torch.cuda.device_count())) 57 | surrogate_model = surrogate_model.to(opts.device) 58 | 59 | train_loader, train_db = get_dataset(opts, db_nm=opts.db_nm, cls_nm=opts.vid_infos, partition="train", 60 | batch_size=opts.batch_size, shuffle=True, drop_last=True, n_cores=opts.n_cores) 61 | optimizer = torch.optim.Adam(surrogate_model.parameters(), lr=opts.lr, betas=(0.9, 0.999)) 62 | scheduler = get_lr_scheduler(opts, optimizer) 63 | cal = lambda x: np.mean(x).item() 64 | 65 | ## training 66 | for epoch_idx in range(opts.n_epoches): 67 | 68 | # train 69 | surrogate_model.train() 70 | 71 | rcd = defaultdict(list) 72 | for i, data in enumerate(tqdm(train_loader, 0)): 73 | 74 | data = {k: v.to(opts.device) for k, v in data.items()} 75 | 76 | # model prediction 77 | X, X_normal = data["src_pcd"], data["src_pcd_normal"] # [B, N, 3] 78 | Y, Y_normal = data["model_pcd"], data["model_pcd_normal"] # [B, M, 3] 79 | Rs_gt, ts_gt = data['transform_gt'][:, :3, :3], data["transform_gt"][:, :3, 3] # [B, 3, 3], [B, 3] 80 | B = Rs_gt.shape[0] 81 | 82 | ### SE(3) diffusion process 83 | H_0 = torch.eye(4)[None].expand(B, -1, -1).to(opts.device) 84 | H_0[:, :3, :3], H_0[:, :3, 3] = Rs_gt, ts_gt 85 | H_T = torch.eye(4)[None].expand(B, -1, -1).to(opts.device) 86 | 87 | taus = opts.vs.uniform_sample_t(B) 88 | alpha_bars = opts.vs.alpha_bars[taus].to(opts.device)[:, None] # [B, 1] 89 | H_t = se3.exp((1. - torch.sqrt(alpha_bars)) * se3.log(H_T @ torch.inverse(H_0))) @ H_0 90 | 91 | ### add noise 92 | scale = torch.cat([torch.ones(3) * opts.sigma_r, torch.ones(3) * opts.sigma_t])[None].to(opts.device) # [1, 6] 93 | noise = torch.sqrt(1. - alpha_bars) * scale * torch.randn(B, 6).to(opts.device) # [B, 6] 94 | H_noise = se3.exp(noise) 95 | H_t_noise = H_noise @ H_t # [B, 4, 4] 96 | 97 | T_t_R = H_t_noise[:, :3, :3] # [B, 3, 3] 98 | T_t_t = H_t_noise[:, :3, 3] # [B, 3] 99 | 100 | X_t = (T_t_R @ X.transpose(2, 1) + T_t_t.unsqueeze(-1)).transpose(2, 1) # [B, N, 3] 101 | X_normal_t = (T_t_R @ X_normal.transpose(2, 1)).transpose(2, 1) # [B, N, 3] 102 | 103 | transform_gt = torch.eye(4)[None].expand(B, -1, -1).to(opts.device) 104 | transform_gt[:, :3] = data['transform_gt'] 105 | input = { 106 | "src_pcd": X_t, 107 | "src_pcd_normal": X_normal_t, 108 | "model_pcd": Y, 109 | "model_pcd_normal": Y_normal, 110 | } 111 | Rs_pred_rot, ts_pred_rot = surrogate_model.forward(input) 112 | pred_transforms = torch.cat([Rs_pred_rot, ts_pred_rot.unsqueeze(-1)], dim=2) # [B, 3, 4] 113 | train_losses_diff = compute_losses_diff(opts, X, X_t, [pred_transforms], data['transform_gt'], loss_type="mae", reduction='mean') 114 | loss = train_losses_diff['total'] 115 | 116 | # original loss 117 | input = { 118 | "src_pcd": X, 119 | "src_pcd_normal": X_normal, 120 | "model_pcd": Y, 121 | "model_pcd_normal": Y_normal, 122 | } 123 | Rs_pred_rot1, ts_pred_rot1 = surrogate_model.forward(input) 124 | pred_transforms1 = torch.cat([Rs_pred_rot1, ts_pred_rot1.unsqueeze(-1)], dim=2) # [B, 3, 4] 125 | train_losses_origin = compute_losses(opts, X, [pred_transforms1], data['transform_gt'], loss_type="mae", reduction='mean') 126 | loss += train_losses_origin["total"] 127 | rcd["losses"].append(loss.item()) 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | print("=== Train. Epoch [%d], losses: %1.3f ===" % (epoch_idx, cal(rcd["losses"]))) 134 | 135 | if i > 0 and i % 200 == 0 and not opts.is_debug: 136 | print("Save model. %s" % ('%s/model_epoch%d.pth' % (opts.results_dir, epoch_idx))) 137 | torch.save(surrogate_model.state_dict(), '%s/model_epoch%d.pth' % (opts.results_dir, epoch_idx)) 138 | 139 | print(opts.results_dir) 140 | 141 | # save model 142 | if not opts.is_debug: 143 | print("Save model. %s" % ('%s/model_epoch%d.pth' % (opts.results_dir, epoch_idx))) 144 | torch.save(surrogate_model.state_dict(), '%s/model_epoch%d.pth' % (opts.results_dir, epoch_idx)) 145 | else: 146 | print("Debug. Not save model.") 147 | 148 | scheduler.step() 149 | 150 | 151 | if __name__ == '__main__': 152 | 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--net_type', default="DiffusionDCP", type=str, choices=['DiffusionDCP']) 155 | parser.add_argument('--db_nm', default="tudl", type=str, choices=["tudl"]) 156 | args = parser.parse_args() 157 | 158 | opts.net_type = args.net_type 159 | opts.db_nm = args.db_nm 160 | main(opts) -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | 3 | def __init__(self, *args, **kwargs): 4 | super(AttrDict, self).__init__(*args, **kwargs) 5 | 6 | def __getattr__(self, key): 7 | if key.startswith('__'): 8 | raise AttributeError 9 | return self.get(key, None) 10 | 11 | def __setattr__(self, key, value): 12 | if key.startswith('__'): 13 | raise AttributeError("Cannot set magic attribute '{}'".format(key)) 14 | self[key] = value -------------------------------------------------------------------------------- /utils/commons.py: -------------------------------------------------------------------------------- 1 | import pickle, open3d as o3d, numpy as np, copy 2 | 3 | def load_data(path): 4 | file = open(path, "rb") 5 | data = pickle.load(file) 6 | file.close() 7 | return data 8 | 9 | def save_data(path, data): 10 | file = open(path, "wb") 11 | pickle.dump(data, file) 12 | file.close() 13 | 14 | def cal_normal(pcd, radius=0.1, max_nn=30): 15 | _pcd = o3d.geometry.PointCloud() 16 | _pcd.points = o3d.utility.Vector3dVector(pcd) 17 | _pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=radius, max_nn=max_nn)) 18 | # o3d.geometry.estimate_normals(_pcd, o3d.geometry.KDTreeSearchParamHybrid(radius=radius, max_nn=max_nn)) 19 | normals = np.asarray(_pcd.normals) 20 | return normals 21 | 22 | def regularize_pcd(pcd, n_points, is_test): 23 | assert is_test is not None 24 | if is_test: 25 | np.random.seed(1) 26 | idxs = np.random.randint(low=0, high=pcd.shape[1], size=n_points, dtype=np.int64) 27 | new_pcd = pcd[:, idxs].astype(np.float32) 28 | return new_pcd 29 | 30 | def crop_pcd(pcd, bbox, offset=0, scale=1.0, is_scale_max=False, is_mask=False): 31 | bbox_tmp = copy.deepcopy(bbox) 32 | if is_scale_max: 33 | bbox_tmp.wlh = np.asarray([np.max(bbox_tmp.wlh)] * 3) 34 | bbox_tmp.wlh = bbox_tmp.wlh * scale 35 | maxi = np.max(bbox_tmp.corners(), 1) + offset 36 | mini = np.min(bbox_tmp.corners(), 1) - offset 37 | 38 | x_filt_max = pcd[0, :] < maxi[0] 39 | x_filt_min = pcd[0, :] > mini[0] 40 | y_filt_max = pcd[1, :] < maxi[1] 41 | y_filt_min = pcd[1, :] > mini[1] 42 | z_filt_max = pcd[2, :] < maxi[2] 43 | z_filt_min = pcd[2, :] > mini[2] 44 | 45 | close = np.logical_and(x_filt_min, x_filt_max) 46 | close = np.logical_and(close, y_filt_min) 47 | close = np.logical_and(close, y_filt_max) 48 | close = np.logical_and(close, z_filt_min) 49 | close = np.logical_and(close, z_filt_max) 50 | if is_mask: 51 | return pcd[:, close], close 52 | else: 53 | return pcd[:, close] 54 | 55 | def depth2pcd(depth, min_xy, max_xy, H, W, cam_info, scale=None): 56 | if scale is None: 57 | depth = depth / cam_info[[4]] # [H, W] 58 | else: 59 | depth = depth / scale # [H, W] 60 | 61 | xv = np.arange(W) 62 | yv = np.arange(H) 63 | X_np, Y_np = np.meshgrid(xv, yv) # [H, W] 64 | x = (X_np[min_xy[1]: max_xy[1], min_xy[0]: max_xy[0]] - cam_info[2]) * depth / cam_info[0] # [H, W] 65 | y = (Y_np[min_xy[1]: max_xy[1], min_xy[0]: max_xy[0]] - cam_info[3]) * depth / cam_info[1] # [H, W] 66 | pcd = np.stack([x, y, depth], axis=2).reshape([-1, 3]) # [H * W, 3] 67 | return pcd -------------------------------------------------------------------------------- /utils/criterion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def evaluate_Rt(R_gt, t_gt, R_pred, t_pred): 4 | t_pred = t_pred.flatten() 5 | t_gt = t_gt.flatten() 6 | R_err = np.arccos(np.clip((np.trace(R_pred.T @ R_gt) - 1.) / 2., -1., 1.0)) 7 | t_err = np.sum((t_pred - t_gt) ** 2) 8 | 9 | return R_err, t_err 10 | 11 | def mAP(Rs_pred, ts_pred, Rs_gt, ts_gt): 12 | """ 13 | Rs_pred: [B, 3, 3] 14 | Rs_gt: [B, 3, 3] 15 | ts_pred: [B, 3] 16 | ts_gt: [B, 3] 17 | """ 18 | Rs_err, ts_err = [], [] 19 | for idx in range(len(Rs_pred)): 20 | R_pred, t_pred = Rs_pred[idx], ts_pred[idx] 21 | R_gt, t_gt = Rs_gt[idx], ts_gt[idx] 22 | R_err, t_err = evaluate_Rt(R_gt, t_gt, R_pred, t_pred) 23 | Rs_err.append(R_err) 24 | ts_err.append(t_err) 25 | Rs_err, ts_err = np.asarray(Rs_err), np.asarray(ts_err) 26 | Rs_err = Rs_err * 180. / np.pi 27 | 28 | R_ths = np.array([0., 5., 10., 20.]) 29 | t_ths = np.array([0., 1., 2., 5.]) 30 | R_acc_hist, _ = np.histogram(Rs_err, R_ths) 31 | t_acc_hist, _ = np.histogram(ts_err, t_ths) 32 | num_pair = float(len(Rs_err)) 33 | R_acc_hist = R_acc_hist.astype(float) / num_pair 34 | t_acc_hist = t_acc_hist.astype(float) / num_pair 35 | R_acc = np.cumsum(R_acc_hist) 36 | t_acc = np.cumsum(t_acc_hist) 37 | 38 | return R_acc, t_acc 39 | 40 | -------------------------------------------------------------------------------- /utils/data_classes.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch, numpy as np 4 | from pyquaternion import Quaternion 5 | 6 | class PointCloud: 7 | 8 | def __init__(self, points): 9 | """ 10 | Class for manipulating and viewing point clouds. 11 | :param points: . Input point cloud matrix. 12 | """ 13 | self.points = points 14 | if self.points.shape[0] > 3: 15 | self.points = self.points[0:3, :] 16 | 17 | @staticmethod 18 | def load_pcd_bin(file_name): 19 | """ 20 | Loads from binary format. Data is stored as (x, y, z, intensity, ring index). 21 | :param file_name: . 22 | :return: . Point cloud matrix (x, y, z, intensity). 23 | """ 24 | scan = np.fromfile(file_name, dtype=np.float32) 25 | points = scan.reshape((-1, 5))[:, :4] 26 | return points.T 27 | 28 | @classmethod 29 | def from_file(cls, file_name): 30 | """ 31 | Instantiate from a .pcl, .pdc, .npy, or .bin file. 32 | :param file_name: . Path of the pointcloud file on disk. 33 | :return: . 34 | """ 35 | 36 | if file_name.endswith('.bin'): 37 | points = cls.load_pcd_bin(file_name) 38 | elif file_name.endswith('.npy'): 39 | points = np.load(file_name) 40 | else: 41 | raise ValueError('Unsupported filetype {}'.format(file_name)) 42 | 43 | return cls(points) 44 | 45 | def nbr_points(self): 46 | """ 47 | Returns the number of points. 48 | :return: . Number of points. 49 | """ 50 | return self.points.shape[1] 51 | 52 | def subsample(self, ratio): 53 | """ 54 | Sub-samples the pointcloud. 55 | :param ratio: . Fraction to keep. 56 | :return: . 57 | """ 58 | selected_ind = np.random.choice(np.arange(0, self.nbr_points()), 59 | size=int(self.nbr_points() * ratio)) 60 | self.points = self.points[:, selected_ind] 61 | return self 62 | 63 | def remove_close(self, radius): 64 | """ 65 | Removes point too close within a certain radius from origin. 66 | :param radius: . 67 | :return: . 68 | """ 69 | 70 | x_filt = np.abs(self.points[0, :]) < radius 71 | y_filt = np.abs(self.points[1, :]) < radius 72 | not_close = np.logical_not(np.logical_and(x_filt, y_filt)) 73 | self.points = self.points[:, not_close] 74 | 75 | def translate(self, x): 76 | """ 77 | Applies a translation to the point cloud. 78 | :param x: . Translation in x, y, z. 79 | :return: . 80 | """ 81 | for i in range(3): 82 | self.points[i, :] = self.points[i, :] + x[i] 83 | 84 | return self 85 | 86 | def rotate(self, rot_matrix): 87 | """ 88 | Applies a rotation. 89 | :param rot_matrix: . Rotation matrix. 90 | :return: . 91 | """ 92 | self.points[:3, :] = np.dot(rot_matrix, self.points[:3, :]) 93 | return self 94 | 95 | def transform(self, transf_matrix): 96 | """ 97 | Applies a homogeneous transform. 98 | :param transf_matrix: . Homogenous transformation matrix. 99 | :return: . 100 | """ 101 | self.points[:3, :] = transf_matrix.dot( 102 | np.vstack((self.points[:3, :], np.ones(self.nbr_points()))))[:3, :] 103 | 104 | def convertToPytorch(self): 105 | """ 106 | Helper from pytorch. 107 | :return: Pytorch array of points. 108 | """ 109 | return torch.from_numpy(self.points) 110 | 111 | @staticmethod 112 | def fromPytorch(cls, pytorchTensor): 113 | """ 114 | Loads from binary format. Data is stored as (x, y, z, intensity, ring index). 115 | :param pyttorchTensor: . 116 | :return: . Point cloud matrix (x, y, z, intensity). 117 | """ 118 | points = pytorchTensor.numpy() 119 | # points = points.reshape((-1, 5))[:, :4] 120 | return cls(points) 121 | 122 | def normalize(self, wlh): 123 | normalizer = [wlh[1], wlh[0], wlh[2]] 124 | self.points = self.points / np.atleast_2d(normalizer).T 125 | return self 126 | 127 | 128 | class BBox: 129 | """ Simple data class representing a 3d box including, label, score and velocity. """ 130 | 131 | def __init__(self, center, size, orientation, label=np.nan, score=np.nan, velocity=(np.nan, np.nan, np.nan), 132 | name=None): 133 | """ 134 | :param center: [: 3]. Center of box given as x, y, z. 135 | :param size: [: 3]. Size of box in width, length, height. 136 | :param orientation: . Box orientation. 137 | :param label: . Integer label, optional. 138 | :param score: . Classification score, optional. 139 | :param velocity: [: 3]. Box velocity in x, y, z direction. 140 | :param name: . Box name, optional. Can be used e.g. for denote category name. 141 | """ 142 | assert not np.any(np.isnan(center)) 143 | assert not np.any(np.isnan(size)) 144 | assert len(center) == 3 145 | assert len(size) == 3 146 | # assert type(orientation) == Quaternion 147 | 148 | self.center = np.array(center) 149 | self.wlh = np.array(size) 150 | self.orientation = orientation 151 | self.label = int(label) if not np.isnan(label) else label 152 | self.score = float(score) if not np.isnan(score) else score 153 | self.velocity = np.array(velocity) 154 | self.name = name 155 | 156 | def __eq__(self, other): 157 | center = np.allclose(self.center, other.center) 158 | wlh = np.allclose(self.wlh, other.wlh) 159 | orientation = np.allclose(self.orientation.elements, other.orientation.elements) 160 | label = (self.label == other.label) or (np.isnan(self.label) and np.isnan(other.label)) 161 | score = (self.score == other.score) or (np.isnan(self.score) and np.isnan(other.score)) 162 | vel = (np.allclose(self.velocity, other.velocity) or 163 | (np.all(np.isnan(self.velocity)) and np.all(np.isnan(other.velocity)))) 164 | 165 | return center and wlh and orientation and label and score and vel 166 | 167 | def __repr__(self): 168 | repr_str = 'label: {}, score: {:.2f}, xyz: [{:.2f}, {:.2f}, {:.2f}], wlh: [{:.2f}, {:.2f}, {:.2f}], ' \ 169 | 'rot axis: [{:.2f}, {:.2f}, {:.2f}], ang(degrees): {:.2f}, ang(rad): {:.2f}, ' \ 170 | 'vel: {:.2f}, {:.2f}, {:.2f}, name: {}' 171 | 172 | return repr_str.format(self.label, self.score, self.center[0], self.center[1], self.center[2], self.wlh[0], 173 | self.wlh[1], self.wlh[2], self.orientation.axis[0], self.orientation.axis[1], 174 | self.orientation.axis[2], self.orientation.degrees, self.orientation.radians, 175 | self.velocity[0], self.velocity[1], self.velocity[2], self.name) 176 | 177 | def encode(self): 178 | """ 179 | Encodes the box instance to a JSON-friendly vector representation. 180 | :return: [: 16]. List of floats encoding the box. 181 | """ 182 | return self.center.tolist() + self.wlh.tolist() + self.orientation.elements.tolist() + [self.label] + [self.score] + self.velocity.tolist() + [self.name] 183 | 184 | @classmethod 185 | def decode(cls, data): 186 | """ 187 | Instantiates a Box instance from encoded vector representation. 188 | :param data: [: 16]. Output from encode. 189 | :return: . 190 | """ 191 | return BBox(data[0:3], data[3:6], Quaternion(data[6:10]), label=data[10], score=data[11], velocity=data[12:15], 192 | name=data[15]) 193 | 194 | @property 195 | def rotation_matrix(self): 196 | """ 197 | Return a rotation matrix. 198 | :return: . 199 | """ 200 | return self.orientation.rotation_matrix 201 | 202 | def translate(self, x): 203 | """ 204 | Applies a translation. 205 | :param x: . Translation in x, y, z direction. 206 | :return: . 207 | """ 208 | self.center += x 209 | return self 210 | 211 | def rotate(self, quaternion): 212 | """ 213 | Rotates box. 214 | :param quaternion: . Rotation to apply. 215 | :return: . 216 | """ 217 | self.center = np.dot(quaternion.rotation_matrix, self.center) 218 | self.orientation = quaternion * self.orientation 219 | self.velocity = np.dot(quaternion.rotation_matrix, self.velocity) 220 | return self 221 | 222 | def transform(self, transf_matrix): 223 | transformed = np.dot(transf_matrix[0:3,0:4].T, self.center) 224 | self.center = transformed[0:3]/transformed[3] 225 | self.orientation = self.orientation * Quaternion(matrix = transf_matrix[0:3,0:3]) 226 | self.velocity = np.dot(transf_matrix[0:3,0:3], self.velocity) 227 | 228 | def corners(self, wlh_factor=1.0): 229 | """ 230 | Returns the bounding box corners. 231 | :param wlh_factor: . Multiply w, l, h by a factor to inflate or deflate the box. 232 | :return: . First four corners are the ones facing forward. 233 | The last four are the ones facing backwards. 234 | """ 235 | w, l, h = self.wlh * wlh_factor 236 | 237 | # 3D bounding box corners. (Convention: x points forward, y to the left, z up.) 238 | x_corners = l / 2 * np.array([1, 1, 1, 1, -1, -1, -1, -1]) 239 | y_corners = w / 2 * np.array([1, -1, -1, 1, 1, -1, -1, 1]) 240 | z_corners = h / 2 * np.array([1, 1, -1, -1, 1, 1, -1, -1]) 241 | corners = np.vstack((x_corners, y_corners, z_corners)) 242 | 243 | # Rotate 244 | corners = np.dot(self.orientation.rotation_matrix, corners) 245 | 246 | # Translate 247 | x, y, z = self.center 248 | corners[0, :] = corners[0, :] + x 249 | corners[1, :] = corners[1, :] + y 250 | corners[2, :] = corners[2, :] + z 251 | 252 | return corners 253 | 254 | def bottom_corners(self): 255 | """ 256 | Returns the four bottom corners. 257 | :return: . Bottom corners. First two face forward, last two face backwards. 258 | """ 259 | return self.corners()[:, [2, 3, 7, 6]] 260 | -------------------------------------------------------------------------------- /utils/diffusion_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | 3 | class DiffusionScheduler(torch.nn.Module): 4 | 5 | def __init__(self, opts): 6 | super().__init__() 7 | self.num_steps = opts.n_diff_steps 8 | self.beta_1 = opts.beta_1 9 | self.beta_T = opts.beta_T 10 | self.mode = opts.schedule_type 11 | 12 | if self.mode == 'linear': 13 | betas = torch.linspace(self.beta_1, self.beta_T, steps=self.num_steps) 14 | elif self.mode == 'cosine': 15 | def betas_fn(s): 16 | T = self.num_steps 17 | 18 | def f(t, T): 19 | return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2 20 | 21 | alphas = [] 22 | f0 = f(0, T) 23 | for t in range(T + 1): 24 | alphas.append(f(t, T) / f0) 25 | 26 | betas = [] 27 | for t in range(1, T + 1): 28 | betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999)) 29 | return betas 30 | 31 | if opts.S is None: 32 | opts.S = 0.008 33 | betas = betas_fn(s=opts.S) 34 | 35 | betas = torch.FloatTensor(betas) 36 | 37 | self.betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding 38 | self.alphas = 1 - self.betas 39 | 40 | log_alphas = torch.log(self.alphas) 41 | for i in range(1, log_alphas.size(0)): # 1 to T 42 | log_alphas[i] += log_alphas[i - 1] 43 | self.alpha_bars = log_alphas.exp() 44 | 45 | self.gamma0 = torch.zeros_like(self.betas) 46 | self.gamma1 = torch.zeros_like(self.betas) 47 | self.gamma2 = torch.zeros_like(self.betas) 48 | for t in range(2, self.num_steps + 1): # 2 to T 49 | self.gamma0[t] = self.betas[t] * torch.sqrt(self.alpha_bars[t - 1]) / (1. - self.alpha_bars[t]) 50 | self.gamma1[t] = (1. - self.alpha_bars[t - 1]) * torch.sqrt(self.alphas[t]) / (1. - self.alpha_bars[t]) 51 | self.gamma2[t] = (1. - self.alpha_bars[t - 1]) * self.betas[t] / (1. - self.alpha_bars[t]) 52 | 53 | def uniform_sample_t(self, batch_size): 54 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) 55 | return ts.tolist() -------------------------------------------------------------------------------- /utils/get_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.optim.lr_scheduler as lr_scheduler 2 | 3 | def get_lr_scheduler(opts, optimizer): 4 | scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.2) 5 | return scheduler 6 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_losses_diff(opts, src_pcd0, src_pcd, pred_transforms, transform_gt, loss_type = 'mae', reduction = 'mean'): 4 | 5 | losses = {} 6 | num_iter = len(pred_transforms) 7 | 8 | # Compute losses 9 | gt_src_transformed = transform(transform_gt, src_pcd0) 10 | if loss_type == 'mse': 11 | # MSE loss to the groundtruth (does not take into account possible symmetries) 12 | criterion = torch.nn.MSELoss(reduction=reduction) 13 | for i in range(num_iter): 14 | pred_src_transformed = transform(pred_transforms[i], src_pcd) 15 | if reduction.lower() == 'mean': 16 | losses['mse_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) 17 | elif reduction.lower() == 'none': 18 | losses['mse_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), 19 | dim=[-1, -2]) 20 | elif loss_type == 'mae': 21 | # MSE loss to the groundtruth (does not take into account possible symmetries) 22 | criterion = torch.nn.L1Loss(reduction=reduction) 23 | for i in range(num_iter): 24 | pred_src_transformed = transform(pred_transforms[i], src_pcd) 25 | if reduction.lower() == 'mean': 26 | losses['mae_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) 27 | elif reduction.lower() == 'none': 28 | losses['mae_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), dim=[-1, -2]) 29 | else: 30 | raise NotImplementedError 31 | 32 | discount_factor = 0.5 # Early iterations will be discounted 33 | total_losses = [] 34 | for k in losses: 35 | discount = discount_factor ** (num_iter - int(k[k.rfind('_')+1:]) - 1) 36 | total_losses.append(losses[k] * discount) 37 | losses['total'] = torch.sum(torch.stack(total_losses), dim=0) 38 | 39 | return losses 40 | 41 | def compute_losses(opts, src_pcd, pred_transforms, transform_gt, loss_type = 'mae', reduction = 'mean'): 42 | 43 | losses = {} 44 | num_iter = len(pred_transforms) 45 | 46 | # Compute losses 47 | gt_src_transformed = transform(transform_gt, src_pcd) 48 | if loss_type == 'mse': 49 | # MSE loss to the groundtruth (does not take into account possible symmetries) 50 | criterion = torch.nn.MSELoss(reduction=reduction) 51 | for i in range(num_iter): 52 | pred_src_transformed = transform(pred_transforms[i], src_pcd) 53 | if reduction.lower() == 'mean': 54 | losses['mse_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) 55 | elif reduction.lower() == 'none': 56 | losses['mse_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), 57 | dim=[-1, -2]) 58 | elif loss_type == 'mae': 59 | # MSE loss to the groundtruth (does not take into account possible symmetries) 60 | criterion = torch.nn.L1Loss(reduction=reduction) 61 | for i in range(num_iter): 62 | pred_src_transformed = transform(pred_transforms[i], src_pcd) 63 | if reduction.lower() == 'mean': 64 | losses['mae_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) 65 | elif reduction.lower() == 'none': 66 | losses['mae_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), dim=[-1, -2]) 67 | else: 68 | raise NotImplementedError 69 | 70 | discount_factor = 0.5 # Early iterations will be discounted 71 | total_losses = [] 72 | for k in losses: 73 | discount = discount_factor ** (num_iter - int(k[k.rfind('_')+1:]) - 1) 74 | total_losses.append(losses[k] * discount) 75 | losses['total'] = torch.sum(torch.stack(total_losses), dim=0) 76 | 77 | return losses 78 | 79 | def transform(g, a, normals=None): 80 | R = g[..., :3, :3] # (B, 3, 3) 81 | p = g[..., :3, 3] # (B, 3) 82 | 83 | if len(g.size()) == len(a.size()): 84 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :] 85 | else: 86 | raise NotImplementedError 87 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked 88 | 89 | if normals is not None: 90 | rotated_normals = normals @ R.transpose(-1, -2) 91 | return b, rotated_normals 92 | 93 | else: 94 | return b -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.attr_dict import AttrDict 3 | 4 | opts = AttrDict() 5 | opts.is_debug = False 6 | opts.is_normal = True 7 | opts.batch_size = 32 8 | opts.n_workers = 12 9 | opts.n_epoches = 30 10 | opts.lr = 0.001 11 | opts.n_input_feats = 0 12 | opts.seed = 1 13 | opts.is_completion = True 14 | opts.device = torch.device("cuda") 15 | 16 | # model config 17 | opts.emb_nn = ["rpmnet_emb", "dgcnn", "pointnet"][0] 18 | opts.pointer = "transformer" 19 | opts.head = ["svd", "mlp"][0] 20 | opts.emb_dims = 96 21 | opts.n_blocks = 1 22 | opts.n_heads = 4 23 | opts.ff_dims = 256 24 | opts.dropout = 0.0 25 | if opts.emb_nn in ["rpmnet_emb"]: 26 | opts.features = ['ppf', 'dxyz', 'xyz'] 27 | opts.feat_dim = 96 28 | opts.num_neighbors = 64 29 | opts.radius = 0.3 -------------------------------------------------------------------------------- /utils/se_math/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import invmat, se3, sinc, so3, transforms 3 | 4 | #EOF -------------------------------------------------------------------------------- /utils/se_math/invmat.py: -------------------------------------------------------------------------------- 1 | """ inverse matrix """ 2 | 3 | import torch 4 | 5 | 6 | def batch_inverse(x): 7 | """ M(n) -> M(n); x -> x^-1 """ 8 | batch_size, h, w = x.size() 9 | assert h == w 10 | y = torch.zeros_like(x) 11 | for i in range(batch_size): 12 | y[i, :, :] = x[i, :, :].inverse() 13 | return y 14 | 15 | 16 | def batch_inverse_dx(y): 17 | """ backward """ 18 | # Let y(x) = x^-1. 19 | # compute dy 20 | # dy = dy(j,k) 21 | # = - y(j,m) * dx(m,n) * y(n,k) 22 | # = - y(j,m) * y(n,k) * dx(m,n) 23 | # therefore, 24 | # dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 25 | batch_size, h, w = y.size() 26 | assert h == w 27 | # compute dy(j,k,m,n) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 28 | # = - (y(j,:))' * y'(k,:) 29 | yl = y.repeat(1, 1, h).view(batch_size * h * h, h, 1) 30 | yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size * h * h, 1, h) 31 | dy = - yl.bmm(yr).view(batch_size, h, h, h, h) 32 | 33 | # compute dy(m,n,j,k) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 34 | # = - (y'(m,:))' * y(n,:) 35 | # yl = y.transpose(1, 2).repeat(1, 1, h).view(batch_size*h*h, h, 1) 36 | # yr = y.repeat(1, h, 1).view(batch_size*h*h, 1, h) 37 | # dy = - yl.bmm(yr).view(batch_size, h, h, h, h) 38 | 39 | return dy 40 | 41 | 42 | def batch_pinv_dx(x): 43 | """ returns y = (x'*x)^-1 * x' and dy/dx. """ 44 | # y = (x'*x)^-1 * x' 45 | # = s^-1 * x' 46 | # = b * x' 47 | # d{y(j,k)}/d{x(m,n)} 48 | # = d{b(j,i) * x(k,i)}/d{x(m,n)} 49 | # = d{b(j,i)}/d{x(m,n)} * x(k,i) + b(j,i) * d{x(k,i)}/d{x(m,n)} 50 | # d{b(j,i)}/d{x(m,n)} 51 | # = d{b(j,i)}/d{s(p,q)} * d{s(p,q)}/d{x(m,n)} 52 | # = -b(j,p)*b(q,i) * d{s(p,q)}/d{x(m,n)} 53 | # d{s(p,q)}/d{x(m,n)} 54 | # = d{x(t,p)*x(t,q)}/d{x(m,n)} 55 | # = d{x(t,p)}/d{x(m,n)} * x(t,q) + x(t,p) * d{x(t,q)}/d{x(m,n)} 56 | batch_size, h, w = x.size() 57 | xt = x.transpose(1, 2) 58 | s = xt.bmm(x) 59 | b = batch_inverse(s) 60 | y = b.bmm(xt) 61 | 62 | # dx/dx 63 | ex = torch.eye(h * w).to(x).unsqueeze(0).view(1, h, w, h, w) 64 | # ds/dx = dx(t,_)/dx * x(t,_) + x(t,_) * dx(t,_)/dx 65 | ex1 = ex.view(1, h, w * h * w) # [t, p*m*n] 66 | dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) # [q, p,m,n] 67 | ds_dx = dx1.transpose(1, 2) + dx1 # [p, q, m, n] 68 | # db/ds 69 | db_ds = batch_inverse_dx(b) # [j, i, p, q] 70 | # db/dx = db/d{s(p,q)} * d{s(p,q)}/dx 71 | db1 = db_ds.view(batch_size, w * w, w * w).bmm(ds_dx.view(batch_size, w * w, h * w)) 72 | db_dx = db1.view(batch_size, w, w, h, w) # [j, i, m, n] 73 | # dy/dx = db(_,i)/dx * x(_,i) + b(_,i) * dx(_,i)/dx 74 | dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w * h * w) 75 | dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) # [k, j, m, n] 76 | ext = ex.transpose(1, 2).contiguous().view(1, w, h * h * w) 77 | dy2 = b.matmul(ext).view(batch_size, w, h, h, w) # [j, k, m, n] 78 | dy_dx = dy1.transpose(1, 2) + dy2 79 | 80 | return y, dy_dx 81 | 82 | 83 | class InvMatrix(torch.autograd.Function): 84 | """ M(n) -> M(n); x -> x^-1. 85 | """ 86 | 87 | @staticmethod 88 | def forward(ctx, x): 89 | y = batch_inverse(x) 90 | ctx.save_for_backward(y) 91 | return y 92 | 93 | @staticmethod 94 | def backward(ctx, grad_output): 95 | y, = ctx.saved_tensors # v0.4 96 | # y, = ctx.saved_variables # v0.3.1 97 | batch_size, h, w = y.size() 98 | assert h == w 99 | 100 | # Let y(x) = x^-1 and assume any function f(y(x)). 101 | # compute df/dx(m,n)... 102 | # df/dx(m,n) = df/dy(j,k) * dy(j,k)/dx(m,n) 103 | # well, df/dy is 'grad_output' 104 | # and so we will return 'grad_input = df/dy(j,k) * dy(j,k)/dx(m,n)' 105 | 106 | dy = batch_inverse_dx(y) # dy(j,k,m,n) = dy(j,k)/dx(m,n) 107 | go = grad_output.contiguous().view(batch_size, 1, h * h) # [1, (j*k)] 108 | ym = dy.view(batch_size, h * h, h * h) # [(j*k), (m*n)] 109 | r = go.bmm(ym) # [1, (m*n)] 110 | grad_input = r.view(batch_size, h, h) # [m, n] 111 | 112 | return grad_input 113 | 114 | 115 | if __name__ == '__main__': 116 | def test(): 117 | x = torch.randn(2, 3, 2) 118 | x_val = x.requires_grad_() 119 | 120 | s_val = x_val.transpose(1, 2).bmm(x_val) 121 | s_inv = InvMatrix.apply(s_val) 122 | y_val = s_inv.bmm(x_val.transpose(1, 2)) 123 | y_val.sum().backward() 124 | t1 = x_val.grad 125 | 126 | y, dy_dx = batch_pinv_dx(x) 127 | t2 = dy_dx.sum(1).sum(1) 128 | 129 | print(t1) 130 | print(t2) 131 | print(t1 - t2) 132 | 133 | 134 | test() 135 | 136 | # EOF 137 | -------------------------------------------------------------------------------- /utils/se_math/mesh.py: -------------------------------------------------------------------------------- 1 | """ 3-d mesh reader """ 2 | import os 3 | import copy 4 | import numpy 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 7 | import matplotlib.pyplot 8 | 9 | # used to read ply files 10 | from plyfile import PlyData 11 | import open3d as o3d 12 | import numpy as np 13 | 14 | 15 | class Mesh: 16 | def __init__(self): 17 | self._vertices = [] # array-like (N, D) 18 | self._faces = [] # array-like (M, K) 19 | self._edges = [] # array-like (L, 2) 20 | 21 | def clone(self): 22 | other = copy.deepcopy(self) 23 | return other 24 | 25 | def clear(self): 26 | for key in self.__dict__: 27 | self.__dict__[key] = [] 28 | 29 | def add_attr(self, name): 30 | self.__dict__[name] = [] 31 | 32 | @property 33 | def vertex_array(self): 34 | return numpy.array(self._vertices) 35 | 36 | @property 37 | def vertex_list(self): 38 | return list(map(tuple, self._vertices)) 39 | 40 | @staticmethod 41 | def faces2polygons(faces, vertices): 42 | p = list(map(lambda face: \ 43 | list(map(lambda vidx: vertices[vidx], face)), faces)) 44 | return p 45 | 46 | @property 47 | def polygon_list(self): 48 | p = Mesh.faces2polygons(self._faces, self._vertices) 49 | return p 50 | 51 | def plot(self, fig=None, ax=None, *args, **kwargs): 52 | p = self.polygon_list 53 | v = self.vertex_array 54 | if fig is None: 55 | fig = matplotlib.pyplot.gcf() 56 | if ax is None: 57 | ax = Axes3D(fig) 58 | if p: 59 | ax.add_collection3d(Poly3DCollection(p)) 60 | if v.shape: 61 | ax.scatter(v[:, 0], v[:, 1], v[:, 2], *args, **kwargs) 62 | ax.set_xlabel('X') 63 | ax.set_ylabel('Y') 64 | ax.set_zlabel('Z') 65 | return fig, ax 66 | 67 | def on_unit_sphere(self, zero_mean=False): 68 | # radius == 1 69 | v = self.vertex_array # (N, D) 70 | if zero_mean: 71 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 72 | v[:, 0:3] = v[:, 0:3] - a 73 | n = numpy.linalg.norm(v[:, 0:3], axis=1) # (N,) 74 | m = numpy.max(n) # scalar 75 | v[:, 0:3] = v[:, 0:3] / m 76 | self._vertices = v 77 | return self 78 | 79 | def on_unit_cube(self, zero_mean=False): 80 | # volume == 1 81 | v = self.vertex_array # (N, D) 82 | if zero_mean: 83 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 84 | v[:, 0:3] = v[:, 0:3] - a 85 | m = numpy.max(numpy.abs(v)) # scalar 86 | v[:, 0:3] = v[:, 0:3] / (m * 2) 87 | self._vertices = v 88 | return self 89 | 90 | def rot_x(self): 91 | # camera local (up: +Y, front: -Z) -> model local (up: +Z, front: +Y). 92 | v = self.vertex_array 93 | t = numpy.copy(v[:, 1]) 94 | v[:, 1] = -numpy.copy(v[:, 2]) 95 | v[:, 2] = t 96 | self._vertices = list(map(tuple, v)) 97 | return self 98 | 99 | def rot_zc(self): 100 | # R = [0, -1; 101 | # 1, 0] 102 | v = self.vertex_array 103 | x = numpy.copy(v[:, 0]) 104 | y = numpy.copy(v[:, 1]) 105 | v[:, 0] = -y 106 | v[:, 1] = x 107 | self._vertices = list(map(tuple, v)) 108 | return self 109 | 110 | def offread_uniformed(filepath, sampled_pt_num=1024): 111 | """ read OFF mesh file and uniformly sample points on the mesh. """ 112 | mesh = Mesh() 113 | input = o3d.io.read_triangle_mesh(filepath) 114 | pointCloud = input.sample_points_uniformly(sampled_pt_num) 115 | points = np.asarray(pointCloud.points) 116 | pts = tuple(map(tuple, points)) 117 | mesh._vertices = pts 118 | 119 | return mesh 120 | 121 | def offread(filepath, points_only=True): 122 | """ read Geomview OFF file. """ 123 | with open(filepath, 'r') as fin: 124 | mesh, fixme = _load_off(fin, points_only) 125 | if fixme: 126 | _fix_modelnet_broken_off(filepath) 127 | return mesh 128 | 129 | 130 | def _load_off(fin, points_only): 131 | """ read Geomview OFF file. """ 132 | mesh = Mesh() 133 | 134 | fixme = False 135 | sig = fin.readline().strip() 136 | if sig == 'OFF': 137 | line = fin.readline().strip() 138 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 139 | elif sig[0:3] == 'OFF': # ...broken data in ModelNet (missing '\n')... 140 | line = sig[3:] 141 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 142 | fixme = True 143 | else: 144 | raise RuntimeError('unknown format') 145 | 146 | for v in range(num_verts): 147 | vp = tuple(float(s) for s in fin.readline().strip().split(' ')) 148 | mesh._vertices.append(vp) 149 | 150 | if points_only: 151 | return mesh, fixme 152 | 153 | for f in range(num_faces): 154 | fc = tuple([int(s) for s in fin.readline().strip().split(' ')][1:]) 155 | mesh._faces.append(fc) 156 | 157 | return mesh, fixme 158 | 159 | 160 | def _fix_modelnet_broken_off(filepath): 161 | oldfile = '{}.orig'.format(filepath) 162 | os.rename(filepath, oldfile) 163 | with open(oldfile, 'r') as fin: 164 | with open(filepath, 'w') as fout: 165 | sig = fin.readline().strip() 166 | line = sig[3:] 167 | print('OFF', file=fout) 168 | print(line, file=fout) 169 | for line in fin: 170 | print(line.strip(), file=fout) 171 | 172 | 173 | def objread(filepath, points_only=True): 174 | """Loads a Wavefront OBJ file. """ 175 | _vertices = [] 176 | _normals = [] 177 | _texcoords = [] 178 | _faces = [] 179 | _mtl_name = None 180 | 181 | material = None 182 | for line in open(filepath, "r"): 183 | if line.startswith('#'): continue 184 | values = line.split() 185 | if not values: continue 186 | if values[0] == 'v': 187 | v = tuple(map(float, values[1:4])) 188 | _vertices.append(v) 189 | elif values[0] == 'vn': 190 | v = tuple(map(float, values[1:4])) 191 | _normals.append(v) 192 | elif values[0] == 'vt': 193 | _texcoords.append(tuple(map(float, values[1:3]))) 194 | elif values[0] in ('usemtl', 'usemat'): 195 | material = values[1] 196 | elif values[0] == 'mtllib': 197 | _mtl_name = values[1] 198 | elif values[0] == 'f': 199 | face_ = [] 200 | texcoords_ = [] 201 | norms_ = [] 202 | for v in values[1:]: 203 | w = v.split('/') 204 | face_.append(int(w[0]) - 1) 205 | if len(w) >= 2 and len(w[1]) > 0: 206 | texcoords_.append(int(w[1]) - 1) 207 | else: 208 | texcoords_.append(-1) 209 | if len(w) >= 3 and len(w[2]) > 0: 210 | norms_.append(int(w[2]) - 1) 211 | else: 212 | norms_.append(-1) 213 | # _faces.append((face_, norms_, texcoords_, material)) 214 | _faces.append(face_) 215 | 216 | mesh = Mesh() 217 | mesh._vertices = _vertices 218 | if points_only: 219 | return mesh 220 | 221 | mesh._faces = _faces 222 | 223 | return mesh 224 | 225 | 226 | def plyread(filepath, points_only=True): 227 | # read binary ply file and return [x, y, z] array 228 | data = PlyData.read(filepath) 229 | vertex = data['vertex'] 230 | 231 | (x, y, z) = (vertex[t] for t in ('x', 'y', 'z')) 232 | num_verts = len(x) 233 | 234 | mesh = Mesh() 235 | 236 | for v in range(num_verts): 237 | vp = tuple(float(s) for s in [x[v], y[v], z[v]]) 238 | mesh._vertices.append(vp) 239 | 240 | return mesh 241 | 242 | 243 | if __name__ == '__main__': 244 | def test1(): 245 | mesh = objread('model_normalized.obj', points_only=False) 246 | # mesh.on_unit_sphere() 247 | mesh.rot_x() 248 | mesh.plot(c='m') 249 | matplotlib.pyplot.show() 250 | 251 | 252 | def test2(): 253 | mesh = plyread('1.ply', points_only=True) 254 | # mesh.on_unit_sphere() 255 | mesh.rot_x() 256 | mesh.plot(c='m') 257 | matplotlib.pyplot.show() 258 | 259 | 260 | def make_open3d_point_cloud(xyz, color=None): 261 | pcd = o3d.geometry.PointCloud() 262 | pcd.points = o3d.utility.Vector3dVector(xyz) 263 | if color is not None: 264 | if len(color) != len(xyz): 265 | color = np.tile(color, (len(xyz), 1)) 266 | pcd.colors = o3d.utility.Vector3dVector(color) 267 | return pcd 268 | 269 | def test3(): 270 | mesh = offread("../data/bed.off",False) 271 | mesh = offread_uniformed("../data/bed.off") 272 | points = mesh.vertex_array 273 | p1 = np.asarray(points) 274 | pcd = make_open3d_point_cloud(p1) 275 | o3d.visualization.draw_geometries([pcd]) 276 | 277 | 278 | test3() 279 | 280 | # EOF 281 | -------------------------------------------------------------------------------- /utils/se_math/se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transfomation group and corresponding Lie algebra. """ 2 | import torch 3 | from .sinc import sinc1, sinc2, sinc3 4 | from . import so3 5 | 6 | 7 | def twist_prod(x, y): 8 | x_ = x.view(-1, 6) 9 | y_ = y.view(-1, 6) 10 | 11 | xw, xv = x_[:, 0:3], x_[:, 3:6] 12 | yw, yv = y_[:, 0:3], y_[:, 3:6] 13 | 14 | zw = so3.cross_prod(xw, yw) 15 | zv = so3.cross_prod(xw, yv) + so3.cross_prod(xv, yw) 16 | 17 | z = torch.cat((zw, zv), dim=1) 18 | 19 | return z.view_as(x) 20 | 21 | 22 | def liebracket(x, y): 23 | return twist_prod(x, y) 24 | 25 | 26 | def mat(x): 27 | # size: [*, 6] -> [*, 4, 4] 28 | x_ = x.view(-1, 6) 29 | w1, w2, w3 = x_[:, 0], x_[:, 1], x_[:, 2] 30 | v1, v2, v3 = x_[:, 3], x_[:, 4], x_[:, 5] 31 | O = torch.zeros_like(w1) 32 | 33 | X = torch.stack(( 34 | torch.stack((O, -w3, w2, v1), dim=1), 35 | torch.stack((w3, O, -w1, v2), dim=1), 36 | torch.stack((-w2, w1, O, v3), dim=1), 37 | torch.stack((O, O, O, O), dim=1)), dim=1) 38 | return X.view(*(x.size()[0:-1]), 4, 4) 39 | 40 | 41 | def vec(X): 42 | X_ = X.view(-1, 4, 4) 43 | w1, w2, w3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 44 | v1, v2, v3 = X_[:, 0, 3], X_[:, 1, 3], X_[:, 2, 3] 45 | x = torch.stack((w1, w2, w3, v1, v2, v3), dim=1) 46 | return x.view(*X.size()[0:-2], 6) 47 | 48 | 49 | def genvec(): 50 | return torch.eye(6) 51 | 52 | 53 | def genmat(): 54 | return mat(genvec()) 55 | 56 | 57 | def exp(x): 58 | x_ = x.view(-1, 6) 59 | w, v = x_[:, 0:3], x_[:, 3:6] 60 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 61 | W = so3.mat(w) 62 | S = W.bmm(W) 63 | I = torch.eye(3).to(w) 64 | 65 | # Rodrigues' rotation formula. 66 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 67 | # = eye(3) + sinc1(t)*W + sinc2(t)*S 68 | R = I + sinc1(t) * W + sinc2(t) * S 69 | 70 | # V = sinc1(t)*eye(3) + sinc2(t)*W + sinc3(t)*(w*w') 71 | # = eye(3) + sinc2(t)*W + sinc3(t)*S 72 | V = I + sinc2(t) * W + sinc3(t) * S 73 | 74 | p = V.bmm(v.contiguous().view(-1, 3, 1)) 75 | 76 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(x_.size(0), 1, 1).to(x) 77 | Rp = torch.cat((R, p), dim=2) 78 | g = torch.cat((Rp, z), dim=1) 79 | 80 | return g.view(*(x.size()[0:-1]), 4, 4) 81 | 82 | 83 | def inverse(g): 84 | g_ = g.view(-1, 4, 4) 85 | R = g_[:, 0:3, 0:3] 86 | p = g_[:, 0:3, 3] 87 | Q = R.transpose(1, 2) 88 | q = -Q.matmul(p.unsqueeze(-1)) 89 | 90 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(g_.size(0), 1, 1).to(g) 91 | Qq = torch.cat((Q, q), dim=2) 92 | ig = torch.cat((Qq, z), dim=1) 93 | 94 | return ig.view(*(g.size()[0:-2]), 4, 4) 95 | 96 | 97 | def log(g): 98 | g_ = g.view(-1, 4, 4) 99 | R = g_[:, 0:3, 0:3] 100 | p = g_[:, 0:3, 3] 101 | 102 | w = so3.log(R) 103 | H = so3.inv_vecs_Xg_ig(w) 104 | v = H.bmm(p.contiguous().view(-1, 3, 1)).view(-1, 3) 105 | 106 | x = torch.cat((w, v), dim=1) 107 | return x.view(*(g.size()[0:-2]), 6) 108 | 109 | 110 | def transform(g, a): 111 | # g : SE(3), * x 4 x 4 112 | # a : R^3, * x 3[x N] 113 | g_ = g.view(-1, 4, 4) 114 | R = g_[:, 0:3, 0:3].contiguous().view(*(g.size()[0:-2]), 3, 3) 115 | p = g_[:, 0:3, 3].contiguous().view(*(g.size()[0:-2]), 3) 116 | if len(g.size()) == len(a.size()): 117 | b = R.matmul(a) + p.unsqueeze(-1) 118 | else: 119 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p 120 | return b 121 | 122 | 123 | def group_prod(g, h): 124 | # g, h : SE(3) 125 | g1 = g.matmul(h) 126 | return g1 127 | 128 | 129 | class ExpMap(torch.autograd.Function): 130 | """ Exp: se(3) -> SE(3) 131 | """ 132 | 133 | @staticmethod 134 | def forward(ctx, x): 135 | """ Exp: R^6 -> M(4), 136 | size: [B, 6] -> [B, 4, 4], 137 | or [B, 1, 6] -> [B, 1, 4, 4] 138 | """ 139 | ctx.save_for_backward(x) 140 | g = exp(x) 141 | return g 142 | 143 | @staticmethod 144 | def backward(ctx, grad_output): 145 | x, = ctx.saved_tensors 146 | g = exp(x) 147 | gen_k = genmat().to(x) 148 | 149 | # Let z = f(g) = f(exp(x)) 150 | # dz = df/dgij * dgij/dxk * dxk 151 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 152 | # = df/dgij * [gen_k*g]_ij * dxk 153 | 154 | dg = gen_k.matmul(g.view(-1, 1, 4, 4)) 155 | # (k, i, j) 156 | dg = dg.to(grad_output) 157 | 158 | go = grad_output.contiguous().view(-1, 1, 4, 4) 159 | dd = go * dg 160 | grad_input = dd.sum(-1).sum(-1) 161 | 162 | return grad_input 163 | 164 | 165 | Exp = ExpMap.apply 166 | 167 | # EOF 168 | -------------------------------------------------------------------------------- /utils/se_math/sinc.py: -------------------------------------------------------------------------------- 1 | """ sinc(t) := sin(t) / t """ 2 | import torch 3 | from torch import sin, cos 4 | 5 | 6 | def sinc1(t): 7 | """ sinc1: t -> sin(t)/t """ 8 | e = 0.01 9 | r = torch.zeros_like(t) 10 | a = torch.abs(t) 11 | 12 | s = a < e 13 | c = (s == 0) 14 | t2 = t[s] ** 2 15 | r[s] = 1 - t2 / 6 * (1 - t2 / 20 * (1 - t2 / 42)) # Taylor series O(t^8) 16 | r[c] = sin(t[c]) / t[c] 17 | 18 | return r 19 | 20 | 21 | def sinc1_dt(t): 22 | """ d/dt(sinc1) """ 23 | e = 0.01 24 | r = torch.zeros_like(t) 25 | a = torch.abs(t) 26 | 27 | s = a < e 28 | c = (s == 0) 29 | t2 = t ** 2 30 | r[s] = -t[s] / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8) 31 | r[c] = cos(t[c]) / t[c] - sin(t[c]) / t2[c] 32 | 33 | return r 34 | 35 | 36 | def sinc1_dt_rt(t): 37 | """ d/dt(sinc1) / t """ 38 | e = 0.01 39 | r = torch.zeros_like(t) 40 | a = torch.abs(t) 41 | 42 | s = a < e 43 | c = (s == 0) 44 | t2 = t ** 2 45 | r[s] = -1 / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8) 46 | r[c] = (cos(t[c]) / t[c] - sin(t[c]) / t2[c]) / t[c] 47 | 48 | return r 49 | 50 | 51 | def rsinc1(t): 52 | """ rsinc1: t -> t/sinc1(t) """ 53 | e = 0.01 54 | r = torch.zeros_like(t) 55 | a = torch.abs(t) 56 | 57 | s = a < e 58 | c = (s == 0) 59 | t2 = t[s] ** 2 60 | r[s] = (((31 * t2) / 42 + 7) * t2 / 60 + 1) * t2 / 6 + 1 # Taylor series O(t^8) 61 | r[c] = t[c] / sin(t[c]) 62 | 63 | return r 64 | 65 | 66 | def rsinc1_dt(t): 67 | """ d/dt(rsinc1) """ 68 | e = 0.01 69 | r = torch.zeros_like(t) 70 | a = torch.abs(t) 71 | 72 | s = a < e 73 | c = (s == 0) 74 | t2 = t[s] ** 2 75 | r[s] = ((((127 * t2) / 30 + 31) * t2 / 28 + 7) * t2 / 30 + 1) * t[s] / 3 # Taylor series O(t^8) 76 | r[c] = 1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c])) 77 | 78 | return r 79 | 80 | 81 | def rsinc1_dt_csc(t): 82 | """ d/dt(rsinc1) / sin(t) """ 83 | e = 0.01 84 | r = torch.zeros_like(t) 85 | a = torch.abs(t) 86 | 87 | s = a < e 88 | c = (s == 0) 89 | t2 = t[s] ** 2 90 | r[s] = t2 * (t2 * ((4 * t2) / 675 + 2 / 63) + 2 / 15) + 1 / 3 # Taylor series O(t^8) 91 | r[c] = (1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c]))) / sin(t[c]) 92 | 93 | return r 94 | 95 | 96 | def sinc2(t): 97 | """ sinc2: t -> (1 - cos(t)) / (t**2) """ 98 | e = 0.01 99 | r = torch.zeros_like(t) 100 | a = torch.abs(t) 101 | 102 | s = a < e 103 | c = (s == 0) 104 | t2 = t ** 2 105 | r[s] = 1 / 2 * (1 - t2[s] / 12 * (1 - t2[s] / 30 * (1 - t2[s] / 56))) # Taylor series O(t^8) 106 | r[c] = (1 - cos(t[c])) / t2[c] 107 | 108 | return r 109 | 110 | 111 | def sinc2_dt(t): 112 | """ d/dt(sinc2) """ 113 | e = 0.01 114 | r = torch.zeros_like(t) 115 | a = torch.abs(t) 116 | 117 | s = a < e 118 | c = (s == 0) 119 | t2 = t ** 2 120 | r[s] = -t[s] / 12 * (1 - t2[s] / 5 * (1.0 / 3 - t2[s] / 56 * (1.0 / 2 - t2[s] / 135))) # Taylor series O(t^8) 121 | r[c] = sin(t[c]) / t2[c] - 2 * (1 - cos(t[c])) / (t2[c] * t[c]) 122 | 123 | return r 124 | 125 | 126 | def sinc3(t): 127 | """ sinc3: t -> (t - sin(t)) / (t**3) """ 128 | e = 0.01 129 | r = torch.zeros_like(t) 130 | a = torch.abs(t) 131 | 132 | s = a < e 133 | c = (s == 0) 134 | t2 = t[s] ** 2 135 | r[s] = 1 / 6 * (1 - t2 / 20 * (1 - t2 / 42 * (1 - t2 / 72))) # Taylor series O(t^8) 136 | r[c] = (t[c] - sin(t[c])) / (t[c] ** 3) 137 | 138 | return r 139 | 140 | 141 | def sinc3_dt(t): 142 | """ d/dt(sinc3) """ 143 | e = 0.01 144 | r = torch.zeros_like(t) 145 | a = torch.abs(t) 146 | 147 | s = a < e 148 | c = (s == 0) 149 | t2 = t[s] ** 2 150 | r[s] = -t[s] / 60 * (1 - t2 / 21 * (1 - t2 / 24 * (1.0 / 2 - t2 / 165))) # Taylor series O(t^8) 151 | r[c] = (3 * sin(t[c]) - t[c] * (cos(t[c]) + 2)) / (t[c] ** 4) 152 | 153 | return r 154 | 155 | 156 | def sinc4(t): 157 | """ sinc4: t -> 1/t^2 * (1/2 - sinc2(t)) 158 | = 1/t^2 * (1/2 - (1 - cos(t))/t^2) 159 | """ 160 | e = 0.01 161 | r = torch.zeros_like(t) 162 | a = torch.abs(t) 163 | 164 | s = a < e 165 | c = (s == 0) 166 | t2 = t ** 2 167 | r[s] = 1 / 24 * (1 - t2 / 30 * (1 - t2 / 56 * (1 - t2 / 90))) # Taylor series O(t^8) 168 | r[c] = (0.5 - (1 - cos(t)) / t2) / t2 169 | 170 | 171 | class Sinc1_autograd(torch.autograd.Function): 172 | @staticmethod 173 | def forward(ctx, theta): 174 | ctx.save_for_backward(theta) 175 | return sinc1(theta) 176 | 177 | @staticmethod 178 | def backward(ctx, grad_output): 179 | theta, = ctx.saved_tensors 180 | grad_theta = None 181 | if ctx.needs_input_grad[0]: 182 | grad_theta = grad_output * sinc1_dt(theta).to(grad_output) 183 | return grad_theta 184 | 185 | 186 | Sinc1 = Sinc1_autograd.apply 187 | 188 | 189 | class RSinc1_autograd(torch.autograd.Function): 190 | @staticmethod 191 | def forward(ctx, theta): 192 | ctx.save_for_backward(theta) 193 | return rsinc1(theta) 194 | 195 | @staticmethod 196 | def backward(ctx, grad_output): 197 | theta, = ctx.saved_tensors 198 | grad_theta = None 199 | if ctx.needs_input_grad[0]: 200 | grad_theta = grad_output * rsinc1_dt(theta).to(grad_output) 201 | return grad_theta 202 | 203 | 204 | RSinc1 = RSinc1_autograd.apply 205 | 206 | 207 | class Sinc2_autograd(torch.autograd.Function): 208 | @staticmethod 209 | def forward(ctx, theta): 210 | ctx.save_for_backward(theta) 211 | return sinc2(theta) 212 | 213 | @staticmethod 214 | def backward(ctx, grad_output): 215 | theta, = ctx.saved_tensors 216 | grad_theta = None 217 | if ctx.needs_input_grad[0]: 218 | grad_theta = grad_output * sinc2_dt(theta).to(grad_output) 219 | return grad_theta 220 | 221 | 222 | Sinc2 = Sinc2_autograd.apply 223 | 224 | 225 | class Sinc3_autograd(torch.autograd.Function): 226 | @staticmethod 227 | def forward(ctx, theta): 228 | ctx.save_for_backward(theta) 229 | return sinc3(theta) 230 | 231 | @staticmethod 232 | def backward(ctx, grad_output): 233 | theta, = ctx.saved_tensors 234 | grad_theta = None 235 | if ctx.needs_input_grad[0]: 236 | grad_theta = grad_output * sinc3_dt(theta).to(grad_output) 237 | return grad_theta 238 | 239 | 240 | Sinc3 = Sinc3_autograd.apply 241 | 242 | # EOF 243 | -------------------------------------------------------------------------------- /utils/se_math/so3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rotation group and corresponding Lie algebra """ 2 | import torch 3 | from . import sinc 4 | from .sinc import sinc1, sinc2, sinc3 5 | 6 | 7 | def cross_prod(x, y): 8 | z = torch.cross(x.view(-1, 3), y.view(-1, 3), dim=1).view_as(x) 9 | return z 10 | 11 | 12 | def liebracket(x, y): 13 | return cross_prod(x, y) 14 | 15 | 16 | def mat(x): 17 | # size: [*, 3] -> [*, 3, 3] 18 | x_ = x.view(-1, 3) 19 | x1, x2, x3 = x_[:, 0], x_[:, 1], x_[:, 2] 20 | O = torch.zeros_like(x1) 21 | 22 | X = torch.stack(( 23 | torch.stack((O, -x3, x2), dim=1), 24 | torch.stack((x3, O, -x1), dim=1), 25 | torch.stack((-x2, x1, O), dim=1)), dim=1) 26 | return X.view(*(x.size()[0:-1]), 3, 3) 27 | 28 | 29 | def vec(X): 30 | X_ = X.view(-1, 3, 3) 31 | x1, x2, x3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 32 | x = torch.stack((x1, x2, x3), dim=1) 33 | return x.view(*X.size()[0:-2], 3) 34 | 35 | 36 | def genvec(): 37 | return torch.eye(3) 38 | 39 | 40 | def genmat(): 41 | return mat(genvec()) 42 | 43 | 44 | def RodriguesRotation(x): 45 | # for autograd 46 | w = x.view(-1, 3) 47 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 48 | W = mat(w) 49 | S = W.bmm(W) 50 | I = torch.eye(3).to(w) 51 | 52 | # Rodrigues' rotation formula. 53 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 54 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S 55 | 56 | R = I + sinc.Sinc1(t) * W + sinc.Sinc2(t) * S 57 | 58 | return R.view(*(x.size()[0:-1]), 3, 3) 59 | 60 | 61 | def exp(x): 62 | w = x.view(-1, 3) 63 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 64 | W = mat(w) 65 | S = W.bmm(W) 66 | I = torch.eye(3).to(w) 67 | 68 | # Rodrigues' rotation formula. 69 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 70 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S 71 | 72 | R = I + sinc1(t) * W + sinc2(t) * S 73 | 74 | return R.view(*(x.size()[0:-1]), 3, 3) 75 | 76 | 77 | def inverse(g): 78 | R = g.view(-1, 3, 3) 79 | Rt = R.transpose(1, 2) 80 | return Rt.view_as(g) 81 | 82 | 83 | def btrace(X): 84 | # batch-trace: [B, N, N] -> [B] 85 | n = X.size(-1) 86 | X_ = X.view(-1, n, n) 87 | tr = torch.zeros(X_.size(0)).to(X) 88 | for i in range(tr.size(0)): 89 | m = X_[i, :, :] 90 | tr[i] = torch.trace(m) 91 | return tr.view(*(X.size()[0:-2])) 92 | 93 | 94 | def log(g): 95 | eps = 1.0e-7 96 | R = g.view(-1, 3, 3) 97 | tr = btrace(R) 98 | c = (tr - 1) / 2 99 | t = torch.acos(c) 100 | sc = sinc1(t) 101 | idx0 = (torch.abs(sc) <= eps) 102 | idx1 = (torch.abs(sc) > eps) 103 | sc = sc.view(-1, 1, 1) 104 | 105 | X = torch.zeros_like(R) 106 | if idx1.any(): 107 | X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2 * sc[idx1]) 108 | 109 | if idx0.any(): 110 | # t[idx0] == math.pi 111 | t2 = t[idx0] ** 2 112 | A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2 113 | aw1 = torch.sqrt(A[:, 0, 0]) 114 | aw2 = torch.sqrt(A[:, 1, 1]) 115 | aw3 = torch.sqrt(A[:, 2, 2]) 116 | sgn_3 = torch.sign(A[:, 0, 2]) 117 | sgn_3[sgn_3 == 0] = 1 118 | sgn_23 = torch.sign(A[:, 1, 2]) 119 | sgn_23[sgn_23 == 0] = 1 120 | sgn_2 = sgn_23 * sgn_3 121 | w1 = aw1 122 | w2 = aw2 * sgn_2 123 | w3 = aw3 * sgn_3 124 | w = torch.stack((w1, w2, w3), dim=-1) 125 | W = mat(w) 126 | X[idx0] = W 127 | 128 | x = vec(X.view_as(g)) 129 | return x 130 | 131 | 132 | def transform(g, a): 133 | # g in SO(3): * x 3 x 3 134 | # a in R^3: * x 3[x N] 135 | if len(g.size()) == len(a.size()): 136 | b = g.matmul(a) 137 | else: 138 | b = g.matmul(a.unsqueeze(-1)).squeeze(-1) 139 | return b 140 | 141 | 142 | def group_prod(g, h): 143 | # g, h : SO(3) 144 | g1 = g.matmul(h) 145 | return g1 146 | 147 | 148 | def vecs_Xg_ig(x): 149 | """ Vi = vec(dg/dxi * inv(g)), where g = exp(x) 150 | (== [Ad(exp(x))] * vecs_ig_Xg(x)) 151 | """ 152 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 153 | X = mat(x) 154 | S = X.bmm(X) 155 | # B = x.view(-1,3,1).bmm(x.view(-1,1,3)) # B = x*x' 156 | I = torch.eye(3).to(X) 157 | 158 | # V = sinc1(t)*eye(3) + sinc2(t)*X + sinc3(t)*B 159 | # V = eye(3) + sinc2(t)*X + sinc3(t)*S 160 | 161 | V = I + sinc2(t) * X + sinc3(t) * S 162 | 163 | return V.view(*(x.size()[0:-1]), 3, 3) 164 | 165 | 166 | def inv_vecs_Xg_ig(x): 167 | """ H = inv(vecs_Xg_ig(x)) """ 168 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 169 | X = mat(x) 170 | S = X.bmm(X) 171 | I = torch.eye(3).to(x) 172 | 173 | e = 0.01 174 | eta = torch.zeros_like(t) 175 | s = (t < e) 176 | c = (s == 0) 177 | t2 = t[s] ** 2 178 | eta[s] = ((t2 / 40 + 1) * t2 / 42 + 1) * t2 / 720 + 1 / 12 # O(t**8) 179 | eta[c] = (1 - (t[c] / 2) / torch.tan(t[c] / 2)) / (t[c] ** 2) 180 | 181 | H = I - 1 / 2 * X + eta * S 182 | return H.view(*(x.size()[0:-1]), 3, 3) 183 | 184 | 185 | class ExpMap(torch.autograd.Function): 186 | """ Exp: so(3) -> SO(3) 187 | """ 188 | 189 | @staticmethod 190 | def forward(ctx, x): 191 | """ Exp: R^3 -> M(3), 192 | size: [B, 3] -> [B, 3, 3], 193 | or [B, 1, 3] -> [B, 1, 3, 3] 194 | """ 195 | ctx.save_for_backward(x) 196 | g = exp(x) 197 | return g 198 | 199 | @staticmethod 200 | def backward(ctx, grad_output): 201 | x, = ctx.saved_tensors 202 | g = exp(x) 203 | gen_k = genmat().to(x) 204 | # gen_1 = gen_k[0, :, :] 205 | # gen_2 = gen_k[1, :, :] 206 | # gen_3 = gen_k[2, :, :] 207 | 208 | # Let z = f(g) = f(exp(x)) 209 | # dz = df/dgij * dgij/dxk * dxk 210 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 211 | # = df/dgij * [gen_k*g]_ij * dxk 212 | 213 | dg = gen_k.matmul(g.view(-1, 1, 3, 3)) 214 | # (k, i, j) 215 | dg = dg.to(grad_output) 216 | 217 | go = grad_output.contiguous().view(-1, 1, 3, 3) 218 | dd = go * dg 219 | grad_input = dd.sum(-1).sum(-1) 220 | 221 | return grad_input 222 | 223 | 224 | Exp = ExpMap.apply 225 | 226 | # EOF 227 | -------------------------------------------------------------------------------- /utils/se_math/transforms.py: -------------------------------------------------------------------------------- 1 | """ gives some transform methods for 3d points """ 2 | import math 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | from . import so3 8 | from . import se3 9 | 10 | 11 | class Mesh2Points: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, mesh): 16 | mesh = mesh.clone() 17 | v = mesh.vertex_array 18 | return torch.from_numpy(v).type(dtype=torch.float) 19 | 20 | 21 | class OnUnitSphere: 22 | def __init__(self, zero_mean=False): 23 | self.zero_mean = zero_mean 24 | 25 | def __call__(self, tensor): 26 | if self.zero_mean: 27 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 28 | v = tensor - m 29 | else: 30 | v = tensor 31 | nn = v.norm(p=2, dim=1) # [N, D] -> [N] 32 | nmax = torch.max(nn) 33 | return v / nmax 34 | 35 | 36 | class OnUnitCube: 37 | def __init__(self): 38 | pass 39 | 40 | def method1(self, tensor): 41 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 42 | v = tensor - m 43 | s = torch.max(v.abs()) 44 | v = v / s * 0.5 45 | return v 46 | 47 | def method2(self, tensor): 48 | c = torch.max(tensor, dim=0)[0] - torch.min(tensor, dim=0)[0] # [N, D] -> [D] 49 | s = torch.max(c) # -> scalar 50 | v = tensor / s 51 | return v - v.mean(dim=0, keepdim=True) 52 | 53 | def __call__(self, tensor): 54 | # return self.method1(tensor) 55 | return self.method2(tensor) 56 | 57 | 58 | class Resampler: 59 | """ [N, D] -> [M, D] """ 60 | 61 | def __init__(self, num): 62 | self.num = num 63 | 64 | def __call__(self, tensor): 65 | num_points, dim_p = tensor.size() 66 | out = torch.zeros(self.num, dim_p).to(tensor) 67 | 68 | selected = 0 69 | while selected < self.num: 70 | remainder = self.num - selected 71 | idx = torch.randperm(num_points) 72 | sel = min(remainder, num_points) 73 | val = tensor[idx[:sel]] 74 | out[selected:(selected + sel)] = val 75 | selected += sel 76 | return out 77 | 78 | 79 | class RandomTranslate: 80 | def __init__(self, mag=None, randomly=True): 81 | self.mag = 1.0 if mag is None else mag 82 | self.randomly = randomly 83 | self.igt = None 84 | 85 | def __call__(self, tensor): 86 | # tensor: [N, 3] 87 | amp = torch.rand(1) if self.randomly else 1.0 88 | t = torch.randn(1, 3).to(tensor) 89 | t = t / t.norm(p=2, dim=1, keepdim=True) * amp * self.mag 90 | 91 | g = torch.eye(4).to(tensor) 92 | g[0:3, 3] = t[0, :] 93 | self.igt = g # [4, 4] 94 | 95 | p1 = tensor + t 96 | return p1 97 | 98 | 99 | class RandomRotator: 100 | def __init__(self, mag=None, randomly=True): 101 | self.mag = math.pi if mag is None else mag 102 | self.randomly = randomly 103 | self.igt = None 104 | 105 | def __call__(self, tensor): 106 | # tensor: [N, 3] 107 | amp = torch.rand(1) if self.randomly else 1.0 108 | w = torch.randn(1, 3) 109 | w = w / w.norm(p=2, dim=1, keepdim=True) * amp * self.mag 110 | 111 | g = so3.exp(w).to(tensor) # [1, 3, 3] 112 | self.igt = g.squeeze(0) # [3, 3] 113 | 114 | p1 = so3.transform(g, tensor) # [1, 3, 3] x [N, 3] -> [N, 3] 115 | return p1 116 | 117 | 118 | class RandomRotatorZ: 119 | def __init__(self): 120 | self.mag = 2 * math.pi 121 | 122 | def __call__(self, tensor): 123 | # tensor: [N, 3] 124 | w = torch.Tensor([0, 0, 1]).view(1, 3) * torch.rand(1) * self.mag 125 | 126 | g = so3.exp(w).to(tensor) # [1, 3, 3] 127 | 128 | p1 = so3.transform(g, tensor) 129 | return p1 130 | 131 | 132 | class RandomJitter: 133 | """ generate perturbations """ 134 | 135 | def __init__(self, scale=0.01, clip=0.05): 136 | self.scale = scale 137 | self.clip = clip 138 | self.e = None 139 | 140 | def jitter(self, tensor): 141 | noise = torch.zeros_like(tensor).to(tensor) # [N, 3] 142 | noise.normal_(0, self.scale) 143 | noise.clamp_(-self.clip, self.clip) 144 | self.e = noise 145 | return tensor.add(noise) 146 | 147 | def __call__(self, tensor): 148 | return self.jitter(tensor) 149 | 150 | 151 | class RandomTransformSE3: 152 | """ rigid motion """ 153 | 154 | def __init__(self, mag=1, mag_randomly=False): 155 | self.mag = mag 156 | self.randomly = mag_randomly 157 | 158 | self.gt = None 159 | self.igt = None 160 | 161 | def generate_transform(self): 162 | # return: a twist-vector 163 | amp = self.mag 164 | if self.randomly: 165 | amp = torch.rand(1, 1) * self.mag 166 | x = torch.randn(1, 6) 167 | x = x / x.norm(p=2, dim=1, keepdim=True) * amp 168 | 169 | '''a = torch.rand(3) 170 | a = a * math.pi 171 | b = torch.zeros(1, 6) 172 | b[:, 0:3] = a 173 | x = x+b 174 | ''' 175 | return x # [1, 6] 176 | 177 | def apply_transform(self, p0, x): 178 | # p0: [N, 3] 179 | # x: [1, 6] 180 | g = se3.exp(x).to(p0) # [1, 4, 4] 181 | gt = se3.exp(-x).to(p0) # [1, 4, 4] 182 | 183 | p1 = se3.transform(g, p0) 184 | self.gt = gt.squeeze(0) # gt: p1 -> p0 185 | self.igt = g.squeeze(0) # igt: p0 -> p1 186 | return p1 187 | 188 | def transform(self, tensor): 189 | x = self.generate_transform() 190 | return self.apply_transform(tensor, x) 191 | 192 | def __call__(self, tensor): 193 | return self.transform(tensor) 194 | 195 | # EOF 196 | --------------------------------------------------------------------------------