├── config └── __init__.py ├── engine ├── __init__.py └── organize_loss.py ├── losses ├── __init__.py ├── nn_distance │ ├── setup.py │ ├── src │ │ └── nn_distance.cpp │ └── chamfer_loss.py ├── consistency_loss.py ├── backbone_loss.py ├── uncertainty_loss.py ├── shape_prior_loss.py └── fs_net_loss.py ├── tools ├── __init__.py ├── visualize │ ├── __init__.py │ ├── copy_pic.py │ ├── combine_dpn_camera.py │ ├── combine_dualposenet_result.py │ ├── combine_dpn_camera_dict.py │ ├── combine_sgpa_result.py │ ├── visualize_invalid_data.py │ ├── combine_result.py │ └── plot_map_sgpa.py ├── torch_utils │ ├── __init__.py │ └── solver │ │ ├── __init__.py │ │ ├── over9000.py │ │ ├── sgdp.py │ │ ├── adamp.py │ │ ├── lookahead.py │ │ ├── ralamb.py │ │ ├── optimize.py │ │ └── rmsprop_tf.py ├── pyTorchChamferDistance │ ├── __init__.py │ ├── chamfer_distance.py │ ├── chamfer_distance.cu │ └── chamfer_distance.cpp ├── plane_utils.py ├── training_utils.py ├── perspective3d.py ├── rot_utils.py ├── solver_utils.py ├── shape_prior_utils.py ├── dataset_utils.py └── align_utils.py ├── datasets └── __init__.py ├── evaluation ├── __init__.py ├── refine_mug_in_detection_dict_camera.py └── refine_mug_in_detection_dict.py ├── network ├── __init__.py ├── backbone_repo │ ├── __init__.py │ ├── ATSA │ │ ├── __init__.py │ │ ├── depth_attention_module.py │ │ └── model_depth.py │ ├── Resnet │ │ ├── __init__.py │ │ └── exts │ │ │ ├── setup.py │ │ │ ├── guideconv.cpp │ │ │ └── guideconv_kernel.cu │ └── pspnet.py ├── fs_net_repo │ ├── __init__.py │ ├── PoseR.py │ ├── PoseTs.py │ ├── pcl_encoder.py │ └── FaceRecon.py └── point_sample │ └── __init__.py ├── prepare_data ├── __init__.py ├── lib │ ├── nn_distance │ │ ├── setup.py │ │ ├── src │ │ │ └── nn_distance.cpp │ │ └── chamfer_loss.py │ ├── loss.py │ ├── auto_encoder.py │ ├── network.py │ ├── pspnet.py │ └── align.py ├── renderer.py └── gen_pts.py ├── .gitignore ├── pic └── pipeline.png ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── GPV_pose_shape_prior_release.iml └── deployment.xml ├── nnutils ├── demo_utils.py ├── logger.py └── utils.py ├── README.md └── requirements.txt /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prepare_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/backbone_repo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/fs_net_repo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/point_sample/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/backbone_repo/ATSA/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /config/config_old/ 2 | .idea -------------------------------------------------------------------------------- /network/backbone_repo/Resnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/pyTorchChamferDistance/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pic/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lolrudy/RBP_Pose/HEAD/pic/pipeline.png -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /nnutils/demo_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Written by Yufei Ye (https://github.com/JudyYe) 3 | # -------------------------------------------------------- 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path as osp 8 | import numpy as np 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /network/backbone_repo/Resnet/exts/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='GuideConv', 6 | ext_modules=[ 7 | CUDAExtension('GuideConv', [ 8 | 'guideconv.cpp', 9 | 'guideconv_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /tools/visualize/copy_pic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | dir = '/data2/zrd/GPV_pose_result/visualize_bbox_all_separate_CAMERA' 4 | output_dir = '/data2/zrd/GPV_pose_result/visualize_bbox_CAMERA_sample' 5 | filename_list = os.listdir(dir) 6 | filename_list.sort() 7 | if not os.path.exists(output_dir): 8 | os.makedirs(output_dir) 9 | for name in filename_list[:3000]: 10 | shutil.copyfile(os.path.join(dir, name), os.path.join(output_dir, name)) -------------------------------------------------------------------------------- /tools/visualize/combine_dpn_camera.py: -------------------------------------------------------------------------------- 1 | data_dir = '/data2/zrd/GPV_pose_result/dualposenet_results/CAMERA25' 2 | import os, mmcv 3 | 4 | filelist = os.listdir(data_dir) 5 | filelist.sort() 6 | print(filelist) 7 | result_list = [] 8 | for filename in filelist: 9 | result = mmcv.load(os.path.join(data_dir, filename)) 10 | result_list.append(result) 11 | 12 | mmcv.dump(result_list, '/data2/zrd/GPV_pose_result/dualposenet_results/CAMERA25_results.pkl') 13 | -------------------------------------------------------------------------------- /tools/visualize/combine_dualposenet_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | result_dir = '/data2/zrd/GPV_pose_result/dualposenet_results/REAL275' 4 | save_path = '/data2/zrd/GPV_pose_result/dualposenet_results/REAL275_results.pkl' 5 | 6 | file_list = os.listdir(result_dir) 7 | total_result = [] 8 | for file_name in file_list: 9 | result = mmcv.load(os.path.join(result_dir,file_name)) 10 | total_result.append(result) 11 | 12 | mmcv.dump(total_result, save_path) 13 | 14 | -------------------------------------------------------------------------------- /losses/nn_distance/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='nn_distance', 7 | ext_modules=[ 8 | CUDAExtension('nn_distance', [ 9 | 'src/nn_distance.cpp', 10 | 'src/nn_distance_cuda.cu', ], 11 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}) 12 | ], 13 | 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) 17 | -------------------------------------------------------------------------------- /prepare_data/lib/nn_distance/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='nn_distance', 7 | ext_modules=[ 8 | CUDAExtension('nn_distance', [ 9 | 'src/nn_distance.cpp', 10 | 'src/nn_distance_cuda.cu', ], 11 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}) 12 | ], 13 | 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) 17 | -------------------------------------------------------------------------------- /.idea/GPV_pose_shape_prior_release.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /tools/visualize/combine_dpn_camera_dict.py: -------------------------------------------------------------------------------- 1 | data_dir = '/data2/zrd/GPV_pose_result/dualposenet_results/CAMERA25' 2 | import os, mmcv 3 | 4 | filelist = os.listdir(data_dir) 5 | filelist.sort() 6 | print(filelist) 7 | result_dict = {} 8 | for filename in filelist: 9 | name = filename[8:-4] 10 | image_path = 'data/camera/'+name.replace('_', '/') 11 | result = mmcv.load(os.path.join(data_dir, filename)) 12 | result_dict[image_path] = result 13 | 14 | mmcv.dump(result_dict, '/data2/zrd/GPV_pose_result/dualposenet_results/CAMERA25_results_dict.pkl') 15 | -------------------------------------------------------------------------------- /tools/visualize/combine_sgpa_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | result_dir = '/data/zrd/datasets/NOCS/results/sgpa_results/real' 4 | save_path = '/data/zrd/datasets/NOCS/results/sgpa_results/REAL275_results.pkl' 5 | 6 | file_list = os.listdir(result_dir) 7 | total_result = [] 8 | total_result_dict = {} 9 | for file_name in file_list: 10 | if file_name.endswith('pkl') and file_name.startswith('result'): 11 | result = mmcv.load(os.path.join(result_dir,file_name)) 12 | total_result.append(result) 13 | print(len(total_result)) 14 | mmcv.dump(total_result, save_path) 15 | 16 | -------------------------------------------------------------------------------- /tools/visualize/visualize_invalid_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | data_dir = '/data/zrd/project/GPV_pose_shape_prior/output/modelsave_prior_laptop_camera+real/invalid_image' 3 | import matplotlib.pyplot as plt 4 | import os 5 | rgb_path = os.path.join(data_dir, 'invalid_num_12_0_rgb.npy') 6 | rgb = np.load(rgb_path) 7 | rgb = np.rollaxis(rgb, 0, 3).astype(int) 8 | depth_path = rgb_path.replace('rgb', 'depth') 9 | depth = np.load(depth_path) 10 | depth = np.rollaxis(depth, 0, 3) 11 | depth = depth / np.max(depth) 12 | nocs_path = rgb_path.replace('rgb', 'coord') 13 | nocs = np.load(nocs_path) 14 | nocs = np.rollaxis(nocs, 0, 3) + 0.5 15 | plt.imshow(rgb) 16 | plt.show() 17 | plt.imshow(depth) 18 | plt.show() 19 | plt.imshow(nocs) 20 | plt.show() 21 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/over9000.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | # import torch, math 6 | # from torch.optim.optimizer import Optimizer 7 | # import itertools as it 8 | from .lookahead import Lookahead 9 | from .ralamb import Ralamb 10 | 11 | 12 | # RangerLars = Over9000 = RAdam + LARS + LookAHead 13 | 14 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 15 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 16 | 17 | 18 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs): 19 | ralamb = Ralamb(params, *args, **kwargs) 20 | return Lookahead(ralamb, alpha, k) 21 | 22 | 23 | RangerLars = Over9000 24 | -------------------------------------------------------------------------------- /losses/consistency_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import absl.flags as flags 4 | from absl import app 5 | 6 | FLAGS = flags.FLAGS # can control the weight of each term here 7 | 8 | class consistency_loss(nn.Module): 9 | def __init__(self, beta): 10 | super(consistency_loss, self).__init__() 11 | self.loss_func = nn.SmoothL1Loss(beta=beta) 12 | 13 | def forward(self, name_list, pred_list, gt_list): 14 | loss_list = {} 15 | if 'nocs_dist_consistency' in name_list: 16 | loss_list['nocs_dist'] = self.loss_func(pred_list['face_dis_prior'], pred_list['face_dis_pred']) 17 | return loss_list 18 | 19 | def cal_obj_mask(self, p_mask, g_mask): 20 | return self.loss_func(p_mask, g_mask.long().squeeze()) 21 | 22 | -------------------------------------------------------------------------------- /losses/backbone_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import absl.flags as flags 4 | from absl import app 5 | 6 | FLAGS = flags.FLAGS # can control the weight of each term here 7 | 8 | class backbone_mask_loss(nn.Module): 9 | def __init__(self): 10 | super(backbone_mask_loss, self).__init__() 11 | self.loss_func = nn.CrossEntropyLoss() 12 | # self.loss_func = nn.NLLLoss(ignore_index=-1) 13 | 14 | def forward(self, name_list, pred_list, gt_list): 15 | loss_list = {} 16 | 17 | if 'Obj_mask' in name_list: 18 | # gt_mask_np = gt_list['Mask'].detach().cpu().numpy() 19 | # pred_mask_np = pred_list['Mask'].detach().cpu().numpy() 20 | loss_list['obj_mask'] = FLAGS.mask_w * self.cal_obj_mask(pred_list['Mask'], gt_list['Mask']) 21 | return loss_list 22 | 23 | def cal_obj_mask(self, p_mask, g_mask): 24 | return self.loss_func(p_mask, g_mask.long().squeeze()) 25 | 26 | -------------------------------------------------------------------------------- /losses/nn_distance/src/nn_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int nn_distance_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 5 | 6 | 7 | int nn_distance_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 8 | 9 | 10 | int nn_distance_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 11 | return nn_distance_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 12 | } 13 | 14 | 15 | int nn_distance_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 16 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 17 | return nn_distance_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 18 | } 19 | 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("forward", &nn_distance_forward, "nn_distance forward (CUDA)"); 23 | m.def("backward", &nn_distance_backward, "nn_distance backward (CUDA)"); 24 | } -------------------------------------------------------------------------------- /prepare_data/lib/nn_distance/src/nn_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int nn_distance_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 5 | 6 | 7 | int nn_distance_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 8 | 9 | 10 | int nn_distance_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 11 | return nn_distance_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 12 | } 13 | 14 | 15 | int nn_distance_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 16 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 17 | return nn_distance_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 18 | } 19 | 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("forward", &nn_distance_forward, "nn_distance forward (CUDA)"); 23 | m.def("backward", &nn_distance_backward, "nn_distance backward (CUDA)"); 24 | } -------------------------------------------------------------------------------- /losses/uncertainty_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def laplacian_aleatoric_uncertainty_loss(input, target, log_variance, balance_weight, reduction='mean', sum_last_dim=False): 5 | ''' 6 | References: 7 | MonoPair: Monocular 3D Object Detection Using Pairwise Spatial Relationships, CVPR'20 8 | Geometry and Uncertainty in Deep Learning for Computer Vision, University of Cambridge 9 | ''' 10 | assert reduction in ['mean', 'sum'] 11 | if sum_last_dim: 12 | loss = 1.4142 * torch.exp(-0.5*log_variance) * torch.abs(input - target).sum(-1) + balance_weight * 0.5 * log_variance 13 | else: 14 | loss = 1.4142 * torch.exp(-0.5*log_variance) * torch.abs(input - target) + balance_weight * 0.5 * log_variance 15 | return loss.mean() if reduction == 'mean' else loss.sum() 16 | 17 | 18 | def gaussian_aleatoric_uncertainty_loss(input, target, log_variance, reduction='mean'): 19 | ''' 20 | References: 21 | What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?, Neuips'17 22 | Geometry and Uncertainty in Deep Learning for Computer Vision, University of Cambridge 23 | ''' 24 | assert reduction in ['mean', 'sum'] 25 | loss = 0.5 * torch.exp(-log_variance) * torch.abs(input - target)**2 + 0.5 * log_variance 26 | return loss.mean() if reduction == 'mean' else loss.sum() 27 | 28 | 29 | 30 | if __name__ == '__main__': 31 | pass 32 | -------------------------------------------------------------------------------- /network/backbone_repo/Resnet/exts/guideconv.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 09/02/19. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 11 | size_t K); 12 | 13 | void 14 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci, 15 | size_t Co, size_t B, size_t K); 16 | 17 | 18 | at::Tensor Conv2dLocal_F( 19 | at::Tensor a, // BCHW 20 | at::Tensor b // BCKKHW 21 | ) { 22 | int N1, N2, Ci, Co, K, B; 23 | B = a.size(0); 24 | Ci = a.size(1); 25 | N1 = a.size(2); 26 | N2 = a.size(3); 27 | Co = Ci; 28 | K = sqrt(b.size(1) / Co); 29 | auto c = at::zeros_like(a); 30 | Conv2d_LF_Cuda(a, b, c, N1, N2, Ci, Co, B, K); 31 | return c; 32 | } 33 | 34 | 35 | std::tuple Conv2dLocal_B( 36 | at::Tensor a, 37 | at::Tensor b, 38 | at::Tensor gc 39 | ) { 40 | int N1, N2, Ci, Co, K, B; 41 | B = a.size(0); 42 | Ci = a.size(1); 43 | N1 = a.size(2); 44 | N2 = a.size(3); 45 | Co = Ci; 46 | K = sqrt(b.size(1) / Co); 47 | auto ga = at::zeros_like(a); 48 | auto gb = at::zeros_like(b); 49 | Conv2d_LB_Cuda(a, b, ga, gb, gc, N1, N2, Ci, Co, B, K); 50 | return std::make_tuple(ga, gb); 51 | } 52 | 53 | 54 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m 55 | ) { 56 | m.def("Conv2dLocal_F", &Conv2dLocal_F, "Conv2dLocal Forward (CUDA)"); 57 | m.def("Conv2dLocal_B", &Conv2dLocal_B, "Conv2dLocal Backward (CUDA)"); 58 | } -------------------------------------------------------------------------------- /engine/organize_loss.py: -------------------------------------------------------------------------------- 1 | def control_loss(Train_stage): 2 | if Train_stage == 'PoseNet_only': 3 | name_mask_list = [] 4 | name_fs_list = ['Rot1', 'Rot2', 'Rot1_cos', 'Rot2_cos', 'Rot_regular', 'Tran', 'Size', 'R_con'] 5 | name_recon_list = ['Per_point', 'Point_voting'] 6 | name_prop_list = ['Prop_pm', 'Prop_sym', 'Prop_point_cano'] 7 | elif Train_stage == 'seman_encoder_only': 8 | name_mask_list = ['Obj_mask'] 9 | name_fs_list = [] 10 | name_recon_list = [] 11 | name_prop_list = [] 12 | elif Train_stage == 'shape_prior_only': 13 | name_mask_list = [] 14 | name_fs_list = ['Rot1', 'Rot2', 'Rot1_cos', 'Rot2_cos', 'Rot_regular', 'Tran', 'Size', 'R_con'] 15 | name_recon_list = [] 16 | name_prop_list = ['Prop_pm', 'Prop_sym', 'Prop_point_cano'] 17 | elif Train_stage == 'prior+recon': 18 | name_mask_list = [] 19 | name_fs_list = ['Rot1', 'Rot2', 'Rot1_cos', 'Rot2_cos', 'Rot_regular', 'Tran', 'Size', 'R_con'] 20 | name_recon_list = ['Per_point', 'Point_voting'] 21 | name_prop_list = ['Prop_pm', 'Prop_sym', 'Prop_point_cano'] 22 | elif Train_stage == 'prior+recon+novote': 23 | name_mask_list = [] 24 | name_fs_list = ['Rot1', 'Rot2', 'Rot1_cos', 'Rot2_cos', 'Rot_regular', 'Tran', 'Size', 'R_con'] 25 | name_recon_list = ['Per_point', ] 26 | name_prop_list = ['Prop_pm', 'Prop_sym', 'Prop_point_cano'] 27 | elif Train_stage == 'FSNet_only': 28 | name_mask_list = [] 29 | name_fs_list = ['Rot1', 'Rot2', 'Tran', 'Size', 'Recon'] 30 | name_recon_list = [] 31 | name_prop_list = [] 32 | else: 33 | raise NotImplementedError 34 | return name_mask_list, name_fs_list, name_recon_list, name_prop_list 35 | -------------------------------------------------------------------------------- /network/fs_net_repo/PoseR.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import absl.flags as flags 5 | from absl import app 6 | from config.config import * 7 | FLAGS = flags.FLAGS 8 | from mmcv.cnn import normal_init, constant_init 9 | from torch.nn.modules.batchnorm import _BatchNorm 10 | 11 | class RotHead(nn.Module): 12 | def __init__(self): 13 | super(RotHead, self).__init__() 14 | self.f = FLAGS.feat_pcl + FLAGS.feat_global_pcl + FLAGS.feat_seman 15 | self.k = FLAGS.R_c 16 | 17 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1) 18 | 19 | self.conv2 = torch.nn.Conv1d(1024, 256, 1) 20 | self.conv3 = torch.nn.Conv1d(256, 256, 1) 21 | self.conv4 = torch.nn.Conv1d(256, self.k, 1) 22 | self.drop1 = nn.Dropout(0.2) 23 | self.bn1 = nn.BatchNorm1d(1024) 24 | self.bn2 = nn.BatchNorm1d(256) 25 | self.bn3 = nn.BatchNorm1d(256) 26 | self._init_weights() 27 | 28 | def forward(self, x): 29 | x = F.relu(self.bn1(self.conv1(x))) 30 | x = F.relu(self.bn2(self.conv2(x))) 31 | 32 | x = torch.max(x, 2, keepdim=True)[0] 33 | 34 | x = F.relu(self.bn3(self.conv3(x))) 35 | x = self.drop1(x) 36 | x = self.conv4(x) 37 | 38 | x = x.squeeze(2) 39 | x = x.contiguous() 40 | 41 | return x 42 | 43 | def _init_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 46 | normal_init(m, std=0.001) 47 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 48 | constant_init(m, 1) 49 | elif isinstance(m, nn.Linear): 50 | normal_init(m, std=0.001) 51 | 52 | def main(argv): 53 | points = torch.rand(2, 1350, 1500) # batchsize x feature x numofpoint 54 | rot_head = RotHead() 55 | rot = rot_head(points) 56 | t = 1 57 | 58 | 59 | if __name__ == "__main__": 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /tools/visualize/combine_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | import os 4 | import cv2 5 | from tools.eval_utils import get_bbox 6 | from tools.dataset_utils import crop_resize_by_warp_affine 7 | import matplotlib.pyplot as plt 8 | import math 9 | from evaluation.eval_utils_cass import get_3d_bbox, transform_coordinates_3d, compute_3d_iou_new 10 | from tqdm import tqdm 11 | 12 | pick_list = ['scene_1/0097', 'scene_2/0500', 'scene_3/0438', 'scene_4/0040', 'scene_6/0217'] 13 | pick_idx_list = [(0,1,2), (0,1,2), (1,3,4), (1,2,5), (0,1,3)] 14 | detection_dir = '/data2/zrd/datasets/NOCS/detection_dualposenet/data/segmentation_results/REAL275/results_test_' 15 | img_path_prefix = 'data/real/test/' 16 | dataset_dir = '/data2/zrd/datasets/NOCS' 17 | result_dir = '/data2/zrd/GPV_pose_result/visualize_bbox_pick' 18 | save_dir = '/data2/zrd/GPV_pose_result/visualize_bbox_pick_combine' 19 | 20 | if not os.path.exists(result_dir): 21 | os.makedirs(result_dir) 22 | pick_list = [img_path_prefix+item for item in pick_list] 23 | blank_len = 20 24 | 25 | for i, img_path in tqdm(enumerate(pick_list)): 26 | final_img = None 27 | fig_iou = plt.figure(figsize=(15, 10)) 28 | save_path = os.path.join(save_dir, img_path.replace('/', '_')) + f'_box_combine.png' 29 | for ll, j in enumerate(pick_idx_list[i]): 30 | our_result_path = os.path.join(result_dir, img_path.replace('/', '_')) + f'_box_{j}_our.png' 31 | our_result_pic = cv2.imread(our_result_path) 32 | dpn_result_path = os.path.join(result_dir, img_path.replace('/', '_')) + f'_box_{j}_dpn.png' 33 | dpn_result_pic = cv2.imread(dpn_result_path) 34 | blank_space = np.ones((blank_len, 256, 3))*255 35 | column = np.vstack((dpn_result_pic, blank_space, our_result_pic)) 36 | if final_img is None: 37 | final_img = column 38 | else: 39 | blank_space = np.ones((2*256+blank_len, blank_len, 3))*255 40 | final_img = np.hstack((final_img, blank_space, column)) 41 | cv2.imwrite(save_path, final_img) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RBP-Pose 2 | Pytorch implementation of RBP-Pose: Residual Bounding Box Projection for Category-Level Pose Estimation. 3 | 4 | [//]: # (([link](https://arxiv.org/abs/2203.07918))) 5 | 6 | ![pipeline](pic/pipeline.png) 7 | 8 | ## Required environment 9 | 10 | - Ubuntu 18.04 11 | - Python 3.8 12 | - Pytorch 1.10.1 13 | - CUDA 11.3. 14 | 15 | 16 | 17 | ## Installing 18 | 19 | - Install the main requirements in 'requirement.txt'. 20 | - Install [Detectron2](https://github.com/facebookresearch/detectron2). 21 | 22 | ## Data Preparation 23 | To generate your own dataset, use the data preprocess code provided in this [git](https://github.com/mentian/object-deformnet/blob/master/preprocess/pose_data.py). 24 | Download the detection results in this [link](https://drive.google.com/drive/folders/1q8pjmHDfSUTna13F2R_gU3P-FYCjEP7A?usp=sharing). 25 | 26 | 27 | ## Trained model 28 | Trained model is available [here](https://drive.google.com/drive/folders/1q8pjmHDfSUTna13F2R_gU3P-FYCjEP7A?usp=sharing). 29 | 30 | ## Training 31 | Please note, some details are changed from the original paper for more efficient training. 32 | 33 | Specify the dataset directory and run the following command. 34 | ```shell 35 | python -m engine.train --data_dir YOUR_DATA_DIR --model_save SAVE_DIR --training_stage shape_prior_only # first stage 36 | python -m engine.train --data_dir YOUR_DATA_DIR --model_save SAVE_DIR --resume 1 --resume_model MODEL_PATH--training_stage prior+recon+novote # second stage 37 | ``` 38 | 39 | Detailed configurations are in 'config/config.py'. 40 | 41 | ## Evaluation 42 | ```shell 43 | python -m evaluation.evaluate --data_dir YOUR_DATA_DIR --detection_dir DETECTION_DIR --resume 1 --resume_model MODEL_PATH --model_save SAVE_DIR 44 | ``` 45 | 46 | 47 | ## Acknowledgment 48 | Our implementation leverages the code from [3dgcn](https://github.com/j1a0m0e4sNTU/3dgcn), [FS-Net](https://github.com/DC1991/FS_Net), 49 | [DualPoseNet](https://github.com/Gorilla-Lab-SCUT/DualPoseNet), [SPD](https://github.com/mentian/object-deformnet). 50 | -------------------------------------------------------------------------------- /tools/plane_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_plane(pc, pc_w): 4 | if torch.max(pc_w) < 1e-6: 5 | print('plane weight is too small!!') 6 | print(torch.max(pc_w)) 7 | # return None, None, None 8 | 9 | pc_w = pc_w / torch.max(pc_w) 10 | # min least square 11 | n = pc.shape[0] 12 | A = torch.cat([pc[:, :2], torch.ones([n, 1], device=pc.device)], dim=-1) 13 | b = pc[:, 2].view(-1, 1) 14 | W = torch.diag(pc_w) 15 | WA = torch.mm(W, A) 16 | ATWA = torch.mm(A.permute(1, 0), WA) 17 | 18 | Wb = torch.mm(W, b) 19 | ATWb = torch.mm(A.permute(1, 0), Wb) 20 | try: 21 | if torch.linalg.matrix_rank(ATWA) == 3: 22 | X = torch.linalg.solve(ATWA, ATWb) 23 | else: 24 | ATWA_1 = torch.pinverse(ATWA) 25 | X = torch.mm(ATWA_1, ATWb) 26 | except: 27 | print('error when computing plane parameter due to ill-conditioned matrix') 28 | return None, None, None 29 | 30 | dn_up = torch.cat([X[0] * X[2], X[1] * X[2], -X[2]], dim=0), 31 | dn_norm = X[0] * X[0] + X[1] * X[1] + 1.0 32 | dn = dn_up[0] / dn_norm 33 | 34 | normal_n = dn / (torch.norm(dn) + 1e-10) 35 | for_p2plane = X[2] / torch.sqrt(dn_norm) 36 | return normal_n, dn, for_p2plane 37 | 38 | def get_plane_parameter(pc, pc_w): 39 | # min least square 40 | pc_w = pc_w / torch.max(pc_w) 41 | # min least square 42 | n = pc.shape[0] 43 | A = torch.cat([pc[:, :2], torch.ones([n, 1], device=pc.device)], dim=-1) 44 | b = pc[:, 2].view(-1, 1) 45 | W = torch.diag(pc_w) 46 | WA = torch.mm(W, A) 47 | ATWA = torch.mm(A.permute(1, 0), WA) 48 | 49 | Wb = torch.mm(W, b) 50 | ATWb = torch.mm(A.permute(1, 0), Wb) 51 | try: 52 | if torch.linalg.matrix_rank(ATWA) == 3: 53 | X = torch.linalg.solve(ATWA, ATWb) 54 | else: 55 | ATWA_1 = torch.pinverse(ATWA) 56 | X = torch.mm(ATWA_1, ATWb) 57 | except: 58 | # print('error when computing plane parameter due to ill-conditioned matrix') 59 | X = torch.ones((3, 1), device=A.device) 60 | return X -------------------------------------------------------------------------------- /prepare_data/lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .nn_distance.chamfer_loss import ChamferLoss 5 | 6 | 7 | class Loss(nn.Module): 8 | """ Loss for training DeformNet. 9 | Use NOCS coords to supervise training. 10 | """ 11 | def __init__(self, corr_wt, cd_wt, entropy_wt, deform_wt): 12 | super(Loss, self).__init__() 13 | self.threshold = 0.1 14 | self.chamferloss = ChamferLoss() 15 | self.corr_wt = corr_wt 16 | self.cd_wt = cd_wt 17 | self.entropy_wt = entropy_wt 18 | self.deform_wt = deform_wt 19 | 20 | def forward(self, assign_mat, deltas, prior, nocs, model): 21 | """ 22 | Args: 23 | assign_mat: bs x n_pts x nv 24 | deltas: bs x nv x 3 25 | prior: bs x nv x 3 26 | """ 27 | inst_shape = prior + deltas 28 | # smooth L1 loss for correspondences 29 | soft_assign = F.softmax(assign_mat, dim=2) 30 | coords = torch.bmm(soft_assign, inst_shape) # bs x n_pts x 3 31 | diff = torch.abs(coords - nocs) 32 | less = torch.pow(diff, 2) / (2.0 * self.threshold) 33 | higher = diff - self.threshold / 2.0 34 | corr_loss = torch.where(diff > self.threshold, higher, less) 35 | corr_loss = torch.mean(torch.sum(corr_loss, dim=2)) 36 | corr_loss = self.corr_wt * corr_loss 37 | # entropy loss to encourage peaked distribution 38 | log_assign = F.log_softmax(assign_mat, dim=2) 39 | entropy_loss = torch.mean(-torch.sum(soft_assign * log_assign, 2)) 40 | entropy_loss = self.entropy_wt * entropy_loss 41 | # cd-loss for instance reconstruction 42 | cd_loss, _, _ = self.chamferloss(inst_shape, model) 43 | cd_loss = self.cd_wt * cd_loss 44 | # L2 regularizations on deformation 45 | deform_loss = torch.norm(deltas, p=2, dim=2).mean() 46 | deform_loss = self.deform_wt * deform_loss 47 | # total loss 48 | total_loss = corr_loss + entropy_loss + cd_loss + deform_loss 49 | return total_loss, corr_loss, cd_loss, entropy_loss, deform_loss 50 | -------------------------------------------------------------------------------- /nnutils/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Written by Yufei Ye (https://github.com/JudyYe) 3 | # -------------------------------------------------------- 4 | from __future__ import print_function 5 | 6 | import os 7 | import torchvision.utils as vutils 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | 11 | class Logger(object): 12 | def __init__(self, model_name, save_dir): 13 | self.model_name = os.path.basename(model_name) 14 | self.plotter_dict = {} 15 | cmd = 'rm -rf %s' % os.path.join(save_dir, model_name) 16 | os.system(cmd) 17 | 18 | self.save_dir = os.path.join(save_dir, model_name, 'train') 19 | if not os.path.exists(self.save_dir): 20 | os.makedirs(self.save_dir) 21 | print('## Make Directory: ', self.save_dir) 22 | 23 | cmd = 'rm -rf %s' % self.save_dir 24 | os.system(cmd) 25 | print(cmd) 26 | self.tf_wr = SummaryWriter(self.save_dir) 27 | 28 | def add_loss(self, t, dictionary, pref=''): 29 | for key in dictionary: 30 | name = pref + key.replace(':', '/') 31 | self.tf_wr.add_scalar(name, dictionary[key], t) 32 | 33 | def add_hist_by_dim(self, t, z, name='', max_dim=10): 34 | dim = z.size(-1) 35 | dim = min(dim, max_dim) 36 | for d in range(dim): 37 | index = name + '/%d' % d 38 | self.tf_wr.add_histogram(index, z[:, d], t) 39 | 40 | def add_images(self, iteration, images, name=''): 41 | """ 42 | :param iteration: 43 | :param images: Tensor (N, C, H, W), in range (-1, 1) 44 | :param name: 45 | :return: 46 | """ 47 | # images = torch.stack(images, dim=0), 48 | # x = vutils.make_grid(images) 49 | images = images.cpu().detach() 50 | x = vutils.make_grid(images) 51 | self.tf_wr.add_image(name, x / 2 + 0.5, iteration) 52 | 53 | def print(self, t, epoch, losses, total_loss): 54 | print('[Epoch %2d] iter: %d of model' % (epoch, t), self.model_name) 55 | print('\tTotal Loss: %.6f' % total_loss) 56 | for k in losses: 57 | print('\t\t%s: %.6f' % (k, losses[k])) -------------------------------------------------------------------------------- /nnutils/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Written by Yufei Ye (https://github.com/JudyYe) 3 | # -------------------------------------------------------- 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path as osp 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_cuda(datapoint): 13 | skip = ['index'] 14 | for key in datapoint: 15 | if key in skip: 16 | continue 17 | if isinstance(datapoint[key], list): 18 | datapoint[key] = [e.cuda() for e in datapoint[key]] 19 | else: 20 | if hasattr(datapoint[key], 'cuda'): 21 | datapoint[key] = datapoint[key].cuda() 22 | return datapoint 23 | 24 | 25 | def get_model_name(FLAGS): 26 | # dataset 27 | name = '%s/%s' % (FLAGS.exp, FLAGS.dataset) 28 | if FLAGS.dataset.startswith('oi'): 29 | name += '_%s%g' % (FLAGS.filter_model, FLAGS.filter_trunc,) 30 | 31 | if FLAGS.know_pose == 1: 32 | name += '_pose' 33 | if FLAGS.know_mean == 1: 34 | name += '_3d%d' % FLAGS.vox_loss 35 | 36 | # model 37 | name += '_%s' % (FLAGS.batch_size) 38 | name += '_%s' % (FLAGS.g_mod) 39 | name += '_%s' % (FLAGS.vol_render) 40 | 41 | # loss 42 | name += '_%s%dm%dc%d' % (FLAGS.mask_loss_type, FLAGS.d_loss_rgb, FLAGS.cyc_mask_loss, FLAGS.content_loss) 43 | name += '_%s' % (FLAGS.sample_view) 44 | name += '%d' % FLAGS.seed 45 | 46 | if FLAGS.prior_thin > 0: 47 | name += 'th%g' % FLAGS.prior_thin 48 | if FLAGS.prior_blob > 0: 49 | name += 'bl%g' % FLAGS.prior_blob 50 | if FLAGS.prior_same > 0: 51 | name += 'sa%g' % FLAGS.prior_same 52 | return name 53 | 54 | 55 | def load_my_state_dict(model: torch.nn.Module, state_dict): 56 | own_state = model.state_dict() 57 | for name, param in state_dict.items(): 58 | if name not in own_state: 59 | # print('Not found in checkpoint', name) 60 | continue 61 | if isinstance(param, torch.nn.Parameter): 62 | # backwards compatibility for serialized parameters 63 | param = param.data 64 | if param.size() != own_state[name].size(): 65 | # print('size not match', name, param.size(), own_state[name].size()) 66 | continue 67 | own_state[name].copy_(param) 68 | 69 | -------------------------------------------------------------------------------- /prepare_data/lib/auto_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PointCloudEncoder(nn.Module): 7 | def __init__(self, emb_dim): 8 | super(PointCloudEncoder, self).__init__() 9 | self.conv1 = nn.Conv1d(3, 64, 1) 10 | self.conv2 = nn.Conv1d(64, 128, 1) 11 | self.conv3 = nn.Conv1d(256, 256, 1) 12 | self.conv4 = nn.Conv1d(256, 1024, 1) 13 | self.fc = nn.Linear(1024, emb_dim) 14 | 15 | def forward(self, xyz): 16 | """ 17 | Args: 18 | xyz: (B, 3, N) 19 | 20 | """ 21 | np = xyz.size()[2] 22 | x = F.relu(self.conv1(xyz)) 23 | x = F.relu(self.conv2(x)) 24 | global_feat = F.adaptive_max_pool1d(x, 1) 25 | x = torch.cat((x, global_feat.repeat(1, 1, np)), dim=1) 26 | x = F.relu(self.conv3(x)) 27 | x = F.relu(self.conv4(x)) 28 | x = torch.squeeze(F.adaptive_max_pool1d(x, 1), dim=2) 29 | embedding = self.fc(x) 30 | return embedding 31 | 32 | 33 | class PointCloudDecoder(nn.Module): 34 | def __init__(self, emb_dim, n_pts): 35 | super(PointCloudDecoder, self).__init__() 36 | self.fc1 = nn.Linear(emb_dim, 512) 37 | self.fc2 = nn.Linear(512, 1024) 38 | self.fc3 = nn.Linear(1024, 3*n_pts) 39 | 40 | def forward(self, embedding): 41 | """ 42 | Args: 43 | embedding: (B, 512) 44 | 45 | """ 46 | bs = embedding.size()[0] 47 | out = F.relu(self.fc1(embedding)) 48 | out = F.relu(self.fc2(out)) 49 | out = self.fc3(out) 50 | out_pc = out.view(bs, -1, 3) 51 | return out_pc 52 | 53 | 54 | class PointCloudAE(nn.Module): 55 | def __init__(self, emb_dim=512, n_pts=1024): 56 | super(PointCloudAE, self).__init__() 57 | self.encoder = PointCloudEncoder(emb_dim) 58 | self.decoder = PointCloudDecoder(emb_dim, n_pts) 59 | 60 | def forward(self, in_pc, emb=None): 61 | """ 62 | Args: 63 | in_pc: (B, N, 3) 64 | emb: (B, 512) 65 | 66 | Returns: 67 | emb: (B, emb_dim) 68 | out_pc: (B, n_pts, 3) 69 | 70 | """ 71 | if emb is None: 72 | xyz = in_pc.permute(0, 2, 1) 73 | emb = self.encoder(xyz) 74 | out_pc = self.decoder(emb) 75 | return emb, out_pc 76 | -------------------------------------------------------------------------------- /tools/pyTorchChamferDistance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.cpp_extension import load 5 | import platform 6 | 7 | path = 'your own path' 8 | cd = load(name="cd", 9 | sources=[path+ "pyTorchChamferDistance/chamfer_distance.cpp", 10 | path + "pyTorchChamferDistance/chamfer_distance.cu"]) 11 | 12 | class ChamferDistanceFunction(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, xyz1, xyz2): 15 | batchsize, n, _ = xyz1.size() 16 | _, m, _ = xyz2.size() 17 | xyz1 = xyz1.contiguous() 18 | xyz2 = xyz2.contiguous() 19 | dist1 = torch.zeros(batchsize, n) 20 | dist2 = torch.zeros(batchsize, m) 21 | 22 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 23 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 24 | 25 | if not xyz1.is_cuda: 26 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 27 | else: 28 | dist1 = dist1.cuda() 29 | dist2 = dist2.cuda() 30 | idx1 = idx1.cuda() 31 | idx2 = idx2.cuda() 32 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 33 | 34 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 35 | 36 | return dist1, dist2 37 | 38 | @staticmethod 39 | def backward(ctx, graddist1, graddist2): 40 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 41 | 42 | graddist1 = graddist1.contiguous() 43 | graddist2 = graddist2.contiguous() 44 | 45 | gradxyz1 = torch.zeros(xyz1.size()) 46 | gradxyz2 = torch.zeros(xyz2.size()) 47 | 48 | if not graddist1.is_cuda: 49 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 50 | else: 51 | gradxyz1 = gradxyz1.cuda() 52 | gradxyz2 = gradxyz2.cuda() 53 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 54 | 55 | return gradxyz1, gradxyz2 56 | 57 | 58 | class ChamferDistance(torch.nn.Module): 59 | def forward(self, xyz1, xyz2): 60 | return ChamferDistanceFunction.apply(xyz1, xyz2) 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | 66 | chamfer_dist = ChamferDistance() 67 | a = torch.randn(1, 100, 3) 68 | b = torch.randn(1, 50, 5) 69 | dist1, dist2 = chamfer_dist(a, b) 70 | loss = (torch.mean(dist1)) + (torch.mean(dist2)) 71 | print(loss) 72 | -------------------------------------------------------------------------------- /network/backbone_repo/ATSA/depth_attention_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from mmcv.cnn import constant_init, kaiming_init 4 | 5 | 6 | def last_zero_init(m): 7 | if isinstance(m, nn.Sequential): 8 | constant_init(m[-1], val=0) 9 | m[-1].inited = True 10 | else: 11 | constant_init(m, val=0) 12 | m.inited = True 13 | 14 | 15 | class DAM(nn.Module): 16 | 17 | def __init__(self, inplanes, planes): 18 | super(DAM, self).__init__() 19 | 20 | self.inplanes = inplanes 21 | self.planes = planes 22 | self.conv_mask = nn.Conv2d(self.inplanes, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) 23 | self.softmax = nn.Softmax(dim=2) 24 | self.softmax_channel = nn.Softmax(dim=1) 25 | self.channel_mul_conv = nn.Sequential( 26 | nn.Conv2d(self.inplanes, self.planes, kernel_size=(1, 1), stride=(1, 1), bias=False), 27 | nn.LayerNorm([self.planes, 1, 1]), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(self.planes, self.inplanes, kernel_size=(1, 1), stride=(1, 1), bias=False), 30 | ) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | kaiming_init(self.conv_mask, mode='fan_in') 35 | self.conv_mask.inited = True 36 | last_zero_init(self.channel_mul_conv) 37 | 38 | 39 | def spatial_pool(self, depth_feature): 40 | batch, channel, height, width = depth_feature.size() 41 | input_x = depth_feature 42 | # [N, C, H * W] 43 | input_x = input_x.view(batch, channel, height * width) 44 | # [N, 1, C, H * W] 45 | input_x = input_x.unsqueeze(1) 46 | # [N, 1, H, W] 47 | context_mask = self.conv_mask(depth_feature) 48 | # [N, 1, H * W] 49 | context_mask = context_mask.view(batch, 1, height * width) 50 | # [N, 1, H * W] 51 | context_mask = self.softmax(context_mask) 52 | # [N, 1, H * W, 1] 53 | context_mask = context_mask.unsqueeze(3) 54 | # [N, 1, C, 1] 55 | # context attention 56 | context = torch.matmul(input_x, context_mask) 57 | # [N, C, 1, 1] 58 | context = context.view(batch, channel, 1, 1) 59 | 60 | return context 61 | 62 | def forward(self, x, depth_feature): 63 | # [N, C, 1, 1] 64 | context = self.spatial_pool(depth_feature) 65 | # [N, C, 1, 1] 66 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 67 | # channel-wise attention 68 | out1 = torch.sigmoid(depth_feature * channel_mul_term) 69 | # fusion 70 | out = x * out1 71 | 72 | return torch.sigmoid(out) -------------------------------------------------------------------------------- /tools/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import absl.flags as flags 4 | 5 | FLAGS = flags.FLAGS 6 | from mmcv import Config 7 | from tools.solver_utils import build_lr_scheduler, build_optimizer_with_params 8 | 9 | 10 | # important parameters used here 11 | # total_iters: total_epoch x iteration per epoch 12 | 13 | def build_lr_rate(optimizer, total_iters): 14 | # build cfg from flags 15 | cfg = dict( 16 | SOLVER=dict( 17 | IMS_PER_BATCH=FLAGS.batch_size, 18 | TOTAL_EPOCHS=FLAGS.total_epoch, 19 | LR_SCHEDULER_NAME=FLAGS.lr_scheduler_name, 20 | REL_STEPS=(0.5, 0.75), 21 | ANNEAL_METHOD=FLAGS.anneal_method, # "cosine" 22 | ANNEAL_POINT=FLAGS.anneal_point, 23 | # REL_STEPS=(0.3125, 0.625, 0.9375), 24 | OPTIMIZER_CFG=dict(type=FLAGS.optimizer_type, lr=FLAGS.lr, weight_decay=0), 25 | WEIGHT_DECAY=FLAGS.weight_decay, 26 | WARMUP_FACTOR=FLAGS.warmup_factor, 27 | WARMUP_ITERS=FLAGS.warmup_iters, 28 | WARMUP_METHOD=FLAGS.warmup_method, 29 | GAMMA=FLAGS.gamma, 30 | POLY_POWER=FLAGS.poly_power, 31 | ), 32 | ) 33 | cfg = Config(cfg) 34 | scheduler = build_lr_scheduler(cfg, optimizer, total_iters=total_iters) 35 | return scheduler 36 | 37 | 38 | def build_optimizer(params): 39 | # build cfg from flags 40 | cfg = dict( 41 | SOLVER=dict( 42 | IMS_PER_BATCH=FLAGS.batch_size, 43 | TOTAL_EPOCHS=FLAGS.total_epoch, 44 | LR_SCHEDULER_NAME=FLAGS.lr_scheduler_name, 45 | ANNEAL_METHOD=FLAGS.anneal_method, # "cosine" 46 | ANNEAL_POINT=FLAGS.anneal_point, 47 | # REL_STEPS=(0.3125, 0.625, 0.9375), 48 | OPTIMIZER_CFG=dict(type=FLAGS.optimizer_type, lr=FLAGS.lr, weight_decay=0), 49 | WEIGHT_DECAY=FLAGS.weight_decay, 50 | WARMUP_FACTOR=FLAGS.warmup_factor, 51 | WARMUP_ITERS=FLAGS.warmup_iters, 52 | ), 53 | ) 54 | cfg = Config(cfg) 55 | optimizer = build_optimizer_with_params(cfg, params) 56 | return optimizer 57 | 58 | 59 | def get_gt_v(Rs, axis=2): 60 | bs = Rs.shape[0] # bs x 3 x 3 61 | # TODO use 3 axis, the order remains: do we need to change order? 62 | if axis == 3: 63 | corners = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float).to(Rs.device) 64 | corners = corners.view(1, 3, 3).repeat(bs, 1, 1) # bs x 3 x 3 65 | gt_vec = torch.bmm(Rs, corners).transpose(2, 1).reshape(bs, -1) 66 | else: 67 | assert axis == 2 68 | corners = torch.tensor([[0, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float).to(Rs.device) 69 | corners = corners.view(1, 3, 3).repeat(bs, 1, 1) # bs x 3 x 3 70 | gt_vec = torch.bmm(Rs, corners).transpose(2, 1).reshape(bs, -1) 71 | gt_green = gt_vec[:, 3:6] 72 | gt_red = gt_vec[:, (6, 7, 8)] 73 | return gt_green, gt_red 74 | 75 | 76 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 77 | assert isinstance(input, torch.Tensor) 78 | if posinf is None: 79 | posinf = torch.finfo(input.dtype).max 80 | if neginf is None: 81 | neginf = torch.finfo(input.dtype).min 82 | assert nan == 0 83 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 84 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/sgdp.py: -------------------------------------------------------------------------------- 1 | """AdamP Copyright (c) 2020-present NAVER Corp. 2 | 3 | MIT license 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim.optimizer import Optimizer, required 10 | import math 11 | 12 | 13 | class SGDP(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=required, 18 | momentum=0, 19 | dampening=0, 20 | weight_decay=0, 21 | nesterov=False, 22 | eps=1e-8, 23 | delta=0.1, 24 | wd_ratio=0.1, 25 | ): 26 | defaults = dict( 27 | lr=lr, 28 | momentum=momentum, 29 | dampening=dampening, 30 | weight_decay=weight_decay, 31 | nesterov=nesterov, 32 | eps=eps, 33 | delta=delta, 34 | wd_ratio=wd_ratio, 35 | ) 36 | super(SGDP, self).__init__(params, defaults) 37 | 38 | def _channel_view(self, x): 39 | return x.view(x.size(0), -1) 40 | 41 | def _layer_view(self, x): 42 | return x.view(1, -1) 43 | 44 | def _cosine_similarity(self, x, y, eps, view_func): 45 | x = view_func(x) 46 | y = view_func(y) 47 | 48 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() 49 | 50 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 51 | wd = 1 52 | expand_size = [-1] + [1] * (len(p.shape) - 1) 53 | for view_func in [self._channel_view, self._layer_view]: 54 | 55 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 56 | 57 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 58 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 59 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 60 | wd = wd_ratio 61 | 62 | return perturb, wd 63 | 64 | return perturb, wd 65 | 66 | def step(self, closure=None): 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | momentum = group["momentum"] 73 | dampening = group["dampening"] 74 | nesterov = group["nesterov"] 75 | 76 | for p in group["params"]: 77 | if p.grad is None: 78 | continue 79 | grad = p.grad.data 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state["momentum"] = torch.zeros_like(p.data) 85 | 86 | # SGD 87 | buf = state["momentum"] 88 | buf.mul_(momentum).add_(grad, alpha=1 - dampening) 89 | if nesterov: 90 | d_p = grad + momentum * buf 91 | else: 92 | d_p = buf 93 | 94 | # Projection 95 | wd_ratio = 1 96 | if len(p.shape) > 1: 97 | d_p, wd_ratio = self._projection(p, grad, d_p, group["delta"], group["wd_ratio"], group["eps"]) 98 | 99 | # Weight decay 100 | if group["weight_decay"] > 0: 101 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio / (1 - momentum)) 102 | 103 | # Step 104 | p.data.add_(d_p, alpha=-group["lr"]) 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /tools/perspective3d.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import absl.flags as flags 8 | from absl import app 9 | import tools.geom_utils as g 10 | import tools.image_utils as i_util 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | class Perspective3d(nn.Module): 15 | def __init__(self, camK, z_range=1, R=None, t=None, s=None, det=None): 16 | super(Perspective3d, self).__init__() 17 | self.z_range = 1 18 | self.fx = camK[0, 0] 19 | self.fy = camK[1, 1] 20 | self.ux = camK[0, 2] 21 | self.uy = camK[1, 2] 22 | self.rot = R # from camera center to current object position 23 | self.tran = t 24 | self.s = s # scale 25 | self.det = det # compensation of the segmentation bs, 2 26 | 27 | # in this function, the scale is set to 1 by default 28 | def ray2grid(self, xy_sample, z_sample, bs, device): 29 | height = width = xy_sample 30 | 31 | x_t = torch.linspace(0, width, width, dtype=torch.float32, device=device) # image space 32 | y_t = torch.linspace(0, height, height, dtype=torch.float32, device=device) # image space 33 | z_t = torch.linspace(-self.z_range / 2, self.z_range / 2, z_sample, dtype=torch.float32, 34 | device=device) # depth step 35 | 36 | z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t) # [D, W, H] # cmt: this must be in ZYX order 37 | 38 | x_t = x_t.unsqueeze(0).repeat(bs, 1, 1, 1) 39 | y_t = y_t.unsqueeze(0).repeat(bs, 1, 1, 1) 40 | z_t = z_t.unsqueeze(0).repeat(bs, 1, 1, 1) 41 | 42 | Z_t = z_t + self.tran[..., -1] 43 | 44 | X_t = (x_t - self.ux + self.det[..., 0]) * Z_t / self.fx 45 | Y_t = (y_t - self.uy + self.det[..., 1]) * Z_t / self.fy 46 | 47 | ones = torch.ones_like(X_t) 48 | grid = torch.stack([X_t, Y_t, Z_t, ones], dim=-1) 49 | 50 | return grid 51 | 52 | def camera2world(self): 53 | rot_T = (g.homo_to_3x3(self.rot)).permute(0, 2, 1) 54 | rt_inv = g.rt_to_homo(rot_T, -torch.matmul(rot_T, self.tran.unsqueeze(-1))) 55 | scale_inv = g.diag_to_homo(1 / self.s) 56 | wTc = torch.matmul(scale_inv, rt_inv) 57 | return wTc 58 | 59 | 60 | def forward(self, voxels, xy_sample, z_sample): 61 | bs = voxels.shape[0] 62 | wTc = self.camera2world() 63 | cGrid = self.ray2grid(xy_sample, z_sample, bs, device=voxels.device) 64 | wGrid = torch.matmul(cGrid.view(bs, -1, 4), wTc.transpose(1, 2)).view(bs, z_sample, xy_sample, xy_sample, 4) 65 | wGrid = 2 * wGrid[..., 0:3] / wGrid[..., 3:4] # scale from [0.5, 0.5] to [-1, 1] 66 | voxels = F.grid_sample(voxels, wGrid, align_corners=True) 67 | return voxels 68 | 69 | 70 | def main(_): 71 | device = 'cuda' 72 | 73 | H = W = D = 16 74 | N = 1 75 | vox = torch.zeros([N, 1, D, H, W], device=device) 76 | vox[..., 0:D // 2, 0:H//2, 0:W//2] = 1 77 | 78 | for i in range(-30, 30, 10): 79 | param = torch.FloatTensor([[0, i / 180 * 3.14, 1, 0, 0, 2]]).to(device) 80 | scale, tran, rot = g.azel2uni(param) 81 | f = 375 82 | camK = torch.tensor([[f, 0.0, 128], [0.0, f, 128], [0, 0, 1]], dtype=torch.float32).to(device) 83 | det = torch.tensor([[100, 100]], dtype=torch.float32).to(device) 84 | det = det.view(1, 2) 85 | IH = 32 86 | layer = Perspective3d(camK=camK, z_range=1, R=rot, t=tran, s=scale, det=det).to(device) 87 | trans_vox = layer(vox, IH, IH) 88 | mask = torch.mean(trans_vox, dim=2) # (N, 1, H, W) 89 | 90 | save_dir = 'outputs' 91 | i_util.save_images(mask, os.path.join(save_dir, 'test_%d' % i)) 92 | 93 | if __name__ == "__main__": 94 | app.run(main) 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /prepare_data/renderer.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """Abstract class of a renderer and a factory function to create a renderer. 5 | 6 | The renderer produces an RGB/depth image of a 3D mesh model in a specified pose 7 | for given camera parameters and illumination settings. 8 | """ 9 | 10 | 11 | class Renderer(object): 12 | """Abstract class of a renderer.""" 13 | 14 | def __init__(self, width, height): 15 | """Constructor. 16 | 17 | :param width: Width of the rendered image. 18 | :param height: Height of the rendered image. 19 | """ 20 | self.width = width 21 | self.height = height 22 | 23 | # 3D location of a point light (in the camera coordinates). 24 | self.light_cam_pos = (0, 0, 0) 25 | 26 | # Set light color and weights. 27 | self.light_color = (1.0, 1.0, 1.0) # Used only in C++ renderer. 28 | self.light_ambient_weight = 0.5 29 | self.light_diffuse_weight = 1.0 # Used only in C++ renderer. 30 | self.light_specular_weight = 0.0 # Used only in C++ renderer. 31 | self.light_specular_shininess = 0.0 # Used only in C++ renderer. 32 | 33 | def set_light_cam_pos(self, light_cam_pos): 34 | """Sets the 3D location of a point light. 35 | 36 | :param light_cam_pos: [X, Y, Z]. 37 | """ 38 | self.light_cam_pos = light_cam_pos 39 | 40 | def set_light_ambient_weight(self, light_ambient_weight): 41 | """Sets weight of the ambient light. 42 | 43 | :param light_ambient_weight: Scalar from 0 to 1. 44 | """ 45 | self.light_ambient_weight = light_ambient_weight 46 | 47 | def add_object(self, obj_id, model_path, **kwargs): 48 | """Loads an object model. 49 | 50 | :param obj_id: Object identifier. 51 | :param model_path: Path to the object model file. 52 | """ 53 | raise NotImplementedError 54 | 55 | def remove_object(self, obj_id): 56 | """Removes an object model. 57 | 58 | :param obj_id: Identifier of the object to remove. 59 | """ 60 | raise NotImplementedError 61 | 62 | def render_object(self, obj_id, R, t, fx, fy, cx, cy): 63 | """Renders an object model in the specified pose. 64 | 65 | :param obj_id: Object identifier. 66 | :param R: 3x3 ndarray with a rotation matrix. 67 | :param t: 3x1 ndarray with a translation vector. 68 | :param fx: Focal length (X axis). 69 | :param fy: Focal length (Y axis). 70 | :param cx: The X coordinate of the principal point. 71 | :param cy: The Y coordinate of the principal point. 72 | :return: Returns a dictionary with rendered images. 73 | """ 74 | raise NotImplementedError 75 | 76 | 77 | def create_renderer(width, height, renderer_type='cpp', mode='rgb+depth', 78 | shading='phong', bg_color=(0.0, 0.0, 0.0, 0.0)): 79 | """A factory to create a renderer. 80 | 81 | Note: Parameters mode, shading and bg_color are currently supported only by 82 | the Python renderer (renderer_type='python'). 83 | 84 | :param width: Width of the rendered image. 85 | :param height: Height of the rendered image. 86 | :param renderer_type: Type of renderer (options: 'cpp', 'python'). 87 | :param mode: Rendering mode ('rgb+depth', 'rgb', 'depth'). 88 | :param shading: Type of shading ('flat', 'phong'). 89 | :param bg_color: Color of the background (R, G, B, A). 90 | :return: Instance of a renderer of the specified type. 91 | """ 92 | if renderer_type == 'python': 93 | from . import renderer_py 94 | return renderer_py.RendererPython(width, height, mode, shading, bg_color) 95 | 96 | elif renderer_type == 'cpp': 97 | from . import renderer_cpp 98 | return renderer_cpp.RendererCpp(width, height) 99 | 100 | else: 101 | raise ValueError('Unknown renderer type.') 102 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | addict==2.4.0 3 | albumentations==1.0.3 4 | antlr4-python3-runtime==4.8 5 | anyio==3.3.2 6 | appdirs==1.4.4 7 | argon2-cffi==21.1.0 8 | astunparse==1.6.3 9 | attrs==21.2.0 10 | Babel==2.9.1 11 | backcall==0.2.0 12 | black==21.4b2 13 | bleach==1.5.0 14 | blessings==1.7 15 | cachetools==4.2.2 16 | certifi==2021.10.8 17 | cffi==1.14.6 18 | charset-normalizer==2.0.4 19 | click==8.0.1 20 | cloudpickle==2.0.0 21 | colorama==0.4.4 22 | cycler==0.10.0 23 | Cython==0.29.24 24 | debugpy==1.4.3 25 | decorator==5.1.0 26 | defusedxml==0.7.1 27 | deprecation==2.1.0 28 | detectron2==0.6+cu113 29 | entrypoints==0.3 30 | flatbuffers==2.0 31 | future==0.18.2 32 | fvcore==0.1.5.post20211023 33 | gast==0.3.3 34 | google-auth==1.35.0 35 | google-auth-oauthlib==0.4.6 36 | google-pasta==0.2.0 37 | gpustat==0.6.0 38 | grpcio==1.40.0 39 | GuideConv==0.0.0 40 | h5py==2.10.0 41 | html5lib==0.9999999 42 | hydra-core==1.1.1 43 | idna==3.2 44 | imageio==2.9.0 45 | imgaug==0.4.0 46 | importlib-resources==5.2.2 47 | iopath==0.1.8 48 | ipykernel==6.4.1 49 | ipython==7.28.0 50 | ipython-genutils==0.2.0 51 | ipywidgets==7.6.5 52 | jedi==0.18.0 53 | Jinja2==3.0.1 54 | joblib==1.0.1 55 | json5==0.9.6 56 | jsonschema==3.2.0 57 | jupyter-client==7.0.5 58 | jupyter-core==4.8.1 59 | jupyter-packaging==0.10.6 60 | jupyter-server==1.11.0 61 | jupyterlab==3.1.14 62 | jupyterlab-pygments==0.1.2 63 | jupyterlab-server==2.8.2 64 | jupyterlab-widgets==1.0.2 65 | keras==2.7.0 66 | Keras-Preprocessing==1.1.2 67 | kiwisolver==1.3.2 68 | libclang==12.0.0 69 | Markdown==3.3.4 70 | MarkupSafe==2.0.1 71 | matplotlib==3.4.3 72 | matplotlib-inline==0.1.3 73 | mistune==0.8.4 74 | mkl-fft==1.3.0 75 | mkl-random==1.2.2 76 | mkl-service==2.4.0 77 | mmcv-full==1.3.12 78 | mypy-extensions==0.4.3 79 | nbclassic==0.3.2 80 | nbclient==0.5.4 81 | nbconvert==6.2.0 82 | nbformat==5.1.3 83 | nest-asyncio==1.5.1 84 | networkx==2.6.3 85 | nn-distance==0.0.0 86 | nose==1.3.7 87 | notebook==6.4.4 88 | numpy==1.20.3 89 | nvidia-ml-py3==7.352.0 90 | oauthlib==3.1.1 91 | olefile==0.46 92 | omegaconf==2.1.1 93 | opencv-python==4.5.3.56 94 | opencv-python-headless==4.5.3.56 95 | opt-einsum==3.3.0 96 | packaging==21.0 97 | pandas==1.3.3 98 | pandocfilters==1.5.0 99 | parso==0.8.2 100 | pathspec==0.9.0 101 | pexpect==4.8.0 102 | pickleshare==0.7.5 103 | Pillow==8.3.1 104 | pip==21.0.1 105 | plyfile==0.7.4 106 | portalocker==2.3.1 107 | prometheus-client==0.11.0 108 | prompt-toolkit==3.0.20 109 | protobuf==3.17.3 110 | psutil==5.8.0 111 | ptyprocess==0.7.0 112 | pyasn1==0.4.8 113 | pyasn1-modules==0.2.8 114 | pycocotools==2.0.2 115 | pycparser==2.20 116 | pydot==1.4.2 117 | Pygments==2.10.0 118 | pyparsing==2.4.7 119 | pyrsistent==0.18.0 120 | python-dateutil==2.8.2 121 | pytz==2021.1 122 | PyWavelets==1.1.1 123 | PyYAML==5.4.1 124 | pyzmq==22.3.0 125 | regex==2021.8.28 126 | requests==2.26.0 127 | requests-oauthlib==1.3.0 128 | requests-unixsocket==0.2.0 129 | rsa==4.7.2 130 | scikit-image==0.18.3 131 | scikit-learn==1.0 132 | scipy==1.4.1 133 | Send2Trash==1.8.0 134 | setuptools==52.0.0.post20210125 135 | Shapely==1.7.1 136 | six==1.16.0 137 | sniffio==1.2.0 138 | tabulate==0.8.9 139 | tensorboard==2.7.0 140 | tensorboard-data-server==0.6.1 141 | tensorboard-plugin-wit==1.8.0 142 | tensorflow-cpu==2.7.0 143 | tensorflow-estimator==2.7.0 144 | tensorflow-io-gcs-filesystem==0.23.0 145 | termcolor==1.1.0 146 | terminado==0.12.1 147 | testpath==0.5.0 148 | threadpoolctl==2.2.0 149 | tifffile==2021.8.30 150 | toml==0.10.2 151 | tomlkit==0.7.2 152 | torch==1.10.1 153 | torch-encoding==1.2.1 154 | torchvision==0.11.2 155 | tornado==6.1 156 | tqdm==4.62.2 157 | traitlets==5.1.0 158 | typing-extensions==3.10.0.0 159 | urllib3==1.26.6 160 | wcwidth==0.2.5 161 | websocket-client==1.2.1 162 | Werkzeug==2.0.1 163 | wheel==0.37.0 164 | widgetsnbextension==3.5.1 165 | wrapt==1.12.1 166 | yacs==0.1.8 167 | yapf==0.31.0 168 | zipp==3.5.0 169 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/adamp.py: -------------------------------------------------------------------------------- 1 | """AdamP Copyright (c) 2020-present NAVER Corp. 2 | 3 | MIT license 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim.optimizer import Optimizer, required 10 | import math 11 | 12 | 13 | class AdamP(Optimizer): 14 | def __init__( 15 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False 16 | ): 17 | defaults = dict( 18 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, delta=delta, wd_ratio=wd_ratio, nesterov=nesterov 19 | ) 20 | super(AdamP, self).__init__(params, defaults) 21 | 22 | def _channel_view(self, x): 23 | return x.view(x.size(0), -1) 24 | 25 | def _layer_view(self, x): 26 | return x.view(1, -1) 27 | 28 | def _cosine_similarity(self, x, y, eps, view_func): 29 | x = view_func(x) 30 | y = view_func(y) 31 | 32 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() 33 | 34 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 35 | wd = 1 36 | expand_size = [-1] + [1] * (len(p.shape) - 1) 37 | for view_func in [self._channel_view, self._layer_view]: 38 | 39 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 40 | 41 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 42 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 43 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 44 | wd = wd_ratio 45 | 46 | return perturb, wd 47 | 48 | return perturb, wd 49 | 50 | def step(self, closure=None): 51 | loss = None 52 | if closure is not None: 53 | loss = closure() 54 | 55 | for group in self.param_groups: 56 | for p in group["params"]: 57 | if p.grad is None: 58 | continue 59 | 60 | grad = p.grad.data 61 | beta1, beta2 = group["betas"] 62 | nesterov = group["nesterov"] 63 | 64 | state = self.state[p] 65 | 66 | # State initialization 67 | if len(state) == 0: 68 | state["step"] = 0 69 | state["exp_avg"] = torch.zeros_like(p.data) 70 | state["exp_avg_sq"] = torch.zeros_like(p.data) 71 | 72 | # Adam 73 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 74 | 75 | state["step"] += 1 76 | bias_correction1 = 1 - beta1 ** state["step"] 77 | bias_correction2 = 1 - beta2 ** state["step"] 78 | 79 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 80 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 81 | 82 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) 83 | step_size = group["lr"] / bias_correction1 84 | 85 | if nesterov: 86 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 87 | else: 88 | perturb = exp_avg / denom 89 | 90 | # Projection 91 | wd_ratio = 1 92 | if len(p.shape) > 1: 93 | perturb, wd_ratio = self._projection( 94 | p, grad, perturb, group["delta"], group["wd_ratio"], group["eps"] 95 | ) 96 | 97 | # Weight decay 98 | if group["weight_decay"] > 0: 99 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio) 100 | 101 | # Step 102 | p.data.add_(perturb, alpha=-step_size) 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /tools/rot_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import torch.nn.functional as F 6 | import absl.flags as flags 7 | FLAGS = flags.FLAGS 8 | def get_vertical_rot_vec(c1, c2, y, z): 9 | ## c1, c2 are weights 10 | ## y, x are rotation vectors 11 | y = y.view(-1) 12 | z = z.view(-1) 13 | rot_x = torch.cross(y, z) 14 | rot_x = rot_x / (torch.norm(rot_x) + 1e-8) 15 | # cal angle between y and z 16 | y_z_cos = torch.sum(y * z) 17 | y_z_theta = torch.acos(y_z_cos) 18 | theta_2 = c1 / (c1 + c2) * (y_z_theta - math.pi / 2) 19 | theta_1 = c2 / (c1 + c2) * (y_z_theta - math.pi / 2) 20 | # first rotate y 21 | c = torch.cos(theta_1) 22 | s = torch.sin(theta_1) 23 | rotmat_y = torch.tensor([[rot_x[0]*rot_x[0]*(1-c)+c, rot_x[0]*rot_x[1]*(1-c)-rot_x[2]*s, rot_x[0]*rot_x[2]*(1-c)+rot_x[1]*s], 24 | [rot_x[1]*rot_x[0]*(1-c)+rot_x[2]*s, rot_x[1]*rot_x[1]*(1-c)+c, rot_x[1]*rot_x[2]*(1-c)-rot_x[0]*s], 25 | [rot_x[0]*rot_x[2]*(1-c)-rot_x[1]*s, rot_x[2]*rot_x[1]*(1-c)+rot_x[0]*s, rot_x[2]*rot_x[2]*(1-c)+c]]).to(y.device) 26 | new_y = torch.mm(rotmat_y, y.view(-1, 1)) 27 | # then rotate z 28 | c = torch.cos(-theta_2) 29 | s = torch.sin(-theta_2) 30 | rotmat_z = torch.tensor([[rot_x[0] * rot_x[0] * (1 - c) + c, rot_x[0] * rot_x[1] * (1 - c) - rot_x[2] * s, 31 | rot_x[0] * rot_x[2] * (1 - c) + rot_x[1] * s], 32 | [rot_x[1] * rot_x[0] * (1 - c) + rot_x[2] * s, rot_x[1] * rot_x[1] * (1 - c) + c, 33 | rot_x[1] * rot_x[2] * (1 - c) - rot_x[0] * s], 34 | [rot_x[0] * rot_x[2] * (1 - c) - rot_x[1] * s, 35 | rot_x[2] * rot_x[1] * (1 - c) + rot_x[0] * s, rot_x[2] * rot_x[2] * (1 - c) + c]]).to( 36 | z.device) 37 | 38 | new_z = torch.mm(rotmat_z, z.view(-1, 1)) 39 | return new_y.view(-1), new_z.view(-1) 40 | 41 | def get_rot_mat_y_first(y, x): 42 | # poses 43 | 44 | y = F.normalize(y, p=2, dim=-1) # bx3 45 | z = torch.cross(x, y, dim=-1) # bx3 46 | z = F.normalize(z, p=2, dim=-1) # bx3 47 | x = torch.cross(y, z, dim=-1) # bx3 48 | 49 | # (*,3)x3 --> (*,3,3) 50 | return torch.stack((x, y, z), dim=-1) # (b,3,3) 51 | 52 | def get_rot_vec_vert_batch(c1, c2, y, z): 53 | bs = c1.shape[0] 54 | new_y = y 55 | new_z = z 56 | for i in range(bs): 57 | new_y[i, ...], new_z[i, ...] = get_vertical_rot_vec(c1[i, ...], c2[i, ...], y[i, ...], z[i, ...]) 58 | return new_y, new_z 59 | 60 | def get_R_batch(f_g_vec, f_r_vec, p_g_vec, p_r_vec, sym): 61 | bs = sym.shape[0] 62 | p_R_batch = torch.zeros((bs,3,3)).to(sym.device) 63 | for i in range(bs): 64 | if sym[i, 0] == 1: 65 | # estimate pred_R 66 | new_y, new_x = get_vertical_rot_vec(f_g_vec[i], 1e-5, p_g_vec[i, ...], p_r_vec[i, ...]) 67 | p_R = get_rot_mat_y_first(new_y.view(1, -1), new_x.view(1, -1))[0] # 3 x 3 68 | else: 69 | # estimate pred_R 70 | new_y, new_x = get_vertical_rot_vec(f_g_vec[i], f_r_vec[i], p_g_vec[i, ...], p_r_vec[i, ...]) 71 | p_R = get_rot_mat_y_first(new_y.view(1, -1), new_x.view(1, -1))[0] # 3 x 3 72 | p_R_batch[i,...] = p_R 73 | return p_R_batch 74 | 75 | if __name__ == '__main__': 76 | g_R=torch.tensor([[0.3126, 0.0018, -0.9499], 77 | [0.7303, -0.6400, 0.2391], 78 | [-0.6074, -0.7684, -0.2014]], device='cuda:0') 79 | y = g_R[:, 1] 80 | x = g_R[:, 0] 81 | c1 = 5 82 | c2 = 1 83 | y = y / torch.norm(y) 84 | x = x / torch.norm(x) 85 | L = torch.dot(y, x) 86 | Lp = torch.cross(x, y) 87 | Lp = Lp / torch.norm(Lp) 88 | new_y, nnew_x = get_vertical_rot_vec(c1, c2, y, x) 89 | M = torch.dot(new_y, nnew_x) 90 | Mp = torch.cross(new_y, nnew_x) 91 | Mp = Mp / torch.norm(Mp) 92 | new_R = get_rot_mat_y_first(new_y.view(1, -1), nnew_x.view(1, -1)) 93 | print('OK') -------------------------------------------------------------------------------- /tools/torch_utils/solver/lookahead.py: -------------------------------------------------------------------------------- 1 | """Lookahead Optimizer Wrapper. Implementation modified from: 2 | https://github.com/alphadl/lookahead.pytorch. 3 | 4 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 5 | """ 6 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | # from lib.utils import logger 12 | 13 | 14 | class Lookahead(Optimizer): 15 | def __init__(self, base_optimizer, alpha=0.5, k=6): 16 | if not 0.0 <= alpha <= 1.0: 17 | raise ValueError(f"Invalid slow update rate: {alpha}") 18 | if not 1 <= k: 19 | raise ValueError(f"Invalid lookahead steps: {k}") 20 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 21 | self.base_optimizer = base_optimizer 22 | self.param_groups = self.base_optimizer.param_groups 23 | self.defaults = base_optimizer.defaults 24 | self.defaults.update(defaults) 25 | self.state = defaultdict(dict) 26 | # manually add our defaults to the param groups 27 | for name, default in defaults.items(): 28 | for group in self.param_groups: 29 | group.setdefault(name, default) 30 | 31 | def update_slow(self, group): 32 | for fast_p in group["params"]: 33 | if fast_p.grad is None: 34 | continue 35 | param_state = self.state[fast_p] 36 | if "slow_buffer" not in param_state: 37 | param_state["slow_buffer"] = torch.empty_like(fast_p.data) 38 | param_state["slow_buffer"].copy_(fast_p.data) 39 | slow = param_state["slow_buffer"] 40 | slow.add_(group["lookahead_alpha"], fast_p.data - slow) 41 | fast_p.data.copy_(slow) 42 | 43 | def sync_lookahead(self): 44 | for group in self.param_groups: 45 | self.update_slow(group) 46 | 47 | def step(self, closure=None): 48 | # assert id(self.param_groups) == id(self.base_optimizer.param_groups) 49 | loss = self.base_optimizer.step(closure) 50 | for group in self.param_groups: 51 | group["lookahead_step"] += 1 52 | if group["lookahead_step"] % group["lookahead_k"] == 0: 53 | self.update_slow(group) 54 | return loss 55 | 56 | def state_dict(self): 57 | fast_state_dict = self.base_optimizer.state_dict() 58 | slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()} 59 | fast_state = fast_state_dict["state"] 60 | param_groups = fast_state_dict["param_groups"] 61 | return {"state": fast_state, "slow_state": slow_state, "param_groups": param_groups} 62 | 63 | def load_state_dict(self, state_dict): 64 | fast_state_dict = {"state": state_dict["state"], "param_groups": state_dict["param_groups"]} 65 | self.base_optimizer.load_state_dict(fast_state_dict) 66 | 67 | # We want to restore the slow state, but share param_groups reference 68 | # with base_optimizer. This is a bit redundant but least code 69 | slow_state_new = False 70 | if "slow_state" not in state_dict: 71 | print("Loading state_dict from optimizer without Lookahead applied.") 72 | state_dict["slow_state"] = defaultdict(dict) 73 | slow_state_new = True 74 | slow_state_dict = { 75 | "state": state_dict["slow_state"], 76 | "param_groups": state_dict["param_groups"], # this is pointless but saves code 77 | } 78 | super(Lookahead, self).load_state_dict(slow_state_dict) 79 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 80 | if slow_state_new: 81 | # reapply defaults to catch missing lookahead specific ones 82 | for name, default in self.defaults.items(): 83 | for group in self.param_groups: 84 | group.setdefault(name, default) 85 | -------------------------------------------------------------------------------- /prepare_data/lib/nn_distance/chamfer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nn_distance 3 | 4 | 5 | class NnDistanceFunction(torch.autograd.Function): 6 | """ 3D point set to 3D point set distance. 7 | 8 | """ 9 | @staticmethod 10 | def forward(ctx, xyz1, xyz2): 11 | B, N, _ = xyz1.size() 12 | B, M, _ = xyz2.size() 13 | result = torch.empty(B, N, dtype=xyz1.dtype, device=xyz1.device) 14 | result_i = torch.empty(B, N, dtype=torch.int32, device=xyz1.device) 15 | result2 = torch.empty(B, M, dtype=xyz2.dtype, device=xyz2.device) 16 | result2_i = torch.empty(B, M, dtype=torch.int32, device=xyz2.device) 17 | nn_distance.forward(xyz1, xyz2, result, result2, result_i, result2_i) 18 | ctx.save_for_backward(xyz1, xyz2, result_i, result2_i) 19 | ctx.mark_non_differentiable(result_i, result2_i) 20 | return result, result2, result_i, result2_i 21 | 22 | @staticmethod 23 | def backward(ctx, d_dist1, d_dist2, d_i1, d_i2): 24 | B, N = d_dist1.size() 25 | B, M = d_dist2.size() 26 | xyz1, xyz2, idx1, idx2 = ctx.saved_variables 27 | d_xyz1 = torch.zeros_like(xyz1) 28 | d_xyz2 = torch.zeros_like(xyz2) 29 | gradient1, gradient2 = ctx.needs_input_grad 30 | nn_distance.backward(xyz1, xyz2, d_xyz1, d_xyz2, d_dist1, d_dist2, idx1, idx2) 31 | if not gradient1: 32 | return None, d_xyz2 33 | if not gradient2: 34 | return d_xyz1, None 35 | else: 36 | return d_xyz1, d_xyz2 37 | 38 | 39 | class ChamferLoss(torch.nn.Module): 40 | """ Chamfer Loss: bidirectional nearest neighbor distance of two point sets. 41 | 42 | """ 43 | def __init__(self, threshold=None, backward_weight=1.0): 44 | super(ChamferLoss, self).__init__() 45 | # only consider distance smaller than threshold*mean(distance) (remove outlier) 46 | self.__threshold = threshold 47 | self.backward_weight = backward_weight 48 | 49 | def set_threshold(self, value): 50 | self.__threshold = value 51 | 52 | def unset_threshold(self): 53 | self.__threshold = None 54 | 55 | def forward(self, pred, gt): 56 | assert(pred.dim() == 3 and gt.dim() == 3), \ 57 | "input for ChamferLoss must be a 3D-tensor, but pred.size() is {} gt.size() is {}".format(pred.size(), gt.size()) 58 | # need transpose 59 | if pred.size(2) != 3: 60 | assert(pred.size(1) == 3), "ChamferLoss is implemented for 3D points" 61 | pred = pred.transpose(2, 1).contiguous() 62 | if gt.size(2) != 3: 63 | assert(gt.size(1) == 3), "ChamferLoss is implemented for 3D points" 64 | gt = gt.transpose(2, 1).contiguous() 65 | assert(pred.size(2) == 3 and gt.size(2) == 3), "ChamferLoss is implemented for 3D points" 66 | pred2gt, gt2pred, idx1, idx2 = NnDistanceFunction.apply(pred, gt) 67 | 68 | if self.__threshold is not None: 69 | threshold = self.__threshold 70 | forward_threshold = torch.mean(pred2gt, dim=1, keepdim=True) * threshold 71 | backward_threshold = torch.mean(gt2pred, dim=1, keepdim=True) * threshold 72 | # only care about distance within threshold (ignore strong outliers) 73 | pred2gt = torch.where(pred2gt < forward_threshold, pred2gt, torch.zeros_like(pred2gt)) 74 | gt2pred = torch.where(gt2pred < backward_threshold, gt2pred, torch.zeros_like(gt2pred)) 75 | 76 | pred2gt = torch.mean(pred2gt, dim=1) 77 | gt2pred = torch.mean(gt2pred, dim=1) 78 | cd_dist = pred2gt + self.backward_weight * gt2pred 79 | cd_loss = torch.mean(cd_dist) 80 | return cd_loss, idx1, idx2 81 | 82 | 83 | if __name__ == '__main__': 84 | from torch.autograd import gradcheck 85 | nndistance = NnDistanceFunction.apply 86 | pc1 = torch.randn([2, 60, 3], dtype=torch.float, requires_grad=True).cuda() 87 | pc2 = torch.randn([2, 30, 3], dtype=torch.float, requires_grad=True).cuda() 88 | test = gradcheck(nndistance, (pc1, pc2), eps=1e-3, atol=1e-3) 89 | print(test) 90 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 127 | -------------------------------------------------------------------------------- /tools/solver_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from typing import Any, Dict, List 3 | 4 | import torch 5 | from detectron2.config import CfgNode 6 | from detectron2.solver import WarmupCosineLR, WarmupMultiStepLR 7 | from tools.torch_utils.solver.lr_scheduler import flat_and_anneal_lr_scheduler 8 | from tools.torch_utils.solver.ranger2020 import Ranger 9 | import absl.flags as flags 10 | FLAGS = flags.FLAGS 11 | 12 | __all__ = ["build_lr_scheduler", "build_optimizer_with_params"] 13 | 14 | ''' 15 | def register_optimizer(name): 16 | """TODO: add more optimizers""" 17 | if name in OPTIMIZERS: 18 | return 19 | if name == "Ranger": 20 | from tools.torch_utils.solver.ranger import Ranger 21 | 22 | # from lib.torch_utils.solver.ranger2020 import Ranger 23 | OPTIMIZERS.register_module()(Ranger) 24 | elif name in ["AdaBelief", "RangerAdaBelief"]: 25 | from tools.torch_utils.solver.AdaBelief import AdaBelief 26 | from tools.torch_utils.solver.ranger_adabelief import RangerAdaBelief 27 | 28 | OPTIMIZERS.register_module()(AdaBelief) 29 | OPTIMIZERS.register_module()(RangerAdaBelief) 30 | elif name in ["SGDP", "AdamP"]: 31 | from tools.torch_utils.solver.adamp import AdamP 32 | from tools.torch_utils.solver.sgdp import SGDP 33 | 34 | OPTIMIZERS.register_module()(AdamP) 35 | OPTIMIZERS.register_module()(SGDP) 36 | elif name in ["SGD_GC", "SGD_GCC"]: 37 | from tools.torch_utils.solver.sgd_gc import SGD_GC, SGD_GCC 38 | 39 | OPTIMIZERS.register_module()(SGD_GC) 40 | OPTIMIZERS.register_module()(SGD_GCC) 41 | else: 42 | raise ValueError(f"Unknown optimizer name: {name}") 43 | ''' 44 | # note that this is adapted from mmcv, if you dont want to use ranger, 45 | # please use the provieded build from cfg in mmcv 46 | def build_optimizer_with_params(cfg, params): 47 | if cfg.SOLVER.OPTIMIZER_CFG == "": 48 | raise RuntimeError("please provide cfg.SOLVER.OPTIMIZER_CFG to build optimizer") 49 | if cfg.SOLVER.OPTIMIZER_CFG.type.lower() == "ranger": 50 | return Ranger(params=params, lr=cfg.SOLVER.OPTIMIZER_CFG.lr, weight_decay=cfg.SOLVER.OPTIMIZER_CFG.weight_decay) 51 | else: 52 | return None 53 | 54 | def build_lr_scheduler( 55 | cfg: CfgNode, optimizer: torch.optim.Optimizer, total_iters: int 56 | ) -> torch.optim.lr_scheduler._LRScheduler: 57 | """Build a LR scheduler from config.""" 58 | name = cfg.SOLVER.LR_SCHEDULER_NAME 59 | steps = [rel_step * total_iters for rel_step in cfg.SOLVER.REL_STEPS] 60 | if name == "WarmupMultiStepLR": 61 | return WarmupMultiStepLR( 62 | optimizer, 63 | steps, # cfg.SOLVER.STEPS, 64 | cfg.SOLVER.GAMMA, 65 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 66 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 67 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 68 | ) 69 | elif name == "WarmupCosineLR": 70 | return WarmupCosineLR( 71 | optimizer, 72 | total_iters, # cfg.SOLVER.MAX_ITER, 73 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 74 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 75 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 76 | ) 77 | elif name.lower() == "flat_and_anneal": 78 | return flat_and_anneal_lr_scheduler( 79 | optimizer, 80 | total_iters=total_iters, # NOTE: TOTAL_EPOCHS * len(train_loader) 81 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 82 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 83 | warmup_method=cfg.SOLVER.WARMUP_METHOD, # default "linear" 84 | anneal_method=cfg.SOLVER.ANNEAL_METHOD, 85 | anneal_point=cfg.SOLVER.ANNEAL_POINT, # default 0.72 86 | steps=cfg.SOLVER.get("REL_STEPS", [2 / 3.0, 8 / 9.0]), # default [2/3., 8/9.], relative decay steps 87 | target_lr_factor=cfg.SOLVER.get("TARTGET_LR_FACTOR", 0), 88 | poly_power=cfg.SOLVER.get("POLY_POWER", 1.0), 89 | step_gamma=cfg.SOLVER.GAMMA, # default 0.1 90 | ) 91 | else: 92 | raise ValueError("Unknown LR scheduler: {}".format(name)) 93 | -------------------------------------------------------------------------------- /network/fs_net_repo/PoseTs.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import absl.flags as flags 5 | from absl import app 6 | from mmcv.cnn import normal_init, constant_init 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | from nnutils.layer_utils import get_norm, get_nn_act_func 9 | 10 | FLAGS = flags.FLAGS 11 | # Point_center encode the segmented point cloud 12 | # one more conv layer compared to original paper 13 | 14 | class Pose_Ts(nn.Module): 15 | def __init__(self): 16 | super(Pose_Ts, self).__init__() 17 | self.f = FLAGS.feat_pcl + FLAGS.feat_global_pcl + FLAGS.feat_seman + 3 18 | self.k = FLAGS.Ts_c 19 | 20 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1) 21 | 22 | self.conv2 = torch.nn.Conv1d(1024, 256, 1) 23 | self.conv3 = torch.nn.Conv1d(256, 256, 1) 24 | self.conv4 = torch.nn.Conv1d(256, self.k, 1) 25 | self.drop1 = nn.Dropout(0.2) 26 | self.bn1 = nn.BatchNorm1d(1024) 27 | self.bn2 = nn.BatchNorm1d(256) 28 | self.bn3 = nn.BatchNorm1d(256) 29 | self._init_weights() 30 | 31 | def forward(self, x): 32 | x = F.relu(self.bn1(self.conv1(x))) 33 | x = F.relu(self.bn2(self.conv2(x))) 34 | 35 | x = torch.max(x, 2, keepdim=True)[0] 36 | 37 | x = F.relu(self.bn3(self.conv3(x))) 38 | x = self.drop1(x) 39 | x = self.conv4(x) 40 | 41 | x = x.squeeze(2) 42 | x = x.contiguous() 43 | xt = x[:, 0:3] 44 | xs = x[:, 3:6] 45 | return xt, xs 46 | 47 | def _init_weights(self): 48 | for m in self.modules(): 49 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 50 | normal_init(m, std=0.001) 51 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 52 | constant_init(m, 1) 53 | elif isinstance(m, nn.Linear): 54 | normal_init(m, std=0.001) 55 | 56 | class Pose_Ts_global(nn.Module): 57 | def __init__( 58 | self, 59 | feat_dim=256, 60 | num_layers=2, 61 | norm="none", 62 | num_gn_groups=32, 63 | act="leaky_relu", 64 | num_classes=1, 65 | norm_input=False, 66 | ): 67 | super().__init__() 68 | in_dim = FLAGS.feat_global_pcl 69 | self.norm = get_norm(norm, feat_dim, num_gn_groups=num_gn_groups) 70 | self.act_func = act_func = get_nn_act_func(act) 71 | self.num_classes = num_classes 72 | self.linears = nn.ModuleList() 73 | if norm_input: 74 | self.linears.append(nn.BatchNorm1d(in_dim)) 75 | for _i in range(num_layers): 76 | _in_dim = in_dim if _i == 0 else feat_dim 77 | self.linears.append(nn.Linear(_in_dim, feat_dim)) 78 | self.linears.append(get_norm(norm, feat_dim, num_gn_groups=num_gn_groups)) 79 | self.linears.append(act_func) 80 | 81 | self.fc_t = nn.Linear(feat_dim, 3 * num_classes) 82 | self.fc_s = nn.Linear(feat_dim, 3 * num_classes) 83 | 84 | # init ------------------------------------ 85 | self._init_weights() 86 | 87 | def _init_weights(self): 88 | for m in self.modules(): 89 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 90 | normal_init(m, std=0.001) 91 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 92 | constant_init(m, 1) 93 | elif isinstance(m, nn.Linear): 94 | normal_init(m, std=0.001) 95 | normal_init(self.fc_t, std=0.01) 96 | normal_init(self.fc_s, std=0.01) 97 | 98 | def forward(self, x): 99 | """ 100 | x: should be flattened 101 | """ 102 | for _layer in self.linears: 103 | x = _layer(x) 104 | 105 | trans = self.fc_t(x) 106 | scale = self.fc_s(x) 107 | return trans, scale 108 | 109 | 110 | def main(argv): 111 | feature = torch.rand(3, 3, 1000) 112 | obj_id = torch.randint(low=0, high=15, size=[3, 1]) 113 | net = Pose_Ts() 114 | out = net(feature, obj_id) 115 | t = 1 116 | 117 | if __name__ == "__main__": 118 | app.run(main) 119 | -------------------------------------------------------------------------------- /losses/nn_distance/chamfer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nn_distance 3 | 4 | 5 | class NnDistanceFunction(torch.autograd.Function): 6 | """ 3D point set to 3D point set distance. 7 | 8 | """ 9 | @staticmethod 10 | def forward(ctx, xyz1, xyz2): 11 | B, N, _ = xyz1.size() 12 | B, M, _ = xyz2.size() 13 | result = torch.empty(B, N, dtype=xyz1.dtype, device=xyz1.device) 14 | result_i = torch.empty(B, N, dtype=torch.int32, device=xyz1.device) 15 | result2 = torch.empty(B, M, dtype=xyz2.dtype, device=xyz2.device) 16 | result2_i = torch.empty(B, M, dtype=torch.int32, device=xyz2.device) 17 | nn_distance.forward(xyz1, xyz2, result, result2, result_i, result2_i) 18 | ctx.save_for_backward(xyz1, xyz2, result_i, result2_i) 19 | ctx.mark_non_differentiable(result_i, result2_i) 20 | return result, result2, result_i, result2_i 21 | 22 | @staticmethod 23 | def backward(ctx, d_dist1, d_dist2, d_i1, d_i2): 24 | B, N = d_dist1.size() 25 | B, M = d_dist2.size() 26 | xyz1, xyz2, idx1, idx2 = ctx.saved_variables 27 | d_xyz1 = torch.zeros_like(xyz1) 28 | d_xyz2 = torch.zeros_like(xyz2) 29 | gradient1, gradient2 = ctx.needs_input_grad 30 | nn_distance.backward(xyz1, xyz2, d_xyz1, d_xyz2, d_dist1, d_dist2, idx1, idx2) 31 | if not gradient1: 32 | return None, d_xyz2 33 | if not gradient2: 34 | return d_xyz1, None 35 | else: 36 | return d_xyz1, d_xyz2 37 | 38 | 39 | class ChamferLoss(torch.nn.Module): 40 | """ Chamfer Loss: bidirectional nearest neighbor distance of two point sets. 41 | 42 | """ 43 | def __init__(self, threshold=None, backward_weight=1.0): 44 | super(ChamferLoss, self).__init__() 45 | # only consider distance smaller than threshold*mean(distance) (remove outlier) 46 | self.__threshold = threshold 47 | self.backward_weight = backward_weight 48 | 49 | def set_threshold(self, value): 50 | self.__threshold = value 51 | 52 | def unset_threshold(self): 53 | self.__threshold = None 54 | 55 | def forward(self, pred, gt): 56 | assert(pred.dim() == 3 and gt.dim() == 3), \ 57 | "input for ChamferLoss must be a 3D-tensor, but pred.size() is {} gt.size() is {}".format(pred.size(), gt.size()) 58 | # need transpose 59 | if pred.size(2) != 3: 60 | assert(pred.size(1) == 3), "ChamferLoss is implemented for 3D points" 61 | pred = pred.transpose(2, 1).contiguous() 62 | if gt.size(2) != 3: 63 | assert(gt.size(1) == 3), "ChamferLoss is implemented for 3D points" 64 | gt = gt.transpose(2, 1).contiguous() 65 | assert(pred.size(2) == 3 and gt.size(2) == 3), "ChamferLoss is implemented for 3D points" 66 | pred2gt, gt2pred, idx1, idx2 = NnDistanceFunction.apply(pred, gt) 67 | 68 | if self.__threshold is not None: 69 | threshold = self.__threshold 70 | forward_threshold = torch.mean(pred2gt, dim=1, keepdim=True) * threshold 71 | backward_threshold = torch.mean(gt2pred, dim=1, keepdim=True) * threshold 72 | # only care about distance within threshold (ignore strong outliers) 73 | pred2gt = torch.where(pred2gt < forward_threshold, pred2gt, torch.zeros_like(pred2gt)) 74 | gt2pred = torch.where(gt2pred < backward_threshold, gt2pred, torch.zeros_like(gt2pred)) 75 | 76 | pred2gt = torch.mean(pred2gt, dim=1) 77 | gt2pred = torch.mean(gt2pred, dim=1) 78 | cd_dist = pred2gt + self.backward_weight * gt2pred 79 | cd_loss = torch.mean(cd_dist) 80 | return cd_loss, idx1, idx2 81 | 82 | 83 | if __name__ == '__main__': 84 | from torch.autograd import gradcheck 85 | # nndistance = NnDistanceFunction.apply 86 | # pc1 = torch.randn([2, 60, 3], dtype=torch.float, requires_grad=True).cuda() 87 | # pc2 = torch.randn([2, 30, 3], dtype=torch.float, requires_grad=True).cuda() 88 | # test = gradcheck(nndistance, (pc1, pc2), eps=1e-3, atol=1e-3) 89 | # print(test) 90 | 91 | distance = ChamferLoss(backward_weight=0.0) 92 | pc1 = torch.randn([2, 60, 3], dtype=torch.float, requires_grad=True).cuda() 93 | pc2 = pc1.clone() 94 | pc2 = pc2[:,:30,:].contiguous() 95 | test = distance(pc2, pc1) 96 | print(test) 97 | -------------------------------------------------------------------------------- /prepare_data/lib/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.pspnet import PSPNet 4 | 5 | 6 | class DeformNet(nn.Module): 7 | def __init__(self, n_cat=6, nv_prior=1024): 8 | super(DeformNet, self).__init__() 9 | self.n_cat = n_cat 10 | self.psp = PSPNet(bins=(1, 2, 3, 6), backend='resnet18') 11 | self.instance_color = nn.Sequential( 12 | nn.Conv1d(32, 64, 1), 13 | nn.ReLU(), 14 | ) 15 | self.instance_geometry = nn.Sequential( 16 | nn.Conv1d(3, 64, 1), 17 | nn.ReLU(), 18 | nn.Conv1d(64, 64, 1), 19 | nn.ReLU(), 20 | nn.Conv1d(64, 64, 1), 21 | nn.ReLU(), 22 | ) 23 | self.instance_global = nn.Sequential( 24 | nn.Conv1d(128, 128, 1), 25 | nn.ReLU(), 26 | nn.Conv1d(128, 1024, 1), 27 | nn.ReLU(), 28 | nn.AdaptiveAvgPool1d(1), 29 | ) 30 | self.category_local = nn.Sequential( 31 | nn.Conv1d(3, 64, 1), 32 | nn.ReLU(), 33 | nn.Conv1d(64, 64, 1), 34 | nn.ReLU(), 35 | nn.Conv1d(64, 64, 1), 36 | nn.ReLU(), 37 | ) 38 | self.category_global = nn.Sequential( 39 | nn.Conv1d(64, 128, 1), 40 | nn.ReLU(), 41 | nn.Conv1d(128, 1024, 1), 42 | nn.ReLU(), 43 | nn.AdaptiveAvgPool1d(1), 44 | ) 45 | self.assignment = nn.Sequential( 46 | nn.Conv1d(2176, 512, 1), 47 | nn.ReLU(), 48 | nn.Conv1d(512, 256, 1), 49 | nn.ReLU(), 50 | nn.Conv1d(256, n_cat*nv_prior, 1), 51 | ) 52 | self.deformation = nn.Sequential( 53 | nn.Conv1d(2112, 512, 1), 54 | nn.ReLU(), 55 | nn.Conv1d(512, 256, 1), 56 | nn.ReLU(), 57 | nn.Conv1d(256, n_cat*3, 1), 58 | ) 59 | # Initialize weights to be small so initial deformations aren't so big 60 | self.deformation[4].weight.data.normal_(0, 0.0001) 61 | 62 | def forward(self, points, img, choose, cat_id, prior): 63 | """ 64 | Args: 65 | points: bs x n_pts x 3 66 | img: bs x 3 x H x W 67 | choose: bs x n_pts 68 | cat_id: bs 69 | prior: bs x nv x 3 70 | 71 | Returns: 72 | assign_mat: bs x n_pts x nv 73 | inst_shape: bs x nv x 3 74 | deltas: bs x nv x 3 75 | log_assign: bs x n_pts x nv, for numerical stability 76 | 77 | """ 78 | bs, n_pts = points.size()[:2] 79 | nv = prior.size()[1] 80 | # instance-specific features 81 | points = points.permute(0, 2, 1) 82 | points = self.instance_geometry(points) 83 | out_img = self.psp(img) 84 | di = out_img.size()[1] 85 | emb = out_img.view(bs, di, -1) 86 | choose = choose.unsqueeze(1).repeat(1, di, 1) 87 | emb = torch.gather(emb, 2, choose).contiguous() 88 | emb = self.instance_color(emb) 89 | inst_local = torch.cat((points, emb), dim=1) # bs x 128 x n_pts 90 | inst_global = self.instance_global(inst_local) # bs x 1024 x 1 91 | # category-specific features 92 | cat_prior = prior.permute(0, 2, 1) 93 | cat_local = self.category_local(cat_prior) # bs x 64 x n_pts 94 | cat_global = self.category_global(cat_local) # bs x 1024 x 1 95 | # assignemnt matrix 96 | assign_feat = torch.cat((inst_local, inst_global.repeat(1, 1, n_pts), cat_global.repeat(1, 1, n_pts)), dim=1) # bs x 2176 x n_pts 97 | assign_mat = self.assignment(assign_feat) 98 | assign_mat = assign_mat.view(-1, nv, n_pts).contiguous() # bs, nc*nv, n_pts -> bs*nc, nv, n_pts 99 | index = cat_id + torch.arange(bs, dtype=torch.long).cuda() * self.n_cat 100 | assign_mat = torch.index_select(assign_mat, 0, index) # bs x nv x n_pts 101 | assign_mat = assign_mat.permute(0, 2, 1).contiguous() # bs x n_pts x nv 102 | # deformation field 103 | deform_feat = torch.cat((cat_local, cat_global.repeat(1, 1, nv), inst_global.repeat(1, 1, nv)), dim=1) # bs x 2112 x n_pts 104 | deltas = self.deformation(deform_feat) 105 | deltas = deltas.view(-1, 3, nv).contiguous() # bs, nc*3, nv -> bs*nc, 3, nv 106 | deltas = torch.index_select(deltas, 0, index) # bs x 3 x nv 107 | deltas = deltas.permute(0, 2, 1).contiguous() # bs x nv x 3 108 | 109 | return assign_mat, deltas 110 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/ralamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import torch, math 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # RAdam + LARS 9 | class Ralamb(Optimizer): 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 11 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 12 | self.buffer = [[None, None, None] for ind in range(10)] 13 | super(Ralamb, self).__init__(params, defaults) 14 | 15 | def __setstate__(self, state): 16 | super(Ralamb, self).__setstate__(state) 17 | 18 | def step(self, closure=None): 19 | 20 | loss = None 21 | if closure is not None: 22 | loss = closure() 23 | 24 | for group in self.param_groups: 25 | 26 | for p in group["params"]: 27 | if p.grad is None: 28 | continue 29 | grad = p.grad.data.float() 30 | if grad.is_sparse: 31 | raise RuntimeError("Ralamb does not support sparse gradients") 32 | 33 | p_data_fp32 = p.data.float() 34 | 35 | state = self.state[p] 36 | 37 | if len(state) == 0: 38 | state["step"] = 0 39 | state["exp_avg"] = torch.zeros_like(p_data_fp32) 40 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) 41 | else: 42 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 43 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) 44 | 45 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 46 | beta1, beta2 = group["betas"] 47 | 48 | # Decay the first and second moment running average coefficient 49 | # m_t 50 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 51 | # v_t 52 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 53 | 54 | state["step"] += 1 55 | buffered = self.buffer[int(state["step"] % 10)] 56 | 57 | if state["step"] == buffered[0]: 58 | N_sma, radam_step_size = buffered[1], buffered[2] 59 | else: 60 | buffered[0] = state["step"] 61 | beta2_t = beta2 ** state["step"] 62 | N_sma_max = 2 / (1 - beta2) - 1 63 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) 64 | buffered[1] = N_sma 65 | 66 | # more conservative since it's an approximated value 67 | if N_sma >= 5: 68 | radam_step_size = math.sqrt( 69 | (1 - beta2_t) 70 | * (N_sma - 4) 71 | / (N_sma_max - 4) 72 | * (N_sma - 2) 73 | / N_sma 74 | * N_sma_max 75 | / (N_sma_max - 2) 76 | ) / (1 - beta1 ** state["step"]) 77 | else: 78 | radam_step_size = 1.0 / (1 - beta1 ** state["step"]) 79 | buffered[2] = radam_step_size 80 | 81 | if group["weight_decay"] != 0: 82 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 83 | 84 | # more conservative since it's an approximated value 85 | radam_step = p_data_fp32.clone() 86 | if N_sma >= 5: 87 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 88 | radam_step.addcdiv_(-radam_step_size * group["lr"], exp_avg, denom) 89 | else: 90 | radam_step.add_(-radam_step_size * group["lr"], exp_avg) 91 | 92 | radam_norm = radam_step.pow(2).sum().sqrt() 93 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 94 | if weight_norm == 0 or radam_norm == 0: 95 | trust_ratio = 1 96 | else: 97 | trust_ratio = weight_norm / radam_norm 98 | 99 | state["weight_norm"] = weight_norm 100 | state["adam_norm"] = radam_norm 101 | state["trust_ratio"] = trust_ratio 102 | 103 | if N_sma >= 5: 104 | p_data_fp32.addcdiv_(-radam_step_size * group["lr"] * trust_ratio, exp_avg, denom) 105 | else: 106 | p_data_fp32.add_(-radam_step_size * group["lr"] * trust_ratio, exp_avg) 107 | 108 | p.data.copy_(p_data_fp32) 109 | 110 | return loss 111 | -------------------------------------------------------------------------------- /network/fs_net_repo/pcl_encoder.py: -------------------------------------------------------------------------------- 1 | # follow FS-Net 2 | import torch.nn as nn 3 | import network.fs_net_repo.gcn3d as gcn3d 4 | import torch 5 | import torch.nn.functional as F 6 | from absl import app 7 | import absl.flags as flags 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | # global feature num : the channels of feature from rgb and depth 12 | # grid_num : the volume resolution 13 | 14 | class PCL_Encoder(nn.Module): 15 | def __init__(self): 16 | super(PCL_Encoder, self).__init__() 17 | self.neighbor_num = FLAGS.gcn_n_num 18 | self.support_num = FLAGS.gcn_sup_num 19 | 20 | # 3D convolution for point cloud 21 | self.conv_0 = gcn3d.Conv_surface(kernel_num=128, support_num=self.support_num) 22 | self.conv_1 = gcn3d.Conv_layer(128, 128, support_num=self.support_num) 23 | self.pool_1 = gcn3d.Pool_layer(pooling_rate=4, neighbor_num=4) 24 | self.conv_2 = gcn3d.Conv_layer(128, 256, support_num=self.support_num) 25 | self.conv_3 = gcn3d.Conv_layer(256, 256, support_num=self.support_num) 26 | self.pool_2 = gcn3d.Pool_layer(pooling_rate=4, neighbor_num=4) 27 | self.conv_4 = gcn3d.Conv_layer(256, 512, support_num=self.support_num) 28 | 29 | self.bn1 = nn.BatchNorm1d(128) 30 | self.bn2 = nn.BatchNorm1d(256) 31 | self.bn3 = nn.BatchNorm1d(256) 32 | 33 | self.recon_num = 3 34 | self.face_recon_num = FLAGS.face_recon_c 35 | 36 | def forward(self, 37 | vertices: "tensor (bs, vetice_num, 3)", 38 | cat_id: "tensor (bs, 1)", 39 | ): 40 | """ 41 | Return: (bs, vertice_num, class_num) 42 | """ 43 | # concate feature 44 | bs, vertice_num, _ = vertices.size() 45 | # cat_id to one-hot 46 | if cat_id.shape[0] == 1: 47 | obj_idh = cat_id.view(-1, 1).repeat(cat_id.shape[0], 1) 48 | else: 49 | obj_idh = cat_id.view(-1, 1) 50 | 51 | one_hot = torch.zeros(bs, FLAGS.obj_c).to(cat_id.device).scatter_(1, obj_idh.long(), 1) 52 | # bs x verticenum x 6 53 | 54 | neighbor_index = gcn3d.get_neighbor_index(vertices, self.neighbor_num) 55 | # ss = time.time() 56 | fm_0 = F.relu(self.conv_0(neighbor_index, vertices), inplace=True) 57 | 58 | fm_1 = F.relu(self.bn1(self.conv_1(neighbor_index, vertices, fm_0).transpose(1, 2)).transpose(1, 2), 59 | inplace=True) 60 | v_pool_1, fm_pool_1 = self.pool_1(vertices, fm_1) 61 | # neighbor_index = gcn3d.get_neighbor_index(v_pool_1, self.neighbor_num) 62 | neighbor_index = gcn3d.get_neighbor_index(v_pool_1, 63 | min(self.neighbor_num, v_pool_1.shape[1] // 8)) 64 | fm_2 = F.relu(self.bn2(self.conv_2(neighbor_index, v_pool_1, fm_pool_1).transpose(1, 2)).transpose(1, 2), 65 | inplace=True) 66 | fm_3 = F.relu(self.bn3(self.conv_3(neighbor_index, v_pool_1, fm_2).transpose(1, 2)).transpose(1, 2), 67 | inplace=True) 68 | v_pool_2, fm_pool_2 = self.pool_2(v_pool_1, fm_3) 69 | # neighbor_index = gcn3d.get_neighbor_index(v_pool_2, self.neighbor_num) 70 | neighbor_index = gcn3d.get_neighbor_index(v_pool_2, min(self.neighbor_num, 71 | v_pool_2.shape[1] // 8)) 72 | fm_4 = self.conv_4(neighbor_index, v_pool_2, fm_pool_2) 73 | f_global = fm_4.max(1)[0] # (bs, f) 74 | 75 | nearest_pool_1 = gcn3d.get_nearest_index(vertices, v_pool_1) 76 | nearest_pool_2 = gcn3d.get_nearest_index(vertices, v_pool_2) 77 | fm_2 = gcn3d.indexing_neighbor(fm_2, nearest_pool_1).squeeze(2) 78 | fm_3 = gcn3d.indexing_neighbor(fm_3, nearest_pool_1).squeeze(2) 79 | fm_4 = gcn3d.indexing_neighbor(fm_4, nearest_pool_2).squeeze(2) 80 | one_hot = one_hot.unsqueeze(1).repeat(1, vertice_num, 1) # (bs, vertice_num, cat_one_hot) 81 | 82 | feat = torch.cat([fm_0, fm_1, fm_2, fm_3, fm_4, one_hot], dim=2) 83 | return feat, f_global 84 | 85 | 86 | def main(argv): 87 | classifier_seg3D = PCL_Encoder() 88 | 89 | points = torch.rand(2, 1000, 3) 90 | import numpy as np 91 | obj_idh = torch.ones((2, 1)) 92 | obj_idh[1, 0] = 5 93 | ''' 94 | if obj_idh.shape[0] == 1: 95 | obj_idh = obj_idh.view(-1, 1).repeat(points.shape[0], 1) 96 | else: 97 | obj_idh = obj_idh.view(-1, 1) 98 | 99 | one_hot = torch.zeros(points.shape[0], 6).scatter_(1, obj_idh.cpu().long(), 1) 100 | ''' 101 | feat = classifier_seg3D(points, obj_idh) 102 | t = 1 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | print(1) 108 | from config.config import * 109 | app.run(main) 110 | 111 | 112 | -------------------------------------------------------------------------------- /prepare_data/gen_pts.py: -------------------------------------------------------------------------------- 1 | # @Time : 12/05/2021 2 | # @Author : Wei Chen 3 | # @Project : Pycharm 4 | 5 | import cv2 6 | import numpy as np 7 | import os 8 | import _pickle as pickle 9 | from uti_tool import getFiles_cate, depth_2_mesh_all, depth_2_mesh_bbx 10 | from prepare_data.renderer import create_renderer 11 | 12 | def render_pre(model_path): 13 | renderer = create_renderer(640, 480, renderer_type='python') 14 | models = getFiles_ab_cate(model_path, '.ply') #model name example: laptop_air_1_norm.ply please adjust the 15 | # corresponding functions according to the model name. 16 | objs=[] 17 | for model in models: 18 | obj = model.split('.')[1] 19 | objs.append(obj) 20 | renderer.add_object(obj, model) 21 | return renderer 22 | 23 | def getFiles_ab_cate(file_dir,suf): 24 | L=[] 25 | for root, dirs, files in os.walk(file_dir): 26 | for file in files: 27 | if file.split('.')[1] == suf: 28 | L.append(os.path.join(root, file)) 29 | return L 30 | 31 | def get_dis_all(pc,dep,dd=15): 32 | 33 | N=pc.shape[0] 34 | M=dep.shape[0] 35 | depp=np.tile(dep,(1,N)) 36 | 37 | depmm=depp.reshape((M,N,3)) 38 | delta = depmm - pc 39 | diss=np.linalg.norm(delta,2, 2) 40 | 41 | aa=np.min(diss,1) 42 | bb=aa.reshape((M,1)) 43 | 44 | ids,cc=np.where(bb[:] 0.0)] * 1000.0 61 | 62 | numbs = 6000 63 | 64 | numbs2 = 1000 65 | if VIS.shape[0] > numbs2: 66 | choice2 = np.random.choice(VIS.shape[0], numbs2, replace=False) 67 | VIS = VIS[choice2, :] 68 | 69 | 70 | filename = save_path + ("/pose%08d.txt" % (idx)) 71 | w_namei = save_pathlab + ("/lab%08d.txt" % (idx)) 72 | 73 | dep3d_ = depth_2_mesh_bbx(depth, bbx, K, enl=0) 74 | 75 | if dep3d_.shape[0] > numbs: 76 | choice = np.random.choice(dep3d_.shape[0], numbs, replace=False) 77 | 78 | dep3d = dep3d_[choice, :] 79 | else: 80 | choice = np.random.choice(dep3d_.shape[0], numbs, replace=True) 81 | dep3d = dep3d_[choice, :] 82 | 83 | dep3d = dep3d[np.where(dep3d[:, 2] != 0.0)] 84 | 85 | 86 | threshold = 12 87 | 88 | ids = get_dis_all(VIS, dep3d[:, 0:3], dd=threshold) ## find the object points 89 | 90 | if len(ids) <= 10: 91 | if os.path.exists(filename): 92 | os.remove(filename) 93 | if os.path.exists(w_namei): 94 | os.remove(w_namei) 95 | 96 | if len(ids) > 10: 97 | 98 | np.savetxt(filename, dep3d, fmt='%f', delimiter=' ') 99 | lab = np.zeros((dep3d.shape[0], 1), dtype=np.uint) 100 | lab[ids, :] = 1 101 | np.savetxt(w_namei, lab, fmt='%d') 102 | 103 | 104 | 105 | 106 | def get_point_wise_lab(basepath, fold, renderer, sp): 107 | base_path = basepath + '%d/' % (fold) 108 | 109 | 110 | depths = getFiles_cate(base_path, '_depth', 4, -4) 111 | 112 | labels = getFiles_cate(base_path, '_label2', 4, -4) 113 | 114 | 115 | L_dep = depths 116 | 117 | Ki = np.array([[591.0125, 0, 322.525], [0, 590.16775, 244.11084], [0, 0, 1]]) 118 | 119 | Lidx = 1000 120 | if fold == 1: 121 | s = 0 122 | else: 123 | s = 0 124 | for i in range(s, len(L_dep)): 125 | 126 | lab = pickle.load(open(labels[i], 'rb')) 127 | 128 | depth = cv2.imread(L_dep[i], -1) 129 | img_id = int(L_dep[i][-14:-10]) 130 | for ii in range(len(lab['class_ids'])): 131 | 132 | 133 | obj = lab['model_list'][ii] 134 | 135 | seg = lab['bboxes'][ii].reshape((1, 4)) ## y1 x1 y2 x2 (ori x1,y1,w h) 136 | 137 | idx = (fold - 1) * Lidx + img_id 138 | 139 | R = lab['rotations'][ii] # .reshape((3, 3)) 140 | 141 | T = lab['translations'][ii].reshape((3, 1)) # -np.array([0,0,100]).reshape((3, 1)) 142 | 143 | 144 | if T[2] < 0: 145 | T[2] = -T[2] 146 | vis_part = renderer.render_object(obj, R, T, Ki[0, 0], Ki[1, 1], Ki[0, 2], Ki[1, 2])['depth'] 147 | 148 | bbx = [seg[0, 0], seg[0, 2], seg[0, 1], seg[0, 3]] 149 | 150 | if vis_part.max() > 0: 151 | get_one(depth, bbx, vis_part, Ki, idx, obj, sp) 152 | 153 | 154 | 155 | 156 | if __name__ == '__main__': 157 | path = 'your own object model path ' 158 | render_pre(path) 159 | 160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /evaluation/refine_mug_in_detection_dict_camera.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import mmcv 3 | import os 4 | import _pickle as cPickle 5 | import numpy as np 6 | from evaluation.eval_utils_cass import get_3d_bbox, transform_coordinates_3d, compute_3d_iou_new 7 | from tqdm import tqdm 8 | 9 | def get_origin_scale(model, nocs_scale): 10 | lx = 2 * max(max(model[:, 0]), -min(model[:, 0])) 11 | ly = max(model[:, 1]) - min(model[:, 1]) 12 | lz = max(model[:, 2]) - min(model[:, 2]) 13 | 14 | # real scale 15 | lx_t = lx * nocs_scale 16 | ly_t = ly * nocs_scale 17 | lz_t = lz * nocs_scale 18 | return np.array([lx_t, ly_t, lz_t]) 19 | 20 | def asymmetric_3d_iou(RT_1, RT_2, scales_1, scales_2): 21 | noc_cube_1 = get_3d_bbox(scales_1, 0) 22 | bbox_3d_1 = transform_coordinates_3d(noc_cube_1, RT_1) 23 | 24 | noc_cube_2 = get_3d_bbox(scales_2, 0) 25 | bbox_3d_2 = transform_coordinates_3d(noc_cube_2, RT_2) 26 | 27 | bbox_1_max = np.amax(bbox_3d_1, axis=0) 28 | bbox_1_min = np.amin(bbox_3d_1, axis=0) 29 | bbox_2_max = np.amax(bbox_3d_2, axis=0) 30 | bbox_2_min = np.amin(bbox_3d_2, axis=0) 31 | 32 | overlap_min = np.maximum(bbox_1_min, bbox_2_min) 33 | overlap_max = np.minimum(bbox_1_max, bbox_2_max) 34 | 35 | # intersections and union 36 | if np.amin(overlap_max - overlap_min) < 0: 37 | intersections = 0 38 | else: 39 | intersections = np.prod(overlap_max - overlap_min) 40 | union = np.prod(bbox_1_max - bbox_1_min) + \ 41 | np.prod(bbox_2_max - bbox_2_min) - intersections 42 | overlaps = intersections / union 43 | return overlaps 44 | 45 | data_dir = '/data/zrd/datasets/NOCS' 46 | detection_dir = os.path.join(data_dir, 'detection_dualposenet/data/segmentation_results') 47 | detection_dir_refine_mug = os.path.join(data_dir, 'detection_dualposenet/data/segmentation_results_refine_for_mug') 48 | dataset_split = 'CAMERA25' 49 | img_list_path = os.path.join(data_dir, 'CAMERA/val_list.txt') 50 | img_list = [os.path.join(img_list_path.split('/')[0], line.rstrip('\n')) 51 | for line in open(os.path.join(data_dir, img_list_path))] 52 | cat_name2id = {'bottle': 1, 'bowl': 2, 'camera': 3, 'can': 4, 'laptop': 5, 'mug': 6} 53 | with open(os.path.join(data_dir, 'obj_models/mug_meta.pkl'), 'rb') as f: 54 | mug_meta = cPickle.load(f) 55 | 56 | models = {} 57 | model_file_path = ['obj_models/camera_val.pkl'] 58 | for path in model_file_path: 59 | with open(os.path.join(data_dir, path), 'rb') as f: 60 | models.update(cPickle.load(f)) 61 | 62 | for img_path in tqdm(img_list): 63 | img_path = os.path.join(data_dir, 'CAMERA', img_path) 64 | 65 | scene = img_path.split('/')[-2] 66 | img_id = img_path.split('/')[-1] 67 | detection_file = os.path.join(detection_dir, dataset_split, f'results_val_{scene}_{img_id}.pkl') 68 | with open(detection_file, 'rb') as f: 69 | detection_dict = cPickle.load(f) 70 | with open(img_path + '_label.pkl', 'rb') as f: 71 | gts = cPickle.load(f) 72 | 73 | mug_idx = [] 74 | for idx_gt in range(len(gts['class_ids'])): 75 | gt_cat_id = gts['class_ids'][idx_gt] # convert to 0-indexed 76 | if gt_cat_id == cat_name2id['mug']: 77 | mug_idx.append(idx_gt) 78 | 79 | mug_idx_detection = [] 80 | for idx_gt in range(len(detection_dict['gt_class_ids'])): 81 | gt_cat_id = detection_dict['gt_class_ids'][idx_gt] # convert to 0-indexed 82 | if gt_cat_id == cat_name2id['mug']: 83 | mug_idx_detection.append(idx_gt) 84 | 85 | previous_detection_idx = None 86 | for idx_gt in mug_idx: 87 | max_iou = 0 88 | max_detection_idx = None 89 | rotation = gts['rotations'][idx_gt] 90 | translation = gts['translations'][idx_gt] 91 | model = models[gts['model_list'][idx_gt]].astype(np.float32) # 1024 points 92 | nocs_scale = gts['scales'][idx_gt] # nocs_scale = image file / model file 93 | scale = get_origin_scale(model, 1) 94 | model_name = gts['model_list'][idx_gt] 95 | # T0_mug = mug_meta[model_name][0] 96 | # s0_mug = mug_meta[model_name][1] 97 | RT = np.zeros((4, 4)) 98 | RT[:3, :3] = rotation * nocs_scale 99 | RT[:3, 3] = translation 100 | RT[3, 3] = 1 101 | for idx_detection in mug_idx_detection: 102 | iou = asymmetric_3d_iou(detection_dict['gt_RTs'][idx_detection], RT, 103 | detection_dict['gt_scales'][idx_detection], scale) 104 | if iou > max_iou: 105 | max_iou = iou 106 | max_detection_idx = idx_detection 107 | detection_dict['gt_RTs'][max_detection_idx] = RT 108 | detection_dict['gt_scales'][max_detection_idx] = scale 109 | assert max_detection_idx != previous_detection_idx 110 | previous_detection_idx = max_detection_idx 111 | 112 | detection_file_refine_mug = os.path.join(detection_dir_refine_mug, dataset_split, f'results_val_{scene}_{img_id}.pkl') 113 | if not os.path.exists(os.path.dirname(detection_file_refine_mug)): 114 | os.makedirs(os.path.dirname(detection_file_refine_mug)) 115 | with open(detection_file_refine_mug, 'wb') as f: 116 | cPickle.dump(detection_dict, f) -------------------------------------------------------------------------------- /evaluation/refine_mug_in_detection_dict.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import mmcv 3 | import os 4 | import _pickle as cPickle 5 | import numpy as np 6 | from evaluation.eval_utils_cass import get_3d_bbox, transform_coordinates_3d, compute_3d_iou_new 7 | from tqdm import tqdm 8 | 9 | def get_origin_scale(model, nocs_scale): 10 | lx = 2 * max(max(model[:, 0]), -min(model[:, 0])) 11 | ly = max(model[:, 1]) - min(model[:, 1]) 12 | lz = max(model[:, 2]) - min(model[:, 2]) 13 | 14 | # real scale 15 | lx_t = lx * nocs_scale 16 | ly_t = ly * nocs_scale 17 | lz_t = lz * nocs_scale 18 | return np.array([lx_t, ly_t, lz_t]) 19 | 20 | def asymmetric_3d_iou(RT_1, RT_2, scales_1, scales_2): 21 | noc_cube_1 = get_3d_bbox(scales_1, 0) 22 | bbox_3d_1 = transform_coordinates_3d(noc_cube_1, RT_1) 23 | 24 | noc_cube_2 = get_3d_bbox(scales_2, 0) 25 | bbox_3d_2 = transform_coordinates_3d(noc_cube_2, RT_2) 26 | 27 | bbox_1_max = np.amax(bbox_3d_1, axis=0) 28 | bbox_1_min = np.amin(bbox_3d_1, axis=0) 29 | bbox_2_max = np.amax(bbox_3d_2, axis=0) 30 | bbox_2_min = np.amin(bbox_3d_2, axis=0) 31 | 32 | overlap_min = np.maximum(bbox_1_min, bbox_2_min) 33 | overlap_max = np.minimum(bbox_1_max, bbox_2_max) 34 | 35 | # intersections and union 36 | if np.amin(overlap_max - overlap_min) < 0: 37 | intersections = 0 38 | else: 39 | intersections = np.prod(overlap_max - overlap_min) 40 | union = np.prod(bbox_1_max - bbox_1_min) + \ 41 | np.prod(bbox_2_max - bbox_2_min) - intersections 42 | overlaps = intersections / union 43 | return overlaps 44 | 45 | data_dir = '/data/zrd/datasets/NOCS' 46 | detection_dir = os.path.join(data_dir, 'detection_dualposenet/data/segmentation_results') 47 | detection_dir_refine_mug = os.path.join(data_dir, 'detection_dualposenet/data/segmentation_results_refine_for_mug') 48 | dataset_split = 'REAL275' # 'CAMERA25' 49 | img_list_path = os.path.join(data_dir, 'Real/test_list.txt') 50 | img_list = [os.path.join(img_list_path.split('/')[0], line.rstrip('\n')) 51 | for line in open(os.path.join(data_dir, img_list_path))] 52 | cat_name2id = {'bottle': 1, 'bowl': 2, 'camera': 3, 'can': 4, 'laptop': 5, 'mug': 6} 53 | with open(os.path.join(data_dir, 'obj_models/mug_meta.pkl'), 'rb') as f: 54 | mug_meta = cPickle.load(f) 55 | 56 | models = {} 57 | model_file_path = ['obj_models/real_test.pkl'] 58 | for path in model_file_path: 59 | with open(os.path.join(data_dir, path), 'rb') as f: 60 | models.update(cPickle.load(f)) 61 | 62 | for img_path in tqdm(img_list): 63 | img_path = os.path.join(data_dir, 'Real', img_path) 64 | 65 | scene = img_path.split('/')[-2] 66 | img_id = img_path.split('/')[-1] 67 | detection_file = os.path.join(detection_dir, dataset_split, f'results_test_{scene}_{img_id}.pkl') 68 | with open(detection_file, 'rb') as f: 69 | detection_dict = cPickle.load(f) 70 | with open(img_path + '_label.pkl', 'rb') as f: 71 | gts = cPickle.load(f) 72 | 73 | mug_idx = [] 74 | for idx_gt in range(len(gts['class_ids'])): 75 | gt_cat_id = gts['class_ids'][idx_gt] # convert to 0-indexed 76 | if gt_cat_id == cat_name2id['mug']: 77 | mug_idx.append(idx_gt) 78 | 79 | mug_idx_detection = [] 80 | for idx_gt in range(len(detection_dict['gt_class_ids'])): 81 | gt_cat_id = detection_dict['gt_class_ids'][idx_gt] # convert to 0-indexed 82 | if gt_cat_id == cat_name2id['mug']: 83 | mug_idx_detection.append(idx_gt) 84 | 85 | previous_detection_idx = None 86 | for idx_gt in mug_idx: 87 | max_iou = 0 88 | max_detection_idx = None 89 | rotation = gts['rotations'][idx_gt] 90 | translation = gts['translations'][idx_gt] 91 | model = models[gts['model_list'][idx_gt]].astype(np.float32) # 1024 points 92 | nocs_scale = gts['scales'][idx_gt] # nocs_scale = image file / model file 93 | scale = get_origin_scale(model, 1) 94 | model_name = gts['model_list'][idx_gt] 95 | # T0_mug = mug_meta[model_name][0] 96 | # s0_mug = mug_meta[model_name][1] 97 | RT = np.zeros((4, 4)) 98 | RT[:3, :3] = rotation * nocs_scale 99 | RT[:3, 3] = translation 100 | RT[3, 3] = 1 101 | for idx_detection in mug_idx_detection: 102 | iou = asymmetric_3d_iou(detection_dict['gt_RTs'][idx_detection], RT, 103 | detection_dict['gt_scales'][idx_detection], scale) 104 | if iou > max_iou: 105 | max_iou = iou 106 | max_detection_idx = idx_detection 107 | detection_dict['gt_RTs'][max_detection_idx] = RT 108 | detection_dict['gt_scales'][max_detection_idx] = scale 109 | assert max_detection_idx != previous_detection_idx 110 | previous_detection_idx = max_detection_idx 111 | 112 | detection_file_refine_mug = os.path.join(detection_dir_refine_mug, dataset_split, f'results_test_{scene}_{img_id}.pkl') 113 | if not os.path.exists(os.path.dirname(detection_file_refine_mug)): 114 | os.makedirs(os.path.dirname(detection_file_refine_mug)) 115 | with open(detection_file_refine_mug, 'wb') as f: 116 | cPickle.dump(detection_dict, f) -------------------------------------------------------------------------------- /prepare_data/lib/pspnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | out = self.conv1(x) 24 | out = self.relu(out) 25 | out = self.conv2(out) 26 | if self.downsample is not None: 27 | residual = self.downsample(x) 28 | out += residual 29 | out = self.relu(out) 30 | return out 31 | 32 | 33 | class ResNet(nn.Module): 34 | def __init__(self, block, layers=(3, 4, 23, 3)): 35 | self.inplanes = 64 36 | super(ResNet, self).__init__() 37 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 40 | self.layer1 = self._make_layer(block, 64, layers[0]) 41 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 42 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 43 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2./n)) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | 53 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 54 | downsample = None 55 | if stride != 1 or self.inplanes != planes*block.expansion: 56 | downsample = nn.Sequential( 57 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | layers = [block(self.inplanes, planes, stride, downsample)] 60 | self.inplanes = planes * block.expansion 61 | for i in range(1, blocks): 62 | layers.append(block(self.inplanes, planes, dilation=dilation)) 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | x = self.layer1(x) 70 | x = self.layer2(x) 71 | x = self.layer3(x) 72 | x = self.layer4(x) 73 | return x 74 | 75 | 76 | class PSPModule(nn.Module): 77 | def __init__(self, feat_dim, bins=(1, 2, 3, 6)): 78 | super(PSPModule, self).__init__() 79 | self.reduction_dim = feat_dim // len(bins) 80 | self.stages = [] 81 | self.stages = nn.ModuleList([self._make_stage(feat_dim, size) for size in bins]) 82 | 83 | def _make_stage(self, feat_dim, size): 84 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 85 | conv = nn.Conv2d(feat_dim, self.reduction_dim, kernel_size=1, bias=False) 86 | relu = nn.ReLU(inplace=True) 87 | return nn.Sequential(prior, conv, relu) 88 | 89 | def forward(self, feats): 90 | h, w = feats.size(2), feats.size(3) 91 | priors = [feats] 92 | for stage in self.stages: 93 | priors.append(F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True)) 94 | return torch.cat(priors, 1) 95 | 96 | 97 | class PSPUpsample(nn.Module): 98 | def __init__(self, in_channels, out_channels): 99 | super(PSPUpsample, self).__init__() 100 | self.conv = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 102 | nn.PReLU() 103 | ) 104 | 105 | def forward(self, x): 106 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 107 | return self.conv(x) 108 | 109 | 110 | class PSPNet(nn.Module): 111 | def __init__(self, bins=(1, 2, 3, 6), backend='resnet18'): 112 | super(PSPNet, self).__init__() 113 | if backend == 'resnet18': 114 | self.feats = ResNet(BasicBlock, [2, 2, 2, 2]) 115 | feat_dim = 512 116 | else: 117 | raise NotImplementedError 118 | self.psp = PSPModule(feat_dim, bins) 119 | self.drop = nn.Dropout2d(p=0.15) 120 | self.up_1 = PSPUpsample(1024, 256) 121 | self.up_2 = PSPUpsample(256, 64) 122 | self.up_3 = PSPUpsample(64, 64) 123 | self.final = nn.Conv2d(64, 32, kernel_size=1) 124 | 125 | def forward(self, x): 126 | f = self.feats(x) 127 | p = self.psp(f) 128 | p = self.up_1(p) 129 | p = self.drop(p) 130 | p = self.up_2(p) 131 | p = self.drop(p) 132 | p = self.up_3(p) 133 | return self.final(p) 134 | -------------------------------------------------------------------------------- /network/backbone_repo/Resnet/exts/guideconv_kernel.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 09/02/19. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace { 10 | 11 | template 12 | __global__ void 13 | conv2d_kernel_lf(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ z, size_t N1, 14 | size_t N2, size_t Ci, size_t Co, size_t B, 15 | size_t K) { 16 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 17 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 18 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 19 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 20 | for (int b = 0; b < B; b++) { 21 | scalar_t result = 0; 22 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 23 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 24 | 25 | if ((row_index + i < 0) || (row_index + i >= N1) || (col_index + j < 0) || 26 | (col_index + j >= N2)) { 27 | continue; 28 | } 29 | 30 | result += x[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index + i) * N2 + col_index + j] * 31 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 32 | (i + (K - 1) / 2) * K * N1 * N2 + 33 | (j + (K - 1) / 2) * N1 * N2 + row_index * N2 + col_index]; 34 | } 35 | } 36 | z[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 37 | } 38 | } 39 | } 40 | 41 | 42 | template 43 | __global__ void conv2d_kernel_lb(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ gx, 44 | scalar_t *__restrict__ gy, scalar_t *__restrict__ gz, size_t N1, size_t N2, 45 | size_t Ci, size_t Co, size_t B, 46 | size_t K) { 47 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 48 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 49 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 50 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 51 | for (int b = 0; b < B; b++) { 52 | scalar_t result = 0; 53 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 54 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 55 | 56 | if ((row_index - i < 0) || (row_index - i >= N1) || (col_index - j < 0) || 57 | (col_index - j >= N2)) { 58 | continue; 59 | } 60 | result += gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 61 | ] * 62 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 63 | (i + (K - 1) / 2) * K * N1 * N2 + 64 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j]; 65 | gy[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + (i + (K - 1) / 2) * K * N1 * N2 + 66 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j] = 67 | gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 68 | ] * x[b * N1 * N2 * Ci + cha_index * N1 * N2 + row_index * N2 + col_index]; 69 | 70 | } 71 | } 72 | gx[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 73 | } 74 | } 75 | } 76 | } 77 | 78 | 79 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 80 | size_t K) { 81 | dim3 blockSize(32, 32, 1); 82 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 83 | (Co + blockSize.z - 1) / blockSize.z); 84 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LF", ([&] { 85 | conv2d_kernel_lf << < gridSize, blockSize >> > ( 86 | x.data(), y.data(), z.data(), 87 | N1, N2, Ci, Co, B, K); 88 | })); 89 | } 90 | 91 | 92 | void 93 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci, 94 | size_t Co, size_t B, size_t K) { 95 | dim3 blockSize(32, 32, 1); 96 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 97 | (Co + blockSize.z - 1) / blockSize.z); 98 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LB", ([&] { 99 | conv2d_kernel_lb << < gridSize, blockSize >> > ( 100 | x.data(), y.data(), 101 | gx.data(), gy.data(), gz.data(), 102 | N1, N2, Ci, Co, B, K); 103 | })); 104 | } 105 | -------------------------------------------------------------------------------- /network/backbone_repo/ATSA/model_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | import torch.nn as nn 7 | 8 | BN_MOMENTUM = 0.1 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | class BasicBlock(nn.Module): 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | out += residual 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | class DepthNet(nn.Module): 44 | 45 | def __init__(self): 46 | super(DepthNet, self).__init__() 47 | # conv1 48 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 49 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 50 | self.relu1_1 = nn.ReLU(inplace=True) 51 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 52 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 53 | self.relu1_2 = nn.ReLU(inplace=True) 54 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 55 | 56 | # conv2 57 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 58 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 59 | self.relu2_1 = nn.ReLU(inplace=True) 60 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 61 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 62 | self.relu2_2 = nn.ReLU(inplace=True) 63 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 64 | num_stages = 3 65 | blocks = BasicBlock 66 | num_blocks = [4, 4, 4] 67 | num_channels = [32, 32, 128] 68 | self.stage = self._make_stages(num_stages, blocks, num_blocks, num_channels) 69 | self.transition1 = nn.Sequential( 70 | nn.Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 71 | nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 72 | nn.ReLU(inplace=True) 73 | ) 74 | self.transition2 = nn.Sequential( 75 | nn.Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 76 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 77 | nn.ReLU(inplace=True) 78 | ) 79 | 80 | def _make_one_stage(self, stage_index, block, num_blocks, num_channels): 81 | layers = [] 82 | for i in range(0, num_blocks[stage_index]): 83 | layers.append( 84 | block( 85 | num_channels[stage_index], 86 | num_channels[stage_index] 87 | ) 88 | ) 89 | return nn.Sequential(*layers) 90 | 91 | def _make_stages(self, num_stages, block, num_blocks, num_channels): 92 | branches = [] 93 | 94 | for i in range(num_stages): 95 | branches.append( 96 | self._make_one_stage(i, block, num_blocks, num_channels) 97 | ) 98 | return nn.ModuleList(branches) 99 | 100 | def forward(self, d): 101 | #depth 分支 102 | d = self.relu1_1(self.bn1_1(self.conv1_1(d))) 103 | d = self.relu1_2(self.bn1_2(self.conv1_2(d))) 104 | d = self.pool1(d) # (128x128)*64 105 | 106 | d = self.relu2_1(self.bn2_1(self.conv2_1(d))) 107 | d = self.relu2_2(self.bn2_2(self.conv2_2(d))) 108 | d1 = self.pool2(d) # (64x64)*128 109 | dt2 = self.transition1(d1) 110 | d2 = self.stage[0](dt2) 111 | d3 = self.stage[1](d2) 112 | dt4 = self.transition2(d3) 113 | d4 = self.stage[2](dt4) 114 | return d2, d3, d4 115 | 116 | def init_weights(self): 117 | logger.info('=> init weights from normal distribution') 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | nn.init.normal_(m.weight, std=0.001) 122 | for name, _ in m.named_parameters(): 123 | if name in ['bias']: 124 | nn.init.constant_(m.bias, 0) 125 | elif isinstance(m, nn.BatchNorm2d): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | elif isinstance(m, nn.ConvTranspose2d): 129 | nn.init.normal_(m.weight, std=0.001) 130 | for name, _ in m.named_parameters(): 131 | if name in ['bias']: 132 | nn.init.constant_(m.bias, 0) -------------------------------------------------------------------------------- /tools/shape_prior_utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from tools.rot_utils import get_vertical_rot_vec 3 | import torch 4 | 5 | def get_nocs_from_deform(prior, deform_field, assign_mat): 6 | inst_shape = prior + deform_field 7 | assign_mat = torch.softmax(assign_mat, dim=2) 8 | nocs_coords = torch.bmm(assign_mat, inst_shape) 9 | return nocs_coords 10 | 11 | 12 | def get_face_dis_from_nocs(nocs_coords, size): 13 | # y+, x+, z+, x-, z-, y- 14 | bs, p_num, _ = nocs_coords.shape 15 | face_dis = torch.zeros((bs, p_num, 6), dtype=nocs_coords.dtype).to(nocs_coords.device) 16 | for i in range(bs): 17 | coord_now = nocs_coords[i] 18 | face_dis_now = face_dis[i] 19 | size_now = size[i] 20 | s_x, s_y, s_z = size_now 21 | diag_len = torch.norm(size_now) 22 | s_x_norm, s_y_norm, s_z_norm = s_x / diag_len, s_y / diag_len, s_z / diag_len 23 | face_dis_now[:, 0] = (s_y_norm / 2 - coord_now[:, 1]) * diag_len 24 | face_dis_now[:, 1] = (s_x_norm / 2 - coord_now[:, 0]) * diag_len 25 | face_dis_now[:, 2] = (s_z_norm / 2 - coord_now[:, 2]) * diag_len 26 | face_dis_now[:, 3] = (s_x_norm / 2 + coord_now[:, 0]) * diag_len 27 | face_dis_now[:, 4] = (s_z_norm / 2 + coord_now[:, 2]) * diag_len 28 | face_dis_now[:, 5] = (s_y_norm / 2 + coord_now[:, 1]) * diag_len 29 | return face_dis 30 | 31 | def get_face_shift_from_dis(face_dis, rot_y, rot_x, f_y, f_x, use_rectify_normal=False): 32 | # y+, x+, z+, x-, z-, y- 33 | bs, p_num, _ = face_dis.shape 34 | face_shift = torch.zeros((bs, p_num, 18), dtype=face_dis.dtype).to(face_dis.device) 35 | face_dis = face_dis.unsqueeze(-1) 36 | for i in range(bs): 37 | dis_now = face_dis[i] 38 | face_shift_now = face_shift[i] 39 | if use_rectify_normal: 40 | rot_y_now, rot_x_now = get_vertical_rot_vec(f_y[i], f_x[i], rot_y[i, ...], rot_x[i, ...]) 41 | rot_z_now = torch.cross(rot_x_now, rot_y_now) 42 | else: 43 | rot_y_now = rot_y[i] 44 | rot_x_now = rot_x[i] 45 | rot_z_now = torch.cross(rot_x_now, rot_y_now) 46 | face_shift_now[:, 0:3] = dis_now[:, 0] * rot_y_now 47 | face_shift_now[:, 3:6] = dis_now[:, 1] * rot_x_now 48 | face_shift_now[:, 6:9] = dis_now[:, 2] * rot_z_now 49 | face_shift_now[:, 9:12] = - dis_now[:, 3] * rot_x_now 50 | face_shift_now[:, 12:15] = - dis_now[:, 4] * rot_z_now 51 | face_shift_now[:, 15:18] = - dis_now[:, 5] * rot_y_now 52 | return face_shift 53 | 54 | def get_point_depth_error(nocs_coords, PC, R, t, gt_s, model=None, nocs_scale=None, save_path=None): 55 | bs, p_num, _ = nocs_coords.shape 56 | 57 | diag_len = torch.norm(gt_s, dim=1) 58 | diag_len_ = diag_len.view(bs, 1) 59 | diag_len_ = diag_len_.repeat(1, p_num).view(bs, p_num, 1) 60 | coords = torch.mul(nocs_coords, diag_len_) 61 | coords = torch.bmm(R, coords.permute(0, 2, 1)) + t.view(bs, 3, 1) 62 | coords = coords.permute(0,2,1) 63 | distance = torch.norm(coords - PC, dim=2, p=1) 64 | 65 | coords_s = torch.mul(nocs_coords, diag_len_) 66 | pc_proj = torch.bmm(R.permute(0, 2, 1), (PC.permute(0, 2, 1) - t.view(bs, 3, 1))).permute(0, 2, 1) 67 | nocs_distance = torch.norm(coords_s - pc_proj, dim=2, p=1) 68 | 69 | # assert save_path is None 70 | if save_path is not None: 71 | if nocs_scale is None: 72 | nocs_scale = diag_len 73 | for i in range(bs): 74 | pc_now = PC[i] 75 | coords_now = coords[i] 76 | pc_proj_now = pc_proj[i] 77 | coords_s_now = coords_s[i] 78 | coords_ori_now = nocs_coords[i] 79 | pc_np = pc_now.detach().cpu().numpy() 80 | coord_np = coords_now.detach().cpu().numpy() 81 | pc_proj_np = pc_proj_now.detach().cpu().numpy() 82 | coords_s_np = coords_s_now.detach().cpu().numpy() 83 | coords_ori_np = coords_ori_now.detach().cpu().numpy() 84 | R_np = R[i].detach().cpu().numpy() 85 | t_np = t[i].detach().cpu().numpy() 86 | s_np = gt_s[i].detach().cpu().numpy() 87 | model_np = model[i].detach().cpu().numpy() * nocs_scale[i].detach().cpu().numpy() 88 | 89 | 90 | import numpy as np 91 | np.savetxt(save_path + f'_{i}_pc.txt', pc_np) 92 | np.savetxt(save_path + f'_{i}_coord2pc.txt', coord_np) 93 | np.savetxt(save_path + f'_{i}_pc2nocs.txt', pc_proj_np) 94 | np.savetxt(save_path + f'_{i}_coord.txt', coords_s_np) 95 | np.savetxt(save_path + f'_{i}_coord_ori.txt', coords_ori_np) 96 | np.savetxt(save_path + f'_{i}_model.txt', model_np) 97 | np.savetxt(save_path + f'_{i}_r.txt', R_np) 98 | np.savetxt(save_path + f'_{i}_t.txt', t_np) 99 | np.savetxt(save_path + f'_{i}_s.txt', s_np) 100 | return distance, nocs_distance 101 | 102 | def get_nocs_model(model_point): 103 | model_point_nocs = torch.zeros_like(model_point).to(model_point.device) 104 | bs = model_point.shape[0] 105 | for i in range(bs): 106 | model_now = model_point[i] 107 | lx = 2 * torch.max(torch.max(model_now[:, 0]), -torch.min(model_now[:, 0])) 108 | ly = torch.max(model_now[:, 1]) - torch.min(model_now[:, 1]) 109 | lz = torch.max(model_now[:, 2]) - torch.min(model_now[:, 2]) 110 | diagonal_len = torch.norm(torch.tensor([lx, ly, lz])) 111 | print('model diagonal in final', diagonal_len) 112 | model_point_nocs[i, :] = model_now / diagonal_len 113 | return model_point_nocs -------------------------------------------------------------------------------- /tools/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def get_2d_coord_np(width, height, low=0, high=1, fmt="CHW"): 5 | """ 6 | Args: 7 | width: 8 | height: 9 | Returns: 10 | xy: (2, height, width) 11 | """ 12 | # coords values are in [low, high] [0,1] or [-1,1] 13 | x = np.linspace(0, width-1, width, dtype=np.float32) 14 | y = np.linspace(0, height-1, height, dtype=np.float32) 15 | xy = np.asarray(np.meshgrid(x, y)) 16 | if fmt == "HWC": 17 | xy = xy.transpose(1, 2, 0) 18 | elif fmt == "CHW": 19 | pass 20 | else: 21 | raise ValueError(f"Unknown format: {fmt}") 22 | return xy 23 | 24 | def aug_bbox_DZI(FLAGS, bbox_xyxy, im_H, im_W): 25 | """Used for DZI, the augmented box is a square (maybe enlarged) 26 | Args: 27 | bbox_xyxy (np.ndarray): 28 | Returns: 29 | center, scale 30 | """ 31 | x1, y1, x2, y2 = bbox_xyxy.copy() 32 | cx = 0.5 * (x1 + x2) 33 | cy = 0.5 * (y1 + y2) 34 | bh = y2 - y1 35 | bw = x2 - x1 36 | if FLAGS.DZI_TYPE.lower() == "uniform": 37 | scale_ratio = 1 + FLAGS.DZI_SCALE_RATIO * (2 * np.random.random_sample() - 1) # [1-0.25, 1+0.25] 38 | shift_ratio = FLAGS.DZI_SHIFT_RATIO * (2 * np.random.random_sample(2) - 1) # [-0.25, 0.25] 39 | bbox_center = np.array([cx + bw * shift_ratio[0], cy + bh * shift_ratio[1]]) # (h/2, w/2) 40 | scale = max(y2 - y1, x2 - x1) * scale_ratio * FLAGS.DZI_PAD_SCALE 41 | elif FLAGS.DZI_TYPE.lower() == "roi10d": 42 | # shift (x1,y1), (x2,y2) by 15% in each direction 43 | _a = -0.15 44 | _b = 0.15 45 | x1 += bw * (np.random.rand() * (_b - _a) + _a) 46 | x2 += bw * (np.random.rand() * (_b - _a) + _a) 47 | y1 += bh * (np.random.rand() * (_b - _a) + _a) 48 | y2 += bh * (np.random.rand() * (_b - _a) + _a) 49 | x1 = min(max(x1, 0), im_W) 50 | x2 = min(max(x1, 0), im_W) 51 | y1 = min(max(y1, 0), im_H) 52 | y2 = min(max(y2, 0), im_H) 53 | bbox_center = np.array([0.5 * (x1 + x2), 0.5 * (y1 + y2)]) 54 | scale = max(y2 - y1, x2 - x1) * FLAGS.DZI_PAD_SCALE 55 | elif FLAGS.DZI_TYPE.lower() == "truncnorm": 56 | raise NotImplementedError("DZI truncnorm not implemented yet.") 57 | else: 58 | bbox_center = np.array([cx, cy]) # (w/2, h/2) 59 | scale = max(y2 - y1, x2 - x1) 60 | scale = min(scale, max(im_H, im_W)) * 1.0 61 | return bbox_center, scale 62 | 63 | def aug_bbox_eval(bbox_xyxy, im_H, im_W): 64 | """Used for DZI, the augmented box is a square (maybe enlarged) 65 | Args: 66 | bbox_xyxy (np.ndarray): 67 | Returns: 68 | center, scale 69 | """ 70 | x1, y1, x2, y2 = bbox_xyxy.copy() 71 | cx = 0.5 * (x1 + x2) 72 | cy = 0.5 * (y1 + y2) 73 | bh = y2 - y1 74 | bw = x2 - x1 75 | bbox_center = np.array([cx, cy]) # (w/2, h/2) 76 | scale = max(y2 - y1, x2 - x1) 77 | scale = min(scale, max(im_H, im_W)) * 1.0 78 | return bbox_center, scale 79 | 80 | def crop_resize_by_warp_affine(img, center, scale, output_size, rot=0, interpolation=cv2.INTER_LINEAR): 81 | """ 82 | output_size: int or (w, h) 83 | NOTE: if img is (h,w,1), the output will be (h,w) 84 | """ 85 | if isinstance(scale, (int, float)): 86 | scale = (scale, scale) 87 | if isinstance(output_size, int): 88 | output_size = (output_size, output_size) 89 | trans = get_affine_transform(center, scale, rot, output_size) 90 | 91 | dst_img = cv2.warpAffine(img, trans, (int(output_size[0]), int(output_size[1])), flags=interpolation) 92 | 93 | return dst_img 94 | 95 | def get_affine_transform(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=False): 96 | """ 97 | adapted from CenterNet: https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py 98 | center: ndarray: (cx, cy) 99 | scale: (w, h) 100 | rot: angle in deg 101 | output_size: int or (w, h) 102 | """ 103 | if isinstance(center, (tuple, list)): 104 | center = np.array(center, dtype=np.float32) 105 | 106 | if isinstance(scale, (int, float)): 107 | scale = np.array([scale, scale], dtype=np.float32) 108 | 109 | if isinstance(output_size, (int, float)): 110 | output_size = (output_size, output_size) 111 | 112 | scale_tmp = scale 113 | src_w = scale_tmp[0] 114 | dst_w = output_size[0] 115 | dst_h = output_size[1] 116 | 117 | rot_rad = np.pi * rot / 180 118 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 119 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 120 | 121 | src = np.zeros((3, 2), dtype=np.float32) 122 | dst = np.zeros((3, 2), dtype=np.float32) 123 | src[0, :] = center + scale_tmp * shift 124 | src[1, :] = center + src_dir + scale_tmp * shift 125 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 126 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir 127 | 128 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 129 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 130 | 131 | if inv: 132 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 133 | else: 134 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 135 | 136 | return trans 137 | 138 | def get_dir(src_point, rot_rad): 139 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 140 | 141 | src_result = [0, 0] 142 | src_result[0] = src_point[0] * cs - src_point[1] * sn 143 | src_result[1] = src_point[0] * sn + src_point[1] * cs 144 | 145 | return src_result 146 | 147 | def get_3rd_point(a, b): 148 | direct = a - b 149 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 150 | -------------------------------------------------------------------------------- /network/backbone_repo/pspnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | out = self.conv1(x) 24 | out = self.relu(out) 25 | out = self.conv2(out) 26 | if self.downsample is not None: 27 | residual = self.downsample(x) 28 | out += residual 29 | out = self.relu(out) 30 | return out 31 | 32 | 33 | class ResNet(nn.Module): 34 | def __init__(self, block, layers=(3, 4, 23, 3)): 35 | self.inplanes = 64 36 | super(ResNet, self).__init__() 37 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 40 | self.layer1 = self._make_layer(block, 64, layers[0]) 41 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 42 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 43 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2./n)) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | 53 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 54 | downsample = None 55 | if stride != 1 or self.inplanes != planes*block.expansion: 56 | downsample = nn.Sequential( 57 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | layers = [block(self.inplanes, planes, stride, downsample)] 60 | self.inplanes = planes * block.expansion 61 | for i in range(1, blocks): 62 | layers.append(block(self.inplanes, planes, dilation=dilation)) 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | x = self.layer1(x) 70 | x = self.layer2(x) 71 | x = self.layer3(x) 72 | x = self.layer4(x) 73 | return x 74 | 75 | 76 | class PSPModule(nn.Module): 77 | def __init__(self, feat_dim, bins=(1, 2, 3, 6)): 78 | super(PSPModule, self).__init__() 79 | self.reduction_dim = feat_dim // len(bins) 80 | self.stages = [] 81 | self.stages = nn.ModuleList([self._make_stage(feat_dim, size) for size in bins]) 82 | 83 | def _make_stage(self, feat_dim, size): 84 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 85 | conv = nn.Conv2d(feat_dim, self.reduction_dim, kernel_size=1, bias=False) 86 | relu = nn.ReLU(inplace=True) 87 | return nn.Sequential(prior, conv, relu) 88 | 89 | def forward(self, feats): 90 | h, w = feats.size(2), feats.size(3) 91 | priors = [feats] 92 | for stage in self.stages: 93 | priors.append(F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True)) 94 | return torch.cat(priors, 1) 95 | 96 | 97 | class PSPUpsample(nn.Module): 98 | def __init__(self, in_channels, out_channels): 99 | super(PSPUpsample, self).__init__() 100 | self.conv = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 102 | nn.PReLU() 103 | ) 104 | 105 | def forward(self, x): 106 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 107 | return self.conv(x) 108 | 109 | 110 | class PSPNet(nn.Module): 111 | def __init__(self, bins=(1, 2, 3, 6), backend='resnet18', output_mask=True, output_channel=32): 112 | super(PSPNet, self).__init__() 113 | if backend == 'resnet18': 114 | self.feats = ResNet(BasicBlock, [2, 2, 2, 2]) 115 | feat_dim = 512 116 | else: 117 | raise NotImplementedError 118 | self.psp = PSPModule(feat_dim, bins) 119 | self.drop = nn.Dropout2d(p=0.15) 120 | self.up_1 = PSPUpsample(1024, 256) 121 | self.up_2 = PSPUpsample(256, 64) 122 | self.up_3 = PSPUpsample(64, 64) 123 | self.final = nn.Conv2d(64, output_channel, kernel_size=1) 124 | self.output_mask = output_mask 125 | if output_mask: 126 | self.mask_final = nn.Sequential(nn.Conv2d(64, 2, kernel_size=1)) 127 | self.softmax = nn.Softmax(dim=1) 128 | self.log_softmax = nn.LogSoftmax(dim=1) 129 | 130 | def forward(self, x): 131 | f = self.feats(x) 132 | p = self.psp(f) 133 | p = self.up_1(p) 134 | p = self.drop(p) 135 | p = self.up_2(p) 136 | p = self.drop(p) 137 | p = self.up_3(p) 138 | if self.output_mask: 139 | mask = self.mask_final(p) 140 | # mask = self.log_softmax(mask) 141 | else: 142 | mask = None 143 | feature = self.final(p) 144 | return feature, mask 145 | 146 | -------------------------------------------------------------------------------- /tools/pyTorchChamferDistance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /network/fs_net_repo/FaceRecon.py: -------------------------------------------------------------------------------- 1 | # follow FS-Net 2 | import torch.nn as nn 3 | import network.fs_net_repo.gcn3d as gcn3d 4 | import torch 5 | import torch.nn.functional as F 6 | from absl import app 7 | import absl.flags as flags 8 | from network.fs_net_repo.pcl_encoder import PCL_Encoder 9 | FLAGS = flags.FLAGS 10 | from mmcv.cnn import normal_init, constant_init 11 | from torch.nn.modules.batchnorm import _BatchNorm 12 | 13 | # global feature num : the channels of feature from rgb and depth 14 | # grid_num : the volume resolution 15 | 16 | class FaceRecon(nn.Module): 17 | def __init__(self): 18 | super(FaceRecon, self).__init__() 19 | self.neighbor_num = FLAGS.gcn_n_num 20 | self.support_num = FLAGS.gcn_sup_num 21 | # 3D convolution for point cloud 22 | self.recon_num = 3 23 | self.face_recon_num = FLAGS.face_recon_c 24 | 25 | dim_fuse = sum([128, 128, 256, 256, 512, FLAGS.obj_c, FLAGS.feat_seman]) 26 | # 16: total 6 categories, 256 is global feature 27 | self.conv1d_block = nn.Sequential( 28 | nn.Conv1d(dim_fuse, 512, 1), 29 | nn.BatchNorm1d(512), 30 | nn.ReLU(inplace=True), 31 | nn.Conv1d(512, 512, 1), 32 | nn.BatchNorm1d(512), 33 | nn.ReLU(inplace=True), 34 | nn.Conv1d(512, 256, 1), 35 | nn.BatchNorm1d(256), 36 | nn.ReLU(inplace=True), 37 | ) 38 | 39 | self.recon_head = nn.Sequential( 40 | nn.Conv1d(256, 128, 1), 41 | nn.BatchNorm1d(128), 42 | nn.ReLU(inplace=True), 43 | nn.Conv1d(128, self.recon_num, 1), 44 | ) 45 | 46 | self.face_decoder = nn.Sequential( 47 | nn.Conv1d(FLAGS.feat_face + 3, 512, 1), 48 | nn.BatchNorm1d(512), 49 | nn.ReLU(inplace=True), 50 | nn.Conv1d(512, 256, 1), 51 | nn.BatchNorm1d(256), 52 | nn.ReLU(inplace=True), 53 | ) 54 | 55 | self.vote_head_1= VoteHead() 56 | self.vote_head_2= VoteHead() 57 | self.vote_head_3= VoteHead() 58 | self.vote_head_4= VoteHead() 59 | self.vote_head_5= VoteHead() 60 | self.vote_head_6= VoteHead() 61 | self.vote_head_list = [self.vote_head_1, self.vote_head_2, self.vote_head_3, 62 | self.vote_head_4, self.vote_head_5, self.vote_head_6] 63 | self.mask_head = nn.Sequential( 64 | nn.Conv1d(256, 128, 1), 65 | nn.BatchNorm1d(128), 66 | nn.ReLU(inplace=True), 67 | nn.Conv1d(128, 1, 1), 68 | nn.Sigmoid() 69 | ) 70 | self._init_weights() 71 | 72 | def forward(self, 73 | feat: "tensor (bs, vetice_num, 256)", 74 | feat_global: "tensor (bs, 1, 256)", 75 | vertices: "tensor (bs, vetice_num, 3)", 76 | face_shift_prior: "tensor (bs, vetice_num, 18)", 77 | 78 | ): 79 | """ 80 | Return: (bs, vertice_num, class_num) 81 | """ 82 | # concate feature 83 | bs, vertice_num, _ = feat.size() 84 | feat_face_re = feat_global.view(bs, 1, feat_global.shape[1]).repeat(1, feat.shape[1], 1).permute(0, 2, 1) 85 | conv1d_input = feat.permute(0, 2, 1) # (bs, fuse_ch, vertice_num) 86 | conv1d_out = self.conv1d_block(conv1d_input) 87 | 88 | recon = self.recon_head(conv1d_out) 89 | # average pooling for face prediction 90 | feat_face_in = torch.cat([feat_face_re, conv1d_out, vertices.permute(0, 2, 1)], dim=1) 91 | feat = self.face_decoder(feat_face_in) 92 | mask = self.mask_head(feat) 93 | face_shift_delta = torch.zeros((bs, vertice_num, 18)).to(feat.device) 94 | face_log_var = torch.zeros((bs, vertice_num, 6)).to(feat.device) 95 | for i, vote_head in enumerate(self.vote_head_list): 96 | face_vote_result = vote_head(feat, face_shift_prior[:,:,3*i:3*i+3]) 97 | face_shift_delta[:,:,3*i:3*i+3] = face_vote_result[:,:,:3] 98 | face_log_var[:,:,i] = face_vote_result[:,:,3] 99 | 100 | return recon.permute(0, 2, 1), face_shift_delta, face_log_var, mask.squeeze() 101 | 102 | def _init_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 105 | normal_init(m, std=0.001) 106 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 107 | constant_init(m, 1) 108 | elif isinstance(m, nn.Linear): 109 | normal_init(m, std=0.001) 110 | 111 | 112 | class VoteHead(nn.Module): 113 | def __init__(self): 114 | super(VoteHead, self).__init__() 115 | self.layer = nn.Sequential( 116 | nn.Conv1d(256 + 3, 128, 1), 117 | nn.BatchNorm1d(128), 118 | nn.ReLU(inplace=True), 119 | nn.Conv1d(128, 3 + 1, 1), 120 | ) 121 | 122 | 123 | def forward(self, 124 | feat: "tensor (bs, 256, vertice_num)", 125 | face_shift_prior: "tensor (bs, vertice_num, 3)" 126 | ): 127 | """ 128 | Return: (bs, vertice_num, class_num) 129 | """ 130 | feat_face_in = torch.cat([feat, face_shift_prior.permute(0, 2, 1)], dim=1) 131 | face = self.layer(feat_face_in) 132 | return face.permute(0, 2, 1) 133 | 134 | def main(argv): 135 | classifier_seg3D = FaceRecon() 136 | 137 | points = torch.rand(2, 1000, 3) 138 | import numpy as np 139 | obj_idh = torch.ones((2, 1)) 140 | obj_idh[1, 0] = 5 141 | ''' 142 | if obj_idh.shape[0] == 1: 143 | obj_idh = obj_idh.view(-1, 1).repeat(points.shape[0], 1) 144 | else: 145 | obj_idh = obj_idh.view(-1, 1) 146 | 147 | one_hot = torch.zeros(points.shape[0], 6).scatter_(1, obj_idh.cpu().long(), 1) 148 | ''' 149 | recon, face, feat = classifier_seg3D(points, obj_idh) 150 | face = face.squeeze(0) 151 | t = 1 152 | 153 | 154 | 155 | if __name__ == "__main__": 156 | print(1) 157 | from config.config import * 158 | app.run(main) 159 | 160 | 161 | -------------------------------------------------------------------------------- /tools/visualize/plot_map_sgpa.py: -------------------------------------------------------------------------------- 1 | import matplotlib as plt 2 | import os 3 | from evaluation.eval_utils_cass import * 4 | import mmcv 5 | from shutil import copyfile 6 | 7 | def plot_mAP(degree_thres_list, shift_thres_list, iou_thres_list, iou_3d_aps, pose_aps, output_path, suffix=''): 8 | # draw iou 3d AP vs. iou thresholds 9 | synset_names = ['BG'] + ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug'] 10 | num_classes = len(synset_names) 11 | iou_3d_aps = iou_3d_aps.copy() * 100 12 | pose_aps = pose_aps.copy() * 100 13 | 14 | fig_iou = plt.figure(figsize=(15, 5)) 15 | ax_iou = plt.subplot(131) 16 | plt.title('3D IOU') 17 | plt.ylabel(f'Average Precision ({suffix})', fontsize=12) 18 | plt.ylim((0, 100)) 19 | plt.xlim((0, 100)) 20 | plt.xlabel('Percent') 21 | iou_thres_list = np.array(iou_thres_list) * 100 22 | for cls_id in range(1, num_classes): 23 | class_name = synset_names[cls_id] 24 | ax_iou.plot(iou_thres_list, iou_3d_aps[cls_id, :], label=class_name) 25 | ax_iou.plot(iou_thres_list, iou_3d_aps[-1, :], label='mean') 26 | ax_iou.xaxis.set_major_locator(plt.MultipleLocator(25)) 27 | ax_iou.grid() 28 | # draw pose AP vs. thresholds 29 | ax_rot = plt.subplot(132) 30 | plt.ylim((0, 100)) 31 | plt.xlim((0, 45)) 32 | plt.title('Rotation') 33 | plt.xlabel('Degree') 34 | for cls_id in range(1, num_classes): 35 | class_name = synset_names[cls_id] 36 | ax_rot.plot( 37 | degree_thres_list[:-1], pose_aps[cls_id, :-1, -1], label=class_name) 38 | 39 | ax_rot.plot(degree_thres_list[:-1], pose_aps[-1, :-1, -1], label='mean') 40 | ax_rot.xaxis.set_major_locator(plt.MultipleLocator(15)) 41 | ax_rot.grid() 42 | 43 | ax_trans = plt.subplot(133) 44 | plt.ylim((0, 100)) 45 | plt.xlim((0, 10)) 46 | plt.title('Translation') 47 | plt.xlabel('Centimeter') 48 | for cls_id in range(1, num_classes): 49 | class_name = synset_names[cls_id] 50 | # print(class_name) 51 | ax_trans.plot(shift_thres_list[:-1], 52 | pose_aps[cls_id, -1, :-1], label=class_name) 53 | 54 | ax_trans.plot(shift_thres_list[:-1], pose_aps[-1, -1, :-1], label='mean') 55 | ax_trans.legend(loc='lower right') 56 | ax_trans.xaxis.set_major_locator(plt.MultipleLocator(5)) 57 | ax_trans.grid() 58 | fig_iou.savefig(output_path) 59 | plt.close(fig_iou) 60 | 61 | if __name__ == '__main__': 62 | dualposenet_path = '/data/zrd/datasets/NOCS/results/sgpa_results/REAL275_results.pkl' 63 | 64 | output_path = '/data/zrd/RBP_result/0221_sgpa' 65 | if not os.path.exists(output_path): 66 | os.makedirs(output_path) 67 | degree_thres_list = list(range(0, 61, 1)) 68 | shift_thres_list = [i / 2 for i in range(21)] 69 | iou_thres_list = [i / 100 for i in range(101)] 70 | synset_names = ['BG'] + ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug'] 71 | num_classes = 6 72 | 73 | # evaluate dualposenet results 74 | output_dpn_eval_file_path = '/data/zrd/datasets/NOCS/results/sgpa_results/REAL275_eval_results.pkl' 75 | if os.path.exists(output_dpn_eval_file_path): 76 | iou_aps, pose_aps = mmcv.load(output_dpn_eval_file_path) 77 | else: 78 | dualposenet_pred_results = mmcv.load(dualposenet_path) 79 | iou_aps, pose_aps = compute_degree_cm_mAP(dualposenet_pred_results, synset_names, output_path, degree_thres_list, 80 | shift_thres_list, 81 | iou_thres_list, iou_pose_thres=0.1, use_matches_for_pose=True) 82 | mmcv.dump([iou_aps, pose_aps], output_dpn_eval_file_path) 83 | degree_thres_list += [360] 84 | shift_thres_list += [100] 85 | 86 | iou_25_idx = iou_thres_list.index(0.25) 87 | iou_50_idx = iou_thres_list.index(0.5) 88 | iou_75_idx = iou_thres_list.index(0.75) 89 | degree_05_idx = degree_thres_list.index(5) 90 | degree_10_idx = degree_thres_list.index(10) 91 | shift_02_idx = shift_thres_list.index(2) 92 | shift_05_idx = shift_thres_list.index(5) 93 | shift_10_idx = shift_thres_list.index(10) 94 | 95 | messages = [] 96 | cls_idx = -1 97 | 98 | messages.append('average mAP:') 99 | messages.append('3D IoU at 25: {:.1f}'.format(iou_aps[cls_idx, iou_25_idx] * 100)) 100 | messages.append('3D IoU at 50: {:.1f}'.format(iou_aps[cls_idx, iou_50_idx] * 100)) 101 | messages.append('3D IoU at 75: {:.1f}'.format(iou_aps[cls_idx, iou_75_idx] * 100)) 102 | messages.append('5 degree, 2cm: {:.1f}'.format(pose_aps[cls_idx, degree_05_idx, shift_02_idx] * 100)) 103 | messages.append('5 degree, 5cm: {:.1f}'.format(pose_aps[cls_idx, degree_05_idx, shift_05_idx] * 100)) 104 | messages.append('10 degree, 2cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_02_idx] * 100)) 105 | messages.append('10 degree, 5cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_05_idx] * 100)) 106 | messages.append('10 degree, 10cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_10_idx] * 100)) 107 | 108 | for cls_idx in range(1, len(synset_names)): 109 | messages.append('category {}'.format(synset_names[cls_idx])) 110 | messages.append('mAP:') 111 | messages.append('3D IoU at 25: {:.1f}'.format(iou_aps[cls_idx, iou_25_idx] * 100)) 112 | messages.append('3D IoU at 50: {:.1f}'.format(iou_aps[cls_idx, iou_50_idx] * 100)) 113 | messages.append('3D IoU at 75: {:.1f}'.format(iou_aps[cls_idx, iou_75_idx] * 100)) 114 | messages.append('5 degree, 2cm: {:.1f}'.format(pose_aps[cls_idx, degree_05_idx, shift_02_idx] * 100)) 115 | messages.append('5 degree, 5cm: {:.1f}'.format(pose_aps[cls_idx, degree_05_idx, shift_05_idx] * 100)) 116 | messages.append('10 degree, 2cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_02_idx] * 100)) 117 | messages.append('10 degree, 5cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_05_idx] * 100)) 118 | messages.append('10 degree, 10cm: {:.1f}'.format(pose_aps[cls_idx, degree_10_idx, shift_10_idx] * 100)) 119 | with open(os.path.join(output_path, 'eval_all_results.txt'), 'w') as file: 120 | for line in messages: 121 | print(line) 122 | file.write(line+'\n') 123 | plot_mAP(degree_thres_list, shift_thres_list, iou_thres_list, iou_aps, pose_aps, os.path.join(output_path, 'sgpa_mAP.png'), 124 | suffix='SGPA') 125 | -------------------------------------------------------------------------------- /losses/shape_prior_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .nn_distance.chamfer_loss import ChamferLoss 6 | 7 | 8 | class shape_prior_loss(nn.Module): 9 | """ Loss for training DeformNet. 10 | Use NOCS coords to supervise training. 11 | """ 12 | def __init__(self, corr_wt, cd_wt, entropy_wt, deform_wt, threshold, sym_wt): 13 | super(shape_prior_loss, self).__init__() 14 | self.threshold = threshold 15 | self.chamferloss = ChamferLoss() 16 | self.corr_wt = corr_wt 17 | self.cd_wt = cd_wt 18 | self.entropy_wt = entropy_wt 19 | self.deform_wt = deform_wt 20 | self.sym_wt = sym_wt 21 | self.symmetry_rotation_matrix_list = self.symmetry_rotation_matrix_y() 22 | self.symmetry_rotation_matrix_list_tensor = None 23 | 24 | def forward(self, assign_mat, deltas, prior, nocs, model, point_mask, sym): 25 | """ 26 | Args: 27 | assign_mat: bs x n_pts x nv 28 | deltas: bs x nv x 3 29 | prior: bs x nv x 3 30 | """ 31 | if self.symmetry_rotation_matrix_list_tensor is None: 32 | result = [] 33 | for rotation_matrix in self.symmetry_rotation_matrix_list: 34 | rotation_matrix = torch.from_numpy(rotation_matrix).float().to(nocs.device) 35 | result.append(rotation_matrix) 36 | self.symmetry_rotation_matrix_list_tensor = result 37 | loss_dict = {} 38 | 39 | inst_shape = prior + deltas 40 | # smooth L1 loss for correspondences 41 | soft_assign = F.softmax(assign_mat, dim=2) 42 | coords = torch.bmm(soft_assign, inst_shape) # bs x n_pts x 3 43 | corr_loss = self.cal_corr_loss(coords, nocs, point_mask, sym) 44 | corr_loss = self.corr_wt * corr_loss 45 | # entropy loss to encourage peaked distribution 46 | log_assign = F.log_softmax(assign_mat, dim=2) 47 | entropy_loss = torch.mean(-torch.sum(soft_assign * log_assign, 2)) 48 | entropy_loss = self.entropy_wt * entropy_loss 49 | # cd-loss for instance reconstruction 50 | cd_loss, _, _ = self.chamferloss(inst_shape, model) 51 | cd_loss = self.cd_wt * cd_loss 52 | # L2 regularizations on deformation 53 | deform_loss = torch.norm(deltas, p=2, dim=2).mean() 54 | deform_loss = self.deform_wt * deform_loss 55 | loss_dict['corr_loss'] = corr_loss # predicted nocs coordinate loss 56 | loss_dict['entropy_loss'] = entropy_loss # entropy loss for assign matrix 57 | loss_dict['cd_loss'] = cd_loss # chamfer distance loss between ground truth shape and predicted full shape 58 | loss_dict['deform_loss'] = deform_loss # regularization loss for deformation field 59 | if self.sym_wt != 0: 60 | loss_dict['sym_loss'] = self.sym_wt * self.cal_sym_loss(inst_shape, sym) 61 | return loss_dict 62 | 63 | def cal_sym_loss(self, inst_shape, sym): 64 | # only calculate NOCS in y-axis for symmetric object 65 | bs = inst_shape.shape[0] 66 | sym_loss = 0 67 | if sym is not None: 68 | target_shape = inst_shape.clone() 69 | # symmetry aware 70 | for i in range(bs): 71 | if sym[i, 0] == 1 and torch.sum(sym[i, 1:]) > 0: # y axis reflection, can, bowl, bottle 72 | target_shape[i, :, 0] = -target_shape[i, :, 0] 73 | target_shape[i, :, 2] = -target_shape[i, :, 2] 74 | elif sym[i, 0] == 0 and sym[i, 1] == 1: # yx reflection, laptop, mug with handle 75 | target_shape[i, :, 2] = -target_shape[i, :, 2] 76 | sym_loss, _, _ = self.chamferloss(inst_shape, target_shape) 77 | return sym_loss / bs 78 | 79 | def symmetry_rotation_matrix_y(self, number=30): 80 | result = [] 81 | for i in range(number): 82 | theta = 2 * np.pi / number * i 83 | r = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) 84 | result.append(r) 85 | return result 86 | 87 | def cal_corr_loss(self, coords, nocs, point_mask, sym): 88 | # filter out invalid point 89 | point_mask = torch.stack([point_mask, point_mask, point_mask], dim=-1) 90 | coords = torch.where(point_mask, coords, torch.zeros_like(coords)) 91 | nocs = torch.where(point_mask, nocs, torch.zeros_like(nocs)) 92 | # only calculate NOCS in y-axis for symmetric object 93 | bs = nocs.shape[0] 94 | corr_loss = 0 95 | if sym is not None: 96 | # symmetry aware 97 | for i in range(bs): 98 | sym_now = sym[i, 0] 99 | coords_now = coords[i] 100 | nocs_now = nocs[i] 101 | if sym_now == 1: 102 | min_corr_loss_now = 1e5 103 | min_rotation_matrix = torch.eye(3).cuda() 104 | with torch.no_grad(): 105 | for rotation_matrix in self.symmetry_rotation_matrix_list_tensor: 106 | # this should be the inverse of rotation matrix, but it has no influence on result 107 | temp_corr_loss = self.cal_corr_loss_for_each_item(coords_now, torch.mm(nocs_now, rotation_matrix)) 108 | if temp_corr_loss < min_corr_loss_now: 109 | min_corr_loss_now = temp_corr_loss 110 | min_rotation_matrix = rotation_matrix 111 | corr_loss = corr_loss + self.cal_corr_loss_for_each_item(coords_now, torch.mm(nocs_now, min_rotation_matrix)) 112 | else: 113 | corr_loss = corr_loss + self.cal_corr_loss_for_each_item(coords_now, nocs_now) 114 | else: 115 | for i in range(bs): 116 | coords_now = coords[i] 117 | nocs_now = nocs[i] 118 | corr_loss = corr_loss + self.cal_corr_loss_for_each_item(coords_now, nocs_now) 119 | return corr_loss / bs 120 | 121 | def cal_corr_loss_for_each_item(self, coords, nocs): 122 | diff = torch.abs(coords - nocs) 123 | lower_corr_loss = torch.pow(diff, 2) / (2.0 * self.threshold) 124 | higher_corr_loss = diff - self.threshold / 2.0 125 | corr_loss_matrix = torch.where(diff > self.threshold, higher_corr_loss, lower_corr_loss) 126 | corr_loss_matrix = torch.sum(corr_loss_matrix, dim=-1) 127 | corr_loss = torch.mean(corr_loss_matrix) 128 | return corr_loss -------------------------------------------------------------------------------- /tools/align_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | RANSAC for Similarity Transformation Estimation 3 | Modified from https://github.com/hughw19/NOCS_CVPR2019 4 | Originally Written by Srinath Sridhar 5 | """ 6 | import time 7 | import numpy as np 8 | 9 | 10 | def estimateSimilarityUmeyama(SourceHom, TargetHom): 11 | # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf 12 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 13 | TargetCentroid = np.mean(TargetHom[:3, :], axis=1) 14 | nPoints = SourceHom.shape[1] 15 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 16 | CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose() 17 | CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints 18 | if np.isnan(CovMatrix).any(): 19 | print('nPoints:', nPoints) 20 | print(SourceHom.shape) 21 | print(TargetHom.shape) 22 | raise RuntimeError('There are NANs in the input.') 23 | 24 | U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True) 25 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 26 | if d: 27 | D[-1] = -D[-1] 28 | U[:, -1] = -U[:, -1] 29 | # rotation 30 | Rotation = np.matmul(U, Vh) 31 | # scale 32 | varP = np.var(SourceHom[:3, :], axis=1).sum() 33 | Scale = 1 / varP * np.sum(D) 34 | # translation 35 | Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(Scale*Rotation.T) 36 | # transformation matrix 37 | OutTransform = np.identity(4) 38 | OutTransform[:3, :3] = Scale * Rotation 39 | OutTransform[:3, 3] = Translation 40 | 41 | return Scale, Rotation, Translation, OutTransform 42 | 43 | 44 | def estimateSimilarityTransform(source: np.array, target: np.array, verbose=False): 45 | """ Add RANSAC algorithm to account for outliers. 46 | 47 | """ 48 | assert source.shape[0] == target.shape[0], 'Source and Target must have same number of points.' 49 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 50 | TargetHom = np.transpose(np.hstack([target, np.ones([target.shape[0], 1])])) 51 | # Auto-parameter selection based on source heuristics 52 | # Assume source is object model or gt nocs map, which is of high quality 53 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 54 | nPoints = SourceHom.shape[1] 55 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 56 | SourceDiameter = 2 * np.amax(np.linalg.norm(CenteredSource, axis=0)) 57 | InlierT = SourceDiameter / 10.0 # 0.1 of source diameter 58 | maxIter = 128 59 | confidence = 0.99 60 | 61 | if verbose: 62 | print('Inlier threshold: ', InlierT) 63 | print('Max number of iterations: ', maxIter) 64 | 65 | BestInlierRatio = 0 66 | BestInlierIdx = np.arange(nPoints) 67 | for i in range(0, maxIter): 68 | # Pick 5 random (but corresponding) points from source and target 69 | RandIdx = np.random.randint(nPoints, size=5) 70 | Scale, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx]) 71 | PassThreshold = Scale * InlierT # propagate inlier threshold to target scale 72 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 73 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 74 | InlierIdx = np.where(ResidualVec < PassThreshold)[0] 75 | nInliers = InlierIdx.shape[0] 76 | InlierRatio = nInliers / nPoints 77 | # update best hypothesis 78 | if InlierRatio > BestInlierRatio: 79 | BestInlierRatio = InlierRatio 80 | BestInlierIdx = InlierIdx 81 | if verbose: 82 | print('Iteration: ', i) 83 | print('Inlier ratio: ', BestInlierRatio) 84 | # early break 85 | if (1 - (1 - BestInlierRatio ** 5) ** i) > confidence: 86 | break 87 | 88 | if(BestInlierRatio < 0.1): 89 | print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio) 90 | return None, None, None, None 91 | 92 | SourceInliersHom = SourceHom[:, BestInlierIdx] 93 | TargetInliersHom = TargetHom[:, BestInlierIdx] 94 | Scale, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom) 95 | 96 | if verbose: 97 | print('BestInlierRatio:', BestInlierRatio) 98 | print('Rotation:\n', Rotation) 99 | print('Translation:\n', Translation) 100 | print('Scale:', Scale) 101 | 102 | return Scale, Rotation, Translation, OutTransform 103 | 104 | 105 | def backproject(depth, intrinsics, instance_mask): 106 | """ Back-projection, use opencv camera coordinate frame. 107 | 108 | """ 109 | cam_fx = intrinsics[0, 0] 110 | cam_fy = intrinsics[1, 1] 111 | cam_cx = intrinsics[0, 2] 112 | cam_cy = intrinsics[1, 2] 113 | 114 | non_zero_mask = (depth > 0) 115 | final_instance_mask = np.logical_and(instance_mask, non_zero_mask) 116 | idxs = np.where(final_instance_mask) 117 | 118 | z = depth[idxs[0], idxs[1]] 119 | x = (idxs[1] - cam_cx) * z / cam_fx 120 | y = (idxs[0] - cam_cy) * z / cam_fy 121 | pts = np.stack((x, y, z), axis=1) 122 | 123 | return pts, idxs 124 | 125 | 126 | def align_nocs_to_depth(masks, coords, depth, intrinsics, instance_ids, img_path, verbose=False): 127 | num_instances = len(instance_ids) 128 | error_messages = '' 129 | elapses = [] 130 | scales = np.zeros(num_instances) 131 | rotations = np.zeros((num_instances, 3, 3)) 132 | translations = np.zeros((num_instances, 3)) 133 | 134 | for i in range(num_instances): 135 | mask = masks[:, :, i] 136 | coord = coords[:, :, i, :] 137 | pts, idxs = backproject(depth, intrinsics, mask) 138 | coord_pts = coord[idxs[0], idxs[1], :] - 0.5 139 | try: 140 | start = time.time() 141 | s, R, T, outtransform = estimateSimilarityTransform(coord_pts, pts, False) 142 | elapsed = time.time() - start 143 | if verbose: 144 | print('elapsed: ', elapsed) 145 | elapses.append(elapsed) 146 | except Exception as e: 147 | message = '[ Error ] aligning instance {} in {} fails. Message: {}.'.format(instance_ids[i], img_path, str(e)) 148 | print(message) 149 | error_messages += message + '\n' 150 | s = 1.0 151 | R = np.eye(3) 152 | T = np.zeros(3) 153 | outtransform = np.identity(4, dtype=np.float32) 154 | 155 | scales[i] = s / 1000.0 156 | rotations[i, :, :] = R 157 | translations[i, :] = T / 1000.0 158 | 159 | return scales, rotations, translations, error_messages, elapses -------------------------------------------------------------------------------- /prepare_data/lib/align.py: -------------------------------------------------------------------------------- 1 | """ 2 | RANSAC for Similarity Transformation Estimation 3 | Modified from https://github.com/hughw19/NOCS_CVPR2019 4 | Originally Written by Srinath Sridhar 5 | """ 6 | import time 7 | import numpy as np 8 | 9 | 10 | def estimateSimilarityUmeyama(SourceHom, TargetHom): 11 | # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf 12 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 13 | TargetCentroid = np.mean(TargetHom[:3, :], axis=1) 14 | nPoints = SourceHom.shape[1] 15 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 16 | CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose() 17 | CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints 18 | if np.isnan(CovMatrix).any(): 19 | print('nPoints:', nPoints) 20 | print(SourceHom.shape) 21 | print(TargetHom.shape) 22 | raise RuntimeError('There are NANs in the input.') 23 | 24 | U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True) 25 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 26 | if d: 27 | D[-1] = -D[-1] 28 | U[:, -1] = -U[:, -1] 29 | # rotation 30 | Rotation = np.matmul(U, Vh) 31 | # scale 32 | varP = np.var(SourceHom[:3, :], axis=1).sum() 33 | Scale = 1 / varP * np.sum(D) 34 | # translation 35 | Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(Scale*Rotation.T) 36 | # transformation matrix 37 | OutTransform = np.identity(4) 38 | OutTransform[:3, :3] = Scale * Rotation 39 | OutTransform[:3, 3] = Translation 40 | 41 | return Scale, Rotation, Translation, OutTransform 42 | 43 | 44 | def estimateSimilarityTransform(source: np.array, target: np.array, verbose=False): 45 | """ Add RANSAC algorithm to account for outliers. 46 | 47 | """ 48 | assert source.shape[0] == target.shape[0], 'Source and Target must have same number of points.' 49 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 50 | TargetHom = np.transpose(np.hstack([target, np.ones([target.shape[0], 1])])) 51 | # Auto-parameter selection based on source heuristics 52 | # Assume source is object model or gt nocs map, which is of high quality 53 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 54 | nPoints = SourceHom.shape[1] 55 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 56 | SourceDiameter = 2 * np.amax(np.linalg.norm(CenteredSource, axis=0)) 57 | InlierT = SourceDiameter / 10.0 # 0.1 of source diameter 58 | maxIter = 128 59 | confidence = 0.99 60 | 61 | if verbose: 62 | print('Inlier threshold: ', InlierT) 63 | print('Max number of iterations: ', maxIter) 64 | 65 | BestInlierRatio = 0 66 | BestInlierIdx = np.arange(nPoints) 67 | for i in range(0, maxIter): 68 | # Pick 5 random (but corresponding) points from source and target 69 | RandIdx = np.random.randint(nPoints, size=5) 70 | Scale, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx]) 71 | PassThreshold = Scale * InlierT # propagate inlier threshold to target scale 72 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 73 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 74 | InlierIdx = np.where(ResidualVec < PassThreshold)[0] 75 | nInliers = InlierIdx.shape[0] 76 | InlierRatio = nInliers / nPoints 77 | # update best hypothesis 78 | if InlierRatio > BestInlierRatio: 79 | BestInlierRatio = InlierRatio 80 | BestInlierIdx = InlierIdx 81 | if verbose: 82 | print('Iteration: ', i) 83 | print('Inlier ratio: ', BestInlierRatio) 84 | # early break 85 | if (1 - (1 - BestInlierRatio ** 5) ** i) > confidence: 86 | break 87 | 88 | if(BestInlierRatio < 0.1): 89 | print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio) 90 | return None, None, None, None 91 | 92 | SourceInliersHom = SourceHom[:, BestInlierIdx] 93 | TargetInliersHom = TargetHom[:, BestInlierIdx] 94 | Scale, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom) 95 | 96 | if verbose: 97 | print('BestInlierRatio:', BestInlierRatio) 98 | print('Rotation:\n', Rotation) 99 | print('Translation:\n', Translation) 100 | print('Scale:', Scale) 101 | 102 | return Scale, Rotation, Translation, OutTransform 103 | 104 | 105 | def backproject(depth, intrinsics, instance_mask): 106 | """ Back-projection, use opencv camera coordinate frame. 107 | 108 | """ 109 | cam_fx = intrinsics[0, 0] 110 | cam_fy = intrinsics[1, 1] 111 | cam_cx = intrinsics[0, 2] 112 | cam_cy = intrinsics[1, 2] 113 | 114 | non_zero_mask = (depth > 0) 115 | final_instance_mask = np.logical_and(instance_mask, non_zero_mask) 116 | idxs = np.where(final_instance_mask) 117 | 118 | z = depth[idxs[0], idxs[1]] 119 | x = (idxs[1] - cam_cx) * z / cam_fx 120 | y = (idxs[0] - cam_cy) * z / cam_fy 121 | pts = np.stack((x, y, z), axis=1) 122 | 123 | return pts, idxs 124 | 125 | 126 | def align_nocs_to_depth(masks, coords, depth, intrinsics, instance_ids, img_path, verbose=False): 127 | num_instances = len(instance_ids) 128 | error_messages = '' 129 | elapses = [] 130 | scales = np.zeros(num_instances) 131 | rotations = np.zeros((num_instances, 3, 3)) 132 | translations = np.zeros((num_instances, 3)) 133 | 134 | for i in range(num_instances): 135 | mask = masks[:, :, i] 136 | coord = coords[:, :, i, :] 137 | pts, idxs = backproject(depth, intrinsics, mask) 138 | coord_pts = coord[idxs[0], idxs[1], :] - 0.5 139 | try: 140 | start = time.time() 141 | s, R, T, outtransform = estimateSimilarityTransform(coord_pts, pts, False) 142 | elapsed = time.time() - start 143 | if verbose: 144 | print('elapsed: ', elapsed) 145 | elapses.append(elapsed) 146 | except Exception as e: 147 | message = '[ Error ] aligning instance {} in {} fails. Message: {}.'.format(instance_ids[i], img_path, str(e)) 148 | print(message) 149 | error_messages += message + '\n' 150 | s = 1.0 151 | R = np.eye(3) 152 | T = np.zeros(3) 153 | outtransform = np.identity(4, dtype=np.float32) 154 | 155 | scales[i] = s / 1000.0 156 | rotations[i, :, :] = R 157 | translations[i, :] = T / 1000.0 158 | 159 | return scales, rotations, translations, error_messages, elapses 160 | -------------------------------------------------------------------------------- /losses/fs_net_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import absl.flags as flags 5 | from absl import app 6 | import mmcv 7 | FLAGS = flags.FLAGS # can control the weight of each term here 8 | 9 | 10 | class fs_net_loss(nn.Module): 11 | def __init__(self): 12 | super(fs_net_loss, self).__init__() 13 | if FLAGS.fsnet_loss_type == 'l1': 14 | self.loss_func_t = nn.L1Loss() 15 | self.loss_func_s = nn.L1Loss() 16 | self.loss_func_Rot1 = nn.L1Loss() 17 | self.loss_func_Rot2 = nn.L1Loss() 18 | self.loss_func_r_con = nn.L1Loss() 19 | self.loss_func_Recon = nn.L1Loss() 20 | elif FLAGS.fsnet_loss_type == 'smoothl1': # same as MSE 21 | self.loss_func_t = nn.SmoothL1Loss(beta=0.5) 22 | self.loss_func_s = nn.SmoothL1Loss(beta=0.5) 23 | self.loss_func_Rot1 = nn.SmoothL1Loss(beta=0.5) 24 | self.loss_func_Rot2 = nn.SmoothL1Loss(beta=0.5) 25 | self.loss_func_r_con = nn.SmoothL1Loss(beta=0.5) 26 | self.loss_func_Recon = nn.SmoothL1Loss(beta=0.3) 27 | else: 28 | raise NotImplementedError 29 | 30 | def forward(self, name_list, pred_list, gt_list, sym): 31 | loss_list = {} 32 | if "Rot1" in name_list: 33 | loss_list["Rot1"] = FLAGS.rot_1_w * self.cal_loss_Rot1(pred_list["Rot1"], gt_list["Rot1"]) 34 | 35 | if "Rot1_cos" in name_list: 36 | loss_list["Rot1_cos"] = FLAGS.rot_1_w * self.cal_cosine_dis(pred_list["Rot1"], gt_list["Rot1"]) 37 | 38 | if "Rot2" in name_list: 39 | loss_list["Rot2"] = FLAGS.rot_2_w * self.cal_loss_Rot2(pred_list["Rot2"], gt_list["Rot2"], sym) 40 | 41 | if "Rot2_cos" in name_list: 42 | loss_list["Rot2_cos"] = FLAGS.rot_2_w * self.cal_cosine_dis_sym(pred_list["Rot2"], gt_list["Rot2"], sym) 43 | 44 | if "Rot_regular" in name_list: 45 | loss_list["Rot_r_a"] = FLAGS.rot_regular * self.cal_rot_regular_angle(pred_list["Rot1"], 46 | pred_list["Rot2"], sym) 47 | 48 | if "Recon" in name_list: 49 | loss_list["Recon"] = FLAGS.recon_w * self.cal_loss_Recon(pred_list["Recon"], gt_list["Recon"]) 50 | 51 | if "Tran" in name_list: 52 | loss_list["Tran"] = FLAGS.tran_w * self.cal_loss_Tran(pred_list["Tran"], gt_list["Tran"]) 53 | 54 | if "Size" in name_list: 55 | loss_list["Size"] = FLAGS.size_w * self.cal_loss_Size(pred_list["Size"], gt_list["Size"]) 56 | 57 | if "R_con" in name_list: 58 | loss_list["R_con"] = FLAGS.r_con_w * self.cal_loss_R_con(pred_list["Rot1"], pred_list["Rot2"], 59 | gt_list["Rot1"], gt_list["Rot2"], 60 | pred_list["Rot1_f"], pred_list["Rot2_f"], sym) 61 | return loss_list 62 | 63 | def cal_loss_R_con(self, p_rot_g, p_rot_r, g_rot_g, g_rot_r, p_g_con, p_r_con, sym): 64 | dis_g = p_rot_g - g_rot_g # bs x 3 65 | dis_g_norm = torch.norm(dis_g, dim=-1) # bs 66 | p_g_con_gt = torch.exp(-13.7 * dis_g_norm * dis_g_norm) # bs 67 | res_g = self.loss_func_r_con(p_g_con_gt, p_g_con) 68 | res_r = 0.0 69 | bs = p_rot_g.shape[0] 70 | for i in range(bs): 71 | if sym[i, 0] == 0: 72 | dis_r = p_rot_r[i, ...] - g_rot_r[i, ...] 73 | dis_r_norm = torch.norm(dis_r) # 1 74 | p_r_con_gt = torch.exp(-13.7 * dis_r_norm * dis_r_norm) 75 | res_r += self.loss_func_r_con(p_r_con_gt, p_r_con[i]) 76 | res_r = res_r / bs 77 | return res_g + res_r 78 | 79 | 80 | def cal_loss_Rot1(self, pred_v, gt_v): 81 | bs = pred_v.shape[0] 82 | res = torch.zeros([bs], dtype=torch.float32, device=pred_v.device) 83 | for i in range(bs): 84 | pred_v_now = pred_v[i, ...] 85 | gt_v_now = gt_v[i, ...] 86 | res[i] = self.loss_func_Rot1(pred_v_now, gt_v_now) 87 | res = torch.mean(res) 88 | return res 89 | 90 | def cal_loss_Rot2(self, pred_v, gt_v, sym): 91 | bs = pred_v.shape[0] 92 | res = 0.0 93 | valid = 0.0 94 | for i in range(bs): 95 | sym_now = sym[i, 0] 96 | if sym_now == 1: 97 | continue 98 | else: 99 | pred_v_now = pred_v[i, ...] 100 | gt_v_now = gt_v[i, ...] 101 | res += self.loss_func_Rot2(pred_v_now, gt_v_now) 102 | valid += 1.0 103 | if valid > 0.0: 104 | res = res / valid 105 | return res 106 | 107 | def cal_cosine_dis(self, pred_v, gt_v): 108 | # pred_v bs x 6, gt_v bs x 6 109 | bs = pred_v.shape[0] 110 | res = torch.zeros([bs], dtype=torch.float32).to(pred_v.device) 111 | for i in range(bs): 112 | pred_v_now = pred_v[i, ...] 113 | gt_v_now = gt_v[i, ...] 114 | res[i] = (1.0 - torch.sum(pred_v_now * gt_v_now)) * 2.0 115 | res = torch.mean(res) 116 | return res 117 | 118 | def cal_cosine_dis_sym(self, pred_v, gt_v, sym): 119 | # pred_v bs x 6, gt_v bs x 6 120 | bs = pred_v.shape[0] 121 | res = 0.0 122 | valid = 0.0 123 | for i in range(bs): 124 | sym_now = sym[i, 0] 125 | if sym_now == 1: 126 | continue 127 | else: 128 | pred_v_now = pred_v[i, ...] 129 | gt_v_now = gt_v[i, ...] 130 | res += (1.0 - torch.sum(pred_v_now * gt_v_now)) * 2.0 131 | valid += 1.0 132 | if valid > 0.0: 133 | res = res / valid 134 | return res 135 | 136 | 137 | def cal_rot_regular_angle(self, pred_v1, pred_v2, sym): 138 | bs = pred_v1.shape[0] 139 | res = 0.0 140 | valid = 0.0 141 | for i in range(bs): 142 | if sym[i, 0] == 1: 143 | continue 144 | y_direction = pred_v1[i, ...] 145 | z_direction = pred_v2[i, ...] 146 | residual = torch.dot(y_direction, z_direction) 147 | res += torch.abs(residual) 148 | valid += 1.0 149 | if valid > 0.0: 150 | res = res / valid 151 | return res 152 | 153 | def cal_loss_Recon(self, pred_recon, gt_recon): 154 | return self.loss_func_Recon(pred_recon, gt_recon) 155 | 156 | def cal_loss_Tran(self, pred_trans, gt_trans): 157 | return self.loss_func_t(pred_trans, gt_trans) 158 | 159 | def cal_loss_Size(self, pred_size, gt_size): 160 | return self.loss_func_s(pred_size, gt_size) 161 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/optimize.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch.nn.utils import clip_grad 4 | import mmcv 5 | from mmcv.runner import obj_from_dict 6 | from lib.utils import logger 7 | from lib.utils.utils import msg 8 | 9 | 10 | def _get_optimizer(params, optimizer_cfg, use_hvd=False): 11 | # cfg.optimizer = dict(type='RMSprop', lr=1e-4, weight_decay=0) 12 | # cfg.optimizer = dict(type='Ranger', lr=1e-4) # , N_sma_threshhold=5, betas=(.95, 0.999)) # 4, (0.90, 0.999) 13 | optim_type_str = optimizer_cfg.pop("type") 14 | if optim_type_str.lower() in ["rangerlars", "over9000"]: # RangerLars 15 | optim_type_str = "lookahead_Ralamb" 16 | optim_split = optim_type_str.split("_") 17 | 18 | optim_type = optim_split[-1] 19 | logger.info(f"optimizer: {optim_type_str} {optim_split}") 20 | 21 | if optim_type == "Ranger": 22 | from lib.torch_utils.solver.ranger import Ranger 23 | 24 | optimizer_cls = Ranger 25 | elif optim_type == "Ralamb": 26 | from lib.torch_utils.solver.ralamb import Ralamb 27 | 28 | optimizer_cls = Ralamb 29 | elif optim_type == "RAdam": 30 | from lib.torch_utils.solver.radam import RAdam 31 | 32 | optimizer_cls = RAdam 33 | else: 34 | optimizer_cls = getattr(torch.optim, optim_type) 35 | opt_kwargs = {k: v for k, v in optimizer_cfg.items() if "lookahead" not in k} 36 | optimizer = optimizer_cls(params, **opt_kwargs) 37 | 38 | if len(optim_split) > 1 and not use_hvd: 39 | if optim_split[0].lower() == "lookahead": 40 | from lib.torch_utils.solver.lookahead import Lookahead 41 | 42 | # TODO: pass lookahead hyper-params 43 | optimizer = Lookahead( 44 | optimizer, alpha=optimizer_cfg.get("lookahead_alpha", 0.5), k=optimizer_cfg.get("lookahead_k", 6) 45 | ) 46 | # logger.info(msg(type(optimizer))) 47 | return optimizer 48 | 49 | 50 | def build_optimizer_on_params(params, optimizer_cfg, use_hvd=False): 51 | optimizer_cfg = optimizer_cfg.copy() 52 | return _get_optimizer(params, optimizer_cfg, use_hvd=use_hvd) 53 | 54 | 55 | def build_optimizer(model, optimizer_cfg, cfg=None, use_hvd=False): 56 | """Build optimizer from configs. 57 | 58 | Args: 59 | model (:obj:`nn.Module`): The model with parameters to be optimized. 60 | optimizer_cfg (dict): The config dict of the optimizer. 61 | cfg.optimizer 62 | Positional fields are: 63 | - type: class name of the optimizer. 64 | - lr: base learning rate. 65 | Optional fields are: 66 | - any arguments of the corresponding optimizer type, e.g., 67 | weight_decay, momentum, etc. 68 | - paramwise_options: a dict with 3 accepted fileds 69 | (bias_lr_mult, bias_decay_mult, norm_decay_mult). 70 | `bias_lr_mult` and `bias_decay_mult` will be multiplied to 71 | the lr and weight decay respectively for all bias parameters 72 | (except for the normalization layers), and 73 | `norm_decay_mult` will be multiplied to the weight decay 74 | for all weight and bias parameters of normalization layers. 75 | 76 | Returns: 77 | torch.optim.Optimizer: The initialized optimizer. 78 | """ 79 | if hasattr(model, "module"): 80 | model = model.module 81 | 82 | optimizer_cfg = optimizer_cfg.copy() 83 | paramwise_options = optimizer_cfg.pop("paramwise_options", None) 84 | # if no paramwise option is specified, just use the global setting 85 | if paramwise_options is None: 86 | if cfg is not None and "train" in cfg and cfg.train.get("slow_base", False): 87 | base_params = [p for p_n, p in model.named_parameters() if "emb_head" not in p_n] 88 | active_params = [p for p_n, p in model.named_parameters() if "emb_head" in p_n] 89 | params = [ 90 | {"params": base_params, "lr": cfg.ref.slow_base_ratio * optimizer_cfg["lr"]}, 91 | {"params": active_params}, 92 | ] 93 | return _get_optimizer(params, optimizer_cfg, use_hvd=use_hvd) 94 | else: 95 | return _get_optimizer(model.parameters(), optimizer_cfg, use_hvd=use_hvd) 96 | else: 97 | assert isinstance(paramwise_options, dict) 98 | # get base lr and weight decay 99 | base_lr = optimizer_cfg["lr"] 100 | base_wd = optimizer_cfg.get("weight_decay", None) 101 | # weight_decay must be explicitly specified if mult is specified 102 | if "bias_decay_mult" in paramwise_options or "norm_decay_mult" in paramwise_options: 103 | assert base_wd is not None 104 | # get param-wise options 105 | bias_lr_mult = paramwise_options.get("bias_lr_mult", 1.0) 106 | bias_decay_mult = paramwise_options.get("bias_decay_mult", 1.0) 107 | norm_decay_mult = paramwise_options.get("norm_decay_mult", 1.0) 108 | # set param-wise lr and weight decay 109 | params = [] 110 | for name, param in model.named_parameters(): 111 | if not param.requires_grad: 112 | continue 113 | 114 | param_group = {"params": [param]} 115 | # for norm layers, overwrite the weight decay of weight and bias 116 | # TODO: obtain the norm layer prefixes dynamically 117 | if re.search(r"(bn|gn)(\d+)?.(weight|bias)", name): 118 | if base_wd is not None: 119 | param_group["weight_decay"] = base_wd * norm_decay_mult 120 | # for other layers, overwrite both lr and weight decay of bias 121 | elif name.endswith(".bias"): 122 | param_group["lr"] = base_lr * bias_lr_mult 123 | if base_wd is not None: 124 | param_group["weight_decay"] = base_wd * bias_decay_mult 125 | 126 | # NOTE: add 127 | if cfg is not None and "train" in cfg and cfg.train.get("slow_base", False): 128 | if "emb_head" not in name: # backbone parameters 129 | param_group["lr"] = cfg.ref.slow_base_ratio * base_lr 130 | # otherwise use the global settings 131 | 132 | params.append(param_group) 133 | return _get_optimizer(params, optimizer_cfg, use_hvd=use_hvd) 134 | 135 | 136 | def clip_grad_norm(params, max_norm=35, norm_type=2): 137 | """ 138 | clip_grad_norm = {'max_norm': 35, 'norm_type': 2} 139 | slow down training 140 | """ 141 | clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, params), max_norm=max_norm, norm_type=norm_type) 142 | 143 | 144 | def clip_grad_value(params, clip_value=10): 145 | # slow down training 146 | clip_grad.clip_grad_value_(filter(lambda p: p.requires_grad, params), clip_value=clip_value) 147 | -------------------------------------------------------------------------------- /tools/torch_utils/solver/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py 2 | """TF/Caffe2 (eps inside sqrt): https://github.com/pytorch/pytorch/blob/v0.4.0/ 3 | caffe2/sgd/rmsprop_op_gpu.cu#L24 PyTorch (eps outside sqrt): 4 | https://github.com/pytorch/pytorch/blob/v0.4.0/torch/optim/rmsprop.py#L93 5 | https://github.com/pytorch/pytorch/issues/23796 6 | https://github.com/rwightman/pytorch-image-models/issues/11. 7 | 8 | the eps can be relatively large 9 | ./distributed_train.sh 8 ../ImageNet/ --model efficientnet_b0 -b 256 \ 10 | --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 \ 11 | --opt rmsproptf --opt-eps .001 -j 8 --warmup-epochs 5 \ 12 | --weight-decay 1e-5 --drop 0.2 --color-jitter .06 --model-ema --lr .128 13 | """ 14 | import torch 15 | from torch.optim import Optimizer 16 | 17 | 18 | class RMSpropTF(Optimizer): 19 | """Implements RMSprop algorithm (TensorFlow style epsilon) 20 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 21 | to closer match Tensorflow for matching hyper-params. 22 | Proposed by G. Hinton in his 23 | `course `_. 24 | The centered version first appears in `Generating Sequences 25 | With Recurrent Neural Networks `_. 26 | Arguments: 27 | params (iterable): iterable of parameters to optimize or dicts defining 28 | parameter groups 29 | lr (float, optional): learning rate (default: 1e-2) 30 | momentum (float, optional): momentum factor (default: 0) 31 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 32 | eps (float, optional): term added to the denominator to improve 33 | numerical stability (default: 1e-10) 34 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 35 | the gradient is normalized by an estimation of its variance 36 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 37 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 38 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 39 | update as per defaults in Tensorflow 40 | """ 41 | 42 | def __init__( 43 | self, 44 | params, 45 | lr=1e-2, 46 | alpha=0.9, 47 | eps=1e-10, 48 | weight_decay=0, 49 | momentum=0.0, 50 | centered=False, 51 | decoupled_decay=False, 52 | lr_in_momentum=True, 53 | ): 54 | if not 0.0 <= lr: 55 | raise ValueError("Invalid learning rate: {}".format(lr)) 56 | if not 0.0 <= eps: 57 | raise ValueError("Invalid epsilon value: {}".format(eps)) 58 | if not 0.0 <= momentum: 59 | raise ValueError("Invalid momentum value: {}".format(momentum)) 60 | if not 0.0 <= weight_decay: 61 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 62 | if not 0.0 <= alpha: 63 | raise ValueError("Invalid alpha value: {}".format(alpha)) 64 | 65 | defaults = dict( 66 | lr=lr, 67 | momentum=momentum, 68 | alpha=alpha, 69 | eps=eps, 70 | centered=centered, 71 | weight_decay=weight_decay, 72 | decoupled_decay=decoupled_decay, 73 | lr_in_momentum=lr_in_momentum, 74 | ) 75 | super(RMSpropTF, self).__init__(params, defaults) 76 | 77 | def __setstate__(self, state): 78 | super(RMSpropTF, self).__setstate__(state) 79 | for group in self.param_groups: 80 | group.setdefault("momentum", 0) 81 | group.setdefault("centered", False) 82 | 83 | def step(self, closure=None): 84 | """Performs a single optimization step. 85 | 86 | Arguments: 87 | closure (callable, optional): A closure that reevaluates the model 88 | and returns the loss. 89 | """ 90 | loss = None 91 | if closure is not None: 92 | loss = closure() 93 | 94 | for group in self.param_groups: 95 | for p in group["params"]: 96 | if p.grad is None: 97 | continue 98 | grad = p.grad.data 99 | if grad.is_sparse: 100 | raise RuntimeError("RMSprop does not support sparse gradients") 101 | state = self.state[p] 102 | 103 | # State initialization 104 | if len(state) == 0: 105 | state["step"] = 0 106 | state["square_avg"] = torch.ones_like(p.data) # PyTorch inits to zero 107 | if group["momentum"] > 0: 108 | state["momentum_buffer"] = torch.zeros_like(p.data) 109 | if group["centered"]: 110 | state["grad_avg"] = torch.zeros_like(p.data) 111 | 112 | square_avg = state["square_avg"] 113 | one_minus_alpha = 1.0 - group["alpha"] 114 | 115 | state["step"] += 1 116 | 117 | if group["weight_decay"] != 0: 118 | if "decoupled_decay" in group and group["decoupled_decay"]: 119 | p.data.add_(-group["weight_decay"], p.data) 120 | else: 121 | grad = grad.add(group["weight_decay"], p.data) 122 | 123 | # Tensorflow order of ops for updating squared avg 124 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 125 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 126 | 127 | if group["centered"]: 128 | grad_avg = state["grad_avg"] 129 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 130 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 131 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_() # eps moved in sqrt 132 | else: 133 | avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt 134 | 135 | if group["momentum"] > 0: 136 | buf = state["momentum_buffer"] 137 | # Tensorflow accumulates the LR scaling in the momentum buffer 138 | if "lr_in_momentum" in group and group["lr_in_momentum"]: 139 | buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg) 140 | p.data.add_(-buf) 141 | else: 142 | # PyTorch scales the param update by LR 143 | buf.mul_(group["momentum"]).addcdiv_(grad, avg) 144 | p.data.add_(-group["lr"], buf) 145 | else: 146 | p.data.addcdiv_(-group["lr"], grad, avg) 147 | 148 | return loss 149 | -------------------------------------------------------------------------------- /tools/pyTorchChamferDistance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | --------------------------------------------------------------------------------