├── LICENSE
├── README.md
├── assets
└── pipeline.png
├── configs
└── config.py
├── data
└── Real
│ └── train
│ └── mug_handle.pkl
├── datasets
└── datasets_genpose.py
├── networks
├── decoder_head
│ ├── rot_head.py
│ └── trans_head.py
├── gf_algorithms
│ ├── energynet.py
│ ├── losses.py
│ ├── samplers.py
│ ├── score_utils.py
│ ├── scorenet.py
│ └── sde.py
├── posenet.py
├── posenet_agent.py
├── pts_encoder
│ ├── pointnet2.py
│ ├── pointnet2_utils
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── pointnet2
│ │ │ ├── pointnet2_modules.py
│ │ │ ├── pointnet2_utils.py
│ │ │ ├── pytorch_utils.py
│ │ │ ├── setup.py
│ │ │ └── src
│ │ │ │ ├── ball_query.cpp
│ │ │ │ ├── ball_query_gpu.cu
│ │ │ │ ├── ball_query_gpu.h
│ │ │ │ ├── cuda_utils.h
│ │ │ │ ├── group_points.cpp
│ │ │ │ ├── group_points_gpu.cu
│ │ │ │ ├── group_points_gpu.h
│ │ │ │ ├── interpolate.cpp
│ │ │ │ ├── interpolate_gpu.cu
│ │ │ │ ├── interpolate_gpu.h
│ │ │ │ ├── pointnet2_api.cpp
│ │ │ │ ├── sampling.cpp
│ │ │ │ ├── sampling_gpu.cu
│ │ │ │ └── sampling_gpu.h
│ │ └── tools
│ │ │ ├── _init_path.py
│ │ │ ├── data
│ │ │ └── KITTI
│ │ │ │ └── ImageSets
│ │ │ │ ├── test.txt
│ │ │ │ ├── train.txt
│ │ │ │ ├── trainval.txt
│ │ │ │ └── val.txt
│ │ │ ├── dataset.py
│ │ │ ├── kitti_utils.py
│ │ │ ├── pointnet2_msg.py
│ │ │ └── train_and_eval.py
│ └── pointnets.py
└── reward.py
├── requirements.txt
├── runners
├── evaluation_single.py
├── evaluation_tracking.py
└── trainer.py
├── scripts
├── eval_single.sh
├── eval_tracking.sh
├── tensorboard.sh
├── train_energy.sh
└── train_score.sh
└── utils
├── data_augmentation.py
├── datasets_utils.py
├── genpose_utils.py
├── metrics.py
├── misc.py
├── sgpa_utils.py
├── so3_visualize.py
├── tracking_utils.py
└── visualize.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jiyao Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GenPose: Generative Category-level Object Pose Estimation via Diffusion Models
2 |
3 |
4 | [](https://sites.google.com/view/genpose)
5 | [](https://arxiv.org/pdf/2306.10531.pdf)
6 | [](https://hits.seeyoufarm.com)
7 | [](https://github.com/Jiyao06/GenPose/blob/main/LICENSE)
8 | [](https://paperswithcode.com/sota/6d-pose-estimation-using-rgbd-on-real275?p=genpose-generative-category-level-object-pose)
9 |
10 | The official Pytorch implementation of the NeurIPS 2023 paper, [GenPose](https://arxiv.org/pdf/2306.10531.pdf).
11 |
12 |
13 | ## News
14 | - **2024/08**: [Omni6DPoseAPI](https://github.com/Omni6DPose/Omni6DPoseAPI) and [genpose++](https://github.com/Omni6DPose/GenPose2) are released!
15 | - **2024/07** [Omni6DPose](https://jiyao06.github.io/Omni6DPose/), the largest and most diverse universal 6D object pose estimation benchmark, gets accepted to ECCV 2024.
16 | - **2023/09**: GenPose gets accepted to NeurIPS 2023
17 |
18 | ## Overview
19 |
20 | 
21 |
22 | **(I)** A score-based diffusion model and an energy-based diffusion model is trained via denoising score-matching.
23 | **(II)** a) We first generate pose candidates from the score-based model and then b) compute the pose energies for candidates via the energy-based model.
24 | c) Finally, we rank the candidates with the energies and then filter out low-ranking candidates.
25 | The remaining candidates are aggregated into the final output by mean-pooling.
26 |
27 | Contents of this repo are as follows:
28 |
29 | * [Overview](#overview)
30 | * [TODO](#todo)
31 | * [Requirements](#requirements)
32 | * [Installation](#installation)
33 | + [Install pytorch](#install-pytorch)
34 | + [Install pytorch3d from a local clone](#install-pytorch3d-from-a-local-clone)
35 | + [Install from requirements.txt](#install-from-requirementstxt)
36 | + [Compile pointnet2](#compile-pointnet2)
37 | * [Download dataset and models](#download-dataset-and-models)
38 | * [Training](#training)
39 | + [Score network](#score-network)
40 | + [Energy network](#energy-network)
41 | * [Evaluation](#evaluation)
42 | + [Evaluate on REAL275 dataset.](#evaluate-on-real275-dataset)
43 | + [Evaluate on CAMERA dataset.](#evaluate-on-camera-dataset)
44 | * [Citation](#citation)
45 | * [Contact](#contact)
46 | * [License](#license)
47 |
48 | ## UPDATE
49 | - Release the code for object pose tracking.
50 | - Release the preprocessed data for object pose tracking.
51 |
52 | ## Requirements
53 | - Ubuntu 20.04
54 | - Python 3.8.15
55 | - Pytorch 1.12.0
56 | - Pytorch3d 0.7.2
57 | - CUDA 11.3
58 | - 1 * NVIDIA RTX 3090
59 |
60 | ## Installation
61 |
62 | - ### Install pytorch
63 | ``` bash
64 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
65 | ```
66 |
67 |
68 | - ### Install pytorch3d from a local clone
69 | ``` bash
70 | git clone https://github.com/facebookresearch/pytorch3d.git
71 | cd pytorch3d
72 | git checkout -f v0.7.2
73 | pip install -e .
74 | ```
75 |
76 | - ### Install from requirements.txt
77 | ``` bash
78 | pip install -r requirements.txt
79 | ```
80 |
81 | - ### Compile pointnet2
82 | ``` bash
83 | cd networks/pts_encoder/pointnet2_utils/pointnet2
84 | python setup.py install
85 | ```
86 |
87 | ## Download dataset and models
88 | - Download camera_train, camera_val, real_train, real_test, ground-truth annotations and mesh models provided by NOCS and unzip the data. Then move the file "mug_handle.pkl" from this repository's "data/Real/train" folder to the corresponding unzipped folders. The file "mug_handle.pkl" is provided by [GPV-Pose](https://github.com/lolrudy/GPV_Pose/blob/master/mug_handle.pkl). Organize these files in $ROOT/data as follows:
89 | ``` bash
90 | data
91 | ├── CAMERA
92 | │ ├── train
93 | │ └── val
94 | ├── Real
95 | │ ├── train
96 | │ │ ├── mug_handle.pkl
97 | │ │ └── ...
98 | │ └── test
99 | ├── gts
100 | │ ├── val
101 | │ └── real_test
102 | └── obj_models
103 | ├── train
104 | ├── val
105 | ├── real_train
106 | └── real_test
107 | ```
108 |
109 | - Preprocess NOCS files following SPD.
110 |
111 | We provide the preprocessed testing data (REAL275) and checkpoints here for a quick evaluation. Download and organize the files in $ROOT/results as follows:
112 | ``` bash
113 | results
114 | ├── ckpts
115 | │ ├── EnergyNet
116 | │ │ └── ckpt_genpose.pth
117 | │ └── ScoreNet
118 | │ └── ckpt_genpose.pth
119 | ├── evaluation_results
120 | │ ├── segmentation_logs_real_test.txt
121 | │ └── segmentation_results_real_test.pkl
122 | └── mrcnn_results
123 | ├── aligned_real_test
124 | ├── real_test
125 | └── val
126 | ```
127 | The *ckpts* are the trained models of GenPose.
128 |
129 | The *evaluation_results* are the preprocessed testing data, which contains the segmentation results of Mask R-CNN, the segmented pointclouds of obejcts, and the ground-truth poses.
130 |
131 | The file *mrcnn_results* represents the segmentation results provided by SPD, and you also can find it here. Note that the file *mrcnn_results/aligned_real_test* contains the manually aligned segmentation results, used for object pose tracking.
132 |
133 | **Note**: You need to preprocess the dataset as mentioned before first if you want to evaluate on CAMERA dataset.
134 |
135 | ## Training
136 | Set the parameter '--data_path' in scripts/train_score.sh and scripts/train_energy.sh to your own path of NOCS dataset.
137 |
138 | - ### Score network
139 | Train the score network to generate the pose candidates.
140 | ``` bash
141 | bash scripts/train_score.sh
142 | ```
143 | - ### Energy network
144 | Train the energy network to aggragate the pose candidates.
145 | ``` bash
146 | bash scripts/train_energy.sh
147 | ```
148 |
149 | ## Evaluation
150 | Set the parameter *--data_path* in *scripts/eval_single.sh* and *scripts/eval_tracking* to your own path of NOCS dataset.
151 |
152 | - ### Evaluate on REAL275 dataset.
153 | Set the parameter *--test_source* in *scripts/eval_single.sh* to *'real_test'* and run:
154 | ``` bash
155 | bash scripts/eval_single.sh
156 | ```
157 | - ### Evaluate on CAMERA dataset.
158 | Set the parameter *--test_source* in *scripts/eval_single.sh* to *'val'* and run:
159 | ``` bash
160 | bash scripts/eval_single.sh
161 | ```
162 | - ### Pose tracking on REAL275 dataset.
163 | ``` bash
164 | bash scripts/eval_tracking.sh
165 | ```
166 |
167 | ## Citation
168 | If you find our work useful in your research, please consider citing:
169 | ``` bash
170 | @article{zhang2024generative,
171 | title={Generative category-level object pose estimation via diffusion models},
172 | author={Zhang, Jiyao and Wu, Mingdong and Dong, Hao},
173 | journal={Advances in Neural Information Processing Systems},
174 | volume={36},
175 | year={2024}
176 | }
177 | ```
178 |
179 | ## Contact
180 | If you have any questions, please feel free to contact us:
181 |
182 | [Jiyao Zhang](https://jiyao06.github.io/): [jiyaozhang@stu.pku.edu.cn](mailto:jiyaozhang@stu.pku.edu.cn)
183 |
184 | [Mingdong Wu](https://aaronanima.github.io/): [wmingd@pku.edu.cn](mailto:wmingd@pku.edu.cn)
185 |
186 | [Hao Dong](https://zsdonghao.github.io/): [hao.dong@pku.edu.cn](mailto:hao.dong@pku.edu.cn)
187 |
188 | ## License
189 | This project is released under the MIT license. See [LICENSE](LICENSE) for additional details.
190 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiyao06/GenPose/3a374e9a1a1dddc825b7d507ebcf38f2e3510aec/assets/pipeline.png
--------------------------------------------------------------------------------
/configs/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ipdb import set_trace
3 |
4 | def get_config():
5 | parser = argparse.ArgumentParser()
6 |
7 | """ dataset """
8 | parser.add_argument('--synset_names', nargs='+', default=['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug'])
9 | parser.add_argument('--selected_classes', nargs='+')
10 | parser.add_argument('--data_path', type=str)
11 | parser.add_argument('--o2c_pose', default=True, action='store_true')
12 | parser.add_argument('--batch_size', type=int, default=192)
13 | parser.add_argument('--max_batch_size', type=int, default=192)
14 | parser.add_argument('--mini_bs', type=int, default=192)
15 | parser.add_argument('--pose_mode', type=str, default='rot_matrix')
16 | parser.add_argument('--seed', type=int, default=0)
17 | parser.add_argument('--percentage_data_for_train', type=float, default=1.0)
18 | parser.add_argument('--percentage_data_for_val', type=float, default=1.0)
19 | parser.add_argument('--percentage_data_for_test', type=float, default=1.0)
20 | parser.add_argument('--train_source', type=str, default='CAMERA+Real')
21 | parser.add_argument('--val_source', type=str, default='CAMERA')
22 | parser.add_argument('--test_source', type=str, default='Real')
23 | parser.add_argument('--device', type=str, default='cuda')
24 | parser.add_argument('--num_points', type=int, default=1024)
25 | parser.add_argument('--per_obj', type=str, default='')
26 | parser.add_argument('--num_workers', type=int, default=32)
27 |
28 |
29 | """ model """
30 | parser.add_argument('--posenet_mode', type=str, default='score')
31 | parser.add_argument('--hidden_dim', type=int, default=128)
32 | parser.add_argument('--sampler_mode', nargs='+')
33 | parser.add_argument('--sampling_steps', type=int)
34 | parser.add_argument('--sde_mode', type=str, default='ve')
35 | parser.add_argument('--sigma', type=float, default=25) # base-sigma for SDE
36 | parser.add_argument('--likelihood_weighting', default=False, action='store_true')
37 | parser.add_argument('--regression_head', type=str, default='Rx_Ry_and_T')
38 | parser.add_argument('--pointnet2_params', type=str, default='light')
39 | parser.add_argument('--pts_encoder', type=str, default='pointnet2')
40 | parser.add_argument('--energy_mode', type=str, default='IP')
41 | parser.add_argument('--s_theta_mode', type=str, default='score')
42 | parser.add_argument('--norm_energy', type=str, default='identical')
43 |
44 |
45 | """ training """
46 | parser.add_argument('--agent_type', type=str, default='score', help='one of the [score, energy, energy_with_ranking]')
47 | parser.add_argument('--pretrained_score_model_path', type=str)
48 | parser.add_argument('--pretrained_energy_model_path', type=str)
49 | parser.add_argument('--distillation', default=False, action='store_true')
50 | parser.add_argument('--n_epochs', type=int, default=1000)
51 | parser.add_argument('--log_dir', type=str, default='debug')
52 | parser.add_argument('--optimizer', type=str, default='Adam')
53 | parser.add_argument('--eval_freq', type=int, default=100)
54 | parser.add_argument('--repeat_num', type=int, default=20)
55 | parser.add_argument('--grad_clip', type=float, default=1.)
56 | parser.add_argument('--ema_rate', type=float, default=0.999)
57 | parser.add_argument('--lr', type=float, default=1e-3)
58 | parser.add_argument('--warmup', type=int, default=100)
59 | parser.add_argument('--lr_decay', type=float, default=0.98)
60 | parser.add_argument('--use_pretrain', default=False, action='store_true')
61 | parser.add_argument('--parallel', default=False, action='store_true')
62 | parser.add_argument('--num_gpu', type=int, default=4)
63 | parser.add_argument('--is_train', default=False, action='store_true')
64 |
65 |
66 | """ testing """
67 | parser.add_argument('--eval', default=False, action='store_true')
68 | parser.add_argument('--pred', default=False, action='store_true')
69 | parser.add_argument('--model_name', type=str)
70 | parser.add_argument('--eval_repeat_num', type=int, default=50)
71 | parser.add_argument('--save_video', default=False, action='store_true')
72 | parser.add_argument('--max_eval_num', type=int, default=10000000)
73 | parser.add_argument('--results_path', type=str, default='')
74 | parser.add_argument('--T0', type=float, default=1.0)
75 |
76 |
77 | """ nocs_mrcnn testing"""
78 | parser.add_argument('--img_size', type=int, default=256, help='cropped image size')
79 | parser.add_argument('--result_dir', type=str, default='', help='result directory')
80 | parser.add_argument('--model_dir_list', nargs='+')
81 | parser.add_argument('--energy_model_dir', type=str, default='', help='energy network ckpt directory')
82 | parser.add_argument('--score_model_dir', type=str, default='', help='score network ckpt directory')
83 | parser.add_argument('--ranker', type=str, default='energy_ranker', help='energy_ranker, gt_ranker or random')
84 | parser.add_argument('--pooling_mode', type=str, default='nearest', help='nearest or average')
85 |
86 |
87 | cfg = parser.parse_args()
88 |
89 | # dynamic zoom in parameters
90 | cfg.DYNAMIC_ZOOM_IN_PARAMS = {
91 | 'DZI_PAD_SCALE': 1.5,
92 | 'DZI_TYPE': 'uniform',
93 | 'DZI_SCALE_RATIO': 0.25,
94 | 'DZI_SHIFT_RATIO': 0.25
95 | }
96 |
97 | # pts aug parameters
98 | cfg.PTS_AUG_PARAMS = {
99 | 'aug_pc_pro': 0.2,
100 | 'aug_pc_r': 0.2,
101 | 'aug_rt_pro': 0.3,
102 | 'aug_bb_pro': 0.3,
103 | 'aug_bc_pro': 0.3
104 | }
105 |
106 | # 2D aug parameters
107 | cfg.DEFORM_2D_PARAMS = {
108 | 'roi_mask_r': 3,
109 | 'roi_mask_pro': 0.5
110 | }
111 |
112 | return cfg
113 |
114 |
--------------------------------------------------------------------------------
/data/Real/train/mug_handle.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiyao06/GenPose/3a374e9a1a1dddc825b7d507ebcf38f2e3510aec/data/Real/train/mug_handle.pkl
--------------------------------------------------------------------------------
/networks/decoder_head/rot_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from ipdb import set_trace
5 |
6 |
7 | class RotHead(nn.Module):
8 | def __init__(self, in_feat_dim, out_dim=3):
9 | super(RotHead, self).__init__()
10 | self.f = in_feat_dim
11 | self.k = out_dim
12 |
13 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)
14 | self.conv2 = torch.nn.Conv1d(1024, 256, 1)
15 | self.conv3 = torch.nn.Conv1d(256, 256, 1)
16 | self.conv4 = torch.nn.Conv1d(256, self.k, 1)
17 | self.drop1 = nn.Dropout(0.2)
18 | self.bn1 = nn.BatchNorm1d(1024)
19 | self.bn2 = nn.BatchNorm1d(256)
20 | self.bn3 = nn.BatchNorm1d(256)
21 |
22 | def forward(self, x):
23 | x = F.relu(self.bn1(self.conv1(x)))
24 | x = F.relu(self.bn2(self.conv2(x)))
25 |
26 | x = torch.max(x, 2, keepdim=True)[0]
27 |
28 | x = F.relu(self.bn3(self.conv3(x)))
29 | x = self.drop1(x)
30 | x = self.conv4(x)
31 |
32 | x = x.squeeze(2)
33 | x = x.contiguous()
34 |
35 | return x
36 |
37 |
38 | def main():
39 | points = torch.rand(2, 1350, 1024) # batchsize x feature x numofpoint
40 | rot_head = RotHead(in_feat_dim=1350, out_dim=3)
41 | rot = rot_head(points)
42 | print(rot.shape)
43 |
44 |
45 | if __name__ == "__main__":
46 | main()
--------------------------------------------------------------------------------
/networks/decoder_head/trans_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from ipdb import set_trace
5 |
6 | # Point_center encode the segmented point cloud
7 | # one more conv layer compared to original paper
8 |
9 | class TransHead(nn.Module):
10 | def __init__(self, in_feat_dim, out_dim=3):
11 | super(TransHead, self).__init__()
12 | self.f = in_feat_dim
13 | self.k = out_dim
14 |
15 | self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)
16 |
17 | self.conv2 = torch.nn.Conv1d(1024, 256, 1)
18 | self.conv3 = torch.nn.Conv1d(256, 256, 1)
19 | self.conv4 = torch.nn.Conv1d(256, self.k, 1)
20 | self.drop1 = nn.Dropout(0.2)
21 | self.bn1 = nn.BatchNorm1d(1024)
22 | self.bn2 = nn.BatchNorm1d(256)
23 | self.bn3 = nn.BatchNorm1d(256)
24 | self.relu1 = nn.ReLU()
25 | self.relu2 = nn.ReLU()
26 | self.relu3 = nn.ReLU()
27 |
28 | def forward(self, x):
29 | x = self.relu1(self.bn1(self.conv1(x)))
30 | x = self.relu2(self.bn2(self.conv2(x)))
31 |
32 | x = torch.max(x, 2, keepdim=True)[0]
33 |
34 | x = self.relu3(self.bn3(self.conv3(x)))
35 | x = self.drop1(x)
36 | x = self.conv4(x)
37 |
38 | x = x.squeeze(2)
39 | x = x.contiguous()
40 | return x
41 |
42 |
43 | def main():
44 | feature = torch.rand(10, 1896, 1000)
45 | net = TransHead(in_feat_dim=1896, out_dim=3)
46 | out = net(feature)
47 | print(out.shape)
48 |
49 | if __name__ == "__main__":
50 | main()
--------------------------------------------------------------------------------
/networks/gf_algorithms/energynet.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ipdb import set_trace
6 | from torch.autograd import Variable
7 | from networks.gf_algorithms.samplers import cond_ode_likelihood, cond_ode_sampler, cond_pc_sampler
8 | from networks.gf_algorithms.scorenet import GaussianFourierProjection
9 | from networks.pts_encoder.pointnet2 import Pointnet2ClsMSG
10 | from networks.pts_encoder.pointnets import PointNetfeat
11 | from utils.genpose_utils import get_pose_dim
12 |
13 |
14 | def zero_module(module):
15 | """
16 | Zero out the parameters of a module and return it.
17 | """
18 | for p in module.parameters():
19 | p.detach().zero_()
20 | return module
21 |
22 |
23 | class TemporaryGrad(object):
24 | def __enter__(self):
25 | self.prev = torch.is_grad_enabled()
26 | torch.set_grad_enabled(True)
27 |
28 | def __exit__(self, exc_type, exc_value, traceback) -> None:
29 | torch.set_grad_enabled(self.prev)
30 |
31 |
32 | class PoseEnergyNet(nn.Module):
33 | def __init__(
34 | self,
35 | marginal_prob_func,
36 | pose_mode='quat_wxyz',
37 | regression_head='RT',
38 | device='cuda',
39 | energy_mode='L2', # ['DAE', 'L2', 'IP']
40 | s_theta_mode='score', # ['score', 'decoder', 'identical'])
41 | norm_energy='identical' # ['identical', 'std', 'minus']
42 | ):
43 | super(PoseEnergyNet, self).__init__()
44 | self.device = device
45 | self.regression_head = regression_head
46 | self.act = nn.ReLU(True)
47 | self.pose_dim = get_pose_dim(pose_mode)
48 | self.energy_mode = energy_mode
49 | self.s_theta_mode = s_theta_mode
50 | self.norm_energy = norm_energy
51 |
52 | ''' encode pose '''
53 | self.pose_encoder = nn.Sequential(
54 | nn.Linear(self.pose_dim, 256),
55 | self.act,
56 | nn.Linear(256, 256),
57 | self.act,
58 | )
59 |
60 | ''' encode t '''
61 | self.t_encoder = nn.Sequential(
62 | GaussianFourierProjection(embed_dim=128),
63 | # self.act,
64 | nn.Linear(128, 128),
65 | self.act,
66 | )
67 |
68 | ''' fusion tail '''
69 | if self.regression_head == 'RT':
70 | self.fusion_tail = nn.Sequential(
71 | nn.Linear(128+256+1024, 512),
72 | self.act,
73 | zero_module(nn.Linear(512, self.pose_dim)),
74 | )
75 |
76 |
77 | elif self.regression_head == 'R_and_T':
78 | ''' rotation regress head '''
79 | self.fusion_tail_rot = nn.Sequential(
80 | nn.Linear(128+256+1024, 256),
81 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
82 | self.act,
83 | zero_module(nn.Linear(256, self.pose_dim-3)),
84 | )
85 |
86 | ''' tranalation regress head '''
87 | self.fusion_tail_trans = nn.Sequential(
88 | nn.Linear(128+256+1024, 256),
89 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
90 | self.act,
91 | zero_module(nn.Linear(256, 3)),
92 | )
93 |
94 |
95 | elif self.regression_head == 'Rx_Ry_and_T':
96 | if pose_mode != 'rot_matrix':
97 | raise NotImplementedError
98 | ''' rotation_x_axis regress head '''
99 | self.fusion_tail_rot_x = nn.Sequential(
100 | nn.Linear(128+256+1024, 256),
101 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
102 | self.act,
103 | zero_module(nn.Linear(256, 3)),
104 | )
105 | self.fusion_tail_rot_y = nn.Sequential(
106 | nn.Linear(128+256+1024, 256),
107 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
108 | self.act,
109 | zero_module(nn.Linear(256, 3)),
110 | )
111 |
112 | ''' tranalation regress head '''
113 | self.fusion_tail_trans = nn.Sequential(
114 | nn.Linear(128+256+1024, 256),
115 | # nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
116 | self.act,
117 | zero_module(nn.Linear(256, 3)),
118 | )
119 |
120 | else:
121 | raise NotImplementedError
122 |
123 |
124 | self.marginal_prob_func = marginal_prob_func
125 |
126 |
127 | def output_zero_initial(self):
128 | if self.regression_head == 'RT':
129 | zero_module(self.fusion_tail[-1])
130 |
131 | elif self.regression_head == 'R_and_T':
132 | zero_module(self.fusion_tail_rot[-1])
133 | zero_module(self.fusion_tail_trans[-1])
134 |
135 | elif self.regression_head == 'Rx_Ry_and_T':
136 | zero_module(self.fusion_tail_rot_x[-1])
137 | zero_module(self.fusion_tail_rot_y[-1])
138 | zero_module(self.fusion_tail_trans[-1])
139 | else:
140 | raise NotImplementedError
141 |
142 |
143 | def get_energy(self, pts_feat, sampled_pose, t, decoupled_rt=True):
144 | t_feat = self.t_encoder(t.squeeze(1))
145 | pose_feat = self.pose_encoder(sampled_pose)
146 |
147 | total_feat = torch.cat([pts_feat, t_feat, pose_feat], dim=-1)
148 | _, std = self.marginal_prob_func(total_feat, t)
149 |
150 | ''' get f_{theta} '''
151 | if self.regression_head == 'RT':
152 | f_theta = self.fusion_tail(total_feat)
153 | elif self.regression_head == 'R_and_T':
154 | rot = self.fusion_tail_rot(total_feat)
155 | trans = self.fusion_tail_trans(total_feat)
156 | f_theta = torch.cat([rot, trans], dim=-1)
157 | elif self.regression_head == 'Rx_Ry_and_T':
158 | rot_x = self.fusion_tail_rot_x(total_feat)
159 | rot_y = self.fusion_tail_rot_y(total_feat)
160 | trans = self.fusion_tail_trans(total_feat)
161 | f_theta = torch.cat([rot_x, rot_y, trans], dim=-1)
162 | else:
163 | raise NotImplementedError
164 |
165 | ''' get s_{theta} '''
166 | if self.s_theta_mode == 'score':
167 | s_theta = f_theta / std
168 | elif self.s_theta_mode == 'decoder':
169 | s_theta = sampled_pose - std * f_theta
170 | elif self.s_theta_mode == 'identical':
171 | s_theta = f_theta
172 | else:
173 | raise NotImplementedError
174 |
175 | ''' get energy '''
176 | if self.energy_mode == 'DAE':
177 | energy = - 0.5 * torch.sum((sampled_pose - s_theta) ** 2, dim=-1)
178 | elif self.energy_mode == 'L2':
179 | energy = - 0.5 * torch.sum(s_theta ** 2, dim=-1)
180 | elif self.energy_mode == 'IP': # Inner Product
181 | energy = torch.sum(sampled_pose * s_theta, dim=-1)
182 | if decoupled_rt:
183 | energy_rot = torch.sum(sampled_pose[:, :-3] * s_theta[:, :-3], dim=-1)
184 | energy_trans = torch.sum(sampled_pose[:, -3:] * s_theta[:, -3:], dim=-1)
185 | energy = torch.cat((energy_rot.unsqueeze(-1), energy_trans.unsqueeze(-1)), dim=-1)
186 | else:
187 | raise NotImplementedError
188 |
189 | ''' normalisation '''
190 | if self.norm_energy == 'identical':
191 | pass
192 | elif self.norm_energy == 'std':
193 | energy = energy / (std + 1e-7)
194 | elif self.norm_energy == 'minus': # Inner Product
195 | energy = - energy
196 | else:
197 | raise NotImplementedError
198 | return energy
199 |
200 |
201 | def forward(self, data, return_item='score'):
202 | pts_feat = data['pts_feat']
203 | sampled_pose = data['sampled_pose']
204 | t = data['t']
205 |
206 | if return_item == 'energy':
207 | energy = self.get_energy(pts_feat, sampled_pose, t)
208 | return energy
209 |
210 | with TemporaryGrad():
211 | inp_variable_sampled_pose = Variable(sampled_pose, requires_grad=True)
212 | energy = self.get_energy(pts_feat, inp_variable_sampled_pose, t, decoupled_rt=False)
213 | scores, = torch.autograd.grad(energy, inp_variable_sampled_pose,
214 | grad_outputs=energy.data.new(energy.shape).fill_(1),
215 | create_graph=True)
216 | # inp_variable_sampled_pose = None # release the variable
217 | if return_item == 'score':
218 | return scores
219 | elif return_item == 'score_and_energy':
220 | return scores, energy
221 | else:
222 | raise NotImplementedError
223 |
224 |
225 |
--------------------------------------------------------------------------------
/networks/gf_algorithms/losses.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from ipdb import set_trace
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | def loss_fn_edm(
9 | model,
10 | data,
11 | marginal_prob_func,
12 | sde_fn,
13 | eps=1e-5,
14 | likelihood_weighting=False,
15 | P_mean=-1.2,
16 | P_std=1.2,
17 | sigma_data=1.4148,
18 | sigma_min=0.002,
19 | sigma_max=80,
20 | ):
21 | pts = data['zero_mean_pts']
22 | y = data['zero_mean_gt_pose']
23 | bs = pts.shape[0]
24 |
25 | # get noise n
26 | z = torch.randn_like(y) # [bs, pose_dim]
27 | # log_sigma_t = torch.randn([bs, 1], device=device) # [bs, 1]
28 | # sigma_t = (P_std * log_sigma_t + P_mean).exp() # [bs, 1]
29 | log_sigma_t = torch.rand([bs, 1], device=device) # [bs, 1]
30 | sigma_t = (math.log(sigma_min) + log_sigma_t * (math.log(sigma_max) - math.log(sigma_min))).exp() # [bs, 1]
31 |
32 | n = z * sigma_t
33 |
34 | perturbed_x = y + n # [bs, pose_dim]
35 | data['sampled_pose'] = perturbed_x
36 | data['t'] = sigma_t # t and sigma is interchangable in EDM
37 | data, output = model(data) # [bs, pose_dim]
38 |
39 | # set_trace()
40 |
41 | # same as VE
42 | loss_ = torch.mean(torch.sum(((output * sigma_t + z)**2).view(bs, -1), dim=-1))
43 |
44 | return loss_
45 |
46 |
47 | def loss_fn(
48 | model,
49 | data,
50 | marginal_prob_func,
51 | sde_fn,
52 | eps=1e-5,
53 | likelihood_weighting=False,
54 | teacher_model=None,
55 | pts_feat_teacher=None
56 | ):
57 | pts = data['zero_mean_pts']
58 | gt_pose = data['zero_mean_gt_pose']
59 |
60 | ''' get std '''
61 | bs = pts.shape[0]
62 | random_t = torch.rand(bs, device=device) * (1. - eps) + eps # [bs, ]
63 | random_t = random_t.unsqueeze(-1) # [bs, 1]
64 | mu, std = marginal_prob_func(gt_pose, random_t) # [bs, pose_dim], [bs]
65 | std = std.view(-1, 1) # [bs, 1]
66 |
67 | ''' perturb data and get estimated score '''
68 | z = torch.randn_like(gt_pose) # [bs, pose_dim]
69 | perturbed_x = mu + z * std # [bs, pose_dim]
70 | data['sampled_pose'] = perturbed_x
71 | data['t'] = random_t
72 | estimated_score = model(data) # [bs, pose_dim]
73 |
74 | ''' get target score '''
75 | if teacher_model is None:
76 | # theoretic estimation
77 | target_score = - z * std / (std ** 2)
78 | else:
79 | # distillation
80 | pts_feat_student = data['pts_feat'].clone()
81 | data['pts_feat'] = pts_feat_teacher
82 | target_score = teacher_model(data)
83 | data['pts_feat'] = pts_feat_student
84 |
85 | ''' loss weighting '''
86 | loss_weighting = std ** 2
87 | loss_ = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score)**2).view(bs, -1), dim=-1))
88 |
89 | return loss_
90 |
91 |
92 |
--------------------------------------------------------------------------------
/networks/gf_algorithms/samplers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from scipy import integrate
7 | from ipdb import set_trace
8 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
9 | from utils.genpose_utils import get_pose_dim
10 | from utils.misc import normalize_rotation
11 |
12 |
13 | def global_prior_likelihood(z, sigma_max):
14 | """The likelihood of a Gaussian distribution with mean zero and
15 | standard deviation sigma."""
16 | # z: [bs, pose_dim]
17 | shape = z.shape
18 | N = np.prod(shape[1:]) # pose_dim
19 | return -N / 2. * torch.log(2*np.pi*sigma_max**2) - torch.sum(z**2, dim=-1) / (2 * sigma_max**2)
20 |
21 |
22 | def cond_ode_likelihood(
23 | score_model,
24 | data,
25 | prior,
26 | sde_coeff,
27 | marginal_prob_fn,
28 | atol=1e-5,
29 | rtol=1e-5,
30 | device='cuda',
31 | eps=1e-5,
32 | num_steps=None,
33 | pose_mode='quat_wxyz',
34 | init_x=None,
35 | ):
36 |
37 | pose_dim = get_pose_dim(pose_mode)
38 | batch_size = data['pts'].shape[0]
39 | epsilon = prior((batch_size, pose_dim)).to(device)
40 | init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
41 | shape = init_x.shape
42 | init_logp = np.zeros((shape[0],)) # [bs]
43 | init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
44 |
45 | def score_eval_wrapper(data):
46 | """A wrapper of the score-based model for use by the ODE solver."""
47 | with torch.no_grad():
48 | score = score_model(data)
49 | return score.cpu().numpy().reshape((-1,))
50 |
51 | def divergence_eval(data, epsilon):
52 | """Compute the divergence of the score-based model with Skilling-Hutchinson."""
53 | # save ckpt of sampled_pose
54 | origin_sampled_pose = data['sampled_pose'].clone()
55 | with torch.enable_grad():
56 | # make sampled_pose differentiable
57 | data['sampled_pose'].requires_grad_(True)
58 | score = score_model(data)
59 | score_energy = torch.sum(score * epsilon) # [, ]
60 | grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
61 | # reset sampled_pose
62 | data['sampled_pose'] = origin_sampled_pose
63 | return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
64 |
65 | def divergence_eval_wrapper(data):
66 | """A wrapper for evaluating the divergence of score for the black-box ODE solver."""
67 | with torch.no_grad():
68 | # Compute likelihood.
69 | div = divergence_eval(data, epsilon) # [bs, 1]
70 | return div.cpu().numpy().reshape((-1,)).astype(np.float64)
71 |
72 | def ode_func(t, inp):
73 | """The ODE function for use by the ODE solver."""
74 | # split x, logp from inp
75 | x = inp[:-shape[0]]
76 | logp = inp[-shape[0]:] # haha, actually we do not need use logp here
77 | # calc x-grad
78 | x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
79 | time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
80 | drift, diffusion = sde_coeff(torch.tensor(t))
81 | drift = drift.cpu().numpy()
82 | diffusion = diffusion.cpu().numpy()
83 | data['sampled_pose'] = x
84 | data['t'] = time_steps
85 | x_grad = drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
86 | # calc logp-grad
87 | logp_grad = drift - 0.5 * (diffusion**2) * divergence_eval_wrapper(data)
88 | # concat curr grad
89 | return np.concatenate([x_grad, logp_grad], axis=0)
90 |
91 | # Run the black-box ODE solver, note the
92 | res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
93 | zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
94 | z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
95 | delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
96 | _, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
97 | prior_logp = global_prior_likelihood(z, sigma_max)
98 | log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
99 | return z, log_likelihoods
100 |
101 |
102 | def cond_pc_sampler(
103 | score_model,
104 | data,
105 | prior,
106 | sde_coeff,
107 | num_steps=500,
108 | snr=0.16,
109 | device='cuda',
110 | eps=1e-5,
111 | pose_mode='quat_wxyz',
112 | init_x=None,
113 | ):
114 |
115 | pose_dim = get_pose_dim(pose_mode)
116 | batch_size = data['pts'].shape[0]
117 | init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
118 | time_steps = torch.linspace(1., eps, num_steps, device=device)
119 | step_size = time_steps[0] - time_steps[1]
120 | noise_norm = np.sqrt(pose_dim)
121 | x = init_x
122 | poses = []
123 | with torch.no_grad():
124 | for time_step in time_steps:
125 | batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
126 | # Corrector step (Langevin MCMC)
127 | data['sampled_pose'] = x
128 | data['t'] = batch_time_step
129 | grad = score_model(data)
130 | grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
131 | langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
132 | x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
133 |
134 | # normalisation
135 | if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
136 | # quat, should be normalised
137 | x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
138 | elif pose_mode == 'euler_xyz':
139 | pass
140 | else:
141 | # rotation(x axis, y axis), should be normalised
142 | x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
143 | x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
144 |
145 | # Predictor step (Euler-Maruyama)
146 | drift, diffusion = sde_coeff(batch_time_step)
147 | drift = drift - diffusion**2*grad # R-SDE
148 | mean_x = x + drift * step_size
149 | x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
150 |
151 | # normalisation
152 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode)
153 | poses.append(x.unsqueeze(0))
154 |
155 | xs = torch.cat(poses, dim=0)
156 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
157 | mean_x[:, -3:] += data['pts_center']
158 | mean_x[:, :-3] = normalize_rotation(mean_x[:, :-3], pose_mode)
159 | # The last step does not include any noise
160 | return xs.permute(1, 0, 2), mean_x
161 |
162 |
163 | def cond_ode_sampler(
164 | score_model,
165 | data,
166 | prior,
167 | sde_coeff,
168 | atol=1e-5,
169 | rtol=1e-5,
170 | device='cuda',
171 | eps=1e-5,
172 | T=1.0,
173 | num_steps=None,
174 | pose_mode='quat_wxyz',
175 | denoise=True,
176 | init_x=None,
177 | ):
178 | pose_dim = get_pose_dim(pose_mode)
179 | batch_size=data['pts'].shape[0]
180 | init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim), T=T).to(device)
181 | shape = init_x.shape
182 |
183 | def score_eval_wrapper(data):
184 | """A wrapper of the score-based model for use by the ODE solver."""
185 | with torch.no_grad():
186 | score = score_model(data)
187 | return score.cpu().numpy().reshape((-1,))
188 |
189 | def ode_func(t, x):
190 | """The ODE function for use by the ODE solver."""
191 | x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
192 | time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
193 | drift, diffusion = sde_coeff(torch.tensor(t))
194 | drift = drift.cpu().numpy()
195 | diffusion = diffusion.cpu().numpy()
196 | data['sampled_pose'] = x
197 | data['t'] = time_steps
198 | return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
199 |
200 | # Run the black-box ODE solver, note the
201 | t_eval = None
202 | if num_steps is not None:
203 | # num_steps, from T -> eps
204 | t_eval = np.linspace(T, eps, num_steps)
205 | res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', t_eval=t_eval)
206 | xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
207 | x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
208 | # denoise, using the predictor step in P-C sampler
209 | if denoise:
210 | # Reverse diffusion predictor for denoising
211 | vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
212 | drift, diffusion = sde_coeff(vec_eps)
213 | data['sampled_pose'] = x.float()
214 | data['t'] = vec_eps
215 | grad = score_model(data)
216 | drift = drift - diffusion**2*grad # R-SDE
217 | mean_x = x + drift * ((1-eps)/(1000 if num_steps is None else num_steps))
218 | x = mean_x
219 |
220 | num_steps = xs.shape[0]
221 | xs = xs.reshape(batch_size*num_steps, -1)
222 | xs[:, :-3] = normalize_rotation(xs[:, :-3], pose_mode)
223 | xs = xs.reshape(num_steps, batch_size, -1)
224 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
225 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode)
226 | x[:, -3:] += data['pts_center']
227 | return xs.permute(1, 0, 2), x
228 |
229 |
230 | def cond_edm_sampler(
231 | decoder_model, data, prior_fn, randn_like=torch.randn_like,
232 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
233 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
234 | pose_mode='quat_wxyz', device='cuda'
235 | ):
236 | pose_dim = get_pose_dim(pose_mode)
237 | batch_size = data['pts'].shape[0]
238 | latents = prior_fn((batch_size, pose_dim)).to(device)
239 |
240 | # Time step discretization. note that sigma and t is interchangable
241 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
242 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
243 | t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
244 |
245 | def decoder_wrapper(decoder, data, x, t):
246 | # save temp
247 | x_, t_= data['sampled_pose'], data['t']
248 | # init data
249 | data['sampled_pose'], data['t'] = x, t
250 | # denoise
251 | data, denoised = decoder(data)
252 | # recover data
253 | data['sampled_pose'], data['t'] = x_, t_
254 | return denoised.to(torch.float64)
255 |
256 | # Main sampling loop.
257 | x_next = latents.to(torch.float64) * t_steps[0]
258 | xs = []
259 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
260 | x_cur = x_next
261 |
262 | # Increase noise temporarily.
263 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
264 | t_hat = torch.as_tensor(t_cur + gamma * t_cur)
265 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
266 |
267 | # Euler step.
268 | denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
269 | d_cur = (x_hat - denoised) / t_hat
270 | x_next = x_hat + (t_next - t_hat) * d_cur
271 |
272 | # Apply 2nd order correction.
273 | if i < num_steps - 1:
274 | denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
275 | d_prime = (x_next - denoised) / t_next
276 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
277 | xs.append(x_next.unsqueeze(0))
278 |
279 | xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
280 | x = xs[-1] # [bs, pose_dim]
281 |
282 | # post-processing
283 | xs = xs.reshape(batch_size*num_steps, -1)
284 | xs[:, :-3] = normalize_rotation(xs[:, :-3], pose_mode)
285 | xs = xs.reshape(num_steps, batch_size, -1)
286 | xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
287 | x[:, :-3] = normalize_rotation(x[:, :-3], pose_mode)
288 | x[:, -3:] += data['pts_center']
289 |
290 | return xs.permute(1, 0, 2), x
291 |
292 |
293 |
--------------------------------------------------------------------------------
/networks/gf_algorithms/score_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class ExponentialMovingAverage:
4 | """
5 | Maintains (exponential) moving average of a set of parameters.
6 | """
7 |
8 | def __init__(self, parameters, decay, use_num_updates=True):
9 | """
10 | Args:
11 | parameters: Iterable of `torch.nn.Parameter`; usually the result of
12 | `model.parameters()`.
13 | decay: The exponential decay.
14 | use_num_updates: Whether to use number of updates when computing
15 | averages.
16 | """
17 | if decay < 0.0 or decay > 1.0:
18 | raise ValueError('Decay must be between 0 and 1')
19 | self.decay = decay
20 | self.num_updates = 0 if use_num_updates else None
21 | self.shadow_params = [p.clone().detach()
22 | for p in parameters if p.requires_grad]
23 | self.collected_params = []
24 |
25 | def update(self, parameters):
26 | """
27 | Update currently maintained parameters.
28 |
29 | Call this every time the parameters are updated, such as the result of
30 | the `optimizer.step()` call.
31 |
32 | Args:
33 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of
34 | parameters used to initialize this object.
35 | """
36 | decay = self.decay
37 | if self.num_updates is not None:
38 | self.num_updates += 1
39 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
40 | one_minus_decay = 1.0 - decay
41 | with torch.no_grad():
42 | parameters = [p for p in parameters if p.requires_grad]
43 | for s_param, param in zip(self.shadow_params, parameters):
44 | s_param.sub_(one_minus_decay * (s_param - param)) # only update the ema-params
45 |
46 |
47 | def copy_to(self, parameters):
48 | """
49 | Copy current parameters into given collection of parameters.
50 |
51 | Args:
52 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
53 | updated with the stored moving averages.
54 | """
55 | parameters = [p for p in parameters if p.requires_grad]
56 | for s_param, param in zip(self.shadow_params, parameters):
57 | if param.requires_grad:
58 | param.data.copy_(s_param.data)
59 |
60 | def store(self, parameters):
61 | """
62 | Save the current parameters for restoring later.
63 |
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 |
78 | Args:
79 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
80 | updated with the stored parameters.
81 | """
82 | for c_param, param in zip(self.collected_params, parameters):
83 | param.data.copy_(c_param.data)
84 |
85 | def state_dict(self):
86 | return dict(decay=self.decay, num_updates=self.num_updates,
87 | shadow_params=self.shadow_params)
88 |
89 | def load_state_dict(self, state_dict):
90 | self.decay = state_dict['decay']
91 | self.num_updates = state_dict['num_updates']
92 | self.shadow_params = state_dict['shadow_params']
93 |
--------------------------------------------------------------------------------
/networks/gf_algorithms/sde.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import functools
4 | import torch
5 | import numpy as np
6 | from ipdb import set_trace
7 | from scipy import integrate
8 | from utils.genpose_utils import get_pose_dim
9 |
10 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
11 |
12 |
13 | #----- VE SDE -----
14 | #------------------
15 | def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
16 | std = sigma_min * (sigma_max / sigma_min) ** t
17 | mean = x
18 | return mean, std
19 |
20 | def ve_sde(t, sigma_min=0.01, sigma_max=90):
21 | sigma = sigma_min * (sigma_max / sigma_min) ** t
22 | drift_coeff = torch.tensor(0)
23 | diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
24 | return drift_coeff, diffusion_coeff
25 |
26 | def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
27 | _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
28 | return torch.randn(*shape) * sigma_max_prior
29 |
30 | #----- VP SDE -----
31 | #------------------
32 | def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
33 | log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
34 | mean = torch.exp(log_mean_coeff) * x
35 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
36 | return mean, std
37 |
38 | def vp_sde(t, beta_0=0.1, beta_1=20):
39 | beta_t = beta_0 + t * (beta_1 - beta_0)
40 | drift_coeff = -0.5 * beta_t
41 | diffusion_coeff = torch.sqrt(beta_t)
42 | return drift_coeff, diffusion_coeff
43 |
44 | def vp_prior(shape, beta_0=0.1, beta_1=20):
45 | return torch.randn(*shape)
46 |
47 | #----- sub-VP SDE -----
48 | #----------------------
49 | def subvp_marginal_prob(x, t, beta_0, beta_1):
50 | log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
51 | mean = torch.exp(log_mean_coeff) * x
52 | std = 1 - torch.exp(2. * log_mean_coeff)
53 | return mean, std
54 |
55 | def subvp_sde(t, beta_0, beta_1):
56 | beta_t = beta_0 + t * (beta_1 - beta_0)
57 | drift_coeff = -0.5 * beta_t
58 | discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
59 | diffusion_coeff = torch.sqrt(beta_t * discount)
60 | return drift_coeff, diffusion_coeff
61 |
62 | def subvp_prior(shape, beta_0=0.1, beta_1=20):
63 | return torch.randn(*shape)
64 |
65 | #----- EDM SDE -----
66 | #------------------
67 | def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
68 | std = t
69 | mean = x
70 | return mean, std
71 |
72 | def edm_sde(t, sigma_min=0.002, sigma_max=80):
73 | drift_coeff = torch.tensor(0)
74 | diffusion_coeff = torch.sqrt(2 * t)
75 | return drift_coeff, diffusion_coeff
76 |
77 | def edm_prior(shape, sigma_min=0.002, sigma_max=80):
78 | return torch.randn(*shape) * sigma_max
79 |
80 | def init_sde(sde_mode):
81 | # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
82 | if sde_mode == 'edm':
83 | sigma_min = 0.002
84 | sigma_max = 80
85 | eps = 0.002
86 | prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
87 | marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
88 | sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
89 | T = sigma_max
90 | elif sde_mode == 've':
91 | sigma_min = 0.01
92 | sigma_max = 50
93 | eps = 1e-5
94 | marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
95 | sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
96 | T = 1.0
97 | prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
98 | elif sde_mode == 'vp':
99 | beta_0 = 0.1
100 | beta_1 = 20
101 | eps = 1e-3
102 | prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
103 | marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
104 | sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
105 | T = 1.0
106 | elif sde_mode == 'subvp':
107 | beta_0 = 0.1
108 | beta_1 = 20
109 | eps = 1e-3
110 | prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
111 | marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
112 | sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
113 | T = 1.0
114 | else:
115 | raise NotImplementedError
116 | return prior_fn, marginal_prob_fn, sde_fn, eps, T
117 |
118 |
--------------------------------------------------------------------------------
/networks/posenet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ipdb import set_trace
7 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
8 | from networks.pts_encoder.pointnets import PointNetfeat
9 | from networks.pts_encoder.pointnet2 import Pointnet2ClsMSG
10 | from networks.gf_algorithms.samplers import cond_ode_likelihood, cond_ode_sampler, cond_pc_sampler
11 | from networks.gf_algorithms.scorenet import PoseScoreNet, PoseDecoderNet
12 | from networks.gf_algorithms.energynet import PoseEnergyNet
13 | from networks.gf_algorithms.sde import init_sde
14 | from configs.config import get_config
15 |
16 |
17 |
18 | class GFObjectPose(nn.Module):
19 | def __init__(self, cfg, prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T):
20 | super(GFObjectPose, self).__init__()
21 |
22 | self.cfg = cfg
23 | self.device = cfg.device
24 | self.is_testing = False
25 |
26 | ''' Load model, define SDE '''
27 | # init SDE config
28 | self.prior_fn = prior_fn
29 | self.marginal_prob_fn = marginal_prob_fn
30 | self.sde_fn = sde_fn
31 | self.sampling_eps = sampling_eps
32 | self.T = T
33 | # self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps = init_sde(cfg.sde_mode)
34 |
35 | ''' encode pts '''
36 | if self.cfg.pts_encoder == 'pointnet':
37 | self.pts_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
38 | elif self.cfg.pts_encoder == 'pointnet2':
39 | self.pts_encoder = Pointnet2ClsMSG(0)
40 | elif self.cfg.pts_encoder == 'pointnet_and_pointnet2':
41 | self.pts_pointnet_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
42 | self.pts_pointnet2_encoder = Pointnet2ClsMSG(0)
43 | self.fusion_layer = nn.Linear(2048, 1024)
44 | self.act = nn.ReLU()
45 | else:
46 | raise NotImplementedError
47 |
48 | ''' score network'''
49 | # if self.cfg.sde_mode == 'edm':
50 | # self.pose_score_net = PoseDecoderNet(
51 | # self.marginal_prob_fn,
52 | # sigma_data=1.4148,
53 | # pose_mode=self.cfg.pose_mode,
54 | # regression_head=self.cfg.regression_head
55 | # )
56 | # else:
57 | per_point_feat = False
58 | if self.cfg.posenet_mode == 'score':
59 | self.pose_score_net = PoseScoreNet(self.marginal_prob_fn, self.cfg.pose_mode, self.cfg.regression_head, per_point_feat)
60 | elif self.cfg.posenet_mode == 'energy':
61 | self.pose_score_net = PoseEnergyNet(
62 | marginal_prob_func=self.marginal_prob_fn,
63 | pose_mode=self.cfg.pose_mode,
64 | regression_head=self.cfg.regression_head,
65 | energy_mode=self.cfg.energy_mode,
66 | s_theta_mode=self.cfg.s_theta_mode,
67 | norm_energy=self.cfg.norm_energy)
68 | ''' ToDo: ranking network '''
69 |
70 |
71 | def extract_pts_feature(self, data):
72 | """extract the input pointcloud feature
73 |
74 | Args:
75 | data (dict): batch example without pointcloud feature. {'pts': [bs, num_pts, 3], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
76 | Returns:
77 | data (dict): batch example with pointcloud feature. {'pts': [bs, num_pts, 3], 'pts_feat': [bs, c], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
78 | """
79 | pts = data['pts']
80 | if self.cfg.pts_encoder == 'pointnet':
81 | pts_feat = self.pts_encoder(pts.permute(0, 2, 1)) # -> (bs, 3, 1024)
82 | elif self.cfg.pts_encoder in ['pointnet2']:
83 | pts_feat = self.pts_encoder(pts)
84 | elif self.cfg.pts_encoder == 'pointnet_and_pointnet2':
85 | pts_pointnet_feat = self.pts_pointnet_encoder(pts.permute(0, 2, 1))
86 | pts_pointnet2_feat = self.pts_pointnet2_encoder(pts)
87 | pts_feat = self.fusion_layer(torch.cat((pts_pointnet_feat, pts_pointnet2_feat), dim=-1))
88 | pts_feat = self.act(pts_feat)
89 | else:
90 | raise NotImplementedError
91 | return pts_feat
92 |
93 |
94 | def sample(self, data, sampler, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
95 | if sampler == 'pc':
96 | in_process_sample, res = cond_pc_sampler(
97 | score_model=self,
98 | data=data,
99 | prior=self.prior_fn,
100 | sde_coeff=self.sde_fn,
101 | num_steps=self.cfg.sampling_steps,
102 | snr=snr,
103 | device=self.device,
104 | eps=self.sampling_eps,
105 | pose_mode=self.cfg.pose_mode,
106 | init_x=init_x
107 | )
108 |
109 | elif sampler == 'ode':
110 | T0 = self.T if T0 is None else T0
111 | in_process_sample, res = cond_ode_sampler(
112 | score_model=self,
113 | data=data,
114 | prior=self.prior_fn,
115 | sde_coeff=self.sde_fn,
116 | atol=atol,
117 | rtol=rtol,
118 | device=self.device,
119 | eps=self.sampling_eps,
120 | T=T0,
121 | num_steps=self.cfg.sampling_steps,
122 | pose_mode=self.cfg.pose_mode,
123 | denoise=denoise,
124 | init_x=init_x
125 | )
126 |
127 | else:
128 | raise NotImplementedError
129 |
130 | return in_process_sample, res
131 |
132 |
133 | def calc_likelihood(self, data, atol=1e-5, rtol=1e-5):
134 | latent_code, log_likelihoods = cond_ode_likelihood(
135 | score_model=self,
136 | data=data,
137 | prior=self.prior_fn,
138 | sde_coeff=self.sde_fn,
139 | marginal_prob_fn=self.marginal_prob_fn,
140 | atol=atol,
141 | rtol=rtol,
142 | device=self.device,
143 | eps=self.sampling_eps,
144 | num_steps=self.cfg.sampling_steps,
145 | pose_mode=self.cfg.pose_mode,
146 | )
147 | return log_likelihoods
148 |
149 |
150 | def forward(self, data, mode='score', init_x=None, T0=None):
151 | '''
152 | Args:
153 | data, dict {
154 | 'pts': [bs, num_pts, 3]
155 | 'pts_feat': [bs, c]
156 | 'sampled_pose': [bs, pose_dim]
157 | 't': [bs, 1]
158 | }
159 | '''
160 | if mode == 'score':
161 | out_score = self.pose_score_net(data) # normalisation
162 | return out_score
163 | elif mode == 'energy':
164 | out_energy = self.pose_score_net(data, return_item='energy')
165 | return out_energy
166 | elif mode == 'likelihood':
167 | likelihoods = self.calc_likelihood(data)
168 | return likelihoods
169 | elif mode == 'pts_feature':
170 | pts_feature = self.extract_pts_feature(data)
171 | return pts_feature
172 | elif mode == 'pc_sample':
173 | in_process_sample, res = self.sample(data, 'pc', init_x=init_x)
174 | return in_process_sample, res
175 | elif mode == 'ode_sample':
176 | in_process_sample, res = self.sample(data, 'ode', init_x=init_x, T0=T0)
177 | return in_process_sample, res
178 | else:
179 | raise NotImplementedError
180 |
181 |
182 |
183 | def test():
184 | def get_parameter_number(model):
185 | total_num = sum(p.numel() for p in model.parameters())
186 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
187 | return {'Total': total_num, 'Trainable': trainable_num}
188 | cfg = get_config()
189 | prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T = init_sde('ve')
190 | net = GFObjectPose(cfg, prior_fn, marginal_prob_fn, sde_fn, sampling_eps, T)
191 | net_parameters_num= get_parameter_number(net)
192 | print(net_parameters_num['Total'], net_parameters_num['Trainable'])
193 | if __name__ == '__main__':
194 | test()
195 |
196 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | import os
5 |
6 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
7 | from networks.pts_encoder.pointnet2_utils.pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG
8 | import networks.pts_encoder.pointnet2_utils.pointnet2.pytorch_utils as pt_utils
9 | from ipdb import set_trace
10 | from configs.config import get_config
11 |
12 |
13 | cfg = get_config()
14 |
15 |
16 | def get_model(input_channels=0):
17 | return Pointnet2MSG(input_channels=input_channels)
18 |
19 | MSG_CFG = {
20 | 'NPOINTS': [512, 256, 128, 64],
21 | 'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16]],
22 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32]],
23 | 'MLPS': [[[16, 16, 32], [32, 32, 64]],
24 | [[64, 64, 128], [64, 96, 128]],
25 | [[128, 196, 256], [128, 196, 256]],
26 | [[256, 256, 512], [256, 384, 512]]],
27 | 'FP_MLPS': [[64, 64], [128, 128], [256, 256], [512, 512]],
28 | 'CLS_FC': [128],
29 | 'DP_RATIO': 0.5,
30 | }
31 |
32 | ClsMSG_CFG = {
33 | 'NPOINTS': [512, 256, 128, 64, None],
34 | 'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
35 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]],
36 | 'MLPS': [[[16, 16, 32], [32, 32, 64]],
37 | [[64, 64, 128], [64, 96, 128]],
38 | [[128, 196, 256], [128, 196, 256]],
39 | [[256, 256, 512], [256, 384, 512]],
40 | [[512, 512], [512, 512]]],
41 | 'DP_RATIO': 0.5,
42 | }
43 |
44 | ClsMSG_CFG_Dense = {
45 | 'NPOINTS': [512, 256, 128, None],
46 | 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
47 | 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
48 | 'MLPS': [[[16, 16, 32], [32, 32, 64]],
49 | [[64, 64, 128], [64, 96, 128]],
50 | [[128, 196, 256], [128, 196, 256]],
51 | [[256, 256, 512], [256, 384, 512]]],
52 | 'DP_RATIO': 0.5,
53 | }
54 |
55 |
56 | ########## Best before 29th April ###########
57 | ClsMSG_CFG_Light = {
58 | 'NPOINTS': [512, 256, 128, None],
59 | 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
60 | 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
61 | 'MLPS': [[[16, 16, 32], [32, 32, 64]],
62 | [[64, 64, 128], [64, 96, 128]],
63 | [[128, 196, 256], [128, 196, 256]],
64 | [[256, 256, 512], [256, 384, 512]]],
65 | 'DP_RATIO': 0.5,
66 | }
67 |
68 |
69 | ClsMSG_CFG_Lighter= {
70 | 'NPOINTS': [512, 256, 128, 64, None],
71 | 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
72 | 'NSAMPLE': [[64], [32], [16], [8], [None]],
73 | 'MLPS': [[[32, 32, 64]],
74 | [[64, 64, 128]],
75 | [[128, 196, 256]],
76 | [[256, 256, 512]],
77 | [[512, 512, 1024]]],
78 | 'DP_RATIO': 0.5,
79 | }
80 |
81 |
82 | if cfg.pointnet2_params == 'light':
83 | SELECTED_PARAMS = ClsMSG_CFG_Light
84 | elif cfg.pointnet2_params == 'lighter':
85 | SELECTED_PARAMS = ClsMSG_CFG_Lighter
86 | elif cfg.pointnet2_params == 'dense':
87 | SELECTED_PARAMS = ClsMSG_CFG_Dense
88 | else:
89 | raise NotImplementedError
90 |
91 |
92 | class Pointnet2MSG(nn.Module):
93 | def __init__(self, input_channels=6):
94 | super().__init__()
95 |
96 | self.SA_modules = nn.ModuleList()
97 | channel_in = input_channels
98 |
99 | skip_channel_list = [input_channels]
100 | for k in range(MSG_CFG['NPOINTS'].__len__()):
101 | mlps = MSG_CFG['MLPS'][k].copy()
102 | channel_out = 0
103 | for idx in range(mlps.__len__()):
104 | mlps[idx] = [channel_in] + mlps[idx]
105 | channel_out += mlps[idx][-1]
106 |
107 | self.SA_modules.append(
108 | PointnetSAModuleMSG(
109 | npoint=MSG_CFG['NPOINTS'][k],
110 | radii=MSG_CFG['RADIUS'][k],
111 | nsamples=MSG_CFG['NSAMPLE'][k],
112 | mlps=mlps,
113 | use_xyz=True,
114 | bn=True
115 | )
116 | )
117 | skip_channel_list.append(channel_out)
118 | channel_in = channel_out
119 |
120 | self.FP_modules = nn.ModuleList()
121 |
122 | for k in range(MSG_CFG['FP_MLPS'].__len__()):
123 | pre_channel = MSG_CFG['FP_MLPS'][k + 1][-1] if k + 1 < len(MSG_CFG['FP_MLPS']) else channel_out
124 | self.FP_modules.append(
125 | PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + MSG_CFG['FP_MLPS'][k])
126 | )
127 |
128 | cls_layers = []
129 | pre_channel = MSG_CFG['FP_MLPS'][0][-1]
130 | for k in range(0, MSG_CFG['CLS_FC'].__len__()):
131 | cls_layers.append(pt_utils.Conv1d(pre_channel, MSG_CFG['CLS_FC'][k], bn=True))
132 | pre_channel = MSG_CFG['CLS_FC'][k]
133 | cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None))
134 | cls_layers.insert(1, nn.Dropout(0.5))
135 | self.cls_layer = nn.Sequential(*cls_layers)
136 |
137 |
138 | def _break_up_pc(self, pc):
139 | xyz = pc[..., 0:3].contiguous()
140 | features = (
141 | pc[..., 3:].transpose(1, 2).contiguous()
142 | if pc.size(-1) > 3 else None
143 | )
144 |
145 | return xyz, features
146 |
147 | def forward(self, pointcloud: torch.cuda.FloatTensor):
148 | xyz, features = self._break_up_pc(pointcloud)
149 |
150 | l_xyz, l_features = [xyz], [features]
151 | for i in range(len(self.SA_modules)):
152 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
153 |
154 | l_xyz.append(li_xyz)
155 | l_features.append(li_features)
156 |
157 | set_trace()
158 | for i in range(-1, -(len(self.FP_modules) + 1), -1):
159 | l_features[i - 1] = self.FP_modules[i](
160 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
161 | )
162 |
163 | return l_features[0]
164 |
165 |
166 | class Pointnet2ClsMSG(nn.Module):
167 | def __init__(self, input_channels=6):
168 | super().__init__()
169 |
170 | self.SA_modules = nn.ModuleList()
171 | channel_in = input_channels
172 |
173 | for k in range(SELECTED_PARAMS['NPOINTS'].__len__()):
174 | mlps = SELECTED_PARAMS['MLPS'][k].copy()
175 | channel_out = 0
176 | for idx in range(mlps.__len__()):
177 | mlps[idx] = [channel_in] + mlps[idx]
178 | channel_out += mlps[idx][-1]
179 |
180 | self.SA_modules.append(
181 | PointnetSAModuleMSG(
182 | npoint=SELECTED_PARAMS['NPOINTS'][k],
183 | radii=SELECTED_PARAMS['RADIUS'][k],
184 | nsamples=SELECTED_PARAMS['NSAMPLE'][k],
185 | mlps=mlps,
186 | use_xyz=True,
187 | bn=True
188 | )
189 | )
190 | channel_in = channel_out
191 |
192 |
193 | def _break_up_pc(self, pc):
194 | xyz = pc[..., 0:3].contiguous()
195 | features = (
196 | pc[..., 3:].transpose(1, 2).contiguous()
197 | if pc.size(-1) > 3 else None
198 | )
199 |
200 | return xyz, features
201 |
202 |
203 | def forward(self, pointcloud: torch.cuda.FloatTensor):
204 | xyz, features = self._break_up_pc(pointcloud)
205 |
206 | l_xyz, l_features = [xyz], [features]
207 | for i in range(len(self.SA_modules)):
208 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
209 | l_xyz.append(li_xyz)
210 | l_features.append(li_features)
211 | return l_features[-1].squeeze(-1)
212 |
213 |
214 | if __name__ == '__main__':
215 | seed = 100
216 | torch.manual_seed(seed)
217 | torch.cuda.manual_seed(seed)
218 | net = Pointnet2ClsMSG(0).cuda()
219 | pts = torch.randn(2, 1024, 3).cuda()
220 | print(torch.mean(pts, dim=1))
221 | pre = net(pts)
222 | print(pre.shape)
223 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/.gitignore:
--------------------------------------------------------------------------------
1 | pointnet2/build/
2 | pointnet2/dist/
3 | pointnet2/pointnet2.egg-info/
4 | __pycache__/
5 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Shaoshuai Shi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/README.md:
--------------------------------------------------------------------------------
1 | # Pointnet2.PyTorch
2 |
3 | * PyTorch implementation of [PointNet++](https://arxiv.org/abs/1706.02413) based on [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch).
4 | * Faster than the original codes by re-implementing the CUDA operations.
5 |
6 | ## Installation
7 | ### Requirements
8 | * Linux (tested on Ubuntu 14.04/16.04)
9 | * Python 3.6+
10 | * PyTorch 1.0
11 |
12 | ### Install
13 | Install this library by running the following command:
14 |
15 | ```shell
16 | cd pointnet2
17 | python setup.py install
18 | cd ../
19 | ```
20 |
21 | ## Examples
22 | Here I provide a simple example to use this library in the task of KITTI ourdoor foreground point cloud segmentation, and you could refer to the paper [PointRCNN](https://arxiv.org/abs/1812.04244) for the details of task description and foreground label generation.
23 |
24 | 1. Download the training data from [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) website and organize the downloaded files as follows:
25 | ```
26 | Pointnet2.PyTorch
27 | ├── pointnet2
28 | ├── tools
29 | │ ├──data
30 | │ │ ├── KITTI
31 | │ │ │ ├── ImageSets
32 | │ │ │ ├── object
33 | │ │ │ │ ├──training
34 | │ │ │ │ ├──calib & velodyne & label_2 & image_2
35 | │ │ train_and_eval.py
36 | ```
37 |
38 | 2. Run the following command to train and evaluate:
39 | ```shell
40 | cd tools
41 | python train_and_eval.py --batch_size 8 --epochs 100 --ckpt_save_interval 2
42 | ```
43 |
44 |
45 |
46 | ## Project using this repo:
47 | * [PointRCNN](https://github.com/sshaoshuai/PointRCNN): 3D object detector from raw point cloud.
48 |
49 | ## Acknowledgement
50 | * [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo.
51 | * [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++.
52 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from . import pointnet2_utils
6 | from . import pytorch_utils as pt_utils
7 | from typing import List
8 |
9 |
10 | class _PointnetSAModuleBase(nn.Module):
11 |
12 | def __init__(self):
13 | super().__init__()
14 | self.npoint = None
15 | self.groupers = None
16 | self.mlps = None
17 | self.pool_method = 'max_pool'
18 |
19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
20 | """
21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
22 | :param features: (B, N, C) tensor of the descriptors of the the features
23 | :param new_xyz:
24 | :return:
25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz
26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
27 | """
28 | new_features_list = []
29 |
30 | xyz_flipped = xyz.transpose(1, 2).contiguous()
31 | if new_xyz is None:
32 | new_xyz = pointnet2_utils.gather_operation(
33 | xyz_flipped,
34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint)
35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None
36 |
37 | for i in range(len(self.groupers)):
38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
39 |
40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
41 |
42 | if self.pool_method == 'max_pool':
43 | new_features = F.max_pool2d(
44 | new_features, kernel_size=[1, new_features.size(3)]
45 | ) # (B, mlp[-1], npoint, 1)
46 | elif self.pool_method == 'avg_pool':
47 | new_features = F.avg_pool2d(
48 | new_features, kernel_size=[1, new_features.size(3)]
49 | ) # (B, mlp[-1], npoint, 1)
50 | else:
51 | raise NotImplementedError
52 |
53 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
54 | new_features_list.append(new_features)
55 |
56 | return new_xyz, torch.cat(new_features_list, dim=1)
57 |
58 |
59 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
60 | """Pointnet set abstraction layer with multiscale grouping"""
61 |
62 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
63 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
64 | """
65 | :param npoint: int
66 | :param radii: list of float, list of radii to group with
67 | :param nsamples: list of int, number of samples in each ball query
68 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
69 | :param bn: whether to use batchnorm
70 | :param use_xyz:
71 | :param pool_method: max_pool / avg_pool
72 | :param instance_norm: whether to use instance_norm
73 | """
74 | super().__init__()
75 |
76 | assert len(radii) == len(nsamples) == len(mlps)
77 |
78 | self.npoint = npoint
79 | self.groupers = nn.ModuleList()
80 | self.mlps = nn.ModuleList()
81 | for i in range(len(radii)):
82 | radius = radii[i]
83 | nsample = nsamples[i]
84 | self.groupers.append(
85 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
86 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
87 | )
88 | mlp_spec = mlps[i]
89 | if use_xyz:
90 | mlp_spec[0] += 3
91 |
92 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
93 | self.pool_method = pool_method
94 |
95 |
96 | class PointnetSAModule(PointnetSAModuleMSG):
97 | """Pointnet set abstraction layer"""
98 |
99 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
100 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
101 | """
102 | :param mlp: list of int, spec of the pointnet before the global max_pool
103 | :param npoint: int, number of features
104 | :param radius: float, radius of ball
105 | :param nsample: int, number of samples in the ball query
106 | :param bn: whether to use batchnorm
107 | :param use_xyz:
108 | :param pool_method: max_pool / avg_pool
109 | :param instance_norm: whether to use instance_norm
110 | """
111 | super().__init__(
112 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
113 | pool_method=pool_method, instance_norm=instance_norm
114 | )
115 |
116 |
117 | class PointnetFPModule(nn.Module):
118 | r"""Propigates the features of one set to another"""
119 |
120 | def __init__(self, *, mlp: List[int], bn: bool = True):
121 | """
122 | :param mlp: list of int
123 | :param bn: whether to use batchnorm
124 | """
125 | super().__init__()
126 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
127 |
128 | def forward(
129 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
130 | ) -> torch.Tensor:
131 | """
132 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
133 | :param known: (B, m, 3) tensor of the xyz positions of the known features
134 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
135 | :param known_feats: (B, C2, m) tensor of features to be propigated
136 | :return:
137 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features
138 | """
139 | if known is not None:
140 | dist, idx = pointnet2_utils.three_nn(unknown, known)
141 | dist_recip = 1.0 / (dist + 1e-8)
142 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
143 | weight = dist_recip / norm
144 |
145 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
146 | else:
147 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
148 |
149 | if unknow_feats is not None:
150 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
151 | else:
152 | new_features = interpolated_feats
153 |
154 | new_features = new_features.unsqueeze(-1)
155 |
156 | new_features = self.mlp(new_features)
157 |
158 | return new_features.squeeze(-1)
159 |
160 |
161 | if __name__ == "__main__":
162 | pass
163 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch.autograd import Function
4 | import torch.nn as nn
5 | from typing import Tuple
6 | import sys
7 |
8 | import pointnet2_cuda as pointnet2
9 |
10 |
11 | class FurthestPointSampling(Function):
12 | @staticmethod
13 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
14 | """
15 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
16 | minimum distance
17 | :param ctx:
18 | :param xyz: (B, N, 3) where N > npoint
19 | :param npoint: int, number of features in the sampled set
20 | :return:
21 | output: (B, npoint) tensor containing the set
22 | """
23 | assert xyz.is_contiguous()
24 |
25 | B, N, _ = xyz.size()
26 | output = torch.cuda.IntTensor(B, npoint)
27 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
28 |
29 | pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
30 | return output
31 |
32 | @staticmethod
33 | def backward(xyz, a=None):
34 | return None, None
35 |
36 |
37 | furthest_point_sample = FurthestPointSampling.apply
38 |
39 |
40 | class GatherOperation(Function):
41 |
42 | @staticmethod
43 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
44 | """
45 | :param ctx:
46 | :param features: (B, C, N)
47 | :param idx: (B, npoint) index tensor of the features to gather
48 | :return:
49 | output: (B, C, npoint)
50 | """
51 | assert features.is_contiguous()
52 | assert idx.is_contiguous()
53 |
54 | B, npoint = idx.size()
55 | _, C, N = features.size()
56 | output = torch.cuda.FloatTensor(B, C, npoint)
57 |
58 | pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
59 |
60 | ctx.for_backwards = (idx, C, N)
61 | return output
62 |
63 | @staticmethod
64 | def backward(ctx, grad_out):
65 | idx, C, N = ctx.for_backwards
66 | B, npoint = idx.size()
67 |
68 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
69 | grad_out_data = grad_out.data.contiguous()
70 | pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
71 | return grad_features, None
72 |
73 |
74 | gather_operation = GatherOperation.apply
75 |
76 |
77 | class ThreeNN(Function):
78 |
79 | @staticmethod
80 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
81 | """
82 | Find the three nearest neighbors of unknown in known
83 | :param ctx:
84 | :param unknown: (B, N, 3)
85 | :param known: (B, M, 3)
86 | :return:
87 | dist: (B, N, 3) l2 distance to the three nearest neighbors
88 | idx: (B, N, 3) index of 3 nearest neighbors
89 | """
90 | assert unknown.is_contiguous()
91 | assert known.is_contiguous()
92 |
93 | B, N, _ = unknown.size()
94 | m = known.size(1)
95 | dist2 = torch.cuda.FloatTensor(B, N, 3)
96 | idx = torch.cuda.IntTensor(B, N, 3)
97 |
98 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
99 | return torch.sqrt(dist2), idx
100 |
101 | @staticmethod
102 | def backward(ctx, a=None, b=None):
103 | return None, None
104 |
105 |
106 | three_nn = ThreeNN.apply
107 |
108 |
109 | class ThreeInterpolate(Function):
110 |
111 | @staticmethod
112 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
113 | """
114 | Performs weight linear interpolation on 3 features
115 | :param ctx:
116 | :param features: (B, C, M) Features descriptors to be interpolated from
117 | :param idx: (B, n, 3) three nearest neighbors of the target features in features
118 | :param weight: (B, n, 3) weights
119 | :return:
120 | output: (B, C, N) tensor of the interpolated features
121 | """
122 | assert features.is_contiguous()
123 | assert idx.is_contiguous()
124 | assert weight.is_contiguous()
125 |
126 | B, c, m = features.size()
127 | n = idx.size(1)
128 | ctx.three_interpolate_for_backward = (idx, weight, m)
129 | output = torch.cuda.FloatTensor(B, c, n)
130 |
131 | pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
132 | return output
133 |
134 | @staticmethod
135 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136 | """
137 | :param ctx:
138 | :param grad_out: (B, C, N) tensor with gradients of outputs
139 | :return:
140 | grad_features: (B, C, M) tensor with gradients of features
141 | None:
142 | None:
143 | """
144 | idx, weight, m = ctx.three_interpolate_for_backward
145 | B, c, n = grad_out.size()
146 |
147 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
148 | grad_out_data = grad_out.data.contiguous()
149 |
150 | pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
151 | return grad_features, None, None
152 |
153 |
154 | three_interpolate = ThreeInterpolate.apply
155 |
156 |
157 | class GroupingOperation(Function):
158 |
159 | @staticmethod
160 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
161 | """
162 | :param ctx:
163 | :param features: (B, C, N) tensor of features to group
164 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
165 | :return:
166 | output: (B, C, npoint, nsample) tensor
167 | """
168 | assert features.is_contiguous()
169 | assert idx.is_contiguous()
170 |
171 | B, nfeatures, nsample = idx.size()
172 | _, C, N = features.size()
173 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
174 |
175 | pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
176 |
177 | ctx.for_backwards = (idx, N)
178 | return output
179 |
180 | @staticmethod
181 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
182 | """
183 | :param ctx:
184 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
185 | :return:
186 | grad_features: (B, C, N) gradient of the features
187 | """
188 | idx, N = ctx.for_backwards
189 |
190 | B, C, npoint, nsample = grad_out.size()
191 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
192 |
193 | grad_out_data = grad_out.data.contiguous()
194 | pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
195 | return grad_features, None
196 |
197 |
198 | grouping_operation = GroupingOperation.apply
199 |
200 |
201 | class BallQuery(Function):
202 |
203 | @staticmethod
204 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
205 | """
206 | :param ctx:
207 | :param radius: float, radius of the balls
208 | :param nsample: int, maximum number of features in the balls
209 | :param xyz: (B, N, 3) xyz coordinates of the features
210 | :param new_xyz: (B, npoint, 3) centers of the ball query
211 | :return:
212 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
213 | """
214 | assert new_xyz.is_contiguous()
215 | assert xyz.is_contiguous()
216 |
217 | B, N, _ = xyz.size()
218 | npoint = new_xyz.size(1)
219 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
220 |
221 | pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
222 | return idx
223 |
224 | @staticmethod
225 | def backward(ctx, a=None):
226 | return None, None, None, None
227 |
228 |
229 | ball_query = BallQuery.apply
230 |
231 |
232 | class QueryAndGroup(nn.Module):
233 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
234 | """
235 | :param radius: float, radius of ball
236 | :param nsample: int, maximum number of features to gather in the ball
237 | :param use_xyz:
238 | """
239 | super().__init__()
240 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
241 |
242 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
243 | """
244 | :param xyz: (B, N, 3) xyz coordinates of the features
245 | :param new_xyz: (B, npoint, 3) centroids
246 | :param features: (B, C, N) descriptors of the features
247 | :return:
248 | new_features: (B, 3 + C, npoint, nsample)
249 | """
250 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
251 | xyz_trans = xyz.transpose(1, 2).contiguous()
252 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
253 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
254 |
255 | if features is not None:
256 | grouped_features = grouping_operation(features, idx)
257 | if self.use_xyz:
258 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
259 | else:
260 | new_features = grouped_features
261 | else:
262 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
263 | new_features = grouped_xyz
264 |
265 | return new_features
266 |
267 |
268 | class GroupAll(nn.Module):
269 | def __init__(self, use_xyz: bool = True):
270 | super().__init__()
271 | self.use_xyz = use_xyz
272 |
273 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
274 | """
275 | :param xyz: (B, N, 3) xyz coordinates of the features
276 | :param new_xyz: ignored
277 | :param features: (B, C, N) descriptors of the features
278 | :return:
279 | new_features: (B, C + 3, 1, N)
280 | """
281 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
282 | if features is not None:
283 | grouped_features = features.unsqueeze(2)
284 | if self.use_xyz:
285 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
286 | else:
287 | new_features = grouped_features
288 | else:
289 | new_features = grouped_xyz
290 |
291 | return new_features
292 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import List, Tuple
3 |
4 |
5 | class SharedMLP(nn.Sequential):
6 |
7 | def __init__(
8 | self,
9 | args: List[int],
10 | *,
11 | bn: bool = False,
12 | activation=nn.ReLU(inplace=True),
13 | preact: bool = False,
14 | first: bool = False,
15 | name: str = "",
16 | instance_norm: bool = False,
17 | ):
18 | super().__init__()
19 |
20 | for i in range(len(args) - 1):
21 | self.add_module(
22 | name + 'layer{}'.format(i),
23 | Conv2d(
24 | args[i],
25 | args[i + 1],
26 | bn=(not first or not preact or (i != 0)) and bn,
27 | activation=activation
28 | if (not first or not preact or (i != 0)) else None,
29 | preact=preact,
30 | instance_norm=instance_norm
31 | )
32 | )
33 |
34 |
35 | class _ConvBase(nn.Sequential):
36 |
37 | def __init__(
38 | self,
39 | in_size,
40 | out_size,
41 | kernel_size,
42 | stride,
43 | padding,
44 | activation,
45 | bn,
46 | init,
47 | conv=None,
48 | batch_norm=None,
49 | bias=True,
50 | preact=False,
51 | name="",
52 | instance_norm=False,
53 | instance_norm_func=None
54 | ):
55 | super().__init__()
56 |
57 | bias = bias and (not bn)
58 | conv_unit = conv(
59 | in_size,
60 | out_size,
61 | kernel_size=kernel_size,
62 | stride=stride,
63 | padding=padding,
64 | bias=bias
65 | )
66 | init(conv_unit.weight)
67 | if bias:
68 | nn.init.constant_(conv_unit.bias, 0)
69 |
70 | if bn:
71 | if not preact:
72 | bn_unit = batch_norm(out_size)
73 | else:
74 | bn_unit = batch_norm(in_size)
75 | if instance_norm:
76 | if not preact:
77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
78 | else:
79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
80 |
81 | if preact:
82 | if bn:
83 | self.add_module(name + 'bn', bn_unit)
84 |
85 | if activation is not None:
86 | self.add_module(name + 'activation', activation)
87 |
88 | if not bn and instance_norm:
89 | self.add_module(name + 'in', in_unit)
90 |
91 | self.add_module(name + 'conv', conv_unit)
92 |
93 | if not preact:
94 | if bn:
95 | self.add_module(name + 'bn', bn_unit)
96 |
97 | if activation is not None:
98 | self.add_module(name + 'activation', activation)
99 |
100 | if not bn and instance_norm:
101 | self.add_module(name + 'in', in_unit)
102 |
103 |
104 | class _BNBase(nn.Sequential):
105 |
106 | def __init__(self, in_size, batch_norm=None, name=""):
107 | super().__init__()
108 | self.add_module(name + "bn", batch_norm(in_size))
109 |
110 | nn.init.constant_(self[0].weight, 1.0)
111 | nn.init.constant_(self[0].bias, 0)
112 |
113 |
114 | class BatchNorm1d(_BNBase):
115 |
116 | def __init__(self, in_size: int, *, name: str = ""):
117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
118 |
119 |
120 | class BatchNorm2d(_BNBase):
121 |
122 | def __init__(self, in_size: int, name: str = ""):
123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
124 |
125 |
126 | class Conv1d(_ConvBase):
127 |
128 | def __init__(
129 | self,
130 | in_size: int,
131 | out_size: int,
132 | *,
133 | kernel_size: int = 1,
134 | stride: int = 1,
135 | padding: int = 0,
136 | activation=nn.ReLU(inplace=True),
137 | bn: bool = False,
138 | init=nn.init.kaiming_normal_,
139 | bias: bool = True,
140 | preact: bool = False,
141 | name: str = "",
142 | instance_norm=False
143 | ):
144 | super().__init__(
145 | in_size,
146 | out_size,
147 | kernel_size,
148 | stride,
149 | padding,
150 | activation,
151 | bn,
152 | init,
153 | conv=nn.Conv1d,
154 | batch_norm=BatchNorm1d,
155 | bias=bias,
156 | preact=preact,
157 | name=name,
158 | instance_norm=instance_norm,
159 | instance_norm_func=nn.InstanceNorm1d
160 | )
161 |
162 |
163 | class Conv2d(_ConvBase):
164 |
165 | def __init__(
166 | self,
167 | in_size: int,
168 | out_size: int,
169 | *,
170 | kernel_size: Tuple[int, int] = (1, 1),
171 | stride: Tuple[int, int] = (1, 1),
172 | padding: Tuple[int, int] = (0, 0),
173 | activation=nn.ReLU(inplace=True),
174 | bn: bool = False,
175 | init=nn.init.kaiming_normal_,
176 | bias: bool = True,
177 | preact: bool = False,
178 | name: str = "",
179 | instance_norm=False
180 | ):
181 | super().__init__(
182 | in_size,
183 | out_size,
184 | kernel_size,
185 | stride,
186 | padding,
187 | activation,
188 | bn,
189 | init,
190 | conv=nn.Conv2d,
191 | batch_norm=BatchNorm2d,
192 | bias=bias,
193 | preact=preact,
194 | name=name,
195 | instance_norm=instance_norm,
196 | instance_norm_func=nn.InstanceNorm2d
197 | )
198 |
199 |
200 | class FC(nn.Sequential):
201 |
202 | def __init__(
203 | self,
204 | in_size: int,
205 | out_size: int,
206 | *,
207 | activation=nn.ReLU(inplace=True),
208 | bn: bool = False,
209 | init=None,
210 | preact: bool = False,
211 | name: str = ""
212 | ):
213 | super().__init__()
214 |
215 | fc = nn.Linear(in_size, out_size, bias=not bn)
216 | if init is not None:
217 | init(fc.weight)
218 | if not bn:
219 | nn.init.constant(fc.bias, 0)
220 |
221 | if preact:
222 | if bn:
223 | self.add_module(name + 'bn', BatchNorm1d(in_size))
224 |
225 | if activation is not None:
226 | self.add_module(name + 'activation', activation)
227 |
228 | self.add_module(name + 'fc', fc)
229 |
230 | if not preact:
231 | if bn:
232 | self.add_module(name + 'bn', BatchNorm1d(out_size))
233 |
234 | if activation is not None:
235 | self.add_module(name + 'activation', activation)
236 |
237 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='pointnet2',
6 | ext_modules=[
7 | CUDAExtension('pointnet2_cuda', [
8 | 'src/pointnet2_api.cpp',
9 |
10 | 'src/ball_query.cpp',
11 | 'src/ball_query_gpu.cu',
12 | 'src/group_points.cpp',
13 | 'src/group_points_gpu.cu',
14 | 'src/interpolate.cpp',
15 | 'src/interpolate_gpu.cu',
16 | 'src/sampling.cpp',
17 | 'src/sampling_gpu.cu',
18 | ],
19 | extra_compile_args={'cxx': ['-g'],
20 | 'nvcc': ['-O2']})
21 | ],
22 | cmdclass={'build_ext': BuildExtension}
23 | )
24 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | // #include
4 | #include
5 | #include
6 | #include "ball_query_gpu.h"
7 | #include
8 | #include
9 |
10 | // extern THCState *state;
11 |
12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
14 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
15 |
16 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
17 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
18 | CHECK_INPUT(new_xyz_tensor);
19 | CHECK_INPUT(xyz_tensor);
20 | const float *new_xyz = new_xyz_tensor.data();
21 | const float *xyz = xyz_tensor.data();
22 | int *idx = idx_tensor.data();
23 |
24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
26 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
27 | return 1;
28 | }
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "ball_query_gpu.h"
6 | #include "cuda_utils.h"
7 |
8 |
9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
11 | // new_xyz: (B, M, 3)
12 | // xyz: (B, N, 3)
13 | // output:
14 | // idx: (B, M, nsample)
15 | int bs_idx = blockIdx.y;
16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
17 | if (bs_idx >= b || pt_idx >= m) return;
18 |
19 | new_xyz += bs_idx * m * 3 + pt_idx * 3;
20 | xyz += bs_idx * n * 3;
21 | idx += bs_idx * m * nsample + pt_idx * nsample;
22 |
23 | float radius2 = radius * radius;
24 | float new_x = new_xyz[0];
25 | float new_y = new_xyz[1];
26 | float new_z = new_xyz[2];
27 |
28 | int cnt = 0;
29 | for (int k = 0; k < n; ++k) {
30 | float x = xyz[k * 3 + 0];
31 | float y = xyz[k * 3 + 1];
32 | float z = xyz[k * 3 + 2];
33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
34 | if (d2 < radius2){
35 | if (cnt == 0){
36 | for (int l = 0; l < nsample; ++l) {
37 | idx[l] = k;
38 | }
39 | }
40 | idx[cnt] = k;
41 | ++cnt;
42 | if (cnt >= nsample) break;
43 | }
44 | }
45 | }
46 |
47 |
48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
50 | // new_xyz: (B, M, 3)
51 | // xyz: (B, N, 3)
52 | // output:
53 | // idx: (B, M, nsample)
54 |
55 | cudaError_t err;
56 |
57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
58 | dim3 threads(THREADS_PER_BLOCK);
59 |
60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
61 | // cudaDeviceSynchronize(); // for using printf in kernel function
62 | err = cudaGetLastError();
63 | if (cudaSuccess != err) {
64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
65 | exit(-1);
66 | }
67 | }
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/ball_query_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _BALL_QUERY_GPU_H
2 | #define _BALL_QUERY_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
11 |
12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
14 |
15 | #endif
16 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 |
6 | #define TOTAL_THREADS 1024
7 | #define THREADS_PER_BLOCK 256
8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
9 |
10 | inline int opt_n_threads(int work_size) {
11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
12 |
13 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
14 | }
15 | #endif
16 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | // #include
6 | #include "group_points_gpu.h"
7 | #include
8 | #include
9 | // extern THCState *state;
10 |
11 |
12 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
13 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
14 |
15 | float *grad_points = grad_points_tensor.data();
16 | const int *idx = idx_tensor.data();
17 | const float *grad_out = grad_out_tensor.data();
18 |
19 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
21 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
22 | return 1;
23 | }
24 |
25 |
26 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
27 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
28 |
29 | const float *points = points_tensor.data();
30 | const int *idx = idx_tensor.data();
31 | float *out = out_tensor.data();
32 |
33 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
35 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
36 | return 1;
37 | }
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "group_points_gpu.h"
6 |
7 |
8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
10 | // grad_out: (B, C, npoints, nsample)
11 | // idx: (B, npoints, nsample)
12 | // output:
13 | // grad_points: (B, C, N)
14 | int bs_idx = blockIdx.z;
15 | int c_idx = blockIdx.y;
16 | int index = blockIdx.x * blockDim.x + threadIdx.x;
17 | int pt_idx = index / nsample;
18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
19 |
20 | int sample_idx = index % nsample;
21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
23 |
24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
25 | }
26 |
27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
29 | // grad_out: (B, C, npoints, nsample)
30 | // idx: (B, npoints, nsample)
31 | // output:
32 | // grad_points: (B, C, N)
33 | cudaError_t err;
34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35 | dim3 threads(THREADS_PER_BLOCK);
36 |
37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
38 |
39 | err = cudaGetLastError();
40 | if (cudaSuccess != err) {
41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42 | exit(-1);
43 | }
44 | }
45 |
46 |
47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
49 | // points: (B, C, N)
50 | // idx: (B, npoints, nsample)
51 | // output:
52 | // out: (B, C, npoints, nsample)
53 | int bs_idx = blockIdx.z;
54 | int c_idx = blockIdx.y;
55 | int index = blockIdx.x * blockDim.x + threadIdx.x;
56 | int pt_idx = index / nsample;
57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
58 |
59 | int sample_idx = index % nsample;
60 |
61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0];
63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
64 |
65 | out[out_idx] = points[in_idx];
66 | }
67 |
68 |
69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
70 | const float *points, const int *idx, float *out, cudaStream_t stream) {
71 | // points: (B, C, N)
72 | // idx: (B, npoints, nsample)
73 | // output:
74 | // out: (B, C, npoints, nsample)
75 | cudaError_t err;
76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
77 | dim3 threads(THREADS_PER_BLOCK);
78 |
79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out);
80 | // cudaDeviceSynchronize(); // for using printf in kernel function
81 | err = cudaGetLastError();
82 | if (cudaSuccess != err) {
83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
84 | exit(-1);
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/group_points_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _GROUP_POINTS_GPU_H
2 | #define _GROUP_POINTS_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
12 |
13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
14 | const float *points, const int *idx, float *out, cudaStream_t stream);
15 |
16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18 |
19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21 |
22 | #endif
23 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | // #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include "interpolate_gpu.h"
12 |
13 | // extern THCState *state;
14 |
15 |
16 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
17 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
18 | const float *unknown = unknown_tensor.data();
19 | const float *known = known_tensor.data();
20 | float *dist2 = dist2_tensor.data();
21 | int *idx = idx_tensor.data();
22 |
23 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
26 | }
27 |
28 |
29 | void three_interpolate_wrapper_fast(int b, int c, int m, int n,
30 | at::Tensor points_tensor,
31 | at::Tensor idx_tensor,
32 | at::Tensor weight_tensor,
33 | at::Tensor out_tensor) {
34 |
35 | const float *points = points_tensor.data();
36 | const float *weight = weight_tensor.data();
37 | float *out = out_tensor.data();
38 | const int *idx = idx_tensor.data();
39 |
40 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
41 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
42 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
43 | }
44 |
45 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
46 | at::Tensor grad_out_tensor,
47 | at::Tensor idx_tensor,
48 | at::Tensor weight_tensor,
49 | at::Tensor grad_points_tensor) {
50 |
51 | const float *grad_out = grad_out_tensor.data();
52 | const float *weight = weight_tensor.data();
53 | float *grad_points = grad_points_tensor.data();
54 | const int *idx = idx_tensor.data();
55 |
56 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
58 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
59 | }
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 | #include "interpolate_gpu.h"
7 |
8 |
9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
11 | // unknown: (B, N, 3)
12 | // known: (B, M, 3)
13 | // output:
14 | // dist2: (B, N, 3)
15 | // idx: (B, N, 3)
16 |
17 | int bs_idx = blockIdx.y;
18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
19 | if (bs_idx >= b || pt_idx >= n) return;
20 |
21 | unknown += bs_idx * n * 3 + pt_idx * 3;
22 | known += bs_idx * m * 3;
23 | dist2 += bs_idx * n * 3 + pt_idx * 3;
24 | idx += bs_idx * n * 3 + pt_idx * 3;
25 |
26 | float ux = unknown[0];
27 | float uy = unknown[1];
28 | float uz = unknown[2];
29 |
30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
31 | int besti1 = 0, besti2 = 0, besti3 = 0;
32 | for (int k = 0; k < m; ++k) {
33 | float x = known[k * 3 + 0];
34 | float y = known[k * 3 + 1];
35 | float z = known[k * 3 + 2];
36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
37 | if (d < best1) {
38 | best3 = best2; besti3 = besti2;
39 | best2 = best1; besti2 = besti1;
40 | best1 = d; besti1 = k;
41 | }
42 | else if (d < best2) {
43 | best3 = best2; besti3 = besti2;
44 | best2 = d; besti2 = k;
45 | }
46 | else if (d < best3) {
47 | best3 = d; besti3 = k;
48 | }
49 | }
50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
52 | }
53 |
54 |
55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
56 | const float *known, float *dist2, int *idx, cudaStream_t stream) {
57 | // unknown: (B, N, 3)
58 | // known: (B, M, 3)
59 | // output:
60 | // dist2: (B, N, 3)
61 | // idx: (B, N, 3)
62 |
63 | cudaError_t err;
64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
65 | dim3 threads(THREADS_PER_BLOCK);
66 |
67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx);
68 |
69 | err = cudaGetLastError();
70 | if (cudaSuccess != err) {
71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
72 | exit(-1);
73 | }
74 | }
75 |
76 |
77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
79 | // points: (B, C, M)
80 | // idx: (B, N, 3)
81 | // weight: (B, N, 3)
82 | // output:
83 | // out: (B, C, N)
84 |
85 | int bs_idx = blockIdx.z;
86 | int c_idx = blockIdx.y;
87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
88 |
89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
90 |
91 | weight += bs_idx * n * 3 + pt_idx * 3;
92 | points += bs_idx * c * m + c_idx * m;
93 | idx += bs_idx * n * 3 + pt_idx * 3;
94 | out += bs_idx * c * n + c_idx * n;
95 |
96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
97 | }
98 |
99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
101 | // points: (B, C, M)
102 | // idx: (B, N, 3)
103 | // weight: (B, N, 3)
104 | // output:
105 | // out: (B, C, N)
106 |
107 | cudaError_t err;
108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
109 | dim3 threads(THREADS_PER_BLOCK);
110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out);
111 |
112 | err = cudaGetLastError();
113 | if (cudaSuccess != err) {
114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
115 | exit(-1);
116 | }
117 | }
118 |
119 |
120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
122 | // grad_out: (B, C, N)
123 | // weight: (B, N, 3)
124 | // output:
125 | // grad_points: (B, C, M)
126 |
127 | int bs_idx = blockIdx.z;
128 | int c_idx = blockIdx.y;
129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
130 |
131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
132 |
133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx;
134 | weight += bs_idx * n * 3 + pt_idx * 3;
135 | grad_points += bs_idx * c * m + c_idx * m;
136 | idx += bs_idx * n * 3 + pt_idx * 3;
137 |
138 |
139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
142 | }
143 |
144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
146 | // grad_out: (B, C, N)
147 | // weight: (B, N, 3)
148 | // output:
149 | // grad_points: (B, C, M)
150 |
151 | cudaError_t err;
152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
153 | dim3 threads(THREADS_PER_BLOCK);
154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points);
155 |
156 | err = cudaGetLastError();
157 | if (cudaSuccess != err) {
158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
159 | exit(-1);
160 | }
161 | }
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/interpolate_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _INTERPOLATE_GPU_H
2 | #define _INTERPOLATE_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
12 |
13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
14 | const float *known, float *dist2, int *idx, cudaStream_t stream);
15 |
16 |
17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
19 |
20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
22 |
23 |
24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
26 |
27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
29 |
30 | #endif
31 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/pointnet2_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ball_query_gpu.h"
5 | #include "group_points_gpu.h"
6 | #include "sampling_gpu.h"
7 | #include "interpolate_gpu.h"
8 |
9 |
10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
12 |
13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
15 |
16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
18 |
19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
20 |
21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
24 | }
25 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | // #include
5 |
6 | #include "sampling_gpu.h"
7 | #include
8 | #include
9 |
10 | // extern THCState *state;
11 |
12 |
13 | int gather_points_wrapper_fast(int b, int c, int n, int npoints,
14 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
15 | const float *points = points_tensor.data();
16 | const int *idx = idx_tensor.data();
17 | float *out = out_tensor.data();
18 |
19 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
21 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
22 | return 1;
23 | }
24 |
25 |
26 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
27 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
28 |
29 | const float *grad_out = grad_out_tensor.data();
30 | const int *idx = idx_tensor.data();
31 | float *grad_points = grad_points_tensor.data();
32 |
33 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
35 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
36 | return 1;
37 | }
38 |
39 |
40 | int furthest_point_sampling_wrapper(int b, int n, int m,
41 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
42 |
43 | const float *points = points_tensor.data();
44 | float *temp = temp_tensor.data();
45 | int *idx = idx_tensor.data();
46 |
47 | // cudaStream_t stream = at::cuda::getCurrentCUDAStream();
48 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
50 | return 1;
51 | }
52 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "sampling_gpu.h"
6 |
7 |
8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m,
9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
10 | // points: (B, C, N)
11 | // idx: (B, M)
12 | // output:
13 | // out: (B, C, M)
14 |
15 | int bs_idx = blockIdx.z;
16 | int c_idx = blockIdx.y;
17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
19 |
20 | out += bs_idx * c * m + c_idx * m + pt_idx;
21 | idx += bs_idx * m + pt_idx;
22 | points += bs_idx * c * n + c_idx * n;
23 | out[0] = points[idx[0]];
24 | }
25 |
26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
27 | const float *points, const int *idx, float *out, cudaStream_t stream) {
28 | // points: (B, C, N)
29 | // idx: (B, npoints)
30 | // output:
31 | // out: (B, C, npoints)
32 |
33 | cudaError_t err;
34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35 | dim3 threads(THREADS_PER_BLOCK);
36 |
37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out);
38 |
39 | err = cudaGetLastError();
40 | if (cudaSuccess != err) {
41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42 | exit(-1);
43 | }
44 | }
45 |
46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
47 | const int *__restrict__ idx, float *__restrict__ grad_points) {
48 | // grad_out: (B, C, M)
49 | // idx: (B, M)
50 | // output:
51 | // grad_points: (B, C, N)
52 |
53 | int bs_idx = blockIdx.z;
54 | int c_idx = blockIdx.y;
55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
57 |
58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx;
59 | idx += bs_idx * m + pt_idx;
60 | grad_points += bs_idx * c * n + c_idx * n;
61 |
62 | atomicAdd(grad_points + idx[0], grad_out[0]);
63 | }
64 |
65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
67 | // grad_out: (B, C, npoints)
68 | // idx: (B, npoints)
69 | // output:
70 | // grad_points: (B, C, N)
71 |
72 | cudaError_t err;
73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
74 | dim3 threads(THREADS_PER_BLOCK);
75 |
76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points);
77 |
78 | err = cudaGetLastError();
79 | if (cudaSuccess != err) {
80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
81 | exit(-1);
82 | }
83 | }
84 |
85 |
86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
87 | const float v1 = dists[idx1], v2 = dists[idx2];
88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
89 | dists[idx1] = max(v1, v2);
90 | dists_i[idx1] = v2 > v1 ? i2 : i1;
91 | }
92 |
93 | template
94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m,
95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
96 | // dataset: (B, N, 3)
97 | // tmp: (B, N)
98 | // output:
99 | // idx: (B, M)
100 |
101 | if (m <= 0) return;
102 | __shared__ float dists[block_size];
103 | __shared__ int dists_i[block_size];
104 |
105 | int batch_index = blockIdx.x;
106 | dataset += batch_index * n * 3;
107 | temp += batch_index * n;
108 | idxs += batch_index * m;
109 |
110 | int tid = threadIdx.x;
111 | const int stride = block_size;
112 |
113 | int old = 0;
114 | if (threadIdx.x == 0)
115 | idxs[0] = old;
116 |
117 | __syncthreads();
118 | for (int j = 1; j < m; j++) {
119 | int besti = 0;
120 | float best = -1;
121 | float x1 = dataset[old * 3 + 0];
122 | float y1 = dataset[old * 3 + 1];
123 | float z1 = dataset[old * 3 + 2];
124 | for (int k = tid; k < n; k += stride) {
125 | float x2, y2, z2;
126 | x2 = dataset[k * 3 + 0];
127 | y2 = dataset[k * 3 + 1];
128 | z2 = dataset[k * 3 + 2];
129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
130 | // if (mag <= 1e-3)
131 | // continue;
132 |
133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
134 | float d2 = min(d, temp[k]);
135 | temp[k] = d2;
136 | besti = d2 > best ? k : besti;
137 | best = d2 > best ? d2 : best;
138 | }
139 | dists[tid] = best;
140 | dists_i[tid] = besti;
141 | __syncthreads();
142 |
143 | if (block_size >= 1024) {
144 | if (tid < 512) {
145 | __update(dists, dists_i, tid, tid + 512);
146 | }
147 | __syncthreads();
148 | }
149 |
150 | if (block_size >= 512) {
151 | if (tid < 256) {
152 | __update(dists, dists_i, tid, tid + 256);
153 | }
154 | __syncthreads();
155 | }
156 | if (block_size >= 256) {
157 | if (tid < 128) {
158 | __update(dists, dists_i, tid, tid + 128);
159 | }
160 | __syncthreads();
161 | }
162 | if (block_size >= 128) {
163 | if (tid < 64) {
164 | __update(dists, dists_i, tid, tid + 64);
165 | }
166 | __syncthreads();
167 | }
168 | if (block_size >= 64) {
169 | if (tid < 32) {
170 | __update(dists, dists_i, tid, tid + 32);
171 | }
172 | __syncthreads();
173 | }
174 | if (block_size >= 32) {
175 | if (tid < 16) {
176 | __update(dists, dists_i, tid, tid + 16);
177 | }
178 | __syncthreads();
179 | }
180 | if (block_size >= 16) {
181 | if (tid < 8) {
182 | __update(dists, dists_i, tid, tid + 8);
183 | }
184 | __syncthreads();
185 | }
186 | if (block_size >= 8) {
187 | if (tid < 4) {
188 | __update(dists, dists_i, tid, tid + 4);
189 | }
190 | __syncthreads();
191 | }
192 | if (block_size >= 4) {
193 | if (tid < 2) {
194 | __update(dists, dists_i, tid, tid + 2);
195 | }
196 | __syncthreads();
197 | }
198 | if (block_size >= 2) {
199 | if (tid < 1) {
200 | __update(dists, dists_i, tid, tid + 1);
201 | }
202 | __syncthreads();
203 | }
204 |
205 | old = dists_i[0];
206 | if (tid == 0)
207 | idxs[j] = old;
208 | }
209 | }
210 |
211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m,
212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
213 | // dataset: (B, N, 3)
214 | // tmp: (B, N)
215 | // output:
216 | // idx: (B, M)
217 |
218 | cudaError_t err;
219 | unsigned int n_threads = opt_n_threads(n);
220 |
221 | switch (n_threads) {
222 | case 1024:
223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break;
224 | case 512:
225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break;
226 | case 256:
227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break;
228 | case 128:
229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break;
230 | case 64:
231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break;
232 | case 32:
233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break;
234 | case 16:
235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break;
236 | case 8:
237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break;
238 | case 4:
239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break;
240 | case 2:
241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break;
242 | case 1:
243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break;
244 | default:
245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs);
246 | }
247 |
248 | err = cudaGetLastError();
249 | if (cudaSuccess != err) {
250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
251 | exit(-1);
252 | }
253 | }
254 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/pointnet2/src/sampling_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_GPU_H
2 | #define _SAMPLING_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints,
10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
11 |
12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
13 | const float *points, const int *idx, float *out, cudaStream_t stream);
14 |
15 |
16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18 |
19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21 |
22 |
23 | int furthest_point_sampling_wrapper(int b, int n, int m,
24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
25 |
26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m,
27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream);
28 |
29 | #endif
30 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/tools/_init_path.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
3 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/tools/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch.utils.data as torch_data
4 | import kitti_utils
5 | import cv2
6 | from PIL import Image
7 |
8 |
9 | USE_INTENSITY = False
10 |
11 |
12 | class KittiDataset(torch_data.Dataset):
13 | def __init__(self, root_dir, split='train', mode='TRAIN'):
14 | self.split = split
15 | self.mode = mode
16 | self.classes = ['Car']
17 | is_test = self.split == 'test'
18 | self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training')
19 |
20 | split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt')
21 | self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
22 | self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list]
23 | self.num_sample = self.image_idx_list.__len__()
24 |
25 | self.npoints = 16384
26 |
27 | self.image_dir = os.path.join(self.imageset_dir, 'image_2')
28 | self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
29 | self.calib_dir = os.path.join(self.imageset_dir, 'calib')
30 | self.label_dir = os.path.join(self.imageset_dir, 'label_2')
31 | self.plane_dir = os.path.join(self.imageset_dir, 'planes')
32 |
33 | def get_image(self, idx):
34 | img_file = os.path.join(self.image_dir, '%06d.png' % idx)
35 | assert os.path.exists(img_file)
36 | return cv2.imread(img_file) # (H, W, 3) BGR mode
37 |
38 | def get_image_shape(self, idx):
39 | img_file = os.path.join(self.image_dir, '%06d.png' % idx)
40 | assert os.path.exists(img_file)
41 | im = Image.open(img_file)
42 | width, height = im.size
43 | return height, width, 3
44 |
45 | def get_lidar(self, idx):
46 | lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx)
47 | assert os.path.exists(lidar_file)
48 | return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)
49 |
50 | def get_calib(self, idx):
51 | calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx)
52 | assert os.path.exists(calib_file)
53 | return kitti_utils.Calibration(calib_file)
54 |
55 | def get_label(self, idx):
56 | label_file = os.path.join(self.label_dir, '%06d.txt' % idx)
57 | assert os.path.exists(label_file)
58 | return kitti_utils.get_objects_from_label(label_file)
59 |
60 | @staticmethod
61 | def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape):
62 | val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1])
63 | val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0])
64 | val_flag_merge = np.logical_and(val_flag_1, val_flag_2)
65 | pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0)
66 | return pts_valid_flag
67 |
68 | def filtrate_objects(self, obj_list):
69 | type_whitelist = self.classes
70 | if self.mode == 'TRAIN':
71 | type_whitelist = list(self.classes)
72 | if 'Car' in self.classes:
73 | type_whitelist.append('Van')
74 |
75 | valid_obj_list = []
76 | for obj in obj_list:
77 | if obj.cls_type not in type_whitelist:
78 | continue
79 |
80 | valid_obj_list.append(obj)
81 | return valid_obj_list
82 |
83 | def __len__(self):
84 | return len(self.sample_id_list)
85 |
86 | def __getitem__(self, index):
87 | sample_id = int(self.sample_id_list[index])
88 | calib = self.get_calib(sample_id)
89 | img_shape = self.get_image_shape(sample_id)
90 | pts_lidar = self.get_lidar(sample_id)
91 |
92 | # get valid point (projected points should be in image)
93 | pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
94 | pts_intensity = pts_lidar[:, 3]
95 |
96 | pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
97 | pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
98 |
99 | pts_rect = pts_rect[pts_valid_flag][:, 0:3]
100 | pts_intensity = pts_intensity[pts_valid_flag]
101 |
102 | if self.npoints < len(pts_rect):
103 | pts_depth = pts_rect[:, 2]
104 | pts_near_flag = pts_depth < 40.0
105 | far_idxs_choice = np.where(pts_near_flag == 0)[0]
106 | near_idxs = np.where(pts_near_flag == 1)[0]
107 | near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)
108 |
109 | choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
110 | if len(far_idxs_choice) > 0 else near_idxs_choice
111 | np.random.shuffle(choice)
112 | else:
113 | choice = np.arange(0, len(pts_rect), dtype=np.int32)
114 | if self.npoints > len(pts_rect):
115 | extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
116 | choice = np.concatenate((choice, extra_choice), axis=0)
117 | np.random.shuffle(choice)
118 |
119 | ret_pts_rect = pts_rect[choice, :]
120 | ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5]
121 |
122 | pts_features = [ret_pts_intensity.reshape(-1, 1)]
123 | ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]
124 |
125 | sample_info = {'sample_id': sample_id}
126 |
127 | if self.mode == 'TEST':
128 | if USE_INTENSITY:
129 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
130 | else:
131 | pts_input = ret_pts_rect
132 | sample_info['pts_input'] = pts_input
133 | sample_info['pts_rect'] = ret_pts_rect
134 | sample_info['pts_features'] = ret_pts_features
135 | return sample_info
136 |
137 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id))
138 |
139 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
140 |
141 | # prepare input
142 | if USE_INTENSITY:
143 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
144 | else:
145 | pts_input = ret_pts_rect
146 |
147 | # generate training labels
148 | cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d)
149 | sample_info['pts_input'] = pts_input
150 | sample_info['pts_rect'] = ret_pts_rect
151 | sample_info['cls_labels'] = cls_labels
152 | return sample_info
153 |
154 | @staticmethod
155 | def generate_training_labels(pts_rect, gt_boxes3d):
156 | cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32)
157 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True)
158 | extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2)
159 | extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True)
160 | for k in range(gt_boxes3d.shape[0]):
161 | box_corners = gt_corners[k]
162 | fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners)
163 | cls_label[fg_pt_flag] = 1
164 |
165 | # enlarge the bbox3d, ignore nearby points
166 | extend_box_corners = extend_gt_corners[k]
167 | fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners)
168 | ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag)
169 | cls_label[ignore_flag] = -1
170 |
171 | return cls_label
172 |
173 | def collate_batch(self, batch):
174 | batch_size = batch.__len__()
175 | ans_dict = {}
176 |
177 | for key in batch[0].keys():
178 | if isinstance(batch[0][key], np.ndarray):
179 | ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0)
180 |
181 | else:
182 | ans_dict[key] = [batch[k][key] for k in range(batch_size)]
183 | if isinstance(batch[0][key], int):
184 | ans_dict[key] = np.array(ans_dict[key], dtype=np.int32)
185 | elif isinstance(batch[0][key], float):
186 | ans_dict[key] = np.array(ans_dict[key], dtype=np.float32)
187 |
188 | return ans_dict
189 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/tools/kitti_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import Delaunay
3 | import scipy
4 |
5 |
6 | def cls_type_to_id(cls_type):
7 | type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3, 'Van': 4}
8 | if cls_type not in type_to_id.keys():
9 | return -1
10 | return type_to_id[cls_type]
11 |
12 |
13 | class Object3d(object):
14 | def __init__(self, line):
15 | label = line.strip().split(' ')
16 | self.src = line
17 | self.cls_type = label[0]
18 | self.cls_id = cls_type_to_id(self.cls_type)
19 | self.trucation = float(label[1])
20 | self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown
21 | self.alpha = float(label[3])
22 | self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32)
23 | self.h = float(label[8])
24 | self.w = float(label[9])
25 | self.l = float(label[10])
26 | self.pos = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32)
27 | self.dis_to_cam = np.linalg.norm(self.pos)
28 | self.ry = float(label[14])
29 | self.score = float(label[15]) if label.__len__() == 16 else -1.0
30 | self.level_str = None
31 | self.level = self.get_obj_level()
32 |
33 | def get_obj_level(self):
34 | height = float(self.box2d[3]) - float(self.box2d[1]) + 1
35 |
36 | if height >= 40 and self.trucation <= 0.15 and self.occlusion <= 0:
37 | self.level_str = 'Easy'
38 | return 1 # Easy
39 | elif height >= 25 and self.trucation <= 0.3 and self.occlusion <= 1:
40 | self.level_str = 'Moderate'
41 | return 2 # Moderate
42 | elif height >= 25 and self.trucation <= 0.5 and self.occlusion <= 2:
43 | self.level_str = 'Hard'
44 | return 3 # Hard
45 | else:
46 | self.level_str = 'UnKnown'
47 | return 4
48 |
49 | def generate_corners3d(self):
50 | """
51 | generate corners3d representation for this object
52 | :return corners_3d: (8, 3) corners of box3d in camera coord
53 | """
54 | l, h, w = self.l, self.h, self.w
55 | x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2]
56 | y_corners = [0, 0, 0, 0, -h, -h, -h, -h]
57 | z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2]
58 |
59 | R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)],
60 | [0, 1, 0],
61 | [-np.sin(self.ry), 0, np.cos(self.ry)]])
62 | corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8)
63 | corners3d = np.dot(R, corners3d).T
64 | corners3d = corners3d + self.pos
65 | return corners3d
66 |
67 | def to_str(self):
68 | print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \
69 | % (self.cls_type, self.trucation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l,
70 | self.pos, self.ry)
71 | return print_str
72 |
73 | def to_kitti_format(self):
74 | kitti_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \
75 | % (self.cls_type, self.trucation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1],
76 | self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.pos[0], self.pos[1], self.pos[2],
77 | self.ry)
78 | return kitti_str
79 |
80 |
81 | def get_calib_from_file(calib_file):
82 | with open(calib_file) as f:
83 | lines = f.readlines()
84 |
85 | obj = lines[2].strip().split(' ')[1:]
86 | P2 = np.array(obj, dtype=np.float32)
87 | obj = lines[3].strip().split(' ')[1:]
88 | P3 = np.array(obj, dtype=np.float32)
89 | obj = lines[4].strip().split(' ')[1:]
90 | R0 = np.array(obj, dtype=np.float32)
91 | obj = lines[5].strip().split(' ')[1:]
92 | Tr_velo_to_cam = np.array(obj, dtype=np.float32)
93 |
94 | return {'P2': P2.reshape(3, 4),
95 | 'P3': P3.reshape(3, 4),
96 | 'R0': R0.reshape(3, 3),
97 | 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)}
98 |
99 |
100 | class Calibration(object):
101 | def __init__(self, calib_file):
102 | if isinstance(calib_file, str):
103 | calib = get_calib_from_file(calib_file)
104 | else:
105 | calib = calib_file
106 |
107 | self.P2 = calib['P2'] # 3 x 4
108 | self.R0 = calib['R0'] # 3 x 3
109 | self.V2C = calib['Tr_velo2cam'] # 3 x 4
110 |
111 | def cart_to_hom(self, pts):
112 | """
113 | :param pts: (N, 3 or 2)
114 | :return pts_hom: (N, 4 or 3)
115 | """
116 | pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32)))
117 | return pts_hom
118 |
119 | def lidar_to_rect(self, pts_lidar):
120 | """
121 | :param pts_lidar: (N, 3)
122 | :return pts_rect: (N, 3)
123 | """
124 | pts_lidar_hom = self.cart_to_hom(pts_lidar)
125 | pts_rect = np.dot(pts_lidar_hom, np.dot(self.V2C.T, self.R0.T))
126 | return pts_rect
127 |
128 | def rect_to_img(self, pts_rect):
129 | """
130 | :param pts_rect: (N, 3)
131 | :return pts_img: (N, 2)
132 | """
133 | pts_rect_hom = self.cart_to_hom(pts_rect)
134 | pts_2d_hom = np.dot(pts_rect_hom, self.P2.T)
135 | pts_img = (pts_2d_hom[:, 0:2].T / pts_rect_hom[:, 2]).T # (N, 2)
136 | pts_rect_depth = pts_2d_hom[:, 2] - self.P2.T[3, 2] # depth in rect camera coord
137 | return pts_img, pts_rect_depth
138 |
139 | def lidar_to_img(self, pts_lidar):
140 | """
141 | :param pts_lidar: (N, 3)
142 | :return pts_img: (N, 2)
143 | """
144 | pts_rect = self.lidar_to_rect(pts_lidar)
145 | pts_img, pts_depth = self.rect_to_img(pts_rect)
146 | return pts_img, pts_depth
147 |
148 |
149 | def get_objects_from_label(label_file):
150 | with open(label_file, 'r') as f:
151 | lines = f.readlines()
152 | objects = [Object3d(line) for line in lines]
153 | return objects
154 |
155 |
156 | def objs_to_boxes3d(obj_list):
157 | boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32)
158 | for k, obj in enumerate(obj_list):
159 | boxes3d[k, 0:3], boxes3d[k, 3], boxes3d[k, 4], boxes3d[k, 5], boxes3d[k, 6] \
160 | = obj.pos, obj.h, obj.w, obj.l, obj.ry
161 | return boxes3d
162 |
163 |
164 | def boxes3d_to_corners3d(boxes3d, rotate=True):
165 | """
166 | :param boxes3d: (N, 7) [x, y, z, h, w, l, ry]
167 | :param rotate:
168 | :return: corners3d: (N, 8, 3)
169 | """
170 | boxes_num = boxes3d.shape[0]
171 | h, w, l = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5]
172 | x_corners = np.array([l / 2., l / 2., -l / 2., -l / 2., l / 2., l / 2., -l / 2., -l / 2.], dtype=np.float32).T # (N, 8)
173 | z_corners = np.array([w / 2., -w / 2., -w / 2., w / 2., w / 2., -w / 2., -w / 2., w / 2.], dtype=np.float32).T # (N, 8)
174 |
175 | y_corners = np.zeros((boxes_num, 8), dtype=np.float32)
176 | y_corners[:, 4:8] = -h.reshape(boxes_num, 1).repeat(4, axis=1) # (N, 8)
177 |
178 | if rotate:
179 | ry = boxes3d[:, 6]
180 | zeros, ones = np.zeros(ry.size, dtype=np.float32), np.ones(ry.size, dtype=np.float32)
181 | rot_list = np.array([[np.cos(ry), zeros, -np.sin(ry)],
182 | [zeros, ones, zeros],
183 | [np.sin(ry), zeros, np.cos(ry)]]) # (3, 3, N)
184 | R_list = np.transpose(rot_list, (2, 0, 1)) # (N, 3, 3)
185 |
186 | temp_corners = np.concatenate((x_corners.reshape(-1, 8, 1), y_corners.reshape(-1, 8, 1),
187 | z_corners.reshape(-1, 8, 1)), axis=2) # (N, 8, 3)
188 | rotated_corners = np.matmul(temp_corners, R_list) # (N, 8, 3)
189 | x_corners, y_corners, z_corners = rotated_corners[:, :, 0], rotated_corners[:, :, 1], rotated_corners[:, :, 2]
190 |
191 | x_loc, y_loc, z_loc = boxes3d[:, 0], boxes3d[:, 1], boxes3d[:, 2]
192 |
193 | x = x_loc.reshape(-1, 1) + x_corners.reshape(-1, 8)
194 | y = y_loc.reshape(-1, 1) + y_corners.reshape(-1, 8)
195 | z = z_loc.reshape(-1, 1) + z_corners.reshape(-1, 8)
196 |
197 | corners = np.concatenate((x.reshape(-1, 8, 1), y.reshape(-1, 8, 1), z.reshape(-1, 8, 1)), axis=2)
198 |
199 | return corners.astype(np.float32)
200 |
201 |
202 | def enlarge_box3d(boxes3d, extra_width):
203 | """
204 | :param boxes3d: (N, 7) [x, y, z, h, w, l, ry]
205 | """
206 | if isinstance(boxes3d, np.ndarray):
207 | large_boxes3d = boxes3d.copy()
208 | else:
209 | large_boxes3d = boxes3d.clone()
210 | large_boxes3d[:, 3:6] += extra_width * 2
211 | large_boxes3d[:, 1] += extra_width
212 | return large_boxes3d
213 |
214 |
215 | def in_hull(p, hull):
216 | """
217 | :param p: (N, K) test points
218 | :param hull: (M, K) M corners of a box
219 | :return (N) bool
220 | """
221 | try:
222 | if not isinstance(hull, Delaunay):
223 | hull = Delaunay(hull)
224 | flag = hull.find_simplex(p) >= 0
225 | except scipy.spatial.qhull.QhullError:
226 | print('Warning: not a hull %s' % str(hull))
227 | flag = np.zeros(p.shape[0], dtype=np.bool)
228 |
229 | return flag
230 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/tools/pointnet2_msg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.append('..')
5 | from pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG
6 | import pointnet2.pytorch_utils as pt_utils
7 |
8 |
9 | def get_model(input_channels=0):
10 | return Pointnet2MSG(input_channels=input_channels)
11 |
12 |
13 | NPOINTS = [4096, 1024, 256, 64]
14 | RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
15 | NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]]
16 | MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]],
17 | [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]]
18 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]]
19 | CLS_FC = [128]
20 | DP_RATIO = 0.5
21 |
22 |
23 | class Pointnet2MSG(nn.Module):
24 | def __init__(self, input_channels=6):
25 | super().__init__()
26 |
27 | self.SA_modules = nn.ModuleList()
28 | channel_in = input_channels
29 |
30 | skip_channel_list = [input_channels]
31 | for k in range(NPOINTS.__len__()):
32 | mlps = MLPS[k].copy()
33 | channel_out = 0
34 | for idx in range(mlps.__len__()):
35 | mlps[idx] = [channel_in] + mlps[idx]
36 | channel_out += mlps[idx][-1]
37 |
38 | self.SA_modules.append(
39 | PointnetSAModuleMSG(
40 | npoint=NPOINTS[k],
41 | radii=RADIUS[k],
42 | nsamples=NSAMPLE[k],
43 | mlps=mlps,
44 | use_xyz=True,
45 | bn=True
46 | )
47 | )
48 | skip_channel_list.append(channel_out)
49 | channel_in = channel_out
50 |
51 | self.FP_modules = nn.ModuleList()
52 |
53 | for k in range(FP_MLPS.__len__()):
54 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out
55 | self.FP_modules.append(
56 | PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k])
57 | )
58 |
59 | cls_layers = []
60 | pre_channel = FP_MLPS[0][-1]
61 | for k in range(0, CLS_FC.__len__()):
62 | cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=True))
63 | pre_channel = CLS_FC[k]
64 | cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None))
65 | cls_layers.insert(1, nn.Dropout(0.5))
66 | self.cls_layer = nn.Sequential(*cls_layers)
67 |
68 | def _break_up_pc(self, pc):
69 | xyz = pc[..., 0:3].contiguous()
70 | features = (
71 | pc[..., 3:].transpose(1, 2).contiguous()
72 | if pc.size(-1) > 3 else None
73 | )
74 |
75 | return xyz, features
76 |
77 | def forward(self, pointcloud: torch.cuda.FloatTensor):
78 | xyz, features = self._break_up_pc(pointcloud)
79 |
80 | l_xyz, l_features = [xyz], [features]
81 | for i in range(len(self.SA_modules)):
82 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
83 |
84 | print(li_xyz.shape, li_features.shape)
85 |
86 | l_xyz.append(li_xyz)
87 | l_features.append(li_features)
88 |
89 | for i in range(-1, -(len(self.FP_modules) + 1), -1):
90 | l_features[i - 1] = self.FP_modules[i](
91 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
92 | )
93 |
94 | pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1)
95 | return pred_cls
96 |
97 | if __name__ == '__main__':
98 | net = Pointnet2MSG(0).cuda()
99 | pts = torch.randn(2, 1024, 3).cuda()
100 |
101 | pre = net(pts)
102 | print(pre.shape)
103 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnet2_utils/tools/train_and_eval.py:
--------------------------------------------------------------------------------
1 | import _init_path
2 | import numpy as np
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.optim.lr_scheduler as lr_sched
8 | from torch.nn.utils import clip_grad_norm_
9 | from torch.utils.data import DataLoader
10 | import tensorboard_logger as tb_log
11 | from dataset import KittiDataset
12 | import argparse
13 | import importlib
14 |
15 | parser = argparse.ArgumentParser(description="Arg parser")
16 | parser.add_argument("--batch_size", type=int, default=8)
17 | parser.add_argument("--epochs", type=int, default=100)
18 | parser.add_argument("--ckpt_save_interval", type=int, default=5)
19 | parser.add_argument('--workers', type=int, default=4)
20 | parser.add_argument("--mode", type=str, default='train')
21 | parser.add_argument("--ckpt", type=str, default='None')
22 |
23 | parser.add_argument("--net", type=str, default='pointnet2_msg')
24 |
25 | parser.add_argument('--lr', type=float, default=0.002)
26 | parser.add_argument('--lr_decay', type=float, default=0.2)
27 | parser.add_argument('--lr_clip', type=float, default=0.000001)
28 | parser.add_argument('--decay_step_list', type=list, default=[50, 70, 80, 90])
29 | parser.add_argument('--weight_decay', type=float, default=0.001)
30 |
31 | parser.add_argument("--output_dir", type=str, default='output')
32 | parser.add_argument("--extra_tag", type=str, default='default')
33 |
34 | args = parser.parse_args()
35 |
36 | FG_THRESH = 0.3
37 |
38 |
39 | def log_print(info, log_f=None):
40 | print(info)
41 | if log_f is not None:
42 | print(info, file=log_f)
43 |
44 |
45 | class DiceLoss(nn.Module):
46 | def __init__(self, ignore_target=-1):
47 | super().__init__()
48 | self.ignore_target = ignore_target
49 |
50 | def forward(self, input, target):
51 | """
52 | :param input: (N), logit
53 | :param target: (N), {0, 1}
54 | :return:
55 | """
56 | input = torch.sigmoid(input.view(-1))
57 | target = target.float().view(-1)
58 | mask = (target != self.ignore_target).float()
59 | return 1.0 - (torch.min(input, target) * mask).sum() / torch.clamp((torch.max(input, target) * mask).sum(), min=1.0)
60 |
61 |
62 | def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f):
63 | model.train()
64 | log_print('===============TRAIN EPOCH %d================' % epoch, log_f=log_f)
65 | loss_func = DiceLoss(ignore_target=-1)
66 |
67 | for it, batch in enumerate(train_loader):
68 | optimizer.zero_grad()
69 |
70 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels']
71 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
72 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1)
73 |
74 | pred_cls = model(pts_input)
75 | pred_cls = pred_cls.view(-1)
76 |
77 | loss = loss_func(pred_cls, cls_labels)
78 | loss.backward()
79 | clip_grad_norm_(model.parameters(), 1.0)
80 | optimizer.step()
81 |
82 | total_it += 1
83 |
84 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH)
85 | fg_mask = cls_labels > 0
86 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum()
87 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
88 | iou = correct / torch.clamp(union, min=1.0)
89 |
90 | cur_lr = lr_scheduler.get_lr()[0]
91 | tb_log.log_value('learning_rate', cur_lr, epoch)
92 | if tb_log is not None:
93 | tb_log.log_value('train_loss', loss, total_it)
94 | tb_log.log_value('train_fg_iou', iou, total_it)
95 |
96 | log_print('training epoch %d: it=%d/%d, total_it=%d, loss=%.5f, fg_iou=%.3f, lr=%f' %
97 | (epoch, it, len(train_loader), total_it, loss.item(), iou.item(), cur_lr), log_f=log_f)
98 |
99 | return total_it
100 |
101 |
102 | def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None):
103 | model.train()
104 | log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f)
105 |
106 | iou_list = []
107 | for it, batch in enumerate(eval_loader):
108 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels']
109 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
110 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1)
111 |
112 | pred_cls = model(pts_input)
113 | pred_cls = pred_cls.view(-1)
114 |
115 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH)
116 | fg_mask = cls_labels > 0
117 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum()
118 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
119 | iou = correct / torch.clamp(union, min=1.0)
120 |
121 | iou_list.append(iou.item())
122 | log_print('EVAL: it=%d/%d, iou=%.3f' % (it, len(eval_loader), iou), log_f=log_f)
123 |
124 | iou_list = np.array(iou_list)
125 | avg_iou = iou_list.mean()
126 | if tb_log is not None:
127 | tb_log.log_value('eval_fg_iou', avg_iou, epoch)
128 |
129 | log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f)
130 | return avg_iou
131 |
132 |
133 | def save_checkpoint(model, epoch, ckpt_name):
134 | if isinstance(model, torch.nn.DataParallel):
135 | model_state = model.module.state_dict()
136 | else:
137 | model_state = model.state_dict()
138 |
139 | state = {'epoch': epoch, 'model_state': model_state}
140 | ckpt_name = '{}.pth'.format(ckpt_name)
141 | torch.save(state, ckpt_name)
142 |
143 |
144 | def load_checkpoint(model, filename):
145 | if os.path.isfile(filename):
146 | log_print("==> Loading from checkpoint %s" % filename)
147 | checkpoint = torch.load(filename)
148 | epoch = checkpoint['epoch']
149 | model.load_state_dict(checkpoint['model_state'])
150 | log_print("==> Done")
151 | else:
152 | raise FileNotFoundError
153 |
154 | return epoch
155 |
156 |
157 | def train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f):
158 | model.cuda()
159 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
160 |
161 | def lr_lbmd(cur_epoch):
162 | cur_decay = 1
163 | for decay_step in args.decay_step_list:
164 | if cur_epoch >= decay_step:
165 | cur_decay = cur_decay * args.lr_decay
166 | return max(cur_decay, args.lr_clip / args.lr)
167 |
168 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
169 |
170 | total_it = 0
171 | for epoch in range(1, args.epochs + 1):
172 | lr_scheduler.step(epoch)
173 | total_it = train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f)
174 |
175 | if epoch % args.ckpt_save_interval == 0:
176 | with torch.no_grad():
177 | avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log, log_f)
178 | ckpt_name = os.path.join(ckpt_dir, 'checkpoint_epoch_%d' % epoch)
179 | save_checkpoint(model, epoch, ckpt_name)
180 |
181 |
182 | if __name__ == '__main__':
183 | MODEL = importlib.import_module(args.net) # import network module
184 | model = MODEL.get_model(input_channels=0)
185 |
186 | eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val')
187 | eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True,
188 | num_workers=args.workers, collate_fn=eval_set.collate_batch)
189 |
190 | if args.mode == 'train':
191 | train_set = KittiDataset(root_dir='./data', mode='TRAIN', split='train')
192 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True,
193 | num_workers=args.workers, collate_fn=train_set.collate_batch)
194 | # output dir config
195 | output_dir = os.path.join(args.output_dir, args.extra_tag)
196 | os.makedirs(output_dir, exist_ok=True)
197 | tb_log.configure(os.path.join(output_dir, 'tensorboard'))
198 | ckpt_dir = os.path.join(output_dir, 'ckpt')
199 | os.makedirs(ckpt_dir, exist_ok=True)
200 |
201 | log_file = os.path.join(output_dir, 'log.txt')
202 | log_f = open(log_file, 'w')
203 |
204 | for key, val in vars(args).items():
205 | log_print("{:16} {}".format(key, val), log_f=log_f)
206 |
207 | # train and eval
208 | train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f)
209 | log_f.close()
210 | elif args.mode == 'eval':
211 | epoch = load_checkpoint(model, args.ckpt)
212 | model.cuda()
213 | with torch.no_grad():
214 | avg_iou = eval_one_epoch(model, eval_loader, epoch)
215 | else:
216 | raise NotImplementedError
217 |
218 |
--------------------------------------------------------------------------------
/networks/pts_encoder/pointnets.py:
--------------------------------------------------------------------------------
1 | """refer to https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py."""
2 |
3 | from __future__ import print_function
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.utils.data
8 | from torch.autograd import Variable
9 | from ipdb import set_trace
10 | import numpy as np
11 | import torch.nn.functional as F
12 |
13 |
14 | class STN3d(nn.Module):
15 | def __init__(self):
16 | super(STN3d, self).__init__()
17 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
18 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
19 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
20 | self.fc1 = nn.Linear(1024, 512)
21 | self.fc2 = nn.Linear(512, 256)
22 | self.fc3 = nn.Linear(256, 9)
23 | self.relu = nn.ReLU()
24 |
25 | def forward(self, x):
26 | batchsize = x.size()[0]
27 | x = F.relu(self.conv1(x))
28 | x = F.relu(self.conv2(x))
29 | x = F.relu(self.conv3(x))
30 | x = torch.max(x, 2, keepdim=True)[0]
31 | x = x.view(-1, 1024)
32 |
33 | x = F.relu(self.fc1(x))
34 | x = F.relu(self.fc2(x))
35 | x = self.fc3(x)
36 |
37 | iden = Variable(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=torch.float32)).view(1, 9).repeat(batchsize, 1)
38 | if x.is_cuda:
39 | iden = iden.cuda()
40 | x = x + iden
41 | x = x.view(-1, 3, 3)
42 | return x
43 |
44 |
45 | class STNkd(nn.Module):
46 | def __init__(self, k=64):
47 | super(STNkd, self).__init__()
48 | self.conv1 = torch.nn.Conv1d(k, 64, 1)
49 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
50 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
51 | self.fc1 = nn.Linear(1024, 512)
52 | self.fc2 = nn.Linear(512, 256)
53 | self.fc3 = nn.Linear(256, k * k)
54 | self.relu = nn.ReLU()
55 |
56 | self.k = k
57 |
58 | def forward(self, x):
59 | batchsize = x.size()[0]
60 | x = F.relu(self.conv1(x))
61 | x = F.relu(self.conv2(x))
62 | x = F.relu(self.conv3(x))
63 | x = torch.max(x, 2, keepdim=True)[0]
64 | x = x.view(-1, 1024)
65 |
66 | x = F.relu(self.fc1(x))
67 | x = F.relu(self.fc2(x))
68 | x = self.fc3(x)
69 |
70 | iden = (
71 | Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
72 | .view(1, self.k * self.k)
73 | .repeat(batchsize, 1)
74 | )
75 | if x.is_cuda:
76 | iden = iden.cuda()
77 | x = x + iden
78 | x = x.view(-1, self.k, self.k)
79 | return x
80 |
81 |
82 | # NOTE: removed BN
83 | class PointNetfeat(nn.Module):
84 | def __init__(self, num_points, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False, **args):
85 | super(PointNetfeat, self).__init__()
86 | self.num_points = num_points
87 | self.out_dim = out_dim
88 | self.feature_transform = feature_transform
89 | # self.stn = STN3d(in_dim=in_dim)
90 | self.stn = STNkd(k=in_dim)
91 | self.conv1 = torch.nn.Conv1d(in_dim, 64, 1)
92 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
93 | self.conv3 = torch.nn.Conv1d(128, 512, 1)
94 | self.conv4 = torch.nn.Conv1d(512, out_dim, 1)
95 | self.global_feat = global_feat
96 | if self.feature_transform:
97 | self.fstn = STNkd(k=64)
98 |
99 | def forward(self, x, **args):
100 | n_pts = x.shape[2]
101 | trans = self.stn(x)
102 | x = x.transpose(2, 1)
103 | x = torch.bmm(x, trans)
104 | x = x.transpose(2, 1)
105 | x = F.relu(self.conv1(x))
106 |
107 | if self.feature_transform:
108 | trans_feat = self.fstn(x)
109 | x = x.transpose(2, 1)
110 | x = torch.bmm(x, trans_feat)
111 | x = x.transpose(2, 1)
112 |
113 | pointfeat = x
114 | x = F.relu(self.conv2(x))
115 | x = F.relu(self.conv3(x))
116 | x = self.conv4(x)
117 | x = torch.max(x, 2, keepdim=True)[0]
118 | x = x.view(-1, self.out_dim)
119 | if self.global_feat:
120 | return x
121 | else:
122 | x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
123 | return torch.cat([x, pointfeat], 1)
124 |
125 |
126 | def feature_transform_regularizer(trans):
127 | d = trans.size()[1]
128 | batchsize = trans.size()[0]
129 | I = torch.eye(d)[None, :, :]
130 | if trans.is_cuda:
131 | I = I.cuda()
132 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)))
133 | return loss
134 |
135 |
136 | if __name__ == "__main__":
137 | sim_data = Variable(torch.rand(32, 3, 2500))
138 | trans = STN3d()
139 | out = trans(sim_data)
140 | print("stn", out.size())
141 | print("loss", feature_transform_regularizer(out))
142 |
143 | sim_data_64d = Variable(torch.rand(32, 64, 2500))
144 | trans = STNkd(k=64)
145 | out = trans(sim_data_64d)
146 | print("stn64d", out.size())
147 | print("loss", feature_transform_regularizer(out))
148 |
149 | pointfeat_g = PointNetfeat(global_feat=True, num_points=2500)
150 | out = pointfeat_g(sim_data)
151 | print("global feat", out.size())
152 |
153 | pointfeat = PointNetfeat(global_feat=False, num_points=2500)
154 | out = pointfeat(sim_data)
155 | print("point feat", out.size())
156 |
157 |
--------------------------------------------------------------------------------
/networks/reward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | sys.path.append(os.getcwd())
8 |
9 | from ipdb import set_trace
10 | from utils.genpose_utils import get_pose_dim
11 | from utils.metrics import get_metrics
12 |
13 |
14 | class RewardModel(nn.Module):
15 | def __init__(self, pose_mode):
16 | """
17 | init func.
18 |
19 | Args:
20 | encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
21 | """
22 | super(RewardModel, self).__init__()
23 | pose_dim = get_pose_dim(pose_mode)
24 | self.act = nn.ReLU(True)
25 |
26 | ''' encode pose '''
27 | self.pose_encoder = nn.Sequential(
28 | nn.Linear(pose_dim, 256),
29 | self.act,
30 | nn.Linear(256, 256),
31 | self.act,
32 | )
33 |
34 | ''' decoder '''
35 | self.reward_layer = nn.Sequential(
36 | nn.Linear(1024+256, 256),
37 | self.act,
38 | nn.Linear(256, 2),
39 | )
40 |
41 | def forward(
42 | self,
43 | pts_feature,
44 | pose
45 | ):
46 | """
47 | calculate the score of every pose
48 |
49 | Args:
50 | pts_feature (torch.tensor): [batch, 1024]
51 | pred_pose (torch.tensor): [batch, pose_dim]
52 | Returns:
53 | reward (torch.tensor): [batch, 2], the score of the pose estimation results,
54 | the first item is rotation score and the second item is translation score.
55 | """
56 |
57 | pose_feature = self.pose_encoder(pose)
58 | feature = torch.cat((pts_feature, pose_feature), dim=-1) # [bs, 1024+256]
59 | reward = self.reward_layer(feature) # (batch, 1)
60 | return reward
61 |
62 |
63 | def sort_results(energy, metrics):
64 | """ Sorting the results according to the pose error (low to high)
65 |
66 | Args:
67 | energy (torch.tensor): [bs, repeat_num, 2]
68 | metrics (torch.tensor): [bs, repeat_num, 2]
69 |
70 | Return:
71 | sorted_energy (torch.tensor): [bs, repeat_num, 2]
72 | """
73 | rot_error = metrics[..., 0]
74 | trans_error = metrics[..., 1]
75 |
76 | rot_index = torch.argsort(rot_error, dim=1, descending=False)
77 | trans_index = torch.argsort(trans_error, dim=1, descending=False)
78 |
79 | sorted_energy = energy.clone()
80 | sorted_energy[..., 0] = energy[..., 0].gather(1, rot_index)
81 | sorted_energy[..., 1] = energy[..., 1].gather(1, trans_index)
82 |
83 | return sorted_energy
84 |
85 |
86 | # def ranking_loss(energy):
87 | # """ Calculate the ranking loss
88 |
89 | # Args:
90 | # energy (torch.tensor): [bs, repeat_num, 2]
91 |
92 | # Returns:
93 | # loss (torch.tensor)
94 | # """
95 | # loss, count = 0, 0
96 | # repeat_num = energy.shape[1]
97 |
98 | # for i in range(repeat_num - 1):
99 | # for j in range(i+1, repeat_num):
100 | # # diff = torch.log(torch.sigmoid(score[:, i, :] - score[:, j, :]))
101 | # diff = torch.sigmoid(-energy[:, i, :] + energy[:, j, :])
102 | # loss += torch.mean(diff)
103 | # count += 1
104 | # loss = loss / count
105 | # return loss
106 |
107 |
108 |
109 | def ranking_loss(energy):
110 | """ Calculate the ranking loss
111 |
112 | Args:
113 | energy (torch.tensor): [bs, repeat_num, 2]
114 |
115 | Returns:
116 | loss (torch.tensor)
117 | """
118 | loss, count = 0, 0
119 | repeat_num = energy.shape[1]
120 |
121 | for i in range(repeat_num - 1):
122 | for j in range(i+1, repeat_num):
123 | # diff = torch.log(torch.sigmoid(score[:, i, :] - score[:, j, :]))
124 | diff = 1 + (-energy[:, i, :] + energy[:, j, :]) / (torch.abs(energy[:, i, :] - energy[:, j, :]) + 1e-5)
125 | loss += torch.mean(diff)
126 | count += 1
127 | loss = loss / count
128 | return loss
129 |
130 |
131 | def sort_poses_by_energy(poses, energy):
132 | """ Rank the poses from highest to lowest energy
133 |
134 | Args:
135 | poses (torch.tensor): [bs, inference_num, pose_dim]
136 | energy (torch.tensor): [bs, inference_num, 2]
137 |
138 | Returns:
139 | sorted_poses (torch.tensor): [bs, inference_num, pose_dim]
140 | sorted_energy (torch.tensor): [bs, inference_num, 2]
141 | """
142 | # get the sorted energy
143 | bs = poses.shape[0]
144 | repeat_num= poses.shape[1]
145 | sorted_energy, indices_1 = torch.sort(energy, descending=True, dim=1)
146 | indices_0 = torch.arange(0, energy.shape[0]).view(1, -1).to(energy.device).repeat(1, repeat_num)
147 | indices_1_rot = indices_1.permute(2, 1, 0)[0].reshape(1, -1)
148 | indices_1_trans = indices_1.permute(2, 1, 0)[1].reshape(1, -1)
149 | rot_index = torch.cat((indices_0, indices_1_rot), dim=0).cpu().numpy().tolist()
150 | trans_index = torch.cat((indices_0, indices_1_trans), dim=0).cpu().numpy().tolist()
151 | sorted_poses = poses[rot_index]
152 | sorted_poses[:, -3:] = poses[trans_index][:, -3:]
153 | sorted_poses = sorted_poses.view(repeat_num, bs, -1).permute(1, 0, 2)
154 |
155 | return sorted_poses, sorted_energy
156 |
157 |
158 | def test_ranking_loss():
159 | energy = torch.tensor([[[100, 100],
160 | [9, 9],
161 | [8, 8],
162 | [10, 10]]])
163 | loss = ranking_loss(energy)
164 | print(loss)
165 |
166 | if __name__ == '__main__':
167 | test_ranking_loss()
168 | # bs = 3
169 | # repeat_num = 5
170 | # pts_feature = torch.randn(bs, 1024)
171 | # pred_pose = torch.randn(bs, repeat_num, 7)
172 | # metrics = torch.randn(bs, repeat_num, 2)
173 |
174 | # reward_model = RewardModel(pose_mode='quat_wxyz')
175 | # reward = reward_model(pts_feature.unsqueeze(1).repeat(1, repeat_num, 1), pred_pose)
176 | # sorted_reward = sort_results(reward, metrics)
177 | # loss = ranking_loss(sorted_reward)
178 |
179 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.2.0.32
2 | scipy==1.4.1
3 | numpy==1.23.5
4 | tensorboardX==2.5.1
--------------------------------------------------------------------------------
/scripts/eval_single.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
2 | --score_model_dir ScoreNet/ckpt_genpose.pth \
3 | --energy_model_dir EnergyNet/ckpt_genpose.pth \
4 | --data_path NOCS_DATASET_PATH \
5 | --sampler_mode ode \
6 | --max_eval_num 1000000 \
7 | --percentage_data_for_test 1.0 \
8 | --batch_size 256 \
9 | --seed 0 \
10 | --test_source real_test \
11 | --result_dir results \
12 | --eval_repeat_num 50 \
13 | --pooling_mode average \
14 | --ranker energy_ranker \
15 | --T0 0.55 \
16 | # --save_video \
17 |
--------------------------------------------------------------------------------
/scripts/eval_tracking.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \
2 | --score_model_dir ScoreNet/ckpt_genpose.pth \
3 | --energy_model_dir EnergyNet/ckpt_genpose.pth \
4 | --data_path NOCS_DATASET_PATH \
5 | --sampler_mode ode \
6 | --max_eval_num 1000000 \
7 | --percentage_data_for_test 1.0 \
8 | --batch_size 256 \
9 | --seed 0 \
10 | --test_source aligned_real_test \
11 | --result_dir results \
12 | --eval_repeat_num 50 \
13 | --pooling_mode average \
14 | --ranker energy_ranker \
15 | --T0 0.15 \
16 | # --save_video \
17 |
--------------------------------------------------------------------------------
/scripts/tensorboard.sh:
--------------------------------------------------------------------------------
1 | tensorboard --logdir ./results/logs/ --port 0505 --reload_interval 1 --samples_per_plugin images=999
--------------------------------------------------------------------------------
/scripts/train_energy.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python runners/trainer.py \
2 | --data_path NOCS_DATASET_PATH \
3 | --log_dir EnergyNet \
4 | --agent_type energy \
5 | --sampler_mode ode \
6 | --batch_size 192 \
7 | --eval_freq 1 \
8 | --n_epochs 200 \
9 | --selected_classes bottle bowl camera can laptop mug \
10 | --percentage_data_for_train 1.0 \
11 | --percentage_data_for_test 1.0 \
12 | --percentage_data_for_val 1.0 \
13 | --seed 0 \
14 | --is_train \
15 |
--------------------------------------------------------------------------------
/scripts/train_score.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python runners/trainer.py \
2 | --data_path NOCS_DATASET_PATH \
3 | --log_dir ScoreNet \
4 | --agent_type score \
5 | --sampler_mode ode \
6 | --sampling_steps 500 \
7 | --eval_freq 1 \
8 | --n_epochs 1900 \
9 | --percentage_data_for_train 1.0 \
10 | --percentage_data_for_test 1.0 \
11 | --percentage_data_for_val 1.0 \
12 | --seed 0 \
13 | --is_train \
14 |
--------------------------------------------------------------------------------
/utils/datasets_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def get_2d_coord_np(width, height, low=0, high=1, fmt="CHW"):
5 | """
6 | Args:
7 | width:
8 | height:
9 | Returns:
10 | xy: (2, height, width)
11 | """
12 | # coords values are in [low, high] [0,1] or [-1,1]
13 | x = np.linspace(0, width-1, width, dtype=np.float32)
14 | y = np.linspace(0, height-1, height, dtype=np.float32)
15 | xy = np.asarray(np.meshgrid(x, y))
16 | if fmt == "HWC":
17 | xy = xy.transpose(1, 2, 0)
18 | elif fmt == "CHW":
19 | pass
20 | else:
21 | raise ValueError(f"Unknown format: {fmt}")
22 | return xy
23 |
24 |
25 | def aug_bbox_DZI(hyper_params, bbox_xyxy, im_H, im_W):
26 | """Used for DZI, the augmented box is a square (maybe enlarged)
27 | Args:
28 | bbox_xyxy (np.ndarray):
29 | Returns:
30 | center, scale
31 | """
32 | x1, y1, x2, y2 = bbox_xyxy.copy()
33 | cx = 0.5 * (x1 + x2)
34 | cy = 0.5 * (y1 + y2)
35 | bh = y2 - y1
36 | bw = x2 - x1
37 | if hyper_params['DZI_TYPE'].lower() == "uniform":
38 | scale_ratio = 1 + hyper_params['DZI_SCALE_RATIO'] * (2 * np.random.random_sample() - 1) # [1-0.25, 1+0.25]
39 | shift_ratio = hyper_params['DZI_SHIFT_RATIO'] * (2 * np.random.random_sample(2) - 1) # [-0.25, 0.25]
40 | bbox_center = np.array([cx + bw * shift_ratio[0], cy + bh * shift_ratio[1]]) # (h/2, w/2)
41 | scale = max(y2 - y1, x2 - x1) * scale_ratio * hyper_params['DZI_PAD_SCALE']
42 | elif hyper_params['DZI_TYPE'].lower() == "roi10d":
43 | # shift (x1,y1), (x2,y2) by 15% in each direction
44 | _a = -0.15
45 | _b = 0.15
46 | x1 += bw * (np.random.rand() * (_b - _a) + _a)
47 | x2 += bw * (np.random.rand() * (_b - _a) + _a)
48 | y1 += bh * (np.random.rand() * (_b - _a) + _a)
49 | y2 += bh * (np.random.rand() * (_b - _a) + _a)
50 | x1 = min(max(x1, 0), im_W)
51 | x2 = min(max(x1, 0), im_W)
52 | y1 = min(max(y1, 0), im_H)
53 | y2 = min(max(y2, 0), im_H)
54 | bbox_center = np.array([0.5 * (x1 + x2), 0.5 * (y1 + y2)])
55 | scale = max(y2 - y1, x2 - x1) * hyper_params['DZI_PAD_SCALE']
56 | elif hyper_params['DZI_TYPE'].lower() == "truncnorm":
57 | raise NotImplementedError("DZI truncnorm not implemented yet.")
58 | else:
59 | bbox_center = np.array([cx, cy]) # (w/2, h/2)
60 | scale = max(y2 - y1, x2 - x1)
61 | scale = min(scale, max(im_H, im_W)) * 1.0
62 | return bbox_center, scale
63 |
64 |
65 | def aug_bbox_eval(bbox_xyxy, im_H, im_W):
66 | """Used for DZI, the augmented box is a square (maybe enlarged)
67 | Args:
68 | bbox_xyxy (np.ndarray):
69 | Returns:
70 | center, scale
71 | """
72 | x1, y1, x2, y2 = bbox_xyxy.copy()
73 | cx = 0.5 * (x1 + x2)
74 | cy = 0.5 * (y1 + y2)
75 | bh = y2 - y1
76 | bw = x2 - x1
77 | bbox_center = np.array([cx, cy]) # (w/2, h/2)
78 | scale = max(y2 - y1, x2 - x1)
79 | scale = min(scale, max(im_H, im_W)) * 1.0
80 | return bbox_center, scale
81 |
82 | def crop_resize_by_warp_affine(img, center, scale, output_size, rot=0, interpolation=cv2.INTER_LINEAR):
83 | """
84 | output_size: int or (w, h)
85 | NOTE: if img is (h,w,1), the output will be (h,w)
86 | """
87 | if isinstance(scale, (int, float)):
88 | scale = (scale, scale)
89 | if isinstance(output_size, int):
90 | output_size = (output_size, output_size)
91 | trans = get_affine_transform(center, scale, rot, output_size)
92 |
93 | dst_img = cv2.warpAffine(img, trans, (int(output_size[0]), int(output_size[1])), flags=interpolation)
94 |
95 | return dst_img
96 |
97 | def get_affine_transform(center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=False):
98 | """
99 | adapted from CenterNet: https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py
100 | center: ndarray: (cx, cy)
101 | scale: (w, h)
102 | rot: angle in deg
103 | output_size: int or (w, h)
104 | """
105 | if isinstance(center, (tuple, list)):
106 | center = np.array(center, dtype=np.float32)
107 |
108 | if isinstance(scale, (int, float)):
109 | scale = np.array([scale, scale], dtype=np.float32)
110 |
111 | if isinstance(output_size, (int, float)):
112 | output_size = (output_size, output_size)
113 |
114 | scale_tmp = scale
115 | src_w = scale_tmp[0]
116 | dst_w = output_size[0]
117 | dst_h = output_size[1]
118 |
119 | rot_rad = np.pi * rot / 180
120 | src_dir = get_dir([0, src_w * -0.5], rot_rad)
121 | dst_dir = np.array([0, dst_w * -0.5], np.float32)
122 |
123 | src = np.zeros((3, 2), dtype=np.float32)
124 | dst = np.zeros((3, 2), dtype=np.float32)
125 | src[0, :] = center + scale_tmp * shift
126 | src[1, :] = center + src_dir + scale_tmp * shift
127 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
128 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
129 |
130 | src[2:, :] = get_3rd_point(src[0, :], src[1, :])
131 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
132 |
133 | if inv:
134 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
135 | else:
136 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
137 |
138 | return trans
139 |
140 | def get_dir(src_point, rot_rad):
141 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
142 |
143 | src_result = [0, 0]
144 | src_result[0] = src_point[0] * cs - src_point[1] * sn
145 | src_result[1] = src_point[0] * sn + src_point[1] * cs
146 |
147 | return src_result
148 |
149 | def get_3rd_point(a, b):
150 | direct = a - b
151 | return b + np.array([-direct[1], direct[0]], dtype=np.float32)
152 |
--------------------------------------------------------------------------------
/utils/genpose_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from ipdb import set_trace
6 |
7 |
8 | def get_pose_dim(rot_mode):
9 | assert rot_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'], \
10 | f"the rotation mode {rot_mode} is not supported!"
11 |
12 | if rot_mode == 'quat_wxyz' or rot_mode == 'quat_xyzw':
13 | pose_dim = 7
14 | elif rot_mode == 'euler_xyz':
15 | pose_dim = 6
16 | elif rot_mode == 'euler_xyz_sx_cx' or rot_mode == 'rot_matrix':
17 | pose_dim = 9
18 | else:
19 | raise NotImplementedError
20 | return pose_dim
21 |
22 | '''
23 | def rot6d_to_mat_batch(d6):
24 | """
25 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix.
26 | Args:
27 | d6: 6D rotation representation, of size (*, 6)
28 | Returns:
29 | batch of rotation matrices of size (*, 3, 3)
30 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
31 | On the Continuity of Rotation Representations in Neural Networks. CVPR 2019.
32 | Retrieved from http://arxiv.org/abs/1812.07035
33 | """
34 | # poses
35 | x_raw = d6[..., 0:3] # bx3
36 | y_raw = d6[..., 3:6] # bx3
37 |
38 | x = F.normalize(x_raw, p=2, dim=-1) # bx3
39 | z = torch.cross(x, y_raw, dim=-1) # bx3
40 | z = F.normalize(z, p=2, dim=-1) # bx3
41 | y = torch.cross(z, x, dim=-1) # bx3
42 |
43 | # (*,3)x3 --> (*,3,3)
44 | return torch.stack((x, y, z), dim=-1) # (b,3,3)
45 | '''
46 |
47 | def rot6d_to_mat_batch(d6):
48 | """
49 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix.
50 | Args:
51 | d6: 6D rotation representation, of size (*, 6)
52 | Returns:
53 | batch of rotation matrices of size (*, 3, 3)
54 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
55 | On the Continuity of Rotation Representations in Neural Networks. CVPR 2019.
56 | Retrieved from http://arxiv.org/abs/1812.07035
57 | """
58 | # poses
59 | x_raw = d6[..., 0:3] # bx3
60 | y_raw = d6[..., 3:6] # bx3
61 |
62 | x = x_raw / np.linalg.norm(x_raw, axis=-1, keepdims=True) # b*3
63 | z = np.cross(x, y_raw) # b*3
64 | z = z / np.linalg.norm(z, axis=-1, keepdims=True) # b*3
65 | y = np.cross(z, x) # b*3
66 |
67 | return np.stack((x, y, z), axis=-1) # (b,3,3)
68 |
69 |
70 | class TrainClock(object):
71 | """ Clock object to track epoch and step during training
72 | """
73 | def __init__(self):
74 | self.epoch = 1
75 | self.minibatch = 0
76 | self.step = 0
77 |
78 | def tick(self):
79 | self.minibatch += 1
80 | self.step += 1
81 |
82 | def tock(self):
83 | self.epoch += 1
84 | self.minibatch = 0
85 |
86 | def make_checkpoint(self):
87 | return {
88 | 'epoch': self.epoch,
89 | 'minibatch': self.minibatch,
90 | 'step': self.step
91 | }
92 |
93 | def restore_checkpoint(self, clock_dict):
94 | self.epoch = clock_dict['epoch']
95 | self.minibatch = clock_dict['minibatch']
96 | self.step = clock_dict['step']
97 |
98 |
99 | def merge_results(results_ori, results_new):
100 | if len(results_ori.keys()) == 0:
101 | return results_new
102 | else:
103 | results = {
104 | 'pred_pose': torch.cat([results_ori['pred_pose'], results_new['pred_pose']], dim=0),
105 | 'gt_pose': torch.cat([results_ori['gt_pose'], results_new['gt_pose']], dim=0),
106 | 'cls_id': torch.cat([results_ori['cls_id'], results_new['cls_id']], dim=0),
107 | 'handle_visibility': torch.cat([results_ori['handle_visibility'], results_new['handle_visibility']], dim=0),
108 | # 'path': results_ori['path'] + results_new['path'],
109 | }
110 | return results
111 |
112 |
113 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | import torch
5 | import numpy as np
6 | import pickle
7 |
8 | from utils.misc import get_rot_matrix, inverse_RT
9 | from utils.genpose_utils import get_pose_dim
10 | from ipdb import set_trace
11 |
12 | def rot_diff_rad(rot1, rot2, chosen_axis=None, flip_axis=False):
13 | if chosen_axis is not None:
14 | axis = {'x': 0, 'y': 1, 'z': 2}
15 | y1, y2 = rot1[..., axis], rot2[..., axis] # [Bs, 3]
16 | diff = torch.sum(y1 * y2, dim=-1) # [Bs]
17 | diff = torch.clamp(diff, min=-1.0, max=1.0)
18 | rad = torch.acos(diff)
19 | if not flip_axis:
20 | return rad
21 | else:
22 | return torch.min(rad, np.pi - rad)
23 |
24 | else:
25 | mat_diff = torch.matmul(rot1, rot2.transpose(-1, -2))
26 | diff = mat_diff[..., 0, 0] + mat_diff[..., 1, 1] + mat_diff[..., 2, 2]
27 | diff = (diff - 1) / 2.0
28 | diff = torch.clamp(diff, min=-1.0, max=1.0)
29 | return torch.acos(diff)
30 |
31 |
32 | def rot_diff_degree(rot1, rot2, chosen_axis=None, flip_axis=False):
33 | return rot_diff_rad(rot1, rot2, chosen_axis=chosen_axis, flip_axis=flip_axis) / np.pi * 180.0
34 |
35 |
36 | def get_trans_error(trans_1, trans_2):
37 | diff = torch.norm(trans_1 - trans_2, dim=-1)
38 | return diff
39 |
40 |
41 | def get_rot_error(rot_1, rot_2, error_mode, chosen_axis=None, flip_axis=False):
42 | assert error_mode in ['radian', 'degree'], f"the rotation error mode {error_mode} is not supported!"
43 | if error_mode == 'radian':
44 | rot_error = rot_diff_rad(rot_1, rot_2, chosen_axis, flip_axis)
45 | else:
46 | rot_error = rot_diff_degree(rot_1, rot_2, chosen_axis, flip_axis)
47 | return rot_error
48 |
49 |
50 | def get_metrics_single_category(pose_1, pose_2, pose_mode, error_mode, chosen_axis=None, flip_axis=False, o2c_pose=False):
51 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'rot_matrix'],\
52 | f"the rotation mode {pose_mode} is not supported!"
53 |
54 | if pose_mode == 'rot_matrix':
55 | index = 6
56 | elif pose_mode == 'euler_xyz':
57 | index = 3
58 | else:
59 | index = 4
60 |
61 | rot_1 = pose_1[:, :index]
62 | rot_2 = pose_2[:, :index]
63 | trans_1 = pose_1[:, index:]
64 | trans_2 = pose_2[:, index:]
65 |
66 | rot_matrix_1 = get_rot_matrix(rot_1, pose_mode)
67 | rot_matrix_2 = get_rot_matrix(rot_2, pose_mode)
68 |
69 | if o2c_pose == False:
70 | rot_matrix_1, trans_1 = inverse_RT(rot_matrix_1, trans_1)
71 | rot_matrix_2, trans_2 = inverse_RT(rot_matrix_2, trans_2)
72 |
73 | rot_error = get_rot_error(rot_matrix_1, rot_matrix_2, error_mode, chosen_axis, flip_axis)
74 | trans_error = get_trans_error(trans_1, trans_2)
75 |
76 | return rot_error.cpu().numpy(), trans_error.cpu().numpy()
77 |
78 |
79 | def compute_RT_errors(RT_1, RT_2, class_id, handle_visibility, synset_names):
80 | """
81 | Args:
82 | sRT_1: [4, 4]. homogeneous affine transformation
83 | sRT_2: [4, 4]. homogeneous affine transformation
84 |
85 | Returns:
86 | theta: angle difference of R in degree
87 | shift: l2 difference of T in centimeter
88 | """
89 | # make sure the last row is [0, 0, 0, 1]
90 | if RT_1 is None or RT_2 is None:
91 | return -1
92 | try:
93 | assert np.array_equal(RT_1[3, :], RT_2[3, :])
94 | assert np.array_equal(RT_1[3, :], np.array([0, 0, 0, 1]))
95 | except AssertionError:
96 | print(RT_1[3, :], RT_2[3, :])
97 | exit()
98 |
99 | R1 = RT_1[:3, :3] / np.cbrt(np.linalg.det(RT_1[:3, :3]))
100 | T1 = RT_1[:3, 3]
101 | R2 = RT_2[:3, :3] / np.cbrt(np.linalg.det(RT_2[:3, :3]))
102 | T2 = RT_2[:3, 3]
103 | # symmetric when rotating around y-axis
104 | if synset_names[class_id] in ['bottle', 'can', 'bowl'] or \
105 | (synset_names[class_id] == 'mug' and handle_visibility == 0):
106 | y = np.array([0, 1, 0])
107 | y1 = R1 @ y
108 | y2 = R2 @ y
109 | cos_theta = y1.dot(y2) / (np.linalg.norm(y1) * np.linalg.norm(y2))
110 | else:
111 | R = R1 @ R2.transpose()
112 | cos_theta = (np.trace(R) - 1) / 2
113 |
114 | theta = np.arccos(np.clip(cos_theta, -1.0, 1.0)) * 180 / np.pi
115 | shift = np.linalg.norm(T1 - T2) * 100
116 | result = np.array([theta, shift])
117 |
118 | return result
119 |
120 | '''
121 | def compute_RT_overlaps(gt_class_ids, gt_sRT, gt_handle_visibility, pred_class_ids, pred_sRT, synset_names):
122 | """ Finds overlaps between prediction and ground truth instances.
123 |
124 | Returns:
125 | overlaps:
126 |
127 | """
128 | num_pred = len(pred_class_ids)
129 | num_gt = len(gt_class_ids)
130 | overlaps = np.zeros((num_pred, num_gt, 2))
131 |
132 | for i in range(num_pred):
133 | for j in range(num_gt):
134 | overlaps[i, j, :] = compute_RT_errors(pred_sRT[i], gt_sRT[j], gt_class_ids[j],
135 | gt_handle_visibility[j], synset_names)
136 | return overlaps
137 |
138 | '''
139 |
140 |
141 | def compute_RT_overlaps(class_ids, gt_RT, pred_RT, gt_handle_visibility, synset_names):
142 | """ Finds overlaps between prediction and ground truth instances.
143 |
144 | Returns:
145 | overlaps:
146 |
147 | """
148 | num = len(class_ids)
149 | overlaps = np.zeros((num, 2))
150 |
151 | for i in range(num):
152 | overlaps[i, :] = compute_RT_errors(pred_RT[i], gt_RT[i], class_ids[i],
153 | gt_handle_visibility[i], synset_names)
154 | return overlaps
155 |
156 |
157 | def get_metrics(pose_1, pose_2, class_ids, synset_names, gt_handle_visibility, pose_mode, o2c_pose=False):
158 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\
159 | f"the rotation mode {pose_mode} is not supported!"
160 |
161 | index = get_pose_dim(pose_mode) - 3
162 |
163 | rot_1 = pose_1[:, :index]
164 | rot_2 = pose_2[:, :index]
165 | trans_1 = pose_1[:, index:]
166 | trans_2 = pose_2[:, index:]
167 |
168 | rot_matrix_1 = get_rot_matrix(rot_1, pose_mode)
169 | rot_matrix_2 = get_rot_matrix(rot_2, pose_mode)
170 |
171 | if o2c_pose == False:
172 | rot_matrix_1, trans_1 = inverse_RT(rot_matrix_1, trans_1)
173 | rot_matrix_2, trans_2 = inverse_RT(rot_matrix_2, trans_2)
174 |
175 | bs = pose_1.shape[0]
176 | RT_1 = torch.eye(4).unsqueeze(0).repeat([bs, 1, 1])
177 | RT_2 = torch.eye(4).unsqueeze(0).repeat([bs, 1, 1])
178 |
179 | RT_1[:, :3, :3] = rot_matrix_1
180 | RT_1[:, :3, 3] = trans_1
181 | RT_2[:, :3, :3] = rot_matrix_2
182 | RT_2[:, :3, 3] = trans_2
183 |
184 | error = compute_RT_overlaps(class_ids, RT_1.cpu().numpy(), RT_2.cpu().numpy(), gt_handle_visibility, synset_names)
185 | rot_error = error[:, 0]
186 | trans_error = error[:, 1]
187 | return rot_error, trans_error
188 |
189 |
190 | if __name__ == '__main__':
191 | gt_pose = torch.rand(8, 7)
192 | gt_pose[:, :4] /= torch.norm(gt_pose[:, :4], dim=-1, keepdim=True)
193 | noise_pose = gt_pose + torch.rand(8, 7) / 10
194 | noise_pose[:, :4] /= torch.norm(noise_pose[:, :4], dim=-1, keepdim=True)
195 | rot_error = get_rot_error(gt_pose[:, :4], noise_pose[:, :4], 'camera', 'quat_wxyz', 'degree')
196 | trans_error = get_trans_error(gt_pose[:, 4:], noise_pose[:, 4:])
197 | print(rot_error, trans_error)
198 |
199 |
200 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import copy
4 | import pytorch3d
5 | import pytorch3d.io
6 | import torch
7 | import torch.distributed as dist
8 | from ipdb import set_trace
9 | os.sys.path.append('..')
10 | from utils.genpose_utils import get_pose_dim
11 | from scipy.spatial.transform import Rotation as R
12 |
13 |
14 | def parallel_setup(rank, world_size, seed):
15 | os.environ['MASTER_ADDR'] = 'localhost'
16 | os.environ['MASTER_PORT'] = '12355'
17 |
18 | # initialize the process group
19 | dist.init_process_group("gloo", rank=rank, world_size=world_size)
20 |
21 | # Explicitly setting seed to make sure that models created in two processes
22 | # start from same random weights and biases.
23 | torch.manual_seed(seed)
24 |
25 |
26 | def parallel_cleanup():
27 | dist.destroy_process_group()
28 |
29 |
30 | def exists_or_mkdir(path):
31 | if not os.path.exists(path):
32 | os.makedirs(path)
33 | return False
34 | else:
35 | return True
36 |
37 |
38 | def depth2xyz(depth_img, camera_params):
39 | # scale camera parameters
40 | h, w = depth_img.shape
41 | scale_x = w / camera_params['xres']
42 | scale_y = h / camera_params['yres']
43 | fx = camera_params['fx'] * scale_x
44 | fy = camera_params['fy'] * scale_y
45 | x_offset = camera_params['cx'] * scale_x
46 | y_offset = camera_params['cy'] * scale_y
47 |
48 | indices = np.indices((h, w), dtype=np.float32).transpose(1,2,0)
49 | z_e = depth_img
50 | x_e = (indices[..., 1] - x_offset) * z_e / fx
51 | y_e = (indices[..., 0] - y_offset) * z_e / fy
52 | xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3]
53 | return xyz_img
54 |
55 |
56 | def fps_down_sample(vertices, num_point_sampled):
57 | # FPS down sample
58 | # vertices.shape = (N,3) or (N,2)
59 |
60 | N = len(vertices)
61 | n = num_point_sampled
62 | assert n <= N, "Num of sampled point should be less than or equal to the size of vertices."
63 | _G = np.mean(vertices, axis=0) # centroid of vertices
64 | _d = np.linalg.norm(vertices - _G, axis=1, ord=2)
65 | farthest = np.argmax(_d) # Select the point farthest from the center of gravity as the starting point
66 | distances = np.inf * np.ones((N,))
67 | flags = np.zeros((N,), np.bool_) # Whether the point is selected
68 | for i in range(n):
69 | flags[farthest] = True
70 | distances[farthest] = 0.
71 | p_farthest = vertices[farthest]
72 | dists = np.linalg.norm(vertices[~flags] - p_farthest, axis=1, ord=2)
73 | distances[~flags] = np.minimum(distances[~flags], dists)
74 | farthest = np.argmax(distances)
75 | return vertices[flags]
76 |
77 |
78 | def sample_data(data, num_sample):
79 | """ data is in N x ...
80 | we want to keep num_samplexC of them.
81 | if N > num_sample, we will randomly keep num_sample of them.
82 | if N < num_sample, we will randomly duplicate samples.
83 | """
84 | N = data.shape[0]
85 | if (N == num_sample):
86 | return data, range(N)
87 | elif (N > num_sample):
88 | sample = np.random.choice(N, num_sample)
89 | return data[sample, ...], sample
90 | else:
91 | # print(N)
92 | sample = np.random.choice(N, num_sample-N)
93 | dup_data = data[sample, ...]
94 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample)
95 |
96 |
97 | def trans_form_quat_and_location(quaternion, location, quat_type='wxyz'):
98 | assert quat_type in ['wxyz', 'xyzw'], f"The type of quaternion {quat_type} is not supported!"
99 |
100 | if quat_type == 'xyzw':
101 | quaternion_xyzw = quaternion
102 | else:
103 | quaternion_xyzw = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]]
104 |
105 | scipy_rot = R.from_quat(quaternion_xyzw)
106 | rot = scipy_rot.as_matrix()
107 |
108 | location = location[np.newaxis, :].T
109 | transformation = np.concatenate((rot, location), axis=1)
110 | transformation = np.concatenate((transformation, np.array([[0, 0, 0, 1]])), axis=0)
111 | return transformation
112 |
113 |
114 | def get_rot_matrix(batch_pose, pose_mode='quat_wxyz'):
115 | """
116 | pose_mode:
117 | 'quat_wxyz' -> batch_pose [B, 4]
118 | 'quat_xyzw' -> batch_pose [B, 4]
119 | 'euler_xyz' -> batch_pose [B, 3]
120 | 'rot_matrix' -> batch_pose [B, 6]
121 |
122 | Return: rot_matrix [B, 3, 3]
123 | """
124 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\
125 | f"the rotation mode {pose_mode} is not supported!"
126 |
127 | if pose_mode in ['quat_wxyz', 'quat_xyzw']:
128 | if pose_mode == 'quat_wxyz':
129 | quat_wxyz = batch_pose
130 | else:
131 | index = [3, 0, 1, 2]
132 | quat_wxyz = batch_pose[:, index]
133 | rot_mat = pytorch3d.transforms.quaternion_to_matrix(quat_wxyz)
134 |
135 | elif pose_mode == 'rot_matrix':
136 | rot_mat= pytorch3d.transforms.rotation_6d_to_matrix(batch_pose).permute(0, 2, 1)
137 |
138 | elif pose_mode == 'euler_xyz_sx_cx':
139 | rot_sin_theta = batch_pose[:, :3]
140 | rot_cos_theta = batch_pose[:, 3:6]
141 | theta = torch.atan2(rot_sin_theta, rot_cos_theta)
142 | rot_mat = pytorch3d.transforms.euler_angles_to_matrix(theta, 'ZYX')
143 | elif pose_mode == 'euler_xyz':
144 | rot_mat = pytorch3d.transforms.euler_angles_to_matrix(batch_pose, 'ZYX')
145 | else:
146 | raise NotImplementedError
147 |
148 | return rot_mat
149 |
150 |
151 | def transform_single_pts(pts, transformation):
152 | N = pts.shape[0]
153 | pts = np.concatenate((pts.T, np.ones(N)[np.newaxis, :]), axis=0)
154 | new_pts = transformation @ pts
155 | return new_pts.T[:, :3]
156 |
157 |
158 | def transform_batch_pts(batch_pts, batch_pose, pose_mode='quat_wxyz', inverse_pose=False):
159 | """
160 | Args:
161 | batch_pts [B, N, C], N is the number of points, and C [x, y, z, ...]
162 | batch_pose [B, C], [quat/rot_mat/euler, trans]
163 | pose_mode is from ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'rot_matrix']
164 | if inverse_pose is true, the transformation will be inversed
165 | Returns:
166 | new_pts [B, N, C]
167 | """
168 | assert pose_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'],\
169 | f"the rotation mode {pose_mode} is not supported!"
170 |
171 | B = batch_pts.shape[0]
172 | index = get_pose_dim(pose_mode) - 3
173 | rot = batch_pose[:, :index]
174 | loc = batch_pose[:, index:]
175 |
176 | rot_mat = get_rot_matrix(rot, pose_mode)
177 | if inverse_pose == True:
178 | rot_mat, loc = inverse_RT(rot_mat, loc)
179 | loc = loc[..., np.newaxis]
180 |
181 | trans_mat = torch.cat((rot_mat, loc), dim=2)
182 | trans_mat = torch.cat((trans_mat, torch.tile(torch.tensor([[0, 0, 0, 1]]).to(trans_mat.device), (B, 1, 1))), dim=1)
183 |
184 | new_pts = copy.deepcopy(batch_pts)
185 | padding = torch.ones([batch_pts.shape[0], batch_pts.shape[1], 1]).to(batch_pts.device)
186 | pts = torch.cat((batch_pts[:, :, :3], padding), dim=2)
187 | new_pts[:, :, :3] = torch.matmul(trans_mat.to(torch.float32), pts.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
188 |
189 | return new_pts
190 |
191 |
192 | def inverse_RT(batch_rot_mat, batch_trans):
193 | """
194 | Args:
195 | batch_rot_mat [B, 3, 3]
196 | batch_trans [B, 3]
197 | Return:
198 | inversed_rot_mat [B, 3, 3]
199 | inversed_trans [B, 3]
200 | """
201 | trans = batch_trans[..., np.newaxis]
202 | inversed_rot_mat = batch_rot_mat.permute(0, 2, 1)
203 | inversed_trans = - inversed_rot_mat @ trans
204 | return inversed_rot_mat, inversed_trans.squeeze(-1)
205 |
206 |
207 | """ https://arc.aiaa.org/doi/abs/10.2514/1.28949 """
208 | """ https://stackoverflow.com/questions/12374087/average-of-multiple-quaternions """
209 | """ http://tbirdal.blogspot.com/2019/10/i-allocate-this-post-to-providing.html """
210 | def average_quaternion_torch(Q, weights=None):
211 | if weights is None:
212 | weights = torch.ones(len(Q), device=Q.device) / len(Q)
213 | A = torch.zeros((4, 4), device=Q.device)
214 | weight_sum = torch.sum(weights)
215 |
216 | oriented_Q = ((Q[:, 0:1] > 0).float() - 0.5) * 2 * Q
217 | A = torch.einsum("bi,bk->bik", (oriented_Q, oriented_Q))
218 | A = torch.sum(torch.einsum("bij,b->bij", (A, weights)), 0)
219 | A /= weight_sum
220 |
221 | q_avg = torch.linalg.eigh(A)[1][:, -1]
222 | if q_avg[0] < 0:
223 | return -q_avg
224 | return q_avg
225 |
226 |
227 | def average_quaternion_batch(Q, weights=None):
228 | """calculate the average quaternion of the multiple quaternions
229 | Args:
230 | Q (tensor): [B, num_quaternions, 4]
231 | weights (tensor, optional): [B, num_quaternions]. Defaults to None.
232 |
233 | Returns:
234 | oriented_q_avg: average quaternion, [B, 4]
235 | """
236 |
237 | if weights is None:
238 | weights = torch.ones((Q.shape[0], Q.shape[1]), device=Q.device) / Q.shape[1]
239 | A = torch.zeros((Q.shape[0], 4, 4), device=Q.device)
240 | weight_sum = torch.sum(weights, axis=-1)
241 |
242 | oriented_Q = ((Q[:, :, 0:1] > 0).float() - 0.5) * 2 * Q
243 | A = torch.einsum("abi,abk->abik", (oriented_Q, oriented_Q))
244 | A = torch.sum(torch.einsum("abij,ab->abij", (A, weights)), 1)
245 | A /= weight_sum.reshape(A.shape[0], -1).unsqueeze(-1).repeat(1, 4, 4)
246 |
247 | q_avg = torch.linalg.eigh(A)[1][:, :, -1]
248 | oriented_q_avg = ((q_avg[:, 0:1] > 0).float() - 0.5) * 2 * q_avg
249 | return oriented_q_avg
250 |
251 |
252 | def average_quaternion_numpy(Q, W=None):
253 | if W is not None:
254 | Q *= W[:, None]
255 | eigvals, eigvecs = np.linalg.eig(Q.T@Q)
256 | return eigvecs[:, eigvals.argmax()]
257 |
258 |
259 | def normalize_rotation(rotation, rotation_mode):
260 | if rotation_mode == 'quat_wxyz' or rotation_mode == 'quat_xyzw':
261 | rotation /= torch.norm(rotation, dim=-1, keepdim=True)
262 | elif rotation_mode == 'rot_matrix':
263 | rot_matrix = get_rot_matrix(rotation, rotation_mode)
264 | rotation[:, :3] = rot_matrix[:, :, 0]
265 | rotation[:, 3:6] = rot_matrix[:, :, 1]
266 | elif rotation_mode == 'euler_xyz_sx_cx':
267 | rot_sin_theta = rotation[:, :3]
268 | rot_cos_theta = rotation[:, 3:6]
269 | theta = torch.atan2(rot_sin_theta, rot_cos_theta)
270 | rotation[:, :3] = torch.sin(theta)
271 | rotation[:, 3:6] = torch.cos(theta)
272 | elif rotation_mode == 'euler_xyz':
273 | pass
274 | else:
275 | raise NotImplementedError
276 | return rotation
277 |
278 |
279 | if __name__ == '__main__':
280 | quat = torch.randn(2, 3, 4)
281 | quat = quat / torch.linalg.norm(quat, axis=-1).unsqueeze(-1)
282 | quat = average_quaternion_batch(quat)
283 |
284 |
285 |
--------------------------------------------------------------------------------
/utils/so3_visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import pytorch3d
5 | import cv2
6 |
7 | from matplotlib import rc
8 | from PIL import Image
9 | from pytorch3d import transforms
10 | from ipdb import set_trace
11 |
12 |
13 | rc("font", **{"family": "serif", "serif": ["Times New Roman"]})
14 | EYE = np.eye(3)
15 |
16 | def visualize_so3_probabilities(
17 | rotations,
18 | probabilities,
19 | rotations_gt=None,
20 | choosed_rotation=None,
21 | ax=None,
22 | fig=None,
23 | display_threshold_probability=0,
24 | to_image=True,
25 | show_color_wheel=True,
26 | canonical_rotation=EYE,
27 | gt_size=600,
28 | choosed_size=300,
29 | y_offset=-30,
30 | dpi=600,
31 | ):
32 | """
33 | Plot a single distribution on SO(3) using the tilt-colored method.
34 | Args:
35 | rotations: [N, 3, 3] tensor of rotation matrices
36 | probabilities: [N] tensor of probabilities
37 | rotations_gt: [N_gt, 3, 3] or [3, 3] ground truth rotation matrices
38 | ax: The matplotlib.pyplot.axis object to paint
39 | fig: The matplotlib.pyplot.figure object to paint
40 | display_threshold_probability: The probability threshold below which to omit
41 | the marker
42 | to_image: If True, return a tensor containing the pixels of the finished
43 | figure; if False return the figure itself
44 | show_color_wheel: If True, display the explanatory color wheel which matches
45 | color on the plot with tilt angle
46 | canonical_rotation: A [3, 3] rotation matrix representing the 'display
47 | rotation', to change the view of the distribution. It rotates the
48 | canonical axes so that the view of SO(3) on the plot is different, which
49 | can help obtain a more informative view.
50 | Returns:
51 | A matplotlib.pyplot.figure object, or a tensor of pixels if to_image=True.
52 | """
53 |
54 | def _show_single_marker(ax, rotation, marker, edgecolors=True, facecolors=False, s=gt_size):
55 | eulers = transforms.matrix_to_euler_angles(torch.tensor(rotation), "ZXY")
56 | eulers = eulers.numpy()
57 |
58 | tilt_angle = eulers[0]
59 | latitude = eulers[1]
60 | longitude = eulers[2]
61 |
62 | color = cmap(0.5 + tilt_angle / 2 / np.pi)
63 | ax.scatter(
64 | longitude,
65 | latitude,
66 | s=s,
67 | edgecolors=color if edgecolors else "none",
68 | facecolors=facecolors if facecolors else "none",
69 | marker=marker,
70 | linewidth=5,
71 | )
72 |
73 | if ax is None:
74 | fig = plt.figure(figsize=(4, 2), dpi=dpi)
75 | ax = fig.add_subplot(111, projection="mollweide")
76 | if rotations_gt is not None and len(rotations_gt.shape) == 2:
77 | rotations_gt = rotations_gt[None]
78 | if choosed_rotation is not None and len(choosed_rotation.shape) == 2:
79 | choosed_rotation = choosed_rotation[None]
80 | display_rotations = rotations @ canonical_rotation
81 | cmap = plt.cm.hsv
82 | scatterpoint_scaling = 4e3
83 | eulers_queries = transforms.matrix_to_euler_angles(
84 | torch.tensor(display_rotations), "ZXY"
85 | )
86 | eulers_queries = eulers_queries.numpy()
87 |
88 | tilt_angles = eulers_queries[:, 0]
89 | longitudes = eulers_queries[:, 2]
90 | latitudes = eulers_queries[:, 1]
91 |
92 | which_to_display = probabilities > display_threshold_probability
93 |
94 | if rotations_gt is not None:
95 | display_rotations_gt = rotations_gt @ canonical_rotation
96 |
97 | for rotation in display_rotations_gt:
98 | _show_single_marker(ax, rotation, "o")
99 | # Cover up the centers with white markers
100 | for rotation in display_rotations_gt:
101 | _show_single_marker(
102 | ax, rotation, "o", edgecolors=False, facecolors="#ffffff"
103 | )
104 |
105 | if choosed_rotation is not None:
106 | display_choosed_rotations = choosed_rotation @ canonical_rotation
107 |
108 | for rotation in display_choosed_rotations:
109 | _show_single_marker(ax, rotation, "o", s=choosed_size)
110 | # Cover up the centers with white markers
111 | for rotation in display_choosed_rotations:
112 | _show_single_marker(
113 | ax, rotation, "o", edgecolors=False, facecolors="#ffffff", s=choosed_size
114 | )
115 |
116 | # Display the distribution
117 | ax.scatter(
118 | longitudes[which_to_display],
119 | latitudes[which_to_display],
120 | s=scatterpoint_scaling * probabilities[which_to_display],
121 | c=cmap(0.5 + tilt_angles[which_to_display] / 2.0 / np.pi),
122 | marker='.'
123 | )
124 |
125 | yticks = np.array([-60, -30, 0, 30, 60])
126 | yticks_minor = np.arange(-75, 90, 15)
127 | ax.set_yticks(yticks_minor * np.pi / 180, minor=True)
128 | ax.set_yticks(yticks * np.pi / 180, [f"{y}°" for y in yticks], fontsize=10)
129 | xticks = np.array([-90, 0, 90])
130 | xticks_minor = np.arange(-150, 180, 30)
131 | ax.set_xticks(xticks * np.pi / 180, [])
132 | ax.set_xticks(xticks_minor * np.pi / 180, minor=True)
133 |
134 | for xtick in xticks:
135 | # Manually set xticks
136 | x = xtick * np.pi / 180
137 | y = y_offset * np.pi / 180
138 | ax.text(x, y, f"{xtick}°", ha="center", va="center", fontsize=10)
139 |
140 | ax.grid(which="minor")
141 | ax.grid(which="major")
142 |
143 | if show_color_wheel:
144 | # Add a color wheel showing the tilt angle to color conversion.
145 | ax = fig.add_axes([0.85, 0.12, 0.12, 0.12], projection="polar")
146 | theta = np.linspace(-3 * np.pi / 2, np.pi / 2, 200)
147 | radii = np.linspace(0.4, 0.5, 2)
148 | _, theta_grid = np.meshgrid(radii, theta)
149 | colormap_val = 0.5 + theta_grid / np.pi / 2.0
150 | ax.pcolormesh(theta, radii, colormap_val.T, cmap=cmap, shading="auto")
151 | ax.set_yticklabels([])
152 | ax.set_xticks(np.arange(0, 2 * np.pi, np.pi / 2))
153 | ax.set_xticklabels(
154 | [
155 | r"90$\degree$",
156 | r"180$\degree$",
157 | r"270$\degree$",
158 | r"0$\degree$",
159 | ],
160 | fontsize=6,
161 | )
162 | ax.spines["polar"].set_visible(False)
163 | ax.grid(False)
164 | plt.text(
165 | 0.5,
166 | 0.5,
167 | "Roll",
168 | fontsize=6,
169 | horizontalalignment="center",
170 | verticalalignment="center",
171 | transform=ax.transAxes,
172 | )
173 |
174 | if to_image:
175 | return plot_to_image(fig)
176 | else:
177 | return fig
178 |
179 |
180 | def plot_to_image(fig):
181 | fig.canvas.draw()
182 | image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
183 | image_from_plot = image_from_plot.reshape(
184 | fig.canvas.get_width_height()[::-1] + (3,)
185 | )
186 | plt.close(fig)
187 | return image_from_plot
188 |
189 |
190 | def antialias(image, level=1):
191 | is_numpy = isinstance(image, np.ndarray)
192 | if is_numpy:
193 | image = Image.fromarray(image)
194 | for _ in range(level):
195 | size = np.array(image.size) // 2
196 | image = image.resize(size, Image.LANCZOS)
197 | if is_numpy:
198 | image = np.array(image)
199 | return image
200 |
201 |
202 | def unnormalize_image(image):
203 | if isinstance(image, torch.Tensor):
204 | image = image.cpu().numpy()
205 | if image.shape[0] == 3:
206 | image = image.transpose(1, 2, 0)
207 | mean = np.array([0.485, 0.456, 0.406])
208 | std = np.array([0.229, 0.224, 0.225])
209 | image = image * std + mean
210 | return (image * 255.0).astype(np.uint8)
211 |
212 |
213 | def visualize_so3(save_path, pred_rotations, gt_rotation, pred_rotation=None, probabilities=None, image=None):
214 | if image == None:
215 | fig = plt.figure(figsize=(5, 2), dpi=600)
216 | gs = fig.add_gridspec(1, 2)
217 | ax = fig.add_subplot(gs[0, :], projection="mollweide")
218 | else:
219 | fig = plt.figure(figsize=(5, 4), dpi=600)
220 | gs = fig.add_gridspec(2, 2)
221 | ax = fig.add_subplot(gs[1, :], projection="mollweide")
222 | bx = fig.add_subplot(gs[0, :])
223 | bx.imshow(image.permute(1, 2, 0).cpu().numpy())
224 | bx.axis("off")
225 | # rotations = np.concatenate((pred_rotations, pred_rotation, gt_rotation), axis=0)
226 | rotations = pred_rotations
227 | if probabilities is None:
228 | probabilities = np.ones(rotations.shape[0])/2000
229 | # probabilities[-2] = 0.002
230 | # probabilities[-1] = 0.003
231 |
232 | so3_vis = visualize_so3_probabilities(
233 | rotations_gt=gt_rotation,
234 | choosed_rotation=pred_rotation,
235 | rotations=rotations,
236 | probabilities=probabilities,
237 | to_image=False,
238 | display_threshold_probability=0.00001,
239 | show_color_wheel=True,
240 | fig=fig,
241 | ax=ax,
242 | )
243 | plt.savefig(save_path)
244 |
245 |
246 | if __name__ == '__main__':
247 | fig = plt.figure(figsize=(4, 4), dpi=300)
248 | gs = fig.add_gridspec(2, 2)
249 | ax1 = fig.add_subplot(gs[0, 0])
250 | # ax1.imshow(unnormalize_image(images[0].cpu().numpy().transpose(1, 2, 0)))
251 | ax1.axis("off")
252 | ax2 = fig.add_subplot(gs[0, 1])
253 | # ax2.imshow(unnormalize_image(images[1].cpu().numpy().transpose(1, 2, 0)))
254 | ax2.axis("off")
255 | ax3 = fig.add_subplot(gs[1, :], projection="mollweide")
256 |
257 | rotations = pytorch3d.transforms.random_rotations(10).cpu().numpy()
258 | probabilities = np.ones(10)/1000
259 | print(probabilities)
260 | so3_vis = visualize_so3_probabilities(
261 | rotations=rotations,
262 | probabilities=probabilities,
263 | to_image=False,
264 | display_threshold_probability=0.0005,
265 | show_color_wheel=True,
266 | fig=fig,
267 | ax=ax3,
268 | )
269 | plt.savefig('./test.jpg')
270 | # plt.show()
271 |
272 |
--------------------------------------------------------------------------------
/utils/tracking_utils.py:
--------------------------------------------------------------------------------
1 | ''' modified from CAPTRA https://github.com/HalfSummer11/CAPTRA/tree/5d7d088c3de49389a90b5fae280e96409e7246c6 '''
2 |
3 | import torch
4 | import copy
5 | import math
6 | from ipdb import set_trace
7 |
8 | def normalize(q):
9 | assert q.shape[-1] == 4
10 | norm = q.norm(dim=-1, keepdim=True)
11 | return q.div(norm)
12 |
13 |
14 | def matrix_to_unit_quaternion(matrix):
15 | assert matrix.shape[-1] == matrix.shape[-2] == 3
16 | if not isinstance(matrix, torch.Tensor):
17 | matrix = torch.tensor(matrix)
18 |
19 | trace = 1 + matrix[..., 0, 0] + matrix[..., 1, 1] + matrix[..., 2, 2]
20 | trace = torch.clamp(trace, min=0.)
21 | r = torch.sqrt(trace)
22 | s = 1.0 / (2 * r + 1e-7)
23 | w = 0.5 * r
24 | x = (matrix[..., 2, 1] - matrix[..., 1, 2])*s
25 | y = (matrix[..., 0, 2] - matrix[..., 2, 0])*s
26 | z = (matrix[..., 1, 0] - matrix[..., 0, 1])*s
27 |
28 | q = torch.stack((w, x, y, z), dim=-1)
29 |
30 | return normalize(q)
31 |
32 |
33 | def generate_random_quaternion(quaternion_shape):
34 | assert quaternion_shape[-1] == 4
35 | rand_norm = torch.randn(quaternion_shape)
36 | rand_q = normalize(rand_norm)
37 | return rand_q
38 |
39 |
40 | def jitter_quaternion(q, theta): #[Bs, 4], [Bs, 1]
41 | new_q = generate_random_quaternion(q.shape).to(q.device)
42 | dot_product = torch.sum(q*new_q, dim=-1, keepdim=True) #
43 | shape = (tuple(1 for _ in range(len(dot_product.shape) - 1)) + (4, ))
44 | q_orthogonal = normalize(new_q - q * dot_product.repeat(*shape))
45 | # theta = 2arccos(|p.dot(q)|)
46 | # |p.dot(q)| = cos(theta/2)
47 | tile_theta = theta.repeat(shape)
48 | jittered_q = q*torch.cos(tile_theta/2) + q_orthogonal*torch.sin(tile_theta/2)
49 |
50 | return jittered_q
51 |
52 |
53 | def assert_normalized(q, atol=1e-3):
54 | assert q.shape[-1] == 4
55 | norm = q.norm(dim=-1)
56 | norm_check = (norm - 1.0).abs()
57 | try:
58 | assert torch.max(norm_check) < atol
59 | except:
60 | print("normalization failure: {}.".format(torch.max(norm_check)))
61 | return -1
62 | return 0
63 |
64 |
65 | def unit_quaternion_to_matrix(q):
66 | assert_normalized(q)
67 | w, x, y, z= torch.unbind(q, dim=-1)
68 | matrix = torch.stack(( 1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y* w,
69 | 2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w,
70 | 2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x -2*y*y),
71 | dim=-1)
72 | matrix_shape = list(matrix.shape)[:-1]+[3,3]
73 | return matrix.view(matrix_shape).contiguous()
74 |
75 |
76 | def noisy_rot_matrix(matrix, rad, type='normal'):
77 | if type == 'normal':
78 | theta = torch.abs(torch.randn_like(matrix[..., 0, 0])) * rad
79 | elif type == 'uniform':
80 | theta = torch.rand_like(matrix[..., 0, 0]) * rad
81 | quater = matrix_to_unit_quaternion(matrix)
82 | new_quater = jitter_quaternion(quater, theta.unsqueeze(-1))
83 | new_mat = unit_quaternion_to_matrix(new_quater)
84 | return new_mat
85 |
86 |
87 | def add_noise_to_RT(RT, type='normal', r=5.0, t=0.03):
88 | rand_type = type # 'uniform' or 'normal' --> we use 'normal'
89 |
90 | def random_tensor(base):
91 | if rand_type == 'uniform':
92 | return torch.rand_like(base) * 2.0 - 1.0
93 | elif rand_type == 'normal':
94 | return torch.randn_like(base)
95 | new_RT = copy.deepcopy(RT)
96 | new_RT[:, :3, :3] = noisy_rot_matrix(RT[:, :3, :3], r/180*math.pi, type=rand_type).reshape(RT[:, :3, :3].shape)
97 | norm = random_tensor(RT[:, 0, 0]) * t # [B, P]
98 | direction = random_tensor(RT[:, :3, 3].squeeze(-1)) # [B, P, 3]
99 | direction = direction / torch.clamp(direction.norm(dim=-1, keepdim=True), min=1e-9) # [B, P, 3] unit vecs
100 | new_RT[:, :3, 3] = RT[:, :3, 3] + (direction * norm.unsqueeze(-1)) # [B, P, 3, 1]
101 |
102 | return new_RT
103 |
104 |
--------------------------------------------------------------------------------