├── .gitignore ├── README.md ├── custom_evaluate.py ├── custom_train.py ├── data ├── CustomDataset.py ├── ModelNet40.py └── __init__.py ├── loss ├── __init__.py ├── cuda │ └── emd_torch │ │ ├── pkg │ │ ├── emd_loss_layer.py │ │ ├── include │ │ │ ├── cuda │ │ │ │ └── emd.cuh │ │ │ ├── cuda_helper.h │ │ │ └── emd.h │ │ ├── layer │ │ │ ├── __init__.py │ │ │ └── emd_loss_layer.py │ │ └── src │ │ │ ├── cuda │ │ │ └── emd.cu │ │ │ └── emd.cpp │ │ └── setup.py └── earth_mover_distance.py ├── metrics ├── __init__.py ├── helper.py ├── metrics.py └── vis.png ├── modelnet40_evaluate.py ├── modelnet40_train.py ├── models ├── __init__.py ├── benchmark.py ├── fgr.py └── icp.py ├── requirements.txt └── utils ├── __init__.py ├── dist.py ├── format.py ├── process.py └── time.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .idea/ 5 | test*.py 6 | checkpoints/ 7 | data/ILSVRC2012_img_val 8 | model/.DS_Store 9 | .DS_Store 10 | tmp.py 11 | work_dirs/ 12 | work_dirs* 13 | loss/cuda/emd_torch/PyTorch_EMD.egg-info/ 14 | loss/cuda/emd_torch/build/ 15 | loss/cuda/emd_torch/dist/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Our recent registraion works**: 2 | 3 | - NgeNet: [paper](https://arxiv.org/pdf/2201.12094.pdf), [code](https://github.com/zhulf0804/NgeNet). We achieved **SoTA RR** (Registration Recall) in 3DMatch with 92.9%. 4 | - ROPNet [paper](https://arxiv.org/pdf/2107.02583.pdf), [code](https://github.com/zhulf0804/ROPNet). Our solution based on ROPNet and [OverlapPredator](https://github.com/prs-eth/OverlapPredator) won the second place on the [MVP Registration Challenge (ICCV Workshop 2021)](https://mvp-dataset.github.io/MVP/Registration.html). [[Technical Report](https://arxiv.org/pdf/2112.12053.pdf)] 5 | 6 | ## Introduction 7 | 8 | A Simple Point Cloud Registration Pipeline based on Deep Learning. Detailed Information Please Visit this [Zhihu Blog](https://zhuanlan.zhihu.com/p/289620126). 9 | 10 | ![](./metrics/vis.png) 11 | 12 | ## Install 13 | - requirements.txt `pip install -r requirements.txt` 14 | - open3d-python==0.9.0.0 `python -m pip install open3d==0.9` 15 | - emd loss `cd loss/cuda/emd_torch & python setup.py install` 16 | 17 | 18 | ## Start 19 | - Download data from [[here](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip), `435M`] 20 | - evaluate and show(download the pretrained checkpoint [[Complete](https://pan.baidu.com/s/1L7fdgMAHYSDEbCNwLM1Crg), **pwd**: `c4z7`, `16.09 M`] or [[Paritial](https://pan.baidu.com/s/1b1kRlKsxqmUwZZ7XJmcK4w), **pwd**: `pcno`, `16.09`] first) 21 | 22 | ``` 23 | # Iterative Benchmark 24 | python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth --cuda 25 | 26 | # Visualization 27 | # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth --show 28 | 29 | # ICP 30 | # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method icp 31 | 32 | # FGR 33 | # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method fgr --normal 34 | 35 | ``` 36 | 37 | - train 38 | 39 | ``` 40 | CUDA_VISIBLE_DEVICES=0 python modelnet40_train.py --root your_data_path/modelnet40_ply_hdf5_2048 41 | ``` 42 | 43 | ## Experiments 44 | 45 | - Point-to-Point Correspondences(**R error is large due to EMDLoss, see [here](https://zhuanlan.zhihu.com/p/289620126)**) 46 | 47 | | Method | isotropic R | isotropic t | anisotropic R(mse, mae) | anisotropic t(mse, mae) | time(s) | 48 | | :---: | :---: | :---: | :---: | :---: | :---: | 49 | | ICP | 11.44 | 0.16 | 17.64(5.48) | 0.22(0.07) | 0.07 | 50 | | **FGR** | **0.01** | **0.00** | **0.07(0.00)** | **0.00(0.00)** | 0.19 | 51 | | IBenchmark | 5.68 | 0.07 | 9.77(2.69) | 0.12(0.03) | **0.022** | 52 | | **IBenchmark + ICP** | 3.65 | 0.04 | 9.22(1.66) | 0.11(0.02) | | 53 | 54 | - Noise Data(infer_npts = 1024) 55 | 56 | | Method | isotropic R | isotropic t | anisotropic R(mse, mae) | anisotropic t(mse, mae) | 57 | | :---: | :---: | :---: | :---: | :---: | 58 | | ICP | 12.14 | 0.17 | 18.32(5.86) | 0.23(0.08) | 59 | | FGR | **4.27** | **0.06** | 11.55(2.43) | **0.09(0.03)** | 60 | | IBenchmark | 6.25 | 0.08 | 9.28(2.94) | 0.12(0.04) | 61 | | **IBenchmark + ICP** | 5.10 | 0.07 | **10.51(2.39)** | 0.13(0.03) | | 62 | 63 | - Partial-to-Complete Registration(infer_npts = 1024) 64 | 65 | | Method | isotropic R | isotropic t | anisotropic R(mse, mae) | anisotropic t(mse, mae) | 66 | | :---: | :---: | :---: | :---: | :---: | 67 | | ICP | 21.33 | 0.32 | 22.83(10.51) | 0.31(0.15) | 68 | | FGR | 9.49 | **0.12** | 19.51(5.58) | **0.17(0.06)** | 69 | | IBenchmark | 15.02 | 0.22 | 15.78(7.45) | 0.21(0.10) | 70 | | **IBenchmark + ICP** | **9.21** | 0.13 | **14.73(4.43)** | 0.18(0.06) | | 71 | 72 | **Note**: 73 | - Detailed metrics information please refer to [RPM-Net](https://arxiv.org/pdf/2003.13479.pdf)[CVPR 2020]. 74 | 75 | ## Train your Own Data 76 | - Prepare the data in the following structure 77 | ``` 78 | |- CustomData(dir) 79 | |- train_data(dir) 80 | - train1.pcd 81 | - train2.pcd 82 | - ... 83 | |- val_data(dir) 84 | - val1.pcd 85 | - val2.pcd 86 | - ... 87 | ``` 88 | - Train 89 | ``` 90 | python custom_train.py --root your_datapath/CustomData --train_npts 2048 91 | # Note: train_npts depends on your dataset 92 | ``` 93 | - Evaluate 94 | ``` 95 | # Evaluate, infer_npts depends on your dataset 96 | python custom_evaluate.py --root your_datapath/CustomData --infer_npts 2048 --checkpoint work_dirs/models/checkpoints/test_min_loss.pth --cuda 97 | 98 | # Visualize, infer_npts depends on your dataset 99 | python custom_evaluate.py --root your_datapath/CustomData --infer_npts 2048 --checkpoint work_dirs/models/checkpoints/test_min_loss.pth --show 100 | ``` 101 | 102 | ## Acknowledgements 103 | 104 | Thanks for the open source [code](https://github.com/vinits5/pcrnet_pytorch) for helping me to train the Point Cloud Registration Network successfully. 105 | -------------------------------------------------------------------------------- /custom_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import open3d as o3d 4 | import random 5 | import time 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from data import CustomData 10 | from models import IterativeBenchmark, icp 11 | from metrics import compute_metrics, summary_metrics, print_metrics 12 | from utils import npy2pcd, pcd2npy 13 | 14 | 15 | def config_params(): 16 | parser = argparse.ArgumentParser(description='Configuration Parameters') 17 | parser.add_argument('--root', required=True, help='the data path') 18 | parser.add_argument('--infer_npts', type=int, required=True, 19 | help='the points number of each pc for training') 20 | parser.add_argument('--in_dim', type=int, default=3, 21 | help='3 for (x, y, z) or 6 for (x, y, z, nx, ny, nz)') 22 | parser.add_argument('--niters', type=int, default=8, 23 | help='iteration nums in one model forward') 24 | parser.add_argument('--gn', action='store_true', 25 | help='whether to use group normalization') 26 | parser.add_argument('--checkpoint', default='', 27 | help='the path to the trained checkpoint') 28 | parser.add_argument('--method', default='benchmark', 29 | help='choice=[benchmark, icp]') 30 | parser.add_argument('--cuda', action='store_true', 31 | help='whether to use the cuda') 32 | parser.add_argument('--show', action='store_true', 33 | help='whether to visualize') 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def evaluate_benchmark(args, test_loader): 39 | model = IterativeBenchmark(in_dim=args.in_dim, 40 | niters=args.niters, 41 | gn=args.gn) 42 | if args.cuda: 43 | model = model.cuda() 44 | model.load_state_dict(torch.load(args.checkpoint)) 45 | else: 46 | model.load_state_dict(torch.load(args.checkpoint, map_location=torch.device('cpu'))) 47 | model.eval() 48 | 49 | dura = [] 50 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 51 | with torch.no_grad(): 52 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 53 | if args.cuda: 54 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 55 | gtR.cuda(), gtt.cuda() 56 | tic = time.time() 57 | R, t, pred_ref_cloud = model(src_cloud.permute(0, 2, 1).contiguous(), 58 | ref_cloud.permute(0, 2, 1).contiguous()) 59 | toc = time.time() 60 | dura.append(toc - tic) 61 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 62 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 63 | r_mse.append(cur_r_mse) 64 | r_mae.append(cur_r_mae) 65 | t_mse.append(cur_t_mse) 66 | t_mae.append(cur_t_mae) 67 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 68 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 69 | 70 | if args.show: 71 | ref_cloud = torch.squeeze(ref_cloud).cpu().numpy() 72 | src_cloud = torch.squeeze(src_cloud).cpu().numpy() 73 | pred_ref_cloud = torch.squeeze(pred_ref_cloud[-1]).cpu().numpy() 74 | pcd1 = npy2pcd(ref_cloud, 0) 75 | pcd2 = npy2pcd(src_cloud, 1) 76 | pcd3 = npy2pcd(pred_ref_cloud, 2) 77 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 78 | 79 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 80 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 81 | 82 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 83 | 84 | 85 | def evaluate_icp(args, test_loader): 86 | dura = [] 87 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 88 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 89 | if args.cuda: 90 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 91 | gtR.cuda(), gtt.cuda() 92 | 93 | ref_cloud = torch.squeeze(ref_cloud).cpu().numpy() 94 | src_cloud = torch.squeeze(src_cloud).cpu().numpy() 95 | 96 | tic = time.time() 97 | R, t, pred_ref_cloud = icp(npy2pcd(src_cloud), npy2pcd(ref_cloud)) 98 | toc = time.time() 99 | R = torch.from_numpy(np.expand_dims(R, 0)).to(gtR) 100 | t = torch.from_numpy(np.expand_dims(t, 0)).to(gtt) 101 | dura.append(toc - tic) 102 | 103 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 104 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 105 | r_mse.append(cur_r_mse) 106 | r_mae.append(cur_r_mae) 107 | t_mse.append(cur_t_mse) 108 | t_mae.append(cur_t_mae) 109 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 110 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 111 | 112 | if args.show: 113 | pcd1 = npy2pcd(ref_cloud, 0) 114 | pcd2 = npy2pcd(src_cloud, 1) 115 | pcd3 = pred_ref_cloud 116 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 117 | 118 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 119 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 120 | 121 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 122 | 123 | 124 | if __name__ == '__main__': 125 | seed = 222 126 | random.seed(seed) 127 | np.random.seed(seed) 128 | 129 | args = config_params() 130 | 131 | test_set = CustomData(args.root, args.infer_npts, False) 132 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) 133 | 134 | if args.method == 'benchmark': 135 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 136 | evaluate_benchmark(args, test_loader) 137 | print_metrics(args.method, 138 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 139 | elif args.method == 'icp': 140 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 141 | evaluate_icp(args, test_loader) 142 | print_metrics(args.method, dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, 143 | t_isotropic) 144 | else: 145 | raise NotImplementedError -------------------------------------------------------------------------------- /custom_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import open3d 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | 11 | from data import CustomData 12 | from models import IterativeBenchmark 13 | from loss import EMDLosspy 14 | from metrics import compute_metrics, summary_metrics, print_train_info 15 | from utils import time_calc 16 | 17 | 18 | def setup_seed(seed): 19 | torch.backends.cudnn.deterministic = True 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | 25 | def config_params(): 26 | parser = argparse.ArgumentParser(description='Configuration Parameters') 27 | ## dataset 28 | parser.add_argument('--root', required=True, help='the data path') 29 | parser.add_argument('--train_npts', type=int, required=True, 30 | help='the points number of each pc for training') 31 | ## models training 32 | parser.add_argument('--seed', type=int, default=1234) 33 | parser.add_argument('--gn', action='store_true', 34 | help='whether to use group normalization') 35 | parser.add_argument('--epoches', type=int, default=400) 36 | parser.add_argument('--batchsize', type=int, default=16) 37 | parser.add_argument('--num_workers', type=int, default=4) 38 | parser.add_argument('--in_dim', type=int, default=3, 39 | help='3 for (x, y, z) or 6 for (x, y, z, nx, ny, nz)') 40 | parser.add_argument('--niters', type=int, default=8, 41 | help='iteration nums in one model forward') 42 | parser.add_argument('--lr', type=float, default=0.0001, 43 | help='initial learning rate') 44 | parser.add_argument('--milestones', type=list, default=[50, 250], 45 | help='lr decays when epoch in milstones') 46 | parser.add_argument('--gamma', type=float, default=0.1, 47 | help='lr decays to gamma * lr every decay epoch') 48 | # logs 49 | parser.add_argument('--saved_path', default='work_dirs/models', 50 | help='the path to save training logs and checkpoints') 51 | parser.add_argument('--saved_frequency', type=int, default=10, 52 | help='the frequency to save the logs and checkpoints') 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | def compute_loss(ref_cloud, pred_ref_clouds, loss_fn): 58 | losses = [] 59 | discount_factor = 0.5 60 | for i in range(8): 61 | loss = loss_fn(ref_cloud[..., :3].contiguous(), 62 | pred_ref_clouds[i][..., :3].contiguous()) 63 | losses.append(discount_factor**(8 - i)*loss) 64 | return torch.sum(torch.stack(losses)) 65 | 66 | 67 | @time_calc 68 | def train_one_epoch(train_loader, model, loss_fn, optimizer): 69 | losses = [] 70 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 71 | for ref_cloud, src_cloud, gtR, gtt in tqdm(train_loader): 72 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 73 | gtR.cuda(), gtt.cuda() 74 | optimizer.zero_grad() 75 | R, t, pred_ref_clouds = model(src_cloud.permute(0, 2, 1).contiguous(), 76 | ref_cloud.permute(0, 2, 1).contiguous()) 77 | loss = compute_loss(ref_cloud, pred_ref_clouds, loss_fn) 78 | loss.backward() 79 | optimizer.step() 80 | 81 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 82 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 83 | losses.append(loss.item()) 84 | r_mse.append(cur_r_mse) 85 | r_mae.append(cur_r_mae) 86 | t_mse.append(cur_t_mse) 87 | t_mae.append(cur_t_mae) 88 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 89 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 90 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 91 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 92 | results = { 93 | 'loss': np.mean(losses), 94 | 'r_mse': r_mse, 95 | 'r_mae': r_mae, 96 | 't_mse': t_mse, 97 | 't_mae': t_mae, 98 | 'r_isotropic': r_isotropic, 99 | 't_isotropic': t_isotropic 100 | } 101 | return results 102 | 103 | 104 | @time_calc 105 | def test_one_epoch(test_loader, model, loss_fn): 106 | model.eval() 107 | losses = [] 108 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 109 | with torch.no_grad(): 110 | for ref_cloud, src_cloud, gtR, gtt in tqdm(test_loader): 111 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 112 | gtR.cuda(), gtt.cuda() 113 | R, t, pred_ref_clouds = model(src_cloud.permute(0, 2, 1).contiguous(), 114 | ref_cloud.permute(0, 2, 1).contiguous()) 115 | loss = compute_loss(ref_cloud, pred_ref_clouds, loss_fn) 116 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 117 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 118 | 119 | losses.append(loss.item()) 120 | r_mse.append(cur_r_mse) 121 | r_mae.append(cur_r_mae) 122 | t_mse.append(cur_t_mse) 123 | t_mae.append(cur_t_mae) 124 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 125 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 126 | model.train() 127 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 128 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 129 | results = { 130 | 'loss': np.mean(losses), 131 | 'r_mse': r_mse, 132 | 'r_mae': r_mae, 133 | 't_mse': t_mse, 134 | 't_mae': t_mae, 135 | 'r_isotropic': r_isotropic, 136 | 't_isotropic': t_isotropic 137 | } 138 | return results 139 | 140 | 141 | def main(): 142 | args = config_params() 143 | print(args) 144 | 145 | setup_seed(args.seed) 146 | if not os.path.exists(args.saved_path): 147 | os.makedirs(args.saved_path) 148 | summary_path = os.path.join(args.saved_path, 'summary') 149 | if not os.path.exists(summary_path): 150 | os.makedirs(summary_path) 151 | checkpoints_path = os.path.join(args.saved_path, 'checkpoints') 152 | if not os.path.exists(checkpoints_path): 153 | os.makedirs(checkpoints_path) 154 | 155 | train_set = CustomData(args.root, args.train_npts) 156 | test_set = CustomData(args.root, args.train_npts, False) 157 | train_loader = DataLoader(train_set, batch_size=args.batchsize, 158 | shuffle=True, num_workers=args.num_workers) 159 | test_loader = DataLoader(test_set, batch_size=args.batchsize, shuffle=False, 160 | num_workers=args.num_workers) 161 | model = IterativeBenchmark(in_dim=args.in_dim, niters=args.niters, gn = args.gn) 162 | model = model.cuda() 163 | loss_fn = EMDLosspy() 164 | loss_fn = loss_fn.cuda() 165 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 166 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 167 | milestones=args.milestones, 168 | gamma=args.gamma, 169 | last_epoch=-1) 170 | 171 | writer = SummaryWriter(summary_path) 172 | 173 | test_min_loss, test_min_r_mse_error, test_min_rot_error = \ 174 | float('inf'), float('inf'), float('inf') 175 | for epoch in range(args.epoches): 176 | print('=' * 20, epoch + 1, '=' * 20) 177 | train_results = train_one_epoch(train_loader, model, loss_fn, optimizer) 178 | print_train_info(train_results) 179 | test_results = test_one_epoch(test_loader, model, loss_fn) 180 | print_train_info(test_results) 181 | 182 | if epoch % args.saved_frequency == 0: 183 | writer.add_scalar('Loss/train', train_results['loss'], epoch + 1) 184 | writer.add_scalar('Loss/test', test_results['loss'], epoch + 1) 185 | writer.add_scalar('RError/train', train_results['r_mse'], epoch + 1) 186 | writer.add_scalar('RError/test', test_results['r_mse'], epoch + 1) 187 | writer.add_scalar('rotError/train', train_results['r_isotropic'], 188 | epoch + 1) 189 | writer.add_scalar('rotError/test', test_results['r_isotropic'], 190 | epoch + 1) 191 | writer.add_scalar('Lr', optimizer.param_groups[0]['lr'], epoch + 1) 192 | test_loss, test_r_error, test_rot_error = \ 193 | test_results['loss'], test_results['r_mse'], test_results[ 194 | 'r_isotropic'] 195 | if test_loss < test_min_loss: 196 | saved_path = os.path.join(checkpoints_path, "test_min_loss.pth") 197 | torch.save(model.state_dict(), saved_path) 198 | test_min_loss = test_loss 199 | if test_r_error < test_min_r_mse_error: 200 | saved_path = os.path.join(checkpoints_path, 201 | "test_min_rmse_error.pth") 202 | torch.save(model.state_dict(), saved_path) 203 | test_min_r_mse_error = test_r_error 204 | if test_rot_error < test_min_rot_error: 205 | saved_path = os.path.join(checkpoints_path, 206 | "test_min_rot_error.pth") 207 | torch.save(model.state_dict(), saved_path) 208 | test_min_rot_error = test_rot_error 209 | scheduler.step() 210 | 211 | 212 | if __name__ == '__main__': 213 | main() -------------------------------------------------------------------------------- /data/CustomDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from utils import readpcd 7 | from utils import pc_normalize, random_select_points, shift_point_cloud, \ 8 | jitter_point_cloud, generate_random_rotation_matrix, \ 9 | generate_random_tranlation_vector, transform 10 | 11 | 12 | class CustomData(Dataset): 13 | def __init__(self, root, npts, train=True): 14 | super(CustomData, self).__init__() 15 | dirname = 'train_data' if train else 'val_data' 16 | path = os.path.join(root, dirname) 17 | self.train = train 18 | self.files = [os.path.join(path, item) for item in sorted(os.listdir(path))] 19 | self.npts = npts 20 | 21 | def __getitem__(self, item): 22 | file = self.files[item] 23 | ref_cloud = readpcd(file, rtype='npy') 24 | ref_cloud = random_select_points(ref_cloud, m=self.npts) 25 | ref_cloud = pc_normalize(ref_cloud) 26 | R, t = generate_random_rotation_matrix(-20, 20), \ 27 | generate_random_tranlation_vector(-0.5, 0.5) 28 | src_cloud = transform(ref_cloud, R, t) 29 | if self.train: 30 | ref_cloud = jitter_point_cloud(ref_cloud) 31 | src_cloud = jitter_point_cloud(src_cloud) 32 | return ref_cloud, src_cloud, R, t 33 | 34 | def __len__(self): 35 | return len(self.files) -------------------------------------------------------------------------------- /data/ModelNet40.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import open3d as o3d 4 | import os 5 | import torch 6 | 7 | from torch.utils.data import Dataset 8 | import sys 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | ROOR_DIR = os.path.dirname(BASE_DIR) 12 | sys.path.append(ROOR_DIR) 13 | from utils import pc_normalize, random_select_points, shift_point_cloud, \ 14 | jitter_point_cloud, generate_random_rotation_matrix, \ 15 | generate_random_tranlation_vector, transform, random_crop 16 | 17 | 18 | class ModelNet40(Dataset): 19 | def __init__(self, root, npts, train=True, normal=False, mode='clean'): 20 | super(ModelNet40, self).__init__() 21 | self.npts = npts 22 | self.train = train 23 | self.normal = normal 24 | self.mode = mode 25 | files = [os.path.join(root, 'ply_data_train{}.h5'.format(i)) 26 | for i in range(5)] 27 | if not train: 28 | files = [os.path.join(root, 'ply_data_test{}.h5'.format(i)) 29 | for i in range(2)] 30 | self.data, self.labels = self.decode_h5(files) 31 | 32 | def decode_h5(self, files): 33 | points, normal, label = [], [], [] 34 | for file in files: 35 | f = h5py.File(file, 'r') 36 | cur_points = f['data'][:].astype(np.float32) 37 | cur_normal = f['normal'][:].astype(np.float32) 38 | cur_label = f['label'][:].astype(np.float32) 39 | points.append(cur_points) 40 | normal.append(cur_normal) 41 | label.append(cur_label) 42 | points = np.concatenate(points, axis=0) 43 | normal = np.concatenate(normal, axis=0) 44 | data = np.concatenate([points, normal], axis=-1).astype(np.float32) 45 | label = np.concatenate(label, axis=0) 46 | return data, label 47 | 48 | def compose(self, mode, item): 49 | ref_cloud = self.data[item, ...] 50 | R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector() 51 | if mode == 'clean': 52 | ref_cloud = random_select_points(ref_cloud, m=self.npts) 53 | src_cloud_points = transform(ref_cloud[:, :3], R, t) 54 | src_cloud_normal = transform(ref_cloud[:, 3:], R) 55 | src_cloud = np.concatenate([src_cloud_points, src_cloud_normal], 56 | axis=-1) 57 | return src_cloud, ref_cloud, R, t 58 | elif mode == 'partial': 59 | source_cloud = random_select_points(ref_cloud, m=self.npts) 60 | ref_cloud = random_select_points(ref_cloud, m=self.npts) 61 | src_cloud_points = transform(source_cloud[:, :3], R, t) 62 | src_cloud_normal = transform(source_cloud[:, 3:], R) 63 | src_cloud = np.concatenate([src_cloud_points, src_cloud_normal], 64 | axis=-1) 65 | src_cloud = random_crop(src_cloud, p_keep=0.7) 66 | return src_cloud, ref_cloud, R, t 67 | elif mode == 'noise': 68 | source_cloud = random_select_points(ref_cloud, m=self.npts) 69 | ref_cloud = random_select_points(ref_cloud, m=self.npts) 70 | src_cloud_points = transform(source_cloud[:, :3], R, t) 71 | src_cloud_normal = transform(source_cloud[:, 3:], R) 72 | src_cloud = np.concatenate([src_cloud_points, src_cloud_normal], 73 | axis=-1) 74 | return src_cloud, ref_cloud, R, t 75 | else: 76 | raise NotImplementedError 77 | 78 | def __getitem__(self, item): 79 | src_cloud, ref_cloud, R, t = self.compose(mode=self.mode, item=item) 80 | if self.train or self.mode == 'noise' or self.mode == 'partial': 81 | ref_cloud[:, :3] = jitter_point_cloud(ref_cloud[:, :3]) 82 | src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3]) 83 | if not self.normal: 84 | ref_cloud, src_cloud = ref_cloud[:, :3], src_cloud[:, :3] 85 | return ref_cloud, src_cloud, R, t 86 | 87 | def __len__(self): 88 | return len(self.data) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .ModelNet40 import ModelNet40 2 | from .CustomDataset import CustomData -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .earth_mover_distance import EMDLosspy 2 | -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | 39 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensions!' 40 | assert xyz1.shape[1] == xyz2.shape[1], 'Both Point Clouds must have same number of points in it.' 41 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/include/cuda/emd.cuh: -------------------------------------------------------------------------------- 1 | #ifndef EMD_CUH_ 2 | #define EMD_CUH_ 3 | 4 | #include "cuda_helper.h" 5 | 6 | template 7 | __global__ void approxmatch(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, T * __restrict__ match, T * temp){ 8 | T * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 9 | T multiL,multiR; 10 | if (n>=m){ 11 | multiL=1; 12 | multiR=n/m; 13 | }else{ 14 | multiL=m/n; 15 | multiR=1; 16 | } 17 | const int Block=1024; 18 | __shared__ T buf[Block*4]; 19 | for (int i=blockIdx.x;i=-2;j--){ 28 | T level=-powf(4.0f,j); 29 | if (j==-2){ 30 | level=0; 31 | } 32 | for (int k0=0;k0>>( 191 | b, n, m, 192 | xyz1.data(), 193 | xyz2.data(), 194 | match.data(), 195 | temp.data()); 196 | })); 197 | cudaDeviceSynchronize(); 198 | CUDA_CHECK(cudaGetLastError()) 199 | } 200 | 201 | template 202 | __global__ void matchcost(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ out){ 203 | __shared__ T allsum[512]; 204 | const int Block=1024; 205 | __shared__ T buf[Block*3]; 206 | for (int i=blockIdx.x;i>>( 249 | b, n, m, 250 | xyz1.data(), 251 | xyz2.data(), 252 | match.data(), 253 | out.data()); 254 | })); 255 | CUDA_CHECK(cudaGetLastError()) 256 | } 257 | 258 | template 259 | __global__ void matchcostgrad2(const int b, const int n, const int m,const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ grad2){ 260 | __shared__ T sum_grad[256*3]; 261 | for (int i=blockIdx.x;i 302 | __global__ void matchcostgrad1(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ grad1){ 303 | for (int i=blockIdx.x;i>>( 328 | b, n, m, 329 | xyz1.data(), 330 | xyz2.data(), 331 | match.data(), 332 | grad1.data()); 333 | })); 334 | CUDA_CHECK(cudaGetLastError()) 335 | 336 | AT_DISPATCH_FLOATING_TYPES(xyz1.type(), "matchcostgrad2", ([&] { 337 | matchcostgrad2<<>>( 338 | b, n, m, 339 | xyz1.data(), 340 | xyz2.data(), 341 | match.data(), 342 | grad2.data()); 343 | })); 344 | CUDA_CHECK(cudaGetLastError()) 345 | } 346 | 347 | #endif -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/include/cuda_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_HELPER_H_ 2 | #define CUDA_HELPER_H_ 3 | 4 | #ifndef AT_CHECK 5 | #define AT_CHECK TORCH_CHECK 6 | #endif 7 | 8 | #define CUDA_CHECK(err) \ 9 | if (cudaSuccess != err) \ 10 | { \ 11 | fprintf(stderr, "CUDA kernel failed: %s (%s:%d)\n", \ 12 | cudaGetErrorString(err), __FILE__, __LINE__); \ 13 | std::exit(-1); \ 14 | } 15 | 16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), \ 17 | #x " must be a CUDA tensor") 18 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), \ 19 | #x " must be contiguous") 20 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 21 | 22 | #endif -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/include/emd.h: -------------------------------------------------------------------------------- 1 | #ifndef EMD_H_ 2 | #define EMD_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_helper.h" 8 | 9 | 10 | std::vector emd_forward_cuda( 11 | at::Tensor xyz1, 12 | at::Tensor xyz2); 13 | 14 | std::vector emd_backward_cuda( 15 | at::Tensor xyz1, 16 | at::Tensor xyz2, 17 | at::Tensor match); 18 | 19 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 20 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 21 | // CALL FUNCTION IMPLEMENTATIONS 22 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 23 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 24 | 25 | std::vector emd_forward( 26 | at::Tensor xyz1, 27 | at::Tensor xyz2) 28 | { 29 | CHECK_INPUT(xyz1); 30 | CHECK_INPUT(xyz2); 31 | 32 | return emd_forward_cuda(xyz1, xyz2); 33 | } 34 | 35 | std::vector emd_backward( 36 | at::Tensor xyz1, 37 | at::Tensor xyz2, 38 | at::Tensor match) 39 | { 40 | CHECK_INPUT(xyz1); 41 | CHECK_INPUT(xyz2); 42 | CHECK_INPUT(match); 43 | 44 | return emd_backward_cuda(xyz1, xyz2, match); 45 | } 46 | 47 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 48 | m.def("emd_forward", &emd_forward, "Compute Earth Mover's Distance"); 49 | m.def("emd_backward", &emd_backward, "Compute Gradients for Earth Mover's Distance"); 50 | } 51 | 52 | 53 | 54 | #endif -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .emd_loss_layer import EMDLoss -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/layer/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | 39 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensionality' 40 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/src/cuda/emd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "cuda/emd.cuh" 6 | 7 | 8 | std::vector emd_forward_cuda( 9 | at::Tensor xyz1, // B x N1 x D 10 | at::Tensor xyz2) // B x N2 x D 11 | { 12 | // Some useful values 13 | const int batch_size = xyz1.size(0); 14 | const int num_pts_1 = xyz1.size(1); 15 | const int num_pts_2 = xyz2.size(1); 16 | 17 | // Allocate necessary data structures 18 | at::Tensor match = at::zeros({batch_size, num_pts_1, num_pts_2}, 19 | xyz1.options()); 20 | at::Tensor cost = at::zeros({batch_size}, xyz1.options()); 21 | at::Tensor temp = at::zeros({batch_size, 2 * (num_pts_1 + num_pts_2)}, 22 | xyz1.options()); 23 | 24 | // Find the approximate matching 25 | approxmatchLauncher( 26 | batch_size, num_pts_1, num_pts_2, 27 | xyz1, 28 | xyz2, 29 | match, 30 | temp 31 | ); 32 | 33 | // Compute the matching cost 34 | matchcostLauncher( 35 | batch_size, num_pts_1, num_pts_2, 36 | xyz1, 37 | xyz2, 38 | match, 39 | cost 40 | ); 41 | 42 | return {cost, match}; 43 | } 44 | 45 | std::vector emd_backward_cuda( 46 | at::Tensor xyz1, 47 | at::Tensor xyz2, 48 | at::Tensor match) 49 | { 50 | // Some useful values 51 | const int batch_size = xyz1.size(0); 52 | const int num_pts_1 = xyz1.size(1); 53 | const int num_pts_2 = xyz2.size(1); 54 | 55 | // Allocate necessary data structures 56 | at::Tensor grad_xyz1 = at::zeros_like(xyz1); 57 | at::Tensor grad_xyz2 = at::zeros_like(xyz2); 58 | 59 | // Compute the gradient with respect to the two inputs (xyz1 and xyz2) 60 | matchcostgradLauncher( 61 | batch_size, num_pts_1, num_pts_2, 62 | xyz1, 63 | xyz2, 64 | match, 65 | grad_xyz1, 66 | grad_xyz2 67 | ); 68 | 69 | return {grad_xyz1, grad_xyz2}; 70 | } -------------------------------------------------------------------------------- /loss/cuda/emd_torch/pkg/src/emd.cpp: -------------------------------------------------------------------------------- 1 | #include "emd.h" 2 | -------------------------------------------------------------------------------- /loss/cuda/emd_torch/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='PyTorch EMD', 7 | version='0.0', 8 | author='Vinit Sarode', 9 | author_email='vinitsarode5@gmail.com', 10 | description='A PyTorch module for the earth mover\'s distance loss', 11 | ext_package='_emd_ext', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='_emd', 15 | sources=[ 16 | 'pkg/src/emd.cpp', 17 | 'pkg/src/cuda/emd.cu', 18 | ], 19 | include_dirs=['pkg/include'], 20 | ), 21 | ], 22 | packages=[ 23 | 'emd', 24 | ], 25 | package_dir={ 26 | 'emd' : 'pkg/layer' 27 | }, 28 | cmdclass={'build_ext': BuildExtension}, 29 | ) -------------------------------------------------------------------------------- /loss/earth_mover_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | 7 | def emd(template: torch.Tensor, source: torch.Tensor): 8 | from emd import EMDLoss 9 | emd_loss = torch.mean(EMDLoss()(template, source))/(template.size()[1]) 10 | return emd_loss 11 | 12 | 13 | class EMDLosspy(nn.Module): 14 | def __init__(self): 15 | super(EMDLosspy, self).__init__() 16 | 17 | def forward(self, template, source): 18 | return emd(template, source) 19 | 20 | 21 | if __name__ == '__main__': 22 | loss = EMDLosspy() 23 | a = torch.randn(4, 5, 3).cuda() 24 | b = copy.deepcopy(a) 25 | v = loss(a, b) 26 | print(v) 27 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import anisotropic_R_error, anisotropic_t_error, \ 2 | isotropic_R_error, isotropic_t_error 3 | from .helper import compute_metrics, summary_metrics, print_metrics, \ 4 | print_train_info -------------------------------------------------------------------------------- /metrics/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from metrics import anisotropic_R_error, anisotropic_t_error, \ 3 | isotropic_R_error, isotropic_t_error 4 | from utils import inv_R_t 5 | 6 | 7 | def compute_metrics(R, t, gtR, gtt): 8 | inv_R, inv_t = inv_R_t(gtR, gtt) 9 | cur_r_mse, cur_r_mae = anisotropic_R_error(R, inv_R) 10 | cur_t_mse, cur_t_mae = anisotropic_t_error(t, inv_t) 11 | cur_r_isotropic = isotropic_R_error(R, inv_R) 12 | cur_t_isotropic = isotropic_t_error(t, inv_t, inv_R) 13 | return cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 14 | cur_t_isotropic 15 | 16 | 17 | def summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic): 18 | r_mse = np.concatenate(r_mse, axis=0) 19 | r_mae = np.concatenate(r_mae, axis=0) 20 | t_mse = np.concatenate(t_mse, axis=0) 21 | t_mae = np.concatenate(t_mae, axis=0) 22 | r_isotropic = np.concatenate(r_isotropic, axis=0) 23 | t_isotropic = np.concatenate(t_isotropic, axis=0) 24 | 25 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 26 | np.sqrt(np.mean(r_mse)), np.mean(r_mae), np.sqrt(np.mean(t_mse)), \ 27 | np.mean(t_mae), np.mean(r_isotropic), np.mean(t_isotropic) 28 | return r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 29 | 30 | 31 | def print_metrics(method, dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic): 32 | print('='*20, method, '='*20) 33 | print('time: {:.2f} s, mean: {:.5f} s'.format(np.sum(dura), np.mean(dura))) 34 | print('isotropic R(rot) error: {:.2f}'.format(r_isotropic)) 35 | print('isotropic t error: {:.2f}'.format(t_isotropic)) 36 | print('anisotropic mse R error: {:.2f}'.format(r_mse)) 37 | print('anisotropic mae R error: {:.2f}'.format(r_mae)) 38 | print('anisotropic mse t error : {:.2f}'.format(t_mse)) 39 | print('anisotropic mae t error: {:.2f}'.format(t_mae)) 40 | 41 | 42 | def print_train_info(results): 43 | print('Loss: {:.4f}, isotropic R: {:.4f}, isotropic t: {:.4f}, ' 44 | 'anisotropic R(mse, mae): {:.4f}, {:.4f}, ' 45 | 'anisotropic t(mse, mae): {:.4f}, {:.4f}'. 46 | format(results['loss'], results['r_isotropic'], 47 | results['t_isotropic'], results['r_mse'], results['r_mae'], 48 | results['t_mse'], results['t_mae'])) -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from scipy.spatial.transform import Rotation 5 | from utils import inv_R_t 6 | 7 | 8 | def anisotropic_R_error(r1, r2, seq='xyz', degrees=True): 9 | ''' 10 | Calculate mse, mae euler agnle error. 11 | :param r1: shape=(B, 3, 3), pred 12 | :param r2: shape=(B, 3, 3), gt 13 | :return: 14 | ''' 15 | if isinstance(r1, torch.Tensor): 16 | r1 = r1.cpu().detach().numpy() 17 | if isinstance(r2, torch.Tensor): 18 | r2 = r2.cpu().detach().numpy() 19 | assert r1.shape == r2.shape 20 | eulers1, eulers2 = [], [] 21 | for i in range(r1.shape[0]): 22 | euler1 = Rotation.from_matrix(r1[i]).as_euler(seq=seq, degrees=degrees) 23 | euler2 = Rotation.from_matrix(r2[i]).as_euler(seq=seq, degrees=degrees) 24 | eulers1.append(euler1) 25 | eulers2.append(euler2) 26 | eulers1 = np.stack(eulers1, axis=0) 27 | eulers2 = np.stack(eulers2, axis=0) 28 | r_mse = np.mean((eulers1 - eulers2)**2, axis=-1) 29 | r_mae = np.mean(np.abs(eulers1 - eulers2), axis=-1) 30 | return r_mse, r_mae 31 | 32 | 33 | def anisotropic_t_error(t1, t2): 34 | ''' 35 | calculate translation mse and mae error. 36 | :param t1: shape=(B, 3) 37 | :param t2: shape=(B, 3) 38 | :return: 39 | ''' 40 | if isinstance(t1, torch.Tensor): 41 | t1 = t1.cpu().detach().numpy() 42 | if isinstance(t2, torch.Tensor): 43 | t2 = t2.cpu().detach().numpy() 44 | assert t1.shape == t2.shape 45 | t_mse = np.mean((t1 - t2) ** 2, axis=1) 46 | t_mae = np.mean(np.abs(t1 - t2), axis=1) 47 | return t_mse, t_mae 48 | 49 | 50 | def isotropic_R_error(r1, r2): 51 | ''' 52 | Calculate isotropic rotation degree error between r1 and r2. 53 | :param r1: shape=(B, 3, 3), pred 54 | :param r2: shape=(B, 3, 3), gt 55 | :return: 56 | ''' 57 | r2_inv = r2.permute(0, 2, 1).contiguous() 58 | r1r2 = torch.matmul(r2_inv, r1) 59 | # device = r1.device 60 | # B = r1.shape[0] 61 | # mask = torch.unsqueeze(torch.eye(3).to(device), dim=0).repeat(B, 1, 1) 62 | # tr = torch.sum(torch.reshape(mask * r1r2, (B, 9)), dim=-1) 63 | tr = r1r2[:, 0, 0] + r1r2[:, 1, 1] + r1r2[:, 2, 2] 64 | rads = torch.acos(torch.clamp((tr - 1) / 2, -1, 1)) 65 | degrees = rads / math.pi * 180 66 | return degrees 67 | 68 | 69 | def isotropic_t_error(t1, t2, R2): 70 | ''' 71 | Calculate isotropic translation error between t1 and t2. 72 | :param t1: shape=(B, 3), pred_t 73 | :param t2: shape=(B, 3), gtt 74 | :param R2: shape=(B, 3, 3), gtR 75 | :return: 76 | ''' 77 | R2, t2 = inv_R_t(R2, t2) 78 | error = torch.squeeze(R2 @ t1[..., None], -1) + t2 79 | error = torch.norm(error, dim=-1) 80 | return error 81 | 82 | 83 | #def modified_CD(tranformed_src, ref): 84 | # pass 85 | 86 | 87 | # def rotation_error(r1, r2): 88 | # ''' 89 | # calculate mse r1-r2 error. 90 | # :param r1: shape=(B, 3, 3), pred 91 | # :param r2: shape=(B, 3, 3), gt 92 | # :return: 93 | # ''' 94 | # r = torch.reshape(r1 - r2, (-1, 9)) 95 | # error = torch.mean(torch.sum(r ** 2, dim=1)) 96 | # return error -------------------------------------------------------------------------------- /metrics/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhulf0804/PCReg.PyTorch/19c26744b00615a680829e7804807c013aa8ae34/metrics/vis.png -------------------------------------------------------------------------------- /modelnet40_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import open3d as o3d 4 | import random 5 | import time 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from data import ModelNet40 10 | from models import Benchmark, IterativeBenchmark, icp, fgr 11 | from utils import npy2pcd, pcd2npy 12 | from metrics import compute_metrics, summary_metrics, print_metrics 13 | 14 | 15 | def config_params(): 16 | parser = argparse.ArgumentParser(description='Configuration Parameters') 17 | parser.add_argument('--root', required=True, help='the data path') 18 | parser.add_argument('--infer_npts', type=int, default=-1, 19 | help='the points number of each pc for training') 20 | parser.add_argument('--mode', default='clean', 21 | choices=['clean', 'partial', 'noise'], 22 | help='training mode about data') 23 | parser.add_argument('--normal', action='store_true', 24 | help='whether to use normal data') 25 | parser.add_argument('--niters', type=int, default=8, 26 | help='iteration nums in one model forward') 27 | parser.add_argument('--gn', action='store_true', 28 | help='whether to use group normalization') 29 | parser.add_argument('--checkpoint', default='', 30 | help='the path to the trained checkpoint') 31 | parser.add_argument('--method', default='benchmark', 32 | help='choice=[benchmark, icp, fgr, bm_icp]') 33 | parser.add_argument('--cuda', action='store_true', 34 | help='whether to use the cuda') 35 | parser.add_argument('--show', action='store_true', 36 | help='whether to visualize') 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def evaluate_benchmark(args, test_loader): 42 | in_dim = 6 if args.normal else 3 43 | model = IterativeBenchmark(in_dim=in_dim, niters=args.niters, gn=args.gn) 44 | if args.cuda: 45 | model = model.cuda() 46 | model.load_state_dict(torch.load(args.checkpoint)) 47 | else: 48 | model.load_state_dict(torch.load(args.checkpoint, map_location=torch.device('cpu'))) 49 | model.eval() 50 | 51 | dura = [] 52 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 53 | with torch.no_grad(): 54 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 55 | if args.cuda: 56 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 57 | gtR.cuda(), gtt.cuda() 58 | tic = time.time() 59 | R, t, pred_ref_cloud = model(src_cloud.permute(0, 2, 1).contiguous(), 60 | ref_cloud.permute(0, 2, 1).contiguous()) 61 | toc = time.time() 62 | dura.append(toc - tic) 63 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 64 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 65 | r_mse.append(cur_r_mse) 66 | r_mae.append(cur_r_mae) 67 | t_mse.append(cur_t_mse) 68 | t_mae.append(cur_t_mae) 69 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 70 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 71 | 72 | if args.show: 73 | ref_cloud = torch.squeeze(ref_cloud).cpu().numpy() 74 | src_cloud = torch.squeeze(src_cloud).cpu().numpy() 75 | pred_ref_cloud = torch.squeeze(pred_ref_cloud[-1]).cpu().numpy() 76 | pcd1 = npy2pcd(ref_cloud, 0) 77 | pcd2 = npy2pcd(src_cloud, 1) 78 | pcd3 = npy2pcd(pred_ref_cloud, 2) 79 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 80 | 81 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 82 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 83 | 84 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 85 | 86 | 87 | def evaluate_icp(args, test_loader): 88 | dura = [] 89 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 90 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 91 | if args.cuda: 92 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 93 | gtR.cuda(), gtt.cuda() 94 | 95 | ref_cloud = torch.squeeze(ref_cloud).cpu().numpy() 96 | src_cloud = torch.squeeze(src_cloud).cpu().numpy() 97 | 98 | tic = time.time() 99 | R, t, pred_ref_cloud = icp(npy2pcd(src_cloud), npy2pcd(ref_cloud)) 100 | toc = time.time() 101 | R = torch.from_numpy(np.expand_dims(R, 0)).to(gtR) 102 | t = torch.from_numpy(np.expand_dims(t, 0)).to(gtt) 103 | dura.append(toc - tic) 104 | 105 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 106 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 107 | r_mse.append(cur_r_mse) 108 | r_mae.append(cur_r_mae) 109 | t_mse.append(cur_t_mse) 110 | t_mae.append(cur_t_mae) 111 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 112 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 113 | 114 | if args.show: 115 | pcd1 = npy2pcd(ref_cloud, 0) 116 | pcd2 = npy2pcd(src_cloud, 1) 117 | pcd3 = pred_ref_cloud 118 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 119 | 120 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 121 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 122 | 123 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 124 | 125 | 126 | def evaluate_fgr(args, test_loader): 127 | dura = [] 128 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 129 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 130 | if args.cuda: 131 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 132 | gtR.cuda(), gtt.cuda() 133 | 134 | ref_points = torch.squeeze(ref_cloud).cpu().numpy()[:, :3] 135 | src_points = torch.squeeze(src_cloud).cpu().numpy()[:, :3] 136 | ref_normals = torch.squeeze(ref_cloud).cpu().numpy()[:, 3:] 137 | src_normals = torch.squeeze(src_cloud).cpu().numpy()[:, 3:] 138 | 139 | tic = time.time() 140 | R, t, pred_ref_cloud = fgr(source=npy2pcd(src_points), 141 | target=npy2pcd(ref_points), 142 | src_normals=src_normals, 143 | tgt_normals=ref_normals) 144 | toc = time.time() 145 | R = torch.from_numpy(np.expand_dims(R, 0)).to(gtR) 146 | t = torch.from_numpy(np.expand_dims(t, 0)).to(gtt) 147 | dura.append(toc - tic) 148 | 149 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 150 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 151 | r_mse.append(cur_r_mse) 152 | r_mae.append(cur_r_mae) 153 | t_mse.append(cur_t_mse) 154 | t_mae.append(cur_t_mae) 155 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 156 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 157 | 158 | if args.show: 159 | pcd1 = npy2pcd(ref_points, 0) 160 | pcd2 = npy2pcd(src_points, 1) 161 | pcd3 = pred_ref_cloud 162 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 163 | 164 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 165 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 166 | 167 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 168 | 169 | 170 | def evaluate_benchmark_icp(args, test_loader): 171 | in_dim = 6 if args.normal else 3 172 | model = IterativeBenchmark(in_dim=in_dim, niters=args.niters, gn=args.gn) 173 | if args.cuda: 174 | model = model.cuda() 175 | model.load_state_dict(torch.load(args.checkpoint)) 176 | else: 177 | model.load_state_dict(torch.load(args.checkpoint, map_location=torch.device('cpu'))) 178 | model.eval() 179 | 180 | dura = [] 181 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 182 | with torch.no_grad(): 183 | for i, (ref_cloud, src_cloud, gtR, gtt) in tqdm(enumerate(test_loader)): 184 | if args.cuda: 185 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 186 | gtR.cuda(), gtt.cuda() 187 | tic = time.time() 188 | R1, t1, pred_ref_cloud = model(src_cloud.permute(0, 2, 1).contiguous(), 189 | ref_cloud.permute(0, 2, 1).contiguous()) 190 | ref_cloud = torch.squeeze(ref_cloud).cpu().numpy() 191 | src_cloud_tmp = torch.squeeze(pred_ref_cloud[-1]).cpu().numpy() 192 | R2, t2, pred_ref_cloud = icp(npy2pcd(src_cloud_tmp), npy2pcd(ref_cloud)) 193 | R2, t2 = torch.from_numpy(R2)[None, ...].to(R1), \ 194 | torch.from_numpy(t2)[None, ...].to(R1) 195 | R, t = R2 @ R1, torch.squeeze(R2 @ t1[:, :, None], dim=-1) + t2 196 | toc = time.time() 197 | dura.append(toc - tic) 198 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 199 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 200 | r_mse.append(cur_r_mse) 201 | r_mae.append(cur_r_mae) 202 | t_mse.append(cur_t_mse) 203 | t_mae.append(cur_t_mae) 204 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 205 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 206 | 207 | if args.show: 208 | src_cloud = torch.squeeze(src_cloud).cpu().numpy() 209 | pcd1 = npy2pcd(ref_cloud, 0) 210 | pcd2 = npy2pcd(src_cloud, 1) 211 | pcd3 = pred_ref_cloud 212 | o3d.visualization.draw_geometries([pcd1, pcd2, pcd3]) 213 | 214 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 215 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 216 | 217 | return dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic 218 | 219 | 220 | if __name__ == '__main__': 221 | seed = 222 222 | random.seed(seed) 223 | np.random.seed(seed) 224 | 225 | args = config_params() 226 | print(args) 227 | test_set = ModelNet40(root=args.root, 228 | npts=args.infer_npts, 229 | train=False, 230 | normal=args.normal, 231 | mode=args.mode) 232 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False) 233 | 234 | if args.method == 'benchmark': 235 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 236 | evaluate_benchmark(args, test_loader) 237 | print_metrics(args.method, 238 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 239 | elif args.method == 'icp': 240 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 241 | evaluate_icp(args, test_loader) 242 | print_metrics(args.method, dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, 243 | t_isotropic) 244 | elif args.method == 'fgr': 245 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 246 | evaluate_fgr(args, test_loader) 247 | print_metrics(args.method, dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, 248 | t_isotropic) 249 | elif args.method == 'bm_icp': 250 | dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 251 | evaluate_benchmark_icp(args, test_loader) 252 | print_metrics(args.method, dura, r_mse, r_mae, t_mse, t_mae, r_isotropic, 253 | t_isotropic) 254 | else: 255 | raise ValueError -------------------------------------------------------------------------------- /modelnet40_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import open3d 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | 11 | from data import ModelNet40 12 | from models import IterativeBenchmark 13 | from loss import EMDLosspy 14 | from metrics import compute_metrics, summary_metrics, print_train_info 15 | from utils import time_calc 16 | 17 | 18 | def setup_seed(seed): 19 | torch.backends.cudnn.deterministic = True 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | 25 | def config_params(): 26 | parser = argparse.ArgumentParser(description='Configuration Parameters') 27 | ## dataset 28 | parser.add_argument('--root', required=True, help='the data path') 29 | parser.add_argument('--train_npts', type=int, default=1024, 30 | help='the points number of each pc for training') 31 | parser.add_argument('--normal', action='store_true', 32 | help='whether to use normal data') 33 | parser.add_argument('--mode', default='clean', 34 | choices=['clean', 'partial', 'noise'], 35 | help='training mode about data') 36 | ## models training 37 | parser.add_argument('--seed', type=int, default=1234) 38 | parser.add_argument('--gn', action='store_true', 39 | help='whether to use group normalization') 40 | parser.add_argument('--epoches', type=int, default=400) 41 | parser.add_argument('--batchsize', type=int, default=32) 42 | parser.add_argument('--num_workers', type=int, default=4) 43 | parser.add_argument('--niters', type=int, default=8, 44 | help='iteration nums in one model forward') 45 | parser.add_argument('--lr', type=float, default=0.0001, 46 | help='initial learning rate') 47 | parser.add_argument('--milestones', type=list, default=[50, 250], 48 | help='lr decays when epoch in milstones') 49 | parser.add_argument('--gamma', type=float, default=0.1, 50 | help='lr decays to gamma * lr every decay epoch') 51 | # logs 52 | parser.add_argument('--saved_path', default='work_dirs/models', 53 | help='the path to save training logs and checkpoints') 54 | parser.add_argument('--saved_frequency', type=int, default=10, 55 | help='the frequency to save the logs and checkpoints') 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def compute_loss(ref_cloud, pred_ref_clouds, loss_fn): 61 | losses = [] 62 | discount_factor = 0.5 63 | for i in range(8): 64 | loss = loss_fn(ref_cloud[..., :3].contiguous(), 65 | pred_ref_clouds[i][..., :3].contiguous()) 66 | losses.append(discount_factor**(8 - i)*loss) 67 | return torch.sum(torch.stack(losses)) 68 | 69 | 70 | @time_calc 71 | def train_one_epoch(train_loader, model, loss_fn, optimizer): 72 | losses = [] 73 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 74 | for ref_cloud, src_cloud, gtR, gtt in tqdm(train_loader): 75 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 76 | gtR.cuda(), gtt.cuda() 77 | optimizer.zero_grad() 78 | R, t, pred_ref_clouds = model(src_cloud.permute(0, 2, 1).contiguous(), 79 | ref_cloud.permute(0, 2, 1).contiguous()) 80 | loss = compute_loss(ref_cloud, pred_ref_clouds, loss_fn) 81 | loss.backward() 82 | optimizer.step() 83 | 84 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 85 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 86 | losses.append(loss.item()) 87 | r_mse.append(cur_r_mse) 88 | r_mae.append(cur_r_mae) 89 | t_mse.append(cur_t_mse) 90 | t_mae.append(cur_t_mae) 91 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 92 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 93 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 94 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 95 | results = { 96 | 'loss': np.mean(losses), 97 | 'r_mse': r_mse, 98 | 'r_mae': r_mae, 99 | 't_mse': t_mse, 100 | 't_mae': t_mae, 101 | 'r_isotropic': r_isotropic, 102 | 't_isotropic': t_isotropic 103 | } 104 | return results 105 | 106 | 107 | @time_calc 108 | def test_one_epoch(test_loader, model, loss_fn): 109 | model.eval() 110 | losses = [] 111 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = [], [], [], [], [], [] 112 | with torch.no_grad(): 113 | for ref_cloud, src_cloud, gtR, gtt in tqdm(test_loader): 114 | ref_cloud, src_cloud, gtR, gtt = ref_cloud.cuda(), src_cloud.cuda(), \ 115 | gtR.cuda(), gtt.cuda() 116 | R, t, pred_ref_clouds = model(src_cloud.permute(0, 2, 1).contiguous(), 117 | ref_cloud.permute(0, 2, 1).contiguous()) 118 | loss = compute_loss(ref_cloud, pred_ref_clouds, loss_fn) 119 | cur_r_mse, cur_r_mae, cur_t_mse, cur_t_mae, cur_r_isotropic, \ 120 | cur_t_isotropic = compute_metrics(R, t, gtR, gtt) 121 | 122 | losses.append(loss.item()) 123 | r_mse.append(cur_r_mse) 124 | r_mae.append(cur_r_mae) 125 | t_mse.append(cur_t_mse) 126 | t_mae.append(cur_t_mae) 127 | r_isotropic.append(cur_r_isotropic.cpu().detach().numpy()) 128 | t_isotropic.append(cur_t_isotropic.cpu().detach().numpy()) 129 | model.train() 130 | r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic = \ 131 | summary_metrics(r_mse, r_mae, t_mse, t_mae, r_isotropic, t_isotropic) 132 | results = { 133 | 'loss': np.mean(losses), 134 | 'r_mse': r_mse, 135 | 'r_mae': r_mae, 136 | 't_mse': t_mse, 137 | 't_mae': t_mae, 138 | 'r_isotropic': r_isotropic, 139 | 't_isotropic': t_isotropic 140 | } 141 | return results 142 | 143 | 144 | def main(): 145 | args = config_params() 146 | print(args) 147 | 148 | setup_seed(args.seed) 149 | if not os.path.exists(args.saved_path): 150 | os.makedirs(args.saved_path) 151 | summary_path = os.path.join(args.saved_path, 'summary') 152 | if not os.path.exists(summary_path): 153 | os.makedirs(summary_path) 154 | checkpoints_path = os.path.join(args.saved_path, 'checkpoints') 155 | if not os.path.exists(checkpoints_path): 156 | os.makedirs(checkpoints_path) 157 | 158 | train_set = ModelNet40(root=args.root, 159 | npts=args.train_npts, 160 | train=True, 161 | normal=args.normal, 162 | mode=args.mode) 163 | test_set = ModelNet40(root=args.root, 164 | npts=args.train_npts, 165 | train=False, 166 | normal=args.normal, 167 | mode=args.mode) 168 | train_loader = DataLoader(train_set, batch_size=args.batchsize, 169 | shuffle=True, num_workers=args.num_workers) 170 | test_loader = DataLoader(test_set, batch_size=args.batchsize, shuffle=False, 171 | num_workers=args.num_workers) 172 | 173 | in_dim = 6 if args.normal else 3 174 | model = IterativeBenchmark(in_dim=in_dim, niters=args.niters, gn=args.gn) 175 | model = model.cuda() 176 | loss_fn = EMDLosspy().cuda() 177 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 178 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 179 | milestones=args.milestones, 180 | gamma=args.gamma, 181 | last_epoch=-1) 182 | 183 | writer = SummaryWriter(summary_path) 184 | 185 | test_min_loss, test_min_r_mse_error, test_min_rot_error = \ 186 | float('inf'), float('inf'), float('inf') 187 | for epoch in range(args.epoches): 188 | print('=' * 20, epoch + 1, '=' * 20) 189 | train_results = train_one_epoch(train_loader, model, loss_fn, optimizer) 190 | print_train_info(train_results) 191 | test_results = test_one_epoch(test_loader, model, loss_fn) 192 | print_train_info(test_results) 193 | 194 | if epoch % args.saved_frequency == 0: 195 | writer.add_scalar('Loss/train', train_results['loss'], epoch + 1) 196 | writer.add_scalar('Loss/test', test_results['loss'], epoch + 1) 197 | writer.add_scalar('RError/train', train_results['r_mse'], epoch + 1) 198 | writer.add_scalar('RError/test', test_results['r_mse'], epoch + 1) 199 | writer.add_scalar('rotError/train', train_results['r_isotropic'], epoch + 1) 200 | writer.add_scalar('rotError/test', test_results['r_isotropic'], epoch + 1) 201 | writer.add_scalar('Lr', optimizer.param_groups[0]['lr'], epoch + 1) 202 | test_loss, test_r_error, test_rot_error = \ 203 | test_results['loss'], test_results['r_mse'], test_results['r_isotropic'] 204 | if test_loss < test_min_loss: 205 | saved_path = os.path.join(checkpoints_path, "test_min_loss.pth") 206 | torch.save(model.state_dict(), saved_path) 207 | test_min_loss = test_loss 208 | if test_r_error < test_min_r_mse_error: 209 | saved_path = os.path.join(checkpoints_path, "test_min_rmse_error.pth") 210 | torch.save(model.state_dict(), saved_path) 211 | test_min_r_mse_error = test_r_error 212 | if test_rot_error < test_min_rot_error: 213 | saved_path = os.path.join(checkpoints_path, "test_min_rot_error.pth") 214 | torch.save(model.state_dict(), saved_path) 215 | test_min_rot_error = test_rot_error 216 | scheduler.step() 217 | 218 | 219 | if __name__ == '__main__': 220 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .benchmark import Benchmark, IterativeBenchmark 2 | from .fgr import fgr 3 | from .icp import icp -------------------------------------------------------------------------------- /models/benchmark.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | ROOR_DIR = os.path.dirname(BASE_DIR) 8 | sys.path.append(ROOR_DIR) 9 | from utils import batch_quat2mat, batch_transform 10 | 11 | 12 | class PointNet(nn.Module): 13 | def __init__(self, in_dim, gn, mlps=[64, 64, 64, 128, 1024]): 14 | super(PointNet, self).__init__() 15 | self.backbone = nn.Sequential() 16 | for i, out_dim in enumerate(mlps): 17 | self.backbone.add_module(f'pointnet_conv_{i}', 18 | nn.Conv1d(in_dim, out_dim, 1, 1, 0)) 19 | if gn: 20 | self.backbone.add_module(f'pointnet_gn_{i}', 21 | nn.GroupNorm(8, out_dim)) 22 | self.backbone.add_module(f'pointnet_relu_{i}', 23 | nn.ReLU(inplace=True)) 24 | in_dim = out_dim 25 | 26 | def forward(self, x): 27 | x = self.backbone(x) 28 | x, _ = torch.max(x, dim=2) 29 | return x 30 | 31 | 32 | class Benchmark(nn.Module): 33 | def __init__(self, gn, in_dim1, in_dim2=2048, fcs=[1024, 1024, 512, 512, 256, 7]): 34 | super(Benchmark, self).__init__() 35 | self.in_dim1 = in_dim1 36 | self.encoder = PointNet(in_dim=in_dim1, gn=gn) 37 | self.decoder = nn.Sequential() 38 | for i, out_dim in enumerate(fcs): 39 | self.decoder.add_module(f'fc_{i}', nn.Linear(in_dim2, out_dim)) 40 | if out_dim != 7: 41 | if gn: 42 | self.decoder.add_module(f'gn_{i}',nn.GroupNorm(8, out_dim)) 43 | self.decoder.add_module(f'relu_{i}', nn.ReLU(inplace=True)) 44 | in_dim2 = out_dim 45 | 46 | def forward(self, x, y): 47 | x_f, y_f = self.encoder(x), self.encoder(y) 48 | concat = torch.cat((x_f, y_f), dim=1) 49 | out = self.decoder(concat) 50 | batch_t, batch_quat = out[:, :3], out[:, 3:] / torch.norm(out[:, 3:], dim=1, keepdim=True) 51 | batch_R = batch_quat2mat(batch_quat) 52 | if self.in_dim1 == 3: 53 | transformed_x = batch_transform(x.permute(0, 2, 1).contiguous(), 54 | batch_R, batch_t) 55 | elif self.in_dim1 == 6: 56 | transformed_pts = batch_transform(x.permute(0, 2, 1)[:, :, :3].contiguous(), 57 | batch_R, batch_t) 58 | transformed_nls = batch_transform(x.permute(0, 2, 1)[:, :, 3:].contiguous(), 59 | batch_R) 60 | transformed_x = torch.cat([transformed_pts, transformed_nls], dim=-1) 61 | else: 62 | raise ValueError 63 | return batch_R, batch_t, transformed_x 64 | 65 | 66 | class IterativeBenchmark(nn.Module): 67 | def __init__(self, in_dim, niters, gn): 68 | super(IterativeBenchmark, self).__init__() 69 | self.benckmark = Benchmark(gn=gn, in_dim1=in_dim) 70 | self.niters = niters 71 | 72 | def forward(self, x, y): 73 | transformed_xs = [] 74 | device = x.device 75 | B = x.size()[0] 76 | transformed_x = torch.clone(x) 77 | batch_R_res = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1) 78 | batch_t_res = torch.zeros(3, 1).to(device).unsqueeze(0).repeat(B, 1, 1) 79 | for i in range(self.niters): 80 | batch_R, batch_t, transformed_x = self.benckmark(transformed_x, y) 81 | transformed_xs.append(transformed_x) 82 | batch_R_res = torch.matmul(batch_R, batch_R_res) 83 | batch_t_res = torch.matmul(batch_R, batch_t_res) \ 84 | + torch.unsqueeze(batch_t, -1) 85 | transformed_x = transformed_x.permute(0, 2, 1).contiguous() 86 | batch_t_res = torch.squeeze(batch_t_res, dim=-1) 87 | #transformed_x = transformed_x.permute(0, 2, 1).contiguous() 88 | return batch_R_res, batch_t_res, transformed_xs 89 | 90 | 91 | if __name__ == '__main__': 92 | x, y = torch.randn(4, 3, 5), torch.randn(4, 3, 5) 93 | net = IterativeBenchmark(in_dim1=3, niters=2) 94 | print(net) 95 | batch_R, batch_t, transformed_x = net(x, y) -------------------------------------------------------------------------------- /models/fgr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import open3d as o3d 3 | 4 | 5 | def fpfh(pcd, normals): 6 | pcd.normals = o3d.utility.Vector3dVector(normals) 7 | pcd_fpfh = o3d.registration.compute_fpfh_feature( 8 | pcd, 9 | o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=64)) 10 | return pcd_fpfh 11 | 12 | 13 | def execute_fast_global_registration(source, target, source_fpfh, target_fpfh): 14 | distance_threshold = 0.01 15 | result = o3d.registration.registration_fast_based_on_feature_matching( 16 | source, target, source_fpfh, target_fpfh, 17 | o3d.registration.FastGlobalRegistrationOption( 18 | maximum_correspondence_distance=distance_threshold)) 19 | transformation = result.transformation 20 | estimate = copy.deepcopy(source) 21 | estimate.transform(transformation) 22 | R, t = transformation[:3, :3], transformation[:3, 3] 23 | return R, t, estimate 24 | 25 | 26 | def fgr(source, target, src_normals, tgt_normals): 27 | source_fpfh = fpfh(source, src_normals) 28 | target_fpfh = fpfh(target, tgt_normals) 29 | R, t, estimate = execute_fast_global_registration(source=source, 30 | target=target, 31 | source_fpfh=source_fpfh, 32 | target_fpfh=target_fpfh) 33 | return R, t, estimate 34 | -------------------------------------------------------------------------------- /models/icp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | 6 | def icp(source, target): 7 | max_correspondence_distance = 2 # 0.5 in RPM-Net 8 | init = np.eye(4, dtype=np.float32) 9 | estimation_method = o3d.registration.TransformationEstimationPointToPoint() 10 | 11 | reg_p2p = o3d.registration.registration_icp( 12 | source=source, 13 | target=target, 14 | init=init, 15 | max_correspondence_distance=max_correspondence_distance, 16 | estimation_method=estimation_method 17 | ) 18 | 19 | transformation = reg_p2p.transformation 20 | estimate = copy.deepcopy(source) 21 | estimate.transform(transformation) 22 | R, t = transformation[:3, :3], transformation[:3, 3] 23 | return R, t, estimate -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | torch==1.4.0 4 | tensorboard 5 | h5py 6 | scipy -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import get_dists 2 | from .format import readpcd, npy2pcd, pcd2npy 3 | from .process import pc_normalize, random_select_points, \ 4 | generate_random_rotation_matrix, generate_random_tranlation_vector, \ 5 | transform, batch_transform, quat2mat, batch_quat2mat, mat2quat, \ 6 | jitter_point_cloud, shift_point_cloud, random_scale_point_cloud, inv_R_t, \ 7 | random_crop 8 | from .time import time_calc -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_dists(points1, points2): 5 | ''' 6 | Calculate square dists between two group points 7 | :param points1: shape=(B, N, 3) 8 | :param points2: shape=(B, M, 3) 9 | :return: 10 | ''' 11 | B, N, C = points1.shape 12 | _, M, _ = points2.shape 13 | dists = torch.sum(torch.pow(points1, 2), dim=-1).view(B, N, 1) + \ 14 | torch.sum(torch.pow(points2, 2), dim=-1).view(B, 1, M) 15 | dists -= 2 * torch.matmul(points1, points2.permute(0, 2, 1)) 16 | #dists = torch.where(dists < 0, torch.ones_like(dists) * 1e-7, dists) # Very Important for dist = 0. 17 | return dists.float() -------------------------------------------------------------------------------- /utils/format.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | 6 | def readpcd(path, rtype='pcd'): 7 | assert rtype in ['pcd', 'npy'] 8 | pcd = o3d.io.read_point_cloud(path) 9 | if rtype == 'pcd': 10 | return pcd 11 | npy = np.asarray(pcd.points).astype(np.float32) 12 | return npy 13 | 14 | 15 | def npy2pcd(npy, ind=-1): 16 | colors = [[1.0, 0, 0], 17 | [0, 1.0, 0], 18 | [0, 0, 1.0]] 19 | color = colors[ind] if ind < 3 else [random.random() for _ in range(3)] 20 | pcd = o3d.geometry.PointCloud() 21 | pcd.points = o3d.utility.Vector3dVector(npy) 22 | if ind >= 0: 23 | pcd.paint_uniform_color(color) 24 | return pcd 25 | 26 | 27 | def pcd2npy(pcd): 28 | npy = np.asarray(pcd.points) 29 | return npy -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def pc_normalize(pc): 7 | mean = np.mean(pc, axis=0) 8 | pc -= mean 9 | m = np.max(np.sqrt(np.sum(np.power(pc, 2), axis=1))) 10 | pc /= m 11 | return pc 12 | 13 | 14 | def random_select_points(pc, m): 15 | if m < 0: 16 | idx = np.arange(pc.shape[0]) 17 | np.random.shuffle(idx) 18 | return pc[idx, :] 19 | n = pc.shape[0] 20 | replace = False if n >= m else True 21 | idx = np.random.choice(n, size=(m, ), replace=replace) 22 | return pc[idx, :] 23 | 24 | 25 | def generate_rotation_x_matrix(theta): 26 | mat = np.eye(3, dtype=np.float32) 27 | mat[1, 1] = math.cos(theta) 28 | mat[1, 2] = -math.sin(theta) 29 | mat[2, 1] = math.sin(theta) 30 | mat[2, 2] = math.cos(theta) 31 | return mat 32 | 33 | 34 | def generate_rotation_y_matrix(theta): 35 | mat = np.eye(3, dtype=np.float32) 36 | mat[0, 0] = math.cos(theta) 37 | mat[0, 2] = math.sin(theta) 38 | mat[2, 0] = -math.sin(theta) 39 | mat[2, 2] = math.cos(theta) 40 | return mat 41 | 42 | 43 | def generate_rotation_z_matrix(theta): 44 | mat = np.eye(3, dtype=np.float32) 45 | mat[0, 0] = math.cos(theta) 46 | mat[0, 1] = -math.sin(theta) 47 | mat[1, 0] = math.sin(theta) 48 | mat[1, 1] = math.cos(theta) 49 | return mat 50 | 51 | 52 | def generate_random_rotation_matrix(angle1=-45, angle2=45): 53 | thetax, thetay, thetaz = np.random.uniform(angle1, angle2, size=(3,)) 54 | matx = generate_rotation_x_matrix(thetax / 180 * math.pi) 55 | maty = generate_rotation_y_matrix(thetay / 180 * math.pi) 56 | matz = generate_rotation_z_matrix(thetaz / 180 * math.pi) 57 | return np.dot(matz, np.dot(maty, matx)) 58 | 59 | 60 | def generate_random_tranlation_vector(range1=-1, range2=1): 61 | tranlation_vector = np.random.uniform(range1, range2, size=(3, )).astype(np.float32) 62 | return tranlation_vector 63 | 64 | 65 | def transform(pc, R, t=None): 66 | pc = np.dot(pc, R.T) 67 | if t is not None: 68 | pc = pc + t 69 | return pc 70 | 71 | 72 | def batch_transform(batch_pc, batch_R, batch_t=None): 73 | ''' 74 | 75 | :param batch_pc: shape=(B, N, 3) 76 | :param batch_R: shape=(B, 3, 3) 77 | :param batch_t: shape=(B, 3) 78 | :return: shape(B, N, 3) 79 | ''' 80 | transformed_pc = torch.matmul(batch_pc, batch_R.permute(0, 2, 1).contiguous()) 81 | if batch_t is not None: 82 | transformed_pc = transformed_pc + torch.unsqueeze(batch_t, 1) 83 | return transformed_pc 84 | 85 | 86 | # The transformation between unit quaternion and rotation matrix is referenced to 87 | # https://zhuanlan.zhihu.com/p/45404840 88 | 89 | def quat2mat(quat): 90 | w, x, y, z = quat 91 | R = np.zeros((3, 3), dtype=np.float32) 92 | R[0][0] = 1 - 2*y*y - 2*z*z 93 | R[0][1] = 2*x*y - 2*z*w 94 | R[0][2] = 2*x*z + 2*y*w 95 | R[1][0] = 2*x*y + 2*z*w 96 | R[1][1] = 1 - 2*x*x - 2*z*z 97 | R[1][2] = 2*y*z - 2*x*w 98 | R[2][0] = 2*x*z - 2*y*w 99 | R[2][1] = 2*y*z + 2*x*w 100 | R[2][2] = 1 - 2*x*x - 2*y*y 101 | return R 102 | 103 | 104 | def batch_quat2mat(batch_quat): 105 | ''' 106 | 107 | :param batch_quat: shape=(B, 4) 108 | :return: 109 | ''' 110 | w, x, y, z = batch_quat[:, 0], batch_quat[:, 1], batch_quat[:, 2], \ 111 | batch_quat[:, 3] 112 | device = batch_quat.device 113 | B = batch_quat.size()[0] 114 | R = torch.zeros(dtype=torch.float, size=(B, 3, 3)).to(device) 115 | R[:, 0, 0] = 1 - 2 * y * y - 2 * z * z 116 | R[:, 0, 1] = 2 * x * y - 2 * z * w 117 | R[:, 0, 2] = 2 * x * z + 2 * y * w 118 | R[:, 1, 0] = 2 * x * y + 2 * z * w 119 | R[:, 1, 1] = 1 - 2 * x * x - 2 * z * z 120 | R[:, 1, 2] = 2 * y * z - 2 * x * w 121 | R[:, 2, 0] = 2 * x * z - 2 * y * w 122 | R[:, 2, 1] = 2 * y * z + 2 * x * w 123 | R[:, 2, 2] = 1 - 2 * x * x - 2 * y * y 124 | return R 125 | 126 | 127 | def mat2quat(mat): 128 | w = math.sqrt(mat[0, 0] + mat[1, 1] + mat[2, 2] + 1) / 2 129 | x = (mat[2, 1] - mat[1, 2]) / (4 * w) 130 | y = (mat[0, 2] - mat[2, 0]) / (4 * w) 131 | z = (mat[1, 0] - mat[0, 1]) / (4 * w) 132 | return w, x, y, z 133 | 134 | 135 | def jitter_point_cloud(pc, sigma=0.01, clip=0.05): 136 | N, C = pc.shape 137 | assert(clip > 0) 138 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip).astype(np.float32) 139 | jittered_data += pc 140 | return jittered_data 141 | 142 | 143 | def shift_point_cloud(pc, shift_range=0.1): 144 | N, C = pc.shape 145 | shifts = np.random.uniform(-shift_range, shift_range, (1, C)).astype(np.float32) 146 | pc += shifts 147 | return pc 148 | 149 | 150 | def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25): 151 | scale = np.random.uniform(scale_low, scale_high, 1) 152 | pc *= scale 153 | return pc 154 | 155 | 156 | def inv_R_t(R, t): 157 | inv_R = R.permute(0, 2, 1).contiguous() 158 | inv_t = - inv_R @ t[..., None] 159 | return inv_R, torch.squeeze(inv_t, -1) 160 | 161 | 162 | def uniform_2_sphere(num: int = None): 163 | """Uniform sampling on a 2-sphere 164 | 165 | Source: https://gist.github.com/andrewbolster/10274979 166 | 167 | Args: 168 | num: Number of vectors to sample (or None if single) 169 | 170 | Returns: 171 | Random Vector (np.ndarray) of size (num, 3) with norm 1. 172 | If num is None returned value will have size (3,) 173 | 174 | """ 175 | if num is not None: 176 | phi = np.random.uniform(0.0, 2 * np.pi, num) 177 | cos_theta = np.random.uniform(-1.0, 1.0, num) 178 | else: 179 | phi = np.random.uniform(0.0, 2 * np.pi) 180 | cos_theta = np.random.uniform(-1.0, 1.0) 181 | 182 | theta = np.arccos(cos_theta) 183 | x = np.sin(theta) * np.cos(phi) 184 | y = np.sin(theta) * np.sin(phi) 185 | z = np.cos(theta) 186 | 187 | return np.stack((x, y, z), axis=-1) 188 | 189 | 190 | def random_crop(pc, p_keep): 191 | rand_xyz = uniform_2_sphere() 192 | centroid = np.mean(pc[:, :3], axis=0) 193 | pc_centered = pc[:, :3] - centroid 194 | 195 | dist_from_plane = np.dot(pc_centered, rand_xyz) 196 | mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100) 197 | return pc[mask, :] -------------------------------------------------------------------------------- /utils/time.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | def time_calc(func): 5 | def wrapper(*args, **kargs): 6 | start_time = time.time() 7 | f = func(*args, **kargs) 8 | print('{}: {:.2f} s'.format(func.__name__, time.time() - start_time)) 9 | return f 10 | return wrapper 11 | --------------------------------------------------------------------------------