├── model ├── __init__.py ├── __pycache__ │ └── model.cpython-35.pyc └── model.py ├── utils ├── __init__.py ├── __pycache__ │ ├── kitti_utils.cpython-35.pyc │ └── model_utils.cpython-35.pyc ├── visualize_augumented_data.py ├── make_image_dataset.py ├── model_utils.py ├── kitti_eval.py └── kitti_utils.py ├── .gitignore ├── dataset ├── __init__.py ├── __pycache__ │ ├── augument.cpython-35.pyc │ └── dataset.cpython-35.pyc ├── dataset.py └── augument.py ├── examples ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── car_detection_ground.png └── cyclist_detection_ground.png ├── config ├── kitti_anchors.txt └── class_flag.txt ├── weights └── .gitkeep ├── kitti └── training │ ├── calib │ └── .gitkeep │ ├── label_2 │ └── .gitkeep │ └── velodyne │ └── .gitkeep ├── predict.py ├── README.md └── train.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/1.png -------------------------------------------------------------------------------- /examples/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/2.png -------------------------------------------------------------------------------- /examples/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/3.png -------------------------------------------------------------------------------- /examples/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/4.png -------------------------------------------------------------------------------- /config/kitti_anchors.txt: -------------------------------------------------------------------------------- 1 | 1.08 1.19 2 | 3.42 4.41 3 | 6.63 11.38 4 | 9.42 5.11 5 | 16.62 10.52 6 | -------------------------------------------------------------------------------- /weights/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep 4 | -------------------------------------------------------------------------------- /kitti/training/calib/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep 4 | -------------------------------------------------------------------------------- /kitti/training/label_2/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep 4 | -------------------------------------------------------------------------------- /kitti/training/velodyne/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep 4 | -------------------------------------------------------------------------------- /examples/car_detection_ground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/car_detection_ground.png -------------------------------------------------------------------------------- /examples/cyclist_detection_ground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/cyclist_detection_ground.png -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/model/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/augument.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/dataset/__pycache__/augument.cpython-35.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/dataset/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/kitti_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/utils/__pycache__/kitti_utils.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/utils/__pycache__/model_utils.cpython-35.pyc -------------------------------------------------------------------------------- /config/class_flag.txt: -------------------------------------------------------------------------------- 1 | 0 Car (255,50,0) 2 | 1 Van (255,0,130) 3 | 2 Truck (0,0,255) 4 | 3 Pedestrain (0,100,255) 5 | 4 person_sitting (255,255,0) 6 | 5 Cyclist (0,255,0) 7 | 6 Tram (0,255,255) 8 | 7 Misc (255,0,255) 9 | -------------------------------------------------------------------------------- /utils/visualize_augumented_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import cv2 3 | import os 4 | import sys 5 | sys.path.append('.') 6 | from dataset.dataset import ImageDataSet 7 | from utils.kitti_utils import draw_rotated_box, get_corner_gtbox 8 | img_h, img_w = 768, 1024 9 | gt_box_color = (255, 255, 255) 10 | class_list = [ 11 | 'Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 12 | 'Misc' 13 | ] 14 | dataset = ImageDataSet(data_set='test', 15 | mode='visualize', 16 | flip=True, 17 | random_scale=True, 18 | load_to_memory=False) 19 | 20 | 21 | def make_dir(directory): 22 | if not os.path.exists(directory): 23 | os.makedirs(directory) 24 | 25 | 26 | make_dir('./tmp') 27 | for img_idx, img, target in dataset.data_generator(): 28 | # draw gt bbox 29 | print("process data: {}, saved in ./tmp".format(img_idx)) 30 | for i in range(target.shape[0]): 31 | cx = int(target[i][1]) 32 | cy = int(target[i][2]) 33 | w = int(target[i][3]) 34 | h = int(target[i][4]) 35 | rz = target[i][5] 36 | draw_rotated_box(img, cx, cy, w, h, rz, gt_box_color) 37 | label = class_list[int(target[i][0])] 38 | box = get_corner_gtbox([cx, cy, w, h]) 39 | cv2.putText(img, label, (box[0], box[1]), cv2.FONT_HERSHEY_PLAIN, 1.0, 40 | gt_box_color, 1) 41 | cv2.imwrite('./tmp/{}.png'.format(img_idx), img[:, :, ::-1]) 42 | -------------------------------------------------------------------------------- /utils/make_image_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import cv2 5 | import numpy as np 6 | import sys 7 | sys.path.append('.') 8 | from dataset.dataset import PointCloudDataset 9 | from utils.model_utils import make_dir 10 | image_dataset_dir = 'kitti/image_dataset/' 11 | class_list = [ 12 | 'Car', 'Van', 'Truck', 'Pedestrian', 13 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc' 14 | ] 15 | img_h, img_w = 768, 1024 16 | # dataset 17 | train_dataset = PointCloudDataset(root='kitti/', data_set='train') 18 | test_dataset = PointCloudDataset(root='kitti/', data_set='test') 19 | 20 | 21 | def delete_file_folder(src): 22 | if os.path.isfile(src): 23 | os.remove(src) 24 | elif os.path.isdir(src): 25 | for item in os.listdir(src): 26 | item_src = os.path.join(src, item) 27 | delete_file_folder(item_src) 28 | 29 | 30 | def preprocess_dataset(dataset): 31 | """ 32 | Convert point cloud data to image while 33 | filtering out image without objects. 34 | param: dataset: (PointCloudDataset) 35 | return: None 36 | """ 37 | for img_idx, rgb_map, target in dataset.getitem(): 38 | rgb_map = np.array(rgb_map * 255, np.uint8) 39 | target = np.array(target) 40 | print('process image: {}'.format(img_idx)) 41 | for i in range(target.shape[0]): 42 | if target[i].sum() == 0: 43 | break 44 | with open("kitti/image_dataset/labels/{}.txt".format(img_idx), 'a+') as f: 45 | label = class_list[int(target[i][0])] 46 | cx = target[i][1] * img_w 47 | cy = target[i][2] * img_h 48 | w = target[i][3] * img_w 49 | h = target[i][4] * img_h 50 | rz = target[i][5] 51 | line = label + ' ' + '{} {} {} {} {}\n'.format(cx, cy, w, h, rz) 52 | f.write(line) 53 | cv2.imwrite('kitti/image_dataset/images/{}.png'.format(img_idx), rgb_map[:, :, ::-1]) 54 | print('make image dataset done!') 55 | 56 | 57 | def make_train_test_list(): 58 | name_list = os.listdir(image_dataset_dir + 'labels') 59 | name_list.sort() 60 | with open('config/test_image_list.txt', 'w') as f: 61 | for name in name_list[0:1000]: 62 | f.write(name.split('.')[0]) 63 | f.write('\n') 64 | with open('config/train_image_list.txt', 'w') as f: 65 | for name in name_list[1000:]: 66 | f.write(name.split('.')[0]) 67 | f.write('\n') 68 | 69 | 70 | if __name__ == "__main__": 71 | make_dir(image_dataset_dir + 'images') 72 | make_dir(image_dataset_dir + 'labels') 73 | delete_file_folder(image_dataset_dir + 'labels') 74 | preprocess_dataset(train_dataset) 75 | preprocess_dataset(test_dataset) 76 | make_train_test_list() 77 | 78 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import tensorflow as tf 7 | from dataset.dataset import PointCloudDataset 8 | from utils.model_utils import preprocess_data, non_max_supression, filter_bbox, make_dir 9 | from utils.kitti_utils import draw_rotated_box, calculate_angle, get_corner_gtbox, \ 10 | read_anchors_from_file, read_class_flag 11 | gt_box_color = (255, 255, 255) 12 | prob_th = 0.3 13 | nms_iou_th = 0.4 14 | n_anchors = 5 15 | n_classes = 8 16 | net_scale = 32 17 | img_h, img_w = 768, 1024 18 | grid_w, grid_h = 32, 24 19 | class_list = [ 20 | 'Car', 'Van', 'Truck', 'Pedestrian', 21 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc' 22 | ] 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--draw_gt_box", type=str, default='True', help="Whether to draw_gtbox, True or False") 26 | parser.add_argument("--weights_path", type=str, default='./weights/yolo_tloss_1.185166835784912_vloss_2.9397876932621-220800', 27 | help="set the weights_path") 28 | args = parser.parse_args() 29 | weights_path = args.weights_path 30 | 31 | # dataset 32 | dataset = PointCloudDataset(root='./kitti/', data_set='test') 33 | make_dir('./predict_result') 34 | 35 | 36 | def predict(draw_gt_box='False'): 37 | 38 | important_classes, names, color = read_class_flag('config/class_flag.txt') 39 | anchors = read_anchors_from_file('config/kitti_anchors.txt') 40 | sess = tf.Session() 41 | saver = tf.train.import_meta_graph(weights_path + '.meta') 42 | saver.restore(sess, weights_path) 43 | graph = tf.get_default_graph() 44 | image = graph.get_tensor_by_name("image_placeholder:0") 45 | train_flag = graph.get_tensor_by_name("flag_placeholder:0") 46 | y = graph.get_tensor_by_name("net/y:0") 47 | for img_idx, rgb_map, target in dataset.getitem(): 48 | print("process data: {}, saved in ./predict_result/".format(img_idx)) 49 | img = np.array(rgb_map * 255, np.uint8) 50 | target = np.array(target) 51 | # draw gt bbox 52 | if draw_gt_box == 'True': 53 | for i in range(target.shape[0]): 54 | if target[i].sum() == 0: 55 | break 56 | cx = int(target[i][1] * img_w) 57 | cy = int(target[i][2] * img_h) 58 | w = int(target[i][3] * img_w) 59 | h = int(target[i][4] * img_h) 60 | rz = target[i][5] 61 | draw_rotated_box(img, cx, cy, w, h, rz, gt_box_color) 62 | label = class_list[int(target[i][0])] 63 | box = get_corner_gtbox([cx, cy, w, h]) 64 | cv2.putText(img, label, (box[0], box[1]), 65 | cv2.FONT_HERSHEY_PLAIN, 1.0, gt_box_color, 1) 66 | data = sess.run(y, feed_dict={image: [rgb_map], train_flag: False}) 67 | classes, rois = preprocess_data(data, anchors, important_classes, 68 | grid_w, grid_h, net_scale) 69 | classes, index = non_max_supression(classes, rois, prob_th, nms_iou_th) 70 | all_boxes = filter_bbox(classes, rois, index) 71 | for box in all_boxes: 72 | class_idx = box[0] 73 | corner_box = get_corner_gtbox(box[1:5]) 74 | angle = calculate_angle(box[6], box[5]) 75 | class_prob = box[7] 76 | draw_rotated_box(img, box[1], box[2], box[3], box[4], 77 | angle, color[class_idx]) 78 | cv2.putText(img, 79 | class_list[class_idx] + ' : {:.2f}'.format(class_prob), 80 | (corner_box[0], corner_box[1]), cv2.FONT_HERSHEY_PLAIN, 81 | 0.7, color[class_idx], 1, cv2.LINE_AA) 82 | cv2.imwrite('./predict_result/{}.png'.format(img_idx), img) 83 | 84 | 85 | if __name__ == '__main__': 86 | predict(draw_gt_box=args.draw_gt_box) 87 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | def make_dir(directory): 6 | if not os.path.exists(directory): 7 | os.makedirs(directory) 8 | 9 | 10 | def bbox_iou(box1, box2, x1y1x2y2=True): 11 | if x1y1x2y2: 12 | mx = min(box1[0], box2[0]) 13 | Mx = max(box1[2], box2[2]) 14 | my = min(box1[1], box2[1]) 15 | My = max(box1[3], box2[3]) 16 | w1 = box1[2] - box1[0] 17 | h1 = box1[3] - box1[1] 18 | w2 = box2[2] - box2[0] 19 | h2 = box2[3] - box2[1] 20 | else: 21 | mx = min(box1[0] - box1[2] / 2.0, box2[0] - box2[2] / 2.0) 22 | Mx = max(box1[0] + box1[2] / 2.0, box2[0] + box2[2] / 2.0) 23 | my = min(box1[1] - box1[3] / 2.0, box2[1] - box2[3] / 2.0) 24 | My = max(box1[1] + box1[3] / 2.0, box2[1] + box2[3] / 2.0) 25 | w1 = box1[2] 26 | h1 = box1[3] 27 | w2 = box2[2] 28 | h2 = box2[3] 29 | uw = Mx - mx 30 | uh = My - my 31 | cw = w1 + w2 - uw 32 | ch = h1 + h2 - uh 33 | carea = 0 34 | if cw <= 0 or ch <= 0: 35 | return 0.0 36 | 37 | area1 = w1 * h1 38 | area2 = w2 * h2 39 | carea = cw * ch 40 | uarea = area1 + area2 - carea 41 | return carea / uarea 42 | 43 | 44 | def iou(r1, r2): 45 | intersect_w = np.maximum( 46 | np.minimum(r1[0] + r1[2], r2[0] + r2[2]) - np.maximum(r1[0], r2[0]), 0) 47 | intersect_h = np.maximum( 48 | np.minimum(r1[1] + r1[3], r2[1] + r2[3]) - np.maximum(r1[1], r2[1]), 0) 49 | area_r1 = r1[2] * r1[3] 50 | area_r2 = r2[2] * r2[3] 51 | intersect = intersect_w * intersect_h 52 | union = area_r1 + area_r2 - intersect 53 | return intersect / union 54 | 55 | 56 | def softmax(x): 57 | e_x = np.exp(x) 58 | return e_x / np.sum(e_x) 59 | 60 | 61 | def sigmoid(x): 62 | return 1.0 / (1.0 + np.exp(-x)) 63 | 64 | 65 | def non_max_supression(classes, locations, prob_th, iou_th): 66 | """ 67 | Filter out some overlapping boxes by non-maximum suppression 68 | """ 69 | classes = np.transpose(classes) 70 | indxs = np.argsort(-classes, axis=1) 71 | 72 | for i in range(classes.shape[0]): 73 | classes[i] = classes[i][indxs[i]] 74 | 75 | for class_idx, class_vec in enumerate(classes): 76 | for roi_idx, roi_prob in enumerate(class_vec): 77 | if roi_prob < prob_th: 78 | classes[class_idx][roi_idx] = 0 79 | 80 | for class_idx, class_vec in enumerate(classes): 81 | for roi_idx, roi_prob in enumerate(class_vec): 82 | if roi_prob == 0: 83 | continue 84 | roi = locations[indxs[class_idx][roi_idx]][0:4] 85 | for roi_ref_idx, roi_ref_prob in enumerate(class_vec): 86 | if roi_ref_prob == 0 or roi_ref_idx <= roi_idx: 87 | continue 88 | roi_ref = locations[indxs[class_idx][roi_ref_idx]][0:4] 89 | if bbox_iou(roi, roi_ref, False) > iou_th: 90 | classes[class_idx][roi_ref_idx] = 0 91 | return classes, indxs 92 | 93 | 94 | def filter_bbox(classes, rois, indxs): 95 | """ 96 | Pick out bounding boxes that are retained after non-maximum suppression 97 | """ 98 | all_bboxs = [] 99 | for class_idx, c in enumerate(classes): 100 | for loc_idx, class_prob in enumerate(c): 101 | if class_prob > 0: 102 | x = int(rois[indxs[class_idx][loc_idx]][0]) 103 | y = int(rois[indxs[class_idx][loc_idx]][1]) 104 | w = int(rois[indxs[class_idx][loc_idx]][2]) 105 | h = int(rois[indxs[class_idx][loc_idx]][3]) 106 | re = rois[indxs[class_idx][loc_idx]][4] 107 | im = rois[indxs[class_idx][loc_idx]][5] 108 | all_bboxs.append([class_idx, x, y, w, h, re, im, class_prob]) 109 | return all_bboxs 110 | 111 | 112 | def preprocess_data(data, anchors, important_classes, grid_w, grid_h, net_scale): 113 | """ 114 | Decode the data output by the model, obtain the center coordinates 115 | x, y and width and height of the bounding box in the image, 116 | and the category, the real and imaginary parts of the complex. 117 | 118 | """ 119 | locations = [] 120 | classes = [] 121 | n_anchors = np.shape(anchors)[0] 122 | for i in range(grid_h): 123 | for j in range(grid_w): 124 | for k in range(n_anchors): 125 | class_vec = softmax(data[0, i, j, k, 7:]) 126 | object_conf = sigmoid(data[0, i, j, k, 6]) 127 | class_prob = object_conf * class_vec 128 | w = np.exp(data[0, i, j, k, 2] 129 | ) * anchors[k][0] / 80 * grid_w * net_scale 130 | h = np.exp(data[0, i, j, k, 3] 131 | ) * anchors[k][1] / 60 * grid_h * net_scale 132 | dx = sigmoid(data[0, i, j, k, 0]) 133 | dy = sigmoid(data[0, i, j, k, 1]) 134 | re = 2 * sigmoid(data[0, i, j, k, 4]) - 1 135 | im = 2 * sigmoid(data[0, i, j, k, 5]) - 1 136 | y = (i + dy) * net_scale 137 | x = (j + dx) * net_scale 138 | classes.append(class_prob[important_classes]) 139 | locations.append([x, y, w, h, re, im]) 140 | classes = np.array(classes) 141 | locations = np.array(locations) 142 | return classes, locations 143 | 144 | -------------------------------------------------------------------------------- /utils/kitti_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This script is mainly used to generate the prediction result 4 | of each RGB-map of the test set, convert the bounding box coordinates 5 | in the image to the lidar coordinate system, and then convert the 6 | coordinates to the camera coordinate system using the coordinate 7 | transformation matrix provided by the kitti data set. 8 | The coordinate transformation requires (x, y, z), since the complex-yolo 9 | does not predict the height of the objects, the height is set to a fixed value 10 | of 1.5, which does not affect the bird's eye view benchmark evaluation. 11 | 12 | """ 13 | from __future__ import division 14 | import numpy as np 15 | import argparse 16 | import tensorflow as tf 17 | import cv2 18 | from model_utils import preprocess_data, non_max_supression, filter_bbox, make_dir 19 | from kitti_utils import load_kitti_calib, calculate_angle,\ 20 | read_anchors_from_file, read_class_flag 21 | from kitti_utils import angle_rz_to_ry, coord_image_to_velo, coord_velo_to_cam 22 | prob_th = 0.3 23 | iou_th = 0.4 24 | n_anchors = 5 25 | n_classes = 8 26 | net_scale = 32 27 | img_h, img_w = 768, 1024 28 | grid_w, grid_h = 32, 24 29 | test_image_path = "kitti/image_dataset/images/" 30 | class_list = [ 31 | 'Car', 'Van', 'Truck', 'Pedestrian', 32 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc' 33 | ] 34 | train_list = 'config/train_image_list.txt' 35 | test_list = 'config/test_image_list.txt' 36 | calib_dir = 'kitti/training/calib/' 37 | 38 | # kitti_static_cylist = 'cyclist_detection_ground.txt' 39 | # kitti_static_car = 'car_detection_ground.txt' 40 | # kitti_static_pedestrian = 'pedestrian_detection_ground.txt' 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--weights_path", type=str, default='./weights/yolo_tloss_1.185166835784912_vloss_2.9397876932621-220800', help="set the weights_path") 44 | args = parser.parse_args() 45 | weights_path = args.weights_path 46 | 47 | make_dir("./eval_results") 48 | 49 | 50 | def kitti_eval(): 51 | important_classes, names, colors = read_class_flag('config/class_flag.txt') 52 | anchors = read_anchors_from_file('config/kitti_anchors.txt') 53 | sess = tf.Session() 54 | saver = tf.train.import_meta_graph(weights_path + '.meta') 55 | saver.restore(sess, weights_path) 56 | graph = tf.get_default_graph() 57 | image = graph.get_tensor_by_name("image_placeholder:0") 58 | train_flag = graph.get_tensor_by_name("flag_placeholder:0") 59 | y = graph.get_tensor_by_name("net/y:0") 60 | for test_file_index in range(1000): 61 | print('process data: {}, saved in ./eval_results'.format(test_file_index)) 62 | calib_file = calib_dir + str(test_file_index).zfill(6) + '.txt' 63 | calib = load_kitti_calib(calib_file) 64 | result_file = "./eval_results/" + str(test_file_index).zfill(6) + ".txt" 65 | img_path = test_image_path + str(test_file_index).zfill(6) + '.png' 66 | rgb_map = cv2.imread(img_path)[:, :, ::-1] 67 | img_for_net = rgb_map / 255.0 68 | data = sess.run(y, 69 | feed_dict={image: [img_for_net], 70 | train_flag: False}) 71 | classes, rois = preprocess_data(data, 72 | anchors, 73 | important_classes, 74 | grid_w, 75 | grid_h, 76 | net_scale) 77 | classes, index = non_max_supression(classes, 78 | rois, 79 | prob_th, 80 | iou_th) 81 | all_boxes = filter_bbox(classes, rois, index) 82 | with open(result_file, "w") as f: 83 | for box in all_boxes: 84 | pred_img_y = box[2] 85 | pred_img_x = box[1] 86 | velo_x, velo_y = coord_image_to_velo(pred_img_y, pred_img_x) 87 | cam_x, cam_z = coord_velo_to_cam(velo_x, velo_y, calib['Tr_velo2cam']) 88 | pred_width = box[3] * 80.0 / img_w 89 | pred_height = box[4] * 60.0 / img_h 90 | pred_cls = class_list[box[0]] 91 | pred_conf = box[7] 92 | angle_rz = calculate_angle(box[6], box[5]) 93 | angle_ry = angle_rz_to_ry(angle_rz) 94 | pred_line = pred_cls + " -1 -1 -10 -1 -1 -1 -1 -1" + \ 95 | " {:.2f} {:.2f}".format(pred_width, pred_height) + \ 96 | " {:.2f} {:.2f} {:.2f}".format(cam_x, -1000, cam_z) + \ 97 | " {:.2f} {:.2f}".format(angle_ry, pred_conf) 98 | f.write(pred_line) 99 | f.write("\n") 100 | 101 | 102 | def cal_ap(kitti_statics_results): 103 | """ 104 | Calculate the ap approximately. 105 | param kitti_statics_results(str): Kitti evaluation script output statistics result file. 106 | return: 107 | """ 108 | with open(kitti_statics_results, 'r') as f: 109 | lines = f.readlines() 110 | all_lines = [] 111 | for line in lines: 112 | pr = list(map(float, line.strip().split(' '))) 113 | all_lines.append(pr) 114 | all_lines = np.array(all_lines) 115 | ap = np.zeros([all_lines.shape[0], 3]) 116 | ap[1:, 0] = 0.025 * all_lines[1:, 1] 117 | ap[1:, 1] = 0.025 * all_lines[1:, 2] 118 | ap[1:, 2] = 0.025 * all_lines[1:, 3] 119 | result = np.sum(ap, 0) 120 | return result 121 | 122 | 123 | if __name__ == '__main__': 124 | kitti_eval() 125 | # print("car ap: {}".format(cal_ap(kitti_static_car))) 126 | # print("cyclist ap: {}".format(cal_ap(kitti_static_cylist))) 127 | # print("pedestrian ap: {}".format(cal_ap(kitti_static_pedestrian))) 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Complex-YOLO implementation in tensorflow 2 | --- 3 | ### Contents 4 | 5 | [Overview](#overview)
[Examples](#examples)
[Dependencies](#dependencies)
[How to use it](#how-to-use-it)
[Others](#others)
[ToDo](#todo) 6 | 7 | ### Overview 8 | 9 | The project is an unofficial implementation of complex-yolo, and the model structure is slightly inconsistent with what the paper describes. [Complex-YOLO: Real-time 3D Object Detection on Point Clouds](https://arxiv.org/abs/1803.06199).   [AI-liu/Complex-YOLO](https://github.com/AI-liu/Complex-YOLO) has the most stars, but there seem to be some bugs. The model has no yaw angle prediction, and on the test set, the model has no generalization ability, so this project only refers to the point cloud preprocessing part , model structure reference  [WojciechMormul/yolo2](https://github.com/WojciechMormul/yolo2). On this basis, a complete complex-yolo algorithm is implemented. Because of the high precision of this model, it can be easily converged, and there is no need to adjust too many parameters carefully. 10 | 11 | Complex-yolo takes point cloud data as input and encodes point cloud into RGB-map of bird 's-eye view to predict the position and yaw angle of objiects in 3d space. In order to improve the efficiency of training model, the point cloud data set is firstly made into RGB dataset. The experiment is based on the kitti dataset. The kitti dataset has a total of 7481 labeled data. The dataset is divided into two parts, the first 1000 samples are used as test sets, and the remaining samples are used as training sets. 12 | 13 | ### Examples 14 | 15 | Below are some prediction examples of the Complex-Yolo, the predictions were made on the splited test set. The iou of car and cyclist are set to 0.5, 0.3 respectively. 16 | 17 | | | | 18 | |---|---| 19 | |
|
| 20 | |
|
| 21 | |
|
| 22 | 23 | ### Dependencies 24 | 25 | * Python 3.x 26 | * Numpy 27 | * TensorFlow 1.x 28 | * OpenCV 29 | 30 | ### How to use it 31 | 32 | Clone this repo 33 | 34 | ```bash 35 | git clone https://github.com/wwooo/tensorflow_complex_yolo 36 | ``` 37 | 38 | 39 | ```bash 40 | cd tensorflow_complex_yolo 41 | ``` 42 | How to prepare data: 43 | 44 | 1 . Download the data from the official website of kitti. 45 | 46 | * [data_object_velodyne.zip](http://www.cvlibs.net/download.php?file=data_object_velodyne.zip) 47 | * [data_object_label_2.zip](http://www.cvlibs.net/download.php?file=data_object_label_2.zip) 48 | * [data_object_calib.zip](http://www.cvlibs.net/download.php?file=data_object_calib.zip) 49 | 50 | 2 . Create the following folder structure in the current working directory 51 | 52 | ``` 53 | tensorflow_complex_yolo 54 | kitti 55 | training 56 | calib 57 | label_2 58 | velodyne 59 | ``` 60 | 61 | 62 | 3 . Unzip the downloaded kitti dataset and get the following data. Place the data in the corresponding folder created above. 63 | 64 | 65 | data_object_velodyne/training/\*.bin      \*.bin -> velodyne 66 | 67 | data_object_label_2/training/label_2/\*.txt      \*.txt -> label_2 68 | 69 | data_object_calib/training/calib/\*.txt           \*.txt -> calib 70 | 71 | 72 | Then create RGB-image data set: 73 | 74 | ```bash 75 | python utils/make_image_dataset.py 76 | ``` 77 | 78 | This script will convert the point cloud data into image data, which will be automatically saved in the ./kitti/image_dataset/, and will generate test_image_list.txt and train_image_list.txt in the ./config folder. 79 | 80 | Note:This model only predicts the area of 60x80 in front of the car, and encodes the point cloud in this area into a 768 x1024 RGB-map. In the kitti data set, not all samples have objects in this area. Therefore, in the process of making image dataset, the script will automatically filter out samples of that doesn't have objects in the area. 81 | 82 | How to train a model: 83 | ```bash 84 | python train.py 85 | --load_weights 86 | --weights_path 87 | --batch_size 88 | --num_iter 89 | --save_dir 90 | --save_interval 91 | --gpu_id 92 | ``` 93 | All parameters have default values, so you can run the script directly. If you want to load model weights, you must provide the weights\_path and set--load\_weights=True , default is False. --batch_size, default 8, you can adjust the batch_size according to the memory size of the GPU card. --num_iter, set the number of iterations. --save_interval, how many epochs to save the model, default is 2 . --save\_dir, where the model is saved, default is ./weights/ . --gpu_id specify which card to use for training, default is 0. 94 | 95 | How to predict: 96 | 97 | ```bash 98 | python predict.py --weights_path =./weights_path/... --draw_gt_box=True 99 | ``` 100 | 101 | When running predict.py , directly use point cloud data as input to the model, and the script saves predicted result in the predict\_result folder. You can set draw\_gt_box = True or False to decide whether to draw the ground truth box on predicted result. 102 | 103 | How to eval: 104 | 105 | ```bash 106 | python utils/kitti_eval.py 107 | ``` 108 | 109 | This script will save the prediction results consistent with the kitti label format. Then use kitti's official evaluation script to evaluate. You should study the official evaluation script of kitti. 110 | 111 | ### Others 112 | 113 | You can run utils/visualize_augumented_data.py to visualize the transformed data and labels, results saved in ./tmp. 114 | 115 | ### ToDo 116 | 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import os 4 | import argparse 5 | import tensorflow as tf 6 | from utils.model_utils import make_dir 7 | from dataset.dataset import ImageDataSet 8 | from model.model import yolo_net, yolo_loss 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--load_weights", type=str, default='False', help="Whether to load weights, True or False") 13 | parser.add_argument("--batch_size", type=int, default=8, help="Set the batch_size") 14 | parser.add_argument("--weights_path", type=str, default="./weights/...", help="Set the weights_path") 15 | parser.add_argument("--save_dir", type=str, default="./weights/", help="Dir to save weights") 16 | parser.add_argument("--gpu_id", type=str, default='0', help="Specify GPU device") 17 | parser.add_argument("--num_iter", type=int, default=16000, help="num_max_iter") 18 | parser.add_argument("--save_interval", type=int, default=2, help="Save once every two epochs") 19 | args = parser.parse_args() 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 21 | 22 | SCALE = 32 23 | GRID_W, GRID_H = 32, 24 24 | N_CLASSES = 8 25 | N_ANCHORS = 5 26 | IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH = GRID_H * SCALE, GRID_W * SCALE, 3 27 | batch_size = args.batch_size 28 | 29 | train_dataset = ImageDataSet(data_set='train', 30 | mode='train', 31 | load_to_memory=False) 32 | test_dataset = ImageDataSet(data_set='test', 33 | mode='test', 34 | flip=False, 35 | aug_hsv=False, 36 | random_scale=False, 37 | load_to_memory=False) 38 | num_val_step = int(test_dataset.num_samples / args.batch_size) 39 | save_steps = int(train_dataset.num_samples / args.batch_size * args.save_interval) 40 | 41 | 42 | def print_info(): 43 | print("train samples: {}".format(train_dataset.num_samples)) 44 | print("test samples: {}".format(test_dataset.num_samples)) 45 | print("batch_size: {}".format(args.batch_size)) 46 | print("iter steps: {}".format(args.num_iter)) 47 | 48 | 49 | def train(load_weights='False'): 50 | make_dir(args.save_dir) 51 | max_val_loss = 99999999.0 52 | global_step = tf.Variable(0, trainable=False) 53 | learning_rate = tf.train.exponential_decay(0.001, 54 | global_step, 55 | 1500, 56 | 0.96, 57 | staircase=True) 58 | 59 | image = tf.placeholder( 60 | shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH], 61 | dtype=tf.float32, 62 | name='image_placeholder') 63 | label = tf.placeholder(shape=[None, GRID_H, GRID_W, N_ANCHORS, 8], 64 | dtype=tf.float32, 65 | name='label_placeholder') 66 | train_flag = tf.placeholder(dtype=tf.bool, name='flag_placeholder') 67 | 68 | with tf.variable_scope('net'): 69 | y = yolo_net(image, train_flag) 70 | with tf.name_scope('loss'): 71 | loss, loss_xy, loss_wh, loss_re, loss_im, loss_obj, loss_no_obj, loss_c = yolo_loss( 72 | y, label, batch_size) 73 | 74 | loss_xy_sum = tf.summary.scalar("loss_xy_sum", loss_xy) 75 | loss_wh_sum = tf.summary.scalar("loss_wh_sum", loss_wh) 76 | loss_re_sum = tf.summary.scalar("loss_re_sum", loss_re) 77 | loss_im_sum = tf.summary.scalar("loss_im_sum", loss_im) 78 | loss_obj_sum = tf.summary.scalar("loss_obj_sum", loss_obj) 79 | loss_no_obj_sum = tf.summary.scalar("loss_no_obj_sum", loss_no_obj) 80 | loss_c_sum = tf.summary.scalar("loss_c", loss_c) 81 | loss_sum = tf.summary.scalar("loss", loss) 82 | loss_tensorboard_sum = tf.summary.merge([ 83 | loss_xy_sum, loss_wh_sum, loss_re_sum, loss_im_sum, 84 | loss_obj_sum, loss_no_obj_sum, loss_c_sum, loss_sum 85 | ]) 86 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate) 87 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 88 | with tf.control_dependencies(update_ops): 89 | train_step = opt.minimize(loss, global_step=global_step) 90 | sess = tf.Session() 91 | sess.run(tf.global_variables_initializer()) 92 | saver = tf.train.Saver() 93 | writer = tf.summary.FileWriter("./logs", sess.graph) 94 | 95 | if load_weights == 'True': 96 | print("load weights from {}".format(args.weights_path)) 97 | saver = tf.train.import_meta_graph(args.weights_path + '.meta') 98 | saver.restore(sess, args.weights_path) 99 | print('load weights done!') 100 | 101 | for step, (train_image_data, train_label_data) in enumerate( 102 | train_dataset.get_batch(batch_size)): 103 | _, lr, train_loss, data, summary_str = sess.run( 104 | [train_step, learning_rate, loss, y, loss_tensorboard_sum], 105 | feed_dict={ 106 | train_flag: True, 107 | image: train_image_data, 108 | label: train_label_data 109 | }) 110 | writer.add_summary(summary_str, step) 111 | 112 | if step % 10 == 0: 113 | print('iter: %i, loss: %f, lr: %f' % (step, train_loss, lr)) 114 | if (step + 1) % save_steps == 0: 115 | print("val...") 116 | val_loss = 0.0 117 | for val_step, (val_image_data, val_label_data) in enumerate( 118 | test_dataset.get_batch(batch_size)): 119 | val_loss += sess.run(loss, 120 | feed_dict={ 121 | train_flag: False, 122 | image: val_image_data, 123 | label: val_label_data 124 | }) 125 | if val_step + 1 == num_val_step: 126 | break 127 | val_loss /= num_val_step 128 | print("iter: {} val_loss: {:.2f}".format(step, val_loss)) 129 | if val_loss < max_val_loss: 130 | saver.save(sess, 131 | os.path.join( 132 | args.save_dir, 133 | 'Complex_YOLO_train_loss_{:.2f}_val_loss_{:.2f}'.format( 134 | train_loss, val_loss)), 135 | global_step=global_step) 136 | max_val_loss = val_loss 137 | if step + 1 == args.num_iter: 138 | saver.save(sess, 139 | os.path.join( 140 | args.save_dir, 141 | 'Complex_YOLO_final_train_loss_{:.2f}'.format( 142 | train_loss)), 143 | global_step=global_step) 144 | print("training done!") 145 | break 146 | 147 | 148 | if __name__ == "__main__": 149 | print_info() 150 | train(load_weights=args.load_weights) 151 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import numpy as np 4 | import cv2 5 | import os 6 | from utils.kitti_utils import read_anchors_from_file, read_label_from_txt, \ 7 | load_kitti_calib, get_target, remove_points, make_bv_feature 8 | from dataset.augument import RandomScaleAugmentation 9 | from model.model import encode_label 10 | img_h, img_w = 768, 1024 11 | grid_h, grid_w = 24, 32 12 | iou_th = 0.5 13 | boundary = { 14 | 'minX': 0, 15 | 'maxX': 80, 16 | 'minY': -40, 17 | 'maxY': 40, 18 | 'minZ': -2, 19 | 'maxZ': 1.25 20 | } 21 | 22 | 23 | class PointCloudDataset(object): 24 | def __init__(self, 25 | root='./kitti/', 26 | data_set='train'): 27 | self.root = root 28 | self.data_path = os.path.join(root, 'training') 29 | self.lidar_path = os.path.join(self.data_path, "velodyne") 30 | self.calib_path = os.path.join(self.data_path, "calib") 31 | self.label_path = os.path.join(self.data_path, "label_2") 32 | self.index_list = [str(i) for i in range(1000)] if data_set == "test" \ 33 | else [str(i) for i in range(1000, 7481)] 34 | 35 | def getitem(self): 36 | """ 37 | Encode single-frame point cloud data into RGB-map and get the label 38 | """ 39 | for index in self.index_list: 40 | index = index.zfill(6) 41 | lidar_file = self.lidar_path + '/' + index + '.bin' 42 | calib_file = self.calib_path + '/' + index + '.txt' 43 | label_file = self.label_path + '/' + index + '.txt' 44 | calib = load_kitti_calib(calib_file) 45 | target = get_target(label_file, calib['Tr_velo2cam']) 46 | # load point cloud data 47 | point_cloud = np.fromfile(lidar_file, 48 | dtype=np.float32).reshape(-1, 4) 49 | b = remove_points(point_cloud, boundary) 50 | rgb_map = make_bv_feature(b) # (768, 1024, 3) 51 | 52 | yield index, rgb_map, target 53 | 54 | 55 | class ImageDataSet(object): 56 | """ 57 | If there is enough memory, set load_to_memory=True, 58 | load the data into memory to improve training efficiency. 59 | """ 60 | def __init__(self, 61 | data_set='train', 62 | mode='train', 63 | flip=True, 64 | random_scale=True, 65 | aug_hsv=False, 66 | load_to_memory=False): 67 | self.mode = mode 68 | self.flip = flip 69 | self.aug_hsv = aug_hsv 70 | self.random_scale = random_scale 71 | self.anchors_path = 'config/kitti_anchors.txt' 72 | self.labels_dir = 'kitti/image_dataset/labels/' 73 | self.images_dir = 'kitti/image_dataset/images/' 74 | self.all_image_index = 'config/' + data_set + '_image_list.txt' 75 | self.load_to_memory = load_to_memory 76 | self.anchors = read_anchors_from_file(self.anchors_path) 77 | self.rand_scale_transform = RandomScaleAugmentation(img_h, img_w) 78 | self.label = None 79 | self.img = None 80 | self.img_index = None 81 | self.label_encoded = None 82 | self.index_list = self.read_index_list() 83 | self.num_samples = len(self.index_list) 84 | 85 | def read_index_list(self): 86 | with open(self.all_image_index, 'r') as f: 87 | index_list = f.readlines() 88 | return index_list 89 | 90 | def horizontal_flip(self, image, target): # target: class,x,y,w,l,angle 91 | image = np.flip(image, 1) # image = image[:, ::-1, :] 92 | image_w = image.shape[1] 93 | target[:, 1] = image_w - target[:, 1] 94 | target[:, 5] = -target[:, 5] 95 | return image, target 96 | 97 | def augment_hsv(self, img): 98 | fraction = 0.30 # must be < 1.0 99 | img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) # hue, sat, val 100 | s = img_hsv[:, :, 1].astype(np.float32) # saturation 101 | v = img_hsv[:, :, 2].astype(np.float32) # value 102 | a = (np.random.random() * 2 - 1) * fraction + 1 103 | b = (np.random.random() * 2 - 1) * fraction + 1 104 | s *= a 105 | v *= b 106 | img_hsv[:, :, 1] = s if a < 1 else s.clip(None, 255) 107 | img_hsv[:, :, 2] = v if b < 1 else v.clip(None, 255) 108 | img = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB) 109 | return img 110 | 111 | def label_box_center_to_corner(self, label): 112 | """ 113 | param label: class, cx, cy, w, l, angle 114 | return: class, x_min, y_min, x_max, y_max, angle 115 | """ 116 | label_ = np.copy(label) 117 | cx = label_[:, 1] 118 | cy = label_[:, 2] 119 | w = label_[:, 3] 120 | l = label_[:, 4] 121 | label[:, 1] = cx - w / 2.0 122 | label[:, 2] = cy - l / 2.0 123 | label[:, 3] = cx + w / 2.0 124 | label[:, 4] = cy + l / 2.0 125 | return label 126 | 127 | def label_box_corner_to_center(self, label): 128 | """ 129 | param label: class, x_min, y_min, x_max, y_max, angle 130 | return: class, cx, cy, w, l, angle 131 | """ 132 | cx = (label[:, 1] + label[:, 3]) / 2.0 133 | cy = (label[:, 2] + label[:, 4]) / 2.0 134 | w = label[:, 3] - label[:, 1] 135 | l = label[:, 4] - label[:, 2] 136 | label[:, 1] = cx 137 | label[:, 2] = cy 138 | label[:, 3] = w 139 | label[:, 4] = l 140 | return label 141 | 142 | def data_generator(self): 143 | 144 | if self.load_to_memory: 145 | all_images = [] 146 | all_labels = [] 147 | all_index = [] 148 | for index in self.index_list: 149 | index = index.strip() 150 | label_file = self.labels_dir + index + '.txt' 151 | label = read_label_from_txt(label_file) 152 | image_path = self.images_dir + index + '.png' 153 | img = cv2.imread(image_path) 154 | if img is None: 155 | print('failed to load image:' + image_path) 156 | continue 157 | img = np.flip(img, 2) 158 | all_index.append(index) 159 | all_images.append(img) 160 | all_labels.append(label) 161 | sample_index = [i for i in range(len(all_index))] 162 | while True: 163 | np.random.shuffle(sample_index) 164 | for i in sample_index: 165 | self.img_index = np.copy(all_index[i]) 166 | self.label = np.copy(all_labels[i]) 167 | self.img = np.copy(all_images[i]) 168 | if self.aug_hsv: 169 | if np.random.random() > 0.5: 170 | self.img = self.augment_hsv(self.img) 171 | if self.flip: 172 | if np.random.random() > 0.5: 173 | self.img, self.label = self.horizontal_flip( 174 | self.img, self.label) 175 | if self.random_scale: 176 | self.label = self.label_box_center_to_corner( 177 | self.label) 178 | self.img, self.label = self.rand_scale_transform( 179 | self.img, self.label) 180 | self.label = self.label_box_corner_to_center( 181 | self.label) 182 | self.label_encoded = encode_label( 183 | self.label, self.anchors, img_w, img_h, grid_w, 184 | grid_h, iou_th) 185 | if self.mode == 'visualize': # Generate data for visualization 186 | yield self.img_index, self.img, self.label 187 | else: 188 | yield self.img_index, self.img / 255.0, self.label_encoded # Generate data for net 189 | 190 | else: 191 | while True: 192 | np.random.shuffle(self.index_list) 193 | for index in self.index_list: 194 | self.img_index = index.strip() 195 | label_file = self.labels_dir + self.img_index + '.txt' 196 | self.label = read_label_from_txt(label_file) 197 | image_path = self.images_dir + self.img_index + '.png' 198 | self.img = cv2.imread(image_path) 199 | if self.img is None: 200 | print('failed to load image:' + image_path) 201 | continue 202 | self.img = np.flip(self.img, 2) 203 | 204 | if self.aug_hsv: 205 | if np.random.random() > 0.5: 206 | self.img = self.augment_hsv(self.img) 207 | if self.flip: 208 | if np.random.random() > 0.5: 209 | self.img, self.label = self.horizontal_flip( 210 | self.img, self.label) 211 | if self.random_scale: 212 | self.label = self.label_box_center_to_corner( 213 | self.label) 214 | self.img, self.label = self.rand_scale_transform( 215 | self.img, self.label) 216 | self.label = self.label_box_corner_to_center( 217 | self.label) 218 | self.label_encoded = encode_label( 219 | self.label, self.anchors, img_w, img_h, grid_w, 220 | grid_h, iou_th) 221 | if self.mode == 'visualize': # Generate data for visualization 222 | yield self.img_index, self.img, self.label 223 | else: 224 | yield self.img_index, self.img / 255.0, self.label_encoded # Generate data for net 225 | 226 | def get_batch(self, batch_size): 227 | """ 228 | Generate a batch of data for model training 229 | param batch_size (int): 230 | 231 | """ 232 | img_batch = [] 233 | label_batch = [] 234 | i = 0 235 | for img_idx, img, label_encoded in self.data_generator(): 236 | i += 1 237 | img_batch.append(img) 238 | label_batch.append(label_encoded) 239 | if i % batch_size == 0: 240 | yield np.array(img_batch), np.array(label_batch) 241 | i = 0 242 | img_batch = [] 243 | label_batch = [] 244 | 245 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | SCALE = 32 4 | GRID_W, GRID_H = 32, 24 5 | N_CLASSES = 8 6 | N_ANCHORS = 5 7 | IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH = GRID_H * SCALE, GRID_W * SCALE, 3 8 | class_dict = { 9 | 'Car': 0, 10 | 'Van': 1, 11 | 'Truck': 2, 12 | 'Pedestrian': 3, 13 | 'Person_sitting': 4, 14 | 'Cyclist': 5, 15 | 'Tram': 6, 16 | 'Misc': 7 17 | } 18 | 19 | 20 | def leak_relu(x, leak): 21 | return tf.maximum(x, leak * x, name='relu') 22 | 23 | 24 | def max_pool_layer(x, size, stride, name): 25 | with tf.name_scope(name): 26 | x = tf.layers.max_pooling2d(x, size, stride, padding='SAME') 27 | return x 28 | 29 | 30 | def conv_layer(x, kernel, depth, train_logical, name): 31 | with tf.variable_scope(name): 32 | x = tf.layers.conv2d( 33 | x, 34 | depth, 35 | kernel, 36 | padding='SAME', 37 | kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 38 | bias_initializer=tf.zeros_initializer()) 39 | x = tf.layers.batch_normalization(x, 40 | training=train_logical, 41 | momentum=0.9, 42 | epsilon=0.001, 43 | center=True, 44 | scale=True) 45 | x = leak_relu(x, 0.2) 46 | # x = tf.nn.relu(x) 47 | return x 48 | 49 | 50 | def passthrough_layer(a, b, kernel, depth, size, train_logical, name): 51 | b = conv_layer(b, kernel, depth, train_logical, name) 52 | b = tf.space_to_depth(b, size) 53 | y = tf.concat([a, b], axis=3) 54 | return y 55 | 56 | 57 | def slice_tensor(x, start, end=None): 58 | """ 59 | Get tensor slices 60 | param x (array): 61 | param start (int): 62 | param end (int): 63 | return (array): 64 | """ 65 | if end < 0: 66 | y = x[..., start:] 67 | else: 68 | if end is None: 69 | end = start 70 | y = x[..., start:end + 1] 71 | return y 72 | 73 | 74 | def iou_wh(box1_wh, box2_wh): 75 | """ 76 | param box1_wh (list, tuple): Width and height of a box 77 | param box2_wh (list, tuple): Width and height of a box 78 | return (float): iou 79 | """ 80 | min_w = min(box1_wh[0], box2_wh[0]) 81 | min_h = min(box1_wh[1], box2_wh[1]) 82 | area_r1 = box1_wh[0] * box1_wh[1] 83 | area_r2 = box2_wh[0] * box2_wh[1] 84 | intersect = min_w * min_h 85 | union = area_r1 + area_r2 - intersect 86 | return intersect / union 87 | 88 | 89 | def get_grid_cell(roi, img_w, img_h, grid_w, grid_h): # roi[x, y, w, h, rz] 90 | """ 91 | Get the grid cell into which the object falls 92 | param roi : [x, y, w, h, rz] 93 | param img_w: The width of images 94 | param img_h: The height of images 95 | param grid_w: 96 | param grid_h: 97 | return (int, int): 98 | """ 99 | x_center = roi[0] 100 | y_center = roi[1] 101 | grid_x = np.minimum(int(grid_w * x_center / img_w), grid_w-1) 102 | grid_y = np.minimum(int(grid_h * y_center / img_h), grid_h-1) 103 | return grid_x, grid_y 104 | 105 | 106 | def get_active_anchors(box_w_h, anchors, iou_th): 107 | """ 108 | Get the index of the anchor that matches the ground truth box 109 | param box_w_h (list, tuple): Width and height of a box 110 | param anchors (array): anchors 111 | param iou_th: Match threshold 112 | return (list): 113 | """ 114 | index = [] 115 | iou_max, index_max = 0, 0 116 | for i, a in enumerate(anchors): 117 | iou = iou_wh(box_w_h, a) 118 | if iou > iou_th: 119 | index.append(i) 120 | if iou > iou_max: 121 | iou_max, index_max = iou, i 122 | if len(index) == 0: 123 | index.append(index_max) 124 | return index 125 | 126 | 127 | def roi2label(roi, anchor, img_w, img_h, grid_w, grid_h): 128 | """ 129 | Encode the label to match the model output format 130 | param roi: x, y, w, h, angle 131 | 132 | return: encoded label 133 | """ 134 | x_center = roi[0] 135 | y_center = roi[1] 136 | w = grid_w * roi[2] / img_w 137 | h = grid_h * roi[3] / img_h 138 | anchor_w = grid_w * anchor[0] / img_w 139 | anchor_h = grid_h * anchor[1] / img_h 140 | grid_x = grid_w * x_center / img_w 141 | grid_y = grid_h * y_center / img_h 142 | grid_x_offset = grid_x - int(grid_x) 143 | grid_y_offset = grid_y - int(grid_y) 144 | roi_w_scale = np.log(w / anchor_w + 1e-16) 145 | roi_h_scale = np.log(h / anchor_h + 1e-16) 146 | re = np.cos(roi[4]) 147 | im = np.sin(roi[4]) 148 | label = [grid_x_offset, grid_y_offset, roi_w_scale, roi_h_scale, re, im] 149 | return label 150 | 151 | 152 | def encode_label(labels, anchors, img_w, img_h, grid_w, grid_h, iou_th): 153 | """ 154 | Encode the label to match the model output format 155 | param labels (array): x, y, w, h, angle 156 | param anchors (array): anchors 157 | return: encoded label 158 | """ 159 | anchors_on_image = np.array([img_w, img_h]) * anchors / np.array([80, 60]) 160 | n_anchors = np.shape(anchors_on_image)[0] 161 | label_encoded = np.zeros([grid_h, grid_w, n_anchors, (6 + 1 + 1)], 162 | dtype=np.float32) 163 | for i in range(labels.shape[0]): 164 | rois = labels[i][1:] 165 | classes = np.array(labels[i][0], dtype=np.int32) 166 | active_indexes = get_active_anchors(rois[2:4], anchors_on_image, iou_th) 167 | grid_x, grid_y = get_grid_cell(rois, img_w, img_h, grid_w, grid_h) 168 | for active_index in active_indexes: 169 | anchor_label = roi2label(rois, anchors_on_image[active_index], 170 | img_w, img_h, grid_w, grid_h) 171 | label_encoded[grid_y, grid_x, active_index] = np.concatenate( 172 | (anchor_label, [classes], [1.0])) 173 | return label_encoded 174 | 175 | 176 | def yolo_net(x, train_logical): 177 | """darknet""" 178 | x = conv_layer(x, (3, 3), 24, train_logical, 'conv1') 179 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool1') 180 | x = conv_layer(x, (3, 3), 48, train_logical, 'conv2') 181 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool2') 182 | 183 | x = conv_layer(x, (3, 3), 64, train_logical, 'conv3') 184 | x = conv_layer(x, (1, 1), 32, train_logical, 'conv4') 185 | x = conv_layer(x, (3, 3), 64, train_logical, 'conv5') 186 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool5') 187 | 188 | x = conv_layer(x, (3, 3), 128, train_logical, 'conv6') 189 | x = conv_layer(x, (1, 1), 64, train_logical, 'conv7') 190 | x = conv_layer(x, (3, 3), 128, train_logical, 'conv8') 191 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool8') 192 | 193 | # x = conv_layer(x, (3, 3), 512, train_logical, 'conv9') 194 | # x = conv_layer(x, (1, 1), 256, train_logical, 'conv10') 195 | x = conv_layer(x, (3, 3), 512, train_logical, 'conv11') 196 | x = conv_layer(x, (1, 1), 256, train_logical, 'conv12') 197 | passthrough = conv_layer(x, (3, 3), 512, train_logical, 'conv13') 198 | x = max_pool_layer(passthrough, (2, 2), (2, 2), 'maxpool13') 199 | 200 | # x = conv_layer(x, (3, 3), 1024, train_logical, 'conv14') 201 | # x = conv_layer(x, (1, 1), 512, train_logical, 'conv15') 202 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv16') 203 | x = conv_layer(x, (1, 1), 512, train_logical, 'conv17') 204 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv18') 205 | 206 | x = passthrough_layer(x, passthrough, (3, 3), 64, 2, train_logical, 207 | 'conv21') 208 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv19') 209 | x = conv_layer(x, (1, 1), N_ANCHORS * (7 + N_CLASSES), train_logical, 210 | 'conv20') # x,y,w,l,re,im,conf + 8 class 211 | y = tf.reshape(x, 212 | shape=(-1, GRID_H, GRID_W, N_ANCHORS, 7 + N_CLASSES), 213 | name='y') 214 | return y 215 | 216 | 217 | def yolo_loss(pred, label, batch_size): 218 | mask = slice_tensor(label, 7, 7) 219 | label = slice_tensor(label, 0, 6) 220 | mask = tf.cast(tf.reshape(mask, shape=(-1, GRID_H, GRID_W, N_ANCHORS)), 221 | tf.bool) 222 | with tf.name_scope('mask'): 223 | masked_label = tf.boolean_mask(label, mask) 224 | masked_pred = tf.boolean_mask(pred, mask) 225 | neg_masked_pred = tf.boolean_mask(pred, tf.logical_not(mask)) 226 | with tf.name_scope('pred'): 227 | masked_pred_xy = tf.sigmoid(slice_tensor(masked_pred, 0, 1)) 228 | masked_pred_wh = slice_tensor(masked_pred, 2, 3) 229 | masked_pred_re = 2 * tf.sigmoid(slice_tensor(masked_pred, 4, 4)) - 1 230 | masked_pred_im = 2 * tf.sigmoid(slice_tensor(masked_pred, 5, 5)) - 1 231 | masked_pred_o = tf.sigmoid(slice_tensor(masked_pred, 6, 6)) 232 | 233 | masked_pred_no_o = tf.sigmoid(slice_tensor(neg_masked_pred, 6, 6)) 234 | # masked_pred_c = tf.nn.sigmoid(slice_tensor(masked_pred, 7, -1)) 235 | masked_pred_c = tf.nn.softmax(slice_tensor(masked_pred, 7, -1)) 236 | # masked_pred_no_c = tf.nn.sigmoid(slice_tensor(neg_masked_pred, 7, -1)) 237 | # print (masked_pred_c, masked_pred_o, masked_pred_no_o) 238 | 239 | with tf.name_scope('lab'): 240 | masked_label_xy = slice_tensor(masked_label, 0, 1) 241 | masked_label_wh = slice_tensor(masked_label, 2, 3) 242 | masked_label_re = slice_tensor(masked_label, 4, 4) 243 | masked_label_im = slice_tensor(masked_label, 5, 5) 244 | masked_label_class = slice_tensor(masked_label, 6, 6) 245 | masked_label_class_vec = tf.reshape(tf.one_hot(tf.cast( 246 | masked_label_class, tf.int32), 247 | depth=N_CLASSES), 248 | shape=(-1, N_CLASSES)) 249 | with tf.name_scope('merge'): 250 | with tf.name_scope('loss_xy'): 251 | loss_xy = tf.reduce_sum( 252 | tf.square(masked_pred_xy - masked_label_xy)) / batch_size 253 | with tf.name_scope('loss_wh'): 254 | loss_wh = tf.reduce_sum( 255 | tf.square(masked_pred_wh - masked_label_wh)) / batch_size 256 | with tf.name_scope('loss_re'): 257 | loss_re = tf.reduce_sum( 258 | tf.square(masked_pred_re - masked_label_re)) / batch_size 259 | with tf.name_scope('loss_im'): 260 | loss_im = tf.reduce_sum( 261 | tf.square(masked_pred_im - masked_label_im)) / batch_size 262 | with tf.name_scope('loss_obj'): 263 | loss_obj = tf.reduce_sum(tf.square(masked_pred_o - 1)) / batch_size 264 | # loss_obj = tf.reduce_sum(-tf.log(masked_pred_o+0.000001))*10 265 | with tf.name_scope('loss_no_obj'): 266 | loss_no_obj = tf.reduce_sum( 267 | tf.square(masked_pred_no_o)) * 0.5 / batch_size 268 | # loss_no_obj = tf.reduce_sum(-tf.log(1-masked_pred_no_o+0.000001)) 269 | with tf.name_scope('loss_class'): 270 | # loss_c = tf.reduce_sum(tf.square(masked_pred_c - masked_label_c_vec)) 271 | loss_c = (tf.reduce_sum(-tf.log(masked_pred_c + 0.000001) * masked_label_class_vec) 272 | + tf.reduce_sum(-tf.log(1 - masked_pred_c + 0.000001) * (1 - masked_label_class_vec))) / batch_size 273 | # + tf.reduce_sum(-tf.log(1 - masked_pred_no_c+0.000001)) * 0.1 274 | # loss = (loss_xy + loss_wh+ loss_re + loss_im+ lambda_coord*loss_obj) + lambda_no_obj*loss_no_obj + loss_c 275 | loss = (loss_xy + loss_wh + loss_re + 276 | loss_im) * 5 + loss_obj + loss_no_obj + loss_c 277 | return loss, loss_xy, loss_wh, loss_re, loss_im, loss_obj, loss_no_obj, loss_c 278 | -------------------------------------------------------------------------------- /utils/kitti_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import numpy as np 4 | import cv2 5 | 6 | # classes 7 | class_list = [ 8 | 'Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 9 | 'Misc' 10 | ] 11 | class_dict = { 12 | 'Car': 0, 13 | 'Van': 1, 14 | 'Truck': 2, 15 | 'Pedestrian': 3, 16 | 'Person_sitting': 4, 17 | 'Cyclist': 5, 18 | 'Tram': 6, 19 | 'Misc': 7 20 | } 21 | 22 | 23 | def read_label_from_txt(label_path): 24 | """ 25 | Read label from txt file. 26 | label: class, cx, cy, w, l, rz(rotation angle) 27 | """ 28 | bounding_box = [] 29 | with open(label_path, "r") as f: 30 | labels = f.readlines() 31 | for label in labels: 32 | if not label: 33 | continue 34 | label = label.strip().split(' ') 35 | label[0] = class_dict[label[0]] 36 | bounding_box.append(label) 37 | if bounding_box: 38 | return np.array(bounding_box, dtype=np.float32) 39 | else: 40 | return None 41 | 42 | 43 | def read_anchors_from_file(file_path): 44 | """ 45 | Read anchors from the configuration file 46 | """ 47 | anchors = [] 48 | with open(file_path, 'r') as file: 49 | for line in file.read().splitlines(): 50 | anchors.append(list(map(float, line.split()))) 51 | return np.array(anchors) 52 | 53 | 54 | def read_class_flag(file_path): 55 | """ 56 | Read class flags for visualization 57 | """ 58 | classes, names, colors = [], [], [] 59 | with open(file_path, 'r') as file: 60 | lines = file.read().splitlines() 61 | for line in lines: 62 | cls, name, color = line.split() 63 | classes.append(int(cls)) 64 | names.append(name) 65 | colors.append(eval(color)) 66 | return classes, names, colors 67 | 68 | 69 | def calculate_angle(im, re): 70 | """ 71 | param: im(float): imaginary parts of the plural 72 | param: re(float): real parts of the plural 73 | return: The angle at which the objects rotate 74 | around the Z axis in the velodyne coordinate system 75 | """ 76 | if re > 0: 77 | return np.arctan(im / re) 78 | elif im < 0: 79 | return -np.pi + np.arctan(im / re) 80 | else: 81 | return np.pi + np.arctan(im / re) 82 | 83 | 84 | def draw_rotated_box(img, cy, cx, w, h, angle, color): 85 | """ 86 | param: img(array): RGB image 87 | param: cy(int, float): Here cy is cx in the image coordinate system 88 | param: cx(int, float): Here cx is cy in the image coordinate system 89 | param: w(int, float): box's width 90 | param: h(int, float): box's height 91 | param: angle(float): rz 92 | param: color(tuple, list): the color of box, (R, G, B) 93 | """ 94 | left = int(cy - w / 2) 95 | top = int(cx - h / 2) 96 | right = int(cx + h / 2) 97 | bottom = int(cy + h / 2) 98 | ro = np.sqrt(pow(left - cy, 2) + pow(top - cx, 2)) 99 | a1 = np.arctan((w / 2) / (h / 2)) 100 | a2 = -np.arctan((w / 2) / (h / 2)) 101 | a3 = -np.pi + a1 102 | a4 = np.pi - a1 103 | rotated_p1_y = cy + int(ro * np.sin(angle + a1)) 104 | rotated_p1_x = cx + int(ro * np.cos(angle + a1)) 105 | rotated_p2_y = cy + int(ro * np.sin(angle + a2)) 106 | rotated_p2_x = cx + int(ro * np.cos(angle + a2)) 107 | rotated_p3_y = cy + int(ro * np.sin(angle + a3)) 108 | rotated_p3_x = cx + int(ro * np.cos(angle + a3)) 109 | rotated_p4_y = cy + int(ro * np.sin(angle + a4)) 110 | rotated_p4_x = cx + int(ro * np.cos(angle + a4)) 111 | center_p1p2y = int((rotated_p1_y + rotated_p2_y) * 0.5) 112 | center_p1p2x = int((rotated_p1_x + rotated_p2_x) * 0.5) 113 | cv2.line(img, (rotated_p1_y, rotated_p1_x), (rotated_p2_y, rotated_p2_x), 114 | color, 1) 115 | cv2.line(img, (rotated_p2_y, rotated_p2_x), (rotated_p3_y, rotated_p3_x), 116 | color, 1) 117 | cv2.line(img, (rotated_p3_y, rotated_p3_x), (rotated_p4_y, rotated_p4_x), 118 | color, 1) 119 | cv2.line(img, (rotated_p4_y, rotated_p4_x), (rotated_p1_y, rotated_p1_x), 120 | color, 1) 121 | cv2.line(img, (center_p1p2y, center_p1p2x), (cy, cx), color, 1) 122 | 123 | 124 | def get_corner_gtbox(box): 125 | """ 126 | param: box(tuple, list): cx, cy, w, l 127 | return: (tuple): x_min, y_min, x_max, y_max 128 | """ 129 | bx = box[0] 130 | by = box[1] 131 | bw = box[2] 132 | bl = box[3] 133 | top = int((by - bl / 2.0)) 134 | left = int((bx - bw / 2.0)) 135 | right = int((bx + bw / 2.0)) 136 | bottom = int((by + bl / 2.0)) 137 | return left, top, right, bottom 138 | 139 | 140 | def remove_points(point_cloud, boundary_condition): 141 | """ 142 | param point_cloud(array): Original point cloud data 143 | param boundary_condition(dict): The boundary of the area of interest 144 | return (array): Point cloud data within the area of interest 145 | """ 146 | # Boundary condition 147 | min_x = boundary_condition['minX'] 148 | max_x = boundary_condition['maxX'] 149 | min_y = boundary_condition['minY'] 150 | max_y = boundary_condition['maxY'] 151 | min_z = boundary_condition['minZ'] 152 | max_z = boundary_condition['maxZ'] 153 | # Remove the point out of range x,y,z 154 | mask = np.where((point_cloud[:, 0] >= min_x) & (point_cloud[:, 0] <= max_x) 155 | & (point_cloud[:, 1] >= min_y) & (point_cloud[:, 1] <= max_y) 156 | & (point_cloud[:, 2] >= min_z) & (point_cloud[:, 2] <= max_z)) 157 | point_cloud = point_cloud[mask] 158 | point_cloud[:, 2] = point_cloud[:, 2] + 2 159 | return point_cloud 160 | 161 | 162 | def make_bv_feature(point_cloud_): 163 | """ 164 | param point_cloud_ (array): Point cloud data within the area of interest 165 | return (array): RGB map 166 | """ 167 | # 1024 x 1024 x 3 168 | Height = 1024 + 1 169 | Width = 1024 + 1 170 | # Discretize Feature Map 171 | point_cloud = np.copy(point_cloud_) 172 | point_cloud[:, 0] = np.int_(np.floor(point_cloud[:, 0] / 60.0 * 768)) 173 | point_cloud[:, 1] = np.int_( 174 | np.floor(point_cloud[:, 1] / 40.0 * 512) + Width / 2) 175 | # sort-3times 176 | indices = np.lexsort( 177 | (-point_cloud[:, 2], point_cloud[:, 1], point_cloud[:, 0])) 178 | point_cloud = point_cloud[indices] 179 | # Height Map 180 | height_map = np.zeros((Height, Width)) 181 | _, indices = np.unique(point_cloud[:, 0:2], axis=0, return_index=True) 182 | point_cloud_frac = point_cloud[indices] 183 | # some important problem is image coordinate is (y,x), not (x,y) 184 | height_map[np.int_(point_cloud_frac[:, 0]), 185 | np.int_(point_cloud_frac[:, 1])] = point_cloud_frac[:, 2] 186 | # Intensity Map & DensityMap 187 | intensity_map = np.zeros((Height, Width)) 188 | density_map = np.zeros((Height, Width)) 189 | _, indices, counts = np.unique(point_cloud[:, 0:2], 190 | axis=0, 191 | return_index=True, 192 | return_counts=True) 193 | point_cloud_top = point_cloud[indices] 194 | normalized_counts = np.minimum(1.0, np.log(counts + 1) / np.log(64)) 195 | intensity_map[np.int_(point_cloud_top[:, 0]), 196 | np.int_(point_cloud_top[:, 1])] = point_cloud_top[:, 3] 197 | density_map[np.int_(point_cloud_top[:, 0]), 198 | np.int_(point_cloud_top[:, 1])] = normalized_counts 199 | 200 | rgb_map = np.zeros((Height, Width, 3)) 201 | rgb_map[:, :, 0] = density_map # r_map 202 | rgb_map[:, :, 1] = height_map / 3.26 # g_map 203 | rgb_map[:, :, 2] = intensity_map # b_map 204 | 205 | save = np.zeros((768, 1024, 3)) 206 | save = rgb_map[0:768, 0:1024, :] 207 | return save 208 | 209 | 210 | def get_target(label_file, transform): 211 | """ 212 | 213 | param label_file (str): The kitti label path 214 | param transform (array): Coordinate transformation matrix 215 | return (array): label 216 | """ 217 | target = np.zeros([50, 6], dtype=np.float32) 218 | with open(label_file, 'r') as f: 219 | lines = f.readlines() 220 | num_obj = len(lines) 221 | index = 0 222 | for j in range(num_obj): 223 | obj = lines[j].strip().split(' ') 224 | obj_class = obj[0].strip() 225 | if obj_class in class_list: 226 | t_lidar, box3d_corner, rz = box3d_cam_to_velo( 227 | obj[8:], transform) # get target 3D object location x,y 228 | location_x = t_lidar[0][0] 229 | location_y = t_lidar[0][1] 230 | if (location_x > 0) & (location_x < 60) & (location_y > -40) & ( 231 | location_y < 40): 232 | target[index][2] = t_lidar[0][0] / 60.0 # make sure target inside the covering area (0,1) 233 | target[index][1] = (t_lidar[0][1] + 40) / 80.0 # we should put this in [0,1], so divide max_size 80 m 234 | obj_width = obj[9].strip() 235 | obj_length = obj[10].strip() 236 | target[index][3] = float(obj_width) / 80.0 237 | target[index][4] = float(obj_length) / 60.0 # get target width ,length 238 | target[index][5] = rz 239 | for i in range(len(class_list)): 240 | if obj_class == class_list[i]: # get target class 241 | target[index][0] = i 242 | index = index + 1 243 | return target 244 | 245 | 246 | def box3d_cam_to_velo(box3d, transform): 247 | def project_cam2velo(cam, transform): 248 | T = np.zeros([4, 4], dtype=np.float32) 249 | T[:3, :] = transform 250 | T[3, 3] = 1 251 | T_inv = np.linalg.inv(T) 252 | lidar_loc_ = np.dot(T_inv, cam) 253 | lidar_loc = lidar_loc_[:3] 254 | return lidar_loc.reshape(1, 3) 255 | 256 | def ry_to_rz(ry): 257 | """ 258 | param ry (float): yaw angle in cam coordinate system 259 | return: (flaot): yaw angle in velodyne coordinate system 260 | """ 261 | angle = -ry - np.pi / 2 262 | if angle >= np.pi: 263 | angle -= np.pi 264 | if angle < -np.pi: 265 | angle = 2 * np.pi + angle 266 | return angle 267 | 268 | h, w, l, tx, ty, tz, ry = [float(i) for i in box3d] 269 | cam = np.ones([4, 1]) 270 | cam[0] = tx 271 | cam[1] = ty 272 | cam[2] = tz 273 | t_lidar = project_cam2velo(cam, transform) 274 | box = np.array( 275 | [[-l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2], 276 | [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], 277 | [0, 0, 0, 0, h, h, h, h]]) 278 | 279 | rz = ry_to_rz(ry) 280 | rot_mat = np.array([[np.cos(rz), -np.sin(rz), 0.0], 281 | [np.sin(rz), np.cos(rz), 0.0], [0.0, 0.0, 1.0]]) 282 | velo_box = np.dot(rot_mat, box) 283 | corner_pos_in_velo = velo_box + np.tile(t_lidar, (8, 1)).T 284 | box3d_corner = corner_pos_in_velo.transpose() 285 | return t_lidar, box3d_corner.astype(np.float32), rz 286 | 287 | 288 | def load_kitti_calib(calib_file): 289 | """ 290 | load projection matrix 291 | """ 292 | with open(calib_file) as fi: 293 | lines = fi.readlines() 294 | assert (len(lines) == 8) 295 | 296 | obj = lines[0].strip().split(' ')[1:] 297 | P0 = np.array(obj, dtype=np.float32) 298 | obj = lines[1].strip().split(' ')[1:] 299 | P1 = np.array(obj, dtype=np.float32) 300 | obj = lines[2].strip().split(' ')[1:] 301 | P2 = np.array(obj, dtype=np.float32) 302 | obj = lines[3].strip().split(' ')[1:] 303 | P3 = np.array(obj, dtype=np.float32) 304 | obj = lines[4].strip().split(' ')[1:] 305 | R0 = np.array(obj, dtype=np.float32) 306 | obj = lines[5].strip().split(' ')[1:] 307 | Tr_velo_to_cam = np.array(obj, dtype=np.float32) 308 | obj = lines[6].strip().split(' ')[1:] 309 | Tr_imu_to_velo = np.array(obj, dtype=np.float32) 310 | 311 | return { 312 | 'P2': P2.reshape(3, 4), 313 | 'R0': R0.reshape(3, 3), 314 | 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4) 315 | } 316 | 317 | 318 | def angle_rz_to_ry(rz): 319 | """ 320 | param rz (float): yaw angle in velodyne coordinate system 321 | return (float): yaw angle in cam coordinate system 322 | """ 323 | angle = -rz - np.pi / 2 324 | if angle < -np.pi: 325 | angle = 2 * np.pi + angle 326 | return angle 327 | 328 | 329 | def coord_image_to_velo(hy, wx): 330 | """ 331 | Convert image coordinates to velodyne coordinates 332 | """ 333 | velo_x = hy * 60 / 768.0 334 | velo_y = (wx - 512) * 40 / 512.0 335 | return velo_x, velo_y 336 | 337 | 338 | def coord_velo_to_cam(velo_x, velo_y, transform): 339 | """ 340 | Convert velodyne coordinates to image coordinates 341 | """ 342 | T = np.zeros([4, 4], dtype=np.float32) 343 | T[:3, :] = transform 344 | T[3, 3] = 1 345 | velo_coord = np.array([[velo_x], [velo_y], [1.5], [1]]) 346 | cam_coord = np.dot(T, velo_coord) 347 | return cam_coord[0][0], cam_coord[2][0] 348 | 349 | 350 | # anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52] 351 | -------------------------------------------------------------------------------- /dataset/augument.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data augmentation operations of the original SSD implementation. 3 | Copyright (C) 2018 Pierluigi Ferrari 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | """ 14 | 15 | import cv2 16 | import inspect 17 | import numpy as np 18 | 19 | 20 | class RandomScaleAugmentation: 21 | """ 22 | Reproduces the data augmentation pipeline used in the training of the original 23 | Caffe implementation of SSD. 24 | """ 25 | 26 | def __init__(self, 27 | img_height=768, 28 | img_width=1024, 29 | labels_format={ 30 | 'class_id': 0, 31 | 'xmin': 1, 32 | 'ymin': 2, 33 | 'xmax': 3, 34 | 'ymax': 4 35 | }): 36 | """ 37 | Arguments: 38 | height (int): The desired height of the output images in pixels. 39 | width (int): The desired width of the output images in pixels. 40 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the 41 | background pixels of the translated images. 42 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 43 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 44 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 45 | """ 46 | 47 | self.labels_format = labels_format 48 | self.random_crop = SSDRandomCrop(labels_format=self.labels_format) 49 | # This box filter makes sure that the resized images don't contain any degenerate boxes. 50 | # Resizing the images could lead the boxes to becomes smaller. For boxes that are already 51 | # pretty small, that might result in boxes with height and/or width zero, which we obviously 52 | # cannot allow. 53 | self.box_filter = BoxFilter(check_overlap=False, 54 | check_min_area=False, 55 | check_degenerate=True, 56 | labels_format=self.labels_format) 57 | 58 | self.resize = ResizeRandomInterp(height=img_height, 59 | width=img_width, 60 | interpolation_modes=[ 61 | cv2.INTER_NEAREST, 62 | cv2.INTER_LINEAR, cv2.INTER_CUBIC, 63 | cv2.INTER_AREA, cv2.INTER_LANCZOS4 64 | ], 65 | box_filter=self.box_filter, 66 | labels_format=self.labels_format) 67 | 68 | self.sequence = [self.random_crop, self.resize] 69 | 70 | def __call__(self, image, labels, return_inverter=False): 71 | self.random_crop.labels_format = self.labels_format 72 | self.resize.labels_format = self.labels_format 73 | inverters = [] 74 | for transform in self.sequence: 75 | if return_inverter and ('return_inverter' in inspect.signature( 76 | transform).parameters): 77 | image, labels, inverter = transform(image, 78 | labels, 79 | return_inverter=True) 80 | inverters.append(inverter) 81 | else: 82 | image, labels = transform(image, labels) 83 | 84 | if return_inverter: 85 | return image, labels, inverters[::-1] 86 | else: 87 | return image, labels 88 | 89 | 90 | class SSDRandomCrop: 91 | """ 92 | Performs the same random crops as defined by the `batch_sampler` instructions 93 | of the original Caffe implementation of SSD. A description of this random cropping 94 | strategy can also be found in the data augmentation section of the paper: 95 | https://arxiv.org/abs/1512.02325 96 | """ 97 | 98 | def __init__(self, 99 | labels_format={ 100 | 'class_id': 0, 101 | 'xmin': 1, 102 | 'ymin': 2, 103 | 'xmax': 3, 104 | 'ymax': 4 105 | }): 106 | """ 107 | Arguments: 108 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 109 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 110 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 111 | """ 112 | 113 | self.labels_format = labels_format 114 | 115 | # This randomly samples one of the lower IoU bounds defined 116 | # by the `sample_space` every time it is called. 117 | self.bound_generator = BoundGenerator( 118 | sample_space=((None, None), (0.1, None), (0.3, None), (0.5, None), 119 | (0.7, None), (0.9, None)), 120 | weights=None) 121 | 122 | # Produces coordinates for candidate patches such that the height 123 | # and width of the patches are between 0.3 and 1.0 of the height 124 | # and width of the respective image and the aspect ratio of the 125 | # patches is between 0.5 and 2.0. 126 | self.patch_coord_generator = PatchCoordinateGenerator( 127 | must_match='h_w', 128 | min_scale=0.6, 129 | max_scale=1.0, 130 | scale_uniformly=False, 131 | min_aspect_ratio=0.5, 132 | max_aspect_ratio=2.0) 133 | 134 | # Filters out boxes whose center point does not lie within the 135 | # chosen patches. 136 | self.box_filter = BoxFilter(check_overlap=True, 137 | check_min_area=False, 138 | check_degenerate=False, 139 | overlap_criterion='center_point', 140 | labels_format=self.labels_format) 141 | 142 | # Determines whether a given patch is considered a valid patch. 143 | # Defines a patch to be valid if at least one ground truth bounding box 144 | # (n_boxes_min == 1) has an IoU overlap with the patch that 145 | # meets the requirements defined by `bound_generator`. 146 | self.image_validator = ImageValidator(overlap_criterion='iou', 147 | n_boxes_min=1, 148 | labels_format=self.labels_format, 149 | border_pixels='half') 150 | 151 | # Performs crops according to the parameters set in the objects above. 152 | # Runs until either a valid patch is found or the original input image 153 | # is returned unaltered. Runs a maximum of 50 trials to find a valid 154 | # patch for each new sampled IoU threshold. Every 50 trials, the original 155 | # image is returned as is with probability (1 - prob) = 0.143. 156 | self.random_crop = RandomPatchInf( 157 | patch_coord_generator=self.patch_coord_generator, 158 | box_filter=self.box_filter, 159 | image_validator=self.image_validator, 160 | bound_generator=self.bound_generator, 161 | n_trials_max=50, 162 | clip_boxes=True, 163 | prob=0.857, 164 | labels_format=self.labels_format) 165 | 166 | def __call__(self, image, labels=None, return_inverter=False): 167 | self.random_crop.labels_format = self.labels_format 168 | return self.random_crop(image, labels, return_inverter) 169 | 170 | 171 | class BoundGenerator: 172 | """ 173 | Generates pairs of floating point values that represent lower and upper bounds 174 | from a given sample space. 175 | """ 176 | 177 | def __init__(self, 178 | sample_space=((0.1, None), (0.3, None), (0.5, None), 179 | (0.7, None), (0.9, None), (None, None)), 180 | weights=None): 181 | """ 182 | Arguments: 183 | sample_space (list or tuple): A list, tuple, or array-like object of shape 184 | `(n, 2)` that contains `n` samples to choose from, where each sample 185 | is a 2-tuple of scalars and/or `None` values. 186 | weights (list or tuple, optional): A list or tuple representing the distribution 187 | over the sample space. If `None`, a uniform distribution will be assumed. 188 | """ 189 | 190 | if (not (weights is None)) and len(weights) != len(sample_space): 191 | raise ValueError( 192 | "`weights` must either be `None` for uniform distribution or have the same length as `sample_space`." 193 | ) 194 | 195 | self.sample_space = [] 196 | for bound_pair in sample_space: 197 | if len(bound_pair) != 2: 198 | raise ValueError( 199 | "All elements of the sample space must be 2-tuples.") 200 | bound_pair = list(bound_pair) 201 | if bound_pair[0] is None: bound_pair[0] = 0.0 202 | if bound_pair[1] is None: bound_pair[1] = 1.0 203 | if bound_pair[0] > bound_pair[1]: 204 | raise ValueError( 205 | "For all sample space elements, the lower bound " 206 | "cannot be greater than the upper bound.") 207 | self.sample_space.append(bound_pair) 208 | 209 | self.sample_space_size = len(self.sample_space) 210 | 211 | if weights is None: 212 | self.weights = [1.0 / self.sample_space_size 213 | ] * self.sample_space_size 214 | else: 215 | self.weights = weights 216 | 217 | def __call__(self): 218 | """ 219 | Returns: 220 | An item of the sample space, i.e. a 2-tuple of scalars. 221 | """ 222 | i = np.random.choice(self.sample_space_size, p=self.weights) 223 | return self.sample_space[i] 224 | 225 | 226 | class BoxFilter: 227 | """ 228 | Returns all bounding boxes that are valid with respect to a the defined criteria. 229 | """ 230 | 231 | def __init__(self, 232 | check_overlap=True, 233 | check_min_area=True, 234 | check_degenerate=True, 235 | overlap_criterion='center_point', 236 | overlap_bounds=(0.3, 1.0), 237 | min_area=10, 238 | labels_format={ 239 | 'class_id': 0, 240 | 'xmin': 1, 241 | 'ymin': 2, 242 | 'xmax': 3, 243 | 'ymax': 4 244 | }, 245 | border_pixels='half'): 246 | """ 247 | Arguments: 248 | check_overlap (bool, optional): Whether or not to enforce the overlap requirements defined by 249 | `overlap_criterion` and `overlap_bounds`. Sometimes you might want to use the box filter only 250 | to enforce a certain minimum area for all boxes (see next argument), in such cases you can 251 | turn the overlap requirements off. 252 | check_min_area (bool, optional): Whether or not to enforce the minimum area requirement defined 253 | by `min_area`. If `True`, any boxes that have an area (in pixels) that is smaller than `min_area` 254 | will be removed from the labels of an image. Bounding boxes below a certain area aren't useful 255 | training examples. An object that takes up only, say, 5 pixels in an image is probably not 256 | recognizable anymore, neither for a human, nor for an object detection model. It makes sense 257 | to remove such boxes. 258 | check_degenerate (bool, optional): Whether or not to check for and remove degenerate bounding boxes. 259 | Degenerate bounding boxes are boxes that have `xmax <= xmin` and/or `ymax <= ymin`. In particular, 260 | boxes with a width and/or height of zero are degenerate. It is obviously important to filter out 261 | such boxes, so you should only set this option to `False` if you are certain that degenerate 262 | boxes are not possible in your data and processing chain. 263 | overlap_criterion (str, optional): Can be either of 'center_point', 'iou', or 'area'. Determines 264 | which boxes are considered valid with respect to a given image. If set to 'center_point', 265 | a given bounding box is considered valid if its center point lies within the image. 266 | If set to 'area', a given bounding box is considered valid if the quotient of its intersection 267 | area with the image and its own area is within the given `overlap_bounds`. If set to 'iou', a given 268 | bounding box is considered valid if its IoU with the image is within the given `overlap_bounds`. 269 | overlap_bounds (list or BoundGenerator, optional): Only relevant if `overlap_criterion` is 'area' or 'iou'. 270 | Determines the lower and upper bounds for `overlap_criterion`. Can be either a 2-tuple of scalars 271 | representing a lower bound and an upper bound, or a `BoundGenerator` object, which provides 272 | the possibility to generate bounds randomly. 273 | min_area (int, optional): Only relevant if `check_min_area` is `True`. Defines the minimum area in 274 | pixels that a bounding box must have in order to be valid. Boxes with an area smaller than this 275 | will be removed. 276 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 277 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 278 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 279 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes. 280 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong 281 | to the boxes. If 'exclude', the border pixels do not belong to the boxes. 282 | If 'half', then one of each of the two horizontal and vertical borders belong 283 | to the boxex, but not the other. 284 | """ 285 | if not isinstance(overlap_bounds, (list, tuple, BoundGenerator)): 286 | raise ValueError( 287 | "`overlap_bounds` must be either a 2-tuple of scalars or a `BoundGenerator` object." 288 | ) 289 | if isinstance( 290 | overlap_bounds, 291 | (list, tuple)) and (overlap_bounds[0] > overlap_bounds[1]): 292 | raise ValueError( 293 | "The lower bound must not be greater than the upper bound.") 294 | if not (overlap_criterion in {'iou', 'area', 'center_point'}): 295 | raise ValueError( 296 | "`overlap_criterion` must be one of 'iou', 'area', or 'center_point'." 297 | ) 298 | self.overlap_criterion = overlap_criterion 299 | self.overlap_bounds = overlap_bounds 300 | self.min_area = min_area 301 | self.check_overlap = check_overlap 302 | self.check_min_area = check_min_area 303 | self.check_degenerate = check_degenerate 304 | self.labels_format = labels_format 305 | self.border_pixels = border_pixels 306 | 307 | def __call__(self, labels, image_height=None, image_width=None): 308 | """ 309 | Arguments: 310 | labels (array): The labels to be filtered. This is an array with shape `(m,n)`, where 311 | `m` is the number of bounding boxes and `n` is the number of elements that defines 312 | each bounding box (box coordinates, class ID, etc.). The box coordinates are expected 313 | to be in the image's coordinate system. 314 | image_height (int): Only relevant if `check_overlap == True`. The height of the image 315 | (in pixels) to compare the box coordinates to. 316 | image_width (int): `check_overlap == True`. The width of the image (in pixels) to compare 317 | the box coordinates to. 318 | 319 | Returns: 320 | An array containing the labels of all boxes that are valid. 321 | """ 322 | 323 | labels = np.copy(labels) 324 | 325 | xmin = self.labels_format['xmin'] 326 | ymin = self.labels_format['ymin'] 327 | xmax = self.labels_format['xmax'] 328 | ymax = self.labels_format['ymax'] 329 | 330 | # Record the boxes that pass all checks here. 331 | requirements_met = np.ones(shape=labels.shape[0], dtype=np.bool) 332 | 333 | if self.check_degenerate: 334 | 335 | non_degenerate = (labels[:, xmax] > labels[:, xmin]) * ( 336 | labels[:, ymax] > labels[:, ymin]) 337 | requirements_met *= non_degenerate 338 | 339 | if self.check_min_area: 340 | 341 | min_area_met = (labels[:, xmax] - labels[:, xmin]) * ( 342 | labels[:, ymax] - labels[:, ymin]) >= self.min_area 343 | requirements_met *= min_area_met 344 | 345 | if self.check_overlap: 346 | 347 | # Get the lower and upper bounds. 348 | if isinstance(self.overlap_bounds, BoundGenerator): 349 | lower, upper = self.overlap_bounds() 350 | else: 351 | lower, upper = self.overlap_bounds 352 | 353 | # Compute which boxes are valid. 354 | 355 | if self.overlap_criterion == 'iou': 356 | # Compute the patch coordinates. 357 | image_coords = np.array([0, 0, image_width, image_height]) 358 | # Compute the IoU between the patch and all of the ground truth boxes. 359 | image_boxes_iou = iou(image_coords, 360 | labels[:, [xmin, ymin, xmax, ymax]], 361 | coords='corners', 362 | mode='element-wise', 363 | border_pixels=self.border_pixels) 364 | requirements_met *= (image_boxes_iou > 365 | lower) * (image_boxes_iou <= upper) 366 | 367 | elif self.overlap_criterion == 'area': 368 | if self.border_pixels == 'half': 369 | d = 0 370 | elif self.border_pixels == 'include': 371 | d = 1 # If border pixels are supposed to belong to the bounding boxes, 372 | # we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`. 373 | elif self.border_pixels == 'exclude': 374 | d = -1 # If border pixels are not supposed to belong to the bounding boxes, 375 | # we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`. 376 | # Compute the areas of the boxes. 377 | box_areas = (labels[:, xmax] - labels[:, xmin] + 378 | d) * (labels[:, ymax] - labels[:, ymin] + d) 379 | # Compute the intersection area between the patch and all of the ground truth boxes. 380 | clipped_boxes = np.copy(labels) 381 | clipped_boxes[:, [ymin, ymax]] = np.clip( 382 | labels[:, [ymin, ymax]], a_min=0, a_max=image_height - 1) 383 | clipped_boxes[:, [xmin, xmax]] = np.clip( 384 | labels[:, [xmin, xmax]], a_min=0, a_max=image_width - 1) 385 | intersection_areas = ( 386 | clipped_boxes[:, xmax] - clipped_boxes[:, xmin] + d) * ( 387 | clipped_boxes[:, ymax] - clipped_boxes[:, ymin] + d 388 | ) # +1 because the border pixels belong to the box areas. 389 | # Check which boxes meet the overlap requirements. 390 | if lower == 0.0: 391 | mask_lower = intersection_areas > lower * box_areas # If `self.lower == 0`, we want to 392 | # make sure that boxes with area 0 don't count, hence the ">" sign instead of the ">=" sign. 393 | else: 394 | mask_lower = intersection_areas >= lower * box_areas # Especially for the case `self.lower == 1` 395 | # we want the ">=" sign, otherwise no boxes would count at all. 396 | mask_upper = intersection_areas <= upper * box_areas 397 | requirements_met *= mask_lower * mask_upper 398 | 399 | elif self.overlap_criterion == 'center_point': 400 | # Compute the center points of the boxes. 401 | cy = (labels[:, ymin] + labels[:, ymax]) / 2 402 | cx = (labels[:, xmin] + labels[:, xmax]) / 2 403 | # Check which of the boxes have center points within the cropped patch remove those that don't. 404 | requirements_met *= (cy >= 0.0) * (cy <= image_height - 1) * ( 405 | cx >= 0.0) * (cx <= image_width - 1) 406 | 407 | return labels[requirements_met] 408 | 409 | 410 | class ImageValidator: 411 | """ 412 | Returns `True` if a given minimum number of bounding boxes meets given overlap 413 | requirements with an image of a given height and width. 414 | """ 415 | 416 | def __init__(self, 417 | overlap_criterion='center_point', 418 | bounds=(0.3, 1.0), 419 | n_boxes_min=1, 420 | labels_format={ 421 | 'class_id': 0, 422 | 'xmin': 1, 423 | 'ymin': 2, 424 | 'xmax': 3, 425 | 'ymax': 4 426 | }, 427 | border_pixels='half'): 428 | """ 429 | Arguments: 430 | overlap_criterion (str, optional): Can be either of 'center_point', 'iou', or 'area'. Determines 431 | which boxes are considered valid with respect to a given image. If set to 'center_point', 432 | a given bounding box is considered valid if its center point lies within the image. 433 | If set to 'area', a given bounding box is considered valid if the quotient of its intersection 434 | area with the image and its own area is within `lower` and `upper`. If set to 'iou', a given 435 | bounding box is considered valid if its IoU with the image is within `lower` and `upper`. 436 | bounds (list or BoundGenerator, optional): Only relevant if `overlap_criterion` is 'area' or 'iou'. 437 | Determines the lower and upper bounds for `overlap_criterion`. Can be either a 2-tuple of scalars 438 | representing a lower bound and an upper bound, or a `BoundGenerator` object, which provides 439 | the possibility to generate bounds randomly. 440 | n_boxes_min (int or str, optional): Either a non-negative integer or the string 'all'. 441 | Determines the minimum number of boxes that must meet the `overlap_criterion` with respect to 442 | an image of the given height and width in order for the image to be a valid image. 443 | If set to 'all', an image is considered valid if all given boxes meet the `overlap_criterion`. 444 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 445 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 446 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 447 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes. 448 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong 449 | to the boxes. If 'exclude', the border pixels do not belong to the boxes. 450 | If 'half', then one of each of the two horizontal and vertical borders belong 451 | to the boxex, but not the other. 452 | """ 453 | if not ((isinstance(n_boxes_min, int) and n_boxes_min > 0) 454 | or n_boxes_min == 'all'): 455 | raise ValueError( 456 | "`n_boxes_min` must be a positive integer or 'all'.") 457 | self.overlap_criterion = overlap_criterion 458 | self.bounds = bounds 459 | self.n_boxes_min = n_boxes_min 460 | self.labels_format = labels_format 461 | self.border_pixels = border_pixels 462 | self.box_filter = BoxFilter(check_overlap=True, 463 | check_min_area=False, 464 | check_degenerate=False, 465 | overlap_criterion=self.overlap_criterion, 466 | overlap_bounds=self.bounds, 467 | labels_format=self.labels_format, 468 | border_pixels=self.border_pixels) 469 | 470 | def __call__(self, labels, image_height, image_width): 471 | """ 472 | Arguments: 473 | labels (array): The labels to be tested. The box coordinates are expected 474 | to be in the image's coordinate system. 475 | image_height (int): The height of the image to compare the box coordinates to. 476 | image_width (int): The width of the image to compare the box coordinates to. 477 | 478 | Returns: 479 | A boolean indicating whether an imgae of the given height and width is 480 | valid with respect to the given bounding boxes. 481 | """ 482 | 483 | self.box_filter.overlap_bounds = self.bounds 484 | self.box_filter.labels_format = self.labels_format 485 | 486 | # Get all boxes that meet the overlap requirements. 487 | valid_labels = self.box_filter(labels=labels, 488 | image_height=image_height, 489 | image_width=image_width) 490 | 491 | # Check whether enough boxes meet the requirements. 492 | if isinstance(self.n_boxes_min, int): 493 | # The image is valid if at least `self.n_boxes_min` ground truth boxes meet the requirements. 494 | if len(valid_labels) >= self.n_boxes_min: 495 | return True 496 | else: 497 | return False 498 | elif self.n_boxes_min == 'all': 499 | # The image is valid if all ground truth boxes meet the requirements. 500 | if len(valid_labels) == len(labels): 501 | return True 502 | else: 503 | return False 504 | 505 | 506 | class RandomPatchInf: 507 | """ 508 | Randomly samples a patch from an image. The randomness refers to whatever 509 | randomness may be introduced by the patch coordinate generator, the box filter, 510 | and the patch validator. 511 | 512 | Input images may be cropped and/or padded along either or both of the two 513 | spatial dimensions as necessary in order to obtain the required patch. 514 | 515 | This operation is very similar to `RandomPatch`, except that: 516 | 1. This operation runs indefinitely until either a valid patch is found or 517 | the input image is returned unaltered, i.e. it cannot fail. 518 | 2. If a bound generator is given, a new pair of bounds will be generated 519 | every `n_trials_max` iterations. 520 | """ 521 | 522 | def __init__(self, 523 | patch_coord_generator, 524 | box_filter=None, 525 | image_validator=None, 526 | bound_generator=None, 527 | n_trials_max=50, 528 | clip_boxes=True, 529 | prob=0.857, 530 | background=(0, 0, 0), 531 | labels_format={ 532 | 'class_id': 0, 533 | 'xmin': 1, 534 | 'ymin': 2, 535 | 'xmax': 3, 536 | 'ymax': 4 537 | }): 538 | """ 539 | Arguments: 540 | patch_coord_generator (PatchCoordinateGenerator): A `PatchCoordinateGenerator` object 541 | to generate the positions and sizes of the patches to be sampled from the input images. 542 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given. 543 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria 544 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`, 545 | the validity of the bounding boxes is not checked. 546 | image_validator (ImageValidator, optional): Only relevant if ground truth bounding boxes are given. 547 | An `ImageValidator` object to determine whether a sampled patch is valid. If `None`, 548 | any outcome is valid. 549 | bound_generator (BoundGenerator, optional): A `BoundGenerator` object to generate upper and 550 | lower bound values for the patch validator. Every `n_trials_max` trials, a new pair of 551 | upper and lower bounds will be generated until a valid patch is found or the original image 552 | is returned. This bound generator overrides the bound generator of the patch validator. 553 | n_trials_max (int, optional): Only relevant if ground truth bounding boxes are given. 554 | The sampler will run indefinitely until either a valid patch is found or the original image 555 | is returned, but this determines the maxmial number of trials to sample a valid patch for each 556 | selected pair of lower and upper bounds before a new pair is picked. 557 | clip_boxes (bool, optional): Only relevant if ground truth bounding boxes are given. 558 | If `True`, any ground truth bounding boxes will be clipped to lie entirely within the 559 | sampled patch. 560 | prob (float, optional): `(1 - prob)` determines the probability with which the original, 561 | unaltered image is returned. 562 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the potential 563 | background pixels of the scaled images. In the case of single-channel images, 564 | the first element of `background` will be used as the background pixel value. 565 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 566 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 567 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 568 | """ 569 | 570 | if not isinstance(patch_coord_generator, PatchCoordinateGenerator): 571 | raise ValueError( 572 | "`patch_coord_generator` must be an instance of `PatchCoordinateGenerator`." 573 | ) 574 | if not (isinstance(image_validator, ImageValidator) 575 | or image_validator is None): 576 | raise ValueError( 577 | "`image_validator` must be either `None` or an `ImageValidator` object." 578 | ) 579 | if not (isinstance(bound_generator, BoundGenerator) 580 | or bound_generator is None): 581 | raise ValueError( 582 | "`bound_generator` must be either `None` or a `BoundGenerator` object." 583 | ) 584 | self.patch_coord_generator = patch_coord_generator 585 | self.box_filter = box_filter 586 | self.image_validator = image_validator 587 | self.bound_generator = bound_generator 588 | self.n_trials_max = n_trials_max 589 | self.clip_boxes = clip_boxes 590 | self.prob = prob 591 | self.background = background 592 | self.labels_format = labels_format 593 | self.sample_patch = CropPad(patch_ymin=None, 594 | patch_xmin=None, 595 | patch_height=None, 596 | patch_width=None, 597 | clip_boxes=self.clip_boxes, 598 | box_filter=self.box_filter, 599 | background=self.background, 600 | labels_format=self.labels_format) 601 | 602 | def __call__(self, image, labels=None, return_inverter=False): 603 | 604 | img_height, img_width = image.shape[:2] 605 | self.patch_coord_generator.img_height = img_height 606 | self.patch_coord_generator.img_width = img_width 607 | 608 | xmin = self.labels_format['xmin'] 609 | ymin = self.labels_format['ymin'] 610 | xmax = self.labels_format['xmax'] 611 | ymax = self.labels_format['ymax'] 612 | 613 | # Override the preset labels format. 614 | if not self.image_validator is None: 615 | self.image_validator.labels_format = self.labels_format 616 | self.sample_patch.labels_format = self.labels_format 617 | 618 | while True: # Keep going until we either find a valid patch or return the original image. 619 | 620 | p = np.random.uniform(0, 1) 621 | if p >= (1.0 - self.prob): 622 | 623 | # In case we have a bound generator, pick a lower and upper bound for the patch validator. 624 | if not ((self.image_validator is None) or 625 | (self.bound_generator is None)): 626 | self.image_validator.bounds = self.bound_generator() 627 | 628 | # Use at most `self.n_trials_max` attempts to find a crop 629 | # that meets our requirements. 630 | for _ in range(max(1, self.n_trials_max)): 631 | 632 | # Generate patch coordinates. 633 | patch_ymin, patch_xmin, patch_height, patch_width = self.patch_coord_generator( 634 | ) 635 | 636 | self.sample_patch.patch_ymin = patch_ymin 637 | self.sample_patch.patch_xmin = patch_xmin 638 | self.sample_patch.patch_height = patch_height 639 | self.sample_patch.patch_width = patch_width 640 | 641 | # Check if the resulting patch meets the aspect ratio requirements. 642 | aspect_ratio = patch_width / patch_height 643 | if not (self.patch_coord_generator.min_aspect_ratio <= 644 | aspect_ratio <= 645 | self.patch_coord_generator.max_aspect_ratio): 646 | continue 647 | 648 | if (labels is None) or (self.image_validator is None): 649 | # We either don't have any boxes or if we do, we will accept any outcome as valid. 650 | return self.sample_patch(image, labels, 651 | return_inverter) 652 | else: 653 | # Translate the box coordinates to the patch's coordinate system. 654 | new_labels = np.copy(labels) 655 | new_labels[:, [ymin, ymax]] -= patch_ymin 656 | new_labels[:, [xmin, xmax]] -= patch_xmin 657 | # Check if the patch contains the minimum number of boxes we require. 658 | if self.image_validator(labels=new_labels, 659 | image_height=patch_height, 660 | image_width=patch_width): 661 | return self.sample_patch(image, labels, 662 | return_inverter) 663 | else: 664 | if return_inverter: 665 | 666 | def inverter(labels): 667 | return labels 668 | 669 | if labels is None: 670 | if return_inverter: 671 | return image, inverter 672 | else: 673 | return image 674 | else: 675 | if return_inverter: 676 | return image, labels, inverter 677 | else: 678 | return image, labels 679 | 680 | 681 | class PatchCoordinateGenerator: 682 | """ 683 | Generates random patch coordinates that meet specified requirements. 684 | """ 685 | 686 | def __init__(self, 687 | img_height=None, 688 | img_width=None, 689 | must_match='h_w', 690 | min_scale=0.3, 691 | max_scale=1.0, 692 | scale_uniformly=False, 693 | min_aspect_ratio=0.5, 694 | max_aspect_ratio=2.0, 695 | patch_ymin=None, 696 | patch_xmin=None, 697 | patch_height=None, 698 | patch_width=None, 699 | patch_aspect_ratio=None): 700 | """ 701 | Arguments: 702 | img_height (int): The height of the image for which the patch coordinates 703 | shall be generated. Doesn't have to be known upon construction. 704 | img_width (int): The width of the image for which the patch coordinates 705 | shall be generated. Doesn't have to be known upon construction. 706 | must_match (str, optional): Can be either of 'h_w', 'h_ar', and 'w_ar'. 707 | Specifies which two of the three quantities height, width, and aspect 708 | ratio determine the shape of the generated patch. The respective third 709 | quantity will be computed from the other two. For example, 710 | if `must_match == 'h_w'`, then the patch's height and width will be 711 | set to lie within [min_scale, max_scale] of the image size or to 712 | `patch_height` and/or `patch_width`, if given. The patch's aspect ratio 713 | is the dependent variable in this case, it will be computed from the 714 | height and width. Any given values for `patch_aspect_ratio`, 715 | `min_aspect_ratio`, or `max_aspect_ratio` will be ignored. 716 | min_scale (float, optional): The minimum size of a dimension of the patch 717 | as a fraction of the respective dimension of the image. Can be greater 718 | than 1. For example, if the image width is 200 and `min_scale == 0.5`, 719 | then the width of the generated patch will be at least 100. If `min_scale == 1.5`, 720 | the width of the generated patch will be at least 300. 721 | max_scale (float, optional): The maximum size of a dimension of the patch 722 | as a fraction of the respective dimension of the image. Can be greater 723 | than 1. For example, if the image width is 200 and `max_scale == 1.0`, 724 | then the width of the generated patch will be at most 200. If `max_scale == 1.5`, 725 | the width of the generated patch will be at most 300. Must be greater than 726 | `min_scale`. 727 | scale_uniformly (bool, optional): If `True` and if `must_match == 'h_w'`, 728 | the patch height and width will be scaled uniformly, otherwise they will 729 | be scaled independently. 730 | min_aspect_ratio (float, optional): Determines the minimum aspect ratio 731 | for the generated patches. 732 | max_aspect_ratio (float, optional): Determines the maximum aspect ratio 733 | for the generated patches. 734 | patch_ymin (int, optional): `None` or the vertical coordinate of the top left 735 | corner of the generated patches. If this is not `None`, the position of the 736 | patches along the vertical axis is fixed. If this is `None`, then the 737 | vertical position of generated patches will be chosen randomly such that 738 | the overlap of a patch and the image along the vertical dimension is 739 | always maximal. 740 | patch_xmin (int, optional): `None` or the horizontal coordinate of the top left 741 | corner of the generated patches. If this is not `None`, the position of the 742 | patches along the horizontal axis is fixed. If this is `None`, then the 743 | horizontal position of generated patches will be chosen randomly such that 744 | the overlap of a patch and the image along the horizontal dimension is 745 | always maximal. 746 | patch_height (int, optional): `None` or the fixed height of the generated patches. 747 | patch_width (int, optional): `None` or the fixed width of the generated patches. 748 | patch_aspect_ratio (float, optional): `None` or the fixed aspect ratio of the 749 | generated patches. 750 | """ 751 | 752 | if not (must_match in {'h_w', 'h_ar', 'w_ar'}): 753 | raise ValueError( 754 | "`must_match` must be either of 'h_w', 'h_ar' and 'w_ar'.") 755 | if min_scale >= max_scale: 756 | raise ValueError("It must be `min_scale < max_scale`.") 757 | if min_aspect_ratio >= max_aspect_ratio: 758 | raise ValueError( 759 | "It must be `min_aspect_ratio < max_aspect_ratio`.") 760 | if scale_uniformly and not ((patch_height is None) and 761 | (patch_width is None)): 762 | raise ValueError( 763 | "If `scale_uniformly == True`, `patch_height` and `patch_width` must both be `None`." 764 | ) 765 | self.img_height = img_height 766 | self.img_width = img_width 767 | self.must_match = must_match 768 | self.min_scale = min_scale 769 | self.max_scale = max_scale 770 | self.scale_uniformly = scale_uniformly 771 | self.min_aspect_ratio = min_aspect_ratio 772 | self.max_aspect_ratio = max_aspect_ratio 773 | self.patch_ymin = patch_ymin 774 | self.patch_xmin = patch_xmin 775 | self.patch_height = patch_height 776 | self.patch_width = patch_width 777 | self.patch_aspect_ratio = patch_aspect_ratio 778 | 779 | def __call__(self): 780 | """ 781 | Returns: 782 | A 4-tuple `(ymin, xmin, height, width)` that represents the coordinates 783 | of the generated patch. 784 | """ 785 | 786 | # Get the patch height and width. 787 | 788 | if self.must_match == 'h_w': # Aspect is the dependent variable. 789 | if not self.scale_uniformly: 790 | # Get the height. 791 | if self.patch_height is None: 792 | patch_height = int( 793 | np.random.uniform(self.min_scale, self.max_scale) * 794 | self.img_height) 795 | else: 796 | patch_height = self.patch_height 797 | # Get the width. 798 | if self.patch_width is None: 799 | patch_width = int( 800 | np.random.uniform(self.min_scale, self.max_scale) * 801 | self.img_width) 802 | else: 803 | patch_width = self.patch_width 804 | else: 805 | scaling_factor = np.random.uniform(self.min_scale, 806 | self.max_scale) 807 | patch_height = int(scaling_factor * self.img_height) 808 | patch_width = int(scaling_factor * self.img_width) 809 | 810 | elif self.must_match == 'h_ar': # Width is the dependent variable. 811 | # Get the height. 812 | if self.patch_height is None: 813 | patch_height = int( 814 | np.random.uniform(self.min_scale, self.max_scale) * 815 | self.img_height) 816 | else: 817 | patch_height = self.patch_height 818 | # Get the aspect ratio. 819 | if self.patch_aspect_ratio is None: 820 | patch_aspect_ratio = np.random.uniform(self.min_aspect_ratio, 821 | self.max_aspect_ratio) 822 | else: 823 | patch_aspect_ratio = self.patch_aspect_ratio 824 | # Get the width. 825 | patch_width = int(patch_height * patch_aspect_ratio) 826 | 827 | elif self.must_match == 'w_ar': # Height is the dependent variable. 828 | # Get the width. 829 | if self.patch_width is None: 830 | patch_width = int( 831 | np.random.uniform(self.min_scale, self.max_scale) * 832 | self.img_width) 833 | else: 834 | patch_width = self.patch_width 835 | # Get the aspect ratio. 836 | if self.patch_aspect_ratio is None: 837 | patch_aspect_ratio = np.random.uniform(self.min_aspect_ratio, 838 | self.max_aspect_ratio) 839 | else: 840 | patch_aspect_ratio = self.patch_aspect_ratio 841 | # Get the height. 842 | patch_height = int(patch_width / patch_aspect_ratio) 843 | 844 | # Get the top left corner coordinates of the patch. 845 | 846 | if self.patch_ymin is None: 847 | # Compute how much room we have along the vertical axis to place the patch. 848 | # A negative number here means that we want to sample a patch that is larger than the original image 849 | # in the vertical dimension, in which case the patch will be placed such that it fully contains the 850 | # image in the vertical dimension. 851 | y_range = self.img_height - patch_height 852 | # Select a random top left corner for the sample position from the possible positions. 853 | if y_range >= 0: 854 | patch_ymin = np.random.randint( 855 | 0, y_range + 1 856 | ) # There are y_range + 1 possible positions for the crop in the vertical dimension. 857 | else: 858 | patch_ymin = np.random.randint( 859 | y_range, 1 860 | ) # The possible positions for the image on the background canvas in the vertical dimension. 861 | else: 862 | patch_ymin = self.patch_ymin 863 | 864 | if self.patch_xmin is None: 865 | # Compute how much room we have along the horizontal axis to place the patch. 866 | # A negative number here means that we want to sample a patch that is larger than the original image 867 | # in the horizontal dimension, in which case the patch will be placed such that it fully contains the 868 | # image in the horizontal dimension. 869 | x_range = self.img_width - patch_width 870 | # Select a random top left corner for the sample position from the possible positions. 871 | if x_range >= 0: 872 | patch_xmin = np.random.randint( 873 | 0, x_range + 1 874 | ) # There are x_range + 1 possible positions for the crop in the horizontal dimension. 875 | else: 876 | patch_xmin = np.random.randint( 877 | x_range, 1 878 | ) # The possible positions for the image on the background canvas in the horizontal dimension. 879 | else: 880 | patch_xmin = self.patch_xmin 881 | 882 | return patch_ymin, patch_xmin, patch_height, patch_width 883 | 884 | 885 | class CropPad: 886 | """ 887 | Crops and/or pads an image deterministically. 888 | 889 | Depending on the given output patch size and the position (top left corner) relative 890 | to the input image, the image will be cropped and/or padded along one or both spatial 891 | dimensions. 892 | 893 | For example, if the output patch lies entirely within the input image, this will result 894 | in a regular crop. If the input image lies entirely within the output patch, this will 895 | result in the image being padded in every direction. All other cases are mixed cases 896 | where the image might be cropped in some directions and padded in others. 897 | 898 | The output patch can be arbitrary in both size and position as long as it overlaps 899 | with the input image. 900 | """ 901 | 902 | def __init__(self, 903 | patch_ymin, 904 | patch_xmin, 905 | patch_height, 906 | patch_width, 907 | clip_boxes=True, 908 | box_filter=None, 909 | background=(0, 0, 0), 910 | labels_format={ 911 | 'class_id': 0, 912 | 'xmin': 1, 913 | 'ymin': 2, 914 | 'xmax': 3, 915 | 'ymax': 4 916 | }): 917 | """ 918 | Arguments: 919 | patch_ymin (int, optional): The vertical coordinate of the top left corner of the output 920 | patch relative to the image coordinate system. Can be negative (i.e. lie outside the image) 921 | as long as the resulting patch still overlaps with the image. 922 | patch_ymin (int, optional): The horizontal coordinate of the top left corner of the output 923 | patch relative to the image coordinate system. Can be negative (i.e. lie outside the image) 924 | as long as the resulting patch still overlaps with the image. 925 | patch_height (int): The height of the patch to be sampled from the image. Can be greater 926 | than the height of the input image. 927 | patch_width (int): The width of the patch to be sampled from the image. Can be greater 928 | than the width of the input image. 929 | clip_boxes (bool, optional): Only relevant if ground truth bounding boxes are given. 930 | If `True`, any ground truth bounding boxes will be clipped to lie entirely within the 931 | sampled patch. 932 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given. 933 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria 934 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`, 935 | the validity of the bounding boxes is not checked. 936 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the potential 937 | background pixels of the scaled images. In the case of single-channel images, 938 | the first element of `background` will be used as the background pixel value. 939 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 940 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 941 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 942 | """ 943 | # if (patch_height <= 0) or (patch_width <= 0): 944 | # raise ValueError("Patch height and width must both be positive.") 945 | # if (patch_ymin + patch_height < 0) or (patch_xmin + patch_width < 0): 946 | # raise ValueError("A patch with the given coordinates cannot overlap with an input image.") 947 | if not (isinstance(box_filter, BoxFilter) or box_filter is None): 948 | raise ValueError( 949 | "`box_filter` must be either `None` or a `BoxFilter` object.") 950 | self.patch_height = patch_height 951 | self.patch_width = patch_width 952 | self.patch_ymin = patch_ymin 953 | self.patch_xmin = patch_xmin 954 | self.clip_boxes = clip_boxes 955 | self.box_filter = box_filter 956 | self.background = background 957 | self.labels_format = labels_format 958 | 959 | def __call__(self, image, labels=None, return_inverter=False): 960 | 961 | img_height, img_width = image.shape[:2] 962 | 963 | if (self.patch_ymin > img_height) or (self.patch_xmin > img_width): 964 | raise ValueError( 965 | "The given patch doesn't overlap with the input image.") 966 | 967 | labels = np.copy(labels) 968 | 969 | xmin = self.labels_format['xmin'] 970 | ymin = self.labels_format['ymin'] 971 | xmax = self.labels_format['xmax'] 972 | ymax = self.labels_format['ymax'] 973 | 974 | # Top left corner of the patch relative to the image coordinate system: 975 | patch_ymin = self.patch_ymin 976 | patch_xmin = self.patch_xmin 977 | 978 | # Create a canvas of the size of the patch we want to end up with. 979 | if image.ndim == 3: 980 | canvas = np.zeros(shape=(self.patch_height, self.patch_width, 3), 981 | dtype=np.uint8) 982 | canvas[:, :] = self.background 983 | elif image.ndim == 2: 984 | canvas = np.zeros(shape=(self.patch_height, self.patch_width), 985 | dtype=np.uint8) 986 | canvas[:, :] = self.background[0] 987 | 988 | # Perform the crop. 989 | if patch_ymin < 0 and patch_xmin < 0: # Pad the image at the top and on the left. 990 | image_crop_height = min( 991 | img_height, self.patch_height + patch_ymin 992 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction. 993 | image_crop_width = min( 994 | img_width, self.patch_width + patch_xmin 995 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction. 996 | canvas[-patch_ymin:-patch_ymin + 997 | image_crop_height, -patch_xmin:-patch_xmin + 998 | image_crop_width] = image[:image_crop_height, : 999 | image_crop_width] 1000 | 1001 | elif patch_ymin < 0 and patch_xmin >= 0: # Pad the image at the top and crop it on the left. 1002 | image_crop_height = min( 1003 | img_height, self.patch_height + patch_ymin 1004 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction. 1005 | image_crop_width = min( 1006 | self.patch_width, img_width - patch_xmin 1007 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction. 1008 | canvas[-patch_ymin:-patch_ymin + image_crop_height, : 1009 | image_crop_width] = image[:image_crop_height, patch_xmin: 1010 | patch_xmin + image_crop_width] 1011 | 1012 | elif patch_ymin >= 0 and patch_xmin < 0: # Crop the image at the top and pad it on the left. 1013 | image_crop_height = min( 1014 | self.patch_height, img_height - patch_ymin 1015 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction. 1016 | image_crop_width = min( 1017 | img_width, self.patch_width + patch_xmin 1018 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction. 1019 | canvas[:image_crop_height, -patch_xmin:-patch_xmin + 1020 | image_crop_width] = image[patch_ymin:patch_ymin + 1021 | image_crop_height, : 1022 | image_crop_width] 1023 | 1024 | elif patch_ymin >= 0 and patch_xmin >= 0: # Crop the image at the top and on the left. 1025 | image_crop_height = min( 1026 | self.patch_height, img_height - patch_ymin 1027 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction. 1028 | image_crop_width = min( 1029 | self.patch_width, img_width - patch_xmin 1030 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction. 1031 | canvas[:image_crop_height, :image_crop_width] = image[ 1032 | patch_ymin:patch_ymin + 1033 | image_crop_height, patch_xmin:patch_xmin + image_crop_width] 1034 | 1035 | image = canvas 1036 | 1037 | if return_inverter: 1038 | 1039 | def inverter(labels): 1040 | labels = np.copy(labels) 1041 | labels[:, [ymin + 1, ymax + 1]] += patch_ymin 1042 | labels[:, [xmin + 1, xmax + 1]] += patch_xmin 1043 | return labels 1044 | 1045 | if not (labels is None): 1046 | 1047 | # Translate the box coordinates to the patch's coordinate system. 1048 | labels[:, [ymin, ymax]] -= patch_ymin 1049 | labels[:, [xmin, xmax]] -= patch_xmin 1050 | 1051 | # Compute all valid boxes for this patch. 1052 | if not (self.box_filter is None): 1053 | self.box_filter.labels_format = self.labels_format 1054 | labels = self.box_filter(labels=labels, 1055 | image_height=self.patch_height, 1056 | image_width=self.patch_width) 1057 | 1058 | if self.clip_boxes: 1059 | labels[:, [ymin, ymax]] = np.clip(labels[:, [ymin, ymax]], 1060 | a_min=0, 1061 | a_max=self.patch_height - 1) 1062 | labels[:, [xmin, xmax]] = np.clip(labels[:, [xmin, xmax]], 1063 | a_min=0, 1064 | a_max=self.patch_width - 1) 1065 | 1066 | if return_inverter: 1067 | return image, labels, inverter 1068 | else: 1069 | return image, labels 1070 | 1071 | else: 1072 | if return_inverter: 1073 | return image, inverter 1074 | else: 1075 | return image 1076 | 1077 | 1078 | class Resize: 1079 | """ 1080 | Resizes images to a specified height and width in pixels. 1081 | """ 1082 | 1083 | def __init__(self, 1084 | height, 1085 | width, 1086 | interpolation_mode=cv2.INTER_LINEAR, 1087 | box_filter=None, 1088 | labels_format={ 1089 | 'class_id': 0, 1090 | 'xmin': 1, 1091 | 'ymin': 2, 1092 | 'xmax': 3, 1093 | 'ymax': 4 1094 | }): 1095 | """ 1096 | Arguments: 1097 | height (int): The desired height of the output images in pixels. 1098 | width (int): The desired width of the output images in pixels. 1099 | interpolation_mode (int, optional): An integer that denotes a valid 1100 | OpenCV interpolation mode. For example, integers 0 through 5 are 1101 | valid interpolation modes. 1102 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given. 1103 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria 1104 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`, 1105 | the validity of the bounding boxes is not checked. 1106 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 1107 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 1108 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 1109 | """ 1110 | if not (isinstance(box_filter, BoxFilter) or box_filter is None): 1111 | raise ValueError( 1112 | "`box_filter` must be either `None` or a `BoxFilter` object.") 1113 | self.out_height = height 1114 | self.out_width = width 1115 | self.interpolation_mode = interpolation_mode 1116 | self.box_filter = box_filter 1117 | self.labels_format = labels_format 1118 | 1119 | def __call__(self, image, labels=None, return_inverter=False): 1120 | 1121 | img_height, img_width = image.shape[:2] 1122 | 1123 | xmin = self.labels_format['xmin'] 1124 | ymin = self.labels_format['ymin'] 1125 | xmax = self.labels_format['xmax'] 1126 | ymax = self.labels_format['ymax'] 1127 | 1128 | image = cv2.resize(image, 1129 | dsize=(self.out_width, self.out_height), 1130 | interpolation=self.interpolation_mode) 1131 | 1132 | if return_inverter: 1133 | 1134 | def inverter(labels): 1135 | labels = np.copy(labels) 1136 | labels[:, [ymin + 1, ymax + 1137 | 1]] = np.round(labels[:, [ymin + 1, ymax + 1]] * 1138 | (img_height / self.out_height), 1139 | decimals=0) 1140 | labels[:, [xmin + 1, xmax + 1141 | 1]] = np.round(labels[:, [xmin + 1, xmax + 1]] * 1142 | (img_width / self.out_width), 1143 | decimals=0) 1144 | return labels 1145 | 1146 | if labels is None: 1147 | if return_inverter: 1148 | return image, inverter 1149 | else: 1150 | return image 1151 | else: 1152 | labels = np.copy(labels) 1153 | labels[:, [ymin, ymax]] = np.round(labels[:, [ymin, ymax]] * 1154 | (self.out_height / img_height), 1155 | decimals=0) 1156 | labels[:, [xmin, xmax]] = np.round(labels[:, [xmin, xmax]] * 1157 | (self.out_width / img_width), 1158 | decimals=0) 1159 | # labels[:, [ymin, ymax]] = labels[:, [ymin, ymax]] * (self.out_height / img_height) 1160 | # labels[:, [xmin, xmax]] = labels[:, [xmin, xmax]] * (self.out_width / img_width) 1161 | if not (self.box_filter is None): 1162 | self.box_filter.labels_format = self.labels_format 1163 | labels = self.box_filter(labels=labels, 1164 | image_height=self.out_height, 1165 | image_width=self.out_width) 1166 | 1167 | if return_inverter: 1168 | return image, labels, inverter 1169 | else: 1170 | return image, labels 1171 | 1172 | 1173 | class ResizeRandomInterp: 1174 | """ 1175 | Resizes images to a specified height and width in pixels using a radnomly 1176 | selected interpolation mode. 1177 | """ 1178 | 1179 | def __init__(self, 1180 | height, 1181 | width, 1182 | interpolation_modes=[ 1183 | cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, 1184 | cv2.INTER_AREA, cv2.INTER_LANCZOS4 1185 | ], 1186 | box_filter=None, 1187 | labels_format={ 1188 | 'class_id': 0, 1189 | 'xmin': 1, 1190 | 'ymin': 2, 1191 | 'xmax': 3, 1192 | 'ymax': 4 1193 | }): 1194 | """ 1195 | Arguments: 1196 | height (int): The desired height of the output image in pixels. 1197 | width (int): The desired width of the output image in pixels. 1198 | interpolation_modes (list/tuple, optional): A list/tuple of integers 1199 | that represent valid OpenCV interpolation modes. For example, 1200 | integers 0 through 5 are valid interpolation modes. 1201 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given. 1202 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria 1203 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`, 1204 | the validity of the bounding boxes is not checked. 1205 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels 1206 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords 1207 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array. 1208 | """ 1209 | if not (isinstance(interpolation_modes, (list, tuple))): 1210 | raise ValueError("`interpolation_mode` must be a list or tuple.") 1211 | self.height = height 1212 | self.width = width 1213 | self.interpolation_modes = interpolation_modes 1214 | self.box_filter = box_filter 1215 | self.labels_format = labels_format 1216 | self.resize = Resize(height=self.height, 1217 | width=self.width, 1218 | box_filter=self.box_filter, 1219 | labels_format=self.labels_format) 1220 | 1221 | def __call__(self, image, labels=None, return_inverter=False): 1222 | self.resize.interpolation_mode = np.random.choice( 1223 | self.interpolation_modes) 1224 | self.resize.labels_format = self.labels_format 1225 | return self.resize(image, labels, return_inverter) 1226 | 1227 | 1228 | def convert_coordinates(tensor, start_index, conversion, border_pixels='half'): 1229 | """ 1230 | Convert coordinates for axis-aligned 2D boxes between two coordinate formats. 1231 | 1232 | Creates a copy of `tensor`, i.e. does not operate in place. Currently there are 1233 | three supported coordinate formats that can be converted from and to each other: 1234 | 1) (xmin, xmax, ymin, ymax) - the 'minmax' format 1235 | 2) (xmin, ymin, xmax, ymax) - the 'corners' format 1236 | 2) (cx, cy, w, h) - the 'centroids' format 1237 | 1238 | Arguments: 1239 | tensor (array): A Numpy nD array containing the four consecutive coordinates 1240 | to be converted somewhere in the last axis. 1241 | start_index (int): The index of the first coordinate in the last axis of `tensor`. 1242 | conversion (str, optional): The conversion direction. Can be 'minmax2centroids', 1243 | 'centroids2minmax', 'corners2centroids', 'centroids2corners', 'minmax2corners', 1244 | or 'corners2minmax'. 1245 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes. 1246 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong 1247 | to the boxes. If 'exclude', the border pixels do not belong to the boxes. 1248 | If 'half', then one of each of the two horizontal and vertical borders belong 1249 | to the boxex, but not the other. 1250 | 1251 | Returns: 1252 | A Numpy nD array, a copy of the input tensor with the converted coordinates 1253 | in place of the original coordinates and the unaltered elements of the original 1254 | tensor elsewhere. 1255 | """ 1256 | if border_pixels == 'half': 1257 | d = 0 1258 | elif border_pixels == 'include': 1259 | d = 1 1260 | elif border_pixels == 'exclude': 1261 | d = -1 1262 | 1263 | ind = start_index 1264 | tensor1 = np.copy(tensor).astype(np.float) 1265 | if conversion == 'minmax2centroids': 1266 | tensor1[..., ind] = (tensor[..., ind] + 1267 | tensor[..., ind + 1]) / 2.0 # Set cx 1268 | tensor1[..., ind + 1] = (tensor[..., ind + 2] + 1269 | tensor[..., ind + 3]) / 2.0 # Set cy 1270 | tensor1[..., ind + 1271 | 2] = tensor[..., ind + 1] - tensor[..., ind] + d # Set w 1272 | tensor1[..., ind + 1273 | 3] = tensor[..., ind + 3] - tensor[..., ind + 2] + d # Set h 1274 | elif conversion == 'centroids2minmax': 1275 | tensor1[..., ind] = tensor[..., ind] - tensor[..., ind + 1276 | 2] / 2.0 # Set xmin 1277 | tensor1[..., ind + 1278 | 1] = tensor[..., ind] + tensor[..., ind + 2] / 2.0 # Set xmax 1279 | tensor1[..., ind + 1280 | 2] = tensor[..., ind + 1281 | 1] - tensor[..., ind + 3] / 2.0 # Set ymin 1282 | tensor1[..., ind + 1283 | 3] = tensor[..., ind + 1284 | 1] + tensor[..., ind + 3] / 2.0 # Set ymax 1285 | elif conversion == 'corners2centroids': 1286 | tensor1[..., ind] = (tensor[..., ind] + 1287 | tensor[..., ind + 2]) / 2.0 # Set cx 1288 | tensor1[..., ind + 1] = (tensor[..., ind + 1] + 1289 | tensor[..., ind + 3]) / 2.0 # Set cy 1290 | tensor1[..., ind + 1291 | 2] = tensor[..., ind + 2] - tensor[..., ind] + d # Set w 1292 | tensor1[..., ind + 1293 | 3] = tensor[..., ind + 3] - tensor[..., ind + 1] + d # Set h 1294 | elif conversion == 'centroids2corners': 1295 | tensor1[..., ind] = tensor[..., ind] - tensor[..., ind + 1296 | 2] / 2.0 # Set xmin 1297 | tensor1[..., ind + 1298 | 1] = tensor[..., ind + 1299 | 1] - tensor[..., ind + 3] / 2.0 # Set ymin 1300 | tensor1[..., ind + 1301 | 2] = tensor[..., ind] + tensor[..., ind + 2] / 2.0 # Set xmax 1302 | tensor1[..., ind + 1303 | 3] = tensor[..., ind + 1304 | 1] + tensor[..., ind + 3] / 2.0 # Set ymax 1305 | elif (conversion == 'minmax2corners') or (conversion == 'corners2minmax'): 1306 | tensor1[..., ind + 1] = tensor[..., ind + 2] 1307 | tensor1[..., ind + 2] = tensor[..., ind + 1] 1308 | else: 1309 | raise ValueError( 1310 | "Unexpected conversion value. Supported values are 'minmax2centroids', 'centroids2minmax', 'corners2centroids', 'centroids2corners', 'minmax2corners', and 'corners2minmax'." 1311 | ) 1312 | 1313 | return tensor1 1314 | 1315 | 1316 | def intersection_area_(boxes1, 1317 | boxes2, 1318 | coords='corners', 1319 | mode='outer_product', 1320 | border_pixels='half'): 1321 | """ 1322 | The same as 'intersection_area()' but for internal use, i.e. without all the safety checks. 1323 | """ 1324 | 1325 | m = boxes1.shape[0] # The number of boxes in `boxes1` 1326 | n = boxes2.shape[0] # The number of boxes in `boxes2` 1327 | 1328 | # Set the correct coordinate indices for the respective formats. 1329 | if coords == 'corners': 1330 | xmin = 0 1331 | ymin = 1 1332 | xmax = 2 1333 | ymax = 3 1334 | elif coords == 'minmax': 1335 | xmin = 0 1336 | xmax = 1 1337 | ymin = 2 1338 | ymax = 3 1339 | 1340 | if border_pixels == 'half': 1341 | d = 0 1342 | elif border_pixels == 'include': 1343 | d = 1 # If border pixels are supposed to belong to the bounding boxes, we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`. 1344 | elif border_pixels == 'exclude': 1345 | d = -1 # If border pixels are not supposed to belong to the bounding boxes, we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`. 1346 | 1347 | # Compute the intersection areas. 1348 | 1349 | if mode == 'outer_product': 1350 | 1351 | # For all possible box combinations, get the greater xmin and ymin values. 1352 | # This is a tensor of shape (m,n,2). 1353 | min_xy = np.maximum( 1354 | np.tile(np.expand_dims(boxes1[:, [xmin, ymin]], axis=1), 1355 | reps=(1, n, 1)), 1356 | np.tile(np.expand_dims(boxes2[:, [xmin, ymin]], axis=0), 1357 | reps=(m, 1, 1))) 1358 | 1359 | # For all possible box combinations, get the smaller xmax and ymax values. 1360 | # This is a tensor of shape (m,n,2). 1361 | max_xy = np.minimum( 1362 | np.tile(np.expand_dims(boxes1[:, [xmax, ymax]], axis=1), 1363 | reps=(1, n, 1)), 1364 | np.tile(np.expand_dims(boxes2[:, [xmax, ymax]], axis=0), 1365 | reps=(m, 1, 1))) 1366 | 1367 | # Compute the side lengths of the intersection rectangles. 1368 | side_lengths = np.maximum(0, max_xy - min_xy + d) 1369 | 1370 | return side_lengths[:, :, 0] * side_lengths[:, :, 1] 1371 | 1372 | elif mode == 'element-wise': 1373 | 1374 | min_xy = np.maximum(boxes1[:, [xmin, ymin]], boxes2[:, [xmin, ymin]]) 1375 | max_xy = np.minimum(boxes1[:, [xmax, ymax]], boxes2[:, [xmax, ymax]]) 1376 | 1377 | # Compute the side lengths of the intersection rectangles. 1378 | side_lengths = np.maximum(0, max_xy - min_xy + d) 1379 | 1380 | return side_lengths[:, 0] * side_lengths[:, 1] 1381 | 1382 | 1383 | def iou(boxes1, 1384 | boxes2, 1385 | coords='centroids', 1386 | mode='outer_product', 1387 | border_pixels='half'): 1388 | """ 1389 | Computes the intersection-over-union similarity (also known as Jaccard similarity) 1390 | of two sets of axis-aligned 2D rectangular boxes. 1391 | 1392 | Let `boxes1` and `boxes2` contain `m` and `n` boxes, respectively. 1393 | 1394 | In 'outer_product' mode, returns an `(m,n)` matrix with the IoUs for all possible 1395 | combinations of the boxes in `boxes1` and `boxes2`. 1396 | 1397 | In 'element-wise' mode, `m` and `n` must be broadcast-compatible. Refer to the explanation 1398 | of the `mode` argument for details. 1399 | 1400 | Arguments: 1401 | boxes1 (array): Either a 1D Numpy array of shape `(4, )` containing the coordinates for one box in the 1402 | format specified by `coords` or a 2D Numpy array of shape `(m, 4)` containing the coordinates for `m` boxes. 1403 | If `mode` is set to 'element_wise', the shape must be broadcast-compatible with `boxes2`. 1404 | boxes2 (array): Either a 1D Numpy array of shape `(4, )` containing the coordinates for one box in the 1405 | format specified by `coords` or a 2D Numpy array of shape `(n, 4)` containing the coordinates for `n` boxes. 1406 | If `mode` is set to 'element_wise', the shape must be broadcast-compatible with `boxes1`. 1407 | coords (str, optional): The coordinate format in the input arrays. Can be either 'centroids' for the format 1408 | `(cx, cy, w, h)`, 'minmax' for the format `(xmin, xmax, ymin, ymax)`, or 'corners' for the format 1409 | `(xmin, ymin, xmax, ymax)`. 1410 | mode (str, optional): Can be one of 'outer_product' and 'element-wise'. In 'outer_product' mode, returns an 1411 | `(m,n)` matrix with the IoU overlaps for all possible combinations of the `m` boxes in `boxes1` with the 1412 | `n` boxes in `boxes2`. In 'element-wise' mode, returns a 1D array and the shapes of `boxes1` and `boxes2` 1413 | must be boadcast-compatible. If both `boxes1` and `boxes2` have `m` boxes, then this returns an array of 1414 | length `m` where the i-th position contains the IoU overlap of `boxes1[i]` with `boxes2[i]`. 1415 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes. 1416 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong 1417 | to the boxes. If 'exclude', the border pixels do not belong to the boxes. 1418 | If 'half', then one of each of the two horizontal and vertical borders belong 1419 | to the boxex, but not the other. 1420 | 1421 | Returns: 1422 | A 1D or 2D Numpy array (refer to the `mode` argument for details) of dtype float containing values in [0,1], 1423 | the Jaccard similarity of the boxes in `boxes1` and `boxes2`. 0 means there is no overlap between two given 1424 | boxes, 1 means their coordinates are identical. 1425 | """ 1426 | 1427 | # Make sure the boxes have the right shapes. 1428 | if boxes1.ndim > 2: 1429 | raise ValueError( 1430 | "boxes1 must have rank either 1 or 2, but has rank {}.".format( 1431 | boxes1.ndim)) 1432 | if boxes2.ndim > 2: 1433 | raise ValueError( 1434 | "boxes2 must have rank either 1 or 2, but has rank {}.".format( 1435 | boxes2.ndim)) 1436 | 1437 | if boxes1.ndim == 1: boxes1 = np.expand_dims(boxes1, axis=0) 1438 | if boxes2.ndim == 1: boxes2 = np.expand_dims(boxes2, axis=0) 1439 | 1440 | if not (boxes1.shape[1] == boxes2.shape[1] == 4): 1441 | raise ValueError( 1442 | "All boxes must consist of 4 coordinates, but the boxes in `boxes1` and `boxes2` have {} and {} coordinates, respectively." 1443 | .format(boxes1.shape[1], boxes2.shape[1])) 1444 | if not mode in {'outer_product', 'element-wise'}: 1445 | raise ValueError( 1446 | "`mode` must be one of 'outer_product' and 'element-wise', but got '{}'." 1447 | .format(mode)) 1448 | 1449 | # Convert the coordinates if necessary. 1450 | if coords == 'centroids': 1451 | boxes1 = convert_coordinates(boxes1, 1452 | start_index=0, 1453 | conversion='centroids2corners') 1454 | boxes2 = convert_coordinates(boxes2, 1455 | start_index=0, 1456 | conversion='centroids2corners') 1457 | coords = 'corners' 1458 | elif not (coords in {'minmax', 'corners'}): 1459 | raise ValueError( 1460 | "Unexpected value for `coords`. Supported values are 'minmax', 'corners' and 'centroids'." 1461 | ) 1462 | 1463 | # Compute the IoU. 1464 | 1465 | # Compute the interesection areas. 1466 | 1467 | intersection_areas = intersection_area_(boxes1, 1468 | boxes2, 1469 | coords=coords, 1470 | mode=mode) 1471 | 1472 | m = boxes1.shape[0] # The number of boxes in `boxes1` 1473 | n = boxes2.shape[0] # The number of boxes in `boxes2` 1474 | 1475 | # Compute the union areas. 1476 | 1477 | # Set the correct coordinate indices for the respective formats. 1478 | if coords == 'corners': 1479 | xmin = 0 1480 | ymin = 1 1481 | xmax = 2 1482 | ymax = 3 1483 | elif coords == 'minmax': 1484 | xmin = 0 1485 | xmax = 1 1486 | ymin = 2 1487 | ymax = 3 1488 | 1489 | if border_pixels == 'half': 1490 | d = 0 1491 | elif border_pixels == 'include': 1492 | d = 1 # If border pixels are supposed to belong to the bounding boxes, we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`. 1493 | elif border_pixels == 'exclude': 1494 | d = -1 # If border pixels are not supposed to belong to the bounding boxes, we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`. 1495 | 1496 | if mode == 'outer_product': 1497 | 1498 | boxes1_areas = np.tile(np.expand_dims( 1499 | (boxes1[:, xmax] - boxes1[:, xmin] + d) * 1500 | (boxes1[:, ymax] - boxes1[:, ymin] + d), 1501 | axis=1), 1502 | reps=(1, n)) 1503 | boxes2_areas = np.tile(np.expand_dims( 1504 | (boxes2[:, xmax] - boxes2[:, xmin] + d) * 1505 | (boxes2[:, ymax] - boxes2[:, ymin] + d), 1506 | axis=0), 1507 | reps=(m, 1)) 1508 | 1509 | elif mode == 'element-wise': 1510 | 1511 | boxes1_areas = (boxes1[:, xmax] - boxes1[:, xmin] + 1512 | d) * (boxes1[:, ymax] - boxes1[:, ymin] + d) 1513 | boxes2_areas = (boxes2[:, xmax] - boxes2[:, xmin] + 1514 | d) * (boxes2[:, ymax] - boxes2[:, ymin] + d) 1515 | 1516 | union_areas = boxes1_areas + boxes2_areas - intersection_areas 1517 | 1518 | return intersection_areas / union_areas 1519 | --------------------------------------------------------------------------------