├── LICENSE ├── README.md ├── assets └── pipeline.png ├── configs └── config.py ├── data └── Real │ └── train │ └── mug_handle.pkl ├── datasets └── datasets_genpose.py ├── networks ├── decoder_head │ ├── rot_head.py │ └── trans_head.py ├── gf_algorithms │ ├── energynet.py │ ├── losses.py │ ├── samplers.py │ ├── score_utils.py │ ├── scorenet.py │ └── sde.py ├── posenet.py ├── posenet_agent.py ├── pts_encoder │ ├── pointnet2.py │ ├── pointnet2_utils │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── pointnet2 │ │ │ ├── pointnet2_modules.py │ │ │ ├── pointnet2_utils.py │ │ │ ├── pytorch_utils.py │ │ │ ├── setup.py │ │ │ └── src │ │ │ │ ├── ball_query.cpp │ │ │ │ ├── ball_query_gpu.cu │ │ │ │ ├── ball_query_gpu.h │ │ │ │ ├── cuda_utils.h │ │ │ │ ├── group_points.cpp │ │ │ │ ├── group_points_gpu.cu │ │ │ │ ├── group_points_gpu.h │ │ │ │ ├── interpolate.cpp │ │ │ │ ├── interpolate_gpu.cu │ │ │ │ ├── interpolate_gpu.h │ │ │ │ ├── pointnet2_api.cpp │ │ │ │ ├── sampling.cpp │ │ │ │ ├── sampling_gpu.cu │ │ │ │ └── sampling_gpu.h │ │ └── tools │ │ │ ├── _init_path.py │ │ │ ├── data │ │ │ └── KITTI │ │ │ │ └── ImageSets │ │ │ │ ├── test.txt │ │ │ │ ├── train.txt │ │ │ │ ├── trainval.txt │ │ │ │ └── val.txt │ │ │ ├── dataset.py │ │ │ ├── kitti_utils.py │ │ │ ├── pointnet2_msg.py │ │ │ └── train_and_eval.py │ └── pointnets.py └── reward.py ├── requirements.txt ├── runners ├── evaluation_single.py ├── evaluation_tracking.py └── trainer.py ├── scripts ├── eval_single.sh ├── eval_tracking.sh ├── tensorboard.sh ├── train_energy.sh └── train_score.sh └── utils ├── data_augmentation.py ├── datasets_utils.py ├── genpose_utils.py ├── metrics.py ├── misc.py ├── sgpa_utils.py ├── so3_visualize.py ├── tracking_utils.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jiyao Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenPose: Generative Category-level Object Pose Estimation via Diffusion Models 2 | 3 | 4 | [![Website](https://img.shields.io/badge/Website-orange.svg )](https://sites.google.com/view/genpose) 5 | [![Arxiv](https://img.shields.io/badge/Arxiv-green.svg )](https://arxiv.org/pdf/2306.10531.pdf) 6 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FJiyao06%2FGenPose&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com) 7 | [![GitHub license](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Jiyao06/GenPose/blob/main/LICENSE) 8 | [![SOTA](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/genpose-generative-category-level-object-pose/6d-pose-estimation-using-rgbd-on-real275)](https://paperswithcode.com/sota/6d-pose-estimation-using-rgbd-on-real275?p=genpose-generative-category-level-object-pose) 9 | 10 | The official Pytorch implementation of the NeurIPS 2023 paper, [GenPose](https://arxiv.org/pdf/2306.10531.pdf). 11 | 12 | 13 | ## News 14 | - **2024/08**: [Omni6DPoseAPI](https://github.com/Omni6DPose/Omni6DPoseAPI) and [genpose++](https://github.com/Omni6DPose/GenPose2) are released! 15 | - **2024/07** [Omni6DPose](https://jiyao06.github.io/Omni6DPose/), the largest and most diverse universal 6D object pose estimation benchmark, gets accepted to ECCV 2024. 16 | - **2023/09**: GenPose gets accepted to NeurIPS 2023 17 | 18 | ## Overview 19 | 20 | ![Pipeline](./assets/pipeline.png) 21 | 22 | **(I)** A score-based diffusion model and an energy-based diffusion model is trained via denoising score-matching. 23 | **(II)** a) We first generate pose candidates from the score-based model and then b) compute the pose energies for candidates via the energy-based model. 24 | c) Finally, we rank the candidates with the energies and then filter out low-ranking candidates. 25 | The remaining candidates are aggregated into the final output by mean-pooling. 26 | 27 | Contents of this repo are as follows: 28 | 29 | * [Overview](#overview) 30 | * [TODO](#todo) 31 | * [Requirements](#requirements) 32 | * [Installation](#installation) 33 | + [Install pytorch](#install-pytorch) 34 | + [Install pytorch3d from a local clone](#install-pytorch3d-from-a-local-clone) 35 | + [Install from requirements.txt](#install-from-requirementstxt) 36 | + [Compile pointnet2](#compile-pointnet2) 37 | * [Download dataset and models](#download-dataset-and-models) 38 | * [Training](#training) 39 | + [Score network](#score-network) 40 | + [Energy network](#energy-network) 41 | * [Evaluation](#evaluation) 42 | + [Evaluate on REAL275 dataset.](#evaluate-on-real275-dataset) 43 | + [Evaluate on CAMERA dataset.](#evaluate-on-camera-dataset) 44 | * [Citation](#citation) 45 | * [Contact](#contact) 46 | * [License](#license) 47 | 48 | ## UPDATE 49 | - Release the code for object pose tracking. 50 | - Release the preprocessed data for object pose tracking. 51 | 52 | ## Requirements 53 | - Ubuntu 20.04 54 | - Python 3.8.15 55 | - Pytorch 1.12.0 56 | - Pytorch3d 0.7.2 57 | - CUDA 11.3 58 | - 1 * NVIDIA RTX 3090 59 | 60 | ## Installation 61 | 62 | - ### Install pytorch 63 | ``` bash 64 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 65 | ``` 66 | 67 | 68 | - ### Install pytorch3d from a local clone 69 | ``` bash 70 | git clone https://github.com/facebookresearch/pytorch3d.git 71 | cd pytorch3d 72 | git checkout -f v0.7.2 73 | pip install -e . 74 | ``` 75 | 76 | - ### Install from requirements.txt 77 | ``` bash 78 | pip install -r requirements.txt 79 | ``` 80 | 81 | - ### Compile pointnet2 82 | ``` bash 83 | cd networks/pts_encoder/pointnet2_utils/pointnet2 84 | python setup.py install 85 | ``` 86 | 87 | ## Download dataset and models 88 | - Download camera_train, camera_val, real_train, real_test, ground-truth annotations and mesh models provided by NOCS and unzip the data. Then move the file "mug_handle.pkl" from this repository's "data/Real/train" folder to the corresponding unzipped folders. The file "mug_handle.pkl" is provided by [GPV-Pose](https://github.com/lolrudy/GPV_Pose/blob/master/mug_handle.pkl). Organize these files in $ROOT/data as follows: 89 | ``` bash 90 | data 91 | ├── CAMERA 92 | │ ├── train 93 | │ └── val 94 | ├── Real 95 | │ ├── train 96 | │ │ ├── mug_handle.pkl 97 | │ │ └── ... 98 | │ └── test 99 | ├── gts 100 | │ ├── val 101 | │ └── real_test 102 | └── obj_models 103 | ├── train 104 | ├── val 105 | ├── real_train 106 | └── real_test 107 | ``` 108 | 109 | - Preprocess NOCS files following SPD. 110 | 111 | We provide the preprocessed testing data (REAL275) and checkpoints here for a quick evaluation. Download and organize the files in $ROOT/results as follows: 112 | ``` bash 113 | results 114 | ├── ckpts 115 | │ ├── EnergyNet 116 | │ │ └── ckpt_genpose.pth 117 | │ └── ScoreNet 118 | │ └── ckpt_genpose.pth 119 | ├── evaluation_results 120 | │ ├── segmentation_logs_real_test.txt 121 | │ └── segmentation_results_real_test.pkl 122 | └── mrcnn_results 123 | ├── aligned_real_test 124 | ├── real_test 125 | └── val 126 | ``` 127 | The *ckpts* are the trained models of GenPose. 128 | 129 | The *evaluation_results* are the preprocessed testing data, which contains the segmentation results of Mask R-CNN, the segmented pointclouds of obejcts, and the ground-truth poses. 130 | 131 | The file *mrcnn_results* represents the segmentation results provided by SPD, and you also can find it here. Note that the file *mrcnn_results/aligned_real_test* contains the manually aligned segmentation results, used for object pose tracking. 132 | 133 | **Note**: You need to preprocess the dataset as mentioned before first if you want to evaluate on CAMERA dataset. 134 | 135 | ## Training 136 | Set the parameter '--data_path' in scripts/train_score.sh and scripts/train_energy.sh to your own path of NOCS dataset. 137 | 138 | - ### Score network 139 | Train the score network to generate the pose candidates. 140 | ``` bash 141 | bash scripts/train_score.sh 142 | ``` 143 | - ### Energy network 144 | Train the energy network to aggragate the pose candidates. 145 | ``` bash 146 | bash scripts/train_energy.sh 147 | ``` 148 | 149 | ## Evaluation 150 | Set the parameter *--data_path* in *scripts/eval_single.sh* and *scripts/eval_tracking* to your own path of NOCS dataset. 151 | 152 | - ### Evaluate on REAL275 dataset. 153 | Set the parameter *--test_source* in *scripts/eval_single.sh* to *'real_test'* and run: 154 | ``` bash 155 | bash scripts/eval_single.sh 156 | ``` 157 | - ### Evaluate on CAMERA dataset. 158 | Set the parameter *--test_source* in *scripts/eval_single.sh* to *'val'* and run: 159 | ``` bash 160 | bash scripts/eval_single.sh 161 | ``` 162 | - ### Pose tracking on REAL275 dataset. 163 | ``` bash 164 | bash scripts/eval_tracking.sh 165 | ``` 166 | 167 | ## Citation 168 | If you find our work useful in your research, please consider citing: 169 | ``` bash 170 | @article{zhang2024generative, 171 | title={Generative category-level object pose estimation via diffusion models}, 172 | author={Zhang, Jiyao and Wu, Mingdong and Dong, Hao}, 173 | journal={Advances in Neural Information Processing Systems}, 174 | volume={36}, 175 | year={2024} 176 | } 177 | ``` 178 | 179 | ## Contact 180 | If you have any questions, please feel free to contact us: 181 | 182 | [Jiyao Zhang](https://jiyao06.github.io/): [jiyaozhang@stu.pku.edu.cn](mailto:jiyaozhang@stu.pku.edu.cn) 183 | 184 | [Mingdong Wu](https://aaronanima.github.io/): [wmingd@pku.edu.cn](mailto:wmingd@pku.edu.cn) 185 | 186 | [Hao Dong](https://zsdonghao.github.io/): [hao.dong@pku.edu.cn](mailto:hao.dong@pku.edu.cn) 187 | 188 | ## License 189 | This project is released under the MIT license. See [LICENSE](LICENSE) for additional details. 190 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiyao06/GenPose/3a374e9a1a1dddc825b7d507ebcf38f2e3510aec/assets/pipeline.png -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ipdb import set_trace 3 | 4 | def get_config(): 5 | parser = argparse.ArgumentParser() 6 | 7 | """ dataset """ 8 | parser.add_argument('--synset_names', nargs='+', default=['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug']) 9 | parser.add_argument('--selected_classes', nargs='+') 10 | parser.add_argument('--data_path', type=str) 11 | parser.add_argument('--o2c_pose', default=True, action='store_true') 12 | parser.add_argument('--batch_size', type=int, default=192) 13 | parser.add_argument('--max_batch_size', type=int, default=192) 14 | parser.add_argument('--mini_bs', type=int, default=192) 15 | parser.add_argument('--pose_mode', type=str, default='rot_matrix') 16 | parser.add_argument('--seed', type=int, default=0) 17 | parser.add_argument('--percentage_data_for_train', type=float, default=1.0) 18 | parser.add_argument('--percentage_data_for_val', type=float, default=1.0) 19 | parser.add_argument('--percentage_data_for_test', type=float, default=1.0) 20 | parser.add_argument('--train_source', type=str, default='CAMERA+Real') 21 | parser.add_argument('--val_source', type=str, default='CAMERA') 22 | parser.add_argument('--test_source', type=str, default='Real') 23 | parser.add_argument('--device', type=str, default='cuda') 24 | parser.add_argument('--num_points', type=int, default=1024) 25 | parser.add_argument('--per_obj', type=str, default='') 26 | parser.add_argument('--num_workers', type=int, default=32) 27 | 28 | 29 | """ model """ 30 | parser.add_argument('--posenet_mode', type=str, default='score') 31 | parser.add_argument('--hidden_dim', type=int, default=128) 32 | parser.add_argument('--sampler_mode', nargs='+') 33 | parser.add_argument('--sampling_steps', type=int) 34 | parser.add_argument('--sde_mode', type=str, default='ve') 35 | parser.add_argument('--sigma', type=float, default=25) # base-sigma for SDE 36 | parser.add_argument('--likelihood_weighting', default=False, action='store_true') 37 | parser.add_argument('--regression_head', type=str, default='Rx_Ry_and_T') 38 | parser.add_argument('--pointnet2_params', type=str, default='light') 39 | parser.add_argument('--pts_encoder', type=str, default='pointnet2') 40 | parser.add_argument('--energy_mode', type=str, default='IP') 41 | parser.add_argument('--s_theta_mode', type=str, default='score') 42 | parser.add_argument('--norm_energy', type=str, default='identical') 43 | 44 | 45 | """ training """ 46 | parser.add_argument('--agent_type', type=str, default='score', help='one of the [score, energy, energy_with_ranking]') 47 | parser.add_argument('--pretrained_score_model_path', type=str) 48 | parser.add_argument('--pretrained_energy_model_path', type=str) 49 | parser.add_argument('--distillation', default=False, action='store_true') 50 | parser.add_argument('--n_epochs', type=int, default=1000) 51 | parser.add_argument('--log_dir', type=str, default='debug') 52 | parser.add_argument('--optimizer', type=str, default='Adam') 53 | parser.add_argument('--eval_freq', type=int, default=100) 54 | parser.add_argument('--repeat_num', type=int, default=20) 55 | parser.add_argument('--grad_clip', type=float, default=1.) 56 | parser.add_argument('--ema_rate', type=float, default=0.999) 57 | parser.add_argument('--lr', type=float, default=1e-3) 58 | parser.add_argument('--warmup', type=int, default=100) 59 | parser.add_argument('--lr_decay', type=float, default=0.98) 60 | parser.add_argument('--use_pretrain', default=False, action='store_true') 61 | parser.add_argument('--parallel', default=False, action='store_true') 62 | parser.add_argument('--num_gpu', type=int, default=4) 63 | parser.add_argument('--is_train', default=False, action='store_true') 64 | 65 | 66 | """ testing """ 67 | parser.add_argument('--eval', default=False, action='store_true') 68 | parser.add_argument('--pred', default=False, action='store_true') 69 | parser.add_argument('--model_name', type=str) 70 | parser.add_argument('--eval_repeat_num', type=int, default=50) 71 | parser.add_argument('--save_video', default=False, action='store_true') 72 | parser.add_argument('--max_eval_num', type=int, default=10000000) 73 | parser.add_argument('--results_path', type=str, default='') 74 | parser.add_argument('--T0', type=float, default=1.0) 75 | 76 | 77 | """ nocs_mrcnn testing""" 78 | parser.add_argument('--img_size', type=int, default=256, help='cropped image size') 79 | parser.add_argument('--result_dir', type=str, default='', help='result directory') 80 | parser.add_argument('--model_dir_list', nargs='+') 81 | parser.add_argument('--energy_model_dir', type=str, default='', help='energy network ckpt directory') 82 | parser.add_argument('--score_model_dir', type=str, default='', help='score network ckpt directory') 83 | parser.add_argument('--ranker', type=str, default='energy_ranker', help='energy_ranker, gt_ranker or random') 84 | parser.add_argument('--pooling_mode', type=str, default='nearest', help='nearest or average') 85 | 86 | 87 | cfg = parser.parse_args() 88 | 89 | # dynamic zoom in parameters 90 | cfg.DYNAMIC_ZOOM_IN_PARAMS = { 91 | 'DZI_PAD_SCALE': 1.5, 92 | 'DZI_TYPE': 'uniform', 93 | 'DZI_SCALE_RATIO': 0.25, 94 | 'DZI_SHIFT_RATIO': 0.25 95 | } 96 | 97 | # pts aug parameters 98 | cfg.PTS_AUG_PARAMS = { 99 | 'aug_pc_pro': 0.2, 100 | 'aug_pc_r': 0.2, 101 | 'aug_rt_pro': 0.3, 102 | 'aug_bb_pro': 0.3, 103 | 'aug_bc_pro': 0.3 104 | } 105 | 106 | # 2D aug parameters 107 | cfg.DEFORM_2D_PARAMS = { 108 | 'roi_mask_r': 3, 109 | 'roi_mask_pro': 0.5 110 | } 111 | 112 | return cfg 113 | 114 | -------------------------------------------------------------------------------- /data/Real/train/mug_handle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiyao06/GenPose/3a374e9a1a1dddc825b7d507ebcf38f2e3510aec/data/Real/train/mug_handle.pkl -------------------------------------------------------------------------------- /networks/decoder_head/rot_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from ipdb import set_trace 5 | 6 | 7 | class RotHead(nn.Module): 8 | def __init__(self, in_feat_dim, out_dim=3): 9 | super(RotHead, self).__init__() 10 | self.f = in_feat_dim 11 | self.k = out_dim 12 | 13 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1) 14 | self.conv2 = torch.nn.Conv1d(1024, 256, 1) 15 | self.conv3 = torch.nn.Conv1d(256, 256, 1) 16 | self.conv4 = torch.nn.Conv1d(256, self.k, 1) 17 | self.drop1 = nn.Dropout(0.2) 18 | self.bn1 = nn.BatchNorm1d(1024) 19 | self.bn2 = nn.BatchNorm1d(256) 20 | self.bn3 = nn.BatchNorm1d(256) 21 | 22 | def forward(self, x): 23 | x = F.relu(self.bn1(self.conv1(x))) 24 | x = F.relu(self.bn2(self.conv2(x))) 25 | 26 | x = torch.max(x, 2, keepdim=True)[0] 27 | 28 | x = F.relu(self.bn3(self.conv3(x))) 29 | x = self.drop1(x) 30 | x = self.conv4(x) 31 | 32 | x = x.squeeze(2) 33 | x = x.contiguous() 34 | 35 | return x 36 | 37 | 38 | def main(): 39 | points = torch.rand(2, 1350, 1024) # batchsize x feature x numofpoint 40 | rot_head = RotHead(in_feat_dim=1350, out_dim=3) 41 | rot = rot_head(points) 42 | print(rot.shape) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() -------------------------------------------------------------------------------- /networks/decoder_head/trans_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from ipdb import set_trace 5 | 6 | # Point_center encode the segmented point cloud 7 | # one more conv layer compared to original paper 8 | 9 | class TransHead(nn.Module): 10 | def __init__(self, in_feat_dim, out_dim=3): 11 | super(TransHead, self).__init__() 12 | self.f = in_feat_dim 13 | self.k = out_dim 14 | 15 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1) 16 | 17 | self.conv2 = torch.nn.Conv1d(1024, 256, 1) 18 | self.conv3 = torch.nn.Conv1d(256, 256, 1) 19 | self.conv4 = torch.nn.Conv1d(256, self.k, 1) 20 | self.drop1 = nn.Dropout(0.2) 21 | self.bn1 = nn.BatchNorm1d(1024) 22 | self.bn2 = nn.BatchNorm1d(256) 23 | self.bn3 = nn.BatchNorm1d(256) 24 | self.relu1 = nn.ReLU() 25 | self.relu2 = nn.ReLU() 26 | self.relu3 = nn.ReLU() 27 | 28 | def forward(self, x): 29 | x = self.relu1(self.bn1(self.conv1(x))) 30 | x = self.relu2(self.bn2(self.conv2(x))) 31 | 32 | x = torch.max(x, 2, keepdim=True)[0] 33 | 34 | x = self.relu3(self.bn3(self.conv3(x))) 35 | x = self.drop1(x) 36 | x = self.conv4(x) 37 | 38 | x = x.squeeze(2) 39 | x = x.contiguous() 40 | return x 41 | 42 | 43 | def main(): 44 | feature = torch.rand(10, 1896, 1000) 45 | net = TransHead(in_feat_dim=1896, out_dim=3) 46 | out = net(feature) 47 | print(out.shape) 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /networks/gf_algorithms/energynet.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ipdb import set_trace 6 | from torch.autograd import Variable 7 | from networks.gf_algorithms.samplers import cond_ode_likelihood, cond_ode_sampler, cond_pc_sampler 8 | from networks.gf_algorithms.scorenet import GaussianFourierProjection 9 | from networks.pts_encoder.pointnet2 import Pointnet2ClsMSG 10 | from networks.pts_encoder.pointnets import PointNetfeat 11 | from utils.genpose_utils import get_pose_dim 12 | 13 | 14 | def zero_module(module): 15 | """ 16 | Zero out the parameters of a module and return it. 17 | """ 18 | for p in module.parameters(): 19 | p.detach().zero_() 20 | return module 21 | 22 | 23 | class TemporaryGrad(object): 24 | def __enter__(self): 25 | self.prev = torch.is_grad_enabled() 26 | torch.set_grad_enabled(True) 27 | 28 | def __exit__(self, exc_type, exc_value, traceback) -> None: 29 | torch.set_grad_enabled(self.prev) 30 | 31 | 32 | class PoseEnergyNet(nn.Module): 33 | def __init__( 34 | self, 35 | marginal_prob_func, 36 | pose_mode='quat_wxyz', 37 | regression_head='RT', 38 | device='cuda', 39 | energy_mode='L2', # ['DAE', 'L2', 'IP'] 40 | s_theta_mode='score', # ['score', 'decoder', 'identical']) 41 | norm_energy='identical' # ['identical', 'std', 'minus'] 42 | ): 43 | super(PoseEnergyNet, self).__init__() 44 | self.device = device 45 | self.regression_head = regression_head 46 | self.act = nn.ReLU(True) 47 | self.pose_dim = get_pose_dim(pose_mode) 48 | self.energy_mode = energy_mode 49 | self.s_theta_mode = s_theta_mode 50 | self.norm_energy = norm_energy 51 | 52 | ''' encode pose ''' 53 | self.pose_encoder = nn.Sequential( 54 | nn.Linear(self.pose_dim, 256), 55 | self.act, 56 | nn.Linear(256, 256), 57 | self.act, 58 | ) 59 | 60 | ''' encode t ''' 61 | self.t_encoder = nn.Sequential( 62 | GaussianFourierProjection(embed_dim=128), 63 | # self.act, 64 | nn.Linear(128, 128), 65 | self.act, 66 | ) 67 | 68 | ''' fusion tail ''' 69 | if self.regression_head == 'RT': 70 | self.fusion_tail = nn.Sequential( 71 | nn.Linear(128+256+1024, 512), 72 | self.act, 73 | zero_module(nn.Linear(512, self.pose_dim)), 74 | ) 75 | 76 | 77 | elif self.regression_head == 'R_and_T': 78 | ''' rotation regress head ''' 79 | self.fusion_tail_rot = nn.Sequential( 80 | nn.Linear(128+256+1024, 256), 81 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 82 | self.act, 83 | zero_module(nn.Linear(256, self.pose_dim-3)), 84 | ) 85 | 86 | ''' tranalation regress head ''' 87 | self.fusion_tail_trans = nn.Sequential( 88 | nn.Linear(128+256+1024, 256), 89 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 90 | self.act, 91 | zero_module(nn.Linear(256, 3)), 92 | ) 93 | 94 | 95 | elif self.regression_head == 'Rx_Ry_and_T': 96 | if pose_mode != 'rot_matrix': 97 | raise NotImplementedError 98 | ''' rotation_x_axis regress head ''' 99 | self.fusion_tail_rot_x = nn.Sequential( 100 | nn.Linear(128+256+1024, 256), 101 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 102 | self.act, 103 | zero_module(nn.Linear(256, 3)), 104 | ) 105 | self.fusion_tail_rot_y = nn.Sequential( 106 | nn.Linear(128+256+1024, 256), 107 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 108 | self.act, 109 | zero_module(nn.Linear(256, 3)), 110 | ) 111 | 112 | ''' tranalation regress head ''' 113 | self.fusion_tail_trans = nn.Sequential( 114 | nn.Linear(128+256+1024, 256), 115 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 116 | self.act, 117 | zero_module(nn.Linear(256, 3)), 118 | ) 119 | 120 | else: 121 | raise NotImplementedError 122 | 123 | 124 | self.marginal_prob_func = marginal_prob_func 125 | 126 | 127 | def output_zero_initial(self): 128 | if self.regression_head == 'RT': 129 | zero_module(self.fusion_tail[-1]) 130 | 131 | elif self.regression_head == 'R_and_T': 132 | zero_module(self.fusion_tail_rot[-1]) 133 | zero_module(self.fusion_tail_trans[-1]) 134 | 135 | elif self.regression_head == 'Rx_Ry_and_T': 136 | zero_module(self.fusion_tail_rot_x[-1]) 137 | zero_module(self.fusion_tail_rot_y[-1]) 138 | zero_module(self.fusion_tail_trans[-1]) 139 | else: 140 | raise NotImplementedError 141 | 142 | 143 | def get_energy(self, pts_feat, sampled_pose, t, decoupled_rt=True): 144 | t_feat = self.t_encoder(t.squeeze(1)) 145 | pose_feat = self.pose_encoder(sampled_pose) 146 | 147 | total_feat = torch.cat([pts_feat, t_feat, pose_feat], dim=-1) 148 | _, std = self.marginal_prob_func(total_feat, t) 149 | 150 | ''' get f_{theta} ''' 151 | if self.regression_head == 'RT': 152 | f_theta = self.fusion_tail(total_feat) 153 | elif self.regression_head == 'R_and_T': 154 | rot = self.fusion_tail_rot(total_feat) 155 | trans = self.fusion_tail_trans(total_feat) 156 | f_theta = torch.cat([rot, trans], dim=-1) 157 | elif self.regression_head == 'Rx_Ry_and_T': 158 | rot_x = self.fusion_tail_rot_x(total_feat) 159 | rot_y = self.fusion_tail_rot_y(total_feat) 160 | trans = self.fusion_tail_trans(total_feat) 161 | f_theta = torch.cat([rot_x, rot_y, trans], dim=-1) 162 | else: 163 | raise NotImplementedError 164 | 165 | ''' get s_{theta} ''' 166 | if self.s_theta_mode == 'score': 167 | s_theta = f_theta / std 168 | elif self.s_theta_mode == 'decoder': 169 | s_theta = sampled_pose - std * f_theta 170 | elif self.s_theta_mode == 'identical': 171 | s_theta = f_theta 172 | else: 173 | raise NotImplementedError 174 | 175 | ''' get energy ''' 176 | if self.energy_mode == 'DAE': 177 | energy = - 0.5 * torch.sum((sampled_pose - s_theta) ** 2, dim=-1) 178 | elif self.energy_mode == 'L2': 179 | energy = - 0.5 * torch.sum(s_theta ** 2, dim=-1) 180 | elif self.energy_mode == 'IP': # Inner Product 181 | energy = torch.sum(sampled_pose * s_theta, dim=-1) 182 | if decoupled_rt: 183 | energy_rot = torch.sum(sampled_pose[:, :-3] * s_theta[:, :-3], dim=-1) 184 | energy_trans = torch.sum(sampled_pose[:, -3:] * s_theta[:, -3:], dim=-1) 185 | energy = torch.cat((energy_rot.unsqueeze(-1), energy_trans.unsqueeze(-1)), dim=-1) 186 | else: 187 | raise NotImplementedError 188 | 189 | ''' normalisation ''' 190 | if self.norm_energy == 'identical': 191 | pass 192 | elif self.norm_energy == 'std': 193 | energy = energy / (std + 1e-7) 194 | elif self.norm_energy == 'minus': # Inner Product 195 | energy = - energy 196 | else: 197 | raise NotImplementedError 198 | return energy 199 | 200 | 201 | def forward(self, data, return_item='score'): 202 | pts_feat = data['pts_feat'] 203 | sampled_pose = data['sampled_pose'] 204 | t = data['t'] 205 | 206 | if return_item == 'energy': 207 | energy = self.get_energy(pts_feat, sampled_pose, t) 208 | return energy 209 | 210 | with TemporaryGrad(): 211 | inp_variable_sampled_pose = Variable(sampled_pose, requires_grad=True) 212 | energy = self.get_energy(pts_feat, inp_variable_sampled_pose, t, decoupled_rt=False) 213 | scores, = torch.autograd.grad(energy, inp_variable_sampled_pose, 214 | grad_outputs=energy.data.new(energy.shape).fill_(1), 215 | create_graph=True) 216 | # inp_variable_sampled_pose = None # release the variable 217 | if return_item == 'score': 218 | return scores 219 | elif return_item == 'score_and_energy': 220 | return scores, energy 221 | else: 222 | raise NotImplementedError 223 | 224 | 225 | -------------------------------------------------------------------------------- /networks/gf_algorithms/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from ipdb import set_trace 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | def loss_fn_edm( 9 | model, 10 | data, 11 | marginal_prob_func, 12 | sde_fn, 13 | eps=1e-5, 14 | likelihood_weighting=False, 15 | P_mean=-1.2, 16 | P_std=1.2, 17 | sigma_data=1.4148, 18 | sigma_min=0.002, 19 | sigma_max=80, 20 | ): 21 | pts = data['zero_mean_pts'] 22 | y = data['zero_mean_gt_pose'] 23 | bs = pts.shape[0] 24 | 25 | # get noise n 26 | z = torch.randn_like(y) # [bs, pose_dim] 27 | # log_sigma_t = torch.randn([bs, 1], device=device) # [bs, 1] 28 | # sigma_t = (P_std * log_sigma_t + P_mean).exp() # [bs, 1] 29 | log_sigma_t = torch.rand([bs, 1], device=device) # [bs, 1] 30 | sigma_t = (math.log(sigma_min) + log_sigma_t * (math.log(sigma_max) - math.log(sigma_min))).exp() # [bs, 1] 31 | 32 | n = z * sigma_t 33 | 34 | perturbed_x = y + n # [bs, pose_dim] 35 | data['sampled_pose'] = perturbed_x 36 | data['t'] = sigma_t # t and sigma is interchangable in EDM 37 | data, output = model(data) # [bs, pose_dim] 38 | 39 | # set_trace() 40 | 41 | # same as VE 42 | loss_ = torch.mean(torch.sum(((output * sigma_t + z)**2).view(bs, -1), dim=-1)) 43 | 44 | return loss_ 45 | 46 | 47 | def loss_fn( 48 | model, 49 | data, 50 | marginal_prob_func, 51 | sde_fn, 52 | eps=1e-5, 53 | likelihood_weighting=False, 54 | teacher_model=None, 55 | pts_feat_teacher=None 56 | ): 57 | pts = data['zero_mean_pts'] 58 | gt_pose = data['zero_mean_gt_pose'] 59 | 60 | ''' get std ''' 61 | bs = pts.shape[0] 62 | random_t = torch.rand(bs, device=device) * (1. - eps) + eps # [bs, ] 63 | random_t = random_t.unsqueeze(-1) # [bs, 1] 64 | mu, std = marginal_prob_func(gt_pose, random_t) # [bs, pose_dim], [bs] 65 | std = std.view(-1, 1) # [bs, 1] 66 | 67 | ''' perturb data and get estimated score ''' 68 | z = torch.randn_like(gt_pose) # [bs, pose_dim] 69 | perturbed_x = mu + z * std # [bs, pose_dim] 70 | data['sampled_pose'] = perturbed_x 71 | data['t'] = random_t 72 | estimated_score = model(data) # [bs, pose_dim] 73 | 74 | ''' get target score ''' 75 | if teacher_model is None: 76 | # theoretic estimation 77 | target_score = - z * std / (std ** 2) 78 | else: 79 | # distillation 80 | pts_feat_student = data['pts_feat'].clone() 81 | data['pts_feat'] = pts_feat_teacher 82 | target_score = teacher_model(data) 83 | data['pts_feat'] = pts_feat_student 84 | 85 | ''' loss weighting ''' 86 | loss_weighting = std ** 2 87 | loss_ = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score)**2).view(bs, -1), dim=-1)) 88 | 89 | return loss_ 90 | 91 | 92 | -------------------------------------------------------------------------------- /networks/gf_algorithms/samplers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from scipy import integrate 7 | from ipdb import set_trace 8 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 9 | from utils.genpose_utils import get_pose_dim 10 | from utils.misc import normalize_rotation 11 | 12 | 13 | def global_prior_likelihood(z, sigma_max): 14 | """The likelihood of a Gaussian distribution with mean zero and 15 | standard deviation sigma.""" 16 | # z: [bs, pose_dim] 17 | shape = z.shape 18 | N = np.prod(shape[1:]) # pose_dim 19 | return -N / 2. * torch.log(2*np.pi*sigma_max**2) - torch.sum(z**2, dim=-1) / (2 * sigma_max**2) 20 | 21 | 22 | def cond_ode_likelihood( 23 | score_model, 24 | data, 25 | prior, 26 | sde_coeff, 27 | marginal_prob_fn, 28 | atol=1e-5, 29 | rtol=1e-5, 30 | device='cuda', 31 | eps=1e-5, 32 | num_steps=None, 33 | pose_mode='quat_wxyz', 34 | init_x=None, 35 | ): 36 | 37 | pose_dim = get_pose_dim(pose_mode) 38 | batch_size = data['pts'].shape[0] 39 | epsilon = prior((batch_size, pose_dim)).to(device) 40 | init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x 41 | shape = init_x.shape 42 | init_logp = np.zeros((shape[0],)) # [bs] 43 | init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0) 44 | 45 | def score_eval_wrapper(data): 46 | """A wrapper of the score-based model for use by the ODE solver.""" 47 | with torch.no_grad(): 48 | score = score_model(data) 49 | return score.cpu().numpy().reshape((-1,)) 50 | 51 | def divergence_eval(data, epsilon): 52 | """Compute the divergence of the score-based model with Skilling-Hutchinson.""" 53 | # save ckpt of sampled_pose 54 | origin_sampled_pose = data['sampled_pose'].clone() 55 | with torch.enable_grad(): 56 | # make sampled_pose differentiable 57 | data['sampled_pose'].requires_grad_(True) 58 | score = score_model(data) 59 | score_energy = torch.sum(score * epsilon) # [, ] 60 | grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim] 61 | # reset sampled_pose 62 | data['sampled_pose'] = origin_sampled_pose 63 | return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1] 64 | 65 | def divergence_eval_wrapper(data): 66 | """A wrapper for evaluating the divergence of score for the black-box ODE solver.""" 67 | with torch.no_grad(): 68 | # Compute likelihood. 69 | div = divergence_eval(data, epsilon) # [bs, 1] 70 | return div.cpu().numpy().reshape((-1,)).astype(np.float64) 71 | 72 | def ode_func(t, inp): 73 | """The ODE function for use by the ODE solver.""" 74 | # split x, logp from inp 75 | x = inp[:-shape[0]] 76 | logp = inp[-shape[0]:] # haha, actually we do not need use logp here 77 | # calc x-grad 78 | x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) 79 | time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t 80 | drift, diffusion = sde_coeff(torch.tensor(t)) 81 | drift = drift.cpu().numpy() 82 | diffusion = diffusion.cpu().numpy() 83 | data['sampled_pose'] = x 84 | data['t'] = time_steps 85 | x_grad = drift - 0.5 * (diffusion**2) * score_eval_wrapper(data) 86 | # calc logp-grad 87 | logp_grad = drift - 0.5 * (diffusion**2) * divergence_eval_wrapper(data) 88 | # concat curr grad 89 | return np.concatenate([x_grad, logp_grad], axis=0) 90 | 91 | # Run the black-box ODE solver, note the 92 | res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45') 93 | zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)] 94 | z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim] 95 | delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp 96 | _, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1 97 | prior_logp = global_prior_likelihood(z, sigma_max) 98 | log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls) 99 | return z, log_likelihoods 100 | 101 | 102 | def cond_pc_sampler( 103 | score_model, 104 | data, 105 | prior, 106 | sde_coeff, 107 | num_steps=500, 108 | snr=0.16, 109 | device='cuda', 110 | eps=1e-5, 111 | pose_mode='quat_wxyz', 112 | init_x=None, 113 | ): 114 | 115 | pose_dim = get_pose_dim(pose_mode) 116 | batch_size = data['pts'].shape[0] 117 | init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x 118 | time_steps = torch.linspace(1., eps, num_steps, device=device) 119 | step_size = time_steps[0] - time_steps[1] 120 | noise_norm = np.sqrt(pose_dim) 121 | x = init_x 122 | poses = [] 123 | with torch.no_grad(): 124 | for time_step in time_steps: 125 | batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step 126 | # Corrector step (Langevin MCMC) 127 | data['sampled_pose'] = x 128 | data['t'] = batch_time_step 129 | grad = score_model(data) 130 | grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean() 131 | langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2 132 | x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x) 133 | 134 | # normalisation 135 | if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw': 136 | # quat, should be normalised 137 | x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True) 138 | elif pose_mode == 'euler_xyz': 139 | pass 140 | else: 141 | # rotation(x axis, y axis), should be normalised 142 | x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True) 143 | x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True) 144 | 145 | # Predictor step (Euler-Maruyama) 146 | drift, diffusion = sde_coeff(batch_time_step) 147 | drift = drift - diffusion**2*grad # R-SDE 148 | mean_x = x + drift * step_size 149 | x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x) 150 | 151 | # normalisation 152 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode) 153 | poses.append(x.unsqueeze(0)) 154 | 155 | xs = torch.cat(poses, dim=0) 156 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) 157 | mean_x[:, -3:] += data['pts_center'] 158 | mean_x[:, :-3] = normalize_rotation(mean_x[:, :-3], pose_mode) 159 | # The last step does not include any noise 160 | return xs.permute(1, 0, 2), mean_x 161 | 162 | 163 | def cond_ode_sampler( 164 | score_model, 165 | data, 166 | prior, 167 | sde_coeff, 168 | atol=1e-5, 169 | rtol=1e-5, 170 | device='cuda', 171 | eps=1e-5, 172 | T=1.0, 173 | num_steps=None, 174 | pose_mode='quat_wxyz', 175 | denoise=True, 176 | init_x=None, 177 | ): 178 | pose_dim = get_pose_dim(pose_mode) 179 | batch_size=data['pts'].shape[0] 180 | init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim), T=T).to(device) 181 | shape = init_x.shape 182 | 183 | def score_eval_wrapper(data): 184 | """A wrapper of the score-based model for use by the ODE solver.""" 185 | with torch.no_grad(): 186 | score = score_model(data) 187 | return score.cpu().numpy().reshape((-1,)) 188 | 189 | def ode_func(t, x): 190 | """The ODE function for use by the ODE solver.""" 191 | x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) 192 | time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t 193 | drift, diffusion = sde_coeff(torch.tensor(t)) 194 | drift = drift.cpu().numpy() 195 | diffusion = diffusion.cpu().numpy() 196 | data['sampled_pose'] = x 197 | data['t'] = time_steps 198 | return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data) 199 | 200 | # Run the black-box ODE solver, note the 201 | t_eval = None 202 | if num_steps is not None: 203 | # num_steps, from T -> eps 204 | t_eval = np.linspace(T, eps, num_steps) 205 | res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', t_eval=t_eval) 206 | xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim] 207 | x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim] 208 | # denoise, using the predictor step in P-C sampler 209 | if denoise: 210 | # Reverse diffusion predictor for denoising 211 | vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps 212 | drift, diffusion = sde_coeff(vec_eps) 213 | data['sampled_pose'] = x.float() 214 | data['t'] = vec_eps 215 | grad = score_model(data) 216 | drift = drift - diffusion**2*grad # R-SDE 217 | mean_x = x + drift * ((1-eps)/(1000 if num_steps is None else num_steps)) 218 | x = mean_x 219 | 220 | num_steps = xs.shape[0] 221 | xs = xs.reshape(batch_size*num_steps, -1) 222 | xs[:, :-3] = normalize_rotation(xs[:, :-3], pose_mode) 223 | xs = xs.reshape(num_steps, batch_size, -1) 224 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) 225 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode) 226 | x[:, -3:] += data['pts_center'] 227 | return xs.permute(1, 0, 2), x 228 | 229 | 230 | def cond_edm_sampler( 231 | decoder_model, data, prior_fn, randn_like=torch.randn_like, 232 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 233 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 234 | pose_mode='quat_wxyz', device='cuda' 235 | ): 236 | pose_dim = get_pose_dim(pose_mode) 237 | batch_size = data['pts'].shape[0] 238 | latents = prior_fn((batch_size, pose_dim)).to(device) 239 | 240 | # Time step discretization. note that sigma and t is interchangable 241 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 242 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 243 | t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 244 | 245 | def decoder_wrapper(decoder, data, x, t): 246 | # save temp 247 | x_, t_= data['sampled_pose'], data['t'] 248 | # init data 249 | data['sampled_pose'], data['t'] = x, t 250 | # denoise 251 | data, denoised = decoder(data) 252 | # recover data 253 | data['sampled_pose'], data['t'] = x_, t_ 254 | return denoised.to(torch.float64) 255 | 256 | # Main sampling loop. 257 | x_next = latents.to(torch.float64) * t_steps[0] 258 | xs = [] 259 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 260 | x_cur = x_next 261 | 262 | # Increase noise temporarily. 263 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 264 | t_hat = torch.as_tensor(t_cur + gamma * t_cur) 265 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 266 | 267 | # Euler step. 268 | denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat) 269 | d_cur = (x_hat - denoised) / t_hat 270 | x_next = x_hat + (t_next - t_hat) * d_cur 271 | 272 | # Apply 2nd order correction. 273 | if i < num_steps - 1: 274 | denoised = decoder_wrapper(decoder_model, data, x_next, t_next) 275 | d_prime = (x_next - denoised) / t_next 276 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 277 | xs.append(x_next.unsqueeze(0)) 278 | 279 | xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim] 280 | x = xs[-1] # [bs, pose_dim] 281 | 282 | # post-processing 283 | xs = xs.reshape(batch_size*num_steps, -1) 284 | xs[:, :-3] = normalize_rotation(xs[:, :-3], pose_mode) 285 | xs = xs.reshape(num_steps, batch_size, -1) 286 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) 287 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode) 288 | x[:, -3:] += data['pts_center'] 289 | 290 | return xs.permute(1, 0, 2), x 291 | 292 | 293 | -------------------------------------------------------------------------------- /networks/gf_algorithms/score_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ExponentialMovingAverage: 4 | """ 5 | Maintains (exponential) moving average of a set of parameters. 6 | """ 7 | 8 | def __init__(self, parameters, decay, use_num_updates=True): 9 | """ 10 | Args: 11 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 12 | `model.parameters()`. 13 | decay: The exponential decay. 14 | use_num_updates: Whether to use number of updates when computing 15 | averages. 16 | """ 17 | if decay < 0.0 or decay > 1.0: 18 | raise ValueError('Decay must be between 0 and 1') 19 | self.decay = decay 20 | self.num_updates = 0 if use_num_updates else None 21 | self.shadow_params = [p.clone().detach() 22 | for p in parameters if p.requires_grad] 23 | self.collected_params = [] 24 | 25 | def update(self, parameters): 26 | """ 27 | Update currently maintained parameters. 28 | 29 | Call this every time the parameters are updated, such as the result of 30 | the `optimizer.step()` call. 31 | 32 | Args: 33 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 34 | parameters used to initialize this object. 35 | """ 36 | decay = self.decay 37 | if self.num_updates is not None: 38 | self.num_updates += 1 39 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 40 | one_minus_decay = 1.0 - decay 41 | with torch.no_grad(): 42 | parameters = [p for p in parameters if p.requires_grad] 43 | for s_param, param in zip(self.shadow_params, parameters): 44 | s_param.sub_(one_minus_decay * (s_param - param)) # only update the ema-params 45 | 46 | 47 | def copy_to(self, parameters): 48 | """ 49 | Copy current parameters into given collection of parameters. 50 | 51 | Args: 52 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 53 | updated with the stored moving averages. 54 | """ 55 | parameters = [p for p in parameters if p.requires_grad] 56 | for s_param, param in zip(self.shadow_params, parameters): 57 | if param.requires_grad: 58 | param.data.copy_(s_param.data) 59 | 60 | def store(self, parameters): 61 | """ 62 | Save the current parameters for restoring later. 63 | 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | 78 | Args: 79 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 80 | updated with the stored parameters. 81 | """ 82 | for c_param, param in zip(self.collected_params, parameters): 83 | param.data.copy_(c_param.data) 84 | 85 | def state_dict(self): 86 | return dict(decay=self.decay, num_updates=self.num_updates, 87 | shadow_params=self.shadow_params) 88 | 89 | def load_state_dict(self, state_dict): 90 | self.decay = state_dict['decay'] 91 | self.num_updates = state_dict['num_updates'] 92 | self.shadow_params = state_dict['shadow_params'] 93 | -------------------------------------------------------------------------------- /networks/gf_algorithms/sde.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import functools 4 | import torch 5 | import numpy as np 6 | from ipdb import set_trace 7 | from scipy import integrate 8 | from utils.genpose_utils import get_pose_dim 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 11 | 12 | 13 | #----- VE SDE ----- 14 | #------------------ 15 | def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90): 16 | std = sigma_min * (sigma_max / sigma_min) ** t 17 | mean = x 18 | return mean, std 19 | 20 | def ve_sde(t, sigma_min=0.01, sigma_max=90): 21 | sigma = sigma_min * (sigma_max / sigma_min) ** t 22 | drift_coeff = torch.tensor(0) 23 | diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device)) 24 | return drift_coeff, diffusion_coeff 25 | 26 | def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0): 27 | _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max) 28 | return torch.randn(*shape) * sigma_max_prior 29 | 30 | #----- VP SDE ----- 31 | #------------------ 32 | def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20): 33 | log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 34 | mean = torch.exp(log_mean_coeff) * x 35 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 36 | return mean, std 37 | 38 | def vp_sde(t, beta_0=0.1, beta_1=20): 39 | beta_t = beta_0 + t * (beta_1 - beta_0) 40 | drift_coeff = -0.5 * beta_t 41 | diffusion_coeff = torch.sqrt(beta_t) 42 | return drift_coeff, diffusion_coeff 43 | 44 | def vp_prior(shape, beta_0=0.1, beta_1=20): 45 | return torch.randn(*shape) 46 | 47 | #----- sub-VP SDE ----- 48 | #---------------------- 49 | def subvp_marginal_prob(x, t, beta_0, beta_1): 50 | log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 51 | mean = torch.exp(log_mean_coeff) * x 52 | std = 1 - torch.exp(2. * log_mean_coeff) 53 | return mean, std 54 | 55 | def subvp_sde(t, beta_0, beta_1): 56 | beta_t = beta_0 + t * (beta_1 - beta_0) 57 | drift_coeff = -0.5 * beta_t 58 | discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2) 59 | diffusion_coeff = torch.sqrt(beta_t * discount) 60 | return drift_coeff, diffusion_coeff 61 | 62 | def subvp_prior(shape, beta_0=0.1, beta_1=20): 63 | return torch.randn(*shape) 64 | 65 | #----- EDM SDE ----- 66 | #------------------ 67 | def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80): 68 | std = t 69 | mean = x 70 | return mean, std 71 | 72 | def edm_sde(t, sigma_min=0.002, sigma_max=80): 73 | drift_coeff = torch.tensor(0) 74 | diffusion_coeff = torch.sqrt(2 * t) 75 | return drift_coeff, diffusion_coeff 76 | 77 | def edm_prior(shape, sigma_min=0.002, sigma_max=80): 78 | return torch.randn(*shape) * sigma_max 79 | 80 | def init_sde(sde_mode): 81 | # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch 82 | if sde_mode == 'edm': 83 | sigma_min = 0.002 84 | sigma_max = 80 85 | eps = 0.002 86 | prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max) 87 | marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) 88 | sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max) 89 | T = sigma_max 90 | elif sde_mode == 've': 91 | sigma_min = 0.01 92 | sigma_max = 50 93 | eps = 1e-5 94 | marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) 95 | sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max) 96 | T = 1.0 97 | prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max) 98 | elif sde_mode == 'vp': 99 | beta_0 = 0.1 100 | beta_1 = 20 101 | eps = 1e-3 102 | prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1) 103 | marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1) 104 | sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1) 105 | T = 1.0 106 | elif sde_mode == 'subvp': 107 | beta_0 = 0.1 108 | beta_1 = 20 109 | eps = 1e-3 110 | prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1) 111 | marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1) 112 | sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1) 113 | T = 1.0 114 | else: 115 | raise NotImplementedError 116 | return prior_fn, marginal_prob_fn, sde_fn, eps, T 117 | 118 | -------------------------------------------------------------------------------- /networks/posenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ipdb import set_trace 7 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 8 | from networks.pts_encoder.pointnets import PointNetfeat 9 | from networks.pts_encoder.pointnet2 import Pointnet2ClsMSG 10 | from networks.gf_algorithms.samplers import cond_ode_likelihood, cond_ode_sampler, cond_pc_sampler 11 | from networks.gf_algorithms.scorenet import PoseScoreNet, PoseDecoderNet 12 | from networks.gf_algorithms.energynet import PoseEnergyNet 13 | from networks.gf_algorithms.sde import init_sde 14 | from configs.config import get_config 15 | 16 | 17 | 18 | class GFObjectPose(nn.Module): 19 | def __init__(self, cfg, prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T): 20 | super(GFObjectPose, self).__init__() 21 | 22 | self.cfg = cfg 23 | self.device = cfg.device 24 | self.is_testing = False 25 | 26 | ''' Load model, define SDE ''' 27 | # init SDE config 28 | self.prior_fn = prior_fn 29 | self.marginal_prob_fn = marginal_prob_fn 30 | self.sde_fn = sde_fn 31 | self.sampling_eps = sampling_eps 32 | self.T = T 33 | # self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps = init_sde(cfg.sde_mode) 34 | 35 | ''' encode pts ''' 36 | if self.cfg.pts_encoder == 'pointnet': 37 | self.pts_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024) 38 | elif self.cfg.pts_encoder == 'pointnet2': 39 | self.pts_encoder = Pointnet2ClsMSG(0) 40 | elif self.cfg.pts_encoder == 'pointnet_and_pointnet2': 41 | self.pts_pointnet_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024) 42 | self.pts_pointnet2_encoder = Pointnet2ClsMSG(0) 43 | self.fusion_layer = nn.Linear(2048, 1024) 44 | self.act = nn.ReLU() 45 | else: 46 | raise NotImplementedError 47 | 48 | ''' score network''' 49 | # if self.cfg.sde_mode == 'edm': 50 | # self.pose_score_net = PoseDecoderNet( 51 | # self.marginal_prob_fn, 52 | # sigma_data=1.4148, 53 | # pose_mode=self.cfg.pose_mode, 54 | # regression_head=self.cfg.regression_head 55 | # ) 56 | # else: 57 | per_point_feat = False 58 | if self.cfg.posenet_mode == 'score': 59 | self.pose_score_net = PoseScoreNet(self.marginal_prob_fn, self.cfg.pose_mode, self.cfg.regression_head, per_point_feat) 60 | elif self.cfg.posenet_mode == 'energy': 61 | self.pose_score_net = PoseEnergyNet( 62 | marginal_prob_func=self.marginal_prob_fn, 63 | pose_mode=self.cfg.pose_mode, 64 | regression_head=self.cfg.regression_head, 65 | energy_mode=self.cfg.energy_mode, 66 | s_theta_mode=self.cfg.s_theta_mode, 67 | norm_energy=self.cfg.norm_energy) 68 | ''' ToDo: ranking network ''' 69 | 70 | 71 | def extract_pts_feature(self, data): 72 | """extract the input pointcloud feature 73 | 74 | Args: 75 | data (dict): batch example without pointcloud feature. {'pts': [bs, num_pts, 3], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]} 76 | Returns: 77 | data (dict): batch example with pointcloud feature. {'pts': [bs, num_pts, 3], 'pts_feat': [bs, c], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]} 78 | """ 79 | pts = data['pts'] 80 | if self.cfg.pts_encoder == 'pointnet': 81 | pts_feat = self.pts_encoder(pts.permute(0, 2, 1)) # -> (bs, 3, 1024) 82 | elif self.cfg.pts_encoder in ['pointnet2']: 83 | pts_feat = self.pts_encoder(pts) 84 | elif self.cfg.pts_encoder == 'pointnet_and_pointnet2': 85 | pts_pointnet_feat = self.pts_pointnet_encoder(pts.permute(0, 2, 1)) 86 | pts_pointnet2_feat = self.pts_pointnet2_encoder(pts) 87 | pts_feat = self.fusion_layer(torch.cat((pts_pointnet_feat, pts_pointnet2_feat), dim=-1)) 88 | pts_feat = self.act(pts_feat) 89 | else: 90 | raise NotImplementedError 91 | return pts_feat 92 | 93 | 94 | def sample(self, data, sampler, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): 95 | if sampler == 'pc': 96 | in_process_sample, res = cond_pc_sampler( 97 | score_model=self, 98 | data=data, 99 | prior=self.prior_fn, 100 | sde_coeff=self.sde_fn, 101 | num_steps=self.cfg.sampling_steps, 102 | snr=snr, 103 | device=self.device, 104 | eps=self.sampling_eps, 105 | pose_mode=self.cfg.pose_mode, 106 | init_x=init_x 107 | ) 108 | 109 | elif sampler == 'ode': 110 | T0 = self.T if T0 is None else T0 111 | in_process_sample, res = cond_ode_sampler( 112 | score_model=self, 113 | data=data, 114 | prior=self.prior_fn, 115 | sde_coeff=self.sde_fn, 116 | atol=atol, 117 | rtol=rtol, 118 | device=self.device, 119 | eps=self.sampling_eps, 120 | T=T0, 121 | num_steps=self.cfg.sampling_steps, 122 | pose_mode=self.cfg.pose_mode, 123 | denoise=denoise, 124 | init_x=init_x 125 | ) 126 | 127 | else: 128 | raise NotImplementedError 129 | 130 | return in_process_sample, res 131 | 132 | 133 | def calc_likelihood(self, data, atol=1e-5, rtol=1e-5): 134 | latent_code, log_likelihoods = cond_ode_likelihood( 135 | score_model=self, 136 | data=data, 137 | prior=self.prior_fn, 138 | sde_coeff=self.sde_fn, 139 | marginal_prob_fn=self.marginal_prob_fn, 140 | atol=atol, 141 | rtol=rtol, 142 | device=self.device, 143 | eps=self.sampling_eps, 144 | num_steps=self.cfg.sampling_steps, 145 | pose_mode=self.cfg.pose_mode, 146 | ) 147 | return log_likelihoods 148 | 149 | 150 | def forward(self, data, mode='score', init_x=None, T0=None): 151 | ''' 152 | Args: 153 | data, dict { 154 | 'pts': [bs, num_pts, 3] 155 | 'pts_feat': [bs, c] 156 | 'sampled_pose': [bs, pose_dim] 157 | 't': [bs, 1] 158 | } 159 | ''' 160 | if mode == 'score': 161 | out_score = self.pose_score_net(data) # normalisation 162 | return out_score 163 | elif mode == 'energy': 164 | out_energy = self.pose_score_net(data, return_item='energy') 165 | return out_energy 166 | elif mode == 'likelihood': 167 | likelihoods = self.calc_likelihood(data) 168 | return likelihoods 169 | elif mode == 'pts_feature': 170 | pts_feature = self.extract_pts_feature(data) 171 | return pts_feature 172 | elif mode == 'pc_sample': 173 | in_process_sample, res = self.sample(data, 'pc', init_x=init_x) 174 | return in_process_sample, res 175 | elif mode == 'ode_sample': 176 | in_process_sample, res = self.sample(data, 'ode', init_x=init_x, T0=T0) 177 | return in_process_sample, res 178 | else: 179 | raise NotImplementedError 180 | 181 | 182 | 183 | def test(): 184 | def get_parameter_number(model): 185 | total_num = sum(p.numel() for p in model.parameters()) 186 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 187 | return {'Total': total_num, 'Trainable': trainable_num} 188 | cfg = get_config() 189 | prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T = init_sde('ve') 190 | net = GFObjectPose(cfg, prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T) 191 | net_parameters_num= get_parameter_number(net) 192 | print(net_parameters_num['Total'], net_parameters_num['Trainable']) 193 | if __name__ == '__main__': 194 | test() 195 | 196 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | import os 5 | 6 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 7 | from networks.pts_encoder.pointnet2_utils.pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG 8 | import networks.pts_encoder.pointnet2_utils.pointnet2.pytorch_utils as pt_utils 9 | from ipdb import set_trace 10 | from configs.config import get_config 11 | 12 | 13 | cfg = get_config() 14 | 15 | 16 | def get_model(input_channels=0): 17 | return Pointnet2MSG(input_channels=input_channels) 18 | 19 | MSG_CFG = { 20 | 'NPOINTS': [512, 256, 128, 64], 21 | 'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16]], 22 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32]], 23 | 'MLPS': [[[16, 16, 32], [32, 32, 64]], 24 | [[64, 64, 128], [64, 96, 128]], 25 | [[128, 196, 256], [128, 196, 256]], 26 | [[256, 256, 512], [256, 384, 512]]], 27 | 'FP_MLPS': [[64, 64], [128, 128], [256, 256], [512, 512]], 28 | 'CLS_FC': [128], 29 | 'DP_RATIO': 0.5, 30 | } 31 | 32 | ClsMSG_CFG = { 33 | 'NPOINTS': [512, 256, 128, 64, None], 34 | 'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], 35 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]], 36 | 'MLPS': [[[16, 16, 32], [32, 32, 64]], 37 | [[64, 64, 128], [64, 96, 128]], 38 | [[128, 196, 256], [128, 196, 256]], 39 | [[256, 256, 512], [256, 384, 512]], 40 | [[512, 512], [512, 512]]], 41 | 'DP_RATIO': 0.5, 42 | } 43 | 44 | ClsMSG_CFG_Dense = { 45 | 'NPOINTS': [512, 256, 128, None], 46 | 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], 47 | 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]], 48 | 'MLPS': [[[16, 16, 32], [32, 32, 64]], 49 | [[64, 64, 128], [64, 96, 128]], 50 | [[128, 196, 256], [128, 196, 256]], 51 | [[256, 256, 512], [256, 384, 512]]], 52 | 'DP_RATIO': 0.5, 53 | } 54 | 55 | 56 | ########## Best before 29th April ########### 57 | ClsMSG_CFG_Light = { 58 | 'NPOINTS': [512, 256, 128, None], 59 | 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], 60 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], 61 | 'MLPS': [[[16, 16, 32], [32, 32, 64]], 62 | [[64, 64, 128], [64, 96, 128]], 63 | [[128, 196, 256], [128, 196, 256]], 64 | [[256, 256, 512], [256, 384, 512]]], 65 | 'DP_RATIO': 0.5, 66 | } 67 | 68 | 69 | ClsMSG_CFG_Lighter= { 70 | 'NPOINTS': [512, 256, 128, 64, None], 71 | 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]], 72 | 'NSAMPLE': [[64], [32], [16], [8], [None]], 73 | 'MLPS': [[[32, 32, 64]], 74 | [[64, 64, 128]], 75 | [[128, 196, 256]], 76 | [[256, 256, 512]], 77 | [[512, 512, 1024]]], 78 | 'DP_RATIO': 0.5, 79 | } 80 | 81 | 82 | if cfg.pointnet2_params == 'light': 83 | SELECTED_PARAMS = ClsMSG_CFG_Light 84 | elif cfg.pointnet2_params == 'lighter': 85 | SELECTED_PARAMS = ClsMSG_CFG_Lighter 86 | elif cfg.pointnet2_params == 'dense': 87 | SELECTED_PARAMS = ClsMSG_CFG_Dense 88 | else: 89 | raise NotImplementedError 90 | 91 | 92 | class Pointnet2MSG(nn.Module): 93 | def __init__(self, input_channels=6): 94 | super().__init__() 95 | 96 | self.SA_modules = nn.ModuleList() 97 | channel_in = input_channels 98 | 99 | skip_channel_list = [input_channels] 100 | for k in range(MSG_CFG['NPOINTS'].__len__()): 101 | mlps = MSG_CFG['MLPS'][k].copy() 102 | channel_out = 0 103 | for idx in range(mlps.__len__()): 104 | mlps[idx] = [channel_in] + mlps[idx] 105 | channel_out += mlps[idx][-1] 106 | 107 | self.SA_modules.append( 108 | PointnetSAModuleMSG( 109 | npoint=MSG_CFG['NPOINTS'][k], 110 | radii=MSG_CFG['RADIUS'][k], 111 | nsamples=MSG_CFG['NSAMPLE'][k], 112 | mlps=mlps, 113 | use_xyz=True, 114 | bn=True 115 | ) 116 | ) 117 | skip_channel_list.append(channel_out) 118 | channel_in = channel_out 119 | 120 | self.FP_modules = nn.ModuleList() 121 | 122 | for k in range(MSG_CFG['FP_MLPS'].__len__()): 123 | pre_channel = MSG_CFG['FP_MLPS'][k + 1][-1] if k + 1 < len(MSG_CFG['FP_MLPS']) else channel_out 124 | self.FP_modules.append( 125 | PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + MSG_CFG['FP_MLPS'][k]) 126 | ) 127 | 128 | cls_layers = [] 129 | pre_channel = MSG_CFG['FP_MLPS'][0][-1] 130 | for k in range(0, MSG_CFG['CLS_FC'].__len__()): 131 | cls_layers.append(pt_utils.Conv1d(pre_channel, MSG_CFG['CLS_FC'][k], bn=True)) 132 | pre_channel = MSG_CFG['CLS_FC'][k] 133 | cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None)) 134 | cls_layers.insert(1, nn.Dropout(0.5)) 135 | self.cls_layer = nn.Sequential(*cls_layers) 136 | 137 | 138 | def _break_up_pc(self, pc): 139 | xyz = pc[..., 0:3].contiguous() 140 | features = ( 141 | pc[..., 3:].transpose(1, 2).contiguous() 142 | if pc.size(-1) > 3 else None 143 | ) 144 | 145 | return xyz, features 146 | 147 | def forward(self, pointcloud: torch.cuda.FloatTensor): 148 | xyz, features = self._break_up_pc(pointcloud) 149 | 150 | l_xyz, l_features = [xyz], [features] 151 | for i in range(len(self.SA_modules)): 152 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 153 | 154 | l_xyz.append(li_xyz) 155 | l_features.append(li_features) 156 | 157 | set_trace() 158 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 159 | l_features[i - 1] = self.FP_modules[i]( 160 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 161 | ) 162 | 163 | return l_features[0] 164 | 165 | 166 | class Pointnet2ClsMSG(nn.Module): 167 | def __init__(self, input_channels=6): 168 | super().__init__() 169 | 170 | self.SA_modules = nn.ModuleList() 171 | channel_in = input_channels 172 | 173 | for k in range(SELECTED_PARAMS['NPOINTS'].__len__()): 174 | mlps = SELECTED_PARAMS['MLPS'][k].copy() 175 | channel_out = 0 176 | for idx in range(mlps.__len__()): 177 | mlps[idx] = [channel_in] + mlps[idx] 178 | channel_out += mlps[idx][-1] 179 | 180 | self.SA_modules.append( 181 | PointnetSAModuleMSG( 182 | npoint=SELECTED_PARAMS['NPOINTS'][k], 183 | radii=SELECTED_PARAMS['RADIUS'][k], 184 | nsamples=SELECTED_PARAMS['NSAMPLE'][k], 185 | mlps=mlps, 186 | use_xyz=True, 187 | bn=True 188 | ) 189 | ) 190 | channel_in = channel_out 191 | 192 | 193 | def _break_up_pc(self, pc): 194 | xyz = pc[..., 0:3].contiguous() 195 | features = ( 196 | pc[..., 3:].transpose(1, 2).contiguous() 197 | if pc.size(-1) > 3 else None 198 | ) 199 | 200 | return xyz, features 201 | 202 | 203 | def forward(self, pointcloud: torch.cuda.FloatTensor): 204 | xyz, features = self._break_up_pc(pointcloud) 205 | 206 | l_xyz, l_features = [xyz], [features] 207 | for i in range(len(self.SA_modules)): 208 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 209 | l_xyz.append(li_xyz) 210 | l_features.append(li_features) 211 | return l_features[-1].squeeze(-1) 212 | 213 | 214 | if __name__ == '__main__': 215 | seed = 100 216 | torch.manual_seed(seed) 217 | torch.cuda.manual_seed(seed) 218 | net = Pointnet2ClsMSG(0).cuda() 219 | pts = torch.randn(2, 1024, 3).cuda() 220 | print(torch.mean(pts, dim=1)) 221 | pre = net(pts) 222 | print(pre.shape) 223 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/.gitignore: -------------------------------------------------------------------------------- 1 | pointnet2/build/ 2 | pointnet2/dist/ 3 | pointnet2/pointnet2.egg-info/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shaoshuai Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/README.md: -------------------------------------------------------------------------------- 1 | # Pointnet2.PyTorch 2 | 3 | * PyTorch implementation of [PointNet++](https://arxiv.org/abs/1706.02413) based on [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). 4 | * Faster than the original codes by re-implementing the CUDA operations. 5 | 6 | ## Installation 7 | ### Requirements 8 | * Linux (tested on Ubuntu 14.04/16.04) 9 | * Python 3.6+ 10 | * PyTorch 1.0 11 | 12 | ### Install 13 | Install this library by running the following command: 14 | 15 | ```shell 16 | cd pointnet2 17 | python setup.py install 18 | cd ../ 19 | ``` 20 | 21 | ## Examples 22 | Here I provide a simple example to use this library in the task of KITTI ourdoor foreground point cloud segmentation, and you could refer to the paper [PointRCNN](https://arxiv.org/abs/1812.04244) for the details of task description and foreground label generation. 23 | 24 | 1. Download the training data from [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) website and organize the downloaded files as follows: 25 | ``` 26 | Pointnet2.PyTorch 27 | ├── pointnet2 28 | ├── tools 29 | │ ├──data 30 | │ │ ├── KITTI 31 | │ │ │ ├── ImageSets 32 | │ │ │ ├── object 33 | │ │ │ │ ├──training 34 | │ │ │ │ ├──calib & velodyne & label_2 & image_2 35 | │ │ train_and_eval.py 36 | ``` 37 | 38 | 2. Run the following command to train and evaluate: 39 | ```shell 40 | cd tools 41 | python train_and_eval.py --batch_size 8 --epochs 100 --ckpt_save_interval 2 42 | ``` 43 | 44 | 45 | 46 | ## Project using this repo: 47 | * [PointRCNN](https://github.com/sshaoshuai/PointRCNN): 3D object detector from raw point cloud. 48 | 49 | ## Acknowledgement 50 | * [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo. 51 | * [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++. 52 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from . import pointnet2_utils 6 | from . import pytorch_utils as pt_utils 7 | from typing import List 8 | 9 | 10 | class _PointnetSAModuleBase(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.npoint = None 15 | self.groupers = None 16 | self.mlps = None 17 | self.pool_method = 'max_pool' 18 | 19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 20 | """ 21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 22 | :param features: (B, N, C) tensor of the descriptors of the the features 23 | :param new_xyz: 24 | :return: 25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 27 | """ 28 | new_features_list = [] 29 | 30 | xyz_flipped = xyz.transpose(1, 2).contiguous() 31 | if new_xyz is None: 32 | new_xyz = pointnet2_utils.gather_operation( 33 | xyz_flipped, 34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 36 | 37 | for i in range(len(self.groupers)): 38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 39 | 40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 41 | 42 | if self.pool_method == 'max_pool': 43 | new_features = F.max_pool2d( 44 | new_features, kernel_size=[1, new_features.size(3)] 45 | ) # (B, mlp[-1], npoint, 1) 46 | elif self.pool_method == 'avg_pool': 47 | new_features = F.avg_pool2d( 48 | new_features, kernel_size=[1, new_features.size(3)] 49 | ) # (B, mlp[-1], npoint, 1) 50 | else: 51 | raise NotImplementedError 52 | 53 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 54 | new_features_list.append(new_features) 55 | 56 | return new_xyz, torch.cat(new_features_list, dim=1) 57 | 58 | 59 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 60 | """Pointnet set abstraction layer with multiscale grouping""" 61 | 62 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, 63 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 64 | """ 65 | :param npoint: int 66 | :param radii: list of float, list of radii to group with 67 | :param nsamples: list of int, number of samples in each ball query 68 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale 69 | :param bn: whether to use batchnorm 70 | :param use_xyz: 71 | :param pool_method: max_pool / avg_pool 72 | :param instance_norm: whether to use instance_norm 73 | """ 74 | super().__init__() 75 | 76 | assert len(radii) == len(nsamples) == len(mlps) 77 | 78 | self.npoint = npoint 79 | self.groupers = nn.ModuleList() 80 | self.mlps = nn.ModuleList() 81 | for i in range(len(radii)): 82 | radius = radii[i] 83 | nsample = nsamples[i] 84 | self.groupers.append( 85 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 86 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 87 | ) 88 | mlp_spec = mlps[i] 89 | if use_xyz: 90 | mlp_spec[0] += 3 91 | 92 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) 93 | self.pool_method = pool_method 94 | 95 | 96 | class PointnetSAModule(PointnetSAModuleMSG): 97 | """Pointnet set abstraction layer""" 98 | 99 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, 100 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 101 | """ 102 | :param mlp: list of int, spec of the pointnet before the global max_pool 103 | :param npoint: int, number of features 104 | :param radius: float, radius of ball 105 | :param nsample: int, number of samples in the ball query 106 | :param bn: whether to use batchnorm 107 | :param use_xyz: 108 | :param pool_method: max_pool / avg_pool 109 | :param instance_norm: whether to use instance_norm 110 | """ 111 | super().__init__( 112 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, 113 | pool_method=pool_method, instance_norm=instance_norm 114 | ) 115 | 116 | 117 | class PointnetFPModule(nn.Module): 118 | r"""Propigates the features of one set to another""" 119 | 120 | def __init__(self, *, mlp: List[int], bn: bool = True): 121 | """ 122 | :param mlp: list of int 123 | :param bn: whether to use batchnorm 124 | """ 125 | super().__init__() 126 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 127 | 128 | def forward( 129 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 130 | ) -> torch.Tensor: 131 | """ 132 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 133 | :param known: (B, m, 3) tensor of the xyz positions of the known features 134 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 135 | :param known_feats: (B, C2, m) tensor of features to be propigated 136 | :return: 137 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 138 | """ 139 | if known is not None: 140 | dist, idx = pointnet2_utils.three_nn(unknown, known) 141 | dist_recip = 1.0 / (dist + 1e-8) 142 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 143 | weight = dist_recip / norm 144 | 145 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) 146 | else: 147 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 148 | 149 | if unknow_feats is not None: 150 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 151 | else: 152 | new_features = interpolated_feats 153 | 154 | new_features = new_features.unsqueeze(-1) 155 | 156 | new_features = self.mlp(new_features) 157 | 158 | return new_features.squeeze(-1) 159 | 160 | 161 | if __name__ == "__main__": 162 | pass 163 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | import torch.nn as nn 5 | from typing import Tuple 6 | import sys 7 | 8 | import pointnet2_cuda as pointnet2 9 | 10 | 11 | class FurthestPointSampling(Function): 12 | @staticmethod 13 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 14 | """ 15 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 16 | minimum distance 17 | :param ctx: 18 | :param xyz: (B, N, 3) where N > npoint 19 | :param npoint: int, number of features in the sampled set 20 | :return: 21 | output: (B, npoint) tensor containing the set 22 | """ 23 | assert xyz.is_contiguous() 24 | 25 | B, N, _ = xyz.size() 26 | output = torch.cuda.IntTensor(B, npoint) 27 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 28 | 29 | pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(xyz, a=None): 34 | return None, None 35 | 36 | 37 | furthest_point_sample = FurthestPointSampling.apply 38 | 39 | 40 | class GatherOperation(Function): 41 | 42 | @staticmethod 43 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 44 | """ 45 | :param ctx: 46 | :param features: (B, C, N) 47 | :param idx: (B, npoint) index tensor of the features to gather 48 | :return: 49 | output: (B, C, npoint) 50 | """ 51 | assert features.is_contiguous() 52 | assert idx.is_contiguous() 53 | 54 | B, npoint = idx.size() 55 | _, C, N = features.size() 56 | output = torch.cuda.FloatTensor(B, C, npoint) 57 | 58 | pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) 59 | 60 | ctx.for_backwards = (idx, C, N) 61 | return output 62 | 63 | @staticmethod 64 | def backward(ctx, grad_out): 65 | idx, C, N = ctx.for_backwards 66 | B, npoint = idx.size() 67 | 68 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 69 | grad_out_data = grad_out.data.contiguous() 70 | pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 71 | return grad_features, None 72 | 73 | 74 | gather_operation = GatherOperation.apply 75 | 76 | 77 | class ThreeNN(Function): 78 | 79 | @staticmethod 80 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Find the three nearest neighbors of unknown in known 83 | :param ctx: 84 | :param unknown: (B, N, 3) 85 | :param known: (B, M, 3) 86 | :return: 87 | dist: (B, N, 3) l2 distance to the three nearest neighbors 88 | idx: (B, N, 3) index of 3 nearest neighbors 89 | """ 90 | assert unknown.is_contiguous() 91 | assert known.is_contiguous() 92 | 93 | B, N, _ = unknown.size() 94 | m = known.size(1) 95 | dist2 = torch.cuda.FloatTensor(B, N, 3) 96 | idx = torch.cuda.IntTensor(B, N, 3) 97 | 98 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 99 | return torch.sqrt(dist2), idx 100 | 101 | @staticmethod 102 | def backward(ctx, a=None, b=None): 103 | return None, None 104 | 105 | 106 | three_nn = ThreeNN.apply 107 | 108 | 109 | class ThreeInterpolate(Function): 110 | 111 | @staticmethod 112 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Performs weight linear interpolation on 3 features 115 | :param ctx: 116 | :param features: (B, C, M) Features descriptors to be interpolated from 117 | :param idx: (B, n, 3) three nearest neighbors of the target features in features 118 | :param weight: (B, n, 3) weights 119 | :return: 120 | output: (B, C, N) tensor of the interpolated features 121 | """ 122 | assert features.is_contiguous() 123 | assert idx.is_contiguous() 124 | assert weight.is_contiguous() 125 | 126 | B, c, m = features.size() 127 | n = idx.size(1) 128 | ctx.three_interpolate_for_backward = (idx, weight, m) 129 | output = torch.cuda.FloatTensor(B, c, n) 130 | 131 | pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) 132 | return output 133 | 134 | @staticmethod 135 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 136 | """ 137 | :param ctx: 138 | :param grad_out: (B, C, N) tensor with gradients of outputs 139 | :return: 140 | grad_features: (B, C, M) tensor with gradients of features 141 | None: 142 | None: 143 | """ 144 | idx, weight, m = ctx.three_interpolate_for_backward 145 | B, c, n = grad_out.size() 146 | 147 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 148 | grad_out_data = grad_out.data.contiguous() 149 | 150 | pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) 151 | return grad_features, None, None 152 | 153 | 154 | three_interpolate = ThreeInterpolate.apply 155 | 156 | 157 | class GroupingOperation(Function): 158 | 159 | @staticmethod 160 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 161 | """ 162 | :param ctx: 163 | :param features: (B, C, N) tensor of features to group 164 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with 165 | :return: 166 | output: (B, C, npoint, nsample) tensor 167 | """ 168 | assert features.is_contiguous() 169 | assert idx.is_contiguous() 170 | 171 | B, nfeatures, nsample = idx.size() 172 | _, C, N = features.size() 173 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 174 | 175 | pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) 176 | 177 | ctx.for_backwards = (idx, N) 178 | return output 179 | 180 | @staticmethod 181 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 182 | """ 183 | :param ctx: 184 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward 185 | :return: 186 | grad_features: (B, C, N) gradient of the features 187 | """ 188 | idx, N = ctx.for_backwards 189 | 190 | B, C, npoint, nsample = grad_out.size() 191 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 192 | 193 | grad_out_data = grad_out.data.contiguous() 194 | pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) 195 | return grad_features, None 196 | 197 | 198 | grouping_operation = GroupingOperation.apply 199 | 200 | 201 | class BallQuery(Function): 202 | 203 | @staticmethod 204 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: 205 | """ 206 | :param ctx: 207 | :param radius: float, radius of the balls 208 | :param nsample: int, maximum number of features in the balls 209 | :param xyz: (B, N, 3) xyz coordinates of the features 210 | :param new_xyz: (B, npoint, 3) centers of the ball query 211 | :return: 212 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls 213 | """ 214 | assert new_xyz.is_contiguous() 215 | assert xyz.is_contiguous() 216 | 217 | B, N, _ = xyz.size() 218 | npoint = new_xyz.size(1) 219 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 220 | 221 | pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) 222 | return idx 223 | 224 | @staticmethod 225 | def backward(ctx, a=None): 226 | return None, None, None, None 227 | 228 | 229 | ball_query = BallQuery.apply 230 | 231 | 232 | class QueryAndGroup(nn.Module): 233 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True): 234 | """ 235 | :param radius: float, radius of ball 236 | :param nsample: int, maximum number of features to gather in the ball 237 | :param use_xyz: 238 | """ 239 | super().__init__() 240 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 241 | 242 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: 243 | """ 244 | :param xyz: (B, N, 3) xyz coordinates of the features 245 | :param new_xyz: (B, npoint, 3) centroids 246 | :param features: (B, C, N) descriptors of the features 247 | :return: 248 | new_features: (B, 3 + C, npoint, nsample) 249 | """ 250 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 251 | xyz_trans = xyz.transpose(1, 2).contiguous() 252 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 253 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 254 | 255 | if features is not None: 256 | grouped_features = grouping_operation(features, idx) 257 | if self.use_xyz: 258 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) 259 | else: 260 | new_features = grouped_features 261 | else: 262 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" 263 | new_features = grouped_xyz 264 | 265 | return new_features 266 | 267 | 268 | class GroupAll(nn.Module): 269 | def __init__(self, use_xyz: bool = True): 270 | super().__init__() 271 | self.use_xyz = use_xyz 272 | 273 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): 274 | """ 275 | :param xyz: (B, N, 3) xyz coordinates of the features 276 | :param new_xyz: ignored 277 | :param features: (B, C, N) descriptors of the features 278 | :return: 279 | new_features: (B, C + 3, 1, N) 280 | """ 281 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 282 | if features is not None: 283 | grouped_features = features.unsqueeze(2) 284 | if self.use_xyz: 285 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) 286 | else: 287 | new_features = grouped_features 288 | else: 289 | new_features = grouped_xyz 290 | 291 | return new_features 292 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pointnet2', 6 | ext_modules=[ 7 | CUDAExtension('pointnet2_cuda', [ 8 | 'src/pointnet2_api.cpp', 9 | 10 | 'src/ball_query.cpp', 11 | 'src/ball_query_gpu.cu', 12 | 'src/group_points.cpp', 13 | 'src/group_points_gpu.cu', 14 | 'src/interpolate.cpp', 15 | 'src/interpolate_gpu.cu', 16 | 'src/sampling.cpp', 17 | 'src/sampling_gpu.cu', 18 | ], 19 | extra_compile_args={'cxx': ['-g'], 20 | 'nvcc': ['-O2']}) 21 | ], 22 | cmdclass={'build_ext': BuildExtension} 23 | ) 24 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | #include 8 | #include 9 | 10 | // extern THCState *state; 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 15 | 16 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 17 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 18 | CHECK_INPUT(new_xyz_tensor); 19 | CHECK_INPUT(xyz_tensor); 20 | const float *new_xyz = new_xyz_tensor.data(); 21 | const float *xyz = xyz_tensor.data(); 22 | int *idx = idx_tensor.data(); 23 | 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 26 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 27 | return 1; 28 | } -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | // #include 6 | #include "group_points_gpu.h" 7 | #include 8 | #include 9 | // extern THCState *state; 10 | 11 | 12 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 13 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 14 | 15 | float *grad_points = grad_points_tensor.data(); 16 | const int *idx = idx_tensor.data(); 17 | const float *grad_out = grad_out_tensor.data(); 18 | 19 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 21 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 22 | return 1; 23 | } 24 | 25 | 26 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 27 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 28 | 29 | const float *points = points_tensor.data(); 30 | const int *idx = idx_tensor.data(); 31 | float *out = out_tensor.data(); 32 | 33 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 35 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 36 | return 1; 37 | } -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "interpolate_gpu.h" 12 | 13 | // extern THCState *state; 14 | 15 | 16 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 17 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 18 | const float *unknown = unknown_tensor.data(); 19 | const float *known = known_tensor.data(); 20 | float *dist2 = dist2_tensor.data(); 21 | int *idx = idx_tensor.data(); 22 | 23 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 26 | } 27 | 28 | 29 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 30 | at::Tensor points_tensor, 31 | at::Tensor idx_tensor, 32 | at::Tensor weight_tensor, 33 | at::Tensor out_tensor) { 34 | 35 | const float *points = points_tensor.data(); 36 | const float *weight = weight_tensor.data(); 37 | float *out = out_tensor.data(); 38 | const int *idx = idx_tensor.data(); 39 | 40 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 41 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 42 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 43 | } 44 | 45 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 46 | at::Tensor grad_out_tensor, 47 | at::Tensor idx_tensor, 48 | at::Tensor weight_tensor, 49 | at::Tensor grad_points_tensor) { 50 | 51 | const float *grad_out = grad_out_tensor.data(); 52 | const float *weight = weight_tensor.data(); 53 | float *grad_points = grad_points_tensor.data(); 54 | const int *idx = idx_tensor.data(); 55 | 56 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 58 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 59 | } -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 24 | } 25 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | // #include 5 | 6 | #include "sampling_gpu.h" 7 | #include 8 | #include 9 | 10 | // extern THCState *state; 11 | 12 | 13 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 14 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 15 | const float *points = points_tensor.data(); 16 | const int *idx = idx_tensor.data(); 17 | float *out = out_tensor.data(); 18 | 19 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 21 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 22 | return 1; 23 | } 24 | 25 | 26 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 27 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 28 | 29 | const float *grad_out = grad_out_tensor.data(); 30 | const int *idx = idx_tensor.data(); 31 | float *grad_points = grad_points_tensor.data(); 32 | 33 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 35 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 36 | return 1; 37 | } 38 | 39 | 40 | int furthest_point_sampling_wrapper(int b, int n, int m, 41 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 42 | 43 | const float *points = points_tensor.data(); 44 | float *temp = temp_tensor.data(); 45 | int *idx = idx_tensor.data(); 46 | 47 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 48 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 49 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 50 | return 1; 51 | } 52 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/tools/_init_path.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) 3 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/tools/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch.utils.data as torch_data 4 | import kitti_utils 5 | import cv2 6 | from PIL import Image 7 | 8 | 9 | USE_INTENSITY = False 10 | 11 | 12 | class KittiDataset(torch_data.Dataset): 13 | def __init__(self, root_dir, split='train', mode='TRAIN'): 14 | self.split = split 15 | self.mode = mode 16 | self.classes = ['Car'] 17 | is_test = self.split == 'test' 18 | self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training') 19 | 20 | split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt') 21 | self.image_idx_list = [x.strip() for x in open(split_dir).readlines()] 22 | self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list] 23 | self.num_sample = self.image_idx_list.__len__() 24 | 25 | self.npoints = 16384 26 | 27 | self.image_dir = os.path.join(self.imageset_dir, 'image_2') 28 | self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne') 29 | self.calib_dir = os.path.join(self.imageset_dir, 'calib') 30 | self.label_dir = os.path.join(self.imageset_dir, 'label_2') 31 | self.plane_dir = os.path.join(self.imageset_dir, 'planes') 32 | 33 | def get_image(self, idx): 34 | img_file = os.path.join(self.image_dir, '%06d.png' % idx) 35 | assert os.path.exists(img_file) 36 | return cv2.imread(img_file) # (H, W, 3) BGR mode 37 | 38 | def get_image_shape(self, idx): 39 | img_file = os.path.join(self.image_dir, '%06d.png' % idx) 40 | assert os.path.exists(img_file) 41 | im = Image.open(img_file) 42 | width, height = im.size 43 | return height, width, 3 44 | 45 | def get_lidar(self, idx): 46 | lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx) 47 | assert os.path.exists(lidar_file) 48 | return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4) 49 | 50 | def get_calib(self, idx): 51 | calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) 52 | assert os.path.exists(calib_file) 53 | return kitti_utils.Calibration(calib_file) 54 | 55 | def get_label(self, idx): 56 | label_file = os.path.join(self.label_dir, '%06d.txt' % idx) 57 | assert os.path.exists(label_file) 58 | return kitti_utils.get_objects_from_label(label_file) 59 | 60 | @staticmethod 61 | def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape): 62 | val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1]) 63 | val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0]) 64 | val_flag_merge = np.logical_and(val_flag_1, val_flag_2) 65 | pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0) 66 | return pts_valid_flag 67 | 68 | def filtrate_objects(self, obj_list): 69 | type_whitelist = self.classes 70 | if self.mode == 'TRAIN': 71 | type_whitelist = list(self.classes) 72 | if 'Car' in self.classes: 73 | type_whitelist.append('Van') 74 | 75 | valid_obj_list = [] 76 | for obj in obj_list: 77 | if obj.cls_type not in type_whitelist: 78 | continue 79 | 80 | valid_obj_list.append(obj) 81 | return valid_obj_list 82 | 83 | def __len__(self): 84 | return len(self.sample_id_list) 85 | 86 | def __getitem__(self, index): 87 | sample_id = int(self.sample_id_list[index]) 88 | calib = self.get_calib(sample_id) 89 | img_shape = self.get_image_shape(sample_id) 90 | pts_lidar = self.get_lidar(sample_id) 91 | 92 | # get valid point (projected points should be in image) 93 | pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3]) 94 | pts_intensity = pts_lidar[:, 3] 95 | 96 | pts_img, pts_rect_depth = calib.rect_to_img(pts_rect) 97 | pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape) 98 | 99 | pts_rect = pts_rect[pts_valid_flag][:, 0:3] 100 | pts_intensity = pts_intensity[pts_valid_flag] 101 | 102 | if self.npoints < len(pts_rect): 103 | pts_depth = pts_rect[:, 2] 104 | pts_near_flag = pts_depth < 40.0 105 | far_idxs_choice = np.where(pts_near_flag == 0)[0] 106 | near_idxs = np.where(pts_near_flag == 1)[0] 107 | near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False) 108 | 109 | choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \ 110 | if len(far_idxs_choice) > 0 else near_idxs_choice 111 | np.random.shuffle(choice) 112 | else: 113 | choice = np.arange(0, len(pts_rect), dtype=np.int32) 114 | if self.npoints > len(pts_rect): 115 | extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False) 116 | choice = np.concatenate((choice, extra_choice), axis=0) 117 | np.random.shuffle(choice) 118 | 119 | ret_pts_rect = pts_rect[choice, :] 120 | ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5] 121 | 122 | pts_features = [ret_pts_intensity.reshape(-1, 1)] 123 | ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0] 124 | 125 | sample_info = {'sample_id': sample_id} 126 | 127 | if self.mode == 'TEST': 128 | if USE_INTENSITY: 129 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) 130 | else: 131 | pts_input = ret_pts_rect 132 | sample_info['pts_input'] = pts_input 133 | sample_info['pts_rect'] = ret_pts_rect 134 | sample_info['pts_features'] = ret_pts_features 135 | return sample_info 136 | 137 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 138 | 139 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 140 | 141 | # prepare input 142 | if USE_INTENSITY: 143 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) 144 | else: 145 | pts_input = ret_pts_rect 146 | 147 | # generate training labels 148 | cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d) 149 | sample_info['pts_input'] = pts_input 150 | sample_info['pts_rect'] = ret_pts_rect 151 | sample_info['cls_labels'] = cls_labels 152 | return sample_info 153 | 154 | @staticmethod 155 | def generate_training_labels(pts_rect, gt_boxes3d): 156 | cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32) 157 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True) 158 | extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2) 159 | extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True) 160 | for k in range(gt_boxes3d.shape[0]): 161 | box_corners = gt_corners[k] 162 | fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners) 163 | cls_label[fg_pt_flag] = 1 164 | 165 | # enlarge the bbox3d, ignore nearby points 166 | extend_box_corners = extend_gt_corners[k] 167 | fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners) 168 | ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag) 169 | cls_label[ignore_flag] = -1 170 | 171 | return cls_label 172 | 173 | def collate_batch(self, batch): 174 | batch_size = batch.__len__() 175 | ans_dict = {} 176 | 177 | for key in batch[0].keys(): 178 | if isinstance(batch[0][key], np.ndarray): 179 | ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0) 180 | 181 | else: 182 | ans_dict[key] = [batch[k][key] for k in range(batch_size)] 183 | if isinstance(batch[0][key], int): 184 | ans_dict[key] = np.array(ans_dict[key], dtype=np.int32) 185 | elif isinstance(batch[0][key], float): 186 | ans_dict[key] = np.array(ans_dict[key], dtype=np.float32) 187 | 188 | return ans_dict 189 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/tools/kitti_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import Delaunay 3 | import scipy 4 | 5 | 6 | def cls_type_to_id(cls_type): 7 | type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3, 'Van': 4} 8 | if cls_type not in type_to_id.keys(): 9 | return -1 10 | return type_to_id[cls_type] 11 | 12 | 13 | class Object3d(object): 14 | def __init__(self, line): 15 | label = line.strip().split(' ') 16 | self.src = line 17 | self.cls_type = label[0] 18 | self.cls_id = cls_type_to_id(self.cls_type) 19 | self.trucation = float(label[1]) 20 | self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown 21 | self.alpha = float(label[3]) 22 | self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32) 23 | self.h = float(label[8]) 24 | self.w = float(label[9]) 25 | self.l = float(label[10]) 26 | self.pos = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32) 27 | self.dis_to_cam = np.linalg.norm(self.pos) 28 | self.ry = float(label[14]) 29 | self.score = float(label[15]) if label.__len__() == 16 else -1.0 30 | self.level_str = None 31 | self.level = self.get_obj_level() 32 | 33 | def get_obj_level(self): 34 | height = float(self.box2d[3]) - float(self.box2d[1]) + 1 35 | 36 | if height >= 40 and self.trucation <= 0.15 and self.occlusion <= 0: 37 | self.level_str = 'Easy' 38 | return 1 # Easy 39 | elif height >= 25 and self.trucation <= 0.3 and self.occlusion <= 1: 40 | self.level_str = 'Moderate' 41 | return 2 # Moderate 42 | elif height >= 25 and self.trucation <= 0.5 and self.occlusion <= 2: 43 | self.level_str = 'Hard' 44 | return 3 # Hard 45 | else: 46 | self.level_str = 'UnKnown' 47 | return 4 48 | 49 | def generate_corners3d(self): 50 | """ 51 | generate corners3d representation for this object 52 | :return corners_3d: (8, 3) corners of box3d in camera coord 53 | """ 54 | l, h, w = self.l, self.h, self.w 55 | x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] 56 | y_corners = [0, 0, 0, 0, -h, -h, -h, -h] 57 | z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] 58 | 59 | R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)], 60 | [0, 1, 0], 61 | [-np.sin(self.ry), 0, np.cos(self.ry)]]) 62 | corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8) 63 | corners3d = np.dot(R, corners3d).T 64 | corners3d = corners3d + self.pos 65 | return corners3d 66 | 67 | def to_str(self): 68 | print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ 69 | % (self.cls_type, self.trucation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, 70 | self.pos, self.ry) 71 | return print_str 72 | 73 | def to_kitti_format(self): 74 | kitti_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \ 75 | % (self.cls_type, self.trucation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1], 76 | self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.pos[0], self.pos[1], self.pos[2], 77 | self.ry) 78 | return kitti_str 79 | 80 | 81 | def get_calib_from_file(calib_file): 82 | with open(calib_file) as f: 83 | lines = f.readlines() 84 | 85 | obj = lines[2].strip().split(' ')[1:] 86 | P2 = np.array(obj, dtype=np.float32) 87 | obj = lines[3].strip().split(' ')[1:] 88 | P3 = np.array(obj, dtype=np.float32) 89 | obj = lines[4].strip().split(' ')[1:] 90 | R0 = np.array(obj, dtype=np.float32) 91 | obj = lines[5].strip().split(' ')[1:] 92 | Tr_velo_to_cam = np.array(obj, dtype=np.float32) 93 | 94 | return {'P2': P2.reshape(3, 4), 95 | 'P3': P3.reshape(3, 4), 96 | 'R0': R0.reshape(3, 3), 97 | 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)} 98 | 99 | 100 | class Calibration(object): 101 | def __init__(self, calib_file): 102 | if isinstance(calib_file, str): 103 | calib = get_calib_from_file(calib_file) 104 | else: 105 | calib = calib_file 106 | 107 | self.P2 = calib['P2'] # 3 x 4 108 | self.R0 = calib['R0'] # 3 x 3 109 | self.V2C = calib['Tr_velo2cam'] # 3 x 4 110 | 111 | def cart_to_hom(self, pts): 112 | """ 113 | :param pts: (N, 3 or 2) 114 | :return pts_hom: (N, 4 or 3) 115 | """ 116 | pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32))) 117 | return pts_hom 118 | 119 | def lidar_to_rect(self, pts_lidar): 120 | """ 121 | :param pts_lidar: (N, 3) 122 | :return pts_rect: (N, 3) 123 | """ 124 | pts_lidar_hom = self.cart_to_hom(pts_lidar) 125 | pts_rect = np.dot(pts_lidar_hom, np.dot(self.V2C.T, self.R0.T)) 126 | return pts_rect 127 | 128 | def rect_to_img(self, pts_rect): 129 | """ 130 | :param pts_rect: (N, 3) 131 | :return pts_img: (N, 2) 132 | """ 133 | pts_rect_hom = self.cart_to_hom(pts_rect) 134 | pts_2d_hom = np.dot(pts_rect_hom, self.P2.T) 135 | pts_img = (pts_2d_hom[:, 0:2].T / pts_rect_hom[:, 2]).T # (N, 2) 136 | pts_rect_depth = pts_2d_hom[:, 2] - self.P2.T[3, 2] # depth in rect camera coord 137 | return pts_img, pts_rect_depth 138 | 139 | def lidar_to_img(self, pts_lidar): 140 | """ 141 | :param pts_lidar: (N, 3) 142 | :return pts_img: (N, 2) 143 | """ 144 | pts_rect = self.lidar_to_rect(pts_lidar) 145 | pts_img, pts_depth = self.rect_to_img(pts_rect) 146 | return pts_img, pts_depth 147 | 148 | 149 | def get_objects_from_label(label_file): 150 | with open(label_file, 'r') as f: 151 | lines = f.readlines() 152 | objects = [Object3d(line) for line in lines] 153 | return objects 154 | 155 | 156 | def objs_to_boxes3d(obj_list): 157 | boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32) 158 | for k, obj in enumerate(obj_list): 159 | boxes3d[k, 0:3], boxes3d[k, 3], boxes3d[k, 4], boxes3d[k, 5], boxes3d[k, 6] \ 160 | = obj.pos, obj.h, obj.w, obj.l, obj.ry 161 | return boxes3d 162 | 163 | 164 | def boxes3d_to_corners3d(boxes3d, rotate=True): 165 | """ 166 | :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] 167 | :param rotate: 168 | :return: corners3d: (N, 8, 3) 169 | """ 170 | boxes_num = boxes3d.shape[0] 171 | h, w, l = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5] 172 | x_corners = np.array([l / 2., l / 2., -l / 2., -l / 2., l / 2., l / 2., -l / 2., -l / 2.], dtype=np.float32).T # (N, 8) 173 | z_corners = np.array([w / 2., -w / 2., -w / 2., w / 2., w / 2., -w / 2., -w / 2., w / 2.], dtype=np.float32).T # (N, 8) 174 | 175 | y_corners = np.zeros((boxes_num, 8), dtype=np.float32) 176 | y_corners[:, 4:8] = -h.reshape(boxes_num, 1).repeat(4, axis=1) # (N, 8) 177 | 178 | if rotate: 179 | ry = boxes3d[:, 6] 180 | zeros, ones = np.zeros(ry.size, dtype=np.float32), np.ones(ry.size, dtype=np.float32) 181 | rot_list = np.array([[np.cos(ry), zeros, -np.sin(ry)], 182 | [zeros, ones, zeros], 183 | [np.sin(ry), zeros, np.cos(ry)]]) # (3, 3, N) 184 | R_list = np.transpose(rot_list, (2, 0, 1)) # (N, 3, 3) 185 | 186 | temp_corners = np.concatenate((x_corners.reshape(-1, 8, 1), y_corners.reshape(-1, 8, 1), 187 | z_corners.reshape(-1, 8, 1)), axis=2) # (N, 8, 3) 188 | rotated_corners = np.matmul(temp_corners, R_list) # (N, 8, 3) 189 | x_corners, y_corners, z_corners = rotated_corners[:, :, 0], rotated_corners[:, :, 1], rotated_corners[:, :, 2] 190 | 191 | x_loc, y_loc, z_loc = boxes3d[:, 0], boxes3d[:, 1], boxes3d[:, 2] 192 | 193 | x = x_loc.reshape(-1, 1) + x_corners.reshape(-1, 8) 194 | y = y_loc.reshape(-1, 1) + y_corners.reshape(-1, 8) 195 | z = z_loc.reshape(-1, 1) + z_corners.reshape(-1, 8) 196 | 197 | corners = np.concatenate((x.reshape(-1, 8, 1), y.reshape(-1, 8, 1), z.reshape(-1, 8, 1)), axis=2) 198 | 199 | return corners.astype(np.float32) 200 | 201 | 202 | def enlarge_box3d(boxes3d, extra_width): 203 | """ 204 | :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] 205 | """ 206 | if isinstance(boxes3d, np.ndarray): 207 | large_boxes3d = boxes3d.copy() 208 | else: 209 | large_boxes3d = boxes3d.clone() 210 | large_boxes3d[:, 3:6] += extra_width * 2 211 | large_boxes3d[:, 1] += extra_width 212 | return large_boxes3d 213 | 214 | 215 | def in_hull(p, hull): 216 | """ 217 | :param p: (N, K) test points 218 | :param hull: (M, K) M corners of a box 219 | :return (N) bool 220 | """ 221 | try: 222 | if not isinstance(hull, Delaunay): 223 | hull = Delaunay(hull) 224 | flag = hull.find_simplex(p) >= 0 225 | except scipy.spatial.qhull.QhullError: 226 | print('Warning: not a hull %s' % str(hull)) 227 | flag = np.zeros(p.shape[0], dtype=np.bool) 228 | 229 | return flag 230 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/tools/pointnet2_msg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.append('..') 5 | from pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG 6 | import pointnet2.pytorch_utils as pt_utils 7 | 8 | 9 | def get_model(input_channels=0): 10 | return Pointnet2MSG(input_channels=input_channels) 11 | 12 | 13 | NPOINTS = [4096, 1024, 256, 64] 14 | RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] 15 | NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]] 16 | MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], 17 | [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]] 18 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] 19 | CLS_FC = [128] 20 | DP_RATIO = 0.5 21 | 22 | 23 | class Pointnet2MSG(nn.Module): 24 | def __init__(self, input_channels=6): 25 | super().__init__() 26 | 27 | self.SA_modules = nn.ModuleList() 28 | channel_in = input_channels 29 | 30 | skip_channel_list = [input_channels] 31 | for k in range(NPOINTS.__len__()): 32 | mlps = MLPS[k].copy() 33 | channel_out = 0 34 | for idx in range(mlps.__len__()): 35 | mlps[idx] = [channel_in] + mlps[idx] 36 | channel_out += mlps[idx][-1] 37 | 38 | self.SA_modules.append( 39 | PointnetSAModuleMSG( 40 | npoint=NPOINTS[k], 41 | radii=RADIUS[k], 42 | nsamples=NSAMPLE[k], 43 | mlps=mlps, 44 | use_xyz=True, 45 | bn=True 46 | ) 47 | ) 48 | skip_channel_list.append(channel_out) 49 | channel_in = channel_out 50 | 51 | self.FP_modules = nn.ModuleList() 52 | 53 | for k in range(FP_MLPS.__len__()): 54 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out 55 | self.FP_modules.append( 56 | PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k]) 57 | ) 58 | 59 | cls_layers = [] 60 | pre_channel = FP_MLPS[0][-1] 61 | for k in range(0, CLS_FC.__len__()): 62 | cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=True)) 63 | pre_channel = CLS_FC[k] 64 | cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None)) 65 | cls_layers.insert(1, nn.Dropout(0.5)) 66 | self.cls_layer = nn.Sequential(*cls_layers) 67 | 68 | def _break_up_pc(self, pc): 69 | xyz = pc[..., 0:3].contiguous() 70 | features = ( 71 | pc[..., 3:].transpose(1, 2).contiguous() 72 | if pc.size(-1) > 3 else None 73 | ) 74 | 75 | return xyz, features 76 | 77 | def forward(self, pointcloud: torch.cuda.FloatTensor): 78 | xyz, features = self._break_up_pc(pointcloud) 79 | 80 | l_xyz, l_features = [xyz], [features] 81 | for i in range(len(self.SA_modules)): 82 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 83 | 84 | print(li_xyz.shape, li_features.shape) 85 | 86 | l_xyz.append(li_xyz) 87 | l_features.append(li_features) 88 | 89 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 90 | l_features[i - 1] = self.FP_modules[i]( 91 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 92 | ) 93 | 94 | pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) 95 | return pred_cls 96 | 97 | if __name__ == '__main__': 98 | net = Pointnet2MSG(0).cuda() 99 | pts = torch.randn(2, 1024, 3).cuda() 100 | 101 | pre = net(pts) 102 | print(pre.shape) 103 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnet2_utils/tools/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import _init_path 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_sched 8 | from torch.nn.utils import clip_grad_norm_ 9 | from torch.utils.data import DataLoader 10 | import tensorboard_logger as tb_log 11 | from dataset import KittiDataset 12 | import argparse 13 | import importlib 14 | 15 | parser = argparse.ArgumentParser(description="Arg parser") 16 | parser.add_argument("--batch_size", type=int, default=8) 17 | parser.add_argument("--epochs", type=int, default=100) 18 | parser.add_argument("--ckpt_save_interval", type=int, default=5) 19 | parser.add_argument('--workers', type=int, default=4) 20 | parser.add_argument("--mode", type=str, default='train') 21 | parser.add_argument("--ckpt", type=str, default='None') 22 | 23 | parser.add_argument("--net", type=str, default='pointnet2_msg') 24 | 25 | parser.add_argument('--lr', type=float, default=0.002) 26 | parser.add_argument('--lr_decay', type=float, default=0.2) 27 | parser.add_argument('--lr_clip', type=float, default=0.000001) 28 | parser.add_argument('--decay_step_list', type=list, default=[50, 70, 80, 90]) 29 | parser.add_argument('--weight_decay', type=float, default=0.001) 30 | 31 | parser.add_argument("--output_dir", type=str, default='output') 32 | parser.add_argument("--extra_tag", type=str, default='default') 33 | 34 | args = parser.parse_args() 35 | 36 | FG_THRESH = 0.3 37 | 38 | 39 | def log_print(info, log_f=None): 40 | print(info) 41 | if log_f is not None: 42 | print(info, file=log_f) 43 | 44 | 45 | class DiceLoss(nn.Module): 46 | def __init__(self, ignore_target=-1): 47 | super().__init__() 48 | self.ignore_target = ignore_target 49 | 50 | def forward(self, input, target): 51 | """ 52 | :param input: (N), logit 53 | :param target: (N), {0, 1} 54 | :return: 55 | """ 56 | input = torch.sigmoid(input.view(-1)) 57 | target = target.float().view(-1) 58 | mask = (target != self.ignore_target).float() 59 | return 1.0 - (torch.min(input, target) * mask).sum() / torch.clamp((torch.max(input, target) * mask).sum(), min=1.0) 60 | 61 | 62 | def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f): 63 | model.train() 64 | log_print('===============TRAIN EPOCH %d================' % epoch, log_f=log_f) 65 | loss_func = DiceLoss(ignore_target=-1) 66 | 67 | for it, batch in enumerate(train_loader): 68 | optimizer.zero_grad() 69 | 70 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] 71 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() 72 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) 73 | 74 | pred_cls = model(pts_input) 75 | pred_cls = pred_cls.view(-1) 76 | 77 | loss = loss_func(pred_cls, cls_labels) 78 | loss.backward() 79 | clip_grad_norm_(model.parameters(), 1.0) 80 | optimizer.step() 81 | 82 | total_it += 1 83 | 84 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) 85 | fg_mask = cls_labels > 0 86 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() 87 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct 88 | iou = correct / torch.clamp(union, min=1.0) 89 | 90 | cur_lr = lr_scheduler.get_lr()[0] 91 | tb_log.log_value('learning_rate', cur_lr, epoch) 92 | if tb_log is not None: 93 | tb_log.log_value('train_loss', loss, total_it) 94 | tb_log.log_value('train_fg_iou', iou, total_it) 95 | 96 | log_print('training epoch %d: it=%d/%d, total_it=%d, loss=%.5f, fg_iou=%.3f, lr=%f' % 97 | (epoch, it, len(train_loader), total_it, loss.item(), iou.item(), cur_lr), log_f=log_f) 98 | 99 | return total_it 100 | 101 | 102 | def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None): 103 | model.train() 104 | log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f) 105 | 106 | iou_list = [] 107 | for it, batch in enumerate(eval_loader): 108 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] 109 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() 110 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) 111 | 112 | pred_cls = model(pts_input) 113 | pred_cls = pred_cls.view(-1) 114 | 115 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) 116 | fg_mask = cls_labels > 0 117 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() 118 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct 119 | iou = correct / torch.clamp(union, min=1.0) 120 | 121 | iou_list.append(iou.item()) 122 | log_print('EVAL: it=%d/%d, iou=%.3f' % (it, len(eval_loader), iou), log_f=log_f) 123 | 124 | iou_list = np.array(iou_list) 125 | avg_iou = iou_list.mean() 126 | if tb_log is not None: 127 | tb_log.log_value('eval_fg_iou', avg_iou, epoch) 128 | 129 | log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f) 130 | return avg_iou 131 | 132 | 133 | def save_checkpoint(model, epoch, ckpt_name): 134 | if isinstance(model, torch.nn.DataParallel): 135 | model_state = model.module.state_dict() 136 | else: 137 | model_state = model.state_dict() 138 | 139 | state = {'epoch': epoch, 'model_state': model_state} 140 | ckpt_name = '{}.pth'.format(ckpt_name) 141 | torch.save(state, ckpt_name) 142 | 143 | 144 | def load_checkpoint(model, filename): 145 | if os.path.isfile(filename): 146 | log_print("==> Loading from checkpoint %s" % filename) 147 | checkpoint = torch.load(filename) 148 | epoch = checkpoint['epoch'] 149 | model.load_state_dict(checkpoint['model_state']) 150 | log_print("==> Done") 151 | else: 152 | raise FileNotFoundError 153 | 154 | return epoch 155 | 156 | 157 | def train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f): 158 | model.cuda() 159 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 160 | 161 | def lr_lbmd(cur_epoch): 162 | cur_decay = 1 163 | for decay_step in args.decay_step_list: 164 | if cur_epoch >= decay_step: 165 | cur_decay = cur_decay * args.lr_decay 166 | return max(cur_decay, args.lr_clip / args.lr) 167 | 168 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) 169 | 170 | total_it = 0 171 | for epoch in range(1, args.epochs + 1): 172 | lr_scheduler.step(epoch) 173 | total_it = train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f) 174 | 175 | if epoch % args.ckpt_save_interval == 0: 176 | with torch.no_grad(): 177 | avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log, log_f) 178 | ckpt_name = os.path.join(ckpt_dir, 'checkpoint_epoch_%d' % epoch) 179 | save_checkpoint(model, epoch, ckpt_name) 180 | 181 | 182 | if __name__ == '__main__': 183 | MODEL = importlib.import_module(args.net) # import network module 184 | model = MODEL.get_model(input_channels=0) 185 | 186 | eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val') 187 | eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, 188 | num_workers=args.workers, collate_fn=eval_set.collate_batch) 189 | 190 | if args.mode == 'train': 191 | train_set = KittiDataset(root_dir='./data', mode='TRAIN', split='train') 192 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, 193 | num_workers=args.workers, collate_fn=train_set.collate_batch) 194 | # output dir config 195 | output_dir = os.path.join(args.output_dir, args.extra_tag) 196 | os.makedirs(output_dir, exist_ok=True) 197 | tb_log.configure(os.path.join(output_dir, 'tensorboard')) 198 | ckpt_dir = os.path.join(output_dir, 'ckpt') 199 | os.makedirs(ckpt_dir, exist_ok=True) 200 | 201 | log_file = os.path.join(output_dir, 'log.txt') 202 | log_f = open(log_file, 'w') 203 | 204 | for key, val in vars(args).items(): 205 | log_print("{:16} {}".format(key, val), log_f=log_f) 206 | 207 | # train and eval 208 | train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f) 209 | log_f.close() 210 | elif args.mode == 'eval': 211 | epoch = load_checkpoint(model, args.ckpt) 212 | model.cuda() 213 | with torch.no_grad(): 214 | avg_iou = eval_one_epoch(model, eval_loader, epoch) 215 | else: 216 | raise NotImplementedError 217 | 218 | -------------------------------------------------------------------------------- /networks/pts_encoder/pointnets.py: -------------------------------------------------------------------------------- 1 | """refer to https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py.""" 2 | 3 | from __future__ import print_function 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | from ipdb import set_trace 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | 14 | class STN3d(nn.Module): 15 | def __init__(self): 16 | super(STN3d, self).__init__() 17 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 18 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 19 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 20 | self.fc1 = nn.Linear(1024, 512) 21 | self.fc2 = nn.Linear(512, 256) 22 | self.fc3 = nn.Linear(256, 9) 23 | self.relu = nn.ReLU() 24 | 25 | def forward(self, x): 26 | batchsize = x.size()[0] 27 | x = F.relu(self.conv1(x)) 28 | x = F.relu(self.conv2(x)) 29 | x = F.relu(self.conv3(x)) 30 | x = torch.max(x, 2, keepdim=True)[0] 31 | x = x.view(-1, 1024) 32 | 33 | x = F.relu(self.fc1(x)) 34 | x = F.relu(self.fc2(x)) 35 | x = self.fc3(x) 36 | 37 | iden = Variable(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=torch.float32)).view(1, 9).repeat(batchsize, 1) 38 | if x.is_cuda: 39 | iden = iden.cuda() 40 | x = x + iden 41 | x = x.view(-1, 3, 3) 42 | return x 43 | 44 | 45 | class STNkd(nn.Module): 46 | def __init__(self, k=64): 47 | super(STNkd, self).__init__() 48 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 49 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 50 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 51 | self.fc1 = nn.Linear(1024, 512) 52 | self.fc2 = nn.Linear(512, 256) 53 | self.fc3 = nn.Linear(256, k * k) 54 | self.relu = nn.ReLU() 55 | 56 | self.k = k 57 | 58 | def forward(self, x): 59 | batchsize = x.size()[0] 60 | x = F.relu(self.conv1(x)) 61 | x = F.relu(self.conv2(x)) 62 | x = F.relu(self.conv3(x)) 63 | x = torch.max(x, 2, keepdim=True)[0] 64 | x = x.view(-1, 1024) 65 | 66 | x = F.relu(self.fc1(x)) 67 | x = F.relu(self.fc2(x)) 68 | x = self.fc3(x) 69 | 70 | iden = ( 71 | Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) 72 | .view(1, self.k * self.k) 73 | .repeat(batchsize, 1) 74 | ) 75 | if x.is_cuda: 76 | iden = iden.cuda() 77 | x = x + iden 78 | x = x.view(-1, self.k, self.k) 79 | return x 80 | 81 | 82 | # NOTE: removed BN 83 | class PointNetfeat(nn.Module): 84 | def __init__(self, num_points, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False, **args): 85 | super(PointNetfeat, self).__init__() 86 | self.num_points = num_points 87 | self.out_dim = out_dim 88 | self.feature_transform = feature_transform 89 | # self.stn = STN3d(in_dim=in_dim) 90 | self.stn = STNkd(k=in_dim) 91 | self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) 92 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 93 | self.conv3 = torch.nn.Conv1d(128, 512, 1) 94 | self.conv4 = torch.nn.Conv1d(512, out_dim, 1) 95 | self.global_feat = global_feat 96 | if self.feature_transform: 97 | self.fstn = STNkd(k=64) 98 | 99 | def forward(self, x, **args): 100 | n_pts = x.shape[2] 101 | trans = self.stn(x) 102 | x = x.transpose(2, 1) 103 | x = torch.bmm(x, trans) 104 | x = x.transpose(2, 1) 105 | x = F.relu(self.conv1(x)) 106 | 107 | if self.feature_transform: 108 | trans_feat = self.fstn(x) 109 | x = x.transpose(2, 1) 110 | x = torch.bmm(x, trans_feat) 111 | x = x.transpose(2, 1) 112 | 113 | pointfeat = x 114 | x = F.relu(self.conv2(x)) 115 | x = F.relu(self.conv3(x)) 116 | x = self.conv4(x) 117 | x = torch.max(x, 2, keepdim=True)[0] 118 | x = x.view(-1, self.out_dim) 119 | if self.global_feat: 120 | return x 121 | else: 122 | x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) 123 | return torch.cat([x, pointfeat], 1) 124 | 125 | 126 | def feature_transform_regularizer(trans): 127 | d = trans.size()[1] 128 | batchsize = trans.size()[0] 129 | I = torch.eye(d)[None, :, :] 130 | if trans.is_cuda: 131 | I = I.cuda() 132 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) 133 | return loss 134 | 135 | 136 | if __name__ == "__main__": 137 | sim_data = Variable(torch.rand(32, 3, 2500)) 138 | trans = STN3d() 139 | out = trans(sim_data) 140 | print("stn", out.size()) 141 | print("loss", feature_transform_regularizer(out)) 142 | 143 | sim_data_64d = Variable(torch.rand(32, 64, 2500)) 144 | trans = STNkd(k=64) 145 | out = trans(sim_data_64d) 146 | print("stn64d", out.size()) 147 | print("loss", feature_transform_regularizer(out)) 148 | 149 | pointfeat_g = PointNetfeat(global_feat=True, num_points=2500) 150 | out = pointfeat_g(sim_data) 151 | print("global feat", out.size()) 152 | 153 | pointfeat = PointNetfeat(global_feat=False, num_points=2500) 154 | out = pointfeat(sim_data) 155 | print("point feat", out.size()) 156 | 157 | -------------------------------------------------------------------------------- /networks/reward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | from ipdb import set_trace 10 | from utils.genpose_utils import get_pose_dim 11 | from utils.metrics import get_metrics 12 | 13 | 14 | class RewardModel(nn.Module): 15 | def __init__(self, pose_mode): 16 | """ 17 | init func. 18 | 19 | Args: 20 | encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0 21 | """ 22 | super(RewardModel, self).__init__() 23 | pose_dim = get_pose_dim(pose_mode) 24 | self.act = nn.ReLU(True) 25 | 26 | ''' encode pose ''' 27 | self.pose_encoder = nn.Sequential( 28 | nn.Linear(pose_dim, 256), 29 | self.act, 30 | nn.Linear(256, 256), 31 | self.act, 32 | ) 33 | 34 | ''' decoder ''' 35 | self.reward_layer = nn.Sequential( 36 | nn.Linear(1024+256, 256), 37 | self.act, 38 | nn.Linear(256, 2), 39 | ) 40 | 41 | def forward( 42 | self, 43 | pts_feature, 44 | pose 45 | ): 46 | """ 47 | calculate the score of every pose 48 | 49 | Args: 50 | pts_feature (torch.tensor): [batch, 1024] 51 | pred_pose (torch.tensor): [batch, pose_dim] 52 | Returns: 53 | reward (torch.tensor): [batch, 2], the score of the pose estimation results, 54 | the first item is rotation score and the second item is translation score. 55 | """ 56 | 57 | pose_feature = self.pose_encoder(pose) 58 | feature = torch.cat((pts_feature, pose_feature), dim=-1) # [bs, 1024+256] 59 | reward = self.reward_layer(feature) # (batch, 1) 60 | return reward 61 | 62 | 63 | def sort_results(energy, metrics): 64 | """ Sorting the results according to the pose error (low to high) 65 | 66 | Args: 67 | energy (torch.tensor): [bs, repeat_num, 2] 68 | metrics (torch.tensor): [bs, repeat_num, 2] 69 | 70 | Return: 71 | sorted_energy (torch.tensor): [bs, repeat_num, 2] 72 | """ 73 | rot_error = metrics[..., 0] 74 | trans_error = metrics[..., 1] 75 | 76 | rot_index = torch.argsort(rot_error, dim=1, descending=False) 77 | trans_index = torch.argsort(trans_error, dim=1, descending=False) 78 | 79 | sorted_energy = energy.clone() 80 | sorted_energy[..., 0] = energy[..., 0].gather(1, rot_index) 81 | sorted_energy[..., 1] = energy[..., 1].gather(1, trans_index) 82 | 83 | return sorted_energy 84 | 85 | 86 | # def ranking_loss(energy): 87 | # """ Calculate the ranking loss 88 | 89 | # Args: 90 | # energy (torch.tensor): [bs, repeat_num, 2] 91 | 92 | # Returns: 93 | # loss (torch.tensor) 94 | # """ 95 | # loss, count = 0, 0 96 | # repeat_num = energy.shape[1] 97 | 98 | # for i in range(repeat_num - 1): 99 | # for j in range(i+1, repeat_num): 100 | # # diff = torch.log(torch.sigmoid(score[:, i, :] - score[:, j, :])) 101 | # diff = torch.sigmoid(-energy[:, i, :] + energy[:, j, :]) 102 | # loss += torch.mean(diff) 103 | # count += 1 104 | # loss = loss / count 105 | # return loss 106 | 107 | 108 | 109 | def ranking_loss(energy): 110 | """ Calculate the ranking loss 111 | 112 | Args: 113 | energy (torch.tensor): [bs, repeat_num, 2] 114 | 115 | Returns: 116 | loss (torch.tensor) 117 | """ 118 | loss, count = 0, 0 119 | repeat_num = energy.shape[1] 120 | 121 | for i in range(repeat_num - 1): 122 | for j in range(i+1, repeat_num): 123 | # diff = torch.log(torch.sigmoid(score[:, i, :] - score[:, j, :])) 124 | diff = 1 + (-energy[:, i, :] + energy[:, j, :]) / (torch.abs(energy[:, i, :] - energy[:, j, :]) + 1e-5) 125 | loss += torch.mean(diff) 126 | count += 1 127 | loss = loss / count 128 | return loss 129 | 130 | 131 | def sort_poses_by_energy(poses, energy): 132 | """ Rank the poses from highest to lowest energy 133 | 134 | Args: 135 | poses (torch.tensor): [bs, inference_num, pose_dim] 136 | energy (torch.tensor): [bs, inference_num, 2] 137 | 138 | Returns: 139 | sorted_poses (torch.tensor): [bs, inference_num, pose_dim] 140 | sorted_energy (torch.tensor): [bs, inference_num, 2] 141 | """ 142 | # get the sorted energy 143 | bs = poses.shape[0] 144 | repeat_num= poses.shape[1] 145 | sorted_energy, indices_1 = torch.sort(energy, descending=True, dim=1) 146 | indices_0 = torch.arange(0, energy.shape[0]).view(1, -1).to(energy.device).repeat(1, repeat_num) 147 | indices_1_rot = indices_1.permute(2, 1, 0)[0].reshape(1, -1) 148 | indices_1_trans = indices_1.permute(2, 1, 0)[1].reshape(1, -1) 149 | rot_index = torch.cat((indices_0, indices_1_rot), dim=0).cpu().numpy().tolist() 150 | trans_index = torch.cat((indices_0, indices_1_trans), dim=0).cpu().numpy().tolist() 151 | sorted_poses = poses[rot_index] 152 | sorted_poses[:, -3:] = poses[trans_index][:, -3:] 153 | sorted_poses = sorted_poses.view(repeat_num, bs, -1).permute(1, 0, 2) 154 | 155 | return sorted_poses, sorted_energy 156 | 157 | 158 | def test_ranking_loss(): 159 | energy = torch.tensor([[[100, 100], 160 | [9, 9], 161 | [8, 8], 162 | [10, 10]]]) 163 | loss = ranking_loss(energy) 164 | print(loss) 165 | 166 | if __name__ == '__main__': 167 | test_ranking_loss() 168 | # bs = 3 169 | # repeat_num = 5 170 | # pts_feature = torch.randn(bs, 1024) 171 | # pred_pose = torch.randn(bs, repeat_num, 7) 172 | # metrics = torch.randn(bs, repeat_num, 2) 173 | 174 | # reward_model = RewardModel(pose_mode='quat_wxyz') 175 | # reward = reward_model(pts_feature.unsqueeze(1).repeat(1, repeat_num, 1), pred_pose) 176 | # sorted_reward = sort_results(reward, metrics) 177 | # loss = ranking_loss(sorted_reward) 178 | 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.2.0.32 2 | scipy==1.4.1 3 | numpy==1.23.5 4 | tensorboardX==2.5.1 -------------------------------------------------------------------------------- /scripts/eval_single.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \ 2 | --score_model_dir ScoreNet/ckpt_genpose.pth \ 3 | --energy_model_dir EnergyNet/ckpt_genpose.pth \ 4 | --data_path NOCS_DATASET_PATH \ 5 | --sampler_mode ode \ 6 | --max_eval_num 1000000 \ 7 | --percentage_data_for_test 1.0 \ 8 | --batch_size 256 \ 9 | --seed 0 \ 10 | --test_source real_test \ 11 | --result_dir results \ 12 | --eval_repeat_num 50 \ 13 | --pooling_mode average \ 14 | --ranker energy_ranker \ 15 | --T0 0.55 \ 16 | # --save_video \ 17 | -------------------------------------------------------------------------------- /scripts/eval_tracking.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \ 2 | --score_model_dir ScoreNet/ckpt_genpose.pth \ 3 | --energy_model_dir EnergyNet/ckpt_genpose.pth \ 4 | --data_path NOCS_DATASET_PATH \ 5 | --sampler_mode ode \ 6 | --max_eval_num 1000000 \ 7 | --percentage_data_for_test 1.0 \ 8 | --batch_size 256 \ 9 | --seed 0 \ 10 | --test_source aligned_real_test \ 11 | --result_dir results \ 12 | --eval_repeat_num 50 \ 13 | --pooling_mode average \ 14 | --ranker energy_ranker \ 15 | --T0 0.15 \ 16 | # --save_video \ 17 | -------------------------------------------------------------------------------- /scripts/tensorboard.sh: -------------------------------------------------------------------------------- 1 | tensorboard --logdir ./results/logs/ --port 0505 --reload_interval 1 --samples_per_plugin images=999 -------------------------------------------------------------------------------- /scripts/train_energy.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python runners/trainer.py \ 2 | --data_path NOCS_DATASET_PATH \ 3 | --log_dir EnergyNet \ 4 | --agent_type energy \ 5 | --sampler_mode ode \ 6 | --batch_size 192 \ 7 | --eval_freq 1 \ 8 | --n_epochs 200 \ 9 | --selected_classes bottle bowl camera can laptop mug \ 10 | --percentage_data_for_train 1.0 \ 11 | --percentage_data_for_test 1.0 \ 12 | --percentage_data_for_val 1.0 \ 13 | --seed 0 \ 14 | --is_train \ 15 | -------------------------------------------------------------------------------- /scripts/train_score.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python runners/trainer.py \ 2 | --data_path NOCS_DATASET_PATH \ 3 | --log_dir ScoreNet \ 4 | --agent_type score \ 5 | --sampler_mode ode \ 6 | --sampling_steps 500 \ 7 | --eval_freq 1 \ 8 | --n_epochs 1900 \ 9 | --percentage_data_for_train 1.0 \ 10 | --percentage_data_for_test 1.0 \ 11 | --percentage_data_for_val 1.0 \ 12 | --seed 0 \ 13 | --is_train \ 14 | -------------------------------------------------------------------------------- /utils/datasets_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def get_2d_coord_np(width, height, low=0, high=1, fmt="CHW"): 5 | """ 6 | Args: 7 | width: 8 | height: 9 | Returns: 10 | xy: (2, height, width) 11 | """ 12 | # coords values are in [low, high] [0,1] or [-1,1] 13 | x = np.linspace(0, width-1, width, dtype=np.float32) 14 | y = np.linspace(0, height-1, height, dtype=np.float32) 15 | xy = np.asarray(np.meshgrid(x, y)) 16 | if fmt == "HWC": 17 | xy = xy.transpose(1, 2, 0) 18 | elif fmt == "CHW": 19 | pass 20 | else: 21 | raise ValueError(f"Unknown format: {fmt}") 22 | return xy 23 | 24 | 25 | def aug_bbox_DZI(hyper_params, bbox_xyxy, im_H, im_W): 26 | """Used for DZI, the augmented box is a square (maybe enlarged) 27 | Args: 28 | bbox_xyxy (np.ndarray): 29 | Returns: 30 | center, scale 31 | """ 32 | x1, y1, x2, y2 = bbox_xyxy.copy() 33 | cx = 0.5 * (x1 + x2) 34 | cy = 0.5 * (y1 + y2) 35 | bh = y2 - y1 36 | bw = x2 - x1 37 | if hyper_params['DZI_TYPE'].lower() == "uniform": 38 | scale_ratio = 1 + hyper_params['DZI_SCALE_RATIO'] * (2 * np.random.random_sample() - 1) # [1-0.25, 1+0.25] 39 | shift_ratio = hyper_params['DZI_SHIFT_RATIO'] * (2 * np.random.random_sample(2) - 1) # [-0.25, 0.25] 40 | bbox_center = np.array([cx + bw * shift_ratio[0], cy + bh * shift_ratio[1]]) # (h/2, w/2) 41 | scale = max(y2 - y1, x2 - x1) * scale_ratio * hyper_params['DZI_PAD_SCALE'] 42 | elif hyper_params['DZI_TYPE'].lower() == "roi10d": 43 | # shift (x1,y1), (x2,y2) by 15% in each direction 44 | _a = -0.15 45 | _b = 0.15 46 | x1 += bw * (np.random.rand() * (_b - _a) + _a) 47 | x2 += bw * (np.random.rand() * (_b - _a) + _a) 48 | y1 += bh * (np.random.rand() * (_b - _a) + _a) 49 | y2 += bh * (np.random.rand() * (_b - _a) + _a) 50 | x1 = min(max(x1, 0), im_W) 51 | x2 = min(max(x1, 0), im_W) 52 | y1 = min(max(y1, 0), im_H) 53 | y2 = min(max(y2, 0), im_H) 54 | bbox_center = np.array([0.5 * (x1 + x2), 0.5 * (y1 + y2)]) 55 | scale = max(y2 - y1, x2 - x1) * hyper_params['DZI_PAD_SCALE'] 56 | elif hyper_params['DZI_TYPE'].lower() == "truncnorm": 57 | raise NotImplementedError("DZI truncnorm not implemented yet.") 58 | else: 59 | bbox_center = np.array([cx, cy]) # (w/2, h/2) 60 | scale = max(y2 - y1, x2 - x1) 61 | scale = min(scale, max(im_H, im_W)) * 1.0 62 | return bbox_center, scale 63 | 64 | 65 | def aug_bbox_eval(bbox_xyxy, im_H, im_W): 66 | """Used for DZI, the augmented box is a square (maybe enlarged) 67 | Args: 68 | bbox_xyxy (np.ndarray): 69 | Returns: 70 | center, scale 71 | """ 72 | x1, y1, x2, y2 = bbox_xyxy.copy() 73 | cx = 0.5 * (x1 + x2) 74 | cy = 0.5 * (y1 + y2) 75 | bh = y2 - y1 76 | bw = x2 - x1 77 | bbox_center = np.array([cx, cy]) # (w/2, h/2) 78 | scale = max(y2 - y1, x2 - x1) 79 | scale = min(scale, max(im_H, im_W)) * 1.0 80 | return bbox_center, scale 81 | 82 | def crop_resize_by_warp_affine(img, center, scale, output_size, rot=0, interpolation=cv2.INTER_LINEAR): 83 | """ 84 | output_size: int or (w, h) 85 | NOTE: if img is (h,w,1), the output will be (h,w) 86 | """ 87 | if isinstance(scale, (int, float)): 88 | scale = (scale, scale) 89 | if isinstance(output_size, int): 90 | output_size = (output_size, output_size) 91 | trans = get_affine_transform(center, scale, rot, output_size) 92 | 93 | dst_img = cv2.warpAffine(img, trans, (int(output_size[0]), int(output_size[1])), flags=interpolation) 94 | 95 | return dst_img 96 | 97 | def get_affine_transform(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=False): 98 | """ 99 | adapted from CenterNet: https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py 100 | center: ndarray: (cx, cy) 101 | scale: (w, h) 102 | rot: angle in deg 103 | output_size: int or (w, h) 104 | """ 105 | if isinstance(center, (tuple, list)): 106 | center = np.array(center, dtype=np.float32) 107 | 108 | if isinstance(scale, (int, float)): 109 | scale = np.array([scale, scale], dtype=np.float32) 110 | 111 | if isinstance(output_size, (int, float)): 112 | output_size = (output_size, output_size) 113 | 114 | scale_tmp = scale 115 | src_w = scale_tmp[0] 116 | dst_w = output_size[0] 117 | dst_h = output_size[1] 118 | 119 | rot_rad = np.pi * rot / 180 120 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 121 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 122 | 123 | src = np.zeros((3, 2), dtype=np.float32) 124 | dst = np.zeros((3, 2), dtype=np.float32) 125 | src[0, :] = center + scale_tmp * shift 126 | src[1, :] = center + src_dir + scale_tmp * shift 127 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 128 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir 129 | 130 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 131 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 132 | 133 | if inv: 134 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 135 | else: 136 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 137 | 138 | return trans 139 | 140 | def get_dir(src_point, rot_rad): 141 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 142 | 143 | src_result = [0, 0] 144 | src_result[0] = src_point[0] * cs - src_point[1] * sn 145 | src_result[1] = src_point[0] * sn + src_point[1] * cs 146 | 147 | return src_result 148 | 149 | def get_3rd_point(a, b): 150 | direct = a - b 151 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 152 | -------------------------------------------------------------------------------- /utils/genpose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from ipdb import set_trace 6 | 7 | 8 | def get_pose_dim(rot_mode): 9 | assert rot_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'], \ 10 | f"the rotation mode {rot_mode} is not supported!" 11 | 12 | if rot_mode == 'quat_wxyz' or rot_mode == 'quat_xyzw': 13 | pose_dim = 7 14 | elif rot_mode == 'euler_xyz': 15 | pose_dim = 6 16 | elif rot_mode == 'euler_xyz_sx_cx' or rot_mode == 'rot_matrix': 17 | pose_dim = 9 18 | else: 19 | raise NotImplementedError 20 | return pose_dim 21 | 22 | ''' 23 | def rot6d_to_mat_batch(d6): 24 | """ 25 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix. 26 | Args: 27 | d6: 6D rotation representation, of size (*, 6) 28 | Returns: 29 | batch of rotation matrices of size (*, 3, 3) 30 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 31 | On the Continuity of Rotation Representations in Neural Networks. CVPR 2019. 32 | Retrieved from http://arxiv.org/abs/1812.07035 33 | """ 34 | # poses 35 | x_raw = d6[..., 0:3] # bx3 36 | y_raw = d6[..., 3:6] # bx3 37 | 38 | x = F.normalize(x_raw, p=2, dim=-1) # bx3 39 | z = torch.cross(x, y_raw, dim=-1) # bx3 40 | z = F.normalize(z, p=2, dim=-1) # bx3 41 | y = torch.cross(z, x, dim=-1) # bx3 42 | 43 | # (*,3)x3 --> (*,3,3) 44 | return torch.stack((x, y, z), dim=-1) # (b,3,3) 45 | ''' 46 | 47 | def rot6d_to_mat_batch(d6): 48 | """ 49 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix. 50 | Args: 51 | d6: 6D rotation representation, of size (*, 6) 52 | Returns: 53 | batch of rotation matrices of size (*, 3, 3) 54 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 55 | On the Continuity of Rotation Representations in Neural Networks. CVPR 2019. 56 | Retrieved from http://arxiv.org/abs/1812.07035 57 | """ 58 | # poses 59 | x_raw = d6[..., 0:3] # bx3 60 | y_raw = d6[..., 3:6] # bx3 61 | 62 | x = x_raw / np.linalg.norm(x_raw, axis=-1, keepdims=True) # b*3 63 | z = np.cross(x, y_raw) # b*3 64 | z = z / np.linalg.norm(z, axis=-1, keepdims=True) # b*3 65 | y = np.cross(z, x) # b*3 66 | 67 | return np.stack((x, y, z), axis=-1) # (b,3,3) 68 | 69 | 70 | class TrainClock(object): 71 | """ Clock object to track epoch and step during training 72 | """ 73 | def __init__(self): 74 | self.epoch = 1 75 | self.minibatch = 0 76 | self.step = 0 77 | 78 | def tick(self): 79 | self.minibatch += 1 80 | self.step += 1 81 | 82 | def tock(self): 83 | self.epoch += 1 84 | self.minibatch = 0 85 | 86 | def make_checkpoint(self): 87 | return { 88 | 'epoch': self.epoch, 89 | 'minibatch': self.minibatch, 90 | 'step': self.step 91 | } 92 | 93 | def restore_checkpoint(self, clock_dict): 94 | self.epoch = clock_dict['epoch'] 95 | self.minibatch = clock_dict['minibatch'] 96 | self.step = clock_dict['step'] 97 | 98 | 99 | def merge_results(results_ori, results_new): 100 | if len(results_ori.keys()) == 0: 101 | return results_new 102 | else: 103 | results = { 104 | 'pred_pose': torch.cat([results_ori['pred_pose'], results_new['pred_pose']], dim=0), 105 | 'gt_pose': torch.cat([results_ori['gt_pose'], results_new['gt_pose']], dim=0), 106 | 'cls_id': torch.cat([results_ori['cls_id'], results_new['cls_id']], dim=0), 107 | 'handle_visibility': torch.cat([results_ori['handle_visibility'], results_new['handle_visibility']], dim=0), 108 | # 'path': results_ori['path'] + results_new['path'], 109 | } 110 | return results 111 | 112 | 113 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import torch 5 | import numpy as np 6 | import pickle 7 | 8 | from utils.misc import get_rot_matrix, inverse_RT 9 | from utils.genpose_utils import get_pose_dim 10 | from ipdb import set_trace 11 | 12 | def rot_diff_rad(rot1, rot2, chosen_axis=None, flip_axis=False): 13 | if chosen_axis is not None: 14 | axis = {'x': 0, 'y': 1, 'z': 2} 15 | y1, y2 = rot1[..., axis], rot2[..., axis] # [Bs, 3] 16 | diff = torch.sum(y1 * y2, dim=-1) # [Bs] 17 | diff = torch.clamp(diff, min=-1.0, max=1.0) 18 | rad = torch.acos(diff) 19 | if not flip_axis: 20 | return rad 21 | else: 22 | return torch.min(rad, np.pi - rad) 23 | 24 | else: 25 | mat_diff = torch.matmul(rot1, rot2.transpose(-1, -2)) 26 | diff = mat_diff[..., 0, 0] + mat_diff[..., 1, 1] + mat_diff[..., 2, 2] 27 | diff = (diff - 1) / 2.0 28 | diff = torch.clamp(diff, min=-1.0, max=1.0) 29 | return torch.acos(diff) 30 | 31 | 32 | def rot_diff_degree(rot1, rot2, chosen_axis=None, flip_axis=False): 33 | return rot_diff_rad(rot1, rot2, chosen_axis=chosen_axis, flip_axis=flip_axis) / np.pi * 180.0 34 | 35 | 36 | def get_trans_error(trans_1, trans_2): 37 | diff = torch.norm(trans_1 - trans_2, dim=-1) 38 | return diff 39 | 40 | 41 | def get_rot_error(rot_1, rot_2, error_mode, chosen_axis=None, flip_axis=False): 42 | assert error_mode in ['radian', 'degree'], f"the rotation error mode {error_mode} is not supported!" 43 | if error_mode == 'radian': 44 | rot_error = rot_diff_rad(rot_1, rot_2, chosen_axis, flip_axis) 45 | else: 46 | rot_error = rot_diff_degree(rot_1, rot_2, chosen_axis, flip_axis) 47 | return rot_error 48 | 49 | 50 | def get_metrics_single_category(pose_1, pose_2, pose_mode, error_mode, chosen_axis=None, flip_axis=False, o2c_pose=False): 51 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'rot_matrix'],\ 52 | f"the rotation mode {pose_mode} is not supported!" 53 | 54 | if pose_mode == 'rot_matrix': 55 | index = 6 56 | elif pose_mode == 'euler_xyz': 57 | index = 3 58 | else: 59 | index = 4 60 | 61 | rot_1 = pose_1[:, :index] 62 | rot_2 = pose_2[:, :index] 63 | trans_1 = pose_1[:, index:] 64 | trans_2 = pose_2[:, index:] 65 | 66 | rot_matrix_1 = get_rot_matrix(rot_1, pose_mode) 67 | rot_matrix_2 = get_rot_matrix(rot_2, pose_mode) 68 | 69 | if o2c_pose == False: 70 | rot_matrix_1, trans_1 = inverse_RT(rot_matrix_1, trans_1) 71 | rot_matrix_2, trans_2 = inverse_RT(rot_matrix_2, trans_2) 72 | 73 | rot_error = get_rot_error(rot_matrix_1, rot_matrix_2, error_mode, chosen_axis, flip_axis) 74 | trans_error = get_trans_error(trans_1, trans_2) 75 | 76 | return rot_error.cpu().numpy(), trans_error.cpu().numpy() 77 | 78 | 79 | def compute_RT_errors(RT_1, RT_2, class_id, handle_visibility, synset_names): 80 | """ 81 | Args: 82 | sRT_1: [4, 4]. homogeneous affine transformation 83 | sRT_2: [4, 4]. homogeneous affine transformation 84 | 85 | Returns: 86 | theta: angle difference of R in degree 87 | shift: l2 difference of T in centimeter 88 | """ 89 | # make sure the last row is [0, 0, 0, 1] 90 | if RT_1 is None or RT_2 is None: 91 | return -1 92 | try: 93 | assert np.array_equal(RT_1[3, :], RT_2[3, :]) 94 | assert np.array_equal(RT_1[3, :], np.array([0, 0, 0, 1])) 95 | except AssertionError: 96 | print(RT_1[3, :], RT_2[3, :]) 97 | exit() 98 | 99 | R1 = RT_1[:3, :3] / np.cbrt(np.linalg.det(RT_1[:3, :3])) 100 | T1 = RT_1[:3, 3] 101 | R2 = RT_2[:3, :3] / np.cbrt(np.linalg.det(RT_2[:3, :3])) 102 | T2 = RT_2[:3, 3] 103 | # symmetric when rotating around y-axis 104 | if synset_names[class_id] in ['bottle', 'can', 'bowl'] or \ 105 | (synset_names[class_id] == 'mug' and handle_visibility == 0): 106 | y = np.array([0, 1, 0]) 107 | y1 = R1 @ y 108 | y2 = R2 @ y 109 | cos_theta = y1.dot(y2) / (np.linalg.norm(y1) * np.linalg.norm(y2)) 110 | else: 111 | R = R1 @ R2.transpose() 112 | cos_theta = (np.trace(R) - 1) / 2 113 | 114 | theta = np.arccos(np.clip(cos_theta, -1.0, 1.0)) * 180 / np.pi 115 | shift = np.linalg.norm(T1 - T2) * 100 116 | result = np.array([theta, shift]) 117 | 118 | return result 119 | 120 | ''' 121 | def compute_RT_overlaps(gt_class_ids, gt_sRT, gt_handle_visibility, pred_class_ids, pred_sRT, synset_names): 122 | """ Finds overlaps between prediction and ground truth instances. 123 | 124 | Returns: 125 | overlaps: 126 | 127 | """ 128 | num_pred = len(pred_class_ids) 129 | num_gt = len(gt_class_ids) 130 | overlaps = np.zeros((num_pred, num_gt, 2)) 131 | 132 | for i in range(num_pred): 133 | for j in range(num_gt): 134 | overlaps[i, j, :] = compute_RT_errors(pred_sRT[i], gt_sRT[j], gt_class_ids[j], 135 | gt_handle_visibility[j], synset_names) 136 | return overlaps 137 | 138 | ''' 139 | 140 | 141 | def compute_RT_overlaps(class_ids, gt_RT, pred_RT, gt_handle_visibility, synset_names): 142 | """ Finds overlaps between prediction and ground truth instances. 143 | 144 | Returns: 145 | overlaps: 146 | 147 | """ 148 | num = len(class_ids) 149 | overlaps = np.zeros((num, 2)) 150 | 151 | for i in range(num): 152 | overlaps[i, :] = compute_RT_errors(pred_RT[i], gt_RT[i], class_ids[i], 153 | gt_handle_visibility[i], synset_names) 154 | return overlaps 155 | 156 | 157 | def get_metrics(pose_1, pose_2, class_ids, synset_names, gt_handle_visibility, pose_mode, o2c_pose=False): 158 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\ 159 | f"the rotation mode {pose_mode} is not supported!" 160 | 161 | index = get_pose_dim(pose_mode) - 3 162 | 163 | rot_1 = pose_1[:, :index] 164 | rot_2 = pose_2[:, :index] 165 | trans_1 = pose_1[:, index:] 166 | trans_2 = pose_2[:, index:] 167 | 168 | rot_matrix_1 = get_rot_matrix(rot_1, pose_mode) 169 | rot_matrix_2 = get_rot_matrix(rot_2, pose_mode) 170 | 171 | if o2c_pose == False: 172 | rot_matrix_1, trans_1 = inverse_RT(rot_matrix_1, trans_1) 173 | rot_matrix_2, trans_2 = inverse_RT(rot_matrix_2, trans_2) 174 | 175 | bs = pose_1.shape[0] 176 | RT_1 = torch.eye(4).unsqueeze(0).repeat([bs, 1, 1]) 177 | RT_2 = torch.eye(4).unsqueeze(0).repeat([bs, 1, 1]) 178 | 179 | RT_1[:, :3, :3] = rot_matrix_1 180 | RT_1[:, :3, 3] = trans_1 181 | RT_2[:, :3, :3] = rot_matrix_2 182 | RT_2[:, :3, 3] = trans_2 183 | 184 | error = compute_RT_overlaps(class_ids, RT_1.cpu().numpy(), RT_2.cpu().numpy(), gt_handle_visibility, synset_names) 185 | rot_error = error[:, 0] 186 | trans_error = error[:, 1] 187 | return rot_error, trans_error 188 | 189 | 190 | if __name__ == '__main__': 191 | gt_pose = torch.rand(8, 7) 192 | gt_pose[:, :4] /= torch.norm(gt_pose[:, :4], dim=-1, keepdim=True) 193 | noise_pose = gt_pose + torch.rand(8, 7) / 10 194 | noise_pose[:, :4] /= torch.norm(noise_pose[:, :4], dim=-1, keepdim=True) 195 | rot_error = get_rot_error(gt_pose[:, :4], noise_pose[:, :4], 'camera', 'quat_wxyz', 'degree') 196 | trans_error = get_trans_error(gt_pose[:, 4:], noise_pose[:, 4:]) 197 | print(rot_error, trans_error) 198 | 199 | 200 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import copy 4 | import pytorch3d 5 | import pytorch3d.io 6 | import torch 7 | import torch.distributed as dist 8 | from ipdb import set_trace 9 | os.sys.path.append('..') 10 | from utils.genpose_utils import get_pose_dim 11 | from scipy.spatial.transform import Rotation as R 12 | 13 | 14 | def parallel_setup(rank, world_size, seed): 15 | os.environ['MASTER_ADDR'] = 'localhost' 16 | os.environ['MASTER_PORT'] = '12355' 17 | 18 | # initialize the process group 19 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 20 | 21 | # Explicitly setting seed to make sure that models created in two processes 22 | # start from same random weights and biases. 23 | torch.manual_seed(seed) 24 | 25 | 26 | def parallel_cleanup(): 27 | dist.destroy_process_group() 28 | 29 | 30 | def exists_or_mkdir(path): 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | return False 34 | else: 35 | return True 36 | 37 | 38 | def depth2xyz(depth_img, camera_params): 39 | # scale camera parameters 40 | h, w = depth_img.shape 41 | scale_x = w / camera_params['xres'] 42 | scale_y = h / camera_params['yres'] 43 | fx = camera_params['fx'] * scale_x 44 | fy = camera_params['fy'] * scale_y 45 | x_offset = camera_params['cx'] * scale_x 46 | y_offset = camera_params['cy'] * scale_y 47 | 48 | indices = np.indices((h, w), dtype=np.float32).transpose(1,2,0) 49 | z_e = depth_img 50 | x_e = (indices[..., 1] - x_offset) * z_e / fx 51 | y_e = (indices[..., 0] - y_offset) * z_e / fy 52 | xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3] 53 | return xyz_img 54 | 55 | 56 | def fps_down_sample(vertices, num_point_sampled): 57 | # FPS down sample 58 | # vertices.shape = (N,3) or (N,2) 59 | 60 | N = len(vertices) 61 | n = num_point_sampled 62 | assert n <= N, "Num of sampled point should be less than or equal to the size of vertices." 63 | _G = np.mean(vertices, axis=0) # centroid of vertices 64 | _d = np.linalg.norm(vertices - _G, axis=1, ord=2) 65 | farthest = np.argmax(_d) # Select the point farthest from the center of gravity as the starting point 66 | distances = np.inf * np.ones((N,)) 67 | flags = np.zeros((N,), np.bool_) # Whether the point is selected 68 | for i in range(n): 69 | flags[farthest] = True 70 | distances[farthest] = 0. 71 | p_farthest = vertices[farthest] 72 | dists = np.linalg.norm(vertices[~flags] - p_farthest, axis=1, ord=2) 73 | distances[~flags] = np.minimum(distances[~flags], dists) 74 | farthest = np.argmax(distances) 75 | return vertices[flags] 76 | 77 | 78 | def sample_data(data, num_sample): 79 | """ data is in N x ... 80 | we want to keep num_samplexC of them. 81 | if N > num_sample, we will randomly keep num_sample of them. 82 | if N < num_sample, we will randomly duplicate samples. 83 | """ 84 | N = data.shape[0] 85 | if (N == num_sample): 86 | return data, range(N) 87 | elif (N > num_sample): 88 | sample = np.random.choice(N, num_sample) 89 | return data[sample, ...], sample 90 | else: 91 | # print(N) 92 | sample = np.random.choice(N, num_sample-N) 93 | dup_data = data[sample, ...] 94 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample) 95 | 96 | 97 | def trans_form_quat_and_location(quaternion, location, quat_type='wxyz'): 98 | assert quat_type in ['wxyz', 'xyzw'], f"The type of quaternion {quat_type} is not supported!" 99 | 100 | if quat_type == 'xyzw': 101 | quaternion_xyzw = quaternion 102 | else: 103 | quaternion_xyzw = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]] 104 | 105 | scipy_rot = R.from_quat(quaternion_xyzw) 106 | rot = scipy_rot.as_matrix() 107 | 108 | location = location[np.newaxis, :].T 109 | transformation = np.concatenate((rot, location), axis=1) 110 | transformation = np.concatenate((transformation, np.array([[0, 0, 0, 1]])), axis=0) 111 | return transformation 112 | 113 | 114 | def get_rot_matrix(batch_pose, pose_mode='quat_wxyz'): 115 | """ 116 | pose_mode: 117 | 'quat_wxyz' -> batch_pose [B, 4] 118 | 'quat_xyzw' -> batch_pose [B, 4] 119 | 'euler_xyz' -> batch_pose [B, 3] 120 | 'rot_matrix' -> batch_pose [B, 6] 121 | 122 | Return: rot_matrix [B, 3, 3] 123 | """ 124 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\ 125 | f"the rotation mode {pose_mode} is not supported!" 126 | 127 | if pose_mode in ['quat_wxyz', 'quat_xyzw']: 128 | if pose_mode == 'quat_wxyz': 129 | quat_wxyz = batch_pose 130 | else: 131 | index = [3, 0, 1, 2] 132 | quat_wxyz = batch_pose[:, index] 133 | rot_mat = pytorch3d.transforms.quaternion_to_matrix(quat_wxyz) 134 | 135 | elif pose_mode == 'rot_matrix': 136 | rot_mat= pytorch3d.transforms.rotation_6d_to_matrix(batch_pose).permute(0, 2, 1) 137 | 138 | elif pose_mode == 'euler_xyz_sx_cx': 139 | rot_sin_theta = batch_pose[:, :3] 140 | rot_cos_theta = batch_pose[:, 3:6] 141 | theta = torch.atan2(rot_sin_theta, rot_cos_theta) 142 | rot_mat = pytorch3d.transforms.euler_angles_to_matrix(theta, 'ZYX') 143 | elif pose_mode == 'euler_xyz': 144 | rot_mat = pytorch3d.transforms.euler_angles_to_matrix(batch_pose, 'ZYX') 145 | else: 146 | raise NotImplementedError 147 | 148 | return rot_mat 149 | 150 | 151 | def transform_single_pts(pts, transformation): 152 | N = pts.shape[0] 153 | pts = np.concatenate((pts.T, np.ones(N)[np.newaxis, :]), axis=0) 154 | new_pts = transformation @ pts 155 | return new_pts.T[:, :3] 156 | 157 | 158 | def transform_batch_pts(batch_pts, batch_pose, pose_mode='quat_wxyz', inverse_pose=False): 159 | """ 160 | Args: 161 | batch_pts [B, N, C], N is the number of points, and C [x, y, z, ...] 162 | batch_pose [B, C], [quat/rot_mat/euler, trans] 163 | pose_mode is from ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'rot_matrix'] 164 | if inverse_pose is true, the transformation will be inversed 165 | Returns: 166 | new_pts [B, N, C] 167 | """ 168 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\ 169 | f"the rotation mode {pose_mode} is not supported!" 170 | 171 | B = batch_pts.shape[0] 172 | index = get_pose_dim(pose_mode) - 3 173 | rot = batch_pose[:, :index] 174 | loc = batch_pose[:, index:] 175 | 176 | rot_mat = get_rot_matrix(rot, pose_mode) 177 | if inverse_pose == True: 178 | rot_mat, loc = inverse_RT(rot_mat, loc) 179 | loc = loc[..., np.newaxis] 180 | 181 | trans_mat = torch.cat((rot_mat, loc), dim=2) 182 | trans_mat = torch.cat((trans_mat, torch.tile(torch.tensor([[0, 0, 0, 1]]).to(trans_mat.device), (B, 1, 1))), dim=1) 183 | 184 | new_pts = copy.deepcopy(batch_pts) 185 | padding = torch.ones([batch_pts.shape[0], batch_pts.shape[1], 1]).to(batch_pts.device) 186 | pts = torch.cat((batch_pts[:, :, :3], padding), dim=2) 187 | new_pts[:, :, :3] = torch.matmul(trans_mat.to(torch.float32), pts.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 188 | 189 | return new_pts 190 | 191 | 192 | def inverse_RT(batch_rot_mat, batch_trans): 193 | """ 194 | Args: 195 | batch_rot_mat [B, 3, 3] 196 | batch_trans [B, 3] 197 | Return: 198 | inversed_rot_mat [B, 3, 3] 199 | inversed_trans [B, 3] 200 | """ 201 | trans = batch_trans[..., np.newaxis] 202 | inversed_rot_mat = batch_rot_mat.permute(0, 2, 1) 203 | inversed_trans = - inversed_rot_mat @ trans 204 | return inversed_rot_mat, inversed_trans.squeeze(-1) 205 | 206 | 207 | """ https://arc.aiaa.org/doi/abs/10.2514/1.28949 """ 208 | """ https://stackoverflow.com/questions/12374087/average-of-multiple-quaternions """ 209 | """ http://tbirdal.blogspot.com/2019/10/i-allocate-this-post-to-providing.html """ 210 | def average_quaternion_torch(Q, weights=None): 211 | if weights is None: 212 | weights = torch.ones(len(Q), device=Q.device) / len(Q) 213 | A = torch.zeros((4, 4), device=Q.device) 214 | weight_sum = torch.sum(weights) 215 | 216 | oriented_Q = ((Q[:, 0:1] > 0).float() - 0.5) * 2 * Q 217 | A = torch.einsum("bi,bk->bik", (oriented_Q, oriented_Q)) 218 | A = torch.sum(torch.einsum("bij,b->bij", (A, weights)), 0) 219 | A /= weight_sum 220 | 221 | q_avg = torch.linalg.eigh(A)[1][:, -1] 222 | if q_avg[0] < 0: 223 | return -q_avg 224 | return q_avg 225 | 226 | 227 | def average_quaternion_batch(Q, weights=None): 228 | """calculate the average quaternion of the multiple quaternions 229 | Args: 230 | Q (tensor): [B, num_quaternions, 4] 231 | weights (tensor, optional): [B, num_quaternions]. Defaults to None. 232 | 233 | Returns: 234 | oriented_q_avg: average quaternion, [B, 4] 235 | """ 236 | 237 | if weights is None: 238 | weights = torch.ones((Q.shape[0], Q.shape[1]), device=Q.device) / Q.shape[1] 239 | A = torch.zeros((Q.shape[0], 4, 4), device=Q.device) 240 | weight_sum = torch.sum(weights, axis=-1) 241 | 242 | oriented_Q = ((Q[:, :, 0:1] > 0).float() - 0.5) * 2 * Q 243 | A = torch.einsum("abi,abk->abik", (oriented_Q, oriented_Q)) 244 | A = torch.sum(torch.einsum("abij,ab->abij", (A, weights)), 1) 245 | A /= weight_sum.reshape(A.shape[0], -1).unsqueeze(-1).repeat(1, 4, 4) 246 | 247 | q_avg = torch.linalg.eigh(A)[1][:, :, -1] 248 | oriented_q_avg = ((q_avg[:, 0:1] > 0).float() - 0.5) * 2 * q_avg 249 | return oriented_q_avg 250 | 251 | 252 | def average_quaternion_numpy(Q, W=None): 253 | if W is not None: 254 | Q *= W[:, None] 255 | eigvals, eigvecs = np.linalg.eig(Q.T@Q) 256 | return eigvecs[:, eigvals.argmax()] 257 | 258 | 259 | def normalize_rotation(rotation, rotation_mode): 260 | if rotation_mode == 'quat_wxyz' or rotation_mode == 'quat_xyzw': 261 | rotation /= torch.norm(rotation, dim=-1, keepdim=True) 262 | elif rotation_mode == 'rot_matrix': 263 | rot_matrix = get_rot_matrix(rotation, rotation_mode) 264 | rotation[:, :3] = rot_matrix[:, :, 0] 265 | rotation[:, 3:6] = rot_matrix[:, :, 1] 266 | elif rotation_mode == 'euler_xyz_sx_cx': 267 | rot_sin_theta = rotation[:, :3] 268 | rot_cos_theta = rotation[:, 3:6] 269 | theta = torch.atan2(rot_sin_theta, rot_cos_theta) 270 | rotation[:, :3] = torch.sin(theta) 271 | rotation[:, 3:6] = torch.cos(theta) 272 | elif rotation_mode == 'euler_xyz': 273 | pass 274 | else: 275 | raise NotImplementedError 276 | return rotation 277 | 278 | 279 | if __name__ == '__main__': 280 | quat = torch.randn(2, 3, 4) 281 | quat = quat / torch.linalg.norm(quat, axis=-1).unsqueeze(-1) 282 | quat = average_quaternion_batch(quat) 283 | 284 | 285 | -------------------------------------------------------------------------------- /utils/so3_visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import pytorch3d 5 | import cv2 6 | 7 | from matplotlib import rc 8 | from PIL import Image 9 | from pytorch3d import transforms 10 | from ipdb import set_trace 11 | 12 | 13 | rc("font", **{"family": "serif", "serif": ["Times New Roman"]}) 14 | EYE = np.eye(3) 15 | 16 | def visualize_so3_probabilities( 17 | rotations, 18 | probabilities, 19 | rotations_gt=None, 20 | choosed_rotation=None, 21 | ax=None, 22 | fig=None, 23 | display_threshold_probability=0, 24 | to_image=True, 25 | show_color_wheel=True, 26 | canonical_rotation=EYE, 27 | gt_size=600, 28 | choosed_size=300, 29 | y_offset=-30, 30 | dpi=600, 31 | ): 32 | """ 33 | Plot a single distribution on SO(3) using the tilt-colored method. 34 | Args: 35 | rotations: [N, 3, 3] tensor of rotation matrices 36 | probabilities: [N] tensor of probabilities 37 | rotations_gt: [N_gt, 3, 3] or [3, 3] ground truth rotation matrices 38 | ax: The matplotlib.pyplot.axis object to paint 39 | fig: The matplotlib.pyplot.figure object to paint 40 | display_threshold_probability: The probability threshold below which to omit 41 | the marker 42 | to_image: If True, return a tensor containing the pixels of the finished 43 | figure; if False return the figure itself 44 | show_color_wheel: If True, display the explanatory color wheel which matches 45 | color on the plot with tilt angle 46 | canonical_rotation: A [3, 3] rotation matrix representing the 'display 47 | rotation', to change the view of the distribution. It rotates the 48 | canonical axes so that the view of SO(3) on the plot is different, which 49 | can help obtain a more informative view. 50 | Returns: 51 | A matplotlib.pyplot.figure object, or a tensor of pixels if to_image=True. 52 | """ 53 | 54 | def _show_single_marker(ax, rotation, marker, edgecolors=True, facecolors=False, s=gt_size): 55 | eulers = transforms.matrix_to_euler_angles(torch.tensor(rotation), "ZXY") 56 | eulers = eulers.numpy() 57 | 58 | tilt_angle = eulers[0] 59 | latitude = eulers[1] 60 | longitude = eulers[2] 61 | 62 | color = cmap(0.5 + tilt_angle / 2 / np.pi) 63 | ax.scatter( 64 | longitude, 65 | latitude, 66 | s=s, 67 | edgecolors=color if edgecolors else "none", 68 | facecolors=facecolors if facecolors else "none", 69 | marker=marker, 70 | linewidth=5, 71 | ) 72 | 73 | if ax is None: 74 | fig = plt.figure(figsize=(4, 2), dpi=dpi) 75 | ax = fig.add_subplot(111, projection="mollweide") 76 | if rotations_gt is not None and len(rotations_gt.shape) == 2: 77 | rotations_gt = rotations_gt[None] 78 | if choosed_rotation is not None and len(choosed_rotation.shape) == 2: 79 | choosed_rotation = choosed_rotation[None] 80 | display_rotations = rotations @ canonical_rotation 81 | cmap = plt.cm.hsv 82 | scatterpoint_scaling = 4e3 83 | eulers_queries = transforms.matrix_to_euler_angles( 84 | torch.tensor(display_rotations), "ZXY" 85 | ) 86 | eulers_queries = eulers_queries.numpy() 87 | 88 | tilt_angles = eulers_queries[:, 0] 89 | longitudes = eulers_queries[:, 2] 90 | latitudes = eulers_queries[:, 1] 91 | 92 | which_to_display = probabilities > display_threshold_probability 93 | 94 | if rotations_gt is not None: 95 | display_rotations_gt = rotations_gt @ canonical_rotation 96 | 97 | for rotation in display_rotations_gt: 98 | _show_single_marker(ax, rotation, "o") 99 | # Cover up the centers with white markers 100 | for rotation in display_rotations_gt: 101 | _show_single_marker( 102 | ax, rotation, "o", edgecolors=False, facecolors="#ffffff" 103 | ) 104 | 105 | if choosed_rotation is not None: 106 | display_choosed_rotations = choosed_rotation @ canonical_rotation 107 | 108 | for rotation in display_choosed_rotations: 109 | _show_single_marker(ax, rotation, "o", s=choosed_size) 110 | # Cover up the centers with white markers 111 | for rotation in display_choosed_rotations: 112 | _show_single_marker( 113 | ax, rotation, "o", edgecolors=False, facecolors="#ffffff", s=choosed_size 114 | ) 115 | 116 | # Display the distribution 117 | ax.scatter( 118 | longitudes[which_to_display], 119 | latitudes[which_to_display], 120 | s=scatterpoint_scaling * probabilities[which_to_display], 121 | c=cmap(0.5 + tilt_angles[which_to_display] / 2.0 / np.pi), 122 | marker='.' 123 | ) 124 | 125 | yticks = np.array([-60, -30, 0, 30, 60]) 126 | yticks_minor = np.arange(-75, 90, 15) 127 | ax.set_yticks(yticks_minor * np.pi / 180, minor=True) 128 | ax.set_yticks(yticks * np.pi / 180, [f"{y}°" for y in yticks], fontsize=10) 129 | xticks = np.array([-90, 0, 90]) 130 | xticks_minor = np.arange(-150, 180, 30) 131 | ax.set_xticks(xticks * np.pi / 180, []) 132 | ax.set_xticks(xticks_minor * np.pi / 180, minor=True) 133 | 134 | for xtick in xticks: 135 | # Manually set xticks 136 | x = xtick * np.pi / 180 137 | y = y_offset * np.pi / 180 138 | ax.text(x, y, f"{xtick}°", ha="center", va="center", fontsize=10) 139 | 140 | ax.grid(which="minor") 141 | ax.grid(which="major") 142 | 143 | if show_color_wheel: 144 | # Add a color wheel showing the tilt angle to color conversion. 145 | ax = fig.add_axes([0.85, 0.12, 0.12, 0.12], projection="polar") 146 | theta = np.linspace(-3 * np.pi / 2, np.pi / 2, 200) 147 | radii = np.linspace(0.4, 0.5, 2) 148 | _, theta_grid = np.meshgrid(radii, theta) 149 | colormap_val = 0.5 + theta_grid / np.pi / 2.0 150 | ax.pcolormesh(theta, radii, colormap_val.T, cmap=cmap, shading="auto") 151 | ax.set_yticklabels([]) 152 | ax.set_xticks(np.arange(0, 2 * np.pi, np.pi / 2)) 153 | ax.set_xticklabels( 154 | [ 155 | r"90$\degree$", 156 | r"180$\degree$", 157 | r"270$\degree$", 158 | r"0$\degree$", 159 | ], 160 | fontsize=6, 161 | ) 162 | ax.spines["polar"].set_visible(False) 163 | ax.grid(False) 164 | plt.text( 165 | 0.5, 166 | 0.5, 167 | "Roll", 168 | fontsize=6, 169 | horizontalalignment="center", 170 | verticalalignment="center", 171 | transform=ax.transAxes, 172 | ) 173 | 174 | if to_image: 175 | return plot_to_image(fig) 176 | else: 177 | return fig 178 | 179 | 180 | def plot_to_image(fig): 181 | fig.canvas.draw() 182 | image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 183 | image_from_plot = image_from_plot.reshape( 184 | fig.canvas.get_width_height()[::-1] + (3,) 185 | ) 186 | plt.close(fig) 187 | return image_from_plot 188 | 189 | 190 | def antialias(image, level=1): 191 | is_numpy = isinstance(image, np.ndarray) 192 | if is_numpy: 193 | image = Image.fromarray(image) 194 | for _ in range(level): 195 | size = np.array(image.size) // 2 196 | image = image.resize(size, Image.LANCZOS) 197 | if is_numpy: 198 | image = np.array(image) 199 | return image 200 | 201 | 202 | def unnormalize_image(image): 203 | if isinstance(image, torch.Tensor): 204 | image = image.cpu().numpy() 205 | if image.shape[0] == 3: 206 | image = image.transpose(1, 2, 0) 207 | mean = np.array([0.485, 0.456, 0.406]) 208 | std = np.array([0.229, 0.224, 0.225]) 209 | image = image * std + mean 210 | return (image * 255.0).astype(np.uint8) 211 | 212 | 213 | def visualize_so3(save_path, pred_rotations, gt_rotation, pred_rotation=None, probabilities=None, image=None): 214 | if image == None: 215 | fig = plt.figure(figsize=(5, 2), dpi=600) 216 | gs = fig.add_gridspec(1, 2) 217 | ax = fig.add_subplot(gs[0, :], projection="mollweide") 218 | else: 219 | fig = plt.figure(figsize=(5, 4), dpi=600) 220 | gs = fig.add_gridspec(2, 2) 221 | ax = fig.add_subplot(gs[1, :], projection="mollweide") 222 | bx = fig.add_subplot(gs[0, :]) 223 | bx.imshow(image.permute(1, 2, 0).cpu().numpy()) 224 | bx.axis("off") 225 | # rotations = np.concatenate((pred_rotations, pred_rotation, gt_rotation), axis=0) 226 | rotations = pred_rotations 227 | if probabilities is None: 228 | probabilities = np.ones(rotations.shape[0])/2000 229 | # probabilities[-2] = 0.002 230 | # probabilities[-1] = 0.003 231 | 232 | so3_vis = visualize_so3_probabilities( 233 | rotations_gt=gt_rotation, 234 | choosed_rotation=pred_rotation, 235 | rotations=rotations, 236 | probabilities=probabilities, 237 | to_image=False, 238 | display_threshold_probability=0.00001, 239 | show_color_wheel=True, 240 | fig=fig, 241 | ax=ax, 242 | ) 243 | plt.savefig(save_path) 244 | 245 | 246 | if __name__ == '__main__': 247 | fig = plt.figure(figsize=(4, 4), dpi=300) 248 | gs = fig.add_gridspec(2, 2) 249 | ax1 = fig.add_subplot(gs[0, 0]) 250 | # ax1.imshow(unnormalize_image(images[0].cpu().numpy().transpose(1, 2, 0))) 251 | ax1.axis("off") 252 | ax2 = fig.add_subplot(gs[0, 1]) 253 | # ax2.imshow(unnormalize_image(images[1].cpu().numpy().transpose(1, 2, 0))) 254 | ax2.axis("off") 255 | ax3 = fig.add_subplot(gs[1, :], projection="mollweide") 256 | 257 | rotations = pytorch3d.transforms.random_rotations(10).cpu().numpy() 258 | probabilities = np.ones(10)/1000 259 | print(probabilities) 260 | so3_vis = visualize_so3_probabilities( 261 | rotations=rotations, 262 | probabilities=probabilities, 263 | to_image=False, 264 | display_threshold_probability=0.0005, 265 | show_color_wheel=True, 266 | fig=fig, 267 | ax=ax3, 268 | ) 269 | plt.savefig('./test.jpg') 270 | # plt.show() 271 | 272 | -------------------------------------------------------------------------------- /utils/tracking_utils.py: -------------------------------------------------------------------------------- 1 | ''' modified from CAPTRA https://github.com/HalfSummer11/CAPTRA/tree/5d7d088c3de49389a90b5fae280e96409e7246c6 ''' 2 | 3 | import torch 4 | import copy 5 | import math 6 | from ipdb import set_trace 7 | 8 | def normalize(q): 9 | assert q.shape[-1] == 4 10 | norm = q.norm(dim=-1, keepdim=True) 11 | return q.div(norm) 12 | 13 | 14 | def matrix_to_unit_quaternion(matrix): 15 | assert matrix.shape[-1] == matrix.shape[-2] == 3 16 | if not isinstance(matrix, torch.Tensor): 17 | matrix = torch.tensor(matrix) 18 | 19 | trace = 1 + matrix[..., 0, 0] + matrix[..., 1, 1] + matrix[..., 2, 2] 20 | trace = torch.clamp(trace, min=0.) 21 | r = torch.sqrt(trace) 22 | s = 1.0 / (2 * r + 1e-7) 23 | w = 0.5 * r 24 | x = (matrix[..., 2, 1] - matrix[..., 1, 2])*s 25 | y = (matrix[..., 0, 2] - matrix[..., 2, 0])*s 26 | z = (matrix[..., 1, 0] - matrix[..., 0, 1])*s 27 | 28 | q = torch.stack((w, x, y, z), dim=-1) 29 | 30 | return normalize(q) 31 | 32 | 33 | def generate_random_quaternion(quaternion_shape): 34 | assert quaternion_shape[-1] == 4 35 | rand_norm = torch.randn(quaternion_shape) 36 | rand_q = normalize(rand_norm) 37 | return rand_q 38 | 39 | 40 | def jitter_quaternion(q, theta): #[Bs, 4], [Bs, 1] 41 | new_q = generate_random_quaternion(q.shape).to(q.device) 42 | dot_product = torch.sum(q*new_q, dim=-1, keepdim=True) # 43 | shape = (tuple(1 for _ in range(len(dot_product.shape) - 1)) + (4, )) 44 | q_orthogonal = normalize(new_q - q * dot_product.repeat(*shape)) 45 | # theta = 2arccos(|p.dot(q)|) 46 | # |p.dot(q)| = cos(theta/2) 47 | tile_theta = theta.repeat(shape) 48 | jittered_q = q*torch.cos(tile_theta/2) + q_orthogonal*torch.sin(tile_theta/2) 49 | 50 | return jittered_q 51 | 52 | 53 | def assert_normalized(q, atol=1e-3): 54 | assert q.shape[-1] == 4 55 | norm = q.norm(dim=-1) 56 | norm_check = (norm - 1.0).abs() 57 | try: 58 | assert torch.max(norm_check) < atol 59 | except: 60 | print("normalization failure: {}.".format(torch.max(norm_check))) 61 | return -1 62 | return 0 63 | 64 | 65 | def unit_quaternion_to_matrix(q): 66 | assert_normalized(q) 67 | w, x, y, z= torch.unbind(q, dim=-1) 68 | matrix = torch.stack(( 1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y* w, 69 | 2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w, 70 | 2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x -2*y*y), 71 | dim=-1) 72 | matrix_shape = list(matrix.shape)[:-1]+[3,3] 73 | return matrix.view(matrix_shape).contiguous() 74 | 75 | 76 | def noisy_rot_matrix(matrix, rad, type='normal'): 77 | if type == 'normal': 78 | theta = torch.abs(torch.randn_like(matrix[..., 0, 0])) * rad 79 | elif type == 'uniform': 80 | theta = torch.rand_like(matrix[..., 0, 0]) * rad 81 | quater = matrix_to_unit_quaternion(matrix) 82 | new_quater = jitter_quaternion(quater, theta.unsqueeze(-1)) 83 | new_mat = unit_quaternion_to_matrix(new_quater) 84 | return new_mat 85 | 86 | 87 | def add_noise_to_RT(RT, type='normal', r=5.0, t=0.03): 88 | rand_type = type # 'uniform' or 'normal' --> we use 'normal' 89 | 90 | def random_tensor(base): 91 | if rand_type == 'uniform': 92 | return torch.rand_like(base) * 2.0 - 1.0 93 | elif rand_type == 'normal': 94 | return torch.randn_like(base) 95 | new_RT = copy.deepcopy(RT) 96 | new_RT[:, :3, :3] = noisy_rot_matrix(RT[:, :3, :3], r/180*math.pi, type=rand_type).reshape(RT[:, :3, :3].shape) 97 | norm = random_tensor(RT[:, 0, 0]) * t # [B, P] 98 | direction = random_tensor(RT[:, :3, 3].squeeze(-1)) # [B, P, 3] 99 | direction = direction / torch.clamp(direction.norm(dim=-1, keepdim=True), min=1e-9) # [B, P, 3] unit vecs 100 | new_RT[:, :3, 3] = RT[:, :3, 3] + (direction * norm.unsqueeze(-1)) # [B, P, 3, 1] 101 | 102 | return new_RT 103 | 104 | --------------------------------------------------------------------------------