├── .gitignore
├── LICENSE
├── README.md
├── common
├── __init__.py
├── manager.py
├── quaternion.py
├── se3.py
├── so3.py
└── utils.py
├── dataset
├── __init__.py
├── data_loader.py
└── transformations.py
├── evaluate.py
├── experiments
├── experiment_finet
│ └── params.json
└── params.json
├── images
└── FINet_poster.png
├── loss
├── __init__.py
└── losses.py
├── model
├── __init__.py
├── module.py
└── net.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | experiments/experiment_finet/summary
2 | experiments/experiment_finet/test_metrics_best.json
3 | experiments/experiment_finet/test_metrics_latest.json
4 | experiments/experiment_finet/test_model_best.pth
5 | experiments/experiment_finet/train.log
6 | experiments/experiment_finet/val_metrics_best.json
7 | experiments/experiment_finet/val_metrics_latest.json
8 | experiments/experiment_finet/val_model_best.pth
9 | experiments/experiment_finet/evaluate.log
10 | dataset/data/modelnet_os
11 | dataset/data/modelnet_ts
12 | dataset/data/OS_data.zip
13 | dataset/data/TS_data.zip
14 | experiments/experiment_finet/model_latest.pth
15 | __pycache__
16 | */__pycache__
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Megvii Technology
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AAAI 2022] FINet: Dual Branches Feature Interaction for Partial-to-Partial Point Cloud Registration
2 |
3 |
Hao Xu1,2, Nianjin Ye2, Guanghui Liu1, Bing Zeng1, Shuaicheng Liu1
4 | $^1$ University of Electronic Science and Technology of China
5 | $^2$ Megvii Research
6 |
7 |
8 | This is the official implementation (MegEngine implementation) of our AAAI2022 paper [FINet](https://www.aaai.org/AAAI22Papers/AAAI-549.XuH.pdf).
9 |
10 | ## Presentation video:
11 | [[Youtube](https://www.youtube.com/watch?v=XDmE9iSx9WM)] [[Bilibili](https://www.bilibili.com/video/BV1z44y1s7up/)].
12 |
13 | ## Abstract
14 | Data association is important in the point cloud registration. In this work, we propose to solve the partial-to-partial registration from a new perspective, by introducing multi-level feature interactions between the source and the reference clouds at the feature extraction stage, such that the registration can be realized without the attentions or explicit mask estimation for the overlapping detection as adopted previously. Specifically, we present FINet, a feature interactionbased structure with the capability to enable and strengthen the information associating between the inputs at multiple stages. To achieve this, we first split the features into two components, one for rotation and one for translation, based on the fact that they belong to different solution spaces, yielding a dual branches structure. Second, we insert several interaction modules at the feature extractor for the data association. Third, we propose a transformation sensitivity loss to obtain rotation-attentive and translation-attentive features. Experiments demonstrate that our method performs higher precision and robustness compared to the state-of-the-art traditional and learning-based methods
15 |
16 |
17 | ## Our Poster
18 |
19 | 
20 |
21 | ## Dependencies
22 |
23 | * MegEngine==1.7.0
24 | * Other requirements please refer to`requirements.txt`.
25 |
26 | ## Data Preparation
27 |
28 | Following [OMNet](https://github.com/megvii-research/OMNet), we use the OS and TS data of the ModelNet40 dataset.
29 |
30 | ### OS data
31 |
32 | We refer the original data from PointNet as OS data, where point clouds are only sampled once from corresponding CAD models. We offer two ways to use OS data, (1) you can download this data from its original link [original_OS_data.zip](http://modelnet.cs.princeton.edu/). (2) you can also download the data that has been preprocessed by us from link [our_OS_data.zip](https://drive.google.com/file/d/1rXnbXwD72tkeu8x6wboMP0X7iL9LiBPq/view?usp=sharing).
33 |
34 | ### TS data
35 |
36 | Since OS data incurs over-fitting issue, we propose our TS data, where point clouds are randomly sampled twice from CAD models. You need to download our preprocessed ModelNet40 dataset first, where 8 axisymmetrical categories are removed and all CAD models have 40 randomly sampled point clouds. The download link is [TS_data.zip](https://drive.google.com/file/d/1DPBBI3Ulvp2Mx7SAZaBEyvADJzBvErFF/view?usp=sharing). All 40 point clouds of a CAD model are stacked to form a (40, 2048, 3) numpy array, you can easily obtain this data by using following code:
37 |
38 | ```
39 | import numpy as np
40 | points = np.load("path_of_npy_file")
41 | print(points.shape, type(points)) # (40, 2048, 3),
42 | ```
43 |
44 | Then, you need to put the data into `./dataset/data`, and the contents of directories are as follows:
45 |
46 | ```
47 | ./dataset/data/
48 | ├── modelnet40_half1_rm_rotate.txt
49 | ├── modelnet40_half2_rm_rotate.txt
50 | ├── modelnet_os
51 | │ ├── modelnet_os_test.pickle
52 | │ ├── modelnet_os_train.pickle
53 | │ ├── modelnet_os_val.pickle
54 | │ ├── test [1146 entries exceeds filelimit, not opening dir]
55 | │ ├── train [4194 entries exceeds filelimit, not opening dir]
56 | │ └── val [1002 entries exceeds filelimit, not opening dir]
57 | └── modelnet_ts
58 | ├── modelnet_ts_test.pickle
59 | ├── modelnet_ts_train.pickle
60 | ├── modelnet_ts_val.pickle
61 | ├── shape_names.txt
62 | ├── test [1146 entries exceeds filelimit, not opening dir]
63 | ├── train [4196 entries exceeds filelimit, not opening dir]
64 | └── val [1002 entries exceeds filelimit, not opening dir]
65 | ```
66 |
67 | ## Training and Evaluation
68 |
69 | ### Begin training
70 |
71 | For ModelNet40 dataset, you can just run:
72 |
73 | ```
74 | python3 train.py --model_dir=./experiments/experiment_finet/
75 | ```
76 |
77 | For other dataset, you need to add your own dataset class in `./dataset/data_loader.py`. Training with a lower batch size, such as 16, may obtain worse performance than training with a larger batch size, e.g., 64.
78 |
79 | ### Begin testing
80 |
81 | You need to download the pretrained checkpoint and run:
82 |
83 | ```
84 | python3 evaluate.py --model_dir=./experiments/experiment_finet --restore_file=./experiments/experiment_finet/test_model_best.pth
85 | ```
86 |
87 | This model weight is for TS data with Gaussian noise. Note that the performance is a little bit worse than the results reported in our paper (Pytorch implementation).
88 |
89 | MegEngine checkpoint for ModelNet40 dataset can be download via [Google Drive](https://drive.google.com/file/d/1nM9bzSYGYA8fsQ0-HSPLo4rOdkG5rxAS/view?usp=sharing).
90 |
91 | ## Citation
92 |
93 | ```
94 | @InProceedings{Xu_2022_AAAI,
95 | author={Xu, Hao and Ye, Nianjin and Liu, Guanghui and Zeng, Bing and Liu, Shuaicheng},
96 | title={FINet: Dual Branches Feature Interaction for Partial-to-Partial Point Cloud Registration},
97 | booktitle={Proceedings of the Thirty-Sixth AAAI Conference on Artificial Intelligence},
98 | year={2022}
99 | }
100 | ```
101 |
102 | ## Acknowledgments
103 |
104 | In this project we use (parts of) the official implementations of the following works:
105 |
106 | * [RPMNet](https://github.com/yewzijian/RPMNet) (ModelNet40 preprocessing and evaluation)
107 | * [PRNet](https://github.com/WangYueFt/prnet) (ModelNet40 preprocessing)
108 | * [OMNet](https://github.com/megvii-research/OMNet) (Code base)
109 |
110 | We thank the respective authors for open sourcing their methods.
111 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/common/__init__.py
--------------------------------------------------------------------------------
/common/manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import megengine as mge
4 | import megengine.distributed as dist
5 | from collections import defaultdict
6 | from termcolor import colored
7 |
8 | from common import utils
9 |
10 |
11 | class Manager():
12 | def __init__(self, model, optimizer, params, dataloaders, writer, logger, scheduler):
13 | # params status
14 | self.params = params
15 | self.optimizer = optimizer
16 | self.model = model
17 | self.dataloaders = dataloaders
18 | self.writer = writer
19 | self.logger = logger
20 | self.scheduler = scheduler
21 |
22 | # metric_rule should be either Descende or Ascende
23 | self.metric_rule = params.metric_rule
24 |
25 | self.epoch = 0
26 | self.step = 0
27 |
28 | # 越低越好
29 | if self.metric_rule == "Descende":
30 | self.best_val_score = 100
31 | self.best_test_score = 100
32 | # 越高越好
33 | elif self.metric_rule == "Ascende":
34 | self.best_val_score = 0
35 | self.best_test_score = 0
36 |
37 | self.cur_val_score = 0
38 | self.cur_test_score = 0
39 |
40 | # train status
41 | self.train_status = defaultdict(utils.AverageMeter)
42 |
43 | # val status
44 | self.val_status = defaultdict(utils.AverageMeter)
45 |
46 | # test status
47 | self.test_status = defaultdict(utils.AverageMeter)
48 |
49 | # model status
50 | self.loss_status = defaultdict(utils.AverageMeter)
51 |
52 | def update_step(self):
53 | self.step += 1
54 |
55 | def update_epoch(self):
56 | self.epoch += 1
57 |
58 | def update_loss_status(self, loss, split, bs=None):
59 | if split == "train":
60 | for k, v in loss.items():
61 | bs = self.params.train_batch_size
62 | self.loss_status[k].update(val=v.item(), num=bs)
63 | elif split == "val":
64 | for k, v in loss.items():
65 | self.loss_status[k].update(val=v.item(), num=bs)
66 | elif split == "test":
67 | for k, v in loss.items():
68 | self.loss_status[k].update(val=v.item(), num=bs)
69 | else:
70 | raise ValueError("Wrong eval type: {}".format(split))
71 |
72 | def update_metric_status(self, metrics, split, bs):
73 | if split == "val":
74 | for k, v in metrics.items():
75 | self.val_status[k].update(val=v.item(), num=bs)
76 | self.cur_val_score = self.val_status[self.params.major_metric].avg
77 | elif split == "test":
78 | for k, v in metrics.items():
79 | self.test_status[k].update(val=v.item(), num=bs)
80 | self.cur_test_score = self.test_status[self.params.major_metric].avg
81 | else:
82 | raise ValueError("Wrong eval type: {}".format(split))
83 |
84 | def summarize_metric_status(self, metrics, split):
85 | if split == "val":
86 | for k in metrics:
87 | if k.endswith('MSE'):
88 | self.val_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.val_status[k].avg))
89 | else:
90 | continue
91 | elif split == "test":
92 | for k in metrics:
93 | if k.endswith('MSE'):
94 | self.test_status[k[:-3] + 'RMSE'].set(val=np.sqrt(self.test_status[k].avg))
95 | else:
96 | continue
97 | else:
98 | raise ValueError("Wrong eval type: {}".format(split))
99 |
100 | def reset_loss_status(self):
101 | for k, v in self.loss_status.items():
102 | self.loss_status[k].reset()
103 |
104 | def reset_metric_status(self, split):
105 | if split == "val":
106 | for k, v in self.val_status.items():
107 | self.val_status[k].reset()
108 | elif split == "test":
109 | for k, v in self.test_status.items():
110 | self.test_status[k].reset()
111 | else:
112 | raise ValueError("Wrong eval type: {}".format(split))
113 |
114 | def print_train_info(self):
115 | exp_name = self.params.model_dir.split('/')[-1]
116 | print_str = "{} Epoch: {:4d}, lr={:.1E} ".format(exp_name, self.epoch, self.scheduler.get_lr()[0])
117 | print_str += "total loss: {:.3f}({:.3f})".format(self.loss_status['total'].val, self.loss_status['total'].avg)
118 | return print_str
119 |
120 | def print_metrics(self, split, title="Eval", color="red", only_best=True):
121 | if split == "val":
122 | metric_status = self.val_status
123 | is_best = self.cur_val_score < self.best_val_score
124 | elif split == "test":
125 | metric_status = self.test_status
126 | is_best = self.cur_test_score < self.best_test_score
127 | else:
128 | raise ValueError("Wrong split string: {}".format(split))
129 | print_str = " | ".join("{}: {:.3f}".format(k, v.avg) for k, v in metric_status.items())
130 | if only_best:
131 | if is_best:
132 | utils.master_logger(self.logger,
133 | colored("Best Epoch: {}, {} Results: {}".format(self.epoch, title, print_str), color, attrs=["bold"]),
134 | dist.get_rank() == 0)
135 | else:
136 | utils.master_logger(self.logger, colored("Epoch: {}, {} Results: {}".format(self.epoch, title, print_str),
137 | color,
138 | attrs=["bold"]),
139 | dist.get_rank() == 0)
140 |
141 | def check_best_save_last_checkpoints(self, save_latest_freq=5, save_best_after=50):
142 |
143 | state = {
144 | "state_dict": self.model.state_dict(),
145 | "optimizer": self.optimizer.state_dict(),
146 | "scheduler": self.scheduler.state_dict(),
147 | "step": self.step,
148 | "epoch": self.epoch,
149 | }
150 | if "val" in self.dataloaders:
151 | state["best_val_score"] = self.best_val_score
152 | if "test" in self.dataloaders:
153 | state["best_test_score"] = self.best_test_score
154 |
155 | # save latest checkpoint
156 | if self.epoch % save_latest_freq == 0:
157 | latest_ckpt_name = os.path.join(self.params.model_dir, "model_latest.pth")
158 | mge.save(state, latest_ckpt_name)
159 | self.logger.info("Saved latest checkpoint to: {}".format(latest_ckpt_name))
160 |
161 | # save val latest metrics, and check if val is best checkpoints
162 | if "val" in self.dataloaders:
163 | val_latest_metrics_name = os.path.join(self.params.model_dir, "val_metrics_latest.json")
164 | utils.save_dict_to_json(self.val_status, val_latest_metrics_name)
165 |
166 | # 越低越好
167 | if self.metric_rule == "Descende":
168 | is_best = self.cur_val_score < self.best_val_score
169 | # 越高越好
170 | elif self.metric_rule == "Ascende":
171 | is_best = self.cur_val_score > self.best_val_score
172 | else:
173 | raise Exception("metric_rule should be either Descende or Ascende")
174 |
175 | if is_best:
176 | # save metrics
177 | self.best_val_score = self.cur_val_score
178 | best_metrics_name = os.path.join(self.params.model_dir, "val_metrics_best.json")
179 | utils.save_dict_to_json(self.val_status, best_metrics_name)
180 | self.logger.info("Current is val best, score={:.3g}".format(self.best_val_score))
181 | # save checkpoint
182 | if self.epoch > save_best_after:
183 | best_ckpt_name = os.path.join(self.params.model_dir, "val_model_best.pth")
184 | mge.save(state, best_ckpt_name)
185 | self.logger.info("Saved val best checkpoint to: {}".format(best_ckpt_name))
186 |
187 | # save test latest metrics, and check if test is best checkpoints
188 | if "test" in self.dataloaders:
189 | test_latest_metrics_name = os.path.join(self.params.model_dir, "test_metrics_latest.json")
190 | utils.save_dict_to_json(self.test_status, test_latest_metrics_name)
191 | # lower is better
192 | if self.metric_rule == "Descende":
193 | is_best = self.cur_test_score < self.best_test_score
194 | # higher is better
195 | elif self.metric_rule == "Ascende":
196 | is_best = self.cur_test_score > self.best_test_score
197 | else:
198 | raise Exception("metric_rule should be either Descende or Ascende")
199 | if is_best:
200 | # save metrics
201 | self.best_test_score = self.cur_test_score
202 | best_metrics_name = os.path.join(self.params.model_dir, "test_metrics_best.json")
203 | utils.save_dict_to_json(self.test_status, best_metrics_name)
204 | self.logger.info("Current is test best, score={:.3g}".format(self.best_test_score))
205 | # save checkpoint
206 | if self.epoch > save_best_after:
207 | best_ckpt_name = os.path.join(self.params.model_dir, "test_model_best.pth")
208 | mge.save(state, best_ckpt_name)
209 | self.logger.info("Saved test best checkpoint to: {}".format(best_ckpt_name))
210 |
211 | def load_checkpoints(self):
212 | state = mge.load(self.params.restore_file)
213 | ckpt_component = []
214 | if "state_dict" in state and self.model is not None:
215 | try:
216 | self.model.load_state_dict(state["state_dict"])
217 |
218 | except Warning("Using custom loading net"):
219 | net_dict = self.model.state_dict()
220 | if "module" not in list(state["state_dict"].keys())[0]:
221 | state_dict = {"module." + k: v for k, v in state["state_dict"].items() if "module." + k in net_dict.keys()}
222 | else:
223 | state_dict = {k: v for k, v in state["state_dict"].items() if k in net_dict.keys()}
224 | net_dict.update(state_dict)
225 | self.model.load_state_dict(net_dict, strict=False)
226 | ckpt_component.append("net")
227 |
228 | if not self.params.only_weights:
229 | if "optimizer" in state and self.optimizer is not None:
230 | try:
231 | self.optimizer.load_state_dict(state["optimizer"])
232 |
233 | except Warning("Using custom loading optimizer"):
234 | optimizer_dict = self.optimizer.state_dict()
235 | state_dict = {k: v for k, v in state["optimizer"].items() if k in optimizer_dict.keys()}
236 | optimizer_dict.update(state_dict)
237 | self.optimizer.load_state_dict(optimizer_dict)
238 | ckpt_component.append("opt")
239 |
240 | if "scheduler" in state and self.train_status["scheduler"] is not None:
241 | try:
242 | self.scheduler.load_state_dict(state["scheduler"])
243 |
244 | except Warning("Using custom loading scheduler"):
245 | scheduler_dict = self.scheduler.state_dict()
246 | state_dict = {k: v for k, v in state["scheduler"].items() if k in scheduler_dict.keys()}
247 | scheduler_dict.update(state_dict)
248 | self.scheduler.load_state_dict(scheduler_dict)
249 | ckpt_component.append("sch")
250 |
251 | if "step" in state:
252 | self.step = state["step"] + 1
253 | self.train_status["step"] = state["step"] + 1
254 | ckpt_component.append("step: {}".format(self.train_status["step"]))
255 |
256 | if "epoch" in state:
257 | self.epoch = state["epoch"] + 1
258 | self.train_status["epoch"] = state["epoch"] + 1
259 | ckpt_component.append("epoch: {}".format(self.train_status["epoch"]))
260 |
261 | if "best_val_score" in state:
262 | self.best_val_score = state["best_val_score"]
263 | ckpt_component.append("best val score: {:.3g}".format(self.best_val_score))
264 |
265 | if "best_test_score" in state:
266 | self.best_test_score = state["best_test_score"]
267 | ckpt_component.append("best test score: {:.3g}".format(self.best_test_score))
268 |
269 | ckpt_component = ", ".join(i for i in ckpt_component)
270 | utils.master_logger(self.logger, "Loaded models from: {}".format(self.params.restore_file), dist.get_rank() == 0)
271 | utils.master_logger(self.logger, "Ckpt load: {}".format(ckpt_component), dist.get_rank() == 0)
272 |
--------------------------------------------------------------------------------
/common/quaternion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import megengine.functional as F
4 |
5 |
6 | def mge_qmul(q1, q2):
7 | """
8 | Multiply quaternion(s) q2q1, rotate q1 first, rotate q2 second.
9 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
10 | Returns q*r as a tensor of shape (*, 4).
11 | """
12 | assert q1.shape[-1] == 4
13 | assert q2.shape[-1] == 4
14 |
15 | original_shape = q1.shape
16 |
17 | # Compute outer product
18 | terms = F.matmul(q1.reshape(-1, 4, 1), q2.reshape(-1, 1, 4))
19 |
20 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
21 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
22 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
23 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
24 | return F.stack((w, x, y, z), axis=1).reshape(original_shape)
25 |
26 |
27 | def mge_qrot(q, v):
28 | """
29 | Rotate vector(s) v about the rotation described by quaternion(s) q.
30 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
31 | where * denotes any number of dimensions.
32 | Returns a tensor of shape (*, 3).
33 | """
34 | assert q.shape[-1] == 4
35 | assert v.shape[-1] == 3
36 | assert q.shape[:-1] == v.shape[:-1]
37 |
38 | original_shape = list(v.shape)
39 | q = q.reshape(-1, 4)
40 | v = v.reshape(-1, 3)
41 |
42 | qvec = q[:, 1:]
43 | uv = F.stack((
44 | qvec[:, 1] * v[:, 2] - qvec[:, 2] * v[:, 1],
45 | qvec[:, 2] * v[:, 0] - qvec[:, 0] * v[:, 2],
46 | qvec[:, 0] * v[:, 1] - qvec[:, 1] * v[:, 0],
47 | ),
48 | axis=1)
49 | uuv = F.stack((
50 | qvec[:, 1] * uv[:, 2] - qvec[:, 2] * uv[:, 1],
51 | qvec[:, 2] * uv[:, 0] - qvec[:, 0] * uv[:, 2],
52 | qvec[:, 0] * uv[:, 1] - qvec[:, 1] * uv[:, 0],
53 | ),
54 | axis=1)
55 | # uv = F.cross(qvec, v, dim=1)
56 | # uuv = F.cross(qvec, uv, dim=1)
57 | return (v + 2 * (q[:, :1] * uv + uuv)).reshape(original_shape)
58 |
59 |
60 | # TODO: check
61 | def mge_quat2euler(q, order, epsilon=0):
62 | """
63 | Convert quaternion(s) q to Euler angles.
64 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
65 | Returns a tensor of shape (*, 3).
66 | """
67 | assert q.shape[-1] == 4
68 |
69 | original_shape = list(q.shape)
70 | original_shape[-1] = 3
71 | q = q.reshape(-1, 4)
72 |
73 | q0 = q[:, 0]
74 | q1 = q[:, 1]
75 | q2 = q[:, 2]
76 | q3 = q[:, 3]
77 |
78 | if order == "xyz":
79 | x = F.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
80 | y = F.asin(F.clip(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
81 | z = F.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
82 | elif order == "yzx":
83 | x = F.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
84 | y = F.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
85 | z = F.asin(F.clip(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
86 | elif order == "zxy":
87 | x = F.asin(F.clip(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
88 | y = F.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
89 | z = F.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
90 | elif order == "xzy":
91 | x = F.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
92 | y = F.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
93 | z = F.asin(F.clip(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
94 | elif order == "yxz":
95 | x = F.asin(F.clip(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
96 | y = F.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
97 | z = F.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
98 | elif order == "zyx":
99 | x = F.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
100 | y = F.asin(F.clip(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
101 | z = F.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
102 | else:
103 | raise
104 |
105 | return F.stack((x, y, z), axis=1).reshape(original_shape)
106 |
107 |
108 | # TODO: check
109 | def mge_euler2quat(e, order):
110 | """
111 | Convert Euler angles to quaternions.
112 | """
113 | assert e.shape[-1] == 3
114 |
115 | original_shape = [e.shape[0], 4]
116 |
117 | x = e[:, 0]
118 | y = e[:, 1]
119 | z = e[:, 2]
120 |
121 | rx = F.stack((F.cos(x / 2), F.sin(x / 2), F.zeros_like(x).cuda(), F.zeros_like(x).cuda()), axis=1)
122 | ry = F.stack((F.cos(y / 2), F.zeros_like(y).cuda(), F.sin(y / 2), F.zeros_like(y).cuda()), axis=1)
123 | rz = F.stack((F.cos(z / 2), F.zeros_like(z).cuda(), F.zeros_like(z).cuda(), F.sin(z / 2)), axis=1)
124 |
125 | result = None
126 | for coord in order:
127 | if coord == "x":
128 | r = rx
129 | elif coord == "y":
130 | r = ry
131 | elif coord == "z":
132 | r = rz
133 | else:
134 | raise
135 | if result is None:
136 | result = r
137 | else:
138 | result = mge_qmul(result, r)
139 |
140 | # Reverse antipodal representation to have a non-negative "w"
141 | if order in ["xyz", "yzx", "zxy"]:
142 | result *= -1
143 |
144 | return result.reshape(original_shape)
145 |
146 |
147 | def mge_quat2mat(pose):
148 | # Separate each quaternion value.
149 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3]
150 | # Convert quaternion to rotation matrix.
151 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf
152 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay.
153 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3
154 | R12 = 2 * (q1 * q2 - q0 * q3)
155 | R13 = 2 * (q1 * q3 + q0 * q2)
156 | R21 = 2 * (q1 * q2 + q0 * q3)
157 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3
158 | R23 = 2 * (q2 * q3 - q0 * q1)
159 | R31 = 2 * (q1 * q3 - q0 * q2)
160 | R32 = 2 * (q2 * q3 + q0 * q1)
161 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2
162 | R = F.stack((F.stack((R11, R12, R13), axis=0), F.stack((R21, R22, R23), axis=0), F.stack((R31, R32, R33), axis=0)), axis=0)
163 |
164 | rot_mat = F.transpose(R, (2, 0, 1)) # (B, 3, 3)
165 | translation = F.expand_dims(pose[:, 4:], axis=-1) # (B, 3, 1)
166 | transform = F.concat((rot_mat, translation), axis=2)
167 | return transform # (B, 3, 4)
168 |
169 |
170 | def mge_transform_pose(pose_old, pose_new):
171 | quat_old, translate_old = pose_old[:, :4], pose_old[:, 4:]
172 | quat_new, translate_new = pose_new[:, :4], pose_new[:, 4:]
173 |
174 | quat = mge_qmul(quat_old, quat_new)
175 | translate = mge_qrot(quat_new, translate_old) + translate_new
176 | pose = F.concat((quat, translate), axis=1)
177 |
178 | return pose
179 |
180 |
181 | # TODO: check
182 | def mge_qinv(q):
183 | # expectes q in (w,x,y,z) format
184 | w = q[:, 0:1]
185 | v = q[:, 1:]
186 | inv = F.concat([w, -v], axis=1)
187 | return inv
188 |
189 |
190 | def mge_quat_rotate(point_cloud, pose_7d):
191 | ndim = point_cloud.ndim
192 | if ndim == 2:
193 | N, _ = point_cloud.shape
194 | assert pose_7d.shape[0] == 1
195 | # repeat transformation vector for each point in shape
196 | quat = pose_7d[:, 0:4].expand([N, 1])
197 | rotated_point_cloud = mge_qrot(quat, point_cloud)
198 |
199 | elif ndim == 3:
200 | B, N, _ = point_cloud.shape
201 | quat = F.tile(F.expand_dims(pose_7d[:, 0:4], axis=1), (1, N, 1))
202 | rotated_point_cloud = mge_qrot(quat, point_cloud)
203 |
204 | else:
205 | raise RuntimeError("point cloud dim must be 2 or 3 !")
206 |
207 | return rotated_point_cloud
208 |
209 |
210 | def mge_quat_transform(pose_7d, pc, normal=None):
211 | pc_t = mge_quat_rotate(pc, pose_7d) + pose_7d[:, 4:].reshape(-1, 1, 3) # Ps" = R*Ps + t
212 | if normal is not None:
213 | normal_t = mge_quat_rotate(normal, pose_7d)
214 | return pc_t, normal_t
215 | else:
216 | return pc_t
217 |
218 |
219 | def np_qmul(q, r):
220 | q = torch.from_numpy(q).contiguous()
221 | r = torch.from_numpy(r).contiguous()
222 | return torch_qmul(q, r).numpy()
223 |
224 |
225 | def np_qrot(q, v):
226 | q = torch.from_numpy(q).contiguous()
227 | v = torch.from_numpy(v).contiguous()
228 | return torch_qrot(q, v).numpy()
229 |
230 |
231 | def np_quat2euler(q, order, epsilon=0, use_gpu=False):
232 | if use_gpu:
233 | q = torch.from_numpy(q).cuda()
234 | return torch_quat2euler(q, order, epsilon).cpu().numpy()
235 | else:
236 | q = torch.from_numpy(q).contiguous()
237 | return torch_quat2euler(q, order, epsilon).numpy()
238 |
239 |
240 | def np_qfix(q):
241 | """
242 | Enforce quaternion continuity across the time dimension by selecting
243 | the representation (q or -q) with minimal euclidean_distance (or, equivalently, maximal dot product)
244 | between two consecutive frames.
245 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
246 | Returns a tensor of the same shape.
247 | """
248 | assert len(q.shape) == 3
249 | assert q.shape[-1] == 4
250 |
251 | result = q.copy()
252 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
253 | mask = dot_products < 0
254 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
255 | result[1:][mask] *= -1
256 | return result
257 |
258 |
259 | def np_expmap2quat(e):
260 | """
261 | Convert axis-angle rotations (aka exponential maps) to quaternions.
262 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
263 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
264 | Returns a tensor of shape (*, 4).
265 | """
266 | assert e.shape[-1] == 3
267 |
268 | original_shape = list(e.shape)
269 | original_shape[-1] = 4
270 | e = e.reshape(-1, 3)
271 |
272 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
273 | w = np.cos(0.5 * theta).reshape(-1, 1)
274 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
275 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
276 |
277 |
278 | def np_euler2quat(e, order):
279 | """
280 | Convert Euler angles to quaternions.
281 | """
282 | assert e.shape[-1] == 3
283 |
284 | original_shape = list(e.shape)
285 | original_shape[-1] = 4
286 |
287 | e = e.reshape(-1, 3)
288 |
289 | x = e[:, 0]
290 | y = e[:, 1]
291 | z = e[:, 2]
292 |
293 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
294 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
295 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
296 |
297 | result = None
298 | for coord in order:
299 | if coord == "x":
300 | r = rx
301 | elif coord == "y":
302 | r = ry
303 | elif coord == "z":
304 | r = rz
305 | else:
306 | raise
307 | if result is None:
308 | result = r
309 | else:
310 | result = np_qmul(result, r)
311 |
312 | # Reverse antipodal representation to have a non-negative "w"
313 | if order in ["xyz", "yzx", "zxy"]:
314 | result *= -1
315 |
316 | return result.reshape(original_shape)
317 |
--------------------------------------------------------------------------------
/common/se3.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import megengine.functional as F
3 | import transforms3d.quaternions as t3d
4 | from scipy.spatial.transform import Rotation
5 |
6 |
7 | def mge_inverse(g):
8 | """ Returns the inverse of the SE3 transform
9 |
10 | Args:
11 | g: (B, 3/4, 4) transform
12 |
13 | Returns:
14 | (B, 3, 4) matrix containing the inverse
15 |
16 | """
17 | # Compute inverse
18 | rot = g[..., 0:3, 0:3]
19 | trans = g[..., 0:3, 3]
20 | inverse_transform = F.concat([rot.transpose(0, 2, 1), F.matmul(rot.transpose(0, 2, 1), F.expand_dims(-trans, axis=-1))], axis=-1)
21 |
22 | return inverse_transform
23 |
24 |
25 | def mge_concatenate(a, b):
26 | """Concatenate two SE3 transforms,
27 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix)
28 |
29 | Args:
30 | a: (B, 3/4, 4)
31 | b: (B, 3/4, 4)
32 |
33 | Returns:
34 | (B, 3/4, 4)
35 | """
36 |
37 | rot1 = a[..., :3, :3]
38 | trans1 = a[..., :3, 3]
39 | rot2 = b[..., :3, :3]
40 | trans2 = b[..., :3, 3]
41 |
42 | rot_cat = F.matmul(rot1, rot2)
43 | trans_cat = F.matmul(rot1, F.expand_dims(trans2, axis=-1)) + F.expand_dims(trans1, axis=-1)
44 | concatenated = F.concat([rot_cat, trans_cat], axis=-1)
45 |
46 | return concatenated
47 |
48 |
49 | def mge_transform(g, a, normals=None):
50 | """ Applies the SE3 transform
51 |
52 | Args:
53 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
54 | a: Points to be transformed (N, 3) or (B, N, 3)
55 | normals: (Optional). If provided, normals will be transformed
56 |
57 | Returns:
58 | transformed points of size (N, 3) or (B, N, 3)
59 |
60 | """
61 | R = g[..., :3, :3] # (B, 3, 3)
62 | p = g[..., :3, 3] # (B, 3)
63 |
64 | if len(g.shape) == len(a.shape):
65 | b = F.matmul(a, R.transpose(0, 2, 1)) + F.expand_dims(p, axis=1)
66 | else:
67 | raise NotImplementedError
68 |
69 | if normals is not None:
70 | rotated_normals = F.matmul(normals, R.transpose(0, 2, 1))
71 | return b, rotated_normals
72 |
73 | else:
74 | return b
75 |
76 |
77 | def np_identity():
78 | return np.eye(3, 4)
79 |
80 |
81 | def np_transform(g: np.ndarray, pts: np.ndarray):
82 | """ Applies the SE3 transform
83 |
84 | Args:
85 | g: SE3 transformation matrix of size ([B,] 3/4, 4)
86 | pts: Points to be transformed ([B,] N, 3)
87 |
88 | Returns:
89 | transformed points of size (N, 3)
90 | """
91 | rot = g[..., :3, :3] # (3, 3)
92 | trans = g[..., :3, 3] # (3)
93 |
94 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :]
95 | return transformed
96 |
97 |
98 | def np_inverse(g: np.ndarray):
99 | """Returns the inverse of the SE3 transform
100 |
101 | Args:
102 | g: ([B,] 3/4, 4) transform
103 |
104 | Returns:
105 | ([B,] 3/4, 4) matrix containing the inverse
106 |
107 | """
108 | rot = g[..., :3, :3] # (3, 3)
109 | trans = g[..., :3, 3] # (3)
110 |
111 | inv_rot = np.swapaxes(rot, -1, -2)
112 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1)
113 | if g.shape[-2] == 4:
114 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2)
115 |
116 | return inverse_transform
117 |
118 |
119 | def np_concatenate(a: np.ndarray, b: np.ndarray):
120 | """ Concatenate two SE3 transforms
121 |
122 | Args:
123 | a: First transform ([B,] 3/4, 4)
124 | b: Second transform ([B,] 3/4, 4)
125 |
126 | Returns:
127 | a*b ([B, ] 3/4, 4)
128 |
129 | """
130 |
131 | r_a, t_a = a[..., :3, :3], a[..., :3, 3]
132 | r_b, t_b = b[..., :3, :3], b[..., :3, 3]
133 |
134 | r_ab = r_a @ r_b
135 | t_ab = r_a @ t_b[..., None] + t_a[..., None]
136 |
137 | concatenated = np.concatenate([r_ab, t_ab], axis=-1)
138 |
139 | if a.shape[-2] == 4:
140 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2)
141 |
142 | return concatenated
143 |
144 |
145 | def np_from_xyzquat(xyzquat):
146 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw
147 |
148 | Args:
149 | xyzquat: np.array (7,) containing translation and quaterion
150 |
151 | Returns:
152 | SE3 matrix (4, 4)
153 | """
154 | rot = Rotation.from_quat(xyzquat[3:])
155 | trans = rot.apply(-xyzquat[:3])
156 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1)
157 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
158 |
159 | return transform
160 |
161 |
162 | def np_mat2quat(transform):
163 | rotate = transform[:3, :3]
164 | translate = transform[:3, 3]
165 | quat = t3d.mat2quat(rotate)
166 | pose = np.concatenate([quat, translate], axis=0)
167 | return pose # (7, )
168 |
169 |
170 | def np_quat2mat(pose):
171 | # Separate each quaternion value.
172 | q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3]
173 | # Convert quaternion to rotation matrix.
174 | # Ref: http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf
175 | # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay.
176 | R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3
177 | R12 = 2 * (q1 * q2 - q0 * q3)
178 | R13 = 2 * (q1 * q3 + q0 * q2)
179 | R21 = 2 * (q1 * q2 + q0 * q3)
180 | R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3
181 | R23 = 2 * (q2 * q3 - q0 * q1)
182 | R31 = 2 * (q1 * q3 - q0 * q2)
183 | R32 = 2 * (q2 * q3 + q0 * q1)
184 | R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2
185 | R = np.stack((np.stack((R11, R12, R13), axis=0), np.stack((R21, R22, R23), axis=0), np.stack((R31, R32, R33), axis=0)), axis=0)
186 |
187 | rot_mat = R.transpose((2, 0, 1)) # (B, 3, 3)
188 | translation = pose[:, 4:][:, :, None] # (B, 3, 1)
189 | transform = np.concatenate((rot_mat, translation), axis=2)
190 | return transform # (B, 3, 4)
191 |
--------------------------------------------------------------------------------
/common/so3.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import megengine as mge
3 | from scipy.spatial.transform import Rotation
4 |
5 |
6 | def np_dcm2euler(mats: np.ndarray, seq: str = "zyx", degrees: bool = True):
7 | """Converts rotation matrix to euler angles
8 |
9 | Args:
10 | mats: (B, 3, 3) containing the B rotation matricecs
11 | seq: Sequence of euler rotations (default: "zyx")
12 | degrees (bool): If true (default), will return in degrees instead of radians
13 |
14 | Returns:
15 |
16 | """
17 |
18 | eulers = []
19 | for i in range(mats.shape[0]):
20 | r = Rotation.from_matrix(mats[i])
21 | eulers.append(r.as_euler(seq, degrees=degrees))
22 | return np.stack(eulers)
23 |
24 |
25 | def np_transform(g: np.ndarray, pts: np.ndarray):
26 | """ Applies the SO3 transform
27 |
28 | Args:
29 | g: SO3 transformation matrix of size (B, 3, 3)
30 | pts: Points to be transformed (B, N, 3)
31 |
32 | Returns:
33 | transformed points of size (B, N, 3)
34 |
35 | """
36 | rot = g[..., :3, :3] # (3, 3)
37 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2)
38 | return transformed
39 |
40 |
41 | def np_inverse(g: np.ndarray):
42 | """Returns the inverse of the SE3 transform
43 |
44 | Args:
45 | g: ([B,] 3/4, 4) transform
46 |
47 | Returns:
48 | ([B,] 3/4, 4) matrix containing the inverse
49 |
50 | """
51 | rot = g[..., :3, :3] # (3, 3)
52 |
53 | inv_rot = np.swapaxes(rot, -1, -2)
54 |
55 | return inv_rot
56 |
57 |
58 | def mge_dcm2euler(mats, seq, degrees=True):
59 | mats = mats.numpy()
60 | eulers = []
61 | for i in range(mats.shape[0]):
62 | r = Rotation.from_matrix(mats[i])
63 | eulers.append(r.as_euler(seq, degrees=degrees))
64 | return mge.tensor(np.stack(eulers))
65 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import megengine as mge
4 | import coloredlogs
5 |
6 |
7 | class Params():
8 | """Class that loads hyperparameters from a json file.
9 |
10 | Example:
11 | ```
12 | params = Params(json_path)
13 | print(params.learning_rate)
14 | params.learning_rate = 0.5 # change the value of learning_rate in params
15 | ```
16 | """
17 | def __init__(self, json_path):
18 | with open(json_path) as f:
19 | params = json.load(f)
20 | self.update(params)
21 |
22 | def save(self, json_path):
23 | with open(json_path, 'w') as f:
24 | json.dump(self.__dict__, f, indent=4)
25 |
26 | def update(self, dict):
27 | """Loads parameters from json file"""
28 | self.__dict__.update(dict)
29 |
30 | @property
31 | def dict(self):
32 | """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
33 | return self.__dict__
34 |
35 |
36 | class RunningAverage():
37 | """A simple class that maintains the running average of a quantity
38 |
39 | Example:
40 | ```
41 | loss_avg = RunningAverage()
42 | loss_avg.update(2)
43 | loss_avg.update(4)
44 | loss_avg() = 3
45 | ```
46 | """
47 | def __init__(self):
48 | self.steps = 0
49 | self.total = 0
50 |
51 | def update(self, val):
52 | self.total += val
53 | self.steps += 1
54 |
55 | def __call__(self):
56 | return self.total / float(self.steps)
57 |
58 |
59 | class AverageMeter():
60 | def __init__(self):
61 | self.reset()
62 |
63 | def reset(self):
64 | self.val = 0
65 | self.val_previous = 0
66 | self.avg = 0
67 | self.sum = 0
68 | self.count = 0
69 |
70 | def set(self, val):
71 | self.val = val
72 | self.avg = val
73 |
74 | def update(self, val, num):
75 | self.val_previous = self.val
76 | self.val = val
77 | self.sum += val * num
78 | self.count += num
79 | self.avg = self.sum / self.count
80 |
81 |
82 | def loss_meter_manager_intial(loss_meter_names):
83 | # 用于根据meter名字初始化需要用到的loss_meter
84 | loss_meters = []
85 | for name in loss_meter_names:
86 | exec("%s = %s" % (name, 'AverageMeter()'))
87 | exec("loss_meters.append(%s)" % name)
88 |
89 | return loss_meters
90 |
91 |
92 | def tensor_mge(batch, check_on=True):
93 | if check_on:
94 | for k, v in batch.items():
95 | batch[k] = mge.Tensor(v)
96 | else:
97 | for k, v in batch.items():
98 | batch[k] = v.numpy()
99 | return batch
100 |
101 |
102 | def set_logger(log_path):
103 | """Set the logger to log info in terminal and file `log_path`.
104 |
105 | In general, it is useful to have a logger so that every output to the terminal is saved
106 | in a permanent file. Here we save it to `model_dir/train.log`.
107 |
108 | Example:
109 | ```
110 | logging.info("Starting training...")
111 | ```
112 |
113 | Args:
114 | log_path: (string) where to log
115 | """
116 | logger = logging.getLogger()
117 | logger.setLevel(logging.INFO)
118 |
119 | # if not logger.handlers:
120 | # # Logging to a file
121 | # file_handler = logging.FileHandler(log_path)
122 | # file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
123 | # logger.addHandler(file_handler)
124 | #
125 | # # Logging to console
126 | # stream_handler = logging.StreamHandler()
127 | # stream_handler.setFormatter(logging.Formatter('%(message)s'))
128 | # logger.addHandler(stream_handler)
129 |
130 | coloredlogs.install(level='INFO', logger=logger, fmt='%(asctime)s %(name)s %(message)s')
131 | file_handler = logging.FileHandler(log_path)
132 | log_formatter = logging.Formatter('%(asctime)s - %(message)s')
133 | file_handler.setFormatter(log_formatter)
134 | logger.addHandler(file_handler)
135 | master_logger(logger, 'Output and logs will be saved to {}'.format(log_path))
136 | return logger
137 |
138 |
139 | def save_dict_to_json(d, json_path):
140 | """Saves dict of floats in json file
141 |
142 | Args:
143 | d: (dict) of float-castable values (np.float, int, float, etc.)
144 | json_path: (string) path to json file
145 | """
146 | save_dict = {}
147 | with open(json_path, "w") as f:
148 | # We need to convert the values to float for json (it doesn"t accept np.array, np.float, )
149 | for k, v in d.items():
150 | if isinstance(v, AverageMeter):
151 | save_dict[k] = float(v.avg)
152 | else:
153 | save_dict[k] = float(v)
154 | json.dump(save_dict, f, indent=4)
155 |
156 |
157 | def master_logger(logger, info, is_master=False):
158 | if is_master:
159 | logger.info(info)
160 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/data_loader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | import h5py
7 |
8 | from megengine.data import DataLoader
9 | from megengine.data.dataset import Dataset
10 | from megengine.data.sampler import RandomSampler, SequentialSampler
11 | import megengine.distributed as dist
12 |
13 | from dataset.transformations import fetch_transform
14 | from common import utils
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | class ModelNetNpy(Dataset):
20 | def __init__(self, dataset_path: str, dataset_mode: str, subset: str = "train", categories=None, transform=None):
21 | self._logger = logging.getLogger(self.__class__.__name__)
22 | self._root = dataset_path
23 | self._subset = subset
24 | self._is_master = dist.get_rank() == 0
25 |
26 | metadata_fpath = os.path.join(self._root, "modelnet_{}_{}.pickle".format(dataset_mode, subset))
27 | utils.master_logger(self._logger, "Loading data from {} for {}".format(metadata_fpath, subset), self._is_master)
28 |
29 | if not os.path.exists(os.path.join(dataset_path)):
30 | assert FileNotFoundError("Not found dataset_path: {}".format(dataset_path))
31 |
32 | with open(os.path.join(dataset_path, "shape_names.txt")) as fid:
33 | self._classes = [l.strip() for l in fid]
34 | self._category2idx = {e[1]: e[0] for e in enumerate(self._classes)}
35 | self._idx2category = self._classes
36 |
37 | if categories is not None:
38 | categories_idx = [self._category2idx[c] for c in categories]
39 | utils.master_logger(self._logger, "Categories used: {}.".format(categories_idx), self._is_master)
40 | self._classes = categories
41 | else:
42 | categories_idx = None
43 | utils.master_logger(self._logger, "Using all categories.", self._is_master)
44 |
45 | self._data = self._read_pickle_files(os.path.join(dataset_path, "modelnet_{}_{}.pickle".format(dataset_mode, subset)),
46 | categories_idx)
47 |
48 | self._transform = transform
49 | utils.master_logger(self._logger, "Loaded {} {} instances.".format(len(self._data), subset), self._is_master)
50 |
51 | @property
52 | def classes(self):
53 | return self._classes
54 |
55 | @staticmethod
56 | def _read_pickle_files(fnames, categories):
57 |
58 | all_data_dict = []
59 | with open(fnames, "rb") as f:
60 | data = pickle.load(f)
61 |
62 | for category in categories:
63 | all_data_dict.extend(data[category])
64 |
65 | return all_data_dict
66 |
67 | def to_category(self, i):
68 | return self._idx2category[i]
69 |
70 | def __getitem__(self, item):
71 |
72 | data_path = self._data[item]
73 |
74 | # load and process data
75 | points = np.load(data_path)
76 | idx = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[1]))
77 | label = np.array(int(os.path.splitext(os.path.basename(data_path))[0].split("_")[3]))
78 | sample = {"points": points, "label": label, "idx": idx}
79 |
80 | if self._transform:
81 | sample = self._transform(sample)
82 | return sample
83 |
84 | def __len__(self):
85 | return len(self._data)
86 |
87 |
88 | def fetch_dataloader(params):
89 | utils.master_logger(_logger, "Dataset type: {}, transform type: {}".format(params.dataset_type, params.transform_type),
90 | dist.get_rank() == 0)
91 |
92 | train_transforms, test_transforms = fetch_transform(params)
93 |
94 | if params.dataset_type == "modelnet_os":
95 | dataset_path = "./dataset/data/modelnet_os"
96 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")]
97 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")]
98 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")]
99 | train_categories.sort()
100 | val_categories.sort()
101 | test_categories.sort()
102 | train_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="train", categories=train_categories, transform=train_transforms)
103 | val_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="val", categories=val_categories, transform=test_transforms)
104 | test_ds = ModelNetNpy(dataset_path, dataset_mode="os", subset="test", categories=test_categories, transform=test_transforms)
105 |
106 | elif params.dataset_type == "modelnet_ts":
107 | dataset_path = "./dataset/data/modelnet_ts"
108 | train_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")]
109 | val_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half1_rm_rotate.txt")]
110 | test_categories = [line.rstrip("\n") for line in open("./dataset/data/modelnet40_half2_rm_rotate.txt")]
111 | train_categories.sort()
112 | val_categories.sort()
113 | test_categories.sort()
114 | train_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="train", categories=train_categories, transform=train_transforms)
115 | val_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="val", categories=val_categories, transform=test_transforms)
116 | test_ds = ModelNetNpy(dataset_path, dataset_mode="ts", subset="test", categories=test_categories, transform=test_transforms)
117 |
118 | dataloaders = {}
119 | # add defalt train data loader
120 | train_sampler = RandomSampler(train_ds, batch_size=params.train_batch_size, drop_last=True)
121 | train_dl = DataLoader(train_ds, train_sampler, num_workers=params.num_workers)
122 | dataloaders["train"] = train_dl
123 |
124 | # chosse val or test data loader for evaluate
125 | for split in ["val", "test"]:
126 | if split in params.eval_type:
127 | if split == "val":
128 | val_sampler = SequentialSampler(val_ds, batch_size=params.eval_batch_size)
129 | dl = DataLoader(val_ds, val_sampler, num_workers=params.num_workers)
130 | elif split == "test":
131 | test_sampler = SequentialSampler(test_ds, batch_size=params.eval_batch_size)
132 | dl = DataLoader(test_ds, test_sampler, num_workers=params.num_workers)
133 | else:
134 | raise ValueError("Unknown eval_type in params, should in [val, test]")
135 | dataloaders[split] = dl
136 | else:
137 | dataloaders[split] = None
138 |
139 | return dataloaders
140 |
--------------------------------------------------------------------------------
/dataset/transformations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import megengine as mge
4 | import megengine.distributed as dist
5 | import numpy as np
6 | from common import se3, so3, utils
7 | from scipy.spatial.transform import Rotation
8 | from megengine.data.transform import Transform
9 |
10 | _logger = logging.getLogger(__name__)
11 |
12 |
13 | def uniform_2_sphere(num: int = None):
14 | """Uniform sampling on a 2-sphere
15 |
16 | Source: https://gist.github.com/andrewbolster/10274979
17 |
18 | Args:
19 | num: Number of vectors to sample (or None if single)
20 |
21 | Returns:
22 | Random Vector (np.ndarray) of size (num, 3) with norm 1.
23 | If num is None returned value will have size (3,)
24 |
25 | """
26 | if num is not None:
27 | phi = np.random.uniform(0.0, 2 * np.pi, num)
28 | cos_theta = np.random.uniform(-1.0, 1.0, num)
29 | else:
30 | phi = np.random.uniform(0.0, 2 * np.pi)
31 | cos_theta = np.random.uniform(-1.0, 1.0)
32 |
33 | theta = np.arccos(cos_theta)
34 | x = np.sin(theta) * np.cos(phi)
35 | y = np.sin(theta) * np.sin(phi)
36 | z = np.cos(theta)
37 |
38 | return np.stack((x, y, z), axis=-1)
39 |
40 |
41 | class SplitSourceRef(Transform):
42 | """Clones the point cloud into separate source and reference point clouds"""
43 | def __init__(self, mode="os"):
44 | self.mode = mode
45 |
46 | def apply(self, sample):
47 | if "deterministic" in sample and sample["deterministic"]:
48 | np.random.seed(sample["idx"])
49 |
50 | if self.mode == "os":
51 | sample["points_raw"] = sample.pop("points").astype(np.float32)[:, :3]
52 | sample["points_src"] = sample["points_raw"].copy()
53 | sample["points_ref"] = sample["points_raw"].copy()
54 | sample["points_src_raw"] = sample["points_src"].copy().astype(np.float32)
55 | sample["points_ref_raw"] = sample["points_ref"].copy().astype(np.float32)
56 | elif self.mode == "ts":
57 | points_raw = sample.pop("points").astype(np.float32)
58 | points_raw = points_raw[np.random.choice(points_raw.shape[0], 2, replace=False), :, :]
59 | sample["points_src"] = points_raw[0, :, :].astype(np.float32)
60 | sample["points_ref"] = points_raw[1, :, :].astype(np.float32)
61 | sample["points_src_raw"] = sample["points_src"].copy()
62 | sample["points_ref_raw"] = sample["points_ref"].copy()
63 |
64 | else:
65 | raise NotImplementedError
66 |
67 | return sample
68 |
69 |
70 | class Resampler(Transform):
71 | def __init__(self, num: int):
72 | """Resamples a point cloud containing N points to one containing M
73 |
74 | Guaranteed to have no repeated points if M <= N.
75 | Otherwise, it is guaranteed that all points appear at least once.
76 |
77 | Args:
78 | num (int): Number of points to resample to, i.e. M
79 |
80 | """
81 | self.num = num
82 |
83 | @staticmethod
84 | def _resample(points, k):
85 | """Resamples the points such that there is exactly k points.
86 |
87 | If the input point cloud has <= k points, it is guaranteed the
88 | resampled point cloud contains every point in the input.
89 | If the input point cloud has > k points, it is guaranteed the
90 | resampled point cloud does not contain repeated point.
91 | """
92 | # print("===", points.shape[0], k)
93 | if k < points.shape[0]:
94 | rand_idxs = np.random.choice(points.shape[0], k, replace=False)
95 | return points[rand_idxs, :]
96 | elif points.shape[0] == k:
97 | return points
98 | else:
99 | rand_idxs = np.concatenate([
100 | np.random.choice(points.shape[0], points.shape[0], replace=False),
101 | np.random.choice(points.shape[0], k - points.shape[0], replace=True)
102 | ])
103 | return points[rand_idxs, :]
104 |
105 | def apply(self, sample):
106 |
107 | if "deterministic" in sample and sample["deterministic"]:
108 | np.random.seed(sample["idx"])
109 |
110 | if "points" in sample:
111 | sample["points"] = self._resample(sample["points"], self.num)
112 | else:
113 | if "crop_proportion" not in sample:
114 | src_size, ref_size = self.num, self.num
115 | elif len(sample["crop_proportion"]) == 1:
116 | src_size = math.ceil(sample["crop_proportion"][0] * self.num)
117 | ref_size = self.num
118 | elif len(sample["crop_proportion"]) == 2:
119 | src_size = math.ceil(sample["crop_proportion"][0] * self.num)
120 | ref_size = math.ceil(sample["crop_proportion"][1] * self.num)
121 | else:
122 | raise ValueError("Crop proportion must have 1 or 2 elements")
123 |
124 | sample["points_src"] = self._resample(sample["points_src"], src_size)
125 | sample["points_ref"] = self._resample(sample["points_ref"], ref_size)
126 |
127 | # sample for the raw point clouds
128 | sample["points_src_raw"] = sample["points_src_raw"][:self.num, :]
129 | sample["points_ref_raw"] = sample["points_ref_raw"][:self.num, :]
130 |
131 | return sample
132 |
133 |
134 | class RandomJitter(Transform):
135 | """ generate perturbations """
136 | def __init__(self, noise_std=0.01, clip=0.05):
137 | self.noise_std = noise_std
138 | self.clip = clip
139 |
140 | def jitter(self, pts):
141 |
142 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip)
143 | pts[:, :3] += noise # Add noise to xyz
144 |
145 | return pts
146 |
147 | def apply(self, sample):
148 |
149 | if "points" in sample:
150 | sample["points"] = self.jitter(sample["points"])
151 | else:
152 | sample["points_src"] = self.jitter(sample["points_src"])
153 | sample["points_ref"] = self.jitter(sample["points_ref"])
154 |
155 | return sample
156 |
157 |
158 | class RandomCrop(Transform):
159 | """Randomly crops the *source* point cloud, approximately retaining half the points
160 |
161 | A direction is randomly sampled from S2, and we retain points which lie within the
162 | half-space oriented in this direction.
163 | If p_keep != 0.5, we shift the plane until approximately p_keep points are retained
164 | """
165 | def __init__(self, p_keep=None):
166 | if p_keep is None:
167 | p_keep = [0.7, 0.7] # Crop both clouds to 70%
168 | self.p_keep = np.array(p_keep, dtype=np.float32)
169 |
170 | @staticmethod
171 | def crop(points, p_keep):
172 | if p_keep == 1.0:
173 | mask = np.ones(shape=(points.shape[0], )) > 0
174 |
175 | else:
176 | rand_xyz = uniform_2_sphere()
177 | centroid = np.mean(points[:, :3], axis=0)
178 | points_centered = points[:, :3] - centroid
179 | dist_from_plane = np.dot(points_centered, rand_xyz)
180 |
181 | if p_keep == 0.5:
182 | mask = dist_from_plane > 0
183 | else:
184 | mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100)
185 |
186 | return points[mask, :]
187 |
188 | def apply(self, sample):
189 |
190 | if "deterministic" in sample and sample["deterministic"]:
191 | np.random.seed(sample["idx"])
192 |
193 | sample["crop_proportion"] = self.p_keep
194 |
195 | if len(sample["crop_proportion"]) == 1:
196 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0])
197 | sample["points_ref"] = self.crop(sample["points_ref"], 1.0)
198 | else:
199 | sample["points_src"] = self.crop(sample["points_src"], self.p_keep[0])
200 | sample["points_ref"] = self.crop(sample["points_ref"], self.p_keep[1])
201 |
202 | return sample
203 |
204 |
205 | class RandomTransformSE3(Transform):
206 | def __init__(self, rot_mag: float = 180.0, trans_mag: float = 1.0, random_mag: bool = False):
207 | """Applies a random rigid transformation to the source point cloud
208 |
209 | Args:
210 | rot_mag (float): Maximum rotation in degrees
211 | trans_mag (float): Maximum translation T. Random translation will
212 | be in the range [-X,X] in each axis
213 | random_mag (bool): If true, will randomize the maximum rotation, i.e. will bias towards small
214 | perturbations
215 | """
216 | self._rot_mag = rot_mag
217 | self._trans_mag = trans_mag
218 | self._random_mag = random_mag
219 |
220 | def generate_transform(self):
221 | """Generate a random SE3 transformation (3, 4) """
222 |
223 | if self._random_mag:
224 | attentuation = np.random.random()
225 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag
226 | else:
227 | rot_mag, trans_mag = self._rot_mag, self._trans_mag
228 |
229 | # Generate rotation
230 | rand_rot = special_ortho_group.rvs(3)
231 | axis_angle = Rotation.as_rotvec(Rotation.from_dcm(rand_rot))
232 | axis_angle *= rot_mag / 180.0
233 | rand_rot = Rotation.from_rotvec(axis_angle).as_dcm()
234 |
235 | # Generate translation
236 | rand_trans = np.random.uniform(-trans_mag, trans_mag, 3)
237 | rand_SE3 = np.concatenate((rand_rot, rand_trans[:, None]), axis=1).astype(np.float32)
238 |
239 | return rand_SE3
240 |
241 | def apply_transform(self, p0, transform_mat):
242 | p1 = se3.np_transform(transform_mat, p0[:, :3])
243 | if p0.shape[1] == 6: # Need to rotate normals also
244 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6])
245 | p1 = np.concatenate((p1, n1), axis=-1)
246 |
247 | igt = transform_mat
248 | gt = se3.np_inverse(igt)
249 |
250 | return p1, gt, igt
251 |
252 | def transform(self, tensor):
253 | transform_mat = self.generate_transform()
254 | return self.apply_transform(tensor, transform_mat)
255 |
256 | def apply(self, sample):
257 |
258 | if "deterministic" in sample and sample["deterministic"]:
259 | np.random.seed(sample["idx"])
260 |
261 | if "points" in sample:
262 | sample["points"], _, _ = self.transform(sample["points"])
263 | else:
264 | src_transformed, transform_r_s, transform_s_r = self.transform(sample["points_src"])
265 | # Apply to source to get reference
266 | sample["transform_gt"] = transform_r_s
267 | sample["pose_gt"] = se3.np_mat2quat(transform_r_s)
268 | sample["transform_igt"] = transform_s_r
269 | sample["points_src"] = src_transformed
270 | # transnform the raw source point cloud
271 | sample["points_src_raw"] = se3.np_transform(transform_s_r, sample["points_src_raw"][:, :3])
272 |
273 | return sample
274 |
275 |
276 | # noinspection PyPep8Naming
277 | class RandomTransformSE3_euler(RandomTransformSE3):
278 | """Same as RandomTransformSE3, but rotates using euler angle rotations
279 |
280 | This transformation is consistent to Deep Closest Point but does not
281 | generate uniform rotations
282 |
283 | """
284 | def generate_transform(self):
285 |
286 | if self._random_mag:
287 | attentuation = np.random.random()
288 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag
289 | else:
290 | rot_mag, trans_mag = self._rot_mag, self._trans_mag
291 |
292 | # Generate rotation
293 | anglex = np.random.uniform() * np.pi * rot_mag / 180.0
294 | angley = np.random.uniform() * np.pi * rot_mag / 180.0
295 | anglez = np.random.uniform() * np.pi * rot_mag / 180.0
296 |
297 | cosx = np.cos(anglex)
298 | cosy = np.cos(angley)
299 | cosz = np.cos(anglez)
300 | sinx = np.sin(anglex)
301 | siny = np.sin(angley)
302 | sinz = np.sin(anglez)
303 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]])
304 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]])
305 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]])
306 | R_ab = Rx @ Ry @ Rz
307 | t_ab = np.random.uniform(-trans_mag, trans_mag, 3)
308 |
309 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32)
310 | return rand_SE3
311 |
312 |
313 | class ShufflePoints(Transform):
314 | """Shuffles the order of the points"""
315 | def apply(self, sample):
316 | if "points" in sample:
317 | sample["points"] = np.random.permutation(sample["points"])
318 | else:
319 | sample["points_ref"] = np.random.permutation(sample["points_ref"])
320 | sample["points_src"] = np.random.permutation(sample["points_src"])
321 | return sample
322 |
323 |
324 | class SetDeterministic(Transform):
325 | """Adds a deterministic flag to the sample such that subsequent transforms
326 | use a fixed random seed where applicable. Used for test"""
327 | def apply(self, sample):
328 | sample["deterministic"] = True
329 | return sample
330 |
331 |
332 | class PRNet(Transform):
333 | def __init__(self, num_points, rot_mag, trans_mag, noise_std=0.01, clip=0.05, add_noise=True, only_z=False, partial=True):
334 | self.num_points = num_points
335 | self.rot_mag = rot_mag
336 | self.trans_mag = trans_mag
337 | self.noise_std = noise_std
338 | self.clip = clip
339 | self.add_noise = add_noise
340 | self.only_z = only_z
341 | self.partial = partial
342 |
343 | def apply_transform(self, p0, transform_mat):
344 | p1 = se3.np_transform(transform_mat, p0[:, :3])
345 | if p0.shape[1] == 6: # Need to rotate normals also
346 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6])
347 | p1 = np.concatenate((p1, n1), axis=-1)
348 |
349 | gt = transform_mat
350 |
351 | return p1, gt
352 |
353 | def jitter(self, pts):
354 | noise = np.clip(np.random.normal(0.0, scale=self.noise_std, size=(pts.shape[0], 3)), a_min=-self.clip, a_max=self.clip)
355 | pts[:, :3] += noise # Add noise to xyz
356 |
357 | return pts
358 |
359 | def knn(self, pts, random_pt, k):
360 | distance = np.sum((pts - random_pt)**2, axis=1)
361 | idx = np.argsort(distance)[:k] # (k,)
362 | return idx
363 |
364 | def apply(self, sample):
365 |
366 | if "deterministic" in sample and sample["deterministic"]:
367 | np.random.seed(sample["idx"])
368 |
369 | src = sample["points_src"]
370 | ref = sample["points_ref"]
371 | # Generate rigid transform
372 | anglex = np.random.uniform() * np.pi * self.rot_mag / 180.0
373 | angley = np.random.uniform() * np.pi * self.rot_mag / 180.0
374 | anglez = np.random.uniform() * np.pi * self.rot_mag / 180.0
375 |
376 | cosx = np.cos(anglex)
377 | cosy = np.cos(angley)
378 | cosz = np.cos(anglez)
379 | sinx = np.sin(anglex)
380 | siny = np.sin(angley)
381 | sinz = np.sin(anglez)
382 | Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]])
383 | Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]])
384 | Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]])
385 |
386 | if not self.only_z:
387 | R_ab = Rx @ Ry @ Rz
388 | else:
389 | R_ab = Rz
390 | t_ab = np.random.uniform(-self.trans_mag, self.trans_mag, 3)
391 |
392 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32)
393 | ref, transform_s_r = self.apply_transform(ref, rand_SE3)
394 | # Apply to source to get reference
395 | sample["transform_gt"] = transform_s_r
396 | sample["pose_gt"] = se3.np_mat2quat(transform_s_r)
397 |
398 | # Crop and sample
399 | if self.partial:
400 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
401 | idx1 = self.knn(src, random_p1, k=768)
402 | random_p2 = random_p1
403 | idx2 = self.knn(ref, random_p2, k=768)
404 | else:
405 | idx1 = np.random.choice(src.shape[0], 1024, replace=False),
406 | idx2 = np.random.choice(ref.shape[0], 1024, replace=False),
407 | src = mge.tensor(src)
408 | ref = mge.tensor(ref)
409 |
410 | # add noise
411 | if self.add_noise:
412 | sample["points_src"] = self.jitter(src[idx1, :])
413 | sample["points_ref"] = self.jitter(ref[idx2, :])
414 | else:
415 | sample["points_src"] = src[idx1, :]
416 | sample["points_ref"] = ref[idx2, :]
417 |
418 | return sample
419 |
420 |
421 | class Compose(object):
422 | def __init__(self, transforms):
423 | self.transforms = transforms
424 |
425 | def __call__(self, input):
426 | for t in self.transforms:
427 | input = t.apply(input)
428 | return input
429 |
430 |
431 | def fetch_transform(params):
432 |
433 | if params.transform_type == "modelnet_os_rpmnet_noise":
434 | train_transforms = [
435 | SplitSourceRef(mode="os"),
436 | RandomCrop(params.partial_ratio),
437 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
438 | Resampler(params.num_points),
439 | RandomJitter(),
440 | ShufflePoints()
441 | ]
442 |
443 | test_transforms = [
444 | SetDeterministic(),
445 | SplitSourceRef(mode="os"),
446 | RandomCrop(params.partial_ratio),
447 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
448 | Resampler(params.num_points),
449 | RandomJitter(),
450 | ShufflePoints()
451 | ]
452 |
453 | elif params.transform_type == "modelnet_os_rpmnet_clean":
454 | train_transforms = [
455 | SplitSourceRef(mode="os"),
456 | RandomCrop(params.partial_ratio),
457 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
458 | Resampler(params.num_points),
459 | ShufflePoints()
460 | ]
461 |
462 | test_transforms = [
463 | SetDeterministic(),
464 | SplitSourceRef(mode="os"),
465 | RandomCrop(params.partial_ratio),
466 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
467 | Resampler(params.num_points),
468 | ShufflePoints()
469 | ]
470 |
471 | elif params.transform_type == "modelnet_ts_rpmnet_noise":
472 | train_transforms = [
473 | SplitSourceRef(mode="ts"),
474 | RandomCrop(params.partial_ratio),
475 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
476 | Resampler(params.num_points),
477 | RandomJitter(noise_std=params.noise_std),
478 | ShufflePoints()
479 | ]
480 |
481 | test_transforms = [
482 | SetDeterministic(),
483 | SplitSourceRef(mode="ts"),
484 | RandomCrop(params.partial_ratio),
485 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
486 | Resampler(params.num_points),
487 | RandomJitter(noise_std=params.noise_std),
488 | ShufflePoints()
489 | ]
490 |
491 | elif params.transform_type == "modelnet_ts_rpmnet_clean":
492 | train_transforms = [
493 | SplitSourceRef(mode="ts"),
494 | RandomCrop(params.partial_ratio),
495 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
496 | Resampler(params.num_points),
497 | ShufflePoints()
498 | ]
499 |
500 | test_transforms = [
501 | SetDeterministic(),
502 | SplitSourceRef(mode="ts"),
503 | RandomCrop(params.partial_ratio),
504 | RandomTransformSE3_euler(rot_mag=params.rot_mag, trans_mag=params.trans_mag),
505 | Resampler(params.num_points),
506 | ShufflePoints()
507 | ]
508 |
509 | elif params.transform_type == "modelnet_ts_prnet_noise":
510 | train_transforms = [
511 | SplitSourceRef(mode="ts"),
512 | ShufflePoints(),
513 | PRNet(num_points=params.num_points,
514 | rot_mag=params.rot_mag,
515 | trans_mag=params.trans_mag,
516 | noise_std=params.noise_std,
517 | add_noise=True)
518 | ]
519 |
520 | test_transforms = [
521 | SetDeterministic(),
522 | SplitSourceRef(mode="ts"),
523 | ShufflePoints(),
524 | PRNet(num_points=params.num_points,
525 | rot_mag=params.rot_mag,
526 | trans_mag=params.trans_mag,
527 | noise_std=params.noise_std,
528 | add_noise=True)
529 | ]
530 |
531 | elif params.transform_type == "modelnet_ts_prnet_clean":
532 | train_transforms = [
533 | SplitSourceRef(mode="ts"),
534 | ShufflePoints(),
535 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False)
536 | ]
537 |
538 | test_transforms = [
539 | SetDeterministic(),
540 | SplitSourceRef(mode="ts"),
541 | ShufflePoints(),
542 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False)
543 | ]
544 |
545 | elif params.transform_type == "modelnet_os_prnet_noise":
546 | train_transforms = [
547 | SplitSourceRef(mode="os"),
548 | ShufflePoints(),
549 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True)
550 | ]
551 |
552 | test_transforms = [
553 | SetDeterministic(),
554 | SplitSourceRef(mode="os"),
555 | ShufflePoints(),
556 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=True)
557 | ]
558 |
559 | elif params.transform_type == "modelnet_os_prnet_clean":
560 | train_transforms = [
561 | SplitSourceRef(mode="os"),
562 | ShufflePoints(),
563 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False)
564 | ]
565 |
566 | test_transforms = [
567 | SetDeterministic(),
568 | SplitSourceRef(mode="os"),
569 | ShufflePoints(),
570 | PRNet(num_points=params.num_points, rot_mag=params.rot_mag, trans_mag=params.trans_mag, add_noise=False)
571 | ]
572 |
573 | utils.master_logger(_logger, "Train transforms: {}".format(", ".join([type(t).__name__ for t in train_transforms])),
574 | dist.get_rank() == 0)
575 | utils.master_logger(_logger, "Val and Test transforms: {}".format(", ".join([type(t).__name__ for t in test_transforms])),
576 | dist.get_rank() == 0)
577 | train_transforms = Compose(train_transforms)
578 | test_transforms = Compose(test_transforms)
579 | return train_transforms, test_transforms
580 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import dataset.data_loader as data_loader
6 |
7 | import model.net as net
8 |
9 | from common import utils
10 | from loss.losses import compute_losses, compute_metrics
11 | from common.manager import Manager
12 | import megengine.distributed as dist
13 | import megengine.functional as F
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model_dir", default="experiments/base_model", help="Directory containing params.json")
17 | parser.add_argument("--restore_file", default="best", help="name of the file in --model_dir containing weights to load")
18 |
19 |
20 | def evaluate(model, manager):
21 | rank = dist.get_rank()
22 | world_size = dist.get_world_size()
23 | """Evaluate the model on `num_steps` batches.
24 |
25 | Args:
26 | model: (torch.nn.Module) the neural network
27 | manager: a class instance that contains objects related to train and evaluate.
28 | """
29 | # set model to evaluation mode
30 | model.eval()
31 |
32 | # compute metrics over the dataset
33 | if manager.dataloaders["val"] is not None:
34 | # loss status and val status initial
35 | manager.reset_loss_status()
36 | manager.reset_metric_status("val")
37 | for data_batch in manager.dataloaders["val"]:
38 | # compute the real batch size
39 | bs = data_batch["points_src"].shape[0]
40 | # move to GPU if available
41 | data_batch = utils.tensor_mge(data_batch)
42 | # compute model output
43 | output_batch = model(data_batch)
44 | # compute all loss on this batch
45 | loss = compute_losses(output_batch, manager.params)
46 | metrics = compute_metrics(output_batch, manager.params)
47 | if world_size > 1:
48 | for k, v in loss.items():
49 | loss[k] = F.distributed.all_reduce_sum(v) / world_size
50 | for k, v in metrics.items():
51 | metrics[k] = F.distributed.all_reduce_sum(v) / world_size
52 | manager.update_loss_status(loss, "val", bs)
53 | # compute all metrics on this batch
54 | manager.update_metric_status(metrics, "val", bs)
55 |
56 | # update val data to tensorboard
57 | if rank == 0:
58 | # compute RMSE metrics
59 | manager.summarize_metric_status(metrics, "val")
60 |
61 | manager.writer.add_scalar("Loss/val", manager.loss_status["total"].avg, manager.epoch)
62 | # manager.logger.info("Loss/valid epoch {}: {:.4f}".format(manager.epoch, manager.loss_status["total"].avg))
63 | for k, v in manager.val_status.items():
64 | manager.writer.add_scalar("Metric/val/{}".format(k), v.avg, manager.epoch)
65 | # For each epoch, print the metric
66 | manager.print_metrics("val", title="Val", color="green")
67 |
68 | if manager.dataloaders["test"] is not None:
69 | # loss status and val status initial
70 | manager.reset_loss_status()
71 | manager.reset_metric_status("test")
72 | for data_batch in manager.dataloaders["test"]:
73 | # compute the real batch size
74 | bs = data_batch["points_src"].shape[0]
75 | # move to GPU if available
76 | data_batch = utils.tensor_mge(data_batch)
77 | # compute model output
78 | output_batch = model(data_batch)
79 | # compute all loss on this batch
80 | loss = compute_losses(output_batch, manager.params)
81 | metrics = compute_metrics(output_batch, manager.params)
82 | if world_size > 1:
83 | for k, v in loss.items():
84 | loss[k] = F.distributed.all_reduce_sum(v) / world_size
85 | for k, v in metrics.items():
86 | metrics[k] = F.distributed.all_reduce_sum(v) / world_size
87 | manager.update_loss_status(loss, "test", bs)
88 | # compute all metrics on this batch
89 | manager.update_metric_status(metrics, "test", bs)
90 |
91 | # update test data to tensorboard
92 | if rank == 0:
93 | # compute RMSE metrics
94 | manager.summarize_metric_status(metrics, "test")
95 |
96 | manager.writer.add_scalar("Loss/test", manager.loss_status["total"].avg, manager.epoch)
97 | # manager.logger.info("Loss/test epoch {}: {:.4f}".format(manager.epoch, manager.loss_status["total"].avg))
98 | for k, v in manager.val_status.items():
99 | manager.writer.add_scalar("Metric/test/{}".format(k), v.avg, manager.epoch)
100 | # For each epoch, print the metric
101 | manager.print_metrics("test", title="Test", color="red")
102 |
103 |
104 | def test(model, manager):
105 | """Test the model with loading checkpoints.
106 |
107 | Args:
108 | model: (torch.nn.Module) the neural network
109 | manager: a class instance that contains objects related to train and evaluate.
110 | """
111 | # set model to evaluation mode
112 | model.eval()
113 |
114 | # compute metrics over the dataset
115 | if manager.dataloaders["val"] is not None:
116 | # loss status and val status initial
117 | manager.reset_loss_status()
118 | manager.reset_metric_status("val")
119 | for data_batch in manager.dataloaders["val"]:
120 | # compute the real batch size
121 | bs = data_batch["points_src"].shape[0]
122 | # move to GPU if available
123 | data_batch = utils.tensor_mge(data_batch)
124 | # compute model output
125 | output_batch = model(data_batch)
126 | # compute all loss on this batch
127 | loss = compute_losses(output_batch, manager.params)
128 | manager.update_loss_status(loss, "val", bs)
129 | # compute all metrics on this batch
130 | metrics = compute_metrics(output_batch, manager.params)
131 | manager.update_metric_status(metrics, "val", bs)
132 |
133 | # compute RMSE metrics
134 | manager.summarize_metric_status(metrics, "val")
135 | # For each epoch, update and print the metric
136 | manager.print_metrics("val", title="Val", color="green")
137 |
138 | if manager.dataloaders["test"] is not None:
139 | # loss status and test status initial
140 | manager.reset_loss_status()
141 | manager.reset_metric_status("test")
142 | for data_batch in manager.dataloaders["test"]:
143 | # compute the real batch size
144 | bs = data_batch["points_src"].shape[0]
145 | # move to GPU if available
146 | data_batch = utils.tensor_mge(data_batch)
147 | # compute model output
148 | output_batch = model(data_batch)
149 | # compute all loss on this batch
150 | loss = compute_losses(output_batch, manager.params)
151 | manager.update_loss_status(loss, "test", bs)
152 | # compute all metrics on this batch
153 | metrics = compute_metrics(output_batch, manager.params)
154 | manager.update_metric_status(metrics, "test", bs)
155 |
156 | # compute RMSE metrics
157 | manager.summarize_metric_status(metrics, "test")
158 | # For each epoch, print the metric
159 | manager.print_metrics("test", title="Test", color="red")
160 |
161 |
162 | if __name__ == "__main__":
163 | """
164 | Evaluate the model on the test set.
165 | """
166 | # Load the parameters
167 | args = parser.parse_args()
168 | json_path = os.path.join(args.model_dir, "params.json")
169 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
170 | params = utils.Params(json_path)
171 | # Only load model weights
172 | params.only_weights = True
173 |
174 | # Update args into params
175 | params.update(vars(args))
176 |
177 | # Get the logger
178 | logger = utils.set_logger(os.path.join(args.model_dir, "evaluate.log"))
179 |
180 | # Create the input data pipeline
181 | logging.info("Creating the dataset...")
182 |
183 | # Fetch dataloaders
184 | dataloaders = data_loader.fetch_dataloader(params)
185 |
186 | # Define the model and optimizer
187 | model = net.fetch_net(params)
188 |
189 | # Initial status for checkpoint manager
190 | manager = Manager(model=model, optimizer=None, scheduler=None, params=params, dataloaders=dataloaders, writer=None, logger=logger)
191 |
192 | # Reload weights from the saved file
193 | manager.load_checkpoints()
194 |
195 | # Test the model
196 | logger.info("Starting test")
197 |
198 | # Evaluate
199 | test(model, manager)
200 |
--------------------------------------------------------------------------------
/experiments/experiment_finet/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "modelnet_ts",
3 | "transform_type": "modelnet_ts_rpmnet_noise",
4 | "net_type": "finet",
5 | "net_config": {
6 | "dropout_ratio": 0.3,
7 | "reg_t_feats": "tr-t",
8 | "reg_R_feats": "tr-tr"
9 | },
10 | "loss_type": "finet",
11 | "loss_alpha1": 1,
12 | "loss_alpha2": 4,
13 | "loss_alpha3": 0.001,
14 | "loss_alpha4": 0.0025,
15 | "margin": [
16 | 0.01,
17 | 0.01
18 | ],
19 | "eval_type": [
20 | "val",
21 | "test"
22 | ],
23 | "major_metric": "score",
24 | "metric_rule": "Descende",
25 | "num_points": 1024,
26 | "rot_mag": 45,
27 | "trans_mag": 0.5,
28 | "partial_ratio": [
29 | 0.7,
30 | 0.7
31 | ],
32 | "noise_std": 0.01,
33 | "titer": 4,
34 | "overlap_dist": 0.1,
35 | "learning_rate": 1e-4,
36 | "gamma": 1,
37 | "num_epochs": 10000,
38 | "train_batch_size": 8,
39 | "eval_batch_size": 32,
40 | "save_summary_steps": 100,
41 | "num_workers": 8
42 | }
43 |
--------------------------------------------------------------------------------
/experiments/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "modelnet_ts",
3 | "transform_type": "modelnet_ts_rpmnet_noise",
4 | "net_type": "finet",
5 | "net_config": {
6 | "dropout_ratio": 0.3,
7 | "reg_t_feats": "tr-t",
8 | "reg_R_feats": "tr-tr"
9 | },
10 | "loss_type": "finet",
11 | "loss_alpha1": 1,
12 | "loss_alpha2": 4,
13 | "loss_alpha3": 0.001,
14 | "loss_alpha4": 0.0025,
15 | "margin": [
16 | 0.01,
17 | 0.01
18 | ],
19 | "eval_type": [
20 | "val",
21 | "test"
22 | ],
23 | "major_metric": "score",
24 | "metric_rule": "Descende",
25 | "num_points": 1024,
26 | "rot_mag": 45,
27 | "trans_mag": 0.5,
28 | "partial_ratio": [
29 | 0.7,
30 | 0.7
31 | ],
32 | "noise_std": 0.01,
33 | "titer": 4,
34 | "overlap_dist": 0.1,
35 | "learning_rate": 1e-4,
36 | "gamma": 1,
37 | "num_epochs": 10000,
38 | "train_batch_size": 8,
39 | "eval_batch_size": 32,
40 | "save_summary_steps": 100,
41 | "num_workers": 8
42 | }
43 |
--------------------------------------------------------------------------------
/images/FINet_poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/images/FINet_poster.png
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/loss/__init__.py
--------------------------------------------------------------------------------
/loss/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import megengine.functional as F
3 | from common import se3, so3
4 |
5 |
6 | def compute_losses(endpoints, params):
7 | loss = {}
8 | # compute losses
9 | if params.loss_type == "finet":
10 | num_iter = len(endpoints["all_pose_pair"])
11 | triplet_loss = {}
12 | for i in range(num_iter):
13 | # reg loss
14 | pose_pair = endpoints["all_pose_pair"][i]
15 | loss["quat_{}".format(i)] = F.nn.l1_loss(pose_pair[0][:, :4], pose_pair[1][:, :4]) * params.loss_alpha1
16 | loss["translate_{}".format(i)] = F.nn.square_loss(pose_pair[0][:, 4:], pose_pair[1][:, 4:]) * params.loss_alpha2
17 |
18 | # transformation sensitivity loss (TSL)
19 | if i < 2:
20 | all_R_feats = endpoints["all_R_feats"][i]
21 | all_t_feats = endpoints["all_t_feats"][i]
22 | # R feats triplet loss
23 | R_feats_pos = F.nn.square_loss(all_t_feats[0], all_t_feats[1])
24 | R_feats_neg = F.nn.square_loss(all_R_feats[0], all_R_feats[1])
25 | triplet_loss["R_feats_triplet_pos_{}".format(i)] = R_feats_pos
26 | triplet_loss["R_feats_triplet_neg_{}".format(i)] = R_feats_neg
27 | loss["R_feats_triplet_{}".format(i)] = (F.clip(-R_feats_neg + params.margin[i], lower=0.0) +
28 | R_feats_pos) * params.loss_alpha3
29 | # t feats triplet loss
30 | t_feats_pos = F.nn.square_loss(all_R_feats[0], all_R_feats[2])
31 | t_feats_neg = F.nn.square_loss(all_t_feats[0], all_t_feats[2])
32 | triplet_loss["t_feats_triplet_pos_{}".format(i)] = t_feats_pos
33 | triplet_loss["t_feats_triplet_neg_{}".format(i)] = t_feats_neg
34 | loss["t_feats_triplet_{}".format(i)] = (F.clip(-t_feats_neg + params.margin[i], lower=0.0) +
35 | t_feats_pos) * params.loss_alpha3
36 |
37 | # point-wise feature dropout loss (PFDL)
38 | all_dropout_R_feats = endpoints["all_dropout_R_feats"][i]
39 | all_dropout_t_feats = endpoints["all_dropout_t_feats"][i]
40 | loss["src_R_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_R_feats[0], all_dropout_R_feats[1]) * params.loss_alpha4
41 | loss["ref_R_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_R_feats[2], all_dropout_R_feats[3]) * params.loss_alpha4
42 | loss["src_t_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_t_feats[0], all_dropout_t_feats[1]) * params.loss_alpha4
43 | loss["ref_t_feats_dropout_{}".format(i)] = F.nn.square_loss(all_dropout_t_feats[2], all_dropout_t_feats[3]) * params.loss_alpha4
44 | # total loss
45 | total_losses = []
46 | for k in loss:
47 | total_losses.append(loss[k])
48 | loss["total"] = F.sum(F.concat(total_losses))
49 |
50 | else:
51 | raise NotImplementedError
52 | return loss
53 |
54 |
55 | def compute_metrics(endpoints, params):
56 | metrics = {}
57 | gt_transforms = endpoints["transform_pair"][0]
58 | pred_transforms = endpoints["transform_pair"][1]
59 |
60 | # Euler angles, Individual translation errors (Deep Closest Point convention)
61 | if "prnet" in params.transform_type:
62 | r_gt_euler_deg = so3.mge_dcm2euler(gt_transforms[:, :3, :3], seq="zyx")
63 | r_pred_euler_deg = so3.mge_dcm2euler(pred_transforms[:, :3, :3], seq="zyx")
64 | else:
65 | r_gt_euler_deg = so3.mge_dcm2euler(gt_transforms[:, :3, :3], seq="xyz")
66 | r_pred_euler_deg = so3.mge_dcm2euler(pred_transforms[:, :3, :3], seq="xyz")
67 | t_gt = gt_transforms[:, :3, 3]
68 | t_pred = pred_transforms[:, :3, 3]
69 |
70 | r_mse = F.mean((r_gt_euler_deg - r_pred_euler_deg)**2, axis=1)
71 | r_mae = F.mean(F.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1)
72 | t_mse = F.mean((t_gt - t_pred)**2, axis=1)
73 | t_mae = F.mean(F.abs(t_gt - t_pred), axis=1)
74 |
75 | r_mse = F.mean(r_mse)
76 | t_mse = F.mean(t_mse)
77 | r_mae = F.mean(r_mae)
78 | t_mae = F.mean(t_mae)
79 |
80 | # Rotation, translation errors (isotropic, i.e. doesn"t depend on error
81 | # direction, which is more representative of the actual error)
82 | concatenated = se3.mge_concatenate(se3.mge_inverse(gt_transforms), pred_transforms)
83 | rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2]
84 | residual_rotdeg = F.acos(F.clip(0.5 * (rot_trace - 1), -1.0, 1.0)) * 180.0 / np.pi
85 | residual_transmag = F.norm(concatenated[:, :, 3], axis=-1)
86 | err_r = F.mean(residual_rotdeg)
87 | err_t = F.mean(residual_transmag)
88 |
89 | # weighted score of isotropic errors
90 | score = err_r * 0.01 + err_t
91 |
92 | metrics = {"R_MSE": r_mse, "R_MAE": r_mae, "t_MSE": t_mse, "t_MAE": t_mae, "Err_R": err_r, "Err_t": err_t, "score": score}
93 |
94 | return metrics
95 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/FINet/5c07e7c0cafad67461c574633e7bb77257af6a96/model/__init__.py
--------------------------------------------------------------------------------
/model/module.py:
--------------------------------------------------------------------------------
1 | import megengine as mge
2 | import megengine.module as nn
3 | import megengine.functional as F
4 |
5 |
6 | class Encoder(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 |
11 | # R
12 | self.R_block1 = nn.Sequential(nn.Conv1d(3, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU())
13 | self.R_block2 = nn.Sequential(nn.Conv1d(64, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU())
14 | self.R_block3 = nn.Sequential(nn.Conv1d(128, 128, 1, bias=False), nn.BatchNorm1d(128), nn.ReLU())
15 | self.R_block4 = nn.Sequential(nn.Conv1d(128, 256, 1, bias=False), nn.BatchNorm1d(256), nn.ReLU())
16 | self.R_block5 = nn.Sequential(nn.Conv1d(512, 512, 1, bias=False), nn.BatchNorm1d(512), nn.ReLU())
17 |
18 | # t
19 | self.t_block1 = nn.Sequential(nn.Conv1d(3, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU())
20 | self.t_block2 = nn.Sequential(nn.Conv1d(64, 64, 1, bias=False), nn.BatchNorm1d(64), nn.ReLU())
21 | self.t_block3 = nn.Sequential(nn.Conv1d(128, 128, 1, bias=False), nn.BatchNorm1d(128), nn.ReLU())
22 | self.t_block4 = nn.Sequential(nn.Conv1d(128, 256, 1, bias=False), nn.BatchNorm1d(256), nn.ReLU())
23 | self.t_block5 = nn.Sequential(nn.Conv1d(512, 512, 1, bias=False), nn.BatchNorm1d(512), nn.ReLU())
24 |
25 | def forward(self, x, mask=None):
26 | B, C, N = x.shape
27 | if self.training:
28 | rand_mask = mge.random.uniform(size=(B, 1, N)) > self.config["dropout_ratio"]
29 | else:
30 | rand_mask = 1
31 |
32 | # R stage1
33 | R_feat_output1 = self.R_block1(x)
34 | if mask is not None:
35 | R_feat_output1 = R_feat_output1 * mask
36 | R_feat_output2 = self.R_block2(R_feat_output1)
37 | if mask is not None:
38 | R_feat_output2 = R_feat_output2 * mask
39 | R_feat_glob2 = F.max(R_feat_output2, axis=-1, keepdims=True)
40 |
41 | # t stage1
42 | t_feat_output1 = self.t_block1(x)
43 | if mask is not None:
44 | t_feat_output1 = t_feat_output1 * mask
45 | t_feat_output2 = self.t_block2(t_feat_output1)
46 | if mask is not None:
47 | t_feat_output2 = t_feat_output2 * mask
48 | t_feat_glob2 = F.max(t_feat_output2, axis=-1, keepdims=True)
49 |
50 | # exchange1
51 | src_R_feat_glob2, ref_R_feat_glob2 = F.split(R_feat_glob2, 2, axis=0)
52 | src_t_feat_glob2, ref_t_feat_glob2 = F.split(t_feat_glob2, 2, axis=0)
53 | exchange_R_feat = F.concat((F.repeat(ref_R_feat_glob2, N, axis=2), F.repeat(src_R_feat_glob2, N, axis=2)), axis=0)
54 | exchange_t_feat = F.concat((F.repeat(ref_t_feat_glob2, N, axis=2), F.repeat(src_t_feat_glob2, N, axis=2)), axis=0)
55 | exchange_R_feat = F.concat((R_feat_output2, exchange_R_feat.detach()), axis=1)
56 | exchange_t_feat = F.concat((t_feat_output2, exchange_t_feat.detach()), axis=1)
57 |
58 | # R stage2
59 | R_feat_output3 = self.R_block3(exchange_R_feat)
60 | if mask is not None:
61 | R_feat_output3 = R_feat_output3 * mask
62 | R_feat_output4 = self.R_block4(R_feat_output3)
63 | if mask is not None:
64 | R_feat_output4 = R_feat_output4 * mask
65 | R_feat_glob4 = F.max(R_feat_output4, axis=-1, keepdims=True)
66 |
67 | # t stage2
68 | t_feat_output3 = self.t_block3(exchange_t_feat)
69 | if mask is not None:
70 | t_feat_output3 = t_feat_output3 * mask
71 | t_feat_output4 = self.t_block4(t_feat_output3)
72 | if mask is not None:
73 | t_feat_output4 = t_feat_output4 * mask
74 | t_feat_glob4 = F.max(t_feat_output4, axis=-1, keepdims=True)
75 |
76 | # exchange2
77 | src_R_feat_glob4, ref_R_feat_glob4 = F.split(R_feat_glob4, 2, axis=0)
78 | src_t_feat_glob4, ref_t_feat_glob4 = F.split(t_feat_glob4, 2, axis=0)
79 | exchange_R_feat = F.concat((F.repeat(ref_R_feat_glob4, N, axis=2), F.repeat(src_R_feat_glob4, N, axis=2)), axis=0)
80 | exchange_t_feat = F.concat((F.repeat(ref_t_feat_glob4, N, axis=2), F.repeat(src_t_feat_glob4, N, axis=2)), axis=0)
81 | exchange_R_feat = F.concat((R_feat_output4, exchange_R_feat.detach()), axis=1)
82 | exchange_t_feat = F.concat((t_feat_output4, exchange_t_feat.detach()), axis=1)
83 |
84 | # R stage3
85 | R_feat_output5 = self.R_block5(exchange_R_feat)
86 | if mask is not None:
87 | R_feat_output5 = R_feat_output5 * mask
88 |
89 | # t stage3
90 | t_feat_output5 = self.t_block5(exchange_t_feat)
91 | if mask is not None:
92 | t_feat_output5 = t_feat_output5 * mask
93 |
94 | # final
95 | R_final_feat_output = F.concat((R_feat_output1, R_feat_output2, R_feat_output3, R_feat_output4, R_feat_output5), axis=1)
96 | t_final_feat_output = F.concat((t_feat_output1, t_feat_output2, t_feat_output3, t_feat_output4, t_feat_output5), axis=1)
97 |
98 | R_final_glob_feat = F.max(R_final_feat_output, axis=-1, keepdims=False)
99 | t_final_glob_feat = F.max(t_final_feat_output, axis=-1, keepdims=False)
100 |
101 | R_final_feat_dropout = R_final_feat_output * rand_mask
102 | R_final_feat_dropout = F.max(R_final_feat_dropout, axis=-1, keepdims=False)
103 |
104 | t_final_feat_dropout = t_final_feat_output * rand_mask
105 | t_final_feat_dropout = F.max(t_final_feat_dropout, axis=-1, keepdims=False)
106 |
107 | return [R_final_glob_feat, t_final_glob_feat, R_final_feat_dropout, t_final_feat_dropout]
108 |
109 |
110 | class Fusion(nn.Module):
111 | def __init__(self):
112 | super().__init__()
113 |
114 | # R
115 | self.R_block1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.ReLU())
116 | self.R_block2 = nn.Sequential(nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU())
117 | self.R_block3 = nn.Sequential(nn.Linear(1024, 1024), nn.BatchNorm1d(1024), nn.ReLU())
118 |
119 | # t
120 | self.t_block1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.ReLU())
121 | self.t_block2 = nn.Sequential(nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU())
122 | self.t_block3 = nn.Sequential(nn.Linear(1024, 1024), nn.BatchNorm1d(1024), nn.ReLU())
123 |
124 | def forward(self, R_feat, t_feat):
125 | # R
126 | fuse_R_feat = self.R_block1(R_feat)
127 | fuse_R_feat = self.R_block2(fuse_R_feat)
128 | fuse_R_feat = self.R_block3(fuse_R_feat)
129 | # t
130 | fuse_t_feat = self.t_block1(t_feat)
131 | fuse_t_feat = self.t_block2(fuse_t_feat)
132 | fuse_t_feat = self.t_block3(fuse_t_feat)
133 |
134 | return [fuse_R_feat, fuse_t_feat]
135 |
136 |
137 | class Regression(nn.Module):
138 | def __init__(self, config):
139 | super().__init__()
140 | self.config = config
141 | if self.config["reg_R_feats"] == "tr-tr":
142 | R_in_channel = 4096
143 | elif self.config["reg_R_feats"] == "tr-r":
144 | R_in_channel = 3072
145 | elif self.config["reg_R_feats"] == "r-r":
146 | R_in_channel = 2048
147 | else:
148 | raise ValueError("Unknown reg_R_feats order {}".format(self.config["reg_R_feats"]))
149 |
150 | if self.config["reg_t_feats"] == "tr-t":
151 | t_in_channel = 3072
152 | elif self.config["reg_t_feats"] == "t-t":
153 | t_in_channel = 2048
154 | else:
155 | raise ValueError("Unknown reg_t_feats order {}".format(self.config["reg_t_feats"]))
156 |
157 | self.R_net = nn.Sequential(
158 | # block 1
159 | nn.Linear(R_in_channel, 2048),
160 | nn.BatchNorm1d(2048),
161 | nn.ReLU(),
162 | # block 2
163 | nn.Linear(2048, 1024),
164 | nn.BatchNorm1d(1024),
165 | nn.ReLU(),
166 | # block 3
167 | nn.Linear(1024, 512),
168 | nn.BatchNorm1d(512),
169 | nn.ReLU(),
170 | # block 4
171 | nn.Linear(512, 256),
172 | nn.BatchNorm1d(256),
173 | nn.ReLU(),
174 | # final fc
175 | nn.Linear(256, 4),
176 | )
177 |
178 | self.t_net = nn.Sequential(
179 | # block 1
180 | nn.Linear(t_in_channel, 2048),
181 | nn.BatchNorm1d(2048),
182 | nn.ReLU(),
183 | # block 2
184 | nn.Linear(2048, 1024),
185 | nn.BatchNorm1d(1024),
186 | nn.ReLU(),
187 | # block 3
188 | nn.Linear(1024, 512),
189 | nn.BatchNorm1d(512),
190 | nn.ReLU(),
191 | # block 4
192 | nn.Linear(512, 256),
193 | nn.BatchNorm1d(256),
194 | nn.ReLU(),
195 | # final fc
196 | nn.Linear(256, 3),
197 | )
198 |
199 | def forward(self, R_feat, t_feat):
200 |
201 | pred_quat = self.R_net(R_feat)
202 | pred_quat = F.normalize(pred_quat, axis=1)
203 | pred_translate = self.t_net(t_feat)
204 |
205 | return [pred_quat, pred_translate]
206 |
--------------------------------------------------------------------------------
/model/net.py:
--------------------------------------------------------------------------------
1 | import megengine as mge
2 | import megengine.module as nn
3 | import megengine.functional as F
4 | from model.module import Encoder, Fusion, Regression
5 | from common import quaternion
6 | import math
7 |
8 |
9 | class FINet(nn.Module):
10 | def __init__(self, params):
11 | super().__init__()
12 | self.params = params
13 | self.num_iter = params.titer
14 | self.net_config = params.net_config
15 | self.encoder = [Encoder(self.net_config) for _ in range(self.num_iter)]
16 | self.fusion = [Fusion() for _ in range(self.num_iter)]
17 | self.regression = [Regression(self.net_config) for _ in range(self.num_iter)]
18 |
19 | for m in self.modules():
20 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
21 | nn.init.msra_normal_(m.weight, a=math.sqrt(5))
22 | if m.bias is not None:
23 | fan_in, _ = nn.init.calculate_fan_in_and_fan_out(m.weight)
24 | bound = 1 / math.sqrt(fan_in)
25 | nn.init.uniform_(m.bias, -bound, bound)
26 | # elif isinstance(m, nn.BatchNorm1d):
27 | # nn.init.ones_(m.weight)
28 | # nn.init.zeros_(m.bias)
29 |
30 | def forward(self, data):
31 | endpoints = {}
32 |
33 | xyz_src = data["points_src"][:, :, :3]
34 | xyz_ref = data["points_ref"][:, :, :3]
35 | transform_gt = data["transform_gt"]
36 | pose_gt = data["pose_gt"]
37 |
38 | # init endpoints
39 | all_R_feats = []
40 | all_t_feats = []
41 | all_dropout_R_feats = []
42 | all_dropout_t_feats = []
43 | all_transform_pair = []
44 | all_pose_pair = []
45 |
46 | # init params
47 | B = xyz_src.shape[0]
48 | init_quat = F.tile(mge.tensor([1, 0, 0, 0], dtype="float32"), (B, 1)) # (B, 4)
49 | init_translate = F.tile(mge.tensor([0, 0, 0], dtype="float32"), (B, 1)) # (B, 3)
50 | pose_pred = F.concat((init_quat, init_translate), axis=1) # (B, 7)
51 |
52 | # rename xyz_src
53 | xyz_src_iter = F.copy(xyz_src, device=xyz_src.device)
54 |
55 | for i in range(self.num_iter):
56 | # encoder
57 | encoder = self.encoder[i]
58 | enc_input = F.concat((xyz_src_iter.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)), axis=0) # 2B, C, N
59 | enc_feats = encoder(enc_input)
60 | src_enc_feats = [feat[:B, ...] for feat in enc_feats]
61 | ref_enc_feats = [feat[B:, ...] for feat in enc_feats]
62 | enc_src_R_feat = src_enc_feats[0] # B, C
63 | enc_src_t_feat = src_enc_feats[1] # B, C
64 | enc_ref_R_feat = ref_enc_feats[0] # B, C
65 | enc_ref_t_feat = ref_enc_feats[1] # B, C
66 |
67 | # GFI
68 | src_R_cat_feat = F.concat((enc_src_R_feat, enc_ref_R_feat), axis=-1) # B, 2C
69 | ref_R_cat_feat = F.concat((enc_ref_R_feat, enc_src_R_feat), axis=-1) # B, 2C
70 | src_t_cat_feat = F.concat((enc_src_t_feat, enc_ref_t_feat), axis=-1) # B, 2C
71 | ref_t_cat_feat = F.concat((enc_ref_t_feat, enc_src_t_feat), axis=-1) # B, 2C
72 | fusion_R_input = F.concat((src_R_cat_feat, ref_R_cat_feat), axis=0) # 2B, C
73 | fusion_t_input = F.concat((src_t_cat_feat, ref_t_cat_feat), axis=0) # 2B, C
74 | fusion_feats = self.fusion[i](fusion_R_input, fusion_t_input)
75 | src_fusion_feats = [feat[:B, ...] for feat in fusion_feats]
76 | ref_fusion_feats = [feat[B:, ...] for feat in fusion_feats]
77 | src_R_feat = src_fusion_feats[0] # B, C
78 | src_t_feat = src_fusion_feats[1] # B, C
79 | ref_R_feat = ref_fusion_feats[0] # B, C
80 | ref_t_feat = ref_fusion_feats[1] # B, C
81 |
82 | # R feats
83 | if self.net_config["reg_R_feats"] == "tr-tr":
84 | R_feats = F.concat((src_t_feat, src_R_feat, ref_t_feat, ref_R_feat), axis=-1) # B, 4C
85 |
86 | elif self.net_config["reg_R_feats"] == "tr-r":
87 | R_feats = F.concat((src_R_feat, src_t_feat, ref_R_feat), axis=-1) # B, 3C
88 |
89 | elif self.net_config["reg_R_feats"] == "r-r":
90 | R_feats = F.concat((src_R_feat, ref_R_feat), axis=-1) # B, 2C
91 |
92 | else:
93 | raise ValueError("Unknown reg_R_feats order {}".format(self.net_config["reg_R_feats"]))
94 |
95 | # t feats
96 | if self.net_config["reg_t_feats"] == "tr-t":
97 | src_t_feats = F.concat((src_t_feat, src_R_feat, ref_t_feat), axis=-1) # B, 3C
98 | ref_t_feats = F.concat((ref_t_feat, ref_R_feat, src_t_feat), axis=-1) # B, 3C
99 |
100 | elif self.net_config["reg_t_feats"] == "t-t":
101 | src_t_feats = F.concat((src_t_feat, ref_t_feat), axis=-1) # B, 2C
102 | ref_t_feats = F.concat((ref_t_feat, src_t_feat), axis=-1) # B, 2C
103 |
104 | else:
105 | raise ValueError("Unknown reg_t_feats order {}".format(self.net_config["reg_t_feats"]))
106 |
107 | # regression
108 | t_feats = F.concat((src_t_feats, ref_t_feats), axis=0) # 2B, 3C or 2B, 2C
109 | pred_quat, pred_center = self.regression[i](R_feats, t_feats)
110 | src_pred_center, ref_pred_center = F.split(pred_center, 2, axis=0)
111 | pred_translate = ref_pred_center - src_pred_center
112 | pose_pred_iter = F.concat((pred_quat, pred_translate), axis=-1) # B, 7
113 |
114 | # extract features for compute transformation sensitivity loss (TSL)
115 | xyz_src_rotated = quaternion.mge_quat_rotate(xyz_src_iter.detach(), pose_pred_iter.detach()) # B, N, 3
116 | xyz_src_translated = xyz_src_iter.detach() + F.expand_dims(pose_pred_iter.detach()[:, 4:], axis=1) # B, N, 3
117 |
118 | rotated_enc_input = F.concat((xyz_src_rotated.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)), axis=0) # 2B, C, N
119 | rotated_enc_feats = encoder(rotated_enc_input)
120 | rotated_src_enc_feats = [feat[:B, ...] for feat in rotated_enc_feats]
121 | rotated_enc_src_R_feat = rotated_src_enc_feats[0] # B, C
122 | rotated_enc_src_t_feat = rotated_src_enc_feats[1] # B, C
123 |
124 | translated_enc_input = F.concat((xyz_src_translated.transpose(0, 2, 1).detach(), xyz_ref.transpose(0, 2, 1)),
125 | axis=0) # 2B, C, N
126 | translated_enc_feats = encoder(translated_enc_input)
127 | translated_src_enc_feats = [feat[:B, ...] for feat in translated_enc_feats]
128 | translated_enc_src_R_feat = translated_src_enc_feats[0] # B, C
129 | translated_enc_src_t_feat = translated_src_enc_feats[1] # B, C
130 |
131 | # dropout
132 | dropout_src_R_feat = src_enc_feats[2] # B, C
133 | dropout_src_t_feat = src_enc_feats[3] # B, C
134 | dropout_ref_R_feat = ref_enc_feats[2] # B, C
135 | dropout_ref_t_feat = ref_enc_feats[3] # B, C
136 |
137 | # do transform
138 | xyz_src_iter = quaternion.mge_quat_transform(pose_pred_iter, xyz_src_iter.detach())
139 | pose_pred = quaternion.mge_transform_pose(pose_pred.detach(), pose_pred_iter)
140 | transform_pred = quaternion.mge_quat2mat(pose_pred)
141 |
142 | # add endpoints at each iteration
143 | all_R_feats.append([enc_src_R_feat, rotated_enc_src_R_feat, translated_enc_src_R_feat])
144 | all_t_feats.append([enc_src_t_feat, rotated_enc_src_t_feat, translated_enc_src_t_feat])
145 | all_dropout_R_feats.append([dropout_src_R_feat, enc_src_R_feat, dropout_ref_R_feat, enc_ref_R_feat])
146 | all_dropout_t_feats.append([dropout_src_t_feat, enc_src_t_feat, dropout_ref_t_feat, enc_ref_t_feat])
147 | all_transform_pair.append([transform_gt, transform_pred])
148 | all_pose_pair.append([pose_gt, pose_pred])
149 |
150 | mge.coalesce_free_memory()
151 |
152 | # add endpoints finally
153 | endpoints["all_R_feats"] = all_R_feats
154 | endpoints["all_t_feats"] = all_t_feats
155 | endpoints["all_dropout_R_feats"] = all_dropout_R_feats
156 | endpoints["all_dropout_t_feats"] = all_dropout_t_feats
157 | endpoints["all_transform_pair"] = all_transform_pair
158 | endpoints["all_pose_pair"] = all_pose_pair
159 | endpoints["transform_pair"] = [transform_gt, transform_pred]
160 | endpoints["pose_pair"] = [pose_gt, pose_pred]
161 |
162 | return endpoints
163 |
164 |
165 | def fetch_net(params):
166 | if params.net_type == "finet":
167 | net = FINet(params)
168 |
169 | else:
170 | raise NotImplementedError
171 | return net
172 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | coloredlogs==15.0.1
2 | h5py==3.5.0
3 | megengine==1.7.0
4 | numpy==1.21.4
5 | scipy==1.7.2
6 | tensorboardX==2.4
7 | termcolor==1.1.0
8 | tqdm==4.62.3
9 | transforms3d==0.3.1
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Train the model"""
2 |
3 | import argparse
4 | import datetime
5 | import os
6 |
7 | import megengine as mge
8 | # mge.core.set_option("async_level", 0)
9 |
10 | from megengine.optimizer import Adam, MultiStepLR, LRScheduler
11 | from megengine.autodiff import GradManager
12 | import megengine.distributed as dist
13 | from tqdm import tqdm
14 |
15 | import dataset.data_loader as data_loader
16 | import model.net as net
17 |
18 | from common import utils
19 | from common.manager import Manager
20 | from evaluate import evaluate
21 | from loss.losses import compute_losses
22 | from tensorboardX import SummaryWriter
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--model_dir", default="experiments/experiment_omnet", help="Directory containing params.json")
26 | parser.add_argument("--restore_file",
27 | default=None,
28 | help="Optional, name of the file in model_dir containing weights to reload before training")
29 | parser.add_argument("-ow", "--only_weights", action="store_true", help="Only load model weights or load all train status.")
30 |
31 |
32 | def train(model, manager: Manager, gm):
33 | rank = dist.get_rank()
34 | # loss status and val/test status initial
35 | manager.reset_loss_status()
36 | # set model to training mode
37 | model.train()
38 | # Use tqdm for progress bar
39 | if rank == 0:
40 | t = tqdm(total=len(manager.dataloaders["train"]))
41 |
42 | for i, data_batch in enumerate(manager.dataloaders["train"]):
43 | # move to GPU if available
44 | data_batch = utils.tensor_mge(data_batch)
45 |
46 | # infor print
47 | print_str = manager.print_train_info()
48 |
49 | with gm:
50 | # compute model output and loss
51 | output_batch = model(data_batch)
52 | loss = compute_losses(output_batch, manager.params)
53 |
54 | # update loss status and print current loss and average loss
55 | manager.update_loss_status(loss=loss, split="train")
56 | gm.backward(loss["total"])
57 |
58 | # performs updates using calculated gradients
59 | manager.optimizer.step().clear_grad()
60 |
61 | manager.update_step()
62 | if rank == 0:
63 | manager.writer.add_scalar("Loss/train", manager.loss_status["total"].val, manager.step)
64 | t.set_description(desc=print_str)
65 | t.update()
66 |
67 | if rank == 0:
68 | t.close()
69 |
70 | manager.scheduler.step()
71 | manager.update_epoch()
72 |
73 |
74 | def train_and_evaluate(model, manager: Manager):
75 | rank = dist.get_rank()
76 | # reload weights from restore_file if specified
77 | if args.restore_file is not None:
78 | manager.load_checkpoints()
79 |
80 | world_size = dist.get_world_size()
81 | if world_size > 1:
82 | dist.bcast_list_(model.parameters())
83 | dist.bcast_list_(model.buffers())
84 |
85 | gm = GradManager().attach(
86 | model.parameters(),
87 | callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None,
88 | )
89 |
90 | for epoch in range(manager.params.num_epochs):
91 | # compute number of batches in one epoch (one full pass over the training set)
92 | train(model, manager, gm)
93 |
94 | # Evaluate for one epoch on validation set
95 | evaluate(model, manager)
96 |
97 | # Save best model weights accroding to the params.major_metric
98 | if rank == 0:
99 | manager.check_best_save_last_checkpoints(save_latest_freq=100, save_best_after=200)
100 |
101 |
102 | def main(params):
103 | # DTR support
104 | # mge.dtr.eviction_threshold = "5GB"
105 | # mge.dtr.enable()
106 |
107 | # Set the logger
108 | logger = utils.set_logger(os.path.join(params.model_dir, "train.log"))
109 |
110 | # Set the tensorboard writer
111 | tb_dir = os.path.join(params.model_dir, "summary")
112 | os.makedirs(tb_dir, exist_ok=True)
113 | writter = SummaryWriter(log_dir=tb_dir)
114 |
115 | # fetch dataloaders
116 | dataloaders = data_loader.fetch_dataloader(params)
117 |
118 | # Define the model and optimizer
119 | model = net.fetch_net(params)
120 |
121 | optimizer = Adam(model.parameters(), lr=params.learning_rate)
122 | scheduler = MultiStepLR(optimizer, milestones=[])
123 |
124 | # initial status for checkpoint manager
125 | manager = Manager(model=model,
126 | optimizer=optimizer,
127 | scheduler=scheduler,
128 | params=params,
129 | dataloaders=dataloaders,
130 | writer=writter,
131 | logger=logger)
132 |
133 | # Train the model
134 | utils.master_logger(logger, "Starting training for {} epoch(s)".format(params.num_epochs))
135 |
136 | train_and_evaluate(model, manager)
137 |
138 |
139 | if __name__ == "__main__":
140 | # Load the parameters from json file
141 | args = parser.parse_args()
142 | json_path = os.path.join(args.model_dir, "params.json")
143 | assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
144 | params = utils.Params(json_path)
145 | params.update(vars(args))
146 |
147 | train_proc = dist.launcher(main) if mge.device.get_device_count("gpu") > 1 else main
148 | train_proc(params)
149 |
--------------------------------------------------------------------------------