├── .gitignore ├── LICENSE ├── README.md ├── common ├── __init__.py ├── manager.py ├── quaternion.py ├── se3.py ├── so3.py └── utils.py ├── dataset ├── __init__.py ├── data_loader.py └── transformations.py ├── evaluate.py ├── experiments └── experiment_omnet │ └── params.json ├── images ├── OMNet_poster.png └── pipeline.png ├── loss ├── __init__.py └── loss.py ├── model ├── __init__.py ├── module.py └── net.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/experiment_omnet/summary 2 | experiments/experiment_omnet/*.log 3 | experiments/experiment_omnet/*.pth 4 | experiments/experiment_omnet/val* 5 | experiments/experiment_omnet/test* 6 | dataset/data 7 | dataset/data.zip 8 | */__pycache__ 9 | checkpoints 10 | configs 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hao Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2021] OMNet: Learning Overlapping Mask for Partial-to-Partial Point Cloud Registration 2 | 3 | This is the Pytorch implementation of our ICCV2021 paper [OMNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Xu_OMNet_Learning_Overlapping_Mask_for_Partial-to-Partial_Point_Cloud_Registration_ICCV_2021_paper.pdf). For our MegEngine implementation, please refer to [this repo](https://github.com/megvii-research/OMNet). 4 | 5 | Our presentation video: [[Youtube](https://www.youtube.com/watch?v=u2lTKsom8oU)][[Bilibili](https://www.bilibili.com/video/BV1Ef4y1J7XP/)]. 6 | 7 | ## Our Poster 8 | 9 | ![image](./images/OMNet_poster.png) 10 | 11 | ## Dependencies 12 | 13 | * Pytorch>=1.5.0 14 | * Other requirements please refer to`requirements.txt`. 15 | 16 | ## Data Preparation 17 | 18 | ### OS data 19 | 20 | We refer the original data from PointNet as OS data, where point clouds are only sampled once from corresponding CAD models. We offer two ways to use OS data, (1) you can download this data from its original link [original_OS_data.zip](http://modelnet.cs.princeton.edu/). (2) you can also download the data that has been preprocessed by us from link [our_OS_data.zip](https://drive.google.com/file/d/1rXnbXwD72tkeu8x6wboMP0X7iL9LiBPq/view?usp=sharing). 21 | 22 | ### TS data 23 | 24 | Since OS data incurs over-fitting issue, we propose our TS data, where point clouds are randomly sampled twice from CAD models. You need to download our preprocessed ModelNet40 dataset first, where 8 axisymmetrical categories are removed and all CAD models have 40 randomly sampled point clouds. The download link is [TS_data.zip](https://drive.google.com/file/d/1DPBBI3Ulvp2Mx7SAZaBEyvADJzBvErFF/view?usp=sharing). All 40 point clouds of a CAD model are stacked to form a (40, 2048, 3) numpy array, you can easily obtain this data by using following code: 25 | 26 | ``` 27 | import numpy as np 28 | points = np.load("path_of_npy_file") 29 | print(points.shape, type(points)) # (40, 2048, 3), 30 | ``` 31 | 32 | Then, you need to put the data into `./dataset/data`, and the contents of directories are as follows: 33 | 34 | ``` 35 | ./dataset/data/ 36 | ├── modelnet40_half1_rm_rotate.txt 37 | ├── modelnet40_half2_rm_rotate.txt 38 | ├── modelnet_os 39 | │   ├── modelnet_os_test.pickle 40 | │   ├── modelnet_os_train.pickle 41 | │   ├── modelnet_os_val.pickle 42 | │   ├── test [1146 entries exceeds filelimit, not opening dir] 43 | │   ├── train [4194 entries exceeds filelimit, not opening dir] 44 | │   └── val [1002 entries exceeds filelimit, not opening dir] 45 | └── modelnet_ts 46 | ├── modelnet_ts_test.pickle 47 | ├── modelnet_ts_train.pickle 48 | ├── modelnet_ts_val.pickle 49 | ├── shape_names.txt 50 | ├── test [1146 entries exceeds filelimit, not opening dir] 51 | ├── train [4196 entries exceeds filelimit, not opening dir] 52 | └── val [1002 entries exceeds filelimit, not opening dir] 53 | ``` 54 | 55 | ## Training and Evaluation 56 | 57 | ### Begin training 58 | 59 | For ModelNet40 dataset, you can just run: 60 | 61 | ``` 62 | python3 train.py --model_dir=./experiments/experiment_omnet/ 63 | ``` 64 | 65 | For other dataset, you need to add your own dataset code in `./dataset/data_loader.py`. Training with a lower batch size, such as 16, may obtain worse performance than training with a larger batch size, e.g., 64. 66 | 67 | ### Begin testing 68 | 69 | You need to download the pretrained checkpoint and run: 70 | 71 | ``` 72 | python3 evaluate.py --model_dir=./experiments/experiment_omnet --restore_file=./experiments/experiment_omnet/val_model_best.pth 73 | ``` 74 | 75 | The following table shows our performance on ModelNet40, where `val` and `test` indicate `Unseen Shapes` and `Unseen Categories` respectively. `PRNet` and `RPMNet` indicate the partial manners used in [PRNet](https://arxiv.org/pdf/1910.12240.pdf) and [RPMNet](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yew_RPM-Net_Robust_Point_Matching_Using_Learned_Features_CVPR_2020_paper.pdf) respectively. 76 | 77 | | dataset | | RMSE(R) | MAE(R) | RMSE(t) | MAE(t) | Error(R) | Error(t) | checkpoint | 78 | | :-------------: | :--: | :-----: | :----: | :-----: | :----: | :------: | :------: | :---------------------------------------------------------------------------------------------: | 79 | | OS_PRNet_clean | val | 0.912 | 0.339 | 0.0078 | 0.0049 | 0.639 | 0.0099 | [Google Drive](https://drive.google.com/file/d/1i6nsSPFriGYxD1rDGTpbtTBYmdQcT8St/view?usp=sharing) | 80 | | | test | 2.247 | 0.652 | 0.0177 | 0.0077 | 1.241 | 0.0154 | [Google Drive](https://drive.google.com/file/d/1LTR4rCT4eQ6JXXOekeUjhXQpwhlh9NEY/view?usp=sharing) | 81 | | TS_PRNet_clean | val | 1.032 | 0.506 | 0.0085 | 0.0057 | 0.984 | 0.0113 | [Google Drive](https://drive.google.com/file/d/1AdutxYe7FS88uoLMf7V6Mo9Tlb9hSDlF/view?usp=sharing) | 82 | | | test | 2.372 | 0.974 | 0.0146 | 0.0077 | 1.892 | 0.0152 | [Google Drive](https://drive.google.com/file/d/1A-6xTPGPAbmnwbnt81NjhN6Mw9VMHmwN/view?usp=sharing) | 83 | | OS_PRNet_noise | val | 1.029 | 0.573 | 0.0089 | 0.0061 | 1.077 | 0.0123 | [Google Drive](https://drive.google.com/file/d/1JbBlBW08PQrucbdpp-G-VlWdjlik6tO7/view?usp=sharing) | 84 | | | test | 2.318 | 0.957 | 0.0155 | 0.0078 | 1.809 | 0.0156 | [Google Drive](https://drive.google.com/file/d/154xYpstuQJ0eDk3rqbDShg5P5b17xlry/view?usp=sharing) | 85 | | TS_PRNet_noise | val | 1.314 | 0.771 | 0.0102 | 0.0074 | 1.490 | 0.0148 | [Google Drive](https://drive.google.com/file/d/1ZzetsjHC4POh8Irr1RfSl8boPJvCQFMx/view?usp=sharing) | 86 | | | test | 2.443 | 1.189 | 0.0181 | 0.0097 | 2.311 | 0.0193 | [Google Drive](https://drive.google.com/file/d/1eHi9pzAmL3jrYGmv6X9xy-8U7hAw9OdI/view?usp=sharing) | 87 | | OS_RPMNet_clean | val | 0.771 | 0.277 | 0.0154 | 0.0056 | 0.561 | 0.0122 | [Google Drive](https://drive.google.com/file/d/1_wGJTxaezFvb4xqABmFfrIR03Wq2c80U/view?usp=sharing) | 88 | | | test | 3.719 | 1.314 | 0.0392 | 0.0151 | 2.659 | 0.0321 | [Google Drive](https://drive.google.com/file/d/1IQ0DZ_OmaZPErPqm4DJ4NBWfnt1d1YG5/view?usp=sharing) | 89 | | TS_RPMNet_clean | val | 1.401 | 0.544 | 0.0241 | 0.0095 | 1.128 | 0.0202 | [Google Drive](https://drive.google.com/file/d/1IlUSzGoAXHzon5ZrwLPNBsTuICphhrAO/view?usp=sharing) | 90 | | | test | 4.016 | 1.622 | 0.0419 | 0.0184 | 3.205 | 0.0394 | [Google Drive](https://drive.google.com/file/d/1NJZcfHoXlCFTMVz01ZACiiTMNUEW1QQC/view?usp=sharing) | 91 | | OS_RPMNet_noise | val | 0.998 | 0.555 | 0.0172 | 0.0078 | 1.079 | 0.0167 | [Google Drive](https://drive.google.com/file/d/1LvhPwrtUs-A2AZWO1YgrvhgcxdZXTen-/view?usp=sharing) | 92 | | | test | 3.572 | 1.570 | 0.0391 | 0.0172 | 3.073 | 0.0359 | [Google Drive](https://drive.google.com/file/d/1xnHcKikXs8D9UuGchwo3YK21vG86zRtp/view?usp=sharing) | 93 | | TS_RPMNet_noise | val | 1.522 | 0.817 | 0.0189 | 0.0098 | 1.622 | 0.0208 | - | 94 | | | test | 4.356 | 1.924 | 0.0486 | 0.0223 | 3.834 | 0.0476 | - | 95 | 96 | ## Citation 97 | 98 | ``` 99 | @InProceedings{Xu_2021_ICCV, 100 | author={Xu, Hao and Liu, Shuaicheng and Wang, Guangfu and Liu, Guanghui and Zeng, Bing}, 101 | title={OMNet: Learning Overlapping Mask for Partial-to-Partial Point Cloud Registration}, 102 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 103 | month={October}, 104 | year={2021}, 105 | pages={3132-3141} 106 | } 107 | ``` 108 | 109 | ## Acknowledgments 110 | 111 | In this project we use (parts of) the official implementations of the following works: 112 | 113 | * [RPMNet](https://github.com/yewzijian/RPMNet) (ModelNet40 preprocessing and evaluation) 114 | * [PRNet](https://github.com/WangYueFt/prnet) (ModelNet40 preprocessing) 115 | 116 | We thank the respective authors for open sourcing their methods. 117 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/common/__init__.py -------------------------------------------------------------------------------- /common/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from termcolor import colored 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from common import utils 10 | 11 | 12 | class Manager(): 13 | def __init__(self, model, optimizer, scheduler, params, dataloaders, logger): 14 | # params status 15 | self.params = params 16 | 17 | self.model = model 18 | self.optimizer = optimizer 19 | self.scheduler = scheduler 20 | self.dataloaders = dataloaders 21 | self.logger = logger 22 | 23 | self.epoch = 0 24 | self.step = 0 25 | self.best_val_score = np.inf 26 | self.cur_val_score = np.inf 27 | self.best_test_score = np.inf 28 | self.cur_test_score = np.inf 29 | 30 | # train status 31 | self.train_status = defaultdict(utils.AverageMeter) 32 | 33 | # val status 34 | self.val_status = defaultdict(utils.AverageMeter) 35 | 36 | # test status 37 | self.test_status = defaultdict(utils.AverageMeter) 38 | 39 | # model status 40 | self.loss_status = defaultdict(utils.AverageMeter) 41 | 42 | # init local tensorboard and html 43 | self.init_tb_and_html() 44 | 45 | def init_tb_and_html(self): 46 | # tensorboard loss 47 | local_tb_dir = os.path.join(self.params.model_dir, "summary/loss") 48 | os.makedirs(local_tb_dir, exist_ok=True) 49 | self.local_loss_writter = SummaryWriter(log_dir=local_tb_dir) 50 | # tensorboard metric 51 | local_tb_dir = os.path.join(self.params.model_dir, "summary/metric") 52 | os.makedirs(local_tb_dir, exist_ok=True) 53 | self.local_metric_writter = SummaryWriter(log_dir=local_tb_dir) 54 | # html 55 | local_html_dir = os.path.join(self.params.model_dir, "summary/html") 56 | os.makedirs(local_html_dir, exist_ok=True) 57 | self.local_html_dir = local_html_dir 58 | 59 | def update_step(self): 60 | self.step += 1 61 | 62 | def update_epoch(self): 63 | self.epoch += 1 64 | 65 | def update_loss_status(self, loss, batch_size): 66 | for k, v in loss.items(): 67 | self.loss_status[k].update(val=v.item(), num=batch_size) 68 | 69 | def update_metric_status(self, metrics, split, batch_size): 70 | if split == "val": 71 | for k, v in metrics.items(): 72 | self.val_status[k].update(val=v.item(), num=batch_size) 73 | self.cur_val_score = self.val_status[self.params.major_metric].avg 74 | elif split == "test": 75 | for k, v in metrics.items(): 76 | self.test_status[k].update(val=v.item(), num=batch_size) 77 | self.cur_test_score = self.test_status[self.params.major_metric].avg 78 | else: 79 | raise ValueError("Wrong eval type: {}".format(split)) 80 | 81 | def summarize_metric_status(self, metrics, split): 82 | if split == "val": 83 | for k in metrics: 84 | if k.endswith('MSE'): 85 | self.val_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.val_status[k].avg)) 86 | else: 87 | continue 88 | elif split == "test": 89 | for k in metrics: 90 | if k.endswith('MSE'): 91 | self.test_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.test_status[k].avg)) 92 | else: 93 | continue 94 | else: 95 | raise ValueError("Wrong eval type: {}".format(split)) 96 | 97 | def reset_loss_status(self): 98 | for k, v in self.loss_status.items(): 99 | self.loss_status[k].reset() 100 | 101 | def reset_metric_status(self, split): 102 | if split == "val": 103 | for k, v in self.val_status.items(): 104 | self.val_status[k].reset() 105 | elif split == "test": 106 | for k, v in self.test_status.items(): 107 | self.test_status[k].reset() 108 | else: 109 | raise ValueError("Wrong split string: {}".format(split)) 110 | 111 | def print_train_info(self): 112 | exp_name = self.params.model_dir.split('/')[-1] 113 | print_str = "{} Epoch: {:4d}, lr={:.4f} ".format(exp_name, self.epoch, self.scheduler.get_last_lr()[0]) 114 | print_str += "total loss: %.4f(%.4f)" % (self.loss_status['total'].val, self.loss_status['total'].avg) 115 | return print_str 116 | 117 | def print_metrics(self, split, title="Eval", color="red", only_best=False): 118 | if split == "val": 119 | metric_status = self.val_status 120 | is_best = self.cur_val_score < self.best_val_score 121 | elif split == "test": 122 | metric_status = self.test_status 123 | is_best = self.cur_test_score < self.best_test_score 124 | else: 125 | raise ValueError("Wrong split string: {}".format(split)) 126 | 127 | print_str = " | ".join("{}: {:4g}".format(k, v.avg) for k, v in metric_status.items()) 128 | if only_best: 129 | if is_best: 130 | self.logger.info(colored("Best Epoch: {}, {} Results: {}".format(self.epoch, title, print_str), color, attrs=["bold"])) 131 | else: 132 | self.logger.info(colored("Epoch: {}, {} Results: {}".format(self.epoch, title, print_str), color, attrs=["bold"])) 133 | 134 | def write_loss_to_tb(self, split): 135 | for k, v in self.loss_status.items(): 136 | if split == "train": 137 | self.local_loss_writter.add_scalar("train_Loss/{}".format(k), v.val, self.step) 138 | elif split == "val": 139 | self.local_loss_writter.add_scalar("val_Loss/{}".format(k), v.val, self.step) 140 | elif split == "test": 141 | self.local_loss_writter.add_scalar("test_Loss/{}".format(k), v.val, self.step) 142 | else: 143 | raise ValueError("Wrong split string: {}".format(split)) 144 | 145 | def write_metric_to_tb(self, split): 146 | if split == "val": 147 | for k, v in self.val_status.items(): 148 | self.local_metric_writter.add_scalar("val_Metric/{}".format(k), v.avg, self.epoch) 149 | elif split == "test": 150 | for k, v in self.test_status.items(): 151 | self.local_metric_writter.add_scalar("test_Metric/{}".format(k), v.avg, self.epoch) 152 | else: 153 | raise ValueError("Wrong split string: {}".format(split)) 154 | 155 | def check_best_save_last_checkpoints(self, save_latest_freq=5, save_best_after=50): 156 | 157 | state = { 158 | "state_dict": self.model.state_dict(), 159 | "optimizer": self.optimizer.state_dict(), 160 | "scheduler": self.scheduler.state_dict(), 161 | "step": self.step, 162 | "epoch": self.epoch, 163 | } 164 | if self.dataloaders["val"] is not None: 165 | state["best_val_score"] = self.best_val_score 166 | if self.dataloaders["test"] is not None: 167 | state["best_test_score"] = self.best_test_score 168 | 169 | # save latest checkpoint 170 | if self.epoch % save_latest_freq == 0: 171 | latest_ckpt_name = os.path.join(self.params.model_dir, "model_latest.pth") 172 | torch.save(state, latest_ckpt_name) 173 | self.logger.info("Saved latest checkpoint to: {}".format(latest_ckpt_name)) 174 | 175 | # save val latest metrics, and check if val is best checkpoints 176 | if self.dataloaders["val"] is not None: 177 | val_latest_metrics_name = os.path.join(self.params.model_dir, "val_metrics_latest.json") 178 | utils.save_dict_to_json(self.val_status, val_latest_metrics_name) 179 | is_best = self.cur_val_score < self.best_val_score 180 | if is_best: 181 | # save metrics 182 | self.best_val_score = self.cur_val_score 183 | best_metrics_name = os.path.join(self.params.model_dir, "val_metrics_best.json") 184 | utils.save_dict_to_json(self.val_status, best_metrics_name) 185 | self.logger.info("Current is val best, score={:.7f}".format(self.best_val_score)) 186 | # save checkpoint 187 | if self.epoch > save_best_after: 188 | best_ckpt_name = os.path.join(self.params.model_dir, "val_model_best.pth") 189 | torch.save(state, best_ckpt_name) 190 | self.logger.info("Saved val best checkpoint to: {}".format(best_ckpt_name)) 191 | 192 | # save test latest metrics, and check if test is best checkpoints 193 | if self.dataloaders["test"] is not None: 194 | test_latest_metrics_name = os.path.join(self.params.model_dir, "test_metrics_latest.json") 195 | utils.save_dict_to_json(self.test_status, test_latest_metrics_name) 196 | is_best = self.cur_test_score < self.best_test_score 197 | if is_best: 198 | # save metrics 199 | self.best_test_score = self.cur_test_score 200 | best_metrics_name = os.path.join(self.params.model_dir, "test_metrics_best.json") 201 | utils.save_dict_to_json(self.test_status, best_metrics_name) 202 | self.logger.info("Current is test best, score={:.7f}".format(self.best_test_score)) 203 | # save checkpoint 204 | if self.epoch > save_best_after: 205 | best_ckpt_name = os.path.join(self.params.model_dir, "test_model_best.pth") 206 | torch.save(state, best_ckpt_name) 207 | self.logger.info("Saved test best checkpoint to: {}".format(best_ckpt_name)) 208 | 209 | def load_checkpoints(self): 210 | state = torch.load(self.params.restore_file) 211 | 212 | ckpt_component = [] 213 | if "state_dict" in state and self.model is not None: 214 | try: 215 | self.model.load_state_dict(state["state_dict"]) 216 | except RuntimeError: 217 | print("Using custom loading net") 218 | net_dict = self.model.state_dict() 219 | if "module" not in list(state["state_dict"].keys())[0]: 220 | state_dict = {"module." + k: v for k, v in state["state_dict"].items() if "module." + k in net_dict.keys()} 221 | else: 222 | state_dict = {k: v for k, v in state["state_dict"].items() if k in net_dict.keys()} 223 | net_dict.update(state_dict) 224 | self.model.load_state_dict(net_dict, strict=False) 225 | ckpt_component.append("net") 226 | 227 | if not self.params.only_weights: 228 | 229 | if "optimizer" in state and self.optimizer is not None: 230 | try: 231 | self.optimizer.load_state_dict(state["optimizer"]) 232 | 233 | except RuntimeError: 234 | print("Using custom loading optimizer") 235 | optimizer_dict = self.optimizer.state_dict() 236 | state_dict = {k: v for k, v in state["optimizer"].items() if k in optimizer_dict.keys()} 237 | optimizer_dict.update(state_dict) 238 | self.optimizer.load_state_dict(optimizer_dict) 239 | ckpt_component.append("opt") 240 | 241 | if "scheduler" in state and self.train_status["scheduler"] is not None: 242 | try: 243 | self.scheduler.load_state_dict(state["scheduler"]) 244 | 245 | except RuntimeError: 246 | print("Using custom loading scheduler") 247 | scheduler_dict = self.scheduler.state_dict() 248 | state_dict = {k: v for k, v in state["scheduler"].items() if k in scheduler_dict.keys()} 249 | scheduler_dict.update(state_dict) 250 | self.scheduler.load_state_dict(scheduler_dict) 251 | ckpt_component.append("sch") 252 | 253 | if "step" in state: 254 | self.step = state["step"] + 1 255 | ckpt_component.append("step") 256 | 257 | if "epoch" in state: 258 | self.epoch = state["epoch"] + 1 259 | ckpt_component.append("epoch") 260 | 261 | if "best_val_score" in state: 262 | self.best_val_score = state["best_val_score"] 263 | ckpt_component.append("best val score: {:.3g}".format(self.best_val_score)) 264 | 265 | if "best_test_score" in state: 266 | self.best_test_score = state["best_test_score"] 267 | ckpt_component.append("best test score: {:.3g}".format(self.best_test_score)) 268 | 269 | ckpt_component = ", ".join(i for i in ckpt_component) 270 | self.logger.info("Loaded models from: {}".format(self.params.restore_file)) 271 | self.logger.info("Ckpt load: {}".format(ckpt_component)) 272 | -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def torch_qmul(q1, q2): 6 | """ 7 | Multiply quaternion(s) q2q1, rotate q1 first, rotate q2 second. 8 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 9 | Returns q*r as a tensor of shape (*, 4). 10 | """ 11 | assert q1.shape[-1] == 4 12 | assert q2.shape[-1] == 4 13 | 14 | original_shape = q1.shape 15 | 16 | # Compute outer product 17 | terms = torch.bmm(q1.view(-1, 4, 1), q2.view(-1, 1, 4)) 18 | 19 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 20 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 21 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 22 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 23 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 24 | 25 | 26 | def torch_qrot(q, v): 27 | """ 28 | Rotate vector(s) v about the rotation described by quaternion(s) q. 29 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 30 | where * denotes any number of dimensions. 31 | Returns a tensor of shape (*, 3). 32 | """ 33 | assert q.shape[-1] == 4 34 | assert v.shape[-1] == 3 35 | assert q.shape[:-1] == v.shape[:-1] 36 | 37 | original_shape = list(v.shape) 38 | q = q.view(-1, 4) 39 | v = v.view(-1, 3) 40 | 41 | qvec = q[:, 1:] 42 | uv = torch.cross(qvec, v, dim=1) 43 | uuv = torch.cross(qvec, uv, dim=1) 44 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 45 | 46 | 47 | def torch_quat2euler(q, order, epsilon=0): 48 | """ 49 | Convert quaternion(s) q to Euler angles. 50 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 51 | Returns a tensor of shape (*, 3). 52 | """ 53 | assert q.shape[-1] == 4 54 | 55 | original_shape = list(q.shape) 56 | original_shape[-1] = 3 57 | q = q.view(-1, 4) 58 | 59 | q0 = q[:, 0] 60 | q1 = q[:, 1] 61 | q2 = q[:, 2] 62 | q3 = q[:, 3] 63 | 64 | if order == "xyz": 65 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 66 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 67 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 68 | elif order == "yzx": 69 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 70 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 71 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 72 | elif order == "zxy": 73 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 74 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 75 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 76 | elif order == "xzy": 77 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 78 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 79 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 80 | elif order == "yxz": 81 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 82 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 83 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 84 | elif order == "zyx": 85 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 86 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 87 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 88 | else: 89 | raise 90 | 91 | return torch.stack((x, y, z), dim=1).view(original_shape) 92 | 93 | 94 | def torch_euler2quat(e, order): 95 | """ 96 | Convert Euler angles to quaternions. 97 | """ 98 | assert e.size()[-1] == 3 99 | 100 | original_shape = [e.size()[0], 4] 101 | 102 | x = e[:, 0] 103 | y = e[:, 1] 104 | z = e[:, 2] 105 | 106 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x).cuda(), torch.zeros_like(x).cuda()), dim=1) 107 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y).cuda(), torch.sin(y / 2), torch.zeros_like(y).cuda()), dim=1) 108 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z).cuda(), torch.zeros_like(z).cuda(), torch.sin(z / 2)), dim=1) 109 | 110 | result = None 111 | for coord in order: 112 | if coord == "x": 113 | r = rx 114 | elif coord == "y": 115 | r = ry 116 | elif coord == "z": 117 | r = rz 118 | else: 119 | raise 120 | if result is None: 121 | result = r 122 | else: 123 | result = torch_qmul(result, r) 124 | 125 | # Reverse antipodal representation to have a non-negative "w" 126 | if order in ["xyz", "yzx", "zxy"]: 127 | result *= -1 128 | 129 | return result.reshape(original_shape) 130 | 131 | 132 | def torch_quat2mat(pose): 133 | # Separate each quaternion value. 134 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3] 135 | # Convert quaternion to rotation matrix. 136 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf 137 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay. 138 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3 139 | R12 = 2 * (q1 * q2 - q0 * q3) 140 | R13 = 2 * (q1 * q3 + q0 * q2) 141 | R21 = 2 * (q1 * q2 + q0 * q3) 142 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3 143 | R23 = 2 * (q2 * q3 - q0 * q1) 144 | R31 = 2 * (q1 * q3 - q0 * q2) 145 | R32 = 2 * (q2 * q3 + q0 * q1) 146 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2 147 | R = torch.stack((torch.stack((R11, R12, R13), dim=0), torch.stack((R21, R22, R23), dim=0), torch.stack((R31, R32, R33), dim=0)), dim=0) 148 | 149 | rot_mat = R.permute((2, 0, 1)) # (B, 3, 3) 150 | translation = pose[:, 4:].unsqueeze(2) # (B, 3, 1) 151 | transform = torch.cat((rot_mat, translation), dim=2) 152 | return transform # (B, 3, 4) 153 | 154 | 155 | def torch_transform_pose(pose_old, pose_new): 156 | quat_old, translate_old = pose_old[:, :4], pose_old[:, 4:] 157 | quat_new, translate_new = pose_new[:, :4], pose_new[:, 4:] 158 | 159 | quat = torch_qmul(quat_old, quat_new) 160 | translate = torch_qrot(quat_new, translate_old) + translate_new 161 | pose = torch.cat((quat, translate), dim=1) 162 | 163 | return pose 164 | 165 | 166 | def torch_qinv(q): 167 | # expectes q in (w,x,y,z) format 168 | w = q[:, 0:1] 169 | v = q[:, 1:] 170 | inv = torch.cat([w, -v], dim=1) 171 | return inv 172 | 173 | 174 | def torch_quat_rotate(point_cloud: torch.Tensor, pose_7d: torch.Tensor): 175 | ndim = point_cloud.dim() 176 | if ndim == 2: 177 | N, _ = point_cloud.shape 178 | assert pose_7d.shape[0] == 1 179 | # repeat transformation vector for each point in shape 180 | quat = pose_7d[:, 0:4].expand([N, -1]) 181 | rotated_point_cloud = torch_qrot(quat, point_cloud) 182 | 183 | elif ndim == 3: 184 | B, N, _ = point_cloud.shape 185 | quat = pose_7d[:, 0:4].unsqueeze(1).expand([-1, N, -1]).contiguous() 186 | rotated_point_cloud = torch_qrot(quat, point_cloud) 187 | 188 | else: 189 | raise RuntimeError("point cloud dim must be 2 or 3 !") 190 | 191 | return rotated_point_cloud 192 | 193 | 194 | def torch_quat_transform(pose_7d: torch.Tensor, pc: torch.Tensor, normal: torch.Tensor = None): 195 | pc_t = torch_quat_rotate(pc, pose_7d) + pose_7d[:, 4:].view(-1, 1, 3).repeat(1, pc.shape[1], 1) # Ps" = R*Ps + t 196 | if normal is not None: 197 | normal_t = torch_quat_rotate(normal, pose_7d) 198 | return pc_t, normal_t 199 | else: 200 | return pc_t 201 | 202 | 203 | def np_qmul(q, r): 204 | q = torch.from_numpy(q).contiguous() 205 | r = torch.from_numpy(r).contiguous() 206 | return torch_qmul(q, r).numpy() 207 | 208 | 209 | def np_qrot(q, v): 210 | q = torch.from_numpy(q).contiguous() 211 | v = torch.from_numpy(v).contiguous() 212 | return torch_qrot(q, v).numpy() 213 | 214 | 215 | def np_quat2euler(q, order, epsilon=0, use_gpu=False): 216 | if use_gpu: 217 | q = torch.from_numpy(q).cuda() 218 | return torch_quat2euler(q, order, epsilon).cpu().numpy() 219 | else: 220 | q = torch.from_numpy(q).contiguous() 221 | return torch_quat2euler(q, order, epsilon).numpy() 222 | 223 | 224 | def np_qfix(q): 225 | """ 226 | Enforce quaternion continuity across the time dimension by selecting 227 | the representation (q or -q) with minimal euclidean_distance (or, equivalently, maximal dot product) 228 | between two consecutive frames. 229 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 230 | Returns a tensor of the same shape. 231 | """ 232 | assert len(q.shape) == 3 233 | assert q.shape[-1] == 4 234 | 235 | result = q.copy() 236 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 237 | mask = dot_products < 0 238 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 239 | result[1:][mask] *= -1 240 | return result 241 | 242 | 243 | def np_expmap2quat(e): 244 | """ 245 | Convert axis-angle rotations (aka exponential maps) to quaternions. 246 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 247 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 248 | Returns a tensor of shape (*, 4). 249 | """ 250 | assert e.shape[-1] == 3 251 | 252 | original_shape = list(e.shape) 253 | original_shape[-1] = 4 254 | e = e.reshape(-1, 3) 255 | 256 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 257 | w = np.cos(0.5 * theta).reshape(-1, 1) 258 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 259 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 260 | 261 | 262 | def np_euler2quat(e, order): 263 | """ 264 | Convert Euler angles to quaternions. 265 | """ 266 | assert e.shape[-1] == 3 267 | 268 | original_shape = list(e.shape) 269 | original_shape[-1] = 4 270 | 271 | e = e.reshape(-1, 3) 272 | 273 | x = e[:, 0] 274 | y = e[:, 1] 275 | z = e[:, 2] 276 | 277 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 278 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 279 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 280 | 281 | result = None 282 | for coord in order: 283 | if coord == "x": 284 | r = rx 285 | elif coord == "y": 286 | r = ry 287 | elif coord == "z": 288 | r = rz 289 | else: 290 | raise 291 | if result is None: 292 | result = r 293 | else: 294 | result = np_qmul(result, r) 295 | 296 | # Reverse antipodal representation to have a non-negative "w" 297 | if order in ["xyz", "yzx", "zxy"]: 298 | result *= -1 299 | 300 | return result.reshape(original_shape) 301 | -------------------------------------------------------------------------------- /common/se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | # import transforms3d.euler as t3de 4 | import transforms3d.quaternions as t3d 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | def torch_identity(batch_size): 9 | return torch.eye(3, 4)[None, ...].repeat(batch_size, 1, 1) 10 | 11 | 12 | def torch_inverse(g): 13 | """ Returns the inverse of the SE3 transform 14 | 15 | Args: 16 | g: (B, 3/4, 4) transform 17 | 18 | Returns: 19 | (B, 3, 4) matrix containing the inverse 20 | 21 | """ 22 | # Compute inverse 23 | rot = g[..., 0:3, 0:3] 24 | trans = g[..., 0:3, 3] 25 | inverse_transform = torch.cat([rot.transpose(-1, -2), rot.transpose(-1, -2) @ -trans[..., None]], dim=-1) 26 | 27 | return inverse_transform 28 | 29 | 30 | def torch_concatenate(a, b): 31 | """Concatenate two SE3 transforms, 32 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix) 33 | 34 | Args: 35 | a: (B, 3/4, 4) 36 | b: (B, 3/4, 4) 37 | 38 | Returns: 39 | (B, 3/4, 4) 40 | """ 41 | 42 | rot1 = a[..., :3, :3] 43 | trans1 = a[..., :3, 3] 44 | rot2 = b[..., :3, :3] 45 | trans2 = b[..., :3, 3] 46 | 47 | rot_cat = rot1 @ rot2 48 | trans_cat = rot1 @ trans2[..., None] + trans1[..., None] 49 | concatenated = torch.cat([rot_cat, trans_cat], dim=-1) 50 | 51 | return concatenated 52 | 53 | 54 | def torch_transform(g, a, normals=None): 55 | """ Applies the SE3 transform 56 | 57 | Args: 58 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4) 59 | a: Points to be transformed (N, 3) or (B, N, 3) 60 | normals: (Optional). If provided, normals will be transformed 61 | 62 | Returns: 63 | transformed points of size (N, 3) or (B, N, 3) 64 | 65 | """ 66 | R = g[..., :3, :3] # (B, 3, 3) 67 | p = g[..., :3, 3] # (B, 3) 68 | 69 | if len(g.size()) == len(a.size()): 70 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :] 71 | else: 72 | raise NotImplementedError 73 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked 74 | 75 | if normals is not None: 76 | rotated_normals = normals @ R.transpose(-1, -2) 77 | return b, rotated_normals 78 | 79 | else: 80 | return b 81 | 82 | 83 | def torch_mat2quat(M): 84 | all_pose = [] 85 | for i in range(M.size()[0]): 86 | rotate = M[i, :3, :3] 87 | translate = M[i, :3, 3] 88 | 89 | # Qyx refers to the contribution of the y input vector component to 90 | # the x output vector component. Qyx is therefore the same as 91 | # M[0,1]. The notation is from the Wikipedia article. 92 | Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = rotate.flatten() 93 | # print(Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz) 94 | # Fill only lower half of symmetric matrix 95 | K = torch.tensor([[Qxx - Qyy - Qzz, 0, 0, 0], [Qyx + Qxy, Qyy - Qxx - Qzz, 0, 0], [Qzx + Qxz, Qzy + Qyz, Qzz - Qxx - Qyy, 0], 96 | [Qyz - Qzy, Qzx - Qxz, Qxy - Qyx, Qxx + Qyy + Qzz]]) / 3.0 97 | # Use Hermitian eigenvectors, values for speed 98 | vals, vecs = torch.symeig(K, True, False) 99 | # Select largest eigenvector, reorder to w,x,y,z quaternion 100 | 101 | q = vecs[[3, 0, 1, 2], torch.argmax(vals)] 102 | # Prefer quaternion with positive w 103 | # (q * -1 corresponds to same rotation as q) 104 | if q[0] < 0: 105 | q *= -1 106 | 107 | pose = torch.cat((q, translate), dim=0) 108 | all_pose.append(pose) 109 | all_pose = torch.stack(all_pose, dim=0) 110 | return all_pose # (B, 7) 111 | 112 | 113 | def np_identity(): 114 | return np.eye(3, 4) 115 | 116 | 117 | def np_transform(g: np.ndarray, pts: np.ndarray): 118 | """ Applies the SE3 transform 119 | 120 | Args: 121 | g: SE3 transformation matrix of size ([B,] 3/4, 4) 122 | pts: Points to be transformed ([B,] N, 3) 123 | 124 | Returns: 125 | transformed points of size (N, 3) 126 | """ 127 | rot = g[..., :3, :3] # (3, 3) 128 | trans = g[..., :3, 3] # (3) 129 | 130 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :] 131 | return transformed 132 | 133 | 134 | def np_inverse(g: np.ndarray): 135 | """Returns the inverse of the SE3 transform 136 | 137 | Args: 138 | g: ([B,] 3/4, 4) transform 139 | 140 | Returns: 141 | ([B,] 3/4, 4) matrix containing the inverse 142 | 143 | """ 144 | rot = g[..., :3, :3] # (3, 3) 145 | trans = g[..., :3, 3] # (3) 146 | 147 | inv_rot = np.swapaxes(rot, -1, -2) 148 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1) 149 | if g.shape[-2] == 4: 150 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 151 | 152 | return inverse_transform 153 | 154 | 155 | def np_concatenate(a: np.ndarray, b: np.ndarray): 156 | """ Concatenate two SE3 transforms 157 | 158 | Args: 159 | a: First transform ([B,] 3/4, 4) 160 | b: Second transform ([B,] 3/4, 4) 161 | 162 | Returns: 163 | a*b ([B, ] 3/4, 4) 164 | 165 | """ 166 | 167 | r_a, t_a = a[..., :3, :3], a[..., :3, 3] 168 | r_b, t_b = b[..., :3, :3], b[..., :3, 3] 169 | 170 | r_ab = r_a @ r_b 171 | t_ab = r_a @ t_b[..., None] + t_a[..., None] 172 | 173 | concatenated = np.concatenate([r_ab, t_ab], axis=-1) 174 | 175 | if a.shape[-2] == 4: 176 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 177 | 178 | return concatenated 179 | 180 | 181 | def np_from_xyzquat(xyzquat): 182 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw 183 | 184 | Args: 185 | xyzquat: np.array (7,) containing translation and quaterion 186 | 187 | Returns: 188 | SE3 matrix (4, 4) 189 | """ 190 | rot = Rotation.from_quat(xyzquat[3:]) 191 | trans = rot.apply(-xyzquat[:3]) 192 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1) 193 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 194 | 195 | return transform 196 | 197 | 198 | def np_mat2quat(transform): 199 | rotate = transform[:3, :3] 200 | translate = transform[:3, 3] 201 | quat = t3d.mat2quat(rotate) 202 | pose = np.concatenate([quat, translate], axis=0) 203 | return pose # (7, ) 204 | 205 | 206 | def np_quat2mat(pose): 207 | # Separate each quaternion value. 208 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3] 209 | # Convert quaternion to rotation matrix. 210 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf 211 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay. 212 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3 213 | R12 = 2 * (q1 * q2 - q0 * q3) 214 | R13 = 2 * (q1 * q3 + q0 * q2) 215 | R21 = 2 * (q1 * q2 + q0 * q3) 216 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3 217 | R23 = 2 * (q2 * q3 - q0 * q1) 218 | R31 = 2 * (q1 * q3 - q0 * q2) 219 | R32 = 2 * (q2 * q3 + q0 * q1) 220 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2 221 | R = np.stack((np.stack((R11, R12, R13), axis=0), np.stack((R21, R22, R23), axis=0), np.stack((R31, R32, R33), axis=0)), axis=0) 222 | 223 | rot_mat = R.transpose((2, 0, 1)) # (B, 3, 3) 224 | translation = pose[:, 4:][:, :, None] # (B, 3, 1) 225 | transform = np.concatenate((rot_mat, translation), axis=2) 226 | return transform # (B, 3, 4) 227 | -------------------------------------------------------------------------------- /common/so3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pytorch3d import transforms as p3d_transforms 4 | from scipy.spatial.transform import Rotation 5 | 6 | 7 | def np_dcm2euler(mats: np.ndarray, seq: str = "zyx", degrees: bool = True): 8 | """Converts rotation matrix to euler angles 9 | 10 | Args: 11 | mats: (B, 3, 3) containing the B rotation matricecs 12 | seq: Sequence of euler rotations (default: "zyx") 13 | degrees (bool): If true (default), will return in degrees instead of radians 14 | 15 | Returns: 16 | 17 | """ 18 | 19 | eulers = [] 20 | for i in range(mats.shape[0]): 21 | r = Rotation.from_matrix(mats[i]) 22 | eulers.append(r.as_euler(seq, degrees=degrees)) 23 | return np.stack(eulers) 24 | 25 | 26 | def np_transform(g: np.ndarray, pts: np.ndarray): 27 | """ Applies the SO3 transform 28 | 29 | Args: 30 | g: SO3 transformation matrix of size (B, 3, 3) 31 | pts: Points to be transformed (B, N, 3) 32 | 33 | Returns: 34 | transformed points of size (B, N, 3) 35 | 36 | """ 37 | rot = g[..., :3, :3] # (3, 3) 38 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) 39 | return transformed 40 | 41 | 42 | def np_inverse(g: np.ndarray): 43 | """Returns the inverse of the SE3 transform 44 | 45 | Args: 46 | g: ([B,] 3/4, 4) transform 47 | 48 | Returns: 49 | ([B,] 3/4, 4) matrix containing the inverse 50 | 51 | """ 52 | rot = g[..., :3, :3] # (3, 3) 53 | 54 | inv_rot = np.swapaxes(rot, -1, -2) 55 | 56 | return inv_rot 57 | 58 | 59 | def torch_dcm2euler(mats, seq, degrees=True): 60 | if seq == "xyz": 61 | eulers = p3d_transforms.matrix_to_euler_angles(mats, "ZYX") 62 | elif seq == "zyx": 63 | eulers = p3d_transforms.matrix_to_euler_angles(mats, "XYZ") 64 | eulers = eulers[:, [2, 1, 0]] 65 | if degrees: 66 | eulers = eulers / np.pi * 180 67 | return eulers 68 | 69 | 70 | def torch_quat2mat(quat): 71 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 72 | 73 | B = quat.size(0) 74 | 75 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 76 | wx, wy, wz = w * x, w * y, w * z 77 | xy, xz, yz = x * y, x * z, y * z 78 | 79 | rotMat = torch.stack([ 80 | w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 81 | 2 * wx + 2 * yz, w2 - x2 - y2 + z2 82 | ], 83 | dim=1).reshape(B, 3, 3) 84 | return rotMat 85 | 86 | 87 | if __name__ == "__main__": 88 | anglex = np.pi / 2 89 | angley = np.pi / 2 90 | anglez = 0 91 | cosx = np.cos(anglex) 92 | cosy = np.cos(angley) 93 | cosz = np.cos(anglez) 94 | sinx = np.sin(anglex) 95 | siny = np.sin(angley) 96 | sinz = np.sin(anglez) 97 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 98 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 99 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 100 | 101 | np_mat = (Rx @ Ry @ Rz)[None, ...] 102 | torch_mat = torch.from_numpy(np_mat) 103 | 104 | np_inv_mat = np_inverse(np_mat) 105 | 106 | np_euler = np_dcm2euler(np_inv_mat, "xyz") 107 | torch_euler = torch_dcm2euler(torch_mat, "xyz") 108 | print("=" * 50) 109 | print(np_euler) 110 | print(torch_euler) 111 | 112 | # src = np.array([[[1, 0, 0]]]) 113 | # ref_forward = np_transform(np_mat, src) 114 | # print(ref_forward) 115 | # ref_backword = np_transform(np_inv_mat, ref_forward) 116 | # print(ref_backword) 117 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import torch 5 | import coloredlogs 6 | import numpy as np 7 | 8 | 9 | class Params(): 10 | """Class that loads hyperparameters from a json file. 11 | 12 | Example: 13 | ``` 14 | params = Params(json_path) 15 | print(params.learning_rate) 16 | params.learning_rate = 0.5 # change the value of learning_rate in params 17 | ``` 18 | """ 19 | def __init__(self, json_path): 20 | with open(json_path) as f: 21 | params = json.load(f) 22 | self.update(params) 23 | 24 | def save(self, json_path): 25 | with open(json_path, "w") as f: 26 | json.dump(self.__dict__, f, indent=4) 27 | 28 | def update(self, dict): 29 | """Loads parameters from json file""" 30 | self.__dict__.update(dict) 31 | 32 | @property 33 | def dict(self): 34 | """Gives dict-like access to Params instance by `params.dict["learning_rate"]""" 35 | return self.__dict__ 36 | 37 | 38 | class RunningAverage(): 39 | """A simple class that maintains the running average of a quantity 40 | Example: 41 | ``` 42 | loss_avg = RunningAverage() 43 | loss_avg.update(2) 44 | loss_avg.update(4) 45 | loss_avg() = 3 46 | ``` 47 | """ 48 | def __init__(self): 49 | self.steps = 0 50 | self.total = 0 51 | 52 | def update(self, val): 53 | self.total += val 54 | self.steps += 1 55 | 56 | def __call__(self): 57 | return self.total / float(self.steps) 58 | 59 | 60 | class AverageMeter(): 61 | def __init__(self): 62 | self.reset() 63 | 64 | def reset(self): 65 | self.val = 0 66 | self.val_previous = 0 67 | self.avg = 0 68 | self.sum = 0 69 | self.count = 0 70 | 71 | def set(self, val): 72 | self.val = val 73 | self.avg = val 74 | 75 | def update(self, val, num): 76 | self.val_previous = self.val 77 | self.val = val 78 | self.sum += val * num 79 | self.count += num 80 | self.avg = self.sum / self.count 81 | 82 | 83 | class NpzMaker(): 84 | @classmethod 85 | def save_npz(cls, files, npz_save_path): 86 | np.savez(npz_save_path, files=[files, 0]) 87 | 88 | @classmethod 89 | def load_npz(cls, npz_save_path): 90 | with np.load(npz_save_path) as fin: 91 | files = fin['files'] 92 | files = list(files) 93 | return files[0] 94 | 95 | 96 | def loss_meter_manager_intial(loss_meter_names): 97 | # 用于根据meter名字初始化需要用到的loss_meter 98 | loss_meters = [] 99 | for name in loss_meter_names: 100 | exec("%s = %s" % (name, "AverageMeter()")) 101 | exec("loss_meters.append(%s)" % name) 102 | 103 | return loss_meters 104 | 105 | 106 | def tensor_gpu(batch, check_on=True): 107 | def check_on_gpu(tensor_): 108 | if isinstance(tensor_, str) or isinstance(tensor_, list): 109 | tensor_g = tensor_ 110 | else: 111 | tensor_g = tensor_.cuda() 112 | return tensor_g.float() 113 | 114 | def check_off_gpu(tensor_): 115 | if isinstance(tensor_, str) or isinstance(tensor_, list): 116 | return tensor_ 117 | 118 | if tensor_.is_cuda: 119 | tensor_c = tensor_.cpu() 120 | else: 121 | tensor_c = tensor_ 122 | tensor_c = tensor_c.detach().numpy() 123 | return tensor_c 124 | 125 | if torch.cuda.is_available(): 126 | if check_on: 127 | for k, v in batch.items(): 128 | batch[k] = check_on_gpu(v) 129 | else: 130 | for k, v in batch.items(): 131 | batch[k] = check_off_gpu(v) 132 | else: 133 | if check_on: 134 | batch = batch 135 | else: 136 | for k, v in batch.items(): 137 | batch[k] = v.detach().numpy() 138 | 139 | return batch 140 | 141 | 142 | def set_logger(log_path): 143 | """Set the logger to log info in terminal and file `log_path`. 144 | 145 | In general, it is useful to have a logger so that every output to the terminal is saved 146 | in a permanent file. Here we save it to `model_dir/train.log`. 147 | 148 | Example: 149 | ``` 150 | logging.info("Starting training...") 151 | ``` 152 | 153 | Args: 154 | log_path: (string) where to log 155 | """ 156 | logger = logging.getLogger() 157 | logger.setLevel(logging.INFO) 158 | 159 | # if not logger.handlers: 160 | # # Logging to a file 161 | # file_handler = logging.FileHandler(log_path) 162 | # file_handler.setFormatter(logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")) 163 | # logger.addHandler(file_handler) 164 | # 165 | # # Logging to console 166 | # stream_handler = logging.StreamHandler() 167 | # stream_handler.setFormatter(logging.Formatter("%(message)s")) 168 | # logger.addHandler(stream_handler) 169 | 170 | coloredlogs.install(level="INFO", logger=logger, fmt="%(asctime)s %(name)s %(message)s") 171 | file_handler = logging.FileHandler(log_path) 172 | log_formatter = logging.Formatter("%(asctime)s - %(message)s") 173 | file_handler.setFormatter(log_formatter) 174 | logger.addHandler(file_handler) 175 | logger.info("Output and logs will be saved to {}".format(log_path)) 176 | return logger 177 | 178 | 179 | def save_dict_to_json(d, json_path): 180 | """Saves dict of floats in json file 181 | 182 | Args: 183 | d: (dict) of float-castable values (np.float, int, float, etc.) 184 | json_path: (string) path to json file 185 | """ 186 | save_dict = {} 187 | with open(json_path, "w") as f: 188 | # We need to convert the values to float for json (it doesn"t accept np.array, np.float, ) 189 | for k, v in d.items(): 190 | if isinstance(v, AverageMeter): 191 | save_dict[k] = float(v.avg) 192 | else: 193 | save_dict[k] = float(v) 194 | json.dump(save_dict, f, indent=4) 195 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from dataset.transformations import fetch_transform 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | def worker_init_fn(worker_id): 16 | rand_seed = random.randint(0, 2**32 - 1) 17 | random.seed(rand_seed) 18 | np.random.seed(rand_seed) 19 | torch.manual_seed(rand_seed) 20 | torch.cuda.manual_seed(rand_seed) 21 | torch.cuda.manual_seed_all(rand_seed) 22 | 23 | 24 | class ModelNetNpy(Dataset): 25 | def __init__(self, dataset_path: str, dataset_mode: str, subset: str = "train", categories=None, transform=None): 26 | """ModelNet40 TS data. 27 | """ 28 | self._logger = logging.getLogger(self.__class__.__name__) 29 | self._root = dataset_path 30 | self._subset = subset 31 | 32 | metadata_fpath = os.path.join(self._root, "modelnet_{}_{}.pickle".format(dataset_mode, subset)) 33 | self._logger.info("Loading data from {} for {}".format(metadata_fpath, subset)) 34 | 35 | if not os.path.exists(os.path.join(dataset_path)): 36 | assert FileNotFoundError("Not found dataset_path: {}".format(dataset_path)) 37 | 38 | with open(os.path.join(dataset_path, "shape_names.txt")) as fid: 39 | self._classes = [l.strip() for l in fid] 40 | self._category2idx = {e[1]: e[0] for e in enumerate(self._classes)} 41 | self._idx2category = self._classes 42 | 43 | if categories is not None: 44 | categories_idx = [self._category2idx[c] for c in categories] 45 | self._logger.info("Categories used: {}.".format(categories_idx)) 46 | self._classes = categories 47 | else: 48 | categories_idx = None 49 | self._logger.info("Using all categories.") 50 | 51 | self._data = self._read_pickle_files(os.path.join(dataset_path, "modelnet_{}_{}.pickle".format(dataset_mode, subset)), 52 | categories_idx) 53 | 54 | self._transform = transform 55 | self._logger.info("Loaded {} {} instances.".format(len(self._data), subset)) 56 | 57 | @property 58 | def classes(self): 59 | return self._classes 60 | 61 | @staticmethod 62 | def _read_pickle_files(fnames, categories): 63 | 64 | all_data_dict = [] 65 | with open(fnames, "rb") as f: 66 | data = pickle.load(f) 67 | 68 | for category in categories: 69 | all_data_dict.extend(data[category]) 70 | 71 | return all_data_dict 72 | 73 | def to_category(self, i): 74 | return self._idx2category[i] 75 | 76 | def __getitem__(self, item): 77 | 78 | data_path = self._data[item] 79 | 80 | # load and process data 81 | points = np.load(data_path) 82 | idx = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[1])) 83 | label = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[3])) 84 | sample = {"points": points, "label": label, "idx": idx} 85 | 86 | if self._transform: 87 | sample = self._transform(sample) 88 | return sample 89 | 90 | def __len__(self): 91 | return len(self._data) 92 | 93 | 94 | def fetch_dataloader(params): 95 | _logger.info("Dataset type: {}, transform type: {}".format(params.dataset_type, params.transform_type)) 96 | train_transforms, test_transforms = fetch_transform(params) 97 | if params.dataset_type == "modelnet_os": 98 | dataset_path = "./dataset/data/modelnet_os" 99 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 100 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 101 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")] 102 | train_categories.sort() 103 | val_categories.sort() 104 | test_categories.sort() 105 | train_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="train", categories=train_categories, transform=train_transforms) 106 | val_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="val", categories=val_categories, transform=test_transforms) 107 | test_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="test", categories=test_categories, transform=test_transforms) 108 | 109 | elif params.dataset_type == "modelnet_ts": 110 | dataset_path = "./dataset/data/modelnet_ts" 111 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 112 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 113 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")] 114 | train_categories.sort() 115 | val_categories.sort() 116 | test_categories.sort() 117 | train_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="train", categories=train_categories, transform=train_transforms) 118 | val_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="val", categories=val_categories, transform=test_transforms) 119 | test_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="test", categories=test_categories, transform=test_transforms) 120 | 121 | else: 122 | raise NotImplementedError 123 | 124 | dataloaders = {} 125 | params.prefetch_factor = 5 126 | # add defalt train data loader 127 | train_dl = DataLoader(train_ds, 128 | batch_size=params.train_batch_size, 129 | shuffle=True, 130 | num_workers=params.num_workers, 131 | pin_memory=params.cuda, 132 | drop_last=True, 133 | prefetch_factor=params.prefetch_factor, 134 | worker_init_fn=worker_init_fn) 135 | dataloaders["train"] = train_dl 136 | 137 | # chosse val or test data loader for evaluate 138 | for split in ["val", "test"]: 139 | if split in params.eval_type: 140 | if split == "val": 141 | dl = DataLoader(val_ds, 142 | batch_size=params.eval_batch_size, 143 | shuffle=False, 144 | num_workers=params.num_workers, 145 | pin_memory=params.cuda, 146 | prefetch_factor=params.prefetch_factor) 147 | elif split == "test": 148 | dl = DataLoader(test_ds, 149 | batch_size=params.eval_batch_size, 150 | shuffle=False, 151 | num_workers=params.num_workers, 152 | pin_memory=params.cuda, 153 | prefetch_factor=params.prefetch_factor) 154 | else: 155 | raise ValueError("Unknown eval_type in params, should in [val, test]") 156 | dataloaders[split] = dl 157 | else: 158 | dataloaders[split] = None 159 | 160 | return dataloaders 161 | -------------------------------------------------------------------------------- /dataset/transformations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from common import se3, so3 7 | from scipy.spatial.transform import Rotation 8 | from scipy.stats import special_ortho_group 9 | 10 | _logger = logging.getLogger(__name__) 11 | 12 | 13 | def uniform_2_sphere(num: int = None): 14 | """Uniform sampling on a 2-sphere 15 | 16 | Source: https://gist.github.com/andrewbolster/10274979 17 | 18 | Args: 19 | num: Number of vectors to sample (or None if single) 20 | 21 | Returns: 22 | Random Vector (np.ndarray) of size (num, 3) with norm 1. 23 | If num is None returned value will have size (3,) 24 | 25 | """ 26 | if num is not None: 27 | phi = np.random.uniform(0.0, 2 * np.pi, num) 28 | cos_theta = np.random.uniform(-1.0, 1.0, num) 29 | else: 30 | phi = np.random.uniform(0.0, 2 * np.pi) 31 | cos_theta = np.random.uniform(-1.0, 1.0) 32 | 33 | theta = np.arccos(cos_theta) 34 | x = np.sin(theta) * np.cos(phi) 35 | y = np.sin(theta) * np.sin(phi) 36 | z = np.cos(theta) 37 | 38 | return np.stack((x, y, z), axis=-1) 39 | 40 | 41 | class SplitSourceRef: 42 | """Clones the point cloud into separate source and reference point clouds""" 43 | def __init__(self, mode="hdf"): 44 | self.mode = mode 45 | 46 | def __call__(self, sample): 47 | if "deterministic" in sample and sample["deterministic"]: 48 | np.random.seed(sample["idx"]) 49 | 50 | if self.mode == "hdf": 51 | sample["points_raw"] = sample.pop("points").astype(np.float32)[:, :3] 52 | sample["points_src"] = sample["points_raw"].copy() 53 | sample["points_ref"] = sample["points_raw"].copy() 54 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32) 55 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32) 56 | 57 | elif self.mode == "resample": 58 | points_raw = sample.pop("points").astype(np.float32) 59 | sample["points_src"] = points_raw[np.random.choice(points_raw.shape[0], 2048, replace=False), :] 60 | sample["points_ref"] = points_raw[np.random.choice(points_raw.shape[0], 2048, replace=False), :] 61 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32) 62 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32) 63 | 64 | elif self.mode == "donothing": 65 | points_raw = sample.pop("points").astype(np.float32) 66 | points_raw = points_raw[np.random.choice(points_raw.shape[0], 2, replace=False), :, :] 67 | sample["points_src"] = points_raw[0, :, :].astype(np.float32) 68 | sample["points_ref"] = points_raw[1, :, :].astype(np.float32) 69 | sample["points_src_raw"] = sample["points_src"].copy() 70 | sample["points_ref_raw"] = sample["points_ref"].copy() 71 | 72 | elif self.mode == "7scenes": 73 | points_src = sample["points_src"].astype(np.float32) 74 | points_ref = sample["points_ref"].astype(np.float32) 75 | sample["points_src_raw"] = points_src.astype(np.float32) 76 | sample["points_ref_raw"] = points_ref.astype(np.float32) 77 | if sample["num_points"] == -1: 78 | sample["points_src"] = points_src 79 | sample["points_ref"] = points_ref 80 | else: 81 | if points_src.shape[0] > sample["num_points"]: 82 | sample["points_src"] = points_src[np.random.choice(points_src.shape[0], sample["num_points"], replace=False), :] 83 | else: 84 | rand_idxs = np.concatenate([ 85 | np.random.choice(points_src.shape[0], points_src.shape[0], replace=False), 86 | np.random.choice(points_src.shape[0], sample["num_points"] - points_src.shape[0], replace=True) 87 | ]) 88 | sample["points_src"] = points_src[rand_idxs, :] 89 | 90 | if points_ref.shape[0] > sample["num_points"]: 91 | sample["points_ref"] = points_ref[np.random.choice(points_ref.shape[0], sample["num_points"], replace=False), :] 92 | else: 93 | rand_idxs = np.concatenate([ 94 | np.random.choice(points_ref.shape[0], points_ref.shape[0], replace=False), 95 | np.random.choice(points_ref.shape[0], sample["num_points"] - points_ref.shape[0], replace=True) 96 | ]) 97 | sample["points_ref"] = points_ref[rand_idxs, :] 98 | 99 | elif self.mode == "kitti_real": 100 | points_src = sample["points_src"].astype(np.float32) 101 | points_ref = sample["points_ref"].astype(np.float32) 102 | points_src = torch.from_numpy(points_src).unsqueeze(0).cuda() 103 | points_ref = torch.from_numpy(points_ref).unsqueeze(0).cuda() 104 | points_src_flipped = points_src.transpose(1, 2).contiguous().detach() 105 | points_src_fps = gather_operation(points_src_flipped, furthest_point_sample(points_src, 106 | 4096)).transpose(1, 2).contiguous().detach() 107 | points_ref_flipped = points_ref.transpose(1, 2).contiguous().detach() 108 | points_ref_fps = gather_operation(points_ref_flipped, furthest_point_sample(points_ref, 109 | 4096)).transpose(1, 2).contiguous().detach() 110 | sample["points_src"] = points_src_fps.squeeze(0).cpu().numpy() 111 | sample["points_ref"] = points_ref_fps.squeeze(0).cpu().numpy() 112 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32) 113 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32) 114 | 115 | elif self.mode == "kitti_sync": 116 | points_raw = sample.pop("points").astype(np.float32) 117 | sample["points_src"] = points_raw[np.random.choice(points_raw.shape[0], 4096, replace=False), :] 118 | sample["points_ref"] = points_raw[np.random.choice(points_raw.shape[0], 4096, replace=False), :] 119 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32) 120 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32) 121 | 122 | else: 123 | raise NotImplementedError 124 | 125 | return sample 126 | 127 | 128 | class Resampler: 129 | def __init__(self, num: int): 130 | """Resamples a point cloud containing N points to one containing M 131 | 132 | Guaranteed to have no repeated points if M <= N. 133 | Otherwise, it is guaranteed that all points appear at least once. 134 | 135 | Args: 136 | num (int): Number of points to resample to, i.e. M 137 | 138 | """ 139 | self.num = num 140 | 141 | @staticmethod 142 | def _resample(points, k): 143 | """Resamples the points such that there is exactly k points. 144 | 145 | If the input point cloud has <= k points, it is guaranteed the 146 | resampled point cloud contains every point in the input. 147 | If the input point cloud has > k points, it is guaranteed the 148 | resampled point cloud does not contain repeated point. 149 | """ 150 | # print("===", points.shape[0], k) 151 | if k < points.shape[0]: 152 | rand_idxs = np.random.choice(points.shape[0], k, replace=False) 153 | return points[rand_idxs, :] 154 | elif points.shape[0] == k: 155 | return points 156 | else: 157 | rand_idxs = np.concatenate([ 158 | np.random.choice(points.shape[0], points.shape[0], replace=False), 159 | np.random.choice(points.shape[0], k - points.shape[0], replace=True) 160 | ]) 161 | return points[rand_idxs, :] 162 | 163 | def __call__(self, sample): 164 | 165 | if "deterministic" in sample and sample["deterministic"]: 166 | np.random.seed(sample["idx"]) 167 | 168 | if "points" in sample: 169 | sample["points"] = self._resample(sample["points"], self.num) 170 | else: 171 | if "crop_proportion" not in sample: 172 | src_size, ref_size = self.num, self.num 173 | elif len(sample["crop_proportion"]) == 1: 174 | src_size = math.ceil(sample["crop_proportion"][0] * self.num) 175 | ref_size = self.num 176 | elif len(sample["crop_proportion"]) == 2: 177 | src_size = math.ceil(sample["crop_proportion"][0] * self.num) 178 | ref_size = math.ceil(sample["crop_proportion"][1] * self.num) 179 | else: 180 | raise ValueError("Crop proportion must have 1 or 2 elements") 181 | 182 | sample["points_src"] = self._resample(sample["points_src"], src_size) 183 | sample["points_ref"] = self._resample(sample["points_ref"], ref_size) 184 | 185 | # sample for the raw point clouds 186 | sample["points_src_raw"] = sample["points_src_raw"][:self.num, :] 187 | sample["points_ref_raw"] = sample["points_ref_raw"][:self.num, :] 188 | 189 | return sample 190 | 191 | 192 | class RandomJitter: 193 | """ generate perturbations """ 194 | def __init__(self, noise_std=0.01, clip=0.05): 195 | self.noise_std = noise_std 196 | self.clip = clip 197 | 198 | def jitter(self, pts): 199 | 200 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip) 201 | pts[:, :3] += noise # Add noise to xyz 202 | 203 | return pts 204 | 205 | def __call__(self, sample): 206 | 207 | if "points" in sample: 208 | sample["points"] = self.jitter(sample["points"]) 209 | else: 210 | sample["points_src"] = self.jitter(sample["points_src"]) 211 | sample["points_ref"] = self.jitter(sample["points_ref"]) 212 | 213 | return sample 214 | 215 | 216 | class RandomCrop: 217 | """Randomly crops the *source* point cloud, approximately retaining half the points 218 | 219 | A direction is randomly sampled from S2, and we retain points which lie within the 220 | half-space oriented in this direction. 221 | If p_keep != 0.5, we shift the plane until approximately p_keep points are retained 222 | """ 223 | def __init__(self, p_keep=None): 224 | if p_keep is None: 225 | p_keep = [0.7, 0.7] # Crop both clouds to 70% 226 | self.p_keep = np.array(p_keep, dtype=np.float32) 227 | 228 | @staticmethod 229 | def crop(points, p_keep): 230 | if p_keep == 1.0: 231 | mask = np.ones(shape=(points.shape[0], )) > 0 232 | 233 | else: 234 | rand_xyz = uniform_2_sphere() 235 | centroid = np.mean(points[:, :3], axis=0) 236 | points_centered = points[:, :3] - centroid 237 | dist_from_plane = np.dot(points_centered, rand_xyz) 238 | 239 | if p_keep == 0.5: 240 | mask = dist_from_plane > 0 241 | else: 242 | mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100) 243 | 244 | return points[mask, :] 245 | 246 | def __call__(self, sample): 247 | 248 | if "deterministic" in sample and sample["deterministic"]: 249 | np.random.seed(sample["idx"]) 250 | 251 | sample["crop_proportion"] = self.p_keep 252 | 253 | if len(sample["crop_proportion"]) == 1: 254 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0]) 255 | sample["points_ref"] = self.crop(sample["points_ref"], 1.0) 256 | else: 257 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0]) 258 | sample["points_ref"] = self.crop(sample["points_ref"], self.p_keep[1]) 259 | 260 | return sample 261 | 262 | 263 | class RandomTransformSE3: 264 | def __init__(self, rot_mag: float = 180.0, trans_mag: float = 1.0, random_mag: bool = False): 265 | """Applies a random rigid transformation to the source point cloud 266 | 267 | Args: 268 | rot_mag (float): Maximum rotation in degrees 269 | trans_mag (float): Maximum translation T. Random translation will 270 | be in the range [-X,X] in each axis 271 | random_mag (bool): If true, will randomize the maximum rotation, i.e. will bias towards small 272 | perturbations 273 | """ 274 | self._rot_mag = rot_mag 275 | self._trans_mag = trans_mag 276 | self._random_mag = random_mag 277 | 278 | def generate_transform(self): 279 | """Generate a random SE3 transformation (3, 4) """ 280 | 281 | if self._random_mag: 282 | attentuation = np.random.random() 283 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 284 | else: 285 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 286 | 287 | # Generate rotation 288 | rand_rot = special_ortho_group.rvs(3) 289 | axis_angle = Rotation.as_rotvec(Rotation.from_dcm(rand_rot)) 290 | axis_angle *= rot_mag / 180.0 291 | rand_rot = Rotation.from_rotvec(axis_angle).as_dcm() 292 | 293 | # Generate translation 294 | rand_trans = np.random.uniform(-trans_mag, trans_mag, 3) 295 | rand_SE3 = np.concatenate((rand_rot, rand_trans[:, None]), axis=1).astype(np.float32) 296 | 297 | return rand_SE3 298 | 299 | def apply_transform(self, p0, transform_mat): 300 | p1 = se3.np_transform(transform_mat, p0[:, :3]) 301 | if p0.shape[1] == 6: # Need to rotate normals also 302 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 303 | p1 = np.concatenate((p1, n1), axis=-1) 304 | 305 | igt = transform_mat 306 | gt = se3.np_inverse(igt) 307 | 308 | return p1, gt, igt 309 | 310 | def transform(self, tensor): 311 | transform_mat = self.generate_transform() 312 | return self.apply_transform(tensor, transform_mat) 313 | 314 | def __call__(self, sample): 315 | 316 | if "deterministic" in sample and sample["deterministic"]: 317 | np.random.seed(sample["idx"]) 318 | 319 | if "points" in sample: 320 | sample["points"], _, _ = self.transform(sample["points"]) 321 | else: 322 | src_transformed, transform_r_s, transform_s_r = self.transform(sample["points_src"]) 323 | # Apply to source to get reference 324 | sample["transform_gt"] = transform_r_s 325 | sample["pose_gt"] = se3.np_mat2quat(transform_r_s) 326 | sample["transform_igt"] = transform_s_r 327 | sample["points_src"] = src_transformed 328 | # transnform the raw source point cloud 329 | sample["points_src_raw"] = se3.np_transform(transform_s_r, sample["points_src_raw"][:, :3]) 330 | 331 | return sample 332 | 333 | 334 | # noinspection PyPep8Naming 335 | class RandomTransformSE3_euler(RandomTransformSE3): 336 | """Same as RandomTransformSE3, but rotates using euler angle rotations 337 | 338 | This transformation is consistent to Deep Closest Point but does not 339 | generate uniform rotations 340 | 341 | """ 342 | def generate_transform(self): 343 | 344 | if self._random_mag: 345 | attentuation = np.random.random() 346 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 347 | else: 348 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 349 | 350 | # Generate rotation 351 | anglex = np.random.uniform() * np.pi * rot_mag / 180.0 352 | angley = np.random.uniform() * np.pi * rot_mag / 180.0 353 | anglez = np.random.uniform() * np.pi * rot_mag / 180.0 354 | 355 | cosx = np.cos(anglex) 356 | cosy = np.cos(angley) 357 | cosz = np.cos(anglez) 358 | sinx = np.sin(anglex) 359 | siny = np.sin(angley) 360 | sinz = np.sin(anglez) 361 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 362 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 363 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 364 | R_ab = Rx @ Ry @ Rz 365 | t_ab = np.random.uniform(-trans_mag, trans_mag, 3) 366 | 367 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 368 | return rand_SE3 369 | 370 | 371 | class ShufflePoints: 372 | """Shuffles the order of the points""" 373 | def __call__(self, sample): 374 | if "points" in sample: 375 | sample["points"] = np.random.permutation(sample["points"]) 376 | else: 377 | sample["points_ref"] = np.random.permutation(sample["points_ref"]) 378 | sample["points_src"] = np.random.permutation(sample["points_src"]) 379 | 380 | return sample 381 | 382 | 383 | class SetDeterministic: 384 | """Adds a deterministic flag to the sample such that subsequent transforms 385 | use a fixed random seed where applicable. Used for test""" 386 | def __call__(self, sample): 387 | sample["deterministic"] = True 388 | return sample 389 | 390 | 391 | class PRNetTorch: 392 | def __init__(self, num_points, rot_mag, trans_mag, noise_std=0.01, clip=0.05, add_noise=True, only_z=False, partial=True): 393 | self.num_points = num_points 394 | self.rot_mag = rot_mag 395 | self.trans_mag = trans_mag 396 | self.noise_std = noise_std 397 | self.clip = clip 398 | self.add_noise = add_noise 399 | self.only_z = only_z 400 | self.partial = partial 401 | 402 | def apply_transform(self, p0, transform_mat): 403 | p1 = se3.np_transform(transform_mat, p0[:, :3]) 404 | if p0.shape[1] == 6: # Need to rotate normals also 405 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 406 | p1 = np.concatenate((p1, n1), axis=-1) 407 | 408 | gt = transform_mat 409 | 410 | return p1, gt 411 | 412 | def jitter(self, pts): 413 | 414 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip) 415 | noise = torch.from_numpy(noise).to(pts.device) 416 | pts[:, :3] += noise # Add noise to xyz 417 | 418 | return pts 419 | 420 | def knn(self, pts, random_pt, k): 421 | random_pt = torch.from_numpy(random_pt).to(pts.device) 422 | distance = torch.sum((pts - random_pt)**2, dim=1) 423 | idx = distance.topk(k=k, dim=0, largest=False)[1] # (batch_size, num_points, k) 424 | return idx 425 | 426 | def __call__(self, sample): 427 | 428 | if "deterministic" in sample and sample["deterministic"]: 429 | np.random.seed(sample["idx"]) 430 | 431 | src = sample["points_src"] 432 | ref = sample["points_ref"] 433 | # Generate rigid transform 434 | anglex = np.random.uniform() * np.pi * self.rot_mag / 180.0 435 | angley = np.random.uniform() * np.pi * self.rot_mag / 180.0 436 | anglez = np.random.uniform() * np.pi * self.rot_mag / 180.0 437 | 438 | cosx = np.cos(anglex) 439 | cosy = np.cos(angley) 440 | cosz = np.cos(anglez) 441 | sinx = np.sin(anglex) 442 | siny = np.sin(angley) 443 | sinz = np.sin(anglez) 444 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 445 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 446 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 447 | 448 | if not self.only_z: 449 | R_ab = Rx @ Ry @ Rz 450 | else: 451 | R_ab = Rz 452 | t_ab = np.random.uniform(-self.trans_mag, self.trans_mag, 3) 453 | 454 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 455 | ref, transform_s_r = self.apply_transform(ref, rand_SE3) 456 | # Apply to source to get reference 457 | sample["transform_gt"] = transform_s_r 458 | sample["pose_gt"] = se3.np_mat2quat(transform_s_r) 459 | 460 | # Crop and sample 461 | if self.partial: 462 | src = torch.from_numpy(src) 463 | ref = torch.from_numpy(ref) 464 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 465 | idx1 = self.knn(src, random_p1, k=768) 466 | # np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 2, -2]) 467 | random_p2 = random_p1 468 | idx2 = self.knn(ref, random_p2, k=768) 469 | else: 470 | idx1 = np.random.choice(src.shape[0], 1024, replace=False), 471 | idx2 = np.random.choice(ref.shape[0], 1024, replace=False), 472 | # src = np.squeeze(src, axis=-1) 473 | # ref = np.squeeze(ref, axis=-1) 474 | src = torch.from_numpy(src) 475 | ref = torch.from_numpy(ref) 476 | 477 | # add noise 478 | if self.add_noise: 479 | sample["points_src"] = self.jitter(src[idx1, :]) 480 | sample["points_ref"] = self.jitter(ref[idx2, :]) 481 | else: 482 | sample["points_src"] = src[idx1, :] 483 | sample["points_ref"] = ref[idx2, :] 484 | if sample["points_src"].size()[0] == 1: 485 | sample["points_src"] = sample["points_src"].squeeze(0) 486 | sample["points_ref"] = sample["points_ref"].squeeze(0) 487 | 488 | # # for inference time 489 | # if sample["points_src"].shape[0] < self.num_points: 490 | # rand_idxs = np.concatenate( 491 | # [np.random.choice(sample["points_src"].shape[0], sample["points_src"].shape[0], replace=False), 492 | # np.random.choice(sample["points_src"].shape[0], self.num_points - sample["points_src"].shape[0], 493 | # replace=True)]) 494 | # sample["points_src"] = sample["points_src"][rand_idxs, :] 495 | # rand_idxs = np.concatenate( 496 | # [np.random.choice(sample["points_ref"].shape[0], sample["points_ref"].shape[0], replace=False), 497 | # np.random.choice(sample["points_ref"].shape[0], self.num_points - sample["points_ref"].shape[0], 498 | # replace=True)]) 499 | # sample["points_ref"] = sample["points_ref"][rand_idxs, :] 500 | # else: 501 | # rand_idxs = np.random.choice(sample["points_src"].shape[0], self.num_points, replace=False), 502 | # sample["points_src"] = sample["points_src"][rand_idxs, :] 503 | # rand_idxs = np.random.choice(sample["points_ref"].shape[0], self.num_points, replace=False), 504 | # sample["points_ref"] = sample["points_ref"][rand_idxs, :] 505 | # if sample["points_src"].shape[0] == 1: 506 | # sample["points_src"] = sample["points_src"].squeeze(0) 507 | # sample["points_ref"] = sample["points_ref"].squeeze(0) 508 | 509 | return sample 510 | 511 | 512 | class PRNetTorchOverlapRatio: 513 | def __init__(self, num_points, rot_mag, trans_mag, noise_std=0.01, clip=0.05, add_noise=True, overlap_ratio=0.8): 514 | self.num_points = num_points 515 | self.rot_mag = rot_mag 516 | self.trans_mag = trans_mag 517 | self.noise_std = noise_std 518 | self.clip = clip 519 | self.add_noise = add_noise 520 | self.overlap_ratio = overlap_ratio 521 | 522 | def apply_transform(self, p0, transform_mat): 523 | p1 = se3.np_transform(transform_mat, p0[:, :3]) 524 | if p0.shape[1] == 6: # Need to rotate normals also 525 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 526 | p1 = np.concatenate((p1, n1), axis=-1) 527 | 528 | gt = transform_mat 529 | 530 | return p1, gt 531 | 532 | def jitter(self, pts): 533 | 534 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip) 535 | noise = torch.from_numpy(noise).to(pts.device) 536 | pts[:, :3] += noise # Add noise to xyz 537 | 538 | return pts 539 | 540 | def knn(self, pts, random_pt, k1, k2): 541 | random_pt = torch.from_numpy(random_pt).to(pts.device) 542 | distance = torch.sum((pts - random_pt)**2, dim=1) 543 | idx1 = distance.topk(k=k1, dim=0, largest=False)[1] # (batch_size, num_points, k) 544 | idx2 = distance.topk(k=k2, dim=0, largest=True)[1] 545 | return idx1, idx2 546 | 547 | def __call__(self, sample): 548 | 549 | if "deterministic" in sample and sample["deterministic"]: 550 | np.random.seed(sample["idx"]) 551 | 552 | src = sample["points_src"] 553 | ref = sample["points_ref"] 554 | 555 | # Crop and sample 556 | src = torch.from_numpy(src) 557 | ref = torch.from_numpy(ref) 558 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 559 | # pdb.set_trace() 560 | src_idx1, src_idx2 = self.knn(src, random_p1, k1=768, k2=2048 - 768) 561 | ref_idx1, ref_idx2 = self.knn(ref, random_p1, k1=768, k2=2048 - 768) 562 | 563 | k1_idx = ref_idx1[np.random.randint(ref_idx1.size()[0])] 564 | k1 = ref[k1_idx, :] 565 | k2_idx = ref_idx1[0] 566 | k2 = ref[k2_idx, :] 567 | 568 | distance = torch.sum((ref[ref_idx1, :] - k1)**2, dim=1) 569 | overlap_idx = distance.topk(k=int(768 * self.overlap_ratio), dim=0, largest=False)[1] 570 | k1_points = ref[ref_idx1, :][overlap_idx, :] 571 | distance = torch.sum((ref[ref_idx2, :] - k2)**2, dim=1) 572 | nonoverlap_idx = distance.topk(k=768 - int(768 * self.overlap_ratio), dim=0, largest=False)[1] 573 | k2_points = ref[ref_idx2, :][nonoverlap_idx, :] 574 | ref = torch.cat((k1_points, k2_points), dim=0) 575 | src = src[src_idx1, :] 576 | # pdb.set_trace() 577 | 578 | # Generate rigid transform 579 | anglex = np.random.uniform() * np.pi * self.rot_mag / 180.0 580 | angley = np.random.uniform() * np.pi * self.rot_mag / 180.0 581 | anglez = np.random.uniform() * np.pi * self.rot_mag / 180.0 582 | 583 | cosx = np.cos(anglex) 584 | cosy = np.cos(angley) 585 | cosz = np.cos(anglez) 586 | sinx = np.sin(anglex) 587 | siny = np.sin(angley) 588 | sinz = np.sin(anglez) 589 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 590 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 591 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 592 | 593 | R_ab = Rx @ Ry @ Rz 594 | t_ab = np.random.uniform(-self.trans_mag, self.trans_mag, 3) 595 | 596 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 597 | ref, transform_s_r = self.apply_transform(ref, rand_SE3) 598 | # Apply to source to get reference 599 | sample["transform_gt"] = transform_s_r 600 | sample["pose_gt"] = se3.np_mat2quat(transform_s_r) 601 | 602 | # add noise 603 | if self.add_noise: 604 | sample["points_src"] = self.jitter(src) 605 | sample["points_ref"] = self.jitter(ref) 606 | else: 607 | sample["points_src"] = src 608 | sample["points_ref"] = ref 609 | 610 | if sample["points_src"].size()[0] == 1: 611 | sample["points_src"] = sample["points_src"].squeeze(0) 612 | sample["points_ref"] = sample["points_ref"].squeeze(0) 613 | 614 | return sample 615 | 616 | 617 | def fetch_transform(params): 618 | 619 | if params.transform_type == "modelnet_os_rpmnet_noise": 620 | train_transforms = [ 621 | SplitSourceRef(mode="hdf"), 622 | RandomCrop(params.partial_ratio), 623 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 624 | Resampler(params.num_points), 625 | RandomJitter(), 626 | ShufflePoints() 627 | ] 628 | 629 | test_transforms = [ 630 | SetDeterministic(), 631 | SplitSourceRef(mode="hdf"), 632 | RandomCrop(params.partial_ratio), 633 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 634 | Resampler(params.num_points), 635 | RandomJitter(), 636 | ShufflePoints() 637 | ] 638 | 639 | elif params.transform_type == "modelnet_os_rpmnet_clean": 640 | train_transforms = [ 641 | SplitSourceRef(mode="hdf"), 642 | RandomCrop(params.partial_ratio), 643 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 644 | Resampler(params.num_points), 645 | # RandomJitter(), 646 | ShufflePoints() 647 | ] 648 | 649 | test_transforms = [ 650 | SetDeterministic(), 651 | SplitSourceRef(mode="hdf"), 652 | RandomCrop(params.partial_ratio), 653 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 654 | Resampler(params.num_points), 655 | # RandomJitter(), 656 | ShufflePoints() 657 | ] 658 | 659 | elif params.transform_type == "modelnet_ts_rpmnet_noise": 660 | train_transforms = [ 661 | SplitSourceRef(mode="donothing"), 662 | RandomCrop(params.partial_ratio), 663 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 664 | Resampler(params.num_points), 665 | RandomJitter(noise_std=params.noise_std), 666 | ShufflePoints() 667 | ] 668 | 669 | test_transforms = [ 670 | SetDeterministic(), 671 | SplitSourceRef(mode="donothing"), 672 | RandomCrop(params.partial_ratio), 673 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 674 | Resampler(params.num_points), 675 | RandomJitter(noise_std=params.noise_std), 676 | ShufflePoints() 677 | ] 678 | 679 | elif params.transform_type == "modelnet_ts_rpmnet_clean": 680 | train_transforms = [ 681 | SplitSourceRef(mode="donothing"), 682 | RandomCrop(params.partial_ratio), 683 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 684 | Resampler(params.num_points), 685 | # RandomJitter(), 686 | ShufflePoints() 687 | ] 688 | 689 | test_transforms = [ 690 | SetDeterministic(), 691 | SplitSourceRef(mode="donothing"), 692 | RandomCrop(params.partial_ratio), 693 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 694 | Resampler(params.num_points), 695 | # RandomJitter(), 696 | ShufflePoints() 697 | ] 698 | 699 | elif params.transform_type == "modelnet_ts_prnet_noise": 700 | train_transforms = [ 701 | SplitSourceRef(mode="donothing"), 702 | ShufflePoints(), 703 | PRNetTorch(num_points=params.num_points, 704 | rot_mag=params.rot_mag, 705 | trans_mag=params.trans_mag, 706 | noise_std=params.noise_std, 707 | add_noise=True) 708 | ] 709 | 710 | test_transforms = [ 711 | SetDeterministic(), 712 | SplitSourceRef(mode="donothing"), 713 | ShufflePoints(), 714 | PRNetTorch(num_points=params.num_points, 715 | rot_mag=params.rot_mag, 716 | trans_mag=params.trans_mag, 717 | noise_std=params.noise_std, 718 | add_noise=True) 719 | ] 720 | 721 | elif params.transform_type == "modelnet_ts_prnet_clean": 722 | train_transforms = [ 723 | SplitSourceRef(mode="donothing"), 724 | ShufflePoints(), 725 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 726 | ] 727 | 728 | test_transforms = [ 729 | SetDeterministic(), 730 | SplitSourceRef(mode="donothing"), 731 | ShufflePoints(), 732 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 733 | ] 734 | 735 | elif params.transform_type == "modelnet_os_prnet_noise": 736 | train_transforms = [ 737 | SplitSourceRef(mode="hdf"), 738 | ShufflePoints(), 739 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True) 740 | ] 741 | 742 | test_transforms = [ 743 | SetDeterministic(), 744 | SplitSourceRef(mode="hdf"), 745 | ShufflePoints(), 746 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True) 747 | ] 748 | 749 | elif params.transform_type == "modelnet_os_prnet_clean": 750 | train_transforms = [ 751 | SplitSourceRef(mode="hdf"), 752 | ShufflePoints(), 753 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 754 | ] 755 | 756 | test_transforms = [ 757 | SetDeterministic(), 758 | SplitSourceRef(mode="hdf"), 759 | ShufflePoints(), 760 | PRNetTorch(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 761 | ] 762 | 763 | elif params.transform_type == "modelnet_os_prnet_clean_onlyz": 764 | train_transforms = [ 765 | SplitSourceRef(mode="hdf"), 766 | ShufflePoints(), 767 | PRNetTorch(num_points=params.num_points, 768 | rot_mag=params.rot_mag, 769 | trans_mag=params.trans_mag, 770 | add_noise=False, 771 | only_z=True, 772 | partial=False) 773 | ] 774 | 775 | test_transforms = [ 776 | SetDeterministic(), 777 | SplitSourceRef(mode="hdf"), 778 | ShufflePoints(), 779 | PRNetTorch(num_points=params.num_points, 780 | rot_mag=params.rot_mag, 781 | trans_mag=params.trans_mag, 782 | add_noise=False, 783 | only_z=True, 784 | partial=False) 785 | ] 786 | 787 | elif params.transform_type == "modelnet_ts_prnet_clean_onlyz": 788 | train_transforms = [ 789 | SplitSourceRef(mode="donothing"), 790 | ShufflePoints(), 791 | PRNetTorch(num_points=params.num_points, 792 | rot_mag=params.rot_mag, 793 | trans_mag=params.trans_mag, 794 | add_noise=False, 795 | only_z=True, 796 | partial=False) 797 | ] 798 | 799 | test_transforms = [ 800 | SetDeterministic(), 801 | SplitSourceRef(mode="donothing"), 802 | ShufflePoints(), 803 | PRNetTorch(num_points=params.num_points, 804 | rot_mag=params.rot_mag, 805 | trans_mag=params.trans_mag, 806 | add_noise=False, 807 | only_z=True, 808 | partial=False) 809 | ] 810 | 811 | elif params.transform_type == "modelnet_ts_prnet_noise_overlap_ratio": 812 | train_transforms = [ 813 | SplitSourceRef(mode="donothing"), 814 | PRNetTorchOverlapRatio(num_points=params.num_points, 815 | rot_mag=params.rot_mag, 816 | trans_mag=params.trans_mag, 817 | add_noise=True, 818 | overlap_ratio=params.overlap_ratio), 819 | ShufflePoints() 820 | ] 821 | 822 | test_transforms = [ 823 | SetDeterministic(), 824 | SplitSourceRef(mode="donothing"), 825 | PRNetTorchOverlapRatio(num_points=params.num_points, 826 | rot_mag=params.rot_mag, 827 | trans_mag=params.trans_mag, 828 | add_noise=True, 829 | overlap_ratio=params.overlap_ratio), 830 | ShufflePoints() 831 | ] 832 | 833 | else: 834 | raise NotImplementedError 835 | 836 | _logger.info("Train transforms: {}".format(", ".join([type(t).__name__ for t in train_transforms]))) 837 | _logger.info("Val and Test transforms: {}".format(", ".join([type(t).__name__ for t in test_transforms]))) 838 | train_transforms = torchvision.transforms.Compose(train_transforms) 839 | test_transforms = torchvision.transforms.Compose(test_transforms) 840 | return train_transforms, test_transforms 841 | 842 | 843 | if __name__ == "__main__": 844 | print("hello world") 845 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | 5 | import time 6 | import torch 7 | 8 | import dataset.data_loader as data_loader 9 | import model.net as net 10 | from common import utils 11 | from common.manager import Manager 12 | from loss.loss import compute_loss, compute_metrics 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_dir", type=str, default="./experiments/experiment_omnet", help="Directory containing params.json") 16 | parser.add_argument("--restore_file", type=str, help="name of the file in --model_dir containing weights to load") 17 | 18 | 19 | def test(model, manager): 20 | # set model to evaluation mode 21 | torch.cuda.empty_cache() 22 | model.eval() 23 | 24 | with torch.no_grad(): 25 | # compute metrics over the dataset 26 | if manager.dataloaders["val"] is not None: 27 | # inference time 28 | total_time = 0. 29 | all_endpoints = defaultdict(list) 30 | # loss status and val status initial 31 | manager.reset_loss_status() 32 | manager.reset_metric_status("val") 33 | for batch_idx, data_batch in enumerate(manager.dataloaders["val"]): 34 | # move to GPU if available 35 | data_batch = utils.tensor_gpu(data_batch) 36 | # compute model output 37 | output_batch = model(data_batch) 38 | 39 | # real batch size 40 | batch_size = data_batch["points_src"].size()[0] 41 | # compute all loss on this batch 42 | loss = compute_loss(output_batch, manager.params) 43 | manager.update_loss_status(loss, batch_size) 44 | # compute all metrics on this batch 45 | metrics = compute_metrics(output_batch, manager.params) 46 | manager.update_metric_status(metrics, "val", batch_size) 47 | 48 | # compute RMSE metrics 49 | manager.summarize_metric_status(metrics, "val") 50 | # For each epoch, update and print the metric 51 | manager.print_metrics("val", title="Val", color="green") 52 | 53 | if manager.dataloaders["test"] is not None: 54 | # inference time 55 | total_time = {"total": 0.} 56 | total_time_outside = 0. 57 | all_endpoints = defaultdict(list) 58 | # loss status and test status initial 59 | manager.reset_loss_status() 60 | manager.reset_metric_status("test") 61 | for batch_idx, data_batch in enumerate(manager.dataloaders["test"]): 62 | # move to GPU if available 63 | data_batch = utils.tensor_gpu(data_batch) 64 | # compute model output 65 | start_time = time.time() 66 | output_batch = model(data_batch) 67 | total_time_outside += time.time() - start_time 68 | 69 | # real batch size 70 | batch_size = data_batch["points_src"].size()[0] 71 | # compute all loss on this batch 72 | loss = compute_loss(output_batch, manager.params) 73 | manager.update_loss_status(loss, batch_size) 74 | # compute all metrics on this batch 75 | metrics = compute_metrics(output_batch, manager.params) 76 | manager.update_metric_status(metrics, "test", batch_size) 77 | 78 | # compute RMSE metrics 79 | manager.summarize_metric_status(metrics, "test") 80 | # For each epoch, print the metric 81 | manager.print_metrics("test", title="Test", color="red") 82 | 83 | 84 | if __name__ == '__main__': 85 | # Load the parameters 86 | args = parser.parse_args() 87 | json_path = os.path.join(args.model_dir, 'params.json') 88 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 89 | params = utils.Params(json_path) 90 | 91 | # Only load model weights 92 | params.only_weights = True 93 | 94 | # Update args into params 95 | params.update(vars(args)) 96 | 97 | # Get the logger 98 | logger = utils.set_logger(os.path.join(args.model_dir, 'evaluate.log')) 99 | 100 | # Use GPU if available 101 | params.cuda = torch.cuda.is_available() 102 | if params.cuda: 103 | num_gpu = torch.cuda.device_count() 104 | if num_gpu > 0: 105 | torch.cuda.set_device(0) 106 | gpu_ids = ", ".join(str(i) for i in [j for j in range(num_gpu)]) 107 | logger.info("Using GPU ids: [{}]".format(gpu_ids)) 108 | torch.backends.cudnn.deterministic = True 109 | torch.backends.cudnn.benchmark = False 110 | torch.backends.cudnn.enabled = False 111 | 112 | # Fetch dataloaders 113 | dataloaders = data_loader.fetch_dataloader(params) 114 | 115 | # Define the model and optimizer 116 | if params.cuda: 117 | model = net.fetch_net(params).cuda() 118 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 119 | else: 120 | model = net.fetch_net(params) 121 | 122 | # Initial status for checkpoint manager 123 | manager = Manager(model=model, optimizer=None, scheduler=None, params=params, dataloaders=dataloaders, logger=logger) 124 | 125 | # Reload weights from the saved file 126 | manager.load_checkpoints() 127 | 128 | # Test the model 129 | logger.info("Starting test") 130 | 131 | # Evaluate 132 | test(model, manager) 133 | -------------------------------------------------------------------------------- /experiments/experiment_omnet/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "omnet", 3 | "model_dir": "experiments/experiment_omnet", 4 | "dataset_type": "modelnet_os", 5 | "transform_type": "modelnet_os_prnet_clean", 6 | "net_type": "omnet", 7 | "loss_type": "omnet", 8 | "eval_type": [ 9 | "val", 10 | "test" 11 | ], 12 | "major_metric": "score", 13 | "titer": 4, 14 | "loss_alpha1": 1, 15 | "loss_alpha2": 4, 16 | "gamma": 1, 17 | "rot_mag": 45, 18 | "trans_mag": 0.5, 19 | "noise_std": 0.01, 20 | "save_summary_steps": 100, 21 | "train_batch_size": 64, 22 | "eval_batch_size": 256, 23 | "num_workers": 8, 24 | "partial_ratio": [ 25 | 0.7, 26 | 0.7 27 | ], 28 | "num_points": 1024, 29 | "overlap_dist": 0.1, 30 | "num_epochs": 10000, 31 | "learning_rate": 0.0001 32 | } 33 | -------------------------------------------------------------------------------- /images/OMNet_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/images/OMNet_poster.png -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/images/pipeline.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/loss/__init__.py -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from common import se3, so3 5 | from common.utils import tensor_gpu 6 | 7 | 8 | class LossL1(nn.Module): 9 | def __init__(self): 10 | super(LossL1, self).__init__() 11 | self.loss = nn.L1Loss() 12 | 13 | def __call__(self, input, target): 14 | return self.loss(input, target) 15 | 16 | 17 | class LossL2(nn.Module): 18 | def __init__(self): 19 | super(LossL2, self).__init__() 20 | self.loss = nn.MSELoss() 21 | 22 | def __call__(self, input, target): 23 | return self.loss(input, target) 24 | 25 | 26 | class LossCrossEntropy(nn.Module): 27 | def __init__(self, weight=None): 28 | super(LossCrossEntropy, self).__init__() 29 | self.loss = torch.nn.CrossEntropyLoss(weight=weight) 30 | 31 | def __call__(self, input, target, weight=None): 32 | return self.loss(input, target) 33 | 34 | 35 | def compute_loss(endpoints, params): 36 | loss = {} 37 | 38 | l1_criterion = LossL1() 39 | l2_criterion = LossL2() 40 | cls_criterion = LossCrossEntropy(weight=torch.tensor([0.7, 0.3]).cuda()) 41 | num_iter = len(endpoints["all_pose_pair"]) 42 | if params.loss_type == "omnet": 43 | for i in range(num_iter): 44 | # cls loss 45 | src_cls_pair, ref_cls_pair = endpoints['all_src_cls_pair'][i], endpoints['all_ref_cls_pair'][i] 46 | src_cls = cls_criterion(src_cls_pair[1], src_cls_pair[0].long()) 47 | ref_cls = cls_criterion(ref_cls_pair[1], ref_cls_pair[0].long()) 48 | loss['cls_{}'.format(i)] = (src_cls + ref_cls) / 2.0 49 | # reg loss 50 | pose_pair = endpoints["all_pose_pair"][i] 51 | loss["quat_{}".format(i)] = l1_criterion(pose_pair[0][:, :4], pose_pair[1][:, :4]) * params.loss_alpha1 52 | loss["translate_{}".format(i)] = l2_criterion(pose_pair[0][:, 4:], pose_pair[1][:, 4:]) * params.loss_alpha2 53 | # total loss 54 | total_loss = [] 55 | for k in loss: 56 | total_loss.append(loss[k].float()) 57 | loss["total"] = torch.sum(torch.stack(total_loss), dim=0) 58 | else: 59 | raise NotImplementedError 60 | 61 | return loss 62 | 63 | 64 | def compute_metrics(endpoints, params): 65 | metrics = {} 66 | with torch.no_grad(): 67 | gt_transforms = endpoints["transform_pair"][0] 68 | pred_transforms = endpoints["transform_pair"][1] 69 | 70 | # Euler angles, Individual translation errors (Deep Closest Point convention) 71 | if "prnet" in params.transform_type: 72 | r_gt_euler_deg = so3.torch_dcm2euler(gt_transforms[:, :3, :3], seq="zyx") 73 | r_pred_euler_deg = so3.torch_dcm2euler(pred_transforms[:, :3, :3], seq="zyx") 74 | else: 75 | r_gt_euler_deg = so3.torch_dcm2euler(gt_transforms[:, :3, :3], seq="xyz") 76 | r_pred_euler_deg = so3.torch_dcm2euler(pred_transforms[:, :3, :3], seq="xyz") 77 | t_gt = gt_transforms[:, :3, 3] 78 | t_pred = pred_transforms[:, :3, 3] 79 | 80 | r_mse = torch.mean((r_gt_euler_deg - r_pred_euler_deg)**2, dim=1) 81 | r_mae = torch.mean(torch.abs(r_gt_euler_deg - r_pred_euler_deg), dim=1) 82 | t_mse = torch.mean((t_gt - t_pred)**2, dim=1) 83 | t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1) 84 | 85 | r_mse = torch.mean(r_mse) 86 | t_mse = torch.mean(t_mse) 87 | r_mae = torch.mean(r_mae) 88 | t_mae = torch.mean(t_mae) 89 | 90 | # Rotation, translation errors (isotropic, i.e. doesn"t depend on error 91 | # direction, which is more representative of the actual error) 92 | concatenated = se3.torch_concatenate(se3.torch_inverse(gt_transforms), pred_transforms) 93 | rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2] 94 | residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi 95 | residual_transmag = concatenated[:, :, 3].norm(dim=-1) 96 | err_r = torch.mean(residual_rotdeg) 97 | err_t = torch.mean(residual_transmag) 98 | 99 | # weighted score of isotropic errors 100 | score = err_r * 0.01 + err_t 101 | 102 | metrics = {"R_MSE": r_mse, "R_MAE": r_mae, "t_MSE": t_mse, "t_MAE": t_mae, "Err_R": err_r, "Err_t": err_t, "score": score} 103 | metrics = tensor_gpu(metrics, check_on=False) 104 | 105 | return metrics 106 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxwork/OMNet_Pytorch/0c6d669c70f79d15b24d29a749b054a1adbe3c2f/model/__init__.py -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # ======================================================================================================== 6 | 7 | 8 | class OMNetEncoder(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.conv_block1 = nn.Sequential(nn.Conv1d(3, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True)) 12 | 13 | self.conv_block2 = nn.Sequential(nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True)) 14 | 15 | self.conv_block3 = nn.Sequential(nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True)) 16 | 17 | self.conv_block4 = nn.Sequential(nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True)) 18 | 19 | self.conv_block5 = nn.Sequential(nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU(inplace=True)) 20 | 21 | def forward(self, x, mask=None): 22 | x = self.conv_block1(x) 23 | point_feat64 = self.conv_block2(x) 24 | x = self.conv_block3(point_feat64) 25 | x = self.conv_block4(x) 26 | point_feat1024 = self.conv_block5(x) 27 | 28 | if mask is None: 29 | L = [point_feat64, point_feat1024] 30 | glob_feat = torch.max(point_feat1024, dim=-1, keepdim=True)[0] 31 | else: 32 | L = [point_feat64 * mask, point_feat1024 * mask] 33 | glob_feat = torch.max(point_feat1024 * mask, dim=-1, keepdim=True)[0] 34 | 35 | return L, glob_feat 36 | 37 | 38 | class OMNetFusion(nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | self.conv_block1 = nn.Sequential(nn.Conv1d(2048 + 64, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU(inplace=True)) 42 | 43 | self.conv_block2 = nn.Sequential(nn.Conv1d(1024, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU(inplace=True)) 44 | 45 | self.conv_block3 = nn.Sequential(nn.Conv1d(1024, 512, 1), nn.BatchNorm1d(512), nn.ReLU(inplace=True)) 46 | 47 | def forward(self, x, mask=None): 48 | 49 | point_feat1024_0 = self.conv_block1(x) 50 | point_feat1024_1 = self.conv_block2(point_feat1024_0) 51 | fuse_feat = self.conv_block3(point_feat1024_1) 52 | 53 | if mask is None: 54 | L = [point_feat1024_0, point_feat1024_1] 55 | fuse_feat = fuse_feat 56 | 57 | else: 58 | L = [point_feat1024_0 * mask, point_feat1024_1 * mask] 59 | fuse_feat = fuse_feat * mask 60 | 61 | return L, fuse_feat 62 | 63 | 64 | class OMNetDecoder(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.conv_block1 = nn.Sequential( 68 | nn.Conv1d(512, 512, 1), 69 | nn.BatchNorm1d(512), 70 | nn.ReLU(inplace=True), 71 | ) 72 | 73 | self.conv_block2 = nn.Sequential( 74 | nn.Conv1d(512, 256, 1), 75 | nn.BatchNorm1d(256), 76 | nn.ReLU(inplace=True), 77 | ) 78 | 79 | self.conv_block3 = nn.Sequential( 80 | nn.Conv1d(256, 256, 1), 81 | nn.BatchNorm1d(256), 82 | nn.ReLU(inplace=True), 83 | ) 84 | 85 | self.conv_block4 = nn.Sequential(nn.Conv1d(256, 2, 1), ) 86 | 87 | def forward(self, x): 88 | point_feat512 = self.conv_block1(x) 89 | point_feat256_0 = self.conv_block2(point_feat512) 90 | point_feat256_1 = self.conv_block3(point_feat256_0) 91 | L = [point_feat512, point_feat256_0, point_feat256_1] 92 | cls = self.conv_block4(point_feat256_1) 93 | 94 | return L, cls 95 | 96 | 97 | class OMNetRegression(nn.Module): 98 | def __init__(self): 99 | super().__init__() 100 | self.fc_block1 = nn.Sequential( 101 | nn.Linear(3072, 2048), 102 | nn.BatchNorm1d(2048), 103 | nn.ReLU(inplace=True), 104 | ) 105 | 106 | self.fc_block2 = nn.Sequential( 107 | nn.Linear(2048, 1024), 108 | nn.BatchNorm1d(1024), 109 | nn.ReLU(inplace=True), 110 | ) 111 | 112 | self.fc_block3 = nn.Sequential( 113 | nn.Linear(1024, 512), 114 | nn.BatchNorm1d(512), 115 | nn.ReLU(inplace=True), 116 | ) 117 | 118 | self.fc_block4 = nn.Sequential( 119 | nn.Linear(512, 256), 120 | nn.BatchNorm1d(256), 121 | nn.ReLU(inplace=True), 122 | ) 123 | 124 | self.final_fc = nn.Sequential(nn.Linear(256, 7), ) 125 | 126 | def forward(self, x): 127 | x = self.fc_block1(x) 128 | x = self.fc_block2(x) 129 | x = self.fc_block3(x) 130 | x = self.fc_block4(x) 131 | pred_pose = self.final_fc(x) 132 | pred_quat, pred_translate, = pred_pose[:, :4], pred_pose[:, 4:] 133 | pred_quat = F.normalize(pred_quat, dim=1) 134 | pred_pose = torch.cat((pred_quat, pred_translate), dim=1) # (B, 7) 135 | 136 | return pred_pose 137 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import transforms3d.euler as t3d 6 | from common import quaternion, se3 7 | 8 | from model.module import * 9 | 10 | 11 | class OMNet(nn.Module): 12 | def __init__(self, params): 13 | super().__init__() 14 | self._logger = logging.getLogger(self.__class__.__name__) 15 | self.num_iter = params.titer 16 | self.encoder = nn.ModuleList([OMNetEncoder() for _ in range(self.num_iter)]) 17 | self.fusion = nn.ModuleList([OMNetFusion() for _ in range(self.num_iter)]) 18 | self.decoder = nn.ModuleList([OMNetDecoder() for _ in range(self.num_iter)]) 19 | self.regression = nn.ModuleList([OMNetRegression() for _ in range(self.num_iter)]) 20 | self.overlap_dist = params.overlap_dist 21 | 22 | def generate_overlap_mask(self, points_src: torch.Tensor, points_ref: torch.Tensor, mask_src: torch.Tensor, mask_ref: torch.Tensor, 23 | transform_gt: torch.Tensor): 24 | points_src[torch.logical_not(mask_src), :] = 50.0 25 | points_ref[torch.logical_not(mask_ref), :] = 100.0 26 | points_src = se3.torch_transform(transform_gt, points_src) 27 | dist_matrix = torch.sqrt(torch.sum(torch.square(points_src[:, :, None, :] - points_ref[:, None, :, :]), dim=-1)) # (B, N, N) 28 | dist_s2r = torch.min(dist_matrix, dim=2)[0] 29 | dist_r2s = torch.min(dist_matrix, dim=1)[0] 30 | overlap_src_mask = dist_s2r < self.overlap_dist # (B, N) 31 | overlap_ref_mask = dist_r2s < self.overlap_dist # (B, N) 32 | 33 | return overlap_src_mask, overlap_ref_mask 34 | 35 | def forward(self, data): 36 | endpoints = {} 37 | 38 | xyz_src = data["points_src"][:, :, :3] 39 | xyz_ref = data["points_ref"][:, :, :3] 40 | transform_gt = data["transform_gt"] 41 | pose_gt = data["pose_gt"] 42 | 43 | # init endpoints 44 | all_src_cls_pair = [] 45 | all_ref_cls_pair = [] 46 | all_transform_pair = [] 47 | all_pose_pair = [] 48 | 49 | # init params 50 | B, src_N, _ = xyz_src.size() 51 | _, ref_N, _ = xyz_ref.size() 52 | init_quat = t3d.euler2quat(0., 0., 0., "sxyz") 53 | init_quat = torch.from_numpy(init_quat).expand(B, 4) 54 | init_translate = torch.from_numpy(np.array([[0., 0., 0.]])).expand(B, 3) 55 | pose_pred = torch.cat((init_quat, init_translate), dim=1).float().cuda() # (B, 7) 56 | transform_pred = quaternion.torch_quat2mat(pose_pred) 57 | src_pred_mask = torch.ones(size=(B, src_N), dtype=xyz_src.dtype).cuda() 58 | ref_pred_mask = torch.ones(size=(B, ref_N), dtype=xyz_ref.dtype).cuda() 59 | overlap_src_mask, overlap_ref_mask = self.generate_overlap_mask(xyz_src.clone(), xyz_ref.clone(), src_pred_mask, ref_pred_mask, 60 | transform_gt) 61 | 62 | # rename xyz_src 63 | xyz_src_iter = xyz_src.clone() 64 | 65 | for i in range(self.num_iter): 66 | # mask deley 67 | if i < 2: 68 | src_pred_mask = torch.ones(size=(B, src_N), dtype=xyz_src.dtype).cuda() 69 | ref_pred_mask = torch.ones(size=(B, ref_N), dtype=xyz_ref.dtype).cuda() 70 | 71 | # encoder 72 | src_encoder_feats, src_glob_feat = self.encoder[i](xyz_src_iter.transpose(1, 2).detach(), src_pred_mask.unsqueeze(1)) 73 | ref_encoder_feats, ref_glob_feat = self.encoder[i](xyz_ref.transpose(1, 2), ref_pred_mask.unsqueeze(1)) 74 | 75 | # fusion 76 | src_cat_feat = torch.cat((src_encoder_feats[0], src_glob_feat.repeat(1, 1, src_N), ref_glob_feat.repeat(1, 1, src_N)), dim=1) 77 | ref_cat_feat = torch.cat((ref_encoder_feats[0], ref_glob_feat.repeat(1, 1, ref_N), src_glob_feat.repeat(1, 1, ref_N)), dim=1) 78 | _, src_fused_feat = self.fusion[i](src_cat_feat, src_pred_mask.unsqueeze(1)) 79 | _, ref_fused_feat = self.fusion[i](ref_cat_feat, ref_pred_mask.unsqueeze(1)) 80 | 81 | # decoder 82 | src_decoder_feats, src_cls_pred = self.decoder[i](src_fused_feat) 83 | ref_decoder_feats, ref_cls_pred = self.decoder[i](ref_fused_feat) 84 | 85 | # regression 86 | src_feat = torch.cat(src_decoder_feats, dim=1) * src_pred_mask.unsqueeze(1) 87 | ref_feat = torch.cat(ref_decoder_feats, dim=1) * ref_pred_mask.unsqueeze(1) 88 | cat_feat = torch.cat((src_fused_feat, src_feat, ref_fused_feat, ref_feat), dim=1) 89 | cat_feat = torch.max(cat_feat, dim=-1)[0] 90 | pose_pred_iter = self.regression[i](cat_feat) # (B, 7) 91 | xyz_src_iter = quaternion.torch_quat_transform(pose_pred_iter, xyz_src_iter.detach()) 92 | pose_pred = quaternion.torch_transform_pose(pose_pred.detach(), pose_pred_iter) 93 | transform_pred = quaternion.torch_quat2mat(pose_pred) 94 | 95 | # compute overlap and cls gt 96 | overlap_src_mask, overlap_ref_mask = self.generate_overlap_mask(xyz_src.clone(), xyz_ref.clone(), src_pred_mask, ref_pred_mask, 97 | transform_gt) 98 | src_cls_gt = torch.ones(B, src_N).cuda() * overlap_src_mask 99 | ref_cls_gt = torch.ones(B, ref_N).cuda() * overlap_ref_mask 100 | src_pred_mask = torch.argmax(src_cls_pred, dim=1) 101 | ref_pred_mask = torch.argmax(ref_cls_pred, dim=1) 102 | 103 | # add endpoints 104 | all_src_cls_pair.append([src_cls_gt, src_cls_pred]) 105 | all_ref_cls_pair.append([ref_cls_gt, ref_cls_pred]) 106 | all_transform_pair.append([transform_gt, transform_pred]) 107 | all_pose_pair.append([pose_gt, pose_pred]) 108 | 109 | endpoints["all_src_cls_pair"] = all_src_cls_pair 110 | endpoints["all_ref_cls_pair"] = all_ref_cls_pair 111 | endpoints["all_transform_pair"] = all_transform_pair 112 | endpoints["all_pose_pair"] = all_pose_pair 113 | endpoints["transform_pair"] = [transform_gt, transform_pred] 114 | endpoints["pose_pair"] = [pose_gt, pose_pred] 115 | 116 | return endpoints 117 | 118 | 119 | def fetch_net(params): 120 | if params.net_type == "omnet": 121 | net = OMNet(params) 122 | 123 | else: 124 | raise NotImplementedError 125 | 126 | return net 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coloredlogs==15.0.1 2 | h5py==3.1.0 3 | numpy==1.19.5 4 | pytorch3d==0.3.0 5 | scipy==1.5.3 6 | termcolor==1.1.0 7 | torch==1.9.1 8 | torchvision==0.10.1 9 | tqdm==4.62.3 10 | transforms3d==0.3.1 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | 5 | import torch 6 | import torch.optim as optim 7 | from tqdm import tqdm 8 | 9 | import dataset.data_loader as data_loader 10 | import model.net as net 11 | from loss.loss import compute_loss, compute_metrics 12 | from common import utils 13 | from common.manager import Manager 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_dir", default="experiments/base_model", help="Directory containing params.json") 17 | parser.add_argument("--restore_file", 18 | default=None, 19 | help="Optional, name of the file in --model_dir containing weights to reload before training") 20 | parser.add_argument("-ow", "--only_weights", action="store_true", help="Only use weights to load or load all train status.") 21 | 22 | 23 | def train(model, manager: Manager): 24 | """Train the model on `num_steps` batches 25 | 26 | Args: 27 | model: (torch.nn.Module) the neural network 28 | optimizer: (torch.optim) optimizer for parameters of model 29 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch 30 | dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data 31 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch 32 | params: (Params) hyperparameters 33 | num_steps: (int) number of batches to train on, each of size params.batch_size 34 | """ 35 | 36 | # loss status initial 37 | manager.reset_loss_status() 38 | 39 | # set model to training mode 40 | torch.cuda.empty_cache() 41 | model.train() 42 | 43 | # Use tqdm for progress bar 44 | with tqdm(total=len(manager.dataloaders["train"])) as t: 45 | for batch_idx, data_batch in enumerate(manager.dataloaders["train"]): 46 | # move to GPU if available 47 | data_batch = utils.tensor_gpu(data_batch) 48 | 49 | # compute model output and loss 50 | output_batch = model(data_batch) 51 | losses = compute_loss(output_batch, manager.params) 52 | 53 | # real batch size 54 | batch_size = data_batch["points_src"].size()[0] 55 | 56 | # update loss status and print current loss and average loss 57 | manager.update_loss_status(loss=losses, batch_size=batch_size) 58 | 59 | # clear previous gradients, compute gradients of all variables wrt loss 60 | manager.optimizer.zero_grad() 61 | losses["total"].backward() 62 | # performs updates using calculated gradients 63 | manager.optimizer.step() 64 | 65 | manager.write_loss_to_tb(split="train") 66 | 67 | # update step: step += 1 68 | manager.update_step() 69 | 70 | # info print 71 | print_str = manager.print_train_info() 72 | 73 | t.set_description(desc=print_str) 74 | t.update() 75 | 76 | manager.scheduler.step() 77 | # update epoch: epoch += 1 78 | manager.update_epoch() 79 | 80 | 81 | def evaluate(model, manager: Manager): 82 | """Evaluate the model on `num_steps` batches. 83 | 84 | Args: 85 | model: (torch.nn.Module) the neural network 86 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch 87 | dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data 88 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch 89 | params: (Params) hyperparameters 90 | num_steps: (int) number of batches to train on, each of size params.batch_size 91 | """ 92 | 93 | # set model to evaluation mode 94 | torch.cuda.empty_cache() 95 | model.eval() 96 | with torch.no_grad(): 97 | # compute metrics over the dataset 98 | if manager.dataloaders["val"] is not None: 99 | # loss status and val status initial 100 | manager.reset_loss_status() 101 | manager.reset_metric_status("val") 102 | for batch_idx, data_batch in enumerate(manager.dataloaders["val"]): 103 | # move to GPU if available 104 | data_batch = utils.tensor_gpu(data_batch) 105 | # compute model output 106 | output_batch = model(data_batch) 107 | # real batch size 108 | batch_size = data_batch["points_src"].size()[0] 109 | # compute all loss on this batch 110 | loss = compute_loss(output_batch, manager.params) 111 | manager.update_loss_status(loss, batch_size) 112 | # compute all metrics on this batch 113 | metrics = compute_metrics(output_batch, manager.params) 114 | manager.update_metric_status(metrics, "val", batch_size) 115 | 116 | # compute RMSE metrics 117 | manager.summarize_metric_status(metrics, "val") 118 | # update data to tensorboard 119 | manager.write_metric_to_tb(split="val") 120 | # For each epoch, update and print the metric 121 | manager.print_metrics("val", title="Val", color="green", only_best=True) 122 | 123 | if manager.dataloaders["test"] is not None: 124 | # loss status and test status initial 125 | manager.reset_loss_status() 126 | manager.reset_metric_status("test") 127 | for batch_idx, data_batch in enumerate(manager.dataloaders["test"]): 128 | # move to GPU if available 129 | data_batch = utils.tensor_gpu(data_batch) 130 | # compute model output 131 | output_batch = model(data_batch) 132 | # real batch size 133 | batch_size = data_batch["points_src"].size()[0] 134 | # compute all loss on this batch 135 | loss = compute_loss(output_batch, manager.params) 136 | manager.update_loss_status(loss, batch_size) 137 | # compute all metrics on this batch 138 | metrics = compute_metrics(output_batch, manager.params) 139 | manager.update_metric_status(metrics, "test", batch_size) 140 | 141 | # compute RMSE metrics 142 | manager.summarize_metric_status(metrics, "test") 143 | # update data to tensorboard 144 | manager.write_metric_to_tb(split="test") 145 | # For each epoch, update and print the metric 146 | manager.print_metrics("test", title="Test", color="red", only_best=True) 147 | 148 | 149 | def train_and_evaluate(model, manager: Manager): 150 | """Train the model and evaluate every epoch. 151 | 152 | Args: 153 | model: (torch.nn.Module) the neural network 154 | train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data 155 | """ 156 | 157 | # reload weights from restore_file if specified 158 | if args.restore_file is not None: 159 | manager.load_checkpoints() 160 | 161 | for epoch in range(manager.epoch, manager.params.num_epochs): 162 | # compute number of batches in one epoch (one full pass over the training set) 163 | train(model, manager) 164 | 165 | # Evaluate for one epoch on validation set 166 | evaluate(model, manager) 167 | 168 | # Check if current is best, save checkpoints if best, meanwhile, save latest checkpoints 169 | manager.check_best_save_last_checkpoints(save_latest_freq=100, save_best_after=1000) 170 | 171 | 172 | if __name__ == "__main__": 173 | # Load the parameters from json file 174 | args = parser.parse_args() 175 | json_path = os.path.join(args.model_dir, "params.json") 176 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 177 | params = utils.Params(json_path) 178 | 179 | # Update args into params 180 | params.update(vars(args)) 181 | 182 | # Set the logger 183 | logger = utils.set_logger(os.path.join(params.model_dir, "train.log")) 184 | 185 | # Set the tensorboard writer 186 | log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 187 | 188 | # use GPU if available 189 | params.cuda = torch.cuda.is_available() 190 | if params.cuda: 191 | num_gpu = torch.cuda.device_count() 192 | if num_gpu > 0: 193 | torch.cuda.set_device(0) 194 | gpu_ids = ", ".join(str(i) for i in [j for j in range(num_gpu)]) 195 | logger.info("Using GPU ids: [{}]".format(gpu_ids)) 196 | torch.backends.cudnn.deterministic = True 197 | torch.backends.cudnn.benchmark = False 198 | torch.backends.cudnn.enabled = False 199 | 200 | # fetch dataloaders 201 | dataloaders = data_loader.fetch_dataloader(params) 202 | 203 | # Define the model and optimizer 204 | if params.cuda: 205 | model = net.fetch_net(params).cuda() 206 | model = torch.nn.DataParallel(model, device_ids=range(num_gpu)) 207 | else: 208 | model = net.fetch_net(params) 209 | 210 | optimizer = optim.Adam(model.parameters(), lr=params.learning_rate) 211 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.gamma) 212 | 213 | # initial status for checkpoint manager 214 | manager = Manager(model=model, optimizer=optimizer, scheduler=scheduler, params=params, dataloaders=dataloaders, logger=logger) 215 | 216 | # Train the model 217 | logger.info("Starting training for {} epoch(s)".format(params.num_epochs)) 218 | 219 | train_and_evaluate(model, manager) 220 | --------------------------------------------------------------------------------