├── LICENSE ├── README.md ├── dataset └── README.md ├── doc └── teaser.jpg ├── kitti ├── image_sets │ ├── test.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt ├── kitti_object.py ├── kitti_util.py ├── prepare_data.py └── rgb_detections │ ├── rgb_detection_train.txt │ └── rgb_detection_val.txt ├── mayavi ├── kitti_sample_scan.txt ├── mayavi_install.sh ├── test_drawline.py └── viz_util.py ├── models ├── frustum_pointnets_v1.py ├── frustum_pointnets_v2.py ├── model_util.py ├── pointnet_util.py ├── tf_ops │ ├── 3d_interpolation │ │ ├── interpolate.cpp │ │ ├── tf_interpolate.cpp │ │ ├── tf_interpolate.py │ │ ├── tf_interpolate_compile.sh │ │ ├── tf_interpolate_op_test.py │ │ └── visu_interpolation.py │ ├── grouping │ │ ├── compile.sh │ │ ├── query_ball_point.cpp │ │ ├── query_ball_point.cu │ │ ├── query_ball_point_block.cu │ │ ├── query_ball_point_grid.cu │ │ ├── selection_sort.cpp │ │ ├── selection_sort.cu │ │ ├── selection_sort_const.cu │ │ ├── test_knn.py │ │ ├── tf_grouping.cpp │ │ ├── tf_grouping.py │ │ ├── tf_grouping_compile.sh │ │ ├── tf_grouping_g.cu │ │ └── tf_grouping_op_test.py │ └── sampling │ │ ├── tf_sampling.cpp │ │ ├── tf_sampling.py │ │ ├── tf_sampling_compile.sh │ │ └── tf_sampling_g.cu └── tf_util.py ├── scripts ├── command_prep_data.sh ├── command_test_v1.sh ├── command_test_v2.sh ├── command_train_v1.sh └── command_train_v2.sh ├── sunrgbd ├── README.md ├── sunrgbd_data │ ├── box3d_dimensions.pickle │ ├── cluster_box3d.py │ ├── matlab │ │ ├── README.txt │ │ ├── SUNRGBDtoolbox │ │ │ ├── README.txt │ │ │ ├── demo.m │ │ │ ├── draw │ │ │ │ ├── drawRoom.m │ │ │ │ ├── drawSeg.m │ │ │ │ ├── draw_square_3d.m │ │ │ │ ├── myObjectColor.m │ │ │ │ ├── plotcube.m │ │ │ │ ├── vis_cube.m │ │ │ │ ├── vis_line.m │ │ │ │ ├── vis_point_cloud.m │ │ │ │ └── visulize_wholeroom.m │ │ │ ├── extract_rgbd_data.m │ │ │ ├── getSequenceName.m │ │ │ ├── jsonlab │ │ │ │ ├── AUTHORS.txt │ │ │ │ ├── ChangeLog.txt │ │ │ │ ├── LICENSE_BSD.txt │ │ │ │ ├── README.txt │ │ │ │ ├── jsonopt.m │ │ │ │ ├── loadjson.m │ │ │ │ ├── loadubjson.m │ │ │ │ ├── mergestruct.m │ │ │ │ ├── savejson.m │ │ │ │ ├── saveubjson.m │ │ │ │ └── varargin2struct.m │ │ │ ├── mBB │ │ │ │ ├── PolygonClip.dll │ │ │ │ ├── PolygonClip.mexa64 │ │ │ │ ├── PolygonClip.mexmaci64 │ │ │ │ ├── bb3dOverlapCloseForm.m │ │ │ │ ├── create_bounding_box_3d.m │ │ │ │ ├── cuboidIntersectionVolume.c │ │ │ │ ├── cuboidIntersectionVolume.mexa64 │ │ │ │ ├── cuboidIntersectionVolume.mexmaci64 │ │ │ │ ├── cuboidVolume.m │ │ │ │ ├── get_corners_of_bb3d.m │ │ │ │ ├── project3dPtsTo2d.m │ │ │ │ └── projectStructBbsTo2d.m │ │ │ ├── order_basis.m │ │ │ ├── readData │ │ │ │ ├── read3dPoints.m │ │ │ │ └── read_3d_pts_general.m │ │ │ ├── readframeSUNRGBD.m │ │ │ └── utils │ │ │ │ ├── file2string.m │ │ │ │ └── findsubstring.m │ │ └── detection │ │ │ ├── benchmark_groundtruth.m │ │ │ ├── computePRCurve3D.m │ │ │ ├── extract_gt_boxes.m │ │ │ ├── get_average_precision.m │ │ │ └── script_3Deval.m │ ├── sunrgbd_data.py │ └── utils.py └── sunrgbd_detection │ ├── ap_curves │ ├── figure_1.png │ ├── figure_10.png │ ├── figure_2.png │ ├── figure_3.png │ ├── figure_4.png │ ├── figure_5.png │ ├── figure_6.png │ ├── figure_7.png │ ├── figure_8.png │ └── figure_9.png │ ├── compare_matlab_and_python_eval.py │ ├── eval_det.py │ ├── evaluate.py │ ├── frustum_pointnets_v1_sunrgbd.py │ ├── gt_boxes │ ├── bathtub_gt_boxes.dat │ ├── bathtub_gt_imgids.txt │ ├── bed_gt_boxes.dat │ ├── bed_gt_imgids.txt │ ├── bookshelf_gt_boxes.dat │ ├── bookshelf_gt_imgids.txt │ ├── chair_gt_boxes.dat │ ├── chair_gt_imgids.txt │ ├── desk_gt_boxes.dat │ ├── desk_gt_imgids.txt │ ├── dresser_gt_boxes.dat │ ├── dresser_gt_imgids.txt │ ├── night_stand_gt_boxes.dat │ ├── night_stand_gt_imgids.txt │ ├── sofa_gt_boxes.dat │ ├── sofa_gt_imgids.txt │ ├── table_gt_boxes.dat │ ├── table_gt_imgids.txt │ ├── toilet_gt_boxes.dat │ └── toilet_gt_imgids.txt │ ├── model_util_sunrgbd.py │ ├── roi_seg_box3d_dataset.py │ ├── test_one_hot.py │ ├── train_one_hot.py │ ├── train_util.py │ ├── viz.py │ └── viz_eval.py └── train ├── box_util.py ├── kitti_eval ├── README.md ├── compile.sh ├── evaluate_object_3d_offline ├── evaluate_object_3d_offline.cpp └── mail.h ├── provider.py ├── test.py ├── train.py └── train_util.py /dataset/README.md: -------------------------------------------------------------------------------- 1 | Download KITTI 3D object detection data and organize the folders as follows: 2 | 3 | dataset/KITTI/object/ 4 | 5 | training/ 6 | calib/ 7 | image_2/ 8 | label_2/ 9 | velodyne/ 10 | 11 | testing/ 12 | calib/ 13 | image_2/ 14 | velodyne/ 15 | -------------------------------------------------------------------------------- /doc/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/doc/teaser.jpg -------------------------------------------------------------------------------- /kitti/kitti_object.py: -------------------------------------------------------------------------------- 1 | ''' Helper class and functions for loading KITTI objects 2 | 3 | Author: Charles R. Qi 4 | Date: September 2017 5 | ''' 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | import numpy as np 11 | import cv2 12 | from PIL import Image 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | ROOT_DIR = os.path.dirname(BASE_DIR) 15 | sys.path.append(os.path.join(ROOT_DIR, 'mayavi')) 16 | import kitti_util as utils 17 | 18 | try: 19 | raw_input # Python 2 20 | except NameError: 21 | raw_input = input # Python 3 22 | 23 | 24 | class kitti_object(object): 25 | '''Load and parse object data into a usable format.''' 26 | 27 | def __init__(self, root_dir, split='training'): 28 | '''root_dir contains training and testing folders''' 29 | self.root_dir = root_dir 30 | self.split = split 31 | self.split_dir = os.path.join(root_dir, split) 32 | 33 | if split == 'training': 34 | self.num_samples = 7481 35 | elif split == 'testing': 36 | self.num_samples = 7518 37 | else: 38 | print('Unknown split: %s' % (split)) 39 | exit(-1) 40 | 41 | self.image_dir = os.path.join(self.split_dir, 'image_2') 42 | self.calib_dir = os.path.join(self.split_dir, 'calib') 43 | self.lidar_dir = os.path.join(self.split_dir, 'velodyne') 44 | self.label_dir = os.path.join(self.split_dir, 'label_2') 45 | 46 | def __len__(self): 47 | return self.num_samples 48 | 49 | def get_image(self, idx): 50 | assert(idx=xmin) & \ 143 | (pts_2d[:,1]=ymin) 144 | fov_inds = fov_inds & (pc_velo[:,0]>clip_distance) 145 | imgfov_pc_velo = pc_velo[fov_inds,:] 146 | if return_more: 147 | return imgfov_pc_velo, pts_2d, fov_inds 148 | else: 149 | return imgfov_pc_velo 150 | 151 | def show_lidar_with_boxes(pc_velo, objects, calib, 152 | img_fov=False, img_width=None, img_height=None): 153 | ''' Show all LiDAR points. 154 | Draw 3d box in LiDAR point cloud (in velo coord system) ''' 155 | if 'mlab' not in sys.modules: import mayavi.mlab as mlab 156 | from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d 157 | 158 | print(('All point num: ', pc_velo.shape[0])) 159 | fig = mlab.figure(figure=None, bgcolor=(0,0,0), 160 | fgcolor=None, engine=None, size=(1000, 500)) 161 | if img_fov: 162 | pc_velo = get_lidar_in_image_fov(pc_velo, calib, 0, 0, 163 | img_width, img_height) 164 | print(('FOV point num: ', pc_velo.shape[0])) 165 | draw_lidar(pc_velo, fig=fig) 166 | 167 | for obj in objects: 168 | if obj.type=='DontCare':continue 169 | # Draw 3d bounding box 170 | box3d_pts_2d, box3d_pts_3d = utils.compute_box_3d(obj, calib.P) 171 | box3d_pts_3d_velo = calib.project_rect_to_velo(box3d_pts_3d) 172 | # Draw heading arrow 173 | ori3d_pts_2d, ori3d_pts_3d = utils.compute_orientation_3d(obj, calib.P) 174 | ori3d_pts_3d_velo = calib.project_rect_to_velo(ori3d_pts_3d) 175 | x1,y1,z1 = ori3d_pts_3d_velo[0,:] 176 | x2,y2,z2 = ori3d_pts_3d_velo[1,:] 177 | draw_gt_boxes3d([box3d_pts_3d_velo], fig=fig) 178 | mlab.plot3d([x1, x2], [y1, y2], [z1,z2], color=(0.5,0.5,0.5), 179 | tube_radius=None, line_width=1, figure=fig) 180 | mlab.show(1) 181 | 182 | def show_lidar_on_image(pc_velo, img, calib, img_width, img_height): 183 | ''' Project LiDAR points to image ''' 184 | imgfov_pc_velo, pts_2d, fov_inds = get_lidar_in_image_fov(pc_velo, 185 | calib, 0, 0, img_width, img_height, True) 186 | imgfov_pts_2d = pts_2d[fov_inds,:] 187 | imgfov_pc_rect = calib.project_velo_to_rect(imgfov_pc_velo) 188 | 189 | import matplotlib.pyplot as plt 190 | cmap = plt.cm.get_cmap('hsv', 256) 191 | cmap = np.array([cmap(i) for i in range(256)])[:,:3]*255 192 | 193 | for i in range(imgfov_pts_2d.shape[0]): 194 | depth = imgfov_pc_rect[i,2] 195 | color = cmap[int(640.0/depth),:] 196 | cv2.circle(img, (int(np.round(imgfov_pts_2d[i,0])), 197 | int(np.round(imgfov_pts_2d[i,1]))), 198 | 2, color=tuple(color), thickness=-1) 199 | Image.fromarray(img).show() 200 | return img 201 | 202 | def dataset_viz(): 203 | dataset = kitti_object(os.path.join(ROOT_DIR, 'dataset/KITTI/object')) 204 | 205 | for data_idx in range(len(dataset)): 206 | # Load data from dataset 207 | objects = dataset.get_label_objects(data_idx) 208 | objects[0].print_object() 209 | img = dataset.get_image(data_idx) 210 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 211 | img_height, img_width, img_channel = img.shape 212 | print(('Image shape: ', img.shape)) 213 | pc_velo = dataset.get_lidar(data_idx)[:,0:3] 214 | calib = dataset.get_calibration(data_idx) 215 | 216 | # Draw 2d and 3d boxes on image 217 | show_image_with_boxes(img, objects, calib, False) 218 | raw_input() 219 | # Show all LiDAR points. Draw 3d box in LiDAR point cloud 220 | show_lidar_with_boxes(pc_velo, objects, calib, True, img_width, img_height) 221 | raw_input() 222 | 223 | if __name__=='__main__': 224 | import mayavi.mlab as mlab 225 | from viz_util import draw_lidar_simple, draw_lidar, draw_gt_boxes3d 226 | dataset_viz() 227 | -------------------------------------------------------------------------------- /mayavi/mayavi_install.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | # Install Mayavi (scientific data visualization and plotting in Python) on Ubuntu 3 | # Ref: http://docs.enthought.com/mayavi/mayavi/installation.html 4 | sudo apt-get install python-vtk python-qt4 python-qt4-gl python-setuptools python-numpy python-configobj 5 | sudo pip install mayavi 6 | -------------------------------------------------------------------------------- /mayavi/test_drawline.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from mayavi.mlab import * 3 | 4 | def test_plot3d(): 5 | """Generates a pretty set of lines.""" 6 | n_mer, n_long = 6, 11 7 | pi = numpy.pi 8 | dphi = pi / 1000.0 9 | phi = numpy.arange(0.0, 2 * pi + 0.5 * dphi, dphi) 10 | mu = phi * n_mer 11 | x = numpy.cos(mu) * (1 + numpy.cos(n_long * mu / n_mer) * 0.5) 12 | y = numpy.sin(mu) * (1 + numpy.cos(n_long * mu / n_mer) * 0.5) 13 | z = numpy.sin(n_long * mu / n_mer) * 0.5 14 | 15 | l = plot3d(x, y, z, numpy.sin(mu), tube_radius=0.025, colormap='Spectral') 16 | return l 17 | 18 | test_plot3d() 19 | raw_input() 20 | -------------------------------------------------------------------------------- /mayavi/viz_util.py: -------------------------------------------------------------------------------- 1 | ''' Visualization code for point clouds and 3D bounding boxes with mayavi. 2 | 3 | Modified by Charles R. Qi 4 | Date: September 2017 5 | 6 | Ref: https://github.com/hengck23/didi-udacity-2017/blob/master/baseline-04/kitti_data/draw.py 7 | ''' 8 | 9 | import numpy as np 10 | import mayavi.mlab as mlab 11 | 12 | try: 13 | raw_input # Python 2 14 | except NameError: 15 | raw_input = input # Python 3 16 | 17 | 18 | def draw_lidar_simple(pc, color=None): 19 | ''' Draw lidar points. simplest set up. ''' 20 | fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1600, 1000)) 21 | if color is None: color = pc[:,2] 22 | #draw points 23 | mlab.points3d(pc[:,0], pc[:,1], pc[:,2], color, color=None, mode='point', colormap = 'gnuplot', scale_factor=1, figure=fig) 24 | #draw origin 25 | mlab.points3d(0, 0, 0, color=(1,1,1), mode='sphere', scale_factor=0.2) 26 | #draw axis 27 | axes=np.array([ 28 | [2.,0.,0.,0.], 29 | [0.,2.,0.,0.], 30 | [0.,0.,2.,0.], 31 | ],dtype=np.float64) 32 | mlab.plot3d([0, axes[0,0]], [0, axes[0,1]], [0, axes[0,2]], color=(1,0,0), tube_radius=None, figure=fig) 33 | mlab.plot3d([0, axes[1,0]], [0, axes[1,1]], [0, axes[1,2]], color=(0,1,0), tube_radius=None, figure=fig) 34 | mlab.plot3d([0, axes[2,0]], [0, axes[2,1]], [0, axes[2,2]], color=(0,0,1), tube_radius=None, figure=fig) 35 | mlab.view(azimuth=180, elevation=70, focalpoint=[ 12.0909996 , -1.04700089, -2.03249991], distance=62.0, figure=fig) 36 | return fig 37 | 38 | def draw_lidar(pc, color=None, fig=None, bgcolor=(0,0,0), pts_scale=1, pts_mode='point', pts_color=None): 39 | ''' Draw lidar points 40 | Args: 41 | pc: numpy array (n,3) of XYZ 42 | color: numpy array (n) of intensity or whatever 43 | fig: mayavi figure handler, if None create new one otherwise will use it 44 | Returns: 45 | fig: created or used fig 46 | ''' 47 | if fig is None: fig = mlab.figure(figure=None, bgcolor=bgcolor, fgcolor=None, engine=None, size=(1600, 1000)) 48 | if color is None: color = pc[:,2] 49 | mlab.points3d(pc[:,0], pc[:,1], pc[:,2], color, color=pts_color, mode=pts_mode, colormap = 'gnuplot', scale_factor=pts_scale, figure=fig) 50 | 51 | #draw origin 52 | mlab.points3d(0, 0, 0, color=(1,1,1), mode='sphere', scale_factor=0.2) 53 | 54 | #draw axis 55 | axes=np.array([ 56 | [2.,0.,0.,0.], 57 | [0.,2.,0.,0.], 58 | [0.,0.,2.,0.], 59 | ],dtype=np.float64) 60 | mlab.plot3d([0, axes[0,0]], [0, axes[0,1]], [0, axes[0,2]], color=(1,0,0), tube_radius=None, figure=fig) 61 | mlab.plot3d([0, axes[1,0]], [0, axes[1,1]], [0, axes[1,2]], color=(0,1,0), tube_radius=None, figure=fig) 62 | mlab.plot3d([0, axes[2,0]], [0, axes[2,1]], [0, axes[2,2]], color=(0,0,1), tube_radius=None, figure=fig) 63 | 64 | # draw fov (todo: update to real sensor spec.) 65 | fov=np.array([ # 45 degree 66 | [20., 20., 0.,0.], 67 | [20.,-20., 0.,0.], 68 | ],dtype=np.float64) 69 | 70 | mlab.plot3d([0, fov[0,0]], [0, fov[0,1]], [0, fov[0,2]], color=(1,1,1), tube_radius=None, line_width=1, figure=fig) 71 | mlab.plot3d([0, fov[1,0]], [0, fov[1,1]], [0, fov[1,2]], color=(1,1,1), tube_radius=None, line_width=1, figure=fig) 72 | 73 | # draw square region 74 | TOP_Y_MIN=-20 75 | TOP_Y_MAX=20 76 | TOP_X_MIN=0 77 | TOP_X_MAX=40 78 | TOP_Z_MIN=-2.0 79 | TOP_Z_MAX=0.4 80 | 81 | x1 = TOP_X_MIN 82 | x2 = TOP_X_MAX 83 | y1 = TOP_Y_MIN 84 | y2 = TOP_Y_MAX 85 | mlab.plot3d([x1, x1], [y1, y2], [0,0], color=(0.5,0.5,0.5), tube_radius=0.1, line_width=1, figure=fig) 86 | mlab.plot3d([x2, x2], [y1, y2], [0,0], color=(0.5,0.5,0.5), tube_radius=0.1, line_width=1, figure=fig) 87 | mlab.plot3d([x1, x2], [y1, y1], [0,0], color=(0.5,0.5,0.5), tube_radius=0.1, line_width=1, figure=fig) 88 | mlab.plot3d([x1, x2], [y2, y2], [0,0], color=(0.5,0.5,0.5), tube_radius=0.1, line_width=1, figure=fig) 89 | 90 | #mlab.orientation_axes() 91 | mlab.view(azimuth=180, elevation=70, focalpoint=[ 12.0909996 , -1.04700089, -2.03249991], distance=62.0, figure=fig) 92 | return fig 93 | 94 | def draw_gt_boxes3d(gt_boxes3d, fig, color=(1,1,1), line_width=1, draw_text=True, text_scale=(1,1,1), color_list=None): 95 | ''' Draw 3D bounding boxes 96 | Args: 97 | gt_boxes3d: numpy array (n,8,3) for XYZs of the box corners 98 | fig: mayavi figure handler 99 | color: RGB value tuple in range (0,1), box line color 100 | line_width: box line width 101 | draw_text: boolean, if true, write box indices beside boxes 102 | text_scale: three number tuple 103 | color_list: a list of RGB tuple, if not None, overwrite color. 104 | Returns: 105 | fig: updated fig 106 | ''' 107 | num = len(gt_boxes3d) 108 | for n in range(num): 109 | b = gt_boxes3d[n] 110 | if color_list is not None: 111 | color = color_list[n] 112 | if draw_text: mlab.text3d(b[4,0], b[4,1], b[4,2], '%d'%n, scale=text_scale, color=color, figure=fig) 113 | for k in range(0,4): 114 | #http://docs.enthought.com/mayavi/mayavi/auto/mlab_helper_functions.html 115 | i,j=k,(k+1)%4 116 | mlab.plot3d([b[i,0], b[j,0]], [b[i,1], b[j,1]], [b[i,2], b[j,2]], color=color, tube_radius=None, line_width=line_width, figure=fig) 117 | 118 | i,j=k+4,(k+1)%4 + 4 119 | mlab.plot3d([b[i,0], b[j,0]], [b[i,1], b[j,1]], [b[i,2], b[j,2]], color=color, tube_radius=None, line_width=line_width, figure=fig) 120 | 121 | i,j=k,k+4 122 | mlab.plot3d([b[i,0], b[j,0]], [b[i,1], b[j,1]], [b[i,2], b[j,2]], color=color, tube_radius=None, line_width=line_width, figure=fig) 123 | #mlab.show(1) 124 | #mlab.view(azimuth=180, elevation=70, focalpoint=[ 12.0909996 , -1.04700089, -2.03249991], distance=62.0, figure=fig) 125 | return fig 126 | 127 | 128 | if __name__=='__main__': 129 | pc = np.loadtxt('mayavi/kitti_sample_scan.txt') 130 | fig = draw_lidar(pc) 131 | mlab.savefig('pc_view.jpg', figure=fig) 132 | raw_input() 133 | -------------------------------------------------------------------------------- /models/frustum_pointnets_v2.py: -------------------------------------------------------------------------------- 1 | ''' Frustum PointNets v2 Model. 2 | ''' 3 | from __future__ import print_function 4 | 5 | import sys 6 | import os 7 | import tensorflow as tf 8 | import numpy as np 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | ROOT_DIR = os.path.dirname(BASE_DIR) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 13 | import tf_util 14 | from pointnet_util import pointnet_sa_module, pointnet_sa_module_msg, pointnet_fp_module 15 | from model_util import NUM_HEADING_BIN, NUM_SIZE_CLUSTER, NUM_OBJECT_POINT 16 | from model_util import point_cloud_masking, get_center_regression_net 17 | from model_util import placeholder_inputs, parse_output_to_tensors, get_loss 18 | 19 | 20 | def get_instance_seg_v2_net(point_cloud, one_hot_vec, 21 | is_training, bn_decay, end_points): 22 | ''' 3D instance segmentation PointNet v2 network. 23 | Input: 24 | point_cloud: TF tensor in shape (B,N,4) 25 | frustum point clouds with XYZ and intensity in point channels 26 | XYZs are in frustum coordinate 27 | one_hot_vec: TF tensor in shape (B,3) 28 | length-3 vectors indicating predicted object type 29 | is_training: TF boolean scalar 30 | bn_decay: TF float scalar 31 | end_points: dict 32 | Output: 33 | logits: TF tensor in shape (B,N,2), scores for bkg/clutter and object 34 | end_points: dict 35 | ''' 36 | 37 | l0_xyz = tf.slice(point_cloud, [0,0,0], [-1,-1,3]) 38 | l0_points = tf.slice(point_cloud, [0,0,3], [-1,-1,1]) 39 | 40 | # Set abstraction layers 41 | l1_xyz, l1_points = pointnet_sa_module_msg(l0_xyz, l0_points, 42 | 128, [0.2,0.4,0.8], [32,64,128], 43 | [[32,32,64], [64,64,128], [64,96,128]], 44 | is_training, bn_decay, scope='layer1') 45 | l2_xyz, l2_points = pointnet_sa_module_msg(l1_xyz, l1_points, 46 | 32, [0.4,0.8,1.6], [64,64,128], 47 | [[64,64,128], [128,128,256], [128,128,256]], 48 | is_training, bn_decay, scope='layer2') 49 | l3_xyz, l3_points, _ = pointnet_sa_module(l2_xyz, l2_points, 50 | npoint=None, radius=None, nsample=None, mlp=[128,256,1024], 51 | mlp2=None, group_all=True, is_training=is_training, 52 | bn_decay=bn_decay, scope='layer3') 53 | 54 | # Feature Propagation layers 55 | l3_points = tf.concat([l3_points, tf.expand_dims(one_hot_vec, 1)], axis=2) 56 | l2_points = pointnet_fp_module(l2_xyz, l3_xyz, l2_points, l3_points, 57 | [128,128], is_training, bn_decay, scope='fa_layer1') 58 | l1_points = pointnet_fp_module(l1_xyz, l2_xyz, l1_points, l2_points, 59 | [128,128], is_training, bn_decay, scope='fa_layer2') 60 | l0_points = pointnet_fp_module(l0_xyz, l1_xyz, 61 | tf.concat([l0_xyz,l0_points],axis=-1), l1_points, 62 | [128,128], is_training, bn_decay, scope='fa_layer3') 63 | 64 | # FC layers 65 | net = tf_util.conv1d(l0_points, 128, 1, padding='VALID', bn=True, 66 | is_training=is_training, scope='conv1d-fc1', bn_decay=bn_decay) 67 | end_points['feats'] = net 68 | net = tf_util.dropout(net, keep_prob=0.7, 69 | is_training=is_training, scope='dp1') 70 | logits = tf_util.conv1d(net, 2, 1, 71 | padding='VALID', activation_fn=None, scope='conv1d-fc2') 72 | 73 | return logits, end_points 74 | 75 | def get_3d_box_estimation_v2_net(object_point_cloud, one_hot_vec, 76 | is_training, bn_decay, end_points): 77 | ''' 3D Box Estimation PointNet v2 network. 78 | Input: 79 | object_point_cloud: TF tensor in shape (B,M,C) 80 | masked point clouds in object coordinate 81 | one_hot_vec: TF tensor in shape (B,3) 82 | length-3 vectors indicating predicted object type 83 | Output: 84 | output: TF tensor in shape (B,3+NUM_HEADING_BIN*2+NUM_SIZE_CLUSTER*4) 85 | including box centers, heading bin class scores and residuals, 86 | and size cluster scores and residuals 87 | ''' 88 | # Gather object points 89 | batch_size = object_point_cloud.get_shape()[0].value 90 | 91 | l0_xyz = object_point_cloud 92 | l0_points = None 93 | # Set abstraction layers 94 | l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, 95 | npoint=128, radius=0.2, nsample=64, mlp=[64,64,128], 96 | mlp2=None, group_all=False, 97 | is_training=is_training, bn_decay=bn_decay, scope='ssg-layer1') 98 | l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, 99 | npoint=32, radius=0.4, nsample=64, mlp=[128,128,256], 100 | mlp2=None, group_all=False, 101 | is_training=is_training, bn_decay=bn_decay, scope='ssg-layer2') 102 | l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, 103 | npoint=None, radius=None, nsample=None, mlp=[256,256,512], 104 | mlp2=None, group_all=True, 105 | is_training=is_training, bn_decay=bn_decay, scope='ssg-layer3') 106 | 107 | # Fully connected layers 108 | net = tf.reshape(l3_points, [batch_size, -1]) 109 | net = tf.concat([net, one_hot_vec], axis=1) 110 | net = tf_util.fully_connected(net, 512, bn=True, 111 | is_training=is_training, scope='fc1', bn_decay=bn_decay) 112 | net = tf_util.fully_connected(net, 256, bn=True, 113 | is_training=is_training, scope='fc2', bn_decay=bn_decay) 114 | 115 | # The first 3 numbers: box center coordinates (cx,cy,cz), 116 | # the next NUM_HEADING_BIN*2: heading bin class scores and bin residuals 117 | # next NUM_SIZE_CLUSTER*4: box cluster scores and residuals 118 | output = tf_util.fully_connected(net, 119 | 3+NUM_HEADING_BIN*2+NUM_SIZE_CLUSTER*4, activation_fn=None, scope='fc3') 120 | return output, end_points 121 | 122 | 123 | def get_model(point_cloud, one_hot_vec, is_training, bn_decay=None): 124 | ''' Frustum PointNets model. The model predict 3D object masks and 125 | amodel bounding boxes for objects in frustum point clouds. 126 | 127 | Input: 128 | point_cloud: TF tensor in shape (B,N,4) 129 | frustum point clouds with XYZ and intensity in point channels 130 | XYZs are in frustum coordinate 131 | one_hot_vec: TF tensor in shape (B,3) 132 | length-3 vectors indicating predicted object type 133 | is_training: TF boolean scalar 134 | bn_decay: TF float scalar 135 | Output: 136 | end_points: dict (map from name strings to TF tensors) 137 | ''' 138 | end_points = {} 139 | 140 | # 3D Instance Segmentation PointNet 141 | logits, end_points = get_instance_seg_v2_net(\ 142 | point_cloud, one_hot_vec, 143 | is_training, bn_decay, end_points) 144 | end_points['mask_logits'] = logits 145 | 146 | # Masking 147 | # select masked points and translate to masked points' centroid 148 | object_point_cloud_xyz, mask_xyz_mean, end_points = \ 149 | point_cloud_masking(point_cloud, logits, end_points) 150 | 151 | # T-Net and coordinate translation 152 | center_delta, end_points = get_center_regression_net(\ 153 | object_point_cloud_xyz, one_hot_vec, 154 | is_training, bn_decay, end_points) 155 | stage1_center = center_delta + mask_xyz_mean # Bx3 156 | end_points['stage1_center'] = stage1_center 157 | # Get object point cloud in object coordinate 158 | object_point_cloud_xyz_new = \ 159 | object_point_cloud_xyz - tf.expand_dims(center_delta, 1) 160 | 161 | # Amodel Box Estimation PointNet 162 | output, end_points = get_3d_box_estimation_v2_net(\ 163 | object_point_cloud_xyz_new, one_hot_vec, 164 | is_training, bn_decay, end_points) 165 | 166 | # Parse output to 3D box parameters 167 | end_points = parse_output_to_tensors(output, end_points) 168 | end_points['center'] = end_points['center_boxnet'] + stage1_center # Bx3 169 | 170 | return end_points 171 | 172 | if __name__=='__main__': 173 | with tf.Graph().as_default(): 174 | inputs = tf.zeros((32,1024,4)) 175 | outputs = get_model(inputs, tf.ones((32,3)), tf.constant(True)) 176 | for key in outputs: 177 | print((key, outputs[key])) 178 | loss = get_loss(tf.zeros((32,1024),dtype=tf.int32), 179 | tf.zeros((32,3)), tf.zeros((32,),dtype=tf.int32), 180 | tf.zeros((32,)), tf.zeros((32,),dtype=tf.int32), 181 | tf.zeros((32,3)), outputs) 182 | print(loss) 183 | -------------------------------------------------------------------------------- /models/tf_ops/3d_interpolation/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // Find three nearest neigbors with square distance 19 | // input: xyz1 (b,n,3), xyz2(b,m,3) 20 | // output: dist (b,n,3), idx (b,n,3) 21 | void threenn_cpu(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx) { 22 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | void query_ball_point_cpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | for (int i=0;i>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 113 | cudaDeviceSynchronize(); 114 | printf("query_ball_point gpu time %f\n",get_time()-t0); 115 | 116 | t0=get_time(); 117 | group_point_gpu<<<1,1>>>(b,n,c,m,nsample,points,idx,out); 118 | cudaDeviceSynchronize(); 119 | printf("grou_point gpu time %f\n",get_time()-t0); 120 | 121 | t0=get_time(); 122 | group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 123 | cudaDeviceSynchronize(); 124 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 125 | 126 | cudaFree(xyz1); 127 | cudaFree(xyz2); 128 | cudaFree(points); 129 | cudaFree(idx); 130 | cudaFree(out); 131 | cudaFree(grad_out); 132 | cudaFree(grad_points); 133 | return 0; 134 | } 135 | -------------------------------------------------------------------------------- /models/tf_ops/grouping/query_ball_point_block.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | int index = threadIdx.x; 21 | xyz1 += n*3*index; 22 | xyz2 += m*3*index; 23 | idx += m*nsample*index; 24 | 25 | for (int j=0;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 113 | cudaDeviceSynchronize(); 114 | printf("query_ball_point gpu time %f\n",get_time()-t0); 115 | 116 | t0=get_time(); 117 | group_point_gpu<<<1,b>>>(b,n,c,m,nsample,points,idx,out); 118 | cudaDeviceSynchronize(); 119 | printf("grou_point gpu time %f\n",get_time()-t0); 120 | 121 | t0=get_time(); 122 | group_point_grad_gpu<<<1,b>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 123 | cudaDeviceSynchronize(); 124 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 125 | 126 | cudaFree(xyz1); 127 | cudaFree(xyz2); 128 | cudaFree(points); 129 | cudaFree(idx); 130 | cudaFree(out); 131 | cudaFree(grad_out); 132 | cudaFree(grad_points); 133 | return 0; 134 | } 135 | -------------------------------------------------------------------------------- /models/tf_ops/grouping/query_ball_point_grid.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 18 | // output: idx (b,m,nsample) 19 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx) { 20 | int batch_index = blockIdx.x; 21 | xyz1 += n*3*batch_index; 22 | xyz2 += m*3*batch_index; 23 | idx += m*nsample*batch_index; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | 28 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx); 123 | cudaDeviceSynchronize(); 124 | printf("query_ball_point gpu time %f\n",get_time()-t0); 125 | 126 | t0=get_time(); 127 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 128 | cudaDeviceSynchronize(); 129 | printf("grou_point gpu time %f\n",get_time()-t0); 130 | 131 | t0=get_time(); 132 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 133 | cudaDeviceSynchronize(); 134 | printf("grou_point_grad gpu time %f\n",get_time()-t0); 135 | 136 | cudaFree(xyz1); 137 | cudaFree(xyz2); 138 | cudaFree(points); 139 | cudaFree(idx); 140 | cudaFree(out); 141 | cudaFree(grad_out); 142 | cudaFree(grad_points); 143 | return 0; 144 | } 145 | -------------------------------------------------------------------------------- /models/tf_ops/grouping/selection_sort.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,n), val (b,m,n) 20 | void selection_sort_cpu(int b, int n, int m, int k, const float *dist, int *idx, float *val) { 21 | float *p_dist; 22 | float tmp; 23 | int tmpi; 24 | for (int i=0;i 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,k), val (b,m,k) 20 | __global__ void selection_sort_gpu(int b, int n, int m, int k, float *dist, int *idx, float *val) { 21 | int batch_index = blockIdx.x; 22 | dist+=m*n*batch_index; 23 | idx+=m*k*batch_index; 24 | val+=m*k*batch_index; 25 | 26 | int index = threadIdx.x; 27 | int stride = blockDim.x; 28 | 29 | float *p_dist; 30 | for (int j=index;j>>(b,n,m,k,dist,idx,val); 68 | cudaDeviceSynchronize(); 69 | printf("selection sort cpu time %f\n",get_time()-t0); 70 | 71 | return 0; 72 | } 73 | -------------------------------------------------------------------------------- /models/tf_ops/grouping/selection_sort_const.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // memset 4 | #include // rand, RAND_MAX 5 | #include // sqrtf 6 | #include 7 | #include 8 | using namespace std; 9 | float randomf(){ 10 | return (rand()+0.5)/(RAND_MAX+1.0); 11 | } 12 | static double get_time(){ 13 | timespec tp; 14 | clock_gettime(CLOCK_MONOTONIC,&tp); 15 | return tp.tv_sec+tp.tv_nsec*1e-9; 16 | } 17 | 18 | // input: k (1), distance matrix dist (b,m,n) 19 | // output: idx (b,m,n), dist_out (b,m,n) 20 | __global__ void selection_sort_gpu(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 21 | int batch_index = blockIdx.x; 22 | dist+=m*n*batch_index; 23 | outi+=m*n*batch_index; 24 | out+=m*n*batch_index; 25 | 26 | int index = threadIdx.x; 27 | int stride = blockDim.x; 28 | 29 | // copy from dist to dist_out 30 | for (int j=index;j>>(b,n,m,k,dist,idx,dist_out); 84 | cudaDeviceSynchronize(); 85 | printf("selection sort cpu time %f\n",get_time()-t0); 86 | 87 | //for (int i=0;i>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 127 | //cudaDeviceSynchronize(); 128 | } 129 | void selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 130 | selection_sort_gpu<<>>(b,n,m,k,dist,outi,out); 131 | //cudaDeviceSynchronize(); 132 | } 133 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){ 134 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 135 | //cudaDeviceSynchronize(); 136 | } 137 | void groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points){ 138 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 139 | //group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 140 | //cudaDeviceSynchronize(); 141 | } 142 | -------------------------------------------------------------------------------- /models/tf_ops/grouping/tf_grouping_op_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | import numpy as np 4 | from tf_grouping import query_ball_point, group_point 5 | 6 | class GroupPointTest(tf.test.TestCase): 7 | def test(self): 8 | pass 9 | 10 | def test_grad(self): 11 | with tf.device('/gpu:0'): 12 | points = tf.constant(np.random.random((1,128,16)).astype('float32')) 13 | print(points) 14 | xyz1 = tf.constant(np.random.random((1,128,3)).astype('float32')) 15 | xyz2 = tf.constant(np.random.random((1,8,3)).astype('float32')) 16 | radius = 0.3 17 | nsample = 32 18 | idx, pts_cnt = query_ball_point(radius, nsample, xyz1, xyz2) 19 | grouped_points = group_point(points, idx) 20 | print(grouped_points) 21 | 22 | with self.test_session(): 23 | print("---- Going to compute gradient error") 24 | err = tf.test.compute_gradient_error(points, (1,128,16), grouped_points, (1,8,32,16)) 25 | print(err) 26 | self.assertLess(err, 1e-4) 27 | 28 | if __name__=='__main__': 29 | tf.test.main() 30 | -------------------------------------------------------------------------------- /models/tf_ops/sampling/tf_sampling.py: -------------------------------------------------------------------------------- 1 | ''' Furthest point sampling 2 | Original author: Haoqiang Fan 3 | Modified by Charles R. Qi 4 | All Rights Reserved. 2017. 5 | ''' 6 | from __future__ import print_function 7 | import tensorflow as tf 8 | from tensorflow.python.framework import ops 9 | import sys 10 | import os 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | sampling_module=tf.load_op_library(os.path.join(BASE_DIR, 'tf_sampling_so.so')) 14 | def prob_sample(inp,inpr): 15 | ''' 16 | input: 17 | batch_size * ncategory float32 18 | batch_size * npoints float32 19 | returns: 20 | batch_size * npoints int32 21 | ''' 22 | return sampling_module.prob_sample(inp,inpr) 23 | ops.NoGradient('ProbSample') 24 | # TF1.0 API requires set shape in C++ 25 | #@tf.RegisterShape('ProbSample') 26 | #def _prob_sample_shape(op): 27 | # shape1=op.inputs[0].get_shape().with_rank(2) 28 | # shape2=op.inputs[1].get_shape().with_rank(2) 29 | # return [tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 30 | def gather_point(inp,idx): 31 | ''' 32 | input: 33 | batch_size * ndataset * 3 float32 34 | batch_size * npoints int32 35 | returns: 36 | batch_size * npoints * 3 float32 37 | ''' 38 | return sampling_module.gather_point(inp,idx) 39 | #@tf.RegisterShape('GatherPoint') 40 | #def _gather_point_shape(op): 41 | # shape1=op.inputs[0].get_shape().with_rank(3) 42 | # shape2=op.inputs[1].get_shape().with_rank(2) 43 | # return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[2]])] 44 | @tf.RegisterGradient('GatherPoint') 45 | def _gather_point_grad(op,out_g): 46 | inp=op.inputs[0] 47 | idx=op.inputs[1] 48 | return [sampling_module.gather_point_grad(inp,idx,out_g),None] 49 | def farthest_point_sample(npoint,inp): 50 | ''' 51 | input: 52 | int32 53 | batch_size * ndataset * 3 float32 54 | returns: 55 | batch_size * npoint int32 56 | ''' 57 | return sampling_module.farthest_point_sample(inp, npoint) 58 | ops.NoGradient('FarthestPointSample') 59 | 60 | 61 | if __name__=='__main__': 62 | import numpy as np 63 | np.random.seed(100) 64 | triangles=np.random.rand(1,5,3,3).astype('float32') 65 | with tf.device('/gpu:1'): 66 | inp=tf.constant(triangles) 67 | tria=inp[:,:,0,:] 68 | trib=inp[:,:,1,:] 69 | tric=inp[:,:,2,:] 70 | areas=tf.sqrt(tf.reduce_sum(tf.cross(trib-tria,tric-tria)**2,2)+1e-9) 71 | randomnumbers=tf.random_uniform((1,8192)) 72 | triids=prob_sample(areas,randomnumbers) 73 | tria_sample=gather_point(tria,triids) 74 | trib_sample=gather_point(trib,triids) 75 | tric_sample=gather_point(tric,triids) 76 | us=tf.random_uniform((1,8192)) 77 | vs=tf.random_uniform((1,8192)) 78 | uplusv=1-tf.abs(us+vs-1) 79 | uminusv=us-vs 80 | us=(uplusv+uminusv)*0.5 81 | vs=(uplusv-uminusv)*0.5 82 | pt_sample=tria_sample+(trib_sample-tria_sample)*tf.expand_dims(us,-1)+(tric_sample-tria_sample)*tf.expand_dims(vs,-1) 83 | print('pt_sample: ', pt_sample) 84 | reduced_sample=gather_point(pt_sample,farthest_point_sample(1024,pt_sample)) 85 | print(reduced_sample) 86 | with tf.Session('') as sess: 87 | ret=sess.run(reduced_sample) 88 | print(ret.shape,ret.dtype) 89 | import cPickle as pickle 90 | pickle.dump(ret,open('1.pkl','wb'),-1) 91 | -------------------------------------------------------------------------------- /models/tf_ops/sampling/tf_sampling_compile.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | /usr/local/cuda-8.0/bin/nvcc tf_sampling_g.cu -o tf_sampling_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 3 | 4 | # TF1.2 5 | #g++ -std=c++11 tf_sampling.cpp tf_sampling_g.cu.o -o tf_sampling_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -lcudart -L /usr/local/cuda-8.0/lib64/ -O2 -D_GLIBCXX_USE_CXX11_ABI=0 6 | 7 | # TF1.4 8 | g++ -std=c++11 tf_sampling.cpp tf_sampling_g.cu.o -o tf_sampling_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I /usr/local/lib/python2.7/dist-packages/tensorflow/include/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L/usr/local/lib/python2.7/dist-packages/tensorflow -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 9 | -------------------------------------------------------------------------------- /models/tf_ops/sampling/tf_sampling_g.cu: -------------------------------------------------------------------------------- 1 | /* Furthest point sampling GPU implementation 2 | * Original author: Haoqiang Fan 3 | * Modified by Charles R. Qi 4 | * All Rights Reserved. 2017. 5 | */ 6 | 7 | __global__ void cumsumKernel(int b,int n,const float * __restrict__ inp,float * __restrict__ out){ 8 | const int BlockSize=2048; 9 | const int paddingLevel=5; 10 | __shared__ float buffer4[BlockSize*4]; 11 | __shared__ float buffer[BlockSize+(BlockSize>>paddingLevel)]; 12 | for (int i=blockIdx.x;i>2; 18 | for (int k=threadIdx.x*4;k>2)+(k>>(2+paddingLevel))]=v4; 33 | }else{ 34 | float v=0; 35 | for (int k2=k;k2>2)+(k>>(2+paddingLevel))]=v; 43 | } 44 | } 45 | int u=0; 46 | for (;(2<>(u+1));k+=blockDim.x){ 49 | int i1=(((k<<1)+2)<>paddingLevel; 52 | i2+=i2>>paddingLevel; 53 | buffer[i1]+=buffer[i2]; 54 | } 55 | } 56 | u--; 57 | for (;u>=0;u--){ 58 | __syncthreads(); 59 | for (int k=threadIdx.x;k>(u+1));k+=blockDim.x){ 60 | int i1=(((k<<1)+3)<>paddingLevel; 63 | i2+=i2>>paddingLevel; 64 | buffer[i1]+=buffer[i2]; 65 | } 66 | } 67 | __syncthreads(); 68 | for (int k=threadIdx.x*4;k>2)-1)+(((k>>2)-1)>>paddingLevel); 71 | buffer4[k]+=buffer[k2]; 72 | buffer4[k+1]+=buffer[k2]; 73 | buffer4[k+2]+=buffer[k2]; 74 | buffer4[k+3]+=buffer[k2]; 75 | } 76 | } 77 | __syncthreads(); 78 | for (int k=threadIdx.x;k>paddingLevel)]+runningsum2; 82 | float r2=runningsum+t; 83 | runningsum2=t-(r2-runningsum); 84 | runningsum=r2; 85 | __syncthreads(); 86 | } 87 | } 88 | } 89 | 90 | __global__ void binarysearchKernel(int b,int n,int m,const float * __restrict__ dataset,const float * __restrict__ query, int * __restrict__ result){ 91 | int base=1; 92 | while (base=1;k>>=1) 99 | if (r>=k && dataset[i*n+r-k]>=q) 100 | r-=k; 101 | result[i*m+j]=r; 102 | } 103 | } 104 | } 105 | __global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){ 106 | if (m<=0) 107 | return; 108 | const int BlockSize=512; 109 | __shared__ float dists[BlockSize]; 110 | __shared__ int dists_i[BlockSize]; 111 | const int BufferSize=3072; 112 | __shared__ float buf[BufferSize*3]; 113 | for (int i=blockIdx.x;ibest){ 147 | best=d2; 148 | besti=k; 149 | } 150 | } 151 | dists[threadIdx.x]=best; 152 | dists_i[threadIdx.x]=besti; 153 | for (int u=0;(1<>(u+1))){ 156 | int i1=(threadIdx.x*2)<>>(b,n,inp,out); 196 | } 197 | //require b*n working space 198 | void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out){ 199 | cumsumKernel<<<32,512>>>(b,n,inp_p,temp); 200 | binarysearchKernel<<>>(b,n,m,temp,inp_r,out); 201 | } 202 | //require 32*n working space 203 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 204 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 205 | } 206 | void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out){ 207 | gatherpointKernel<<>>(b,n,m,inp,idx,out); 208 | } 209 | void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){ 210 | scatteraddpointKernel<<>>(b,n,m,out_g,idx,inp_g); 211 | } 212 | 213 | -------------------------------------------------------------------------------- /scripts/command_prep_data.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | python kitti/prepare_data.py --gen_train --gen_val --gen_val_rgb_detection 3 | -------------------------------------------------------------------------------- /scripts/command_test_v1.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | python train/test.py --gpu 0 --num_point 1024 --model frustum_pointnets_v1 --model_path train/log_v1/model.ckpt --output train/detection_results_v1 --data_path kitti/frustum_carpedcyc_val_rgb_detection.pickle --from_rgb_detection --idx_path kitti/image_sets/val.txt --from_rgb_detection 3 | train/kitti_eval/evaluate_object_3d_offline dataset/KITTI/object/training/label_2/ train/detection_results_v1 4 | -------------------------------------------------------------------------------- /scripts/command_test_v2.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | python train/test.py --gpu 0 --num_point 1024 --model frustum_pointnets_v2 --model_path train/log_v2/model.ckpt --output train/detection_results_v2 --data_path kitti/frustum_carpedcyc_val_rgb_detection.pickle --from_rgb_detection --idx_path kitti/image_sets/val.txt --from_rgb_detection 3 | train/kitti_eval/evaluate_object_3d_offline dataset/KITTI/object/training/label_2/ train/detection_results_v2 4 | -------------------------------------------------------------------------------- /scripts/command_train_v1.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | python train/train.py --gpu 0 --model frustum_pointnets_v1 --log_dir train/log_v1 --num_point 1024 --max_epoch 201 --batch_size 32 --decay_step 800000 --decay_rate 0.5 3 | -------------------------------------------------------------------------------- /scripts/command_train_v2.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | python train/train.py --gpu 0 --model frustum_pointnets_v2 --log_dir train/log_v2 --num_point 1024 --max_epoch 201 --batch_size 24 --decay_step 800000 --decay_rate 0.5 3 | -------------------------------------------------------------------------------- /sunrgbd/README.md: -------------------------------------------------------------------------------- 1 | ### Data Preparation, Training and Evaluation of Frustum PointNets on SUN-RGBD data 2 | 3 | CLAIM: This is still a beta release of the code, with lots of things to clarify -- but could be useful for some of you who would like to start earlier. 4 | 5 | #### 1. Prepare SUN RGB-D data 6 | Download SUNRGBD V1 dataset and toolkit 7 | 8 | Run `extract_rgbd_data.m` in `sunrgbd_data/matlab/SUNRGBDtoolbox/` 9 | 10 | The generated data should be organized a bit in the supposed mysunrgbd foder by moving all subfolders into /mysunrgbd/training/ and creating a train/val file list etc. 11 | 12 | Prepare pickle files for TensorFlow training pipeline: 13 | run `sunrgbd_data/sunrgbd_data.py` 14 | 15 | This will prepare frustum point clouds and labels and save them to zipped pickle files. 16 | 17 | #### 2. Training 18 | 19 | Run `train_one_hot.py` with the following parameters: 20 | 21 | `batch_size=32, decay_rate=0.5, decay_step=800000, gpu=0, learning_rate=0.001, log_dir='log', max_epoch=151, model='frustum_pointnets_v1_sunrgbd', momentum=0.9, no_rgb=False, num_point=2048, optimizer='adam', restore_model_path=None` 22 | 23 | #### 3. Testing and evaluation 24 | 25 | To test the model on validation set you also need to prepare pickle files from detected 2D boxes in step 2 (last line in the main function of `sunrgbd_data.py`) -- the 2D detector should be trained to predict ``amodal'' 2D boxes. 26 | 27 | You can run `test_one_hot.py` to test a trained frustum pointnet model with `--dump_result` flag, which will dump a pickle file for test results. And then run `evaluate.py` to evaluate the 3D AP with the dumped pickle file. We wrote our own 3D detection evaluation script because the original MATLAB one is too slow. 28 | 29 | A typical evaluation script is like: 30 | `python evaluate.py --data_path ../sunrgbd_data/fcn_det_val.zip.pickle --result_path test_results_v1_fcn_ft_val.pickle --from_rgb_detection` 31 | 32 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/cluster_box3d.py: -------------------------------------------------------------------------------- 1 | ''' Cluster and visualize distribution of box3d ''' 2 | 3 | import cPickle as pickle 4 | with open('box3d_dimensions.pickle','rb') as fp: 5 | type_list = pickle.load(fp) 6 | dimension_list = pickle.load(fp) # l,w,h 7 | ry_list = pickle.load(fp) 8 | 9 | import numpy as np 10 | box3d_pts = np.vstack(dimension_list) 11 | print box3d_pts.shape 12 | 13 | print set(type_list) 14 | raw_input() 15 | 16 | 17 | # Get average box size for different catgories 18 | median_box3d_list = [] 19 | for class_type in sorted(set(type_list)): 20 | cnt = 0 21 | box3d_list = [] 22 | for i in range(len(dimension_list)): 23 | if type_list[i]==class_type: 24 | cnt += 1 25 | box3d_list.append(dimension_list[i]) 26 | #print class_type, cnt, box3d/float(cnt) 27 | median_box3d = np.median(box3d_list,0) 28 | print "\'%s\': np.array([%f,%f,%f])," % (class_type, median_box3d[0]*2, median_box3d[1]*2, median_box3d[2]*2) 29 | median_box3d_list.append(median_box3d) 30 | raw_input() 31 | 32 | import mayavi.mlab as mlab 33 | fig = mlab.figure(figure=None, bgcolor=(0,0,0), fgcolor=None, engine=None, size=(1000, 500)) 34 | mlab.points3d(box3d_pts[:,0], box3d_pts[:,1], box3d_pts[:,2], mode='point', colormap='gnuplot', scale_factor=1, figure=fig) 35 | ##draw axis 36 | mlab.points3d(0, 0, 0, color=(1,1,1), mode='sphere', scale_factor=0.2) 37 | 38 | axes=np.array([ 39 | [2.,0.,0.,0.], 40 | [0.,2.,0.,0.], 41 | [0.,0.,2.,0.], 42 | ],dtype=np.float64) 43 | fov=np.array([ ## : now is 45 deg. use actual setting later ... 44 | [20., 20., 0.,0.], 45 | [20.,-20., 0.,0.], 46 | ],dtype=np.float64) 47 | 48 | mlab.plot3d([0, axes[0,0]], [0, axes[0,1]], [0, axes[0,2]], color=(1,0,0), tube_radius=None, figure=fig) 49 | mlab.plot3d([0, axes[1,0]], [0, axes[1,1]], [0, axes[1,2]], color=(0,1,0), tube_radius=None, figure=fig) 50 | mlab.plot3d([0, axes[2,0]], [0, axes[2,1]], [0, axes[2,2]], color=(0,0,1), tube_radius=None, figure=fig) 51 | mlab.plot3d([0, fov[0,0]], [0, fov[0,1]], [0, fov[0,2]], color=(1,1,1), tube_radius=None, line_width=1, figure=fig) 52 | mlab.plot3d([0, fov[1,0]], [0, fov[1,1]], [0, fov[1,2]], color=(1,1,1), tube_radius=None, line_width=1, figure=fig) 53 | mlab.orientation_axes() 54 | 55 | for box in median_box3d_list: 56 | mlab.points3d(box[0], box[1], box[2], color=(1,0,1), mode='sphere', scale_factor=0.4) 57 | raw_input() 58 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/README.txt: -------------------------------------------------------------------------------- 1 | Compared with original SUNRGBD toolkits 2 | 3 | removed metadata and GT data -- large files. 4 | added detection/extract_gt_boxes.m to extract GT boxes 5 | Under SUNRGBDtoolkit added 6 | extract_data.m 7 | extract_data_dimension.m 8 | order_basis.m 9 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/README.txt: -------------------------------------------------------------------------------- 1 | **************************************************************************************** 2 | Data: Image depth and label data are in SUNRGBD.zip 3 | image: rgb image 4 | depth: depth image to read the depth see the code in SUNRGBDtoolbox/read3dPoints/. 5 | extrinsics: the rotation matrix to align the point could with gravity 6 | fullres: full resolution depth and rgb image 7 | intrinsics.txt : sensor intrinsic 8 | scene.txt : scene type 9 | annotation2Dfinal : 2D segmentation 10 | annotation3Dfinal : 3D bounding box 11 | annotation3Dlayout : 3D room layout bounding box 12 | 13 | **************************************************************************************** 14 | Label: 15 | In SUNRGBDtoolbox/Metadata 16 | SUNRGBDMeta.mat: 2D,3D bounding box ground truth and image information for each frame. 17 | SUNRGBD2Dseg.mat: 2D segmetation ground truth. 18 | The index in "SUNRGBD2Dseg(imageId).seglabelall" mapping the name to "seglistall". 19 | The index in "SUNRGBD2Dseg(imageId).seglabel" are mapping the object name in "seg37list". 20 | 21 | **************************************************************************************** 22 | In SUNRGBDtoolbox/traintestsplit 23 | allsplit.mat: stores the training and testing split. 24 | 25 | **************************************************************************************** 26 | Code: 27 | SUNRGBDtoolbox/demo.m : Examples to load and visualize the data. 28 | SUNRGBDtoolbox/readframeSUNRGBD.m : Example code to read SUNRGBD annotation from ".json" file. 29 | 30 | ***************************************************************************************** 31 | Citation: 32 | Please cite our paper if you use this data: 33 | S. Song, S. Lichtenberg, and J. Xiao. 34 | SUN RGB-D: A RGB-D Scene Understanding Benchmark Suite 35 | Proceedings of 28th IEEE Conference on Computer Vision and Pattern Recognition (CVPR2015) 36 | 37 | 38 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/demo.m: -------------------------------------------------------------------------------- 1 | addpath(genpath('.')) 2 | load('./Metadata/SUNRGBDMeta.mat') 3 | load('./Metadata/SUNRGBD2Dseg.mat') 4 | %% Read 5 | imageId = 1; 6 | data = SUNRGBDMeta(imageId); 7 | % data.depthpath = '/data/rqi/SUNRGBD/SUNRGBD/kv2/kinect2data/000037_2014-05-26_14-54-02_260595134347_rgbf000041-resize/depth/0000041.png'; 8 | % data.rgbpath = '/data/rqi/SUNRGBD/SUNRGBD/kv2/kinect2data/000037_2014-05-26_14-54-02_260595134347_rgbf000041-resize/image/0000041.jpg'; 9 | data.depthpath(1:16) = ''; 10 | data.depthpath = strcat('/data/rqi/SUNRGBD',data.depthpath); 11 | data.rgbpath(1:16) = ''; 12 | data.rgbpath = strcat('/data/rqi/SUNRGBD',data.rgbpath); 13 | [rgb,points3d,depthInpaint,imsize]=read3dPoints(data); 14 | %% draw 15 | figure, 16 | imshow(data.rgbpath); 17 | hold on; 18 | for kk =1:length(data.groundtruth3DBB) 19 | rectangle('Position', [data.groundtruth3DBB(kk).gtBb2D(1) data.groundtruth3DBB(kk).gtBb2D(2) data.groundtruth3DBB(kk).gtBb2D(3) data.groundtruth3DBB(kk).gtBb2D(4)],'edgecolor','y'); 20 | text(data.groundtruth3DBB(kk).gtBb2D(1),data.groundtruth3DBB(kk).gtBb2D(2),data.groundtruth3DBB(kk).classname,'BackgroundColor','y') 21 | end 22 | %% draw 3D 23 | figure, 24 | vis_point_cloud(points3d,rgb) 25 | hold on; 26 | for kk =1:length(data.groundtruth3DBB) 27 | vis_cube(data.groundtruth3DBB(kk),'r') 28 | end 29 | 30 | %% 31 | anno2d = SUNRGBD2Dseg(imageId); 32 | figure, 33 | imagesc(anno2d.seglabel); 34 | % category name in 37 categories list 35 | load('./Metadata/seg37list.mat'); 36 | objectname37 = unique(anno2d.seglabel(:)); 37 | objectname37 = seg37list(objectname37(objectname37~=0)); 38 | 39 | figure, 40 | imagesc(anno2d.seglabelall); 41 | % category name of all categories 42 | objectnameall = anno2d.names 43 | 44 | 45 | %% Example to read single data 46 | data = readframeSUNRGBD('/n/fs/sun3d/data/SUNRGBD/kv2/kinect2data/000002_2014-05-26_14-23-37_260595134347_rgbf000103-resize/','/n/fs/sun3d/data/'); 47 | [rgb,points3d,depthInpaint,imsize]=read3dPoints(data); 48 | %% draw 49 | figure, 50 | imshow(data.rgbpath); 51 | hold on; 52 | for kk =1:length(data.groundtruth3DBB) 53 | rectangle('Position', [data.groundtruth3DBB(kk).gtBb2D(1) data.groundtruth3DBB(kk).gtBb2D(2) data.groundtruth3DBB(kk).gtBb2D(3) data.groundtruth3DBB(kk).gtBb2D(4)],'edgecolor','y'); 54 | text(data.groundtruth3DBB(kk).gtBb2D(1),data.groundtruth3DBB(kk).gtBb2D(2),data.groundtruth3DBB(kk).classname,'BackgroundColor','y') 55 | end 56 | %% draw 3D 57 | figure, 58 | vis_point_cloud(points3d,rgb) 59 | hold on; 60 | for kk =1:length(data.groundtruth3DBB) 61 | vis_cube(data.groundtruth3DBB(kk),'r') 62 | end 63 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/drawRoom.m: -------------------------------------------------------------------------------- 1 | function drawRoom(roomLayout,color,lineWidth,maxhight) 2 | totalPoints = size(roomLayout,2); 3 | bottompoints= roomLayout(:,1:totalPoints/2); 4 | toppoints = roomLayout(:,totalPoints/2+1:totalPoints); 5 | toppoints(3,toppoints(3,:)>maxhight)=maxhight; 6 | bottompoints(3,bottompoints(3,:)>maxhight)=maxhight; 7 | [~,ind] = min(toppoints(2,:)); 8 | for i =1:length(toppoints)-1 9 | if i~=ind&&i+1~=ind 10 | vis_line(toppoints(:,i)', toppoints(:,i+1)', color, lineWidth); 11 | vis_line(bottompoints(:,i)', bottompoints(:,i+1)', color, lineWidth); 12 | vis_line(toppoints(:,i)', bottompoints(:,i)', color, lineWidth); 13 | end 14 | end 15 | if 1~=ind&&size(toppoints,2)~=ind 16 | vis_line(toppoints(:,end)', toppoints(:,1)', color, lineWidth); 17 | vis_line(bottompoints(:,end)', bottompoints(:,1)', color, lineWidth); 18 | vis_line(toppoints(:,end)', bottompoints(:,end)', color, lineWidth); 19 | end 20 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/drawSeg.m: -------------------------------------------------------------------------------- 1 | mask = 5*double(instances(:,:,1))+5*double(labels(:,:,1)); 2 | maskColor = zeros(size(image,1)*size(image,2),3); 3 | uniquemask = unique(mask(:)); 4 | for i =1:length(uniquemask) 5 | 6 | sel = mask(:)==uniquemask(i); 7 | 8 | maskColor(sel, :) = repmat(ObjectColor(uniquemask(i)),sum(sel),1); 9 | 10 | end 11 | maskColor(find(mask(:)==0),:) = 1; 12 | maskColor = reshape(maskColor,[size(mask,1),size(mask,2),3]); 13 | 14 | 15 | figure 16 | 17 | imshow(maskColor); 18 | imwrit(maskColor,'NYUmask.png') 19 | %% 20 | load('/n/fs/modelnet/SUN3DV2/prepareGT/cls.mat') 21 | 22 | addpath('/n/fs/modelnet/SUN3DV2/roomlayout/') 23 | fullname = '/n/fs/sun3d/data/rgbd_voc/000414_2014-06-04_19-49-13_260595134347_rgbf000044-resize' 24 | data = readframe(fullname); 25 | groundTruthBbs = data.groundtruth3DBB; 26 | 27 | sequenceName = getSequenceName(fullname); 28 | gtRoom3D = GroundTruthBox(sequenceName,0); 29 | cameraXYZ = data.anno_extrinsics'*gtRoom3D; 30 | cameraXYZ([2 3],:) = cameraXYZ([3 2],:); 31 | cameraXYZ(3,:) = - cameraXYZ(3,:); 32 | cameraXYZ = data.Rtilt * cameraXYZ; 33 | 34 | my_mhCorner3D = cameraXYZ; %data.Rtilt*data.anno_extrinsics'*gtCorner3D; 35 | visulize_wholeroom(groundTruthBbs,cls,fullname,my_mhCorner3D) 36 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/draw_square_3d.m: -------------------------------------------------------------------------------- 1 | % Draws a square in 3D 2 | % 3 | % Args: 4 | % corners - 8x2 matrix of 2d corners. 5 | % color - matlab color code, a single character. 6 | % lineWidth - the width of each line of the square. 7 | % 8 | % Author: Nathan Silberman (silberman@cs.nyu.edu) 9 | function draw_square_3d(corners, color, lineWidth) 10 | if nargin < 2 11 | color = 'r'; 12 | end 13 | 14 | if nargin < 3 15 | lineWidth = 0.5; 16 | end 17 | 18 | vis_line(corners(1,:), corners(2,:), color, lineWidth); 19 | vis_line(corners(2,:), corners(3,:), color, lineWidth); 20 | vis_line(corners(3,:), corners(4,:), color, lineWidth); 21 | vis_line(corners(4,:), corners(1,:), color, lineWidth); 22 | 23 | vis_line(corners(5,:), corners(6,:), color, lineWidth); 24 | vis_line(corners(6,:), corners(7,:), color, lineWidth); 25 | vis_line(corners(7,:), corners(8,:), color, lineWidth); 26 | vis_line(corners(8,:), corners(5,:), color, lineWidth); 27 | 28 | vis_line(corners(1,:), corners(5,:), color, lineWidth); 29 | vis_line(corners(2,:), corners(6,:), color, lineWidth); 30 | vis_line(corners(3,:), corners(7,:), color, lineWidth); 31 | vis_line(corners(4,:), corners(8,:), color, lineWidth); 32 | end 33 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/myObjectColor.m: -------------------------------------------------------------------------------- 1 | function color = myObjectColor(objectID) 2 | 3 | % same color as the online annotator 4 | 5 | objectColors = {'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd', '#ccebc5', '#ffed6f', '#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628', '#f781bf', '#999999', '#621e15', '#e59076', '#128dcd', '#083c52', '#64c5f2', '#61afaf', '#0f7369', '#9c9da1', '#365e96', '#983334', '#77973d', '#5d437c', '#36869f', '#d1702f', '#8197c5', '#c47f80', '#acc484', '#9887b0', '#2d588a', '#58954c', '#e9a044', '#c12f32', '#723e77', '#7d807f', '#9c9ede', '#7375b5', '#4a5584', '#cedb9c', '#b5cf6b', '#8ca252', '#637939', '#e7cb94', '#e7ba52', '#bd9e39', '#8c6d31', '#e7969c', '#d6616b', '#ad494a', '#843c39', '#de9ed6', '#ce6dbd', '#a55194', '#7b4173', '#000000', '#0000FF'}; 6 | %objectColors = loadjson(urlread('http://sun3d.cs.princeton.edu/player/ObjectColors.json')); 7 | 8 | objectID = mod(objectID,length(objectColors)) + 1; 9 | 10 | color = objectColors{objectID}; 11 | 12 | color = [hex2dec(color(2:3)) hex2dec(color(4:5)) hex2dec(color(6:7))]/255; 13 | 14 | color = color * 0.6; 15 | 16 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/plotcube.m: -------------------------------------------------------------------------------- 1 | function plotcube(varargin) 2 | % PLOTCUBE - Display a 3D-cube in the current axes 3 | % 4 | % PLOTCUBE(EDGES,ORIGIN,ALPHA,COLOR) displays a 3D-cube in the current axes 5 | % with the following properties: 6 | % * EDGES : 3-elements vector that defines the length of cube edges 7 | % * ORIGIN: 3-elements vector that defines the start point of the cube 8 | % * ALPHA : scalar that defines the transparency of the cube faces (from 0 9 | % to 1) 10 | % * COLOR : 3-elements vector that defines the faces color of the cube 11 | % 12 | % Example: 13 | % >> plotcube([5 5 5],[ 2 2 2],.8,[1 0 0]); 14 | % >> plotcube([5 5 5],[10 10 10],.8,[0 1 0]); 15 | % >> plotcube([5 5 5],[20 20 20],.8,[0 0 1]); 16 | 17 | % Default input arguments 18 | inArgs = { ... 19 | [10 56 100] , ... % Default edge sizes (x,y and z) 20 | [10 10 10] , ... % Default coordinates of the origin point of the cube 21 | .7 , ... % Default alpha value for the cube's faces 22 | [1 0 0] ... % Default Color for the cube 23 | }; 24 | 25 | % Replace default input arguments by input values 26 | inArgs(1:nargin) = varargin; 27 | 28 | % Create all variables 29 | [edges,origin,alpha,clr] = deal(inArgs{:}); 30 | 31 | XYZ = { ... 32 | [0 0 0 0] [0 0 1 1] [0 1 1 0] ; ... 33 | [1 1 1 1] [0 0 1 1] [0 1 1 0] ; ... 34 | [0 1 1 0] [0 0 0 0] [0 0 1 1] ; ... 35 | [0 1 1 0] [1 1 1 1] [0 0 1 1] ; ... 36 | [0 1 1 0] [0 0 1 1] [0 0 0 0] ; ... 37 | [0 1 1 0] [0 0 1 1] [1 1 1 1] ... 38 | }; 39 | 40 | XYZ = mat2cell(... 41 | cellfun( @(x,y,z) x*y+z , ... 42 | XYZ , ... 43 | repmat(mat2cell(edges,1,[1 1 1]),6,1) , ... 44 | repmat(mat2cell(origin,1,[1 1 1]),6,1) , ... 45 | 'UniformOutput',false), ... 46 | 6,[1 1 1]); 47 | 48 | 49 | cellfun(@patch,XYZ{1},XYZ{2},XYZ{3},... 50 | repmat({clr},6,1),... 51 | repmat({'FaceAlpha'},6,1),... 52 | repmat({alpha},6,1)... 53 | ); 54 | 55 | view(3); 56 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/vis_cube.m: -------------------------------------------------------------------------------- 1 | % Visualizes a 3D bounding box. 2 | % 3 | % Args: 4 | % bb3d - 3D bounding box struct 5 | % color - matlab color code, a single character 6 | % lineWidth - the width of each line of the square 7 | % 8 | % See: 9 | % create_bounding_box_3d.m 10 | % 11 | % Author: 12 | % Nathan Silberman (silberman@cs.nyu.edu) 13 | function vis_cube(bb3d, color, lineWidth) 14 | if nargin < 3 15 | lineWidth = 0.5; 16 | end 17 | corners = get_corners_of_bb3d(bb3d); 18 | draw_square_3d(corners, color, lineWidth); 19 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/vis_line.m: -------------------------------------------------------------------------------- 1 | % Visualizes a line in 2D or 3D space 2 | % 3 | % Args: 4 | % p1 - 1x2 or 1x3 point 5 | % p2 - 1x2 or 1x3 point 6 | % color - matlab color code, a single character 7 | % lineWidth - the width of the drawn line 8 | % 9 | % Author: Nathan Silberman (silberman@cs.nyu.edu) 10 | function vis_line(p1, p2, color, lineWidth) 11 | if nargin < 3 12 | color = 'b'; 13 | end 14 | 15 | if nargin < 4 16 | lineWidth = 0.5; 17 | end 18 | 19 | % Make sure theyre the same size. 20 | assert(ndims(p1) == ndims(p2), 'Vectors are of different dimensions'); 21 | assert(all(size(p1) == size(p2)), 'Vectors are of different dimensions'); 22 | 23 | switch numel(p1) 24 | case 2 25 | line([p1(1) p2(1)], [p1(2) p2(2)], 'Color', color); 26 | case 3 27 | line([p1(1) p2(1)], [p1(2) p2(2)], [p1(3) p2(3)], 'Color', color, 'LineWidth', lineWidth); 28 | otherwise 29 | error('vectors must be either 2 or 3 dimensional'); 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/vis_point_cloud.m: -------------------------------------------------------------------------------- 1 | % Visualizes a 3D point cloud. 2 | % 3 | % Args: 4 | % points3d - Nx3 or Nx2 point cloud where N is the number of points. 5 | % colors - (optional) Nx3 vector of colors or Nx1 vector of values which which 6 | % be scaled for visualization. 7 | % sizes - (optional) Nx1 vector of point sizes or a scalar value which is applied 8 | % to every point in the cloud. 9 | % sampleSize - (optional) the maximum number of points to show. Note that since matlab 10 | % is slow, around 5000 points is a good sampleSize is practice. 11 | % 12 | % Author: Nathan Silberman (silberman@cs.nyu.edu) 13 | function vis_point_cloud(points, colors, sizes, sampleSize) 14 | removeNaN = sum(isnan(points),2)>0; 15 | points(removeNaN,:)=[]; 16 | [N, D] = size(points); 17 | assert(D == 2 || D == 3, 'points must be Nx2 or Nx3'); 18 | 19 | if ~exist('colors', 'var') || isempty(colors) 20 | norms = sqrt(sum(points.^2, 2)); 21 | colors = values2colors(norms); 22 | else 23 | colors(removeNaN,:)=[]; 24 | 25 | end 26 | 27 | if ~exist('sizes', 'var') || isempty(sizes) 28 | sizes = ones(N, 1) * 10; 29 | elseif numel(sizes) == 1 30 | sizes = sizes * ones(N, 1); 31 | elseif numel(sizes) ~= N 32 | error('sizes:size', 'sizes must be Nx1'); 33 | end 34 | 35 | if ~exist('sampleSize', 'var') || isempty(sampleSize) 36 | sampleSize = 5000; 37 | end 38 | 39 | % Sample the points, colors and sizes. 40 | N = size(points, 1); 41 | if N > sampleSize 42 | seq = randperm(size(points, 1)); 43 | seq = seq(1:sampleSize); 44 | 45 | points = points(seq, :); 46 | colors = colors(seq, :); 47 | sizes = sizes(seq); 48 | end 49 | 50 | switch size(points, 2) 51 | case 2 52 | vis_2d(points, colors, sizes); 53 | case 3 54 | vis_3d(points, colors, sizes); 55 | otherwise 56 | error('Points must be either 2 or 3d'); 57 | end 58 | 59 | axis equal; 60 | %view(0,90); 61 | %s view(0,8); 62 | end 63 | 64 | function colors = values2colors(values) 65 | values = scale_values(values); 66 | inds = ceil(values * 255) + 1; 67 | h = colormap(jet(256)); 68 | colors = h(inds, :); 69 | end 70 | 71 | function values = scale_values(values) 72 | if length(values(:))>1, 73 | values = values - min(values(:)); 74 | values = values ./ max(values(:)); 75 | end 76 | end 77 | 78 | function vis_3d(points, colors, sizes) 79 | X = points(:,1); 80 | Y = points(:,2); 81 | Z = points(:,3); 82 | scatter3(X, Y, Z, sizes, colors, 'filled'); 83 | end 84 | 85 | function vis_2d(points, colors, sizes) 86 | X = points(:,1); 87 | Y = points(:,2); 88 | scatter(X, Y, sizes, colors, 'filled'); 89 | 90 | xlabel('x'); 91 | ylabel('y'); 92 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/draw/visulize_wholeroom.m: -------------------------------------------------------------------------------- 1 | function visulize_wholeroom(BbsTight,cls,fullname,roomLayout,numofpoints2plot,savepath) 2 | if ~exist('numofpoints2plot','var'), 3 | numofpoints2plot = 5000; 4 | end 5 | sizeofpoint = max(round(200000/numofpoints2plot),3); 6 | Linewidth =5; 7 | maxhight = 1.2; 8 | vis = 'on' 9 | f = figure; 10 | set(f, 'Position', [100, 100, 1149, 1249]); 11 | if exist('fullname','var')&&~isempty(fullname); 12 | data = readframe(fullname); 13 | [rgb,points3d,~,imsize]=read3dPoints(data); 14 | vis_point_cloud(points3d,double(rgb),sizeofpoint,numofpoints2plot); 15 | maxhight = min(maxhight,max(points3d(:,3))); 16 | hold on; 17 | end 18 | if ~isempty(BbsTight) 19 | if exist('cls','var')&&~isempty(cls)&&~isfield(BbsTight,'classid') 20 | [~,classid] = ismember({BbsTight.classname},cls); 21 | else 22 | classid = [BbsTight.classid]; 23 | end 24 | end 25 | for i =1:length(BbsTight) 26 | vis_cube(BbsTight(i), myObjectColor(classid(i)),Linewidth); 27 | end 28 | hold on; 29 | if exist('roomLayout','var')&&~isempty(roomLayout) 30 | drawRoom(roomLayout,'b',Linewidth,maxhight); 31 | end 32 | axis equal; 33 | axis tight; 34 | axis off; 35 | view(16,32); 36 | if exist('savepath','var')&&~isempty(savepath) 37 | saveas(f,savepath); 38 | close(f) 39 | end 40 | 41 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/extract_rgbd_data.m: -------------------------------------------------------------------------------- 1 | %% Dump SUNRGBD data to our format 2 | % for each sample, we have RGB image, 2d boxes, point cloud (in camera 3 | % coordinate), calibration and 3d boxes 4 | % 5 | % Author: Charles R. Qi 6 | % Date: 09/27/2017 7 | % 8 | clear; close all; clc; 9 | addpath(genpath('.')) 10 | % load('./Metadata/SUNRGBDMeta.mat') 11 | % load('./Metadata/SUNRGBD2Dseg.mat') 12 | %load('../SUNRGBDMeta3DBB_v2.mat'); % SUNRGBDMeta2DBB 13 | %load('../SUNRGBDMeta2DBB_v2.mat'); % SUNRGBDMeta 14 | load('./Metadata/SUNRGBDMeta.mat'); % SUNRGBDMeta 15 | depth_folder = 'mysunrgbd/depth/'; 16 | image_folder = 'mysunrgbd/image/'; 17 | calib_folder = 'mysunrgbd/calib/'; 18 | %label_folder = 'mysunrgbd/label/'; 19 | label_folder = 'mysunrgbd/label_dimension/'; 20 | mkdir(depth_folder); 21 | mkdir(image_folder); 22 | mkdir(calib_folder); 23 | mkdir(label_folder); 24 | %% Read 25 | parfor imageId = 1:10335 26 | imageId 27 | % if imageId == 10 28 | % break 29 | % end 30 | try 31 | data = SUNRGBDMeta(imageId); 32 | data.depthpath(1:16) = ''; 33 | data.depthpath = strcat('/data/rqi/SUNRGBD',data.depthpath); 34 | data.rgbpath(1:16) = ''; 35 | data.rgbpath = strcat('/data/rqi/SUNRGBD',data.rgbpath); 36 | 37 | % Write point cloud in depth map 38 | [rgb,points3d,depthInpaint,imsize]=read3dPoints(data); 39 | rgb(isnan(points3d(:,1)),:) = []; 40 | points3d(isnan(points3d(:,1)),:) = []; 41 | points3d_rgb = [points3d, rgb]; 42 | filename = strcat(num2str(imageId,'%06d'), '.txt'); 43 | dlmwrite(strcat(depth_folder, filename), points3d_rgb, 'delimiter', ' '); 44 | 45 | % Write images 46 | copyfile(data.rgbpath, sprintf('%s/%06d.jpg', image_folder, imageId)); 47 | 48 | % Write calibration 49 | dlmwrite(strcat(calib_folder, filename), data.Rtilt(:)', 'delimiter', ' '); 50 | dlmwrite(strcat(calib_folder, filename), data.K(:)', 'delimiter', ' ', '-append'); 51 | 52 | % Write 2D and 3D box label 53 | %data2d = SUNRGBDMeta2DBB(imageId); 54 | data2d = data; 55 | fid = fopen(strcat(label_folder, filename), 'w'); 56 | for j = 1:length(data.groundtruth3DBB) 57 | %if data2d.groundtruth2DBB(j).has3dbox == 0 58 | % continue 59 | %end 60 | centroid = data.groundtruth3DBB(j).centroid; 61 | classname = data.groundtruth3DBB(j).classname; 62 | orientation = data.groundtruth3DBB(j).orientation; 63 | coeffs = abs(data.groundtruth3DBB(j).coeffs); 64 | [new_basis, new_coeffs] = order_basis(data.groundtruth3DBB(j).basis, coeffs, centroid); 65 | box2d = data2d.groundtruth2DBB(j).gtBb2D; 66 | %assert(strcmp(data2d.groundtruth2DBB(j).classname, classname)); 67 | fprintf(fid, '%s %d %d %d %d %f %f %f %f %f %f %f %f %f %f %f %f\n', classname, box2d(1), box2d(2), box2d(3), box2d(4), centroid(1), centroid(2), centroid(3), coeffs(1), coeffs(2), coeffs(3), new_basis(1,1), new_basis(1,2), new_basis(2,1), new_basis(2,2), orientation(1), orientation(2)); 68 | end 69 | fclose(fid); 70 | catch 71 | end 72 | 73 | end 74 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/getSequenceName.m: -------------------------------------------------------------------------------- 1 | function sequenceName = getSequenceName(thispath,dataRoot) 2 | if ~exist('dataRoot','var'), 3 | dataRoot = '/n/fs/sun3d/data/'; 4 | end 5 | sequenceName = thispath(length(dataRoot):end); 6 | while sequenceName(1)=='/',sequenceName =sequenceName(2:end);end 7 | while sequenceName(end)=='/',sequenceName =sequenceName(1:end-1);end 8 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/AUTHORS.txt: -------------------------------------------------------------------------------- 1 | The author of "jsonlab" toolbox is Qianqian Fang. Qianqian 2 | is currently an Assistant Professor at Massachusetts General Hospital, 3 | Harvard Medical School. 4 | 5 | Address: Martinos Center for Biomedical Imaging, 6 | Massachusetts General Hospital, 7 | Harvard Medical School 8 | Bldg 149, 13th St, Charlestown, MA 02129, USA 9 | URL: http://nmr.mgh.harvard.edu/~fangq/ 10 | Email: or 11 | 12 | 13 | The script loadjson.m was built upon previous works by 14 | 15 | - Nedialko Krouchev: http://www.mathworks.com/matlabcentral/fileexchange/25713 16 | date: 2009/11/02 17 | - François Glineur: http://www.mathworks.com/matlabcentral/fileexchange/23393 18 | date: 2009/03/22 19 | - Joel Feenstra: http://www.mathworks.com/matlabcentral/fileexchange/20565 20 | date: 2008/07/03 21 | 22 | 23 | This toolbox contains patches submitted by the following contributors: 24 | 25 | - Blake Johnson 26 | part of revision 341 27 | 28 | - Niclas Borlin 29 | various fixes in revision 394, including 30 | - loadjson crashes for all-zero sparse matrix. 31 | - loadjson crashes for empty sparse matrix. 32 | - Non-zero size of 0-by-N and N-by-0 empty matrices is lost after savejson/loadjson. 33 | - loadjson crashes for sparse real column vector. 34 | - loadjson crashes for sparse complex column vector. 35 | - Data is corrupted by savejson for sparse real row vector. 36 | - savejson crashes for sparse complex row vector. 37 | 38 | - Yul Kang 39 | patches for svn revision 415. 40 | - savejson saves an empty cell array as [] instead of null 41 | - loadjson differentiates an empty struct from an empty array 42 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/ChangeLog.txt: -------------------------------------------------------------------------------- 1 | ============================================================================ 2 | 3 | JSONlab - a toolbox to encode/decode JSON/UBJSON files in MATLAB/Octave 4 | 5 | ---------------------------------------------------------------------------- 6 | 7 | JSONlab ChangeLog (key features marked by *): 8 | 9 | == JSONlab 1.0.0-RC1 (codename: Optimus - RC1), FangQ == 10 | 11 | 2014/09/17 fix several compatibility issues when running on octave versions 3.2-3.8 12 | 2014/09/17 support 2D cell and struct arrays in both savejson and saveubjson 13 | 2014/08/04 escape special characters in a JSON string 14 | 2014/02/16 fix a bug when saving ubjson files 15 | 16 | == JSONlab 0.9.9 (codename: Optimus - beta), FangQ == 17 | 18 | 2014/01/22 use binary read and write in saveubjson and loadubjson 19 | 20 | == JSONlab 0.9.8-1 (codename: Optimus - alpha update 1), FangQ == 21 | 22 | 2013/10/07 better round-trip conservation for empty arrays and structs (patch submitted by Yul Kang) 23 | 24 | == JSONlab 0.9.8 (codename: Optimus - alpha), FangQ == 25 | 2013/08/23 *universal Binary JSON (UBJSON) support, including both saveubjson and loadubjson 26 | 27 | == JSONlab 0.9.1 (codename: Rodimus, update 1), FangQ == 28 | 2012/12/18 *handling of various empty and sparse matrices (fixes submitted by Niclas Borlin) 29 | 30 | == JSONlab 0.9.0 (codename: Rodimus), FangQ == 31 | 32 | 2012/06/17 *new format for an invalid leading char, unpacking hex code in savejson 33 | 2012/06/01 support JSONP in savejson 34 | 2012/05/25 fix the empty cell bug (reported by Cyril Davin) 35 | 2012/04/05 savejson can save to a file (suggested by Patrick Rapin) 36 | 37 | == JSONlab 0.8.1 (codename: Sentiel, Update 1), FangQ == 38 | 39 | 2012/02/28 loadjson quotation mark escape bug, see http://bit.ly/yyk1nS 40 | 2012/01/25 patch to handle root-less objects, contributed by Blake Johnson 41 | 42 | == JSONlab 0.8.0 (codename: Sentiel), FangQ == 43 | 44 | 2012/01/13 *speed up loadjson by 20 fold when parsing large data arrays in matlab 45 | 2012/01/11 remove row bracket if an array has 1 element, suggested by Mykel Kochenderfer 46 | 2011/12/22 *accept sequence of 'param',value input in savejson and loadjson 47 | 2011/11/18 fix struct array bug reported by Mykel Kochenderfer 48 | 49 | == JSONlab 0.5.1 (codename: Nexus Update 1), FangQ == 50 | 51 | 2011/10/21 fix a bug in loadjson, previous code does not use any of the acceleration 52 | 2011/10/20 loadjson supports JSON collections - concatenated JSON objects 53 | 54 | == JSONlab 0.5.0 (codename: Nexus), FangQ == 55 | 56 | 2011/10/16 package and release jsonlab 0.5.0 57 | 2011/10/15 *add json demo and regression test, support cpx numbers, fix double quote bug 58 | 2011/10/11 *speed up readjson dramatically, interpret _Array* tags, show data in root level 59 | 2011/10/10 create jsonlab project, start jsonlab website, add online documentation 60 | 2011/10/07 *speed up savejson by 25x using sprintf instead of mat2str, add options support 61 | 2011/10/06 *savejson works for structs, cells and arrays 62 | 2011/09/09 derive loadjson from JSON parser from MATLAB Central, draft savejson.m 63 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/LICENSE_BSD.txt: -------------------------------------------------------------------------------- 1 | Copyright 2011-2014 Qianqian Fang . All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are 4 | permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of 7 | conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list 10 | of conditions and the following disclaimer in the documentation and/or other materials 11 | provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ''AS IS'' AND ANY EXPRESS OR IMPLIED 14 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 15 | FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS 16 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 18 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 19 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 20 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 21 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | 23 | The views and conclusions contained in the software and documentation are those of the 24 | authors and should not be interpreted as representing official policies, either expressed 25 | or implied, of the copyright holders. 26 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/README.txt -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/jsonopt.m: -------------------------------------------------------------------------------- 1 | function val=jsonopt(key,default,varargin) 2 | % 3 | % val=jsonopt(key,default,optstruct) 4 | % 5 | % setting options based on a struct. The struct can be produced 6 | % by varargin2struct from a list of 'param','value' pairs 7 | % 8 | % authors:Qianqian Fang (fangq nmr.mgh.harvard.edu) 9 | % 10 | % $Id: loadjson.m 371 2012-06-20 12:43:06Z fangq $ 11 | % 12 | % input: 13 | % key: a string with which one look up a value from a struct 14 | % default: if the key does not exist, return default 15 | % optstruct: a struct where each sub-field is a key 16 | % 17 | % output: 18 | % val: if key exists, val=optstruct.key; otherwise val=default 19 | % 20 | % license: 21 | % Simplified BSD License 22 | % 23 | % -- this function is part of jsonlab toolbox (http://iso2mesh.sf.net/cgi-bin/index.cgi?jsonlab) 24 | % 25 | 26 | val=default; 27 | if(nargin<=2) return; end 28 | opt=varargin{1}; 29 | if(isstruct(opt) && isfield(opt,key)) 30 | val=getfield(opt,key); 31 | end 32 | 33 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/mergestruct.m: -------------------------------------------------------------------------------- 1 | function s=mergestruct(s1,s2) 2 | % 3 | % s=mergestruct(s1,s2) 4 | % 5 | % merge two struct objects into one 6 | % 7 | % authors:Qianqian Fang (fangq nmr.mgh.harvard.edu) 8 | % date: 2012/12/22 9 | % 10 | % input: 11 | % s1,s2: a struct object, s1 and s2 can not be arrays 12 | % 13 | % output: 14 | % s: the merged struct object. fields in s1 and s2 will be combined in s. 15 | % 16 | % license: 17 | % Simplified BSD License 18 | % 19 | % -- this function is part of jsonlab toolbox (http://iso2mesh.sf.net/cgi-bin/index.cgi?jsonlab) 20 | % 21 | 22 | if(~isstruct(s1) || ~isstruct(s2)) 23 | error('input parameters contain non-struct'); 24 | end 25 | if(length(s1)>1 || length(s2)>1) 26 | error('can not merge struct arrays'); 27 | end 28 | fn=fieldnames(s2); 29 | s=s1; 30 | for i=1:length(fn) 31 | s=setfield(s,fn{i},getfield(s2,fn{i})); 32 | end 33 | 34 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/jsonlab/varargin2struct.m: -------------------------------------------------------------------------------- 1 | function opt=varargin2struct(varargin) 2 | % 3 | % opt=varargin2struct('param1',value1,'param2',value2,...) 4 | % or 5 | % opt=varargin2struct(...,optstruct,...) 6 | % 7 | % convert a series of input parameters into a structure 8 | % 9 | % authors:Qianqian Fang (fangq nmr.mgh.harvard.edu) 10 | % date: 2012/12/22 11 | % 12 | % input: 13 | % 'param', value: the input parameters should be pairs of a string and a value 14 | % optstruct: if a parameter is a struct, the fields will be merged to the output struct 15 | % 16 | % output: 17 | % opt: a struct where opt.param1=value1, opt.param2=value2 ... 18 | % 19 | % license: 20 | % Simplified BSD License 21 | % 22 | % -- this function is part of jsonlab toolbox (http://iso2mesh.sf.net/cgi-bin/index.cgi?jsonlab) 23 | % 24 | 25 | len=length(varargin); 26 | opt=struct; 27 | if(len==0) return; end 28 | i=1; 29 | while(i<=len) 30 | if(isstruct(varargin{i})) 31 | opt=mergestruct(opt,varargin{i}); 32 | elseif(ischar(varargin{i}) && i=6&&size(bb1input,2)<10 13 | bb1 = bb1input; 14 | bb1(:,4:6) = bb1input(:,4:6) + bb1input(:,1:3); 15 | xMax = bb1(:,4); 16 | yMax = bb1(:,5); 17 | bb1 = [bb1(:,1) bb1(:,2) xMax bb1(:,2) xMax yMax bb1(:,1) yMax bb1(:,3) bb1(:,6)]; 18 | elseif size(bb1input,2)==1 19 | for i = 1:nBb1 20 | corners = get_corners_of_bb3d(bb1input(i)); 21 | bb1(i,:) = [reshape([corners(1:4,1) corners(1:4,2)]',1,[]) min(corners([1 end],3)) max(corners([1 end],3))]; 22 | end 23 | elseif size(bb1input,2)>=10 24 | bb1 = bb1input(:,1:10); 25 | end 26 | 27 | for i = 1:nBb2 28 | corners = get_corners_of_bb3d(bb2struct(i)); 29 | bb2(i,:) = [reshape([corners(1:4,1) corners(1:4,2)]',1,[]) min(corners([1 end],3)) max(corners([1 end],3))]; 30 | end 31 | 32 | bb1 = bb1'; 33 | bb2 = bb2'; 34 | 35 | % a ha, we are done with dirty format conversion 36 | 37 | 38 | nBb1 = size(bb1,2); 39 | nBb2 = size(bb2,2); 40 | 41 | volume1 = cuboidVolume(bb1); 42 | volume2 = cuboidVolume(bb2); 43 | intersection = cuboidIntersectionVolume(double(bb1),double(bb2)); 44 | 45 | %{ 46 | volume1(6818) 47 | volume2(1) 48 | intersection(6818,1) 49 | cuboidVolume(bb1(:,6818)) 50 | cuboidVolume(bb2(:,1)) 51 | cuboidIntersectionVolume(bb1(:,6818),bb2(:,1)) 52 | cuboidDraw(bb1(:,6818)) 53 | cuboidDraw(bb2(:,1)) 54 | %} 55 | 56 | union = repmat(volume1',1,nBb2)+repmat(volume2,nBb1,1)-intersection; 57 | 58 | scoreMatrix = intersection ./ union; 59 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/create_bounding_box_3d.m: -------------------------------------------------------------------------------- 1 | % Helper method for quickly creating bounding boxes. 2 | % 3 | % Args: 4 | % basis2d - 2x2 matrix for the basis in the XY plane 5 | % centroid - 1x3 vector for the 3D centroid of the bounding box. 6 | % coeffs - 1x3 vector for the radii in each dimension (x, y, and z) 7 | % 8 | % Returns: 9 | % bb - a bounding box struct. 10 | % 11 | % Author: Nathan Silberman (silberman@cs.nyu.edu) 12 | function bb = create_bounding_box_3d(basis2d, centroid, coeffs) 13 | assert(all(size(basis2d) == [2, 2])); 14 | assert(numel(centroid) == 3); 15 | assert(numel(coeffs) == 3); 16 | 17 | centroid = centroid(:)'; 18 | coeffs = coeffs(:)'; 19 | 20 | bb = struct(); 21 | bb.basis = zeros(3,3); 22 | bb.basis(3,:) = [0 0 1]; 23 | bb.basis(1:2,1:2) = basis2d; 24 | 25 | bb.centroid = centroid; 26 | bb.coeffs = coeffs; 27 | % bb.volume = prod(2 * bb.coeffs); 28 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidIntersectionVolume.c: -------------------------------------------------------------------------------- 1 | /* 2 | mex gpc.c cuboidIntersectionVolume.c -O -output cuboidIntersectionVolume % optimized 3 | mex gpc.c cuboidIntersectionVolume.c -argcheck -output cuboidIntersectionVolume % with argument checking 4 | mex gpc.c cuboidIntersectionVolume.c -g -output cuboidIntersectionVolume % for debugging 5 | */ 6 | 7 | #include "mex.h" 8 | #include "gpc.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | /* =============================== 15 | Constants 16 | ===============================*/ 17 | 18 | #define MAX(x, y) (((x) > (y)) ? (x) : (y)) 19 | #define MIN(x, y) (((x) < (y)) ? (x) : (y)) 20 | 21 | /* ================================= 22 | GATEWAY ROUTINE TO MATLAB 23 | =================================*/ 24 | 25 | void mexFunction(int nlhs, mxArray *plhs[], 26 | int nrhs, const mxArray *prhs[]) 27 | { 28 | unsigned int i,j,n1,n2,c,v,m; 29 | double* volume; 30 | double* b1; 31 | double* b2; 32 | unsigned int joffset; 33 | double zOverlap; 34 | double areaOverlap; 35 | double* result_vertex; 36 | gpc_polygon subject, clip, result; 37 | gpc_vertex_list subject_contour; 38 | gpc_vertex_list clip_contour; 39 | int hole = 0; 40 | 41 | subject.num_contours = 1; 42 | subject.hole = &hole; 43 | subject.contour = &subject_contour; 44 | subject.contour[0].num_vertices = 4; 45 | 46 | clip.num_contours = 1; 47 | clip.hole = &hole; 48 | clip.contour = &clip_contour; 49 | clip.contour[0].num_vertices = 4; 50 | 51 | n2 = mxGetN(prhs[0]); 52 | n1 = mxGetN(prhs[1]); 53 | b2 = mxGetPr(prhs[0]); 54 | b1 = mxGetPr(prhs[1]); 55 | 56 | plhs[0] = mxCreateNumericMatrix(n2, n1, mxDOUBLE_CLASS, mxREAL); 57 | volume = (double*) mxGetData(plhs[0]); 58 | 59 | for (i=0; i0){ 67 | /* get intersection */ 68 | clip.contour[0].vertex = (gpc_vertex *)(b2+joffset); 69 | 70 | gpc_polygon_clip(1, &subject, &clip, &result); 71 | 72 | 73 | if (result.num_contours>0 && result.contour[0].num_vertices > 2) { 74 | /* compute area of intersection */ 75 | 76 | /* 77 | * http://www.mathopenref.com/coordpolygonarea.html 78 | * Green's theorem for the functions -y and x; 79 | http://stackoverflow.com/questions/451426/how-do-i-calculate-the-surface-area-of-a-2d-polygon 80 | */ 81 | result_vertex = (double*)(result.contour[0].vertex); 82 | m = result.contour[0].num_vertices; 83 | areaOverlap = (result_vertex[2*m-2]*result_vertex[1]-result_vertex[2*m-1]*result_vertex[0]); 84 | for (v= 1; v < m; v++) 85 | { 86 | areaOverlap += (result_vertex[v*2-2]*result_vertex[v*2+1]-result_vertex[v*2-1]*result_vertex[v*2]); 87 | } 88 | *volume = zOverlap * 0.5 * fabs(areaOverlap); 89 | 90 | 91 | } 92 | gpc_free_polygon(&result); 93 | } 94 | ++volume; 95 | } 96 | b1+=10; 97 | } 98 | /* 99 | gpc_free_polygon(&subject); 100 | gpc_free_polygon(&clip); 101 | 102 | mxFree(subject.hole); 103 | mxFree(subject.contour); 104 | mxFree(clip.hole); 105 | mxFree(clip.contour); 106 | **/ 107 | 108 | } 109 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidIntersectionVolume.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidIntersectionVolume.mexa64 -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidIntersectionVolume.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidIntersectionVolume.mexmaci64 -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/cuboidVolume.m: -------------------------------------------------------------------------------- 1 | function volume = cuboidVolume(bb) 2 | 3 | dis = (bb([1 2 5 6],:)-bb([3 4 3 4],:)).^2; 4 | 5 | volume = (bb(10,:)-bb(9,:)).*sqrt((dis(1,:)+dis(2,:)).*(dis(3,:)+dis(4,:))); -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/get_corners_of_bb3d.m: -------------------------------------------------------------------------------- 1 | % Gets the 3D coordinates of the corners of a 3D bounding box. 2 | % 3 | % Args: 4 | % bb3d - 3D bounding box struct. 5 | % 6 | % Returns: 7 | % corners - 8x3 matrix of 3D coordinates. 8 | % 9 | % See: 10 | % create_bounding_box_3d.m 11 | % 12 | % Author: Nathan Silberman (silberman@cs.nyu.edu) 13 | function corners = get_corners_of_bb3d(bb3d) 14 | corners = zeros(8, 3); 15 | 16 | % Order the bases. 17 | [~, inds] = sort(abs(bb3d.basis(:,1)), 'descend'); 18 | basis = bb3d.basis(inds, :); 19 | coeffs = bb3d.coeffs(inds); 20 | 21 | [~, inds] = sort(abs(basis(2:3,2)), 'descend'); 22 | if inds(1) == 2 23 | basis(2:3,:) = flipdim(basis(2:3,:), 1); 24 | coeffs(2:3) = flipdim(coeffs(2:3), 2); 25 | end 26 | 27 | % Now, we know the basis vectors are orders X, Y, Z. Next, flip the basis 28 | % vectors towards the viewer. 29 | basis = flip_towards_viewer(basis, repmat(bb3d.centroid, [3 1])); 30 | 31 | coeffs = abs(coeffs); 32 | 33 | corners(1,:) = -basis(1,:) * coeffs(1) + basis(2,:) * coeffs(2) + basis(3,:) * coeffs(3); 34 | corners(2,:) = basis(1,:) * coeffs(1) + basis(2,:) * coeffs(2) + basis(3,:) * coeffs(3); 35 | corners(3,:) = basis(1,:) * coeffs(1) + -basis(2,:) * coeffs(2) + basis(3,:) * coeffs(3); 36 | corners(4,:) = -basis(1,:) * coeffs(1) + -basis(2,:) * coeffs(2) + basis(3,:) * coeffs(3); 37 | 38 | corners(5,:) = -basis(1,:) * coeffs(1) + basis(2,:) * coeffs(2) + -basis(3,:) * coeffs(3); 39 | corners(6,:) = basis(1,:) * coeffs(1) + basis(2,:) * coeffs(2) + -basis(3,:) * coeffs(3); 40 | corners(7,:) = basis(1,:) * coeffs(1) + -basis(2,:) * coeffs(2) + -basis(3,:) * coeffs(3); 41 | corners(8,:) = -basis(1,:) * coeffs(1) + -basis(2,:) * coeffs(2) + -basis(3,:) * coeffs(3); 42 | 43 | corners = corners + repmat(bb3d.centroid, [8 1]); 44 | end 45 | 46 | function normals = flip_towards_viewer(normals, points) 47 | points = points ./ repmat(sqrt(sum(points.^2, 2)), [1, 3]); 48 | 49 | proj = sum(points .* normals, 2); 50 | 51 | flip = proj > 0; 52 | normals(flip, :) = -normals(flip, :); 53 | end 54 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/project3dPtsTo2d.m: -------------------------------------------------------------------------------- 1 | function [points2d,z3] = project3dPtsTo2d(points3d,Rtilt,crop,K) 2 | %% inverse of get_aligned_point_cloud 3 | points3d =[Rtilt'*points3d']'; 4 | 5 | %% inverse rgb_plane2rgb_world 6 | 7 | 8 | % Now, swap Y and Z. 9 | points3d(:, [2, 3]) = points3d(:,[3, 2]); 10 | 11 | % Make the original consistent with the camera location: 12 | x3 = points3d(:,1); 13 | y3 = -points3d(:,2); 14 | z3 = points3d(:,3); 15 | 16 | xx = x3 * K(1,1) ./ z3 + K(1,3); 17 | yy = y3 * K(2,2) ./ z3 + K(2,3); 18 | 19 | 20 | xx = xx - crop(2) + 1; 21 | yy = yy - crop(1) + 1; 22 | 23 | points2d = [xx yy]; 24 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/mBB/projectStructBbsTo2d.m: -------------------------------------------------------------------------------- 1 | function [bb2d,bb2dDraw] = projectStructBbsTo2d(bb,Rtilt,crop,K) 2 | if isempty(crop) 3 | crop =[1,1]; 4 | end 5 | if isempty(bb), 6 | bb2d =[]; 7 | bb2dDraw =[]; 8 | else 9 | nBbs = numel(bb); 10 | if isfield(bb,'confidence'), 11 | conf = [bb.confidence]; 12 | conf = conf(:); 13 | else 14 | conf = ones(nBbs,1); 15 | end 16 | points3d = zeros(8*nBbs,3); 17 | for i = 1:nBbs, 18 | corners = get_corners_of_bb3d(bb(i)); 19 | points3d((i-1)*8+(1:8),:) = corners([8 4 5 1 7 3 6 2],:); 20 | %points3d((i-1)*8+(1:8),:) = corners([5,1,8,4,6,2,3,7],:); 21 | end 22 | 23 | points2d = project3dPtsTo2d(points3d,Rtilt,crop,K); 24 | 25 | bb2d = zeros(nBbs,5); 26 | bb2d(:,1) = min(reshape(points2d(:,1),[8,nBbs]),[],1); 27 | bb2d(:,2) = min(reshape(points2d(:,2),[8,nBbs]),[],1); 28 | bb2d(:,3) = max(reshape(points2d(:,1),[8,nBbs]),[],1); 29 | bb2d(:,4) = max(reshape(points2d(:,2),[8,nBbs]),[],1); 30 | bb2d(:,3) = bb2d(:,3) - bb2d(:,1); 31 | bb2d(:,4) = bb2d(:,4) - bb2d(:,2); 32 | bb2d(:,5) = conf; 33 | 34 | bb2dDraw = zeros(nBbs,17); 35 | pts = points2d'; 36 | pts = pts(:); 37 | pts = reshape(pts,[16,nBbs]); 38 | bb2dDraw(:,1:16) = pts'; 39 | bb2dDraw(:,17) = conf; 40 | end 41 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/order_basis.m: -------------------------------------------------------------------------------- 1 | function [new_basis, new_coeffs] = order_basis(basis, coeffs, centroid) 2 | % Order the bases. 3 | [~, inds] = sort(abs(basis(:,1)), 'descend'); 4 | basis = basis(inds, :); 5 | coeffs = coeffs(inds); 6 | 7 | [~, inds] = sort(abs(basis(2:3,2)), 'descend'); 8 | if inds(1) == 2 9 | basis(2:3,:) = flipdim(basis(2:3,:), 1); 10 | coeffs(2:3) = flipdim(coeffs(2:3), 2); 11 | end 12 | 13 | % Now, we know the basis vectors are orders X, Y, Z. Next, flip the basis 14 | % vectors towards the viewer. 15 | new_basis = flip_towards_viewer(basis, repmat(centroid, [3 1])); 16 | new_coeffs = coeffs; 17 | end 18 | 19 | function normals = flip_towards_viewer(normals, points) 20 | points = points ./ repmat(sqrt(sum(points.^2, 2)), [1, 3]); 21 | 22 | proj = sum(points .* normals, 2); 23 | 24 | flip = proj > 0; 25 | normals(flip, :) = -normals(flip, :); 26 | end 27 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/readData/read3dPoints.m: -------------------------------------------------------------------------------- 1 | function [rgb,points3d,depthInpaint,imsize]=read3dPoints(data) 2 | depthVis = imread(data.depthpath); 3 | imsize = size(depthVis); 4 | depthInpaint = bitor(bitshift(depthVis,-3), bitshift(depthVis,16-3)); 5 | depthInpaint = single(depthInpaint)/1000; 6 | depthInpaint(depthInpaint >8)=8; 7 | [rgb,points3d]=read_3d_pts_general(depthInpaint,data.K,size(depthInpaint),data.rgbpath); 8 | points3d = (data.Rtilt*points3d')'; 9 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/readData/read_3d_pts_general.m: -------------------------------------------------------------------------------- 1 | function [rgb,points3d,points3dMatrix]=read_3d_pts_general(depthInpaint,K,depthInpaintsize,imageName,crop) 2 | % K is [fx 0 cx; 0 fy cy; 0 0 1]; 3 | % for uncrop image crop =[1,1]; 4 | % imageName is the full path to image 5 | cx = K(1,3); cy = K(2,3); 6 | fx = K(1,1); fy = K(2,2); 7 | invalid = depthInpaint==0; 8 | if ~isempty(imageName) 9 | im = imread(imageName); 10 | rgb = im2double(im); 11 | else 12 | rgb =double(cat(3,zeros(depthInpaintsize(1),depthInpaintsize(2)),... 13 | ones(depthInpaintsize(1),depthInpaintsize(2)),... 14 | zeros(depthInpaintsize(1),depthInpaintsize(2)))); 15 | end 16 | rgb = reshape(rgb, [], 3); 17 | %3D points 18 | [x,y] = meshgrid(1:depthInpaintsize(2), 1:depthInpaintsize(1)); 19 | x3 = (x-cx).*depthInpaint*1/fx; 20 | y3 = (y-cy).*depthInpaint*1/fy; 21 | z3 = depthInpaint; 22 | points3dMatrix =cat(3,x3,z3,-y3); 23 | points3dMatrix(cat(3,invalid,invalid,invalid))=NaN; 24 | points3d = [x3(:) z3(:) -y3(:)]; 25 | points3d(invalid(:),:) =NaN; 26 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/readframeSUNRGBD.m: -------------------------------------------------------------------------------- 1 | function data = readframeSUNRGBD(thispath,dataRoot,cls,bbmode) 2 | % example code to read annotation from ".json" file. 3 | % thispath: full path to the data folder. 4 | % dataRoot: root directory of all data folder. 5 | % cls : object category of ground truth to load. If not speficfy, the code will load all ground truth. 6 | if ~exist('cls','var') 7 | cls =[]; 8 | end 9 | if ~exist('bbmode','var') 10 | bbmode ='2Dbb'; 11 | end 12 | 13 | if ~exist('dataRoot','var')||isempty(dataRoot) 14 | dataRoot = '/n/fs/sun3d/data/'; 15 | end 16 | sequenceName = getSequenceName(thispath,dataRoot); 17 | if ~exist(thispath,'dir') 18 | data.sequenceName = sequenceName; 19 | data.valid = 0; 20 | return; 21 | end 22 | indd = find(sequenceName=='/'); 23 | sensorType = sequenceName(indd(1)+1:indd(2)-1); 24 | % get K 25 | fID = fopen([thispath '/intrinsics.txt'],'r'); 26 | K = reshape(fscanf(fID,'%f'),[3,3])'; 27 | fclose(fID); 28 | 29 | % get image and depth path 30 | depthpath = dir([thispath '/depth/' '/*.png']); 31 | depthname = depthpath(1).name; 32 | depthpath = [thispath '/depth/' depthpath(1).name]; 33 | 34 | rgbpath = dir([thispath '/image/' '/*.jpg']); 35 | rgbname = rgbpath(1).name; 36 | rgbpath = [thispath '/image/' rgbpath(1).name]; 37 | 38 | if exist(sprintf('%s/annotation3Dfinal/index.json',thispath),'file') 39 | annoteImage =loadjson(sprintf('%s/annotation3Dfinal/index.json',thispath)); 40 | % get Box 41 | filename = dir([fullfile(thispath,'extrinsics') '/*.txt']); 42 | Rtilt = dlmread([fullfile(thispath,'extrinsics') '/' filename(end).name]); 43 | Rtilt = Rtilt(1:3,1:3); 44 | anno_extrinsics = Rtilt; 45 | % convert it into matlab coordinate 46 | Rtilt = [1 0 0; 0 0 1 ;0 -1 0]*Rtilt*[1 0 0; 0 0 -1 ;0 1 0]; 47 | 48 | 49 | cnt =1; 50 | for obji =1:length(annoteImage.objects) 51 | annoteobject =annoteImage.objects(obji); 52 | if ~isempty(annoteobject)&&~isempty(annoteobject{1})&&~isempty(annoteobject{1}.polygon) 53 | annoteobject =annoteobject{1}; 54 | box = annoteobject.polygon{1}; 55 | 56 | % class name and label 57 | ind = find(annoteobject.name==':'); 58 | if isempty(ind) 59 | classname = annoteobject.name; 60 | labelname =''; 61 | else 62 | if ismember(annoteobject.name(ind-1),{'_',' '}), 63 | clname = annoteobject.name(1:ind-2); 64 | else 65 | clname = annoteobject.name(1:ind-1); 66 | end 67 | %[~,classId]= ismember(clname,classNames); 68 | classname = clname; 69 | labelname = annoteobject.name(ind+2:end); 70 | %[~,label]= ismember(Labelname,labelNames); 71 | end 72 | if ismember(classname,{'wall','floor','ceiling'})||(~isempty(cls)&&~(sum(ismember(cls,{classname}))>0)), 73 | continue; 74 | end 75 | 76 | 77 | x =box.X; 78 | y =box.Z; 79 | vector1 =[x(2)-x(1),y(2)-y(1),0]; 80 | coeff1 =norm(vector1); 81 | vector1 =vector1/norm(vector1); 82 | vector2 =[x(3)-x(2),y(3)-y(2),0]; 83 | coeff2 = norm(vector2); 84 | vector2 =vector2/norm(vector2); 85 | up = cross(vector1,vector2); 86 | vector1 = vector1*up(3)/up(3); 87 | vector2 = vector2*up(3)/up(3); 88 | zmax =-box.Ymax; 89 | zmin =-box.Ymin; 90 | centroid2D = [0.5*(x(1)+x(3)); 0.5*(y(1)+y(3))]; 91 | 92 | thisbb.basis = [vector1;vector2; 0 0 1]; % one row is one basis 93 | thisbb.coeffs = abs([coeff1, coeff2, zmax-zmin])/2; 94 | thisbb.centroid = [centroid2D(1), centroid2D(2), 0.5*(zmin+zmax)]; 95 | thisbb.classname = classname; 96 | thisbb.labelname = labelname; 97 | thisbb.sequenceName = sequenceName; 98 | orientation = [([0.5*(x(2)+x(1)),0.5*(y(2)+y(1))] - centroid2D(:)'), 0]; 99 | thisbb.orientation = orientation/norm(orientation); 100 | 101 | if strcmp(bbmode,'2Dbb'), 102 | [bb2d,bb2dDraw] = projectStructBbsTo2d(thisbb,Rtilt,[],K); 103 | %gtBb2D = crop2DBB(gtBb2D,427,561); 104 | thisbb.gtBb2D = bb2d(1:4); 105 | end 106 | groundtruth3DBB(cnt) =thisbb; 107 | cnt=cnt+1; 108 | end 109 | end 110 | if cnt==1,groundtruth3DBB =[];end 111 | else 112 | groundtruth3DBB =[]; 113 | filename = dir([fullfile(thispath,'extrinsics') '/*.txt']); 114 | Rtilt = dlmread([fullfile(thispath,'extrinsics') '/' filename(end).name]); 115 | Rtilt = Rtilt(1:3,1:3); 116 | anno_extrinsics = Rtilt; 117 | Rtilt = [1 0 0; 0 0 1 ;0 -1 0]*Rtilt*[1 0 0; 0 0 -1 ;0 1 0]; 118 | 119 | end 120 | % read in room 121 | gtCorner3D =[]; 122 | if exist([thispath '/annotation3Dlayout/index.json'],'file') 123 | json=loadjson([thispath '/annotation3Dlayout/index.json']); 124 | for objectID=1:length(json.objects) 125 | try 126 | groundTruth = json.objects{objectID}.polygon{1}; 127 | numCorners = length(groundTruth.X); 128 | 129 | gtCorner3D(1,:) = [groundTruth.X groundTruth.X]; 130 | gtCorner3D(2,:) = [repmat(groundTruth.Ymin,[1 numCorners]) repmat(groundTruth.Ymax,[1 numCorners])]; 131 | gtCorner3D(3,:) = [groundTruth.Z groundTruth.Z]; 132 | gtCorner3D = anno_extrinsics'*gtCorner3D; 133 | gtCorner3D = gtCorner3D([1,3,2],:); 134 | gtCorner3D(3,:) = -1*gtCorner3D(3,:); 135 | gtCorner3D = Rtilt*gtCorner3D; 136 | break; 137 | catch 138 | end 139 | end 140 | 141 | end 142 | 143 | data =struct('sequenceName',sequenceName,'groundtruth3DBB',... 144 | groundtruth3DBB,'Rtilt',Rtilt,'K',K,... 145 | 'depthpath',depthpath,'rgbpath',rgbpath,'anno_extrinsics',anno_extrinsics,'depthname',depthname,... 146 | 'rgbname',rgbname,'sensorType',sensorType,'valid',1,'gtCorner3D',gtCorner3D); 147 | 148 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/utils/file2string.m: -------------------------------------------------------------------------------- 1 | function fileStr = file2string(fname) 2 | fileStr = ''; 3 | fid = fopen(fname,'r'); 4 | tline = fgetl(fid); 5 | while ischar(tline) 6 | fileStr = [fileStr sprintf('\n') tline]; 7 | tline = fgetl(fid); 8 | end 9 | fclose(fid); 10 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/SUNRGBDtoolbox/utils/findsubstring.m: -------------------------------------------------------------------------------- 1 | function sustr = findsubstring(str,strartstr,endstr) 2 | [ind1,ind2]=regexp(str, strartstr); 3 | str = str(ind2+1:end); 4 | [ind1,ind2]=regexp(str, endstr); 5 | sustr = str(1:ind1-1); 6 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/detection/benchmark_groundtruth.m: -------------------------------------------------------------------------------- 1 | function [groundtruthall,all_sequenceName] = benchmark_groundtruth(cls,path2gt,path2testAll) 2 | % get the ground truth box for this class 3 | try 4 | a = load(fullfile(path2gt,'groundtruth.mat')); 5 | catch 6 | path2gt = '/n/fs/modelnet/SUN3DV2/prepareGT/Metadata/'; 7 | a = load(fullfile(path2gt,'groundtruth.mat')); 8 | end 9 | 10 | if ~isempty(cls) 11 | pick = ismember({a.groundtruth.classname},cls); 12 | groundtruthall = a.groundtruth(pick); 13 | else 14 | groundtruthall = a.groundtruth; 15 | end 16 | 17 | if exist('path2testAll','var')&&~isempty(path2testAll) 18 | all_sequenceName = cell(1,length(path2testAll)); 19 | for i =1:length(all_sequenceName) 20 | all_sequenceName{i} = getSequenceName(path2testAll{i},'/data/rqi/SUNRGBD/'); 21 | end 22 | % hash ground truth image id and get the valid GT 23 | [validGT,GTimageid] = ismember({groundtruthall.sequenceName},all_sequenceName); 24 | groundtruthall = groundtruthall(validGT); 25 | GTimageid = GTimageid(validGT); 26 | for i =1:length(groundtruthall) 27 | groundtruthall(i).imageNum = GTimageid(i); 28 | end 29 | 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/detection/computePRCurve3D.m: -------------------------------------------------------------------------------- 1 | function [apScore,precision,reccall,isTp,isFp,isMissed,gtAssignment,maxOverlaps,numOfgt,gtIdxAll,allOverlaps] ... 2 | = computePRCurveTightBB(classname,predictedBbsTight,imageIds,groundTruthBbs,isDifficult) 3 | if length(groundTruthBbs) ~= length(isDifficult), error('inconsistent difficulty size.'); end 4 | if size(groundTruthBbs,1) ==1,groundTruthBbs= groundTruthBbs';end 5 | if size(isDifficult,1) ==1,isDifficult= isDifficult';end 6 | 7 | gtAssignment = zeros(length(predictedBbsTight),1); 8 | 9 | % pick ground truth of current class 10 | isMissed = false(numel(groundTruthBbs),1); 11 | isSameClass = ismember({groundTruthBbs.classname},classname); 12 | numOfgt =sum(isSameClass); 13 | groundTruthBbs = groundTruthBbs(isSameClass); 14 | isDifficult = logical(isDifficult(isSameClass)); 15 | 16 | P = size(predictedBbsTight,1); 17 | G = numel(groundTruthBbs); 18 | 19 | % sort all detections 20 | [~,sortIdx] = sort([predictedBbsTight.confidence],'descend'); 21 | %predictedBbs = predictedBbs(sortIdx,:); 22 | predictedBbsTight =predictedBbsTight(sortIdx,:); 23 | imageIds = imageIds(sortIdx); 24 | % threshold overlap 25 | tic; 26 | %predictedBbs =predictedBbs(:,1:6); 27 | %allOverlaps = bb3dOverlapCloseForm(predictedBbs,groundTruthBbs); 28 | allOverlaps = bb3dOverlapCloseForm(predictedBbsTight,groundTruthBbs); 29 | toc; 30 | onSameImage = bsxfun(@eq,imageIds(:),[groundTruthBbs.imageNum]); 31 | allOverlaps(~onSameImage) = 0; 32 | [maxOverlaps,gtIdx] = max(allOverlaps,[],2); 33 | gtIdx(maxOverlaps= 0.25; 36 | gtIdx(~isOverlapping) = 0; 37 | 38 | % Assign ground truth to the best matched detection 39 | [uniqueGtIdx,firstAssignment,~] = unique(gtIdx,'first'); 40 | isFirstAssignment = false(P,1); 41 | isFirstAssignment(firstAssignment) = true; 42 | isFirstAssignment = isFirstAssignment & gtIdx>0; 43 | 44 | % Get GT assignment of each TP detection 45 | tmpGtAssignment = zeros(P,1); 46 | tmpGtAssignment(firstAssignment) = uniqueGtIdx; 47 | tmp = find(isSameClass); 48 | tmpGtAssignment(tmpGtAssignment>0) = tmp(tmpGtAssignment(tmpGtAssignment>0)); 49 | gtAssignment(sortIdx) = tmpGtAssignment; 50 | 51 | % % Assign best matched conf to ground truth 52 | % tmpGtConf = -1e10 * ones(G,1); 53 | % if uniqueGtIdx(1) == 0, 54 | % tmpGtConf(uniqueGtIdx(2:end)) = predictedBbs(firstAssignment(2:end),7); 55 | % else 56 | % tmpGtConf(uniqueGtIdx) = predictedBbs(firstAssignment,7); 57 | % end 58 | % gtConf(isSameClass) = tmpGtConf; 59 | % isMissed = gtConf < lowConfThresh; 60 | 61 | % get unassigned ground truth bounding boxes 62 | tmpIsMissed = true(G,1); 63 | tmpIsMissed(setdiff(uniqueGtIdx,0)) = false; 64 | isMissed(isSameClass) = tmpIsMissed; 65 | 66 | % assign detection labels 67 | tp = isFirstAssignment & isOverlapping; 68 | fp = ~tp; 69 | dc = gtIdx ~= 0 & isDifficult(max(1,gtIdx)); 70 | tp(dc) = false; 71 | fp(dc) = false; 72 | isTp = false(numel(tp),1); 73 | isTp(sortIdx) = tp; 74 | isFp = false(numel(fp),1); 75 | isFp(sortIdx) = fp; 76 | 77 | % Compute precision/recall 78 | sumFp = cumsum(double(fp)); 79 | sumTp = cumsum(double(tp)); 80 | reccall = sumTp / sum(~isDifficult); 81 | precision = sumTp ./ (sumFp + sumTp); 82 | 83 | apScore = get_average_precision(precision, reccall); 84 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/detection/extract_gt_boxes.m: -------------------------------------------------------------------------------- 1 | %% Extract GT boxes 2 | 3 | clear all 4 | toolboxpath = '/afs/cs.stanford.edu/u/rqi/Data/SUNRGBD/SUNRGBDtoolbox'; 5 | addpath(genpath(toolboxpath)); 6 | 7 | 8 | %load('./chair_demo.mat','allTestImgIds','allBb3dtight') 9 | for className = {'bed','table','sofa','chair','toilet','desk','dresser','night_stand','bookshelf','bathtub'} 10 | clear('bbs'); 11 | clear('imgids'); 12 | split = load(fullfile(toolboxpath,'/traintestSUNRGBD/allsplit.mat')); 13 | testset_path = split.alltest; 14 | 15 | for i = 1:length(testset_path) 16 | testset_path{i}(1:16) = ''; 17 | testset_path{i} = strcat('/data/rqi/SUNRGBD', testset_path{i}); 18 | end 19 | [groundTruthBbs,all_sequenceName] = benchmark_groundtruth(className,fullfile(toolboxpath,'Metadata/'),testset_path); 20 | 21 | nBb = length(groundTruthBbs); 22 | for i = 1:nBb 23 | corners = get_corners_of_bb3d(groundTruthBbs(i)); 24 | bbs(i,:) = [reshape([corners(1:4,1) corners(1:4,2)]',1,[]) min(corners([1 end],3)) max(corners([1 end],3))]; 25 | imgids(i) = groundTruthBbs(i).imageNum; 26 | end 27 | 28 | dlmwrite(strcat(className{1}, '_gt_boxes.dat'), bbs, 'delimiter', ' '); 29 | dlmwrite(strcat(className{1}, '_gt_imgids.txt'), imgids, 'delimiter', ' '); 30 | end 31 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/detection/get_average_precision.m: -------------------------------------------------------------------------------- 1 | % Returns the average precision score given precision and recall vectors. The AP is computed by 2 | % numerically integrating the area under the PR curve. 3 | % 4 | % Note: this code was taken from the VOC2011 toolkit. 5 | % 6 | % Args: 7 | % precision - Px1 vector of precision scores, where P is the number of predictions. Note that the 8 | % precision scores must be monotonically decreasing. 9 | % recall - Px1 vector of recall scores, where P is the number of predictions. 10 | % 11 | % Returns: 12 | % ap - the average precision. 13 | function ap = get_average_precision(precision, recall) 14 | 15 | mrec = [0; recall; 1]; 16 | mpre = [0; precision; 0]; 17 | 18 | for ii = numel(mpre) - 1 : -1 : 1 19 | mpre(ii) = max(mpre(ii), mpre(ii+1)); 20 | end 21 | 22 | ii = find(mrec(2:end) ~= mrec(1:end-1)) + 1; 23 | ap = sum((mrec(ii) - mrec(ii-1)) .* mpre(ii)); 24 | end -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_data/matlab/detection/script_3Deval.m: -------------------------------------------------------------------------------- 1 | % The result file should contains following feild 2 | % allTestImgIds: Nx1 array of testing image's id in "alltest" for each box 3 | % allBb3dtight : Nx1 cell 3D bounding box strcture 4 | clear all 5 | toolboxpath = '/afs/cs.stanford.edu/u/rqi/Data/SUNRGBD/SUNRGBDtoolbox'; 6 | addpath(genpath(toolboxpath)); 7 | 8 | 9 | %load('./chair_demo.mat','allTestImgIds','allBb3dtight') 10 | className ='chair'; 11 | load('./chair_demo.mat','allTestImgIds','allBb3dtight') 12 | %load('../exampleresult_bathtub.mat'); 13 | %className ='bathtub'; 14 | split = load(fullfile(toolboxpath,'/traintestSUNRGBD/allsplit.mat')); 15 | testset_path = split.alltest; 16 | 17 | for i = 1:length(testset_path) 18 | testset_path{i}(1:16) = ''; 19 | testset_path{i} = strcat('/data/rqi/SUNRGBD', testset_path{i}); 20 | end 21 | [groundTruthBbs,all_sequenceName] = benchmark_groundtruth(className,fullfile(toolboxpath,'Metadata/'),testset_path); 22 | [apScore,precision,recall,isTp,isFp,isMissed,gtAssignment,maxOverlaps] = computePRCurve3D(className,allBb3dtight,allTestImgIds,groundTruthBbs,zeros(length(groundTruthBbs),1)); 23 | result_all = struct('apScore',apScore,'precision',precision,'recall',recall,'isTp',isTp,'isFp',isFp,'isMissed',isMissed,'gtAssignment',gtAssignment); 24 | 25 | figure, 26 | plot(recall,precision) 27 | title(className) -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_1.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_10.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_2.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_3.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_4.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_5.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_6.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_7.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_8.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/ap_curves/figure_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/sunrgbd/sunrgbd_detection/ap_curves/figure_9.png -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/compare_matlab_and_python_eval.py: -------------------------------------------------------------------------------- 1 | """ Compare MATLAB and Python eval code on AP computation """ 2 | import cPickle as pickle 3 | import numpy as np 4 | import argparse 5 | from PIL import Image 6 | import cv2 7 | import sys 8 | import os 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(os.path.join(BASE_DIR, '../sunrgbd_data')) 11 | from sunrgbd_data import sunrgbd_object 12 | from utils import rotz, compute_box_3d, load_zipped_pickle 13 | sys.path.append(os.path.join(BASE_DIR, '../../train')) 14 | from box_util import box3d_iou, is_clockwise 15 | import roi_seg_box3d_dataset 16 | from roi_seg_box3d_dataset import rotate_pc_along_y, NUM_HEADING_BIN 17 | from eval_det import eval_det_cls 18 | 19 | root_dir = '/home/rqi/Data/detection' 20 | gt_boxes_dir = '/home/rqi/Projects/kitti-challenge/sunrgbd_detection/gt_boxes' 21 | 22 | def flip_axis_to_camera(pc): 23 | ''' Flip X-right,Y-forward,Z-up to X-right,Y-down,Z-forward 24 | Input and output are both (N,3) array 25 | ''' 26 | pc2 = np.copy(pc) 27 | pc2[:,[0,1,2]] = pc2[:,[0,2,1]] # cam X,Y,Z = depth X,-Z,Y 28 | pc2[:,1] *= -1 29 | return pc2 30 | 31 | def box_conversion(bbox): 32 | """ In upright depth camera coord """ 33 | bbox3d = np.zeros((8,3)) 34 | # Make clockwise 35 | # NOTE: in box3d IoU evaluation we require the polygon vertices in 36 | # counter clockwise order. However, from dumped data in MATLAB 37 | # some of the polygons are in clockwise, some others are counter clockwise 38 | # so we need to inspect each box and make them consistent.. 39 | xy = np.reshape(bbox[0:8], (4,2)) 40 | if is_clockwise(xy): 41 | bbox3d[0:4,0:2] = xy 42 | bbox3d[4:,0:2] = xy 43 | else: 44 | bbox3d[0:4,0:2] = xy[::-1,:] 45 | bbox3d[4:,0:2] = xy[::-1,:] 46 | bbox3d[0:4,2] = bbox[9] # zmax 47 | bbox3d[4:,2] = bbox[8] # zmin 48 | return bbox3d 49 | 50 | def wrapper(bbox): 51 | bbox3d = box_conversion(bbox) 52 | bbox3d = flip_axis_to_camera(bbox3d) 53 | bbox3d_flipped = np.copy(bbox3d) 54 | bbox3d_flipped[0:4,:] = bbox3d[4:,:] 55 | bbox3d_flipped[4:,:] = bbox3d[0:4,:] 56 | return bbox3d_flipped 57 | 58 | def get_gt_cls(classname): 59 | gt = {} 60 | gt_boxes = np.loadtxt(os.path.join(gt_boxes_dir, '%s_gt_boxes.dat'%(classname))) 61 | gt_imgids = np.loadtxt(os.path.join(gt_boxes_dir, '%s_gt_imgids.txt'%(classname))) 62 | print gt_boxes.shape 63 | print gt_imgids.shape 64 | for i in range(len(gt_imgids)): 65 | imgid = gt_imgids[i] 66 | bbox = gt_boxes[i] 67 | bbox3d = wrapper(bbox) 68 | 69 | if imgid not in gt: 70 | gt[imgid] = [] 71 | gt[imgid].append(bbox3d) 72 | return gt 73 | 74 | if __name__=='__main__': 75 | #gt_boxes = np.loadtxt(os.path.join(gt_boxes_dir, 'chair_gt_boxes.dat')) 76 | #gt_imgids = np.loadtxt(os.path.join(gt_boxes_dir, 'chair_gt_imgids.txt')) 77 | pred_boxes = np.transpose(np.loadtxt(os.path.join(root_dir, 'chair_pred_boxes.dat'))) 78 | pred_imgids = np.loadtxt(os.path.join(root_dir, 'chair_pred_imgids.txt')) 79 | pred_confidence = np.loadtxt(os.path.join(root_dir, 'chair_pred_confidence.txt')) 80 | 81 | pred = {} 82 | ovthresh = 0.25 83 | 84 | print pred_boxes.shape 85 | 86 | for i in range(0,10000): 87 | imgid = pred_imgids[i] 88 | score = pred_confidence[i] 89 | bbox = pred_boxes[i] 90 | bbox3d = wrapper(bbox) 91 | 92 | if imgid not in pred: 93 | pred[imgid] = [] 94 | pred[imgid].append((bbox3d, score)) 95 | 96 | gt = get_gt_cls('chair') 97 | 98 | # ================================================================================= 99 | """ 100 | import cPickle as pickle 101 | from PIL import Image 102 | import cv2 103 | import roi_seg_box3d_dataset 104 | sys.path.append('../sunrgbd_data') 105 | from sunrgbd_data import sunrgbd_object 106 | from utils import rotz, compute_box_3d, load_zipped_pickle 107 | 108 | IMG_DIR = '/home/rqi/Data/mysunrgbd/training/image' 109 | TEST_DATASET = roi_seg_box3d_dataset.ROISegBoxDataset(npoints=2048, split='val', rotate_to_center=True, overwritten_data_path='val_1002.zip.pickle', from_rgb_detection=False) 110 | dataset = sunrgbd_object('/home/rqi/Data/mysunrgbd', 'training') 111 | 112 | # For detection evaluation 113 | gt = {} 114 | 115 | # Get GT boxes 116 | print 'Construct GT boxes...' 117 | for i in range(len(TEST_DATASET)): 118 | img_id = TEST_DATASET.id_list[i] 119 | if img_id in gt: continue # All ready counted.. 120 | gt[img_id] = [] 121 | 122 | objects = dataset.get_label_objects(img_id) 123 | calib = dataset.get_calibration(img_id) 124 | for obj in objects: 125 | if obj.classname != 'chair': continue 126 | box3d_pts_2d, box3d_pts_3d = compute_box_3d(obj, calib) 127 | box3d_pts_3d = calib.project_upright_depth_to_upright_camera(box3d_pts_3d) 128 | box3d_pts_3d_flipped = np.copy(box3d_pts_3d) 129 | box3d_pts_3d_flipped[0:4,:] = box3d_pts_3d[4:,:] 130 | box3d_pts_3d_flipped[4:,:] = box3d_pts_3d[0:4,:] 131 | gt[img_id].append(box3d_pts_3d_flipped) 132 | """ 133 | # ==================================================================================== 134 | 135 | 136 | import matplotlib.pyplot as plt 137 | rec, prec, ap = eval_det_cls(pred, gt, ovthresh) 138 | print prec[0:100] 139 | print rec[0:100] 140 | 141 | plt.plot(rec, prec, lw=2) 142 | fig = plt.gcf() 143 | fig.subplots_adjust(bottom=0.25) 144 | plt.xlim([0.0, 0.16]) 145 | plt.ylim([0.0, 1.05]) 146 | plt.xlabel('Recall') 147 | plt.ylabel('Precision') 148 | plt.show() 149 | 150 | print ap 151 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/eval_det.py: -------------------------------------------------------------------------------- 1 | """ Generic Code for Object Detection Evaluation 2 | 3 | Input: 4 | For each class: 5 | For each image: 6 | Predictions: box, score 7 | Groundtruths: box 8 | 9 | Output: 10 | For each class: 11 | precision-recal and average precision 12 | 13 | Author: Charles R. Qi 14 | Date: Oct 4th 2017 15 | 16 | Ref: https://raw.githubusercontent.com/rbgirshick/py-faster-rcnn/master/lib/datasets/voc_eval.py 17 | 18 | Author: Charles R. Qi 19 | Date: October, 2017 20 | """ 21 | import numpy as np 22 | 23 | def voc_ap(rec, prec, use_07_metric=False): 24 | """ ap = voc_ap(rec, prec, [use_07_metric]) 25 | Compute VOC AP given precision and recall. 26 | If use_07_metric is true, uses the 27 | VOC 07 11 point method (default:False). 28 | """ 29 | if use_07_metric: 30 | # 11 point metric 31 | ap = 0. 32 | for t in np.arange(0., 1.1, 0.1): 33 | if np.sum(rec >= t) == 0: 34 | p = 0 35 | else: 36 | p = np.max(prec[rec >= t]) 37 | ap = ap + p / 11. 38 | else: 39 | # correct AP calculation 40 | # first append sentinel values at the end 41 | mrec = np.concatenate(([0.], rec, [1.])) 42 | mpre = np.concatenate(([0.], prec, [0.])) 43 | 44 | # compute the precision envelope 45 | for i in range(mpre.size - 1, 0, -1): 46 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 47 | 48 | # to calculate area under PR curve, look for points 49 | # where X axis (recall) changes value 50 | i = np.where(mrec[1:] != mrec[:-1])[0] 51 | 52 | # and sum (\Delta recall) * prec 53 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 54 | return ap 55 | 56 | import os 57 | import sys 58 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 59 | sys.path.append(os.path.join(BASE_DIR, '../../train')) 60 | from box_util import box3d_iou 61 | def get_iou(bb1, bb2): 62 | """ Compute IoU of two bounding boxes. 63 | ** Define your bod IoU function HERE ** 64 | """ 65 | #pass 66 | iou3d, iou2d = box3d_iou(bb1, bb2) 67 | return iou3d 68 | 69 | def eval_det_cls(pred, gt, ovthresh=0.25, use_07_metric=False): 70 | """ Generic functions to compute precision/recall for object detection 71 | for a single class. 72 | Input: 73 | pred: map of {img_id: [(bbox, score)]} where bbox is numpy array 74 | gt: map of {img_id: [bbox]} 75 | ovthresh: scalar, iou threshold 76 | use_07_metric: bool, if True use VOC07 11 point method 77 | Output: 78 | rec: numpy array of length nd 79 | prec: numpy array of length nd 80 | ap: scalar, average precision 81 | """ 82 | 83 | # construct gt objects 84 | class_recs = {} # {img_id: {'bbox': bbox list, 'det': matched list}} 85 | npos = 0 86 | for img_id in gt.keys(): 87 | bbox = np.array(gt[img_id]) 88 | det = [False] * len(bbox) 89 | npos += len(bbox) 90 | class_recs[img_id] = {'bbox': bbox, 'det': det} 91 | # pad empty list to all other imgids 92 | for img_id in pred.keys(): 93 | if img_id not in gt: 94 | class_recs[img_id] = {'bbox': np.array([]), 'det': []} 95 | 96 | # construct dets 97 | image_ids = [] 98 | confidence = [] 99 | BB = [] 100 | for img_id in pred.keys(): 101 | for box,score in pred[img_id]: 102 | image_ids.append(img_id) 103 | confidence.append(score) 104 | BB.append(box) 105 | confidence = np.array(confidence) 106 | BB = np.array(BB) # (nd,4 or 8,3) 107 | 108 | # sort by confidence 109 | sorted_ind = np.argsort(-confidence) 110 | sorted_scores = np.sort(-confidence) 111 | BB = BB[sorted_ind, ...] 112 | image_ids = [image_ids[x] for x in sorted_ind] 113 | 114 | # go down dets and mark TPs and FPs 115 | nd = len(image_ids) 116 | tp = np.zeros(nd) 117 | fp = np.zeros(nd) 118 | for d in range(nd): 119 | if d%100==0: print d 120 | R = class_recs[image_ids[d]] 121 | bb = BB[d,:].astype(float) 122 | ovmax = -np.inf 123 | BBGT = R['bbox'].astype(float) 124 | 125 | if BBGT.size > 0: 126 | # compute overlaps 127 | for j in range(BBGT.shape[0]): 128 | iou = get_iou(bb, BBGT[j,...]) 129 | if iou > ovmax: 130 | ovmax = iou 131 | jmax = j 132 | 133 | #print d, ovmax 134 | if ovmax > ovthresh: 135 | if not R['det'][jmax]: 136 | tp[d] = 1. 137 | R['det'][jmax] = 1 138 | else: 139 | fp[d] = 1. 140 | else: 141 | fp[d] = 1. 142 | 143 | # compute precision recall 144 | fp = np.cumsum(fp) 145 | tp = np.cumsum(tp) 146 | rec = tp / float(npos) 147 | print 'NPOS: ', npos 148 | # avoid divide by zero in case the first detection matches a difficult 149 | # ground truth 150 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 151 | ap = voc_ap(rec, prec, use_07_metric) 152 | 153 | return rec, prec, ap 154 | 155 | def eval_det(pred_all, gt_all, ovthresh=0.25, use_07_metric=False): 156 | """ Generic functions to compute precision/recall for object detection 157 | for multiple classes. 158 | Input: 159 | pred_all: map of {img_id: [(classname, bbox, score)]} 160 | gt_all: map of {img_id: [(classname, bbox)]} 161 | ovthresh: scalar, iou threshold 162 | use_07_metric: bool, if true use VOC07 11 point method 163 | Output: 164 | rec: {classname: rec} 165 | prec: {classname: prec_all} 166 | ap: {classname: scalar} 167 | """ 168 | pred = {} # map {classname: pred} 169 | gt = {} # map {classname: gt} 170 | for img_id in pred_all.keys(): 171 | for classname, bbox, score in pred_all[img_id]: 172 | if classname not in pred: pred[classname] = {} 173 | if img_id not in pred[classname]: 174 | pred[classname][img_id] = [] 175 | if classname not in gt: gt[classname] = {} 176 | if img_id not in gt[classname]: 177 | gt[classname][img_id] = [] 178 | pred[classname][img_id].append((bbox,score)) 179 | for img_id in gt_all.keys(): 180 | for classname, bbox in gt_all[img_id]: 181 | if classname not in gt: gt[classname] = {} 182 | if img_id not in gt[classname]: 183 | gt[classname][img_id] = [] 184 | gt[classname][img_id].append(bbox) 185 | 186 | rec = {} 187 | prec = {} 188 | ap = {} 189 | for classname in gt.keys(): 190 | print 'Computing AP for class: ', classname 191 | rec[classname], prec[classname], ap[classname] = eval_det_cls(pred[classname], gt[classname], ovthresh, use_07_metric) 192 | print classname, ap[classname] 193 | 194 | return rec, prec, ap 195 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/evaluate.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import numpy as np 3 | import argparse 4 | from PIL import Image 5 | import cv2 6 | import sys 7 | import os 8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.append(os.path.join(BASE_DIR, '../sunrgbd_data')) 10 | from sunrgbd_data import sunrgbd_object 11 | from utils import rotz, compute_box_3d, load_zipped_pickle 12 | sys.path.append(os.path.join(BASE_DIR, '../../train')) 13 | from box_util import box3d_iou 14 | import roi_seg_box3d_dataset 15 | from roi_seg_box3d_dataset import rotate_pc_along_y, NUM_HEADING_BIN 16 | from eval_det import eval_det 17 | from compare_matlab_and_python_eval import get_gt_cls 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--data_path', default=None, help='data path for .pickle file, the one used for val in train.py [default: None]') 21 | parser.add_argument('--result_path', default=None, help='result path for .pickle file from test.py [default: None]') 22 | parser.add_argument('--from_rgb_detection', action='store_true', help='test from data file from rgb detection.') 23 | FLAGS = parser.parse_args() 24 | 25 | 26 | IMG_DIR = '/home/rqi/Data/mysunrgbd/training/image' 27 | TEST_DATASET = roi_seg_box3d_dataset.ROISegBoxDataset(npoints=2048, split='val', rotate_to_center=True, overwritten_data_path=FLAGS.data_path, from_rgb_detection=FLAGS.from_rgb_detection) 28 | dataset = sunrgbd_object('/home/rqi/Data/mysunrgbd', 'training') 29 | 30 | ps_list, segp_list, center_list, heading_cls_list, heading_res_list, size_cls_list, size_res_list, rot_angle_list, score_list = load_zipped_pickle(FLAGS.result_path) 31 | 32 | # For detection evaluation 33 | pred_all = {} 34 | gt_all = {} 35 | ovthresh = 0.25 36 | 37 | print len(segp_list), len(TEST_DATASET) 38 | raw_input() 39 | 40 | # Get GT boxes 41 | print 'Construct GT boxes...' 42 | classname_list = ['bed','table','sofa','chair','toilet','desk','dresser','night_stand','bookshelf','bathtub'] 43 | """ 44 | for i in range(len(TEST_DATASET)): 45 | img_id = TEST_DATASET.id_list[i] 46 | if img_id in gt_all: continue # All ready counted.. 47 | gt_all[img_id] = [] 48 | 49 | objects = dataset.get_label_objects(img_id) 50 | calib = dataset.get_calibration(img_id) 51 | for obj in objects: 52 | if obj.classname not in classname_list: continue 53 | box3d_pts_2d, box3d_pts_3d = compute_box_3d(obj, calib) 54 | box3d_pts_3d = calib.project_upright_depth_to_upright_camera(box3d_pts_3d) 55 | box3d_pts_3d_flipped = np.copy(box3d_pts_3d) 56 | box3d_pts_3d_flipped[0:4,:] = box3d_pts_3d[4:,:] 57 | box3d_pts_3d_flipped[4:,:] = box3d_pts_3d[0:4,:] 58 | gt_all[img_id].append((obj.classname, box3d_pts_3d_flipped)) 59 | """ 60 | 61 | #gt_all2 = {} 62 | gt_cls = {} 63 | for classname in classname_list: 64 | gt_cls[classname] = get_gt_cls(classname) 65 | for img_id in gt_cls[classname]: 66 | if img_id not in gt_all: 67 | gt_all[img_id] = [] 68 | for box in gt_cls[classname][img_id]: 69 | gt_all[img_id].append((classname, box)) 70 | #print gt_all[1] 71 | #print gt_all2[1] 72 | raw_input() 73 | 74 | # Get PRED boxes 75 | print 'Construct PRED boxes...' 76 | for i in range(len(TEST_DATASET)): 77 | img_id = TEST_DATASET.id_list[i] 78 | classname = TEST_DATASET.type_list[i] 79 | 80 | center = center_list[i].squeeze() 81 | ret = TEST_DATASET[i] 82 | if FLAGS.from_rgb_detection: 83 | rot_angle = ret[1] 84 | else: 85 | rot_angle = ret[7] 86 | 87 | # Get heading angle and size 88 | #print heading_cls_list[i], heading_res_list[i], size_cls_list[i], size_res_list[i] 89 | heading_angle = roi_seg_box3d_dataset.class2angle(heading_cls_list[i], heading_res_list[i], NUM_HEADING_BIN) 90 | box_size = roi_seg_box3d_dataset.class2size(size_cls_list[i], size_res_list[i]) 91 | corners_3d_pred = roi_seg_box3d_dataset.get_3d_box(box_size, heading_angle, center) 92 | corners_3d_pred = rotate_pc_along_y(corners_3d_pred, -rot_angle) 93 | 94 | if img_id not in pred_all: 95 | pred_all[img_id] = [] 96 | pred_all[img_id].append((classname, corners_3d_pred, score_list[i])) 97 | print pred_all[1] 98 | raw_input() 99 | 100 | import matplotlib.pyplot as plt 101 | import matplotlib as mpl 102 | mpl.rc('axes', linewidth=2) 103 | print 'Computing AP...' 104 | rec, prec, ap = eval_det(pred_all, gt_all, ovthresh) 105 | for classname in ap.keys(): 106 | print '%015s: %f' % (classname, ap[classname]) 107 | plt.plot(rec[classname], prec[classname], lw=3) 108 | fig = plt.gcf() 109 | fig.subplots_adjust(bottom=0.25) 110 | plt.xlim([0.0, 1.0]) 111 | plt.ylim([0.0, 1.05]) 112 | plt.xlabel('Recall', fontsize=24) 113 | plt.ylabel('Precision', fontsize=24) 114 | plt.title(classname, fontsize=24) 115 | plt.show() 116 | raw_input() 117 | print 'mean AP: ', np.mean([ap[classname] for classname in ap]) 118 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/bathtub_gt_boxes.dat: -------------------------------------------------------------------------------- 1 | -0.6028 1.6455 0.97657 1.107 1.25 1.9091 -0.32937 2.4475 -1.264 -0.74716 2 | -1.4303 0.84867 0.036112 0.67893 0.13266 1.513 -1.3338 1.6828 -1.4641 -1.1551 3 | -1.2318 0.39091 0.26818 0.81818 0.031818 1.6545 -1.4682 1.2273 -1.3561 -0.88636 4 | 1.3227 0.54545 0.040909 0.85455 0.23182 1.6636 1.5136 1.3545 -1.2 -0.67727 5 | -1.1409 1.2727 -0.39545 0.77273 0.42273 2 -0.32273 2.5 -1.2 -0.81364 6 | 0.47727 1.9818 -1.0409 1.6727 -1.1955 2.4182 0.32273 2.7273 -1.2 -0.76818 7 | 0.64091 1.3909 -0.75 1.1273 -0.89545 1.8727 0.49545 2.1364 -1.2 -0.77727 8 | 0.31364 1.9727 -0.80455 1.3818 -1.1409 2.0182 -0.022727 2.6091 -1.2 -0.84091 9 | 0.077273 1.5182 -1.0318 0.80909 -1.45 1.4636 -0.34091 2.1727 -1.2 -0.79545 10 | -2.2322 1.857 -1.0415 1.2167 -0.27131 2.6489 -1.4621 3.2892 -1.3627 -0.59874 11 | 1.5003 1.873 -0.62539 1.3116 -0.95636 2.5647 1.1693 3.1261 -1.3708 -0.57456 12 | 1.6421 0.41441 -0.22832 1.2906 0.20288 2.2111 2.0733 1.3349 -1.1325 -0.56243 13 | 0.15532 2.5669 -1.122 1.9543 -1.4617 2.6628 -0.1844 3.2753 -1.3842 -0.95377 14 | 0.67246 1.7836 -0.79325 1.5421 -0.91977 2.3098 0.54594 2.5513 -1.4806 -0.98738 15 | 0.99555 1.5295 -0.45505 1.7106 -0.41164 2.0583 1.039 1.8772 -1.3136 -1.0769 16 | -0.52519 1.8725 0.53909 1.5869 0.71522 2.2431 -0.34906 2.5288 -0.98849 -0.68564 17 | 0.43278 1.8082 -0.92906 1.3072 -1.0288 1.5784 0.33306 2.0793 -1.0704 1.0313 18 | -1.9769 1.2149 -1.0728 0.75162 -0.32685 2.2075 -1.231 2.6708 -1.3688 -0.66097 19 | 1.0805 1.8314 -0.61295 1.2308 -0.94417 2.1648 0.74932 2.7654 -1.3313 -0.64375 20 | -2.718 2.5595 -1.8948 2.0325 -1.4059 2.7961 -2.2291 3.3231 -1.3988 -0.97059 21 | -0.20414 3.0689 -1.4741 1.9801 -2.0008 2.5945 -0.73086 3.6833 -1.4213 -0.9626 22 | 1.1033 2.7588 -0.38229 2.4939 -0.52668 3.304 0.95892 3.5688 -1.3629 -0.89375 23 | 1.6191 2.005 0.57142 1.153 -0.46222 2.4239 0.58546 3.276 -1.5116 -0.72326 24 | 0.89785 0.99508 0.23885 0.89966 0.0019348 2.5359 0.66093 2.6313 -1.2117 -0.84009 25 | -1.2796 1.3932 0.98284 1.1486 1.126 2.4724 -1.1365 2.7171 -1.2409 -0.50455 26 | 2.2276 0.9889 1.549 0.56928 0.71038 1.9254 1.3889 2.345 -1.2 -0.3 27 | -1.0566 2.6266 -0.44071 2.5273 -0.23119 3.8269 -0.84705 3.9262 -1.2188 -0.86322 28 | -0.23854 1.791 0.41843 1.1923 1.4063 2.2762 0.7493 2.8749 -1.1113 -0.74974 29 | -0.9759 2.6333 0.51329 2.5791 0.55022 3.593 -0.93898 3.6472 -1.1187 -0.68125 30 | -1.227 1.3677 0.10826 1.7668 -0.15126 2.6352 -1.4865 2.2362 -1.1901 -0.72728 31 | 1.4661 1.7646 0.90908 1.334 0.081296 2.4048 0.63832 2.8354 -1.2063 -0.85432 32 | 0.42139 2.6172 1.3473 1.9713 1.8122 2.6377 0.88629 3.2837 -1.3938 -0.91875 33 | -1.1788 1.5625 0.19375 1.5875 0.17875 2.4125 -1.1938 2.3875 -1.3188 -0.91875 34 | 1.0045 0.79091 0.54961 0.40597 -0.15303 1.2364 0.3019 1.6213 -1.0652 -0.38333 35 | 0.99138 0.83345 0.46818 0.86364 0.51569 1.8233 1.0389 1.7931 -0.72273 -0.3 36 | 0.26212 2.1364 -1.2306 1.4981 -1.5515 2.2488 -0.058865 2.8871 -1.4742 -1.047 37 | 0.80758 2.1273 -0.73984 1.8661 -0.87149 2.6463 0.67592 2.9074 -1.2 -0.65303 38 | 0.56113 1.4591 -0.82273 0.83636 -1.1437 1.5497 0.24015 2.1724 -1.4227 -1.0591 39 | 0.53997 1.6936 -0.50906 1.5042 -0.62766 2.1609 0.42137 2.3504 -1.3854 -0.91263 40 | -1.0864 0.81818 -0.80382 0.57096 -0.15732 1.3098 -0.43986 1.557 -1.3864 -1.0136 41 | 0.64091 0.8 -0.62273 0.45455 -0.83182 1.2091 0.43182 1.5545 -1.4682 -1.0955 42 | 0.46818 1.1364 -0.77018 0.48987 -1.1409 1.2 0.097451 1.8465 -1.4318 -1.0227 43 | -0.97799 0.66974 -0.26818 0.5 0.0099138 1.6629 -0.69989 1.8327 -1.3682 -0.95909 44 | -0.97229 0.33172 -0.41893 -0.09964 0.49134 1.0681 -0.062018 1.4994 -1.4591 -0.97727 45 | 0.98636 0.21818 0.41364 0.18182 0.35 1.2545 0.92273 1.2909 -1.3773 -1.0136 46 | 0.80455 0.8 -0.37759 0.02624 -0.76269 0.6112 0.41945 1.385 -1.3864 -1.05 47 | -1.2005 0.47517 -0.44727 0.51727 -0.53182 1.9545 -1.285 1.9124 -1.3955 -1.0591 48 | 1.3136 0.25455 0.52273 0.28182 0.56818 1.7091 1.3591 1.6818 -1.3682 -0.95909 49 | 1.2562 0.044653 0.43565 -0.037665 0.31434 1.1715 1.1349 1.2538 -1.4591 -1.0591 50 | -0.95229 0.4912 0.41765 0.23551 0.53069 0.83096 -0.83926 1.0867 -1.4126 -0.93081 51 | 0.73182 0.2 -0.37727 0.29091 -0.32273 0.98182 0.78636 0.89091 -1.3672 -0.92172 52 | -0.93182 0.42727 -0.20455 0.20909 0.086364 1.1818 -0.64091 1.4 -1.4672 -1.0308 53 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/bathtub_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 104 108 115 121 122 124 199 393 414 2054 2055 2056 2057 2058 2059 2127 2128 2134 2135 2137 2138 2139 2146 2150 2154 2155 2156 2157 2159 2165 2168 2175 2176 2615 2619 2761 2784 3972 3974 4007 4009 4043 4118 4170 4188 4213 4216 4239 4240 4242 4261 4279 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/bed_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 27 28 29 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 109 112 113 119 120 190 192 193 195 197 401 403 406 407 408 409 410 412 678 679 680 681 682 682 682 683 684 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 742 743 744 745 747 748 749 750 768 769 770 771 772 773 774 776 777 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 803 806 809 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 856 856 856 858 858 858 859 859 860 861 862 862 863 863 864 865 866 866 867 868 869 870 871 872 873 874 875 875 876 877 878 878 879 880 881 908 909 913 915 917 918 919 919 920 921 922 922 923 923 925 983 984 985 986 1048 1049 1052 1111 1112 1116 1131 1132 1133 1134 1135 1136 1138 1139 1140 1141 1142 1143 1144 1145 1151 1153 1154 1155 1156 1157 1158 1675 1676 1680 1693 1694 1694 1695 1696 1697 1697 1699 1700 1702 1704 1705 1708 1708 1888 1891 1892 1895 1896 1898 1926 1927 1928 1929 1931 1932 1934 1935 1936 1942 1961 1962 1963 1970 1971 1973 2060 2061 2061 2062 2063 2064 2065 2066 2067 2070 2071 2072 2073 2075 2076 2248 2249 2250 2251 2252 2253 2254 2255 2256 2258 2259 2260 2261 2262 2263 2264 2265 2266 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2281 2282 2284 2288 2289 2290 2291 2292 2293 2294 2296 2297 2298 2300 2300 2302 2303 2304 2305 2306 2307 2308 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2332 2333 2333 2334 2335 2336 2337 2338 2339 2339 2340 2340 2341 2343 2344 2344 2345 2346 2349 2350 2352 2353 2354 2355 2358 2359 2360 2362 2363 2364 2365 2366 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2381 2382 2554 2598 2603 2620 2635 2636 2671 2672 2680 2707 2708 2713 2745 2746 2779 2781 3422 3423 3425 3426 3427 3428 3429 3432 3433 3439 3448 3450 3451 3454 3456 3457 3459 3460 3471 3479 3483 3484 3484 3485 3485 3488 3489 3489 3490 3490 3491 3498 3499 3500 3502 3538 3542 3546 3588 3589 3593 3598 3600 3603 3604 3605 3606 3606 3607 3608 3612 3613 3614 3615 3616 3617 3967 3968 3969 3970 3975 3981 4004 4005 4027 4028 4029 4038 4039 4047 4053 4054 4110 4111 4112 4113 4115 4172 4174 4175 4177 4190 4190 4191 4191 4196 4197 4199 4201 4202 4205 4206 4207 4210 4228 4232 4234 4235 4255 4256 4257 4258 4259 4273 4274 4277 4277 4304 4321 4504 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/bookshelf_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 132 137 139 139 140 141 146 148 149 156 157 158 158 201 208 210 256 276 292 293 294 300 300 332 384 509 509 526 535 542 543 544 544 545 546 546 547 548 563 570 570 573 573 574 574 574 575 575 576 578 578 581 582 583 583 584 584 585 587 587 597 597 599 615 615 637 637 637 638 639 639 639 650 657 657 657 658 671 676 981 987 988 990 1055 1059 1088 1091 1095 1105 1109 1110 1119 1124 1719 1747 1748 1749 1865 1866 1899 1900 1901 1901 1901 1901 1901 1902 1902 1902 1902 1904 1904 1904 1904 1905 1907 1907 1908 1926 1927 1943 1953 1955 1971 1972 1974 1977 1977 1978 1979 1979 2009 2012 2019 2020 2048 2051 2052 2053 2062 2064 2078 2081 2081 2098 2106 2108 2111 2125 2126 2126 2181 2257 2280 2283 2308 2309 2311 2316 2322 2325 2326 2336 2347 2348 2357 2358 2367 2367 2368 2382 2395 2395 2395 2396 2398 2398 2399 2430 2452 2458 2496 2498 2501 2568 2569 2572 2579 2580 2585 2586 2586 2590 2591 2594 2608 2609 2612 2621 2634 2744 3010 3011 3072 3149 3166 3166 3166 3168 3169 3169 3188 3194 3194 3195 3259 3260 3432 3435 3441 3442 3458 3459 3460 3493 3493 3494 3496 3504 3505 3506 3511 3610 3617 3618 3633 3633 3634 3977 3978 3982 3992 4003 4036 4037 4041 4041 4082 4083 4085 4086 4131 4131 4503 4504 4533 4619 4619 4621 4622 4623 4624 4626 4626 4626 4678 4678 4688 4700 4708 4716 4716 4739 4750 4753 4765 4803 4803 4804 4804 4816 4818 4846 4848 4850 4879 4879 4880 4881 4882 4920 4923 4924 4924 4925 5000 5000 5001 5001 5002 5002 5003 5003 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/dresser_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 1 3 6 11 13 14 15 15 16 18 25 26 33 34 34 35 36 37 39 40 41 42 43 45 46 47 656 656 656 656 679 679 680 681 681 681 682 683 683 684 684 684 714 718 721 722 724 726 726 727 728 729 729 730 731 731 731 732 732 733 733 734 734 735 736 736 736 739 740 741 742 742 743 744 745 747 749 750 768 769 770 771 772 774 777 779 779 780 781 782 783 784 785 786 787 789 790 792 794 794 795 796 798 806 807 820 840 841 845 845 847 848 851 853 854 855 857 858 860 864 867 867 875 876 877 877 877 878 879 1116 1132 1138 1140 1893 1894 1926 1927 1929 1930 1933 1934 1939 1940 1944 1945 1962 2065 2071 2072 2074 2248 2249 2250 2265 2267 2268 2271 2277 2277 2284 2288 2294 2296 2299 2308 2309 2315 2317 2326 2327 2330 2348 2350 2351 2352 2355 2358 2361 2364 2366 2369 2370 2376 2378 2379 2379 2381 2382 2595 2596 2597 2598 2602 2606 2701 3539 3540 3546 3585 3586 3591 3592 3594 3595 3596 3599 3602 3605 3609 3612 3615 3616 4028 4029 4030 4113 4115 4206 4210 4305 4505 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/night_stand_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 1 2 3 4 6 9 10 10 11 11 12 13 14 15 16 17 18 19 20 21 22 24 25 26 26 27 28 29 30 31 32 33 35 37 38 39 40 41 42 43 44 44 46 47 50 50 109 112 112 119 120 120 190 192 193 195 197 401 403 407 408 409 412 412 656 656 677 678 679 679 680 680 681 681 683 684 713 714 715 719 720 721 722 723 723 724 725 725 726 726 727 728 729 729 730 731 732 732 732 733 733 734 735 736 738 739 740 740 741 742 744 744 747 747 748 750 765 769 770 771 772 773 776 777 779 780 782 784 786 787 789 791 791 792 793 796 797 798 799 800 800 806 809 818 820 820 820 840 841 842 843 843 845 846 847 847 849 849 851 852 852 853 856 858 860 860 862 863 864 869 869 870 872 874 875 876 878 880 908 913 917 918 919 920 921 922 923 1116 1140 1141 1141 1153 1155 1156 1157 1158 1158 1898 1970 2063 2066 2067 2075 2076 2255 2256 2259 2261 2261 2263 2264 2265 2270 2271 2272 2276 2277 2279 2280 2282 2296 2297 2303 2312 2312 2318 2320 2321 2329 2333 2335 2338 2341 2343 2349 2354 2360 2361 2362 2365 2366 2368 2369 2371 2378 2549 2672 2680 2707 2737 2742 2745 3498 3968 4035 4110 4113 4273 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/sofa_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 19 20 57 58 61 63 64 65 67 67 68 69 69 80 83 84 87 91 132 133 134 135 136 137 137 137 138 141 143 145 146 148 150 150 151 152 153 153 154 155 156 207 208 210 211 213 214 258 259 265 265 266 267 268 270 270 272 272 275 275 275 275 276 276 277 278 278 278 279 279 282 282 282 283 283 284 284 284 285 285 285 285 285 286 286 286 289 289 297 298 305 306 307 308 315 315 315 317 317 317 322 325 326 328 328 328 330 331 332 334 336 336 360 361 370 370 370 370 450 453 454 457 462 462 462 492 498 498 498 516 542 545 561 563 576 577 601 602 610 627 647 751 752 754 756 758 760 761 761 762 762 763 763 763 763 764 764 765 765 766 766 767 778 778 804 804 805 805 806 808 808 810 810 811 813 815 816 825 826 828 828 829 829 829 830 831 832 882 883 884 885 886 887 887 888 888 888 889 889 890 890 891 892 893 899 900 903 904 907 907 910 911 912 942 943 944 967 968 968 969 970 970 971 975 976 1000 1001 1003 1003 1004 1004 1005 1011 1012 1013 1014 1021 1021 1022 1022 1023 1023 1024 1031 1032 1032 1034 1035 1035 1040 1040 1047 1065 1071 1079 1088 1088 1089 1089 1090 1090 1095 1095 1096 1096 1102 1117 1118 1118 1119 1119 1124 1130 1165 1166 1167 1550 1720 1757 1761 1764 1779 1779 1780 1790 1798 1800 1864 1877 1877 1878 1885 1920 1926 1927 1953 1954 1957 1958 1972 1982 1989 1989 2003 2004 2006 2010 2011 2011 2011 2014 2026 2027 2031 2040 2041 2042 2043 2044 2044 2044 2050 2051 2053 2098 2098 2099 2100 2101 2101 2102 2102 2103 2105 2106 2108 2109 2110 2115 2116 2125 2126 2205 2256 2268 2330 2375 2384 2385 2386 2386 2387 2387 2388 2388 2389 2390 2392 2393 2394 2396 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2409 2409 2410 2411 2411 2412 2412 2413 2413 2415 2416 2416 2417 2418 2424 2425 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2434 2435 2435 2436 2437 2438 2440 2442 2443 2444 2445 2445 2446 2447 2450 2451 2452 2453 2455 2455 2457 2458 2459 2460 2461 2462 2492 2509 2569 2579 2580 2583 2584 2643 2644 2646 2648 2649 2650 2651 2659 2660 2665 2668 2669 2670 2730 2731 2732 2733 2750 2768 2769 2770 2786 2799 2800 2805 2805 2810 2889 2890 2893 2936 2937 2938 2954 2954 2954 2955 2956 2957 2957 2958 2970 2971 2972 2980 2980 2981 2983 2984 2990 2990 3003 3078 3152 3153 3160 3162 3163 3177 3189 3198 3199 3205 3206 3233 3235 3236 3270 3392 3394 3395 3408 3409 3415 3416 3417 3419 3420 3509 3510 3515 3517 3522 3526 3527 3528 3528 3529 3531 3534 3548 3549 3575 3575 3576 3583 3584 3734 3748 3754 3758 3901 3902 3903 3904 3924 3956 3957 3957 3958 3959 3978 3979 3979 3980 3990 3993 3994 4015 4022 4023 4057 4058 4060 4098 4099 4102 4103 4104 4105 4122 4123 4130 4143 4144 4146 4147 4166 4167 4244 4246 4302 4313 4313 4314 4314 4315 4324 4339 4375 4379 4382 4390 4392 4429 4430 4440 4484 4489 4490 4490 4496 4512 4513 4513 4514 4515 4516 4518 4518 4524 4525 4527 4530 4530 4535 4538 4549 4549 4550 4550 4551 4552 4552 4563 4704 4749 4755 4757 4758 4783 4817 4817 4845 4872 4875 4884 4894 4896 4898 4919 4924 4925 4929 4929 4930 4976 4991 4992 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/gt_boxes/toilet_gt_imgids.txt: -------------------------------------------------------------------------------- 1 | 103 105 106 107 111 114 115 121 123 124 199 389 391 393 394 414 415 536 926 934 935 937 938 939 948 949 950 958 959 960 961 962 963 966 993 1043 1044 1045 1160 1678 1733 1734 1840 1841 1886 1887 2054 2057 2058 2127 2128 2129 2130 2140 2141 2142 2144 2148 2149 2151 2152 2156 2157 2159 2165 2166 2168 2169 2170 2171 2172 2173 2175 2176 2556 2558 2559 2560 2614 2616 2617 2618 2619 2747 2748 2870 2871 2872 2910 2913 2914 2915 2916 2917 2917 3174 3175 3218 3243 3564 3972 3973 4007 4032 4033 4044 4088 4090 4116 4118 4168 4169 4186 4187 4188 4211 4214 4215 4216 4217 4239 4240 4241 4260 4281 4294 4295 4351 4352 4458 4459 4472 4501 4502 4509 4546 4562 4580 4584 4584 4618 4648 4658 4702 4703 4868 4955 4985 4987 5005 5026 2 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/model_util_sunrgbd.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | from roi_seg_box3d_dataset import NUM_HEADING_BIN, NUM_SIZE_CLUSTER, compute_box3d_iou, class2type, type_mean_size 9 | mean_size_arr = np.zeros((NUM_SIZE_CLUSTER, 3)) 10 | for i in range(NUM_SIZE_CLUSTER): 11 | mean_size_arr[i,:] = type_mean_size[class2type[i]] 12 | 13 | def huber_loss(error, delta): 14 | abs_error = tf.abs(error) 15 | quadratic = tf.minimum(abs_error, delta) 16 | linear = (abs_error - quadratic) 17 | losses = 0.5 * quadratic**2 + delta * linear 18 | return tf.reduce_mean(losses) 19 | 20 | 21 | def get_box3d_corners_helper(centers, headings, sizes): 22 | """ TF layer. Input: (N,3), (N,), (N,3), Output: (N,8,3) """ 23 | print '-----', centers 24 | N = centers.get_shape()[0].value 25 | l = tf.slice(sizes, [0,0], [-1,1]) # (N,1) 26 | w = tf.slice(sizes, [0,1], [-1,1]) # (N,1) 27 | h = tf.slice(sizes, [0,2], [-1,1]) # (N,1) 28 | print l,w,h 29 | x_corners = tf.concat([l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2], axis=1) # (N,8) 30 | y_corners = tf.concat([h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2], axis=1) # (N,8) 31 | z_corners = tf.concat([w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2], axis=1) # (N,8) 32 | corners = tf.concat([tf.expand_dims(x_corners,1), tf.expand_dims(y_corners,1), tf.expand_dims(z_corners,1)], axis=1) # (N,3,8) 33 | print x_corners, y_corners, z_corners 34 | c = tf.cos(headings) 35 | s = tf.sin(headings) 36 | ones = tf.ones([N], dtype=tf.float32) 37 | zeros = tf.zeros([N], dtype=tf.float32) 38 | row1 = tf.stack([c,zeros,s], axis=1) # (N,3) 39 | row2 = tf.stack([zeros,ones,zeros], axis=1) 40 | row3 = tf.stack([-s,zeros,c], axis=1) 41 | R = tf.concat([tf.expand_dims(row1,1), tf.expand_dims(row2,1), tf.expand_dims(row3,1)], axis=1) # (N,3,3) 42 | print row1, row2, row3, R, N 43 | corners_3d = tf.matmul(R, corners) # (N,3,8) 44 | corners_3d += tf.tile(tf.expand_dims(centers,2), [1,1,8]) # (N,3,8) 45 | corners_3d = tf.transpose(corners_3d, perm=[0,2,1]) # (N,8,3) 46 | return corners_3d 47 | 48 | def get_box3d_corners(center, heading_residuals, size_residuals): 49 | """ TF layer. 50 | Inputs: 51 | center: (B,3) 52 | heading_residuals: (B,NH) 53 | size_residuals: (B,NS,3) 54 | Outputs: 55 | box3d_corners: (B,NH,NS,8,3) tensor 56 | """ 57 | batch_size = center.get_shape()[0].value 58 | heading_bin_centers = tf.constant(np.arange(0,2*np.pi,2*np.pi/NUM_HEADING_BIN), dtype=tf.float32) # (NH,) 59 | headings = heading_residuals + tf.expand_dims(heading_bin_centers, 0) # (B,NH) 60 | 61 | mean_sizes = tf.expand_dims(tf.constant(mean_size_arr, dtype=tf.float32), 0) + size_residuals # (B,NS,1) 62 | sizes = mean_sizes + size_residuals # (B,NS,3) 63 | sizes = tf.tile(tf.expand_dims(sizes,1), [1,NUM_HEADING_BIN,1,1]) # (B,NH,NS,3) 64 | headings = tf.tile(tf.expand_dims(headings,-1), [1,1,NUM_SIZE_CLUSTER]) # (B,NH,NS) 65 | centers = tf.tile(tf.expand_dims(tf.expand_dims(center,1),1), [1,NUM_HEADING_BIN, NUM_SIZE_CLUSTER,1]) # (B,NH,NS,3) 66 | 67 | N = batch_size*NUM_HEADING_BIN*NUM_SIZE_CLUSTER 68 | corners_3d = get_box3d_corners_helper(tf.reshape(centers, [N,3]), tf.reshape(headings, [N]), tf.reshape(sizes, [N,3])) 69 | 70 | return tf.reshape(corners_3d, [batch_size, NUM_HEADING_BIN, NUM_SIZE_CLUSTER, 8, 3]) 71 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/train_util.py: -------------------------------------------------------------------------------- 1 | ''' Utils for training. 2 | 3 | Author: Charles R. Qi 4 | Date: October 2017 5 | ''' 6 | 7 | import numpy as np 8 | from roi_seg_box3d_dataset import NUM_CLASS 9 | 10 | def get_batch(dataset, idxs, start_idx, end_idx, num_point, num_channel, from_rgb_detection=False): 11 | if from_rgb_detection: 12 | return get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx, num_point, num_channel) 13 | 14 | bsize = end_idx-start_idx 15 | batch_data = np.zeros((bsize, num_point, num_channel)) 16 | batch_label = np.zeros((bsize, num_point), dtype=np.int32) 17 | batch_center = np.zeros((bsize, 3)) 18 | batch_heading_class = np.zeros((bsize,), dtype=np.int32) 19 | batch_heading_residual = np.zeros((bsize,)) 20 | batch_size_class = np.zeros((bsize,), dtype=np.int32) 21 | batch_size_residual = np.zeros((bsize, 3)) 22 | batch_rot_angle = np.zeros((bsize,)) 23 | if dataset.one_hot: batch_one_hot_vec = np.zeros((bsize,NUM_CLASS)) # for car,ped,cyc 24 | for i in range(bsize): 25 | if dataset.one_hot: 26 | ps,seg,center,hclass,hres,sclass,sres,rotangle,onehotvec = dataset[idxs[i+start_idx]] 27 | batch_one_hot_vec[i] = onehotvec 28 | else: 29 | ps,seg,center,hclass,hres,sclass,sres,rotangle = dataset[idxs[i+start_idx]] 30 | batch_data[i,...] = ps[:,0:num_channel] 31 | batch_label[i,:] = seg 32 | batch_center[i,:] = center 33 | batch_heading_class[i] = hclass 34 | batch_heading_residual[i] = hres 35 | batch_size_class[i] = sclass 36 | batch_size_residual[i] = sres 37 | batch_rot_angle[i] = rotangle 38 | if dataset.one_hot: 39 | return batch_data, batch_label, batch_center, batch_heading_class, batch_heading_residual, batch_size_class, batch_size_residual, batch_rot_angle, batch_one_hot_vec 40 | else: 41 | return batch_data, batch_label, batch_center, batch_heading_class, batch_heading_residual, batch_size_class, batch_size_residual, batch_rot_angle 42 | 43 | def get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx, num_point, num_channel): 44 | bsize = end_idx-start_idx 45 | batch_data = np.zeros((bsize, num_point, num_channel)) 46 | batch_rot_angle = np.zeros((bsize,)) 47 | batch_prob = np.zeros((bsize,)) 48 | if dataset.one_hot: batch_one_hot_vec = np.zeros((bsize,NUM_CLASS)) # for car,ped,cyc 49 | for i in range(bsize): 50 | if dataset.one_hot: 51 | ps,rotangle,prob,onehotvec = dataset[idxs[i+start_idx]] 52 | batch_one_hot_vec[i] = onehotvec 53 | else: 54 | ps,rotangle,prob = dataset[idxs[i+start_idx]] 55 | batch_data[i,...] = ps[:,0:num_channel] 56 | batch_rot_angle[i] = rotangle 57 | batch_prob[i] = prob 58 | if dataset.one_hot: 59 | return batch_data, batch_rot_angle, batch_prob, batch_one_hot_vec 60 | else: 61 | return batch_data, batch_rot_angle, batch_prob 62 | 63 | 64 | -------------------------------------------------------------------------------- /sunrgbd/sunrgbd_detection/viz_eval.py: -------------------------------------------------------------------------------- 1 | ''' Example usage: 2 | python viz.py --data_path roi_seg_box3d_caronly_val_0911.pickle --result_path test_results_caronly_aug5x.pickle 3 | 4 | Take GT box2d, eval 3D box estimation accuracy. Also able to visualize 3D predictions. 5 | ''' 6 | import cPickle as pickle 7 | import numpy as np 8 | import argparse 9 | from PIL import Image 10 | import cv2 11 | import sys 12 | import os 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | import roi_seg_box3d_dataset 15 | sys.path.append(os.path.join(BASE_DIR, '../sunrgbd_data')) 16 | from sunrgbd_data import sunrgbd_object 17 | from utils import load_zipped_pickle 18 | sys.path.append(os.path.join(BASE_DIR, '../../train')) 19 | sys.path.append(os.path.join(BASE_DIR, '../../mayavi')) 20 | from box_util import box3d_iou 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--data_path', default=None, help='data path for .pickle file, the one used for val in train.py [default: None]') 24 | parser.add_argument('--result_path', default=None, help='result path for .pickle file from test.py [default: None]') 25 | parser.add_argument('--viz', action='store_true', help='to visualize error result.') 26 | parser.add_argument('--from_rgb_detection', action='store_true', help='test from data file from rgb detection.') 27 | FLAGS = parser.parse_args() 28 | 29 | IMG_DIR = '/home/rqi/Data/mysunrgbd/training/image' 30 | TEST_DATASET = roi_seg_box3d_dataset.ROISegBoxDataset(npoints=2048, split='val', rotate_to_center=True, overwritten_data_path=FLAGS.data_path, from_rgb_detection=FLAGS.from_rgb_detection) 31 | dataset = sunrgbd_object('/home/rqi/Data/mysunrgbd', 'training') 32 | VISU = FLAGS.viz 33 | if VISU: 34 | import mayavi.mlab as mlab 35 | from viz_util import draw_lidar, draw_gt_boxes3d 36 | 37 | #with open(FLAGS.result_path, 'rb') as fp: 38 | # ps_list = pickle.load(fp) 39 | # segp_list = pickle.load(fp) 40 | # center_list = pickle.load(fp) 41 | # heading_cls_list = pickle.load(fp) 42 | # heading_res_list = pickle.load(fp) 43 | # size_cls_list = pickle.load(fp) 44 | # size_res_list = pickle.load(fp) 45 | # rot_angle_list = pickle.load(fp) 46 | # score_list = pickle.load(fp) 47 | ps_list, segp_list, center_list, heading_cls_list, heading_res_list, size_cls_list, size_res_list, rot_angle_list, score_list = load_zipped_pickle(FLAGS.result_path) 48 | 49 | total_cnt = 0 50 | correct_cnt = 0 51 | type_whitelist=['bed','table','sofa','chair','toilet','desk','dresser','night_stand','bookshelf','bathtub'] 52 | class_correct_cnt = {classname:0 for classname in type_whitelist} 53 | class_total_cnt = {classname:0 for classname in type_whitelist} 54 | for i in range(len(segp_list)): 55 | print " ---- %d/%d"%(i,len(segp_list)) 56 | img_id = TEST_DATASET.id_list[i] 57 | box2d = TEST_DATASET.box2d_list[i] 58 | classname = TEST_DATASET.type_list[i] 59 | 60 | objects = dataset.get_label_objects(img_id) 61 | target_obj = None 62 | for obj in objects: # **Assuming we use GT box2d for 3D box estimation evaluation** 63 | if np.sum(np.abs(obj.box2d-box2d))<1e-3: 64 | target_obj = obj 65 | break 66 | assert(target_obj is not None) 67 | 68 | box3d = TEST_DATASET.get_center_view_box3d(i) 69 | ps = ps_list[i] 70 | segp = segp_list[i].squeeze() 71 | center = center_list[i].squeeze() 72 | ret = TEST_DATASET[i] 73 | rot_angle = ret[7] 74 | 75 | # Get heading angle and size 76 | print heading_cls_list[i], heading_res_list[i], size_cls_list[i], size_res_list[i] 77 | heading_angle = roi_seg_box3d_dataset.class2angle(heading_cls_list[i], heading_res_list[i], 12) 78 | box_size = roi_seg_box3d_dataset.class2size(size_cls_list[i], size_res_list[i]) 79 | corners_3d_pred = roi_seg_box3d_dataset.get_3d_box(box_size, heading_angle, center) 80 | 81 | # NOTE: fix this, box3d (projected from upright_depth coord) has flipped ymin,ymax as that in corners_3d_pred 82 | box3d_new = np.copy(box3d) 83 | box3d_new[0:4,:] = box3d[4:,:] 84 | box3d_new[4:,:] = box3d[0:4,:] 85 | iou_3d, iou_2d = box3d_iou(corners_3d_pred, box3d_new) 86 | print corners_3d_pred 87 | print box3d_new 88 | print 'Ground/3D IoU: ', iou_2d, iou_3d 89 | correct = int(iou_3d >= 0.25) 90 | total_cnt += 1 91 | correct_cnt += correct 92 | class_total_cnt[classname] += 1 93 | class_correct_cnt[classname] += correct 94 | 95 | if VISU: #and iou_3d<0.7: 96 | img_filename = os.path.join(IMG_DIR, '%06d.jpg'%(img_id)) 97 | img = cv2.imread(img_filename) 98 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 99 | cv2.rectangle(img, (int(box2d[0]),int(box2d[1])), (int(box2d[2]),int(box2d[3])), (0,255,0), 3) 100 | Image.fromarray(img).show() 101 | 102 | # Draw figures 103 | fig = mlab.figure(figure=None, bgcolor=(0.6,0.6,0.6), fgcolor=None, engine=None, size=(1000, 500)) 104 | mlab.points3d(0, 0, 0, color=(1,1,1), mode='sphere', scale_factor=0.2, figure=fig) 105 | mlab.points3d(ps[:,0], ps[:,1], ps[:,2], segp, mode='point', colormap='gnuplot', scale_factor=1, figure=fig) 106 | draw_gt_boxes3d([box3d], fig, color = (0,0,1), draw_text=False) 107 | draw_gt_boxes3d([corners_3d_pred], fig, color = (0,1,0), draw_text=False) 108 | mlab.points3d(center[0], center[1], center[2], color=(0,1,0), mode='sphere', scale_factor=0.4, figure=fig) 109 | mlab.orientation_axes() 110 | raw_input() 111 | 112 | print '-----------------------' 113 | print 'Total cnt: %d, acuracy: %f' % (total_cnt, correct_cnt/float(total_cnt)) 114 | for classname in type_whitelist: 115 | print 'Class: %s\tcnt: %d\taccuracy: %f' % (classname.ljust(15), class_total_cnt[classname], class_correct_cnt[classname]/float(class_total_cnt[classname])) 116 | 117 | -------------------------------------------------------------------------------- /train/box_util.py: -------------------------------------------------------------------------------- 1 | """ Helper functions for calculating 2D and 3D bounding box IoU. 2 | 3 | Collected by Charles R. Qi 4 | Date: September 2017 5 | """ 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | from scipy.spatial import ConvexHull 10 | 11 | def polygon_clip(subjectPolygon, clipPolygon): 12 | """ Clip a polygon with another polygon. 13 | 14 | Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python 15 | 16 | Args: 17 | subjectPolygon: a list of (x,y) 2d points, any polygon. 18 | clipPolygon: a list of (x,y) 2d points, has to be *convex* 19 | Note: 20 | **points have to be counter-clockwise ordered** 21 | 22 | Return: 23 | a list of (x,y) vertex point for the intersection polygon. 24 | """ 25 | def inside(p): 26 | return(cp2[0]-cp1[0])*(p[1]-cp1[1]) > (cp2[1]-cp1[1])*(p[0]-cp1[0]) 27 | 28 | def computeIntersection(): 29 | dc = [ cp1[0] - cp2[0], cp1[1] - cp2[1] ] 30 | dp = [ s[0] - e[0], s[1] - e[1] ] 31 | n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] 32 | n2 = s[0] * e[1] - s[1] * e[0] 33 | n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) 34 | return [(n1*dp[0] - n2*dc[0]) * n3, (n1*dp[1] - n2*dc[1]) * n3] 35 | 36 | outputList = subjectPolygon 37 | cp1 = clipPolygon[-1] 38 | 39 | for clipVertex in clipPolygon: 40 | cp2 = clipVertex 41 | inputList = outputList 42 | outputList = [] 43 | s = inputList[-1] 44 | 45 | for subjectVertex in inputList: 46 | e = subjectVertex 47 | if inside(e): 48 | if not inside(s): 49 | outputList.append(computeIntersection()) 50 | outputList.append(e) 51 | elif inside(s): 52 | outputList.append(computeIntersection()) 53 | s = e 54 | cp1 = cp2 55 | if len(outputList) == 0: 56 | return None 57 | return(outputList) 58 | 59 | def poly_area(x,y): 60 | """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """ 61 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 62 | 63 | def convex_hull_intersection(p1, p2): 64 | """ Compute area of two convex hull's intersection area. 65 | p1,p2 are a list of (x,y) tuples of hull vertices. 66 | return a list of (x,y) for the intersection and its volume 67 | """ 68 | inter_p = polygon_clip(p1,p2) 69 | if inter_p is not None: 70 | hull_inter = ConvexHull(inter_p) 71 | return inter_p, hull_inter.volume 72 | else: 73 | return None, 0.0 74 | 75 | def box3d_vol(corners): 76 | ''' corners: (8,3) no assumption on axis direction ''' 77 | a = np.sqrt(np.sum((corners[0,:] - corners[1,:])**2)) 78 | b = np.sqrt(np.sum((corners[1,:] - corners[2,:])**2)) 79 | c = np.sqrt(np.sum((corners[0,:] - corners[4,:])**2)) 80 | return a*b*c 81 | 82 | def is_clockwise(p): 83 | x = p[:,0] 84 | y = p[:,1] 85 | return np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)) > 0 86 | 87 | def box3d_iou(corners1, corners2): 88 | ''' Compute 3D bounding box IoU. 89 | 90 | Input: 91 | corners1: numpy array (8,3), assume up direction is negative Y 92 | corners2: numpy array (8,3), assume up direction is negative Y 93 | Output: 94 | iou: 3D bounding box IoU 95 | iou_2d: bird's eye view 2D bounding box IoU 96 | 97 | todo (rqi): add more description on corner points' orders. 98 | ''' 99 | # corner points are in counter clockwise order 100 | rect1 = [(corners1[i,0], corners1[i,2]) for i in range(3,-1,-1)] 101 | rect2 = [(corners2[i,0], corners2[i,2]) for i in range(3,-1,-1)] 102 | area1 = poly_area(np.array(rect1)[:,0], np.array(rect1)[:,1]) 103 | area2 = poly_area(np.array(rect2)[:,0], np.array(rect2)[:,1]) 104 | inter, inter_area = convex_hull_intersection(rect1, rect2) 105 | iou_2d = inter_area/(area1+area2-inter_area) 106 | ymax = min(corners1[0,1], corners2[0,1]) 107 | ymin = max(corners1[4,1], corners2[4,1]) 108 | inter_vol = inter_area * max(0.0, ymax-ymin) 109 | vol1 = box3d_vol(corners1) 110 | vol2 = box3d_vol(corners2) 111 | iou = inter_vol / (vol1 + vol2 - inter_vol) 112 | return iou, iou_2d 113 | 114 | 115 | def get_iou(bb1, bb2): 116 | """ 117 | Calculate the Intersection over Union (IoU) of two 2D bounding boxes. 118 | 119 | Parameters 120 | ---------- 121 | bb1 : dict 122 | Keys: {'x1', 'x2', 'y1', 'y2'} 123 | The (x1, y1) position is at the top left corner, 124 | the (x2, y2) position is at the bottom right corner 125 | bb2 : dict 126 | Keys: {'x1', 'x2', 'y1', 'y2'} 127 | The (x, y) position is at the top left corner, 128 | the (x2, y2) position is at the bottom right corner 129 | 130 | Returns 131 | ------- 132 | float 133 | in [0, 1] 134 | """ 135 | assert bb1['x1'] < bb1['x2'] 136 | assert bb1['y1'] < bb1['y2'] 137 | assert bb2['x1'] < bb2['x2'] 138 | assert bb2['y1'] < bb2['y2'] 139 | 140 | # determine the coordinates of the intersection rectangle 141 | x_left = max(bb1['x1'], bb2['x1']) 142 | y_top = max(bb1['y1'], bb2['y1']) 143 | x_right = min(bb1['x2'], bb2['x2']) 144 | y_bottom = min(bb1['y2'], bb2['y2']) 145 | 146 | if x_right < x_left or y_bottom < y_top: 147 | return 0.0 148 | 149 | # The intersection of two axis-aligned bounding boxes is always an 150 | # axis-aligned bounding box 151 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 152 | 153 | # compute the area of both AABBs 154 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) 155 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) 156 | 157 | # compute the intersection over union by taking the intersection 158 | # area and dividing it by the sum of prediction + ground-truth 159 | # areas - the interesection area 160 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 161 | assert iou >= 0.0 162 | assert iou <= 1.0 163 | return iou 164 | 165 | def box2d_iou(box1, box2): 166 | ''' Compute 2D bounding box IoU. 167 | 168 | Input: 169 | box1: tuple of (xmin,ymin,xmax,ymax) 170 | box2: tuple of (xmin,ymin,xmax,ymax) 171 | Output: 172 | iou: 2D IoU scalar 173 | ''' 174 | return get_iou({'x1':box1[0], 'y1':box1[1], 'x2':box1[2], 'y2':box1[3]}, \ 175 | {'x1':box2[0], 'y1':box2[1], 'x2':box2[2], 'y2':box2[3]}) 176 | 177 | 178 | if __name__=='__main__': 179 | 180 | # Function for polygon ploting 181 | import matplotlib 182 | from matplotlib.patches import Polygon 183 | from matplotlib.collections import PatchCollection 184 | import matplotlib.pyplot as plt 185 | def plot_polys(plist,scale=500.0): 186 | fig, ax = plt.subplots() 187 | patches = [] 188 | for p in plist: 189 | poly = Polygon(np.array(p)/scale, True) 190 | patches.append(poly) 191 | 192 | pc = PatchCollection(patches, cmap=matplotlib.cm.jet, alpha=0.5) 193 | colors = 100*np.random.rand(len(patches)) 194 | pc.set_array(np.array(colors)) 195 | ax.add_collection(pc) 196 | plt.show() 197 | 198 | # Demo on ConvexHull 199 | points = np.random.rand(30, 2) # 30 random points in 2-D 200 | hull = ConvexHull(points) 201 | # **In 2D "volume" is is area, "area" is perimeter 202 | print(('Hull area: ', hull.volume)) 203 | for simplex in hull.simplices: 204 | print(simplex) 205 | 206 | # Demo on convex hull overlaps 207 | sub_poly = [(0,0),(300,0),(300,300),(0,300)] 208 | clip_poly = [(150,150),(300,300),(150,450),(0,300)] 209 | inter_poly = polygon_clip(sub_poly, clip_poly) 210 | print(poly_area(np.array(inter_poly)[:,0], np.array(inter_poly)[:,1])) 211 | 212 | # Test convex hull interaction function 213 | rect1 = [(50,0),(50,300),(300,300),(300,0)] 214 | rect2 = [(150,150),(300,300),(150,450),(0,300)] 215 | plot_polys([rect1, rect2]) 216 | inter, area = convex_hull_intersection(rect1, rect2) 217 | print((inter, area)) 218 | if inter is not None: 219 | print(poly_area(np.array(inter)[:,0], np.array(inter)[:,1])) 220 | 221 | print('------------------') 222 | rect1 = [(0.30026005199835404, 8.9408694211408424), \ 223 | (-1.1571105364358421, 9.4686676477075533), \ 224 | (0.1777082043006144, 13.154404877812102), \ 225 | (1.6350787927348105, 12.626606651245391)] 226 | rect1 = [rect1[0], rect1[3], rect1[2], rect1[1]] 227 | rect2 = [(0.23908745901608636, 8.8551095691132886), \ 228 | (-1.2771419487733995, 9.4269062966181956), \ 229 | (0.13138836963152717, 13.161896351296868), \ 230 | (1.647617777421013, 12.590099623791961)] 231 | rect2 = [rect2[0], rect2[3], rect2[2], rect2[1]] 232 | plot_polys([rect1, rect2]) 233 | inter, area = convex_hull_intersection(rect1, rect2) 234 | print((inter, area)) 235 | -------------------------------------------------------------------------------- /train/kitti_eval/README.md: -------------------------------------------------------------------------------- 1 | Reference: https://github.com/prclibo/kitti_eval 2 | 3 | # kitti_eval 4 | 5 | `evaluate_object_3d_offline.cpp`evaluates your KITTI detection locally on your own computer using your validation data selected from KITTI training dataset, with the following metrics: 6 | 7 | - overlap on image (AP) 8 | - oriented overlap on image (AOS) 9 | - overlap on ground-plane (AP) 10 | - overlap in 3D (AP) 11 | 12 | Compile `evaluate_object_3d_offline.cpp` with dependency of Boost and Linux `dirent.h` (You should already have it under most Linux). 13 | 14 | Run the evalutaion by: 15 | 16 | ./evaluate_object_3d_offline groundtruth_dir result_dir 17 | 18 | Note that you don't have to detect over all KITTI training data. The evaluator only evaluates samples whose result files exist. 19 | 20 | 21 | ### Updates 22 | 23 | - June, 2017: 24 | * Fixed the bug of detection box filtering based on min height according to KITTI's note on 25.04.2017. 25 | -------------------------------------------------------------------------------- /train/kitti_eval/compile.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | g++ -o evaluate_object_3d_offline evaluate_object_3d_offline.cpp 3 | -------------------------------------------------------------------------------- /train/kitti_eval/evaluate_object_3d_offline: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charlesq34/frustum-pointnets/2ffdd345e1fce4775ecb508d207e0ad465bcca80/train/kitti_eval/evaluate_object_3d_offline -------------------------------------------------------------------------------- /train/kitti_eval/mail.h: -------------------------------------------------------------------------------- 1 | #ifndef MAIL_H 2 | #define MAIL_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | class Mail { 9 | 10 | public: 11 | 12 | Mail (std::string email = "") { 13 | if (email.compare("")) { 14 | mail = popen("/usr/lib/sendmail -t -f noreply@cvlibs.net","w"); 15 | fprintf(mail,"To: %s\n", email.c_str()); 16 | fprintf(mail,"From: noreply@cvlibs.net\n"); 17 | fprintf(mail,"Subject: KITTI Evaluation Benchmark\n"); 18 | fprintf(mail,"\n\n"); 19 | } else { 20 | mail = 0; 21 | } 22 | } 23 | 24 | ~Mail() { 25 | if (mail) { 26 | pclose(mail); 27 | } 28 | } 29 | 30 | void msg (const char *format, ...) { 31 | va_list args; 32 | va_start(args,format); 33 | if (mail) { 34 | vfprintf(mail,format,args); 35 | fprintf(mail,"\n"); 36 | } 37 | vprintf(format,args); 38 | printf("\n"); 39 | va_end(args); 40 | } 41 | 42 | private: 43 | 44 | FILE *mail; 45 | 46 | }; 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /train/train_util.py: -------------------------------------------------------------------------------- 1 | ''' Util functions for training and evaluation. 2 | 3 | Author: Charles R. Qi 4 | Date: September 2017 5 | ''' 6 | 7 | import numpy as np 8 | 9 | def get_batch(dataset, idxs, start_idx, end_idx, 10 | num_point, num_channel, 11 | from_rgb_detection=False): 12 | ''' Prepare batch data for training/evaluation. 13 | batch size is determined by start_idx-end_idx 14 | 15 | Input: 16 | dataset: an instance of FrustumDataset class 17 | idxs: a list of data element indices 18 | start_idx: int scalar, start position in idxs 19 | end_idx: int scalar, end position in idxs 20 | num_point: int scalar 21 | num_channel: int scalar 22 | from_rgb_detection: bool 23 | Output: 24 | batched data and label 25 | ''' 26 | if from_rgb_detection: 27 | return get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx, 28 | num_point, num_channel) 29 | 30 | bsize = end_idx-start_idx 31 | batch_data = np.zeros((bsize, num_point, num_channel)) 32 | batch_label = np.zeros((bsize, num_point), dtype=np.int32) 33 | batch_center = np.zeros((bsize, 3)) 34 | batch_heading_class = np.zeros((bsize,), dtype=np.int32) 35 | batch_heading_residual = np.zeros((bsize,)) 36 | batch_size_class = np.zeros((bsize,), dtype=np.int32) 37 | batch_size_residual = np.zeros((bsize, 3)) 38 | batch_rot_angle = np.zeros((bsize,)) 39 | if dataset.one_hot: 40 | batch_one_hot_vec = np.zeros((bsize,3)) # for car,ped,cyc 41 | for i in range(bsize): 42 | if dataset.one_hot: 43 | ps,seg,center,hclass,hres,sclass,sres,rotangle,onehotvec = \ 44 | dataset[idxs[i+start_idx]] 45 | batch_one_hot_vec[i] = onehotvec 46 | else: 47 | ps,seg,center,hclass,hres,sclass,sres,rotangle = \ 48 | dataset[idxs[i+start_idx]] 49 | batch_data[i,...] = ps[:,0:num_channel] 50 | batch_label[i,:] = seg 51 | batch_center[i,:] = center 52 | batch_heading_class[i] = hclass 53 | batch_heading_residual[i] = hres 54 | batch_size_class[i] = sclass 55 | batch_size_residual[i] = sres 56 | batch_rot_angle[i] = rotangle 57 | if dataset.one_hot: 58 | return batch_data, batch_label, batch_center, \ 59 | batch_heading_class, batch_heading_residual, \ 60 | batch_size_class, batch_size_residual, \ 61 | batch_rot_angle, batch_one_hot_vec 62 | else: 63 | return batch_data, batch_label, batch_center, \ 64 | batch_heading_class, batch_heading_residual, \ 65 | batch_size_class, batch_size_residual, batch_rot_angle 66 | 67 | def get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx, 68 | num_point, num_channel): 69 | bsize = end_idx-start_idx 70 | batch_data = np.zeros((bsize, num_point, num_channel)) 71 | batch_rot_angle = np.zeros((bsize,)) 72 | batch_prob = np.zeros((bsize,)) 73 | if dataset.one_hot: 74 | batch_one_hot_vec = np.zeros((bsize,3)) # for car,ped,cyc 75 | for i in range(bsize): 76 | if dataset.one_hot: 77 | ps,rotangle,prob,onehotvec = dataset[idxs[i+start_idx]] 78 | batch_one_hot_vec[i] = onehotvec 79 | else: 80 | ps,rotangle,prob = dataset[idxs[i+start_idx]] 81 | batch_data[i,...] = ps[:,0:num_channel] 82 | batch_rot_angle[i] = rotangle 83 | batch_prob[i] = prob 84 | if dataset.one_hot: 85 | return batch_data, batch_rot_angle, batch_prob, batch_one_hot_vec 86 | else: 87 | return batch_data, batch_rot_angle, batch_prob 88 | 89 | 90 | --------------------------------------------------------------------------------