├── model
├── __init__.py
├── __pycache__
│ └── model.cpython-35.pyc
└── model.py
├── utils
├── __init__.py
├── __pycache__
│ ├── kitti_utils.cpython-35.pyc
│ └── model_utils.cpython-35.pyc
├── visualize_augumented_data.py
├── make_image_dataset.py
├── model_utils.py
├── kitti_eval.py
└── kitti_utils.py
├── .gitignore
├── dataset
├── __init__.py
├── __pycache__
│ ├── augument.cpython-35.pyc
│ └── dataset.cpython-35.pyc
├── dataset.py
└── augument.py
├── examples
├── 1.png
├── 2.png
├── 3.png
├── 4.png
├── car_detection_ground.png
└── cyclist_detection_ground.png
├── config
├── kitti_anchors.txt
└── class_flag.txt
├── weights
└── .gitkeep
├── kitti
└── training
│ ├── calib
│ └── .gitkeep
│ ├── label_2
│ └── .gitkeep
│ └── velodyne
│ └── .gitkeep
├── predict.py
├── README.md
└── train.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/1.png
--------------------------------------------------------------------------------
/examples/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/2.png
--------------------------------------------------------------------------------
/examples/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/3.png
--------------------------------------------------------------------------------
/examples/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/4.png
--------------------------------------------------------------------------------
/config/kitti_anchors.txt:
--------------------------------------------------------------------------------
1 | 1.08 1.19
2 | 3.42 4.41
3 | 6.63 11.38
4 | 9.42 5.11
5 | 16.62 10.52
6 |
--------------------------------------------------------------------------------
/weights/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
4 |
--------------------------------------------------------------------------------
/kitti/training/calib/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
4 |
--------------------------------------------------------------------------------
/kitti/training/label_2/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
4 |
--------------------------------------------------------------------------------
/kitti/training/velodyne/.gitkeep:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file !.gitkeep
4 |
--------------------------------------------------------------------------------
/examples/car_detection_ground.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/car_detection_ground.png
--------------------------------------------------------------------------------
/examples/cyclist_detection_ground.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/examples/cyclist_detection_ground.png
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/model/__pycache__/model.cpython-35.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/augument.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/dataset/__pycache__/augument.cpython-35.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/dataset/__pycache__/dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/kitti_utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/utils/__pycache__/kitti_utils.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wwooo/tensorflow_complex_yolo/HEAD/utils/__pycache__/model_utils.cpython-35.pyc
--------------------------------------------------------------------------------
/config/class_flag.txt:
--------------------------------------------------------------------------------
1 | 0 Car (255,50,0)
2 | 1 Van (255,0,130)
3 | 2 Truck (0,0,255)
4 | 3 Pedestrain (0,100,255)
5 | 4 person_sitting (255,255,0)
6 | 5 Cyclist (0,255,0)
7 | 6 Tram (0,255,255)
8 | 7 Misc (255,0,255)
9 |
--------------------------------------------------------------------------------
/utils/visualize_augumented_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import cv2
3 | import os
4 | import sys
5 | sys.path.append('.')
6 | from dataset.dataset import ImageDataSet
7 | from utils.kitti_utils import draw_rotated_box, get_corner_gtbox
8 | img_h, img_w = 768, 1024
9 | gt_box_color = (255, 255, 255)
10 | class_list = [
11 | 'Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
12 | 'Misc'
13 | ]
14 | dataset = ImageDataSet(data_set='test',
15 | mode='visualize',
16 | flip=True,
17 | random_scale=True,
18 | load_to_memory=False)
19 |
20 |
21 | def make_dir(directory):
22 | if not os.path.exists(directory):
23 | os.makedirs(directory)
24 |
25 |
26 | make_dir('./tmp')
27 | for img_idx, img, target in dataset.data_generator():
28 | # draw gt bbox
29 | print("process data: {}, saved in ./tmp".format(img_idx))
30 | for i in range(target.shape[0]):
31 | cx = int(target[i][1])
32 | cy = int(target[i][2])
33 | w = int(target[i][3])
34 | h = int(target[i][4])
35 | rz = target[i][5]
36 | draw_rotated_box(img, cx, cy, w, h, rz, gt_box_color)
37 | label = class_list[int(target[i][0])]
38 | box = get_corner_gtbox([cx, cy, w, h])
39 | cv2.putText(img, label, (box[0], box[1]), cv2.FONT_HERSHEY_PLAIN, 1.0,
40 | gt_box_color, 1)
41 | cv2.imwrite('./tmp/{}.png'.format(img_idx), img[:, :, ::-1])
42 |
--------------------------------------------------------------------------------
/utils/make_image_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import os
4 | import cv2
5 | import numpy as np
6 | import sys
7 | sys.path.append('.')
8 | from dataset.dataset import PointCloudDataset
9 | from utils.model_utils import make_dir
10 | image_dataset_dir = 'kitti/image_dataset/'
11 | class_list = [
12 | 'Car', 'Van', 'Truck', 'Pedestrian',
13 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc'
14 | ]
15 | img_h, img_w = 768, 1024
16 | # dataset
17 | train_dataset = PointCloudDataset(root='kitti/', data_set='train')
18 | test_dataset = PointCloudDataset(root='kitti/', data_set='test')
19 |
20 |
21 | def delete_file_folder(src):
22 | if os.path.isfile(src):
23 | os.remove(src)
24 | elif os.path.isdir(src):
25 | for item in os.listdir(src):
26 | item_src = os.path.join(src, item)
27 | delete_file_folder(item_src)
28 |
29 |
30 | def preprocess_dataset(dataset):
31 | """
32 | Convert point cloud data to image while
33 | filtering out image without objects.
34 | param: dataset: (PointCloudDataset)
35 | return: None
36 | """
37 | for img_idx, rgb_map, target in dataset.getitem():
38 | rgb_map = np.array(rgb_map * 255, np.uint8)
39 | target = np.array(target)
40 | print('process image: {}'.format(img_idx))
41 | for i in range(target.shape[0]):
42 | if target[i].sum() == 0:
43 | break
44 | with open("kitti/image_dataset/labels/{}.txt".format(img_idx), 'a+') as f:
45 | label = class_list[int(target[i][0])]
46 | cx = target[i][1] * img_w
47 | cy = target[i][2] * img_h
48 | w = target[i][3] * img_w
49 | h = target[i][4] * img_h
50 | rz = target[i][5]
51 | line = label + ' ' + '{} {} {} {} {}\n'.format(cx, cy, w, h, rz)
52 | f.write(line)
53 | cv2.imwrite('kitti/image_dataset/images/{}.png'.format(img_idx), rgb_map[:, :, ::-1])
54 | print('make image dataset done!')
55 |
56 |
57 | def make_train_test_list():
58 | name_list = os.listdir(image_dataset_dir + 'labels')
59 | name_list.sort()
60 | with open('config/test_image_list.txt', 'w') as f:
61 | for name in name_list[0:1000]:
62 | f.write(name.split('.')[0])
63 | f.write('\n')
64 | with open('config/train_image_list.txt', 'w') as f:
65 | for name in name_list[1000:]:
66 | f.write(name.split('.')[0])
67 | f.write('\n')
68 |
69 |
70 | if __name__ == "__main__":
71 | make_dir(image_dataset_dir + 'images')
72 | make_dir(image_dataset_dir + 'labels')
73 | delete_file_folder(image_dataset_dir + 'labels')
74 | preprocess_dataset(train_dataset)
75 | preprocess_dataset(test_dataset)
76 | make_train_test_list()
77 |
78 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import tensorflow as tf
7 | from dataset.dataset import PointCloudDataset
8 | from utils.model_utils import preprocess_data, non_max_supression, filter_bbox, make_dir
9 | from utils.kitti_utils import draw_rotated_box, calculate_angle, get_corner_gtbox, \
10 | read_anchors_from_file, read_class_flag
11 | gt_box_color = (255, 255, 255)
12 | prob_th = 0.3
13 | nms_iou_th = 0.4
14 | n_anchors = 5
15 | n_classes = 8
16 | net_scale = 32
17 | img_h, img_w = 768, 1024
18 | grid_w, grid_h = 32, 24
19 | class_list = [
20 | 'Car', 'Van', 'Truck', 'Pedestrian',
21 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc'
22 | ]
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--draw_gt_box", type=str, default='True', help="Whether to draw_gtbox, True or False")
26 | parser.add_argument("--weights_path", type=str, default='./weights/yolo_tloss_1.185166835784912_vloss_2.9397876932621-220800',
27 | help="set the weights_path")
28 | args = parser.parse_args()
29 | weights_path = args.weights_path
30 |
31 | # dataset
32 | dataset = PointCloudDataset(root='./kitti/', data_set='test')
33 | make_dir('./predict_result')
34 |
35 |
36 | def predict(draw_gt_box='False'):
37 |
38 | important_classes, names, color = read_class_flag('config/class_flag.txt')
39 | anchors = read_anchors_from_file('config/kitti_anchors.txt')
40 | sess = tf.Session()
41 | saver = tf.train.import_meta_graph(weights_path + '.meta')
42 | saver.restore(sess, weights_path)
43 | graph = tf.get_default_graph()
44 | image = graph.get_tensor_by_name("image_placeholder:0")
45 | train_flag = graph.get_tensor_by_name("flag_placeholder:0")
46 | y = graph.get_tensor_by_name("net/y:0")
47 | for img_idx, rgb_map, target in dataset.getitem():
48 | print("process data: {}, saved in ./predict_result/".format(img_idx))
49 | img = np.array(rgb_map * 255, np.uint8)
50 | target = np.array(target)
51 | # draw gt bbox
52 | if draw_gt_box == 'True':
53 | for i in range(target.shape[0]):
54 | if target[i].sum() == 0:
55 | break
56 | cx = int(target[i][1] * img_w)
57 | cy = int(target[i][2] * img_h)
58 | w = int(target[i][3] * img_w)
59 | h = int(target[i][4] * img_h)
60 | rz = target[i][5]
61 | draw_rotated_box(img, cx, cy, w, h, rz, gt_box_color)
62 | label = class_list[int(target[i][0])]
63 | box = get_corner_gtbox([cx, cy, w, h])
64 | cv2.putText(img, label, (box[0], box[1]),
65 | cv2.FONT_HERSHEY_PLAIN, 1.0, gt_box_color, 1)
66 | data = sess.run(y, feed_dict={image: [rgb_map], train_flag: False})
67 | classes, rois = preprocess_data(data, anchors, important_classes,
68 | grid_w, grid_h, net_scale)
69 | classes, index = non_max_supression(classes, rois, prob_th, nms_iou_th)
70 | all_boxes = filter_bbox(classes, rois, index)
71 | for box in all_boxes:
72 | class_idx = box[0]
73 | corner_box = get_corner_gtbox(box[1:5])
74 | angle = calculate_angle(box[6], box[5])
75 | class_prob = box[7]
76 | draw_rotated_box(img, box[1], box[2], box[3], box[4],
77 | angle, color[class_idx])
78 | cv2.putText(img,
79 | class_list[class_idx] + ' : {:.2f}'.format(class_prob),
80 | (corner_box[0], corner_box[1]), cv2.FONT_HERSHEY_PLAIN,
81 | 0.7, color[class_idx], 1, cv2.LINE_AA)
82 | cv2.imwrite('./predict_result/{}.png'.format(img_idx), img)
83 |
84 |
85 | if __name__ == '__main__':
86 | predict(draw_gt_box=args.draw_gt_box)
87 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 |
5 | def make_dir(directory):
6 | if not os.path.exists(directory):
7 | os.makedirs(directory)
8 |
9 |
10 | def bbox_iou(box1, box2, x1y1x2y2=True):
11 | if x1y1x2y2:
12 | mx = min(box1[0], box2[0])
13 | Mx = max(box1[2], box2[2])
14 | my = min(box1[1], box2[1])
15 | My = max(box1[3], box2[3])
16 | w1 = box1[2] - box1[0]
17 | h1 = box1[3] - box1[1]
18 | w2 = box2[2] - box2[0]
19 | h2 = box2[3] - box2[1]
20 | else:
21 | mx = min(box1[0] - box1[2] / 2.0, box2[0] - box2[2] / 2.0)
22 | Mx = max(box1[0] + box1[2] / 2.0, box2[0] + box2[2] / 2.0)
23 | my = min(box1[1] - box1[3] / 2.0, box2[1] - box2[3] / 2.0)
24 | My = max(box1[1] + box1[3] / 2.0, box2[1] + box2[3] / 2.0)
25 | w1 = box1[2]
26 | h1 = box1[3]
27 | w2 = box2[2]
28 | h2 = box2[3]
29 | uw = Mx - mx
30 | uh = My - my
31 | cw = w1 + w2 - uw
32 | ch = h1 + h2 - uh
33 | carea = 0
34 | if cw <= 0 or ch <= 0:
35 | return 0.0
36 |
37 | area1 = w1 * h1
38 | area2 = w2 * h2
39 | carea = cw * ch
40 | uarea = area1 + area2 - carea
41 | return carea / uarea
42 |
43 |
44 | def iou(r1, r2):
45 | intersect_w = np.maximum(
46 | np.minimum(r1[0] + r1[2], r2[0] + r2[2]) - np.maximum(r1[0], r2[0]), 0)
47 | intersect_h = np.maximum(
48 | np.minimum(r1[1] + r1[3], r2[1] + r2[3]) - np.maximum(r1[1], r2[1]), 0)
49 | area_r1 = r1[2] * r1[3]
50 | area_r2 = r2[2] * r2[3]
51 | intersect = intersect_w * intersect_h
52 | union = area_r1 + area_r2 - intersect
53 | return intersect / union
54 |
55 |
56 | def softmax(x):
57 | e_x = np.exp(x)
58 | return e_x / np.sum(e_x)
59 |
60 |
61 | def sigmoid(x):
62 | return 1.0 / (1.0 + np.exp(-x))
63 |
64 |
65 | def non_max_supression(classes, locations, prob_th, iou_th):
66 | """
67 | Filter out some overlapping boxes by non-maximum suppression
68 | """
69 | classes = np.transpose(classes)
70 | indxs = np.argsort(-classes, axis=1)
71 |
72 | for i in range(classes.shape[0]):
73 | classes[i] = classes[i][indxs[i]]
74 |
75 | for class_idx, class_vec in enumerate(classes):
76 | for roi_idx, roi_prob in enumerate(class_vec):
77 | if roi_prob < prob_th:
78 | classes[class_idx][roi_idx] = 0
79 |
80 | for class_idx, class_vec in enumerate(classes):
81 | for roi_idx, roi_prob in enumerate(class_vec):
82 | if roi_prob == 0:
83 | continue
84 | roi = locations[indxs[class_idx][roi_idx]][0:4]
85 | for roi_ref_idx, roi_ref_prob in enumerate(class_vec):
86 | if roi_ref_prob == 0 or roi_ref_idx <= roi_idx:
87 | continue
88 | roi_ref = locations[indxs[class_idx][roi_ref_idx]][0:4]
89 | if bbox_iou(roi, roi_ref, False) > iou_th:
90 | classes[class_idx][roi_ref_idx] = 0
91 | return classes, indxs
92 |
93 |
94 | def filter_bbox(classes, rois, indxs):
95 | """
96 | Pick out bounding boxes that are retained after non-maximum suppression
97 | """
98 | all_bboxs = []
99 | for class_idx, c in enumerate(classes):
100 | for loc_idx, class_prob in enumerate(c):
101 | if class_prob > 0:
102 | x = int(rois[indxs[class_idx][loc_idx]][0])
103 | y = int(rois[indxs[class_idx][loc_idx]][1])
104 | w = int(rois[indxs[class_idx][loc_idx]][2])
105 | h = int(rois[indxs[class_idx][loc_idx]][3])
106 | re = rois[indxs[class_idx][loc_idx]][4]
107 | im = rois[indxs[class_idx][loc_idx]][5]
108 | all_bboxs.append([class_idx, x, y, w, h, re, im, class_prob])
109 | return all_bboxs
110 |
111 |
112 | def preprocess_data(data, anchors, important_classes, grid_w, grid_h, net_scale):
113 | """
114 | Decode the data output by the model, obtain the center coordinates
115 | x, y and width and height of the bounding box in the image,
116 | and the category, the real and imaginary parts of the complex.
117 |
118 | """
119 | locations = []
120 | classes = []
121 | n_anchors = np.shape(anchors)[0]
122 | for i in range(grid_h):
123 | for j in range(grid_w):
124 | for k in range(n_anchors):
125 | class_vec = softmax(data[0, i, j, k, 7:])
126 | object_conf = sigmoid(data[0, i, j, k, 6])
127 | class_prob = object_conf * class_vec
128 | w = np.exp(data[0, i, j, k, 2]
129 | ) * anchors[k][0] / 80 * grid_w * net_scale
130 | h = np.exp(data[0, i, j, k, 3]
131 | ) * anchors[k][1] / 60 * grid_h * net_scale
132 | dx = sigmoid(data[0, i, j, k, 0])
133 | dy = sigmoid(data[0, i, j, k, 1])
134 | re = 2 * sigmoid(data[0, i, j, k, 4]) - 1
135 | im = 2 * sigmoid(data[0, i, j, k, 5]) - 1
136 | y = (i + dy) * net_scale
137 | x = (j + dx) * net_scale
138 | classes.append(class_prob[important_classes])
139 | locations.append([x, y, w, h, re, im])
140 | classes = np.array(classes)
141 | locations = np.array(locations)
142 | return classes, locations
143 |
144 |
--------------------------------------------------------------------------------
/utils/kitti_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | This script is mainly used to generate the prediction result
4 | of each RGB-map of the test set, convert the bounding box coordinates
5 | in the image to the lidar coordinate system, and then convert the
6 | coordinates to the camera coordinate system using the coordinate
7 | transformation matrix provided by the kitti data set.
8 | The coordinate transformation requires (x, y, z), since the complex-yolo
9 | does not predict the height of the objects, the height is set to a fixed value
10 | of 1.5, which does not affect the bird's eye view benchmark evaluation.
11 |
12 | """
13 | from __future__ import division
14 | import numpy as np
15 | import argparse
16 | import tensorflow as tf
17 | import cv2
18 | from model_utils import preprocess_data, non_max_supression, filter_bbox, make_dir
19 | from kitti_utils import load_kitti_calib, calculate_angle,\
20 | read_anchors_from_file, read_class_flag
21 | from kitti_utils import angle_rz_to_ry, coord_image_to_velo, coord_velo_to_cam
22 | prob_th = 0.3
23 | iou_th = 0.4
24 | n_anchors = 5
25 | n_classes = 8
26 | net_scale = 32
27 | img_h, img_w = 768, 1024
28 | grid_w, grid_h = 32, 24
29 | test_image_path = "kitti/image_dataset/images/"
30 | class_list = [
31 | 'Car', 'Van', 'Truck', 'Pedestrian',
32 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc'
33 | ]
34 | train_list = 'config/train_image_list.txt'
35 | test_list = 'config/test_image_list.txt'
36 | calib_dir = 'kitti/training/calib/'
37 |
38 | # kitti_static_cylist = 'cyclist_detection_ground.txt'
39 | # kitti_static_car = 'car_detection_ground.txt'
40 | # kitti_static_pedestrian = 'pedestrian_detection_ground.txt'
41 |
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--weights_path", type=str, default='./weights/yolo_tloss_1.185166835784912_vloss_2.9397876932621-220800', help="set the weights_path")
44 | args = parser.parse_args()
45 | weights_path = args.weights_path
46 |
47 | make_dir("./eval_results")
48 |
49 |
50 | def kitti_eval():
51 | important_classes, names, colors = read_class_flag('config/class_flag.txt')
52 | anchors = read_anchors_from_file('config/kitti_anchors.txt')
53 | sess = tf.Session()
54 | saver = tf.train.import_meta_graph(weights_path + '.meta')
55 | saver.restore(sess, weights_path)
56 | graph = tf.get_default_graph()
57 | image = graph.get_tensor_by_name("image_placeholder:0")
58 | train_flag = graph.get_tensor_by_name("flag_placeholder:0")
59 | y = graph.get_tensor_by_name("net/y:0")
60 | for test_file_index in range(1000):
61 | print('process data: {}, saved in ./eval_results'.format(test_file_index))
62 | calib_file = calib_dir + str(test_file_index).zfill(6) + '.txt'
63 | calib = load_kitti_calib(calib_file)
64 | result_file = "./eval_results/" + str(test_file_index).zfill(6) + ".txt"
65 | img_path = test_image_path + str(test_file_index).zfill(6) + '.png'
66 | rgb_map = cv2.imread(img_path)[:, :, ::-1]
67 | img_for_net = rgb_map / 255.0
68 | data = sess.run(y,
69 | feed_dict={image: [img_for_net],
70 | train_flag: False})
71 | classes, rois = preprocess_data(data,
72 | anchors,
73 | important_classes,
74 | grid_w,
75 | grid_h,
76 | net_scale)
77 | classes, index = non_max_supression(classes,
78 | rois,
79 | prob_th,
80 | iou_th)
81 | all_boxes = filter_bbox(classes, rois, index)
82 | with open(result_file, "w") as f:
83 | for box in all_boxes:
84 | pred_img_y = box[2]
85 | pred_img_x = box[1]
86 | velo_x, velo_y = coord_image_to_velo(pred_img_y, pred_img_x)
87 | cam_x, cam_z = coord_velo_to_cam(velo_x, velo_y, calib['Tr_velo2cam'])
88 | pred_width = box[3] * 80.0 / img_w
89 | pred_height = box[4] * 60.0 / img_h
90 | pred_cls = class_list[box[0]]
91 | pred_conf = box[7]
92 | angle_rz = calculate_angle(box[6], box[5])
93 | angle_ry = angle_rz_to_ry(angle_rz)
94 | pred_line = pred_cls + " -1 -1 -10 -1 -1 -1 -1 -1" + \
95 | " {:.2f} {:.2f}".format(pred_width, pred_height) + \
96 | " {:.2f} {:.2f} {:.2f}".format(cam_x, -1000, cam_z) + \
97 | " {:.2f} {:.2f}".format(angle_ry, pred_conf)
98 | f.write(pred_line)
99 | f.write("\n")
100 |
101 |
102 | def cal_ap(kitti_statics_results):
103 | """
104 | Calculate the ap approximately.
105 | param kitti_statics_results(str): Kitti evaluation script output statistics result file.
106 | return:
107 | """
108 | with open(kitti_statics_results, 'r') as f:
109 | lines = f.readlines()
110 | all_lines = []
111 | for line in lines:
112 | pr = list(map(float, line.strip().split(' ')))
113 | all_lines.append(pr)
114 | all_lines = np.array(all_lines)
115 | ap = np.zeros([all_lines.shape[0], 3])
116 | ap[1:, 0] = 0.025 * all_lines[1:, 1]
117 | ap[1:, 1] = 0.025 * all_lines[1:, 2]
118 | ap[1:, 2] = 0.025 * all_lines[1:, 3]
119 | result = np.sum(ap, 0)
120 | return result
121 |
122 |
123 | if __name__ == '__main__':
124 | kitti_eval()
125 | # print("car ap: {}".format(cal_ap(kitti_static_car)))
126 | # print("cyclist ap: {}".format(cal_ap(kitti_static_cylist)))
127 | # print("pedestrian ap: {}".format(cal_ap(kitti_static_pedestrian)))
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Complex-YOLO implementation in tensorflow
2 | ---
3 | ### Contents
4 |
5 | [Overview](#overview)
[Examples](#examples)
[Dependencies](#dependencies)
[How to use it](#how-to-use-it)
[Others](#others)
[ToDo](#todo)
6 |
7 | ### Overview
8 |
9 | The project is an unofficial implementation of complex-yolo, and the model structure is slightly inconsistent with what the paper describes. [Complex-YOLO: Real-time 3D Object Detection on Point Clouds](https://arxiv.org/abs/1803.06199). [AI-liu/Complex-YOLO](https://github.com/AI-liu/Complex-YOLO) has the most stars, but there seem to be some bugs. The model has no yaw angle prediction, and on the test set, the model has no generalization ability, so this project only refers to the point cloud preprocessing part , model structure reference [WojciechMormul/yolo2](https://github.com/WojciechMormul/yolo2). On this basis, a complete complex-yolo algorithm is implemented. Because of the high precision of this model, it can be easily converged, and there is no need to adjust too many parameters carefully.
10 |
11 | Complex-yolo takes point cloud data as input and encodes point cloud into RGB-map of bird 's-eye view to predict the position and yaw angle of objiects in 3d space. In order to improve the efficiency of training model, the point cloud data set is firstly made into RGB dataset. The experiment is based on the kitti dataset. The kitti dataset has a total of 7481 labeled data. The dataset is divided into two parts, the first 1000 samples are used as test sets, and the remaining samples are used as training sets.
12 |
13 | ### Examples
14 |
15 | Below are some prediction examples of the Complex-Yolo, the predictions were made on the splited test set. The iou of car and cyclist are set to 0.5, 0.3 respectively.
16 |
17 | | | |
18 | |---|---|
19 | |

|
|
20 | | 
| 
|
21 | |
|
|
22 |
23 | ### Dependencies
24 |
25 | * Python 3.x
26 | * Numpy
27 | * TensorFlow 1.x
28 | * OpenCV
29 |
30 | ### How to use it
31 |
32 | Clone this repo
33 |
34 | ```bash
35 | git clone https://github.com/wwooo/tensorflow_complex_yolo
36 | ```
37 |
38 |
39 | ```bash
40 | cd tensorflow_complex_yolo
41 | ```
42 | How to prepare data:
43 |
44 | 1 . Download the data from the official website of kitti.
45 |
46 | * [data_object_velodyne.zip](http://www.cvlibs.net/download.php?file=data_object_velodyne.zip)
47 | * [data_object_label_2.zip](http://www.cvlibs.net/download.php?file=data_object_label_2.zip)
48 | * [data_object_calib.zip](http://www.cvlibs.net/download.php?file=data_object_calib.zip)
49 |
50 | 2 . Create the following folder structure in the current working directory
51 |
52 | ```
53 | tensorflow_complex_yolo
54 | kitti
55 | training
56 | calib
57 | label_2
58 | velodyne
59 | ```
60 |
61 |
62 | 3 . Unzip the downloaded kitti dataset and get the following data. Place the data in the corresponding folder created above.
63 |
64 |
65 | data_object_velodyne/training/\*.bin \*.bin -> velodyne
66 |
67 | data_object_label_2/training/label_2/\*.txt \*.txt -> label_2
68 |
69 | data_object_calib/training/calib/\*.txt \*.txt -> calib
70 |
71 |
72 | Then create RGB-image data set:
73 |
74 | ```bash
75 | python utils/make_image_dataset.py
76 | ```
77 |
78 | This script will convert the point cloud data into image data, which will be automatically saved in the ./kitti/image_dataset/, and will generate test_image_list.txt and train_image_list.txt in the ./config folder.
79 |
80 | Note:This model only predicts the area of 60x80 in front of the car, and encodes the point cloud in this area into a 768 x1024 RGB-map. In the kitti data set, not all samples have objects in this area. Therefore, in the process of making image dataset, the script will automatically filter out samples of that doesn't have objects in the area.
81 |
82 | How to train a model:
83 | ```bash
84 | python train.py
85 | --load_weights
86 | --weights_path
87 | --batch_size
88 | --num_iter
89 | --save_dir
90 | --save_interval
91 | --gpu_id
92 | ```
93 | All parameters have default values, so you can run the script directly. If you want to load model weights, you must provide the weights\_path and set--load\_weights=True , default is False. --batch_size, default 8, you can adjust the batch_size according to the memory size of the GPU card. --num_iter, set the number of iterations. --save_interval, how many epochs to save the model, default is 2 . --save\_dir, where the model is saved, default is ./weights/ . --gpu_id specify which card to use for training, default is 0.
94 |
95 | How to predict:
96 |
97 | ```bash
98 | python predict.py --weights_path =./weights_path/... --draw_gt_box=True
99 | ```
100 |
101 | When running predict.py , directly use point cloud data as input to the model, and the script saves predicted result in the predict\_result folder. You can set draw\_gt_box = True or False to decide whether to draw the ground truth box on predicted result.
102 |
103 | How to eval:
104 |
105 | ```bash
106 | python utils/kitti_eval.py
107 | ```
108 |
109 | This script will save the prediction results consistent with the kitti label format. Then use kitti's official evaluation script to evaluate. You should study the official evaluation script of kitti.
110 |
111 | ### Others
112 |
113 | You can run utils/visualize_augumented_data.py to visualize the transformed data and labels, results saved in ./tmp.
114 |
115 | ### ToDo
116 |
117 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import os
4 | import argparse
5 | import tensorflow as tf
6 | from utils.model_utils import make_dir
7 | from dataset.dataset import ImageDataSet
8 | from model.model import yolo_net, yolo_loss
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--load_weights", type=str, default='False', help="Whether to load weights, True or False")
13 | parser.add_argument("--batch_size", type=int, default=8, help="Set the batch_size")
14 | parser.add_argument("--weights_path", type=str, default="./weights/...", help="Set the weights_path")
15 | parser.add_argument("--save_dir", type=str, default="./weights/", help="Dir to save weights")
16 | parser.add_argument("--gpu_id", type=str, default='0', help="Specify GPU device")
17 | parser.add_argument("--num_iter", type=int, default=16000, help="num_max_iter")
18 | parser.add_argument("--save_interval", type=int, default=2, help="Save once every two epochs")
19 | args = parser.parse_args()
20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
21 |
22 | SCALE = 32
23 | GRID_W, GRID_H = 32, 24
24 | N_CLASSES = 8
25 | N_ANCHORS = 5
26 | IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH = GRID_H * SCALE, GRID_W * SCALE, 3
27 | batch_size = args.batch_size
28 |
29 | train_dataset = ImageDataSet(data_set='train',
30 | mode='train',
31 | load_to_memory=False)
32 | test_dataset = ImageDataSet(data_set='test',
33 | mode='test',
34 | flip=False,
35 | aug_hsv=False,
36 | random_scale=False,
37 | load_to_memory=False)
38 | num_val_step = int(test_dataset.num_samples / args.batch_size)
39 | save_steps = int(train_dataset.num_samples / args.batch_size * args.save_interval)
40 |
41 |
42 | def print_info():
43 | print("train samples: {}".format(train_dataset.num_samples))
44 | print("test samples: {}".format(test_dataset.num_samples))
45 | print("batch_size: {}".format(args.batch_size))
46 | print("iter steps: {}".format(args.num_iter))
47 |
48 |
49 | def train(load_weights='False'):
50 | make_dir(args.save_dir)
51 | max_val_loss = 99999999.0
52 | global_step = tf.Variable(0, trainable=False)
53 | learning_rate = tf.train.exponential_decay(0.001,
54 | global_step,
55 | 1500,
56 | 0.96,
57 | staircase=True)
58 |
59 | image = tf.placeholder(
60 | shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH],
61 | dtype=tf.float32,
62 | name='image_placeholder')
63 | label = tf.placeholder(shape=[None, GRID_H, GRID_W, N_ANCHORS, 8],
64 | dtype=tf.float32,
65 | name='label_placeholder')
66 | train_flag = tf.placeholder(dtype=tf.bool, name='flag_placeholder')
67 |
68 | with tf.variable_scope('net'):
69 | y = yolo_net(image, train_flag)
70 | with tf.name_scope('loss'):
71 | loss, loss_xy, loss_wh, loss_re, loss_im, loss_obj, loss_no_obj, loss_c = yolo_loss(
72 | y, label, batch_size)
73 |
74 | loss_xy_sum = tf.summary.scalar("loss_xy_sum", loss_xy)
75 | loss_wh_sum = tf.summary.scalar("loss_wh_sum", loss_wh)
76 | loss_re_sum = tf.summary.scalar("loss_re_sum", loss_re)
77 | loss_im_sum = tf.summary.scalar("loss_im_sum", loss_im)
78 | loss_obj_sum = tf.summary.scalar("loss_obj_sum", loss_obj)
79 | loss_no_obj_sum = tf.summary.scalar("loss_no_obj_sum", loss_no_obj)
80 | loss_c_sum = tf.summary.scalar("loss_c", loss_c)
81 | loss_sum = tf.summary.scalar("loss", loss)
82 | loss_tensorboard_sum = tf.summary.merge([
83 | loss_xy_sum, loss_wh_sum, loss_re_sum, loss_im_sum,
84 | loss_obj_sum, loss_no_obj_sum, loss_c_sum, loss_sum
85 | ])
86 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
87 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
88 | with tf.control_dependencies(update_ops):
89 | train_step = opt.minimize(loss, global_step=global_step)
90 | sess = tf.Session()
91 | sess.run(tf.global_variables_initializer())
92 | saver = tf.train.Saver()
93 | writer = tf.summary.FileWriter("./logs", sess.graph)
94 |
95 | if load_weights == 'True':
96 | print("load weights from {}".format(args.weights_path))
97 | saver = tf.train.import_meta_graph(args.weights_path + '.meta')
98 | saver.restore(sess, args.weights_path)
99 | print('load weights done!')
100 |
101 | for step, (train_image_data, train_label_data) in enumerate(
102 | train_dataset.get_batch(batch_size)):
103 | _, lr, train_loss, data, summary_str = sess.run(
104 | [train_step, learning_rate, loss, y, loss_tensorboard_sum],
105 | feed_dict={
106 | train_flag: True,
107 | image: train_image_data,
108 | label: train_label_data
109 | })
110 | writer.add_summary(summary_str, step)
111 |
112 | if step % 10 == 0:
113 | print('iter: %i, loss: %f, lr: %f' % (step, train_loss, lr))
114 | if (step + 1) % save_steps == 0:
115 | print("val...")
116 | val_loss = 0.0
117 | for val_step, (val_image_data, val_label_data) in enumerate(
118 | test_dataset.get_batch(batch_size)):
119 | val_loss += sess.run(loss,
120 | feed_dict={
121 | train_flag: False,
122 | image: val_image_data,
123 | label: val_label_data
124 | })
125 | if val_step + 1 == num_val_step:
126 | break
127 | val_loss /= num_val_step
128 | print("iter: {} val_loss: {:.2f}".format(step, val_loss))
129 | if val_loss < max_val_loss:
130 | saver.save(sess,
131 | os.path.join(
132 | args.save_dir,
133 | 'Complex_YOLO_train_loss_{:.2f}_val_loss_{:.2f}'.format(
134 | train_loss, val_loss)),
135 | global_step=global_step)
136 | max_val_loss = val_loss
137 | if step + 1 == args.num_iter:
138 | saver.save(sess,
139 | os.path.join(
140 | args.save_dir,
141 | 'Complex_YOLO_final_train_loss_{:.2f}'.format(
142 | train_loss)),
143 | global_step=global_step)
144 | print("training done!")
145 | break
146 |
147 |
148 | if __name__ == "__main__":
149 | print_info()
150 | train(load_weights=args.load_weights)
151 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import numpy as np
4 | import cv2
5 | import os
6 | from utils.kitti_utils import read_anchors_from_file, read_label_from_txt, \
7 | load_kitti_calib, get_target, remove_points, make_bv_feature
8 | from dataset.augument import RandomScaleAugmentation
9 | from model.model import encode_label
10 | img_h, img_w = 768, 1024
11 | grid_h, grid_w = 24, 32
12 | iou_th = 0.5
13 | boundary = {
14 | 'minX': 0,
15 | 'maxX': 80,
16 | 'minY': -40,
17 | 'maxY': 40,
18 | 'minZ': -2,
19 | 'maxZ': 1.25
20 | }
21 |
22 |
23 | class PointCloudDataset(object):
24 | def __init__(self,
25 | root='./kitti/',
26 | data_set='train'):
27 | self.root = root
28 | self.data_path = os.path.join(root, 'training')
29 | self.lidar_path = os.path.join(self.data_path, "velodyne")
30 | self.calib_path = os.path.join(self.data_path, "calib")
31 | self.label_path = os.path.join(self.data_path, "label_2")
32 | self.index_list = [str(i) for i in range(1000)] if data_set == "test" \
33 | else [str(i) for i in range(1000, 7481)]
34 |
35 | def getitem(self):
36 | """
37 | Encode single-frame point cloud data into RGB-map and get the label
38 | """
39 | for index in self.index_list:
40 | index = index.zfill(6)
41 | lidar_file = self.lidar_path + '/' + index + '.bin'
42 | calib_file = self.calib_path + '/' + index + '.txt'
43 | label_file = self.label_path + '/' + index + '.txt'
44 | calib = load_kitti_calib(calib_file)
45 | target = get_target(label_file, calib['Tr_velo2cam'])
46 | # load point cloud data
47 | point_cloud = np.fromfile(lidar_file,
48 | dtype=np.float32).reshape(-1, 4)
49 | b = remove_points(point_cloud, boundary)
50 | rgb_map = make_bv_feature(b) # (768, 1024, 3)
51 |
52 | yield index, rgb_map, target
53 |
54 |
55 | class ImageDataSet(object):
56 | """
57 | If there is enough memory, set load_to_memory=True,
58 | load the data into memory to improve training efficiency.
59 | """
60 | def __init__(self,
61 | data_set='train',
62 | mode='train',
63 | flip=True,
64 | random_scale=True,
65 | aug_hsv=False,
66 | load_to_memory=False):
67 | self.mode = mode
68 | self.flip = flip
69 | self.aug_hsv = aug_hsv
70 | self.random_scale = random_scale
71 | self.anchors_path = 'config/kitti_anchors.txt'
72 | self.labels_dir = 'kitti/image_dataset/labels/'
73 | self.images_dir = 'kitti/image_dataset/images/'
74 | self.all_image_index = 'config/' + data_set + '_image_list.txt'
75 | self.load_to_memory = load_to_memory
76 | self.anchors = read_anchors_from_file(self.anchors_path)
77 | self.rand_scale_transform = RandomScaleAugmentation(img_h, img_w)
78 | self.label = None
79 | self.img = None
80 | self.img_index = None
81 | self.label_encoded = None
82 | self.index_list = self.read_index_list()
83 | self.num_samples = len(self.index_list)
84 |
85 | def read_index_list(self):
86 | with open(self.all_image_index, 'r') as f:
87 | index_list = f.readlines()
88 | return index_list
89 |
90 | def horizontal_flip(self, image, target): # target: class,x,y,w,l,angle
91 | image = np.flip(image, 1) # image = image[:, ::-1, :]
92 | image_w = image.shape[1]
93 | target[:, 1] = image_w - target[:, 1]
94 | target[:, 5] = -target[:, 5]
95 | return image, target
96 |
97 | def augment_hsv(self, img):
98 | fraction = 0.30 # must be < 1.0
99 | img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) # hue, sat, val
100 | s = img_hsv[:, :, 1].astype(np.float32) # saturation
101 | v = img_hsv[:, :, 2].astype(np.float32) # value
102 | a = (np.random.random() * 2 - 1) * fraction + 1
103 | b = (np.random.random() * 2 - 1) * fraction + 1
104 | s *= a
105 | v *= b
106 | img_hsv[:, :, 1] = s if a < 1 else s.clip(None, 255)
107 | img_hsv[:, :, 2] = v if b < 1 else v.clip(None, 255)
108 | img = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
109 | return img
110 |
111 | def label_box_center_to_corner(self, label):
112 | """
113 | param label: class, cx, cy, w, l, angle
114 | return: class, x_min, y_min, x_max, y_max, angle
115 | """
116 | label_ = np.copy(label)
117 | cx = label_[:, 1]
118 | cy = label_[:, 2]
119 | w = label_[:, 3]
120 | l = label_[:, 4]
121 | label[:, 1] = cx - w / 2.0
122 | label[:, 2] = cy - l / 2.0
123 | label[:, 3] = cx + w / 2.0
124 | label[:, 4] = cy + l / 2.0
125 | return label
126 |
127 | def label_box_corner_to_center(self, label):
128 | """
129 | param label: class, x_min, y_min, x_max, y_max, angle
130 | return: class, cx, cy, w, l, angle
131 | """
132 | cx = (label[:, 1] + label[:, 3]) / 2.0
133 | cy = (label[:, 2] + label[:, 4]) / 2.0
134 | w = label[:, 3] - label[:, 1]
135 | l = label[:, 4] - label[:, 2]
136 | label[:, 1] = cx
137 | label[:, 2] = cy
138 | label[:, 3] = w
139 | label[:, 4] = l
140 | return label
141 |
142 | def data_generator(self):
143 |
144 | if self.load_to_memory:
145 | all_images = []
146 | all_labels = []
147 | all_index = []
148 | for index in self.index_list:
149 | index = index.strip()
150 | label_file = self.labels_dir + index + '.txt'
151 | label = read_label_from_txt(label_file)
152 | image_path = self.images_dir + index + '.png'
153 | img = cv2.imread(image_path)
154 | if img is None:
155 | print('failed to load image:' + image_path)
156 | continue
157 | img = np.flip(img, 2)
158 | all_index.append(index)
159 | all_images.append(img)
160 | all_labels.append(label)
161 | sample_index = [i for i in range(len(all_index))]
162 | while True:
163 | np.random.shuffle(sample_index)
164 | for i in sample_index:
165 | self.img_index = np.copy(all_index[i])
166 | self.label = np.copy(all_labels[i])
167 | self.img = np.copy(all_images[i])
168 | if self.aug_hsv:
169 | if np.random.random() > 0.5:
170 | self.img = self.augment_hsv(self.img)
171 | if self.flip:
172 | if np.random.random() > 0.5:
173 | self.img, self.label = self.horizontal_flip(
174 | self.img, self.label)
175 | if self.random_scale:
176 | self.label = self.label_box_center_to_corner(
177 | self.label)
178 | self.img, self.label = self.rand_scale_transform(
179 | self.img, self.label)
180 | self.label = self.label_box_corner_to_center(
181 | self.label)
182 | self.label_encoded = encode_label(
183 | self.label, self.anchors, img_w, img_h, grid_w,
184 | grid_h, iou_th)
185 | if self.mode == 'visualize': # Generate data for visualization
186 | yield self.img_index, self.img, self.label
187 | else:
188 | yield self.img_index, self.img / 255.0, self.label_encoded # Generate data for net
189 |
190 | else:
191 | while True:
192 | np.random.shuffle(self.index_list)
193 | for index in self.index_list:
194 | self.img_index = index.strip()
195 | label_file = self.labels_dir + self.img_index + '.txt'
196 | self.label = read_label_from_txt(label_file)
197 | image_path = self.images_dir + self.img_index + '.png'
198 | self.img = cv2.imread(image_path)
199 | if self.img is None:
200 | print('failed to load image:' + image_path)
201 | continue
202 | self.img = np.flip(self.img, 2)
203 |
204 | if self.aug_hsv:
205 | if np.random.random() > 0.5:
206 | self.img = self.augment_hsv(self.img)
207 | if self.flip:
208 | if np.random.random() > 0.5:
209 | self.img, self.label = self.horizontal_flip(
210 | self.img, self.label)
211 | if self.random_scale:
212 | self.label = self.label_box_center_to_corner(
213 | self.label)
214 | self.img, self.label = self.rand_scale_transform(
215 | self.img, self.label)
216 | self.label = self.label_box_corner_to_center(
217 | self.label)
218 | self.label_encoded = encode_label(
219 | self.label, self.anchors, img_w, img_h, grid_w,
220 | grid_h, iou_th)
221 | if self.mode == 'visualize': # Generate data for visualization
222 | yield self.img_index, self.img, self.label
223 | else:
224 | yield self.img_index, self.img / 255.0, self.label_encoded # Generate data for net
225 |
226 | def get_batch(self, batch_size):
227 | """
228 | Generate a batch of data for model training
229 | param batch_size (int):
230 |
231 | """
232 | img_batch = []
233 | label_batch = []
234 | i = 0
235 | for img_idx, img, label_encoded in self.data_generator():
236 | i += 1
237 | img_batch.append(img)
238 | label_batch.append(label_encoded)
239 | if i % batch_size == 0:
240 | yield np.array(img_batch), np.array(label_batch)
241 | i = 0
242 | img_batch = []
243 | label_batch = []
244 |
245 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | SCALE = 32
4 | GRID_W, GRID_H = 32, 24
5 | N_CLASSES = 8
6 | N_ANCHORS = 5
7 | IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH = GRID_H * SCALE, GRID_W * SCALE, 3
8 | class_dict = {
9 | 'Car': 0,
10 | 'Van': 1,
11 | 'Truck': 2,
12 | 'Pedestrian': 3,
13 | 'Person_sitting': 4,
14 | 'Cyclist': 5,
15 | 'Tram': 6,
16 | 'Misc': 7
17 | }
18 |
19 |
20 | def leak_relu(x, leak):
21 | return tf.maximum(x, leak * x, name='relu')
22 |
23 |
24 | def max_pool_layer(x, size, stride, name):
25 | with tf.name_scope(name):
26 | x = tf.layers.max_pooling2d(x, size, stride, padding='SAME')
27 | return x
28 |
29 |
30 | def conv_layer(x, kernel, depth, train_logical, name):
31 | with tf.variable_scope(name):
32 | x = tf.layers.conv2d(
33 | x,
34 | depth,
35 | kernel,
36 | padding='SAME',
37 | kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
38 | bias_initializer=tf.zeros_initializer())
39 | x = tf.layers.batch_normalization(x,
40 | training=train_logical,
41 | momentum=0.9,
42 | epsilon=0.001,
43 | center=True,
44 | scale=True)
45 | x = leak_relu(x, 0.2)
46 | # x = tf.nn.relu(x)
47 | return x
48 |
49 |
50 | def passthrough_layer(a, b, kernel, depth, size, train_logical, name):
51 | b = conv_layer(b, kernel, depth, train_logical, name)
52 | b = tf.space_to_depth(b, size)
53 | y = tf.concat([a, b], axis=3)
54 | return y
55 |
56 |
57 | def slice_tensor(x, start, end=None):
58 | """
59 | Get tensor slices
60 | param x (array):
61 | param start (int):
62 | param end (int):
63 | return (array):
64 | """
65 | if end < 0:
66 | y = x[..., start:]
67 | else:
68 | if end is None:
69 | end = start
70 | y = x[..., start:end + 1]
71 | return y
72 |
73 |
74 | def iou_wh(box1_wh, box2_wh):
75 | """
76 | param box1_wh (list, tuple): Width and height of a box
77 | param box2_wh (list, tuple): Width and height of a box
78 | return (float): iou
79 | """
80 | min_w = min(box1_wh[0], box2_wh[0])
81 | min_h = min(box1_wh[1], box2_wh[1])
82 | area_r1 = box1_wh[0] * box1_wh[1]
83 | area_r2 = box2_wh[0] * box2_wh[1]
84 | intersect = min_w * min_h
85 | union = area_r1 + area_r2 - intersect
86 | return intersect / union
87 |
88 |
89 | def get_grid_cell(roi, img_w, img_h, grid_w, grid_h): # roi[x, y, w, h, rz]
90 | """
91 | Get the grid cell into which the object falls
92 | param roi : [x, y, w, h, rz]
93 | param img_w: The width of images
94 | param img_h: The height of images
95 | param grid_w:
96 | param grid_h:
97 | return (int, int):
98 | """
99 | x_center = roi[0]
100 | y_center = roi[1]
101 | grid_x = np.minimum(int(grid_w * x_center / img_w), grid_w-1)
102 | grid_y = np.minimum(int(grid_h * y_center / img_h), grid_h-1)
103 | return grid_x, grid_y
104 |
105 |
106 | def get_active_anchors(box_w_h, anchors, iou_th):
107 | """
108 | Get the index of the anchor that matches the ground truth box
109 | param box_w_h (list, tuple): Width and height of a box
110 | param anchors (array): anchors
111 | param iou_th: Match threshold
112 | return (list):
113 | """
114 | index = []
115 | iou_max, index_max = 0, 0
116 | for i, a in enumerate(anchors):
117 | iou = iou_wh(box_w_h, a)
118 | if iou > iou_th:
119 | index.append(i)
120 | if iou > iou_max:
121 | iou_max, index_max = iou, i
122 | if len(index) == 0:
123 | index.append(index_max)
124 | return index
125 |
126 |
127 | def roi2label(roi, anchor, img_w, img_h, grid_w, grid_h):
128 | """
129 | Encode the label to match the model output format
130 | param roi: x, y, w, h, angle
131 |
132 | return: encoded label
133 | """
134 | x_center = roi[0]
135 | y_center = roi[1]
136 | w = grid_w * roi[2] / img_w
137 | h = grid_h * roi[3] / img_h
138 | anchor_w = grid_w * anchor[0] / img_w
139 | anchor_h = grid_h * anchor[1] / img_h
140 | grid_x = grid_w * x_center / img_w
141 | grid_y = grid_h * y_center / img_h
142 | grid_x_offset = grid_x - int(grid_x)
143 | grid_y_offset = grid_y - int(grid_y)
144 | roi_w_scale = np.log(w / anchor_w + 1e-16)
145 | roi_h_scale = np.log(h / anchor_h + 1e-16)
146 | re = np.cos(roi[4])
147 | im = np.sin(roi[4])
148 | label = [grid_x_offset, grid_y_offset, roi_w_scale, roi_h_scale, re, im]
149 | return label
150 |
151 |
152 | def encode_label(labels, anchors, img_w, img_h, grid_w, grid_h, iou_th):
153 | """
154 | Encode the label to match the model output format
155 | param labels (array): x, y, w, h, angle
156 | param anchors (array): anchors
157 | return: encoded label
158 | """
159 | anchors_on_image = np.array([img_w, img_h]) * anchors / np.array([80, 60])
160 | n_anchors = np.shape(anchors_on_image)[0]
161 | label_encoded = np.zeros([grid_h, grid_w, n_anchors, (6 + 1 + 1)],
162 | dtype=np.float32)
163 | for i in range(labels.shape[0]):
164 | rois = labels[i][1:]
165 | classes = np.array(labels[i][0], dtype=np.int32)
166 | active_indexes = get_active_anchors(rois[2:4], anchors_on_image, iou_th)
167 | grid_x, grid_y = get_grid_cell(rois, img_w, img_h, grid_w, grid_h)
168 | for active_index in active_indexes:
169 | anchor_label = roi2label(rois, anchors_on_image[active_index],
170 | img_w, img_h, grid_w, grid_h)
171 | label_encoded[grid_y, grid_x, active_index] = np.concatenate(
172 | (anchor_label, [classes], [1.0]))
173 | return label_encoded
174 |
175 |
176 | def yolo_net(x, train_logical):
177 | """darknet"""
178 | x = conv_layer(x, (3, 3), 24, train_logical, 'conv1')
179 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool1')
180 | x = conv_layer(x, (3, 3), 48, train_logical, 'conv2')
181 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool2')
182 |
183 | x = conv_layer(x, (3, 3), 64, train_logical, 'conv3')
184 | x = conv_layer(x, (1, 1), 32, train_logical, 'conv4')
185 | x = conv_layer(x, (3, 3), 64, train_logical, 'conv5')
186 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool5')
187 |
188 | x = conv_layer(x, (3, 3), 128, train_logical, 'conv6')
189 | x = conv_layer(x, (1, 1), 64, train_logical, 'conv7')
190 | x = conv_layer(x, (3, 3), 128, train_logical, 'conv8')
191 | x = max_pool_layer(x, (2, 2), (2, 2), 'maxpool8')
192 |
193 | # x = conv_layer(x, (3, 3), 512, train_logical, 'conv9')
194 | # x = conv_layer(x, (1, 1), 256, train_logical, 'conv10')
195 | x = conv_layer(x, (3, 3), 512, train_logical, 'conv11')
196 | x = conv_layer(x, (1, 1), 256, train_logical, 'conv12')
197 | passthrough = conv_layer(x, (3, 3), 512, train_logical, 'conv13')
198 | x = max_pool_layer(passthrough, (2, 2), (2, 2), 'maxpool13')
199 |
200 | # x = conv_layer(x, (3, 3), 1024, train_logical, 'conv14')
201 | # x = conv_layer(x, (1, 1), 512, train_logical, 'conv15')
202 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv16')
203 | x = conv_layer(x, (1, 1), 512, train_logical, 'conv17')
204 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv18')
205 |
206 | x = passthrough_layer(x, passthrough, (3, 3), 64, 2, train_logical,
207 | 'conv21')
208 | x = conv_layer(x, (3, 3), 1024, train_logical, 'conv19')
209 | x = conv_layer(x, (1, 1), N_ANCHORS * (7 + N_CLASSES), train_logical,
210 | 'conv20') # x,y,w,l,re,im,conf + 8 class
211 | y = tf.reshape(x,
212 | shape=(-1, GRID_H, GRID_W, N_ANCHORS, 7 + N_CLASSES),
213 | name='y')
214 | return y
215 |
216 |
217 | def yolo_loss(pred, label, batch_size):
218 | mask = slice_tensor(label, 7, 7)
219 | label = slice_tensor(label, 0, 6)
220 | mask = tf.cast(tf.reshape(mask, shape=(-1, GRID_H, GRID_W, N_ANCHORS)),
221 | tf.bool)
222 | with tf.name_scope('mask'):
223 | masked_label = tf.boolean_mask(label, mask)
224 | masked_pred = tf.boolean_mask(pred, mask)
225 | neg_masked_pred = tf.boolean_mask(pred, tf.logical_not(mask))
226 | with tf.name_scope('pred'):
227 | masked_pred_xy = tf.sigmoid(slice_tensor(masked_pred, 0, 1))
228 | masked_pred_wh = slice_tensor(masked_pred, 2, 3)
229 | masked_pred_re = 2 * tf.sigmoid(slice_tensor(masked_pred, 4, 4)) - 1
230 | masked_pred_im = 2 * tf.sigmoid(slice_tensor(masked_pred, 5, 5)) - 1
231 | masked_pred_o = tf.sigmoid(slice_tensor(masked_pred, 6, 6))
232 |
233 | masked_pred_no_o = tf.sigmoid(slice_tensor(neg_masked_pred, 6, 6))
234 | # masked_pred_c = tf.nn.sigmoid(slice_tensor(masked_pred, 7, -1))
235 | masked_pred_c = tf.nn.softmax(slice_tensor(masked_pred, 7, -1))
236 | # masked_pred_no_c = tf.nn.sigmoid(slice_tensor(neg_masked_pred, 7, -1))
237 | # print (masked_pred_c, masked_pred_o, masked_pred_no_o)
238 |
239 | with tf.name_scope('lab'):
240 | masked_label_xy = slice_tensor(masked_label, 0, 1)
241 | masked_label_wh = slice_tensor(masked_label, 2, 3)
242 | masked_label_re = slice_tensor(masked_label, 4, 4)
243 | masked_label_im = slice_tensor(masked_label, 5, 5)
244 | masked_label_class = slice_tensor(masked_label, 6, 6)
245 | masked_label_class_vec = tf.reshape(tf.one_hot(tf.cast(
246 | masked_label_class, tf.int32),
247 | depth=N_CLASSES),
248 | shape=(-1, N_CLASSES))
249 | with tf.name_scope('merge'):
250 | with tf.name_scope('loss_xy'):
251 | loss_xy = tf.reduce_sum(
252 | tf.square(masked_pred_xy - masked_label_xy)) / batch_size
253 | with tf.name_scope('loss_wh'):
254 | loss_wh = tf.reduce_sum(
255 | tf.square(masked_pred_wh - masked_label_wh)) / batch_size
256 | with tf.name_scope('loss_re'):
257 | loss_re = tf.reduce_sum(
258 | tf.square(masked_pred_re - masked_label_re)) / batch_size
259 | with tf.name_scope('loss_im'):
260 | loss_im = tf.reduce_sum(
261 | tf.square(masked_pred_im - masked_label_im)) / batch_size
262 | with tf.name_scope('loss_obj'):
263 | loss_obj = tf.reduce_sum(tf.square(masked_pred_o - 1)) / batch_size
264 | # loss_obj = tf.reduce_sum(-tf.log(masked_pred_o+0.000001))*10
265 | with tf.name_scope('loss_no_obj'):
266 | loss_no_obj = tf.reduce_sum(
267 | tf.square(masked_pred_no_o)) * 0.5 / batch_size
268 | # loss_no_obj = tf.reduce_sum(-tf.log(1-masked_pred_no_o+0.000001))
269 | with tf.name_scope('loss_class'):
270 | # loss_c = tf.reduce_sum(tf.square(masked_pred_c - masked_label_c_vec))
271 | loss_c = (tf.reduce_sum(-tf.log(masked_pred_c + 0.000001) * masked_label_class_vec)
272 | + tf.reduce_sum(-tf.log(1 - masked_pred_c + 0.000001) * (1 - masked_label_class_vec))) / batch_size
273 | # + tf.reduce_sum(-tf.log(1 - masked_pred_no_c+0.000001)) * 0.1
274 | # loss = (loss_xy + loss_wh+ loss_re + loss_im+ lambda_coord*loss_obj) + lambda_no_obj*loss_no_obj + loss_c
275 | loss = (loss_xy + loss_wh + loss_re +
276 | loss_im) * 5 + loss_obj + loss_no_obj + loss_c
277 | return loss, loss_xy, loss_wh, loss_re, loss_im, loss_obj, loss_no_obj, loss_c
278 |
--------------------------------------------------------------------------------
/utils/kitti_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import numpy as np
4 | import cv2
5 |
6 | # classes
7 | class_list = [
8 | 'Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
9 | 'Misc'
10 | ]
11 | class_dict = {
12 | 'Car': 0,
13 | 'Van': 1,
14 | 'Truck': 2,
15 | 'Pedestrian': 3,
16 | 'Person_sitting': 4,
17 | 'Cyclist': 5,
18 | 'Tram': 6,
19 | 'Misc': 7
20 | }
21 |
22 |
23 | def read_label_from_txt(label_path):
24 | """
25 | Read label from txt file.
26 | label: class, cx, cy, w, l, rz(rotation angle)
27 | """
28 | bounding_box = []
29 | with open(label_path, "r") as f:
30 | labels = f.readlines()
31 | for label in labels:
32 | if not label:
33 | continue
34 | label = label.strip().split(' ')
35 | label[0] = class_dict[label[0]]
36 | bounding_box.append(label)
37 | if bounding_box:
38 | return np.array(bounding_box, dtype=np.float32)
39 | else:
40 | return None
41 |
42 |
43 | def read_anchors_from_file(file_path):
44 | """
45 | Read anchors from the configuration file
46 | """
47 | anchors = []
48 | with open(file_path, 'r') as file:
49 | for line in file.read().splitlines():
50 | anchors.append(list(map(float, line.split())))
51 | return np.array(anchors)
52 |
53 |
54 | def read_class_flag(file_path):
55 | """
56 | Read class flags for visualization
57 | """
58 | classes, names, colors = [], [], []
59 | with open(file_path, 'r') as file:
60 | lines = file.read().splitlines()
61 | for line in lines:
62 | cls, name, color = line.split()
63 | classes.append(int(cls))
64 | names.append(name)
65 | colors.append(eval(color))
66 | return classes, names, colors
67 |
68 |
69 | def calculate_angle(im, re):
70 | """
71 | param: im(float): imaginary parts of the plural
72 | param: re(float): real parts of the plural
73 | return: The angle at which the objects rotate
74 | around the Z axis in the velodyne coordinate system
75 | """
76 | if re > 0:
77 | return np.arctan(im / re)
78 | elif im < 0:
79 | return -np.pi + np.arctan(im / re)
80 | else:
81 | return np.pi + np.arctan(im / re)
82 |
83 |
84 | def draw_rotated_box(img, cy, cx, w, h, angle, color):
85 | """
86 | param: img(array): RGB image
87 | param: cy(int, float): Here cy is cx in the image coordinate system
88 | param: cx(int, float): Here cx is cy in the image coordinate system
89 | param: w(int, float): box's width
90 | param: h(int, float): box's height
91 | param: angle(float): rz
92 | param: color(tuple, list): the color of box, (R, G, B)
93 | """
94 | left = int(cy - w / 2)
95 | top = int(cx - h / 2)
96 | right = int(cx + h / 2)
97 | bottom = int(cy + h / 2)
98 | ro = np.sqrt(pow(left - cy, 2) + pow(top - cx, 2))
99 | a1 = np.arctan((w / 2) / (h / 2))
100 | a2 = -np.arctan((w / 2) / (h / 2))
101 | a3 = -np.pi + a1
102 | a4 = np.pi - a1
103 | rotated_p1_y = cy + int(ro * np.sin(angle + a1))
104 | rotated_p1_x = cx + int(ro * np.cos(angle + a1))
105 | rotated_p2_y = cy + int(ro * np.sin(angle + a2))
106 | rotated_p2_x = cx + int(ro * np.cos(angle + a2))
107 | rotated_p3_y = cy + int(ro * np.sin(angle + a3))
108 | rotated_p3_x = cx + int(ro * np.cos(angle + a3))
109 | rotated_p4_y = cy + int(ro * np.sin(angle + a4))
110 | rotated_p4_x = cx + int(ro * np.cos(angle + a4))
111 | center_p1p2y = int((rotated_p1_y + rotated_p2_y) * 0.5)
112 | center_p1p2x = int((rotated_p1_x + rotated_p2_x) * 0.5)
113 | cv2.line(img, (rotated_p1_y, rotated_p1_x), (rotated_p2_y, rotated_p2_x),
114 | color, 1)
115 | cv2.line(img, (rotated_p2_y, rotated_p2_x), (rotated_p3_y, rotated_p3_x),
116 | color, 1)
117 | cv2.line(img, (rotated_p3_y, rotated_p3_x), (rotated_p4_y, rotated_p4_x),
118 | color, 1)
119 | cv2.line(img, (rotated_p4_y, rotated_p4_x), (rotated_p1_y, rotated_p1_x),
120 | color, 1)
121 | cv2.line(img, (center_p1p2y, center_p1p2x), (cy, cx), color, 1)
122 |
123 |
124 | def get_corner_gtbox(box):
125 | """
126 | param: box(tuple, list): cx, cy, w, l
127 | return: (tuple): x_min, y_min, x_max, y_max
128 | """
129 | bx = box[0]
130 | by = box[1]
131 | bw = box[2]
132 | bl = box[3]
133 | top = int((by - bl / 2.0))
134 | left = int((bx - bw / 2.0))
135 | right = int((bx + bw / 2.0))
136 | bottom = int((by + bl / 2.0))
137 | return left, top, right, bottom
138 |
139 |
140 | def remove_points(point_cloud, boundary_condition):
141 | """
142 | param point_cloud(array): Original point cloud data
143 | param boundary_condition(dict): The boundary of the area of interest
144 | return (array): Point cloud data within the area of interest
145 | """
146 | # Boundary condition
147 | min_x = boundary_condition['minX']
148 | max_x = boundary_condition['maxX']
149 | min_y = boundary_condition['minY']
150 | max_y = boundary_condition['maxY']
151 | min_z = boundary_condition['minZ']
152 | max_z = boundary_condition['maxZ']
153 | # Remove the point out of range x,y,z
154 | mask = np.where((point_cloud[:, 0] >= min_x) & (point_cloud[:, 0] <= max_x)
155 | & (point_cloud[:, 1] >= min_y) & (point_cloud[:, 1] <= max_y)
156 | & (point_cloud[:, 2] >= min_z) & (point_cloud[:, 2] <= max_z))
157 | point_cloud = point_cloud[mask]
158 | point_cloud[:, 2] = point_cloud[:, 2] + 2
159 | return point_cloud
160 |
161 |
162 | def make_bv_feature(point_cloud_):
163 | """
164 | param point_cloud_ (array): Point cloud data within the area of interest
165 | return (array): RGB map
166 | """
167 | # 1024 x 1024 x 3
168 | Height = 1024 + 1
169 | Width = 1024 + 1
170 | # Discretize Feature Map
171 | point_cloud = np.copy(point_cloud_)
172 | point_cloud[:, 0] = np.int_(np.floor(point_cloud[:, 0] / 60.0 * 768))
173 | point_cloud[:, 1] = np.int_(
174 | np.floor(point_cloud[:, 1] / 40.0 * 512) + Width / 2)
175 | # sort-3times
176 | indices = np.lexsort(
177 | (-point_cloud[:, 2], point_cloud[:, 1], point_cloud[:, 0]))
178 | point_cloud = point_cloud[indices]
179 | # Height Map
180 | height_map = np.zeros((Height, Width))
181 | _, indices = np.unique(point_cloud[:, 0:2], axis=0, return_index=True)
182 | point_cloud_frac = point_cloud[indices]
183 | # some important problem is image coordinate is (y,x), not (x,y)
184 | height_map[np.int_(point_cloud_frac[:, 0]),
185 | np.int_(point_cloud_frac[:, 1])] = point_cloud_frac[:, 2]
186 | # Intensity Map & DensityMap
187 | intensity_map = np.zeros((Height, Width))
188 | density_map = np.zeros((Height, Width))
189 | _, indices, counts = np.unique(point_cloud[:, 0:2],
190 | axis=0,
191 | return_index=True,
192 | return_counts=True)
193 | point_cloud_top = point_cloud[indices]
194 | normalized_counts = np.minimum(1.0, np.log(counts + 1) / np.log(64))
195 | intensity_map[np.int_(point_cloud_top[:, 0]),
196 | np.int_(point_cloud_top[:, 1])] = point_cloud_top[:, 3]
197 | density_map[np.int_(point_cloud_top[:, 0]),
198 | np.int_(point_cloud_top[:, 1])] = normalized_counts
199 |
200 | rgb_map = np.zeros((Height, Width, 3))
201 | rgb_map[:, :, 0] = density_map # r_map
202 | rgb_map[:, :, 1] = height_map / 3.26 # g_map
203 | rgb_map[:, :, 2] = intensity_map # b_map
204 |
205 | save = np.zeros((768, 1024, 3))
206 | save = rgb_map[0:768, 0:1024, :]
207 | return save
208 |
209 |
210 | def get_target(label_file, transform):
211 | """
212 |
213 | param label_file (str): The kitti label path
214 | param transform (array): Coordinate transformation matrix
215 | return (array): label
216 | """
217 | target = np.zeros([50, 6], dtype=np.float32)
218 | with open(label_file, 'r') as f:
219 | lines = f.readlines()
220 | num_obj = len(lines)
221 | index = 0
222 | for j in range(num_obj):
223 | obj = lines[j].strip().split(' ')
224 | obj_class = obj[0].strip()
225 | if obj_class in class_list:
226 | t_lidar, box3d_corner, rz = box3d_cam_to_velo(
227 | obj[8:], transform) # get target 3D object location x,y
228 | location_x = t_lidar[0][0]
229 | location_y = t_lidar[0][1]
230 | if (location_x > 0) & (location_x < 60) & (location_y > -40) & (
231 | location_y < 40):
232 | target[index][2] = t_lidar[0][0] / 60.0 # make sure target inside the covering area (0,1)
233 | target[index][1] = (t_lidar[0][1] + 40) / 80.0 # we should put this in [0,1], so divide max_size 80 m
234 | obj_width = obj[9].strip()
235 | obj_length = obj[10].strip()
236 | target[index][3] = float(obj_width) / 80.0
237 | target[index][4] = float(obj_length) / 60.0 # get target width ,length
238 | target[index][5] = rz
239 | for i in range(len(class_list)):
240 | if obj_class == class_list[i]: # get target class
241 | target[index][0] = i
242 | index = index + 1
243 | return target
244 |
245 |
246 | def box3d_cam_to_velo(box3d, transform):
247 | def project_cam2velo(cam, transform):
248 | T = np.zeros([4, 4], dtype=np.float32)
249 | T[:3, :] = transform
250 | T[3, 3] = 1
251 | T_inv = np.linalg.inv(T)
252 | lidar_loc_ = np.dot(T_inv, cam)
253 | lidar_loc = lidar_loc_[:3]
254 | return lidar_loc.reshape(1, 3)
255 |
256 | def ry_to_rz(ry):
257 | """
258 | param ry (float): yaw angle in cam coordinate system
259 | return: (flaot): yaw angle in velodyne coordinate system
260 | """
261 | angle = -ry - np.pi / 2
262 | if angle >= np.pi:
263 | angle -= np.pi
264 | if angle < -np.pi:
265 | angle = 2 * np.pi + angle
266 | return angle
267 |
268 | h, w, l, tx, ty, tz, ry = [float(i) for i in box3d]
269 | cam = np.ones([4, 1])
270 | cam[0] = tx
271 | cam[1] = ty
272 | cam[2] = tz
273 | t_lidar = project_cam2velo(cam, transform)
274 | box = np.array(
275 | [[-l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2],
276 | [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2],
277 | [0, 0, 0, 0, h, h, h, h]])
278 |
279 | rz = ry_to_rz(ry)
280 | rot_mat = np.array([[np.cos(rz), -np.sin(rz), 0.0],
281 | [np.sin(rz), np.cos(rz), 0.0], [0.0, 0.0, 1.0]])
282 | velo_box = np.dot(rot_mat, box)
283 | corner_pos_in_velo = velo_box + np.tile(t_lidar, (8, 1)).T
284 | box3d_corner = corner_pos_in_velo.transpose()
285 | return t_lidar, box3d_corner.astype(np.float32), rz
286 |
287 |
288 | def load_kitti_calib(calib_file):
289 | """
290 | load projection matrix
291 | """
292 | with open(calib_file) as fi:
293 | lines = fi.readlines()
294 | assert (len(lines) == 8)
295 |
296 | obj = lines[0].strip().split(' ')[1:]
297 | P0 = np.array(obj, dtype=np.float32)
298 | obj = lines[1].strip().split(' ')[1:]
299 | P1 = np.array(obj, dtype=np.float32)
300 | obj = lines[2].strip().split(' ')[1:]
301 | P2 = np.array(obj, dtype=np.float32)
302 | obj = lines[3].strip().split(' ')[1:]
303 | P3 = np.array(obj, dtype=np.float32)
304 | obj = lines[4].strip().split(' ')[1:]
305 | R0 = np.array(obj, dtype=np.float32)
306 | obj = lines[5].strip().split(' ')[1:]
307 | Tr_velo_to_cam = np.array(obj, dtype=np.float32)
308 | obj = lines[6].strip().split(' ')[1:]
309 | Tr_imu_to_velo = np.array(obj, dtype=np.float32)
310 |
311 | return {
312 | 'P2': P2.reshape(3, 4),
313 | 'R0': R0.reshape(3, 3),
314 | 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)
315 | }
316 |
317 |
318 | def angle_rz_to_ry(rz):
319 | """
320 | param rz (float): yaw angle in velodyne coordinate system
321 | return (float): yaw angle in cam coordinate system
322 | """
323 | angle = -rz - np.pi / 2
324 | if angle < -np.pi:
325 | angle = 2 * np.pi + angle
326 | return angle
327 |
328 |
329 | def coord_image_to_velo(hy, wx):
330 | """
331 | Convert image coordinates to velodyne coordinates
332 | """
333 | velo_x = hy * 60 / 768.0
334 | velo_y = (wx - 512) * 40 / 512.0
335 | return velo_x, velo_y
336 |
337 |
338 | def coord_velo_to_cam(velo_x, velo_y, transform):
339 | """
340 | Convert velodyne coordinates to image coordinates
341 | """
342 | T = np.zeros([4, 4], dtype=np.float32)
343 | T[:3, :] = transform
344 | T[3, 3] = 1
345 | velo_coord = np.array([[velo_x], [velo_y], [1.5], [1]])
346 | cam_coord = np.dot(T, velo_coord)
347 | return cam_coord[0][0], cam_coord[2][0]
348 |
349 |
350 | # anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52]
351 |
--------------------------------------------------------------------------------
/dataset/augument.py:
--------------------------------------------------------------------------------
1 | """
2 | The data augmentation operations of the original SSD implementation.
3 | Copyright (C) 2018 Pierluigi Ferrari
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 | http://www.apache.org/licenses/LICENSE-2.0
8 | Unless required by applicable law or agreed to in writing, software
9 | distributed under the License is distributed on an "AS IS" BASIS,
10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | See the License for the specific language governing permissions and
12 | limitations under the License.
13 | """
14 |
15 | import cv2
16 | import inspect
17 | import numpy as np
18 |
19 |
20 | class RandomScaleAugmentation:
21 | """
22 | Reproduces the data augmentation pipeline used in the training of the original
23 | Caffe implementation of SSD.
24 | """
25 |
26 | def __init__(self,
27 | img_height=768,
28 | img_width=1024,
29 | labels_format={
30 | 'class_id': 0,
31 | 'xmin': 1,
32 | 'ymin': 2,
33 | 'xmax': 3,
34 | 'ymax': 4
35 | }):
36 | """
37 | Arguments:
38 | height (int): The desired height of the output images in pixels.
39 | width (int): The desired width of the output images in pixels.
40 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the
41 | background pixels of the translated images.
42 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
43 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
44 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
45 | """
46 |
47 | self.labels_format = labels_format
48 | self.random_crop = SSDRandomCrop(labels_format=self.labels_format)
49 | # This box filter makes sure that the resized images don't contain any degenerate boxes.
50 | # Resizing the images could lead the boxes to becomes smaller. For boxes that are already
51 | # pretty small, that might result in boxes with height and/or width zero, which we obviously
52 | # cannot allow.
53 | self.box_filter = BoxFilter(check_overlap=False,
54 | check_min_area=False,
55 | check_degenerate=True,
56 | labels_format=self.labels_format)
57 |
58 | self.resize = ResizeRandomInterp(height=img_height,
59 | width=img_width,
60 | interpolation_modes=[
61 | cv2.INTER_NEAREST,
62 | cv2.INTER_LINEAR, cv2.INTER_CUBIC,
63 | cv2.INTER_AREA, cv2.INTER_LANCZOS4
64 | ],
65 | box_filter=self.box_filter,
66 | labels_format=self.labels_format)
67 |
68 | self.sequence = [self.random_crop, self.resize]
69 |
70 | def __call__(self, image, labels, return_inverter=False):
71 | self.random_crop.labels_format = self.labels_format
72 | self.resize.labels_format = self.labels_format
73 | inverters = []
74 | for transform in self.sequence:
75 | if return_inverter and ('return_inverter' in inspect.signature(
76 | transform).parameters):
77 | image, labels, inverter = transform(image,
78 | labels,
79 | return_inverter=True)
80 | inverters.append(inverter)
81 | else:
82 | image, labels = transform(image, labels)
83 |
84 | if return_inverter:
85 | return image, labels, inverters[::-1]
86 | else:
87 | return image, labels
88 |
89 |
90 | class SSDRandomCrop:
91 | """
92 | Performs the same random crops as defined by the `batch_sampler` instructions
93 | of the original Caffe implementation of SSD. A description of this random cropping
94 | strategy can also be found in the data augmentation section of the paper:
95 | https://arxiv.org/abs/1512.02325
96 | """
97 |
98 | def __init__(self,
99 | labels_format={
100 | 'class_id': 0,
101 | 'xmin': 1,
102 | 'ymin': 2,
103 | 'xmax': 3,
104 | 'ymax': 4
105 | }):
106 | """
107 | Arguments:
108 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
109 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
110 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
111 | """
112 |
113 | self.labels_format = labels_format
114 |
115 | # This randomly samples one of the lower IoU bounds defined
116 | # by the `sample_space` every time it is called.
117 | self.bound_generator = BoundGenerator(
118 | sample_space=((None, None), (0.1, None), (0.3, None), (0.5, None),
119 | (0.7, None), (0.9, None)),
120 | weights=None)
121 |
122 | # Produces coordinates for candidate patches such that the height
123 | # and width of the patches are between 0.3 and 1.0 of the height
124 | # and width of the respective image and the aspect ratio of the
125 | # patches is between 0.5 and 2.0.
126 | self.patch_coord_generator = PatchCoordinateGenerator(
127 | must_match='h_w',
128 | min_scale=0.6,
129 | max_scale=1.0,
130 | scale_uniformly=False,
131 | min_aspect_ratio=0.5,
132 | max_aspect_ratio=2.0)
133 |
134 | # Filters out boxes whose center point does not lie within the
135 | # chosen patches.
136 | self.box_filter = BoxFilter(check_overlap=True,
137 | check_min_area=False,
138 | check_degenerate=False,
139 | overlap_criterion='center_point',
140 | labels_format=self.labels_format)
141 |
142 | # Determines whether a given patch is considered a valid patch.
143 | # Defines a patch to be valid if at least one ground truth bounding box
144 | # (n_boxes_min == 1) has an IoU overlap with the patch that
145 | # meets the requirements defined by `bound_generator`.
146 | self.image_validator = ImageValidator(overlap_criterion='iou',
147 | n_boxes_min=1,
148 | labels_format=self.labels_format,
149 | border_pixels='half')
150 |
151 | # Performs crops according to the parameters set in the objects above.
152 | # Runs until either a valid patch is found or the original input image
153 | # is returned unaltered. Runs a maximum of 50 trials to find a valid
154 | # patch for each new sampled IoU threshold. Every 50 trials, the original
155 | # image is returned as is with probability (1 - prob) = 0.143.
156 | self.random_crop = RandomPatchInf(
157 | patch_coord_generator=self.patch_coord_generator,
158 | box_filter=self.box_filter,
159 | image_validator=self.image_validator,
160 | bound_generator=self.bound_generator,
161 | n_trials_max=50,
162 | clip_boxes=True,
163 | prob=0.857,
164 | labels_format=self.labels_format)
165 |
166 | def __call__(self, image, labels=None, return_inverter=False):
167 | self.random_crop.labels_format = self.labels_format
168 | return self.random_crop(image, labels, return_inverter)
169 |
170 |
171 | class BoundGenerator:
172 | """
173 | Generates pairs of floating point values that represent lower and upper bounds
174 | from a given sample space.
175 | """
176 |
177 | def __init__(self,
178 | sample_space=((0.1, None), (0.3, None), (0.5, None),
179 | (0.7, None), (0.9, None), (None, None)),
180 | weights=None):
181 | """
182 | Arguments:
183 | sample_space (list or tuple): A list, tuple, or array-like object of shape
184 | `(n, 2)` that contains `n` samples to choose from, where each sample
185 | is a 2-tuple of scalars and/or `None` values.
186 | weights (list or tuple, optional): A list or tuple representing the distribution
187 | over the sample space. If `None`, a uniform distribution will be assumed.
188 | """
189 |
190 | if (not (weights is None)) and len(weights) != len(sample_space):
191 | raise ValueError(
192 | "`weights` must either be `None` for uniform distribution or have the same length as `sample_space`."
193 | )
194 |
195 | self.sample_space = []
196 | for bound_pair in sample_space:
197 | if len(bound_pair) != 2:
198 | raise ValueError(
199 | "All elements of the sample space must be 2-tuples.")
200 | bound_pair = list(bound_pair)
201 | if bound_pair[0] is None: bound_pair[0] = 0.0
202 | if bound_pair[1] is None: bound_pair[1] = 1.0
203 | if bound_pair[0] > bound_pair[1]:
204 | raise ValueError(
205 | "For all sample space elements, the lower bound "
206 | "cannot be greater than the upper bound.")
207 | self.sample_space.append(bound_pair)
208 |
209 | self.sample_space_size = len(self.sample_space)
210 |
211 | if weights is None:
212 | self.weights = [1.0 / self.sample_space_size
213 | ] * self.sample_space_size
214 | else:
215 | self.weights = weights
216 |
217 | def __call__(self):
218 | """
219 | Returns:
220 | An item of the sample space, i.e. a 2-tuple of scalars.
221 | """
222 | i = np.random.choice(self.sample_space_size, p=self.weights)
223 | return self.sample_space[i]
224 |
225 |
226 | class BoxFilter:
227 | """
228 | Returns all bounding boxes that are valid with respect to a the defined criteria.
229 | """
230 |
231 | def __init__(self,
232 | check_overlap=True,
233 | check_min_area=True,
234 | check_degenerate=True,
235 | overlap_criterion='center_point',
236 | overlap_bounds=(0.3, 1.0),
237 | min_area=10,
238 | labels_format={
239 | 'class_id': 0,
240 | 'xmin': 1,
241 | 'ymin': 2,
242 | 'xmax': 3,
243 | 'ymax': 4
244 | },
245 | border_pixels='half'):
246 | """
247 | Arguments:
248 | check_overlap (bool, optional): Whether or not to enforce the overlap requirements defined by
249 | `overlap_criterion` and `overlap_bounds`. Sometimes you might want to use the box filter only
250 | to enforce a certain minimum area for all boxes (see next argument), in such cases you can
251 | turn the overlap requirements off.
252 | check_min_area (bool, optional): Whether or not to enforce the minimum area requirement defined
253 | by `min_area`. If `True`, any boxes that have an area (in pixels) that is smaller than `min_area`
254 | will be removed from the labels of an image. Bounding boxes below a certain area aren't useful
255 | training examples. An object that takes up only, say, 5 pixels in an image is probably not
256 | recognizable anymore, neither for a human, nor for an object detection model. It makes sense
257 | to remove such boxes.
258 | check_degenerate (bool, optional): Whether or not to check for and remove degenerate bounding boxes.
259 | Degenerate bounding boxes are boxes that have `xmax <= xmin` and/or `ymax <= ymin`. In particular,
260 | boxes with a width and/or height of zero are degenerate. It is obviously important to filter out
261 | such boxes, so you should only set this option to `False` if you are certain that degenerate
262 | boxes are not possible in your data and processing chain.
263 | overlap_criterion (str, optional): Can be either of 'center_point', 'iou', or 'area'. Determines
264 | which boxes are considered valid with respect to a given image. If set to 'center_point',
265 | a given bounding box is considered valid if its center point lies within the image.
266 | If set to 'area', a given bounding box is considered valid if the quotient of its intersection
267 | area with the image and its own area is within the given `overlap_bounds`. If set to 'iou', a given
268 | bounding box is considered valid if its IoU with the image is within the given `overlap_bounds`.
269 | overlap_bounds (list or BoundGenerator, optional): Only relevant if `overlap_criterion` is 'area' or 'iou'.
270 | Determines the lower and upper bounds for `overlap_criterion`. Can be either a 2-tuple of scalars
271 | representing a lower bound and an upper bound, or a `BoundGenerator` object, which provides
272 | the possibility to generate bounds randomly.
273 | min_area (int, optional): Only relevant if `check_min_area` is `True`. Defines the minimum area in
274 | pixels that a bounding box must have in order to be valid. Boxes with an area smaller than this
275 | will be removed.
276 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
277 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
278 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
279 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes.
280 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong
281 | to the boxes. If 'exclude', the border pixels do not belong to the boxes.
282 | If 'half', then one of each of the two horizontal and vertical borders belong
283 | to the boxex, but not the other.
284 | """
285 | if not isinstance(overlap_bounds, (list, tuple, BoundGenerator)):
286 | raise ValueError(
287 | "`overlap_bounds` must be either a 2-tuple of scalars or a `BoundGenerator` object."
288 | )
289 | if isinstance(
290 | overlap_bounds,
291 | (list, tuple)) and (overlap_bounds[0] > overlap_bounds[1]):
292 | raise ValueError(
293 | "The lower bound must not be greater than the upper bound.")
294 | if not (overlap_criterion in {'iou', 'area', 'center_point'}):
295 | raise ValueError(
296 | "`overlap_criterion` must be one of 'iou', 'area', or 'center_point'."
297 | )
298 | self.overlap_criterion = overlap_criterion
299 | self.overlap_bounds = overlap_bounds
300 | self.min_area = min_area
301 | self.check_overlap = check_overlap
302 | self.check_min_area = check_min_area
303 | self.check_degenerate = check_degenerate
304 | self.labels_format = labels_format
305 | self.border_pixels = border_pixels
306 |
307 | def __call__(self, labels, image_height=None, image_width=None):
308 | """
309 | Arguments:
310 | labels (array): The labels to be filtered. This is an array with shape `(m,n)`, where
311 | `m` is the number of bounding boxes and `n` is the number of elements that defines
312 | each bounding box (box coordinates, class ID, etc.). The box coordinates are expected
313 | to be in the image's coordinate system.
314 | image_height (int): Only relevant if `check_overlap == True`. The height of the image
315 | (in pixels) to compare the box coordinates to.
316 | image_width (int): `check_overlap == True`. The width of the image (in pixels) to compare
317 | the box coordinates to.
318 |
319 | Returns:
320 | An array containing the labels of all boxes that are valid.
321 | """
322 |
323 | labels = np.copy(labels)
324 |
325 | xmin = self.labels_format['xmin']
326 | ymin = self.labels_format['ymin']
327 | xmax = self.labels_format['xmax']
328 | ymax = self.labels_format['ymax']
329 |
330 | # Record the boxes that pass all checks here.
331 | requirements_met = np.ones(shape=labels.shape[0], dtype=np.bool)
332 |
333 | if self.check_degenerate:
334 |
335 | non_degenerate = (labels[:, xmax] > labels[:, xmin]) * (
336 | labels[:, ymax] > labels[:, ymin])
337 | requirements_met *= non_degenerate
338 |
339 | if self.check_min_area:
340 |
341 | min_area_met = (labels[:, xmax] - labels[:, xmin]) * (
342 | labels[:, ymax] - labels[:, ymin]) >= self.min_area
343 | requirements_met *= min_area_met
344 |
345 | if self.check_overlap:
346 |
347 | # Get the lower and upper bounds.
348 | if isinstance(self.overlap_bounds, BoundGenerator):
349 | lower, upper = self.overlap_bounds()
350 | else:
351 | lower, upper = self.overlap_bounds
352 |
353 | # Compute which boxes are valid.
354 |
355 | if self.overlap_criterion == 'iou':
356 | # Compute the patch coordinates.
357 | image_coords = np.array([0, 0, image_width, image_height])
358 | # Compute the IoU between the patch and all of the ground truth boxes.
359 | image_boxes_iou = iou(image_coords,
360 | labels[:, [xmin, ymin, xmax, ymax]],
361 | coords='corners',
362 | mode='element-wise',
363 | border_pixels=self.border_pixels)
364 | requirements_met *= (image_boxes_iou >
365 | lower) * (image_boxes_iou <= upper)
366 |
367 | elif self.overlap_criterion == 'area':
368 | if self.border_pixels == 'half':
369 | d = 0
370 | elif self.border_pixels == 'include':
371 | d = 1 # If border pixels are supposed to belong to the bounding boxes,
372 | # we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`.
373 | elif self.border_pixels == 'exclude':
374 | d = -1 # If border pixels are not supposed to belong to the bounding boxes,
375 | # we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`.
376 | # Compute the areas of the boxes.
377 | box_areas = (labels[:, xmax] - labels[:, xmin] +
378 | d) * (labels[:, ymax] - labels[:, ymin] + d)
379 | # Compute the intersection area between the patch and all of the ground truth boxes.
380 | clipped_boxes = np.copy(labels)
381 | clipped_boxes[:, [ymin, ymax]] = np.clip(
382 | labels[:, [ymin, ymax]], a_min=0, a_max=image_height - 1)
383 | clipped_boxes[:, [xmin, xmax]] = np.clip(
384 | labels[:, [xmin, xmax]], a_min=0, a_max=image_width - 1)
385 | intersection_areas = (
386 | clipped_boxes[:, xmax] - clipped_boxes[:, xmin] + d) * (
387 | clipped_boxes[:, ymax] - clipped_boxes[:, ymin] + d
388 | ) # +1 because the border pixels belong to the box areas.
389 | # Check which boxes meet the overlap requirements.
390 | if lower == 0.0:
391 | mask_lower = intersection_areas > lower * box_areas # If `self.lower == 0`, we want to
392 | # make sure that boxes with area 0 don't count, hence the ">" sign instead of the ">=" sign.
393 | else:
394 | mask_lower = intersection_areas >= lower * box_areas # Especially for the case `self.lower == 1`
395 | # we want the ">=" sign, otherwise no boxes would count at all.
396 | mask_upper = intersection_areas <= upper * box_areas
397 | requirements_met *= mask_lower * mask_upper
398 |
399 | elif self.overlap_criterion == 'center_point':
400 | # Compute the center points of the boxes.
401 | cy = (labels[:, ymin] + labels[:, ymax]) / 2
402 | cx = (labels[:, xmin] + labels[:, xmax]) / 2
403 | # Check which of the boxes have center points within the cropped patch remove those that don't.
404 | requirements_met *= (cy >= 0.0) * (cy <= image_height - 1) * (
405 | cx >= 0.0) * (cx <= image_width - 1)
406 |
407 | return labels[requirements_met]
408 |
409 |
410 | class ImageValidator:
411 | """
412 | Returns `True` if a given minimum number of bounding boxes meets given overlap
413 | requirements with an image of a given height and width.
414 | """
415 |
416 | def __init__(self,
417 | overlap_criterion='center_point',
418 | bounds=(0.3, 1.0),
419 | n_boxes_min=1,
420 | labels_format={
421 | 'class_id': 0,
422 | 'xmin': 1,
423 | 'ymin': 2,
424 | 'xmax': 3,
425 | 'ymax': 4
426 | },
427 | border_pixels='half'):
428 | """
429 | Arguments:
430 | overlap_criterion (str, optional): Can be either of 'center_point', 'iou', or 'area'. Determines
431 | which boxes are considered valid with respect to a given image. If set to 'center_point',
432 | a given bounding box is considered valid if its center point lies within the image.
433 | If set to 'area', a given bounding box is considered valid if the quotient of its intersection
434 | area with the image and its own area is within `lower` and `upper`. If set to 'iou', a given
435 | bounding box is considered valid if its IoU with the image is within `lower` and `upper`.
436 | bounds (list or BoundGenerator, optional): Only relevant if `overlap_criterion` is 'area' or 'iou'.
437 | Determines the lower and upper bounds for `overlap_criterion`. Can be either a 2-tuple of scalars
438 | representing a lower bound and an upper bound, or a `BoundGenerator` object, which provides
439 | the possibility to generate bounds randomly.
440 | n_boxes_min (int or str, optional): Either a non-negative integer or the string 'all'.
441 | Determines the minimum number of boxes that must meet the `overlap_criterion` with respect to
442 | an image of the given height and width in order for the image to be a valid image.
443 | If set to 'all', an image is considered valid if all given boxes meet the `overlap_criterion`.
444 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
445 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
446 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
447 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes.
448 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong
449 | to the boxes. If 'exclude', the border pixels do not belong to the boxes.
450 | If 'half', then one of each of the two horizontal and vertical borders belong
451 | to the boxex, but not the other.
452 | """
453 | if not ((isinstance(n_boxes_min, int) and n_boxes_min > 0)
454 | or n_boxes_min == 'all'):
455 | raise ValueError(
456 | "`n_boxes_min` must be a positive integer or 'all'.")
457 | self.overlap_criterion = overlap_criterion
458 | self.bounds = bounds
459 | self.n_boxes_min = n_boxes_min
460 | self.labels_format = labels_format
461 | self.border_pixels = border_pixels
462 | self.box_filter = BoxFilter(check_overlap=True,
463 | check_min_area=False,
464 | check_degenerate=False,
465 | overlap_criterion=self.overlap_criterion,
466 | overlap_bounds=self.bounds,
467 | labels_format=self.labels_format,
468 | border_pixels=self.border_pixels)
469 |
470 | def __call__(self, labels, image_height, image_width):
471 | """
472 | Arguments:
473 | labels (array): The labels to be tested. The box coordinates are expected
474 | to be in the image's coordinate system.
475 | image_height (int): The height of the image to compare the box coordinates to.
476 | image_width (int): The width of the image to compare the box coordinates to.
477 |
478 | Returns:
479 | A boolean indicating whether an imgae of the given height and width is
480 | valid with respect to the given bounding boxes.
481 | """
482 |
483 | self.box_filter.overlap_bounds = self.bounds
484 | self.box_filter.labels_format = self.labels_format
485 |
486 | # Get all boxes that meet the overlap requirements.
487 | valid_labels = self.box_filter(labels=labels,
488 | image_height=image_height,
489 | image_width=image_width)
490 |
491 | # Check whether enough boxes meet the requirements.
492 | if isinstance(self.n_boxes_min, int):
493 | # The image is valid if at least `self.n_boxes_min` ground truth boxes meet the requirements.
494 | if len(valid_labels) >= self.n_boxes_min:
495 | return True
496 | else:
497 | return False
498 | elif self.n_boxes_min == 'all':
499 | # The image is valid if all ground truth boxes meet the requirements.
500 | if len(valid_labels) == len(labels):
501 | return True
502 | else:
503 | return False
504 |
505 |
506 | class RandomPatchInf:
507 | """
508 | Randomly samples a patch from an image. The randomness refers to whatever
509 | randomness may be introduced by the patch coordinate generator, the box filter,
510 | and the patch validator.
511 |
512 | Input images may be cropped and/or padded along either or both of the two
513 | spatial dimensions as necessary in order to obtain the required patch.
514 |
515 | This operation is very similar to `RandomPatch`, except that:
516 | 1. This operation runs indefinitely until either a valid patch is found or
517 | the input image is returned unaltered, i.e. it cannot fail.
518 | 2. If a bound generator is given, a new pair of bounds will be generated
519 | every `n_trials_max` iterations.
520 | """
521 |
522 | def __init__(self,
523 | patch_coord_generator,
524 | box_filter=None,
525 | image_validator=None,
526 | bound_generator=None,
527 | n_trials_max=50,
528 | clip_boxes=True,
529 | prob=0.857,
530 | background=(0, 0, 0),
531 | labels_format={
532 | 'class_id': 0,
533 | 'xmin': 1,
534 | 'ymin': 2,
535 | 'xmax': 3,
536 | 'ymax': 4
537 | }):
538 | """
539 | Arguments:
540 | patch_coord_generator (PatchCoordinateGenerator): A `PatchCoordinateGenerator` object
541 | to generate the positions and sizes of the patches to be sampled from the input images.
542 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given.
543 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria
544 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`,
545 | the validity of the bounding boxes is not checked.
546 | image_validator (ImageValidator, optional): Only relevant if ground truth bounding boxes are given.
547 | An `ImageValidator` object to determine whether a sampled patch is valid. If `None`,
548 | any outcome is valid.
549 | bound_generator (BoundGenerator, optional): A `BoundGenerator` object to generate upper and
550 | lower bound values for the patch validator. Every `n_trials_max` trials, a new pair of
551 | upper and lower bounds will be generated until a valid patch is found or the original image
552 | is returned. This bound generator overrides the bound generator of the patch validator.
553 | n_trials_max (int, optional): Only relevant if ground truth bounding boxes are given.
554 | The sampler will run indefinitely until either a valid patch is found or the original image
555 | is returned, but this determines the maxmial number of trials to sample a valid patch for each
556 | selected pair of lower and upper bounds before a new pair is picked.
557 | clip_boxes (bool, optional): Only relevant if ground truth bounding boxes are given.
558 | If `True`, any ground truth bounding boxes will be clipped to lie entirely within the
559 | sampled patch.
560 | prob (float, optional): `(1 - prob)` determines the probability with which the original,
561 | unaltered image is returned.
562 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the potential
563 | background pixels of the scaled images. In the case of single-channel images,
564 | the first element of `background` will be used as the background pixel value.
565 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
566 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
567 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
568 | """
569 |
570 | if not isinstance(patch_coord_generator, PatchCoordinateGenerator):
571 | raise ValueError(
572 | "`patch_coord_generator` must be an instance of `PatchCoordinateGenerator`."
573 | )
574 | if not (isinstance(image_validator, ImageValidator)
575 | or image_validator is None):
576 | raise ValueError(
577 | "`image_validator` must be either `None` or an `ImageValidator` object."
578 | )
579 | if not (isinstance(bound_generator, BoundGenerator)
580 | or bound_generator is None):
581 | raise ValueError(
582 | "`bound_generator` must be either `None` or a `BoundGenerator` object."
583 | )
584 | self.patch_coord_generator = patch_coord_generator
585 | self.box_filter = box_filter
586 | self.image_validator = image_validator
587 | self.bound_generator = bound_generator
588 | self.n_trials_max = n_trials_max
589 | self.clip_boxes = clip_boxes
590 | self.prob = prob
591 | self.background = background
592 | self.labels_format = labels_format
593 | self.sample_patch = CropPad(patch_ymin=None,
594 | patch_xmin=None,
595 | patch_height=None,
596 | patch_width=None,
597 | clip_boxes=self.clip_boxes,
598 | box_filter=self.box_filter,
599 | background=self.background,
600 | labels_format=self.labels_format)
601 |
602 | def __call__(self, image, labels=None, return_inverter=False):
603 |
604 | img_height, img_width = image.shape[:2]
605 | self.patch_coord_generator.img_height = img_height
606 | self.patch_coord_generator.img_width = img_width
607 |
608 | xmin = self.labels_format['xmin']
609 | ymin = self.labels_format['ymin']
610 | xmax = self.labels_format['xmax']
611 | ymax = self.labels_format['ymax']
612 |
613 | # Override the preset labels format.
614 | if not self.image_validator is None:
615 | self.image_validator.labels_format = self.labels_format
616 | self.sample_patch.labels_format = self.labels_format
617 |
618 | while True: # Keep going until we either find a valid patch or return the original image.
619 |
620 | p = np.random.uniform(0, 1)
621 | if p >= (1.0 - self.prob):
622 |
623 | # In case we have a bound generator, pick a lower and upper bound for the patch validator.
624 | if not ((self.image_validator is None) or
625 | (self.bound_generator is None)):
626 | self.image_validator.bounds = self.bound_generator()
627 |
628 | # Use at most `self.n_trials_max` attempts to find a crop
629 | # that meets our requirements.
630 | for _ in range(max(1, self.n_trials_max)):
631 |
632 | # Generate patch coordinates.
633 | patch_ymin, patch_xmin, patch_height, patch_width = self.patch_coord_generator(
634 | )
635 |
636 | self.sample_patch.patch_ymin = patch_ymin
637 | self.sample_patch.patch_xmin = patch_xmin
638 | self.sample_patch.patch_height = patch_height
639 | self.sample_patch.patch_width = patch_width
640 |
641 | # Check if the resulting patch meets the aspect ratio requirements.
642 | aspect_ratio = patch_width / patch_height
643 | if not (self.patch_coord_generator.min_aspect_ratio <=
644 | aspect_ratio <=
645 | self.patch_coord_generator.max_aspect_ratio):
646 | continue
647 |
648 | if (labels is None) or (self.image_validator is None):
649 | # We either don't have any boxes or if we do, we will accept any outcome as valid.
650 | return self.sample_patch(image, labels,
651 | return_inverter)
652 | else:
653 | # Translate the box coordinates to the patch's coordinate system.
654 | new_labels = np.copy(labels)
655 | new_labels[:, [ymin, ymax]] -= patch_ymin
656 | new_labels[:, [xmin, xmax]] -= patch_xmin
657 | # Check if the patch contains the minimum number of boxes we require.
658 | if self.image_validator(labels=new_labels,
659 | image_height=patch_height,
660 | image_width=patch_width):
661 | return self.sample_patch(image, labels,
662 | return_inverter)
663 | else:
664 | if return_inverter:
665 |
666 | def inverter(labels):
667 | return labels
668 |
669 | if labels is None:
670 | if return_inverter:
671 | return image, inverter
672 | else:
673 | return image
674 | else:
675 | if return_inverter:
676 | return image, labels, inverter
677 | else:
678 | return image, labels
679 |
680 |
681 | class PatchCoordinateGenerator:
682 | """
683 | Generates random patch coordinates that meet specified requirements.
684 | """
685 |
686 | def __init__(self,
687 | img_height=None,
688 | img_width=None,
689 | must_match='h_w',
690 | min_scale=0.3,
691 | max_scale=1.0,
692 | scale_uniformly=False,
693 | min_aspect_ratio=0.5,
694 | max_aspect_ratio=2.0,
695 | patch_ymin=None,
696 | patch_xmin=None,
697 | patch_height=None,
698 | patch_width=None,
699 | patch_aspect_ratio=None):
700 | """
701 | Arguments:
702 | img_height (int): The height of the image for which the patch coordinates
703 | shall be generated. Doesn't have to be known upon construction.
704 | img_width (int): The width of the image for which the patch coordinates
705 | shall be generated. Doesn't have to be known upon construction.
706 | must_match (str, optional): Can be either of 'h_w', 'h_ar', and 'w_ar'.
707 | Specifies which two of the three quantities height, width, and aspect
708 | ratio determine the shape of the generated patch. The respective third
709 | quantity will be computed from the other two. For example,
710 | if `must_match == 'h_w'`, then the patch's height and width will be
711 | set to lie within [min_scale, max_scale] of the image size or to
712 | `patch_height` and/or `patch_width`, if given. The patch's aspect ratio
713 | is the dependent variable in this case, it will be computed from the
714 | height and width. Any given values for `patch_aspect_ratio`,
715 | `min_aspect_ratio`, or `max_aspect_ratio` will be ignored.
716 | min_scale (float, optional): The minimum size of a dimension of the patch
717 | as a fraction of the respective dimension of the image. Can be greater
718 | than 1. For example, if the image width is 200 and `min_scale == 0.5`,
719 | then the width of the generated patch will be at least 100. If `min_scale == 1.5`,
720 | the width of the generated patch will be at least 300.
721 | max_scale (float, optional): The maximum size of a dimension of the patch
722 | as a fraction of the respective dimension of the image. Can be greater
723 | than 1. For example, if the image width is 200 and `max_scale == 1.0`,
724 | then the width of the generated patch will be at most 200. If `max_scale == 1.5`,
725 | the width of the generated patch will be at most 300. Must be greater than
726 | `min_scale`.
727 | scale_uniformly (bool, optional): If `True` and if `must_match == 'h_w'`,
728 | the patch height and width will be scaled uniformly, otherwise they will
729 | be scaled independently.
730 | min_aspect_ratio (float, optional): Determines the minimum aspect ratio
731 | for the generated patches.
732 | max_aspect_ratio (float, optional): Determines the maximum aspect ratio
733 | for the generated patches.
734 | patch_ymin (int, optional): `None` or the vertical coordinate of the top left
735 | corner of the generated patches. If this is not `None`, the position of the
736 | patches along the vertical axis is fixed. If this is `None`, then the
737 | vertical position of generated patches will be chosen randomly such that
738 | the overlap of a patch and the image along the vertical dimension is
739 | always maximal.
740 | patch_xmin (int, optional): `None` or the horizontal coordinate of the top left
741 | corner of the generated patches. If this is not `None`, the position of the
742 | patches along the horizontal axis is fixed. If this is `None`, then the
743 | horizontal position of generated patches will be chosen randomly such that
744 | the overlap of a patch and the image along the horizontal dimension is
745 | always maximal.
746 | patch_height (int, optional): `None` or the fixed height of the generated patches.
747 | patch_width (int, optional): `None` or the fixed width of the generated patches.
748 | patch_aspect_ratio (float, optional): `None` or the fixed aspect ratio of the
749 | generated patches.
750 | """
751 |
752 | if not (must_match in {'h_w', 'h_ar', 'w_ar'}):
753 | raise ValueError(
754 | "`must_match` must be either of 'h_w', 'h_ar' and 'w_ar'.")
755 | if min_scale >= max_scale:
756 | raise ValueError("It must be `min_scale < max_scale`.")
757 | if min_aspect_ratio >= max_aspect_ratio:
758 | raise ValueError(
759 | "It must be `min_aspect_ratio < max_aspect_ratio`.")
760 | if scale_uniformly and not ((patch_height is None) and
761 | (patch_width is None)):
762 | raise ValueError(
763 | "If `scale_uniformly == True`, `patch_height` and `patch_width` must both be `None`."
764 | )
765 | self.img_height = img_height
766 | self.img_width = img_width
767 | self.must_match = must_match
768 | self.min_scale = min_scale
769 | self.max_scale = max_scale
770 | self.scale_uniformly = scale_uniformly
771 | self.min_aspect_ratio = min_aspect_ratio
772 | self.max_aspect_ratio = max_aspect_ratio
773 | self.patch_ymin = patch_ymin
774 | self.patch_xmin = patch_xmin
775 | self.patch_height = patch_height
776 | self.patch_width = patch_width
777 | self.patch_aspect_ratio = patch_aspect_ratio
778 |
779 | def __call__(self):
780 | """
781 | Returns:
782 | A 4-tuple `(ymin, xmin, height, width)` that represents the coordinates
783 | of the generated patch.
784 | """
785 |
786 | # Get the patch height and width.
787 |
788 | if self.must_match == 'h_w': # Aspect is the dependent variable.
789 | if not self.scale_uniformly:
790 | # Get the height.
791 | if self.patch_height is None:
792 | patch_height = int(
793 | np.random.uniform(self.min_scale, self.max_scale) *
794 | self.img_height)
795 | else:
796 | patch_height = self.patch_height
797 | # Get the width.
798 | if self.patch_width is None:
799 | patch_width = int(
800 | np.random.uniform(self.min_scale, self.max_scale) *
801 | self.img_width)
802 | else:
803 | patch_width = self.patch_width
804 | else:
805 | scaling_factor = np.random.uniform(self.min_scale,
806 | self.max_scale)
807 | patch_height = int(scaling_factor * self.img_height)
808 | patch_width = int(scaling_factor * self.img_width)
809 |
810 | elif self.must_match == 'h_ar': # Width is the dependent variable.
811 | # Get the height.
812 | if self.patch_height is None:
813 | patch_height = int(
814 | np.random.uniform(self.min_scale, self.max_scale) *
815 | self.img_height)
816 | else:
817 | patch_height = self.patch_height
818 | # Get the aspect ratio.
819 | if self.patch_aspect_ratio is None:
820 | patch_aspect_ratio = np.random.uniform(self.min_aspect_ratio,
821 | self.max_aspect_ratio)
822 | else:
823 | patch_aspect_ratio = self.patch_aspect_ratio
824 | # Get the width.
825 | patch_width = int(patch_height * patch_aspect_ratio)
826 |
827 | elif self.must_match == 'w_ar': # Height is the dependent variable.
828 | # Get the width.
829 | if self.patch_width is None:
830 | patch_width = int(
831 | np.random.uniform(self.min_scale, self.max_scale) *
832 | self.img_width)
833 | else:
834 | patch_width = self.patch_width
835 | # Get the aspect ratio.
836 | if self.patch_aspect_ratio is None:
837 | patch_aspect_ratio = np.random.uniform(self.min_aspect_ratio,
838 | self.max_aspect_ratio)
839 | else:
840 | patch_aspect_ratio = self.patch_aspect_ratio
841 | # Get the height.
842 | patch_height = int(patch_width / patch_aspect_ratio)
843 |
844 | # Get the top left corner coordinates of the patch.
845 |
846 | if self.patch_ymin is None:
847 | # Compute how much room we have along the vertical axis to place the patch.
848 | # A negative number here means that we want to sample a patch that is larger than the original image
849 | # in the vertical dimension, in which case the patch will be placed such that it fully contains the
850 | # image in the vertical dimension.
851 | y_range = self.img_height - patch_height
852 | # Select a random top left corner for the sample position from the possible positions.
853 | if y_range >= 0:
854 | patch_ymin = np.random.randint(
855 | 0, y_range + 1
856 | ) # There are y_range + 1 possible positions for the crop in the vertical dimension.
857 | else:
858 | patch_ymin = np.random.randint(
859 | y_range, 1
860 | ) # The possible positions for the image on the background canvas in the vertical dimension.
861 | else:
862 | patch_ymin = self.patch_ymin
863 |
864 | if self.patch_xmin is None:
865 | # Compute how much room we have along the horizontal axis to place the patch.
866 | # A negative number here means that we want to sample a patch that is larger than the original image
867 | # in the horizontal dimension, in which case the patch will be placed such that it fully contains the
868 | # image in the horizontal dimension.
869 | x_range = self.img_width - patch_width
870 | # Select a random top left corner for the sample position from the possible positions.
871 | if x_range >= 0:
872 | patch_xmin = np.random.randint(
873 | 0, x_range + 1
874 | ) # There are x_range + 1 possible positions for the crop in the horizontal dimension.
875 | else:
876 | patch_xmin = np.random.randint(
877 | x_range, 1
878 | ) # The possible positions for the image on the background canvas in the horizontal dimension.
879 | else:
880 | patch_xmin = self.patch_xmin
881 |
882 | return patch_ymin, patch_xmin, patch_height, patch_width
883 |
884 |
885 | class CropPad:
886 | """
887 | Crops and/or pads an image deterministically.
888 |
889 | Depending on the given output patch size and the position (top left corner) relative
890 | to the input image, the image will be cropped and/or padded along one or both spatial
891 | dimensions.
892 |
893 | For example, if the output patch lies entirely within the input image, this will result
894 | in a regular crop. If the input image lies entirely within the output patch, this will
895 | result in the image being padded in every direction. All other cases are mixed cases
896 | where the image might be cropped in some directions and padded in others.
897 |
898 | The output patch can be arbitrary in both size and position as long as it overlaps
899 | with the input image.
900 | """
901 |
902 | def __init__(self,
903 | patch_ymin,
904 | patch_xmin,
905 | patch_height,
906 | patch_width,
907 | clip_boxes=True,
908 | box_filter=None,
909 | background=(0, 0, 0),
910 | labels_format={
911 | 'class_id': 0,
912 | 'xmin': 1,
913 | 'ymin': 2,
914 | 'xmax': 3,
915 | 'ymax': 4
916 | }):
917 | """
918 | Arguments:
919 | patch_ymin (int, optional): The vertical coordinate of the top left corner of the output
920 | patch relative to the image coordinate system. Can be negative (i.e. lie outside the image)
921 | as long as the resulting patch still overlaps with the image.
922 | patch_ymin (int, optional): The horizontal coordinate of the top left corner of the output
923 | patch relative to the image coordinate system. Can be negative (i.e. lie outside the image)
924 | as long as the resulting patch still overlaps with the image.
925 | patch_height (int): The height of the patch to be sampled from the image. Can be greater
926 | than the height of the input image.
927 | patch_width (int): The width of the patch to be sampled from the image. Can be greater
928 | than the width of the input image.
929 | clip_boxes (bool, optional): Only relevant if ground truth bounding boxes are given.
930 | If `True`, any ground truth bounding boxes will be clipped to lie entirely within the
931 | sampled patch.
932 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given.
933 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria
934 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`,
935 | the validity of the bounding boxes is not checked.
936 | background (list/tuple, optional): A 3-tuple specifying the RGB color value of the potential
937 | background pixels of the scaled images. In the case of single-channel images,
938 | the first element of `background` will be used as the background pixel value.
939 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
940 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
941 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
942 | """
943 | # if (patch_height <= 0) or (patch_width <= 0):
944 | # raise ValueError("Patch height and width must both be positive.")
945 | # if (patch_ymin + patch_height < 0) or (patch_xmin + patch_width < 0):
946 | # raise ValueError("A patch with the given coordinates cannot overlap with an input image.")
947 | if not (isinstance(box_filter, BoxFilter) or box_filter is None):
948 | raise ValueError(
949 | "`box_filter` must be either `None` or a `BoxFilter` object.")
950 | self.patch_height = patch_height
951 | self.patch_width = patch_width
952 | self.patch_ymin = patch_ymin
953 | self.patch_xmin = patch_xmin
954 | self.clip_boxes = clip_boxes
955 | self.box_filter = box_filter
956 | self.background = background
957 | self.labels_format = labels_format
958 |
959 | def __call__(self, image, labels=None, return_inverter=False):
960 |
961 | img_height, img_width = image.shape[:2]
962 |
963 | if (self.patch_ymin > img_height) or (self.patch_xmin > img_width):
964 | raise ValueError(
965 | "The given patch doesn't overlap with the input image.")
966 |
967 | labels = np.copy(labels)
968 |
969 | xmin = self.labels_format['xmin']
970 | ymin = self.labels_format['ymin']
971 | xmax = self.labels_format['xmax']
972 | ymax = self.labels_format['ymax']
973 |
974 | # Top left corner of the patch relative to the image coordinate system:
975 | patch_ymin = self.patch_ymin
976 | patch_xmin = self.patch_xmin
977 |
978 | # Create a canvas of the size of the patch we want to end up with.
979 | if image.ndim == 3:
980 | canvas = np.zeros(shape=(self.patch_height, self.patch_width, 3),
981 | dtype=np.uint8)
982 | canvas[:, :] = self.background
983 | elif image.ndim == 2:
984 | canvas = np.zeros(shape=(self.patch_height, self.patch_width),
985 | dtype=np.uint8)
986 | canvas[:, :] = self.background[0]
987 |
988 | # Perform the crop.
989 | if patch_ymin < 0 and patch_xmin < 0: # Pad the image at the top and on the left.
990 | image_crop_height = min(
991 | img_height, self.patch_height + patch_ymin
992 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction.
993 | image_crop_width = min(
994 | img_width, self.patch_width + patch_xmin
995 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction.
996 | canvas[-patch_ymin:-patch_ymin +
997 | image_crop_height, -patch_xmin:-patch_xmin +
998 | image_crop_width] = image[:image_crop_height, :
999 | image_crop_width]
1000 |
1001 | elif patch_ymin < 0 and patch_xmin >= 0: # Pad the image at the top and crop it on the left.
1002 | image_crop_height = min(
1003 | img_height, self.patch_height + patch_ymin
1004 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction.
1005 | image_crop_width = min(
1006 | self.patch_width, img_width - patch_xmin
1007 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction.
1008 | canvas[-patch_ymin:-patch_ymin + image_crop_height, :
1009 | image_crop_width] = image[:image_crop_height, patch_xmin:
1010 | patch_xmin + image_crop_width]
1011 |
1012 | elif patch_ymin >= 0 and patch_xmin < 0: # Crop the image at the top and pad it on the left.
1013 | image_crop_height = min(
1014 | self.patch_height, img_height - patch_ymin
1015 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction.
1016 | image_crop_width = min(
1017 | img_width, self.patch_width + patch_xmin
1018 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction.
1019 | canvas[:image_crop_height, -patch_xmin:-patch_xmin +
1020 | image_crop_width] = image[patch_ymin:patch_ymin +
1021 | image_crop_height, :
1022 | image_crop_width]
1023 |
1024 | elif patch_ymin >= 0 and patch_xmin >= 0: # Crop the image at the top and on the left.
1025 | image_crop_height = min(
1026 | self.patch_height, img_height - patch_ymin
1027 | ) # The number of pixels of the image that will end up on the canvas in the vertical direction.
1028 | image_crop_width = min(
1029 | self.patch_width, img_width - patch_xmin
1030 | ) # The number of pixels of the image that will end up on the canvas in the horizontal direction.
1031 | canvas[:image_crop_height, :image_crop_width] = image[
1032 | patch_ymin:patch_ymin +
1033 | image_crop_height, patch_xmin:patch_xmin + image_crop_width]
1034 |
1035 | image = canvas
1036 |
1037 | if return_inverter:
1038 |
1039 | def inverter(labels):
1040 | labels = np.copy(labels)
1041 | labels[:, [ymin + 1, ymax + 1]] += patch_ymin
1042 | labels[:, [xmin + 1, xmax + 1]] += patch_xmin
1043 | return labels
1044 |
1045 | if not (labels is None):
1046 |
1047 | # Translate the box coordinates to the patch's coordinate system.
1048 | labels[:, [ymin, ymax]] -= patch_ymin
1049 | labels[:, [xmin, xmax]] -= patch_xmin
1050 |
1051 | # Compute all valid boxes for this patch.
1052 | if not (self.box_filter is None):
1053 | self.box_filter.labels_format = self.labels_format
1054 | labels = self.box_filter(labels=labels,
1055 | image_height=self.patch_height,
1056 | image_width=self.patch_width)
1057 |
1058 | if self.clip_boxes:
1059 | labels[:, [ymin, ymax]] = np.clip(labels[:, [ymin, ymax]],
1060 | a_min=0,
1061 | a_max=self.patch_height - 1)
1062 | labels[:, [xmin, xmax]] = np.clip(labels[:, [xmin, xmax]],
1063 | a_min=0,
1064 | a_max=self.patch_width - 1)
1065 |
1066 | if return_inverter:
1067 | return image, labels, inverter
1068 | else:
1069 | return image, labels
1070 |
1071 | else:
1072 | if return_inverter:
1073 | return image, inverter
1074 | else:
1075 | return image
1076 |
1077 |
1078 | class Resize:
1079 | """
1080 | Resizes images to a specified height and width in pixels.
1081 | """
1082 |
1083 | def __init__(self,
1084 | height,
1085 | width,
1086 | interpolation_mode=cv2.INTER_LINEAR,
1087 | box_filter=None,
1088 | labels_format={
1089 | 'class_id': 0,
1090 | 'xmin': 1,
1091 | 'ymin': 2,
1092 | 'xmax': 3,
1093 | 'ymax': 4
1094 | }):
1095 | """
1096 | Arguments:
1097 | height (int): The desired height of the output images in pixels.
1098 | width (int): The desired width of the output images in pixels.
1099 | interpolation_mode (int, optional): An integer that denotes a valid
1100 | OpenCV interpolation mode. For example, integers 0 through 5 are
1101 | valid interpolation modes.
1102 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given.
1103 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria
1104 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`,
1105 | the validity of the bounding boxes is not checked.
1106 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
1107 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
1108 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
1109 | """
1110 | if not (isinstance(box_filter, BoxFilter) or box_filter is None):
1111 | raise ValueError(
1112 | "`box_filter` must be either `None` or a `BoxFilter` object.")
1113 | self.out_height = height
1114 | self.out_width = width
1115 | self.interpolation_mode = interpolation_mode
1116 | self.box_filter = box_filter
1117 | self.labels_format = labels_format
1118 |
1119 | def __call__(self, image, labels=None, return_inverter=False):
1120 |
1121 | img_height, img_width = image.shape[:2]
1122 |
1123 | xmin = self.labels_format['xmin']
1124 | ymin = self.labels_format['ymin']
1125 | xmax = self.labels_format['xmax']
1126 | ymax = self.labels_format['ymax']
1127 |
1128 | image = cv2.resize(image,
1129 | dsize=(self.out_width, self.out_height),
1130 | interpolation=self.interpolation_mode)
1131 |
1132 | if return_inverter:
1133 |
1134 | def inverter(labels):
1135 | labels = np.copy(labels)
1136 | labels[:, [ymin + 1, ymax +
1137 | 1]] = np.round(labels[:, [ymin + 1, ymax + 1]] *
1138 | (img_height / self.out_height),
1139 | decimals=0)
1140 | labels[:, [xmin + 1, xmax +
1141 | 1]] = np.round(labels[:, [xmin + 1, xmax + 1]] *
1142 | (img_width / self.out_width),
1143 | decimals=0)
1144 | return labels
1145 |
1146 | if labels is None:
1147 | if return_inverter:
1148 | return image, inverter
1149 | else:
1150 | return image
1151 | else:
1152 | labels = np.copy(labels)
1153 | labels[:, [ymin, ymax]] = np.round(labels[:, [ymin, ymax]] *
1154 | (self.out_height / img_height),
1155 | decimals=0)
1156 | labels[:, [xmin, xmax]] = np.round(labels[:, [xmin, xmax]] *
1157 | (self.out_width / img_width),
1158 | decimals=0)
1159 | # labels[:, [ymin, ymax]] = labels[:, [ymin, ymax]] * (self.out_height / img_height)
1160 | # labels[:, [xmin, xmax]] = labels[:, [xmin, xmax]] * (self.out_width / img_width)
1161 | if not (self.box_filter is None):
1162 | self.box_filter.labels_format = self.labels_format
1163 | labels = self.box_filter(labels=labels,
1164 | image_height=self.out_height,
1165 | image_width=self.out_width)
1166 |
1167 | if return_inverter:
1168 | return image, labels, inverter
1169 | else:
1170 | return image, labels
1171 |
1172 |
1173 | class ResizeRandomInterp:
1174 | """
1175 | Resizes images to a specified height and width in pixels using a radnomly
1176 | selected interpolation mode.
1177 | """
1178 |
1179 | def __init__(self,
1180 | height,
1181 | width,
1182 | interpolation_modes=[
1183 | cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC,
1184 | cv2.INTER_AREA, cv2.INTER_LANCZOS4
1185 | ],
1186 | box_filter=None,
1187 | labels_format={
1188 | 'class_id': 0,
1189 | 'xmin': 1,
1190 | 'ymin': 2,
1191 | 'xmax': 3,
1192 | 'ymax': 4
1193 | }):
1194 | """
1195 | Arguments:
1196 | height (int): The desired height of the output image in pixels.
1197 | width (int): The desired width of the output image in pixels.
1198 | interpolation_modes (list/tuple, optional): A list/tuple of integers
1199 | that represent valid OpenCV interpolation modes. For example,
1200 | integers 0 through 5 are valid interpolation modes.
1201 | box_filter (BoxFilter, optional): Only relevant if ground truth bounding boxes are given.
1202 | A `BoxFilter` object to filter out bounding boxes that don't meet the given criteria
1203 | after the transformation. Refer to the `BoxFilter` documentation for details. If `None`,
1204 | the validity of the bounding boxes is not checked.
1205 | labels_format (dict, optional): A dictionary that defines which index in the last axis of the labels
1206 | of an image contains which bounding box coordinate. The dictionary maps at least the keywords
1207 | 'xmin', 'ymin', 'xmax', and 'ymax' to their respective indices within last axis of the labels array.
1208 | """
1209 | if not (isinstance(interpolation_modes, (list, tuple))):
1210 | raise ValueError("`interpolation_mode` must be a list or tuple.")
1211 | self.height = height
1212 | self.width = width
1213 | self.interpolation_modes = interpolation_modes
1214 | self.box_filter = box_filter
1215 | self.labels_format = labels_format
1216 | self.resize = Resize(height=self.height,
1217 | width=self.width,
1218 | box_filter=self.box_filter,
1219 | labels_format=self.labels_format)
1220 |
1221 | def __call__(self, image, labels=None, return_inverter=False):
1222 | self.resize.interpolation_mode = np.random.choice(
1223 | self.interpolation_modes)
1224 | self.resize.labels_format = self.labels_format
1225 | return self.resize(image, labels, return_inverter)
1226 |
1227 |
1228 | def convert_coordinates(tensor, start_index, conversion, border_pixels='half'):
1229 | """
1230 | Convert coordinates for axis-aligned 2D boxes between two coordinate formats.
1231 |
1232 | Creates a copy of `tensor`, i.e. does not operate in place. Currently there are
1233 | three supported coordinate formats that can be converted from and to each other:
1234 | 1) (xmin, xmax, ymin, ymax) - the 'minmax' format
1235 | 2) (xmin, ymin, xmax, ymax) - the 'corners' format
1236 | 2) (cx, cy, w, h) - the 'centroids' format
1237 |
1238 | Arguments:
1239 | tensor (array): A Numpy nD array containing the four consecutive coordinates
1240 | to be converted somewhere in the last axis.
1241 | start_index (int): The index of the first coordinate in the last axis of `tensor`.
1242 | conversion (str, optional): The conversion direction. Can be 'minmax2centroids',
1243 | 'centroids2minmax', 'corners2centroids', 'centroids2corners', 'minmax2corners',
1244 | or 'corners2minmax'.
1245 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes.
1246 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong
1247 | to the boxes. If 'exclude', the border pixels do not belong to the boxes.
1248 | If 'half', then one of each of the two horizontal and vertical borders belong
1249 | to the boxex, but not the other.
1250 |
1251 | Returns:
1252 | A Numpy nD array, a copy of the input tensor with the converted coordinates
1253 | in place of the original coordinates and the unaltered elements of the original
1254 | tensor elsewhere.
1255 | """
1256 | if border_pixels == 'half':
1257 | d = 0
1258 | elif border_pixels == 'include':
1259 | d = 1
1260 | elif border_pixels == 'exclude':
1261 | d = -1
1262 |
1263 | ind = start_index
1264 | tensor1 = np.copy(tensor).astype(np.float)
1265 | if conversion == 'minmax2centroids':
1266 | tensor1[..., ind] = (tensor[..., ind] +
1267 | tensor[..., ind + 1]) / 2.0 # Set cx
1268 | tensor1[..., ind + 1] = (tensor[..., ind + 2] +
1269 | tensor[..., ind + 3]) / 2.0 # Set cy
1270 | tensor1[..., ind +
1271 | 2] = tensor[..., ind + 1] - tensor[..., ind] + d # Set w
1272 | tensor1[..., ind +
1273 | 3] = tensor[..., ind + 3] - tensor[..., ind + 2] + d # Set h
1274 | elif conversion == 'centroids2minmax':
1275 | tensor1[..., ind] = tensor[..., ind] - tensor[..., ind +
1276 | 2] / 2.0 # Set xmin
1277 | tensor1[..., ind +
1278 | 1] = tensor[..., ind] + tensor[..., ind + 2] / 2.0 # Set xmax
1279 | tensor1[..., ind +
1280 | 2] = tensor[..., ind +
1281 | 1] - tensor[..., ind + 3] / 2.0 # Set ymin
1282 | tensor1[..., ind +
1283 | 3] = tensor[..., ind +
1284 | 1] + tensor[..., ind + 3] / 2.0 # Set ymax
1285 | elif conversion == 'corners2centroids':
1286 | tensor1[..., ind] = (tensor[..., ind] +
1287 | tensor[..., ind + 2]) / 2.0 # Set cx
1288 | tensor1[..., ind + 1] = (tensor[..., ind + 1] +
1289 | tensor[..., ind + 3]) / 2.0 # Set cy
1290 | tensor1[..., ind +
1291 | 2] = tensor[..., ind + 2] - tensor[..., ind] + d # Set w
1292 | tensor1[..., ind +
1293 | 3] = tensor[..., ind + 3] - tensor[..., ind + 1] + d # Set h
1294 | elif conversion == 'centroids2corners':
1295 | tensor1[..., ind] = tensor[..., ind] - tensor[..., ind +
1296 | 2] / 2.0 # Set xmin
1297 | tensor1[..., ind +
1298 | 1] = tensor[..., ind +
1299 | 1] - tensor[..., ind + 3] / 2.0 # Set ymin
1300 | tensor1[..., ind +
1301 | 2] = tensor[..., ind] + tensor[..., ind + 2] / 2.0 # Set xmax
1302 | tensor1[..., ind +
1303 | 3] = tensor[..., ind +
1304 | 1] + tensor[..., ind + 3] / 2.0 # Set ymax
1305 | elif (conversion == 'minmax2corners') or (conversion == 'corners2minmax'):
1306 | tensor1[..., ind + 1] = tensor[..., ind + 2]
1307 | tensor1[..., ind + 2] = tensor[..., ind + 1]
1308 | else:
1309 | raise ValueError(
1310 | "Unexpected conversion value. Supported values are 'minmax2centroids', 'centroids2minmax', 'corners2centroids', 'centroids2corners', 'minmax2corners', and 'corners2minmax'."
1311 | )
1312 |
1313 | return tensor1
1314 |
1315 |
1316 | def intersection_area_(boxes1,
1317 | boxes2,
1318 | coords='corners',
1319 | mode='outer_product',
1320 | border_pixels='half'):
1321 | """
1322 | The same as 'intersection_area()' but for internal use, i.e. without all the safety checks.
1323 | """
1324 |
1325 | m = boxes1.shape[0] # The number of boxes in `boxes1`
1326 | n = boxes2.shape[0] # The number of boxes in `boxes2`
1327 |
1328 | # Set the correct coordinate indices for the respective formats.
1329 | if coords == 'corners':
1330 | xmin = 0
1331 | ymin = 1
1332 | xmax = 2
1333 | ymax = 3
1334 | elif coords == 'minmax':
1335 | xmin = 0
1336 | xmax = 1
1337 | ymin = 2
1338 | ymax = 3
1339 |
1340 | if border_pixels == 'half':
1341 | d = 0
1342 | elif border_pixels == 'include':
1343 | d = 1 # If border pixels are supposed to belong to the bounding boxes, we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`.
1344 | elif border_pixels == 'exclude':
1345 | d = -1 # If border pixels are not supposed to belong to the bounding boxes, we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`.
1346 |
1347 | # Compute the intersection areas.
1348 |
1349 | if mode == 'outer_product':
1350 |
1351 | # For all possible box combinations, get the greater xmin and ymin values.
1352 | # This is a tensor of shape (m,n,2).
1353 | min_xy = np.maximum(
1354 | np.tile(np.expand_dims(boxes1[:, [xmin, ymin]], axis=1),
1355 | reps=(1, n, 1)),
1356 | np.tile(np.expand_dims(boxes2[:, [xmin, ymin]], axis=0),
1357 | reps=(m, 1, 1)))
1358 |
1359 | # For all possible box combinations, get the smaller xmax and ymax values.
1360 | # This is a tensor of shape (m,n,2).
1361 | max_xy = np.minimum(
1362 | np.tile(np.expand_dims(boxes1[:, [xmax, ymax]], axis=1),
1363 | reps=(1, n, 1)),
1364 | np.tile(np.expand_dims(boxes2[:, [xmax, ymax]], axis=0),
1365 | reps=(m, 1, 1)))
1366 |
1367 | # Compute the side lengths of the intersection rectangles.
1368 | side_lengths = np.maximum(0, max_xy - min_xy + d)
1369 |
1370 | return side_lengths[:, :, 0] * side_lengths[:, :, 1]
1371 |
1372 | elif mode == 'element-wise':
1373 |
1374 | min_xy = np.maximum(boxes1[:, [xmin, ymin]], boxes2[:, [xmin, ymin]])
1375 | max_xy = np.minimum(boxes1[:, [xmax, ymax]], boxes2[:, [xmax, ymax]])
1376 |
1377 | # Compute the side lengths of the intersection rectangles.
1378 | side_lengths = np.maximum(0, max_xy - min_xy + d)
1379 |
1380 | return side_lengths[:, 0] * side_lengths[:, 1]
1381 |
1382 |
1383 | def iou(boxes1,
1384 | boxes2,
1385 | coords='centroids',
1386 | mode='outer_product',
1387 | border_pixels='half'):
1388 | """
1389 | Computes the intersection-over-union similarity (also known as Jaccard similarity)
1390 | of two sets of axis-aligned 2D rectangular boxes.
1391 |
1392 | Let `boxes1` and `boxes2` contain `m` and `n` boxes, respectively.
1393 |
1394 | In 'outer_product' mode, returns an `(m,n)` matrix with the IoUs for all possible
1395 | combinations of the boxes in `boxes1` and `boxes2`.
1396 |
1397 | In 'element-wise' mode, `m` and `n` must be broadcast-compatible. Refer to the explanation
1398 | of the `mode` argument for details.
1399 |
1400 | Arguments:
1401 | boxes1 (array): Either a 1D Numpy array of shape `(4, )` containing the coordinates for one box in the
1402 | format specified by `coords` or a 2D Numpy array of shape `(m, 4)` containing the coordinates for `m` boxes.
1403 | If `mode` is set to 'element_wise', the shape must be broadcast-compatible with `boxes2`.
1404 | boxes2 (array): Either a 1D Numpy array of shape `(4, )` containing the coordinates for one box in the
1405 | format specified by `coords` or a 2D Numpy array of shape `(n, 4)` containing the coordinates for `n` boxes.
1406 | If `mode` is set to 'element_wise', the shape must be broadcast-compatible with `boxes1`.
1407 | coords (str, optional): The coordinate format in the input arrays. Can be either 'centroids' for the format
1408 | `(cx, cy, w, h)`, 'minmax' for the format `(xmin, xmax, ymin, ymax)`, or 'corners' for the format
1409 | `(xmin, ymin, xmax, ymax)`.
1410 | mode (str, optional): Can be one of 'outer_product' and 'element-wise'. In 'outer_product' mode, returns an
1411 | `(m,n)` matrix with the IoU overlaps for all possible combinations of the `m` boxes in `boxes1` with the
1412 | `n` boxes in `boxes2`. In 'element-wise' mode, returns a 1D array and the shapes of `boxes1` and `boxes2`
1413 | must be boadcast-compatible. If both `boxes1` and `boxes2` have `m` boxes, then this returns an array of
1414 | length `m` where the i-th position contains the IoU overlap of `boxes1[i]` with `boxes2[i]`.
1415 | border_pixels (str, optional): How to treat the border pixels of the bounding boxes.
1416 | Can be 'include', 'exclude', or 'half'. If 'include', the border pixels belong
1417 | to the boxes. If 'exclude', the border pixels do not belong to the boxes.
1418 | If 'half', then one of each of the two horizontal and vertical borders belong
1419 | to the boxex, but not the other.
1420 |
1421 | Returns:
1422 | A 1D or 2D Numpy array (refer to the `mode` argument for details) of dtype float containing values in [0,1],
1423 | the Jaccard similarity of the boxes in `boxes1` and `boxes2`. 0 means there is no overlap between two given
1424 | boxes, 1 means their coordinates are identical.
1425 | """
1426 |
1427 | # Make sure the boxes have the right shapes.
1428 | if boxes1.ndim > 2:
1429 | raise ValueError(
1430 | "boxes1 must have rank either 1 or 2, but has rank {}.".format(
1431 | boxes1.ndim))
1432 | if boxes2.ndim > 2:
1433 | raise ValueError(
1434 | "boxes2 must have rank either 1 or 2, but has rank {}.".format(
1435 | boxes2.ndim))
1436 |
1437 | if boxes1.ndim == 1: boxes1 = np.expand_dims(boxes1, axis=0)
1438 | if boxes2.ndim == 1: boxes2 = np.expand_dims(boxes2, axis=0)
1439 |
1440 | if not (boxes1.shape[1] == boxes2.shape[1] == 4):
1441 | raise ValueError(
1442 | "All boxes must consist of 4 coordinates, but the boxes in `boxes1` and `boxes2` have {} and {} coordinates, respectively."
1443 | .format(boxes1.shape[1], boxes2.shape[1]))
1444 | if not mode in {'outer_product', 'element-wise'}:
1445 | raise ValueError(
1446 | "`mode` must be one of 'outer_product' and 'element-wise', but got '{}'."
1447 | .format(mode))
1448 |
1449 | # Convert the coordinates if necessary.
1450 | if coords == 'centroids':
1451 | boxes1 = convert_coordinates(boxes1,
1452 | start_index=0,
1453 | conversion='centroids2corners')
1454 | boxes2 = convert_coordinates(boxes2,
1455 | start_index=0,
1456 | conversion='centroids2corners')
1457 | coords = 'corners'
1458 | elif not (coords in {'minmax', 'corners'}):
1459 | raise ValueError(
1460 | "Unexpected value for `coords`. Supported values are 'minmax', 'corners' and 'centroids'."
1461 | )
1462 |
1463 | # Compute the IoU.
1464 |
1465 | # Compute the interesection areas.
1466 |
1467 | intersection_areas = intersection_area_(boxes1,
1468 | boxes2,
1469 | coords=coords,
1470 | mode=mode)
1471 |
1472 | m = boxes1.shape[0] # The number of boxes in `boxes1`
1473 | n = boxes2.shape[0] # The number of boxes in `boxes2`
1474 |
1475 | # Compute the union areas.
1476 |
1477 | # Set the correct coordinate indices for the respective formats.
1478 | if coords == 'corners':
1479 | xmin = 0
1480 | ymin = 1
1481 | xmax = 2
1482 | ymax = 3
1483 | elif coords == 'minmax':
1484 | xmin = 0
1485 | xmax = 1
1486 | ymin = 2
1487 | ymax = 3
1488 |
1489 | if border_pixels == 'half':
1490 | d = 0
1491 | elif border_pixels == 'include':
1492 | d = 1 # If border pixels are supposed to belong to the bounding boxes, we have to add one pixel to any difference `xmax - xmin` or `ymax - ymin`.
1493 | elif border_pixels == 'exclude':
1494 | d = -1 # If border pixels are not supposed to belong to the bounding boxes, we have to subtract one pixel from any difference `xmax - xmin` or `ymax - ymin`.
1495 |
1496 | if mode == 'outer_product':
1497 |
1498 | boxes1_areas = np.tile(np.expand_dims(
1499 | (boxes1[:, xmax] - boxes1[:, xmin] + d) *
1500 | (boxes1[:, ymax] - boxes1[:, ymin] + d),
1501 | axis=1),
1502 | reps=(1, n))
1503 | boxes2_areas = np.tile(np.expand_dims(
1504 | (boxes2[:, xmax] - boxes2[:, xmin] + d) *
1505 | (boxes2[:, ymax] - boxes2[:, ymin] + d),
1506 | axis=0),
1507 | reps=(m, 1))
1508 |
1509 | elif mode == 'element-wise':
1510 |
1511 | boxes1_areas = (boxes1[:, xmax] - boxes1[:, xmin] +
1512 | d) * (boxes1[:, ymax] - boxes1[:, ymin] + d)
1513 | boxes2_areas = (boxes2[:, xmax] - boxes2[:, xmin] +
1514 | d) * (boxes2[:, ymax] - boxes2[:, ymin] + d)
1515 |
1516 | union_areas = boxes1_areas + boxes2_areas - intersection_areas
1517 |
1518 | return intersection_areas / union_areas
1519 |
--------------------------------------------------------------------------------