├── .gitignore ├── LICENSE ├── README.md ├── common ├── __init__.py ├── manager.py ├── quaternion.py ├── se3.py ├── so3.py └── utils.py ├── dataset ├── __init__.py ├── data_loader.py └── transformations.py ├── evaluate.py ├── experiments ├── experiment_finet │ └── params.json └── params.json ├── images └── FINet_poster.png ├── loss ├── __init__.py └── losses.py ├── model ├── __init__.py ├── module.py └── net.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/experiment_finet/summary 2 | experiments/experiment_finet/test_metrics_best.json 3 | experiments/experiment_finet/test_metrics_latest.json 4 | experiments/experiment_finet/test_model_best.pth 5 | experiments/experiment_finet/train.log 6 | experiments/experiment_finet/val_metrics_best.json 7 | experiments/experiment_finet/val_metrics_latest.json 8 | experiments/experiment_finet/val_model_best.pth 9 | experiments/experiment_finet/evaluate.log 10 | dataset/data/modelnet_os 11 | dataset/data/modelnet_ts 12 | dataset/data/OS_data.zip 13 | dataset/data/TS_data.zip 14 | experiments/experiment_finet/model_latest.pth 15 | __pycache__ 16 | */__pycache__ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Megvii Technology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI 2022] FINet: Dual Branches Feature Interaction for Partial-to-Partial Point Cloud Registration 2 | 3 |

Hao Xu1,2, Nianjin Ye2, Guanghui Liu1, Bing Zeng1, Shuaicheng Liu1

4 |

$^1$ University of Electronic Science and Technology of China

5 |

$^2$ Megvii Research

