├── .gitignore ├── LICENSE ├── README.md ├── eval.py ├── evaluate.py ├── kitti_common.py └── rotate_iou.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kitti-object-eval-python 2 | Fast kitti object detection eval in python(finish eval in less than 10 second), support 2d/bev/3d/aos. , support coco-style AP. If you use command line interface, numba need some time to compile jit functions. 3 | 4 | _WARNING_: The "coco" isn't official metrics. Only "AP(Average Precision)" is. 5 | ## Dependencies 6 | Only support python 3.6+, need `numpy`, `skimage`, `numba`, `fire`, `scipy`. If you have Anaconda, just install `cudatoolkit` in anaconda. Otherwise, please reference to this [page](https://github.com/numba/numba#custom-python-environments) to set up llvm and cuda for numba. 7 | * Install by conda: 8 | ``` 9 | conda install -c numba cudatoolkit=x.x (8.0, 9.0, 10.0, depend on your environment) 10 | ``` 11 | ## Usage 12 | * commandline interface: 13 | ``` 14 | python evaluate.py evaluate --label_path=/path/to/your_gt_label_folder --result_path=/path/to/your_result_folder --label_split_file=/path/to/val.txt --current_class=0 --coco=False 15 | ``` 16 | * python interface: 17 | ```Python 18 | import kitti_common as kitti 19 | from eval import get_official_eval_result, get_coco_eval_result 20 | def _read_imageset_file(path): 21 | with open(path, 'r') as f: 22 | lines = f.readlines() 23 | return [int(line) for line in lines] 24 | det_path = "/path/to/your_result_folder" 25 | dt_annos = kitti.get_label_annos(det_path) 26 | gt_path = "/path/to/your_gt_label_folder" 27 | gt_split_file = "/path/to/val.txt" # from https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz 28 | val_image_ids = _read_imageset_file(gt_split_file) 29 | gt_annos = kitti.get_label_annos(gt_path, val_image_ids) 30 | print(get_official_eval_result(gt_annos, dt_annos, 0)) # 6s in my computer 31 | print(get_coco_eval_result(gt_annos, dt_annos, 0)) # 18s in my computer 32 | ``` 33 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import io as sysio 2 | import time 3 | 4 | import numba 5 | import numpy as np 6 | from scipy.interpolate import interp1d 7 | 8 | from second.core.non_max_suppression.nms_gpu import rotate_iou_gpu_eval 9 | 10 | 11 | def get_mAP(prec): 12 | sums = 0 13 | for i in range(0, len(prec), 4): 14 | sums += prec[i] 15 | return sums / 11 * 100 16 | 17 | 18 | @numba.jit 19 | def get_thresholds(scores: np.ndarray, num_gt, num_sample_pts=41): 20 | scores.sort() 21 | scores = scores[::-1] 22 | current_recall = 0 23 | thresholds = [] 24 | for i, score in enumerate(scores): 25 | l_recall = (i + 1) / num_gt 26 | if i < (len(scores) - 1): 27 | r_recall = (i + 2) / num_gt 28 | else: 29 | r_recall = l_recall 30 | if (((r_recall - current_recall) < (current_recall - l_recall)) 31 | and (i < (len(scores) - 1))): 32 | continue 33 | # recall = l_recall 34 | thresholds.append(score) 35 | current_recall += 1 / (num_sample_pts - 1.0) 36 | # print(len(thresholds), len(scores), num_gt) 37 | return thresholds 38 | 39 | 40 | def clean_data(gt_anno, dt_anno, current_class, difficulty): 41 | CLASS_NAMES = [ 42 | 'car', 'pedestrian', 'cyclist', 'van', 'person_sitting', 'car', 43 | 'tractor', 'trailer' 44 | ] 45 | MIN_HEIGHT = [40, 25, 25] 46 | MAX_OCCLUSION = [0, 1, 2] 47 | MAX_TRUNCATION = [0.15, 0.3, 0.5] 48 | dc_bboxes, ignored_gt, ignored_dt = [], [], [] 49 | current_cls_name = CLASS_NAMES[current_class].lower() 50 | num_gt = len(gt_anno["name"]) 51 | num_dt = len(dt_anno["name"]) 52 | num_valid_gt = 0 53 | for i in range(num_gt): 54 | bbox = gt_anno["bbox"][i] 55 | gt_name = gt_anno["name"][i].lower() 56 | height = bbox[3] - bbox[1] 57 | valid_class = -1 58 | if (gt_name == current_cls_name): 59 | valid_class = 1 60 | elif (current_cls_name == "Pedestrian".lower() 61 | and "Person_sitting".lower() == gt_name): 62 | valid_class = 0 63 | elif (current_cls_name == "Car".lower() and "Van".lower() == gt_name): 64 | valid_class = 0 65 | else: 66 | valid_class = -1 67 | ignore = False 68 | if ((gt_anno["occluded"][i] > MAX_OCCLUSION[difficulty]) 69 | or (gt_anno["truncated"][i] > MAX_TRUNCATION[difficulty]) 70 | or (height <= MIN_HEIGHT[difficulty])): 71 | # if gt_anno["difficulty"][i] > difficulty or gt_anno["difficulty"][i] == -1: 72 | ignore = True 73 | if valid_class == 1 and not ignore: 74 | ignored_gt.append(0) 75 | num_valid_gt += 1 76 | elif (valid_class == 0 or (ignore and (valid_class == 1))): 77 | ignored_gt.append(1) 78 | else: 79 | ignored_gt.append(-1) 80 | # for i in range(num_gt): 81 | if gt_anno["name"][i] == "DontCare": 82 | dc_bboxes.append(gt_anno["bbox"][i]) 83 | for i in range(num_dt): 84 | if (dt_anno["name"][i].lower() == current_cls_name): 85 | valid_class = 1 86 | else: 87 | valid_class = -1 88 | height = abs(dt_anno["bbox"][i, 3] - dt_anno["bbox"][i, 1]) 89 | if height < MIN_HEIGHT[difficulty]: 90 | ignored_dt.append(1) 91 | elif valid_class == 1: 92 | ignored_dt.append(0) 93 | else: 94 | ignored_dt.append(-1) 95 | 96 | return num_valid_gt, ignored_gt, ignored_dt, dc_bboxes 97 | 98 | 99 | @numba.jit(nopython=True) 100 | def image_box_overlap(boxes, query_boxes, criterion=-1): 101 | N = boxes.shape[0] 102 | K = query_boxes.shape[0] 103 | overlaps = np.zeros((N, K), dtype=boxes.dtype) 104 | for k in range(K): 105 | qbox_area = ((query_boxes[k, 2] - query_boxes[k, 0]) * 106 | (query_boxes[k, 3] - query_boxes[k, 1])) 107 | for n in range(N): 108 | iw = (min(boxes[n, 2], query_boxes[k, 2]) - max( 109 | boxes[n, 0], query_boxes[k, 0])) 110 | if iw > 0: 111 | ih = (min(boxes[n, 3], query_boxes[k, 3]) - max( 112 | boxes[n, 1], query_boxes[k, 1])) 113 | if ih > 0: 114 | if criterion == -1: 115 | ua = ( 116 | (boxes[n, 2] - boxes[n, 0]) * 117 | (boxes[n, 3] - boxes[n, 1]) + qbox_area - iw * ih) 118 | elif criterion == 0: 119 | ua = ((boxes[n, 2] - boxes[n, 0]) * 120 | (boxes[n, 3] - boxes[n, 1])) 121 | elif criterion == 1: 122 | ua = qbox_area 123 | else: 124 | ua = 1.0 125 | overlaps[n, k] = iw * ih / ua 126 | return overlaps 127 | 128 | 129 | def bev_box_overlap(boxes, qboxes, criterion=-1): 130 | riou = rotate_iou_gpu_eval(boxes, qboxes, criterion) 131 | return riou 132 | 133 | 134 | @numba.jit(nopython=True, parallel=True) 135 | def d3_box_overlap_kernel(boxes, 136 | qboxes, 137 | rinc, 138 | criterion=-1, 139 | z_axis=1, 140 | z_center=1.0): 141 | """ 142 | z_axis: the z (height) axis. 143 | z_center: unified z (height) center of box. 144 | """ 145 | N, K = boxes.shape[0], qboxes.shape[0] 146 | for i in range(N): 147 | for j in range(K): 148 | if rinc[i, j] > 0: 149 | min_z = min( 150 | boxes[i, z_axis] + boxes[i, z_axis + 3] * (1 - z_center), 151 | qboxes[j, z_axis] + qboxes[j, z_axis + 3] * (1 - z_center)) 152 | max_z = max( 153 | boxes[i, z_axis] - boxes[i, z_axis + 3] * z_center, 154 | qboxes[j, z_axis] - qboxes[j, z_axis + 3] * z_center) 155 | iw = min_z - max_z 156 | if iw > 0: 157 | area1 = boxes[i, 3] * boxes[i, 4] * boxes[i, 5] 158 | area2 = qboxes[j, 3] * qboxes[j, 4] * qboxes[j, 5] 159 | inc = iw * rinc[i, j] 160 | if criterion == -1: 161 | ua = (area1 + area2 - inc) 162 | elif criterion == 0: 163 | ua = area1 164 | elif criterion == 1: 165 | ua = area2 166 | else: 167 | ua = 1.0 168 | rinc[i, j] = inc / ua 169 | else: 170 | rinc[i, j] = 0.0 171 | 172 | 173 | def d3_box_overlap(boxes, qboxes, criterion=-1, z_axis=1, z_center=1.0): 174 | """kitti camera format z_axis=1. 175 | """ 176 | bev_axes = list(range(7)) 177 | bev_axes.pop(z_axis + 3) 178 | bev_axes.pop(z_axis) 179 | rinc = rotate_iou_gpu_eval(boxes[:, bev_axes], qboxes[:, bev_axes], 2) 180 | d3_box_overlap_kernel(boxes, qboxes, rinc, criterion, z_axis, z_center) 181 | return rinc 182 | 183 | 184 | @numba.jit(nopython=True) 185 | def compute_statistics_jit(overlaps, 186 | gt_datas, 187 | dt_datas, 188 | ignored_gt, 189 | ignored_det, 190 | dc_bboxes, 191 | metric, 192 | min_overlap, 193 | thresh=0, 194 | compute_fp=False, 195 | compute_aos=False): 196 | 197 | det_size = dt_datas.shape[0] 198 | gt_size = gt_datas.shape[0] 199 | dt_scores = dt_datas[:, -1] 200 | dt_alphas = dt_datas[:, 4] 201 | gt_alphas = gt_datas[:, 4] 202 | dt_bboxes = dt_datas[:, :4] 203 | # gt_bboxes = gt_datas[:, :4] 204 | 205 | assigned_detection = [False] * det_size 206 | ignored_threshold = [False] * det_size 207 | if compute_fp: 208 | for i in range(det_size): 209 | if (dt_scores[i] < thresh): 210 | ignored_threshold[i] = True 211 | NO_DETECTION = -10000000 212 | tp, fp, fn, similarity = 0, 0, 0, 0 213 | # thresholds = [0.0] 214 | # delta = [0.0] 215 | thresholds = np.zeros((gt_size, )) 216 | thresh_idx = 0 217 | delta = np.zeros((gt_size, )) 218 | delta_idx = 0 219 | for i in range(gt_size): 220 | if ignored_gt[i] == -1: 221 | continue 222 | det_idx = -1 223 | valid_detection = NO_DETECTION 224 | max_overlap = 0 225 | assigned_ignored_det = False 226 | 227 | for j in range(det_size): 228 | if (ignored_det[j] == -1): 229 | continue 230 | if (assigned_detection[j]): 231 | continue 232 | if (ignored_threshold[j]): 233 | continue 234 | overlap = overlaps[j, i] 235 | dt_score = dt_scores[j] 236 | if (not compute_fp and (overlap > min_overlap) 237 | and dt_score > valid_detection): 238 | det_idx = j 239 | valid_detection = dt_score 240 | elif (compute_fp and (overlap > min_overlap) 241 | and (overlap > max_overlap or assigned_ignored_det) 242 | and ignored_det[j] == 0): 243 | max_overlap = overlap 244 | det_idx = j 245 | valid_detection = 1 246 | assigned_ignored_det = False 247 | elif (compute_fp and (overlap > min_overlap) 248 | and (valid_detection == NO_DETECTION) 249 | and ignored_det[j] == 1): 250 | det_idx = j 251 | valid_detection = 1 252 | assigned_ignored_det = True 253 | 254 | if (valid_detection == NO_DETECTION) and ignored_gt[i] == 0: 255 | fn += 1 256 | elif ((valid_detection != NO_DETECTION) 257 | and (ignored_gt[i] == 1 or ignored_det[det_idx] == 1)): 258 | assigned_detection[det_idx] = True 259 | elif valid_detection != NO_DETECTION: 260 | # only a tp add a threshold. 261 | tp += 1 262 | # thresholds.append(dt_scores[det_idx]) 263 | thresholds[thresh_idx] = dt_scores[det_idx] 264 | thresh_idx += 1 265 | if compute_aos: 266 | # delta.append(gt_alphas[i] - dt_alphas[det_idx]) 267 | delta[delta_idx] = gt_alphas[i] - dt_alphas[det_idx] 268 | delta_idx += 1 269 | 270 | assigned_detection[det_idx] = True 271 | if compute_fp: 272 | for i in range(det_size): 273 | if (not (assigned_detection[i] or ignored_det[i] == -1 274 | or ignored_det[i] == 1 or ignored_threshold[i])): 275 | fp += 1 276 | nstuff = 0 277 | if metric == 0: 278 | overlaps_dt_dc = image_box_overlap(dt_bboxes, dc_bboxes, 0) 279 | for i in range(dc_bboxes.shape[0]): 280 | for j in range(det_size): 281 | if (assigned_detection[j]): 282 | continue 283 | if (ignored_det[j] == -1 or ignored_det[j] == 1): 284 | continue 285 | if (ignored_threshold[j]): 286 | continue 287 | if overlaps_dt_dc[j, i] > min_overlap: 288 | assigned_detection[j] = True 289 | nstuff += 1 290 | fp -= nstuff 291 | if compute_aos: 292 | tmp = np.zeros((fp + delta_idx, )) 293 | # tmp = [0] * fp 294 | for i in range(delta_idx): 295 | tmp[i + fp] = (1.0 + np.cos(delta[i])) / 2.0 296 | # tmp.append((1.0 + np.cos(delta[i])) / 2.0) 297 | # assert len(tmp) == fp + tp 298 | # assert len(delta) == tp 299 | if tp > 0 or fp > 0: 300 | similarity = np.sum(tmp) 301 | else: 302 | similarity = -1 303 | return tp, fp, fn, similarity, thresholds[:thresh_idx] 304 | 305 | 306 | def get_split_parts(num, num_part): 307 | same_part = num // num_part 308 | remain_num = num % num_part 309 | if remain_num == 0: 310 | return [same_part] * num_part 311 | else: 312 | return [same_part] * num_part + [remain_num] 313 | 314 | 315 | @numba.jit(nopython=True) 316 | def fused_compute_statistics(overlaps, 317 | pr, 318 | gt_nums, 319 | dt_nums, 320 | dc_nums, 321 | gt_datas, 322 | dt_datas, 323 | dontcares, 324 | ignored_gts, 325 | ignored_dets, 326 | metric, 327 | min_overlap, 328 | thresholds, 329 | compute_aos=False): 330 | gt_num = 0 331 | dt_num = 0 332 | dc_num = 0 333 | for i in range(gt_nums.shape[0]): 334 | for t, thresh in enumerate(thresholds): 335 | overlap = overlaps[dt_num:dt_num + dt_nums[i], gt_num:gt_num + 336 | gt_nums[i]] 337 | 338 | gt_data = gt_datas[gt_num:gt_num + gt_nums[i]] 339 | dt_data = dt_datas[dt_num:dt_num + dt_nums[i]] 340 | ignored_gt = ignored_gts[gt_num:gt_num + gt_nums[i]] 341 | ignored_det = ignored_dets[dt_num:dt_num + dt_nums[i]] 342 | dontcare = dontcares[dc_num:dc_num + dc_nums[i]] 343 | tp, fp, fn, similarity, _ = compute_statistics_jit( 344 | overlap, 345 | gt_data, 346 | dt_data, 347 | ignored_gt, 348 | ignored_det, 349 | dontcare, 350 | metric, 351 | min_overlap=min_overlap, 352 | thresh=thresh, 353 | compute_fp=True, 354 | compute_aos=compute_aos) 355 | pr[t, 0] += tp 356 | pr[t, 1] += fp 357 | pr[t, 2] += fn 358 | if similarity != -1: 359 | pr[t, 3] += similarity 360 | gt_num += gt_nums[i] 361 | dt_num += dt_nums[i] 362 | dc_num += dc_nums[i] 363 | 364 | 365 | def calculate_iou_partly(gt_annos, 366 | dt_annos, 367 | metric, 368 | num_parts=50, 369 | z_axis=1, 370 | z_center=1.0): 371 | """fast iou algorithm. this function can be used independently to 372 | do result analysis. 373 | Args: 374 | gt_annos: dict, must from get_label_annos() in kitti_common.py 375 | dt_annos: dict, must from get_label_annos() in kitti_common.py 376 | metric: eval type. 0: bbox, 1: bev, 2: 3d 377 | num_parts: int. a parameter for fast calculate algorithm 378 | z_axis: height axis. kitti camera use 1, lidar use 2. 379 | """ 380 | assert len(gt_annos) == len(dt_annos) 381 | total_dt_num = np.stack([len(a["name"]) for a in dt_annos], 0) 382 | total_gt_num = np.stack([len(a["name"]) for a in gt_annos], 0) 383 | num_examples = len(gt_annos) 384 | split_parts = get_split_parts(num_examples, num_parts) 385 | parted_overlaps = [] 386 | example_idx = 0 387 | bev_axes = list(range(3)) 388 | bev_axes.pop(z_axis) 389 | for num_part in split_parts: 390 | gt_annos_part = gt_annos[example_idx:example_idx + num_part] 391 | dt_annos_part = dt_annos[example_idx:example_idx + num_part] 392 | if metric == 0: 393 | gt_boxes = np.concatenate([a["bbox"] for a in gt_annos_part], 0) 394 | dt_boxes = np.concatenate([a["bbox"] for a in dt_annos_part], 0) 395 | overlap_part = image_box_overlap(gt_boxes, dt_boxes) 396 | elif metric == 1: 397 | loc = np.concatenate( 398 | [a["location"][:, bev_axes] for a in gt_annos_part], 0) 399 | dims = np.concatenate( 400 | [a["dimensions"][:, bev_axes] for a in gt_annos_part], 0) 401 | rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0) 402 | gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], 403 | axis=1) 404 | loc = np.concatenate( 405 | [a["location"][:, bev_axes] for a in dt_annos_part], 0) 406 | dims = np.concatenate( 407 | [a["dimensions"][:, bev_axes] for a in dt_annos_part], 0) 408 | rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0) 409 | dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], 410 | axis=1) 411 | overlap_part = bev_box_overlap(gt_boxes, 412 | dt_boxes).astype(np.float64) 413 | elif metric == 2: 414 | loc = np.concatenate([a["location"] for a in gt_annos_part], 0) 415 | dims = np.concatenate([a["dimensions"] for a in gt_annos_part], 0) 416 | rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0) 417 | gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], 418 | axis=1) 419 | loc = np.concatenate([a["location"] for a in dt_annos_part], 0) 420 | dims = np.concatenate([a["dimensions"] for a in dt_annos_part], 0) 421 | rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0) 422 | dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], 423 | axis=1) 424 | overlap_part = d3_box_overlap( 425 | gt_boxes, dt_boxes, z_axis=z_axis, 426 | z_center=z_center).astype(np.float64) 427 | else: 428 | raise ValueError("unknown metric") 429 | parted_overlaps.append(overlap_part) 430 | example_idx += num_part 431 | overlaps = [] 432 | example_idx = 0 433 | for j, num_part in enumerate(split_parts): 434 | gt_annos_part = gt_annos[example_idx:example_idx + num_part] 435 | dt_annos_part = dt_annos[example_idx:example_idx + num_part] 436 | gt_num_idx, dt_num_idx = 0, 0 437 | for i in range(num_part): 438 | gt_box_num = total_gt_num[example_idx + i] 439 | dt_box_num = total_dt_num[example_idx + i] 440 | overlaps.append( 441 | parted_overlaps[j][gt_num_idx:gt_num_idx + 442 | gt_box_num, dt_num_idx:dt_num_idx + 443 | dt_box_num]) 444 | gt_num_idx += gt_box_num 445 | dt_num_idx += dt_box_num 446 | example_idx += num_part 447 | 448 | return overlaps, parted_overlaps, total_gt_num, total_dt_num 449 | 450 | 451 | def _prepare_data(gt_annos, dt_annos, current_class, difficulty): 452 | gt_datas_list = [] 453 | dt_datas_list = [] 454 | total_dc_num = [] 455 | ignored_gts, ignored_dets, dontcares = [], [], [] 456 | total_num_valid_gt = 0 457 | for i in range(len(gt_annos)): 458 | rets = clean_data(gt_annos[i], dt_annos[i], current_class, difficulty) 459 | num_valid_gt, ignored_gt, ignored_det, dc_bboxes = rets 460 | ignored_gts.append(np.array(ignored_gt, dtype=np.int64)) 461 | ignored_dets.append(np.array(ignored_det, dtype=np.int64)) 462 | if len(dc_bboxes) == 0: 463 | dc_bboxes = np.zeros((0, 4)).astype(np.float64) 464 | else: 465 | dc_bboxes = np.stack(dc_bboxes, 0).astype(np.float64) 466 | total_dc_num.append(dc_bboxes.shape[0]) 467 | dontcares.append(dc_bboxes) 468 | total_num_valid_gt += num_valid_gt 469 | gt_datas = np.concatenate( 470 | [gt_annos[i]["bbox"], gt_annos[i]["alpha"][..., np.newaxis]], 1) 471 | dt_datas = np.concatenate([ 472 | dt_annos[i]["bbox"], dt_annos[i]["alpha"][..., np.newaxis], 473 | dt_annos[i]["score"][..., np.newaxis] 474 | ], 1) 475 | gt_datas_list.append(gt_datas) 476 | dt_datas_list.append(dt_datas) 477 | total_dc_num = np.stack(total_dc_num, axis=0) 478 | return (gt_datas_list, dt_datas_list, ignored_gts, ignored_dets, dontcares, 479 | total_dc_num, total_num_valid_gt) 480 | 481 | 482 | def eval_class(gt_annos, 483 | dt_annos, 484 | current_classes, 485 | difficultys, 486 | metric, 487 | min_overlaps, 488 | compute_aos=False, 489 | z_axis=1, 490 | z_center=1.0, 491 | num_parts=50): 492 | """Kitti eval. support 2d/bev/3d/aos eval. support 0.5:0.05:0.95 coco AP. 493 | Args: 494 | gt_annos: dict, must from get_label_annos() in kitti_common.py 495 | dt_annos: dict, must from get_label_annos() in kitti_common.py 496 | current_class: int, 0: car, 1: pedestrian, 2: cyclist 497 | difficulty: int. eval difficulty, 0: easy, 1: normal, 2: hard 498 | metric: eval type. 0: bbox, 1: bev, 2: 3d 499 | min_overlap: float, min overlap. official: 500 | [[0.7, 0.5, 0.5], [0.7, 0.5, 0.5], [0.7, 0.5, 0.5]] 501 | format: [metric, class]. choose one from matrix above. 502 | num_parts: int. a parameter for fast calculate algorithm 503 | 504 | Returns: 505 | dict of recall, precision and aos 506 | """ 507 | assert len(gt_annos) == len(dt_annos) 508 | num_examples = len(gt_annos) 509 | split_parts = get_split_parts(num_examples, num_parts) 510 | 511 | rets = calculate_iou_partly( 512 | dt_annos, 513 | gt_annos, 514 | metric, 515 | num_parts, 516 | z_axis=z_axis, 517 | z_center=z_center) 518 | overlaps, parted_overlaps, total_dt_num, total_gt_num = rets 519 | N_SAMPLE_PTS = 41 520 | num_minoverlap = len(min_overlaps) 521 | num_class = len(current_classes) 522 | num_difficulty = len(difficultys) 523 | precision = np.zeros( 524 | [num_class, num_difficulty, num_minoverlap, N_SAMPLE_PTS]) 525 | recall = np.zeros( 526 | [num_class, num_difficulty, num_minoverlap, N_SAMPLE_PTS]) 527 | aos = np.zeros([num_class, num_difficulty, num_minoverlap, N_SAMPLE_PTS]) 528 | all_thresholds = np.zeros([num_class, num_difficulty, num_minoverlap, N_SAMPLE_PTS]) 529 | for m, current_class in enumerate(current_classes): 530 | for l, difficulty in enumerate(difficultys): 531 | rets = _prepare_data(gt_annos, dt_annos, current_class, difficulty) 532 | (gt_datas_list, dt_datas_list, ignored_gts, ignored_dets, 533 | dontcares, total_dc_num, total_num_valid_gt) = rets 534 | for k, min_overlap in enumerate(min_overlaps[:, metric, m]): 535 | thresholdss = [] 536 | for i in range(len(gt_annos)): 537 | rets = compute_statistics_jit( 538 | overlaps[i], 539 | gt_datas_list[i], 540 | dt_datas_list[i], 541 | ignored_gts[i], 542 | ignored_dets[i], 543 | dontcares[i], 544 | metric, 545 | min_overlap=min_overlap, 546 | thresh=0.0, 547 | compute_fp=False) 548 | tp, fp, fn, similarity, thresholds = rets 549 | thresholdss += thresholds.tolist() 550 | thresholdss = np.array(thresholdss) 551 | thresholds = get_thresholds(thresholdss, total_num_valid_gt) 552 | thresholds = np.array(thresholds) 553 | all_thresholds[m, l, k, :len(thresholds)] = thresholds 554 | pr = np.zeros([len(thresholds), 4]) 555 | idx = 0 556 | for j, num_part in enumerate(split_parts): 557 | gt_datas_part = np.concatenate( 558 | gt_datas_list[idx:idx + num_part], 0) 559 | dt_datas_part = np.concatenate( 560 | dt_datas_list[idx:idx + num_part], 0) 561 | dc_datas_part = np.concatenate( 562 | dontcares[idx:idx + num_part], 0) 563 | ignored_dets_part = np.concatenate( 564 | ignored_dets[idx:idx + num_part], 0) 565 | ignored_gts_part = np.concatenate( 566 | ignored_gts[idx:idx + num_part], 0) 567 | fused_compute_statistics( 568 | parted_overlaps[j], 569 | pr, 570 | total_gt_num[idx:idx + num_part], 571 | total_dt_num[idx:idx + num_part], 572 | total_dc_num[idx:idx + num_part], 573 | gt_datas_part, 574 | dt_datas_part, 575 | dc_datas_part, 576 | ignored_gts_part, 577 | ignored_dets_part, 578 | metric, 579 | min_overlap=min_overlap, 580 | thresholds=thresholds, 581 | compute_aos=compute_aos) 582 | idx += num_part 583 | for i in range(len(thresholds)): 584 | precision[m, l, k, i] = pr[i, 0] / (pr[i, 0] + pr[i, 1]) 585 | if compute_aos: 586 | aos[m, l, k, i] = pr[i, 3] / (pr[i, 0] + pr[i, 1]) 587 | for i in range(len(thresholds)): 588 | precision[m, l, k, i] = np.max( 589 | precision[m, l, k, i:], axis=-1) 590 | if compute_aos: 591 | aos[m, l, k, i] = np.max(aos[m, l, k, i:], axis=-1) 592 | 593 | ret_dict = { 594 | # "recall": recall, # [num_class, num_difficulty, num_minoverlap, N_SAMPLE_PTS] 595 | "precision": precision, 596 | "orientation": aos, 597 | "thresholds": all_thresholds, 598 | "min_overlaps": min_overlaps, 599 | } 600 | return ret_dict 601 | 602 | 603 | def get_mAP_v2(prec): 604 | sums = 0 605 | for i in range(0, prec.shape[-1], 4): 606 | sums = sums + prec[..., i] 607 | return sums / 11 * 100 608 | 609 | 610 | def do_eval_v2(gt_annos, 611 | dt_annos, 612 | current_classes, 613 | min_overlaps, 614 | compute_aos=False, 615 | difficultys=(0, 1, 2), 616 | z_axis=1, 617 | z_center=1.0): 618 | # min_overlaps: [num_minoverlap, metric, num_class] 619 | ret = eval_class( 620 | gt_annos, 621 | dt_annos, 622 | current_classes, 623 | difficultys, 624 | 0, 625 | min_overlaps, 626 | compute_aos, 627 | z_axis=z_axis, 628 | z_center=z_center) 629 | # ret: [num_class, num_diff, num_minoverlap, num_sample_points] 630 | mAP_bbox = get_mAP_v2(ret["precision"]) 631 | mAP_aos = None 632 | if compute_aos: 633 | mAP_aos = get_mAP_v2(ret["orientation"]) 634 | ret = eval_class( 635 | gt_annos, 636 | dt_annos, 637 | current_classes, 638 | difficultys, 639 | 1, 640 | min_overlaps, 641 | z_axis=z_axis, 642 | z_center=z_center) 643 | mAP_bev = get_mAP_v2(ret["precision"]) 644 | ret = eval_class( 645 | gt_annos, 646 | dt_annos, 647 | current_classes, 648 | difficultys, 649 | 2, 650 | min_overlaps, 651 | z_axis=z_axis, 652 | z_center=z_center) 653 | mAP_3d = get_mAP_v2(ret["precision"]) 654 | return mAP_bbox, mAP_bev, mAP_3d, mAP_aos 655 | 656 | def do_eval_v3(gt_annos, 657 | dt_annos, 658 | current_classes, 659 | min_overlaps, 660 | compute_aos=False, 661 | difficultys=(0, 1, 2), 662 | z_axis=1, 663 | z_center=1.0): 664 | # min_overlaps: [num_minoverlap, metric, num_class] 665 | types = ["bbox", "bev", "3d"] 666 | metrics = {} 667 | for i in range(3): 668 | ret = eval_class( 669 | gt_annos, 670 | dt_annos, 671 | current_classes, 672 | difficultys, 673 | i, 674 | min_overlaps, 675 | compute_aos, 676 | z_axis=z_axis, 677 | z_center=z_center) 678 | metrics[types[i]] = ret 679 | return metrics 680 | 681 | 682 | def do_coco_style_eval(gt_annos, 683 | dt_annos, 684 | current_classes, 685 | overlap_ranges, 686 | compute_aos, 687 | z_axis=1, 688 | z_center=1.0): 689 | # overlap_ranges: [range, metric, num_class] 690 | min_overlaps = np.zeros([10, *overlap_ranges.shape[1:]]) 691 | for i in range(overlap_ranges.shape[1]): 692 | for j in range(overlap_ranges.shape[2]): 693 | min_overlaps[:, i, j] = np.linspace(*overlap_ranges[:, i, j]) 694 | mAP_bbox, mAP_bev, mAP_3d, mAP_aos = do_eval_v2( 695 | gt_annos, 696 | dt_annos, 697 | current_classes, 698 | min_overlaps, 699 | compute_aos, 700 | z_axis=z_axis, 701 | z_center=z_center) 702 | # ret: [num_class, num_diff, num_minoverlap] 703 | mAP_bbox = mAP_bbox.mean(-1) 704 | mAP_bev = mAP_bev.mean(-1) 705 | mAP_3d = mAP_3d.mean(-1) 706 | if mAP_aos is not None: 707 | mAP_aos = mAP_aos.mean(-1) 708 | return mAP_bbox, mAP_bev, mAP_3d, mAP_aos 709 | 710 | 711 | def print_str(value, *arg, sstream=None): 712 | if sstream is None: 713 | sstream = sysio.StringIO() 714 | sstream.truncate(0) 715 | sstream.seek(0) 716 | print(value, *arg, file=sstream) 717 | return sstream.getvalue() 718 | 719 | def get_official_eval_result(gt_annos, 720 | dt_annos, 721 | current_classes, 722 | difficultys=[0, 1, 2], 723 | z_axis=1, 724 | z_center=1.0): 725 | """ 726 | gt_annos and dt_annos must contains following keys: 727 | [bbox, location, dimensions, rotation_y, score] 728 | """ 729 | overlap_mod = np.array([[0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7], 730 | [0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7], 731 | [0.7, 0.5, 0.5, 0.7, 0.5, 0.7, 0.7, 0.7]]) 732 | overlap_easy = np.array([[0.7, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.5], 733 | [0.5, 0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.5], 734 | [0.5, 0.25, 0.25, 0.5, 0.25, 0.5, 0.5, 0.5]]) 735 | min_overlaps = np.stack([overlap_mod, overlap_easy], axis=0) # [2, 3, 5] 736 | class_to_name = { 737 | 0: 'Car', 738 | 1: 'Pedestrian', 739 | 2: 'Cyclist', 740 | 3: 'Van', 741 | 4: 'Person_sitting', 742 | 5: 'car', 743 | 6: 'tractor', 744 | 7: 'trailer', 745 | } 746 | name_to_class = {v: n for n, v in class_to_name.items()} 747 | if not isinstance(current_classes, (list, tuple)): 748 | current_classes = [current_classes] 749 | current_classes_int = [] 750 | for curcls in current_classes: 751 | if isinstance(curcls, str): 752 | current_classes_int.append(name_to_class[curcls]) 753 | else: 754 | current_classes_int.append(curcls) 755 | current_classes = current_classes_int 756 | min_overlaps = min_overlaps[:, :, current_classes] 757 | result = '' 758 | # check whether alpha is valid 759 | compute_aos = False 760 | for anno in dt_annos: 761 | if anno['alpha'].shape[0] != 0: 762 | if anno['alpha'][0] != -10: 763 | compute_aos = True 764 | break 765 | metrics = do_eval_v3( 766 | gt_annos, 767 | dt_annos, 768 | current_classes, 769 | min_overlaps, 770 | compute_aos, 771 | difficultys, 772 | z_axis=z_axis, 773 | z_center=z_center) 774 | for j, curcls in enumerate(current_classes): 775 | # mAP threshold array: [num_minoverlap, metric, class] 776 | # mAP result: [num_class, num_diff, num_minoverlap] 777 | for i in range(min_overlaps.shape[0]): 778 | mAPbbox = get_mAP_v2(metrics["bbox"]["precision"][j, :, i]) 779 | mAPbbox = ", ".join(f"{v:.2f}" for v in mAPbbox) 780 | mAPbev = get_mAP_v2(metrics["bev"]["precision"][j, :, i]) 781 | mAPbev = ", ".join(f"{v:.2f}" for v in mAPbev) 782 | mAP3d = get_mAP_v2(metrics["3d"]["precision"][j, :, i]) 783 | mAP3d = ", ".join(f"{v:.2f}" for v in mAP3d) 784 | result += print_str( 785 | (f"{class_to_name[curcls]} " 786 | "AP(Average Precision)@{:.2f}, {:.2f}, {:.2f}:".format(*min_overlaps[i, :, j]))) 787 | result += print_str(f"bbox AP:{mAPbbox}") 788 | result += print_str(f"bev AP:{mAPbev}") 789 | result += print_str(f"3d AP:{mAP3d}") 790 | if compute_aos: 791 | mAPaos = get_mAP_v2(metrics["bbox"]["orientation"][j, :, i]) 792 | mAPaos = ", ".join(f"{v:.2f}" for v in mAPaos) 793 | result += print_str(f"aos AP:{mAPaos}") 794 | 795 | 796 | return result 797 | 798 | 799 | def get_coco_eval_result(gt_annos, 800 | dt_annos, 801 | current_classes, 802 | z_axis=1, 803 | z_center=1.0): 804 | class_to_name = { 805 | 0: 'Car', 806 | 1: 'Pedestrian', 807 | 2: 'Cyclist', 808 | 3: 'Van', 809 | 4: 'Person_sitting', 810 | 5: 'car', 811 | 6: 'tractor', 812 | 7: 'trailer', 813 | } 814 | class_to_range = { 815 | 0: [0.5, 1.0, 0.05], 816 | 1: [0.25, 0.75, 0.05], 817 | 2: [0.25, 0.75, 0.05], 818 | 3: [0.5, 1.0, 0.05], 819 | 4: [0.25, 0.75, 0.05], 820 | 5: [0.5, 1.0, 0.05], 821 | 6: [0.5, 1.0, 0.05], 822 | 7: [0.5, 1.0, 0.05], 823 | } 824 | class_to_range = { 825 | 0: [0.5, 0.95, 10], 826 | 1: [0.25, 0.7, 10], 827 | 2: [0.25, 0.7, 10], 828 | 3: [0.5, 0.95, 10], 829 | 4: [0.25, 0.7, 10], 830 | 5: [0.5, 0.95, 10], 831 | 6: [0.5, 0.95, 10], 832 | 7: [0.5, 0.95, 10], 833 | } 834 | 835 | name_to_class = {v: n for n, v in class_to_name.items()} 836 | if not isinstance(current_classes, (list, tuple)): 837 | current_classes = [current_classes] 838 | current_classes_int = [] 839 | for curcls in current_classes: 840 | if isinstance(curcls, str): 841 | current_classes_int.append(name_to_class[curcls]) 842 | else: 843 | current_classes_int.append(curcls) 844 | current_classes = current_classes_int 845 | overlap_ranges = np.zeros([3, 3, len(current_classes)]) 846 | for i, curcls in enumerate(current_classes): 847 | overlap_ranges[:, :, i] = np.array( 848 | class_to_range[curcls])[:, np.newaxis] 849 | result = '' 850 | # check whether alpha is valid 851 | compute_aos = False 852 | for anno in dt_annos: 853 | if anno['alpha'].shape[0] != 0: 854 | if anno['alpha'][0] != -10: 855 | compute_aos = True 856 | break 857 | mAPbbox, mAPbev, mAP3d, mAPaos = do_coco_style_eval( 858 | gt_annos, 859 | dt_annos, 860 | current_classes, 861 | overlap_ranges, 862 | compute_aos, 863 | z_axis=z_axis, 864 | z_center=z_center) 865 | for j, curcls in enumerate(current_classes): 866 | # mAP threshold array: [num_minoverlap, metric, class] 867 | # mAP result: [num_class, num_diff, num_minoverlap] 868 | o_range = np.array(class_to_range[curcls])[[0, 2, 1]] 869 | o_range[1] = (o_range[2] - o_range[0]) / (o_range[1] - 1) 870 | result += print_str((f"{class_to_name[curcls]} " 871 | "coco AP@{:.2f}:{:.2f}:{:.2f}:".format(*o_range))) 872 | result += print_str((f"bbox AP:{mAPbbox[j, 0]:.2f}, " 873 | f"{mAPbbox[j, 1]:.2f}, " 874 | f"{mAPbbox[j, 2]:.2f}")) 875 | result += print_str((f"bev AP:{mAPbev[j, 0]:.2f}, " 876 | f"{mAPbev[j, 1]:.2f}, " 877 | f"{mAPbev[j, 2]:.2f}")) 878 | result += print_str((f"3d AP:{mAP3d[j, 0]:.2f}, " 879 | f"{mAP3d[j, 1]:.2f}, " 880 | f"{mAP3d[j, 2]:.2f}")) 881 | if compute_aos: 882 | result += print_str((f"aos AP:{mAPaos[j, 0]:.2f}, " 883 | f"{mAPaos[j, 1]:.2f}, " 884 | f"{mAPaos[j, 2]:.2f}")) 885 | return result -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | import fire 3 | import kitti_common as kitti 4 | from eval import get_official_eval_result, get_coco_eval_result 5 | 6 | 7 | def _read_imageset_file(path): 8 | with open(path, 'r') as f: 9 | lines = f.readlines() 10 | return [int(line) for line in lines] 11 | 12 | 13 | def evaluate(label_path, 14 | result_path, 15 | label_split_file, 16 | current_class=0, 17 | coco=False, 18 | score_thresh=-1): 19 | dt_annos = kitti.get_label_annos(result_path) 20 | if score_thresh > 0: 21 | dt_annos = kitti.filter_annos_low_score(dt_annos, score_thresh) 22 | val_image_ids = _read_imageset_file(label_split_file) 23 | gt_annos = kitti.get_label_annos(label_path, val_image_ids) 24 | if coco: 25 | print(get_coco_eval_result(gt_annos, dt_annos, current_class)) 26 | else: 27 | print(get_official_eval_result(gt_annos, dt_annos, current_class)) 28 | 29 | 30 | if __name__ == '__main__': 31 | fire.Fire() 32 | -------------------------------------------------------------------------------- /kitti_common.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures as futures 2 | import os 3 | import pathlib 4 | import re 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | from skimage import io 9 | 10 | def get_image_index_str(img_idx): 11 | return "{:06d}".format(img_idx) 12 | 13 | 14 | def get_kitti_info_path(idx, 15 | prefix, 16 | info_type='image_2', 17 | file_tail='.png', 18 | training=True, 19 | relative_path=True): 20 | img_idx_str = get_image_index_str(idx) 21 | img_idx_str += file_tail 22 | prefix = pathlib.Path(prefix) 23 | if training: 24 | file_path = pathlib.Path('training') / info_type / img_idx_str 25 | else: 26 | file_path = pathlib.Path('testing') / info_type / img_idx_str 27 | if not (prefix / file_path).exists(): 28 | raise ValueError("file not exist: {}".format(file_path)) 29 | if relative_path: 30 | return str(file_path) 31 | else: 32 | return str(prefix / file_path) 33 | 34 | 35 | def get_image_path(idx, prefix, training=True, relative_path=True): 36 | return get_kitti_info_path(idx, prefix, 'image_2', '.png', training, 37 | relative_path) 38 | 39 | 40 | def get_label_path(idx, prefix, training=True, relative_path=True): 41 | return get_kitti_info_path(idx, prefix, 'label_2', '.txt', training, 42 | relative_path) 43 | 44 | 45 | def get_velodyne_path(idx, prefix, training=True, relative_path=True): 46 | return get_kitti_info_path(idx, prefix, 'velodyne', '.bin', training, 47 | relative_path) 48 | 49 | 50 | def get_calib_path(idx, prefix, training=True, relative_path=True): 51 | return get_kitti_info_path(idx, prefix, 'calib', '.txt', training, 52 | relative_path) 53 | 54 | 55 | def _extend_matrix(mat): 56 | mat = np.concatenate([mat, np.array([[0., 0., 0., 1.]])], axis=0) 57 | return mat 58 | 59 | 60 | def get_kitti_image_info(path, 61 | training=True, 62 | label_info=True, 63 | velodyne=False, 64 | calib=False, 65 | image_ids=7481, 66 | extend_matrix=True, 67 | num_worker=8, 68 | relative_path=True, 69 | with_imageshape=True): 70 | # image_infos = [] 71 | root_path = pathlib.Path(path) 72 | if not isinstance(image_ids, list): 73 | image_ids = list(range(image_ids)) 74 | 75 | def map_func(idx): 76 | image_info = {'image_idx': idx} 77 | annotations = None 78 | if velodyne: 79 | image_info['velodyne_path'] = get_velodyne_path( 80 | idx, path, training, relative_path) 81 | image_info['img_path'] = get_image_path(idx, path, training, 82 | relative_path) 83 | if with_imageshape: 84 | img_path = image_info['img_path'] 85 | if relative_path: 86 | img_path = str(root_path / img_path) 87 | image_info['img_shape'] = np.array( 88 | io.imread(img_path).shape[:2], dtype=np.int32) 89 | if label_info: 90 | label_path = get_label_path(idx, path, training, relative_path) 91 | if relative_path: 92 | label_path = str(root_path / label_path) 93 | annotations = get_label_anno(label_path) 94 | if calib: 95 | calib_path = get_calib_path( 96 | idx, path, training, relative_path=False) 97 | with open(calib_path, 'r') as f: 98 | lines = f.readlines() 99 | P0 = np.array( 100 | [float(info) for info in lines[0].split(' ')[1:13]]).reshape( 101 | [3, 4]) 102 | P1 = np.array( 103 | [float(info) for info in lines[1].split(' ')[1:13]]).reshape( 104 | [3, 4]) 105 | P2 = np.array( 106 | [float(info) for info in lines[2].split(' ')[1:13]]).reshape( 107 | [3, 4]) 108 | P3 = np.array( 109 | [float(info) for info in lines[3].split(' ')[1:13]]).reshape( 110 | [3, 4]) 111 | if extend_matrix: 112 | P0 = _extend_matrix(P0) 113 | P1 = _extend_matrix(P1) 114 | P2 = _extend_matrix(P2) 115 | P3 = _extend_matrix(P3) 116 | image_info['calib/P0'] = P0 117 | image_info['calib/P1'] = P1 118 | image_info['calib/P2'] = P2 119 | image_info['calib/P3'] = P3 120 | R0_rect = np.array([ 121 | float(info) for info in lines[4].split(' ')[1:10] 122 | ]).reshape([3, 3]) 123 | if extend_matrix: 124 | rect_4x4 = np.zeros([4, 4], dtype=R0_rect.dtype) 125 | rect_4x4[3, 3] = 1. 126 | rect_4x4[:3, :3] = R0_rect 127 | else: 128 | rect_4x4 = R0_rect 129 | image_info['calib/R0_rect'] = rect_4x4 130 | Tr_velo_to_cam = np.array([ 131 | float(info) for info in lines[5].split(' ')[1:13] 132 | ]).reshape([3, 4]) 133 | Tr_imu_to_velo = np.array([ 134 | float(info) for info in lines[6].split(' ')[1:13] 135 | ]).reshape([3, 4]) 136 | if extend_matrix: 137 | Tr_velo_to_cam = _extend_matrix(Tr_velo_to_cam) 138 | Tr_imu_to_velo = _extend_matrix(Tr_imu_to_velo) 139 | image_info['calib/Tr_velo_to_cam'] = Tr_velo_to_cam 140 | image_info['calib/Tr_imu_to_velo'] = Tr_imu_to_velo 141 | if annotations is not None: 142 | image_info['annos'] = annotations 143 | add_difficulty_to_annos(image_info) 144 | return image_info 145 | 146 | with futures.ThreadPoolExecutor(num_worker) as executor: 147 | image_infos = executor.map(map_func, image_ids) 148 | return list(image_infos) 149 | 150 | 151 | def filter_kitti_anno(image_anno, 152 | used_classes, 153 | used_difficulty=None, 154 | dontcare_iou=None): 155 | if not isinstance(used_classes, (list, tuple)): 156 | used_classes = [used_classes] 157 | img_filtered_annotations = {} 158 | relevant_annotation_indices = [ 159 | i for i, x in enumerate(image_anno['name']) if x in used_classes 160 | ] 161 | for key in image_anno.keys(): 162 | img_filtered_annotations[key] = ( 163 | image_anno[key][relevant_annotation_indices]) 164 | if used_difficulty is not None: 165 | relevant_annotation_indices = [ 166 | i for i, x in enumerate(img_filtered_annotations['difficulty']) 167 | if x in used_difficulty 168 | ] 169 | for key in image_anno.keys(): 170 | img_filtered_annotations[key] = ( 171 | img_filtered_annotations[key][relevant_annotation_indices]) 172 | 173 | if 'DontCare' in used_classes and dontcare_iou is not None: 174 | dont_care_indices = [ 175 | i for i, x in enumerate(img_filtered_annotations['name']) 176 | if x == 'DontCare' 177 | ] 178 | # bounding box format [y_min, x_min, y_max, x_max] 179 | all_boxes = img_filtered_annotations['bbox'] 180 | ious = iou(all_boxes, all_boxes[dont_care_indices]) 181 | 182 | # Remove all bounding boxes that overlap with a dontcare region. 183 | if ious.size > 0: 184 | boxes_to_remove = np.amax(ious, axis=1) > dontcare_iou 185 | for key in image_anno.keys(): 186 | img_filtered_annotations[key] = (img_filtered_annotations[key][ 187 | np.logical_not(boxes_to_remove)]) 188 | return img_filtered_annotations 189 | 190 | def filter_annos_low_score(image_annos, thresh): 191 | new_image_annos = [] 192 | for anno in image_annos: 193 | img_filtered_annotations = {} 194 | relevant_annotation_indices = [ 195 | i for i, s in enumerate(anno['score']) if s >= thresh 196 | ] 197 | for key in anno.keys(): 198 | img_filtered_annotations[key] = ( 199 | anno[key][relevant_annotation_indices]) 200 | new_image_annos.append(img_filtered_annotations) 201 | return new_image_annos 202 | 203 | def kitti_result_line(result_dict, precision=4): 204 | prec_float = "{" + ":.{}f".format(precision) + "}" 205 | res_line = [] 206 | all_field_default = OrderedDict([ 207 | ('name', None), 208 | ('truncated', -1), 209 | ('occluded', -1), 210 | ('alpha', -10), 211 | ('bbox', None), 212 | ('dimensions', [-1, -1, -1]), 213 | ('location', [-1000, -1000, -1000]), 214 | ('rotation_y', -10), 215 | ('score', None), 216 | ]) 217 | res_dict = [(key, None) for key, val in all_field_default.items()] 218 | res_dict = OrderedDict(res_dict) 219 | for key, val in result_dict.items(): 220 | if all_field_default[key] is None and val is None: 221 | raise ValueError("you must specify a value for {}".format(key)) 222 | res_dict[key] = val 223 | 224 | for key, val in res_dict.items(): 225 | if key == 'name': 226 | res_line.append(val) 227 | elif key in ['truncated', 'alpha', 'rotation_y', 'score']: 228 | if val is None: 229 | res_line.append(str(all_field_default[key])) 230 | else: 231 | res_line.append(prec_float.format(val)) 232 | elif key == 'occluded': 233 | if val is None: 234 | res_line.append(str(all_field_default[key])) 235 | else: 236 | res_line.append('{}'.format(val)) 237 | elif key in ['bbox', 'dimensions', 'location']: 238 | if val is None: 239 | res_line += [str(v) for v in all_field_default[key]] 240 | else: 241 | res_line += [prec_float.format(v) for v in val] 242 | else: 243 | raise ValueError("unknown key. supported key:{}".format( 244 | res_dict.keys())) 245 | return ' '.join(res_line) 246 | 247 | 248 | def add_difficulty_to_annos(info): 249 | min_height = [40, 25, 250 | 25] # minimum height for evaluated groundtruth/detections 251 | max_occlusion = [ 252 | 0, 1, 2 253 | ] # maximum occlusion level of the groundtruth used for evaluation 254 | max_trunc = [ 255 | 0.15, 0.3, 0.5 256 | ] # maximum truncation level of the groundtruth used for evaluation 257 | annos = info['annos'] 258 | dims = annos['dimensions'] # lhw format 259 | bbox = annos['bbox'] 260 | height = bbox[:, 3] - bbox[:, 1] 261 | occlusion = annos['occluded'] 262 | truncation = annos['truncated'] 263 | diff = [] 264 | easy_mask = np.ones((len(dims), ), dtype=np.bool) 265 | moderate_mask = np.ones((len(dims), ), dtype=np.bool) 266 | hard_mask = np.ones((len(dims), ), dtype=np.bool) 267 | i = 0 268 | for h, o, t in zip(height, occlusion, truncation): 269 | if o > max_occlusion[0] or h <= min_height[0] or t > max_trunc[0]: 270 | easy_mask[i] = False 271 | if o > max_occlusion[1] or h <= min_height[1] or t > max_trunc[1]: 272 | moderate_mask[i] = False 273 | if o > max_occlusion[2] or h <= min_height[2] or t > max_trunc[2]: 274 | hard_mask[i] = False 275 | i += 1 276 | is_easy = easy_mask 277 | is_moderate = np.logical_xor(easy_mask, moderate_mask) 278 | is_hard = np.logical_xor(hard_mask, moderate_mask) 279 | 280 | for i in range(len(dims)): 281 | if is_easy[i]: 282 | diff.append(0) 283 | elif is_moderate[i]: 284 | diff.append(1) 285 | elif is_hard[i]: 286 | diff.append(2) 287 | else: 288 | diff.append(-1) 289 | annos["difficulty"] = np.array(diff, np.int32) 290 | return diff 291 | 292 | 293 | def get_label_anno(label_path): 294 | annotations = {} 295 | annotations.update({ 296 | 'name': [], 297 | 'truncated': [], 298 | 'occluded': [], 299 | 'alpha': [], 300 | 'bbox': [], 301 | 'dimensions': [], 302 | 'location': [], 303 | 'rotation_y': [] 304 | }) 305 | with open(label_path, 'r') as f: 306 | lines = f.readlines() 307 | # if len(lines) == 0 or len(lines[0]) < 15: 308 | # content = [] 309 | # else: 310 | content = [line.strip().split(' ') for line in lines] 311 | annotations['name'] = np.array([x[0] for x in content]) 312 | annotations['truncated'] = np.array([float(x[1]) for x in content]) 313 | annotations['occluded'] = np.array([int(x[2]) for x in content]) 314 | annotations['alpha'] = np.array([float(x[3]) for x in content]) 315 | annotations['bbox'] = np.array( 316 | [[float(info) for info in x[4:8]] for x in content]).reshape(-1, 4) 317 | # dimensions will convert hwl format to standard lhw(camera) format. 318 | annotations['dimensions'] = np.array( 319 | [[float(info) for info in x[8:11]] for x in content]).reshape( 320 | -1, 3)[:, [2, 0, 1]] 321 | annotations['location'] = np.array( 322 | [[float(info) for info in x[11:14]] for x in content]).reshape(-1, 3) 323 | annotations['rotation_y'] = np.array( 324 | [float(x[14]) for x in content]).reshape(-1) 325 | if len(content) != 0 and len(content[0]) == 16: # have score 326 | annotations['score'] = np.array([float(x[15]) for x in content]) 327 | else: 328 | annotations['score'] = np.zeros([len(annotations['bbox'])]) 329 | return annotations 330 | 331 | def get_label_annos(label_folder, image_ids=None): 332 | if image_ids is None: 333 | filepaths = pathlib.Path(label_folder).glob('*.txt') 334 | prog = re.compile(r'^\d{6}.txt$') 335 | filepaths = filter(lambda f: prog.match(f.name), filepaths) 336 | image_ids = [int(p.stem) for p in filepaths] 337 | image_ids = sorted(image_ids) 338 | if not isinstance(image_ids, list): 339 | image_ids = list(range(image_ids)) 340 | annos = [] 341 | label_folder = pathlib.Path(label_folder) 342 | for idx in image_ids: 343 | image_idx = get_image_index_str(idx) 344 | label_filename = label_folder / (image_idx + '.txt') 345 | annos.append(get_label_anno(label_filename)) 346 | return annos 347 | 348 | def area(boxes, add1=False): 349 | """Computes area of boxes. 350 | 351 | Args: 352 | boxes: Numpy array with shape [N, 4] holding N boxes 353 | 354 | Returns: 355 | a numpy array with shape [N*1] representing box areas 356 | """ 357 | if add1: 358 | return (boxes[:, 2] - boxes[:, 0] + 1.0) * ( 359 | boxes[:, 3] - boxes[:, 1] + 1.0) 360 | else: 361 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 362 | 363 | 364 | def intersection(boxes1, boxes2, add1=False): 365 | """Compute pairwise intersection areas between boxes. 366 | 367 | Args: 368 | boxes1: a numpy array with shape [N, 4] holding N boxes 369 | boxes2: a numpy array with shape [M, 4] holding M boxes 370 | 371 | Returns: 372 | a numpy array with shape [N*M] representing pairwise intersection area 373 | """ 374 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 375 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 376 | 377 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 378 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 379 | if add1: 380 | all_pairs_min_ymax += 1.0 381 | intersect_heights = np.maximum( 382 | np.zeros(all_pairs_max_ymin.shape), 383 | all_pairs_min_ymax - all_pairs_max_ymin) 384 | 385 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 386 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 387 | if add1: 388 | all_pairs_min_xmax += 1.0 389 | intersect_widths = np.maximum( 390 | np.zeros(all_pairs_max_xmin.shape), 391 | all_pairs_min_xmax - all_pairs_max_xmin) 392 | return intersect_heights * intersect_widths 393 | 394 | 395 | def iou(boxes1, boxes2, add1=False): 396 | """Computes pairwise intersection-over-union between box collections. 397 | 398 | Args: 399 | boxes1: a numpy array with shape [N, 4] holding N boxes. 400 | boxes2: a numpy array with shape [M, 4] holding N boxes. 401 | 402 | Returns: 403 | a numpy array with shape [N, M] representing pairwise iou scores. 404 | """ 405 | intersect = intersection(boxes1, boxes2, add1) 406 | area1 = area(boxes1, add1) 407 | area2 = area(boxes2, add1) 408 | union = np.expand_dims( 409 | area1, axis=1) + np.expand_dims( 410 | area2, axis=0) - intersect 411 | return intersect / union 412 | -------------------------------------------------------------------------------- /rotate_iou.py: -------------------------------------------------------------------------------- 1 | ##################### 2 | # Based on https://github.com/hongzhenwang/RRPN-revise 3 | # Licensed under The MIT License 4 | # Author: yanyan, scrin@foxmail.com 5 | ##################### 6 | import math 7 | 8 | import numba 9 | import numpy as np 10 | from numba import cuda 11 | 12 | @numba.jit(nopython=True) 13 | def div_up(m, n): 14 | return m // n + (m % n > 0) 15 | 16 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 17 | def trangle_area(a, b, c): 18 | return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) * 19 | (b[0] - c[0])) / 2.0 20 | 21 | 22 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 23 | def area(int_pts, num_of_inter): 24 | area_val = 0.0 25 | for i in range(num_of_inter - 2): 26 | area_val += abs( 27 | trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4], 28 | int_pts[2 * i + 4:2 * i + 6])) 29 | return area_val 30 | 31 | 32 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 33 | def sort_vertex_in_convex_polygon(int_pts, num_of_inter): 34 | if num_of_inter > 0: 35 | center = cuda.local.array((2, ), dtype=numba.float32) 36 | center[:] = 0.0 37 | for i in range(num_of_inter): 38 | center[0] += int_pts[2 * i] 39 | center[1] += int_pts[2 * i + 1] 40 | center[0] /= num_of_inter 41 | center[1] /= num_of_inter 42 | v = cuda.local.array((2, ), dtype=numba.float32) 43 | vs = cuda.local.array((16, ), dtype=numba.float32) 44 | for i in range(num_of_inter): 45 | v[0] = int_pts[2 * i] - center[0] 46 | v[1] = int_pts[2 * i + 1] - center[1] 47 | d = math.sqrt(v[0] * v[0] + v[1] * v[1]) 48 | v[0] = v[0] / d 49 | v[1] = v[1] / d 50 | if v[1] < 0: 51 | v[0] = -2 - v[0] 52 | vs[i] = v[0] 53 | j = 0 54 | temp = 0 55 | for i in range(1, num_of_inter): 56 | if vs[i - 1] > vs[i]: 57 | temp = vs[i] 58 | tx = int_pts[2 * i] 59 | ty = int_pts[2 * i + 1] 60 | j = i 61 | while j > 0 and vs[j - 1] > temp: 62 | vs[j] = vs[j - 1] 63 | int_pts[j * 2] = int_pts[j * 2 - 2] 64 | int_pts[j * 2 + 1] = int_pts[j * 2 - 1] 65 | j -= 1 66 | 67 | vs[j] = temp 68 | int_pts[j * 2] = tx 69 | int_pts[j * 2 + 1] = ty 70 | 71 | 72 | @cuda.jit( 73 | '(float32[:], float32[:], int32, int32, float32[:])', 74 | device=True, 75 | inline=True) 76 | def line_segment_intersection(pts1, pts2, i, j, temp_pts): 77 | A = cuda.local.array((2, ), dtype=numba.float32) 78 | B = cuda.local.array((2, ), dtype=numba.float32) 79 | C = cuda.local.array((2, ), dtype=numba.float32) 80 | D = cuda.local.array((2, ), dtype=numba.float32) 81 | 82 | A[0] = pts1[2 * i] 83 | A[1] = pts1[2 * i + 1] 84 | 85 | B[0] = pts1[2 * ((i + 1) % 4)] 86 | B[1] = pts1[2 * ((i + 1) % 4) + 1] 87 | 88 | C[0] = pts2[2 * j] 89 | C[1] = pts2[2 * j + 1] 90 | 91 | D[0] = pts2[2 * ((j + 1) % 4)] 92 | D[1] = pts2[2 * ((j + 1) % 4) + 1] 93 | BA0 = B[0] - A[0] 94 | BA1 = B[1] - A[1] 95 | DA0 = D[0] - A[0] 96 | CA0 = C[0] - A[0] 97 | DA1 = D[1] - A[1] 98 | CA1 = C[1] - A[1] 99 | acd = DA1 * CA0 > CA1 * DA0 100 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0]) 101 | if acd != bcd: 102 | abc = CA1 * BA0 > BA1 * CA0 103 | abd = DA1 * BA0 > BA1 * DA0 104 | if abc != abd: 105 | DC0 = D[0] - C[0] 106 | DC1 = D[1] - C[1] 107 | ABBA = A[0] * B[1] - B[0] * A[1] 108 | CDDC = C[0] * D[1] - D[0] * C[1] 109 | DH = BA1 * DC0 - BA0 * DC1 110 | Dx = ABBA * DC0 - BA0 * CDDC 111 | Dy = ABBA * DC1 - BA1 * CDDC 112 | temp_pts[0] = Dx / DH 113 | temp_pts[1] = Dy / DH 114 | return True 115 | return False 116 | 117 | 118 | @cuda.jit( 119 | '(float32[:], float32[:], int32, int32, float32[:])', 120 | device=True, 121 | inline=True) 122 | def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts): 123 | a = cuda.local.array((2, ), dtype=numba.float32) 124 | b = cuda.local.array((2, ), dtype=numba.float32) 125 | c = cuda.local.array((2, ), dtype=numba.float32) 126 | d = cuda.local.array((2, ), dtype=numba.float32) 127 | 128 | a[0] = pts1[2 * i] 129 | a[1] = pts1[2 * i + 1] 130 | 131 | b[0] = pts1[2 * ((i + 1) % 4)] 132 | b[1] = pts1[2 * ((i + 1) % 4) + 1] 133 | 134 | c[0] = pts2[2 * j] 135 | c[1] = pts2[2 * j + 1] 136 | 137 | d[0] = pts2[2 * ((j + 1) % 4)] 138 | d[1] = pts2[2 * ((j + 1) % 4) + 1] 139 | 140 | area_abc = trangle_area(a, b, c) 141 | area_abd = trangle_area(a, b, d) 142 | 143 | if area_abc * area_abd >= 0: 144 | return False 145 | 146 | area_cda = trangle_area(c, d, a) 147 | area_cdb = area_cda + area_abc - area_abd 148 | 149 | if area_cda * area_cdb >= 0: 150 | return False 151 | t = area_cda / (area_abd - area_abc) 152 | 153 | dx = t * (b[0] - a[0]) 154 | dy = t * (b[1] - a[1]) 155 | temp_pts[0] = a[0] + dx 156 | temp_pts[1] = a[1] + dy 157 | return True 158 | 159 | 160 | @cuda.jit('(float32, float32, float32[:])', device=True, inline=True) 161 | def point_in_quadrilateral(pt_x, pt_y, corners): 162 | ab0 = corners[2] - corners[0] 163 | ab1 = corners[3] - corners[1] 164 | 165 | ad0 = corners[6] - corners[0] 166 | ad1 = corners[7] - corners[1] 167 | 168 | ap0 = pt_x - corners[0] 169 | ap1 = pt_y - corners[1] 170 | 171 | abab = ab0 * ab0 + ab1 * ab1 172 | abap = ab0 * ap0 + ab1 * ap1 173 | adad = ad0 * ad0 + ad1 * ad1 174 | adap = ad0 * ap0 + ad1 * ap1 175 | 176 | return abab >= abap and abap >= 0 and adad >= adap and adap >= 0 177 | 178 | 179 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 180 | def quadrilateral_intersection(pts1, pts2, int_pts): 181 | num_of_inter = 0 182 | for i in range(4): 183 | if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2): 184 | int_pts[num_of_inter * 2] = pts1[2 * i] 185 | int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1] 186 | num_of_inter += 1 187 | if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1): 188 | int_pts[num_of_inter * 2] = pts2[2 * i] 189 | int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1] 190 | num_of_inter += 1 191 | temp_pts = cuda.local.array((2, ), dtype=numba.float32) 192 | for i in range(4): 193 | for j in range(4): 194 | has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts) 195 | if has_pts: 196 | int_pts[num_of_inter * 2] = temp_pts[0] 197 | int_pts[num_of_inter * 2 + 1] = temp_pts[1] 198 | num_of_inter += 1 199 | 200 | return num_of_inter 201 | 202 | 203 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 204 | def rbbox_to_corners(corners, rbbox): 205 | # generate clockwise corners and rotate it clockwise 206 | angle = rbbox[4] 207 | a_cos = math.cos(angle) 208 | a_sin = math.sin(angle) 209 | center_x = rbbox[0] 210 | center_y = rbbox[1] 211 | x_d = rbbox[2] 212 | y_d = rbbox[3] 213 | corners_x = cuda.local.array((4, ), dtype=numba.float32) 214 | corners_y = cuda.local.array((4, ), dtype=numba.float32) 215 | corners_x[0] = -x_d / 2 216 | corners_x[1] = -x_d / 2 217 | corners_x[2] = x_d / 2 218 | corners_x[3] = x_d / 2 219 | corners_y[0] = -y_d / 2 220 | corners_y[1] = y_d / 2 221 | corners_y[2] = y_d / 2 222 | corners_y[3] = -y_d / 2 223 | for i in range(4): 224 | corners[2 * 225 | i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x 226 | corners[2 * i 227 | + 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y 228 | 229 | 230 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 231 | def inter(rbbox1, rbbox2): 232 | corners1 = cuda.local.array((8, ), dtype=numba.float32) 233 | corners2 = cuda.local.array((8, ), dtype=numba.float32) 234 | intersection_corners = cuda.local.array((16, ), dtype=numba.float32) 235 | 236 | rbbox_to_corners(corners1, rbbox1) 237 | rbbox_to_corners(corners2, rbbox2) 238 | 239 | num_intersection = quadrilateral_intersection(corners1, corners2, 240 | intersection_corners) 241 | sort_vertex_in_convex_polygon(intersection_corners, num_intersection) 242 | # print(intersection_corners.reshape([-1, 2])[:num_intersection]) 243 | 244 | return area(intersection_corners, num_intersection) 245 | 246 | 247 | @cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True) 248 | def devRotateIoUEval(rbox1, rbox2, criterion=-1): 249 | area1 = rbox1[2] * rbox1[3] 250 | area2 = rbox2[2] * rbox2[3] 251 | area_inter = inter(rbox1, rbox2) 252 | if criterion == -1: 253 | return area_inter / (area1 + area2 - area_inter) 254 | elif criterion == 0: 255 | return area_inter / area1 256 | elif criterion == 1: 257 | return area_inter / area2 258 | else: 259 | return area_inter 260 | 261 | @cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False) 262 | def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1): 263 | threadsPerBlock = 8 * 8 264 | row_start = cuda.blockIdx.x 265 | col_start = cuda.blockIdx.y 266 | tx = cuda.threadIdx.x 267 | row_size = min(N - row_start * threadsPerBlock, threadsPerBlock) 268 | col_size = min(K - col_start * threadsPerBlock, threadsPerBlock) 269 | block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 270 | block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 271 | 272 | dev_query_box_idx = threadsPerBlock * col_start + tx 273 | dev_box_idx = threadsPerBlock * row_start + tx 274 | if (tx < col_size): 275 | block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0] 276 | block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1] 277 | block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2] 278 | block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3] 279 | block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4] 280 | if (tx < row_size): 281 | block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0] 282 | block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1] 283 | block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2] 284 | block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3] 285 | block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4] 286 | cuda.syncthreads() 287 | if tx < row_size: 288 | for i in range(col_size): 289 | offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i 290 | dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5], 291 | block_boxes[tx * 5:tx * 5 + 5], criterion) 292 | 293 | 294 | def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0): 295 | """rotated box iou running in gpu. 500x faster than cpu version 296 | (take 5ms in one example with numba.cuda code). 297 | convert from [this project]( 298 | https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation). 299 | 300 | Args: 301 | boxes (float tensor: [N, 5]): rbboxes. format: centers, dims, 302 | angles(clockwise when positive) 303 | query_boxes (float tensor: [K, 5]): [description] 304 | device_id (int, optional): Defaults to 0. [description] 305 | 306 | Returns: 307 | [type]: [description] 308 | """ 309 | box_dtype = boxes.dtype 310 | boxes = boxes.astype(np.float32) 311 | query_boxes = query_boxes.astype(np.float32) 312 | N = boxes.shape[0] 313 | K = query_boxes.shape[0] 314 | iou = np.zeros((N, K), dtype=np.float32) 315 | if N == 0 or K == 0: 316 | return iou 317 | threadsPerBlock = 8 * 8 318 | cuda.select_device(device_id) 319 | blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock)) 320 | 321 | stream = cuda.stream() 322 | with stream.auto_synchronize(): 323 | boxes_dev = cuda.to_device(boxes.reshape([-1]), stream) 324 | query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream) 325 | iou_dev = cuda.to_device(iou.reshape([-1]), stream) 326 | rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream]( 327 | N, K, boxes_dev, query_boxes_dev, iou_dev, criterion) 328 | iou_dev.copy_to_host(iou.reshape([-1]), stream=stream) 329 | return iou.astype(boxes.dtype) 330 | --------------------------------------------------------------------------------