├── .gitignore ├── LICENSE ├── README.md ├── command_test.sh ├── command_train.sh ├── dataset ├── generate_graspness.py ├── graspnet_dataset.py ├── simplify_dataset.py └── vis_graspness.py ├── doc ├── example_data │ ├── color.png │ ├── demo_result.png │ ├── depth.png │ ├── meta.mat │ └── workspace_mask.png └── teaser.png ├── infer_vis_grasp.py ├── knn ├── knn_modules.py ├── setup.py └── src │ ├── cpu │ ├── knn_cpu.cpp │ └── vision.h │ ├── cuda │ ├── knn.cu │ └── vision.h │ ├── knn.h │ └── vision.cpp ├── models ├── backbone_resunet14.py ├── graspnet.py ├── loss.py ├── modules.py └── resnet.py ├── pointnet2 ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── cylinder_query.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── cylinder_query.cpp │ │ ├── cylinder_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu ├── pointnet2_modules.py ├── pointnet2_utils.py ├── pytorch_utils.py └── setup.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── collision_detector.py ├── data_utils.py ├── label_generation.py └── loss_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | *.ipynb 3 | **/.ipynb_checkpoints/** 4 | *.npy 5 | *.npz 6 | **/.vscode/** 7 | **/grasp_label*/** 8 | **/log*/** 9 | **/dump*/** 10 | **/build/** 11 | *.o 12 | *.so 13 | *.egg 14 | **/*.egg-info/** 15 | logs 16 | dataset/tolerance 17 | **/.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GRASPNET-BASELINE 2 | SOFTWARE LICENSE AGREEMENT 3 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY 4 | 5 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. 6 | 7 | This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Shanghai Jiao Tong University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. 8 | 9 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: 10 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, 11 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). 12 | 13 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. 14 | 15 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. 16 | 17 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. 18 | 19 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. 20 | 21 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “AlphaPose", "Shanghai Jiao Tong" or any renditions thereof without the prior written permission of Licensor. 22 | 23 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. 24 | 25 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. 26 | 27 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. 28 | 29 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. 30 | 31 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. 32 | 33 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. 34 | 35 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. 36 | 37 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. 38 | 39 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable 40 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. 41 | 42 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. 43 | 44 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. 45 | 46 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. 47 | 48 | 49 | 50 | ************************************************************************ 51 | 52 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 53 | 54 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 55 | 56 | 1. PyTorch (https://github.com/pytorch/pytorch) 57 | 58 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 59 | 60 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 61 | 62 | From PyTorch: 63 | 64 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 65 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 66 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 67 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 68 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 69 | Copyright (c) 2011-2013 NYU (Clement Farabet) 70 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 71 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 72 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 73 | 74 | From Caffe2: 75 | 76 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 77 | 78 | All contributions by Facebook: 79 | Copyright (c) 2016 Facebook Inc. 80 | 81 | All contributions by Google: 82 | Copyright (c) 2015 Google Inc. 83 | All rights reserved. 84 | 85 | All contributions by Yangqing Jia: 86 | Copyright (c) 2015 Yangqing Jia 87 | All rights reserved. 88 | 89 | All contributions by Kakao Brain: 90 | Copyright 2019-2020 Kakao Brain 91 | 92 | All contributions from Caffe: 93 | Copyright(c) 2013, 2014, 2015, the respective contributors 94 | All rights reserved. 95 | 96 | All other contributions: 97 | Copyright(c) 2015, 2016 the respective contributors 98 | All rights reserved. 99 | 100 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 101 | copyright over their contributions to Caffe2. The project versioning records 102 | all such contribution and copyright details. If a contributor wants to further 103 | mark their specific copyright on a particular contribution, they should 104 | indicate their copyright solely in the commit message of the change when it is 105 | committed. 106 | 107 | All rights reserved. 108 | 109 | Redistribution and use in source and binary forms, with or without 110 | modification, are permitted provided that the following conditions are met: 111 | 112 | 1. Redistributions of source code must retain the above copyright 113 | notice, this list of conditions and the following disclaimer. 114 | 115 | 2. Redistributions in binary form must reproduce the above copyright 116 | notice, this list of conditions and the following disclaimer in the 117 | documentation and/or other materials provided with the distribution. 118 | 119 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 120 | and IDIAP Research Institute nor the names of its contributors may be 121 | used to endorse or promote products derived from this software without 122 | specific prior written permission. 123 | 124 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 125 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 126 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 127 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 128 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 129 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 130 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 131 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 132 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 133 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 134 | POSSIBILITY OF SUCH DAMAGE. 135 | 136 | 2. VoteNet (https://github.com/facebookresearch/votenet) 137 | 138 | MIT License 139 | 140 | Copyright (c) Facebook, Inc. and its affiliates. 141 | 142 | Permission is hereby granted, free of charge, to any person obtaining a copy 143 | of this software and associated documentation files (the "Software"), to deal 144 | in the Software without restriction, including without limitation the rights 145 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 146 | copies of the Software, and to permit persons to whom the Software is 147 | furnished to do so, subject to the following conditions: 148 | 149 | The above copyright notice and this permission notice shall be included in all 150 | copies or substantial portions of the Software. 151 | 152 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 153 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 154 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 155 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 156 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 157 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 158 | SOFTWARE. 159 | 160 | ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraspNet graspness 2 | My implementation of paper "Graspness Discovery in Clutters for Fast and Accurate Grasp Detection" (ICCV 2021). 3 | 4 | [[paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Graspness_Discovery_in_Clutters_for_Fast_and_Accurate_Grasp_Detection_ICCV_2021_paper.pdf)] 5 | [[dataset](https://graspnet.net/)] 6 | [[API](https://github.com/graspnet/graspnetAPI)] 7 | 8 | 9 | ## Requirements 10 | - Python 3 11 | - PyTorch 1.8 12 | - Open3d 0.8 13 | - TensorBoard 2.3 14 | - NumPy 15 | - SciPy 16 | - Pillow 17 | - tqdm 18 | - MinkowskiEngine 19 | 20 | ## Installation 21 | Get the code. 22 | ```bash 23 | git clone https://github.com/rhett-chen/graspness_implementation.git 24 | cd graspnet-graspness 25 | ``` 26 | Install packages via Pip. 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | Compile and install pointnet2 operators (code adapted from [votenet](https://github.com/facebookresearch/votenet)). 31 | ```bash 32 | cd pointnet2 33 | python setup.py install 34 | ``` 35 | Compile and install knn operator (code adapted from [pytorch_knn_cuda](https://github.com/chrischoy/pytorch_knn_cuda)). 36 | ```bash 37 | cd knn 38 | python setup.py install 39 | ``` 40 | Install graspnetAPI for evaluation. 41 | ```bash 42 | git clone https://github.com/graspnet/graspnetAPI.git 43 | cd graspnetAPI 44 | pip install . 45 | ``` 46 | For MinkowskiEngine, please refer https://github.com/NVIDIA/MinkowskiEngine 47 | ## Point level Graspness Generation 48 | Point level graspness label are not included in the original dataset, and need additional generation. Make sure you have downloaded the orginal dataset from [GraspNet](https://graspnet.net/). The generation code is in [dataset/generate_graspness.py](dataset/generate_graspness.py). 49 | ```bash 50 | cd dataset 51 | python generate_graspness.py --dataset_root /data3/graspnet --camera_type kinect 52 | ``` 53 | 54 | ## Simplify dataset 55 | original dataset grasp_label files have redundant data, We can significantly save the memory cost. The code is in [dataset/simplify_dataset.py](dataset/simplify_dataset.py) 56 | ```bash 57 | cd dataset 58 | python simplify_dataset.py --dataset_root /data3/graspnet 59 | ``` 60 | 61 | ## Training and Testing 62 | Training examples are shown in [command_train.sh](command_train.sh). `--dataset_root`, `--camera` and `--log_dir` should be specified according to your settings. You can use TensorBoard to visualize training process. 63 | 64 | Testing examples are shown in [command_test.sh](command_test.sh), which contains inference and result evaluation. `--dataset_root`, `--camera`, `--checkpoint_path` and `--dump_dir` should be specified according to your settings. Set `--collision_thresh` to -1 for fast inference. 65 | 66 | ## Results 67 | Results "In repo" report the model performance of my results without collision detection. 68 | 69 | Evaluation results on Kinect camera: 70 | | | | Seen | | | Similar | | | Novel | | 71 | |:--------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:| 72 | | | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | 73 | | In paper | 61.19 | 71.46 | 56.04 | 47.39 | 56.78 | 40.43 | 19.01 | 23.73 | 10.60 | 74 | | In repo | 61.83 | 73.28 | 54.14 | 51.13 | 62.53 | 41.57 | 19.94 | 24.90 | 11.02 | 75 | 76 | 77 | ## Troubleshooting 78 | If you meet the torch.floor error in MinkowskiEngine, you can simply solve it by changing the source code of MinkowskiEngine: 79 | MinkowskiEngine/utils/quantization.py 262,from discrete_coordinates =_auto_floor(coordinates) to discrete_coordinates = coordinates 80 | ## Acknowledgement 81 | My code is mainly based on Graspnet-baseline https://github.com/graspnet/graspnet-baseline. 82 | -------------------------------------------------------------------------------- /command_test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python test.py --camera kinect --dump_dir logs/log_kn/dump_epoch10 --checkpoint_path logs/log_kn/minkresunet_epoch10.tar --batch_size 1 --dataset_root /data3/graspnet --infer --eval --collision_thresh -1 -------------------------------------------------------------------------------- /command_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python train.py --camera kinect --log_dir logs/log_kn --batch_size 4 --learning_rate 0.001 --model_name minkuresunet --dataset_root /data3/graspnet -------------------------------------------------------------------------------- /dataset/generate_graspness.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | import scipy.io as scio 5 | import sys 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(ROOT_DIR) 8 | from utils.data_utils import get_workspace_mask, CameraInfo, create_point_cloud_from_depth_image 9 | from knn.knn_modules import knn 10 | import torch 11 | from graspnetAPI.utils.xmlhandler import xmlReader 12 | from graspnetAPI.utils.utils import get_obj_pose_list, transform_points 13 | import argparse 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataset_root', default=None, required=True) 18 | parser.add_argument('--camera_type', default='kinect', help='Camera split [realsense/kinect]') 19 | 20 | 21 | if __name__ == '__main__': 22 | cfgs = parser.parse_args() 23 | dataset_root = cfgs.dataset_root # set dataset root 24 | camera_type = cfgs.camera_type # kinect / realsense 25 | save_path_root = os.path.join(dataset_root, 'graspness') 26 | 27 | num_views, num_angles, num_depths = 300, 12, 4 28 | fric_coef_thresh = 0.8 29 | point_grasp_num = num_views * num_angles * num_depths 30 | for scene_id in range(100): 31 | save_path = os.path.join(save_path_root, 'scene_' + str(scene_id).zfill(4), camera_type) 32 | if not os.path.exists(save_path): 33 | os.makedirs(save_path) 34 | labels = np.load( 35 | os.path.join(dataset_root, 'collision_label', 'scene_' + str(scene_id).zfill(4), 'collision_labels.npz')) 36 | collision_dump = [] 37 | for j in range(len(labels)): 38 | collision_dump.append(labels['arr_{}'.format(j)]) 39 | 40 | for ann_id in range(256): 41 | # get scene point cloud 42 | print('generating scene: {} ann: {}'.format(scene_id, ann_id)) 43 | depth = np.array(Image.open(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 44 | camera_type, 'depth', str(ann_id).zfill(4) + '.png'))) 45 | seg = np.array(Image.open(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 46 | camera_type, 'label', str(ann_id).zfill(4) + '.png'))) 47 | meta = scio.loadmat(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 48 | camera_type, 'meta', str(ann_id).zfill(4) + '.mat')) 49 | intrinsic = meta['intrinsic_matrix'] 50 | factor_depth = meta['factor_depth'] 51 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], 52 | factor_depth) 53 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 54 | 55 | # remove outlier and get objectness label 56 | depth_mask = (depth > 0) 57 | camera_poses = np.load(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 58 | camera_type, 'camera_poses.npy')) 59 | camera_pose = camera_poses[ann_id] 60 | align_mat = np.load(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 61 | camera_type, 'cam0_wrt_table.npy')) 62 | trans = np.dot(align_mat, camera_pose) 63 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 64 | mask = (depth_mask & workspace_mask) 65 | cloud_masked = cloud[mask] 66 | objectness_label = seg[mask] 67 | 68 | # get scene object and grasp info 69 | scene_reader = xmlReader(os.path.join(dataset_root, 'scenes', 'scene_' + str(scene_id).zfill(4), 70 | camera_type, 'annotations', '%04d.xml' % ann_id)) 71 | pose_vectors = scene_reader.getposevectorlist() 72 | obj_list, pose_list = get_obj_pose_list(camera_pose, pose_vectors) 73 | grasp_labels = {} 74 | for i in obj_list: 75 | file = np.load(os.path.join(dataset_root, 'grasp_label', '{}_labels.npz'.format(str(i).zfill(3)))) 76 | grasp_labels[i] = (file['points'].astype(np.float32), file['offsets'].astype(np.float32), 77 | file['scores'].astype(np.float32)) 78 | 79 | grasp_points = [] 80 | grasp_points_graspness = [] 81 | for i, (obj_idx, trans_) in enumerate(zip(obj_list, pose_list)): 82 | sampled_points, offsets, fric_coefs = grasp_labels[obj_idx] 83 | collision = collision_dump[i] # Npoints * num_views * num_angles * num_depths 84 | num_points = sampled_points.shape[0] 85 | 86 | valid_grasp_mask = ((fric_coefs <= fric_coef_thresh) & (fric_coefs > 0) & ~collision) 87 | valid_grasp_mask = valid_grasp_mask.reshape(num_points, -1) 88 | graspness = np.sum(valid_grasp_mask, axis=1) / point_grasp_num 89 | target_points = transform_points(sampled_points, trans_) 90 | target_points = transform_points(target_points, np.linalg.inv(camera_pose)) # fix bug 91 | grasp_points.append(target_points) 92 | grasp_points_graspness.append(graspness.reshape(num_points, 1)) 93 | grasp_points = np.vstack(grasp_points) 94 | grasp_points_graspness = np.vstack(grasp_points_graspness) 95 | 96 | grasp_points = torch.from_numpy(grasp_points).cuda() 97 | grasp_points_graspness = torch.from_numpy(grasp_points_graspness).cuda() 98 | grasp_points = grasp_points.transpose(0, 1).contiguous().unsqueeze(0) 99 | 100 | masked_points_num = cloud_masked.shape[0] 101 | cloud_masked_graspness = np.zeros((masked_points_num, 1)) 102 | part_num = int(masked_points_num / 10000) 103 | for i in range(1, part_num + 2): # lack of cuda memory 104 | if i == part_num + 1: 105 | cloud_masked_partial = cloud_masked[10000 * part_num:] 106 | if len(cloud_masked_partial) == 0: 107 | break 108 | else: 109 | cloud_masked_partial = cloud_masked[10000 * (i - 1):(i * 10000)] 110 | cloud_masked_partial = torch.from_numpy(cloud_masked_partial).cuda() 111 | cloud_masked_partial = cloud_masked_partial.transpose(0, 1).contiguous().unsqueeze(0) 112 | nn_inds = knn(grasp_points, cloud_masked_partial, k=1).squeeze() - 1 113 | cloud_masked_graspness[10000 * (i - 1):(i * 10000)] = torch.index_select( 114 | grasp_points_graspness, 0, nn_inds).cpu().numpy() 115 | 116 | max_graspness = np.max(cloud_masked_graspness) 117 | min_graspness = np.min(cloud_masked_graspness) 118 | cloud_masked_graspness = (cloud_masked_graspness - min_graspness) / (max_graspness - min_graspness) 119 | 120 | np.save(os.path.join(save_path, str(ann_id).zfill(4) + '.npy'), cloud_masked_graspness) 121 | -------------------------------------------------------------------------------- /dataset/graspnet_dataset.py: -------------------------------------------------------------------------------- 1 | """ GraspNet dataset processing. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import scipy.io as scio 8 | from PIL import Image 9 | 10 | import torch 11 | import collections.abc as container_abcs 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | import MinkowskiEngine as ME 15 | from data_utils import CameraInfo, transform_point_cloud, create_point_cloud_from_depth_image, get_workspace_mask 16 | 17 | 18 | class GraspNetDataset(Dataset): 19 | def __init__(self, root, grasp_labels=None, camera='kinect', split='train', num_points=20000, 20 | voxel_size=0.005, remove_outlier=True, augment=False, load_label=True): 21 | assert (num_points <= 50000) 22 | self.root = root 23 | self.split = split 24 | self.voxel_size = voxel_size 25 | self.num_points = num_points 26 | self.remove_outlier = remove_outlier 27 | self.grasp_labels = grasp_labels 28 | self.camera = camera 29 | self.augment = augment 30 | self.load_label = load_label 31 | self.collision_labels = {} 32 | 33 | if split == 'train': 34 | self.sceneIds = list(range(100)) 35 | elif split == 'test': 36 | self.sceneIds = list(range(100, 190)) 37 | elif split == 'test_seen': 38 | self.sceneIds = list(range(100, 130)) 39 | elif split == 'test_similar': 40 | self.sceneIds = list(range(130, 160)) 41 | elif split == 'test_novel': 42 | self.sceneIds = list(range(160, 190)) 43 | self.sceneIds = ['scene_{}'.format(str(x).zfill(4)) for x in self.sceneIds] 44 | 45 | self.depthpath = [] 46 | self.labelpath = [] 47 | self.metapath = [] 48 | self.scenename = [] 49 | self.frameid = [] 50 | self.graspnesspath = [] 51 | for x in tqdm(self.sceneIds, desc='Loading data path and collision labels...'): 52 | for img_num in range(256): 53 | self.depthpath.append(os.path.join(root, 'scenes', x, camera, 'depth', str(img_num).zfill(4) + '.png')) 54 | self.labelpath.append(os.path.join(root, 'scenes', x, camera, 'label', str(img_num).zfill(4) + '.png')) 55 | self.metapath.append(os.path.join(root, 'scenes', x, camera, 'meta', str(img_num).zfill(4) + '.mat')) 56 | self.graspnesspath.append(os.path.join(root, 'graspness', x, camera, str(img_num).zfill(4) + '.npy')) 57 | self.scenename.append(x.strip()) 58 | self.frameid.append(img_num) 59 | if self.load_label: 60 | collision_labels = np.load(os.path.join(root, 'collision_label', x.strip(), 'collision_labels.npz')) 61 | self.collision_labels[x.strip()] = {} 62 | for i in range(len(collision_labels)): 63 | self.collision_labels[x.strip()][i] = collision_labels['arr_{}'.format(i)] 64 | 65 | def scene_list(self): 66 | return self.scenename 67 | 68 | def __len__(self): 69 | return len(self.depthpath) 70 | 71 | def augment_data(self, point_clouds, object_poses_list): 72 | # Flipping along the YZ plane 73 | if np.random.random() > 0.5: 74 | flip_mat = np.array([[-1, 0, 0], 75 | [0, 1, 0], 76 | [0, 0, 1]]) 77 | point_clouds = transform_point_cloud(point_clouds, flip_mat, '3x3') 78 | for i in range(len(object_poses_list)): 79 | object_poses_list[i] = np.dot(flip_mat, object_poses_list[i]).astype(np.float32) 80 | 81 | # Rotation along up-axis/Z-axis 82 | rot_angle = (np.random.random() * np.pi / 3) - np.pi / 6 # -30 ~ +30 degree 83 | c, s = np.cos(rot_angle), np.sin(rot_angle) 84 | rot_mat = np.array([[1, 0, 0], 85 | [0, c, -s], 86 | [0, s, c]]) 87 | point_clouds = transform_point_cloud(point_clouds, rot_mat, '3x3') 88 | for i in range(len(object_poses_list)): 89 | object_poses_list[i] = np.dot(rot_mat, object_poses_list[i]).astype(np.float32) 90 | 91 | return point_clouds, object_poses_list 92 | 93 | def __getitem__(self, index): 94 | if self.load_label: 95 | return self.get_data_label(index) 96 | else: 97 | return self.get_data(index) 98 | 99 | def get_data(self, index, return_raw_cloud=False): 100 | depth = np.array(Image.open(self.depthpath[index])) 101 | seg = np.array(Image.open(self.labelpath[index])) 102 | meta = scio.loadmat(self.metapath[index]) 103 | scene = self.scenename[index] 104 | try: 105 | intrinsic = meta['intrinsic_matrix'] 106 | factor_depth = meta['factor_depth'] 107 | except Exception as e: 108 | print(repr(e)) 109 | print(scene) 110 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], 111 | factor_depth) 112 | 113 | # generate cloud 114 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 115 | 116 | # get valid points 117 | depth_mask = (depth > 0) 118 | if self.remove_outlier: 119 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 120 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 121 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 122 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 123 | mask = (depth_mask & workspace_mask) 124 | else: 125 | mask = depth_mask 126 | cloud_masked = cloud[mask] 127 | 128 | if return_raw_cloud: 129 | return cloud_masked 130 | # sample points random 131 | if len(cloud_masked) >= self.num_points: 132 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 133 | else: 134 | idxs1 = np.arange(len(cloud_masked)) 135 | idxs2 = np.random.choice(len(cloud_masked), self.num_points - len(cloud_masked), replace=True) 136 | idxs = np.concatenate([idxs1, idxs2], axis=0) 137 | cloud_sampled = cloud_masked[idxs] 138 | 139 | ret_dict = {'point_clouds': cloud_sampled.astype(np.float32), 140 | 'coors': cloud_sampled.astype(np.float32) / self.voxel_size, 141 | 'feats': np.ones_like(cloud_sampled).astype(np.float32), 142 | } 143 | return ret_dict 144 | 145 | def get_data_label(self, index): 146 | depth = np.array(Image.open(self.depthpath[index])) 147 | seg = np.array(Image.open(self.labelpath[index])) 148 | meta = scio.loadmat(self.metapath[index]) 149 | graspness = np.load(self.graspnesspath[index]) # for each point in workspace masked point cloud 150 | scene = self.scenename[index] 151 | try: 152 | obj_idxs = meta['cls_indexes'].flatten().astype(np.int32) 153 | poses = meta['poses'] 154 | intrinsic = meta['intrinsic_matrix'] 155 | factor_depth = meta['factor_depth'] 156 | except Exception as e: 157 | print(repr(e)) 158 | print(scene) 159 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], 160 | factor_depth) 161 | 162 | # generate cloud 163 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 164 | 165 | # get valid points 166 | depth_mask = (depth > 0) 167 | if self.remove_outlier: 168 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 169 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 170 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 171 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 172 | mask = (depth_mask & workspace_mask) 173 | else: 174 | mask = depth_mask 175 | cloud_masked = cloud[mask] 176 | seg_masked = seg[mask] 177 | 178 | # sample points 179 | if len(cloud_masked) >= self.num_points: 180 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 181 | else: 182 | idxs1 = np.arange(len(cloud_masked)) 183 | idxs2 = np.random.choice(len(cloud_masked), self.num_points - len(cloud_masked), replace=True) 184 | idxs = np.concatenate([idxs1, idxs2], axis=0) 185 | cloud_sampled = cloud_masked[idxs] 186 | seg_sampled = seg_masked[idxs] 187 | graspness_sampled = graspness[idxs] 188 | objectness_label = seg_sampled.copy() 189 | 190 | objectness_label[objectness_label > 1] = 1 191 | 192 | object_poses_list = [] 193 | grasp_points_list = [] 194 | grasp_widths_list = [] 195 | grasp_scores_list = [] 196 | for i, obj_idx in enumerate(obj_idxs): 197 | if (seg_sampled == obj_idx).sum() < 50: 198 | continue 199 | object_poses_list.append(poses[:, :, i]) 200 | points, widths, scores = self.grasp_labels[obj_idx] 201 | collision = self.collision_labels[scene][i] # (Np, V, A, D) 202 | 203 | idxs = np.random.choice(len(points), min(max(int(len(points) / 4), 300), len(points)), replace=False) 204 | grasp_points_list.append(points[idxs]) 205 | grasp_widths_list.append(widths[idxs]) 206 | collision = collision[idxs].copy() 207 | scores = scores[idxs].copy() 208 | scores[collision] = 0 209 | grasp_scores_list.append(scores) 210 | 211 | if self.augment: 212 | cloud_sampled, object_poses_list = self.augment_data(cloud_sampled, object_poses_list) 213 | 214 | ret_dict = {'point_clouds': cloud_sampled.astype(np.float32), 215 | 'coors': cloud_sampled.astype(np.float32) / self.voxel_size, 216 | 'feats': np.ones_like(cloud_sampled).astype(np.float32), 217 | 'graspness_label': graspness_sampled.astype(np.float32), 218 | 'objectness_label': objectness_label.astype(np.int64), 219 | 'object_poses_list': object_poses_list, 220 | 'grasp_points_list': grasp_points_list, 221 | 'grasp_widths_list': grasp_widths_list, 222 | 'grasp_scores_list': grasp_scores_list} 223 | return ret_dict 224 | 225 | 226 | def load_grasp_labels(root): 227 | obj_names = list(range(1, 89)) 228 | grasp_labels = {} 229 | for obj_name in tqdm(obj_names, desc='Loading grasping labels...'): 230 | label = np.load(os.path.join(root, 'grasp_label_simplified', '{}_labels.npz'.format(str(obj_name - 1).zfill(3)))) 231 | grasp_labels[obj_name] = (label['points'].astype(np.float32), label['width'].astype(np.float32), 232 | label['scores'].astype(np.float32)) 233 | 234 | return grasp_labels 235 | 236 | 237 | def minkowski_collate_fn(list_data): 238 | coordinates_batch, features_batch = ME.utils.sparse_collate([d["coors"] for d in list_data], 239 | [d["feats"] for d in list_data]) 240 | coordinates_batch, features_batch, _, quantize2original = ME.utils.sparse_quantize( 241 | coordinates_batch, features_batch, return_index=True, return_inverse=True) 242 | res = { 243 | "coors": coordinates_batch, 244 | "feats": features_batch, 245 | "quantize2original": quantize2original 246 | } 247 | 248 | def collate_fn_(batch): 249 | if type(batch[0]).__module__ == 'numpy': 250 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 251 | elif isinstance(batch[0], container_abcs.Sequence): 252 | return [[torch.from_numpy(sample) for sample in b] for b in batch] 253 | elif isinstance(batch[0], container_abcs.Mapping): 254 | for key in batch[0]: 255 | if key == 'coors' or key == 'feats': 256 | continue 257 | res[key] = collate_fn_([d[key] for d in batch]) 258 | return res 259 | res = collate_fn_(list_data) 260 | 261 | return res 262 | -------------------------------------------------------------------------------- /dataset/simplify_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--dataset_root', default=None, required=True) 8 | 9 | 10 | def simplify_grasp_labels(root, save_path): 11 | """ 12 | original dataset grasp_label files have redundant data, We can significantly save the memory cost 13 | """ 14 | obj_names = list(range(88)) 15 | if not os.path.exists(save_path): 16 | os.makedirs(save_path) 17 | for i in obj_names: 18 | print('\nsimplifying object {}:'.format(i)) 19 | label = np.load(os.path.join(root, 'grasp_label', '{}_labels.npz'.format(str(i).zfill(3)))) 20 | # point_num = len(label['points']) 21 | print('original shape: ', label['points'].shape, label['offsets'].shape, label['scores'].shape) 22 | # if point_num > 4820: 23 | # idxs = np.random.choice(point_num, 4820, False) 24 | # points = label['points'][idxs] 25 | # offsets = label['offsets'][idxs] 26 | # scores = label['scores'][idxs] 27 | # print('Warning!!! down sample object {}'.format(i)) 28 | # else: 29 | points = label['points'] 30 | scores = label['scores'] 31 | offsets = label['offsets'] 32 | width = offsets[:, :, :, :, 2] 33 | print('after simplify, offset shape: ', points.shape, scores.shape, width.shape) 34 | np.savez(os.path.join(save_path, '{}_labels.npz'.format(str(i).zfill(3))), 35 | points=points, scores=scores, width=width) 36 | 37 | 38 | if __name__ == '__main__': 39 | cfgs = parser.parse_args() 40 | root = cfgs.dataset_root # set root and save path 41 | save_path = os.path.join(root, 'grasp_label_simplified') 42 | simplify_grasp_labels(root, save_path) 43 | 44 | -------------------------------------------------------------------------------- /dataset/vis_graspness.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import scipy.io as scio 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | import sys 7 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(ROOT_DIR) 9 | from utils.data_utils import get_workspace_mask, CameraInfo, create_point_cloud_from_depth_image 10 | 11 | data_path = '/media/bot/980A6F5E0A6F38801/datasets/graspnet/' 12 | scene_id = 'scene_0060' 13 | ann_id = '0000' 14 | camera_type = 'realsense' 15 | color = np.array(Image.open(os.path.join(data_path, 'scenes', scene_id, camera_type, 'rgb', ann_id + '.png')), dtype=np.float32) / 255.0 16 | depth = np.array(Image.open(os.path.join(data_path, 'scenes', scene_id, camera_type, 'depth', ann_id + '.png'))) 17 | seg = np.array(Image.open(os.path.join(data_path, 'scenes', scene_id, camera_type, 'label', ann_id + '.png'))) 18 | meta = scio.loadmat(os.path.join(data_path, 'scenes', scene_id, camera_type, 'meta', ann_id + '.mat')) 19 | intrinsic = meta['intrinsic_matrix'] 20 | factor_depth = meta['factor_depth'] 21 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 22 | point_cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 23 | depth_mask = (depth > 0) 24 | camera_poses = np.load(os.path.join(data_path, 'scenes', scene_id, camera_type, 'camera_poses.npy')) 25 | align_mat = np.load(os.path.join(data_path, 'scenes', scene_id, camera_type, 'cam0_wrt_table.npy')) 26 | trans = np.dot(align_mat, camera_poses[int(ann_id)]) 27 | workspace_mask = get_workspace_mask(point_cloud, seg, trans=trans, organized=True, outlier=0.02) 28 | mask = (depth_mask & workspace_mask) 29 | point_cloud = point_cloud[mask] 30 | color = color[mask] 31 | seg = seg[mask] 32 | 33 | graspness_full = np.load(os.path.join(data_path, 'graspness', scene_id, camera_type, ann_id + '.npy')).squeeze() 34 | graspness_full[seg == 0] = 0. 35 | print('graspness full scene: ', graspness_full.shape, (graspness_full > 0.1).sum()) 36 | color[graspness_full > 0.1] = [0., 1., 0.] 37 | 38 | 39 | cloud = o3d.geometry.PointCloud() 40 | cloud.points = o3d.utility.Vector3dVector(point_cloud.astype(np.float32)) 41 | cloud.colors = o3d.utility.Vector3dVector(color.astype(np.float32)) 42 | o3d.visualization.draw_geometries([cloud]) 43 | -------------------------------------------------------------------------------- /doc/example_data/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/example_data/color.png -------------------------------------------------------------------------------- /doc/example_data/demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/example_data/demo_result.png -------------------------------------------------------------------------------- /doc/example_data/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/example_data/depth.png -------------------------------------------------------------------------------- /doc/example_data/meta.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/example_data/meta.mat -------------------------------------------------------------------------------- /doc/example_data/workspace_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/example_data/workspace_mask.png -------------------------------------------------------------------------------- /doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhett-chen/graspness_implementation/ff33da111e72db1b8697758c7863fcec2359280e/doc/teaser.png -------------------------------------------------------------------------------- /infer_vis_grasp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import argparse 5 | from PIL import Image 6 | import time 7 | import scipy.io as scio 8 | import torch 9 | import open3d as o3d 10 | from graspnetAPI.graspnet_eval import GraspGroup 11 | 12 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | from models.graspnet import GraspNet, pred_decode 16 | from dataset.graspnet_dataset import minkowski_collate_fn 17 | from collision_detector import ModelFreeCollisionDetector 18 | from data_utils import CameraInfo, create_point_cloud_from_depth_image, get_workspace_mask 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset_root', default='/data/datasets/graspnet') 22 | parser.add_argument('--checkpoint_path', default='/data/zibo/logs/graspness_kn.tar') 23 | parser.add_argument('--dump_dir', help='Dump dir to save outputs', default='/data/zibo/logs/') 24 | parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim') 25 | parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]') 26 | parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 15000]') 27 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]') 28 | parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size for sparse convolution') 29 | parser.add_argument('--collision_thresh', type=float, default=-1, 30 | help='Collision Threshold in collision detection [default: 0.01]') 31 | parser.add_argument('--voxel_size_cd', type=float, default=0.01, help='Voxel Size for collision detection') 32 | parser.add_argument('--infer', action='store_true', default=False) 33 | parser.add_argument('--vis', action='store_true', default=False) 34 | parser.add_argument('--scene', type=str, default='0188') 35 | parser.add_argument('--index', type=str, default='0000') 36 | cfgs = parser.parse_args() 37 | 38 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 39 | if not os.path.exists(cfgs.dump_dir): 40 | os.mkdir(cfgs.dump_dir) 41 | 42 | 43 | def data_process(): 44 | root = cfgs.dataset_root 45 | camera_type = cfgs.camera 46 | 47 | depth = np.array(Image.open(os.path.join(root, 'scenes', scene_id, camera_type, 'depth', index + '.png'))) 48 | seg = np.array(Image.open(os.path.join(root, 'scenes', scene_id, camera_type, 'label', index + '.png'))) 49 | meta = scio.loadmat(os.path.join(root, 'scenes', scene_id, camera_type, 'meta', index + '.mat')) 50 | try: 51 | intrinsic = meta['intrinsic_matrix'] 52 | factor_depth = meta['factor_depth'] 53 | except Exception as e: 54 | print(repr(e)) 55 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], 56 | factor_depth) 57 | # generate cloud 58 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 59 | 60 | # get valid points 61 | depth_mask = (depth > 0) 62 | camera_poses = np.load(os.path.join(root, 'scenes', scene_id, camera_type, 'camera_poses.npy')) 63 | align_mat = np.load(os.path.join(root, 'scenes', scene_id, camera_type, 'cam0_wrt_table.npy')) 64 | trans = np.dot(align_mat, camera_poses[int(index)]) 65 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 66 | mask = (depth_mask & workspace_mask) 67 | 68 | cloud_masked = cloud[mask] 69 | 70 | # sample points random 71 | if len(cloud_masked) >= cfgs.num_point: 72 | idxs = np.random.choice(len(cloud_masked), cfgs.num_point, replace=False) 73 | else: 74 | idxs1 = np.arange(len(cloud_masked)) 75 | idxs2 = np.random.choice(len(cloud_masked), cfgs.num_point - len(cloud_masked), replace=True) 76 | idxs = np.concatenate([idxs1, idxs2], axis=0) 77 | cloud_sampled = cloud_masked[idxs] 78 | 79 | ret_dict = {'point_clouds': cloud_sampled.astype(np.float32), 80 | 'coors': cloud_sampled.astype(np.float32) / cfgs.voxel_size, 81 | 'feats': np.ones_like(cloud_sampled).astype(np.float32), 82 | } 83 | return ret_dict 84 | 85 | 86 | # Init datasets and dataloaders 87 | def my_worker_init_fn(worker_id): 88 | np.random.seed(np.random.get_state()[1][0] + worker_id) 89 | pass 90 | 91 | 92 | def inference(data_input): 93 | batch_data = minkowski_collate_fn([data_input]) 94 | net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=False) 95 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 96 | net.to(device) 97 | # Load checkpoint 98 | checkpoint = torch.load(cfgs.checkpoint_path) 99 | net.load_state_dict(checkpoint['model_state_dict']) 100 | start_epoch = checkpoint['epoch'] 101 | print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch)) 102 | 103 | net.eval() 104 | tic = time.time() 105 | 106 | for key in batch_data: 107 | if 'list' in key: 108 | for i in range(len(batch_data[key])): 109 | for j in range(len(batch_data[key][i])): 110 | batch_data[key][i][j] = batch_data[key][i][j].to(device) 111 | else: 112 | batch_data[key] = batch_data[key].to(device) 113 | # Forward pass 114 | with torch.no_grad(): 115 | end_points = net(batch_data) 116 | grasp_preds = pred_decode(end_points) 117 | 118 | preds = grasp_preds[0].detach().cpu().numpy() 119 | gg = GraspGroup(preds) 120 | # collision detection 121 | if cfgs.collision_thresh > 0: 122 | cloud = data_input['point_clouds'] 123 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size_cd) 124 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 125 | gg = gg[~collision_mask] 126 | 127 | # save grasps 128 | save_dir = os.path.join(cfgs.dump_dir, scene_id, cfgs.camera) 129 | save_path = os.path.join(save_dir, cfgs.index + '.npy') 130 | if not os.path.exists(save_dir): 131 | os.makedirs(save_dir) 132 | gg.save_npy(save_path) 133 | 134 | toc = time.time() 135 | print('inference time: %fs' % (toc - tic)) 136 | 137 | 138 | if __name__ == '__main__': 139 | scene_id = 'scene_' + cfgs.scene 140 | index = cfgs.index 141 | data_dict = data_process() 142 | 143 | if cfgs.infer: 144 | inference(data_dict) 145 | if cfgs.vis: 146 | pc = data_dict['point_clouds'] 147 | gg = np.load(os.path.join(cfgs.dump_dir, scene_id, cfgs.camera, cfgs.index + '.npy')) 148 | gg = GraspGroup(gg) 149 | gg = gg.nms() 150 | gg = gg.sort_by_score() 151 | if gg.__len__() > 30: 152 | gg = gg[:30] 153 | grippers = gg.to_open3d_geometry_list() 154 | cloud = o3d.geometry.PointCloud() 155 | cloud.points = o3d.utility.Vector3dVector(pc.astype(np.float32)) 156 | o3d.visualization.draw_geometries([cloud, *grippers]) 157 | -------------------------------------------------------------------------------- /knn/knn_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import gc 3 | import operator as op 4 | import functools 5 | import torch 6 | from torch.autograd import Variable, Function 7 | from knn_pytorch import knn_pytorch 8 | # import knn_pytorch 9 | def knn(ref, query, k=1): 10 | """ Compute k nearest neighbors for each query point. 11 | """ 12 | device = ref.device 13 | ref = ref.float().to(device) 14 | query = query.float().to(device) 15 | inds = torch.empty(query.shape[0], k, query.shape[2]).long().to(device) 16 | knn_pytorch.knn(ref, query, inds) 17 | return inds 18 | -------------------------------------------------------------------------------- /knn/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import CUDA_HOME 10 | from torch.utils.cpp_extension import CppExtension 11 | from torch.utils.cpp_extension import CUDAExtension 12 | 13 | requirements = ["torch", "torchvision"] 14 | 15 | 16 | def get_extensions(): 17 | this_dir = os.path.dirname(os.path.abspath(__file__)) 18 | extensions_dir = os.path.join(this_dir, "src") 19 | 20 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 21 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 22 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 23 | 24 | sources = main_file + source_cpu 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if torch.cuda.is_available() and CUDA_HOME is not None: 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | 43 | include_dirs = [extensions_dir] 44 | 45 | ext_modules = [ 46 | extension( 47 | "knn_pytorch.knn_pytorch", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ] 54 | 55 | return ext_modules 56 | 57 | 58 | setup( 59 | name="knn_pytorch", 60 | version="0.1", 61 | author="foolyc", 62 | url="https://github.com/foolyc/torchKNN", 63 | description="KNN implement in Pytorch 1.0 including both cpu version and gpu version", 64 | ext_modules=get_extensions(), 65 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 66 | ) 67 | -------------------------------------------------------------------------------- /knn/src/cpu/knn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/vision.h" 2 | 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, float* query_dev, int query_width, 5 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf) 6 | { 7 | // Compute all the distances 8 | for(int query_idx = 0;query_idx dist_dev[query_idx * ref_width + j + 1]) 31 | { 32 | temp_value = dist_dev[query_idx * ref_width + j]; 33 | dist_dev[query_idx * ref_width + j] = dist_dev[query_idx * ref_width + j + 1]; 34 | dist_dev[query_idx * ref_width + j + 1] = temp_value; 35 | temp_idx = ind_buf[j]; 36 | ind_buf[j] = ind_buf[j + 1]; 37 | ind_buf[j + 1] = temp_idx; 38 | } 39 | 40 | } 41 | 42 | for(int i = 0;i < k;i++) 43 | ind_dev[query_idx + i * query_width] = ind_buf[i]; 44 | #if DEBUG 45 | for(int i = 0;i < ref_width;i++) 46 | printf("%d, ", ind_buf[i]); 47 | printf("\n"); 48 | #endif 49 | 50 | } 51 | 52 | 53 | 54 | 55 | 56 | } -------------------------------------------------------------------------------- /knn/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, 5 | float* query_dev, int query_width, 6 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf); -------------------------------------------------------------------------------- /knn/src/cuda/knn.cu: -------------------------------------------------------------------------------- 1 | /** Modifed version of knn-CUDA from https://github.com/vincentfpgarcia/kNN-CUDA 2 | * The modifications are 3 | * removed texture memory usage 4 | * removed split query KNN computation 5 | * added feature extraction with bilinear interpolation 6 | * 7 | * Last modified by Christopher B. Choy 12/23/2016 8 | */ 9 | 10 | // Includes 11 | #include 12 | #include "cuda.h" 13 | 14 | #define IDX2D(i, j, dj) (dj * i + j) 15 | #define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk)) 16 | 17 | #define BLOCK 512 18 | #define MAX_STREAMS 512 19 | 20 | // Constants used by the program 21 | #define BLOCK_DIM 16 22 | #define DEBUG 0 23 | 24 | 25 | /** 26 | * Computes the distance between two matrix A (reference points) and 27 | * B (query points) containing respectively wA and wB points. 28 | * 29 | * @param A pointer on the matrix A 30 | * @param wA width of the matrix A = number of points in A 31 | * @param B pointer on the matrix B 32 | * @param wB width of the matrix B = number of points in B 33 | * @param dim dimension of points = height of matrices A and B 34 | * @param AB pointer on the matrix containing the wA*wB distances computed 35 | */ 36 | __global__ void cuComputeDistanceGlobal( float* A, int wA, 37 | float* B, int wB, int dim, float* AB){ 38 | 39 | // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B 40 | __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; 41 | __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; 42 | 43 | 44 | // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) 45 | __shared__ int begin_A; 46 | __shared__ int begin_B; 47 | __shared__ int step_A; 48 | __shared__ int step_B; 49 | __shared__ int end_A; 50 | 51 | // Thread index 52 | int tx = threadIdx.x; 53 | int ty = threadIdx.y; 54 | 55 | // Other variables 56 | float tmp; 57 | float ssd = 0; 58 | 59 | // Loop parameters 60 | begin_A = BLOCK_DIM * blockIdx.y; 61 | begin_B = BLOCK_DIM * blockIdx.x; 62 | step_A = BLOCK_DIM * wA; 63 | step_B = BLOCK_DIM * wB; 64 | end_A = begin_A + (dim-1) * wA; 65 | 66 | // Conditions 67 | int cond0 = (begin_A + tx < wA); // used to write in shared memory 68 | int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix 69 | int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix 70 | 71 | // Loop over all the sub-matrices of A and B required to compute the block sub-matrix 72 | for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { 73 | // Load the matrices from device memory to shared memory; each thread loads one element of each matrix 74 | if (a/wA + ty < dim){ 75 | shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0; 76 | shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0; 77 | } 78 | else{ 79 | shared_A[ty][tx] = 0; 80 | shared_B[ty][tx] = 0; 81 | } 82 | 83 | // Synchronize to make sure the matrices are loaded 84 | __syncthreads(); 85 | 86 | // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix 87 | if (cond2 && cond1){ 88 | for (int k = 0; k < BLOCK_DIM; ++k){ 89 | tmp = shared_A[k][ty] - shared_B[k][tx]; 90 | ssd += tmp*tmp; 91 | } 92 | } 93 | 94 | // Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration 95 | __syncthreads(); 96 | } 97 | 98 | // Write the block sub-matrix to device memory; each thread writes one element 99 | if (cond2 && cond1) 100 | AB[(begin_A + ty) * wB + begin_B + tx] = ssd; 101 | } 102 | 103 | 104 | /** 105 | * Gathers k-th smallest distances for each column of the distance matrix in the top. 106 | * 107 | * @param dist distance matrix 108 | * @param ind index matrix 109 | * @param width width of the distance matrix and of the index matrix 110 | * @param height height of the distance matrix and of the index matrix 111 | * @param k number of neighbors to consider 112 | */ 113 | __global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){ 114 | 115 | // Variables 116 | int l, i, j; 117 | float *p_dist; 118 | long *p_ind; 119 | float curr_dist, max_dist; 120 | long curr_row, max_row; 121 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 122 | 123 | if (xIndexcurr_dist){ 138 | i=a; 139 | break; 140 | } 141 | } 142 | for (j=l; j>i; j--){ 143 | p_dist[j*width] = p_dist[(j-1)*width]; 144 | p_ind[j*width] = p_ind[(j-1)*width]; 145 | } 146 | p_dist[i*width] = curr_dist; 147 | p_ind[i*width] = l+1; 148 | } else { 149 | p_ind[l*width] = l+1; 150 | } 151 | max_dist = p_dist[curr_row]; 152 | } 153 | 154 | // Part 2 : insert element in the k-th first lines 155 | max_row = (k-1)*width; 156 | for (l=k; lcurr_dist){ 162 | i=a; 163 | break; 164 | } 165 | } 166 | for (j=k-1; j>i; j--){ 167 | p_dist[j*width] = p_dist[(j-1)*width]; 168 | p_ind[j*width] = p_ind[(j-1)*width]; 169 | } 170 | p_dist[i*width] = curr_dist; 171 | p_ind[i*width] = l+1; 172 | max_dist = p_dist[max_row]; 173 | } 174 | } 175 | } 176 | } 177 | 178 | 179 | /** 180 | * Computes the square root of the first line (width-th first element) 181 | * of the distance matrix. 182 | * 183 | * @param dist distance matrix 184 | * @param width width of the distance matrix 185 | * @param k number of neighbors to consider 186 | */ 187 | __global__ void cuParallelSqrt(float *dist, int width, int k){ 188 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 189 | unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y; 190 | if (xIndex>>(ref_dev, ref_nb, query_dev, query_nb, dim, dist_dev); 237 | 238 | // Kernel 2: Sort each column 239 | cuInsertionSort<<>>(dist_dev, ind_dev, query_nb, ref_nb, k); 240 | 241 | // Kernel 3: Compute square root of k first elements 242 | // cuParallelSqrt<<>>(dist_dev, query_nb, k); 243 | 244 | #if DEBUG 245 | unsigned int size_of_float = sizeof(float); 246 | unsigned long size_of_long = sizeof(long); 247 | 248 | float* dist_host = new float[query_nb * k]; 249 | long* idx_host = new long[query_nb * k]; 250 | 251 | // Memory copy of output from device to host 252 | cudaMemcpy(&dist_host[0], dist_dev, 253 | query_nb * k *size_of_float, cudaMemcpyDeviceToHost); 254 | 255 | cudaMemcpy(&idx_host[0], ind_dev, 256 | query_nb * k * size_of_long, cudaMemcpyDeviceToHost); 257 | 258 | int i = 0; 259 | for(i = 0; i < 100; i++){ 260 | printf("IDX[%d]: %d\n", i, (int)idx_host[i]); 261 | } 262 | #endif 263 | } 264 | 265 | 266 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /knn/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void knn_device(float* ref_dev, int ref_width, 6 | float* query_dev, int query_width, 7 | int height, int k, float* dist_dev, long* ind_dev, cudaStream_t stream); -------------------------------------------------------------------------------- /knn/src/knn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cpu/vision.h" 3 | 4 | #ifdef WITH_CUDA 5 | #include "cuda/vision.h" 6 | #include 7 | extern THCState *state; 8 | #endif 9 | 10 | 11 | 12 | int knn(at::Tensor& ref, at::Tensor& query, at::Tensor& idx) 13 | { 14 | 15 | // TODO check dimensions 16 | long batch, ref_nb, query_nb, dim, k; 17 | batch = ref.size(0); 18 | dim = ref.size(1); 19 | k = idx.size(1); 20 | ref_nb = ref.size(2); 21 | query_nb = query.size(2); 22 | 23 | float *ref_dev = ref.data(); 24 | float *query_dev = query.data(); 25 | long *idx_dev = idx.data(); 26 | 27 | 28 | 29 | 30 | if (ref.type().is_cuda()) { 31 | #ifdef WITH_CUDA 32 | // TODO raise error if not compiled with CUDA 33 | float *dist_dev = (float*)THCudaMalloc(state, ref_nb * query_nb * sizeof(float)); 34 | 35 | for (int b = 0; b < batch; b++) 36 | { 37 | // knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 38 | // dist_dev, idx_dev + b * k * query_nb, THCState_getCurrentStream(state)); 39 | knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 40 | dist_dev, idx_dev + b * k * query_nb, c10::cuda::getCurrentCUDAStream()); 41 | } 42 | THCudaFree(state, dist_dev); 43 | cudaError_t err = cudaGetLastError(); 44 | if (err != cudaSuccess) 45 | { 46 | printf("error in knn: %s\n", cudaGetErrorString(err)); 47 | THError("aborting"); 48 | } 49 | return 1; 50 | #else 51 | AT_ERROR("Not compiled with GPU support"); 52 | #endif 53 | } 54 | 55 | 56 | float *dist_dev = (float*)malloc(ref_nb * query_nb * sizeof(float)); 57 | long *ind_buf = (long*)malloc(ref_nb * sizeof(long)); 58 | for (int b = 0; b < batch; b++) { 59 | knn_cpu(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 60 | dist_dev, idx_dev + b * k * query_nb, ind_buf); 61 | } 62 | 63 | free(dist_dev); 64 | free(ind_buf); 65 | 66 | return 1; 67 | 68 | } 69 | -------------------------------------------------------------------------------- /knn/src/vision.cpp: -------------------------------------------------------------------------------- 1 | #include "knn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("knn", &knn, "k-nearest neighbors"); 5 | } 6 | -------------------------------------------------------------------------------- /models/backbone_resunet14.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 3 | from models.resnet import ResNetBase 4 | 5 | 6 | class MinkUNetBase(ResNetBase): 7 | BLOCK = None 8 | PLANES = None 9 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 10 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 11 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 12 | INIT_DIM = 32 13 | OUT_TENSOR_STRIDE = 1 14 | 15 | # To use the model, must call initialize_coords before forward pass. 16 | # Once data is processed, call clear to reset the model before calling 17 | # initialize_coords 18 | def __init__(self, in_channels, out_channels, D=3): 19 | ResNetBase.__init__(self, in_channels, out_channels, D) 20 | 21 | def network_initialization(self, in_channels, out_channels, D): 22 | # Output of the first conv concated to conv6 23 | self.inplanes = self.INIT_DIM 24 | self.conv0p1s1 = ME.MinkowskiConvolution( 25 | in_channels, self.inplanes, kernel_size=5, dimension=D) 26 | 27 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 28 | 29 | self.conv1p1s2 = ME.MinkowskiConvolution( 30 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 31 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 32 | 33 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], 34 | self.LAYERS[0]) 35 | 36 | self.conv2p2s2 = ME.MinkowskiConvolution( 37 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 38 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 39 | 40 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], 41 | self.LAYERS[1]) 42 | 43 | self.conv3p4s2 = ME.MinkowskiConvolution( 44 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 45 | 46 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 47 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], 48 | self.LAYERS[2]) 49 | 50 | self.conv4p8s2 = ME.MinkowskiConvolution( 51 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 52 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 53 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], 54 | self.LAYERS[3]) 55 | 56 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( 57 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) 58 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) 59 | 60 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion 61 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], 62 | self.LAYERS[4]) 63 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( 64 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) 65 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) 66 | 67 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion 68 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], 69 | self.LAYERS[5]) 70 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( 71 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) 72 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) 73 | 74 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion 75 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], 76 | self.LAYERS[6]) 77 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( 78 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) 79 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 80 | 81 | self.inplanes = self.PLANES[7] + self.INIT_DIM 82 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], 83 | self.LAYERS[7]) 84 | 85 | self.final = ME.MinkowskiConvolution( 86 | self.PLANES[7] * self.BLOCK.expansion, 87 | out_channels, 88 | kernel_size=1, 89 | bias=True, 90 | dimension=D) 91 | self.relu = ME.MinkowskiReLU(inplace=True) 92 | 93 | def forward(self, x): 94 | out = self.conv0p1s1(x) 95 | out = self.bn0(out) 96 | out_p1 = self.relu(out) 97 | 98 | out = self.conv1p1s2(out_p1) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | out_b1p2 = self.block1(out) 102 | 103 | out = self.conv2p2s2(out_b1p2) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | out_b2p4 = self.block2(out) 107 | 108 | out = self.conv3p4s2(out_b2p4) 109 | out = self.bn3(out) 110 | out = self.relu(out) 111 | out_b3p8 = self.block3(out) 112 | 113 | # tensor_stride=16 114 | out = self.conv4p8s2(out_b3p8) 115 | out = self.bn4(out) 116 | out = self.relu(out) 117 | out = self.block4(out) 118 | 119 | # tensor_stride=8 120 | out = self.convtr4p16s2(out) 121 | out = self.bntr4(out) 122 | out = self.relu(out) 123 | 124 | out = ME.cat(out, out_b3p8) 125 | out = self.block5(out) 126 | 127 | # tensor_stride=4 128 | out = self.convtr5p8s2(out) 129 | out = self.bntr5(out) 130 | out = self.relu(out) 131 | 132 | out = ME.cat(out, out_b2p4) 133 | out = self.block6(out) 134 | 135 | # tensor_stride=2 136 | out = self.convtr6p4s2(out) 137 | out = self.bntr6(out) 138 | out = self.relu(out) 139 | 140 | out = ME.cat(out, out_b1p2) 141 | out = self.block7(out) 142 | 143 | # tensor_stride=1 144 | out = self.convtr7p2s2(out) 145 | out = self.bntr7(out) 146 | out = self.relu(out) 147 | 148 | out = ME.cat(out, out_p1) 149 | out = self.block8(out) 150 | 151 | return self.final(out) 152 | 153 | 154 | class MinkUNet14(MinkUNetBase): 155 | BLOCK = BasicBlock 156 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 157 | 158 | 159 | class MinkUNet18(MinkUNetBase): 160 | BLOCK = BasicBlock 161 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 162 | 163 | 164 | class MinkUNet34(MinkUNetBase): 165 | BLOCK = BasicBlock 166 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 167 | 168 | 169 | class MinkUNet50(MinkUNetBase): 170 | BLOCK = Bottleneck 171 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 172 | 173 | 174 | class MinkUNet101(MinkUNetBase): 175 | BLOCK = Bottleneck 176 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 177 | 178 | 179 | class MinkUNet14A(MinkUNet14): 180 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 181 | 182 | 183 | class MinkUNet14B(MinkUNet14): 184 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 185 | 186 | 187 | class MinkUNet14C(MinkUNet14): 188 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 189 | 190 | 191 | class MinkUNet14Dori(MinkUNet14): 192 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 193 | 194 | 195 | class MinkUNet14E(MinkUNet14): 196 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 197 | 198 | 199 | class MinkUNet14D(MinkUNet14): 200 | PLANES = (32, 64, 128, 256, 192, 192, 192, 192) 201 | 202 | 203 | class MinkUNet18A(MinkUNet18): 204 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 205 | 206 | 207 | class MinkUNet18B(MinkUNet18): 208 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 209 | 210 | 211 | class MinkUNet18D(MinkUNet18): 212 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 213 | 214 | 215 | class MinkUNet34A(MinkUNet34): 216 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 217 | 218 | 219 | class MinkUNet34B(MinkUNet34): 220 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 221 | 222 | 223 | class MinkUNet34C(MinkUNet34): 224 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 225 | -------------------------------------------------------------------------------- /models/graspnet.py: -------------------------------------------------------------------------------- 1 | """ GraspNet baseline model definition. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import MinkowskiEngine as ME 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(ROOT_DIR) 15 | 16 | from models.backbone_resunet14 import MinkUNet14D 17 | from models.modules import ApproachNet, GraspableNet, CloudCrop, SWADNet 18 | from loss_utils import GRASP_MAX_WIDTH, NUM_VIEW, NUM_ANGLE, NUM_DEPTH, GRASPNESS_THRESHOLD, M_POINT 19 | from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix 20 | from pointnet2.pointnet2_utils import furthest_point_sample, gather_operation 21 | 22 | 23 | class GraspNet(nn.Module): 24 | def __init__(self, cylinder_radius=0.05, seed_feat_dim=512, is_training=True): 25 | super().__init__() 26 | self.is_training = is_training 27 | self.seed_feature_dim = seed_feat_dim 28 | self.num_depth = NUM_DEPTH 29 | self.num_angle = NUM_ANGLE 30 | self.M_points = M_POINT 31 | self.num_view = NUM_VIEW 32 | 33 | self.backbone = MinkUNet14D(in_channels=3, out_channels=self.seed_feature_dim, D=3) 34 | self.graspable = GraspableNet(seed_feature_dim=self.seed_feature_dim) 35 | self.rotation = ApproachNet(self.num_view, seed_feature_dim=self.seed_feature_dim, is_training=self.is_training) 36 | self.crop = CloudCrop(nsample=16, cylinder_radius=cylinder_radius, seed_feature_dim=self.seed_feature_dim) 37 | self.swad = SWADNet(num_angle=self.num_angle, num_depth=self.num_depth) 38 | 39 | def forward(self, end_points): 40 | seed_xyz = end_points['point_clouds'] # use all sampled point cloud, B*Ns*3 41 | B, point_num, _ = seed_xyz.shape # batch _size 42 | # point-wise features 43 | coordinates_batch = end_points['coors'] 44 | features_batch = end_points['feats'] 45 | mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch) 46 | seed_features = self.backbone(mink_input).F 47 | seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2) 48 | 49 | end_points = self.graspable(seed_features, end_points) 50 | seed_features_flipped = seed_features.transpose(1, 2) # B*Ns*feat_dim 51 | objectness_score = end_points['objectness_score'] 52 | graspness_score = end_points['graspness_score'].squeeze(1) 53 | objectness_pred = torch.argmax(objectness_score, 1) 54 | objectness_mask = (objectness_pred == 1) 55 | graspness_mask = graspness_score > GRASPNESS_THRESHOLD 56 | graspable_mask = objectness_mask & graspness_mask 57 | 58 | seed_features_graspable = [] 59 | seed_xyz_graspable = [] 60 | graspable_num_batch = 0. 61 | for i in range(B): 62 | cur_mask = graspable_mask[i] 63 | graspable_num_batch += cur_mask.sum() 64 | cur_feat = seed_features_flipped[i][cur_mask] # Ns*feat_dim 65 | cur_seed_xyz = seed_xyz[i][cur_mask] # Ns*3 66 | 67 | cur_seed_xyz = cur_seed_xyz.unsqueeze(0) # 1*Ns*3 68 | fps_idxs = furthest_point_sample(cur_seed_xyz, self.M_points) 69 | cur_seed_xyz_flipped = cur_seed_xyz.transpose(1, 2).contiguous() # 1*3*Ns 70 | cur_seed_xyz = gather_operation(cur_seed_xyz_flipped, fps_idxs).transpose(1, 2).squeeze(0).contiguous() # Ns*3 71 | cur_feat_flipped = cur_feat.unsqueeze(0).transpose(1, 2).contiguous() # 1*feat_dim*Ns 72 | cur_feat = gather_operation(cur_feat_flipped, fps_idxs).squeeze(0).contiguous() # feat_dim*Ns 73 | 74 | seed_features_graspable.append(cur_feat) 75 | seed_xyz_graspable.append(cur_seed_xyz) 76 | seed_xyz_graspable = torch.stack(seed_xyz_graspable, 0) # B*Ns*3 77 | seed_features_graspable = torch.stack(seed_features_graspable) # B*feat_dim*Ns 78 | end_points['xyz_graspable'] = seed_xyz_graspable 79 | end_points['graspable_count_stage1'] = graspable_num_batch / B 80 | 81 | end_points, res_feat = self.rotation(seed_features_graspable, end_points) 82 | seed_features_graspable = seed_features_graspable + res_feat 83 | 84 | if self.is_training: 85 | end_points = process_grasp_labels(end_points) 86 | grasp_top_views_rot, end_points = match_grasp_view_and_label(end_points) 87 | else: 88 | grasp_top_views_rot = end_points['grasp_top_view_rot'] 89 | 90 | group_features = self.crop(seed_xyz_graspable.contiguous(), seed_features_graspable.contiguous(), grasp_top_views_rot) 91 | end_points = self.swad(group_features, end_points) 92 | 93 | return end_points 94 | 95 | 96 | def pred_decode(end_points): 97 | batch_size = len(end_points['point_clouds']) 98 | grasp_preds = [] 99 | for i in range(batch_size): 100 | grasp_center = end_points['xyz_graspable'][i].float() 101 | 102 | grasp_score = end_points['grasp_score_pred'][i].float() 103 | grasp_score = grasp_score.view(M_POINT, NUM_ANGLE*NUM_DEPTH) 104 | grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT] 105 | grasp_score = grasp_score.view(-1, 1) 106 | grasp_angle = (grasp_score_inds // NUM_DEPTH) * np.pi / 12 107 | grasp_depth = (grasp_score_inds % NUM_DEPTH + 1) * 0.01 108 | grasp_depth = grasp_depth.view(-1, 1) 109 | grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10. 110 | grasp_width = grasp_width.view(M_POINT, NUM_ANGLE*NUM_DEPTH) 111 | grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1)) 112 | grasp_width = torch.clamp(grasp_width, min=0., max=GRASP_MAX_WIDTH) 113 | 114 | approaching = -end_points['grasp_top_view_xyz'][i].float() 115 | grasp_rot = batch_viewpoint_params_to_matrix(approaching, grasp_angle) 116 | grasp_rot = grasp_rot.view(M_POINT, 9) 117 | 118 | # merge preds 119 | grasp_height = 0.02 * torch.ones_like(grasp_score) 120 | obj_ids = -1 * torch.ones_like(grasp_score) 121 | grasp_preds.append( 122 | torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1)) 123 | return grasp_preds 124 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def get_loss(end_points): 6 | objectness_loss, end_points = compute_objectness_loss(end_points) 7 | graspness_loss, end_points = compute_graspness_loss(end_points) 8 | view_loss, end_points = compute_view_graspness_loss(end_points) 9 | score_loss, end_points = compute_score_loss(end_points) 10 | width_loss, end_points = compute_width_loss(end_points) 11 | loss = objectness_loss + 10 * graspness_loss + 100 * view_loss + 15 * score_loss + 10 * width_loss 12 | end_points['loss/overall_loss'] = loss 13 | return loss, end_points 14 | 15 | 16 | def compute_objectness_loss(end_points): 17 | criterion = nn.CrossEntropyLoss(reduction='mean') 18 | objectness_score = end_points['objectness_score'] 19 | objectness_label = end_points['objectness_label'] 20 | loss = criterion(objectness_score, objectness_label) 21 | end_points['loss/stage1_objectness_loss'] = loss 22 | 23 | objectness_pred = torch.argmax(objectness_score, 1) 24 | end_points['stage1_objectness_acc'] = (objectness_pred == objectness_label.long()).float().mean() 25 | end_points['stage1_objectness_prec'] = (objectness_pred == objectness_label.long())[ 26 | objectness_pred == 1].float().mean() 27 | end_points['stage1_objectness_recall'] = (objectness_pred == objectness_label.long())[ 28 | objectness_label == 1].float().mean() 29 | return loss, end_points 30 | 31 | 32 | def compute_graspness_loss(end_points): 33 | criterion = nn.SmoothL1Loss(reduction='none') 34 | graspness_score = end_points['graspness_score'].squeeze(1) 35 | graspness_label = end_points['graspness_label'].squeeze(-1) 36 | loss_mask = end_points['objectness_label'].bool() 37 | loss = criterion(graspness_score, graspness_label) 38 | loss = loss[loss_mask] 39 | loss = loss.mean() 40 | 41 | graspness_score_c = graspness_score.detach().clone()[loss_mask] 42 | graspness_label_c = graspness_label.detach().clone()[loss_mask] 43 | graspness_score_c = torch.clamp(graspness_score_c, 0., 0.99) 44 | graspness_label_c = torch.clamp(graspness_label_c, 0., 0.99) 45 | rank_error = (torch.abs(torch.trunc(graspness_score_c * 20) - torch.trunc(graspness_label_c * 20)) / 20.).mean() 46 | end_points['stage1_graspness_acc_rank_error'] = rank_error 47 | 48 | end_points['loss/stage1_graspness_loss'] = loss 49 | return loss, end_points 50 | 51 | 52 | def compute_view_graspness_loss(end_points): 53 | criterion = nn.SmoothL1Loss(reduction='mean') 54 | view_score = end_points['view_score'] 55 | view_label = end_points['batch_grasp_view_graspness'] 56 | loss = criterion(view_score, view_label) 57 | end_points['loss/stage2_view_loss'] = loss 58 | return loss, end_points 59 | 60 | 61 | def compute_score_loss(end_points): 62 | criterion = nn.SmoothL1Loss(reduction='mean') 63 | grasp_score_pred = end_points['grasp_score_pred'] 64 | grasp_score_label = end_points['batch_grasp_score'] 65 | loss = criterion(grasp_score_pred, grasp_score_label) 66 | 67 | end_points['loss/stage3_score_loss'] = loss 68 | return loss, end_points 69 | 70 | 71 | def compute_width_loss(end_points): 72 | criterion = nn.SmoothL1Loss(reduction='none') 73 | grasp_width_pred = end_points['grasp_width_pred'] 74 | grasp_width_label = end_points['batch_grasp_width'] * 10 75 | loss = criterion(grasp_width_pred, grasp_width_label) 76 | grasp_score_label = end_points['batch_grasp_score'] 77 | loss_mask = grasp_score_label > 0 78 | loss = loss[loss_mask].mean() 79 | end_points['loss/stage3_width_loss'] = loss 80 | return loss, end_points 81 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | ROOT_DIR = os.path.dirname(BASE_DIR) 9 | sys.path.append(ROOT_DIR) 10 | 11 | import pointnet2.pytorch_utils as pt_utils 12 | from pointnet2.pointnet2_utils import CylinderQueryAndGroup 13 | from loss_utils import generate_grasp_views, batch_viewpoint_params_to_matrix 14 | 15 | 16 | class GraspableNet(nn.Module): 17 | def __init__(self, seed_feature_dim): 18 | super().__init__() 19 | self.in_dim = seed_feature_dim 20 | self.conv_graspable = nn.Conv1d(self.in_dim, 3, 1) 21 | 22 | def forward(self, seed_features, end_points): 23 | graspable_score = self.conv_graspable(seed_features) # (B, 3, num_seed) 24 | end_points['objectness_score'] = graspable_score[:, :2] 25 | end_points['graspness_score'] = graspable_score[:, 2] 26 | return end_points 27 | 28 | 29 | class ApproachNet(nn.Module): 30 | def __init__(self, num_view, seed_feature_dim, is_training=True): 31 | super().__init__() 32 | self.num_view = num_view 33 | self.in_dim = seed_feature_dim 34 | self.is_training = is_training 35 | self.conv1 = nn.Conv1d(self.in_dim, self.in_dim, 1) 36 | self.conv2 = nn.Conv1d(self.in_dim, self.num_view, 1) 37 | 38 | def forward(self, seed_features, end_points): 39 | B, _, num_seed = seed_features.size() 40 | res_features = F.relu(self.conv1(seed_features), inplace=True) 41 | features = self.conv2(res_features) 42 | view_score = features.transpose(1, 2).contiguous() # (B, num_seed, num_view) 43 | end_points['view_score'] = view_score 44 | 45 | if self.is_training: 46 | # normalize view graspness score to 0~1 47 | view_score_ = view_score.clone().detach() 48 | view_score_max, _ = torch.max(view_score_, dim=2) 49 | view_score_min, _ = torch.min(view_score_, dim=2) 50 | view_score_max = view_score_max.unsqueeze(-1).expand(-1, -1, self.num_view) 51 | view_score_min = view_score_min.unsqueeze(-1).expand(-1, -1, self.num_view) 52 | view_score_ = (view_score_ - view_score_min) / (view_score_max - view_score_min + 1e-8) 53 | 54 | top_view_inds = [] 55 | for i in range(B): 56 | top_view_inds_batch = torch.multinomial(view_score_[i], 1, replacement=False) 57 | top_view_inds.append(top_view_inds_batch) 58 | top_view_inds = torch.stack(top_view_inds, dim=0).squeeze(-1) # B, num_seed 59 | else: 60 | _, top_view_inds = torch.max(view_score, dim=2) # (B, num_seed) 61 | 62 | top_view_inds_ = top_view_inds.view(B, num_seed, 1, 1).expand(-1, -1, -1, 3).contiguous() 63 | template_views = generate_grasp_views(self.num_view).to(features.device) # (num_view, 3) 64 | template_views = template_views.view(1, 1, self.num_view, 3).expand(B, num_seed, -1, -1).contiguous() 65 | vp_xyz = torch.gather(template_views, 2, top_view_inds_).squeeze(2) # (B, num_seed, 3) 66 | vp_xyz_ = vp_xyz.view(-1, 3) 67 | batch_angle = torch.zeros(vp_xyz_.size(0), dtype=vp_xyz.dtype, device=vp_xyz.device) 68 | vp_rot = batch_viewpoint_params_to_matrix(-vp_xyz_, batch_angle).view(B, num_seed, 3, 3) 69 | end_points['grasp_top_view_xyz'] = vp_xyz 70 | end_points['grasp_top_view_rot'] = vp_rot 71 | 72 | end_points['grasp_top_view_inds'] = top_view_inds 73 | return end_points, res_features 74 | 75 | 76 | class CloudCrop(nn.Module): 77 | def __init__(self, nsample, seed_feature_dim, cylinder_radius=0.05, hmin=-0.02, hmax=0.04): 78 | super().__init__() 79 | self.nsample = nsample 80 | self.in_dim = seed_feature_dim 81 | self.cylinder_radius = cylinder_radius 82 | mlps = [3 + self.in_dim, 256, 256] # use xyz, so plus 3 83 | 84 | self.grouper = CylinderQueryAndGroup(radius=cylinder_radius, hmin=hmin, hmax=hmax, nsample=nsample, 85 | use_xyz=True, normalize_xyz=True) 86 | self.mlps = pt_utils.SharedMLP(mlps, bn=True) 87 | 88 | def forward(self, seed_xyz_graspable, seed_features_graspable, vp_rot): 89 | grouped_feature = self.grouper(seed_xyz_graspable, seed_xyz_graspable, vp_rot, 90 | seed_features_graspable) # B*3 + feat_dim*M*K 91 | new_features = self.mlps(grouped_feature) # (batch_size, mlps[-1], M, K) 92 | new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]) # (batch_size, mlps[-1], M, 1) 93 | new_features = new_features.squeeze(-1) # (batch_size, mlps[-1], M) 94 | return new_features 95 | 96 | 97 | class SWADNet(nn.Module): 98 | def __init__(self, num_angle, num_depth): 99 | super().__init__() 100 | self.num_angle = num_angle 101 | self.num_depth = num_depth 102 | 103 | self.conv1 = nn.Conv1d(256, 256, 1) # input feat dim need to be consistent with CloudCrop module 104 | self.conv_swad = nn.Conv1d(256, 2*num_angle*num_depth, 1) 105 | 106 | def forward(self, vp_features, end_points): 107 | B, _, num_seed = vp_features.size() 108 | vp_features = F.relu(self.conv1(vp_features), inplace=True) 109 | vp_features = self.conv_swad(vp_features) 110 | vp_features = vp_features.view(B, 2, self.num_angle, self.num_depth, num_seed) 111 | vp_features = vp_features.permute(0, 1, 4, 2, 3) 112 | 113 | # split prediction 114 | end_points['grasp_score_pred'] = vp_features[:, 0] # B * num_seed * num angle * num_depth 115 | end_points['grasp_width_pred'] = vp_features[:, 1] 116 | return end_points 117 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | import open3d as o3d 5 | except ImportError: 6 | raise ImportError("Please install open3d with `pip install open3d`.") 7 | 8 | import MinkowskiEngine as ME 9 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 10 | 11 | 12 | class ResNetBase(nn.Module): 13 | BLOCK = None 14 | LAYERS = () 15 | INIT_DIM = 64 16 | PLANES = (64, 128, 256, 512) 17 | 18 | def __init__(self, in_channels, out_channels, D=3): 19 | nn.Module.__init__(self) 20 | self.D = D 21 | assert self.BLOCK is not None 22 | 23 | self.network_initialization(in_channels, out_channels, D) 24 | self.weight_initialization() 25 | 26 | def network_initialization(self, in_channels, out_channels, D): 27 | 28 | self.inplanes = self.INIT_DIM 29 | self.conv1 = nn.Sequential( 30 | ME.MinkowskiConvolution( 31 | in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D 32 | ), 33 | ME.MinkowskiInstanceNorm(self.inplanes), 34 | ME.MinkowskiReLU(inplace=True), 35 | ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D), 36 | ) 37 | 38 | self.layer1 = self._make_layer( 39 | self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2 40 | ) 41 | self.layer2 = self._make_layer( 42 | self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2 43 | ) 44 | self.layer3 = self._make_layer( 45 | self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2 46 | ) 47 | self.layer4 = self._make_layer( 48 | self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2 49 | ) 50 | 51 | self.conv5 = nn.Sequential( 52 | ME.MinkowskiDropout(), 53 | ME.MinkowskiConvolution( 54 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D 55 | ), 56 | ME.MinkowskiInstanceNorm(self.inplanes), 57 | ME.MinkowskiGELU(), 58 | ) 59 | 60 | self.glob_pool = ME.MinkowskiGlobalMaxPooling() 61 | 62 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 63 | 64 | def weight_initialization(self): 65 | for m in self.modules(): 66 | if isinstance(m, ME.MinkowskiConvolution): 67 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 68 | 69 | if isinstance(m, ME.MinkowskiBatchNorm): 70 | nn.init.constant_(m.bn.weight, 1) 71 | nn.init.constant_(m.bn.bias, 0) 72 | 73 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 74 | downsample = None 75 | if stride != 1 or self.inplanes != planes * block.expansion: 76 | downsample = nn.Sequential( 77 | ME.MinkowskiConvolution( 78 | self.inplanes, 79 | planes * block.expansion, 80 | kernel_size=1, 81 | stride=stride, 82 | dimension=self.D, 83 | ), 84 | ME.MinkowskiBatchNorm(planes * block.expansion), 85 | ) 86 | layers = [] 87 | layers.append( 88 | block( 89 | self.inplanes, 90 | planes, 91 | stride=stride, 92 | dilation=dilation, 93 | downsample=downsample, 94 | dimension=self.D, 95 | ) 96 | ) 97 | self.inplanes = planes * block.expansion 98 | for i in range(1, blocks): 99 | layers.append( 100 | block( 101 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D 102 | ) 103 | ) 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x: ME.SparseTensor): 108 | x = self.conv1(x) 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | x = self.layer3(x) 112 | x = self.layer4(x) 113 | x = self.conv5(x) 114 | x = self.glob_pool(x) 115 | return self.final(x) 116 | 117 | 118 | class ResNet14(ResNetBase): 119 | BLOCK = BasicBlock 120 | LAYERS = (1, 1, 1, 1) 121 | 122 | 123 | class ResNet18(ResNetBase): 124 | BLOCK = BasicBlock 125 | LAYERS = (2, 2, 2, 2) 126 | 127 | 128 | class ResNet34(ResNetBase): 129 | BLOCK = BasicBlock 130 | LAYERS = (3, 4, 6, 3) 131 | 132 | 133 | class ResNet50(ResNetBase): 134 | BLOCK = Bottleneck 135 | LAYERS = (3, 4, 6, 3) 136 | 137 | 138 | class ResNet101(ResNetBase): 139 | BLOCK = Bottleneck 140 | LAYERS = (3, 4, 23, 3) 141 | 142 | 143 | class ResFieldNetBase(ResNetBase): 144 | def network_initialization(self, in_channels, out_channels, D): 145 | field_ch = 32 146 | field_ch2 = 64 147 | self.field_network = nn.Sequential( 148 | ME.MinkowskiSinusoidal(in_channels, field_ch), 149 | ME.MinkowskiBatchNorm(field_ch), 150 | ME.MinkowskiReLU(inplace=True), 151 | ME.MinkowskiLinear(field_ch, field_ch), 152 | ME.MinkowskiBatchNorm(field_ch), 153 | ME.MinkowskiReLU(inplace=True), 154 | ME.MinkowskiToSparseTensor(), 155 | ) 156 | self.field_network2 = nn.Sequential( 157 | ME.MinkowskiSinusoidal(field_ch + in_channels, field_ch2), 158 | ME.MinkowskiBatchNorm(field_ch2), 159 | ME.MinkowskiReLU(inplace=True), 160 | ME.MinkowskiLinear(field_ch2, field_ch2), 161 | ME.MinkowskiBatchNorm(field_ch2), 162 | ME.MinkowskiReLU(inplace=True), 163 | ME.MinkowskiToSparseTensor(), 164 | ) 165 | 166 | ResNetBase.network_initialization(self, field_ch2, out_channels, D) 167 | 168 | def forward(self, x: ME.TensorField): 169 | otensor = self.field_network(x) 170 | otensor2 = self.field_network2(otensor.cat_slice(x)) 171 | return ResNetBase.forward(self, otensor2) 172 | 173 | 174 | class ResFieldNet14(ResFieldNetBase): 175 | BLOCK = BasicBlock 176 | LAYERS = (1, 1, 1, 1) 177 | 178 | 179 | class ResFieldNet18(ResFieldNetBase): 180 | BLOCK = BasicBlock 181 | LAYERS = (2, 2, 2, 2) 182 | 183 | 184 | class ResFieldNet34(ResFieldNetBase): 185 | BLOCK = BasicBlock 186 | LAYERS = (3, 4, 6, 3) 187 | 188 | 189 | class ResFieldNet50(ResFieldNetBase): 190 | BLOCK = Bottleneck 191 | LAYERS = (3, 4, 6, 3) 192 | 193 | 194 | class ResFieldNet101(ResFieldNetBase): 195 | BLOCK = Bottleneck 196 | LAYERS = (3, 4, 23, 3) 197 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cylinder_query.h: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #pragma once 4 | #include 5 | 6 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 7 | const int nsample); 8 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | #include "cylinder_query.h" 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("gather_points", &gather_points); 14 | m.def("gather_points_grad", &gather_points_grad); 15 | m.def("furthest_point_sampling", &furthest_point_sampling); 16 | 17 | m.def("three_nn", &three_nn); 18 | m.def("three_interpolate", &three_interpolate); 19 | m.def("three_interpolate_grad", &three_interpolate_grad); 20 | 21 | m.def("ball_query", &ball_query); 22 | 23 | m.def("group_points", &group_points); 24 | m.def("group_points_grad", &group_points_grad); 25 | 26 | m.def("cylinder_query", &cylinder_query); 27 | } 28 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/cylinder_query.cpp: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include "cylinder_query.h" 4 | #include "utils.h" 5 | 6 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 7 | int nsample, const float *new_xyz, 8 | const float *xyz, const float *rot, int *idx); 9 | 10 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 11 | const int nsample) { 12 | CHECK_CONTIGUOUS(new_xyz); 13 | CHECK_CONTIGUOUS(xyz); 14 | CHECK_CONTIGUOUS(rot); 15 | CHECK_IS_FLOAT(new_xyz); 16 | CHECK_IS_FLOAT(xyz); 17 | CHECK_IS_FLOAT(rot); 18 | 19 | if (new_xyz.type().is_cuda()) { 20 | CHECK_CUDA(xyz); 21 | CHECK_CUDA(rot); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, hmin, hmax, nsample, new_xyz.data(), 31 | xyz.data(), rot.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/cylinder_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | __global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | const float *__restrict__ rot, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | xyz += batch_index * n * 3; 17 | new_xyz += batch_index * m * 3; 18 | rot += batch_index * m * 9; 19 | idx += m * nsample * batch_index; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | 24 | float radius2 = radius * radius; 25 | for (int j = index; j < m; j += stride) { 26 | float new_x = new_xyz[j * 3 + 0]; 27 | float new_y = new_xyz[j * 3 + 1]; 28 | float new_z = new_xyz[j * 3 + 2]; 29 | float r0 = rot[j * 9 + 0]; 30 | float r1 = rot[j * 9 + 1]; 31 | float r2 = rot[j * 9 + 2]; 32 | float r3 = rot[j * 9 + 3]; 33 | float r4 = rot[j * 9 + 4]; 34 | float r5 = rot[j * 9 + 5]; 35 | float r6 = rot[j * 9 + 6]; 36 | float r7 = rot[j * 9 + 7]; 37 | float r8 = rot[j * 9 + 8]; 38 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 39 | float x = xyz[k * 3 + 0] - new_x; 40 | float y = xyz[k * 3 + 1] - new_y; 41 | float z = xyz[k * 3 + 2] - new_z; 42 | float x_rot = r0 * x + r3 * y + r6 * z; 43 | float y_rot = r1 * x + r4 * y + r7 * z; 44 | float z_rot = r2 * x + r5 * y + r8 * z; 45 | float d2 = y_rot * y_rot + z_rot * z_rot; 46 | if (d2 < radius2 && x_rot > hmin && x_rot < hmax) { 47 | if (cnt == 0) { 48 | for (int l = 0; l < nsample; ++l) { 49 | idx[j * nsample + l] = k; 50 | } 51 | } 52 | idx[j * nsample + cnt] = k; 53 | ++cnt; 54 | } 55 | } 56 | } 57 | } 58 | 59 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 60 | int nsample, const float *new_xyz, 61 | const float *xyz, const float *rot, int *idx) { 62 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 63 | query_cylinder_point_kernel<<>>( 64 | b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx); 65 | 66 | CUDA_CHECK_ERRORS(); 67 | } 68 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Pointnet2 layers. 7 | Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch 8 | Extended with the following: 9 | 1. Uniform sampling in each local region (sample_uniformly) 10 | 2. Return sampled points indices to support votenet. 11 | ''' 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import os 17 | import sys 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.append(BASE_DIR) 20 | 21 | import pointnet2_utils 22 | import pytorch_utils as pt_utils 23 | from typing import List 24 | 25 | 26 | class _PointnetSAModuleBase(nn.Module): 27 | 28 | def __init__(self): 29 | super().__init__() 30 | self.npoint = None 31 | self.groupers = None 32 | self.mlps = None 33 | 34 | def forward(self, xyz: torch.Tensor, 35 | features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 36 | r""" 37 | Parameters 38 | ---------- 39 | xyz : torch.Tensor 40 | (B, N, 3) tensor of the xyz coordinates of the features 41 | features : torch.Tensor 42 | (B, N, C) tensor of the descriptors of the the features 43 | 44 | Returns 45 | ------- 46 | new_xyz : torch.Tensor 47 | (B, npoint, 3) tensor of the new features' xyz 48 | new_features : torch.Tensor 49 | (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 50 | """ 51 | 52 | new_features_list = [] 53 | 54 | xyz_flipped = xyz.transpose(1, 2).contiguous() 55 | new_xyz = pointnet2_utils.gather_operation( 56 | xyz_flipped, 57 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 58 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 59 | 60 | for i in range(len(self.groupers)): 61 | new_features = self.groupers[i]( 62 | xyz, new_xyz, features 63 | ) # (B, C, npoint, nsample) 64 | 65 | new_features = self.mlps[i]( 66 | new_features 67 | ) # (B, mlp[-1], npoint, nsample) 68 | new_features = F.max_pool2d( 69 | new_features, kernel_size=[1, new_features.size(3)] 70 | ) # (B, mlp[-1], npoint, 1) 71 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 72 | 73 | new_features_list.append(new_features) 74 | 75 | return new_xyz, torch.cat(new_features_list, dim=1) 76 | 77 | 78 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 79 | r"""Pointnet set abstrction layer with multiscale grouping 80 | 81 | Parameters 82 | ---------- 83 | npoint : int 84 | Number of features 85 | radii : list of float32 86 | list of radii to group with 87 | nsamples : list of int32 88 | Number of samples in each ball query 89 | mlps : list of list of int32 90 | Spec of the pointnet before the global max_pool for each scale 91 | bn : bool 92 | Use batchnorm 93 | """ 94 | 95 | def __init__( 96 | self, 97 | *, 98 | npoint: int, 99 | radii: List[float], 100 | nsamples: List[int], 101 | mlps: List[List[int]], 102 | bn: bool = True, 103 | use_xyz: bool = True, 104 | sample_uniformly: bool = False 105 | ): 106 | super().__init__() 107 | 108 | assert len(radii) == len(nsamples) == len(mlps) 109 | 110 | self.npoint = npoint 111 | self.groupers = nn.ModuleList() 112 | self.mlps = nn.ModuleList() 113 | for i in range(len(radii)): 114 | radius = radii[i] 115 | nsample = nsamples[i] 116 | self.groupers.append( 117 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) 118 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 119 | ) 120 | mlp_spec = mlps[i] 121 | if use_xyz: 122 | mlp_spec[0] += 3 123 | 124 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 125 | 126 | 127 | class PointnetSAModule(PointnetSAModuleMSG): 128 | r"""Pointnet set abstrction layer 129 | 130 | Parameters 131 | ---------- 132 | npoint : int 133 | Number of features 134 | radius : float 135 | Radius of ball 136 | nsample : int 137 | Number of samples in the ball query 138 | mlp : list 139 | Spec of the pointnet before the global max_pool 140 | bn : bool 141 | Use batchnorm 142 | """ 143 | 144 | def __init__( 145 | self, 146 | *, 147 | mlp: List[int], 148 | npoint: int = None, 149 | radius: float = None, 150 | nsample: int = None, 151 | bn: bool = True, 152 | use_xyz: bool = True 153 | ): 154 | super().__init__( 155 | mlps=[mlp], 156 | npoint=npoint, 157 | radii=[radius], 158 | nsamples=[nsample], 159 | bn=bn, 160 | use_xyz=use_xyz 161 | ) 162 | 163 | 164 | class PointnetSAModuleVotes(nn.Module): 165 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 166 | with extra support for returning point indices for getting their GT votes ''' 167 | 168 | def __init__( 169 | self, 170 | *, 171 | mlp: List[int], 172 | npoint: int = None, 173 | radius: float = None, 174 | nsample: int = None, 175 | bn: bool = True, 176 | use_xyz: bool = True, 177 | pooling: str = 'max', 178 | sigma: float = None, # for RBF pooling 179 | normalize_xyz: bool = False, # noramlize local XYZ with radius 180 | sample_uniformly: bool = False, 181 | ret_unique_cnt: bool = False 182 | ): 183 | super().__init__() 184 | 185 | self.npoint = npoint 186 | self.radius = radius 187 | self.nsample = nsample 188 | self.pooling = pooling 189 | self.mlp_module = None 190 | self.use_xyz = use_xyz 191 | self.sigma = sigma 192 | if self.sigma is None: 193 | self.sigma = self.radius/2 194 | self.normalize_xyz = normalize_xyz 195 | self.ret_unique_cnt = ret_unique_cnt 196 | 197 | if npoint is not None: 198 | self.grouper = pointnet2_utils.QueryAndGroup(radius, nsample, 199 | use_xyz=use_xyz, ret_grouped_xyz=True, normalize_xyz=normalize_xyz, 200 | sample_uniformly=sample_uniformly, ret_unique_cnt=ret_unique_cnt) 201 | else: 202 | self.grouper = pointnet2_utils.GroupAll(use_xyz, ret_grouped_xyz=True) 203 | 204 | mlp_spec = mlp 205 | if use_xyz and len(mlp_spec)>0: 206 | mlp_spec[0] += 3 207 | self.mlp_module = pt_utils.SharedMLP(mlp_spec, bn=bn) 208 | 209 | 210 | def forward(self, xyz: torch.Tensor, 211 | features: torch.Tensor = None, 212 | inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 213 | r""" 214 | Parameters 215 | ---------- 216 | xyz : torch.Tensor 217 | (B, N, 3) tensor of the xyz coordinates of the features 218 | features : torch.Tensor 219 | (B, C, N) tensor of the descriptors of the the features 220 | inds : torch.Tensor 221 | (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) 222 | 223 | Returns 224 | ------- 225 | new_xyz : torch.Tensor 226 | (B, npoint, 3) tensor of the new features' xyz 227 | new_features : torch.Tensor 228 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 229 | inds: torch.Tensor 230 | (B, npoint) tensor of the inds 231 | """ 232 | 233 | xyz_flipped = xyz.transpose(1, 2).contiguous() 234 | if inds is None: 235 | inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) 236 | else: 237 | assert(inds.shape[1] == self.npoint) 238 | new_xyz = pointnet2_utils.gather_operation( 239 | xyz_flipped, inds 240 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 241 | 242 | if not self.ret_unique_cnt: 243 | grouped_features, grouped_xyz = self.grouper( 244 | xyz, new_xyz, features 245 | ) # (B, C, npoint, nsample) 246 | else: 247 | grouped_features, grouped_xyz, unique_cnt = self.grouper( 248 | xyz, new_xyz, features 249 | ) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint) 250 | 251 | new_features = self.mlp_module( 252 | grouped_features 253 | ) # (B, mlp[-1], npoint, nsample) 254 | if self.pooling == 'max': 255 | new_features = F.max_pool2d( 256 | new_features, kernel_size=[1, new_features.size(3)] 257 | ) # (B, mlp[-1], npoint, 1) 258 | elif self.pooling == 'avg': 259 | new_features = F.avg_pool2d( 260 | new_features, kernel_size=[1, new_features.size(3)] 261 | ) # (B, mlp[-1], npoint, 1) 262 | elif self.pooling == 'rbf': 263 | # Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma) 264 | # Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel 265 | rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1,keepdim=False) / (self.sigma**2) / 2) # (B, npoint, nsample) 266 | new_features = torch.sum(new_features * rbf.unsqueeze(1), -1, keepdim=True) / float(self.nsample) # (B, mlp[-1], npoint, 1) 267 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 268 | 269 | if not self.ret_unique_cnt: 270 | return new_xyz, new_features, inds 271 | else: 272 | return new_xyz, new_features, inds, unique_cnt 273 | 274 | class PointnetSAModuleMSGVotes(nn.Module): 275 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 276 | with extra support for returning point indices for getting their GT votes ''' 277 | 278 | def __init__( 279 | self, 280 | *, 281 | mlps: List[List[int]], 282 | npoint: int, 283 | radii: List[float], 284 | nsamples: List[int], 285 | bn: bool = True, 286 | use_xyz: bool = True, 287 | sample_uniformly: bool = False 288 | ): 289 | super().__init__() 290 | 291 | assert(len(mlps) == len(nsamples) == len(radii)) 292 | 293 | self.npoint = npoint 294 | self.groupers = nn.ModuleList() 295 | self.mlps = nn.ModuleList() 296 | for i in range(len(radii)): 297 | radius = radii[i] 298 | nsample = nsamples[i] 299 | self.groupers.append( 300 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) 301 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 302 | ) 303 | mlp_spec = mlps[i] 304 | if use_xyz: 305 | mlp_spec[0] += 3 306 | 307 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 308 | 309 | def forward(self, xyz: torch.Tensor, 310 | features: torch.Tensor = None, inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 311 | r""" 312 | Parameters 313 | ---------- 314 | xyz : torch.Tensor 315 | (B, N, 3) tensor of the xyz coordinates of the features 316 | features : torch.Tensor 317 | (B, C, C) tensor of the descriptors of the the features 318 | inds : torch.Tensor 319 | (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) 320 | 321 | Returns 322 | ------- 323 | new_xyz : torch.Tensor 324 | (B, npoint, 3) tensor of the new features' xyz 325 | new_features : torch.Tensor 326 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 327 | inds: torch.Tensor 328 | (B, npoint) tensor of the inds 329 | """ 330 | new_features_list = [] 331 | 332 | xyz_flipped = xyz.transpose(1, 2).contiguous() 333 | if inds is None: 334 | inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) 335 | new_xyz = pointnet2_utils.gather_operation( 336 | xyz_flipped, inds 337 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 338 | 339 | for i in range(len(self.groupers)): 340 | new_features = self.groupers[i]( 341 | xyz, new_xyz, features 342 | ) # (B, C, npoint, nsample) 343 | new_features = self.mlps[i]( 344 | new_features 345 | ) # (B, mlp[-1], npoint, nsample) 346 | new_features = F.max_pool2d( 347 | new_features, kernel_size=[1, new_features.size(3)] 348 | ) # (B, mlp[-1], npoint, 1) 349 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 350 | 351 | new_features_list.append(new_features) 352 | 353 | return new_xyz, torch.cat(new_features_list, dim=1), inds 354 | 355 | 356 | class PointnetFPModule(nn.Module): 357 | r"""Propigates the features of one set to another 358 | 359 | Parameters 360 | ---------- 361 | mlp : list 362 | Pointnet module parameters 363 | bn : bool 364 | Use batchnorm 365 | """ 366 | 367 | def __init__(self, *, mlp: List[int], bn: bool = True): 368 | super().__init__() 369 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 370 | 371 | def forward( 372 | self, unknown: torch.Tensor, known: torch.Tensor, 373 | unknow_feats: torch.Tensor, known_feats: torch.Tensor 374 | ) -> torch.Tensor: 375 | r""" 376 | Parameters 377 | ---------- 378 | unknown : torch.Tensor 379 | (B, n, 3) tensor of the xyz positions of the unknown features 380 | known : torch.Tensor 381 | (B, m, 3) tensor of the xyz positions of the known features 382 | unknow_feats : torch.Tensor 383 | (B, C1, n) tensor of the features to be propigated to 384 | known_feats : torch.Tensor 385 | (B, C2, m) tensor of features to be propigated 386 | 387 | Returns 388 | ------- 389 | new_features : torch.Tensor 390 | (B, mlp[-1], n) tensor of the features of the unknown features 391 | """ 392 | 393 | if known is not None: 394 | dist, idx = pointnet2_utils.three_nn(unknown, known) 395 | dist_recip = 1.0 / (dist + 1e-8) 396 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 397 | weight = dist_recip / norm 398 | 399 | interpolated_feats = pointnet2_utils.three_interpolate( 400 | known_feats, idx, weight 401 | ) 402 | else: 403 | interpolated_feats = known_feats.expand( 404 | *known_feats.size()[0:2], unknown.size(1) 405 | ) 406 | 407 | if unknow_feats is not None: 408 | new_features = torch.cat([interpolated_feats, unknow_feats], 409 | dim=1) #(B, C2 + C1, n) 410 | else: 411 | new_features = interpolated_feats 412 | 413 | new_features = new_features.unsqueeze(-1) 414 | new_features = self.mlp(new_features) 415 | 416 | return new_features.squeeze(-1) 417 | 418 | class PointnetLFPModuleMSG(nn.Module): 419 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 420 | learnable feature propagation layer.''' 421 | 422 | def __init__( 423 | self, 424 | *, 425 | mlps: List[List[int]], 426 | radii: List[float], 427 | nsamples: List[int], 428 | post_mlp: List[int], 429 | bn: bool = True, 430 | use_xyz: bool = True, 431 | sample_uniformly: bool = False 432 | ): 433 | super().__init__() 434 | 435 | assert(len(mlps) == len(nsamples) == len(radii)) 436 | 437 | self.post_mlp = pt_utils.SharedMLP(post_mlp, bn=bn) 438 | 439 | self.groupers = nn.ModuleList() 440 | self.mlps = nn.ModuleList() 441 | for i in range(len(radii)): 442 | radius = radii[i] 443 | nsample = nsamples[i] 444 | self.groupers.append( 445 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, 446 | sample_uniformly=sample_uniformly) 447 | ) 448 | mlp_spec = mlps[i] 449 | if use_xyz: 450 | mlp_spec[0] += 3 451 | 452 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 453 | 454 | def forward(self, xyz2: torch.Tensor, xyz1: torch.Tensor, 455 | features2: torch.Tensor, features1: torch.Tensor) -> torch.Tensor: 456 | r""" Propagate features from xyz1 to xyz2. 457 | Parameters 458 | ---------- 459 | xyz2 : torch.Tensor 460 | (B, N2, 3) tensor of the xyz coordinates of the features 461 | xyz1 : torch.Tensor 462 | (B, N1, 3) tensor of the xyz coordinates of the features 463 | features2 : torch.Tensor 464 | (B, C2, N2) tensor of the descriptors of the the features 465 | features1 : torch.Tensor 466 | (B, C1, N1) tensor of the descriptors of the the features 467 | 468 | Returns 469 | ------- 470 | new_features1 : torch.Tensor 471 | (B, \sum_k(mlps[k][-1]), N1) tensor of the new_features descriptors 472 | """ 473 | new_features_list = [] 474 | 475 | for i in range(len(self.groupers)): 476 | new_features = self.groupers[i]( 477 | xyz1, xyz2, features1 478 | ) # (B, C1, N2, nsample) 479 | new_features = self.mlps[i]( 480 | new_features 481 | ) # (B, mlp[-1], N2, nsample) 482 | new_features = F.max_pool2d( 483 | new_features, kernel_size=[1, new_features.size(3)] 484 | ) # (B, mlp[-1], N2, 1) 485 | new_features = new_features.squeeze(-1) # (B, mlp[-1], N2) 486 | 487 | if features2 is not None: 488 | new_features = torch.cat([new_features, features2], 489 | dim=1) #(B, mlp[-1] + C2, N2) 490 | 491 | new_features = new_features.unsqueeze(-1) 492 | new_features = self.post_mlp(new_features) 493 | 494 | new_features_list.append(new_features) 495 | 496 | return torch.cat(new_features_list, dim=1).squeeze(-1) 497 | 498 | 499 | if __name__ == "__main__": 500 | from torch.autograd import Variable 501 | torch.manual_seed(1) 502 | torch.cuda.manual_seed_all(1) 503 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) 504 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) 505 | 506 | test_module = PointnetSAModuleMSG( 507 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]] 508 | ) 509 | test_module.cuda() 510 | print(test_module(xyz, xyz_feats)) 511 | 512 | for _ in range(1): 513 | _, new_features = test_module(xyz, xyz_feats) 514 | new_features.backward( 515 | torch.cuda.FloatTensor(*new_features.size()).fill_(1) 516 | ) 517 | print(new_features) 518 | print(xyz.grad) 519 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import pytorch_utils as pt_utils 18 | import sys 19 | 20 | try: 21 | import builtins 22 | except: 23 | import __builtin__ as builtins 24 | 25 | try: 26 | import pointnet2._ext as _ext 27 | except ImportError: 28 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 29 | raise ImportError( 30 | "Could not import _ext module.\n" 31 | "Please see the setup instructions in the README: " 32 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 33 | ) 34 | 35 | if False: 36 | # Workaround for type hints without depending on the `typing` module 37 | from typing import * 38 | 39 | 40 | class RandomDropout(nn.Module): 41 | def __init__(self, p=0.5, inplace=False): 42 | super(RandomDropout, self).__init__() 43 | self.p = p 44 | self.inplace = inplace 45 | 46 | def forward(self, X): 47 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 48 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 49 | 50 | 51 | class FurthestPointSampling(Function): 52 | @staticmethod 53 | def forward(ctx, xyz, npoint): 54 | # type: (Any, torch.Tensor, int) -> torch.Tensor 55 | r""" 56 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 57 | minimum distance 58 | 59 | Parameters 60 | ---------- 61 | xyz : torch.Tensor 62 | (B, N, 3) tensor where N > npoint 63 | npoint : int32 64 | number of features in the sampled set 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | (B, npoint) tensor containing the set 70 | """ 71 | return _ext.furthest_point_sampling(xyz, npoint) 72 | 73 | @staticmethod 74 | def backward(xyz, a=None): 75 | return None, None 76 | 77 | 78 | furthest_point_sample = FurthestPointSampling.apply 79 | 80 | 81 | class GatherOperation(Function): 82 | @staticmethod 83 | def forward(ctx, features, idx): 84 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 85 | r""" 86 | 87 | Parameters 88 | ---------- 89 | features : torch.Tensor 90 | (B, C, N) tensor 91 | 92 | idx : torch.Tensor 93 | (B, npoint) tensor of the features to gather 94 | 95 | Returns 96 | ------- 97 | torch.Tensor 98 | (B, C, npoint) tensor 99 | """ 100 | 101 | _, C, N = features.size() 102 | 103 | ctx.for_backwards = (idx, C, N) 104 | 105 | return _ext.gather_points(features, idx) 106 | 107 | @staticmethod 108 | def backward(ctx, grad_out): 109 | idx, C, N = ctx.for_backwards 110 | 111 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 112 | return grad_features, None 113 | 114 | 115 | gather_operation = GatherOperation.apply 116 | 117 | 118 | class ThreeNN(Function): 119 | @staticmethod 120 | def forward(ctx, unknown, known): 121 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 122 | r""" 123 | Find the three nearest neighbors of unknown in known 124 | Parameters 125 | ---------- 126 | unknown : torch.Tensor 127 | (B, n, 3) tensor of known features 128 | known : torch.Tensor 129 | (B, m, 3) tensor of unknown features 130 | 131 | Returns 132 | ------- 133 | dist : torch.Tensor 134 | (B, n, 3) l2 distance to the three nearest neighbors 135 | idx : torch.Tensor 136 | (B, n, 3) index of 3 nearest neighbors 137 | """ 138 | dist2, idx = _ext.three_nn(unknown, known) 139 | 140 | return torch.sqrt(dist2), idx 141 | 142 | @staticmethod 143 | def backward(ctx, a=None, b=None): 144 | return None, None 145 | 146 | 147 | three_nn = ThreeNN.apply 148 | 149 | 150 | class ThreeInterpolate(Function): 151 | @staticmethod 152 | def forward(ctx, features, idx, weight): 153 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 154 | r""" 155 | Performs weight linear interpolation on 3 features 156 | Parameters 157 | ---------- 158 | features : torch.Tensor 159 | (B, c, m) Features descriptors to be interpolated from 160 | idx : torch.Tensor 161 | (B, n, 3) three nearest neighbors of the target features in features 162 | weight : torch.Tensor 163 | (B, n, 3) weights 164 | 165 | Returns 166 | ------- 167 | torch.Tensor 168 | (B, c, n) tensor of the interpolated features 169 | """ 170 | B, c, m = features.size() 171 | n = idx.size(1) 172 | 173 | ctx.three_interpolate_for_backward = (idx, weight, m) 174 | 175 | return _ext.three_interpolate(features, idx, weight) 176 | 177 | @staticmethod 178 | def backward(ctx, grad_out): 179 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 180 | r""" 181 | Parameters 182 | ---------- 183 | grad_out : torch.Tensor 184 | (B, c, n) tensor with gradients of ouputs 185 | 186 | Returns 187 | ------- 188 | grad_features : torch.Tensor 189 | (B, c, m) tensor with gradients of features 190 | 191 | None 192 | 193 | None 194 | """ 195 | idx, weight, m = ctx.three_interpolate_for_backward 196 | 197 | grad_features = _ext.three_interpolate_grad( 198 | grad_out.contiguous(), idx, weight, m 199 | ) 200 | 201 | return grad_features, None, None 202 | 203 | 204 | three_interpolate = ThreeInterpolate.apply 205 | 206 | 207 | class GroupingOperation(Function): 208 | @staticmethod 209 | def forward(ctx, features, idx): 210 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 211 | r""" 212 | 213 | Parameters 214 | ---------- 215 | features : torch.Tensor 216 | (B, C, N) tensor of features to group 217 | idx : torch.Tensor 218 | (B, npoint, nsample) tensor containing the indicies of features to group with 219 | 220 | Returns 221 | ------- 222 | torch.Tensor 223 | (B, C, npoint, nsample) tensor 224 | """ 225 | B, nfeatures, nsample = idx.size() 226 | _, C, N = features.size() 227 | 228 | ctx.for_backwards = (idx, N) 229 | 230 | return _ext.group_points(features, idx) 231 | 232 | @staticmethod 233 | def backward(ctx, grad_out): 234 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 235 | r""" 236 | 237 | Parameters 238 | ---------- 239 | grad_out : torch.Tensor 240 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 241 | 242 | Returns 243 | ------- 244 | torch.Tensor 245 | (B, C, N) gradient of the features 246 | None 247 | """ 248 | idx, N = ctx.for_backwards 249 | 250 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 251 | 252 | return grad_features, None 253 | 254 | 255 | grouping_operation = GroupingOperation.apply 256 | 257 | 258 | class BallQuery(Function): 259 | @staticmethod 260 | def forward(ctx, radius, nsample, xyz, new_xyz): 261 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 262 | r""" 263 | 264 | Parameters 265 | ---------- 266 | radius : float 267 | radius of the balls 268 | nsample : int 269 | maximum number of features in the balls 270 | xyz : torch.Tensor 271 | (B, N, 3) xyz coordinates of the features 272 | new_xyz : torch.Tensor 273 | (B, npoint, 3) centers of the ball query 274 | 275 | Returns 276 | ------- 277 | torch.Tensor 278 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 279 | """ 280 | return _ext.ball_query(new_xyz, xyz, radius, nsample) 281 | 282 | @staticmethod 283 | def backward(ctx, a=None): 284 | return None, None, None, None 285 | 286 | 287 | ball_query = BallQuery.apply 288 | 289 | 290 | class QueryAndGroup(nn.Module): 291 | r""" 292 | Groups with a ball query of radius 293 | 294 | Parameters 295 | --------- 296 | radius : float32 297 | Radius of ball 298 | nsample : int32 299 | Maximum number of features to gather in the ball 300 | """ 301 | 302 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 303 | # type: (QueryAndGroup, float, int, bool) -> None 304 | super(QueryAndGroup, self).__init__() 305 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 306 | self.ret_grouped_xyz = ret_grouped_xyz 307 | self.normalize_xyz = normalize_xyz 308 | self.sample_uniformly = sample_uniformly 309 | self.ret_unique_cnt = ret_unique_cnt 310 | if self.ret_unique_cnt: 311 | assert(self.sample_uniformly) 312 | 313 | def forward(self, xyz, new_xyz, features=None): 314 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 315 | r""" 316 | Parameters 317 | ---------- 318 | xyz : torch.Tensor 319 | xyz coordinates of the features (B, N, 3) 320 | new_xyz : torch.Tensor 321 | centriods (B, npoint, 3) 322 | features : torch.Tensor 323 | Descriptors of the features (B, C, N) 324 | 325 | Returns 326 | ------- 327 | new_features : torch.Tensor 328 | (B, 3 + C, npoint, nsample) tensor 329 | """ 330 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 331 | 332 | if self.sample_uniformly: 333 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 334 | for i_batch in range(idx.shape[0]): 335 | for i_region in range(idx.shape[1]): 336 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 337 | num_unique = unique_ind.shape[0] 338 | unique_cnt[i_batch, i_region] = num_unique 339 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 340 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 341 | idx[i_batch, i_region, :] = all_ind 342 | 343 | 344 | xyz_trans = xyz.transpose(1, 2).contiguous() 345 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 346 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 347 | if self.normalize_xyz: 348 | grouped_xyz /= self.radius 349 | 350 | if features is not None: 351 | grouped_features = grouping_operation(features, idx) 352 | if self.use_xyz: 353 | new_features = torch.cat( 354 | [grouped_xyz, grouped_features], dim=1 355 | ) # (B, C + 3, npoint, nsample) 356 | else: 357 | new_features = grouped_features 358 | else: 359 | assert ( 360 | self.use_xyz 361 | ), "Cannot have not features and not use xyz as a feature!" 362 | new_features = grouped_xyz 363 | 364 | ret = [new_features] 365 | if self.ret_grouped_xyz: 366 | ret.append(grouped_xyz) 367 | if self.ret_unique_cnt: 368 | ret.append(unique_cnt) 369 | if len(ret) == 1: 370 | return ret[0] 371 | else: 372 | return tuple(ret) 373 | 374 | 375 | class GroupAll(nn.Module): 376 | r""" 377 | Groups all features 378 | 379 | Parameters 380 | --------- 381 | """ 382 | 383 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 384 | # type: (GroupAll, bool) -> None 385 | super(GroupAll, self).__init__() 386 | self.use_xyz = use_xyz 387 | 388 | def forward(self, xyz, new_xyz, features=None): 389 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 390 | r""" 391 | Parameters 392 | ---------- 393 | xyz : torch.Tensor 394 | xyz coordinates of the features (B, N, 3) 395 | new_xyz : torch.Tensor 396 | Ignored 397 | features : torch.Tensor 398 | Descriptors of the features (B, C, N) 399 | 400 | Returns 401 | ------- 402 | new_features : torch.Tensor 403 | (B, C + 3, 1, N) tensor 404 | """ 405 | 406 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 407 | if features is not None: 408 | grouped_features = features.unsqueeze(2) 409 | if self.use_xyz: 410 | new_features = torch.cat( 411 | [grouped_xyz, grouped_features], dim=1 412 | ) # (B, 3 + C, 1, N) 413 | else: 414 | new_features = grouped_features 415 | else: 416 | new_features = grouped_xyz 417 | 418 | if self.ret_grouped_xyz: 419 | return new_features, grouped_xyz 420 | else: 421 | return new_features 422 | 423 | 424 | class CylinderQuery(Function): 425 | @staticmethod 426 | def forward(ctx, radius, hmin, hmax, nsample, xyz, new_xyz, rot): 427 | # type: (Any, float, float, float, int, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 428 | r""" 429 | 430 | Parameters 431 | ---------- 432 | radius : float 433 | radius of the cylinders 434 | hmin, hmax : float 435 | endpoints of cylinder height in x-rotation axis 436 | nsample : int 437 | maximum number of features in the cylinders 438 | xyz : torch.Tensor 439 | (B, N, 3) xyz coordinates of the features 440 | new_xyz : torch.Tensor 441 | (B, npoint, 3) centers of the cylinder query 442 | rot: torch.Tensor 443 | (B, npoint, 9) flatten rotation matrices from 444 | cylinder frame to world frame 445 | 446 | Returns 447 | ------- 448 | torch.Tensor 449 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 450 | """ 451 | return _ext.cylinder_query(new_xyz, xyz, rot, radius, hmin, hmax, nsample) 452 | 453 | @staticmethod 454 | def backward(ctx, a=None): 455 | return None, None, None, None, None, None, None 456 | 457 | 458 | cylinder_query = CylinderQuery.apply 459 | 460 | 461 | class CylinderQueryAndGroup(nn.Module): 462 | r""" 463 | Groups with a cylinder query of radius and height 464 | 465 | Parameters 466 | --------- 467 | radius : float32 468 | Radius of cylinder 469 | hmin, hmax: float32 470 | endpoints of cylinder height in x-rotation axis 471 | nsample : int32 472 | Maximum number of features to gather in the ball 473 | """ 474 | 475 | def __init__(self, radius, hmin, hmax, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, rotate_xyz=True, sample_uniformly=False, ret_unique_cnt=False): 476 | super(CylinderQueryAndGroup, self).__init__() 477 | self.radius, self.nsample, self.hmin, self.hmax, = radius, nsample, hmin, hmax 478 | self.use_xyz = use_xyz 479 | self.ret_grouped_xyz = ret_grouped_xyz 480 | self.normalize_xyz = normalize_xyz 481 | self.rotate_xyz = rotate_xyz 482 | self.sample_uniformly = sample_uniformly 483 | self.ret_unique_cnt = ret_unique_cnt 484 | if self.ret_unique_cnt: 485 | assert(self.sample_uniformly) 486 | 487 | def forward(self, xyz, new_xyz, rot, features=None): 488 | r""" 489 | Parameters 490 | ---------- 491 | xyz : torch.Tensor 492 | xyz coordinates of the features (B, N, 3) 493 | new_xyz : torch.Tensor 494 | centriods (B, npoint, 3) 495 | rot : torch.Tensor 496 | rotation matrices (B, npoint, 3, 3) 497 | features : torch.Tensor 498 | Descriptors of the features (B, C, N) 499 | 500 | Returns 501 | ------- 502 | new_features : torch.Tensor 503 | (B, 3 + C, npoint, nsample) tensor 504 | """ 505 | B, npoint, _ = new_xyz.size() 506 | idx = cylinder_query(self.radius, self.hmin, self.hmax, self.nsample, xyz, new_xyz, rot.view(B, npoint, 9)) 507 | 508 | if self.sample_uniformly: 509 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 510 | for i_batch in range(idx.shape[0]): 511 | for i_region in range(idx.shape[1]): 512 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 513 | num_unique = unique_ind.shape[0] 514 | unique_cnt[i_batch, i_region] = num_unique 515 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 516 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 517 | idx[i_batch, i_region, :] = all_ind 518 | 519 | 520 | xyz_trans = xyz.transpose(1, 2).contiguous() 521 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 522 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 523 | if self.normalize_xyz: 524 | grouped_xyz /= self.radius 525 | if self.rotate_xyz: 526 | grouped_xyz_ = grouped_xyz.permute(0, 2, 3, 1).contiguous() # (B, npoint, nsample, 3) 527 | grouped_xyz_ = torch.matmul(grouped_xyz_, rot) 528 | grouped_xyz = grouped_xyz_.permute(0, 3, 1, 2).contiguous() 529 | 530 | 531 | if features is not None: 532 | grouped_features = grouping_operation(features, idx) 533 | if self.use_xyz: 534 | new_features = torch.cat( 535 | [grouped_xyz, grouped_features], dim=1 536 | ) # (B, C + 3, npoint, nsample) 537 | else: 538 | new_features = grouped_features 539 | else: 540 | assert ( 541 | self.use_xyz 542 | ), "Cannot have not features and not use xyz as a feature!" 543 | new_features = grouped_xyz 544 | 545 | ret = [new_features] 546 | if self.ret_grouped_xyz: 547 | ret.append(grouped_xyz) 548 | if self.ret_unique_cnt: 549 | ret.append(unique_cnt) 550 | if len(ret) == 1: 551 | return ret[0] 552 | else: 553 | return tuple(ret) -------------------------------------------------------------------------------- /pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os 10 | ROOT = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | _ext_src_root = "_ext_src" 13 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 14 | "{}/src/*.cu".format(_ext_src_root) 15 | ) 16 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 17 | 18 | setup( 19 | name='pointnet2', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='pointnet2._ext', 23 | sources=_ext_sources, 24 | extra_compile_args={ 25 | "cxx": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 26 | "nvcc": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 27 | }, 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8 2 | tensorboard==2.3 3 | numpy 4 | scipy 5 | open3d>=0.8 6 | Pillow 7 | tqdm 8 | MinkowskiEngine==0.5.4 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import argparse 5 | import time 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from graspnetAPI.graspnet_eval import GraspGroup, GraspNetEval 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 12 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 13 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 14 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 15 | 16 | from models.graspnet import GraspNet, pred_decode 17 | from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn 18 | from utils.collision_detector import ModelFreeCollisionDetector 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset_root', default=None, required=True) 22 | parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None, required=True) 23 | parser.add_argument('--dump_dir', help='Dump dir to save outputs', default=None, required=True) 24 | parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim') 25 | parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]') 26 | parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 15000]') 27 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]') 28 | parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size for sparse convolution') 29 | parser.add_argument('--collision_thresh', type=float, default=0.01, 30 | help='Collision Threshold in collision detection [default: 0.01]') 31 | parser.add_argument('--voxel_size_cd', type=float, default=0.01, help='Voxel Size for collision detection') 32 | parser.add_argument('--infer', action='store_true', default=False) 33 | parser.add_argument('--eval', action='store_true', default=False) 34 | cfgs = parser.parse_args() 35 | 36 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 37 | if not os.path.exists(cfgs.dump_dir): 38 | os.mkdir(cfgs.dump_dir) 39 | 40 | 41 | # Init datasets and dataloaders 42 | def my_worker_init_fn(worker_id): 43 | np.random.seed(np.random.get_state()[1][0] + worker_id) 44 | pass 45 | 46 | 47 | def inference(): 48 | test_dataset = GraspNetDataset(cfgs.dataset_root, split='test_seen', camera=cfgs.camera, num_points=cfgs.num_point, 49 | voxel_size=cfgs.voxel_size, remove_outlier=True, augment=False, load_label=False) 50 | print('Test dataset length: ', len(test_dataset)) 51 | scene_list = test_dataset.scene_list() 52 | test_dataloader = DataLoader(test_dataset, batch_size=cfgs.batch_size, shuffle=False, 53 | num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn) 54 | print('Test dataloader length: ', len(test_dataloader)) 55 | # Init the model 56 | net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=False) 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | net.to(device) 59 | # Load checkpoint 60 | checkpoint = torch.load(cfgs.checkpoint_path) 61 | net.load_state_dict(checkpoint['model_state_dict']) 62 | start_epoch = checkpoint['epoch'] 63 | print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch)) 64 | 65 | batch_interval = 100 66 | net.eval() 67 | tic = time.time() 68 | for batch_idx, batch_data in enumerate(test_dataloader): 69 | for key in batch_data: 70 | if 'list' in key: 71 | for i in range(len(batch_data[key])): 72 | for j in range(len(batch_data[key][i])): 73 | batch_data[key][i][j] = batch_data[key][i][j].to(device) 74 | else: 75 | batch_data[key] = batch_data[key].to(device) 76 | 77 | # Forward pass 78 | with torch.no_grad(): 79 | end_points = net(batch_data) 80 | grasp_preds = pred_decode(end_points) 81 | 82 | # Dump results for evaluation 83 | for i in range(cfgs.batch_size): 84 | data_idx = batch_idx * cfgs.batch_size + i 85 | preds = grasp_preds[i].detach().cpu().numpy() 86 | 87 | gg = GraspGroup(preds) 88 | # collision detection 89 | if cfgs.collision_thresh > 0: 90 | cloud = test_dataset.get_data(data_idx, return_raw_cloud=True) 91 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size_cd) 92 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 93 | gg = gg[~collision_mask] 94 | 95 | # save grasps 96 | save_dir = os.path.join(cfgs.dump_dir, scene_list[data_idx], cfgs.camera) 97 | save_path = os.path.join(save_dir, str(data_idx % 256).zfill(4) + '.npy') 98 | if not os.path.exists(save_dir): 99 | os.makedirs(save_dir) 100 | gg.save_npy(save_path) 101 | 102 | if (batch_idx + 1) % batch_interval == 0: 103 | toc = time.time() 104 | print('Eval batch: %d, time: %fs' % (batch_idx + 1, (toc - tic) / batch_interval)) 105 | tic = time.time() 106 | 107 | 108 | def evaluate(dump_dir): 109 | ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test_seen') 110 | res, ap = ge.eval_seen(dump_folder=dump_dir, proc=6) 111 | save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera)) 112 | np.save(save_dir, res) 113 | 114 | 115 | if __name__ == '__main__': 116 | if cfgs.infer: 117 | inference() 118 | if cfgs.eval: 119 | evaluate(cfgs.dump_dir) 120 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from datetime import datetime 5 | import argparse 6 | 7 | import torch 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 17 | 18 | from models.graspnet import GraspNet 19 | from models.loss import get_loss 20 | from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn, load_grasp_labels 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset_root', default=None, required=True) 24 | parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]') 25 | parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None) 26 | parser.add_argument('--model_name', type=str, default=None) 27 | parser.add_argument('--log_dir', default='logs/log') 28 | parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 20000]') 29 | parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim') 30 | parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size to process point clouds ') 31 | parser.add_argument('--max_epoch', type=int, default=10, help='Epoch to run [default: 18]') 32 | parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 2]') 33 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 34 | parser.add_argument('--resume', action='store_true', default=False, help='Whether to resume from checkpoint') 35 | cfgs = parser.parse_args() 36 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 37 | EPOCH_CNT = 0 38 | CHECKPOINT_PATH = cfgs.checkpoint_path if cfgs.checkpoint_path is not None and cfgs.resume else None 39 | if not os.path.exists(cfgs.log_dir): 40 | os.makedirs(cfgs.log_dir) 41 | 42 | LOG_FOUT = open(os.path.join(cfgs.log_dir, 'log_train.txt'), 'a') 43 | LOG_FOUT.write(str(cfgs) + '\n') 44 | 45 | 46 | def log_string(out_str): 47 | LOG_FOUT.write(out_str + '\n') 48 | LOG_FOUT.flush() 49 | print(out_str) 50 | 51 | 52 | # Init datasets and dataloaders 53 | def my_worker_init_fn(worker_id): 54 | np.random.seed(np.random.get_state()[1][0] + worker_id) 55 | pass 56 | 57 | 58 | grasp_labels = load_grasp_labels(cfgs.dataset_root) 59 | TRAIN_DATASET = GraspNetDataset(cfgs.dataset_root, grasp_labels=grasp_labels, camera=cfgs.camera, split='train', 60 | num_points=cfgs.num_point, voxel_size=cfgs.voxel_size, 61 | remove_outlier=True, augment=True, load_label=True) 62 | print('train dataset length: ', len(TRAIN_DATASET)) 63 | TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=cfgs.batch_size, shuffle=True, 64 | num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn) 65 | print('train dataloader length: ', len(TRAIN_DATALOADER)) 66 | 67 | net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=True) 68 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 69 | net.to(device) 70 | # Load the Adam optimizer 71 | optimizer = optim.Adam(net.parameters(), lr=cfgs.learning_rate) 72 | start_epoch = 0 73 | if CHECKPOINT_PATH is not None and os.path.isfile(CHECKPOINT_PATH): 74 | checkpoint = torch.load(CHECKPOINT_PATH) 75 | net.load_state_dict(checkpoint['model_state_dict']) 76 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 77 | start_epoch = checkpoint['epoch'] 78 | log_string("-> loaded checkpoint %s (epoch: %d)" % (CHECKPOINT_PATH, start_epoch)) 79 | # TensorBoard Visualizers 80 | TRAIN_WRITER = SummaryWriter(os.path.join(cfgs.log_dir, 'train')) 81 | 82 | 83 | def get_current_lr(epoch): 84 | lr = cfgs.learning_rate 85 | lr = lr * (0.95 ** epoch) 86 | return lr 87 | 88 | 89 | def adjust_learning_rate(optimizer, epoch): 90 | lr = get_current_lr(epoch) 91 | for param_group in optimizer.param_groups: 92 | param_group['lr'] = lr 93 | 94 | 95 | def train_one_epoch(): 96 | stat_dict = {} # collect statistics 97 | adjust_learning_rate(optimizer, EPOCH_CNT) 98 | net.train() 99 | batch_interval = 20 100 | for batch_idx, batch_data_label in enumerate(TRAIN_DATALOADER): 101 | for key in batch_data_label: 102 | if 'list' in key: 103 | for i in range(len(batch_data_label[key])): 104 | for j in range(len(batch_data_label[key][i])): 105 | batch_data_label[key][i][j] = batch_data_label[key][i][j].to(device) 106 | else: 107 | batch_data_label[key] = batch_data_label[key].to(device) 108 | 109 | end_points = net(batch_data_label) 110 | loss, end_points = get_loss(end_points) 111 | loss.backward() 112 | optimizer.step() 113 | optimizer.zero_grad() 114 | 115 | for key in end_points: 116 | if 'loss' in key or 'acc' in key or 'prec' in key or 'recall' in key or 'count' in key: 117 | if key not in stat_dict: 118 | stat_dict[key] = 0 119 | stat_dict[key] += end_points[key].item() 120 | 121 | if (batch_idx + 1) % batch_interval == 0: 122 | log_string(' ----epoch: %03d ---- batch: %03d ----' % (EPOCH_CNT, batch_idx + 1)) 123 | for key in sorted(stat_dict.keys()): 124 | TRAIN_WRITER.add_scalar(key, stat_dict[key] / batch_interval, 125 | (EPOCH_CNT * len(TRAIN_DATALOADER) + batch_idx) * cfgs.batch_size) 126 | log_string('mean %s: %f' % (key, stat_dict[key] / batch_interval)) 127 | stat_dict[key] = 0 128 | 129 | 130 | def train(start_epoch): 131 | global EPOCH_CNT 132 | for epoch in range(start_epoch, cfgs.max_epoch): 133 | EPOCH_CNT = epoch 134 | log_string('**** EPOCH %03d ****' % epoch) 135 | log_string('Current learning rate: %f' % (get_current_lr(epoch))) 136 | log_string(str(datetime.now())) 137 | # Reset numpy seed. 138 | # REF: https://github.com/pytorch/pytorch/issues/5059 139 | np.random.seed() 140 | train_one_epoch() 141 | 142 | save_dict = {'epoch': epoch + 1, 'optimizer_state_dict': optimizer.state_dict(), 143 | 'model_state_dict': net.state_dict()} 144 | torch.save(save_dict, os.path.join(cfgs.log_dir, cfgs.model_name + '_epoch' + str(epoch + 1).zfill(2) + '.tar')) 145 | 146 | 147 | if __name__ == '__main__': 148 | train(start_epoch) 149 | -------------------------------------------------------------------------------- /utils/collision_detector.py: -------------------------------------------------------------------------------- 1 | """ Collision detection to remove collided grasp pose predictions. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import open3d as o3d 9 | 10 | class ModelFreeCollisionDetector(): 11 | """ Collision detection in scenes without object labels. Current finger width and length are fixed. 12 | 13 | Input: 14 | scene_points: [numpy.ndarray, (N,3), numpy.float32] 15 | the scene points to detect 16 | voxel_size: [float] 17 | used for downsample 18 | 19 | Example usage: 20 | mfcdetector = ModelFreeCollisionDetector(scene_points, voxel_size=0.005) 21 | collision_mask = mfcdetector.detect(grasp_group, approach_dist=0.03) 22 | collision_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, return_ious=True) 23 | collision_mask, empty_mask = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, 24 | return_empty_grasp=True, empty_thresh=0.01) 25 | collision_mask, empty_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, 26 | return_empty_grasp=True, empty_thresh=0.01, return_ious=True) 27 | """ 28 | def __init__(self, scene_points, voxel_size=0.005): 29 | self.finger_width = 0.01 30 | self.finger_length = 0.06 31 | self.voxel_size = voxel_size 32 | scene_cloud = o3d.geometry.PointCloud() 33 | scene_cloud.points = o3d.utility.Vector3dVector(scene_points) 34 | scene_cloud = scene_cloud.voxel_down_sample(voxel_size) 35 | self.scene_points = np.array(scene_cloud.points) 36 | 37 | def detect(self, grasp_group, approach_dist=0.03, collision_thresh=0.05, return_empty_grasp=False, empty_thresh=0.01, return_ious=False): 38 | """ Detect collision of grasps. 39 | 40 | Input: 41 | grasp_group: [GraspGroup, M grasps] 42 | the grasps to check 43 | approach_dist: [float] 44 | the distance for a gripper to move along approaching direction before grasping 45 | this shifting space requires no point either 46 | collision_thresh: [float] 47 | if global collision iou is greater than this threshold, 48 | a collision is detected 49 | return_empty_grasp: [bool] 50 | if True, return a mask to imply whether there are objects in a grasp 51 | empty_thresh: [float] 52 | if inner space iou is smaller than this threshold, 53 | a collision is detected 54 | only set when [return_empty_grasp] is True 55 | return_ious: [bool] 56 | if True, return global collision iou and part collision ious 57 | 58 | Output: 59 | collision_mask: [numpy.ndarray, (M,), numpy.bool] 60 | True implies collision 61 | [optional] empty_mask: [numpy.ndarray, (M,), numpy.bool] 62 | True implies empty grasp 63 | only returned when [return_empty_grasp] is True 64 | [optional] iou_list: list of [numpy.ndarray, (M,), numpy.float32] 65 | global and part collision ious, containing 66 | [global_iou, left_iou, right_iou, bottom_iou, shifting_iou] 67 | only returned when [return_ious] is True 68 | """ 69 | approach_dist = max(approach_dist, self.finger_width) 70 | T = grasp_group.translations 71 | R = grasp_group.rotation_matrices 72 | heights = grasp_group.heights[:,np.newaxis] 73 | depths = grasp_group.depths[:,np.newaxis] 74 | widths = grasp_group.widths[:,np.newaxis] 75 | targets = self.scene_points[np.newaxis,:,:] - T[:,np.newaxis,:] 76 | targets = np.matmul(targets, R) 77 | 78 | ## collision detection 79 | # height mask 80 | mask1 = ((targets[:,:,2] > -heights/2) & (targets[:,:,2] < heights/2)) 81 | # left finger mask 82 | mask2 = ((targets[:,:,0] > depths - self.finger_length) & (targets[:,:,0] < depths)) 83 | mask3 = (targets[:,:,1] > -(widths/2 + self.finger_width)) 84 | mask4 = (targets[:,:,1] < -widths/2) 85 | # right finger mask 86 | mask5 = (targets[:,:,1] < (widths/2 + self.finger_width)) 87 | mask6 = (targets[:,:,1] > widths/2) 88 | # bottom mask 89 | mask7 = ((targets[:,:,0] <= depths - self.finger_length)\ 90 | & (targets[:,:,0] > depths - self.finger_length - self.finger_width)) 91 | # shifting mask 92 | mask8 = ((targets[:,:,0] <= depths - self.finger_length - self.finger_width)\ 93 | & (targets[:,:,0] > depths - self.finger_length - self.finger_width - approach_dist)) 94 | 95 | # get collision mask of each point 96 | left_mask = (mask1 & mask2 & mask3 & mask4) 97 | right_mask = (mask1 & mask2 & mask5 & mask6) 98 | bottom_mask = (mask1 & mask3 & mask5 & mask7) 99 | shifting_mask = (mask1 & mask3 & mask5 & mask8) 100 | global_mask = (left_mask | right_mask | bottom_mask | shifting_mask) 101 | 102 | # calculate equivalant volume of each part 103 | left_right_volume = (heights * self.finger_length * self.finger_width / (self.voxel_size**3)).reshape(-1) 104 | bottom_volume = (heights * (widths+2*self.finger_width) * self.finger_width / (self.voxel_size**3)).reshape(-1) 105 | shifting_volume = (heights * (widths+2*self.finger_width) * approach_dist / (self.voxel_size**3)).reshape(-1) 106 | volume = left_right_volume*2 + bottom_volume + shifting_volume 107 | 108 | # get collision iou of each part 109 | global_iou = global_mask.sum(axis=1) / (volume+1e-6) 110 | 111 | # get collison mask 112 | collision_mask = (global_iou > collision_thresh) 113 | 114 | if not (return_empty_grasp or return_ious): 115 | return collision_mask 116 | 117 | ret_value = [collision_mask,] 118 | if return_empty_grasp: 119 | inner_mask = (mask1 & mask2 & (~mask4) & (~mask6)) 120 | inner_volume = (heights * self.finger_length * widths / (self.voxel_size**3)).reshape(-1) 121 | empty_mask = (inner_mask.sum(axis=-1)/inner_volume < empty_thresh) 122 | ret_value.append(empty_mask) 123 | if return_ious: 124 | left_iou = left_mask.sum(axis=1) / (left_right_volume+1e-6) 125 | right_iou = right_mask.sum(axis=1) / (left_right_volume+1e-6) 126 | bottom_iou = bottom_mask.sum(axis=1) / (bottom_volume+1e-6) 127 | shifting_iou = shifting_mask.sum(axis=1) / (shifting_volume+1e-6) 128 | ret_value.append([global_iou, left_iou, right_iou, bottom_iou, shifting_iou]) 129 | return ret_value 130 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for data processing. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | class CameraInfo(): 9 | """ Camera intrisics for point cloud creation. """ 10 | 11 | def __init__(self, width, height, fx, fy, cx, cy, scale): 12 | self.width = width 13 | self.height = height 14 | self.fx = fx 15 | self.fy = fy 16 | self.cx = cx 17 | self.cy = cy 18 | self.scale = scale 19 | 20 | 21 | def create_point_cloud_from_depth_image(depth, camera, organized=True): 22 | """ Generate point cloud using depth image only. 23 | 24 | Input: 25 | depth: [numpy.ndarray, (H,W), numpy.float32] 26 | depth image 27 | camera: [CameraInfo] 28 | camera intrinsics 29 | organized: bool 30 | whether to keep the cloud in image shape (H,W,3) 31 | 32 | Output: 33 | cloud: [numpy.ndarray, (H,W,3)/(H*W,3), numpy.float32] 34 | generated cloud, (H,W,3) for organized=True, (H*W,3) for organized=False 35 | """ 36 | assert (depth.shape[0] == camera.height and depth.shape[1] == camera.width) 37 | xmap = np.arange(camera.width) 38 | ymap = np.arange(camera.height) 39 | xmap, ymap = np.meshgrid(xmap, ymap) 40 | points_z = depth / camera.scale 41 | points_x = (xmap - camera.cx) * points_z / camera.fx 42 | points_y = (ymap - camera.cy) * points_z / camera.fy 43 | cloud = np.stack([points_x, points_y, points_z], axis=-1) 44 | if not organized: 45 | cloud = cloud.reshape([-1, 3]) 46 | return cloud 47 | 48 | 49 | def transform_point_cloud(cloud, transform, format='4x4'): 50 | """ Transform points to new coordinates with transformation matrix. 51 | 52 | Input: 53 | cloud: [np.ndarray, (N,3), np.float32] 54 | points in original coordinates 55 | transform: [np.ndarray, (3,3)/(3,4)/(4,4), np.float32] 56 | transformation matrix, could be rotation only or rotation+translation 57 | format: [string, '3x3'/'3x4'/'4x4'] 58 | the shape of transformation matrix 59 | '3x3' --> rotation matrix 60 | '3x4'/'4x4' --> rotation matrix + translation matrix 61 | 62 | Output: 63 | cloud_transformed: [np.ndarray, (N,3), np.float32] 64 | points in new coordinates 65 | """ 66 | if not (format == '3x3' or format == '4x4' or format == '3x4'): 67 | raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.') 68 | if format == '3x3': 69 | cloud_transformed = np.dot(transform, cloud.T).T 70 | elif format == '4x4' or format == '3x4': 71 | ones = np.ones(cloud.shape[0])[:, np.newaxis] 72 | cloud_ = np.concatenate([cloud, ones], axis=1) 73 | cloud_transformed = np.dot(transform, cloud_.T).T 74 | cloud_transformed = cloud_transformed[:, :3] 75 | return cloud_transformed 76 | 77 | 78 | def compute_point_dists(A, B): 79 | """ Compute pair-wise point distances in two matrices. 80 | 81 | Input: 82 | A: [np.ndarray, (N,3), np.float32] 83 | point cloud A 84 | B: [np.ndarray, (M,3), np.float32] 85 | point cloud B 86 | 87 | Output: 88 | dists: [np.ndarray, (N,M), np.float32] 89 | distance matrix 90 | """ 91 | A = A[:, np.newaxis, :] 92 | B = B[np.newaxis, :, :] 93 | dists = np.linalg.norm(A - B, axis=-1) 94 | return dists 95 | 96 | 97 | def remove_invisible_grasp_points(cloud, grasp_points, pose, th=0.01): 98 | """ Remove invisible part of object model according to scene point cloud. 99 | 100 | Input: 101 | cloud: [np.ndarray, (N,3), np.float32] 102 | scene point cloud 103 | grasp_points: [np.ndarray, (M,3), np.float32] 104 | grasp point label in object coordinates 105 | pose: [np.ndarray, (4,4), np.float32] 106 | transformation matrix from object coordinates to world coordinates 107 | th: [float] 108 | if the minimum distance between a grasp point and the scene points is greater than outlier, the point will be removed 109 | 110 | Output: 111 | visible_mask: [np.ndarray, (M,), np.bool] 112 | mask to show the visible part of grasp points 113 | """ 114 | grasp_points_trans = transform_point_cloud(grasp_points, pose) 115 | dists = compute_point_dists(grasp_points_trans, cloud) 116 | min_dists = dists.min(axis=1) 117 | visible_mask = (min_dists < th) 118 | return visible_mask 119 | 120 | 121 | def get_workspace_mask(cloud, seg, trans=None, organized=True, outlier=0): 122 | """ Keep points in workspace as input. 123 | 124 | Input: 125 | cloud: [np.ndarray, (H,W,3), np.float32] 126 | scene point cloud 127 | seg: [np.ndarray, (H,W,), np.uint8] 128 | segmantation label of scene points 129 | trans: [np.ndarray, (4,4), np.float32] 130 | transformation matrix for scene points, default: None. 131 | organized: [bool] 132 | whether to keep the cloud in image shape (H,W,3) 133 | outlier: [float] 134 | if the distance between a point and workspace is greater than outlier, the point will be removed 135 | 136 | Output: 137 | workspace_mask: [np.ndarray, (H,W)/(H*W,), np.bool] 138 | mask to indicate whether scene points are in workspace 139 | """ 140 | if organized: 141 | h, w, _ = cloud.shape 142 | cloud = cloud.reshape([h * w, 3]) 143 | seg = seg.reshape(h * w) 144 | if trans is not None: 145 | cloud = transform_point_cloud(cloud, trans) 146 | foreground = cloud[seg > 0] 147 | xmin, ymin, zmin = foreground.min(axis=0) 148 | xmax, ymax, zmax = foreground.max(axis=0) 149 | mask_x = ((cloud[:, 0] > xmin - outlier) & (cloud[:, 0] < xmax + outlier)) 150 | mask_y = ((cloud[:, 1] > ymin - outlier) & (cloud[:, 1] < ymax + outlier)) 151 | mask_z = ((cloud[:, 2] > zmin - outlier) & (cloud[:, 2] < zmax + outlier)) 152 | workspace_mask = (mask_x & mask_y & mask_z) 153 | if organized: 154 | workspace_mask = workspace_mask.reshape([h, w]) 155 | 156 | return workspace_mask 157 | -------------------------------------------------------------------------------- /utils/label_generation.py: -------------------------------------------------------------------------------- 1 | """ Dynamically generate grasp labels during training. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import torch 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | ROOT_DIR = os.path.dirname(BASE_DIR) 11 | sys.path.append(ROOT_DIR) 12 | # sys.path.append(os.path.join(ROOT_DIR, 'knn')) 13 | 14 | from knn.knn_modules import knn 15 | from loss_utils import GRASP_MAX_WIDTH, batch_viewpoint_params_to_matrix, \ 16 | transform_point_cloud, generate_grasp_views 17 | 18 | 19 | def process_grasp_labels(end_points): 20 | """ Process labels according to scene points and object poses. """ 21 | seed_xyzs = end_points['xyz_graspable'] # (B, M_point, 3) 22 | batch_size, num_samples, _ = seed_xyzs.size() 23 | 24 | batch_grasp_points = [] 25 | batch_grasp_views_rot = [] 26 | batch_grasp_scores = [] 27 | batch_grasp_widths = [] 28 | for i in range(batch_size): 29 | seed_xyz = seed_xyzs[i] # (Ns, 3) 30 | poses = end_points['object_poses_list'][i] # [(3, 4),] 31 | 32 | # get merged grasp points for label computation 33 | grasp_points_merged = [] 34 | grasp_views_rot_merged = [] 35 | grasp_scores_merged = [] 36 | grasp_widths_merged = [] 37 | for obj_idx, pose in enumerate(poses): 38 | grasp_points = end_points['grasp_points_list'][i][obj_idx] # (Np, 3) 39 | grasp_scores = end_points['grasp_scores_list'][i][obj_idx] # (Np, V, A, D) 40 | grasp_widths = end_points['grasp_widths_list'][i][obj_idx] # (Np, V, A, D) 41 | _, V, A, D = grasp_scores.size() 42 | num_grasp_points = grasp_points.size(0) 43 | # generate and transform template grasp views 44 | grasp_views = generate_grasp_views(V).to(pose.device) # (V, 3) 45 | grasp_points_trans = transform_point_cloud(grasp_points, pose, '3x4') 46 | grasp_views_trans = transform_point_cloud(grasp_views, pose[:3, :3], '3x3') 47 | # generate and transform template grasp view rotation 48 | angles = torch.zeros(grasp_views.size(0), dtype=grasp_views.dtype, device=grasp_views.device) 49 | grasp_views_rot = batch_viewpoint_params_to_matrix(-grasp_views, angles) # (V, 3, 3) 50 | grasp_views_rot_trans = torch.matmul(pose[:3, :3], grasp_views_rot) # (V, 3, 3) 51 | 52 | # assign views 53 | grasp_views_ = grasp_views.transpose(0, 1).contiguous().unsqueeze(0) 54 | grasp_views_trans_ = grasp_views_trans.transpose(0, 1).contiguous().unsqueeze(0) 55 | view_inds = knn(grasp_views_trans_, grasp_views_, k=1).squeeze() - 1 56 | grasp_views_rot_trans = torch.index_select(grasp_views_rot_trans, 0, view_inds) # (V, 3, 3) 57 | grasp_views_rot_trans = grasp_views_rot_trans.unsqueeze(0).expand(num_grasp_points, -1, -1, 58 | -1) # (Np, V, 3, 3) 59 | grasp_scores = torch.index_select(grasp_scores, 1, view_inds) # (Np, V, A, D) 60 | grasp_widths = torch.index_select(grasp_widths, 1, view_inds) # (Np, V, A, D) 61 | # add to list 62 | grasp_points_merged.append(grasp_points_trans) 63 | grasp_views_rot_merged.append(grasp_views_rot_trans) 64 | grasp_scores_merged.append(grasp_scores) 65 | grasp_widths_merged.append(grasp_widths) 66 | 67 | grasp_points_merged = torch.cat(grasp_points_merged, dim=0) # (Np', 3) 68 | grasp_views_rot_merged = torch.cat(grasp_views_rot_merged, dim=0) # (Np', V, 3, 3) 69 | grasp_scores_merged = torch.cat(grasp_scores_merged, dim=0) # (Np', V, A, D) 70 | grasp_widths_merged = torch.cat(grasp_widths_merged, dim=0) # (Np', V, A, D) 71 | 72 | # compute nearest neighbors 73 | seed_xyz_ = seed_xyz.transpose(0, 1).contiguous().unsqueeze(0) # (1, 3, Ns) 74 | grasp_points_merged_ = grasp_points_merged.transpose(0, 1).contiguous().unsqueeze(0) # (1, 3, Np') 75 | nn_inds = knn(grasp_points_merged_, seed_xyz_, k=1).squeeze() - 1 # (Ns) 76 | 77 | # assign anchor points to real points 78 | grasp_points_merged = torch.index_select(grasp_points_merged, 0, nn_inds) # (Ns, 3) 79 | grasp_views_rot_merged = torch.index_select(grasp_views_rot_merged, 0, nn_inds) # (Ns, V, 3, 3) 80 | grasp_scores_merged = torch.index_select(grasp_scores_merged, 0, nn_inds) # (Ns, V, A, D) 81 | grasp_widths_merged = torch.index_select(grasp_widths_merged, 0, nn_inds) # (Ns, V, A, D) 82 | 83 | # add to batch 84 | batch_grasp_points.append(grasp_points_merged) 85 | batch_grasp_views_rot.append(grasp_views_rot_merged) 86 | batch_grasp_scores.append(grasp_scores_merged) 87 | batch_grasp_widths.append(grasp_widths_merged) 88 | 89 | batch_grasp_points = torch.stack(batch_grasp_points, 0) # (B, Ns, 3) 90 | batch_grasp_views_rot = torch.stack(batch_grasp_views_rot, 0) # (B, Ns, V, 3, 3) 91 | batch_grasp_scores = torch.stack(batch_grasp_scores, 0) # (B, Ns, V, A, D) 92 | batch_grasp_widths = torch.stack(batch_grasp_widths, 0) # (B, Ns, V, A, D) 93 | 94 | # compute view graspness 95 | view_u_threshold = 0.6 96 | view_grasp_num = 48 97 | batch_grasp_view_valid_mask = (batch_grasp_scores <= view_u_threshold) & (batch_grasp_scores > 0) # (B, Ns, V, A, D) 98 | batch_grasp_view_valid = batch_grasp_view_valid_mask.float() 99 | batch_grasp_view_graspness = torch.sum(torch.sum(batch_grasp_view_valid, dim=-1), dim=-1) / view_grasp_num # (B, Ns, V) 100 | view_graspness_min, _ = torch.min(batch_grasp_view_graspness, dim=-1) # (B, Ns) 101 | view_graspness_max, _ = torch.max(batch_grasp_view_graspness, dim=-1) 102 | view_graspness_max = view_graspness_max.unsqueeze(-1).expand(-1, -1, 300) # (B, Ns, V) 103 | view_graspness_min = view_graspness_min.unsqueeze(-1).expand(-1, -1, 300) # same shape as batch_grasp_view_graspness 104 | batch_grasp_view_graspness = (batch_grasp_view_graspness - view_graspness_min) / (view_graspness_max - view_graspness_min + 1e-5) 105 | 106 | # process scores 107 | label_mask = (batch_grasp_scores > 0) & (batch_grasp_widths <= GRASP_MAX_WIDTH) # (B, Ns, V, A, D) 108 | batch_grasp_scores[~label_mask] = 0 109 | 110 | end_points['batch_grasp_point'] = batch_grasp_points 111 | end_points['batch_grasp_view_rot'] = batch_grasp_views_rot 112 | end_points['batch_grasp_score'] = batch_grasp_scores 113 | end_points['batch_grasp_width'] = batch_grasp_widths 114 | end_points['batch_grasp_view_graspness'] = batch_grasp_view_graspness 115 | 116 | return end_points 117 | 118 | 119 | def match_grasp_view_and_label(end_points): 120 | """ Slice grasp labels according to predicted views. """ 121 | top_view_inds = end_points['grasp_top_view_inds'] # (B, Ns) 122 | template_views_rot = end_points['batch_grasp_view_rot'] # (B, Ns, V, 3, 3) 123 | grasp_scores = end_points['batch_grasp_score'] # (B, Ns, V, A, D) 124 | grasp_widths = end_points['batch_grasp_width'] # (B, Ns, V, A, D, 3) 125 | 126 | B, Ns, V, A, D = grasp_scores.size() 127 | top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, 3, 3) 128 | top_template_views_rot = torch.gather(template_views_rot, 2, top_view_inds_).squeeze(2) 129 | top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, A, D) 130 | top_view_grasp_scores = torch.gather(grasp_scores, 2, top_view_inds_).squeeze(2) 131 | top_view_grasp_widths = torch.gather(grasp_widths, 2, top_view_inds_).squeeze(2) 132 | 133 | u_max = top_view_grasp_scores.max() 134 | po_mask = top_view_grasp_scores > 0 135 | po_mask_num = torch.sum(po_mask) 136 | if po_mask_num > 0: 137 | u_min = top_view_grasp_scores[po_mask].min() 138 | top_view_grasp_scores[po_mask] = torch.log(u_max / top_view_grasp_scores[po_mask]) / (torch.log(u_max / u_min) + 1e-6) 139 | 140 | end_points['batch_grasp_score'] = top_view_grasp_scores # (B, Ns, A, D) 141 | end_points['batch_grasp_width'] = top_view_grasp_widths # (B, Ns, A, D) 142 | 143 | return top_template_views_rot, end_points 144 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for loss computation. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | GRASP_MAX_WIDTH = 0.1 9 | GRASPNESS_THRESHOLD = 0.1 10 | NUM_VIEW = 300 11 | NUM_ANGLE = 12 12 | NUM_DEPTH = 4 13 | M_POINT = 1024 14 | 15 | 16 | def transform_point_cloud(cloud, transform, format='4x4'): 17 | """ Transform points to new coordinates with transformation matrix. 18 | 19 | Input: 20 | cloud: [torch.FloatTensor, (N,3)] 21 | points in original coordinates 22 | transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)] 23 | transformation matrix, could be rotation only or rotation+translation 24 | format: [string, '3x3'/'3x4'/'4x4'] 25 | the shape of transformation matrix 26 | '3x3' --> rotation matrix 27 | '3x4'/'4x4' --> rotation matrix + translation matrix 28 | 29 | Output: 30 | cloud_transformed: [torch.FloatTensor, (N,3)] 31 | points in new coordinates 32 | """ 33 | if not (format == '3x3' or format == '4x4' or format == '3x4'): 34 | raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.') 35 | if format == '3x3': 36 | cloud_transformed = torch.matmul(transform, cloud.T).T 37 | elif format == '4x4' or format == '3x4': 38 | ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1) 39 | cloud_ = torch.cat([cloud, ones], dim=1) 40 | cloud_transformed = torch.matmul(transform, cloud_.T).T 41 | cloud_transformed = cloud_transformed[:, :3] 42 | return cloud_transformed 43 | 44 | 45 | def generate_grasp_views(N=300, phi=(np.sqrt(5) - 1) / 2, center=np.zeros(3), r=1): 46 | """ View sampling on a unit sphere using Fibonacci lattices. 47 | Ref: https://arxiv.org/abs/0912.4540 48 | 49 | Input: 50 | N: [int] 51 | number of sampled views 52 | phi: [float] 53 | constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2 54 | center: [np.ndarray, (3,), np.float32] 55 | sphere center 56 | r: [float] 57 | sphere radius 58 | 59 | Output: 60 | views: [torch.FloatTensor, (N,3)] 61 | sampled view coordinates 62 | """ 63 | views = [] 64 | for i in range(N): 65 | zi = (2 * i + 1) / N - 1 66 | xi = np.sqrt(1 - zi ** 2) * np.cos(2 * i * np.pi * phi) 67 | yi = np.sqrt(1 - zi ** 2) * np.sin(2 * i * np.pi * phi) 68 | views.append([xi, yi, zi]) 69 | views = r * np.array(views) + center 70 | return torch.from_numpy(views.astype(np.float32)) 71 | 72 | 73 | def batch_viewpoint_params_to_matrix(batch_towards, batch_angle): 74 | """ Transform approach vectors and in-plane rotation angles to rotation matrices. 75 | 76 | Input: 77 | batch_towards: [torch.FloatTensor, (N,3)] 78 | approach vectors in batch 79 | batch_angle: [torch.floatTensor, (N,)] 80 | in-plane rotation angles in batch 81 | 82 | Output: 83 | batch_matrix: [torch.floatTensor, (N,3,3)] 84 | rotation matrices in batch 85 | """ 86 | axis_x = batch_towards 87 | ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 88 | zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 89 | axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1) 90 | mask_y = (torch.norm(axis_y, dim=-1) == 0) 91 | axis_y[mask_y, 1] = 1 92 | axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True) 93 | axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True) 94 | axis_z = torch.cross(axis_x, axis_y) 95 | sin = torch.sin(batch_angle) 96 | cos = torch.cos(batch_angle) 97 | R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1) 98 | R1 = R1.reshape([-1, 3, 3]) 99 | R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1) 100 | batch_matrix = torch.matmul(R2, R1) 101 | return batch_matrix 102 | 103 | 104 | def huber_loss(error, delta=1.0): 105 | """ 106 | Args: 107 | error: Torch tensor (d1,d2,...,dk) 108 | Returns: 109 | loss: Torch tensor (d1,d2,...,dk) 110 | 111 | x = error = pred - gt or dist(pred,gt) 112 | 0.5 * |x|^2 if |x|<=d 113 | 0.5 * d^2 + d * (|x|-d) if |x|>d 114 | Author: Charles R. Qi 115 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 116 | """ 117 | abs_error = torch.abs(error) 118 | quadratic = torch.clamp(abs_error, max=delta) 119 | linear = (abs_error - quadratic) 120 | loss = 0.5 * quadratic ** 2 + delta * linear 121 | return loss 122 | --------------------------------------------------------------------------------