6 | 7 | 8 | This is the official implementation (MegEngine implementation) of our AAAI2022 paper [FINet](https://www.aaai.org/AAAI22Papers/AAAI-549.XuH.pdf). 9 | 10 | ## Presentation video: 11 | [[Youtube](https://www.youtube.com/watch?v=XDmE9iSx9WM)] [[Bilibili](https://www.bilibili.com/video/BV1z44y1s7up/)]. 12 | 13 | ## Abstract 14 | Data association is important in the point cloud registration. In this work, we propose to solve the partial-to-partial registration from a new perspective, by introducing multi-level feature interactions between the source and the reference clouds at the feature extraction stage, such that the registration can be realized without the attentions or explicit mask estimation for the overlapping detection as adopted previously. Specifically, we present FINet, a feature interactionbased structure with the capability to enable and strengthen the information associating between the inputs at multiple stages. To achieve this, we first split the features into two components, one for rotation and one for translation, based on the fact that they belong to different solution spaces, yielding a dual branches structure. Second, we insert several interaction modules at the feature extractor for the data association. Third, we propose a transformation sensitivity loss to obtain rotation-attentive and translation-attentive features. Experiments demonstrate that our method performs higher precision and robustness compared to the state-of-the-art traditional and learning-based methods 15 | 16 | 17 | ## Our Poster 18 | 19 | ![image](./images/FINet_poster.png) 20 | 21 | ## Dependencies 22 | 23 | * MegEngine==1.7.0 24 | * Other requirements please refer to`requirements.txt`. 25 | 26 | ## Data Preparation 27 | 28 | Following [OMNet](https://github.com/megvii-research/OMNet), we use the OS and TS data of the ModelNet40 dataset. 29 | 30 | ### OS data 31 | 32 | We refer the original data from PointNet as OS data, where point clouds are only sampled once from corresponding CAD models. We offer two ways to use OS data, (1) you can download this data from its original link [original_OS_data.zip](http://modelnet.cs.princeton.edu/). (2) you can also download the data that has been preprocessed by us from link [our_OS_data.zip](https://drive.google.com/file/d/1rXnbXwD72tkeu8x6wboMP0X7iL9LiBPq/view?usp=sharing). 33 | 34 | ### TS data 35 | 36 | Since OS data incurs over-fitting issue, we propose our TS data, where point clouds are randomly sampled twice from CAD models. You need to download our preprocessed ModelNet40 dataset first, where 8 axisymmetrical categories are removed and all CAD models have 40 randomly sampled point clouds. The download link is [TS_data.zip](https://drive.google.com/file/d/1DPBBI3Ulvp2Mx7SAZaBEyvADJzBvErFF/view?usp=sharing). All 40 point clouds of a CAD model are stacked to form a (40, 2048, 3) numpy array, you can easily obtain this data by using following code: 37 | 38 | ``` 39 | import numpy as np 40 | points = np.load("path_of_npy_file") 41 | print(points.shape, type(points)) # (40, 2048, 3), 42 | ``` 43 | 44 | Then, you need to put the data into `./dataset/data`, and the contents of directories are as follows: 45 | 46 | ``` 47 | ./dataset/data/ 48 | ├── modelnet40_half1_rm_rotate.txt 49 | ├── modelnet40_half2_rm_rotate.txt 50 | ├── modelnet_os 51 | │   ├── modelnet_os_test.pickle 52 | │   ├── modelnet_os_train.pickle 53 | │   ├── modelnet_os_val.pickle 54 | │   ├── test [1146 entries exceeds filelimit, not opening dir] 55 | │   ├── train [4194 entries exceeds filelimit, not opening dir] 56 | │   └── val [1002 entries exceeds filelimit, not opening dir] 57 | └── modelnet_ts 58 | ├── modelnet_ts_test.pickle 59 | ├── modelnet_ts_train.pickle 60 | ├── modelnet_ts_val.pickle 61 | ├── shape_names.txt 62 | ├── test [1146 entries exceeds filelimit, not opening dir] 63 | ├── train [4196 entries exceeds filelimit, not opening dir] 64 | └── val [1002 entries exceeds filelimit, not opening dir] 65 | ``` 66 | 67 | ## Training and Evaluation 68 | 69 | ### Begin training 70 | 71 | For ModelNet40 dataset, you can just run: 72 | 73 | ``` 74 | python3 train.py --model_dir=./experiments/experiment_finet/ 75 | ``` 76 | 77 | For other dataset, you need to add your own dataset class in `./dataset/data_loader.py`. Training with a lower batch size, such as 16, may obtain worse performance than training with a larger batch size, e.g., 64. 78 | 79 | ### Begin testing 80 | 81 | You need to download the pretrained checkpoint and run: 82 | 83 | ``` 84 | python3 evaluate.py --model_dir=./experiments/experiment_finet --restore_file=./experiments/experiment_finet/test_model_best.pth 85 | ``` 86 | 87 | This model weight is for TS data with Gaussian noise. Note that the performance is a little bit worse than the results reported in our paper (Pytorch implementation). 88 | 89 | MegEngine checkpoint for ModelNet40 dataset can be download via [Google Drive](https://drive.google.com/file/d/1nM9bzSYGYA8fsQ0-HSPLo4rOdkG5rxAS/view?usp=sharing). 90 | 91 | ## Citation 92 | 93 | ``` 94 | @InProceedings{Xu_2022_AAAI, 95 | author={Xu, Hao and Ye, Nianjin and Liu, Guanghui and Zeng, Bing and Liu, Shuaicheng}, 96 | title={FINet: Dual Branches Feature Interaction for Partial-to-Partial Point Cloud Registration}, 97 | booktitle={Proceedings of the Thirty-Sixth AAAI Conference on Artificial Intelligence}, 98 | year={2022} 99 | } 100 | ``` 101 | 102 | ## Acknowledgments 103 | 104 | In this project we use (parts of) the official implementations of the following works: 105 | 106 | * [RPMNet](https://github.com/yewzijian/RPMNet) (ModelNet40 preprocessing and evaluation) 107 | * [PRNet](https://github.com/WangYueFt/prnet) (ModelNet40 preprocessing) 108 | * [OMNet](https://github.com/megvii-research/OMNet) (Code base) 109 | 110 | We thank the respective authors for open sourcing their methods. 111 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/common/__init__.py -------------------------------------------------------------------------------- /common/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import megengine as mge 4 | import megengine.distributed as dist 5 | from collections import defaultdict 6 | from termcolor import colored 7 | 8 | from common import utils 9 | 10 | 11 | class Manager(): 12 | def __init__(self, model, optimizer, params, dataloaders, writer, logger, scheduler): 13 | # params status 14 | self.params = params 15 | self.optimizer = optimizer 16 | self.model = model 17 | self.dataloaders = dataloaders 18 | self.writer = writer 19 | self.logger = logger 20 | self.scheduler = scheduler 21 | 22 | # metric_rule should be either Descende or Ascende 23 | self.metric_rule = params.metric_rule 24 | 25 | self.epoch = 0 26 | self.step = 0 27 | 28 | # 越低越好 29 | if self.metric_rule == "Descende": 30 | self.best_val_score = 100 31 | self.best_test_score = 100 32 | # 越高越好 33 | elif self.metric_rule == "Ascende": 34 | self.best_val_score = 0 35 | self.best_test_score = 0 36 | 37 | self.cur_val_score = 0 38 | self.cur_test_score = 0 39 | 40 | # train status 41 | self.train_status = defaultdict(utils.AverageMeter) 42 | 43 | # val status 44 | self.val_status = defaultdict(utils.AverageMeter) 45 | 46 | # test status 47 | self.test_status = defaultdict(utils.AverageMeter) 48 | 49 | # model status 50 | self.loss_status = defaultdict(utils.AverageMeter) 51 | 52 | def update_step(self): 53 | self.step += 1 54 | 55 | def update_epoch(self): 56 | self.epoch += 1 57 | 58 | def update_loss_status(self, loss, split, bs=None): 59 | if split == "train": 60 | for k, v in loss.items(): 61 | bs = self.params.train_batch_size 62 | self.loss_status[k].update(val=v.item(), num=bs) 63 | elif split == "val": 64 | for k, v in loss.items(): 65 | self.loss_status[k].update(val=v.item(), num=bs) 66 | elif split == "test": 67 | for k, v in loss.items(): 68 | self.loss_status[k].update(val=v.item(), num=bs) 69 | else: 70 | raise ValueError("Wrong eval type: {}".format(split)) 71 | 72 | def update_metric_status(self, metrics, split, bs): 73 | if split == "val": 74 | for k, v in metrics.items(): 75 | self.val_status[k].update(val=v.item(), num=bs) 76 | self.cur_val_score = self.val_status[self.params.major_metric].avg 77 | elif split == "test": 78 | for k, v in metrics.items(): 79 | self.test_status[k].update(val=v.item(), num=bs) 80 | self.cur_test_score = self.test_status[self.params.major_metric].avg 81 | else: 82 | raise ValueError("Wrong eval type: {}".format(split)) 83 | 84 | def summarize_metric_status(self, metrics, split): 85 | if split == "val": 86 | for k in metrics: 87 | if k.endswith('MSE'): 88 | self.val_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.val_status[k].avg)) 89 | else: 90 | continue 91 | elif split == "test": 92 | for k in metrics: 93 | if k.endswith('MSE'): 94 | self.test_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.test_status[k].avg)) 95 | else: 96 | continue 97 | else: 98 | raise ValueError("Wrong eval type: {}".format(split)) 99 | 100 | def reset_loss_status(self): 101 | for k, v in self.loss_status.items(): 102 | self.loss_status[k].reset() 103 | 104 | def reset_metric_status(self, split): 105 | if split == "val": 106 | for k, v in self.val_status.items(): 107 | self.val_status[k].reset() 108 | elif split == "test": 109 | for k, v in self.test_status.items(): 110 | self.test_status[k].reset() 111 | else: 112 | raise ValueError("Wrong eval type: {}".format(split)) 113 | 114 | def print_train_info(self): 115 | exp_name = self.params.model_dir.split('/')[-1] 116 | print_str = "{} Epoch: {:4d}, lr={:.1E} ".format(exp_name, self.epoch, self.scheduler.get_lr()[0]) 117 | print_str += "total loss: {:.3f}({:.3f})".format(self.loss_status['total'].val, self.loss_status['total'].avg) 118 | return print_str 119 | 120 | def print_metrics(self, split, title="Eval", color="red", only_best=True): 121 | if split == "val": 122 | metric_status = self.val_status 123 | is_best = self.cur_val_score < self.best_val_score 124 | elif split == "test": 125 | metric_status = self.test_status 126 | is_best = self.cur_test_score < self.best_test_score 127 | else: 128 | raise ValueError("Wrong split string: {}".format(split)) 129 | print_str = " | ".join("{}: {:.3f}".format(k, v.avg) for k, v in metric_status.items()) 130 | if only_best: 131 | if is_best: 132 | utils.master_logger(self.logger, 133 | colored("Best Epoch: {}, {} Results: {}".format(self.epoch, title, print_str), color, attrs=["bold"]), 134 | dist.get_rank() == 0) 135 | else: 136 | utils.master_logger(self.logger, colored("Epoch: {}, {} Results: {}".format(self.epoch, title, print_str), 137 | color, 138 | attrs=["bold"]), 139 | dist.get_rank() == 0) 140 | 141 | def check_best_save_last_checkpoints(self, save_latest_freq=5, save_best_after=50): 142 | 143 | state = { 144 | "state_dict": self.model.state_dict(), 145 | "optimizer": self.optimizer.state_dict(), 146 | "scheduler": self.scheduler.state_dict(), 147 | "step": self.step, 148 | "epoch": self.epoch, 149 | } 150 | if "val" in self.dataloaders: 151 | state["best_val_score"] = self.best_val_score 152 | if "test" in self.dataloaders: 153 | state["best_test_score"] = self.best_test_score 154 | 155 | # save latest checkpoint 156 | if self.epoch % save_latest_freq == 0: 157 | latest_ckpt_name = os.path.join(self.params.model_dir, "model_latest.pth") 158 | mge.save(state, latest_ckpt_name) 159 | self.logger.info("Saved latest checkpoint to: {}".format(latest_ckpt_name)) 160 | 161 | # save val latest metrics, and check if val is best checkpoints 162 | if "val" in self.dataloaders: 163 | val_latest_metrics_name = os.path.join(self.params.model_dir, "val_metrics_latest.json") 164 | utils.save_dict_to_json(self.val_status, val_latest_metrics_name) 165 | 166 | # 越低越好 167 | if self.metric_rule == "Descende": 168 | is_best = self.cur_val_score < self.best_val_score 169 | # 越高越好 170 | elif self.metric_rule == "Ascende": 171 | is_best = self.cur_val_score > self.best_val_score 172 | else: 173 | raise Exception("metric_rule should be either Descende or Ascende") 174 | 175 | if is_best: 176 | # save metrics 177 | self.best_val_score = self.cur_val_score 178 | best_metrics_name = os.path.join(self.params.model_dir, "val_metrics_best.json") 179 | utils.save_dict_to_json(self.val_status, best_metrics_name) 180 | self.logger.info("Current is val best, score={:.3g}".format(self.best_val_score)) 181 | # save checkpoint 182 | if self.epoch > save_best_after: 183 | best_ckpt_name = os.path.join(self.params.model_dir, "val_model_best.pth") 184 | mge.save(state, best_ckpt_name) 185 | self.logger.info("Saved val best checkpoint to: {}".format(best_ckpt_name)) 186 | 187 | # save test latest metrics, and check if test is best checkpoints 188 | if "test" in self.dataloaders: 189 | test_latest_metrics_name = os.path.join(self.params.model_dir, "test_metrics_latest.json") 190 | utils.save_dict_to_json(self.test_status, test_latest_metrics_name) 191 | # lower is better 192 | if self.metric_rule == "Descende": 193 | is_best = self.cur_test_score < self.best_test_score 194 | # higher is better 195 | elif self.metric_rule == "Ascende": 196 | is_best = self.cur_test_score > self.best_test_score 197 | else: 198 | raise Exception("metric_rule should be either Descende or Ascende") 199 | if is_best: 200 | # save metrics 201 | self.best_test_score = self.cur_test_score 202 | best_metrics_name = os.path.join(self.params.model_dir, "test_metrics_best.json") 203 | utils.save_dict_to_json(self.test_status, best_metrics_name) 204 | self.logger.info("Current is test best, score={:.3g}".format(self.best_test_score)) 205 | # save checkpoint 206 | if self.epoch > save_best_after: 207 | best_ckpt_name = os.path.join(self.params.model_dir, "test_model_best.pth") 208 | mge.save(state, best_ckpt_name) 209 | self.logger.info("Saved test best checkpoint to: {}".format(best_ckpt_name)) 210 | 211 | def load_checkpoints(self): 212 | state = mge.load(self.params.restore_file) 213 | ckpt_component = [] 214 | if "state_dict" in state and self.model is not None: 215 | try: 216 | self.model.load_state_dict(state["state_dict"]) 217 | 218 | except Warning("Using custom loading net"): 219 | net_dict = self.model.state_dict() 220 | if "module" not in list(state["state_dict"].keys())[0]: 221 | state_dict = {"module." + k: v for k, v in state["state_dict"].items() if "module." + k in net_dict.keys()} 222 | else: 223 | state_dict = {k: v for k, v in state["state_dict"].items() if k in net_dict.keys()} 224 | net_dict.update(state_dict) 225 | self.model.load_state_dict(net_dict, strict=False) 226 | ckpt_component.append("net") 227 | 228 | if not self.params.only_weights: 229 | if "optimizer" in state and self.optimizer is not None: 230 | try: 231 | self.optimizer.load_state_dict(state["optimizer"]) 232 | 233 | except Warning("Using custom loading optimizer"): 234 | optimizer_dict = self.optimizer.state_dict() 235 | state_dict = {k: v for k, v in state["optimizer"].items() if k in optimizer_dict.keys()} 236 | optimizer_dict.update(state_dict) 237 | self.optimizer.load_state_dict(optimizer_dict) 238 | ckpt_component.append("opt") 239 | 240 | if "scheduler" in state and self.train_status["scheduler"] is not None: 241 | try: 242 | self.scheduler.load_state_dict(state["scheduler"]) 243 | 244 | except Warning("Using custom loading scheduler"): 245 | scheduler_dict = self.scheduler.state_dict() 246 | state_dict = {k: v for k, v in state["scheduler"].items() if k in scheduler_dict.keys()} 247 | scheduler_dict.update(state_dict) 248 | self.scheduler.load_state_dict(scheduler_dict) 249 | ckpt_component.append("sch") 250 | 251 | if "step" in state: 252 | self.step = state["step"] + 1 253 | self.train_status["step"] = state["step"] + 1 254 | ckpt_component.append("step: {}".format(self.train_status["step"])) 255 | 256 | if "epoch" in state: 257 | self.epoch = state["epoch"] + 1 258 | self.train_status["epoch"] = state["epoch"] + 1 259 | ckpt_component.append("epoch: {}".format(self.train_status["epoch"])) 260 | 261 | if "best_val_score" in state: 262 | self.best_val_score = state["best_val_score"] 263 | ckpt_component.append("best val score: {:.3g}".format(self.best_val_score)) 264 | 265 | if "best_test_score" in state: 266 | self.best_test_score = state["best_test_score"] 267 | ckpt_component.append("best test score: {:.3g}".format(self.best_test_score)) 268 | 269 | ckpt_component = ", ".join(i for i in ckpt_component) 270 | utils.master_logger(self.logger, "Loaded models from: {}".format(self.params.restore_file), dist.get_rank() == 0) 271 | utils.master_logger(self.logger, "Ckpt load: {}".format(ckpt_component), dist.get_rank() == 0) 272 | -------------------------------------------------------------------------------- /common/quaternion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import megengine.functional as F 4 | 5 | 6 | def mge_qmul(q1, q2): 7 | """ 8 | Multiply quaternion(s) q2q1, rotate q1 first, rotate q2 second. 9 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 10 | Returns q*r as a tensor of shape (*, 4). 11 | """ 12 | assert q1.shape[-1] == 4 13 | assert q2.shape[-1] == 4 14 | 15 | original_shape = q1.shape 16 | 17 | # Compute outer product 18 | terms = F.matmul(q1.reshape(-1, 4, 1), q2.reshape(-1, 1, 4)) 19 | 20 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 21 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 22 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 23 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 24 | return F.stack((w, x, y, z), axis=1).reshape(original_shape) 25 | 26 | 27 | def mge_qrot(q, v): 28 | """ 29 | Rotate vector(s) v about the rotation described by quaternion(s) q. 30 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 31 | where * denotes any number of dimensions. 32 | Returns a tensor of shape (*, 3). 33 | """ 34 | assert q.shape[-1] == 4 35 | assert v.shape[-1] == 3 36 | assert q.shape[:-1] == v.shape[:-1] 37 | 38 | original_shape = list(v.shape) 39 | q = q.reshape(-1, 4) 40 | v = v.reshape(-1, 3) 41 | 42 | qvec = q[:, 1:] 43 | uv = F.stack(( 44 | qvec[:, 1] * v[:, 2] - qvec[:, 2] * v[:, 1], 45 | qvec[:, 2] * v[:, 0] - qvec[:, 0] * v[:, 2], 46 | qvec[:, 0] * v[:, 1] - qvec[:, 1] * v[:, 0], 47 | ), 48 | axis=1) 49 | uuv = F.stack(( 50 | qvec[:, 1] * uv[:, 2] - qvec[:, 2] * uv[:, 1], 51 | qvec[:, 2] * uv[:, 0] - qvec[:, 0] * uv[:, 2], 52 | qvec[:, 0] * uv[:, 1] - qvec[:, 1] * uv[:, 0], 53 | ), 54 | axis=1) 55 | # uv = F.cross(qvec, v, dim=1) 56 | # uuv = F.cross(qvec, uv, dim=1) 57 | return (v + 2 * (q[:, :1] * uv + uuv)).reshape(original_shape) 58 | 59 | 60 | # TODO: check 61 | def mge_quat2euler(q, order, epsilon=0): 62 | """ 63 | Convert quaternion(s) q to Euler angles. 64 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 65 | Returns a tensor of shape (*, 3). 66 | """ 67 | assert q.shape[-1] == 4 68 | 69 | original_shape = list(q.shape) 70 | original_shape[-1] = 3 71 | q = q.reshape(-1, 4) 72 | 73 | q0 = q[:, 0] 74 | q1 = q[:, 1] 75 | q2 = q[:, 2] 76 | q3 = q[:, 3] 77 | 78 | if order == "xyz": 79 | x = F.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 80 | y = F.asin(F.clip(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 81 | z = F.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 82 | elif order == "yzx": 83 | x = F.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 84 | y = F.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 85 | z = F.asin(F.clip(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 86 | elif order == "zxy": 87 | x = F.asin(F.clip(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 88 | y = F.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 89 | z = F.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 90 | elif order == "xzy": 91 | x = F.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 92 | y = F.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 93 | z = F.asin(F.clip(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 94 | elif order == "yxz": 95 | x = F.asin(F.clip(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 96 | y = F.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 97 | z = F.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 98 | elif order == "zyx": 99 | x = F.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 100 | y = F.asin(F.clip(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 101 | z = F.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 102 | else: 103 | raise 104 | 105 | return F.stack((x, y, z), axis=1).reshape(original_shape) 106 | 107 | 108 | # TODO: check 109 | def mge_euler2quat(e, order): 110 | """ 111 | Convert Euler angles to quaternions. 112 | """ 113 | assert e.shape[-1] == 3 114 | 115 | original_shape = [e.shape[0], 4] 116 | 117 | x = e[:, 0] 118 | y = e[:, 1] 119 | z = e[:, 2] 120 | 121 | rx = F.stack((F.cos(x / 2), F.sin(x / 2), F.zeros_like(x).cuda(), F.zeros_like(x).cuda()), axis=1) 122 | ry = F.stack((F.cos(y / 2), F.zeros_like(y).cuda(), F.sin(y / 2), F.zeros_like(y).cuda()), axis=1) 123 | rz = F.stack((F.cos(z / 2), F.zeros_like(z).cuda(), F.zeros_like(z).cuda(), F.sin(z / 2)), axis=1) 124 | 125 | result = None 126 | for coord in order: 127 | if coord == "x": 128 | r = rx 129 | elif coord == "y": 130 | r = ry 131 | elif coord == "z": 132 | r = rz 133 | else: 134 | raise 135 | if result is None: 136 | result = r 137 | else: 138 | result = mge_qmul(result, r) 139 | 140 | # Reverse antipodal representation to have a non-negative "w" 141 | if order in ["xyz", "yzx", "zxy"]: 142 | result *= -1 143 | 144 | return result.reshape(original_shape) 145 | 146 | 147 | def mge_quat2mat(pose): 148 | # Separate each quaternion value. 149 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3] 150 | # Convert quaternion to rotation matrix. 151 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf 152 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay. 153 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3 154 | R12 = 2 * (q1 * q2 - q0 * q3) 155 | R13 = 2 * (q1 * q3 + q0 * q2) 156 | R21 = 2 * (q1 * q2 + q0 * q3) 157 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3 158 | R23 = 2 * (q2 * q3 - q0 * q1) 159 | R31 = 2 * (q1 * q3 - q0 * q2) 160 | R32 = 2 * (q2 * q3 + q0 * q1) 161 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2 162 | R = F.stack((F.stack((R11, R12, R13), axis=0), F.stack((R21, R22, R23), axis=0), F.stack((R31, R32, R33), axis=0)), axis=0) 163 | 164 | rot_mat = F.transpose(R, (2, 0, 1)) # (B, 3, 3) 165 | translation = F.expand_dims(pose[:, 4:], axis=-1) # (B, 3, 1) 166 | transform = F.concat((rot_mat, translation), axis=2) 167 | return transform # (B, 3, 4) 168 | 169 | 170 | def mge_transform_pose(pose_old, pose_new): 171 | quat_old, translate_old = pose_old[:, :4], pose_old[:, 4:] 172 | quat_new, translate_new = pose_new[:, :4], pose_new[:, 4:] 173 | 174 | quat = mge_qmul(quat_old, quat_new) 175 | translate = mge_qrot(quat_new, translate_old) + translate_new 176 | pose = F.concat((quat, translate), axis=1) 177 | 178 | return pose 179 | 180 | 181 | # TODO: check 182 | def mge_qinv(q): 183 | # expectes q in (w,x,y,z) format 184 | w = q[:, 0:1] 185 | v = q[:, 1:] 186 | inv = F.concat([w, -v], axis=1) 187 | return inv 188 | 189 | 190 | def mge_quat_rotate(point_cloud, pose_7d): 191 | ndim = point_cloud.ndim 192 | if ndim == 2: 193 | N, _ = point_cloud.shape 194 | assert pose_7d.shape[0] == 1 195 | # repeat transformation vector for each point in shape 196 | quat = pose_7d[:, 0:4].expand([N, 1]) 197 | rotated_point_cloud = mge_qrot(quat, point_cloud) 198 | 199 | elif ndim == 3: 200 | B, N, _ = point_cloud.shape 201 | quat = F.tile(F.expand_dims(pose_7d[:, 0:4], axis=1), (1, N, 1)) 202 | rotated_point_cloud = mge_qrot(quat, point_cloud) 203 | 204 | else: 205 | raise RuntimeError("point cloud dim must be 2 or 3 !") 206 | 207 | return rotated_point_cloud 208 | 209 | 210 | def mge_quat_transform(pose_7d, pc, normal=None): 211 | pc_t = mge_quat_rotate(pc, pose_7d) + pose_7d[:, 4:].reshape(-1, 1, 3) # Ps" = R*Ps + t 212 | if normal is not None: 213 | normal_t = mge_quat_rotate(normal, pose_7d) 214 | return pc_t, normal_t 215 | else: 216 | return pc_t 217 | 218 | 219 | def np_qmul(q, r): 220 | q = torch.from_numpy(q).contiguous() 221 | r = torch.from_numpy(r).contiguous() 222 | return torch_qmul(q, r).numpy() 223 | 224 | 225 | def np_qrot(q, v): 226 | q = torch.from_numpy(q).contiguous() 227 | v = torch.from_numpy(v).contiguous() 228 | return torch_qrot(q, v).numpy() 229 | 230 | 231 | def np_quat2euler(q, order, epsilon=0, use_gpu=False): 232 | if use_gpu: 233 | q = torch.from_numpy(q).cuda() 234 | return torch_quat2euler(q, order, epsilon).cpu().numpy() 235 | else: 236 | q = torch.from_numpy(q).contiguous() 237 | return torch_quat2euler(q, order, epsilon).numpy() 238 | 239 | 240 | def np_qfix(q): 241 | """ 242 | Enforce quaternion continuity across the time dimension by selecting 243 | the representation (q or -q) with minimal euclidean_distance (or, equivalently, maximal dot product) 244 | between two consecutive frames. 245 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 246 | Returns a tensor of the same shape. 247 | """ 248 | assert len(q.shape) == 3 249 | assert q.shape[-1] == 4 250 | 251 | result = q.copy() 252 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 253 | mask = dot_products < 0 254 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 255 | result[1:][mask] *= -1 256 | return result 257 | 258 | 259 | def np_expmap2quat(e): 260 | """ 261 | Convert axis-angle rotations (aka exponential maps) to quaternions. 262 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 263 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 264 | Returns a tensor of shape (*, 4). 265 | """ 266 | assert e.shape[-1] == 3 267 | 268 | original_shape = list(e.shape) 269 | original_shape[-1] = 4 270 | e = e.reshape(-1, 3) 271 | 272 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 273 | w = np.cos(0.5 * theta).reshape(-1, 1) 274 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 275 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 276 | 277 | 278 | def np_euler2quat(e, order): 279 | """ 280 | Convert Euler angles to quaternions. 281 | """ 282 | assert e.shape[-1] == 3 283 | 284 | original_shape = list(e.shape) 285 | original_shape[-1] = 4 286 | 287 | e = e.reshape(-1, 3) 288 | 289 | x = e[:, 0] 290 | y = e[:, 1] 291 | z = e[:, 2] 292 | 293 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 294 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 295 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 296 | 297 | result = None 298 | for coord in order: 299 | if coord == "x": 300 | r = rx 301 | elif coord == "y": 302 | r = ry 303 | elif coord == "z": 304 | r = rz 305 | else: 306 | raise 307 | if result is None: 308 | result = r 309 | else: 310 | result = np_qmul(result, r) 311 | 312 | # Reverse antipodal representation to have a non-negative "w" 313 | if order in ["xyz", "yzx", "zxy"]: 314 | result *= -1 315 | 316 | return result.reshape(original_shape) 317 | -------------------------------------------------------------------------------- /common/se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import megengine.functional as F 3 | import transforms3d.quaternions as t3d 4 | from scipy.spatial.transform import Rotation 5 | 6 | 7 | def mge_inverse(g): 8 | """ Returns the inverse of the SE3 transform 9 | 10 | Args: 11 | g: (B, 3/4, 4) transform 12 | 13 | Returns: 14 | (B, 3, 4) matrix containing the inverse 15 | 16 | """ 17 | # Compute inverse 18 | rot = g[..., 0:3, 0:3] 19 | trans = g[..., 0:3, 3] 20 | inverse_transform = F.concat([rot.transpose(0, 2, 1), F.matmul(rot.transpose(0, 2, 1), F.expand_dims(-trans, axis=-1))], axis=-1) 21 | 22 | return inverse_transform 23 | 24 | 25 | def mge_concatenate(a, b): 26 | """Concatenate two SE3 transforms, 27 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix) 28 | 29 | Args: 30 | a: (B, 3/4, 4) 31 | b: (B, 3/4, 4) 32 | 33 | Returns: 34 | (B, 3/4, 4) 35 | """ 36 | 37 | rot1 = a[..., :3, :3] 38 | trans1 = a[..., :3, 3] 39 | rot2 = b[..., :3, :3] 40 | trans2 = b[..., :3, 3] 41 | 42 | rot_cat = F.matmul(rot1, rot2) 43 | trans_cat = F.matmul(rot1, F.expand_dims(trans2, axis=-1)) + F.expand_dims(trans1, axis=-1) 44 | concatenated = F.concat([rot_cat, trans_cat], axis=-1) 45 | 46 | return concatenated 47 | 48 | 49 | def mge_transform(g, a, normals=None): 50 | """ Applies the SE3 transform 51 | 52 | Args: 53 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4) 54 | a: Points to be transformed (N, 3) or (B, N, 3) 55 | normals: (Optional). If provided, normals will be transformed 56 | 57 | Returns: 58 | transformed points of size (N, 3) or (B, N, 3) 59 | 60 | """ 61 | R = g[..., :3, :3] # (B, 3, 3) 62 | p = g[..., :3, 3] # (B, 3) 63 | 64 | if len(g.shape) == len(a.shape): 65 | b = F.matmul(a, R.transpose(0, 2, 1)) + F.expand_dims(p, axis=1) 66 | else: 67 | raise NotImplementedError 68 | 69 | if normals is not None: 70 | rotated_normals = F.matmul(normals, R.transpose(0, 2, 1)) 71 | return b, rotated_normals 72 | 73 | else: 74 | return b 75 | 76 | 77 | def np_identity(): 78 | return np.eye(3, 4) 79 | 80 | 81 | def np_transform(g: np.ndarray, pts: np.ndarray): 82 | """ Applies the SE3 transform 83 | 84 | Args: 85 | g: SE3 transformation matrix of size ([B,] 3/4, 4) 86 | pts: Points to be transformed ([B,] N, 3) 87 | 88 | Returns: 89 | transformed points of size (N, 3) 90 | """ 91 | rot = g[..., :3, :3] # (3, 3) 92 | trans = g[..., :3, 3] # (3) 93 | 94 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :] 95 | return transformed 96 | 97 | 98 | def np_inverse(g: np.ndarray): 99 | """Returns the inverse of the SE3 transform 100 | 101 | Args: 102 | g: ([B,] 3/4, 4) transform 103 | 104 | Returns: 105 | ([B,] 3/4, 4) matrix containing the inverse 106 | 107 | """ 108 | rot = g[..., :3, :3] # (3, 3) 109 | trans = g[..., :3, 3] # (3) 110 | 111 | inv_rot = np.swapaxes(rot, -1, -2) 112 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1) 113 | if g.shape[-2] == 4: 114 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 115 | 116 | return inverse_transform 117 | 118 | 119 | def np_concatenate(a: np.ndarray, b: np.ndarray): 120 | """ Concatenate two SE3 transforms 121 | 122 | Args: 123 | a: First transform ([B,] 3/4, 4) 124 | b: Second transform ([B,] 3/4, 4) 125 | 126 | Returns: 127 | a*b ([B, ] 3/4, 4) 128 | 129 | """ 130 | 131 | r_a, t_a = a[..., :3, :3], a[..., :3, 3] 132 | r_b, t_b = b[..., :3, :3], b[..., :3, 3] 133 | 134 | r_ab = r_a @ r_b 135 | t_ab = r_a @ t_b[..., None] + t_a[..., None] 136 | 137 | concatenated = np.concatenate([r_ab, t_ab], axis=-1) 138 | 139 | if a.shape[-2] == 4: 140 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 141 | 142 | return concatenated 143 | 144 | 145 | def np_from_xyzquat(xyzquat): 146 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw 147 | 148 | Args: 149 | xyzquat: np.array (7,) containing translation and quaterion 150 | 151 | Returns: 152 | SE3 matrix (4, 4) 153 | """ 154 | rot = Rotation.from_quat(xyzquat[3:]) 155 | trans = rot.apply(-xyzquat[:3]) 156 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1) 157 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 158 | 159 | return transform 160 | 161 | 162 | def np_mat2quat(transform): 163 | rotate = transform[:3, :3] 164 | translate = transform[:3, 3] 165 | quat = t3d.mat2quat(rotate) 166 | pose = np.concatenate([quat, translate], axis=0) 167 | return pose # (7, ) 168 | 169 | 170 | def np_quat2mat(pose): 171 | # Separate each quaternion value. 172 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3] 173 | # Convert quaternion to rotation matrix. 174 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf 175 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay. 176 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3 177 | R12 = 2 * (q1 * q2 - q0 * q3) 178 | R13 = 2 * (q1 * q3 + q0 * q2) 179 | R21 = 2 * (q1 * q2 + q0 * q3) 180 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3 181 | R23 = 2 * (q2 * q3 - q0 * q1) 182 | R31 = 2 * (q1 * q3 - q0 * q2) 183 | R32 = 2 * (q2 * q3 + q0 * q1) 184 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2 185 | R = np.stack((np.stack((R11, R12, R13), axis=0), np.stack((R21, R22, R23), axis=0), np.stack((R31, R32, R33), axis=0)), axis=0) 186 | 187 | rot_mat = R.transpose((2, 0, 1)) # (B, 3, 3) 188 | translation = pose[:, 4:][:, :, None] # (B, 3, 1) 189 | transform = np.concatenate((rot_mat, translation), axis=2) 190 | return transform # (B, 3, 4) 191 | -------------------------------------------------------------------------------- /common/so3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import megengine as mge 3 | from scipy.spatial.transform import Rotation 4 | 5 | 6 | def np_dcm2euler(mats: np.ndarray, seq: str = "zyx", degrees: bool = True): 7 | """Converts rotation matrix to euler angles 8 | 9 | Args: 10 | mats: (B, 3, 3) containing the B rotation matricecs 11 | seq: Sequence of euler rotations (default: "zyx") 12 | degrees (bool): If true (default), will return in degrees instead of radians 13 | 14 | Returns: 15 | 16 | """ 17 | 18 | eulers = [] 19 | for i in range(mats.shape[0]): 20 | r = Rotation.from_matrix(mats[i]) 21 | eulers.append(r.as_euler(seq, degrees=degrees)) 22 | return np.stack(eulers) 23 | 24 | 25 | def np_transform(g: np.ndarray, pts: np.ndarray): 26 | """ Applies the SO3 transform 27 | 28 | Args: 29 | g: SO3 transformation matrix of size (B, 3, 3) 30 | pts: Points to be transformed (B, N, 3) 31 | 32 | Returns: 33 | transformed points of size (B, N, 3) 34 | 35 | """ 36 | rot = g[..., :3, :3] # (3, 3) 37 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) 38 | return transformed 39 | 40 | 41 | def np_inverse(g: np.ndarray): 42 | """Returns the inverse of the SE3 transform 43 | 44 | Args: 45 | g: ([B,] 3/4, 4) transform 46 | 47 | Returns: 48 | ([B,] 3/4, 4) matrix containing the inverse 49 | 50 | """ 51 | rot = g[..., :3, :3] # (3, 3) 52 | 53 | inv_rot = np.swapaxes(rot, -1, -2) 54 | 55 | return inv_rot 56 | 57 | 58 | def mge_dcm2euler(mats, seq, degrees=True): 59 | mats = mats.numpy() 60 | eulers = [] 61 | for i in range(mats.shape[0]): 62 | r = Rotation.from_matrix(mats[i]) 63 | eulers.append(r.as_euler(seq, degrees=degrees)) 64 | return mge.tensor(np.stack(eulers)) 65 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import megengine as mge 4 | import coloredlogs 5 | 6 | 7 | class Params(): 8 | """Class that loads hyperparameters from a json file. 9 | 10 | Example: 11 | ``` 12 | params = Params(json_path) 13 | print(params.learning_rate) 14 | params.learning_rate = 0.5 # change the value of learning_rate in params 15 | ``` 16 | """ 17 | def __init__(self, json_path): 18 | with open(json_path) as f: 19 | params = json.load(f) 20 | self.update(params) 21 | 22 | def save(self, json_path): 23 | with open(json_path, 'w') as f: 24 | json.dump(self.__dict__, f, indent=4) 25 | 26 | def update(self, dict): 27 | """Loads parameters from json file""" 28 | self.__dict__.update(dict) 29 | 30 | @property 31 | def dict(self): 32 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 33 | return self.__dict__ 34 | 35 | 36 | class RunningAverage(): 37 | """A simple class that maintains the running average of a quantity 38 | 39 | Example: 40 | ``` 41 | loss_avg = RunningAverage() 42 | loss_avg.update(2) 43 | loss_avg.update(4) 44 | loss_avg() = 3 45 | ``` 46 | """ 47 | def __init__(self): 48 | self.steps = 0 49 | self.total = 0 50 | 51 | def update(self, val): 52 | self.total += val 53 | self.steps += 1 54 | 55 | def __call__(self): 56 | return self.total / float(self.steps) 57 | 58 | 59 | class AverageMeter(): 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.val_previous = 0 66 | self.avg = 0 67 | self.sum = 0 68 | self.count = 0 69 | 70 | def set(self, val): 71 | self.val = val 72 | self.avg = val 73 | 74 | def update(self, val, num): 75 | self.val_previous = self.val 76 | self.val = val 77 | self.sum += val * num 78 | self.count += num 79 | self.avg = self.sum / self.count 80 | 81 | 82 | def loss_meter_manager_intial(loss_meter_names): 83 | # 用于根据meter名字初始化需要用到的loss_meter 84 | loss_meters = [] 85 | for name in loss_meter_names: 86 | exec("%s = %s" % (name, 'AverageMeter()')) 87 | exec("loss_meters.append(%s)" % name) 88 | 89 | return loss_meters 90 | 91 | 92 | def tensor_mge(batch, check_on=True): 93 | if check_on: 94 | for k, v in batch.items(): 95 | batch[k] = mge.Tensor(v) 96 | else: 97 | for k, v in batch.items(): 98 | batch[k] = v.numpy() 99 | return batch 100 | 101 | 102 | def set_logger(log_path): 103 | """Set the logger to log info in terminal and file `log_path`. 104 | 105 | In general, it is useful to have a logger so that every output to the terminal is saved 106 | in a permanent file. Here we save it to `model_dir/train.log`. 107 | 108 | Example: 109 | ``` 110 | logging.info("Starting training...") 111 | ``` 112 | 113 | Args: 114 | log_path: (string) where to log 115 | """ 116 | logger = logging.getLogger() 117 | logger.setLevel(logging.INFO) 118 | 119 | # if not logger.handlers: 120 | # # Logging to a file 121 | # file_handler = logging.FileHandler(log_path) 122 | # file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 123 | # logger.addHandler(file_handler) 124 | # 125 | # # Logging to console 126 | # stream_handler = logging.StreamHandler() 127 | # stream_handler.setFormatter(logging.Formatter('%(message)s')) 128 | # logger.addHandler(stream_handler) 129 | 130 | coloredlogs.install(level='INFO', logger=logger, fmt='%(asctime)s %(name)s %(message)s') 131 | file_handler = logging.FileHandler(log_path) 132 | log_formatter = logging.Formatter('%(asctime)s - %(message)s') 133 | file_handler.setFormatter(log_formatter) 134 | logger.addHandler(file_handler) 135 | master_logger(logger, 'Output and logs will be saved to {}'.format(log_path)) 136 | return logger 137 | 138 | 139 | def save_dict_to_json(d, json_path): 140 | """Saves dict of floats in json file 141 | 142 | Args: 143 | d: (dict) of float-castable values (np.float, int, float, etc.) 144 | json_path: (string) path to json file 145 | """ 146 | save_dict = {} 147 | with open(json_path, "w") as f: 148 | # We need to convert the values to float for json (it doesn"t accept np.array, np.float, ) 149 | for k, v in d.items(): 150 | if isinstance(v, AverageMeter): 151 | save_dict[k] = float(v.avg) 152 | else: 153 | save_dict[k] = float(v) 154 | json.dump(save_dict, f, indent=4) 155 | 156 | 157 | def master_logger(logger, info, is_master=False): 158 | if is_master: 159 | logger.info(info) 160 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import h5py 7 | 8 | from megengine.data import DataLoader 9 | from megengine.data.dataset import Dataset 10 | from megengine.data.sampler import RandomSampler, SequentialSampler 11 | import megengine.distributed as dist 12 | 13 | from dataset.transformations import fetch_transform 14 | from common import utils 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class ModelNetNpy(Dataset): 20 | def __init__(self, dataset_path: str, dataset_mode: str, subset: str = "train", categories=None, transform=None): 21 | self._logger = logging.getLogger(self.__class__.__name__) 22 | self._root = dataset_path 23 | self._subset = subset 24 | self._is_master = dist.get_rank() == 0 25 | 26 | metadata_fpath = os.path.join(self._root, "modelnet_{}_{}.pickle".format(dataset_mode, subset)) 27 | utils.master_logger(self._logger, "Loading data from {} for {}".format(metadata_fpath, subset), self._is_master) 28 | 29 | if not os.path.exists(os.path.join(dataset_path)): 30 | assert FileNotFoundError("Not found dataset_path: {}".format(dataset_path)) 31 | 32 | with open(os.path.join(dataset_path, "shape_names.txt")) as fid: 33 | self._classes = [l.strip() for l in fid] 34 | self._category2idx = {e[1]: e[0] for e in enumerate(self._classes)} 35 | self._idx2category = self._classes 36 | 37 | if categories is not None: 38 | categories_idx = [self._category2idx[c] for c in categories] 39 | utils.master_logger(self._logger, "Categories used: {}.".format(categories_idx), self._is_master) 40 | self._classes = categories 41 | else: 42 | categories_idx = None 43 | utils.master_logger(self._logger, "Using all categories.", self._is_master) 44 | 45 | self._data = self._read_pickle_files(os.path.join(dataset_path, "modelnet_{}_{}.pickle".format(dataset_mode, subset)), 46 | categories_idx) 47 | 48 | self._transform = transform 49 | utils.master_logger(self._logger, "Loaded {} {} instances.".format(len(self._data), subset), self._is_master) 50 | 51 | @property 52 | def classes(self): 53 | return self._classes 54 | 55 | @staticmethod 56 | def _read_pickle_files(fnames, categories): 57 | 58 | all_data_dict = [] 59 | with open(fnames, "rb") as f: 60 | data = pickle.load(f) 61 | 62 | for category in categories: 63 | all_data_dict.extend(data[category]) 64 | 65 | return all_data_dict 66 | 67 | def to_category(self, i): 68 | return self._idx2category[i] 69 | 70 | def __getitem__(self, item): 71 | 72 | data_path = self._data[item] 73 | 74 | # load and process data 75 | points = np.load(data_path) 76 | idx = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[1])) 77 | label = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[3])) 78 | sample = {"points": points, "label": label, "idx": idx} 79 | 80 | if self._transform: 81 | sample = self._transform(sample) 82 | return sample 83 | 84 | def __len__(self): 85 | return len(self._data) 86 | 87 | 88 | def fetch_dataloader(params): 89 | utils.master_logger(_logger, "Dataset type: {}, transform type: {}".format(params.dataset_type, params.transform_type), 90 | dist.get_rank() == 0) 91 | 92 | train_transforms, test_transforms = fetch_transform(params) 93 | 94 | if params.dataset_type == "modelnet_os": 95 | dataset_path = "./dataset/data/modelnet_os" 96 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 97 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 98 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")] 99 | train_categories.sort() 100 | val_categories.sort() 101 | test_categories.sort() 102 | train_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="train", categories=train_categories, transform=train_transforms) 103 | val_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="val", categories=val_categories, transform=test_transforms) 104 | test_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="test", categories=test_categories, transform=test_transforms) 105 | 106 | elif params.dataset_type == "modelnet_ts": 107 | dataset_path = "./dataset/data/modelnet_ts" 108 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 109 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")] 110 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")] 111 | train_categories.sort() 112 | val_categories.sort() 113 | test_categories.sort() 114 | train_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="train", categories=train_categories, transform=train_transforms) 115 | val_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="val", categories=val_categories, transform=test_transforms) 116 | test_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="test", categories=test_categories, transform=test_transforms) 117 | 118 | dataloaders = {} 119 | # add defalt train data loader 120 | train_sampler = RandomSampler(train_ds, batch_size=params.train_batch_size, drop_last=True) 121 | train_dl = DataLoader(train_ds, train_sampler, num_workers=params.num_workers) 122 | dataloaders["train"] = train_dl 123 | 124 | # chosse val or test data loader for evaluate 125 | for split in ["val", "test"]: 126 | if split in params.eval_type: 127 | if split == "val": 128 | val_sampler = SequentialSampler(val_ds, batch_size=params.eval_batch_size) 129 | dl = DataLoader(val_ds, val_sampler, num_workers=params.num_workers) 130 | elif split == "test": 131 | test_sampler = SequentialSampler(test_ds, batch_size=params.eval_batch_size) 132 | dl = DataLoader(test_ds, test_sampler, num_workers=params.num_workers) 133 | else: 134 | raise ValueError("Unknown eval_type in params, should in [val, test]") 135 | dataloaders[split] = dl 136 | else: 137 | dataloaders[split] = None 138 | 139 | return dataloaders 140 | -------------------------------------------------------------------------------- /dataset/transformations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import megengine as mge 4 | import megengine.distributed as dist 5 | import numpy as np 6 | from common import se3, so3, utils 7 | from scipy.spatial.transform import Rotation 8 | from megengine.data.transform import Transform 9 | 10 | _logger = logging.getLogger(__name__) 11 | 12 | 13 | def uniform_2_sphere(num: int = None): 14 | """Uniform sampling on a 2-sphere 15 | 16 | Source: https://gist.github.com/andrewbolster/10274979 17 | 18 | Args: 19 | num: Number of vectors to sample (or None if single) 20 | 21 | Returns: 22 | Random Vector (np.ndarray) of size (num, 3) with norm 1. 23 | If num is None returned value will have size (3,) 24 | 25 | """ 26 | if num is not None: 27 | phi = np.random.uniform(0.0, 2 * np.pi, num) 28 | cos_theta = np.random.uniform(-1.0, 1.0, num) 29 | else: 30 | phi = np.random.uniform(0.0, 2 * np.pi) 31 | cos_theta = np.random.uniform(-1.0, 1.0) 32 | 33 | theta = np.arccos(cos_theta) 34 | x = np.sin(theta) * np.cos(phi) 35 | y = np.sin(theta) * np.sin(phi) 36 | z = np.cos(theta) 37 | 38 | return np.stack((x, y, z), axis=-1) 39 | 40 | 41 | class SplitSourceRef(Transform): 42 | """Clones the point cloud into separate source and reference point clouds""" 43 | def __init__(self, mode="os"): 44 | self.mode = mode 45 | 46 | def apply(self, sample): 47 | if "deterministic" in sample and sample["deterministic"]: 48 | np.random.seed(sample["idx"]) 49 | 50 | if self.mode == "os": 51 | sample["points_raw"] = sample.pop("points").astype(np.float32)[:, :3] 52 | sample["points_src"] = sample["points_raw"].copy() 53 | sample["points_ref"] = sample["points_raw"].copy() 54 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32) 55 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32) 56 | elif self.mode == "ts": 57 | points_raw = sample.pop("points").astype(np.float32) 58 | points_raw = points_raw[np.random.choice(points_raw.shape[0], 2, replace=False), :, :] 59 | sample["points_src"] = points_raw[0, :, :].astype(np.float32) 60 | sample["points_ref"] = points_raw[1, :, :].astype(np.float32) 61 | sample["points_src_raw"] = sample["points_src"].copy() 62 | sample["points_ref_raw"] = sample["points_ref"].copy() 63 | 64 | else: 65 | raise NotImplementedError 66 | 67 | return sample 68 | 69 | 70 | class Resampler(Transform): 71 | def __init__(self, num: int): 72 | """Resamples a point cloud containing N points to one containing M 73 | 74 | Guaranteed to have no repeated points if M <= N. 75 | Otherwise, it is guaranteed that all points appear at least once. 76 | 77 | Args: 78 | num (int): Number of points to resample to, i.e. M 79 | 80 | """ 81 | self.num = num 82 | 83 | @staticmethod 84 | def _resample(points, k): 85 | """Resamples the points such that there is exactly k points. 86 | 87 | If the input point cloud has <= k points, it is guaranteed the 88 | resampled point cloud contains every point in the input. 89 | If the input point cloud has > k points, it is guaranteed the 90 | resampled point cloud does not contain repeated point. 91 | """ 92 | # print("===", points.shape[0], k) 93 | if k < points.shape[0]: 94 | rand_idxs = np.random.choice(points.shape[0], k, replace=False) 95 | return points[rand_idxs, :] 96 | elif points.shape[0] == k: 97 | return points 98 | else: 99 | rand_idxs = np.concatenate([ 100 | np.random.choice(points.shape[0], points.shape[0], replace=False), 101 | np.random.choice(points.shape[0], k - points.shape[0], replace=True) 102 | ]) 103 | return points[rand_idxs, :] 104 | 105 | def apply(self, sample): 106 | 107 | if "deterministic" in sample and sample["deterministic"]: 108 | np.random.seed(sample["idx"]) 109 | 110 | if "points" in sample: 111 | sample["points"] = self._resample(sample["points"], self.num) 112 | else: 113 | if "crop_proportion" not in sample: 114 | src_size, ref_size = self.num, self.num 115 | elif len(sample["crop_proportion"]) == 1: 116 | src_size = math.ceil(sample["crop_proportion"][0] * self.num) 117 | ref_size = self.num 118 | elif len(sample["crop_proportion"]) == 2: 119 | src_size = math.ceil(sample["crop_proportion"][0] * self.num) 120 | ref_size = math.ceil(sample["crop_proportion"][1] * self.num) 121 | else: 122 | raise ValueError("Crop proportion must have 1 or 2 elements") 123 | 124 | sample["points_src"] = self._resample(sample["points_src"], src_size) 125 | sample["points_ref"] = self._resample(sample["points_ref"], ref_size) 126 | 127 | # sample for the raw point clouds 128 | sample["points_src_raw"] = sample["points_src_raw"][:self.num, :] 129 | sample["points_ref_raw"] = sample["points_ref_raw"][:self.num, :] 130 | 131 | return sample 132 | 133 | 134 | class RandomJitter(Transform): 135 | """ generate perturbations """ 136 | def __init__(self, noise_std=0.01, clip=0.05): 137 | self.noise_std = noise_std 138 | self.clip = clip 139 | 140 | def jitter(self, pts): 141 | 142 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip) 143 | pts[:, :3] += noise # Add noise to xyz 144 | 145 | return pts 146 | 147 | def apply(self, sample): 148 | 149 | if "points" in sample: 150 | sample["points"] = self.jitter(sample["points"]) 151 | else: 152 | sample["points_src"] = self.jitter(sample["points_src"]) 153 | sample["points_ref"] = self.jitter(sample["points_ref"]) 154 | 155 | return sample 156 | 157 | 158 | class RandomCrop(Transform): 159 | """Randomly crops the *source* point cloud, approximately retaining half the points 160 | 161 | A direction is randomly sampled from S2, and we retain points which lie within the 162 | half-space oriented in this direction. 163 | If p_keep != 0.5, we shift the plane until approximately p_keep points are retained 164 | """ 165 | def __init__(self, p_keep=None): 166 | if p_keep is None: 167 | p_keep = [0.7, 0.7] # Crop both clouds to 70% 168 | self.p_keep = np.array(p_keep, dtype=np.float32) 169 | 170 | @staticmethod 171 | def crop(points, p_keep): 172 | if p_keep == 1.0: 173 | mask = np.ones(shape=(points.shape[0], )) > 0 174 | 175 | else: 176 | rand_xyz = uniform_2_sphere() 177 | centroid = np.mean(points[:, :3], axis=0) 178 | points_centered = points[:, :3] - centroid 179 | dist_from_plane = np.dot(points_centered, rand_xyz) 180 | 181 | if p_keep == 0.5: 182 | mask = dist_from_plane > 0 183 | else: 184 | mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100) 185 | 186 | return points[mask, :] 187 | 188 | def apply(self, sample): 189 | 190 | if "deterministic" in sample and sample["deterministic"]: 191 | np.random.seed(sample["idx"]) 192 | 193 | sample["crop_proportion"] = self.p_keep 194 | 195 | if len(sample["crop_proportion"]) == 1: 196 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0]) 197 | sample["points_ref"] = self.crop(sample["points_ref"], 1.0) 198 | else: 199 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0]) 200 | sample["points_ref"] = self.crop(sample["points_ref"], self.p_keep[1]) 201 | 202 | return sample 203 | 204 | 205 | class RandomTransformSE3(Transform): 206 | def __init__(self, rot_mag: float = 180.0, trans_mag: float = 1.0, random_mag: bool = False): 207 | """Applies a random rigid transformation to the source point cloud 208 | 209 | Args: 210 | rot_mag (float): Maximum rotation in degrees 211 | trans_mag (float): Maximum translation T. Random translation will 212 | be in the range [-X,X] in each axis 213 | random_mag (bool): If true, will randomize the maximum rotation, i.e. will bias towards small 214 | perturbations 215 | """ 216 | self._rot_mag = rot_mag 217 | self._trans_mag = trans_mag 218 | self._random_mag = random_mag 219 | 220 | def generate_transform(self): 221 | """Generate a random SE3 transformation (3, 4) """ 222 | 223 | if self._random_mag: 224 | attentuation = np.random.random() 225 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 226 | else: 227 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 228 | 229 | # Generate rotation 230 | rand_rot = special_ortho_group.rvs(3) 231 | axis_angle = Rotation.as_rotvec(Rotation.from_dcm(rand_rot)) 232 | axis_angle *= rot_mag / 180.0 233 | rand_rot = Rotation.from_rotvec(axis_angle).as_dcm() 234 | 235 | # Generate translation 236 | rand_trans = np.random.uniform(-trans_mag, trans_mag, 3) 237 | rand_SE3 = np.concatenate((rand_rot, rand_trans[:, None]), axis=1).astype(np.float32) 238 | 239 | return rand_SE3 240 | 241 | def apply_transform(self, p0, transform_mat): 242 | p1 = se3.np_transform(transform_mat, p0[:, :3]) 243 | if p0.shape[1] == 6: # Need to rotate normals also 244 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 245 | p1 = np.concatenate((p1, n1), axis=-1) 246 | 247 | igt = transform_mat 248 | gt = se3.np_inverse(igt) 249 | 250 | return p1, gt, igt 251 | 252 | def transform(self, tensor): 253 | transform_mat = self.generate_transform() 254 | return self.apply_transform(tensor, transform_mat) 255 | 256 | def apply(self, sample): 257 | 258 | if "deterministic" in sample and sample["deterministic"]: 259 | np.random.seed(sample["idx"]) 260 | 261 | if "points" in sample: 262 | sample["points"], _, _ = self.transform(sample["points"]) 263 | else: 264 | src_transformed, transform_r_s, transform_s_r = self.transform(sample["points_src"]) 265 | # Apply to source to get reference 266 | sample["transform_gt"] = transform_r_s 267 | sample["pose_gt"] = se3.np_mat2quat(transform_r_s) 268 | sample["transform_igt"] = transform_s_r 269 | sample["points_src"] = src_transformed 270 | # transnform the raw source point cloud 271 | sample["points_src_raw"] = se3.np_transform(transform_s_r, sample["points_src_raw"][:, :3]) 272 | 273 | return sample 274 | 275 | 276 | # noinspection PyPep8Naming 277 | class RandomTransformSE3_euler(RandomTransformSE3): 278 | """Same as RandomTransformSE3, but rotates using euler angle rotations 279 | 280 | This transformation is consistent to Deep Closest Point but does not 281 | generate uniform rotations 282 | 283 | """ 284 | def generate_transform(self): 285 | 286 | if self._random_mag: 287 | attentuation = np.random.random() 288 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 289 | else: 290 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 291 | 292 | # Generate rotation 293 | anglex = np.random.uniform() * np.pi * rot_mag / 180.0 294 | angley = np.random.uniform() * np.pi * rot_mag / 180.0 295 | anglez = np.random.uniform() * np.pi * rot_mag / 180.0 296 | 297 | cosx = np.cos(anglex) 298 | cosy = np.cos(angley) 299 | cosz = np.cos(anglez) 300 | sinx = np.sin(anglex) 301 | siny = np.sin(angley) 302 | sinz = np.sin(anglez) 303 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 304 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 305 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 306 | R_ab = Rx @ Ry @ Rz 307 | t_ab = np.random.uniform(-trans_mag, trans_mag, 3) 308 | 309 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 310 | return rand_SE3 311 | 312 | 313 | class ShufflePoints(Transform): 314 | """Shuffles the order of the points""" 315 | def apply(self, sample): 316 | if "points" in sample: 317 | sample["points"] = np.random.permutation(sample["points"]) 318 | else: 319 | sample["points_ref"] = np.random.permutation(sample["points_ref"]) 320 | sample["points_src"] = np.random.permutation(sample["points_src"]) 321 | return sample 322 | 323 | 324 | class SetDeterministic(Transform): 325 | """Adds a deterministic flag to the sample such that subsequent transforms 326 | use a fixed random seed where applicable. Used for test""" 327 | def apply(self, sample): 328 | sample["deterministic"] = True 329 | return sample 330 | 331 | 332 | class PRNet(Transform): 333 | def __init__(self, num_points, rot_mag, trans_mag, noise_std=0.01, clip=0.05, add_noise=True, only_z=False, partial=True): 334 | self.num_points = num_points 335 | self.rot_mag = rot_mag 336 | self.trans_mag = trans_mag 337 | self.noise_std = noise_std 338 | self.clip = clip 339 | self.add_noise = add_noise 340 | self.only_z = only_z 341 | self.partial = partial 342 | 343 | def apply_transform(self, p0, transform_mat): 344 | p1 = se3.np_transform(transform_mat, p0[:, :3]) 345 | if p0.shape[1] == 6: # Need to rotate normals also 346 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 347 | p1 = np.concatenate((p1, n1), axis=-1) 348 | 349 | gt = transform_mat 350 | 351 | return p1, gt 352 | 353 | def jitter(self, pts): 354 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip) 355 | pts[:, :3] += noise # Add noise to xyz 356 | 357 | return pts 358 | 359 | def knn(self, pts, random_pt, k): 360 | distance = np.sum((pts - random_pt)**2, axis=1) 361 | idx = np.argsort(distance)[:k] # (k,) 362 | return idx 363 | 364 | def apply(self, sample): 365 | 366 | if "deterministic" in sample and sample["deterministic"]: 367 | np.random.seed(sample["idx"]) 368 | 369 | src = sample["points_src"] 370 | ref = sample["points_ref"] 371 | # Generate rigid transform 372 | anglex = np.random.uniform() * np.pi * self.rot_mag / 180.0 373 | angley = np.random.uniform() * np.pi * self.rot_mag / 180.0 374 | anglez = np.random.uniform() * np.pi * self.rot_mag / 180.0 375 | 376 | cosx = np.cos(anglex) 377 | cosy = np.cos(angley) 378 | cosz = np.cos(anglez) 379 | sinx = np.sin(anglex) 380 | siny = np.sin(angley) 381 | sinz = np.sin(anglez) 382 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]]) 383 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]]) 384 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]]) 385 | 386 | if not self.only_z: 387 | R_ab = Rx @ Ry @ Rz 388 | else: 389 | R_ab = Rz 390 | t_ab = np.random.uniform(-self.trans_mag, self.trans_mag, 3) 391 | 392 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 393 | ref, transform_s_r = self.apply_transform(ref, rand_SE3) 394 | # Apply to source to get reference 395 | sample["transform_gt"] = transform_s_r 396 | sample["pose_gt"] = se3.np_mat2quat(transform_s_r) 397 | 398 | # Crop and sample 399 | if self.partial: 400 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 401 | idx1 = self.knn(src, random_p1, k=768) 402 | random_p2 = random_p1 403 | idx2 = self.knn(ref, random_p2, k=768) 404 | else: 405 | idx1 = np.random.choice(src.shape[0], 1024, replace=False), 406 | idx2 = np.random.choice(ref.shape[0], 1024, replace=False), 407 | src = mge.tensor(src) 408 | ref = mge.tensor(ref) 409 | 410 | # add noise 411 | if self.add_noise: 412 | sample["points_src"] = self.jitter(src[idx1, :]) 413 | sample["points_ref"] = self.jitter(ref[idx2, :]) 414 | else: 415 | sample["points_src"] = src[idx1, :] 416 | sample["points_ref"] = ref[idx2, :] 417 | 418 | return sample 419 | 420 | 421 | class Compose(object): 422 | def __init__(self, transforms): 423 | self.transforms = transforms 424 | 425 | def __call__(self, input): 426 | for t in self.transforms: 427 | input = t.apply(input) 428 | return input 429 | 430 | 431 | def fetch_transform(params): 432 | 433 | if params.transform_type == "modelnet_os_rpmnet_noise": 434 | train_transforms = [ 435 | SplitSourceRef(mode="os"), 436 | RandomCrop(params.partial_ratio), 437 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 438 | Resampler(params.num_points), 439 | RandomJitter(), 440 | ShufflePoints() 441 | ] 442 | 443 | test_transforms = [ 444 | SetDeterministic(), 445 | SplitSourceRef(mode="os"), 446 | RandomCrop(params.partial_ratio), 447 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 448 | Resampler(params.num_points), 449 | RandomJitter(), 450 | ShufflePoints() 451 | ] 452 | 453 | elif params.transform_type == "modelnet_os_rpmnet_clean": 454 | train_transforms = [ 455 | SplitSourceRef(mode="os"), 456 | RandomCrop(params.partial_ratio), 457 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 458 | Resampler(params.num_points), 459 | ShufflePoints() 460 | ] 461 | 462 | test_transforms = [ 463 | SetDeterministic(), 464 | SplitSourceRef(mode="os"), 465 | RandomCrop(params.partial_ratio), 466 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 467 | Resampler(params.num_points), 468 | ShufflePoints() 469 | ] 470 | 471 | elif params.transform_type == "modelnet_ts_rpmnet_noise": 472 | train_transforms = [ 473 | SplitSourceRef(mode="ts"), 474 | RandomCrop(params.partial_ratio), 475 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 476 | Resampler(params.num_points), 477 | RandomJitter(noise_std=params.noise_std), 478 | ShufflePoints() 479 | ] 480 | 481 | test_transforms = [ 482 | SetDeterministic(), 483 | SplitSourceRef(mode="ts"), 484 | RandomCrop(params.partial_ratio), 485 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 486 | Resampler(params.num_points), 487 | RandomJitter(noise_std=params.noise_std), 488 | ShufflePoints() 489 | ] 490 | 491 | elif params.transform_type == "modelnet_ts_rpmnet_clean": 492 | train_transforms = [ 493 | SplitSourceRef(mode="ts"), 494 | RandomCrop(params.partial_ratio), 495 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 496 | Resampler(params.num_points), 497 | ShufflePoints() 498 | ] 499 | 500 | test_transforms = [ 501 | SetDeterministic(), 502 | SplitSourceRef(mode="ts"), 503 | RandomCrop(params.partial_ratio), 504 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag), 505 | Resampler(params.num_points), 506 | ShufflePoints() 507 | ] 508 | 509 | elif params.transform_type == "modelnet_ts_prnet_noise": 510 | train_transforms = [ 511 | SplitSourceRef(mode="ts"), 512 | ShufflePoints(), 513 | PRNet(num_points=params.num_points, 514 | rot_mag=params.rot_mag, 515 | trans_mag=params.trans_mag, 516 | noise_std=params.noise_std, 517 | add_noise=True) 518 | ] 519 | 520 | test_transforms = [ 521 | SetDeterministic(), 522 | SplitSourceRef(mode="ts"), 523 | ShufflePoints(), 524 | PRNet(num_points=params.num_points, 525 | rot_mag=params.rot_mag, 526 | trans_mag=params.trans_mag, 527 | noise_std=params.noise_std, 528 | add_noise=True) 529 | ] 530 | 531 | elif params.transform_type == "modelnet_ts_prnet_clean": 532 | train_transforms = [ 533 | SplitSourceRef(mode="ts"), 534 | ShufflePoints(), 535 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 536 | ] 537 | 538 | test_transforms = [ 539 | SetDeterministic(), 540 | SplitSourceRef(mode="ts"), 541 | ShufflePoints(), 542 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 543 | ] 544 | 545 | elif params.transform_type == "modelnet_os_prnet_noise": 546 | train_transforms = [ 547 | SplitSourceRef(mode="os"), 548 | ShufflePoints(), 549 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True) 550 | ] 551 | 552 | test_transforms = [ 553 | SetDeterministic(), 554 | SplitSourceRef(mode="os"), 555 | ShufflePoints(), 556 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True) 557 | ] 558 | 559 | elif params.transform_type == "modelnet_os_prnet_clean": 560 | train_transforms = [ 561 | SplitSourceRef(mode="os"), 562 | ShufflePoints(), 563 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 564 | ] 565 | 566 | test_transforms = [ 567 | SetDeterministic(), 568 | SplitSourceRef(mode="os"), 569 | ShufflePoints(), 570 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False) 571 | ] 572 | 573 | utils.master_logger(_logger, "Train transforms: {}".format(", ".join([type(t).__name__ for t in train_transforms])), 574 | dist.get_rank() == 0) 575 | utils.master_logger(_logger, "Val and Test transforms: {}".format(", ".join([type(t).__name__ for t in test_transforms])), 576 | dist.get_rank() == 0) 577 | train_transforms = Compose(train_transforms) 578 | test_transforms = Compose(test_transforms) 579 | return train_transforms, test_transforms 580 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import dataset.data_loader as data_loader 6 | 7 | import model.net as net 8 | 9 | from common import utils 10 | from loss.losses import compute_losses, compute_metrics 11 | from common.manager import Manager 12 | import megengine.distributed as dist 13 | import megengine.functional as F 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_dir", default="experiments/base_model", help="Directory containing params.json") 17 | parser.add_argument("--restore_file", default="best", help="name of the file in --model_dir containing weights to load") 18 | 19 | 20 | def evaluate(model, manager): 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | """Evaluate the model on `num_steps` batches. 24 | 25 | Args: 26 | model: (torch.nn.Module) the neural network 27 | manager: a class instance that contains objects related to train and evaluate. 28 | """ 29 | # set model to evaluation mode 30 | model.eval() 31 | 32 | # compute metrics over the dataset 33 | if manager.dataloaders["val"] is not None: 34 | # loss status and val status initial 35 | manager.reset_loss_status() 36 | manager.reset_metric_status("val") 37 | for data_batch in manager.dataloaders["val"]: 38 | # compute the real batch size 39 | bs = data_batch["points_src"].shape[0] 40 | # move to GPU if available 41 | data_batch = utils.tensor_mge(data_batch) 42 | # compute model output 43 | output_batch = model(data_batch) 44 | # compute all loss on this batch 45 | loss = compute_losses(output_batch, manager.params) 46 | metrics = compute_metrics(output_batch, manager.params) 47 | if world_size > 1: 48 | for k, v in loss.items(): 49 | loss[k] = F.distributed.all_reduce_sum(v) / world_size 50 | for k, v in metrics.items(): 51 | metrics[k] = F.distributed.all_reduce_sum(v) / world_size 52 | manager.update_loss_status(loss, "val", bs) 53 | # compute all metrics on this batch 54 | manager.update_metric_status(metrics, "val", bs) 55 | 56 | # update val data to tensorboard 57 | if rank == 0: 58 | # compute RMSE metrics 59 | manager.summarize_metric_status(metrics, "val") 60 | 61 | manager.writer.add_scalar("Loss/val", manager.loss_status["total"].avg, manager.epoch) 62 | # manager.logger.info("Loss/valid epoch {}: {:.4f}".format(manager.epoch, manager.loss_status["total"].avg)) 63 | for k, v in manager.val_status.items(): 64 | manager.writer.add_scalar("Metric/val/{}".format(k), v.avg, manager.epoch) 65 | # For each epoch, print the metric 66 | manager.print_metrics("val", title="Val", color="green") 67 | 68 | if manager.dataloaders["test"] is not None: 69 | # loss status and val status initial 70 | manager.reset_loss_status() 71 | manager.reset_metric_status("test") 72 | for data_batch in manager.dataloaders["test"]: 73 | # compute the real batch size 74 | bs = data_batch["points_src"].shape[0] 75 | # move to GPU if available 76 | data_batch = utils.tensor_mge(data_batch) 77 | # compute model output 78 | output_batch = model(data_batch) 79 | # compute all loss on this batch 80 | loss = compute_losses(output_batch, manager.params) 81 | metrics = compute_metrics(output_batch, manager.params) 82 | if world_size > 1: 83 | for k, v in loss.items(): 84 | loss[k] = F.distributed.all_reduce_sum(v) / world_size 85 | for k, v in metrics.items(): 86 | metrics[k] = F.distributed.all_reduce_sum(v) / world_size 87 | manager.update_loss_status(loss, "test", bs) 88 | # compute all metrics on this batch 89 | manager.update_metric_status(metrics, "test", bs) 90 | 91 | # update test data to tensorboard 92 | if rank == 0: 93 | # compute RMSE metrics 94 | manager.summarize_metric_status(metrics, "test") 95 | 96 | manager.writer.add_scalar("Loss/test", manager.loss_status["total"].avg, manager.epoch) 97 | # manager.logger.info("Loss/test epoch {}: {:.4f}".format(manager.epoch, manager.loss_status["total"].avg)) 98 | for k, v in manager.val_status.items(): 99 | manager.writer.add_scalar("Metric/test/{}".format(k), v.avg, manager.epoch) 100 | # For each epoch, print the metric 101 | manager.print_metrics("test", title="Test", color="red") 102 | 103 | 104 | def test(model, manager): 105 | """Test the model with loading checkpoints. 106 | 107 | Args: 108 | model: (torch.nn.Module) the neural network 109 | manager: a class instance that contains objects related to train and evaluate. 110 | """ 111 | # set model to evaluation mode 112 | model.eval() 113 | 114 | # compute metrics over the dataset 115 | if manager.dataloaders["val"] is not None: 116 | # loss status and val status initial 117 | manager.reset_loss_status() 118 | manager.reset_metric_status("val") 119 | for data_batch in manager.dataloaders["val"]: 120 | # compute the real batch size 121 | bs = data_batch["points_src"].shape[0] 122 | # move to GPU if available 123 | data_batch = utils.tensor_mge(data_batch) 124 | # compute model output 125 | output_batch = model(data_batch) 126 | # compute all loss on this batch 127 | loss = compute_losses(output_batch, manager.params) 128 | manager.update_loss_status(loss, "val", bs) 129 | # compute all metrics on this batch 130 | metrics = compute_metrics(output_batch, manager.params) 131 | manager.update_metric_status(metrics, "val", bs) 132 | 133 | # compute RMSE metrics 134 | manager.summarize_metric_status(metrics, "val") 135 | # For each epoch, update and print the metric 136 | manager.print_metrics("val", title="Val", color="green") 137 | 138 | if manager.dataloaders["test"] is not None: 139 | # loss status and test status initial 140 | manager.reset_loss_status() 141 | manager.reset_metric_status("test") 142 | for data_batch in manager.dataloaders["test"]: 143 | # compute the real batch size 144 | bs = data_batch["points_src"].shape[0] 145 | # move to GPU if available 146 | data_batch = utils.tensor_mge(data_batch) 147 | # compute model output 148 | output_batch = model(data_batch) 149 | # compute all loss on this batch 150 | loss = compute_losses(output_batch, manager.params) 151 | manager.update_loss_status(loss, "test", bs) 152 | # compute all metrics on this batch 153 | metrics = compute_metrics(output_batch, manager.params) 154 | manager.update_metric_status(metrics, "test", bs) 155 | 156 | # compute RMSE metrics 157 | manager.summarize_metric_status(metrics, "test") 158 | # For each epoch, print the metric 159 | manager.print_metrics("test", title="Test", color="red") 160 | 161 | 162 | if __name__ == "__main__": 163 | """ 164 | Evaluate the model on the test set. 165 | """ 166 | # Load the parameters 167 | args = parser.parse_args() 168 | json_path = os.path.join(args.model_dir, "params.json") 169 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 170 | params = utils.Params(json_path) 171 | # Only load model weights 172 | params.only_weights = True 173 | 174 | # Update args into params 175 | params.update(vars(args)) 176 | 177 | # Get the logger 178 | logger = utils.set_logger(os.path.join(args.model_dir, "evaluate.log")) 179 | 180 | # Create the input data pipeline 181 | logging.info("Creating the dataset...") 182 | 183 | # Fetch dataloaders 184 | dataloaders = data_loader.fetch_dataloader(params) 185 | 186 | # Define the model and optimizer 187 | model = net.fetch_net(params) 188 | 189 | # Initial status for checkpoint manager 190 | manager = Manager(model=model, optimizer=None, scheduler=None, params=params, dataloaders=dataloaders, writer=None, logger=logger) 191 | 192 | # Reload weights from the saved file 193 | manager.load_checkpoints() 194 | 195 | # Test the model 196 | logger.info("Starting test") 197 | 198 | # Evaluate 199 | test(model, manager) 200 | -------------------------------------------------------------------------------- /experiments/experiment_finet/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "modelnet_ts", 3 | "transform_type": "modelnet_ts_rpmnet_noise", 4 | "net_type": "finet", 5 | "net_config": { 6 | "dropout_ratio": 0.3, 7 | "reg_t_feats": "tr-t", 8 | "reg_R_feats": "tr-tr" 9 | }, 10 | "loss_type": "finet", 11 | "loss_alpha1": 1, 12 | "loss_alpha2": 4, 13 | "loss_alpha3": 0.001, 14 | "loss_alpha4": 0.0025, 15 | "margin": [ 16 | 0.01, 17 | 0.01 18 | ], 19 | "eval_type": [ 20 | "val", 21 | "test" 22 | ], 23 | "major_metric": "score", 24 | "metric_rule": "Descende", 25 | "num_points": 1024, 26 | "rot_mag": 45, 27 | "trans_mag": 0.5, 28 | "partial_ratio": [ 29 | 0.7, 30 | 0.7 31 | ], 32 | "noise_std": 0.01, 33 | "titer": 4, 34 | "overlap_dist": 0.1, 35 | "learning_rate": 1e-4, 36 | "gamma": 1, 37 | "num_epochs": 10000, 38 | "train_batch_size": 8, 39 | "eval_batch_size": 32, 40 | "save_summary_steps": 100, 41 | "num_workers": 8 42 | } 43 | -------------------------------------------------------------------------------- /experiments/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "modelnet_ts", 3 | "transform_type": "modelnet_ts_rpmnet_noise", 4 | "net_type": "finet", 5 | "net_config": { 6 | "dropout_ratio": 0.3, 7 | "reg_t_feats": "tr-t", 8 | "reg_R_feats": "tr-tr" 9 | }, 10 | "loss_type": "finet", 11 | "loss_alpha1": 1, 12 | "loss_alpha2": 4, 13 | "loss_alpha3": 0.001, 14 | "loss_alpha4": 0.0025, 15 | "margin": [ 16 | 0.01, 17 | 0.01 18 | ], 19 | "eval_type": [ 20 | "val", 21 | "test" 22 | ], 23 | "major_metric": "score", 24 | "metric_rule": "Descende", 25 | "num_points": 1024, 26 | "rot_mag": 45, 27 | "trans_mag": 0.5, 28 | "partial_ratio": [ 29 | 0.7, 30 | 0.7 31 | ], 32 | "noise_std": 0.01, 33 | "titer": 4, 34 | "overlap_dist": 0.1, 35 | "learning_rate": 1e-4, 36 | "gamma": 1, 37 | "num_epochs": 10000, 38 | "train_batch_size": 8, 39 | "eval_batch_size": 32, 40 | "save_summary_steps": 100, 41 | "num_workers": 8 42 | } 43 | -------------------------------------------------------------------------------- /images/FINet_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/images/FINet_poster.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/loss/__init__.py -------------------------------------------------------------------------------- /loss/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import megengine.functional as F 3 | from common import se3, so3 4 | 5 | 6 | def compute_losses(endpoints, params): 7 | loss = {} 8 | # compute losses 9 | if params.loss_type == "finet": 10 | num_iter = len(endpoints["all_pose_pair"]) 11 | triplet_loss = {} 12 | for i in range(num_iter): 13 | # reg loss 14 | pose_pair = endpoints["all_pose_pair"][i] 15 | loss["quat_{}".format(i)] = F.nn.l1_loss(pose_pair[0][:, :4], pose_pair[1][:, :4]) * params.loss_alpha1 16 | loss["translate_{}".format(i)] = F.nn.square_loss(pose_pair[0][:, 4:], pose_pair[1][:, 4:]) * params.loss_alpha2 17 | 18 | # transformation sensitivity loss (TSL) 19 | if i < 2: 20 | all_R_feats = endpoints["all_R_feats"][i] 21 | all_t_feats = endpoints["all_t_feats"][i] 22 | # R feats triplet loss 23 | R_feats_pos = F.nn.square_loss(all_t_feats[0], all_t_feats[1]) 24 | R_feats_neg = F.nn.square_loss(all_R_feats[0], all_R_feats[1]) 25 | triplet_loss["R_feats_triplet_pos_{}".format(i)] = R_feats_pos 26 | triplet_loss["R_feats_triplet_neg_{}".format(i)] = R_feats_neg 27 | loss["R_feats_triplet_{}".format(i)] = (F.clip(-R_feats_neg + params.margin[i], lower=0.0) + 28 | R_feats_pos) * params.loss_alpha3 29 | # t feats triplet loss 30 | t_feats_pos = F.nn.square_loss(all_R_feats[0], all_R_feats[2]) 31 | t_feats_neg = F.nn.square_loss(all_t_feats[0], all_t_feats[2]) 32 | triplet_loss["t_feats_triplet_pos_{}".format(i)] = t_feats_pos 33 | triplet_loss["t_feats_triplet_neg_{}".format(i)] = t_feats_neg 34 | loss["t_feats_triplet_{}".format(i)] = (F.clip(-t_feats_neg + params.margin[i], lower=0.0) + 35 | t_feats_pos) * params.loss_alpha3 36 | 37 | # point-wise feature dropout loss (PFDL) 38 | all_dropout_R_feats = endpoints["all_dropout_R_feats"][i] 39 | all_dropout_t_feats = endpoints["all_dropout_t_feats"][i] 40 | loss["src_R_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_R_feats[0], all_dropout_R_feats[1]) * params.loss_alpha4 41 | loss["ref_R_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_R_feats[2], all_dropout_R_feats[3]) * params.loss_alpha4 42 | loss["src_t_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_t_feats[0], all_dropout_t_feats[1]) * params.loss_alpha4 43 | loss["ref_t_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_t_feats[2], all_dropout_t_feats[3]) * params.loss_alpha4 44 | # total loss 45 | total_losses = [] 46 | for k in loss: 47 | total_losses.append(loss[k]) 48 | loss["total"] = F.sum(F.concat(total_losses)) 49 | 50 | else: 51 | raise NotImplementedError 52 | return loss 53 | 54 | 55 | def compute_metrics(endpoints, params): 56 | metrics = {} 57 | gt_transforms = endpoints["transform_pair"][0] 58 | pred_transforms = endpoints["transform_pair"][1] 59 | 60 | # Euler angles, Individual translation errors (Deep Closest Point convention) 61 | if "prnet" in params.transform_type: 62 | r_gt_euler_deg = so3.mge_dcm2euler(gt_transforms[:, :3, :3], seq="zyx") 63 | r_pred_euler_deg = so3.mge_dcm2euler(pred_transforms[:, :3, :3], seq="zyx") 64 | else: 65 | r_gt_euler_deg = so3.mge_dcm2euler(gt_transforms[:, :3, :3], seq="xyz") 66 | r_pred_euler_deg = so3.mge_dcm2euler(pred_transforms[:, :3, :3], seq="xyz") 67 | t_gt = gt_transforms[:, :3, 3] 68 | t_pred = pred_transforms[:, :3, 3] 69 | 70 | r_mse = F.mean((r_gt_euler_deg - r_pred_euler_deg)**2, axis=1) 71 | r_mae = F.mean(F.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1) 72 | t_mse = F.mean((t_gt - t_pred)**2, axis=1) 73 | t_mae = F.mean(F.abs(t_gt - t_pred), axis=1) 74 | 75 | r_mse = F.mean(r_mse) 76 | t_mse = F.mean(t_mse) 77 | r_mae = F.mean(r_mae) 78 | t_mae = F.mean(t_mae) 79 | 80 | # Rotation, translation errors (isotropic, i.e. doesn"t depend on error 81 | # direction, which is more representative of the actual error) 82 | concatenated = se3.mge_concatenate(se3.mge_inverse(gt_transforms), pred_transforms) 83 | rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2] 84 | residual_rotdeg = F.acos(F.clip(0.5 * (rot_trace - 1), -1.0, 1.0)) * 180.0 / np.pi 85 | residual_transmag = F.norm(concatenated[:, :, 3], axis=-1) 86 | err_r = F.mean(residual_rotdeg) 87 | err_t = F.mean(residual_transmag) 88 | 89 | # weighted score of isotropic errors 90 | score = err_r * 0.01 + err_t 91 | 92 | metrics = {"R_MSE": r_mse, "R_MAE": r_mae, "t_MSE": t_mse, "t_MAE": t_mae, "Err_R": err_r, "Err_t": err_t, "score": score} 93 | 94 | return metrics 95 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/model/__init__.py -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import megengine as mge 2 | import megengine.module as nn 3 | import megengine.functional as F 4 | 5 | 6 | class Encoder(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | 11 | # R 12 | self.R_block1 = nn.Sequential(nn.Conv1d(3, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU()) 13 | self.R_block2 = nn.Sequential(nn.Conv1d(64, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU()) 14 | self.R_block3 = nn.Sequential(nn.Conv1d(128, 128, 1, bias=False), nn.BatchNorm1d(128), nn.ReLU()) 15 | self.R_block4 = nn.Sequential(nn.Conv1d(128, 256, 1, bias=False), nn.BatchNorm1d(256), nn.ReLU()) 16 | self.R_block5 = nn.Sequential(nn.Conv1d(512, 512, 1, bias=False), nn.BatchNorm1d(512), nn.ReLU()) 17 | 18 | # t 19 | self.t_block1 = nn.Sequential(nn.Conv1d(3, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU()) 20 | self.t_block2 = nn.Sequential(nn.Conv1d(64, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU()) 21 | self.t_block3 = nn.Sequential(nn.Conv1d(128, 128, 1, bias=False), nn.BatchNorm1d(128), nn.ReLU()) 22 | self.t_block4 = nn.Sequential(nn.Conv1d(128, 256, 1, bias=False), nn.BatchNorm1d(256), nn.ReLU()) 23 | self.t_block5 = nn.Sequential(nn.Conv1d(512, 512, 1, bias=False), nn.BatchNorm1d(512), nn.ReLU()) 24 | 25 | def forward(self, x, mask=None): 26 | B, C, N = x.shape 27 | if self.training: 28 | rand_mask = mge.random.uniform(size=(B, 1, N)) > self.config["dropout_ratio"] 29 | else: 30 | rand_mask = 1 31 | 32 | # R stage1 33 | R_feat_output1 = self.R_block1(x) 34 | if mask is not None: 35 | R_feat_output1 = R_feat_output1 * mask 36 | R_feat_output2 = self.R_block2(R_feat_output1) 37 | if mask is not None: 38 | R_feat_output2 = R_feat_output2 * mask 39 | R_feat_glob2 = F.max(R_feat_output2, axis=-1, keepdims=True) 40 | 41 | # t stage1 42 | t_feat_output1 = self.t_block1(x) 43 | if mask is not None: 44 | t_feat_output1 = t_feat_output1 * mask 45 | t_feat_output2 = self.t_block2(t_feat_output1) 46 | if mask is not None: 47 | t_feat_output2 = t_feat_output2 * mask 48 | t_feat_glob2 = F.max(t_feat_output2, axis=-1, keepdims=True) 49 | 50 | # exchange1 51 | src_R_feat_glob2, ref_R_feat_glob2 = F.split(R_feat_glob2, 2, axis=0) 52 | src_t_feat_glob2, ref_t_feat_glob2 = F.split(t_feat_glob2, 2, axis=0) 53 | exchange_R_feat = F.concat((F.repeat(ref_R_feat_glob2, N, axis=2), F.repeat(src_R_feat_glob2, N, axis=2)), axis=0) 54 | exchange_t_feat = F.concat((F.repeat(ref_t_feat_glob2, N, axis=2), F.repeat(src_t_feat_glob2, N, axis=2)), axis=0) 55 | exchange_R_feat = F.concat((R_feat_output2, exchange_R_feat.detach()), axis=1) 56 | exchange_t_feat = F.concat((t_feat_output2, exchange_t_feat.detach()), axis=1) 57 | 58 | # R stage2 59 | R_feat_output3 = self.R_block3(exchange_R_feat) 60 | if mask is not None: 61 | R_feat_output3 = R_feat_output3 * mask 62 | R_feat_output4 = self.R_block4(R_feat_output3) 63 | if mask is not None: 64 | R_feat_output4 = R_feat_output4 * mask 65 | R_feat_glob4 = F.max(R_feat_output4, axis=-1, keepdims=True) 66 | 67 | # t stage2 68 | t_feat_output3 = self.t_block3(exchange_t_feat) 69 | if mask is not None: 70 | t_feat_output3 = t_feat_output3 * mask 71 | t_feat_output4 = self.t_block4(t_feat_output3) 72 | if mask is not None: 73 | t_feat_output4 = t_feat_output4 * mask 74 | t_feat_glob4 = F.max(t_feat_output4, axis=-1, keepdims=True) 75 | 76 | # exchange2 77 | src_R_feat_glob4, ref_R_feat_glob4 = F.split(R_feat_glob4, 2, axis=0) 78 | src_t_feat_glob4, ref_t_feat_glob4 = F.split(t_feat_glob4, 2, axis=0) 79 | exchange_R_feat = F.concat((F.repeat(ref_R_feat_glob4, N, axis=2), F.repeat(src_R_feat_glob4, N, axis=2)), axis=0) 80 | exchange_t_feat = F.concat((F.repeat(ref_t_feat_glob4, N, axis=2), F.repeat(src_t_feat_glob4, N, axis=2)), axis=0) 81 | exchange_R_feat = F.concat((R_feat_output4, exchange_R_feat.detach()), axis=1) 82 | exchange_t_feat = F.concat((t_feat_output4, exchange_t_feat.detach()), axis=1) 83 | 84 | # R stage3 85 | R_feat_output5 = self.R_block5(exchange_R_feat) 86 | if mask is not None: 87 | R_feat_output5 = R_feat_output5 * mask 88 | 89 | # t stage3 90 | t_feat_output5 = self.t_block5(exchange_t_feat) 91 | if mask is not None: 92 | t_feat_output5 = t_feat_output5 * mask 93 | 94 | # final 95 | R_final_feat_output = F.concat((R_feat_output1, R_feat_output2, R_feat_output3, R_feat_output4, R_feat_output5), axis=1) 96 | t_final_feat_output = F.concat((t_feat_output1, t_feat_output2, t_feat_output3, t_feat_output4, t_feat_output5), axis=1) 97 | 98 | R_final_glob_feat = F.max(R_final_feat_output, axis=-1, keepdims=False) 99 | t_final_glob_feat = F.max(t_final_feat_output, axis=-1, keepdims=False) 100 | 101 | R_final_feat_dropout = R_final_feat_output * rand_mask 102 | R_final_feat_dropout = F.max(R_final_feat_dropout, axis=-1, keepdims=False) 103 | 104 | t_final_feat_dropout = t_final_feat_output * rand_mask 105 | t_final_feat_dropout = F.max(t_final_feat_dropout, axis=-1, keepdims=False) 106 | 107 | return [R_final_glob_feat, t_final_glob_feat, R_final_feat_dropout, t_final_feat_dropout] 108 | 109 | 110 | class Fusion(nn.Module): 111 | def __init__(self): 112 | super().__init__() 113 | 114 | # R 115 | self.R_block1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.ReLU()) 116 | self.R_block2 = nn.Sequential(nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU()) 117 | self.R_block3 = nn.Sequential(nn.Linear(1024, 1024), nn.BatchNorm1d(1024), nn.ReLU()) 118 | 119 | # t 120 | self.t_block1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.ReLU()) 121 | self.t_block2 = nn.Sequential(nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU()) 122 | self.t_block3 = nn.Sequential(nn.Linear(1024, 1024), nn.BatchNorm1d(1024), nn.ReLU()) 123 | 124 | def forward(self, R_feat, t_feat): 125 | # R 126 | fuse_R_feat = self.R_block1(R_feat) 127 | fuse_R_feat = self.R_block2(fuse_R_feat) 128 | fuse_R_feat = self.R_block3(fuse_R_feat) 129 | # t 130 | fuse_t_feat = self.t_block1(t_feat) 131 | fuse_t_feat = self.t_block2(fuse_t_feat) 132 | fuse_t_feat = self.t_block3(fuse_t_feat) 133 | 134 | return [fuse_R_feat, fuse_t_feat] 135 | 136 | 137 | class Regression(nn.Module): 138 | def __init__(self, config): 139 | super().__init__() 140 | self.config = config 141 | if self.config["reg_R_feats"] == "tr-tr": 142 | R_in_channel = 4096 143 | elif self.config["reg_R_feats"] == "tr-r": 144 | R_in_channel = 3072 145 | elif self.config["reg_R_feats"] == "r-r": 146 | R_in_channel = 2048 147 | else: 148 | raise ValueError("Unknown reg_R_feats order {}".format(self.config["reg_R_feats"])) 149 | 150 | if self.config["reg_t_feats"] == "tr-t": 151 | t_in_channel = 3072 152 | elif self.config["reg_t_feats"] == "t-t": 153 | t_in_channel = 2048 154 | else: 155 | raise ValueError("Unknown reg_t_feats order {}".format(self.config["reg_t_feats"])) 156 | 157 | self.R_net = nn.Sequential( 158 | # block 1 159 | nn.Linear(R_in_channel, 2048), 160 | nn.BatchNorm1d(2048), 161 | nn.ReLU(), 162 | # block 2 163 | nn.Linear(2048, 1024), 164 | nn.BatchNorm1d(1024), 165 | nn.ReLU(), 166 | # block 3 167 | nn.Linear(1024, 512), 168 | nn.BatchNorm1d(512), 169 | nn.ReLU(), 170 | # block 4 171 | nn.Linear(512, 256), 172 | nn.BatchNorm1d(256), 173 | nn.ReLU(), 174 | # final fc 175 | nn.Linear(256, 4), 176 | ) 177 | 178 | self.t_net = nn.Sequential( 179 | # block 1 180 | nn.Linear(t_in_channel, 2048), 181 | nn.BatchNorm1d(2048), 182 | nn.ReLU(), 183 | # block 2 184 | nn.Linear(2048, 1024), 185 | nn.BatchNorm1d(1024), 186 | nn.ReLU(), 187 | # block 3 188 | nn.Linear(1024, 512), 189 | nn.BatchNorm1d(512), 190 | nn.ReLU(), 191 | # block 4 192 | nn.Linear(512, 256), 193 | nn.BatchNorm1d(256), 194 | nn.ReLU(), 195 | # final fc 196 | nn.Linear(256, 3), 197 | ) 198 | 199 | def forward(self, R_feat, t_feat): 200 | 201 | pred_quat = self.R_net(R_feat) 202 | pred_quat = F.normalize(pred_quat, axis=1) 203 | pred_translate = self.t_net(t_feat) 204 | 205 | return [pred_quat, pred_translate] 206 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import megengine as mge 2 | import megengine.module as nn 3 | import megengine.functional as F 4 | from model.module import Encoder, Fusion, Regression 5 | from common import quaternion 6 | import math 7 | 8 | 9 | class FINet(nn.Module): 10 | def __init__(self, params): 11 | super().__init__() 12 | self.params = params 13 | self.num_iter = params.titer 14 | self.net_config = params.net_config 15 | self.encoder = [Encoder(self.net_config) for _ in range(self.num_iter)] 16 | self.fusion = [Fusion() for _ in range(self.num_iter)] 17 | self.regression = [Regression(self.net_config) for _ in range(self.num_iter)] 18 | 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): 21 | nn.init.msra_normal_(m.weight, a=math.sqrt(5)) 22 | if m.bias is not None: 23 | fan_in, _ = nn.init.calculate_fan_in_and_fan_out(m.weight) 24 | bound = 1 / math.sqrt(fan_in) 25 | nn.init.uniform_(m.bias, -bound, bound) 26 | # elif isinstance(m, nn.BatchNorm1d): 27 | # nn.init.ones_(m.weight) 28 | # nn.init.zeros_(m.bias) 29 | 30 | def forward(self, data): 31 | endpoints = {} 32 | 33 | xyz_src = data["points_src"][:, :, :3] 34 | xyz_ref = data["points_ref"][:, :, :3] 35 | transform_gt = data["transform_gt"] 36 | pose_gt = data["pose_gt"] 37 | 38 | # init endpoints 39 | all_R_feats = [] 40 | all_t_feats = [] 41 | all_dropout_R_feats = [] 42 | all_dropout_t_feats = [] 43 | all_transform_pair = [] 44 | all_pose_pair = [] 45 | 46 | # init params 47 | B = xyz_src.shape[0] 48 | init_quat = F.tile(mge.tensor([1, 0, 0, 0], dtype="float32"), (B, 1)) # (B, 4) 49 | init_translate = F.tile(mge.tensor([0, 0, 0], dtype="float32"), (B, 1)) # (B, 3) 50 | pose_pred = F.concat((init_quat, init_translate), axis=1) # (B, 7) 51 | 52 | # rename xyz_src 53 | xyz_src_iter = F.copy(xyz_src, device=xyz_src.device) 54 | 55 | for i in range(self.num_iter): 56 | # encoder 57 | encoder = self.encoder[i] 58 | enc_input = F.concat((xyz_src_iter.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)), axis=0) # 2B, C, N 59 | enc_feats = encoder(enc_input) 60 | src_enc_feats = [feat[:B, ...] for feat in enc_feats] 61 | ref_enc_feats = [feat[B:, ...] for feat in enc_feats] 62 | enc_src_R_feat = src_enc_feats[0] # B, C 63 | enc_src_t_feat = src_enc_feats[1] # B, C 64 | enc_ref_R_feat = ref_enc_feats[0] # B, C 65 | enc_ref_t_feat = ref_enc_feats[1] # B, C 66 | 67 | # GFI 68 | src_R_cat_feat = F.concat((enc_src_R_feat, enc_ref_R_feat), axis=-1) # B, 2C 69 | ref_R_cat_feat = F.concat((enc_ref_R_feat, enc_src_R_feat), axis=-1) # B, 2C 70 | src_t_cat_feat = F.concat((enc_src_t_feat, enc_ref_t_feat), axis=-1) # B, 2C 71 | ref_t_cat_feat = F.concat((enc_ref_t_feat, enc_src_t_feat), axis=-1) # B, 2C 72 | fusion_R_input = F.concat((src_R_cat_feat, ref_R_cat_feat), axis=0) # 2B, C 73 | fusion_t_input = F.concat((src_t_cat_feat, ref_t_cat_feat), axis=0) # 2B, C 74 | fusion_feats = self.fusion[i](fusion_R_input, fusion_t_input) 75 | src_fusion_feats = [feat[:B, ...] for feat in fusion_feats] 76 | ref_fusion_feats = [feat[B:, ...] for feat in fusion_feats] 77 | src_R_feat = src_fusion_feats[0] # B, C 78 | src_t_feat = src_fusion_feats[1] # B, C 79 | ref_R_feat = ref_fusion_feats[0] # B, C 80 | ref_t_feat = ref_fusion_feats[1] # B, C 81 | 82 | # R feats 83 | if self.net_config["reg_R_feats"] == "tr-tr": 84 | R_feats = F.concat((src_t_feat, src_R_feat, ref_t_feat, ref_R_feat), axis=-1) # B, 4C 85 | 86 | elif self.net_config["reg_R_feats"] == "tr-r": 87 | R_feats = F.concat((src_R_feat, src_t_feat, ref_R_feat), axis=-1) # B, 3C 88 | 89 | elif self.net_config["reg_R_feats"] == "r-r": 90 | R_feats = F.concat((src_R_feat, ref_R_feat), axis=-1) # B, 2C 91 | 92 | else: 93 | raise ValueError("Unknown reg_R_feats order {}".format(self.net_config["reg_R_feats"])) 94 | 95 | # t feats 96 | if self.net_config["reg_t_feats"] == "tr-t": 97 | src_t_feats = F.concat((src_t_feat, src_R_feat, ref_t_feat), axis=-1) # B, 3C 98 | ref_t_feats = F.concat((ref_t_feat, ref_R_feat, src_t_feat), axis=-1) # B, 3C 99 | 100 | elif self.net_config["reg_t_feats"] == "t-t": 101 | src_t_feats = F.concat((src_t_feat, ref_t_feat), axis=-1) # B, 2C 102 | ref_t_feats = F.concat((ref_t_feat, src_t_feat), axis=-1) # B, 2C 103 | 104 | else: 105 | raise ValueError("Unknown reg_t_feats order {}".format(self.net_config["reg_t_feats"])) 106 | 107 | # regression 108 | t_feats = F.concat((src_t_feats, ref_t_feats), axis=0) # 2B, 3C or 2B, 2C 109 | pred_quat, pred_center = self.regression[i](R_feats, t_feats) 110 | src_pred_center, ref_pred_center = F.split(pred_center, 2, axis=0) 111 | pred_translate = ref_pred_center - src_pred_center 112 | pose_pred_iter = F.concat((pred_quat, pred_translate), axis=-1) # B, 7 113 | 114 | # extract features for compute transformation sensitivity loss (TSL) 115 | xyz_src_rotated = quaternion.mge_quat_rotate(xyz_src_iter.detach(), pose_pred_iter.detach()) # B, N, 3 116 | xyz_src_translated = xyz_src_iter.detach() + F.expand_dims(pose_pred_iter.detach()[:, 4:], axis=1) # B, N, 3 117 | 118 | rotated_enc_input = F.concat((xyz_src_rotated.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)), axis=0) # 2B, C, N 119 | rotated_enc_feats = encoder(rotated_enc_input) 120 | rotated_src_enc_feats = [feat[:B, ...] for feat in rotated_enc_feats] 121 | rotated_enc_src_R_feat = rotated_src_enc_feats[0] # B, C 122 | rotated_enc_src_t_feat = rotated_src_enc_feats[1] # B, C 123 | 124 | translated_enc_input = F.concat((xyz_src_translated.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)), 125 | axis=0) # 2B, C, N 126 | translated_enc_feats = encoder(translated_enc_input) 127 | translated_src_enc_feats = [feat[:B, ...] for feat in translated_enc_feats] 128 | translated_enc_src_R_feat = translated_src_enc_feats[0] # B, C 129 | translated_enc_src_t_feat = translated_src_enc_feats[1] # B, C 130 | 131 | # dropout 132 | dropout_src_R_feat = src_enc_feats[2] # B, C 133 | dropout_src_t_feat = src_enc_feats[3] # B, C 134 | dropout_ref_R_feat = ref_enc_feats[2] # B, C 135 | dropout_ref_t_feat = ref_enc_feats[3] # B, C 136 | 137 | # do transform 138 | xyz_src_iter = quaternion.mge_quat_transform(pose_pred_iter, xyz_src_iter.detach()) 139 | pose_pred = quaternion.mge_transform_pose(pose_pred.detach(), pose_pred_iter) 140 | transform_pred = quaternion.mge_quat2mat(pose_pred) 141 | 142 | # add endpoints at each iteration 143 | all_R_feats.append([enc_src_R_feat, rotated_enc_src_R_feat, translated_enc_src_R_feat]) 144 | all_t_feats.append([enc_src_t_feat, rotated_enc_src_t_feat, translated_enc_src_t_feat]) 145 | all_dropout_R_feats.append([dropout_src_R_feat, enc_src_R_feat, dropout_ref_R_feat, enc_ref_R_feat]) 146 | all_dropout_t_feats.append([dropout_src_t_feat, enc_src_t_feat, dropout_ref_t_feat, enc_ref_t_feat]) 147 | all_transform_pair.append([transform_gt, transform_pred]) 148 | all_pose_pair.append([pose_gt, pose_pred]) 149 | 150 | mge.coalesce_free_memory() 151 | 152 | # add endpoints finally 153 | endpoints["all_R_feats"] = all_R_feats 154 | endpoints["all_t_feats"] = all_t_feats 155 | endpoints["all_dropout_R_feats"] = all_dropout_R_feats 156 | endpoints["all_dropout_t_feats"] = all_dropout_t_feats 157 | endpoints["all_transform_pair"] = all_transform_pair 158 | endpoints["all_pose_pair"] = all_pose_pair 159 | endpoints["transform_pair"] = [transform_gt, transform_pred] 160 | endpoints["pose_pair"] = [pose_gt, pose_pred] 161 | 162 | return endpoints 163 | 164 | 165 | def fetch_net(params): 166 | if params.net_type == "finet": 167 | net = FINet(params) 168 | 169 | else: 170 | raise NotImplementedError 171 | return net 172 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coloredlogs==15.0.1 2 | h5py==3.5.0 3 | megengine==1.7.0 4 | numpy==1.21.4 5 | scipy==1.7.2 6 | tensorboardX==2.4 7 | termcolor==1.1.0 8 | tqdm==4.62.3 9 | transforms3d==0.3.1 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train the model""" 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | 7 | import megengine as mge 8 | # mge.core.set_option("async_level", 0) 9 | 10 | from megengine.optimizer import Adam, MultiStepLR, LRScheduler 11 | from megengine.autodiff import GradManager 12 | import megengine.distributed as dist 13 | from tqdm import tqdm 14 | 15 | import dataset.data_loader as data_loader 16 | import model.net as net 17 | 18 | from common import utils 19 | from common.manager import Manager 20 | from evaluate import evaluate 21 | from loss.losses import compute_losses 22 | from tensorboardX import SummaryWriter 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--model_dir", default="experiments/experiment_omnet", help="Directory containing params.json") 26 | parser.add_argument("--restore_file", 27 | default=None, 28 | help="Optional, name of the file in model_dir containing weights to reload before training") 29 | parser.add_argument("-ow", "--only_weights", action="store_true", help="Only load model weights or load all train status.") 30 | 31 | 32 | def train(model, manager: Manager, gm): 33 | rank = dist.get_rank() 34 | # loss status and val/test status initial 35 | manager.reset_loss_status() 36 | # set model to training mode 37 | model.train() 38 | # Use tqdm for progress bar 39 | if rank == 0: 40 | t = tqdm(total=len(manager.dataloaders["train"])) 41 | 42 | for i, data_batch in enumerate(manager.dataloaders["train"]): 43 | # move to GPU if available 44 | data_batch = utils.tensor_mge(data_batch) 45 | 46 | # infor print 47 | print_str = manager.print_train_info() 48 | 49 | with gm: 50 | # compute model output and loss 51 | output_batch = model(data_batch) 52 | loss = compute_losses(output_batch, manager.params) 53 | 54 | # update loss status and print current loss and average loss 55 | manager.update_loss_status(loss=loss, split="train") 56 | gm.backward(loss["total"]) 57 | 58 | # performs updates using calculated gradients 59 | manager.optimizer.step().clear_grad() 60 | 61 | manager.update_step() 62 | if rank == 0: 63 | manager.writer.add_scalar("Loss/train", manager.loss_status["total"].val, manager.step) 64 | t.set_description(desc=print_str) 65 | t.update() 66 | 67 | if rank == 0: 68 | t.close() 69 | 70 | manager.scheduler.step() 71 | manager.update_epoch() 72 | 73 | 74 | def train_and_evaluate(model, manager: Manager): 75 | rank = dist.get_rank() 76 | # reload weights from restore_file if specified 77 | if args.restore_file is not None: 78 | manager.load_checkpoints() 79 | 80 | world_size = dist.get_world_size() 81 | if world_size > 1: 82 | dist.bcast_list_(model.parameters()) 83 | dist.bcast_list_(model.buffers()) 84 | 85 | gm = GradManager().attach( 86 | model.parameters(), 87 | callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None, 88 | ) 89 | 90 | for epoch in range(manager.params.num_epochs): 91 | # compute number of batches in one epoch (one full pass over the training set) 92 | train(model, manager, gm) 93 | 94 | # Evaluate for one epoch on validation set 95 | evaluate(model, manager) 96 | 97 | # Save best model weights accroding to the params.major_metric 98 | if rank == 0: 99 | manager.check_best_save_last_checkpoints(save_latest_freq=100, save_best_after=200) 100 | 101 | 102 | def main(params): 103 | # DTR support 104 | # mge.dtr.eviction_threshold = "5GB" 105 | # mge.dtr.enable() 106 | 107 | # Set the logger 108 | logger = utils.set_logger(os.path.join(params.model_dir, "train.log")) 109 | 110 | # Set the tensorboard writer 111 | tb_dir = os.path.join(params.model_dir, "summary") 112 | os.makedirs(tb_dir, exist_ok=True) 113 | writter = SummaryWriter(log_dir=tb_dir) 114 | 115 | # fetch dataloaders 116 | dataloaders = data_loader.fetch_dataloader(params) 117 | 118 | # Define the model and optimizer 119 | model = net.fetch_net(params) 120 | 121 | optimizer = Adam(model.parameters(), lr=params.learning_rate) 122 | scheduler = MultiStepLR(optimizer, milestones=[]) 123 | 124 | # initial status for checkpoint manager 125 | manager = Manager(model=model, 126 | optimizer=optimizer, 127 | scheduler=scheduler, 128 | params=params, 129 | dataloaders=dataloaders, 130 | writer=writter, 131 | logger=logger) 132 | 133 | # Train the model 134 | utils.master_logger(logger, "Starting training for {} epoch(s)".format(params.num_epochs)) 135 | 136 | train_and_evaluate(model, manager) 137 | 138 | 139 | if __name__ == "__main__": 140 | # Load the parameters from json file 141 | args = parser.parse_args() 142 | json_path = os.path.join(args.model_dir, "params.json") 143 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) 144 | params = utils.Params(json_path) 145 | params.update(vars(args)) 146 | 147 | train_proc = dist.launcher(main) if mge.device.get_device_count("gpu") > 1 else main 148 | train_proc(params) 149 | --------------------------------------------------------------------------------