├── README.md ├── dataset_utils ├── calibration_waymo.py └── object3d.py ├── example_usage.py └── waymo_pytorch_dataset.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Waymo pytorch Dataloader 3 | 4 | Pytorch dataloader for object detection tasks 5 | - Quick attach to your kitti training files 6 | - Has kitti format of calibration and label object 7 | - Uses only top lidar and all 5 images 8 | - Only serial dataloader 9 | Please feel free to send pull requests if you have any changes 10 | 11 | ## Installation 12 | 13 | ```bash 14 | git clone Waymo-pytorch-dataloader 15 | git clone https://github.com/gdlg/simple-waymo-open-dataset-reader 16 | ``` 17 | or recursively download the subrepository like 18 | ```bash 19 | git clone Waymo-pytorch-dataloader --recursive 20 | ``` 21 | 22 | Directly use the dataloader in your script like: 23 | ```python 24 | DATA_PATH = '/home/jupyter/waymo-od/waymo_dataset' 25 | LOCATIONS = ['location_sf'] 26 | 27 | dataset = WaymoDataset(DATA_PATH, LOCATIONS, 'train', True, "new_waymo") 28 | 29 | frame, idx = dataset.data, dataset.count 30 | calib = dataset.get_calib(frame, idx) 31 | pts = dataset.get_lidar(frame, idx) 32 | target = dataset.get_label(frame, idx) 33 | ``` 34 | 35 | 36 | ## License 37 | 38 | This code is released under the Apache License, version 2.0. This projects incorporate some parts of the [Waymo Open Dataset code](https://github.com/waymo-research/waymo-open-dataset/blob/master/README.md) (the files `simple_waymo_open_dataset_reader/*.proto`) and is licensed to you under their original license terms. See `LICENSE` file for details. 39 | -------------------------------------------------------------------------------- /dataset_utils/calibration_waymo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | 5 | def get_calib_from_file(frame, frame_num): 6 | 7 | waymo_cam_RT=np.array([0,-1,0,0, 0,0,-1,0, 1,0,0,0, 0 ,0 ,0 ,1]).reshape(4,4) 8 | camera_calib = [] 9 | R0_rect = ["%e" % i for i in np.eye(3).flatten()] 10 | Tr_velo_to_cam = [] 11 | calib_context = '' 12 | 13 | for camera in frame.context.camera_calibrations: 14 | tmp=np.array(camera.extrinsic.transform).reshape(4,4) 15 | tmp=np.linalg.inv(tmp).reshape((16,)) 16 | Tr_velo_to_cam.append(["%e" % i for i in tmp]) 17 | 18 | for cam in frame.context.camera_calibrations: 19 | tmp=np.zeros((3,4)) 20 | tmp[0,0]=cam.intrinsic[0] 21 | tmp[1,1]=cam.intrinsic[1] 22 | tmp[0,2]=cam.intrinsic[2] 23 | tmp[1,2]=cam.intrinsic[3] 24 | tmp[2,2]=1 25 | tmp=(tmp @ waymo_cam_RT) 26 | tmp=list(tmp.reshape(12)) 27 | tmp = ["%e" % i for i in tmp] 28 | camera_calib.append(tmp) 29 | 30 | for i in range(5): 31 | calib_context += "P" + str(i) + ": " + " ".join(camera_calib[i]) + '\n' 32 | calib_context += "R0_rect" + ": " + " ".join(R0_rect) + '\n' 33 | for i in range(5): 34 | calib_context += "Tr_velo_to_cam_" + str(i) + ": " + " ".join(Tr_velo_to_cam[i]) + '\n' 35 | 36 | lines = calib_context.split('\n') 37 | obj = lines[2].strip().split(' ')[1:] 38 | P2 = np.array(obj, dtype=np.float32) 39 | obj = lines[3].strip().split(' ')[1:] 40 | P3 = np.array(obj, dtype=np.float32) 41 | obj = lines[5].strip().split(' ')[1:] # Waymo 42 | R0 = np.array(obj, dtype=np.float32) 43 | obj = lines[6].strip().split(' ')[1:]# Waymo 44 | Tr_velo_to_cam = np.array(obj, dtype=np.float32).reshape(4,4)[:3, :4] # Recheck this 45 | return {'P2': P2.reshape(3, 4), 46 | 'P3': P3.reshape(3, 4), 47 | 'R0': R0.reshape(3, 3), 48 | 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)} 49 | 50 | class Calibration(object): 51 | def __init__(self, frame, frame_num=None): 52 | if frame_num is not None: 53 | calib = get_calib_from_file(frame, frame_num) 54 | else: 55 | calib = calib_file 56 | self.P2 = calib['P2'] # 3 x 4 57 | self.R0 = calib['R0'] # 3 x 3 58 | self.V2C = calib['Tr_velo2cam'] # 3 x 4 59 | 60 | # Camera intrinsics and extrinsics 61 | self.cu = self.P2[0, 2] 62 | self.cv = self.P2[1, 2] 63 | self.fu = self.P2[0, 0] 64 | self.fv = self.P2[1, 1] 65 | self.tx = self.P2[0, 3] / (-self.fu) 66 | self.ty = self.P2[1, 3] / (-self.fv) 67 | 68 | def cart_to_hom(self, pts): 69 | """ 70 | :param pts: (N, 3 or 2) 71 | :return pts_hom: (N, 4 or 3) 72 | """ 73 | pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32))) 74 | return pts_hom 75 | 76 | def lidar_to_rect(self, pts_lidar): 77 | """ 78 | :param pts_lidar: (N, 3) 79 | :return pts_rect: (N, 3) 80 | """ 81 | pts_lidar_hom = self.cart_to_hom(pts_lidar) 82 | pts_rect = np.dot(pts_lidar_hom, np.dot(self.V2C.T, self.R0.T)) 83 | # pts_rect = reduce(np.dot, (pts_lidar_hom, self.V2C.T, self.R0.T)) 84 | return pts_rect 85 | 86 | def rect_to_img(self, pts_rect): 87 | """ 88 | :param pts_rect: (N, 3) 89 | :return pts_img: (N, 2) 90 | """ 91 | pts_rect_hom = self.cart_to_hom(pts_rect) 92 | pts_2d_hom = np.dot(pts_rect_hom, self.P2.T) 93 | pts_img = (pts_2d_hom[:, 0:2].T / pts_rect_hom[:, 2]).T # (N, 2) 94 | pts_rect_depth = pts_2d_hom[:, 2] - self.P2.T[3, 2] # depth in rect camera coord 95 | return pts_img, pts_rect_depth 96 | 97 | def lidar_to_img(self, pts_lidar): 98 | """ 99 | :param pts_lidar: (N, 3) 100 | :return pts_img: (N, 2) 101 | """ 102 | pts_rect = self.lidar_to_rect(pts_lidar) 103 | pts_img, pts_depth = self.rect_to_img(pts_rect) 104 | return pts_img, pts_depth 105 | 106 | def img_to_rect(self, u, v, depth_rect): 107 | """ 108 | :param u: (N) 109 | :param v: (N) 110 | :param depth_rect: (N) 111 | :return: 112 | """ 113 | x = ((u - self.cu) * depth_rect) / self.fu + self.tx 114 | y = ((v - self.cv) * depth_rect) / self.fv + self.ty 115 | pts_rect = np.concatenate((x.reshape(-1, 1), y.reshape(-1, 1), depth_rect.reshape(-1, 1)), axis=1) 116 | return pts_rect 117 | 118 | def depthmap_to_rect(self, depth_map): 119 | """ 120 | :param depth_map: (H, W), depth_map 121 | :return: 122 | """ 123 | x_range = np.arange(0, depth_map.shape[1]) 124 | y_range = np.arange(0, depth_map.shape[0]) 125 | x_idxs, y_idxs = np.meshgrid(x_range, y_range) 126 | x_idxs, y_idxs = x_idxs.reshape(-1), y_idxs.reshape(-1) 127 | depth = depth_map[y_idxs, x_idxs] 128 | pts_rect = self.img_to_rect(x_idxs, y_idxs, depth) 129 | return pts_rect, x_idxs, y_idxs 130 | 131 | def corners3d_to_img_boxes(self, corners3d): 132 | """ 133 | :param corners3d: (N, 8, 3) corners in rect coordinate 134 | :return: boxes: (None, 4) [x1, y1, x2, y2] in rgb coordinate 135 | :return: boxes_corner: (None, 8) [xi, yi] in rgb coordinate 136 | """ 137 | sample_num = corners3d.shape[0] 138 | corners3d_hom = np.concatenate((corners3d, np.ones((sample_num, 8, 1))), axis=2) # (N, 8, 4) 139 | 140 | img_pts = np.matmul(corners3d_hom, self.P2.T) # (N, 8, 3) 141 | 142 | x, y = img_pts[:, :, 0] / img_pts[:, :, 2], img_pts[:, :, 1] / img_pts[:, :, 2] 143 | x1, y1 = np.min(x, axis=1), np.min(y, axis=1) 144 | x2, y2 = np.max(x, axis=1), np.max(y, axis=1) 145 | 146 | boxes = np.concatenate((x1.reshape(-1, 1), y1.reshape(-1, 1), x2.reshape(-1, 1), y2.reshape(-1, 1)), axis=1) 147 | boxes_corner = np.concatenate((x.reshape(-1, 8, 1), y.reshape(-1, 8, 1)), axis=2) 148 | 149 | return boxes, boxes_corner 150 | 151 | def camera_dis_to_rect(self, u, v, d): 152 | """ 153 | Can only process valid u, v, d, which means u, v can not beyond the image shape, reprojection error 0.02 154 | :param u: (N) 155 | :param v: (N) 156 | :param d: (N), the distance between camera and 3d points, d^2 = x^2 + y^2 + z^2 157 | :return: 158 | """ 159 | assert self.fu == self.fv, '%.8f != %.8f' % (self.fu, self.fv) 160 | fd = np.sqrt((u - self.cu)**2 + (v - self.cv)**2 + self.fu**2) 161 | x = ((u - self.cu) * d) / fd + self.tx 162 | y = ((v - self.cv) * d) / fd + self.ty 163 | z = np.sqrt(d**2 - x**2 - y**2) 164 | pts_rect = np.concatenate((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), axis=1) 165 | return pts_rect 166 | -------------------------------------------------------------------------------- /dataset_utils/object3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | def cls_type_to_id(cls_type): 5 | type_to_id = {'VEHICLE': 1, 'PEDESTRIAN': 2, 'CYCLIST': 3, 'UNKNOWN': 4} 6 | if cls_type not in type_to_id.keys(): 7 | return -1 8 | return type_to_id[cls_type] 9 | 10 | 11 | class Object3d(object): 12 | def __init__(self, frame=None, frame_num=None): 13 | if frame: 14 | line = frame 15 | label = line.strip().split(' ') 16 | self.src = line 17 | self.cls_type = label[0] 18 | self.cls_id = cls_type_to_id(self.cls_type) 19 | self.trucation = float(label[1]) 20 | self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown 21 | self.alpha = float(label[3]) 22 | self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32) 23 | self.h = float(label[8]) 24 | self.w = float(label[9]) 25 | self.l = float(label[10]) 26 | self.pos = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32) 27 | self.dis_to_cam = np.linalg.norm(self.pos) 28 | self.ry = float(label[14]) 29 | self.score = float(label[15]) if label.__len__() == 16 else -1.0 30 | self.level_str = None 31 | self.level = self.get_obj_level() 32 | else: 33 | self.src = None 34 | self.cls_type = None 35 | self.cls_id = None 36 | self.trucation = None 37 | self.occlusion = None 38 | self.alpha = None 39 | self.box2d = None 40 | self.h = None 41 | self.w = None 42 | self.l = None 43 | self.pos = None 44 | self.dis_to_cam = None 45 | self.ry = None 46 | self.score = None 47 | self.level_str = None 48 | self.level = None 49 | 50 | def get_obj_level(self): 51 | height = float(self.box2d[3]) - float(self.box2d[1]) + 1 52 | 53 | if height >= 40 and self.trucation <= 0.15 and self.occlusion <= 0: 54 | self.level_str = 'Easy' 55 | return 1 # Easy 56 | elif height >= 25 and self.trucation <= 0.3 and self.occlusion <= 1: 57 | self.level_str = 'Moderate' 58 | return 2 # Moderate 59 | elif height >= 25 and self.trucation <= 0.5 and self.occlusion <= 2: 60 | self.level_str = 'Hard' 61 | return 3 # Hard 62 | else: 63 | self.level_str = 'UnKnown' 64 | return 4 65 | 66 | def generate_corners3d(self): 67 | """ 68 | generate corners3d representation for this object 69 | :return corners_3d: (8, 3) corners of box3d in camera coord 70 | """ 71 | l, h, w = self.l, self.h, self.w 72 | x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] 73 | y_corners = [0, 0, 0, 0, -h, -h, -h, -h] 74 | z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] 75 | 76 | R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)], 77 | [0, 1, 0], 78 | [-np.sin(self.ry), 0, np.cos(self.ry)]]) 79 | corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8) 80 | corners3d = np.dot(R, corners3d).T 81 | corners3d = corners3d + self.pos 82 | return corners3d 83 | 84 | def to_bev_box2d(self, oblique=True, voxel_size=0.1): 85 | """ 86 | :param bev_shape: (2) for bev shape (h, w), => (y_max, x_max) in image 87 | :param voxel_size: float, 0.1m 88 | :param oblique: 89 | :return: box2d (4, 2)/ (4) in image coordinate 90 | """ 91 | if oblique: 92 | corners3d = self.generate_corners3d() 93 | xz_corners = corners3d[0:4, [0, 2]] 94 | box2d = np.zeros((4, 2), dtype=np.int32) 95 | box2d[:, 0] = ((xz_corners[:, 0] - Object3d.MIN_XZ[0]) / voxel_size).astype(np.int32) 96 | box2d[:, 1] = Object3d.BEV_SHAPE[0] - 1 - ((xz_corners[:, 1] - Object3d.MIN_XZ[1]) / voxel_size).astype(np.int32) 97 | box2d[:, 0] = np.clip(box2d[:, 0], 0, Object3d.BEV_SHAPE[1]) 98 | box2d[:, 1] = np.clip(box2d[:, 1], 0, Object3d.BEV_SHAPE[0]) 99 | else: 100 | box2d = np.zeros(4, dtype=np.int32) 101 | # discrete_center = np.floor((self.pos / voxel_size)).astype(np.int32) 102 | cu = np.floor((self.pos[0] - Object3d.MIN_XZ[0]) / voxel_size).astype(np.int32) 103 | cv = Object3d.BEV_SHAPE[0] - 1 - ((self.pos[2] - Object3d.MIN_XZ[1]) / voxel_size).astype(np.int32) 104 | half_l, half_w = int(self.l / voxel_size / 2), int(self.w / voxel_size / 2) 105 | box2d[0], box2d[1] = cu - half_l, cv - half_w 106 | box2d[2], box2d[3] = cu + half_l, cv + half_w 107 | 108 | return box2d 109 | 110 | def to_str(self): 111 | print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ 112 | % (self.cls_type, self.trucation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, 113 | self.pos, self.ry) 114 | return print_str 115 | 116 | def to_kitti_format(self): 117 | kitti_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \ 118 | % (self.cls_type, self.trucation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1], 119 | self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.pos[0], self.pos[1], self.pos[2], 120 | self.ry) 121 | return kitti_str 122 | 123 | 124 | -------------------------------------------------------------------------------- /example_usage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | ''' 3 | Example usage of Waymo Pytorch dataloader 4 | Inspired from "PointRCNN" RCNN dataloader 5 | ''' 6 | 7 | import numpy as np 8 | import os 9 | import pickle 10 | import torch 11 | 12 | from torch.utils.data import Dataset 13 | from lib.datasets.waymo_fresh_dataset import WaymoDataset 14 | import lib.utils.kitti_utils as kitti_utils 15 | import lib.utils.roipool3d.roipool3d_utils as roipool3d_utils 16 | from lib.config import cfg 17 | 18 | 19 | class WaymoRCNNDataset(Dataset): 20 | def __init__(self, root_dir, npoints=16384, split='train', classes='VEHICLE', mode='TRAIN', random_select=True, 21 | logger=None, rcnn_training_roi_dir=None, rcnn_training_feature_dir=None, rcnn_eval_roi_dir=None, 22 | rcnn_eval_feature_dir=None, gt_database_dir=None): 23 | self.dataset = WaymoDataset(root_dir=root_dir, split=split) 24 | 25 | if classes == 'VEHICLE': 26 | self.classes = ('Background', 'VEHICLE') 27 | aug_scene_root_dir = os.path.join(root_dir, 'KITTI', 'aug_scene') 28 | elif classes == 'PEOPLE': 29 | self.classes = ('Background', 'PEDESTRIAN', 'CYCLIST') 30 | elif classes == 'PEDESTRIAN': 31 | self.classes = ('Background', 'PEDESTRIAN') 32 | aug_scene_root_dir = os.path.join(root_dir, 'KITTI', 'aug_scene_ped') 33 | elif classes == 'CYCLIST': 34 | self.classes = ('Background', 'CYCLIST') 35 | aug_scene_root_dir = os.path.join(root_dir, 'KITTI', 'aug_scene_cyclist') 36 | else: 37 | assert False, "Invalid classes: %s" % classes 38 | 39 | self.num_class = self.classes.__len__() 40 | self.npoints = npoints 41 | self.sample_id_list = [] 42 | self.random_select = random_select 43 | self.logger = logger 44 | 45 | if split == 'train_aug': 46 | self.aug_label_dir = os.path.join(aug_scene_root_dir, 'training', 'aug_label') 47 | self.aug_pts_dir = os.path.join(aug_scene_root_dir, 'training', 'rectified_data') 48 | else: 49 | self.aug_label_dir = os.path.join(aug_scene_root_dir, 'training', 'aug_label') 50 | self.aug_pts_dir = os.path.join(aug_scene_root_dir, 'training', 'rectified_data') 51 | 52 | # for rcnn training 53 | self.rcnn_training_bbox_list = [] 54 | self.rpn_feature_list = {} 55 | self.pos_bbox_list = [] 56 | self.neg_bbox_list = [] 57 | self.far_neg_bbox_list = [] 58 | self.rcnn_eval_roi_dir = rcnn_eval_roi_dir 59 | self.rcnn_eval_feature_dir = rcnn_eval_feature_dir 60 | self.rcnn_training_roi_dir = rcnn_training_roi_dir 61 | self.rcnn_training_feature_dir = rcnn_training_feature_dir 62 | 63 | self.gt_database = None 64 | 65 | if not self.random_select: 66 | self.logger.warning('random select is False') 67 | 68 | assert mode in ['TRAIN', 'EVAL', 'TEST'], 'Invalid mode: %s' % mode 69 | self.mode = mode 70 | 71 | if cfg.RPN.ENABLED: 72 | if gt_database_dir is not None: 73 | self.gt_database = pickle.load(open(gt_database_dir, 'rb')) 74 | if cfg.GT_AUG_HARD_RATIO > 0: 75 | easy_list, hard_list = [], [] 76 | for k in range(self.gt_database.__len__()): 77 | obj = self.gt_database[k] 78 | if obj['points'].shape[0] > 100: 79 | easy_list.append(obj) 80 | else: 81 | hard_list.append(obj) 82 | self.gt_database = [easy_list, hard_list] 83 | logger.info('Loading gt_database(easy(pt_num>100): %d, hard(pt_num<=100): %d) from %s' 84 | % (len(easy_list), len(hard_list), gt_database_dir)) 85 | else: 86 | logger.info('Loading gt_database(%d) from %s' % (len(self.gt_database), gt_database_dir)) 87 | 88 | if mode == 'TRAIN': 89 | # self.preprocess_rpn_training_data() 90 | pass 91 | else: 92 | # self.sample_id_list = self.framenum_to_idx 93 | num_samples = 190 * self.dataset.num_files 94 | self.logger.info('Load testing samples from %s' % root_dir + " " + split) 95 | self.logger.info('Done: total test samples %d' % len(self.sample_id_list)) 96 | elif cfg.RCNN.ENABLED: 97 | for idx in range(0, self.num_sample): 98 | sample_id = int(self.image_idx_list[idx]) 99 | obj_list = self.filtrate_objects(self.get_label(sample_id)) 100 | if len(obj_list) == 0: 101 | # logger.info('No gt classes: %06d' % sample_id) 102 | logger.info("No classes found") 103 | continue 104 | self.sample_id_list.append(sample_id) 105 | 106 | print('Done: filter %s results for rcnn training: %d / %d\n' % 107 | (self.mode, len(self.sample_id_list), len(self.image_idx_list))) 108 | 109 | @staticmethod 110 | def get_rpn_features(rpn_feature_dir, idx): 111 | rpn_feature_file = os.path.join(rpn_feature_dir, '%06d.npy' % idx) 112 | rpn_xyz_file = os.path.join(rpn_feature_dir, '%06d_xyz.npy' % idx) 113 | rpn_intensity_file = os.path.join(rpn_feature_dir, '%06d_intensity.npy' % idx) 114 | if cfg.RCNN.USE_SEG_SCORE: 115 | rpn_seg_file = os.path.join(rpn_feature_dir, '%06d_rawscore.npy' % idx) 116 | rpn_seg_score = np.load(rpn_seg_file).reshape(-1) 117 | rpn_seg_score = torch.sigmoid(torch.from_numpy(rpn_seg_score)).numpy() 118 | else: 119 | rpn_seg_file = os.path.join(rpn_feature_dir, '%06d_seg.npy' % idx) 120 | rpn_seg_score = np.load(rpn_seg_file).reshape(-1) 121 | return np.load(rpn_xyz_file), np.load(rpn_feature_file), np.load(rpn_intensity_file).reshape(-1), rpn_seg_score 122 | 123 | def filtrate_objects(self, obj_list): 124 | """ 125 | Discard objects which are not in self.classes (or its similar classes) 126 | :param obj_list: list 127 | :return: list 128 | """ 129 | type_whitelist = self.classes 130 | if self.mode == 'TRAIN' and cfg.INCLUDE_SIMILAR_TYPE: 131 | type_whitelist = list(self.classes) 132 | if 'VEHICLE' in self.classes: 133 | type_whitelist.append('VAN') 134 | if 'PEDESTRIAN' in self.classes: # or 'Cyclist' in self.classes: 135 | type_whitelist.append('CYCLIST') 136 | 137 | valid_obj_list = [] 138 | for obj in obj_list: 139 | if obj.cls_type not in type_whitelist: # rm Van, 20180928 140 | continue 141 | if self.mode == 'TRAIN' and cfg.PC_REDUCE_BY_RANGE and (self.check_pc_range(obj.pos) is False): 142 | continue 143 | valid_obj_list.append(obj) 144 | return valid_obj_list 145 | 146 | @staticmethod 147 | def filtrate_dc_objects(obj_list): 148 | valid_obj_list = [] 149 | for obj in obj_list: 150 | if obj.cls_type in ['DontCare']: 151 | continue 152 | valid_obj_list.append(obj) 153 | 154 | return valid_obj_list 155 | 156 | @staticmethod 157 | def check_pc_range(xyz): 158 | """ 159 | :param xyz: [x, y, z] 160 | :return: 161 | """ 162 | x_range, y_range, z_range = cfg.PC_AREA_SCOPE 163 | if (x_range[0] <= xyz[0] <= x_range[1]) and (y_range[0] <= xyz[1] <= y_range[1]) and \ 164 | (z_range[0] <= xyz[2] <= z_range[1]): 165 | return True 166 | return False 167 | 168 | @staticmethod 169 | def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape): 170 | """ 171 | Valid point should be in the image (and in the PC_AREA_SCOPE) 172 | :param pts_rect: 173 | :param pts_img: 174 | :param pts_rect_depth: 175 | :param img_shape: 176 | :return: 177 | """ 178 | val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1]) 179 | val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0]) 180 | val_flag_merge = np.logical_and(val_flag_1, val_flag_2) 181 | pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0) 182 | 183 | if cfg.PC_REDUCE_BY_RANGE: 184 | x_range, y_range, z_range = cfg.PC_AREA_SCOPE 185 | pts_x, pts_y, pts_z = pts_rect[:, 0], pts_rect[:, 1], pts_rect[:, 2] 186 | range_flag = (pts_x >= x_range[0]) & (pts_x <= x_range[1]) \ 187 | & (pts_y >= y_range[0]) & (pts_y <= y_range[1]) \ 188 | & (pts_z >= z_range[0]) & (pts_z <= z_range[1]) 189 | pts_valid_flag = pts_valid_flag & range_flag 190 | return pts_valid_flag 191 | 192 | def __len__(self): 193 | if cfg.RPN.ENABLED: 194 | return self.dataset.total_frames#190 * self.dataset.num_files 195 | elif cfg.RCNN.ENABLED: 196 | if self.mode == 'TRAIN': 197 | return len(self.sample_id_list) 198 | else: 199 | return len(self.image_idx_list) 200 | else: 201 | raise NotImplementedError 202 | 203 | def __getitem__(self, index): 204 | if cfg.RPN.ENABLED: 205 | return self.get_rpn_sample(index) 206 | elif cfg.RCNN.ENABLED: 207 | if self.mode == 'TRAIN': 208 | if cfg.RCNN.ROI_SAMPLE_JIT: 209 | return self.get_rcnn_sample_jit(index) 210 | else: 211 | return self.get_rcnn_training_sample_batch(index) 212 | else: 213 | return self.get_proposal_from_file(index) 214 | else: 215 | raise NotImplementedError 216 | 217 | def get_rpn_sample(self, index): 218 | 219 | frame, idx = dataset.data, dataset.count 220 | calib = dataset.get_calib(frame, idx) 221 | pts = dataset.get_lidar(frame, idx) 222 | target = dataset.get_label(frame, idx) 223 | index = idx 224 | sample_id = idx 225 | 226 | pts_rect = pts_lidar[:, 0:3] 227 | if pts_lidar.ndim > 3: 228 | pts_intensity = pts_lidar[:, 3] 229 | else: 230 | pts_intensity = np.ones(len(pts_lidar)) 231 | 232 | if cfg.GT_AUG_ENABLED and self.mode == 'TRAIN': 233 | # all labels for checking overlapping 234 | all_gt_obj_list = self.filtrate_dc_objects(curr_frame_labels) 235 | all_gt_boxes3d = kitti_utils.objs_to_boxes3d(all_gt_obj_list) 236 | 237 | gt_aug_flag = False 238 | if np.random.rand() < cfg.GT_AUG_APPLY_PROB: 239 | # augment one scene 240 | gt_aug_flag, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list = \ 241 | self.apply_gt_aug_to_one_scene(sample_id, pts_rect, pts_intensity, all_gt_boxes3d) 242 | 243 | # generate inputs 244 | if self.mode == 'TRAIN' or self.random_select: 245 | if self.npoints < len(pts_rect): 246 | pts_depth = pts_rect[:, 2] 247 | pts_near_flag = pts_depth < 40.0 248 | far_idxs_choice = np.where(pts_near_flag == 0)[0] 249 | near_idxs = np.where(pts_near_flag == 1)[0] 250 | near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False) 251 | 252 | choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \ 253 | if len(far_idxs_choice) > 0 else near_idxs_choice 254 | np.random.shuffle(choice) 255 | else: 256 | choice = np.arange(0, len(pts_rect), dtype=np.int32) 257 | if self.npoints > len(pts_rect): 258 | extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False) 259 | choice = np.concatenate((choice, extra_choice), axis=0) 260 | np.random.shuffle(choice) 261 | 262 | 263 | ret_pts_rect = pts_rect[choice, :] 264 | # Waymo 265 | ret_pts_intensity = pts_intensity[choice] 266 | # ret_pts_intensity = pts_intensity[choice] - np.mean(pts_intensity[choice]) # translate intensity to [-0.5, 0.5] 267 | # ret_pts_intensity = -0.5 + (ret_pts_intensity - np.min(ret_pts_intensity))/(np.max(ret_pts_intensity) - np.min(ret_pts_intensity))*1 268 | else: 269 | ret_pts_rect = pts_rect 270 | ret_pts_intensity = pts_intensity 271 | # ret_pts_intensity = pts_intensity - 0.5 272 | 273 | pts_features = [ret_pts_intensity.reshape(-1, 1)] 274 | ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0] 275 | 276 | sample_info = {'sample_id': sample_id, 'random_select': self.random_select} 277 | 278 | if self.mode == 'TEST': 279 | if cfg.RPN.USE_INTENSITY: 280 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) 281 | else: 282 | pts_input = ret_pts_rect 283 | sample_info['pts_input'] = pts_input 284 | sample_info['pts_rect'] = ret_pts_rect 285 | sample_info['pts_features'] = ret_pts_features 286 | import pdb; pdb.set_trace() 287 | sample_info['frame_ts'] = frame.timestamp_micros 288 | return sample_info 289 | 290 | gt_obj_list = self.filtrate_objects(curr_frame_labels) 291 | if cfg.GT_AUG_ENABLED and self.mode == 'TRAIN' and gt_aug_flag: 292 | gt_obj_list.extend(extra_gt_obj_list) 293 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 294 | 295 | gt_alpha = np.zeros((gt_obj_list.__len__()), dtype=np.float32) 296 | for k, obj in enumerate(gt_obj_list): 297 | gt_alpha[k] = obj.alpha 298 | 299 | # data augmentation 300 | aug_pts_rect = ret_pts_rect.copy() 301 | aug_gt_boxes3d = gt_boxes3d.copy() 302 | if cfg.AUG_DATA and self.mode == 'TRAIN': 303 | aug_pts_rect, aug_gt_boxes3d, aug_method = self.data_augmentation(aug_pts_rect, aug_gt_boxes3d, gt_alpha, 304 | sample_id) 305 | sample_info['aug_method'] = aug_method 306 | 307 | # prepare input 308 | if cfg.RPN.USE_INTENSITY: 309 | pts_input = np.concatenate((aug_pts_rect, ret_pts_features), axis=1) # (N, C) 310 | else: 311 | pts_input = aug_pts_rect 312 | 313 | if cfg.RPN.FIXED: 314 | sample_info['pts_input'] = pts_input 315 | sample_info['pts_rect'] = aug_pts_rect 316 | sample_info['pts_features'] = ret_pts_features 317 | sample_info['gt_boxes3d'] = aug_gt_boxes3d 318 | # import pdb; pdb.set_trace() 319 | sample_info['frame_ts'] = frame.timestamp_micros 320 | return sample_info 321 | 322 | # generate training labels 323 | rpn_cls_label, rpn_reg_label = self.generate_rpn_training_labels(aug_pts_rect, aug_gt_boxes3d) 324 | sample_info['pts_input'] = pts_input 325 | sample_info['pts_rect'] = aug_pts_rect 326 | sample_info['pts_features'] = ret_pts_features 327 | sample_info['rpn_cls_label'] = rpn_cls_label 328 | sample_info['rpn_reg_label'] = rpn_reg_label 329 | sample_info['gt_boxes3d'] = aug_gt_boxes3d 330 | import pdb; pdb.set_trace() 331 | sample_info['frame_ts'] = frame.timestamp_micros 332 | return sample_info 333 | 334 | @staticmethod 335 | def generate_rpn_training_labels(pts_rect, gt_boxes3d): 336 | cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32) 337 | reg_label = np.zeros((pts_rect.shape[0], 7), dtype=np.float32) # dx, dy, dz, ry, h, w, l 338 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True) 339 | extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2) 340 | extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True) 341 | for k in range(gt_boxes3d.shape[0]): 342 | box_corners = gt_corners[k] 343 | fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners) 344 | fg_pts_rect = pts_rect[fg_pt_flag] 345 | cls_label[fg_pt_flag] = 1 346 | 347 | # enlarge the bbox3d, ignore nearby points 348 | extend_box_corners = extend_gt_corners[k] 349 | fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners) 350 | ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag) 351 | cls_label[ignore_flag] = -1 352 | 353 | # pixel offset of object center 354 | center3d = gt_boxes3d[k][0:3].copy() # (x, y, z) 355 | center3d[1] -= gt_boxes3d[k][3] / 2 356 | reg_label[fg_pt_flag, 0:3] = center3d - fg_pts_rect # Now y is the true center of 3d box 20180928 357 | 358 | # size and angle encoding 359 | reg_label[fg_pt_flag, 3] = gt_boxes3d[k][3] # h 360 | reg_label[fg_pt_flag, 4] = gt_boxes3d[k][4] # w 361 | reg_label[fg_pt_flag, 5] = gt_boxes3d[k][5] # l 362 | reg_label[fg_pt_flag, 6] = gt_boxes3d[k][6] # ry 363 | 364 | return cls_label, reg_label 365 | 366 | def rotate_box3d_along_y(self, box3d, rot_angle): 367 | old_x, old_z, ry = box3d[0], box3d[2], box3d[6] 368 | old_beta = np.arctan2(old_z, old_x) 369 | alpha = -np.sign(old_beta) * np.pi / 2 + old_beta + ry 370 | 371 | box3d = kitti_utils.rotate_pc_along_y(box3d.reshape(1, 7), rot_angle=rot_angle)[0] 372 | new_x, new_z = box3d[0], box3d[2] 373 | new_beta = np.arctan2(new_z, new_x) 374 | box3d[6] = np.sign(new_beta) * np.pi / 2 + alpha - new_beta 375 | return box3d 376 | 377 | def apply_gt_aug_to_one_scene(self, sample_id, pts_rect, pts_intensity, all_gt_boxes3d): 378 | """ 379 | :param pts_rect: (N, 3) 380 | :param all_gt_boxex3d: (M2, 7) 381 | :return: 382 | """ 383 | assert self.gt_database is not None 384 | # extra_gt_num = np.random.randint(10, 15) 385 | # try_times = 50 386 | if cfg.GT_AUG_RAND_NUM: 387 | extra_gt_num = np.random.randint(10, cfg.GT_EXTRA_NUM) 388 | else: 389 | extra_gt_num = cfg.GT_EXTRA_NUM 390 | try_times = 100 391 | cnt = 0 392 | cur_gt_boxes3d = all_gt_boxes3d.copy() 393 | cur_gt_boxes3d[:, 4] += 0.5 # TODO: consider different objects 394 | cur_gt_boxes3d[:, 5] += 0.5 # enlarge new added box to avoid too nearby boxes 395 | cur_gt_corners = kitti_utils.boxes3d_to_corners3d(cur_gt_boxes3d) 396 | 397 | extra_gt_obj_list = [] 398 | extra_gt_boxes3d_list = [] 399 | new_pts_list, new_pts_intensity_list = [], [] 400 | src_pts_flag = np.ones(pts_rect.shape[0], dtype=np.int32) 401 | 402 | road_plane = self.get_road_plane(sample_id) 403 | a, b, c, d = road_plane 404 | 405 | while try_times > 0: 406 | if cnt > extra_gt_num: 407 | break 408 | 409 | try_times -= 1 410 | if cfg.GT_AUG_HARD_RATIO > 0: 411 | p = np.random.rand() 412 | if p > cfg.GT_AUG_HARD_RATIO: 413 | # use easy sample 414 | rand_idx = np.random.randint(0, len(self.gt_database[0])) 415 | new_gt_dict = self.gt_database[0][rand_idx] 416 | else: 417 | # use hard sample 418 | rand_idx = np.random.randint(0, len(self.gt_database[1])) 419 | new_gt_dict = self.gt_database[1][rand_idx] 420 | else: 421 | rand_idx = np.random.randint(0, self.gt_database.__len__()) 422 | new_gt_dict = self.gt_database[rand_idx] 423 | 424 | new_gt_box3d = new_gt_dict['gt_box3d'].copy() 425 | new_gt_points = new_gt_dict['points'].copy() 426 | new_gt_intensity = new_gt_dict['intensity'].copy() 427 | new_gt_obj = new_gt_dict['obj'] 428 | center = new_gt_box3d[0:3] 429 | if cfg.PC_REDUCE_BY_RANGE and (self.check_pc_range(center) is False): 430 | continue 431 | 432 | if new_gt_points.__len__() < 5: # too few points 433 | continue 434 | 435 | # put it on the road plane 436 | cur_height = (-d - a * center[0] - c * center[2]) / b 437 | move_height = new_gt_box3d[1] - cur_height 438 | new_gt_box3d[1] -= move_height 439 | new_gt_points[:, 1] -= move_height 440 | new_gt_obj.pos[1] -= move_height 441 | 442 | new_enlarged_box3d = new_gt_box3d.copy() 443 | new_enlarged_box3d[4] += 0.5 444 | new_enlarged_box3d[5] += 0.5 # enlarge new added box to avoid too nearby boxes 445 | 446 | cnt += 1 447 | new_corners = kitti_utils.boxes3d_to_corners3d(new_enlarged_box3d.reshape(1, 7)) 448 | iou3d = kitti_utils.get_iou3d(new_corners, cur_gt_corners) 449 | valid_flag = iou3d.max() < 1e-8 450 | if not valid_flag: 451 | continue 452 | 453 | enlarged_box3d = new_gt_box3d.copy() 454 | enlarged_box3d[3] += 2 # remove the points above and below the object 455 | 456 | boxes_pts_mask_list = roipool3d_utils.pts_in_boxes3d_cpu( 457 | torch.from_numpy(pts_rect), torch.from_numpy(enlarged_box3d.reshape(1, 7))) 458 | pt_mask_flag = (boxes_pts_mask_list[0].numpy() == 1) 459 | src_pts_flag[pt_mask_flag] = 0 # remove the original points which are inside the new box 460 | 461 | new_pts_list.append(new_gt_points) 462 | new_pts_intensity_list.append(new_gt_intensity) 463 | cur_gt_boxes3d = np.concatenate((cur_gt_boxes3d, new_enlarged_box3d.reshape(1, 7)), axis=0) 464 | cur_gt_corners = np.concatenate((cur_gt_corners, new_corners), axis=0) 465 | extra_gt_boxes3d_list.append(new_gt_box3d.reshape(1, 7)) 466 | extra_gt_obj_list.append(new_gt_obj) 467 | 468 | if new_pts_list.__len__() == 0: 469 | return False, pts_rect, pts_intensity, None, None 470 | 471 | extra_gt_boxes3d = np.concatenate(extra_gt_boxes3d_list, axis=0) 472 | # remove original points and add new points 473 | pts_rect = pts_rect[src_pts_flag == 1] 474 | pts_intensity = pts_intensity[src_pts_flag == 1] 475 | new_pts_rect = np.concatenate(new_pts_list, axis=0) 476 | new_pts_intensity = np.concatenate(new_pts_intensity_list, axis=0) 477 | pts_rect = np.concatenate((pts_rect, new_pts_rect), axis=0) 478 | pts_intensity = np.concatenate((pts_intensity, new_pts_intensity), axis=0) 479 | 480 | return True, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list 481 | 482 | def data_augmentation(self, aug_pts_rect, aug_gt_boxes3d, gt_alpha, sample_id=None, mustaug=False, stage=1): 483 | """ 484 | :param aug_pts_rect: (N, 3) 485 | :param aug_gt_boxes3d: (N, 7) 486 | :param gt_alpha: (N) 487 | :return: 488 | """ 489 | aug_list = cfg.AUG_METHOD_LIST 490 | aug_enable = 1 - np.random.rand(3) 491 | if mustaug is True: 492 | aug_enable[0] = -1 493 | aug_enable[1] = -1 494 | aug_method = [] 495 | if 'rotation' in aug_list and aug_enable[0] < cfg.AUG_METHOD_PROB[0]: 496 | angle = np.random.uniform(-np.pi / cfg.AUG_ROT_RANGE, np.pi / cfg.AUG_ROT_RANGE) 497 | aug_pts_rect = kitti_utils.rotate_pc_along_y(aug_pts_rect, rot_angle=angle) 498 | if stage == 1: 499 | # xyz change, hwl unchange 500 | aug_gt_boxes3d = kitti_utils.rotate_pc_along_y(aug_gt_boxes3d, rot_angle=angle) 501 | 502 | # calculate the ry after rotation 503 | x, z = aug_gt_boxes3d[:, 0], aug_gt_boxes3d[:, 2] 504 | beta = np.arctan2(z, x) 505 | new_ry = np.sign(beta) * np.pi / 2 + gt_alpha - beta 506 | aug_gt_boxes3d[:, 6] = new_ry # TODO: not in [-np.pi / 2, np.pi / 2] 507 | elif stage == 2: 508 | # for debug stage-2, this implementation has little float precision difference with the above one 509 | assert aug_gt_boxes3d.shape[0] == 2 510 | aug_gt_boxes3d[0] = self.rotate_box3d_along_y(aug_gt_boxes3d[0], angle) 511 | aug_gt_boxes3d[1] = self.rotate_box3d_along_y(aug_gt_boxes3d[1], angle) 512 | else: 513 | raise NotImplementedError 514 | 515 | aug_method.append(['rotation', angle]) 516 | 517 | if 'scaling' in aug_list and aug_enable[1] < cfg.AUG_METHOD_PROB[1]: 518 | scale = np.random.uniform(0.95, 1.05) 519 | aug_pts_rect = aug_pts_rect * scale 520 | aug_gt_boxes3d[:, 0:6] = aug_gt_boxes3d[:, 0:6] * scale 521 | aug_method.append(['scaling', scale]) 522 | 523 | if 'flip' in aug_list and aug_enable[2] < cfg.AUG_METHOD_PROB[2]: 524 | # flip horizontal 525 | aug_pts_rect[:, 0] = -aug_pts_rect[:, 0] 526 | aug_gt_boxes3d[:, 0] = -aug_gt_boxes3d[:, 0] 527 | # flip orientation: ry > 0: pi - ry, ry < 0: -pi - ry 528 | if stage == 1: 529 | aug_gt_boxes3d[:, 6] = np.sign(aug_gt_boxes3d[:, 6]) * np.pi - aug_gt_boxes3d[:, 6] 530 | elif stage == 2: 531 | assert aug_gt_boxes3d.shape[0] == 2 532 | aug_gt_boxes3d[0, 6] = np.sign(aug_gt_boxes3d[0, 6]) * np.pi - aug_gt_boxes3d[0, 6] 533 | aug_gt_boxes3d[1, 6] = np.sign(aug_gt_boxes3d[1, 6]) * np.pi - aug_gt_boxes3d[1, 6] 534 | else: 535 | raise NotImplementedError 536 | 537 | aug_method.append('flip') 538 | 539 | return aug_pts_rect, aug_gt_boxes3d, aug_method 540 | 541 | def get_rcnn_sample_info(self, roi_info): 542 | sample_id, gt_box3d = roi_info['sample_id'], roi_info['gt_box3d'] 543 | rpn_xyz, rpn_features, rpn_intensity, seg_mask = self.rpn_feature_list[sample_id] 544 | 545 | # augmentation original roi by adding noise 546 | roi_box3d = self.aug_roi_by_noise(roi_info) 547 | 548 | # point cloud pooling based on roi_box3d 549 | pooled_boxes3d = kitti_utils.enlarge_box3d(roi_box3d.reshape(1, 7), cfg.RCNN.POOL_EXTRA_WIDTH) 550 | 551 | boxes_pts_mask_list = roipool3d_utils.pts_in_boxes3d_cpu(torch.from_numpy(rpn_xyz), 552 | torch.from_numpy(pooled_boxes3d)) 553 | pt_mask_flag = (boxes_pts_mask_list[0].numpy() == 1) 554 | cur_pts = rpn_xyz[pt_mask_flag].astype(np.float32) 555 | 556 | # data augmentation 557 | aug_pts = cur_pts.copy() 558 | aug_gt_box3d = gt_box3d.copy().astype(np.float32) 559 | aug_roi_box3d = roi_box3d.copy() 560 | if cfg.AUG_DATA and self.mode == 'TRAIN': 561 | # calculate alpha by ry 562 | temp_boxes3d = np.concatenate([aug_roi_box3d.reshape(1, 7), aug_gt_box3d.reshape(1, 7)], axis=0) 563 | temp_x, temp_z, temp_ry = temp_boxes3d[:, 0], temp_boxes3d[:, 2], temp_boxes3d[:, 6] 564 | temp_beta = np.arctan2(temp_z, temp_x).astype(np.float64) 565 | temp_alpha = -np.sign(temp_beta) * np.pi / 2 + temp_beta + temp_ry 566 | 567 | # data augmentation 568 | aug_pts, aug_boxes3d, aug_method = self.data_augmentation(aug_pts, temp_boxes3d, temp_alpha, mustaug=True, stage=2) 569 | aug_roi_box3d, aug_gt_box3d = aug_boxes3d[0], aug_boxes3d[1] 570 | aug_gt_box3d = aug_gt_box3d.astype(gt_box3d.dtype) 571 | 572 | # Pool input points 573 | valid_mask = 1 # whether the input is valid 574 | 575 | if aug_pts.shape[0] == 0: 576 | pts_features = np.zeros((1, 128), dtype=np.float32) 577 | input_channel = 3 + int(cfg.RCNN.USE_INTENSITY) + int(cfg.RCNN.USE_MASK) + int(cfg.RCNN.USE_DEPTH) 578 | pts_input = np.zeros((1, input_channel), dtype=np.float32) 579 | valid_mask = 0 580 | else: 581 | pts_features = rpn_features[pt_mask_flag].astype(np.float32) 582 | pts_intensity = rpn_intensity[pt_mask_flag].astype(np.float32) 583 | 584 | pts_input_list = [aug_pts, pts_intensity.reshape(-1, 1)] 585 | if cfg.RCNN.USE_INTENSITY: 586 | pts_input_list = [aug_pts, pts_intensity.reshape(-1, 1)] 587 | else: 588 | pts_input_list = [aug_pts] 589 | 590 | if cfg.RCNN.USE_MASK: 591 | if cfg.RCNN.MASK_TYPE == 'seg': 592 | pts_mask = seg_mask[pt_mask_flag].astype(np.float32) 593 | elif cfg.RCNN.MASK_TYPE == 'roi': 594 | pts_mask = roipool3d_utils.pts_in_boxes3d_cpu(torch.from_numpy(aug_pts), 595 | torch.from_numpy(aug_roi_box3d.reshape(1, 7))) 596 | pts_mask = (pts_mask[0].numpy() == 1).astype(np.float32) 597 | else: 598 | raise NotImplementedError 599 | 600 | pts_input_list.append(pts_mask.reshape(-1, 1)) 601 | 602 | if cfg.RCNN.USE_DEPTH: 603 | pts_depth = np.linalg.norm(aug_pts, axis=1, ord=2) 604 | pts_depth_norm = (pts_depth / 70.0) - 0.5 605 | pts_input_list.append(pts_depth_norm.reshape(-1, 1)) 606 | 607 | pts_input = np.concatenate(pts_input_list, axis=1) # (N, C) 608 | 609 | aug_gt_corners = kitti_utils.boxes3d_to_corners3d(aug_gt_box3d.reshape(-1, 7)) 610 | aug_roi_corners = kitti_utils.boxes3d_to_corners3d(aug_roi_box3d.reshape(-1, 7)) 611 | iou3d = kitti_utils.get_iou3d(aug_roi_corners, aug_gt_corners) 612 | cur_iou = iou3d[0][0] 613 | 614 | # regression valid mask 615 | reg_valid_mask = 1 if cur_iou >= cfg.RCNN.REG_FG_THRESH and valid_mask == 1 else 0 616 | 617 | # classification label 618 | cls_label = 1 if cur_iou > cfg.RCNN.CLS_FG_THRESH else 0 619 | if cfg.RCNN.CLS_BG_THRESH < cur_iou < cfg.RCNN.CLS_FG_THRESH or valid_mask == 0: 620 | cls_label = -1 621 | 622 | # canonical transform and sampling 623 | pts_input_ct, gt_box3d_ct = self.canonical_transform(pts_input, aug_roi_box3d, aug_gt_box3d) 624 | pts_input_ct, pts_features = self.rcnn_input_sample(pts_input_ct, pts_features) 625 | 626 | sample_info = {'sample_id': sample_id, 627 | 'pts_input': pts_input_ct, 628 | 'pts_features': pts_features, 629 | 'cls_label': cls_label, 630 | 'reg_valid_mask': reg_valid_mask, 631 | 'gt_boxes3d_ct': gt_box3d_ct, 632 | 'roi_boxes3d': aug_roi_box3d, 633 | 'roi_size': aug_roi_box3d[3:6], 634 | 'gt_boxes3d': aug_gt_box3d} 635 | 636 | return sample_info 637 | 638 | @staticmethod 639 | def canonical_transform(pts_input, roi_box3d, gt_box3d): 640 | roi_ry = roi_box3d[6] % (2 * np.pi) # 0 ~ 2pi 641 | roi_center = roi_box3d[0:3] 642 | # shift to center 643 | pts_input[:, [0, 1, 2]] = pts_input[:, [0, 1, 2]] - roi_center 644 | gt_box3d_ct = np.copy(gt_box3d) 645 | gt_box3d_ct[0:3] = gt_box3d_ct[0:3] - roi_center 646 | # rotate to the direction of head 647 | gt_box3d_ct = kitti_utils.rotate_pc_along_y(gt_box3d_ct.reshape(1, 7), roi_ry).reshape(7) 648 | gt_box3d_ct[6] = gt_box3d_ct[6] - roi_ry 649 | pts_input = kitti_utils.rotate_pc_along_y(pts_input, roi_ry) 650 | 651 | return pts_input, gt_box3d_ct 652 | 653 | @staticmethod 654 | def canonical_transform_batch(pts_input, roi_boxes3d, gt_boxes3d): 655 | """ 656 | :param pts_input: (N, npoints, 3 + C) 657 | :param roi_boxes3d: (N, 7) 658 | :param gt_boxes3d: (N, 7) 659 | :return: 660 | """ 661 | roi_ry = roi_boxes3d[:, 6] % (2 * np.pi) # 0 ~ 2pi 662 | roi_center = roi_boxes3d[:, 0:3] 663 | # shift to center 664 | pts_input[:, :, [0, 1, 2]] = pts_input[:, :, [0, 1, 2]] - roi_center.reshape(-1, 1, 3) 665 | gt_boxes3d_ct = np.copy(gt_boxes3d) 666 | gt_boxes3d_ct[:, 0:3] = gt_boxes3d_ct[:, 0:3] - roi_center 667 | # rotate to the direction of head 668 | gt_boxes3d_ct = kitti_utils.rotate_pc_along_y_torch(torch.from_numpy(gt_boxes3d_ct.reshape(-1, 1, 7)), 669 | torch.from_numpy(roi_ry)).numpy().reshape(-1, 7) 670 | gt_boxes3d_ct[:, 6] = gt_boxes3d_ct[:, 6] - roi_ry 671 | pts_input = kitti_utils.rotate_pc_along_y_torch(torch.from_numpy(pts_input), torch.from_numpy(roi_ry)).numpy() 672 | 673 | return pts_input, gt_boxes3d_ct 674 | 675 | @staticmethod 676 | def rcnn_input_sample(pts_input, pts_features): 677 | choice = np.random.choice(pts_input.shape[0], cfg.RCNN.NUM_POINTS, replace=True) 678 | 679 | if pts_input.shape[0] < cfg.RCNN.NUM_POINTS: 680 | choice[:pts_input.shape[0]] = np.arange(pts_input.shape[0]) 681 | np.random.shuffle(choice) 682 | pts_input = pts_input[choice] 683 | pts_features = pts_features[choice] 684 | 685 | return pts_input, pts_features 686 | 687 | def aug_roi_by_noise(self, roi_info): 688 | """ 689 | add noise to original roi to get aug_box3d 690 | :param roi_info: 691 | :return: 692 | """ 693 | roi_box3d, gt_box3d = roi_info['roi_box3d'], roi_info['gt_box3d'] 694 | original_iou = roi_info['iou3d'] 695 | temp_iou = cnt = 0 696 | pos_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH) 697 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_box3d.reshape(-1, 7)) 698 | aug_box3d = roi_box3d 699 | while temp_iou < pos_thresh and cnt < 10: 700 | if roi_info['type'] == 'gt': 701 | aug_box3d = self.random_aug_box3d(roi_box3d) # GT, must random 702 | else: 703 | if np.random.rand() < 0.2: 704 | aug_box3d = roi_box3d # p=0.2 to keep the original roi box 705 | else: 706 | aug_box3d = self.random_aug_box3d(roi_box3d) 707 | aug_corners = kitti_utils.boxes3d_to_corners3d(aug_box3d.reshape(-1, 7)) 708 | iou3d = kitti_utils.get_iou3d(aug_corners, gt_corners) 709 | temp_iou = iou3d[0][0] 710 | cnt += 1 711 | if original_iou < pos_thresh: # original bg, break 712 | break 713 | return aug_box3d 714 | 715 | @staticmethod 716 | def random_aug_box3d(box3d): 717 | """ 718 | :param box3d: (7) [x, y, z, h, w, l, ry] 719 | random shift, scale, orientation 720 | """ 721 | if cfg.RCNN.REG_AUG_METHOD == 'single': 722 | pos_shift = (np.random.rand(3) - 0.5) # [-0.5 ~ 0.5] 723 | hwl_scale = (np.random.rand(3) - 0.5) / (0.5 / 0.15) + 1.0 # 724 | angle_rot = (np.random.rand(1) - 0.5) / (0.5 / (np.pi / 12)) # [-pi/12 ~ pi/12] 725 | 726 | aug_box3d = np.concatenate([box3d[0:3] + pos_shift, box3d[3:6] * hwl_scale, 727 | box3d[6:7] + angle_rot]) 728 | return aug_box3d 729 | elif cfg.RCNN.REG_AUG_METHOD == 'multiple': 730 | # pos_range, hwl_range, angle_range, mean_iou 731 | range_config = [[0.2, 0.1, np.pi / 12, 0.7], 732 | [0.3, 0.15, np.pi / 12, 0.6], 733 | [0.5, 0.15, np.pi / 9, 0.5], 734 | [0.8, 0.15, np.pi / 6, 0.3], 735 | [1.0, 0.15, np.pi / 3, 0.2]] 736 | idx = np.random.randint(len(range_config)) 737 | 738 | pos_shift = ((np.random.rand(3) - 0.5) / 0.5) * range_config[idx][0] 739 | hwl_scale = ((np.random.rand(3) - 0.5) / 0.5) * range_config[idx][1] + 1.0 740 | angle_rot = ((np.random.rand(1) - 0.5) / 0.5) * range_config[idx][2] 741 | 742 | aug_box3d = np.concatenate([box3d[0:3] + pos_shift, box3d[3:6] * hwl_scale, box3d[6:7] + angle_rot]) 743 | return aug_box3d 744 | elif cfg.RCNN.REG_AUG_METHOD == 'normal': 745 | x_shift = np.random.normal(loc=0, scale=0.3) 746 | y_shift = np.random.normal(loc=0, scale=0.2) 747 | z_shift = np.random.normal(loc=0, scale=0.3) 748 | h_shift = np.random.normal(loc=0, scale=0.25) 749 | w_shift = np.random.normal(loc=0, scale=0.15) 750 | l_shift = np.random.normal(loc=0, scale=0.5) 751 | ry_shift = ((np.random.rand() - 0.5) / 0.5) * np.pi / 12 752 | 753 | aug_box3d = np.array([box3d[0] + x_shift, box3d[1] + y_shift, box3d[2] + z_shift, box3d[3] + h_shift, 754 | box3d[4] + w_shift, box3d[5] + l_shift, box3d[6] + ry_shift]) 755 | return aug_box3d 756 | else: 757 | raise NotImplementedError 758 | 759 | def get_proposal_from_file(self, index): 760 | sample_id = int(self.image_idx_list[index]) 761 | proposal_file = os.path.join(self.rcnn_eval_roi_dir, '%06d.txt' % sample_id) 762 | roi_obj_list = kitti_utils.get_objects_from_label(proposal_file) 763 | 764 | rpn_xyz, rpn_features, rpn_intensity, seg_mask = self.get_rpn_features(self.rcnn_eval_feature_dir, sample_id) 765 | pts_rect, pts_rpn_features, pts_intensity = rpn_xyz, rpn_features, rpn_intensity 766 | 767 | roi_box3d_list, roi_scores = [], [] 768 | for obj in roi_obj_list: 769 | box3d = np.array([obj.pos[0], obj.pos[1], obj.pos[2], obj.h, obj.w, obj.l, obj.ry], dtype=np.float32) 770 | roi_box3d_list.append(box3d.reshape(1, 7)) 771 | roi_scores.append(obj.score) 772 | 773 | roi_boxes3d = np.concatenate(roi_box3d_list, axis=0) # (N, 7) 774 | roi_scores = np.array(roi_scores, dtype=np.float32) # (N) 775 | 776 | if cfg.RCNN.ROI_SAMPLE_JIT: 777 | sample_dict = {'sample_id': sample_id, 778 | 'rpn_xyz': rpn_xyz, 779 | 'rpn_features': rpn_features, 780 | 'seg_mask': seg_mask, 781 | 'roi_boxes3d': roi_boxes3d, 782 | 'roi_scores': roi_scores, 783 | 'pts_depth': np.linalg.norm(rpn_xyz, ord=2, axis=1)} 784 | 785 | if self.mode != 'TEST': 786 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 787 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 788 | 789 | roi_corners = kitti_utils.boxes3d_to_corners3d(roi_boxes3d) 790 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d) 791 | iou3d = kitti_utils.get_iou3d(roi_corners, gt_corners) 792 | if gt_boxes3d.shape[0] > 0: 793 | gt_iou = iou3d.max(axis=1) 794 | else: 795 | gt_iou = np.zeros(roi_boxes3d.shape[0]).astype(np.float32) 796 | 797 | sample_dict['gt_boxes3d'] = gt_boxes3d 798 | sample_dict['gt_iou'] = gt_iou 799 | return sample_dict 800 | 801 | if cfg.RCNN.USE_INTENSITY: 802 | pts_extra_input_list = [pts_intensity.reshape(-1, 1), seg_mask.reshape(-1, 1)] 803 | else: 804 | pts_extra_input_list = [seg_mask.reshape(-1, 1)] 805 | 806 | if cfg.RCNN.USE_DEPTH: 807 | cur_depth = np.linalg.norm(pts_rect, axis=1, ord=2) 808 | cur_depth_norm = (cur_depth / 70.0) - 0.5 809 | pts_extra_input_list.append(cur_depth_norm.reshape(-1, 1)) 810 | 811 | pts_extra_input = np.concatenate(pts_extra_input_list, axis=1) 812 | pts_input, pts_features = roipool3d_utils.roipool3d_cpu(roi_boxes3d, pts_rect, pts_rpn_features, 813 | pts_extra_input, cfg.RCNN.POOL_EXTRA_WIDTH, 814 | sampled_pt_num=cfg.RCNN.NUM_POINTS) 815 | 816 | sample_dict = {'sample_id': sample_id, 817 | 'pts_input': pts_input, 818 | 'pts_features': pts_features, 819 | 'roi_boxes3d': roi_boxes3d, 820 | 'roi_scores': roi_scores, 821 | 'roi_size': roi_boxes3d[:, 3:6]} 822 | 823 | if self.mode == 'TEST': 824 | return sample_dict 825 | 826 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 827 | gt_boxes3d = np.zeros((gt_obj_list.__len__(), 7), dtype=np.float32) 828 | 829 | for k, obj in enumerate(gt_obj_list): 830 | gt_boxes3d[k, 0:3], gt_boxes3d[k, 3], gt_boxes3d[k, 4], gt_boxes3d[k, 5], gt_boxes3d[k, 6] \ 831 | = obj.pos, obj.h, obj.w, obj.l, obj.ry 832 | 833 | if gt_boxes3d.__len__() == 0: 834 | gt_iou = np.zeros((roi_boxes3d.shape[0]), dtype=np.float32) 835 | else: 836 | roi_corners = kitti_utils.boxes3d_to_corners3d(roi_boxes3d) 837 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d) 838 | iou3d = kitti_utils.get_iou3d(roi_corners, gt_corners) 839 | gt_iou = iou3d.max(axis=1) 840 | sample_dict['gt_boxes3d'] = gt_boxes3d 841 | sample_dict['gt_iou'] = gt_iou 842 | 843 | return sample_dict 844 | 845 | def get_rcnn_training_sample_batch(self, index): 846 | sample_id = int(self.sample_id_list[index]) 847 | rpn_xyz, rpn_features, rpn_intensity, seg_mask = \ 848 | self.get_rpn_features(self.rcnn_training_feature_dir, sample_id) 849 | 850 | # load rois and gt_boxes3d for this sample 851 | roi_file = os.path.join(self.rcnn_training_roi_dir, '%06d.txt' % sample_id) 852 | roi_obj_list = kitti_utils.get_objects_from_label(roi_file) 853 | roi_boxes3d = kitti_utils.objs_to_boxes3d(roi_obj_list) 854 | # roi_scores = kitti_utils.objs_to_scores(roi_obj_list) 855 | 856 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 857 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 858 | 859 | # calculate original iou 860 | iou3d = kitti_utils.get_iou3d(kitti_utils.boxes3d_to_corners3d(roi_boxes3d), 861 | kitti_utils.boxes3d_to_corners3d(gt_boxes3d)) 862 | max_overlaps, gt_assignment = iou3d.max(axis=1), iou3d.argmax(axis=1) 863 | max_iou_of_gt, roi_assignment = iou3d.max(axis=0), iou3d.argmax(axis=0) 864 | roi_assignment = roi_assignment[max_iou_of_gt > 0].reshape(-1) 865 | 866 | # sample fg, easy_bg, hard_bg 867 | fg_rois_per_image = int(np.round(cfg.RCNN.FG_RATIO * cfg.RCNN.ROI_PER_IMAGE)) 868 | fg_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH) 869 | fg_inds = np.nonzero(max_overlaps >= fg_thresh)[0] 870 | fg_inds = np.concatenate((fg_inds, roi_assignment), axis=0) # consider the roi which has max_overlaps with gt as fg 871 | 872 | easy_bg_inds = np.nonzero((max_overlaps < cfg.RCNN.CLS_BG_THRESH_LO))[0] 873 | hard_bg_inds = np.nonzero((max_overlaps < cfg.RCNN.CLS_BG_THRESH) & 874 | (max_overlaps >= cfg.RCNN.CLS_BG_THRESH_LO))[0] 875 | 876 | fg_num_rois = fg_inds.size 877 | bg_num_rois = hard_bg_inds.size + easy_bg_inds.size 878 | 879 | if fg_num_rois > 0 and bg_num_rois > 0: 880 | # sampling fg 881 | fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois) 882 | rand_num = np.random.permutation(fg_num_rois) 883 | fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]] 884 | 885 | # sampling bg 886 | bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE - fg_rois_per_this_image 887 | bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image) 888 | 889 | elif fg_num_rois > 0 and bg_num_rois == 0: 890 | # sampling fg 891 | rand_num = np.floor(np.random.rand(cfg.RCNN.ROI_PER_IMAGE ) * fg_num_rois) 892 | rand_num = torch.from_numpy(rand_num).type_as(gt_boxes3d).long() 893 | fg_inds = fg_inds[rand_num] 894 | fg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE 895 | bg_rois_per_this_image = 0 896 | elif bg_num_rois > 0 and fg_num_rois == 0: 897 | # sampling bg 898 | bg_rois_per_this_image = cfg.RCNN.ROI_PER_IMAGE 899 | bg_inds = self.sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image) 900 | fg_rois_per_this_image = 0 901 | else: 902 | import pdb 903 | pdb.set_trace() 904 | raise NotImplementedError 905 | 906 | # augment the rois by noise 907 | roi_list, roi_iou_list, roi_gt_list = [], [], [] 908 | if fg_rois_per_this_image > 0: 909 | fg_rois_src = roi_boxes3d[fg_inds].copy() 910 | gt_of_fg_rois = gt_boxes3d[gt_assignment[fg_inds]] 911 | fg_rois, fg_iou3d = self.aug_roi_by_noise_batch(fg_rois_src, gt_of_fg_rois, aug_times=10) 912 | roi_list.append(fg_rois) 913 | roi_iou_list.append(fg_iou3d) 914 | roi_gt_list.append(gt_of_fg_rois) 915 | 916 | if bg_rois_per_this_image > 0: 917 | bg_rois_src = roi_boxes3d[bg_inds].copy() 918 | gt_of_bg_rois = gt_boxes3d[gt_assignment[bg_inds]] 919 | bg_rois, bg_iou3d = self.aug_roi_by_noise_batch(bg_rois_src, gt_of_bg_rois, aug_times=1) 920 | roi_list.append(bg_rois) 921 | roi_iou_list.append(bg_iou3d) 922 | roi_gt_list.append(gt_of_bg_rois) 923 | 924 | rois = np.concatenate(roi_list, axis=0) 925 | iou_of_rois = np.concatenate(roi_iou_list, axis=0) 926 | gt_of_rois = np.concatenate(roi_gt_list, axis=0) 927 | 928 | # collect extra features for point cloud pooling 929 | if cfg.RCNN.USE_INTENSITY: 930 | pts_extra_input_list = [rpn_intensity.reshape(-1, 1), seg_mask.reshape(-1, 1)] 931 | else: 932 | pts_extra_input_list = [seg_mask.reshape(-1, 1)] 933 | 934 | if cfg.RCNN.USE_DEPTH: 935 | pts_depth = (np.linalg.norm(rpn_xyz, ord=2, axis=1) / 70.0) - 0.5 936 | pts_extra_input_list.append(pts_depth.reshape(-1, 1)) 937 | pts_extra_input = np.concatenate(pts_extra_input_list, axis=1) 938 | 939 | pts_input, pts_features, pts_empty_flag = roipool3d_utils.roipool3d_cpu(rois, rpn_xyz, rpn_features, 940 | pts_extra_input, 941 | cfg.RCNN.POOL_EXTRA_WIDTH, 942 | sampled_pt_num=cfg.RCNN.NUM_POINTS, 943 | canonical_transform=False) 944 | 945 | # data augmentation 946 | if cfg.AUG_DATA and self.mode == 'TRAIN': 947 | for k in range(rois.__len__()): 948 | aug_pts = pts_input[k, :, 0:3].copy() 949 | aug_gt_box3d = gt_of_rois[k].copy() 950 | aug_roi_box3d = rois[k].copy() 951 | 952 | # calculate alpha by ry 953 | temp_boxes3d = np.concatenate([aug_roi_box3d.reshape(1, 7), aug_gt_box3d.reshape(1, 7)], axis=0) 954 | temp_x, temp_z, temp_ry = temp_boxes3d[:, 0], temp_boxes3d[:, 2], temp_boxes3d[:, 6] 955 | temp_beta = np.arctan2(temp_z, temp_x).astype(np.float64) 956 | temp_alpha = -np.sign(temp_beta) * np.pi / 2 + temp_beta + temp_ry 957 | 958 | # data augmentation 959 | aug_pts, aug_boxes3d, aug_method = self.data_augmentation(aug_pts, temp_boxes3d, temp_alpha, 960 | mustaug=True, stage=2) 961 | 962 | # assign to original data 963 | pts_input[k, :, 0:3] = aug_pts 964 | rois[k] = aug_boxes3d[0] 965 | gt_of_rois[k] = aug_boxes3d[1] 966 | 967 | valid_mask = (pts_empty_flag == 0).astype(np.int32) 968 | 969 | # regression valid mask 970 | reg_valid_mask = (iou_of_rois > cfg.RCNN.REG_FG_THRESH).astype(np.int32) & valid_mask 971 | 972 | # classification label 973 | cls_label = (iou_of_rois > cfg.RCNN.CLS_FG_THRESH).astype(np.int32) 974 | invalid_mask = (iou_of_rois > cfg.RCNN.CLS_BG_THRESH) & (iou_of_rois < cfg.RCNN.CLS_FG_THRESH) 975 | cls_label[invalid_mask] = -1 976 | cls_label[valid_mask == 0] = -1 977 | 978 | # canonical transform and sampling 979 | pts_input_ct, gt_boxes3d_ct = self.canonical_transform_batch(pts_input, rois, gt_of_rois) 980 | 981 | sample_info = {'sample_id': sample_id, 982 | 'pts_input': pts_input_ct, 983 | 'pts_features': pts_features, 984 | 'cls_label': cls_label, 985 | 'reg_valid_mask': reg_valid_mask, 986 | 'gt_boxes3d_ct': gt_boxes3d_ct, 987 | 'roi_boxes3d': rois, 988 | 'roi_size': rois[:, 3:6], 989 | 'gt_boxes3d': gt_of_rois} 990 | 991 | return sample_info 992 | 993 | def sample_bg_inds(self, hard_bg_inds, easy_bg_inds, bg_rois_per_this_image): 994 | if hard_bg_inds.size > 0 and easy_bg_inds.size > 0: 995 | hard_bg_rois_num = int(bg_rois_per_this_image * cfg.RCNN.HARD_BG_RATIO) 996 | easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num 997 | 998 | # sampling hard bg 999 | rand_num = np.floor(np.random.rand(hard_bg_rois_num) * hard_bg_inds.size).astype(np.int32) 1000 | hard_bg_inds = hard_bg_inds[rand_num] 1001 | # sampling easy bg 1002 | rand_num = np.floor(np.random.rand(easy_bg_rois_num) * easy_bg_inds.size).astype(np.int32) 1003 | easy_bg_inds = easy_bg_inds[rand_num] 1004 | 1005 | bg_inds = np.concatenate([hard_bg_inds, easy_bg_inds], axis=0) 1006 | elif hard_bg_inds.size > 0 and easy_bg_inds.size == 0: 1007 | hard_bg_rois_num = bg_rois_per_this_image 1008 | # sampling hard bg 1009 | rand_num = np.floor(np.random.rand(hard_bg_rois_num) * hard_bg_inds.size).astype(np.int32) 1010 | bg_inds = hard_bg_inds[rand_num] 1011 | elif hard_bg_inds.size == 0 and easy_bg_inds.size > 0: 1012 | easy_bg_rois_num = bg_rois_per_this_image 1013 | # sampling easy bg 1014 | rand_num = np.floor(np.random.rand(easy_bg_rois_num) * easy_bg_inds.size).astype(np.int32) 1015 | bg_inds = easy_bg_inds[rand_num] 1016 | else: 1017 | raise NotImplementedError 1018 | 1019 | return bg_inds 1020 | 1021 | def aug_roi_by_noise_batch(self, roi_boxes3d, gt_boxes3d, aug_times=10): 1022 | """ 1023 | :param roi_boxes3d: (N, 7) 1024 | :param gt_boxes3d: (N, 7) 1025 | :return: 1026 | """ 1027 | iou_of_rois = np.zeros(roi_boxes3d.shape[0], dtype=np.float32) 1028 | for k in range(roi_boxes3d.__len__()): 1029 | temp_iou = cnt = 0 1030 | roi_box3d = roi_boxes3d[k] 1031 | gt_box3d = gt_boxes3d[k] 1032 | pos_thresh = min(cfg.RCNN.REG_FG_THRESH, cfg.RCNN.CLS_FG_THRESH) 1033 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_box3d.reshape(1, 7)) 1034 | aug_box3d = roi_box3d 1035 | while temp_iou < pos_thresh and cnt < aug_times: 1036 | if np.random.rand() < 0.2: 1037 | aug_box3d = roi_box3d # p=0.2 to keep the original roi box 1038 | else: 1039 | aug_box3d = self.random_aug_box3d(roi_box3d) 1040 | aug_corners = kitti_utils.boxes3d_to_corners3d(aug_box3d.reshape(1, 7)) 1041 | iou3d = kitti_utils.get_iou3d(aug_corners, gt_corners) 1042 | temp_iou = iou3d[0][0] 1043 | cnt += 1 1044 | roi_boxes3d[k] = aug_box3d 1045 | iou_of_rois[k] = temp_iou 1046 | return roi_boxes3d, iou_of_rois 1047 | 1048 | def get_rcnn_sample_jit(self, index): 1049 | sample_id = int(self.sample_id_list[index]) 1050 | rpn_xyz, rpn_features, rpn_intensity, seg_mask = \ 1051 | self.get_rpn_features(self.rcnn_training_feature_dir, sample_id) 1052 | 1053 | # load rois and gt_boxes3d for this sample 1054 | roi_file = os.path.join(self.rcnn_training_roi_dir, '%06d.txt' % sample_id) 1055 | roi_obj_list = kitti_utils.get_objects_from_label(roi_file) 1056 | roi_boxes3d = kitti_utils.objs_to_boxes3d(roi_obj_list) 1057 | # roi_scores = kitti_utils.objs_to_scores(roi_obj_list) 1058 | 1059 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 1060 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 1061 | 1062 | sample_info = {'sample_id': sample_id, 1063 | 'rpn_xyz': rpn_xyz, 1064 | 'rpn_features': rpn_features, 1065 | 'rpn_intensity': rpn_intensity, 1066 | 'seg_mask': seg_mask, 1067 | 'roi_boxes3d': roi_boxes3d, 1068 | 'gt_boxes3d': gt_boxes3d, 1069 | 'pts_depth': np.linalg.norm(rpn_xyz, ord=2, axis=1)} 1070 | 1071 | return sample_info 1072 | 1073 | def collate_batch(self, batch): 1074 | if self.mode != 'TRAIN' and cfg.RCNN.ENABLED and not cfg.RPN.ENABLED: 1075 | assert batch.__len__() == 1 1076 | return batch[0] 1077 | 1078 | batch_size = batch.__len__() 1079 | ans_dict = {} 1080 | 1081 | for key in batch[0].keys(): 1082 | if cfg.RPN.ENABLED and key == 'gt_boxes3d' or \ 1083 | (cfg.RCNN.ENABLED and cfg.RCNN.ROI_SAMPLE_JIT and key in ['gt_boxes3d', 'roi_boxes3d']): 1084 | max_gt = 0 1085 | for k in range(batch_size): 1086 | max_gt = max(max_gt, batch[k][key].__len__()) 1087 | batch_gt_boxes3d = np.zeros((batch_size, max_gt, 7), dtype=np.float32) 1088 | for i in range(batch_size): 1089 | batch_gt_boxes3d[i, :batch[i][key].__len__(), :] = batch[i][key] 1090 | ans_dict[key] = batch_gt_boxes3d 1091 | continue 1092 | 1093 | if isinstance(batch[0][key], np.ndarray): 1094 | if batch_size == 1: 1095 | ans_dict[key] = batch[0][key][np.newaxis, ...] 1096 | else: 1097 | ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0) 1098 | 1099 | else: 1100 | ans_dict[key] = [batch[k][key] for k in range(batch_size)] 1101 | if isinstance(batch[0][key], int): 1102 | ans_dict[key] = np.array(ans_dict[key], dtype=np.int32) 1103 | elif isinstance(batch[0][key], float): 1104 | ans_dict[key] = np.array(ans_dict[key], dtype=np.float32) 1105 | 1106 | return ans_dict 1107 | 1108 | 1109 | if __name__ == '__main__': 1110 | pass 1111 | -------------------------------------------------------------------------------- /waymo_pytorch_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import IterableDataset 4 | import dataset_utils.calibration_waymo as calibration 5 | import dataset_utils.object3d as object3d 6 | from PIL import Image 7 | import tqdm 8 | import math 9 | 10 | from simple_waymo_open_dataset_reader import WaymoDataFileReader 11 | from simple_waymo_open_dataset_reader import dataset_pb2, label_pb2 12 | from simple_waymo_open_dataset_reader import utils 13 | 14 | class WaymoDataset(Dataset): 15 | '''Wymo dataset for pytorch 16 | CURRENT: 17 | V Serialized data feeding 18 | TODO: 19 | X Implement shuffling 20 | X Implement IterableDataset/BatchSampler 21 | X Make Cache 22 | 23 | USAGE: 24 | DATA_PATH = '/home/jupyter/waymo-od/waymo_dataset' 25 | LOCATIONS = ['location_sf'] 26 | 27 | dataset = WaymoDataset(DATA_PATH, LOCATIONS, 'train', True, "new_waymo") 28 | 29 | frame, idx = dataset.data, dataset.count 30 | calib = dataset.get_calib(frame, idx) 31 | pts = dataset.get_lidar(frame, idx) 32 | target = dataset.get_label(frame, idx) 33 | 34 | :param root_dir: Root directory of the data 35 | :param split: Select if train/test/val 36 | :param use_cache: Select if you need to save a pkl file of the dataset for easy access 37 | ''' 38 | def __init__(self, root_dir, locations, split='train', use_cache=False, name="Waymo"): 39 | self._name=name 40 | self.split = split 41 | is_test = self.split == 'test' 42 | 43 | self._dataset_dir = os.path.join(root_dir,'kitti_dataset', 'testing' if is_test else 'training') 44 | 45 | self.__lidar_list = ['_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT', '_SIDE_LEFT'] 46 | self.__type_list = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] 47 | self.get_file_names() # Storing file names in object 48 | 49 | self._image = None 50 | self._num_files = len(self.__file_names) 51 | self._curr_counter = 0 52 | self._num_frames = 0 53 | self._total_frames = 0 54 | self._idx_to_frame = [] 55 | self._sample_list = [] 56 | self._frame_counter = -1 # Count the number of frames used per file 57 | self._file_counter = -1 # Count the number of files used 58 | self._dataset_nums = [] # Number of frames to be considered from each file (records+files) 59 | self._dataset_itr = # tfRecord iterator 60 | self.num_sample = self.num_frames 61 | 62 | if use_cache: self.make_cache() 63 | 64 | @property 65 | def name(self): 66 | return self._name 67 | 68 | @property 69 | def num_classes(self): 70 | return len(self._classes) 71 | 72 | @property 73 | def classes(self): 74 | return self._classes 75 | 76 | @property 77 | def count(self): 78 | return self._curr_counter 79 | 80 | @property 81 | def data(self): 82 | self._curr_counter+=1 83 | return self.__getitem__(self._curr_counter) 84 | 85 | @property 86 | def frame_count(self): 87 | return self._frame_counter 88 | 89 | @property 90 | def record_table(self): 91 | return self._sample_list 92 | 93 | @property 94 | def image_shape(self): 95 | if not self.image: return None 96 | width, height = self.image.shape 97 | return height, width, 3 98 | 99 | def __len__(self): 100 | if not self._total_frames: 101 | self.count_frames() 102 | return self._total_frames 103 | 104 | def __getitem__(self, idx): 105 | self._curr_counter = idx 106 | # Get the next dataset if frame number is more than table count 107 | if self._frame_counter == -1 or not len(self._dataset_nums) or self._frame_counter >= self._dataset_nums[self._file_counter]-1: 108 | self.current_file = self.__file_names.pop() # get one filename 109 | dataset = WaymoDataFileReader(self.current_file) # get Dataloader 110 | self._sample_list = dataset.get_record_table() # get number of record table 111 | self._dataset_itr = iter(dataset) # Get next record iterator 112 | if self._frame_counter == -1: 113 | self._file_counter +=1 114 | self._dataset_nums.append(len(self._sample_list)) 115 | self._frame_counter = 1 116 | else: 117 | self._frame_counter+=1 118 | self._num_frames+=1 119 | self._idx_to_frame.append((self._file_counter, self._frame_counter)) 120 | return next(self.dataset_itr) # Send next frame from record 121 | 122 | def count_frames(self): 123 | # Count total frames 124 | for file_name in self.__file_names: 125 | dataset = WaymoDataFileReader(file_name) 126 | for frame in tqdm.tqdm(dataset): 127 | self._total_frames+=1 128 | print("[LOG] Total frames: ", self._total_frames) 129 | 130 | def get_file_names(self): 131 | self.__file_names = [] 132 | for i in os.listdir(DATA_PATH): 133 | if i.split('.')[-1] == 'tfrecord': 134 | self.__file_names.append(DATA_PATH + '/' + i) 135 | print("[log] Number of files found {}".format(len(self.__file_names))) 136 | 137 | def get_lidar(self, frame, idx, all_points=False): 138 | '''Get lidar pointcloud 139 | TODO: Get all 4 lidar points appeneded together 140 | :return pcl: (N, 3) points in (x,y,z) 141 | ''' 142 | laser_name = dataset_pb2.LaserName.TOP # laser information 143 | laser = utils.get(frame.lasers, laser_name) 144 | laser_calibration = utils.get(frame.context.laser_calibrations, laser_name) 145 | ri, camera_projection, range_image_pose = utils.parse_range_image_and_camera_projection(laser) 146 | pcl, pcl_attr = utils.project_to_pointcloud(frame, ri, camera_projection, range_image_pose, laser_calibration) 147 | return pcl 148 | 149 | def get_image(self, frame, idx): 150 | '''Get image 151 | ''' 152 | camera_name = dataset_pb2.CameraName.FRONT 153 | camera_calibration = utils.get(frame.context.camera_calibrations, camera_name) 154 | camera = utils.get(frame.images, camera_name) 155 | vehicle_to_image = utils.get_image_transform(camera_calibration) # Transformation 156 | img = utils.decode_image(camera) 157 | self.image=img 158 | return img 159 | 160 | def get_calib(self, frame, idx): 161 | '''Get calibration object 162 | ''' 163 | return calibration.Calibration(frame, idx) 164 | 165 | def get_label(self, frame, idx): 166 | '''Get label as object3d 167 | { 168 | cls_type: Object class 169 | trucation: If truncated or not in image 170 | occlusion: If occluded or not in image 171 | box2d: 2d (x1, y1, x2, y2) 172 | h: box height 173 | w: box width 174 | l: box length 175 | pos: box center position in (x,y,z) 176 | ry: Heading theta about y axis 177 | score: Target score 178 | alpha: 3D rotation azimuth angle 179 | level: hard/medium/easy 180 | dis_to_cam: range distance of point 181 | } 182 | ''' 183 | # preprocess bounding box data 184 | id_to_bbox = dict() 185 | id_to_name = dict() 186 | for labels in frame.projected_lidar_labels: 187 | name = labels.name 188 | for label in labels.labels: 189 | bbox = [label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, 190 | label.box.center_x + label.box.length / 2, label.box.center_y + label.box.width / 2] 191 | id_to_bbox[label.id] = bbox 192 | id_to_name[label.id] = name - 1 193 | 194 | object_list = [] 195 | for obj in frame.laser_labels: 196 | # caculate bounding box 197 | bounding_box = None 198 | name = None 199 | id = obj.id 200 | for lidar in self.__lidar_list: 201 | if id + lidar in id_to_bbox: 202 | bounding_box = id_to_bbox.get(id + lidar) 203 | name = str(id_to_name.get(id + lidar)) 204 | break 205 | if bounding_box == None or name == None: 206 | continue 207 | 208 | kitti_obj = object3d.Object3d() 209 | kitti_obj.cls_type = self.__type_list[obj.type] 210 | kitti_obj.trucation = 0 211 | kitti_obj.occlusion = 0 212 | kitti_obj.box2d = np.array(( float(bounding_box[0]), float(bounding_box[1]), float(bounding_box[2]), float(bounding_box[3])), dtype=np.float32) 213 | kitti_obj.h = obj.box.height 214 | kitti_obj.w = obj.box.width 215 | kitti_obj.l = obj.box.length 216 | x = obj.box.center_x 217 | y = obj.box.center_y 218 | z = obj.box.center_z 219 | kitti_obj.pos = np.array((float(x), float(y), float(z)), dtype=np.float32) 220 | kitti_obj.ry = obj.box.heading 221 | kitti_obj.score = 1 222 | beta = math.atan2(x, z) 223 | kitti_obj.alpha = (kitti_obj.ry + beta - math.pi / 2) % (2 * math.pi) 224 | kitti_obj.level = kitti_obj.get_obj_level() 225 | kitti_obj.dis_to_cam = np.linalg.norm(kitti_obj.pos) 226 | object_list.append(kitti_obj) 227 | return object_list 228 | 229 | def make_cache(self): 230 | return NotImplemented 231 | 232 | --------------------------------------------------------------------------------