├── doc ├── teaser.png └── example_data │ ├── color.png │ ├── depth.png │ ├── meta.mat │ ├── demo_result.png │ └── workspace_mask.png ├── requirements.txt ├── command_demo.sh ├── sync_to_docker.sh ├── dataset ├── command_generate_tolerance_label.sh ├── generate_tolerance_label.py ├── graspnet_dataset.py └── graspnet_dataset_lazy.py ├── knn ├── src │ ├── vision.cpp │ ├── cpu │ │ ├── vision.h │ │ └── knn_cpu.cpp │ ├── cuda │ │ ├── vision.h │ │ └── knn.cu │ └── knn.h ├── knn_modules.py └── setup.py ├── .gitignore ├── pointnet2 ├── _ext_src │ ├── include │ │ ├── cylinder_query.h │ │ ├── ball_query.h │ │ ├── group_points.h │ │ ├── sampling.h │ │ ├── interpolate.h │ │ ├── utils.h │ │ └── cuda_utils.h │ └── src │ │ ├── bindings.cpp │ │ ├── ball_query.cpp │ │ ├── cylinder_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── cylinder_query_gpu.cu │ │ ├── group_points_gpu.cu │ │ ├── sampling.cpp │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ └── sampling_gpu.cu ├── setup.py ├── pytorch_utils.py └── pointnet2_utils.py ├── command_train.sh ├── command_test.sh ├── prof └── __init__.py ├── docker-compose.yml ├── Dockerfile ├── utils ├── loss_utils.py ├── data_utils.py ├── collision_detector.py └── label_generation.py ├── models ├── backbone.py ├── graspnet.py ├── loss.py └── modules.py ├── demo.py ├── test.py ├── README.md ├── train.py └── LICENSE /doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tensorboard 3 | numpy 4 | scipy 5 | open3d>=0.8 6 | Pillow 7 | tqdm 8 | -------------------------------------------------------------------------------- /command_demo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python demo.py --checkpoint_path logs/log_kn/checkpoint.tar 2 | -------------------------------------------------------------------------------- /doc/example_data/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/example_data/color.png -------------------------------------------------------------------------------- /doc/example_data/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/example_data/depth.png -------------------------------------------------------------------------------- /doc/example_data/meta.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/example_data/meta.mat -------------------------------------------------------------------------------- /doc/example_data/demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/example_data/demo_result.png -------------------------------------------------------------------------------- /sync_to_docker.sh: -------------------------------------------------------------------------------- 1 | docker exec -t graspnet-baseline_pytorch_1 rsync --progress -r /nvme/grm/argss/graspnet-baseline /content 2 | -------------------------------------------------------------------------------- /doc/example_data/workspace_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dandelight/graspnet-baseline/HEAD/doc/example_data/workspace_mask.png -------------------------------------------------------------------------------- /dataset/command_generate_tolerance_label.sh: -------------------------------------------------------------------------------- 1 | python generate_tolerance_label.py --dataset_root /data/Benchmark/graspnet --num_workers 50 2 | -------------------------------------------------------------------------------- /knn/src/vision.cpp: -------------------------------------------------------------------------------- 1 | #include "knn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("knn", &knn, "k-nearest neighbors"); 5 | } 6 | -------------------------------------------------------------------------------- /knn/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, 5 | float* query_dev, int query_width, 6 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf); -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | *.ipynb 3 | **/.ipynb_checkpoints/** 4 | *.npy 5 | *.npz 6 | **/.vscode/** 7 | **/grasp_label*/** 8 | **/log*/** 9 | **/dump*/** 10 | **/build/** 11 | *.o 12 | *.so 13 | *.egg 14 | **/*.egg-info/** 15 | logs 16 | dataset/tolerance -------------------------------------------------------------------------------- /knn/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void knn_device(float* ref_dev, int ref_width, 6 | float* query_dev, int query_width, 7 | int height, int k, float* dist_dev, long* ind_dev, cudaStream_t stream); -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cylinder_query.h: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #pragma once 4 | #include 5 | 6 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 7 | const int nsample); 8 | -------------------------------------------------------------------------------- /command_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py --camera realsense --log_dir /nvme/grm/argss/graspnet-baseline/logs/log_rs --batch_size 2 --dataset_root /nvme/grm/data/graspnet 2 | # CUDA_VISIBLE_DEVICES=0 python train.py --camera kinect --log_dir logs/log_kn --batch_size 2 --dataset_root /data/Benchmark/graspnet 3 | -------------------------------------------------------------------------------- /command_test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py --dump_dir logs/dump_rs --checkpoint_path logs/log_rs/checkpoint.tar --camera realsense --dataset_root /data/Benchmark/graspnet 2 | # CUDA_VISIBLE_DEVICES=0 python test.py --dump_dir logs/dump_kn --checkpoint_path logs/log_kn/checkpoint.tar --camera kinect --dataset_root /data/Benchmark/graspnet 3 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /prof/__init__.py: -------------------------------------------------------------------------------- 1 | def memstat(): 2 | with open('/proc/meminfo') as fd: 3 | for line in fd: 4 | if line.startswith('MemTotal'): 5 | MemTotal = line.split()[1] 6 | continue 7 | if line.startswith('MemFree'): 8 | MemFree = line.split()[1] 9 | break 10 | print( "总内存:%sM" % (int(MemTotal)/1024)) 11 | print( "剩余内存:%sM" % (int(MemFree)/1024)) 12 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | pytorch: # 记住改service的名字,或者在下面加一个name字段 4 | build: "." 5 | ports: 6 | - "0:22" # 将容器内22映射到任意端口 7 | volumes: 8 | - $HOME:$HOME 9 | - /nvme:/nvme 10 | shm_size: "32gb" # PyTorch多线程加载数据 11 | stdin_open: true 12 | tty: true 13 | deploy: 14 | resources: 15 | reservations: 16 | devices: 17 | - capabilities: ["gpu"] # NVIDIA GPU支持 18 | entrypoint: "bash" 19 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /knn/knn_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import gc 3 | import operator as op 4 | import functools 5 | import torch 6 | from torch.autograd import Variable, Function 7 | from knn_pytorch import knn_pytorch 8 | # import knn_pytorch 9 | def knn(ref, query, k=1): 10 | """ Compute k nearest neighbors for each query point. 11 | """ 12 | device = ref.device 13 | ref = ref.float().to(device) 14 | query = query.float().to(device) 15 | inds = torch.empty(query.shape[0], k, query.shape[2]).long().to(device) 16 | knn_pytorch.knn(ref, query, inds) 17 | return inds 18 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | #include "cylinder_query.h" 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("gather_points", &gather_points); 14 | m.def("gather_points_grad", &gather_points_grad); 15 | m.def("furthest_point_sampling", &furthest_point_sampling); 16 | 17 | m.def("three_nn", &three_nn); 18 | m.def("three_interpolate", &three_interpolate); 19 | m.def("three_interpolate_grad", &three_interpolate_grad); 20 | 21 | m.def("ball_query", &ball_query); 22 | 23 | m.def("group_points", &group_points); 24 | m.def("group_points_grad", &group_points_grad); 25 | 26 | m.def("cylinder_query", &cylinder_query); 27 | } 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel 2 | # apt换源 3 | RUN sed -i "s@http://.*archive.ubuntu.com@http://repo.huaweicloud.com@g" /etc/apt/sources.list &&\ 4 | sed -i "s@http://.*security.ubuntu.com@http://repo.huaweicloud.com@g" /etc/apt/sources.list 5 | RUN apt-get update 6 | # pip换源 7 | RUN pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple 8 | RUN apt-get install git -y 9 | RUN apt-get install ninja-build 10 | RUN pip install torch tensorboard numpy scipy 'open3d>=0.8' Pillow tqdm 11 | RUN apt-get install libgl1-mesa-glx -y 12 | # COPY 之后的语句都会被重新执行一遍 13 | COPY . /content/graspnet-baseline 14 | WORKDIR /content 15 | RUN cd graspnet-baseline &&\ 16 | cd pointnet2 && TORCH_CUDA_ARCH_LIST="7.5" python setup.py install &&\ 17 | cd ../knn && TORCH_CUDA_ARCH_LIST="7.5" python setup.py install 18 | 19 | RUN git clone https://hub.fastgit.xyz/graspnet/graspnetAPI.git &&\ 20 | cd graspnetAPI && pip install . 21 | 22 | ENTRYPOINT bash 23 | -------------------------------------------------------------------------------- /pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os 10 | ROOT = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | _ext_src_root = "_ext_src" 13 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 14 | "{}/src/*.cu".format(_ext_src_root) 15 | ) 16 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 17 | 18 | setup( 19 | name='pointnet2', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='pointnet2._ext', 23 | sources=_ext_sources, 24 | extra_compile_args={ 25 | "cxx": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 26 | "nvcc": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 27 | }, 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/cylinder_query.cpp: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include "cylinder_query.h" 4 | #include "utils.h" 5 | 6 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 7 | int nsample, const float *new_xyz, 8 | const float *xyz, const float *rot, int *idx); 9 | 10 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 11 | const int nsample) { 12 | CHECK_CONTIGUOUS(new_xyz); 13 | CHECK_CONTIGUOUS(xyz); 14 | CHECK_CONTIGUOUS(rot); 15 | CHECK_IS_FLOAT(new_xyz); 16 | CHECK_IS_FLOAT(xyz); 17 | CHECK_IS_FLOAT(rot); 18 | 19 | if (new_xyz.type().is_cuda()) { 20 | CHECK_CUDA(xyz); 21 | CHECK_CUDA(rot); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, hmin, hmax, nsample, new_xyz.data(), 31 | xyz.data(), rot.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /knn/src/cpu/knn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/vision.h" 2 | 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, float* query_dev, int query_width, 5 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf) 6 | { 7 | // Compute all the distances 8 | for(int query_idx = 0;query_idx dist_dev[query_idx * ref_width + j + 1]) 31 | { 32 | temp_value = dist_dev[query_idx * ref_width + j]; 33 | dist_dev[query_idx * ref_width + j] = dist_dev[query_idx * ref_width + j + 1]; 34 | dist_dev[query_idx * ref_width + j + 1] = temp_value; 35 | temp_idx = ind_buf[j]; 36 | ind_buf[j] = ind_buf[j + 1]; 37 | ind_buf[j + 1] = temp_idx; 38 | } 39 | 40 | } 41 | 42 | for(int i = 0;i < k;i++) 43 | ind_dev[query_idx + i * query_width] = ind_buf[i]; 44 | #if DEBUG 45 | for(int i = 0;i < ref_width;i++) 46 | printf("%d, ", ind_buf[i]); 47 | printf("\n"); 48 | #endif 49 | 50 | } 51 | 52 | 53 | 54 | 55 | 56 | } -------------------------------------------------------------------------------- /knn/src/knn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cpu/vision.h" 3 | 4 | #ifdef WITH_CUDA 5 | #include "cuda/vision.h" 6 | #include 7 | extern THCState *state; 8 | #endif 9 | 10 | 11 | 12 | int knn(at::Tensor& ref, at::Tensor& query, at::Tensor& idx) 13 | { 14 | 15 | // TODO check dimensions 16 | long batch, ref_nb, query_nb, dim, k; 17 | batch = ref.size(0); 18 | dim = ref.size(1); 19 | k = idx.size(1); 20 | ref_nb = ref.size(2); 21 | query_nb = query.size(2); 22 | 23 | float *ref_dev = ref.data(); 24 | float *query_dev = query.data(); 25 | long *idx_dev = idx.data(); 26 | 27 | 28 | 29 | 30 | if (ref.type().is_cuda()) { 31 | #ifdef WITH_CUDA 32 | // TODO raise error if not compiled with CUDA 33 | float *dist_dev = (float*)THCudaMalloc(state, ref_nb * query_nb * sizeof(float)); 34 | 35 | for (int b = 0; b < batch; b++) 36 | { 37 | // knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 38 | // dist_dev, idx_dev + b * k * query_nb, THCState_getCurrentStream(state)); 39 | knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 40 | dist_dev, idx_dev + b * k * query_nb, c10::cuda::getCurrentCUDAStream()); 41 | } 42 | THCudaFree(state, dist_dev); 43 | cudaError_t err = cudaGetLastError(); 44 | if (err != cudaSuccess) 45 | { 46 | printf("error in knn: %s\n", cudaGetErrorString(err)); 47 | THError("aborting"); 48 | } 49 | return 1; 50 | #else 51 | AT_ERROR("Not compiled with GPU support"); 52 | #endif 53 | } 54 | 55 | 56 | float *dist_dev = (float*)malloc(ref_nb * query_nb * sizeof(float)); 57 | long *ind_buf = (long*)malloc(ref_nb * sizeof(long)); 58 | for (int b = 0; b < batch; b++) { 59 | knn_cpu(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 60 | dist_dev, idx_dev + b * k * query_nb, ind_buf); 61 | } 62 | 63 | free(dist_dev); 64 | free(ind_buf); 65 | 66 | return 1; 67 | 68 | } 69 | -------------------------------------------------------------------------------- /knn/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import CUDA_HOME 10 | from torch.utils.cpp_extension import CppExtension 11 | from torch.utils.cpp_extension import CUDAExtension 12 | 13 | requirements = ["torch", "torchvision"] 14 | 15 | 16 | def get_extensions(): 17 | this_dir = os.path.dirname(os.path.abspath(__file__)) 18 | extensions_dir = os.path.join(this_dir, "src") 19 | 20 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 21 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 22 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 23 | 24 | sources = main_file + source_cpu 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if torch.cuda.is_available() and CUDA_HOME is not None: 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | 43 | include_dirs = [extensions_dir] 44 | 45 | ext_modules = [ 46 | extension( 47 | "knn_pytorch.knn_pytorch", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ] 54 | 55 | return ext_modules 56 | 57 | 58 | setup( 59 | name="knn_pytorch", 60 | version="0.1", 61 | author="foolyc", 62 | url="https://github.com/foolyc/torchKNN", 63 | description="KNN implement in Pytorch 1.0 including both cpu version and gpu version", 64 | ext_modules=get_extensions(), 65 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 66 | ) 67 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/cylinder_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | __global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | const float *__restrict__ rot, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | xyz += batch_index * n * 3; 17 | new_xyz += batch_index * m * 3; 18 | rot += batch_index * m * 9; 19 | idx += m * nsample * batch_index; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | 24 | float radius2 = radius * radius; 25 | for (int j = index; j < m; j += stride) { 26 | float new_x = new_xyz[j * 3 + 0]; 27 | float new_y = new_xyz[j * 3 + 1]; 28 | float new_z = new_xyz[j * 3 + 2]; 29 | float r0 = rot[j * 9 + 0]; 30 | float r1 = rot[j * 9 + 1]; 31 | float r2 = rot[j * 9 + 2]; 32 | float r3 = rot[j * 9 + 3]; 33 | float r4 = rot[j * 9 + 4]; 34 | float r5 = rot[j * 9 + 5]; 35 | float r6 = rot[j * 9 + 6]; 36 | float r7 = rot[j * 9 + 7]; 37 | float r8 = rot[j * 9 + 8]; 38 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 39 | float x = xyz[k * 3 + 0] - new_x; 40 | float y = xyz[k * 3 + 1] - new_y; 41 | float z = xyz[k * 3 + 2] - new_z; 42 | float x_rot = r0 * x + r3 * y + r6 * z; 43 | float y_rot = r1 * x + r4 * y + r7 * z; 44 | float z_rot = r2 * x + r5 * y + r8 * z; 45 | float d2 = y_rot * y_rot + z_rot * z_rot; 46 | if (d2 < radius2 && x_rot > hmin && x_rot < hmax) { 47 | if (cnt == 0) { 48 | for (int l = 0; l < nsample; ++l) { 49 | idx[j * nsample + l] = k; 50 | } 51 | } 52 | idx[j * nsample + cnt] = k; 53 | ++cnt; 54 | } 55 | } 56 | } 57 | } 58 | 59 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 60 | int nsample, const float *new_xyz, 61 | const float *xyz, const float *rot, int *idx) { 62 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 63 | query_cylinder_point_kernel<<>>( 64 | b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx); 65 | 66 | CUDA_CHECK_ERRORS(); 67 | } 68 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /dataset/generate_tolerance_label.py: -------------------------------------------------------------------------------- 1 | """ Tolerance label generation. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import time 9 | import argparse 10 | import multiprocessing as mp 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | from data_utils import compute_point_dists 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset_root', required=True, help='Dataset root') 19 | parser.add_argument('--pos_ratio_thresh', type=float, default=0.8, help='Threshold of positive neighbor ratio[default: 0.8]') 20 | parser.add_argument('--mu_thresh', type=float, default=0.55, help='Threshold of friction coefficient[default: 0.55]') 21 | parser.add_argument('--num_workers', type=int, default=50, help='Worker number[default: 50]') 22 | cfgs = parser.parse_args() 23 | 24 | save_path = 'tolerance' 25 | 26 | V = 300 27 | A = 12 28 | D = 4 29 | radius_list = [0.001 * x for x in range(51)] 30 | 31 | def manager(obj_name, pool_size=8): 32 | # load models 33 | label_path = '{}_labels.npz'.format(obj_name) 34 | label = np.load(os.path.join(cfgs.dataset_root, 'grasp_label', label_path)) 35 | points = label['points'] 36 | scores = label['scores'] 37 | 38 | # create dict 39 | tolerance = mp.Manager().dict() 40 | dists = compute_point_dists(points, points) 41 | params = params = (scores, dists) 42 | 43 | # assign works 44 | pool = [] 45 | process_cnt = 0 46 | work_list = [x for x in range(len(points))] 47 | for _ in range(pool_size): 48 | point_ind = work_list.pop(0) 49 | pool.append(mp.Process(target=worker, args=(obj_name, point_ind, params, tolerance))) 50 | [p.start() for p in pool] 51 | 52 | # refill 53 | while len(work_list) > 0: 54 | for ind, p in enumerate(pool): 55 | if not p.is_alive(): 56 | pool.pop(ind) 57 | point_ind = work_list.pop(0) 58 | p = mp.Process(target=worker, args=(obj_name, point_ind, params, tolerance)) 59 | p.start() 60 | pool.append(p) 61 | process_cnt += 1 62 | print('{}/{}'.format(process_cnt, len(points))) 63 | break 64 | while len(pool) > 0: 65 | for ind, p in enumerate(pool): 66 | if not p.is_alive(): 67 | pool.pop(ind) 68 | process_cnt += 1 69 | print('{}/{}'.format(process_cnt, len(points))) 70 | break 71 | 72 | # save tolerance 73 | if not os.path.exists(save_path): 74 | os.mkdir(save_path) 75 | saved_tolerance = [None for _ in range(len(points))] 76 | for i in range(len(points)): 77 | saved_tolerance[i] = tolerance[i] 78 | saved_tolerance = np.array(saved_tolerance) 79 | np.save('{}/{}_tolerance.npy'.format(save_path, obj_name), saved_tolerance) 80 | 81 | def worker(obj_name, point_ind, params, tolerance): 82 | scores, dists = params 83 | tmp_tolerance = np.zeros([V, A, D], dtype=np.float32) 84 | tic = time.time() 85 | for r in radius_list: 86 | dist_mask = (dists[point_ind] <= r) 87 | scores_in_ball = scores[dist_mask] 88 | pos_ratio = ((scores_in_ball > 0) & (scores_in_ball <= cfgs.mu_thresh)).mean(axis=0) 89 | tolerance_mask = (pos_ratio >= cfgs.pos_ratio_thresh) 90 | if tolerance_mask.sum() == 0: 91 | break 92 | tmp_tolerance[tolerance_mask] = r 93 | tolerance[point_ind] = tmp_tolerance 94 | toc = time.time() 95 | print("{}: point {} time".format(obj_name, point_ind), toc - tic) 96 | 97 | if __name__ == '__main__': 98 | obj_list = ['%03d' % x for x in range(88)] 99 | for obj_name in obj_list: 100 | p = mp.Process(target=manager, args=(obj_name, cfgs.num_workers)) 101 | p.start() 102 | p.join() -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for loss computation. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | GRASP_MAX_WIDTH = 0.1 9 | GRASP_MAX_TOLERANCE = 0.05 10 | THRESH_GOOD = 0.7 11 | THRESH_BAD = 0.1 12 | 13 | def transform_point_cloud(cloud, transform, format='4x4'): 14 | """ Transform points to new coordinates with transformation matrix. 15 | 16 | Input: 17 | cloud: [torch.FloatTensor, (N,3)] 18 | points in original coordinates 19 | transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)] 20 | transformation matrix, could be rotation only or rotation+translation 21 | format: [string, '3x3'/'3x4'/'4x4'] 22 | the shape of transformation matrix 23 | '3x3' --> rotation matrix 24 | '3x4'/'4x4' --> rotation matrix + translation matrix 25 | 26 | Output: 27 | cloud_transformed: [torch.FloatTensor, (N,3)] 28 | points in new coordinates 29 | """ 30 | if not (format == '3x3' or format == '4x4' or format == '3x4'): 31 | raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.') 32 | if format == '3x3': 33 | cloud_transformed = torch.matmul(transform, cloud.T).T 34 | elif format == '4x4' or format == '3x4': 35 | ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1) 36 | cloud_ = torch.cat([cloud, ones], dim=1) 37 | cloud_transformed = torch.matmul(transform, cloud_.T).T 38 | cloud_transformed = cloud_transformed[:, :3] 39 | return cloud_transformed 40 | 41 | def generate_grasp_views(N=300, phi=(np.sqrt(5)-1)/2, center=np.zeros(3), r=1): 42 | """ View sampling on a unit sphere using Fibonacci lattices. 43 | Ref: https://arxiv.org/abs/0912.4540 44 | 45 | Input: 46 | N: [int] 47 | number of sampled views 48 | phi: [float] 49 | constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2 50 | center: [np.ndarray, (3,), np.float32] 51 | sphere center 52 | r: [float] 53 | sphere radius 54 | 55 | Output: 56 | views: [torch.FloatTensor, (N,3)] 57 | sampled view coordinates 58 | """ 59 | views = [] 60 | for i in range(N): 61 | zi = (2 * i + 1) / N - 1 62 | xi = np.sqrt(1 - zi**2) * np.cos(2 * i * np.pi * phi) 63 | yi = np.sqrt(1 - zi**2) * np.sin(2 * i * np.pi * phi) 64 | views.append([xi, yi, zi]) 65 | views = r * np.array(views) + center 66 | return torch.from_numpy(views.astype(np.float32)) 67 | 68 | def batch_viewpoint_params_to_matrix(batch_towards, batch_angle): 69 | """ Transform approach vectors and in-plane rotation angles to rotation matrices. 70 | 71 | Input: 72 | batch_towards: [torch.FloatTensor, (N,3)] 73 | approach vectors in batch 74 | batch_angle: [torch.floatTensor, (N,)] 75 | in-plane rotation angles in batch 76 | 77 | Output: 78 | batch_matrix: [torch.floatTensor, (N,3,3)] 79 | rotation matrices in batch 80 | """ 81 | axis_x = batch_towards 82 | ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 83 | zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 84 | axis_y = torch.stack([-axis_x[:,1], axis_x[:,0], zeros], dim=-1) 85 | mask_y = (torch.norm(axis_y, dim=-1) == 0) 86 | axis_y[mask_y,1] = 1 87 | axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True) 88 | axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True) 89 | axis_z = torch.cross(axis_x, axis_y) 90 | sin = torch.sin(batch_angle) 91 | cos = torch.cos(batch_angle) 92 | R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1) 93 | R1 = R1.reshape([-1,3,3]) 94 | R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1) 95 | batch_matrix = torch.matmul(R2, R1) 96 | return batch_matrix 97 | 98 | def huber_loss(error, delta=1.0): 99 | """ 100 | Args: 101 | error: Torch tensor (d1,d2,...,dk) 102 | Returns: 103 | loss: Torch tensor (d1,d2,...,dk) 104 | 105 | x = error = pred - gt or dist(pred,gt) 106 | 0.5 * |x|^2 if |x|<=d 107 | 0.5 * d^2 + d * (|x|-d) if |x|>d 108 | Author: Charles R. Qi 109 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 110 | """ 111 | abs_error = torch.abs(error) 112 | quadratic = torch.clamp(abs_error, max=delta) 113 | linear = (abs_error - quadratic) 114 | loss = 0.5 * quadratic**2 + delta * linear 115 | return loss -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ PointNet2 backbone for feature learning. 2 | Author: Charles R. Qi 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn as nn 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | ROOT_DIR = os.path.dirname(BASE_DIR) 11 | sys.path.append(ROOT_DIR) 12 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 13 | 14 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule 15 | 16 | class Pointnet2Backbone(nn.Module): 17 | r""" 18 | Backbone network for point cloud feature learning. 19 | Based on Pointnet++ single-scale grouping network. 20 | 21 | Parameters 22 | ---------- 23 | input_feature_dim: int 24 | Number of input channels in the feature descriptor for each point. 25 | e.g. 3 for RGB. 26 | """ 27 | def __init__(self, input_feature_dim=0): 28 | super().__init__() 29 | 30 | self.sa1 = PointnetSAModuleVotes( 31 | npoint=2048, 32 | radius=0.04, 33 | nsample=64, 34 | mlp=[input_feature_dim, 64, 64, 128], 35 | use_xyz=True, 36 | normalize_xyz=True 37 | ) 38 | 39 | self.sa2 = PointnetSAModuleVotes( 40 | npoint=1024, 41 | radius=0.1, 42 | nsample=32, 43 | mlp=[128, 128, 128, 256], 44 | use_xyz=True, 45 | normalize_xyz=True 46 | ) 47 | 48 | self.sa3 = PointnetSAModuleVotes( 49 | npoint=512, 50 | radius=0.2, 51 | nsample=16, 52 | mlp=[256, 128, 128, 256], 53 | use_xyz=True, 54 | normalize_xyz=True 55 | ) 56 | 57 | self.sa4 = PointnetSAModuleVotes( 58 | npoint=256, 59 | radius=0.3, 60 | nsample=16, 61 | mlp=[256, 128, 128, 256], 62 | use_xyz=True, 63 | normalize_xyz=True 64 | ) 65 | 66 | self.fp1 = PointnetFPModule(mlp=[256+256,256,256]) 67 | self.fp2 = PointnetFPModule(mlp=[256+256,256,256]) 68 | 69 | def _break_up_pc(self, pc): 70 | xyz = pc[..., 0:3].contiguous() 71 | features = ( 72 | pc[..., 3:].transpose(1, 2).contiguous() 73 | if pc.size(-1) > 3 else None 74 | ) 75 | 76 | return xyz, features 77 | 78 | def forward(self, pointcloud: torch.cuda.FloatTensor, end_points=None): 79 | r""" 80 | Forward pass of the network 81 | 82 | Parameters 83 | ---------- 84 | pointcloud: Variable(torch.cuda.FloatTensor) 85 | (B, N, 3 + input_feature_dim) tensor 86 | Point cloud to run predicts on 87 | Each point in the point-cloud MUST 88 | be formated as (x, y, z, features...) 89 | 90 | Returns 91 | ---------- 92 | end_points: {XXX_xyz, XXX_features, XXX_inds} 93 | XXX_xyz: float32 Tensor of shape (B,K,3) 94 | XXX_features: float32 Tensor of shape (B,D,K) 95 | XXX_inds: int64 Tensor of shape (B,K) values in [0,N-1] 96 | """ 97 | if not end_points: end_points = {} 98 | batch_size = pointcloud.shape[0] 99 | 100 | xyz, features = self._break_up_pc(pointcloud) 101 | end_points['input_xyz'] = xyz 102 | end_points['input_features'] = features 103 | 104 | # --------- 4 SET ABSTRACTION LAYERS --------- 105 | xyz, features, fps_inds = self.sa1(xyz, features) 106 | end_points['sa1_inds'] = fps_inds 107 | end_points['sa1_xyz'] = xyz 108 | end_points['sa1_features'] = features 109 | 110 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023 111 | end_points['sa2_inds'] = fps_inds 112 | end_points['sa2_xyz'] = xyz 113 | end_points['sa2_features'] = features 114 | 115 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511 116 | end_points['sa3_xyz'] = xyz 117 | end_points['sa3_features'] = features 118 | 119 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255 120 | end_points['sa4_xyz'] = xyz 121 | end_points['sa4_features'] = features 122 | 123 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 124 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'], end_points['sa4_features']) 125 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features) 126 | end_points['fp2_features'] = features 127 | end_points['fp2_xyz'] = end_points['sa2_xyz'] 128 | num_seed = end_points['fp2_xyz'].shape[1] 129 | end_points['fp2_inds'] = end_points['sa1_inds'][:,0:num_seed] # indices among the entire input point clouds 130 | 131 | return features, end_points['fp2_xyz'], end_points -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ Demo to show prediction results. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import open3d as o3d 9 | import argparse 10 | import importlib 11 | import scipy.io as scio 12 | from PIL import Image 13 | 14 | import torch 15 | from graspnetAPI import GraspGroup 16 | 17 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 19 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 20 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 21 | 22 | from graspnet import GraspNet, pred_decode 23 | from graspnet_dataset import GraspNetDataset 24 | from collision_detector import ModelFreeCollisionDetector 25 | from data_utils import CameraInfo, create_point_cloud_from_depth_image 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 29 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 30 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 31 | parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]') 32 | parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]') 33 | cfgs = parser.parse_args() 34 | 35 | 36 | def get_net(): 37 | # Init the model 38 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 39 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False) 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | net.to(device) 42 | # Load checkpoint 43 | checkpoint = torch.load(cfgs.checkpoint_path) 44 | net.load_state_dict(checkpoint['model_state_dict']) 45 | start_epoch = checkpoint['epoch'] 46 | print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch)) 47 | # set model to eval mode 48 | net.eval() 49 | return net 50 | 51 | def get_and_process_data(data_dir): 52 | # load data 53 | color = np.array(Image.open(os.path.join(data_dir, 'color.png')), dtype=np.float32) / 255.0 54 | depth = np.array(Image.open(os.path.join(data_dir, 'depth.png'))) 55 | workspace_mask = np.array(Image.open(os.path.join(data_dir, 'workspace_mask.png'))) 56 | meta = scio.loadmat(os.path.join(data_dir, 'meta.mat')) 57 | intrinsic = meta['intrinsic_matrix'] 58 | factor_depth = meta['factor_depth'] 59 | 60 | # generate cloud 61 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 62 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 63 | 64 | # get valid points 65 | mask = (workspace_mask & (depth > 0)) 66 | cloud_masked = cloud[mask] 67 | color_masked = color[mask] 68 | 69 | # sample points 70 | if len(cloud_masked) >= cfgs.num_point: 71 | idxs = np.random.choice(len(cloud_masked), cfgs.num_point, replace=False) 72 | else: 73 | idxs1 = np.arange(len(cloud_masked)) 74 | idxs2 = np.random.choice(len(cloud_masked), cfgs.num_point-len(cloud_masked), replace=True) 75 | idxs = np.concatenate([idxs1, idxs2], axis=0) 76 | cloud_sampled = cloud_masked[idxs] 77 | color_sampled = color_masked[idxs] 78 | 79 | # convert data 80 | cloud = o3d.geometry.PointCloud() 81 | cloud.points = o3d.utility.Vector3dVector(cloud_masked.astype(np.float32)) 82 | cloud.colors = o3d.utility.Vector3dVector(color_masked.astype(np.float32)) 83 | end_points = dict() 84 | cloud_sampled = torch.from_numpy(cloud_sampled[np.newaxis].astype(np.float32)) 85 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 86 | cloud_sampled = cloud_sampled.to(device) 87 | end_points['point_clouds'] = cloud_sampled 88 | end_points['cloud_colors'] = color_sampled 89 | 90 | return end_points, cloud 91 | 92 | def get_grasps(net, end_points): 93 | # Forward pass 94 | with torch.no_grad(): 95 | end_points = net(end_points) 96 | grasp_preds = pred_decode(end_points) 97 | gg_array = grasp_preds[0].detach().cpu().numpy() 98 | gg = GraspGroup(gg_array) 99 | return gg 100 | 101 | def collision_detection(gg, cloud): 102 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 103 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 104 | gg = gg[~collision_mask] 105 | return gg 106 | 107 | def vis_grasps(gg, cloud): 108 | gg.nms() 109 | gg.sort_by_score() 110 | gg = gg[:50] 111 | grippers = gg.to_open3d_geometry_list() 112 | o3d.visualization.draw_geometries([cloud, *grippers]) 113 | 114 | def demo(data_dir): 115 | net = get_net() 116 | end_points, cloud = get_and_process_data(data_dir) 117 | gg = get_grasps(net, end_points) 118 | if cfgs.collision_thresh > 0: 119 | gg = collision_detection(gg, np.array(cloud.points)) 120 | vis_grasps(gg, cloud) 121 | 122 | if __name__=='__main__': 123 | data_dir = 'doc/example_data' 124 | demo(data_dir) 125 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ Testing for GraspNet baseline model. """ 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import argparse 7 | import time 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from graspnetAPI import GraspGroup, GraspNetEval 12 | 13 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 17 | 18 | from graspnet import GraspNet, pred_decode 19 | from graspnet_dataset import GraspNetDataset, collate_fn 20 | from collision_detector import ModelFreeCollisionDetector 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset_root', required=True, help='Dataset root') 24 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 25 | parser.add_argument('--dump_dir', required=True, help='Dump dir to save outputs') 26 | parser.add_argument('--camera', required=True, help='Camera split [realsense/kinect]') 27 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 28 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 29 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]') 30 | parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]') 31 | parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]') 32 | parser.add_argument('--num_workers', type=int, default=30, help='Number of workers used in evaluation [default: 30]') 33 | cfgs = parser.parse_args() 34 | 35 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 36 | if not os.path.exists(cfgs.dump_dir): os.mkdir(cfgs.dump_dir) 37 | 38 | # Init datasets and dataloaders 39 | def my_worker_init_fn(worker_id): 40 | np.random.seed(np.random.get_state()[1][0] + worker_id) 41 | pass 42 | 43 | # Create Dataset and Dataloader 44 | TEST_DATASET = GraspNetDataset(cfgs.dataset_root, valid_obj_idxs=None, grasp_labels=None, split='test', camera=cfgs.camera, num_points=cfgs.num_point, remove_outlier=True, augment=False, load_label=False) 45 | 46 | print(len(TEST_DATASET)) 47 | SCENE_LIST = TEST_DATASET.scene_list() 48 | TEST_DATALOADER = DataLoader(TEST_DATASET, batch_size=cfgs.batch_size, shuffle=False, 49 | num_workers=4, worker_init_fn=my_worker_init_fn, collate_fn=collate_fn) 50 | print(len(TEST_DATALOADER)) 51 | # Init the model 52 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 53 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False) 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | net.to(device) 56 | # Load checkpoint 57 | checkpoint = torch.load(cfgs.checkpoint_path) 58 | net.load_state_dict(checkpoint['model_state_dict']) 59 | start_epoch = checkpoint['epoch'] 60 | print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch)) 61 | 62 | 63 | # ------------------------------------------------------------------------- GLOBAL CONFIG END 64 | 65 | def inference(): 66 | batch_interval = 100 67 | stat_dict = {} # collect statistics 68 | # set model to eval mode (for bn and dp) 69 | net.eval() 70 | tic = time.time() 71 | for batch_idx, batch_data in enumerate(TEST_DATALOADER): 72 | for key in batch_data: 73 | if 'list' in key: 74 | for i in range(len(batch_data[key])): 75 | for j in range(len(batch_data[key][i])): 76 | batch_data[key][i][j] = batch_data[key][i][j].to(device) 77 | else: 78 | batch_data[key] = batch_data[key].to(device) 79 | 80 | # Forward pass 81 | with torch.no_grad(): 82 | end_points = net(batch_data) 83 | grasp_preds = pred_decode(end_points) 84 | 85 | # Dump results for evaluation 86 | for i in range(cfgs.batch_size): 87 | data_idx = batch_idx * cfgs.batch_size + i 88 | preds = grasp_preds[i].detach().cpu().numpy() 89 | gg = GraspGroup(preds) 90 | 91 | # collision detection 92 | if cfgs.collision_thresh > 0: 93 | cloud, _ = TEST_DATASET.get_data(data_idx, return_raw_cloud=True) 94 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 95 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 96 | gg = gg[~collision_mask] 97 | 98 | # save grasps 99 | save_dir = os.path.join(cfgs.dump_dir, SCENE_LIST[data_idx], cfgs.camera) 100 | save_path = os.path.join(save_dir, str(data_idx%256).zfill(4)+'.npy') 101 | if not os.path.exists(save_dir): 102 | os.makedirs(save_dir) 103 | gg.save_npy(save_path) 104 | 105 | if batch_idx % batch_interval == 0: 106 | toc = time.time() 107 | print('Eval batch: %d, time: %fs'%(batch_idx, (toc-tic)/batch_interval)) 108 | tic = time.time() 109 | 110 | def evaluate(): 111 | ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test') 112 | res, ap = ge.eval_all(cfgs.dump_dir, proc=cfgs.num_workers) 113 | save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera)) 114 | np.save(save_dir, res) 115 | 116 | if __name__=='__main__': 117 | inference() 118 | evaluate() 119 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for data processing. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import numpy as np 6 | 7 | class CameraInfo(): 8 | """ Camera intrisics for point cloud creation. """ 9 | def __init__(self, width, height, fx, fy, cx, cy, scale): 10 | self.width = width 11 | self.height = height 12 | self.fx = fx 13 | self.fy = fy 14 | self.cx = cx 15 | self.cy = cy 16 | self.scale = scale 17 | 18 | def create_point_cloud_from_depth_image(depth, camera, organized=True): 19 | """ Generate point cloud using depth image only. 20 | 21 | Input: 22 | depth: [numpy.ndarray, (H,W), numpy.float32] 23 | depth image 24 | camera: [CameraInfo] 25 | camera intrinsics 26 | organized: bool 27 | whether to keep the cloud in image shape (H,W,3) 28 | 29 | Output: 30 | cloud: [numpy.ndarray, (H,W,3)/(H*W,3), numpy.float32] 31 | generated cloud, (H,W,3) for organized=True, (H*W,3) for organized=False 32 | """ 33 | assert(depth.shape[0] == camera.height and depth.shape[1] == camera.width) 34 | xmap = np.arange(camera.width) 35 | ymap = np.arange(camera.height) 36 | xmap, ymap = np.meshgrid(xmap, ymap) 37 | points_z = depth / camera.scale 38 | points_x = (xmap - camera.cx) * points_z / camera.fx 39 | points_y = (ymap - camera.cy) * points_z / camera.fy 40 | cloud = np.stack([points_x, points_y, points_z], axis=-1) 41 | if not organized: 42 | cloud = cloud.reshape([-1, 3]) 43 | return cloud 44 | 45 | def transform_point_cloud(cloud, transform, format='4x4'): 46 | """ Transform points to new coordinates with transformation matrix. 47 | 48 | Input: 49 | cloud: [np.ndarray, (N,3), np.float32] 50 | points in original coordinates 51 | transform: [np.ndarray, (3,3)/(3,4)/(4,4), np.float32] 52 | transformation matrix, could be rotation only or rotation+translation 53 | format: [string, '3x3'/'3x4'/'4x4'] 54 | the shape of transformation matrix 55 | '3x3' --> rotation matrix 56 | '3x4'/'4x4' --> rotation matrix + translation matrix 57 | 58 | Output: 59 | cloud_transformed: [np.ndarray, (N,3), np.float32] 60 | points in new coordinates 61 | """ 62 | if not (format == '3x3' or format == '4x4' or format == '3x4'): 63 | raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.') 64 | if format == '3x3': 65 | cloud_transformed = np.dot(transform, cloud.T).T 66 | elif format == '4x4' or format == '3x4': 67 | ones = np.ones(cloud.shape[0])[:, np.newaxis] 68 | cloud_ = np.concatenate([cloud, ones], axis=1) 69 | cloud_transformed = np.dot(transform, cloud_.T).T 70 | cloud_transformed = cloud_transformed[:, :3] 71 | return cloud_transformed 72 | 73 | def compute_point_dists(A, B): 74 | """ Compute pair-wise point distances in two matrices. 75 | 76 | Input: 77 | A: [np.ndarray, (N,3), np.float32] 78 | point cloud A 79 | B: [np.ndarray, (M,3), np.float32] 80 | point cloud B 81 | 82 | Output: 83 | dists: [np.ndarray, (N,M), np.float32] 84 | distance matrix 85 | """ 86 | A = A[:, np.newaxis, :] 87 | B = B[np.newaxis, :, :] 88 | dists = np.linalg.norm(A-B, axis=-1) 89 | return dists 90 | 91 | def remove_invisible_grasp_points(cloud, grasp_points, pose, th=0.01): 92 | """ Remove invisible part of object model according to scene point cloud. 93 | 94 | Input: 95 | cloud: [np.ndarray, (N,3), np.float32] 96 | scene point cloud 97 | grasp_points: [np.ndarray, (M,3), np.float32] 98 | grasp point label in object coordinates 99 | pose: [np.ndarray, (4,4), np.float32] 100 | transformation matrix from object coordinates to world coordinates 101 | th: [float] 102 | if the minimum distance between a grasp point and the scene points is greater than outlier, the point will be removed 103 | 104 | Output: 105 | visible_mask: [np.ndarray, (M,), np.bool] 106 | mask to show the visible part of grasp points 107 | """ 108 | grasp_points_trans = transform_point_cloud(grasp_points, pose) 109 | dists = compute_point_dists(grasp_points_trans, cloud) 110 | min_dists = dists.min(axis=1) 111 | visible_mask = (min_dists < th) 112 | return visible_mask 113 | 114 | def get_workspace_mask(cloud, seg, trans=None, organized=True, outlier=0): 115 | """ Keep points in workspace as input. 116 | 117 | Input: 118 | cloud: [np.ndarray, (H,W,3), np.float32] 119 | scene point cloud 120 | seg: [np.ndarray, (H,W,), np.uint8] 121 | segmantation label of scene points 122 | trans: [np.ndarray, (4,4), np.float32] 123 | transformation matrix for scene points, default: None. 124 | organized: [bool] 125 | whether to keep the cloud in image shape (H,W,3) 126 | outlier: [float] 127 | if the distance between a point and workspace is greater than outlier, the point will be removed 128 | 129 | Output: 130 | workspace_mask: [np.ndarray, (H,W)/(H*W,), np.bool] 131 | mask to indicate whether scene points are in workspace 132 | """ 133 | if organized: 134 | h, w, _ = cloud.shape 135 | cloud = cloud.reshape([h*w, 3]) 136 | seg = seg.reshape(h*w) 137 | if trans is not None: 138 | cloud = transform_point_cloud(cloud, trans) 139 | foreground = cloud[seg>0] 140 | xmin, ymin, zmin = foreground.min(axis=0) 141 | xmax, ymax, zmax = foreground.max(axis=0) 142 | mask_x = ((cloud[:,0] > xmin-outlier) & (cloud[:,0] < xmax+outlier)) 143 | mask_y = ((cloud[:,1] > ymin-outlier) & (cloud[:,1] < ymax+outlier)) 144 | mask_z = ((cloud[:,2] > zmin-outlier) & (cloud[:,2] < zmax+outlier)) 145 | workspace_mask = (mask_x & mask_y & mask_z) 146 | if organized: 147 | workspace_mask = workspace_mask.reshape([h, w]) 148 | 149 | return workspace_mask -------------------------------------------------------------------------------- /models/graspnet.py: -------------------------------------------------------------------------------- 1 | """ GraspNet baseline model definition. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(ROOT_DIR) 15 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 17 | 18 | from backbone import Pointnet2Backbone 19 | from modules import ApproachNet, CloudCrop, OperationNet, ToleranceNet 20 | from loss import get_loss 21 | from loss_utils import GRASP_MAX_WIDTH, GRASP_MAX_TOLERANCE 22 | from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix 23 | 24 | 25 | class GraspNetStage1(nn.Module): 26 | def __init__(self, input_feature_dim=0, num_view=300): 27 | super().__init__() 28 | self.backbone = Pointnet2Backbone(input_feature_dim) 29 | self.vpmodule = ApproachNet(num_view, 256) 30 | 31 | def forward(self, end_points): 32 | pointcloud = end_points['point_clouds'] 33 | seed_features, seed_xyz, end_points = self.backbone(pointcloud, end_points) 34 | end_points = self.vpmodule(seed_xyz, seed_features, end_points) 35 | return end_points 36 | 37 | 38 | class GraspNetStage2(nn.Module): 39 | def __init__(self, num_angle=12, num_depth=4, cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=True): 40 | super().__init__() 41 | self.num_angle = num_angle 42 | self.num_depth = num_depth 43 | self.is_training = is_training 44 | self.crop = CloudCrop(64, 3, cylinder_radius, hmin, hmax_list) 45 | self.operation = OperationNet(num_angle, num_depth) 46 | self.tolerance = ToleranceNet(num_angle, num_depth) 47 | 48 | def forward(self, end_points): 49 | pointcloud = end_points['input_xyz'] 50 | if self.is_training: 51 | grasp_top_views_rot, _, _, _, end_points = match_grasp_view_and_label(end_points) 52 | seed_xyz = end_points['batch_grasp_point'] 53 | else: 54 | grasp_top_views_rot = end_points['grasp_top_view_rot'] 55 | seed_xyz = end_points['fp2_xyz'] 56 | 57 | vp_features = self.crop(seed_xyz, pointcloud, grasp_top_views_rot) 58 | end_points = self.operation(vp_features, end_points) 59 | end_points = self.tolerance(vp_features, end_points) 60 | 61 | return end_points 62 | 63 | class GraspNet(nn.Module): 64 | def __init__(self, input_feature_dim=0, num_view=300, num_angle=12, num_depth=4, cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=True): 65 | super().__init__() 66 | self.is_training = is_training 67 | self.view_estimator = GraspNetStage1(input_feature_dim, num_view) 68 | self.grasp_generator = GraspNetStage2(num_angle, num_depth, cylinder_radius, hmin, hmax_list, is_training) 69 | 70 | def forward(self, end_points): 71 | end_points = self.view_estimator(end_points) 72 | if self.is_training: 73 | end_points = process_grasp_labels(end_points) 74 | end_points = self.grasp_generator(end_points) 75 | return end_points 76 | 77 | def pred_decode(end_points): 78 | batch_size = len(end_points['point_clouds']) 79 | grasp_preds = [] 80 | for i in range(batch_size): 81 | ## load predictions 82 | objectness_score = end_points['objectness_score'][i].float() 83 | grasp_score = end_points['grasp_score_pred'][i].float() 84 | grasp_center = end_points['fp2_xyz'][i].float() 85 | approaching = -end_points['grasp_top_view_xyz'][i].float() 86 | grasp_angle_class_score = end_points['grasp_angle_cls_pred'][i] 87 | grasp_width = 1.2 * end_points['grasp_width_pred'][i] 88 | grasp_width = torch.clamp(grasp_width, min=0, max=GRASP_MAX_WIDTH) 89 | grasp_tolerance = end_points['grasp_tolerance_pred'][i] 90 | 91 | ## slice preds by angle 92 | # grasp angle 93 | grasp_angle_class = torch.argmax(grasp_angle_class_score, 0) 94 | grasp_angle = grasp_angle_class.float() / 12 * np.pi 95 | # grasp score & width & tolerance 96 | grasp_angle_class_ = grasp_angle_class.unsqueeze(0) 97 | grasp_score = torch.gather(grasp_score, 0, grasp_angle_class_).squeeze(0) 98 | grasp_width = torch.gather(grasp_width, 0, grasp_angle_class_).squeeze(0) 99 | grasp_tolerance = torch.gather(grasp_tolerance, 0, grasp_angle_class_).squeeze(0) 100 | 101 | ## slice preds by score/depth 102 | # grasp depth 103 | grasp_depth_class = torch.argmax(grasp_score, 1, keepdims=True) 104 | grasp_depth = (grasp_depth_class.float()+1) * 0.01 105 | # grasp score & angle & width & tolerance 106 | grasp_score = torch.gather(grasp_score, 1, grasp_depth_class) 107 | grasp_angle = torch.gather(grasp_angle, 1, grasp_depth_class) 108 | grasp_width = torch.gather(grasp_width, 1, grasp_depth_class) 109 | grasp_tolerance = torch.gather(grasp_tolerance, 1, grasp_depth_class) 110 | 111 | ## slice preds by objectness 112 | objectness_pred = torch.argmax(objectness_score, 0) 113 | objectness_mask = (objectness_pred==1) 114 | grasp_score = grasp_score[objectness_mask] 115 | grasp_width = grasp_width[objectness_mask] 116 | grasp_depth = grasp_depth[objectness_mask] 117 | approaching = approaching[objectness_mask] 118 | grasp_angle = grasp_angle[objectness_mask] 119 | grasp_center = grasp_center[objectness_mask] 120 | grasp_tolerance = grasp_tolerance[objectness_mask] 121 | grasp_score = grasp_score * grasp_tolerance / GRASP_MAX_TOLERANCE 122 | 123 | ## convert to rotation matrix 124 | Ns = grasp_angle.size(0) 125 | approaching_ = approaching.view(Ns, 3) 126 | grasp_angle_ = grasp_angle.view(Ns) 127 | rotation_matrix = batch_viewpoint_params_to_matrix(approaching_, grasp_angle_) 128 | rotation_matrix = rotation_matrix.view(Ns, 9) 129 | 130 | # merge preds 131 | grasp_height = 0.02 * torch.ones_like(grasp_score) 132 | obj_ids = -1 * torch.ones_like(grasp_score) 133 | grasp_preds.append(torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, rotation_matrix, grasp_center, obj_ids], axis=-1)) 134 | return grasp_preds -------------------------------------------------------------------------------- /utils/collision_detector.py: -------------------------------------------------------------------------------- 1 | """ Collision detection to remove collided grasp pose predictions. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import open3d as o3d 9 | 10 | class ModelFreeCollisionDetector(): 11 | """ Collision detection in scenes without object labels. Current finger width and length are fixed. 12 | 13 | Input: 14 | scene_points: [numpy.ndarray, (N,3), numpy.float32] 15 | the scene points to detect 16 | voxel_size: [float] 17 | used for downsample 18 | 19 | Example usage: 20 | mfcdetector = ModelFreeCollisionDetector(scene_points, voxel_size=0.005) 21 | collision_mask = mfcdetector.detect(grasp_group, approach_dist=0.03) 22 | collision_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, return_ious=True) 23 | collision_mask, empty_mask = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, 24 | return_empty_grasp=True, empty_thresh=0.01) 25 | collision_mask, empty_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, 26 | return_empty_grasp=True, empty_thresh=0.01, return_ious=True) 27 | """ 28 | def __init__(self, scene_points, voxel_size=0.005): 29 | self.finger_width = 0.01 30 | self.finger_length = 0.06 31 | self.voxel_size = voxel_size 32 | scene_cloud = o3d.geometry.PointCloud() 33 | scene_cloud.points = o3d.utility.Vector3dVector(scene_points) 34 | scene_cloud = scene_cloud.voxel_down_sample(voxel_size) 35 | self.scene_points = np.array(scene_cloud.points) 36 | 37 | def detect(self, grasp_group, approach_dist=0.03, collision_thresh=0.05, return_empty_grasp=False, empty_thresh=0.01, return_ious=False): 38 | """ Detect collision of grasps. 39 | 40 | Input: 41 | grasp_group: [GraspGroup, M grasps] 42 | the grasps to check 43 | approach_dist: [float] 44 | the distance for a gripper to move along approaching direction before grasping 45 | this shifting space requires no point either 46 | collision_thresh: [float] 47 | if global collision iou is greater than this threshold, 48 | a collision is detected 49 | return_empty_grasp: [bool] 50 | if True, return a mask to imply whether there are objects in a grasp 51 | empty_thresh: [float] 52 | if inner space iou is smaller than this threshold, 53 | a collision is detected 54 | only set when [return_empty_grasp] is True 55 | return_ious: [bool] 56 | if True, return global collision iou and part collision ious 57 | 58 | Output: 59 | collision_mask: [numpy.ndarray, (M,), numpy.bool] 60 | True implies collision 61 | [optional] empty_mask: [numpy.ndarray, (M,), numpy.bool] 62 | True implies empty grasp 63 | only returned when [return_empty_grasp] is True 64 | [optional] iou_list: list of [numpy.ndarray, (M,), numpy.float32] 65 | global and part collision ious, containing 66 | [global_iou, left_iou, right_iou, bottom_iou, shifting_iou] 67 | only returned when [return_ious] is True 68 | """ 69 | approach_dist = max(approach_dist, self.finger_width) 70 | T = grasp_group.translations 71 | R = grasp_group.rotation_matrices 72 | heights = grasp_group.heights[:,np.newaxis] 73 | depths = grasp_group.depths[:,np.newaxis] 74 | widths = grasp_group.widths[:,np.newaxis] 75 | targets = self.scene_points[np.newaxis,:,:] - T[:,np.newaxis,:] 76 | targets = np.matmul(targets, R) 77 | 78 | ## collision detection 79 | # height mask 80 | mask1 = ((targets[:,:,2] > -heights/2) & (targets[:,:,2] < heights/2)) 81 | # left finger mask 82 | mask2 = ((targets[:,:,0] > depths - self.finger_length) & (targets[:,:,0] < depths)) 83 | mask3 = (targets[:,:,1] > -(widths/2 + self.finger_width)) 84 | mask4 = (targets[:,:,1] < -widths/2) 85 | # right finger mask 86 | mask5 = (targets[:,:,1] < (widths/2 + self.finger_width)) 87 | mask6 = (targets[:,:,1] > widths/2) 88 | # bottom mask 89 | mask7 = ((targets[:,:,0] <= depths - self.finger_length)\ 90 | & (targets[:,:,0] > depths - self.finger_length - self.finger_width)) 91 | # shifting mask 92 | mask8 = ((targets[:,:,0] <= depths - self.finger_length - self.finger_width)\ 93 | & (targets[:,:,0] > depths - self.finger_length - self.finger_width - approach_dist)) 94 | 95 | # get collision mask of each point 96 | left_mask = (mask1 & mask2 & mask3 & mask4) 97 | right_mask = (mask1 & mask2 & mask5 & mask6) 98 | bottom_mask = (mask1 & mask3 & mask5 & mask7) 99 | shifting_mask = (mask1 & mask3 & mask5 & mask8) 100 | global_mask = (left_mask | right_mask | bottom_mask | shifting_mask) 101 | 102 | # calculate equivalant volume of each part 103 | left_right_volume = (heights * self.finger_length * self.finger_width / (self.voxel_size**3)).reshape(-1) 104 | bottom_volume = (heights * (widths+2*self.finger_width) * self.finger_width / (self.voxel_size**3)).reshape(-1) 105 | shifting_volume = (heights * (widths+2*self.finger_width) * approach_dist / (self.voxel_size**3)).reshape(-1) 106 | volume = left_right_volume*2 + bottom_volume + shifting_volume 107 | 108 | # get collision iou of each part 109 | global_iou = global_mask.sum(axis=1) / (volume+1e-6) 110 | 111 | # get collison mask 112 | collision_mask = (global_iou > collision_thresh) 113 | 114 | if not (return_empty_grasp or return_ious): 115 | return collision_mask 116 | 117 | ret_value = [collision_mask,] 118 | if return_empty_grasp: 119 | inner_mask = (mask1 & mask2 & (~mask4) & (~mask6)) 120 | inner_volume = (heights * self.finger_length * widths / (self.voxel_size**3)).reshape(-1) 121 | empty_mask = (inner_mask.sum(axis=-1)/inner_volume < empty_thresh) 122 | ret_value.append(empty_mask) 123 | if return_ious: 124 | left_iou = left_mask.sum(axis=1) / (left_right_volume+1e-6) 125 | right_iou = right_mask.sum(axis=1) / (left_right_volume+1e-6) 126 | bottom_iou = bottom_mask.sum(axis=1) / (bottom_volume+1e-6) 127 | shifting_iou = shifting_mask.sum(axis=1) / (shifting_volume+1e-6) 128 | ret_value.append([global_iou, left_iou, right_iou, bottom_iou, shifting_iou]) 129 | return ret_value 130 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | """ Loss functions for training. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import sys 10 | import os 11 | import time 12 | 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | ROOT_DIR = os.path.dirname(BASE_DIR) 15 | sys.path.append(ROOT_DIR) 16 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 17 | 18 | from loss_utils import GRASP_MAX_WIDTH, GRASP_MAX_TOLERANCE, THRESH_GOOD, THRESH_BAD,\ 19 | transform_point_cloud, generate_grasp_views,\ 20 | batch_viewpoint_params_to_matrix, huber_loss 21 | 22 | def get_loss(end_points): 23 | objectness_loss, end_points = compute_objectness_loss(end_points) 24 | view_loss, end_points = compute_view_loss(end_points) 25 | grasp_loss, end_points = compute_grasp_loss(end_points) 26 | loss = objectness_loss + view_loss + 0.2 * grasp_loss 27 | end_points['loss/overall_loss'] = loss 28 | return loss, end_points 29 | 30 | def compute_objectness_loss(end_points): 31 | criterion = nn.CrossEntropyLoss(reduction='mean') 32 | objectness_score = end_points['objectness_score'] 33 | objectness_label = end_points['objectness_label'] 34 | fp2_inds = end_points['fp2_inds'].long() 35 | objectness_label = torch.gather(objectness_label, 1, fp2_inds) 36 | loss = criterion(objectness_score, objectness_label) 37 | 38 | end_points['loss/stage1_objectness_loss'] = loss 39 | objectness_pred = torch.argmax(objectness_score, 1) 40 | end_points['stage1_objectness_acc'] = (objectness_pred == objectness_label.long()).float().mean() 41 | 42 | end_points['stage1_objectness_prec'] = (objectness_pred == objectness_label.long())[objectness_pred == 1].float().mean() 43 | end_points['stage1_objectness_recall'] = (objectness_pred == objectness_label.long())[objectness_label == 1].float().mean() 44 | 45 | return loss, end_points 46 | 47 | def compute_view_loss(end_points): 48 | criterion = nn.MSELoss(reduction='none') 49 | view_score = end_points['view_score'] 50 | view_label = end_points['batch_grasp_view_label'] 51 | objectness_label = end_points['objectness_label'] 52 | fp2_inds = end_points['fp2_inds'].long() 53 | V = view_label.size(2) 54 | objectness_label = torch.gather(objectness_label, 1, fp2_inds) 55 | 56 | objectness_mask = (objectness_label > 0) 57 | objectness_mask = objectness_mask.unsqueeze(-1).repeat(1, 1, V) 58 | pos_view_pred_mask = ((view_score >= THRESH_GOOD) & objectness_mask) 59 | 60 | loss = criterion(view_score, view_label) 61 | loss = loss[objectness_mask].mean() 62 | 63 | end_points['loss/stage1_view_loss'] = loss 64 | end_points['stage1_pos_view_pred_count'] = pos_view_pred_mask.long().sum() 65 | 66 | return loss, end_points 67 | 68 | 69 | def compute_grasp_loss(end_points, use_template_in_training=True): 70 | top_view_inds = end_points['grasp_top_view_inds'] # (B, Ns) 71 | vp_rot = end_points['grasp_top_view_rot'] # (B, Ns, view_factor, 3, 3) 72 | objectness_label = end_points['objectness_label'] 73 | fp2_inds = end_points['fp2_inds'].long() 74 | objectness_mask = torch.gather(objectness_label, 1, fp2_inds).bool() # (B, Ns) 75 | 76 | # process labels 77 | batch_grasp_label = end_points['batch_grasp_label'] # (B, Ns, A, D) 78 | batch_grasp_offset = end_points['batch_grasp_offset'] # (B, Ns, A, D, 3) 79 | batch_grasp_tolerance = end_points['batch_grasp_tolerance'] # (B, Ns, A, D) 80 | B, Ns, A, D = batch_grasp_label.size() 81 | 82 | # pick the one with the highest angle score 83 | top_view_grasp_angles = batch_grasp_offset[:, :, :, :, 0] #(B, Ns, A, D) 84 | top_view_grasp_depths = batch_grasp_offset[:, :, :, :, 1] #(B, Ns, A, D) 85 | top_view_grasp_widths = batch_grasp_offset[:, :, :, :, 2] #(B, Ns, A, D) 86 | target_labels_inds = torch.argmax(batch_grasp_label, dim=2, keepdim=True) # (B, Ns, 1, D) 87 | target_labels = torch.gather(batch_grasp_label, 2, target_labels_inds).squeeze(2) # (B, Ns, D) 88 | target_angles = torch.gather(top_view_grasp_angles, 2, target_labels_inds).squeeze(2) # (B, Ns, D) 89 | target_depths = torch.gather(top_view_grasp_depths, 2, target_labels_inds).squeeze(2) # (B, Ns, D) 90 | target_widths = torch.gather(top_view_grasp_widths, 2, target_labels_inds).squeeze(2) # (B, Ns, D) 91 | target_tolerance = torch.gather(batch_grasp_tolerance, 2, target_labels_inds).squeeze(2) # (B, Ns, D) 92 | 93 | graspable_mask = (target_labels > THRESH_BAD) 94 | objectness_mask = objectness_mask.unsqueeze(-1).expand_as(graspable_mask) 95 | loss_mask = (objectness_mask & graspable_mask).float() 96 | 97 | # 1. grasp score loss 98 | target_labels_inds_ = target_labels_inds.transpose(1, 2) # (B, 1, Ns, D) 99 | grasp_score = torch.gather(end_points['grasp_score_pred'], 1, target_labels_inds_).squeeze(1) 100 | grasp_score_loss = huber_loss(grasp_score-target_labels, delta=1.0) 101 | grasp_score_loss = torch.sum(grasp_score_loss * loss_mask) / (loss_mask.sum() + 1e-6) 102 | end_points['loss/stage2_grasp_score_loss'] = grasp_score_loss 103 | 104 | # 2. inplane rotation cls loss 105 | target_angles_cls = target_labels_inds.squeeze(2) # (B, Ns, D) 106 | criterion_grasp_angle_class = nn.CrossEntropyLoss(reduction='none') 107 | grasp_angle_class_score = end_points['grasp_angle_cls_pred'] 108 | grasp_angle_class_loss = criterion_grasp_angle_class(grasp_angle_class_score, target_angles_cls) 109 | grasp_angle_class_loss = torch.sum(grasp_angle_class_loss * loss_mask) / (loss_mask.sum() + 1e-6) 110 | end_points['loss/stage2_grasp_angle_class_loss'] = grasp_angle_class_loss 111 | grasp_angle_class_pred = torch.argmax(grasp_angle_class_score, 1) 112 | end_points['stage2_grasp_angle_class_acc/0_degree'] = (grasp_angle_class_pred==target_angles_cls)[loss_mask.bool()].float().mean() 113 | acc_mask_15 = ((torch.abs(grasp_angle_class_pred-target_angles_cls)<=1) | (torch.abs(grasp_angle_class_pred-target_angles_cls)>=A-1)) 114 | end_points['stage2_grasp_angle_class_acc/15_degree'] = acc_mask_15[loss_mask.bool()].float().mean() 115 | acc_mask_30 = ((torch.abs(grasp_angle_class_pred-target_angles_cls)<=2) | (torch.abs(grasp_angle_class_pred-target_angles_cls)>=A-2)) 116 | end_points['stage2_grasp_angle_class_acc/30_degree'] = acc_mask_30[loss_mask.bool()].float().mean() 117 | 118 | # 3. width reg loss 119 | grasp_width_pred = torch.gather(end_points['grasp_width_pred'], 1, target_labels_inds_).squeeze(1) 120 | grasp_width_loss = huber_loss((grasp_width_pred-target_widths)/GRASP_MAX_WIDTH, delta=1) 121 | grasp_width_loss = torch.sum(grasp_width_loss * loss_mask) / (loss_mask.sum() + 1e-6) 122 | end_points['loss/stage2_grasp_width_loss'] = grasp_width_loss 123 | 124 | # 4. tolerance reg loss 125 | grasp_tolerance_pred = torch.gather(end_points['grasp_tolerance_pred'], 1, target_labels_inds_).squeeze(1) 126 | grasp_tolerance_loss = huber_loss((grasp_tolerance_pred-target_tolerance)/GRASP_MAX_TOLERANCE, delta=1) 127 | grasp_tolerance_loss = torch.sum(grasp_tolerance_loss * loss_mask) / (loss_mask.sum() + 1e-6) 128 | end_points['loss/stage2_grasp_tolerance_loss'] = grasp_tolerance_loss 129 | 130 | grasp_loss = grasp_score_loss + grasp_angle_class_loss\ 131 | + grasp_width_loss + grasp_tolerance_loss 132 | return grasp_loss, end_points -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This fork presents train-time memory improvement. Specificailly, instead of loading training labels at once as the original repo do, we load them just before they are fed into nerual networks. 2 | 3 | # GraspNet Baseline 4 | Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). 5 | 6 | [[paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Fang_GraspNet-1Billion_A_Large-Scale_Benchmark_for_General_Object_Grasping_CVPR_2020_paper.pdf)] 7 | [[dataset](https://graspnet.net/)] 8 | [[API](https://github.com/graspnet/graspnetAPI)] 9 | [[doc](https://graspnetapi.readthedocs.io/en/latest/index.html)] 10 | 11 |
12 | scene_0114 13 | scene_0116 14 | scene_0117 15 |
Top 50 grasps detected by our baseline model. 16 |
17 | 18 | ![teaser](doc/teaser.png) 19 | 20 | ## Requirements 21 | - Python 3 22 | - PyTorch 1.6 23 | - Open3d >=0.8 24 | - TensorBoard 2.3 25 | - NumPy 26 | - SciPy 27 | - Pillow 28 | - tqdm 29 | 30 | ## Installation 31 | Get the code. 32 | ```bash 33 | git clone https://github.com/graspnet/graspnet-baseline.git 34 | cd graspnet-baseline 35 | ``` 36 | Install packages via Pip. 37 | ```bash 38 | pip install -r requirements.txt 39 | ``` 40 | Compile and install pointnet2 operators (code adapted from [votenet](https://github.com/facebookresearch/votenet)). 41 | ```bash 42 | cd pointnet2 43 | python setup.py install 44 | ``` 45 | Compile and install knn operator (code adapted from [pytorch_knn_cuda](https://github.com/chrischoy/pytorch_knn_cuda)). 46 | ```bash 47 | cd knn 48 | python setup.py install 49 | ``` 50 | Install graspnetAPI for evaluation. 51 | ```bash 52 | git clone https://github.com/graspnet/graspnetAPI.git 53 | cd graspnetAPI 54 | pip install . 55 | ``` 56 | 57 | ## Tolerance Label Generation 58 | Tolerance labels are not included in the original dataset, and need additional generation. Make sure you have downloaded the orginal dataset from [GraspNet](https://graspnet.net/). The generation code is in [dataset/generate_tolerance_label.py](dataset/generate_tolerance_label.py). You can simply generate tolerance label by running the script: (`--dataset_root` and `--num_workers` should be specified according to your settings) 59 | ```bash 60 | cd dataset 61 | sh command_generate_tolerance_label.sh 62 | ``` 63 | 64 | Or you can download the tolerance labels from [Google Drive](https://drive.google.com/file/d/1DcjGGhZIJsxd61719N0iWA7L6vNEK0ci/view?usp=sharing)/[Baidu Pan](https://pan.baidu.com/s/1HN29P-csHavJF-R_wec6SQ) and run: 65 | ```bash 66 | mv tolerance.tar dataset/ 67 | cd dataset 68 | tar -xvf tolerance.tar 69 | ``` 70 | 71 | ## Training and Testing 72 | Training examples are shown in [command_train.sh](command_train.sh). `--dataset_root`, `--camera` and `--log_dir` should be specified according to your settings. You can use TensorBoard to visualize training process. 73 | 74 | Testing examples are shown in [command_test.sh](command_test.sh), which contains inference and result evaluation. `--dataset_root`, `--camera`, `--checkpoint_path` and `--dump_dir` should be specified according to your settings. Set `--collision_thresh` to -1 for fast inference. 75 | 76 | The pretrained weights can be downloaded from: 77 | 78 | - `checkpoint-rs.tar` 79 | [[Google Drive](https://drive.google.com/file/d/1hd0G8LN6tRpi4742XOTEisbTXNZ-1jmk/view?usp=sharing)] 80 | [[Baidu Pan](https://pan.baidu.com/s/1Eme60l39tTZrilF0I86R5A)] 81 | - `checkpoint-kn.tar` 82 | [[Google Drive](https://drive.google.com/file/d/1vK-d0yxwyJwXHYWOtH1bDMoe--uZ2oLX/view?usp=sharing)] 83 | [[Baidu Pan](https://pan.baidu.com/s/1QpYzzyID-aG5CgHjPFNB9g)] 84 | 85 | `checkpoint-rs.tar` and `checkpoint-kn.tar` are trained using RealSense data and Kinect data respectively. 86 | 87 | ## Demo 88 | A demo program is provided for grasp detection and visualization using RGB-D images. You can refer to [command_demo.sh](command_demo.sh) to run the program. `--checkpoint_path` should be specified according to your settings (make sure you have downloaded the pretrained weights). The output should be similar to the following example: 89 | 90 |
91 | demo_result 92 |
93 | 94 | __Try your own data__ by modifying `get_and_process_data()` in [demo.py](demo.py). Refer to [doc/example_data/](doc/example_data/) for data preparation. RGB-D images and camera intrinsics are required for inference. `factor_depth` stands for the scale for depth value to be transformed into meters. You can also add a workspace mask for denser output. 95 | 96 | ## Results 97 | Results "In repo" report the model performance with single-view collision detection as post-processing. In evaluation we set `--collision_thresh` to 0.01. 98 | 99 | Evaluation results on RealSense camera: 100 | | | | Seen | | | Similar | | | Novel | | 101 | |:--------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:| 102 | | | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | 103 | | In paper | 27.56 | 33.43 | 16.95 | 26.11 | 34.18 | 14.23 | 10.55 | 11.25 | 3.98 | 104 | | In repo | 47.47 | 55.90 | 41.33 | 42.27 | 51.01 | 35.40 | 16.61 | 20.84 | 8.30 | 105 | 106 | Evaluation results on Kinect camera: 107 | | | | Seen | | | Similar | | | Novel | | 108 | |:--------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:|:------:|:----------------:|:----------------:| 109 | | | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | __AP__ | AP0.8 | AP0.4 | 110 | | In paper | 29.88 | 36.19 | 19.31 | 27.84 | 33.19 | 16.62 | 11.51 | 12.92 | 3.56 | 111 | | In repo | 42.02 | 49.91 | 35.34 | 37.35 | 44.82 | 30.40 | 12.17 | 15.17 | 5.51 | 112 | 113 | ## Citation 114 | Please cite our paper in your publications if it helps your research: 115 | ``` 116 | @inproceedings{fang2020graspnet, 117 | title={GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping}, 118 | author={Fang, Hao-Shu and Wang, Chenxi and Gou, Minghao and Lu, Cewu}, 119 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR)}, 120 | pages={11444--11453}, 121 | year={2020} 122 | } 123 | ``` 124 | 125 | ## License 126 | All data, labels, code and models belong to the graspnet team, MVIG, SJTU and are freely available for free non-commercial use, and may be redistributed under these conditions. For commercial queries, please drop an email at fhaoshu at gmail_dot_com and cc lucewu at sjtu.edu.cn . 127 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /utils/label_generation.py: -------------------------------------------------------------------------------- 1 | """ Dynamically generate grasp labels during training. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import torch 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | ROOT_DIR = os.path.dirname(BASE_DIR) 11 | sys.path.append(ROOT_DIR) 12 | sys.path.append(os.path.join(ROOT_DIR, 'knn')) 13 | 14 | from knn_modules import knn 15 | from loss_utils import GRASP_MAX_WIDTH, batch_viewpoint_params_to_matrix,\ 16 | transform_point_cloud, generate_grasp_views 17 | 18 | def process_grasp_labels(end_points): 19 | """ Process labels according to scene points and object poses. """ 20 | clouds = end_points['input_xyz'] #(B, N, 3) 21 | seed_xyzs = end_points['fp2_xyz'] #(B, Ns, 3) 22 | batch_size, num_samples, _ = seed_xyzs.size() 23 | 24 | batch_grasp_points = [] 25 | batch_grasp_views = [] 26 | batch_grasp_views_rot = [] 27 | batch_grasp_labels = [] 28 | batch_grasp_offsets = [] 29 | batch_grasp_tolerance = [] 30 | for i in range(len(clouds)): 31 | seed_xyz = seed_xyzs[i] #(Ns, 3) 32 | poses = end_points['object_poses_list'][i] #[(3, 4),] 33 | 34 | # get merged grasp points for label computation 35 | grasp_points_merged = [] 36 | grasp_views_merged = [] 37 | grasp_views_rot_merged = [] 38 | grasp_labels_merged = [] 39 | grasp_offsets_merged = [] 40 | grasp_tolerance_merged = [] 41 | for obj_idx, pose in enumerate(poses): 42 | grasp_points = end_points['grasp_points_list'][i][obj_idx] #(Np, 3) 43 | grasp_labels = end_points['grasp_labels_list'][i][obj_idx] #(Np, V, A, D) 44 | grasp_offsets = end_points['grasp_offsets_list'][i][obj_idx] #(Np, V, A, D, 3) 45 | grasp_tolerance = end_points['grasp_tolerance_list'][i][obj_idx] #(Np, V, A, D) 46 | _, V, A, D = grasp_labels.size() 47 | num_grasp_points = grasp_points.size(0) 48 | # generate and transform template grasp views 49 | grasp_views = generate_grasp_views(V).to(pose.device) #(V, 3) 50 | grasp_points_trans = transform_point_cloud(grasp_points, pose, '3x4') 51 | grasp_views_trans = transform_point_cloud(grasp_views, pose[:3,:3], '3x3') 52 | # generate and transform template grasp view rotation 53 | angles = torch.zeros(grasp_views.size(0), dtype=grasp_views.dtype, device=grasp_views.device) 54 | grasp_views_rot = batch_viewpoint_params_to_matrix(-grasp_views, angles) #(V, 3, 3) 55 | grasp_views_rot_trans = torch.matmul(pose[:3,:3], grasp_views_rot) #(V, 3, 3) 56 | 57 | # assign views 58 | grasp_views_ = grasp_views.transpose(0, 1).contiguous().unsqueeze(0) 59 | grasp_views_trans_ = grasp_views_trans.transpose(0, 1).contiguous().unsqueeze(0) 60 | view_inds = knn(grasp_views_trans_, grasp_views_, k=1).squeeze() - 1 61 | grasp_views_trans = torch.index_select(grasp_views_trans, 0, view_inds) #(V, 3) 62 | grasp_views_trans = grasp_views_trans.unsqueeze(0).expand(num_grasp_points, -1, -1) #(Np, V, 3) 63 | grasp_views_rot_trans = torch.index_select(grasp_views_rot_trans, 0, view_inds) #(V, 3, 3) 64 | grasp_views_rot_trans = grasp_views_rot_trans.unsqueeze(0).expand(num_grasp_points, -1, -1, -1) #(Np, V, 3, 3) 65 | grasp_labels = torch.index_select(grasp_labels, 1, view_inds) #(Np, V, A, D) 66 | grasp_offsets = torch.index_select(grasp_offsets, 1, view_inds) #(Np, V, A, D, 3) 67 | grasp_tolerance = torch.index_select(grasp_tolerance, 1, view_inds) #(Np, V, A, D) 68 | # add to list 69 | grasp_points_merged.append(grasp_points_trans) 70 | grasp_views_merged.append(grasp_views_trans) 71 | grasp_views_rot_merged.append(grasp_views_rot_trans) 72 | grasp_labels_merged.append(grasp_labels) 73 | grasp_offsets_merged.append(grasp_offsets) 74 | grasp_tolerance_merged.append(grasp_tolerance) 75 | 76 | grasp_points_merged = torch.cat(grasp_points_merged, dim=0) #(Np', 3) 77 | grasp_views_merged = torch.cat(grasp_views_merged, dim=0) #(Np', V, 3) 78 | grasp_views_rot_merged = torch.cat(grasp_views_rot_merged, dim=0) #(Np', V, 3, 3) 79 | grasp_labels_merged = torch.cat(grasp_labels_merged, dim=0) #(Np', V, A, D) 80 | grasp_offsets_merged = torch.cat(grasp_offsets_merged, dim=0) #(Np', V, A, D, 3) 81 | grasp_tolerance_merged = torch.cat(grasp_tolerance_merged, dim=0) #(Np', V, A, D) 82 | 83 | # compute nearest neighbors 84 | seed_xyz_ = seed_xyz.transpose(0, 1).contiguous().unsqueeze(0) #(1, 3, Ns) 85 | grasp_points_merged_ = grasp_points_merged.transpose(0, 1).contiguous().unsqueeze(0) #(1, 3, Np') 86 | nn_inds = knn(grasp_points_merged_, seed_xyz_, k=1).squeeze() - 1 #(Ns) 87 | 88 | # assign anchor points to real points 89 | grasp_points_merged = torch.index_select(grasp_points_merged, 0, nn_inds) # (Ns, 3) 90 | grasp_views_merged = torch.index_select(grasp_views_merged, 0, nn_inds) # (Ns, V, 3) 91 | grasp_views_rot_merged = torch.index_select(grasp_views_rot_merged, 0, nn_inds) #(Ns, V, 3, 3) 92 | grasp_labels_merged = torch.index_select(grasp_labels_merged, 0, nn_inds) # (Ns, V, A, D) 93 | grasp_offsets_merged = torch.index_select(grasp_offsets_merged, 0, nn_inds) # (Ns, V, A, D, 3) 94 | grasp_tolerance_merged = torch.index_select(grasp_tolerance_merged, 0, nn_inds) # (Ns, V, A, D) 95 | 96 | # add to batch 97 | batch_grasp_points.append(grasp_points_merged) 98 | batch_grasp_views.append(grasp_views_merged) 99 | batch_grasp_views_rot.append(grasp_views_rot_merged) 100 | batch_grasp_labels.append(grasp_labels_merged) 101 | batch_grasp_offsets.append(grasp_offsets_merged) 102 | batch_grasp_tolerance.append(grasp_tolerance_merged) 103 | 104 | batch_grasp_points = torch.stack(batch_grasp_points, 0) #(B, Ns, 3) 105 | batch_grasp_views = torch.stack(batch_grasp_views, 0) #(B, Ns, V, 3) 106 | batch_grasp_views_rot = torch.stack(batch_grasp_views_rot, 0) #(B, Ns, V, 3, 3) 107 | batch_grasp_labels = torch.stack(batch_grasp_labels, 0) #(B, Ns, V, A, D) 108 | batch_grasp_offsets = torch.stack(batch_grasp_offsets, 0) #(B, Ns, V, A, D, 3) 109 | batch_grasp_tolerance = torch.stack(batch_grasp_tolerance, 0) #(B, Ns, V, A, D) 110 | 111 | # process labels 112 | batch_grasp_widths = batch_grasp_offsets[:,:,:,:,:,2] 113 | label_mask = (batch_grasp_labels > 0) & (batch_grasp_widths <= GRASP_MAX_WIDTH) 114 | u_max = batch_grasp_labels.max() 115 | batch_grasp_labels[label_mask] = torch.log(u_max / batch_grasp_labels[label_mask]) 116 | batch_grasp_labels[~label_mask] = 0 117 | batch_grasp_view_scores, _ = batch_grasp_labels.view(batch_size, num_samples, V, A*D).max(dim=-1) 118 | 119 | end_points['batch_grasp_point'] = batch_grasp_points 120 | end_points['batch_grasp_view'] = batch_grasp_views 121 | end_points['batch_grasp_view_rot'] = batch_grasp_views_rot 122 | end_points['batch_grasp_label'] = batch_grasp_labels 123 | end_points['batch_grasp_offset'] = batch_grasp_offsets 124 | end_points['batch_grasp_tolerance'] = batch_grasp_tolerance 125 | end_points['batch_grasp_view_label'] = batch_grasp_view_scores.float() 126 | 127 | return end_points 128 | 129 | def match_grasp_view_and_label(end_points): 130 | """ Slice grasp labels according to predicted views. """ 131 | top_view_inds = end_points['grasp_top_view_inds'] # (B, Ns) 132 | template_views_rot = end_points['batch_grasp_view_rot'] # (B, Ns, V, 3, 3) 133 | grasp_labels = end_points['batch_grasp_label'] # (B, Ns, V, A, D) 134 | grasp_offsets = end_points['batch_grasp_offset'] # (B, Ns, V, A, D, 3) 135 | grasp_tolerance = end_points['batch_grasp_tolerance'] # (B, Ns, V, A, D) 136 | 137 | B, Ns, V, A, D = grasp_labels.size() 138 | top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, 3, 3) 139 | top_template_views_rot = torch.gather(template_views_rot, 2, top_view_inds_).squeeze(2) 140 | top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, A, D) 141 | top_view_grasp_labels = torch.gather(grasp_labels, 2, top_view_inds_).squeeze(2) 142 | top_view_grasp_tolerance = torch.gather(grasp_tolerance, 2, top_view_inds_).squeeze(2) 143 | top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1, 1).expand(-1, -1, -1, A, D, 3) 144 | top_view_grasp_offsets = torch.gather(grasp_offsets, 2, top_view_inds_).squeeze(2) 145 | 146 | end_points['batch_grasp_view_rot'] = top_template_views_rot 147 | end_points['batch_grasp_label'] = top_view_grasp_labels 148 | end_points['batch_grasp_offset'] = top_view_grasp_offsets 149 | end_points['batch_grasp_tolerance'] = top_view_grasp_tolerance 150 | 151 | return top_template_views_rot, top_view_grasp_labels, top_view_grasp_offsets, top_view_grasp_tolerance, end_points -------------------------------------------------------------------------------- /knn/src/cuda/knn.cu: -------------------------------------------------------------------------------- 1 | /** Modifed version of knn-CUDA from https://github.com/vincentfpgarcia/kNN-CUDA 2 | * The modifications are 3 | * removed texture memory usage 4 | * removed split query KNN computation 5 | * added feature extraction with bilinear interpolation 6 | * 7 | * Last modified by Christopher B. Choy 12/23/2016 8 | */ 9 | 10 | // Includes 11 | #include 12 | #include "cuda.h" 13 | 14 | #define IDX2D(i, j, dj) (dj * i + j) 15 | #define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk)) 16 | 17 | #define BLOCK 512 18 | #define MAX_STREAMS 512 19 | 20 | // Constants used by the program 21 | #define BLOCK_DIM 16 22 | #define DEBUG 0 23 | 24 | 25 | /** 26 | * Computes the distance between two matrix A (reference points) and 27 | * B (query points) containing respectively wA and wB points. 28 | * 29 | * @param A pointer on the matrix A 30 | * @param wA width of the matrix A = number of points in A 31 | * @param B pointer on the matrix B 32 | * @param wB width of the matrix B = number of points in B 33 | * @param dim dimension of points = height of matrices A and B 34 | * @param AB pointer on the matrix containing the wA*wB distances computed 35 | */ 36 | __global__ void cuComputeDistanceGlobal( float* A, int wA, 37 | float* B, int wB, int dim, float* AB){ 38 | 39 | // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B 40 | __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; 41 | __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; 42 | 43 | 44 | // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) 45 | __shared__ int begin_A; 46 | __shared__ int begin_B; 47 | __shared__ int step_A; 48 | __shared__ int step_B; 49 | __shared__ int end_A; 50 | 51 | // Thread index 52 | int tx = threadIdx.x; 53 | int ty = threadIdx.y; 54 | 55 | // Other variables 56 | float tmp; 57 | float ssd = 0; 58 | 59 | // Loop parameters 60 | begin_A = BLOCK_DIM * blockIdx.y; 61 | begin_B = BLOCK_DIM * blockIdx.x; 62 | step_A = BLOCK_DIM * wA; 63 | step_B = BLOCK_DIM * wB; 64 | end_A = begin_A + (dim-1) * wA; 65 | 66 | // Conditions 67 | int cond0 = (begin_A + tx < wA); // used to write in shared memory 68 | int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix 69 | int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix 70 | 71 | // Loop over all the sub-matrices of A and B required to compute the block sub-matrix 72 | for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { 73 | // Load the matrices from device memory to shared memory; each thread loads one element of each matrix 74 | if (a/wA + ty < dim){ 75 | shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0; 76 | shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0; 77 | } 78 | else{ 79 | shared_A[ty][tx] = 0; 80 | shared_B[ty][tx] = 0; 81 | } 82 | 83 | // Synchronize to make sure the matrices are loaded 84 | __syncthreads(); 85 | 86 | // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix 87 | if (cond2 && cond1){ 88 | for (int k = 0; k < BLOCK_DIM; ++k){ 89 | tmp = shared_A[k][ty] - shared_B[k][tx]; 90 | ssd += tmp*tmp; 91 | } 92 | } 93 | 94 | // Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration 95 | __syncthreads(); 96 | } 97 | 98 | // Write the block sub-matrix to device memory; each thread writes one element 99 | if (cond2 && cond1) 100 | AB[(begin_A + ty) * wB + begin_B + tx] = ssd; 101 | } 102 | 103 | 104 | /** 105 | * Gathers k-th smallest distances for each column of the distance matrix in the top. 106 | * 107 | * @param dist distance matrix 108 | * @param ind index matrix 109 | * @param width width of the distance matrix and of the index matrix 110 | * @param height height of the distance matrix and of the index matrix 111 | * @param k number of neighbors to consider 112 | */ 113 | __global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){ 114 | 115 | // Variables 116 | int l, i, j; 117 | float *p_dist; 118 | long *p_ind; 119 | float curr_dist, max_dist; 120 | long curr_row, max_row; 121 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 122 | 123 | if (xIndexcurr_dist){ 138 | i=a; 139 | break; 140 | } 141 | } 142 | for (j=l; j>i; j--){ 143 | p_dist[j*width] = p_dist[(j-1)*width]; 144 | p_ind[j*width] = p_ind[(j-1)*width]; 145 | } 146 | p_dist[i*width] = curr_dist; 147 | p_ind[i*width] = l+1; 148 | } else { 149 | p_ind[l*width] = l+1; 150 | } 151 | max_dist = p_dist[curr_row]; 152 | } 153 | 154 | // Part 2 : insert element in the k-th first lines 155 | max_row = (k-1)*width; 156 | for (l=k; lcurr_dist){ 162 | i=a; 163 | break; 164 | } 165 | } 166 | for (j=k-1; j>i; j--){ 167 | p_dist[j*width] = p_dist[(j-1)*width]; 168 | p_ind[j*width] = p_ind[(j-1)*width]; 169 | } 170 | p_dist[i*width] = curr_dist; 171 | p_ind[i*width] = l+1; 172 | max_dist = p_dist[max_row]; 173 | } 174 | } 175 | } 176 | } 177 | 178 | 179 | /** 180 | * Computes the square root of the first line (width-th first element) 181 | * of the distance matrix. 182 | * 183 | * @param dist distance matrix 184 | * @param width width of the distance matrix 185 | * @param k number of neighbors to consider 186 | */ 187 | __global__ void cuParallelSqrt(float *dist, int width, int k){ 188 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 189 | unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y; 190 | if (xIndex>>(ref_dev, ref_nb, query_dev, query_nb, dim, dist_dev); 237 | 238 | // Kernel 2: Sort each column 239 | cuInsertionSort<<>>(dist_dev, ind_dev, query_nb, ref_nb, k); 240 | 241 | // Kernel 3: Compute square root of k first elements 242 | // cuParallelSqrt<<>>(dist_dev, query_nb, k); 243 | 244 | #if DEBUG 245 | unsigned int size_of_float = sizeof(float); 246 | unsigned long size_of_long = sizeof(long); 247 | 248 | float* dist_host = new float[query_nb * k]; 249 | long* idx_host = new long[query_nb * k]; 250 | 251 | // Memory copy of output from device to host 252 | cudaMemcpy(&dist_host[0], dist_dev, 253 | query_nb * k *size_of_float, cudaMemcpyDeviceToHost); 254 | 255 | cudaMemcpy(&idx_host[0], ind_dev, 256 | query_nb * k * size_of_long, cudaMemcpyDeviceToHost); 257 | 258 | int i = 0; 259 | for(i = 0; i < 100; i++){ 260 | printf("IDX[%d]: %d\n", i, (int)idx_host[i]); 261 | } 262 | #endif 263 | } 264 | 265 | 266 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | """ Modules for GraspNet baseline model. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | ROOT_DIR = os.path.dirname(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 16 | 17 | import pytorch_utils as pt_utils 18 | from pointnet2_utils import CylinderQueryAndGroup 19 | from loss_utils import generate_grasp_views, batch_viewpoint_params_to_matrix 20 | 21 | 22 | class ApproachNet(nn.Module): 23 | def __init__(self, num_view, seed_feature_dim): 24 | """ Approach vector estimation from seed point features. 25 | 26 | Input: 27 | num_view: [int] 28 | number of views generated from each each seed point 29 | seed_feature_dim: [int] 30 | number of channels of seed point features 31 | """ 32 | super().__init__() 33 | self.num_view = num_view 34 | self.in_dim = seed_feature_dim 35 | self.conv1 = nn.Conv1d(self.in_dim, self.in_dim, 1) 36 | self.conv2 = nn.Conv1d(self.in_dim, 2+self.num_view, 1) 37 | self.conv3 = nn.Conv1d(2+self.num_view, 2+self.num_view, 1) 38 | self.bn1 = nn.BatchNorm1d(self.in_dim) 39 | self.bn2 = nn.BatchNorm1d(2+self.num_view) 40 | 41 | def forward(self, seed_xyz, seed_features, end_points): 42 | """ Forward pass. 43 | 44 | Input: 45 | seed_xyz: [torch.FloatTensor, (batch_size,num_seed,3)] 46 | coordinates of seed points 47 | seed_features: [torch.FloatTensor, (batch_size,feature_dim,num_seed) 48 | features of seed points 49 | end_points: [dict] 50 | 51 | Output: 52 | end_points: [dict] 53 | """ 54 | B, num_seed, _ = seed_xyz.size() 55 | features = F.relu(self.bn1(self.conv1(seed_features)), inplace=True) 56 | features = F.relu(self.bn2(self.conv2(features)), inplace=True) 57 | features = self.conv3(features) 58 | objectness_score = features[:, :2, :] # (B, 2, num_seed) 59 | view_score = features[:, 2:2+self.num_view, :].transpose(1,2).contiguous() # (B, num_seed, num_view) 60 | end_points['objectness_score'] = objectness_score 61 | end_points['view_score'] = view_score 62 | 63 | # print(view_score.min(), view_score.max(), view_score.mean()) 64 | top_view_scores, top_view_inds = torch.max(view_score, dim=2) # (B, num_seed) 65 | top_view_inds_ = top_view_inds.view(B, num_seed, 1, 1).expand(-1, -1, -1, 3).contiguous() 66 | template_views = generate_grasp_views(self.num_view).to(features.device) # (num_view, 3) 67 | template_views = template_views.view(1, 1, self.num_view, 3).expand(B, num_seed, -1, -1).contiguous() #(B, num_seed, num_view, 3) 68 | vp_xyz = torch.gather(template_views, 2, top_view_inds_).squeeze(2) #(B, num_seed, 3) 69 | vp_xyz_ = vp_xyz.view(-1, 3) 70 | batch_angle = torch.zeros(vp_xyz_.size(0), dtype=vp_xyz.dtype, device=vp_xyz.device) 71 | vp_rot = batch_viewpoint_params_to_matrix(-vp_xyz_, batch_angle).view(B, num_seed, 3, 3) 72 | end_points['grasp_top_view_inds'] = top_view_inds 73 | end_points['grasp_top_view_score'] = top_view_scores 74 | end_points['grasp_top_view_xyz'] = vp_xyz 75 | end_points['grasp_top_view_rot'] = vp_rot 76 | 77 | return end_points 78 | 79 | 80 | class CloudCrop(nn.Module): 81 | """ Cylinder group and align for grasp configure estimation. Return a list of grouped points with different cropping depths. 82 | 83 | Input: 84 | nsample: [int] 85 | sample number in a group 86 | seed_feature_dim: [int] 87 | number of channels of grouped points 88 | cylinder_radius: [float] 89 | radius of the cylinder space 90 | hmin: [float] 91 | height of the bottom surface 92 | hmax_list: [list of float] 93 | list of heights of the upper surface 94 | """ 95 | def __init__(self, nsample, seed_feature_dim, cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04]): 96 | super().__init__() 97 | self.nsample = nsample 98 | self.in_dim = seed_feature_dim 99 | self.cylinder_radius = cylinder_radius 100 | mlps = [self.in_dim, 64, 128, 256] 101 | 102 | self.groupers = [] 103 | for hmax in hmax_list: 104 | self.groupers.append(CylinderQueryAndGroup( 105 | cylinder_radius, hmin, hmax, nsample, use_xyz=True 106 | )) 107 | self.mlps = pt_utils.SharedMLP(mlps, bn=True) 108 | 109 | def forward(self, seed_xyz, pointcloud, vp_rot): 110 | """ Forward pass. 111 | 112 | Input: 113 | seed_xyz: [torch.FloatTensor, (batch_size,num_seed,3)] 114 | coordinates of seed points 115 | pointcloud: [torch.FloatTensor, (batch_size,num_seed,3)] 116 | the points to be cropped 117 | vp_rot: [torch.FloatTensor, (batch_size,num_seed,3,3)] 118 | rotation matrices generated from approach vectors 119 | 120 | Output: 121 | vp_features: [torch.FloatTensor, (batch_size,num_features,num_seed,num_depth)] 122 | features of grouped points in different depths 123 | """ 124 | B, num_seed, _, _ = vp_rot.size() 125 | num_depth = len(self.groupers) 126 | grouped_features = [] 127 | for grouper in self.groupers: 128 | grouped_features.append(grouper( 129 | pointcloud, seed_xyz, vp_rot 130 | )) # (batch_size, feature_dim, num_seed, nsample) 131 | grouped_features = torch.stack(grouped_features, dim=3) # (batch_size, feature_dim, num_seed, num_depth, nsample) 132 | grouped_features = grouped_features.view(B, -1, num_seed*num_depth, self.nsample) # (batch_size, feature_dim, num_seed*num_depth, nsample) 133 | 134 | vp_features = self.mlps( 135 | grouped_features 136 | ) # (batch_size, mlps[-1], num_seed*num_depth, nsample) 137 | vp_features = F.max_pool2d( 138 | vp_features, kernel_size=[1, vp_features.size(3)] 139 | ) # (batch_size, mlps[-1], num_seed*num_depth, 1) 140 | vp_features = vp_features.view(B, -1, num_seed, num_depth) 141 | return vp_features 142 | 143 | 144 | class OperationNet(nn.Module): 145 | """ Grasp configure estimation. 146 | 147 | Input: 148 | num_angle: [int] 149 | number of in-plane rotation angle classes 150 | the value of the i-th class --> i*PI/num_angle (i=0,...,num_angle-1) 151 | num_depth: [int] 152 | number of gripper depth classes 153 | """ 154 | def __init__(self, num_angle, num_depth): 155 | # Output: 156 | # scores(num_angle) 157 | # angle class (num_angle) 158 | # width (num_angle) 159 | super().__init__() 160 | self.num_angle = num_angle 161 | self.num_depth = num_depth 162 | 163 | self.conv1 = nn.Conv1d(256, 128, 1) 164 | self.conv2 = nn.Conv1d(128, 128, 1) 165 | self.conv3 = nn.Conv1d(128, 3*num_angle, 1) 166 | self.bn1 = nn.BatchNorm1d(128) 167 | self.bn2 = nn.BatchNorm1d(128) 168 | 169 | def forward(self, vp_features, end_points): 170 | """ Forward pass. 171 | 172 | Input: 173 | vp_features: [torch.FloatTensor, (batch_size,num_seed,3)] 174 | features of grouped points in different depths 175 | end_points: [dict] 176 | 177 | Output: 178 | end_points: [dict] 179 | """ 180 | B, _, num_seed, num_depth = vp_features.size() 181 | vp_features = vp_features.view(B, -1, num_seed*num_depth) 182 | vp_features = F.relu(self.bn1(self.conv1(vp_features)), inplace=True) 183 | vp_features = F.relu(self.bn2(self.conv2(vp_features)), inplace=True) 184 | vp_features = self.conv3(vp_features) 185 | vp_features = vp_features.view(B, -1, num_seed, num_depth) 186 | 187 | # split prediction 188 | end_points['grasp_score_pred'] = vp_features[:, 0:self.num_angle] 189 | end_points['grasp_angle_cls_pred'] = vp_features[:, self.num_angle:2*self.num_angle] 190 | end_points['grasp_width_pred'] = vp_features[:, 2*self.num_angle:3*self.num_angle] 191 | return end_points 192 | 193 | 194 | class ToleranceNet(nn.Module): 195 | """ Grasp tolerance prediction. 196 | 197 | Input: 198 | num_angle: [int] 199 | number of in-plane rotation angle classes 200 | the value of the i-th class --> i*PI/num_angle (i=0,...,num_angle-1) 201 | num_depth: [int] 202 | number of gripper depth classes 203 | """ 204 | def __init__(self, num_angle, num_depth): 205 | # Output: 206 | # tolerance (num_angle) 207 | super().__init__() 208 | self.conv1 = nn.Conv1d(256, 128, 1) 209 | self.conv2 = nn.Conv1d(128, 128, 1) 210 | self.conv3 = nn.Conv1d(128, num_angle, 1) 211 | self.bn1 = nn.BatchNorm1d(128) 212 | self.bn2 = nn.BatchNorm1d(128) 213 | 214 | def forward(self, vp_features, end_points): 215 | """ Forward pass. 216 | 217 | Input: 218 | vp_features: [torch.FloatTensor, (batch_size,num_seed,3)] 219 | features of grouped points in different depths 220 | end_points: [dict] 221 | 222 | Output: 223 | end_points: [dict] 224 | """ 225 | B, _, num_seed, num_depth = vp_features.size() 226 | vp_features = vp_features.view(B, -1, num_seed*num_depth) 227 | vp_features = F.relu(self.bn1(self.conv1(vp_features)), inplace=True) 228 | vp_features = F.relu(self.bn2(self.conv2(vp_features)), inplace=True) 229 | vp_features = self.conv3(vp_features) 230 | vp_features = vp_features.view(B, -1, num_seed, num_depth) 231 | end_points['grasp_tolerance_pred'] = vp_features 232 | return end_points -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ Training routine for GraspNet baseline model. """ 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | from datetime import datetime 7 | import argparse 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.optim import lr_scheduler 13 | from torch.utils.data import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 18 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 19 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 20 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'prof')) 22 | from prof import memstat 23 | from graspnet import GraspNet, get_loss 24 | from pytorch_utils import BNMomentumScheduler 25 | # ~~~~~!!!!!! Modified 26 | from graspnet_dataset_lazy import GraspNetDataset, collate_fn, load_grasp_labels_list, load_grasp_labels 27 | from label_generation import process_grasp_labels 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dataset_root', required=True, help='Dataset root') 31 | parser.add_argument('--camera', required=True, help='Camera split [realsense/kinect]') 32 | parser.add_argument('--checkpoint_path', default=None, help='Model checkpoint path [default: None]') 33 | parser.add_argument('--log_dir', default='log', help='Dump dir to save model checkpoint [default: log]') 34 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 35 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 36 | parser.add_argument('--max_epoch', type=int, default=18, help='Epoch to run [default: 18]') 37 | parser.add_argument('--batch_size', type=int, default=2, help='Batch Size during training [default: 2]') 38 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 39 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization L2 weight decay [default: 0]') 40 | parser.add_argument('--bn_decay_step', type=int, default=2, help='Period of BN decay (in epochs) [default: 2]') 41 | parser.add_argument('--bn_decay_rate', type=float, default=0.5, help='Decay rate for BN decay [default: 0.5]') 42 | parser.add_argument('--lr_decay_steps', default='8,12,16', help='When to decay the learning rate (in epochs) [default: 8,12,16]') 43 | parser.add_argument('--lr_decay_rates', default='0.1,0.1,0.1', help='Decay rates for lr decay [default: 0.1,0.1,0.1]') 44 | cfgs = parser.parse_args() 45 | 46 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 47 | EPOCH_CNT = 0 48 | LR_DECAY_STEPS = [int(x) for x in cfgs.lr_decay_steps.split(',')] 49 | LR_DECAY_RATES = [float(x) for x in cfgs.lr_decay_rates.split(',')] 50 | assert(len(LR_DECAY_STEPS)==len(LR_DECAY_RATES)) 51 | DEFAULT_CHECKPOINT_PATH = os.path.join(cfgs.log_dir, 'checkpoint.tar') 52 | CHECKPOINT_PATH = cfgs.checkpoint_path if cfgs.checkpoint_path is not None \ 53 | else DEFAULT_CHECKPOINT_PATH 54 | 55 | if not os.path.exists(cfgs.log_dir): 56 | os.makedirs(cfgs.log_dir) 57 | 58 | LOG_FOUT = open(os.path.join(cfgs.log_dir, 'log_train.txt'), 'a') 59 | LOG_FOUT.write(str(cfgs)+'\n') 60 | def log_string(out_str): 61 | LOG_FOUT.write(out_str+'\n') 62 | LOG_FOUT.flush() 63 | print(out_str) 64 | 65 | # Init datasets and dataloaders 66 | def my_worker_init_fn(worker_id): 67 | np.random.seed(np.random.get_state()[1][0] + worker_id) 68 | pass 69 | 70 | memstat() 71 | 72 | # Create Dataset and Dataloader 73 | valid_obj_idxs, grasp_labels_list = load_grasp_labels_list(cfgs.dataset_root) 74 | TRAIN_DATASET = GraspNetDataset(cfgs.dataset_root, valid_obj_idxs, None, grasp_labels_list, camera=cfgs.camera, split='train', num_points=cfgs.num_point, remove_outlier=True, augment=True) 75 | memstat() 76 | TEST_DATASET = GraspNetDataset(cfgs.dataset_root, valid_obj_idxs, None, grasp_labels_list, camera=cfgs.camera, split='test_seen', num_points=cfgs.num_point, remove_outlier=True, augment=False) 77 | memstat() 78 | print(len(TRAIN_DATASET), len(TEST_DATASET)) 79 | memstat() 80 | TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=cfgs.batch_size, shuffle=True, 81 | num_workers=4, worker_init_fn=my_worker_init_fn, collate_fn=collate_fn) 82 | memstat() 83 | TEST_DATALOADER = DataLoader(TEST_DATASET, batch_size=cfgs.batch_size, shuffle=False, 84 | num_workers=4, worker_init_fn=my_worker_init_fn, collate_fn=collate_fn) 85 | print(len(TRAIN_DATALOADER), len(TEST_DATALOADER)) 86 | # Init the model and optimzier 87 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 88 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04]) 89 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 90 | net.to(device) 91 | # Load the Adam optimizer 92 | optimizer = optim.Adam(net.parameters(), lr=cfgs.learning_rate, weight_decay=cfgs.weight_decay) 93 | # Load checkpoint if there is any 94 | it = -1 # for the initialize value of `LambdaLR` and `BNMomentumScheduler` 95 | start_epoch = 0 96 | if CHECKPOINT_PATH is not None and os.path.isfile(CHECKPOINT_PATH): 97 | checkpoint = torch.load(CHECKPOINT_PATH) 98 | net.load_state_dict(checkpoint['model_state_dict']) 99 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 100 | start_epoch = checkpoint['epoch'] 101 | log_string("-> loaded checkpoint %s (epoch: %d)"%(CHECKPOINT_PATH, start_epoch)) 102 | # Decay Batchnorm momentum from 0.5 to 0.999 103 | # note: pytorch's BN momentum (default 0.1)= 1 - tensorflow's BN momentum 104 | BN_MOMENTUM_INIT = 0.5 105 | BN_MOMENTUM_MAX = 0.001 106 | bn_lbmd = lambda it: max(BN_MOMENTUM_INIT * cfgs.bn_decay_rate**(int(it / cfgs.bn_decay_step)), BN_MOMENTUM_MAX) 107 | bnm_scheduler = BNMomentumScheduler(net, bn_lambda=bn_lbmd, last_epoch=start_epoch-1) 108 | 109 | 110 | def get_current_lr(epoch): 111 | lr = cfgs.learning_rate 112 | for i,lr_decay_epoch in enumerate(LR_DECAY_STEPS): 113 | if epoch >= lr_decay_epoch: 114 | lr *= LR_DECAY_RATES[i] 115 | return lr 116 | 117 | def adjust_learning_rate(optimizer, epoch): 118 | lr = get_current_lr(epoch) 119 | for param_group in optimizer.param_groups: 120 | param_group['lr'] = lr 121 | 122 | # TensorBoard Visualizers 123 | TRAIN_WRITER = SummaryWriter(os.path.join(cfgs.log_dir, 'train')) 124 | TEST_WRITER = SummaryWriter(os.path.join(cfgs.log_dir, 'test')) 125 | 126 | # ------------------------------------------------------------------------- GLOBAL CONFIG END 127 | 128 | def train_one_epoch(): 129 | memstat() 130 | print("Start the training.") 131 | stat_dict = {} # collect statistics 132 | adjust_learning_rate(optimizer, EPOCH_CNT) 133 | bnm_scheduler.step() # decay BN momentum 134 | # set model to training mode 135 | net.train() 136 | for batch_idx, batch_data_label in enumerate(TRAIN_DATALOADER): 137 | for key in batch_data_label: 138 | if 'list' in key: 139 | for i in range(len(batch_data_label[key])): 140 | for j in range(len(batch_data_label[key][i])): 141 | batch_data_label[key][i][j] = batch_data_label[key][i][j].to(device) 142 | else: 143 | batch_data_label[key] = batch_data_label[key].to(device) 144 | 145 | # Forward pass 146 | end_points = net(batch_data_label) 147 | 148 | # Compute loss and gradients, update parameters. 149 | loss, end_points = get_loss(end_points) 150 | loss.backward() 151 | if (batch_idx+1) % 1 == 0: 152 | optimizer.step() 153 | optimizer.zero_grad() 154 | 155 | # Accumulate statistics and print out 156 | for key in end_points: 157 | if 'loss' in key or 'acc' in key or 'prec' in key or 'recall' in key or 'count' in key: 158 | if key not in stat_dict: stat_dict[key] = 0 159 | stat_dict[key] += end_points[key].item() 160 | 161 | batch_interval = 10 162 | if (batch_idx+1) % batch_interval == 0: 163 | log_string(' ---- batch: %03d ----' % (batch_idx+1)) 164 | for key in sorted(stat_dict.keys()): 165 | TRAIN_WRITER.add_scalar(key, stat_dict[key]/batch_interval, (EPOCH_CNT*len(TRAIN_DATALOADER)+batch_idx)*cfgs.batch_size) 166 | log_string('mean %s: %f'%(key, stat_dict[key]/batch_interval)) 167 | stat_dict[key] = 0 168 | 169 | def evaluate_one_epoch(): 170 | stat_dict = {} # collect statistics 171 | # set model to eval mode (for bn and dp) 172 | net.eval() 173 | for batch_idx, batch_data_label in enumerate(TEST_DATALOADER): 174 | if batch_idx % 10 == 0: 175 | print('Eval batch: %d'%(batch_idx)) 176 | for key in batch_data_label: 177 | if 'list' in key: 178 | for i in range(len(batch_data_label[key])): 179 | for j in range(len(batch_data_label[key][i])): 180 | batch_data_label[key][i][j] = batch_data_label[key][i][j].to(device) 181 | else: 182 | batch_data_label[key] = batch_data_label[key].to(device) 183 | 184 | # Forward pass 185 | with torch.no_grad(): 186 | end_points = net(batch_data_label) 187 | 188 | # Compute loss 189 | loss, end_points = get_loss(end_points) 190 | 191 | # Accumulate statistics and print out 192 | for key in end_points: 193 | if 'loss' in key or 'acc' in key or 'prec' in key or 'recall' in key or 'count' in key: 194 | if key not in stat_dict: stat_dict[key] = 0 195 | stat_dict[key] += end_points[key].item() 196 | 197 | for key in sorted(stat_dict.keys()): 198 | TEST_WRITER.add_scalar(key, stat_dict[key]/float(batch_idx+1), (EPOCH_CNT+1)*len(TRAIN_DATALOADER)*cfgs.batch_size) 199 | log_string('eval mean %s: %f'%(key, stat_dict[key]/(float(batch_idx+1)))) 200 | 201 | mean_loss = stat_dict['loss/overall_loss']/float(batch_idx+1) 202 | return mean_loss 203 | 204 | 205 | def train(start_epoch): 206 | memstat() 207 | global EPOCH_CNT 208 | min_loss = 1e10 209 | loss = 0 210 | for epoch in range(start_epoch, cfgs.max_epoch): 211 | EPOCH_CNT = epoch 212 | log_string('**** EPOCH %03d ****' % (epoch)) 213 | log_string('Current learning rate: %f'%(get_current_lr(epoch))) 214 | log_string('Current BN decay momentum: %f'%(bnm_scheduler.lmbd(bnm_scheduler.last_epoch))) 215 | log_string(str(datetime.now())) 216 | # Reset numpy seed. 217 | # REF: https://github.com/pytorch/pytorch/issues/5059 218 | np.random.seed() 219 | train_one_epoch() 220 | loss = evaluate_one_epoch() 221 | # Save checkpoint 222 | save_dict = {'epoch': epoch+1, # after training one epoch, the start_epoch should be epoch+1 223 | 'optimizer_state_dict': optimizer.state_dict(), 224 | 'loss': loss, 225 | } 226 | try: # with nn.DataParallel() the net is added as a submodule of DataParallel 227 | save_dict['model_state_dict'] = net.module.state_dict() 228 | except: 229 | save_dict['model_state_dict'] = net.state_dict() 230 | torch.save(save_dict, os.path.join(cfgs.log_dir, 'checkpoint.tar')) 231 | 232 | if __name__=='__main__': 233 | train(start_epoch) 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GRASPNET-BASELINE 2 | SOFTWARE LICENSE AGREEMENT 3 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY 4 | 5 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. 6 | 7 | This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Shanghai Jiao Tong University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. 8 | 9 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: 10 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, 11 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). 12 | 13 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. 14 | 15 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. 16 | 17 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. 18 | 19 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. 20 | 21 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “AlphaPose", "Shanghai Jiao Tong" or any renditions thereof without the prior written permission of Licensor. 22 | 23 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. 24 | 25 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. 26 | 27 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. 28 | 29 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. 30 | 31 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. 32 | 33 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. 34 | 35 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. 36 | 37 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. 38 | 39 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable 40 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. 41 | 42 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. 43 | 44 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. 45 | 46 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. 47 | 48 | 49 | 50 | ************************************************************************ 51 | 52 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 53 | 54 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 55 | 56 | 1. PyTorch (https://github.com/pytorch/pytorch) 57 | 58 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 59 | 60 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 61 | 62 | From PyTorch: 63 | 64 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 65 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 66 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 67 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 68 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 69 | Copyright (c) 2011-2013 NYU (Clement Farabet) 70 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 71 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 72 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 73 | 74 | From Caffe2: 75 | 76 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 77 | 78 | All contributions by Facebook: 79 | Copyright (c) 2016 Facebook Inc. 80 | 81 | All contributions by Google: 82 | Copyright (c) 2015 Google Inc. 83 | All rights reserved. 84 | 85 | All contributions by Yangqing Jia: 86 | Copyright (c) 2015 Yangqing Jia 87 | All rights reserved. 88 | 89 | All contributions by Kakao Brain: 90 | Copyright 2019-2020 Kakao Brain 91 | 92 | All contributions from Caffe: 93 | Copyright(c) 2013, 2014, 2015, the respective contributors 94 | All rights reserved. 95 | 96 | All other contributions: 97 | Copyright(c) 2015, 2016 the respective contributors 98 | All rights reserved. 99 | 100 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 101 | copyright over their contributions to Caffe2. The project versioning records 102 | all such contribution and copyright details. If a contributor wants to further 103 | mark their specific copyright on a particular contribution, they should 104 | indicate their copyright solely in the commit message of the change when it is 105 | committed. 106 | 107 | All rights reserved. 108 | 109 | Redistribution and use in source and binary forms, with or without 110 | modification, are permitted provided that the following conditions are met: 111 | 112 | 1. Redistributions of source code must retain the above copyright 113 | notice, this list of conditions and the following disclaimer. 114 | 115 | 2. Redistributions in binary form must reproduce the above copyright 116 | notice, this list of conditions and the following disclaimer in the 117 | documentation and/or other materials provided with the distribution. 118 | 119 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 120 | and IDIAP Research Institute nor the names of its contributors may be 121 | used to endorse or promote products derived from this software without 122 | specific prior written permission. 123 | 124 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 125 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 126 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 127 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 128 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 129 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 130 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 131 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 132 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 133 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 134 | POSSIBILITY OF SUCH DAMAGE. 135 | 136 | 2. VoteNet (https://github.com/facebookresearch/votenet) 137 | 138 | MIT License 139 | 140 | Copyright (c) Facebook, Inc. and its affiliates. 141 | 142 | Permission is hereby granted, free of charge, to any person obtaining a copy 143 | of this software and associated documentation files (the "Software"), to deal 144 | in the Software without restriction, including without limitation the rights 145 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 146 | copies of the Software, and to permit persons to whom the Software is 147 | furnished to do so, subject to the following conditions: 148 | 149 | The above copyright notice and this permission notice shall be included in all 150 | copies or substantial portions of the Software. 151 | 152 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 153 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 154 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 155 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 156 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 157 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 158 | SOFTWARE. 159 | 160 | ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** 161 | -------------------------------------------------------------------------------- /dataset/graspnet_dataset.py: -------------------------------------------------------------------------------- 1 | """ GraspNet dataset processing. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import scipy.io as scio 9 | from PIL import Image 10 | 11 | import torch 12 | try: 13 | from torch._six import container_abcs 14 | except ImportError: 15 | import collections.abc as container_abcs 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = os.path.dirname(BASE_DIR) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from data_utils import CameraInfo, transform_point_cloud, create_point_cloud_from_depth_image,\ 23 | get_workspace_mask, remove_invisible_grasp_points 24 | 25 | class GraspNetDataset(Dataset): 26 | def __init__(self, root, valid_obj_idxs, grasp_labels, camera='kinect', split='train', num_points=20000, 27 | remove_outlier=False, remove_invisible=True, augment=False, load_label=True): 28 | assert(num_points<=50000) 29 | self.root = root 30 | self.split = split 31 | self.num_points = num_points 32 | self.remove_outlier = remove_outlier 33 | self.remove_invisible = remove_invisible 34 | self.valid_obj_idxs = valid_obj_idxs 35 | self.grasp_labels = grasp_labels 36 | print(self.grasp_labels.keys()) 37 | self.camera = camera 38 | self.augment = augment 39 | self.load_label = load_label 40 | self.collision_labels = {} 41 | 42 | if split == 'train': 43 | self.sceneIds = list( range(100) ) 44 | elif split == 'test': 45 | self.sceneIds = list( range(100,190) ) 46 | elif split == 'test_seen': 47 | self.sceneIds = list( range(100,130) ) 48 | elif split == 'test_similar': 49 | self.sceneIds = list( range(130,160) ) 50 | elif split == 'test_novel': 51 | self.sceneIds = list( range(160,190) ) 52 | self.sceneIds = ['scene_{}'.format(str(x).zfill(4)) for x in self.sceneIds] 53 | 54 | self.colorpath = [] 55 | self.depthpath = [] 56 | self.labelpath = [] 57 | self.metapath = [] 58 | self.scenename = [] 59 | self.frameid = [] 60 | for x in tqdm(self.sceneIds, desc = 'Loading data path and collision labels...'): 61 | for img_num in range(256): 62 | self.colorpath.append(os.path.join(root, 'scenes', x, camera, 'rgb', str(img_num).zfill(4)+'.png')) 63 | self.depthpath.append(os.path.join(root, 'scenes', x, camera, 'depth', str(img_num).zfill(4)+'.png')) 64 | self.labelpath.append(os.path.join(root, 'scenes', x, camera, 'label', str(img_num).zfill(4)+'.png')) 65 | self.metapath.append(os.path.join(root, 'scenes', x, camera, 'meta', str(img_num).zfill(4)+'.mat')) 66 | self.scenename.append(x.strip()) 67 | self.frameid.append(img_num) 68 | if self.load_label: 69 | collision_labels = np.load(os.path.join(root, 'collision_label', x.strip(), 'collision_labels.npz')) 70 | self.collision_labels[x.strip()] = {} 71 | for i in range(len(collision_labels)): 72 | self.collision_labels[x.strip()][i] = collision_labels['arr_{}'.format(i)] 73 | 74 | def scene_list(self): 75 | return self.scenename 76 | 77 | def __len__(self): 78 | return len(self.depthpath) 79 | 80 | def augment_data(self, point_clouds, object_poses_list): 81 | # Flipping along the YZ plane 82 | if np.random.random() > 0.5: 83 | flip_mat = np.array([[-1, 0, 0], 84 | [ 0, 1, 0], 85 | [ 0, 0, 1]]) 86 | point_clouds = transform_point_cloud(point_clouds, flip_mat, '3x3') 87 | for i in range(len(object_poses_list)): 88 | object_poses_list[i] = np.dot(flip_mat, object_poses_list[i]).astype(np.float32) 89 | 90 | # Rotation along up-axis/Z-axis 91 | rot_angle = (np.random.random()*np.pi/3) - np.pi/6 # -30 ~ +30 degree 92 | c, s = np.cos(rot_angle), np.sin(rot_angle) 93 | rot_mat = np.array([[1, 0, 0], 94 | [0, c,-s], 95 | [0, s, c]]) 96 | point_clouds = transform_point_cloud(point_clouds, rot_mat, '3x3') 97 | for i in range(len(object_poses_list)): 98 | object_poses_list[i] = np.dot(rot_mat, object_poses_list[i]).astype(np.float32) 99 | 100 | return point_clouds, object_poses_list 101 | 102 | def __getitem__(self, index): 103 | if self.load_label: 104 | return self.get_data_label(index) 105 | else: 106 | return self.get_data(index) 107 | 108 | def get_data(self, index, return_raw_cloud=False): 109 | color = np.array(Image.open(self.colorpath[index]), dtype=np.float32) / 255.0 110 | depth = np.array(Image.open(self.depthpath[index])) 111 | seg = np.array(Image.open(self.labelpath[index])) 112 | meta = scio.loadmat(self.metapath[index]) 113 | scene = self.scenename[index] 114 | try: 115 | intrinsic = meta['intrinsic_matrix'] 116 | factor_depth = meta['factor_depth'] 117 | except Exception as e: 118 | print(repr(e)) 119 | print(scene) 120 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 121 | 122 | # generate cloud 123 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 124 | 125 | # get valid points 126 | depth_mask = (depth > 0) 127 | seg_mask = (seg > 0) 128 | if self.remove_outlier: 129 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 130 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 131 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 132 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 133 | mask = (depth_mask & workspace_mask) 134 | else: 135 | mask = depth_mask 136 | cloud_masked = cloud[mask] 137 | color_masked = color[mask] 138 | seg_masked = seg[mask] 139 | if return_raw_cloud: 140 | return cloud_masked, color_masked 141 | 142 | # sample points 143 | if len(cloud_masked) >= self.num_points: 144 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 145 | else: 146 | idxs1 = np.arange(len(cloud_masked)) 147 | idxs2 = np.random.choice(len(cloud_masked), self.num_points-len(cloud_masked), replace=True) 148 | idxs = np.concatenate([idxs1, idxs2], axis=0) 149 | cloud_sampled = cloud_masked[idxs] 150 | color_sampled = color_masked[idxs] 151 | 152 | ret_dict = {} 153 | ret_dict['point_clouds'] = cloud_sampled.astype(np.float32) 154 | ret_dict['cloud_colors'] = color_sampled.astype(np.float32) 155 | 156 | return ret_dict 157 | 158 | def get_data_label(self, index): 159 | color = np.array(Image.open(self.colorpath[index]), dtype=np.float32) / 255.0 160 | depth = np.array(Image.open(self.depthpath[index])) 161 | seg = np.array(Image.open(self.labelpath[index])) 162 | meta = scio.loadmat(self.metapath[index]) 163 | scene = self.scenename[index] 164 | try: 165 | obj_idxs = meta['cls_indexes'].flatten().astype(np.int32) 166 | poses = meta['poses'] 167 | intrinsic = meta['intrinsic_matrix'] 168 | factor_depth = meta['factor_depth'] 169 | except Exception as e: 170 | print(repr(e)) 171 | print(scene) 172 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 173 | 174 | # generate cloud 175 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 176 | 177 | # get valid points 178 | depth_mask = (depth > 0) 179 | seg_mask = (seg > 0) 180 | if self.remove_outlier: 181 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 182 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 183 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 184 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 185 | mask = (depth_mask & workspace_mask) 186 | else: 187 | mask = depth_mask 188 | cloud_masked = cloud[mask] 189 | color_masked = color[mask] 190 | seg_masked = seg[mask] 191 | 192 | # sample points 193 | if len(cloud_masked) >= self.num_points: 194 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 195 | else: 196 | idxs1 = np.arange(len(cloud_masked)) 197 | idxs2 = np.random.choice(len(cloud_masked), self.num_points-len(cloud_masked), replace=True) 198 | idxs = np.concatenate([idxs1, idxs2], axis=0) 199 | cloud_sampled = cloud_masked[idxs] 200 | color_sampled = color_masked[idxs] 201 | seg_sampled = seg_masked[idxs] 202 | objectness_label = seg_sampled.copy() 203 | objectness_label[objectness_label>1] = 1 204 | 205 | object_poses_list = [] 206 | grasp_points_list = [] 207 | grasp_offsets_list = [] 208 | grasp_scores_list = [] 209 | grasp_tolerance_list = [] 210 | for i, obj_idx in enumerate(obj_idxs): 211 | if obj_idx not in self.valid_obj_idxs: 212 | continue 213 | if (seg_sampled == obj_idx).sum() < 50: 214 | continue 215 | object_poses_list.append(poses[:, :, i]) 216 | points, offsets, scores, tolerance = self.grasp_labels[obj_idx] 217 | collision = self.collision_labels[scene][i] #(Np, V, A, D) 218 | 219 | # remove invisible grasp points 220 | if self.remove_invisible: 221 | visible_mask = remove_invisible_grasp_points(cloud_sampled[seg_sampled==obj_idx], points, poses[:,:,i], th=0.01) 222 | points = points[visible_mask] 223 | offsets = offsets[visible_mask] 224 | scores = scores[visible_mask] 225 | tolerance = tolerance[visible_mask] 226 | collision = collision[visible_mask] 227 | 228 | idxs = np.random.choice(len(points), min(max(int(len(points)/4),300),len(points)), replace=False) 229 | grasp_points_list.append(points[idxs]) 230 | grasp_offsets_list.append(offsets[idxs]) 231 | collision = collision[idxs].copy() 232 | scores = scores[idxs].copy() 233 | scores[collision] = 0 234 | grasp_scores_list.append(scores) 235 | tolerance = tolerance[idxs].copy() 236 | tolerance[collision] = 0 237 | grasp_tolerance_list.append(tolerance) 238 | 239 | if self.augment: 240 | cloud_sampled, object_poses_list = self.augment_data(cloud_sampled, object_poses_list) 241 | 242 | ret_dict = {} 243 | ret_dict['point_clouds'] = cloud_sampled.astype(np.float32) 244 | ret_dict['cloud_colors'] = color_sampled.astype(np.float32) 245 | ret_dict['objectness_label'] = objectness_label.astype(np.int64) 246 | ret_dict['object_poses_list'] = object_poses_list 247 | ret_dict['grasp_points_list'] = grasp_points_list 248 | ret_dict['grasp_offsets_list'] = grasp_offsets_list 249 | ret_dict['grasp_labels_list'] = grasp_scores_list 250 | ret_dict['grasp_tolerance_list'] = grasp_tolerance_list 251 | 252 | return ret_dict 253 | 254 | def load_grasp_labels(root): 255 | obj_names = list(range(88)) 256 | valid_obj_idxs = [] 257 | grasp_labels = {} 258 | for i, obj_name in enumerate(tqdm(obj_names, desc='Loading grasping labels...')): 259 | if i == 18: continue 260 | valid_obj_idxs.append(i + 1) #here align with label png 261 | label = np.load(os.path.join(root, 'grasp_label', '{}_labels.npz'.format(str(i).zfill(3)))) 262 | tolerance = np.load(os.path.join(BASE_DIR, 'tolerance', '{}_tolerance.npy'.format(str(i).zfill(3)))) 263 | grasp_labels[i + 1] = (label['points'].astype(np.float32), label['offsets'].astype(np.float32), 264 | label['scores'].astype(np.float32), tolerance) 265 | 266 | return valid_obj_idxs, grasp_labels 267 | 268 | def collate_fn(batch): 269 | if type(batch[0]).__module__ == 'numpy': 270 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 271 | elif isinstance(batch[0], container_abcs.Mapping): 272 | return {key:collate_fn([d[key] for d in batch]) for key in batch[0]} 273 | elif isinstance(batch[0], container_abcs.Sequence): 274 | return [[torch.from_numpy(sample) for sample in b] for b in batch] 275 | 276 | raise TypeError("batch must contain tensors, dicts or lists; found {}".format(type(batch[0]))) 277 | 278 | if __name__ == "__main__": 279 | root = '/nvme/grm/data/graspnet' 280 | valid_obj_idxs, grasp_labels = load_grasp_labels(root) 281 | train_dataset = GraspNetDataset(root, valid_obj_idxs, grasp_labels, split='train', remove_outlier=True, remove_invisible=True, num_points=20000) 282 | print(len(train_dataset)) 283 | 284 | end_points = train_dataset[233] 285 | cloud = end_points['point_clouds'] 286 | seg = end_points['objectness_label'] 287 | print(cloud.shape) 288 | print(cloud.dtype) 289 | print(cloud[:,0].min(), cloud[:,0].max()) 290 | print(cloud[:,1].min(), cloud[:,1].max()) 291 | print(cloud[:,2].min(), cloud[:,2].max()) 292 | print(seg.shape) 293 | print((seg>0).sum()) 294 | print(seg.dtype) 295 | print(np.unique(seg)) 296 | -------------------------------------------------------------------------------- /dataset/graspnet_dataset_lazy.py: -------------------------------------------------------------------------------- 1 | """ GraspNet dataset processing with lazy loading to avoid memory hazard`. 2 | Author: chenxi-wang and ruiming-guo 3 | """ 4 | print("### Using the modified version of graspnet_dataset") 5 | import os 6 | import sys 7 | import numpy as np 8 | import scipy.io as scio 9 | from PIL import Image 10 | import pdb 11 | 12 | import torch 13 | try: 14 | from torch._six import container_abcs 15 | except ImportError: 16 | import collections.abc as container_abcs 17 | from torch.utils.data import Dataset 18 | from tqdm import tqdm 19 | 20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | ROOT_DIR = os.path.dirname(BASE_DIR) 22 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 23 | from data_utils import CameraInfo, transform_point_cloud, create_point_cloud_from_depth_image,\ 24 | get_workspace_mask, remove_invisible_grasp_points 25 | 26 | class GraspNetDataset(Dataset): 27 | def __init__(self, root, valid_obj_idxs, grasp_labels, grasp_labels_list, camera='kinect', split='train', num_points=20000, 28 | remove_outlier=False, remove_invisible=True, augment=False, load_label=True): 29 | assert(num_points<=50000) 30 | self.root = root 31 | self.split = split 32 | self.num_points = num_points 33 | self.remove_outlier = remove_outlier 34 | self.remove_invisible = remove_invisible 35 | self.valid_obj_idxs = valid_obj_idxs # a list 36 | self.grasp_labels = grasp_labels # a dict -> None 37 | self.grasp_labels_list = grasp_labels_list 38 | print(type(valid_obj_idxs), type(grasp_labels)) 39 | # pdb.set_trace() 40 | self.camera = camera 41 | self.augment = augment 42 | self.load_label = load_label 43 | self.collision_labels = {} 44 | self.collision_labels_list = {} 45 | 46 | if split == 'train': 47 | self.sceneIds = list( range(100) ) 48 | elif split == 'test': 49 | self.sceneIds = list( range(100,190) ) 50 | elif split == 'test_seen': 51 | self.sceneIds = list( range(100,130) ) 52 | elif split == 'test_similar': 53 | self.sceneIds = list( range(130,160) ) 54 | elif split == 'test_novel': 55 | self.sceneIds = list( range(160,190) ) 56 | self.sceneIds = ['scene_{}'.format(str(x).zfill(4)) for x in self.sceneIds] 57 | 58 | self.colorpath = [] 59 | self.depthpath = [] 60 | self.labelpath = [] 61 | self.metapath = [] 62 | self.scenename = [] 63 | self.frameid = [] 64 | # GraspNetDataset 是懒加载,但label不是 65 | for x in tqdm(self.sceneIds, desc = 'Loading data path and collision labels...'): 66 | for img_num in range(256): 67 | self.colorpath.append(os.path.join(root, 'scenes', x, camera, 'rgb', str(img_num).zfill(4)+'.png')) 68 | self.depthpath.append(os.path.join(root, 'scenes', x, camera, 'depth', str(img_num).zfill(4)+'.png')) 69 | self.labelpath.append(os.path.join(root, 'scenes', x, camera, 'label', str(img_num).zfill(4)+'.png')) 70 | self.metapath.append(os.path.join(root, 'scenes', x, camera, 'meta', str(img_num).zfill(4)+'.mat')) 71 | self.scenename.append(x.strip()) 72 | self.frameid.append(img_num) 73 | if self.load_label: 74 | pass 75 | def scene_list(self): 76 | return self.scenename 77 | 78 | def __len__(self): 79 | return len(self.depthpath) 80 | 81 | def augment_data(self, point_clouds, object_poses_list): 82 | # Flipping along the YZ plane 83 | if np.random.random() > 0.5: 84 | flip_mat = np.array([[-1, 0, 0], 85 | [ 0, 1, 0], 86 | [ 0, 0, 1]]) 87 | point_clouds = transform_point_cloud(point_clouds, flip_mat, '3x3') 88 | for i in range(len(object_poses_list)): 89 | object_poses_list[i] = np.dot(flip_mat, object_poses_list[i]).astype(np.float32) 90 | 91 | # Rotation along up-axis/Z-axis 92 | rot_angle = (np.random.random()*np.pi/3) - np.pi/6 # -30 ~ +30 degree 93 | c, s = np.cos(rot_angle), np.sin(rot_angle) 94 | rot_mat = np.array([[1, 0, 0], 95 | [0, c,-s], 96 | [0, s, c]]) 97 | point_clouds = transform_point_cloud(point_clouds, rot_mat, '3x3') 98 | for i in range(len(object_poses_list)): 99 | object_poses_list[i] = np.dot(rot_mat, object_poses_list[i]).astype(np.float32) 100 | 101 | return point_clouds, object_poses_list 102 | 103 | def __getitem__(self, index): 104 | if self.load_label: 105 | return self.get_data_label(index) 106 | else: 107 | return self.get_data(index) 108 | 109 | def get_data(self, index, return_raw_cloud=False): 110 | color = np.array(Image.open(self.colorpath[index]), dtype=np.float32) / 255.0 111 | depth = np.array(Image.open(self.depthpath[index])) 112 | seg = np.array(Image.open(self.labelpath[index])) 113 | meta = scio.loadmat(self.metapath[index]) 114 | scene = self.scenename[index] 115 | try: 116 | intrinsic = meta['intrinsic_matrix'] 117 | factor_depth = meta['factor_depth'] 118 | except Exception as e: 119 | print(repr(e)) 120 | print(scene) 121 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 122 | 123 | # generate cloud 124 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 125 | 126 | # get valid points 127 | depth_mask = (depth > 0) 128 | seg_mask = (seg > 0) 129 | if self.remove_outlier: 130 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 131 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 132 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 133 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 134 | mask = (depth_mask & workspace_mask) 135 | else: 136 | mask = depth_mask 137 | cloud_masked = cloud[mask] 138 | color_masked = color[mask] 139 | seg_masked = seg[mask] 140 | if return_raw_cloud: 141 | return cloud_masked, color_masked 142 | 143 | # sample points 144 | if len(cloud_masked) >= self.num_points: 145 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 146 | else: 147 | idxs1 = np.arange(len(cloud_masked)) 148 | idxs2 = np.random.choice(len(cloud_masked), self.num_points-len(cloud_masked), replace=True) 149 | idxs = np.concatenate([idxs1, idxs2], axis=0) 150 | cloud_sampled = cloud_masked[idxs] 151 | color_sampled = color_masked[idxs] 152 | 153 | ret_dict = {} 154 | ret_dict['point_clouds'] = cloud_sampled.astype(np.float32) 155 | ret_dict['cloud_colors'] = color_sampled.astype(np.float32) 156 | 157 | return ret_dict 158 | 159 | def get_data_label(self, index): 160 | color = np.array(Image.open(self.colorpath[index]), dtype=np.float32) / 255.0 161 | depth = np.array(Image.open(self.depthpath[index])) 162 | seg = np.array(Image.open(self.labelpath[index])) 163 | meta = scio.loadmat(self.metapath[index]) 164 | scene = self.scenename[index] 165 | try: 166 | obj_idxs = meta['cls_indexes'].flatten().astype(np.int32) 167 | poses = meta['poses'] 168 | intrinsic = meta['intrinsic_matrix'] 169 | factor_depth = meta['factor_depth'] 170 | except Exception as e: 171 | print(repr(e)) 172 | print(scene) 173 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 174 | 175 | # generate cloud 176 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 177 | 178 | # get valid points 179 | depth_mask = (depth > 0) 180 | seg_mask = (seg > 0) 181 | if self.remove_outlier: 182 | camera_poses = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'camera_poses.npy')) 183 | align_mat = np.load(os.path.join(self.root, 'scenes', scene, self.camera, 'cam0_wrt_table.npy')) 184 | trans = np.dot(align_mat, camera_poses[self.frameid[index]]) 185 | workspace_mask = get_workspace_mask(cloud, seg, trans=trans, organized=True, outlier=0.02) 186 | mask = (depth_mask & workspace_mask) 187 | else: 188 | mask = depth_mask 189 | cloud_masked = cloud[mask] 190 | color_masked = color[mask] 191 | seg_masked = seg[mask] 192 | 193 | # sample points 194 | if len(cloud_masked) >= self.num_points: 195 | idxs = np.random.choice(len(cloud_masked), self.num_points, replace=False) 196 | else: 197 | idxs1 = np.arange(len(cloud_masked)) 198 | idxs2 = np.random.choice(len(cloud_masked), self.num_points-len(cloud_masked), replace=True) 199 | idxs = np.concatenate([idxs1, idxs2], axis=0) 200 | cloud_sampled = cloud_masked[idxs] 201 | color_sampled = color_masked[idxs] 202 | seg_sampled = seg_masked[idxs] 203 | objectness_label = seg_sampled.copy() 204 | objectness_label[objectness_label>1] = 1 205 | 206 | object_poses_list = [] 207 | grasp_points_list = [] 208 | grasp_offsets_list = [] 209 | grasp_scores_list = [] 210 | grasp_tolerance_list = [] 211 | 212 | collision_labels_per_scene = self._load_collision_labels(scene) 213 | 214 | for i, obj_idx in enumerate(obj_idxs): 215 | if obj_idx not in self.valid_obj_idxs: 216 | continue 217 | if (seg_sampled == obj_idx).sum() < 50: 218 | continue 219 | object_poses_list.append(poses[:, :, i]) 220 | points, offsets, scores, tolerance = self._load_post(obj_idx) 221 | collision = collision_labels_per_scene[i] #(Np, V, A, D) 222 | 223 | # remove invisible grasp points 224 | if self.remove_invisible: 225 | visible_mask = remove_invisible_grasp_points(cloud_sampled[seg_sampled==obj_idx], points, poses[:,:,i], th=0.01) 226 | points = points[visible_mask] 227 | offsets = offsets[visible_mask] 228 | scores = scores[visible_mask] 229 | tolerance = tolerance[visible_mask] 230 | collision = collision[visible_mask] 231 | 232 | idxs = np.random.choice(len(points), min(max(int(len(points)/4),300),len(points)), replace=False) 233 | grasp_points_list.append(points[idxs]) 234 | grasp_offsets_list.append(offsets[idxs]) 235 | collision = collision[idxs].copy() 236 | scores = scores[idxs].copy() 237 | scores[collision] = 0 238 | grasp_scores_list.append(scores) 239 | tolerance = tolerance[idxs].copy() 240 | tolerance[collision] = 0 241 | grasp_tolerance_list.append(tolerance) 242 | 243 | if self.augment: 244 | cloud_sampled, object_poses_list = self.augment_data(cloud_sampled, object_poses_list) 245 | 246 | ret_dict = {} 247 | ret_dict['point_clouds'] = cloud_sampled.astype(np.float32) 248 | ret_dict['cloud_colors'] = color_sampled.astype(np.float32) 249 | ret_dict['objectness_label'] = objectness_label.astype(np.int64) 250 | ret_dict['object_poses_list'] = object_poses_list 251 | ret_dict['grasp_points_list'] = grasp_points_list 252 | ret_dict['grasp_offsets_list'] = grasp_offsets_list 253 | ret_dict['grasp_labels_list'] = grasp_scores_list 254 | ret_dict['grasp_tolerance_list'] = grasp_tolerance_list 255 | 256 | return ret_dict 257 | 258 | def _load_post(self, obj_idx): 259 | label_path, tolerance_path = self.grasp_labels_list[obj_idx] 260 | 261 | label = np.load(label_path) 262 | tolerance = np.load(tolerance_path) 263 | 264 | return ( 265 | label['points'].astype(np.float32), # (3459, 3) 266 | label['offsets'].astype(np.float32),# (3459, 300, 12, 4, 3) 267 | label['scores'].astype(np.float32), # (3459, 300, 12, 4) 268 | tolerance # (3459, 300, 12, 4) 269 | ) 270 | 271 | def _load_collision_labels(self, scene): 272 | collision_labels = np.load(os.path.join(self.root, 'collision_label', scene.strip(), 'collision_labels.npz')) 273 | collision_labels_per_scene = {} 274 | 275 | for i in range(len(collision_labels)): 276 | collision_labels_per_scene[i] = collision_labels['arr_{}'.format(i)] 277 | 278 | return collision_labels_per_scene 279 | 280 | 281 | def load_grasp_labels(root): 282 | obj_names = list(range(88)) 283 | valid_obj_idxs = [] 284 | grasp_labels = {} # 非得整个以数字为下标的dict,是用js用多了吧? 285 | for i, obj_name in enumerate(tqdm(obj_names, desc='Loading grasping labels...')): 286 | if i == 18: continue 287 | valid_obj_idxs.append(i + 1) #here align with label png 288 | label = np.load(os.path.join(root, 'grasp_label', '{}_labels.npz'.format(str(i).zfill(3)))) 289 | tolerance = np.load(os.path.join(BASE_DIR, 'tolerance', '{}_tolerance.npy'.format(str(i).zfill(3)))) 290 | grasp_labels[i + 1] = ( 291 | label['points'].astype(np.float32), # (3459, 3) 292 | label['offsets'].astype(np.float32),# (3459, 300, 12, 4, 3) 293 | label['scores'].astype(np.float32), # (3459, 300, 12, 4) 294 | tolerance # (3459, 300, 12, 4) 295 | ) 296 | 297 | return valid_obj_idxs, grasp_labels 298 | 299 | def load_grasp_labels_list(root): 300 | obj_names = list(range(88)) 301 | valid_obj_idxs = [] 302 | grasp_labels_list = {} # 非得整个以数字为下标的dict,是用js用多了吧? 303 | for i, obj_name in enumerate(tqdm(obj_names, desc='Loading grasping labels Paths...')): 304 | if i == 18: continue # ??? 305 | valid_obj_idxs.append(i + 1) # here align with label png 306 | label_path = os.path.join(root, 'grasp_label', '{}_labels.npz'.format(str(i).zfill(3))) 307 | tolerance_path = os.path.join(BASE_DIR, 'tolerance', '{}_tolerance.npy'.format(str(i).zfill(3))) 308 | grasp_labels_list[i + 1] = (label_path, tolerance_path) 309 | 310 | return valid_obj_idxs, grasp_labels_list 311 | 312 | def collate_fn(batch): 313 | if type(batch[0]).__module__ == 'numpy': 314 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 315 | elif isinstance(batch[0], container_abcs.Mapping): 316 | return {key:collate_fn([d[key] for d in batch]) for key in batch[0]} 317 | elif isinstance(batch[0], container_abcs.Sequence): 318 | return [[torch.from_numpy(sample) for sample in b] for b in batch] 319 | 320 | raise TypeError("batch must contain tensors, dicts or lists; found {}".format(type(batch[0]))) 321 | 322 | if __name__ == "__main__": 323 | root = '/nvme/grm/data/graspnet' 324 | valid_obj_idxs, grasp_labels = load_grasp_labels(root) 325 | train_dataset = GraspNetDataset(root, valid_obj_idxs, grasp_labels, split='train', remove_outlier=True, remove_invisible=True, num_points=20000) 326 | print(len(train_dataset)) 327 | 328 | end_points = train_dataset[233] 329 | cloud = end_points['point_clouds'] 330 | seg = end_points['objectness_label'] 331 | print(cloud.shape) 332 | print(cloud.dtype) 333 | print(cloud[:,0].min(), cloud[:,0].max()) 334 | print(cloud[:,1].min(), cloud[:,1].max()) 335 | print(cloud[:,2].min(), cloud[:,2].max()) 336 | print(seg.shape) 337 | print((seg>0).sum()) 338 | print(seg.dtype) 339 | print(np.unique(seg)) 340 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import pytorch_utils as pt_utils 18 | import sys 19 | 20 | try: 21 | import builtins 22 | except: 23 | import __builtin__ as builtins 24 | 25 | try: 26 | import pointnet2._ext as _ext 27 | except ImportError: 28 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 29 | raise ImportError( 30 | "Could not import _ext module.\n" 31 | "Please see the setup instructions in the README: " 32 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 33 | ) 34 | 35 | if False: 36 | # Workaround for type hints without depending on the `typing` module 37 | from typing import * 38 | 39 | 40 | class RandomDropout(nn.Module): 41 | def __init__(self, p=0.5, inplace=False): 42 | super(RandomDropout, self).__init__() 43 | self.p = p 44 | self.inplace = inplace 45 | 46 | def forward(self, X): 47 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 48 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 49 | 50 | 51 | class FurthestPointSampling(Function): 52 | @staticmethod 53 | def forward(ctx, xyz, npoint): 54 | # type: (Any, torch.Tensor, int) -> torch.Tensor 55 | r""" 56 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 57 | minimum distance 58 | 59 | Parameters 60 | ---------- 61 | xyz : torch.Tensor 62 | (B, N, 3) tensor where N > npoint 63 | npoint : int32 64 | number of features in the sampled set 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | (B, npoint) tensor containing the set 70 | """ 71 | return _ext.furthest_point_sampling(xyz, npoint) 72 | 73 | @staticmethod 74 | def backward(xyz, a=None): 75 | return None, None 76 | 77 | 78 | furthest_point_sample = FurthestPointSampling.apply 79 | 80 | 81 | class GatherOperation(Function): 82 | @staticmethod 83 | def forward(ctx, features, idx): 84 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 85 | r""" 86 | 87 | Parameters 88 | ---------- 89 | features : torch.Tensor 90 | (B, C, N) tensor 91 | 92 | idx : torch.Tensor 93 | (B, npoint) tensor of the features to gather 94 | 95 | Returns 96 | ------- 97 | torch.Tensor 98 | (B, C, npoint) tensor 99 | """ 100 | 101 | _, C, N = features.size() 102 | 103 | ctx.for_backwards = (idx, C, N) 104 | 105 | return _ext.gather_points(features, idx) 106 | 107 | @staticmethod 108 | def backward(ctx, grad_out): 109 | idx, C, N = ctx.for_backwards 110 | 111 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 112 | return grad_features, None 113 | 114 | 115 | gather_operation = GatherOperation.apply 116 | 117 | 118 | class ThreeNN(Function): 119 | @staticmethod 120 | def forward(ctx, unknown, known): 121 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 122 | r""" 123 | Find the three nearest neighbors of unknown in known 124 | Parameters 125 | ---------- 126 | unknown : torch.Tensor 127 | (B, n, 3) tensor of known features 128 | known : torch.Tensor 129 | (B, m, 3) tensor of unknown features 130 | 131 | Returns 132 | ------- 133 | dist : torch.Tensor 134 | (B, n, 3) l2 distance to the three nearest neighbors 135 | idx : torch.Tensor 136 | (B, n, 3) index of 3 nearest neighbors 137 | """ 138 | dist2, idx = _ext.three_nn(unknown, known) 139 | 140 | return torch.sqrt(dist2), idx 141 | 142 | @staticmethod 143 | def backward(ctx, a=None, b=None): 144 | return None, None 145 | 146 | 147 | three_nn = ThreeNN.apply 148 | 149 | 150 | class ThreeInterpolate(Function): 151 | @staticmethod 152 | def forward(ctx, features, idx, weight): 153 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 154 | r""" 155 | Performs weight linear interpolation on 3 features 156 | Parameters 157 | ---------- 158 | features : torch.Tensor 159 | (B, c, m) Features descriptors to be interpolated from 160 | idx : torch.Tensor 161 | (B, n, 3) three nearest neighbors of the target features in features 162 | weight : torch.Tensor 163 | (B, n, 3) weights 164 | 165 | Returns 166 | ------- 167 | torch.Tensor 168 | (B, c, n) tensor of the interpolated features 169 | """ 170 | B, c, m = features.size() 171 | n = idx.size(1) 172 | 173 | ctx.three_interpolate_for_backward = (idx, weight, m) 174 | 175 | return _ext.three_interpolate(features, idx, weight) 176 | 177 | @staticmethod 178 | def backward(ctx, grad_out): 179 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 180 | r""" 181 | Parameters 182 | ---------- 183 | grad_out : torch.Tensor 184 | (B, c, n) tensor with gradients of ouputs 185 | 186 | Returns 187 | ------- 188 | grad_features : torch.Tensor 189 | (B, c, m) tensor with gradients of features 190 | 191 | None 192 | 193 | None 194 | """ 195 | idx, weight, m = ctx.three_interpolate_for_backward 196 | 197 | grad_features = _ext.three_interpolate_grad( 198 | grad_out.contiguous(), idx, weight, m 199 | ) 200 | 201 | return grad_features, None, None 202 | 203 | 204 | three_interpolate = ThreeInterpolate.apply 205 | 206 | 207 | class GroupingOperation(Function): 208 | @staticmethod 209 | def forward(ctx, features, idx): 210 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 211 | r""" 212 | 213 | Parameters 214 | ---------- 215 | features : torch.Tensor 216 | (B, C, N) tensor of features to group 217 | idx : torch.Tensor 218 | (B, npoint, nsample) tensor containing the indicies of features to group with 219 | 220 | Returns 221 | ------- 222 | torch.Tensor 223 | (B, C, npoint, nsample) tensor 224 | """ 225 | B, nfeatures, nsample = idx.size() 226 | _, C, N = features.size() 227 | 228 | ctx.for_backwards = (idx, N) 229 | 230 | return _ext.group_points(features, idx) 231 | 232 | @staticmethod 233 | def backward(ctx, grad_out): 234 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 235 | r""" 236 | 237 | Parameters 238 | ---------- 239 | grad_out : torch.Tensor 240 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 241 | 242 | Returns 243 | ------- 244 | torch.Tensor 245 | (B, C, N) gradient of the features 246 | None 247 | """ 248 | idx, N = ctx.for_backwards 249 | 250 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 251 | 252 | return grad_features, None 253 | 254 | 255 | grouping_operation = GroupingOperation.apply 256 | 257 | 258 | class BallQuery(Function): 259 | @staticmethod 260 | def forward(ctx, radius, nsample, xyz, new_xyz): 261 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 262 | r""" 263 | 264 | Parameters 265 | ---------- 266 | radius : float 267 | radius of the balls 268 | nsample : int 269 | maximum number of features in the balls 270 | xyz : torch.Tensor 271 | (B, N, 3) xyz coordinates of the features 272 | new_xyz : torch.Tensor 273 | (B, npoint, 3) centers of the ball query 274 | 275 | Returns 276 | ------- 277 | torch.Tensor 278 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 279 | """ 280 | return _ext.ball_query(new_xyz, xyz, radius, nsample) 281 | 282 | @staticmethod 283 | def backward(ctx, a=None): 284 | return None, None, None, None 285 | 286 | 287 | ball_query = BallQuery.apply 288 | 289 | 290 | class QueryAndGroup(nn.Module): 291 | r""" 292 | Groups with a ball query of radius 293 | 294 | Parameters 295 | --------- 296 | radius : float32 297 | Radius of ball 298 | nsample : int32 299 | Maximum number of features to gather in the ball 300 | """ 301 | 302 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 303 | # type: (QueryAndGroup, float, int, bool) -> None 304 | super(QueryAndGroup, self).__init__() 305 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 306 | self.ret_grouped_xyz = ret_grouped_xyz 307 | self.normalize_xyz = normalize_xyz 308 | self.sample_uniformly = sample_uniformly 309 | self.ret_unique_cnt = ret_unique_cnt 310 | if self.ret_unique_cnt: 311 | assert(self.sample_uniformly) 312 | 313 | def forward(self, xyz, new_xyz, features=None): 314 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 315 | r""" 316 | Parameters 317 | ---------- 318 | xyz : torch.Tensor 319 | xyz coordinates of the features (B, N, 3) 320 | new_xyz : torch.Tensor 321 | centriods (B, npoint, 3) 322 | features : torch.Tensor 323 | Descriptors of the features (B, C, N) 324 | 325 | Returns 326 | ------- 327 | new_features : torch.Tensor 328 | (B, 3 + C, npoint, nsample) tensor 329 | """ 330 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 331 | 332 | if self.sample_uniformly: 333 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 334 | for i_batch in range(idx.shape[0]): 335 | for i_region in range(idx.shape[1]): 336 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 337 | num_unique = unique_ind.shape[0] 338 | unique_cnt[i_batch, i_region] = num_unique 339 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 340 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 341 | idx[i_batch, i_region, :] = all_ind 342 | 343 | 344 | xyz_trans = xyz.transpose(1, 2).contiguous() 345 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 346 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 347 | if self.normalize_xyz: 348 | grouped_xyz /= self.radius 349 | 350 | if features is not None: 351 | grouped_features = grouping_operation(features, idx) 352 | if self.use_xyz: 353 | new_features = torch.cat( 354 | [grouped_xyz, grouped_features], dim=1 355 | ) # (B, C + 3, npoint, nsample) 356 | else: 357 | new_features = grouped_features 358 | else: 359 | assert ( 360 | self.use_xyz 361 | ), "Cannot have not features and not use xyz as a feature!" 362 | new_features = grouped_xyz 363 | 364 | ret = [new_features] 365 | if self.ret_grouped_xyz: 366 | ret.append(grouped_xyz) 367 | if self.ret_unique_cnt: 368 | ret.append(unique_cnt) 369 | if len(ret) == 1: 370 | return ret[0] 371 | else: 372 | return tuple(ret) 373 | 374 | 375 | class GroupAll(nn.Module): 376 | r""" 377 | Groups all features 378 | 379 | Parameters 380 | --------- 381 | """ 382 | 383 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 384 | # type: (GroupAll, bool) -> None 385 | super(GroupAll, self).__init__() 386 | self.use_xyz = use_xyz 387 | 388 | def forward(self, xyz, new_xyz, features=None): 389 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 390 | r""" 391 | Parameters 392 | ---------- 393 | xyz : torch.Tensor 394 | xyz coordinates of the features (B, N, 3) 395 | new_xyz : torch.Tensor 396 | Ignored 397 | features : torch.Tensor 398 | Descriptors of the features (B, C, N) 399 | 400 | Returns 401 | ------- 402 | new_features : torch.Tensor 403 | (B, C + 3, 1, N) tensor 404 | """ 405 | 406 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 407 | if features is not None: 408 | grouped_features = features.unsqueeze(2) 409 | if self.use_xyz: 410 | new_features = torch.cat( 411 | [grouped_xyz, grouped_features], dim=1 412 | ) # (B, 3 + C, 1, N) 413 | else: 414 | new_features = grouped_features 415 | else: 416 | new_features = grouped_xyz 417 | 418 | if self.ret_grouped_xyz: 419 | return new_features, grouped_xyz 420 | else: 421 | return new_features 422 | 423 | 424 | class CylinderQuery(Function): 425 | @staticmethod 426 | def forward(ctx, radius, hmin, hmax, nsample, xyz, new_xyz, rot): 427 | # type: (Any, float, float, float, int, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 428 | r""" 429 | 430 | Parameters 431 | ---------- 432 | radius : float 433 | radius of the cylinders 434 | hmin, hmax : float 435 | endpoints of cylinder height in x-rotation axis 436 | nsample : int 437 | maximum number of features in the cylinders 438 | xyz : torch.Tensor 439 | (B, N, 3) xyz coordinates of the features 440 | new_xyz : torch.Tensor 441 | (B, npoint, 3) centers of the cylinder query 442 | rot: torch.Tensor 443 | (B, npoint, 9) flatten rotation matrices from 444 | cylinder frame to world frame 445 | 446 | Returns 447 | ------- 448 | torch.Tensor 449 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 450 | """ 451 | return _ext.cylinder_query(new_xyz, xyz, rot, radius, hmin, hmax, nsample) 452 | 453 | @staticmethod 454 | def backward(ctx, a=None): 455 | return None, None, None, None, None, None, None 456 | 457 | 458 | cylinder_query = CylinderQuery.apply 459 | 460 | 461 | class CylinderQueryAndGroup(nn.Module): 462 | r""" 463 | Groups with a cylinder query of radius and height 464 | 465 | Parameters 466 | --------- 467 | radius : float32 468 | Radius of cylinder 469 | hmin, hmax: float32 470 | endpoints of cylinder height in x-rotation axis 471 | nsample : int32 472 | Maximum number of features to gather in the ball 473 | """ 474 | 475 | def __init__(self, radius, hmin, hmax, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, rotate_xyz=True, sample_uniformly=False, ret_unique_cnt=False): 476 | # type: (CylinderQueryAndGroup, float, float, float, int, bool) -> None 477 | super(CylinderQueryAndGroup, self).__init__() 478 | self.radius, self.nsample, self.hmin, self.hmax, = radius, nsample, hmin, hmax 479 | self.use_xyz = use_xyz 480 | self.ret_grouped_xyz = ret_grouped_xyz 481 | self.normalize_xyz = normalize_xyz 482 | self.rotate_xyz = rotate_xyz 483 | self.sample_uniformly = sample_uniformly 484 | self.ret_unique_cnt = ret_unique_cnt 485 | if self.ret_unique_cnt: 486 | assert(self.sample_uniformly) 487 | 488 | def forward(self, xyz, new_xyz, rot, features=None): 489 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 490 | r""" 491 | Parameters 492 | ---------- 493 | xyz : torch.Tensor 494 | xyz coordinates of the features (B, N, 3) 495 | new_xyz : torch.Tensor 496 | centriods (B, npoint, 3) 497 | rot : torch.Tensor 498 | rotation matrices (B, npoint, 3, 3) 499 | features : torch.Tensor 500 | Descriptors of the features (B, C, N) 501 | 502 | Returns 503 | ------- 504 | new_features : torch.Tensor 505 | (B, 3 + C, npoint, nsample) tensor 506 | """ 507 | B, npoint, _ = new_xyz.size() 508 | idx = cylinder_query(self.radius, self.hmin, self.hmax, self.nsample, xyz, new_xyz, rot.view(B, npoint, 9)) 509 | 510 | if self.sample_uniformly: 511 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 512 | for i_batch in range(idx.shape[0]): 513 | for i_region in range(idx.shape[1]): 514 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 515 | num_unique = unique_ind.shape[0] 516 | unique_cnt[i_batch, i_region] = num_unique 517 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 518 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 519 | idx[i_batch, i_region, :] = all_ind 520 | 521 | 522 | xyz_trans = xyz.transpose(1, 2).contiguous() 523 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 524 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 525 | if self.normalize_xyz: 526 | grouped_xyz /= self.radius 527 | if self.rotate_xyz: 528 | grouped_xyz_ = grouped_xyz.permute(0, 2, 3, 1).contiguous() # (B, npoint, nsample, 3) 529 | grouped_xyz_ = torch.matmul(grouped_xyz_, rot) 530 | grouped_xyz = grouped_xyz_.permute(0, 3, 1, 2).contiguous() 531 | 532 | 533 | if features is not None: 534 | grouped_features = grouping_operation(features, idx) 535 | if self.use_xyz: 536 | new_features = torch.cat( 537 | [grouped_xyz, grouped_features], dim=1 538 | ) # (B, C + 3, npoint, nsample) 539 | else: 540 | new_features = grouped_features 541 | else: 542 | assert ( 543 | self.use_xyz 544 | ), "Cannot have not features and not use xyz as a feature!" 545 | new_features = grouped_xyz 546 | 547 | ret = [new_features] 548 | if self.ret_grouped_xyz: 549 | ret.append(grouped_xyz) 550 | if self.ret_unique_cnt: 551 | ret.append(unique_cnt) 552 | if len(ret) == 1: 553 | return ret[0] 554 | else: 555 | return tuple(ret) --------------------------------------------------------------------------------