├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── TrackEval ├── .gitignore ├── LICENSE ├── Readme.md ├── docs │ ├── How_To │ │ └── Add_a_new_metric.md │ ├── MOTChallenge-Official │ │ └── Readme.md │ ├── OpenWorldTracking-Official │ │ └── Readme.md │ └── RobMOTS-Official │ │ └── Readme.md ├── pyproject.toml ├── scripts │ ├── comparison_plots.py │ ├── run_bdd.py │ ├── run_burst.py │ ├── run_burst_ow.py │ ├── run_davis.py │ ├── run_headtracking_challenge.py │ ├── run_kitti.py │ ├── run_kitti_mots.py │ ├── run_mot_challenge.py │ ├── run_mots_challenge.py │ ├── run_person_path_22.py │ ├── run_rob_mots.py │ ├── run_tao.py │ ├── run_tao_ow.py │ └── run_youtube_vis.py ├── setup.cfg ├── setup.py ├── tests │ ├── test_all_quick.py │ ├── test_davis.py │ ├── test_metrics.py │ ├── test_mot17.py │ └── test_mots.py └── trackeval │ ├── __init__.py │ ├── _timing.py │ ├── baselines │ ├── __init__.py │ ├── baseline_utils.py │ ├── non_overlap.py │ ├── pascal_colormap.py │ ├── stp.py │ ├── thresholder.py │ └── vizualize.py │ ├── datasets │ ├── __init__.py │ ├── _base_dataset.py │ ├── bdd100k.py │ ├── burst.py │ ├── burst_helpers │ │ ├── BURST_SPECIFIC_ISSUES.md │ │ ├── __init__.py │ │ ├── burst_base.py │ │ ├── burst_ow_base.py │ │ ├── convert_burst_format_to_tao_format.py │ │ ├── format_converter.py │ │ └── tao_categories.json │ ├── burst_ow.py │ ├── davis.py │ ├── head_tracking_challenge.py │ ├── kitti_2d_box.py │ ├── kitti_mots.py │ ├── mot_challenge_2d_box.py │ ├── mots_challenge.py │ ├── person_path_22.py │ ├── rob_mots.py │ ├── rob_mots_classmap.py │ ├── run_rob_mots.py │ ├── tao.py │ ├── tao_ow.py │ └── youtube_vis.py │ ├── eval.py │ ├── metrics │ ├── __init__.py │ ├── _base_metric.py │ ├── clear.py │ ├── count.py │ ├── hota.py │ ├── identity.py │ ├── ideucl.py │ ├── j_and_f.py │ ├── track_map.py │ └── vace.py │ ├── plotting.py │ └── utils.py ├── configs ├── nusc.ini ├── virconv │ └── default.ini ├── voxel │ └── default.ini └── voxel_tta │ └── default.ini ├── data_processing ├── check_det_num.py ├── crop_det_images.py ├── crop_points.py ├── crop_seg_images.py ├── dataset_tracking2object.py ├── masks2boxes.py ├── object2tracking.py ├── save_img_shapes.py └── tracking2object.py ├── datasets ├── kitti_dataset.py └── voxelization.py ├── delete_far_objects.py ├── detection ├── kitti_object_eval_python │ ├── __init__.py │ ├── eval.py │ └── rotate_iou.py ├── rtmdet-ins_x_8xb16-300e_coco.py └── voxel_rcnn │ ├── __init__.py │ ├── anchor_generator.py │ ├── anchor_head_single.py │ ├── base_bev_backbone.py │ ├── box_coder.py │ ├── height_compression.py │ ├── iou3d_nms │ ├── __init__.py │ ├── iou3d_nms_utils.py │ └── src │ │ ├── iou3d_nms.cpp │ │ ├── iou3d_nms.h │ │ ├── iou3d_nms_api.cpp │ │ └── iou3d_nms_kernel.cu │ ├── mean_vfe.py │ ├── pointnet2_stack │ ├── __init__.py │ ├── pointnet2_modules.py │ ├── pointnet2_utils.py │ ├── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── ball_query_gpu.h │ │ ├── cuda_utils.h │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── group_points_gpu.h │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── interpolate_gpu.h │ │ ├── pointnet2_api.cpp │ │ ├── sampling.cpp │ │ ├── sampling_gpu.cu │ │ ├── sampling_gpu.h │ │ ├── vector_pool.cpp │ │ ├── vector_pool_gpu.cu │ │ ├── vector_pool_gpu.h │ │ ├── voxel_query.cpp │ │ ├── voxel_query_gpu.cu │ │ └── voxel_query_gpu.h │ ├── voxel_pool_modules.py │ └── voxel_query_utils.py │ ├── spconv_backbone.py │ ├── utils.py │ ├── voxel_rcnn.py │ └── voxel_rcnn_head.py ├── eval_kitti_detection.py ├── eval_nusc_detection.py ├── eval_nusc_tracking.py ├── kitti_2d_3d_det_fusion.py ├── kitti_2d_mots.py ├── kitti_3d_detection.py ├── kitti_3d_tracking.py ├── kitti_trajectory_refinement.py ├── mmdet_kitti_inference.py ├── mots_tools ├── LICENSE ├── README.md ├── eval.py ├── mots_common │ ├── images_to_txt.py │ └── io.py ├── mots_eval │ └── MOTS_metrics.py ├── mots_vis │ └── visualize_mots.py ├── test.seqmap ├── training.seqmap └── val.seqmap ├── nusc_3d_tracking.py ├── nusc_filter_code └── main.py ├── nusc_trajectory_refinement.py ├── scripts ├── eval_kitti_mots.ps1 ├── eval_kitti_tracking.ps1 ├── run_and_eval_kitti_backward.ps1 ├── run_and_eval_kitti_forward.ps1 └── run_and_eval_kitti_merge.ps1 ├── segmentation ├── point_track │ └── point_track.py └── spatial_embeddings │ └── spatial_embeddings.py ├── setup.py ├── tracking ├── __init__.py ├── association.py ├── detections │ ├── __init__.py │ ├── detections.py │ └── kitti_detections.py ├── motion_filters.py ├── mots_util.py ├── track.py ├── tracker.py ├── trajectory_clustering_split_and_recombination.py ├── trajectory_completion.py └── trajectory_refinement.py ├── utils.py └── visualization ├── visualize_embeddings.py ├── visualize_fn_fp.py ├── visualize_frame_bev.py ├── visualize_instances.py └── visualize_trajectories.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | output 3 | *.pth 4 | *.txt 5 | *.pyc 6 | *.obj 7 | build 8 | *.pyd 9 | *.egg-info 10 | __pycache__ 11 | *.pkl -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.watcherExclude": { 3 | "**/data": true, 4 | "**/output": true, 5 | "**/*.egg-info": true, 6 | "**/build": true, 7 | "**/__pycache__": true, 8 | // "**/TrackEval": true, 9 | }, 10 | "search.exclude": { 11 | "**/data": true, 12 | "**/output": true, 13 | "**/*.egg-info": true, 14 | "**/build": true, 15 | "**/__pycache__": true, 16 | // "**/TrackEval": true, 17 | }, 18 | "files.exclude": { 19 | "**/data": true, 20 | "**/output": true, 21 | "**/*.egg-info": true, 22 | "**/build": true, 23 | "**/__pycache__": true, 24 | // "**/TrackEval": true, 25 | } 26 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Kemiao Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiTrack 2 | 3 | [BiTrack: Bidirectional Offline 3D Multi-Object Tracking Using Camera-LiDAR Data](https://arxiv.org/abs/2406.18414) 4 | 5 | ## Get Started 6 | 7 | Put KITTI tracking data to the "data/kitti/tracking" directory (symbolic links are recommended). 8 | 9 | - If you want to generate 2D-3D detection results using VoxelRCNN and SpatialEmbedding, follow the [KITTI detection](#KITTI-Detection) instructions. 10 | 11 | - If you have generated detection files using other models, skip to [2D-3D fusion](#KITTI-2D-3D-fusion) instructions. Put 3D detection results to the "data/kitti/tracking/$split/det3d_out/$det3d_name" directory and put 2D segmentation results to the "data/kitti/tracking/$split/seg_out/$seg_name" directory. You can also use other file paths but the configuration file should be changed accordingly. 12 | 13 | ### KITTI Detection 14 | 15 | #### KITTI 3D Object Detection 16 | 17 | 1. Build CUDA operators for 3D-IoU and PointNet++. 18 | 19 | ```shell 20 | python setup.py develop 21 | ``` 22 | 23 | 2. Download the converted model weight file from [Google Drive](https://drive.google.com/drive/folders/1OBJPBAAJPf3pEXHRlNywmAERPnHXt7tt?usp=sharing) and put it to "detection/voxel_rcnn/voxel_rcnn.pth". 24 | 25 | 3. Network inference using a specific configuration file under the "configs" directory. 26 | 27 | ```shell 28 | python kitti_3d_detection.py $config_path $split --inference 29 | ``` 30 | 31 | 4. (Optional) Average precision evaluation for cars. (1) Convert the tracking labels to the detection format. (2) Convert the detection results to the detection format. (3) Perform evaluation using the converted labels and results. 32 | 33 | ```shell 34 | python data_processing/dataset_tracking2object.py 35 | python data_processing/tracking2object.py $result_src $result_dst 36 | python eval_kitti_detection.py ./data/kitti/detection/training/label_2 $result_dst 37 | ``` 38 | 39 | #### KITTI 2D Instance Segmentation 40 | 41 | 1. Download the converted model weight file from [Google Drive](https://drive.google.com/drive/folders/1OBJPBAAJPf3pEXHRlNywmAERPnHXt7tt?usp=sharing) and put it to "segmentation/spatial_embeddings/spatial_embeddings.pth". 42 | 43 | 2. Network inference using a specific configuration file under the "configs" directory. 44 | 45 | ```shell 46 | python kitti_2d_mots.py $config_path $split 47 | ``` 48 | 49 | #### KITTI 2D-3D Fusion 50 | 51 | 1. Crop LiDAR points that are inside 3D bounding boxes. 52 | 53 | ```shell 54 | python data_processing/crop_points.py $config_path $split 55 | ``` 56 | 57 | 2. Save image shapes to json (not all KITTI images have the same shape). 58 | 59 | ```shell 60 | python data_processing/save_img_shapes.py $config_path $split 61 | ``` 62 | 63 | 3. Run the detection fusion script. 64 | 65 | ```shell 66 | python kitti_2d_3d_det_fusion.py $config_path $split 67 | ``` 68 | 69 | ### KITTI Tracking 70 | 71 | 1. Forward tracking. 72 | 73 | ```shell 74 | python kitti_3d_tracking.py $config_path $forward_tag $split 75 | ``` 76 | 77 | 2. Backward tracking. 78 | 79 | ```shell 80 | python kitti_3d_tracking.py $config_path $backward_tag $split --backward 81 | ``` 82 | 83 | 3. Trajectory fusion and refinement. 84 | 85 | ```shell 86 | python kitti_trajectory_refinement.py $config_path $final_tag $split $foward_tag $backward_tag 87 | ``` 88 | 89 | 4. Evaluation for cars. 90 | 91 | ```shell 92 | python TrackEval/scripts/run_kitti.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER data/kitti/tracking/training --TRACKERS_FOLDER output/kitti/$split --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $tag --SPLIT_TO_EVAL $split 93 | ``` 94 | -------------------------------------------------------------------------------- /TrackEval/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | **.pyc 4 | **__pycache__ 5 | gt_data/* 6 | !gt_data/Readme.md 7 | tracker_output/* 8 | !tracker_output/Readme.md 9 | output/* 10 | data 11 | !goutput/Readme.md 12 | **/__pycache__ 13 | .idea 14 | error_log.txt 15 | -------------------------------------------------------------------------------- /TrackEval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jonathon Luiten 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /TrackEval/docs/How_To/Add_a_new_metric.md: -------------------------------------------------------------------------------- 1 | # How to add a new or custom family of evaluation metrics to TrackEval 2 | 3 | - Create your metrics code in ```trackeval/metrics/.py```. 4 | - It's probably easiest to start by copying an existing metrics code and editing it, e.g. ```trackeval/metrics/identity.py``` is probably the simplest. 5 | - Your metric should be class, and it should inherit from the ```trackeval.metrics._base_metric._BaseMetric``` class. 6 | - Define an ```__init__``` function that defines the different ```fields``` (values) that your metric will calculate. See ```trackeval/metrics/_base_metric.py``` for a list of currently used field types. Feel free to add new types. 7 | - Define your code to actually calculate your metric for a single sequence and single class in a function called ```eval_sequence```, which takes a data dictionary as input, and returns a results dictionary as output. 8 | - Define functions for how to combine your metric field values over a) sequences ```combine_sequences```, b) over classes ```combine_classes_class_averaged```, and c) over classes weighted by the number of detections ```combine_classes_det_averaged```. 9 | - We find using a function such as the ```_compute_final_fields``` function that we use in the current metrics is convienient because it is likely used for metrics calculation and for the different metric combination, however this is not required. 10 | - Register your new metric by adding it to ```trackeval/metrics/init.py``` 11 | - Your new metric can be used by passing the metrics class to a list of metrics which is passed to the evaluator (see files in ```scripts/*```). 12 | -------------------------------------------------------------------------------- /TrackEval/docs/OpenWorldTracking-Official/Readme.md: -------------------------------------------------------------------------------- 1 | ![owt](https://user-images.githubusercontent.com/23000532/160293694-6fc0a3da-c177-4776-8472-49ff6ff375a3.jpg) 2 | # Opening Up Open-World Tracking - Official Evaluation Code 3 | 4 | TrackEval now contains the official evalution code for evaluating the task of **Open World Tracking**. 5 | 6 | This is the official code from the following paper: 7 | 8 |
Opening up Open-World Tracking
 9 | Yang Liu*, Idil Esen Zulfikar*, Jonathon Luiten*, Achal Dave*, Deva Ramanan, Bastian Leibe, Aljoša Ošep, Laura Leal-Taixé
10 | *Equal contribution
11 | CVPR 2022
12 | 13 | [Paper](https://arxiv.org/abs/2104.11221) 14 | 15 | [Website](https://openworldtracking.github.io) 16 | 17 | ## Running and understanding the code 18 | 19 | The code can be run by running the following script (see script for arguments and how to run): 20 | [TAO-OW run script](https://github.com/JonathonLuiten/TrackEval/blob/master/scripts/run_tao_ow.py) 21 | 22 | To understand the the data is being read and used, see the TAO-OW dataset class: 23 | [TAO-OW dataset class](https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/datasets/tao_ow.py) 24 | 25 | The implementation of the 'Open World Tracking Accuracy' (OWTA) metric proposed in the paper can be found here: 26 | [OWTA metric](https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/hota.py) 27 | 28 | ## Citation 29 | If you work with the code and the benchmark, please cite: 30 | 31 | ***Opening Up Open-World Tracking*** 32 | ``` 33 | @inproceedings{liu2022opening, 34 | title={Opening up Open-World Tracking}, 35 | author={Liu, Yang and Zulfikar, Idil Esen and Luiten, Jonathon and Dave, Achal and Ramanan, Deva and Leibe, Bastian and O{\v{s}}ep, Aljo{\v{s}}a and Leal-Taix{\'e}, Laura}, 36 | journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 37 | year={2022} 38 | } 39 | ``` 40 | 41 | ***TrackEval*** 42 | ``` 43 | @misc{luiten2020trackeval, 44 | author = {Jonathon Luiten, Arne Hoffhues}, 45 | title = {TrackEval}, 46 | howpublished = {\url{https://github.com/JonathonLuiten/TrackEval}}, 47 | year = {2020} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /TrackEval/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /TrackEval/scripts/comparison_plots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | import trackeval # noqa: E402 6 | 7 | plots_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'data', 'plots')) 8 | tracker_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'data', 'trackers')) 9 | 10 | # dataset = os.path.join('kitti', 'kitti_2d_box_train') 11 | # classes = ['cars', 'pedestrian'] 12 | 13 | dataset = os.path.join('mot_challenge', 'MOT17-train') 14 | classes = ['pedestrian'] 15 | 16 | data_fol = os.path.join(tracker_folder, dataset) 17 | trackers = os.listdir(data_fol) 18 | out_loc = os.path.join(plots_folder, dataset) 19 | for cls in classes: 20 | trackeval.plotting.plot_compare_trackers(data_fol, trackers, cls, out_loc) 21 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_bdd.py: -------------------------------------------------------------------------------- 1 | 2 | """ run_bdd.py 3 | 4 | Run example: 5 | run_bdd.py --USE_PARALLEL False --METRICS Hota --TRACKERS_TO_EVAL qdtrack 6 | 7 | Command Line Arguments: Defaults, # Comments 8 | Eval arguments: 9 | 'USE_PARALLEL': False, 10 | 'NUM_PARALLEL_CORES': 8, 11 | 'BREAK_ON_ERROR': True, 12 | 'PRINT_RESULTS': True, 13 | 'PRINT_ONLY_COMBINED': False, 14 | 'PRINT_CONFIG': True, 15 | 'TIME_PROGRESS': True, 16 | 'OUTPUT_SUMMARY': True, 17 | 'OUTPUT_DETAILED': True, 18 | 'PLOT_CURVES': True, 19 | Dataset arguments: 20 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/bdd100k/bdd100k_val'), # Location of GT data 21 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/bdd100k/bdd100k_val'), # Trackers location 22 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 23 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 24 | 'CLASSES_TO_EVAL': ['pedestrian', 'rider', 'car', 'bus', 'truck', 'train', 'motorcycle', 'bicycle'], 25 | # Valid: ['pedestrian', 'rider', 'car', 'bus', 'truck', 'train', 'motorcycle', 'bicycle'] 26 | 'SPLIT_TO_EVAL': 'val', # Valid: 'training', 'val', 27 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 28 | 'PRINT_CONFIG': True, # Whether to print current config 29 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 30 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 31 | 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL 32 | Metric arguments: 33 | 'METRICS': ['Hota','Clear', 'ID', 'Count'] 34 | """ 35 | 36 | import sys 37 | import os 38 | import argparse 39 | from multiprocessing import freeze_support 40 | 41 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 42 | import trackeval # noqa: E402 43 | 44 | if __name__ == '__main__': 45 | freeze_support() 46 | 47 | # Command line interface: 48 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 49 | default_eval_config['PRINT_ONLY_COMBINED'] = True 50 | default_dataset_config = trackeval.datasets.BDD100K.get_default_dataset_config() 51 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity']} 52 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 53 | parser = argparse.ArgumentParser() 54 | for setting in config.keys(): 55 | if type(config[setting]) == list or type(config[setting]) == type(None): 56 | parser.add_argument("--" + setting, nargs='+') 57 | else: 58 | parser.add_argument("--" + setting) 59 | args = parser.parse_args().__dict__ 60 | for setting in args.keys(): 61 | if args[setting] is not None: 62 | if type(config[setting]) == type(True): 63 | if args[setting] == 'True': 64 | x = True 65 | elif args[setting] == 'False': 66 | x = False 67 | else: 68 | raise Exception('Command line parameter ' + setting + 'must be True or False') 69 | elif type(config[setting]) == type(1): 70 | x = int(args[setting]) 71 | elif type(args[setting]) == type(None): 72 | x = None 73 | else: 74 | x = args[setting] 75 | config[setting] = x 76 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 77 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 78 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 79 | 80 | # Run code 81 | evaluator = trackeval.Evaluator(eval_config) 82 | dataset_list = [trackeval.datasets.BDD100K(dataset_config)] 83 | metrics_list = [] 84 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity]: 85 | if metric.get_name() in metrics_config['METRICS']: 86 | metrics_list.append(metric()) 87 | if len(metrics_list) == 0: 88 | raise Exception('No metrics selected for evaluation') 89 | evaluator.evaluate(dataset_list, metrics_list) -------------------------------------------------------------------------------- /TrackEval/scripts/run_davis.py: -------------------------------------------------------------------------------- 1 | """ run_davis.py 2 | 3 | Run example: 4 | run_davis.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL ags 5 | 6 | Command Line Arguments: Defaults, # Comments 7 | Eval arguments: 8 | 'USE_PARALLEL': False, 9 | 'NUM_PARALLEL_CORES': 8, 10 | 'BREAK_ON_ERROR': True, 11 | 'PRINT_RESULTS': True, 12 | 'PRINT_ONLY_COMBINED': False, 13 | 'PRINT_CONFIG': True, 14 | 'TIME_PROGRESS': True, 15 | 'OUTPUT_SUMMARY': True, 16 | 'OUTPUT_DETAILED': True, 17 | 'PLOT_CURVES': True, 18 | Dataset arguments: 19 | ' 'GT_FOLDER': os.path.join(code_path, 'data/gt/davis/'), # Location of GT data 20 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/davis/davis_val'), # Trackers location 21 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 22 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 23 | 'SPLIT_TO_EVAL': 'val', # Valid: 'val', 'train' 24 | 'PRINT_CONFIG': True, # Whether to print current config 25 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 26 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 27 | 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL 28 | 'SEQMAP_FOLDER': None, # Where seqmaps are found (if None, GT_FOLDER/ImageSets/2017) 29 | 'SEQMAP_FILE': None, # Directly specify seqmap file (if none use seqmap_folder/split-to-eval.txt) 30 | 'SEQ_INFO': None, # If not None, directly specify sequences to eval and their number of timesteps 31 | 'GT_LOC_FORMAT': '{gt_folder}/Annotations_unsupervised/480p/{seq}', 32 | # '{gt_folder}/Annotations_unsupervised/480p/{seq}' 33 | 'MAX_DETECTIONS': 0 # Maximum number of allowed detections per sequence (0 for no threshold) 34 | Metric arguments: 35 | 'METRICS': ['HOTA', 'CLEAR', 'Identity', 'JAndF'] 36 | """ 37 | 38 | import sys 39 | import os 40 | import argparse 41 | from multiprocessing import freeze_support 42 | 43 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 44 | import trackeval # noqa: E402 45 | 46 | if __name__ == '__main__': 47 | freeze_support() 48 | 49 | # Command line interface: 50 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 51 | default_dataset_config = trackeval.datasets.DAVIS.get_default_dataset_config() 52 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity', 'JAndF']} 53 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 54 | parser = argparse.ArgumentParser() 55 | for setting in config.keys(): 56 | if type(config[setting]) == list or type(config[setting]) == type(None): 57 | parser.add_argument("--" + setting, nargs='+') 58 | else: 59 | parser.add_argument("--" + setting) 60 | args = parser.parse_args().__dict__ 61 | for setting in args.keys(): 62 | if args[setting] is not None: 63 | if type(config[setting]) == type(True): 64 | if args[setting] == 'True': 65 | x = True 66 | elif args[setting] == 'False': 67 | x = False 68 | else: 69 | raise Exception('Command line parameter ' + setting + 'must be True or False') 70 | elif type(config[setting]) == type(1): 71 | x = int(args[setting]) 72 | elif type(args[setting]) == type(None): 73 | x = None 74 | else: 75 | x = args[setting] 76 | config[setting] = x 77 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 78 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 79 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 80 | 81 | # Run code 82 | evaluator = trackeval.Evaluator(eval_config) 83 | dataset_list = [trackeval.datasets.DAVIS(dataset_config)] 84 | metrics_list = [] 85 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.JAndF]: 86 | if metric.get_name() in metrics_config['METRICS']: 87 | metrics_list.append(metric()) 88 | if len(metrics_list) == 0: 89 | raise Exception('No metrics selected for evaluation') 90 | evaluator.evaluate(dataset_list, metrics_list) -------------------------------------------------------------------------------- /TrackEval/scripts/run_headtracking_challenge.py: -------------------------------------------------------------------------------- 1 | 2 | """ run_mot_challenge.py 3 | 4 | Run example: 5 | run_mot_challenge.py --USE_PARALLEL False --METRICS Hota --TRACKERS_TO_EVAL Lif_T 6 | 7 | Command Line Arguments: Defaults, # Comments 8 | Eval arguments: 9 | 'USE_PARALLEL': False, 10 | 'NUM_PARALLEL_CORES': 8, 11 | 'BREAK_ON_ERROR': True, 12 | 'PRINT_RESULTS': True, 13 | 'PRINT_ONLY_COMBINED': False, 14 | 'PRINT_CONFIG': True, 15 | 'TIME_PROGRESS': True, 16 | 'OUTPUT_SUMMARY': True, 17 | 'OUTPUT_DETAILED': True, 18 | 'PLOT_CURVES': True, 19 | Dataset arguments: 20 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/mot_challenge/'), # Location of GT data 21 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/mot_challenge/'), # Trackers location 22 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 23 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 24 | 'CLASSES_TO_EVAL': ['pedestrian'], # Valid: ['pedestrian'] 25 | 'BENCHMARK': 'MOT17', # Valid: 'MOT17', 'MOT16', 'MOT20', 'MOT15' 26 | 'SPLIT_TO_EVAL': 'train', # Valid: 'train', 'test', 'all' 27 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 28 | 'PRINT_CONFIG': True, # Whether to print current config 29 | 'DO_PREPROC': True, # Whether to perform preprocessing (never done for 2D_MOT_2015) 30 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 31 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 32 | Metric arguments: 33 | 'METRICS': ['HOTA', 'CLEAR', 'Identity', 'IDEucl'] 34 | """ 35 | 36 | import sys 37 | import os 38 | import argparse 39 | from multiprocessing import freeze_support 40 | 41 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 42 | import trackeval # noqa: E402 43 | 44 | if __name__ == '__main__': 45 | freeze_support() 46 | 47 | # Command line interface: 48 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 49 | default_eval_config['DISPLAY_LESS_PROGRESS'] = False 50 | default_dataset_config = trackeval.datasets.HeadTrackingChallenge.get_default_dataset_config() 51 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity', 'IDEucl'], 'THRESHOLD': 0.4} 52 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 53 | parser = argparse.ArgumentParser() 54 | for setting in config.keys(): 55 | if type(config[setting]) == list or type(config[setting]) == type(None): 56 | parser.add_argument("--" + setting, nargs='+') 57 | else: 58 | parser.add_argument("--" + setting) 59 | args = parser.parse_args().__dict__ 60 | for setting in args.keys(): 61 | if args[setting] is not None: 62 | if type(config[setting]) == type(True): 63 | if args[setting] == 'True': 64 | x = True 65 | elif args[setting] == 'False': 66 | x = False 67 | else: 68 | raise Exception('Command line parameter ' + setting + 'must be True or False') 69 | elif type(config[setting]) == type(1): 70 | x = int(args[setting]) 71 | elif type(args[setting]) == type(None): 72 | x = None 73 | elif setting == 'SEQ_INFO': 74 | x = dict(zip(args[setting], [None]*len(args[setting]))) 75 | else: 76 | x = args[setting] 77 | config[setting] = x 78 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 79 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 80 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 81 | 82 | # Run code 83 | evaluator = trackeval.Evaluator(eval_config) 84 | dataset_list = [trackeval.datasets.HeadTrackingChallenge(dataset_config)] 85 | metrics_list = [] 86 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.IDEucl]: 87 | if metric.get_name() in metrics_config['METRICS']: 88 | metrics_list.append(metric(metrics_config)) 89 | if len(metrics_list) == 0: 90 | raise Exception('No metrics selected for evaluation') 91 | evaluator.evaluate(dataset_list, metrics_list) 92 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_kitti.py: -------------------------------------------------------------------------------- 1 | 2 | """ run_kitti.py 3 | 4 | Run example: 5 | run_kitti.py --USE_PARALLEL False --METRICS Hota --TRACKERS_TO_EVAL CIWT 6 | 7 | Command Line Arguments: Defaults, # Comments 8 | Eval arguments: 9 | 'USE_PARALLEL': False, 10 | 'NUM_PARALLEL_CORES': 8, 11 | 'BREAK_ON_ERROR': True, 12 | 'PRINT_RESULTS': True, 13 | 'PRINT_ONLY_COMBINED': False, 14 | 'PRINT_CONFIG': True, 15 | 'TIME_PROGRESS': True, 16 | 'OUTPUT_SUMMARY': True, 17 | 'OUTPUT_DETAILED': True, 18 | 'PLOT_CURVES': True, 19 | Dataset arguments: 20 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/kitti/kitti_2d_box_train'), # Location of GT data 21 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/kitti/kitti_2d_box_train/'), # Trackers location 22 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 23 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 24 | 'CLASSES_TO_EVAL': ['car', 'pedestrian'], # Valid: ['car', 'pedestrian'] 25 | 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val', 'training_minus_val', 'test' 26 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 27 | 'PRINT_CONFIG': True, # Whether to print current config 28 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 29 | 'OUTPUT_SUB_FOLDER': '' # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 30 | Metric arguments: 31 | 'METRICS': ['Hota','Clear', 'ID', 'Count'] 32 | """ 33 | 34 | import sys 35 | import os 36 | import argparse 37 | from multiprocessing import freeze_support 38 | 39 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 40 | import trackeval # noqa: E402 41 | 42 | if __name__ == '__main__': 43 | freeze_support() 44 | 45 | # Command line interface: 46 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 47 | default_eval_config['DISPLAY_LESS_PROGRESS'] = False 48 | default_dataset_config = trackeval.datasets.Kitti2DBox.get_default_dataset_config() 49 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity']} 50 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 51 | parser = argparse.ArgumentParser() 52 | for setting in config.keys(): 53 | if type(config[setting]) == list or type(config[setting]) == type(None): 54 | parser.add_argument("--" + setting, nargs='+') 55 | else: 56 | parser.add_argument("--" + setting) 57 | args = parser.parse_args().__dict__ 58 | for setting in args.keys(): 59 | if args[setting] is not None: 60 | if type(config[setting]) == type(True): 61 | if args[setting] == 'True': 62 | x = True 63 | elif args[setting] == 'False': 64 | x = False 65 | else: 66 | raise Exception('Command line parameter ' + setting + 'must be True or False') 67 | elif type(config[setting]) == type(1): 68 | x = int(args[setting]) 69 | elif type(args[setting]) == type(None): 70 | x = None 71 | else: 72 | x = args[setting] 73 | config[setting] = x 74 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 75 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 76 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 77 | 78 | # Run code 79 | evaluator = trackeval.Evaluator(eval_config) 80 | dataset_list = [trackeval.datasets.Kitti2DBox(dataset_config)] 81 | metrics_list = [] 82 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity]: 83 | if metric.get_name() in metrics_config['METRICS']: 84 | metrics_list.append(metric()) 85 | if len(metrics_list) == 0: 86 | raise Exception('No metrics selected for evaluation') 87 | evaluator.evaluate(dataset_list, metrics_list) 88 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_kitti_mots.py: -------------------------------------------------------------------------------- 1 | 2 | """ run_kitti_mots.py 3 | 4 | Run example: 5 | run_kitti_mots.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL trackrcnn 6 | 7 | Command Line Arguments: Defaults, # Comments 8 | Eval arguments: 9 | 'USE_PARALLEL': False, 10 | 'NUM_PARALLEL_CORES': 8, 11 | 'BREAK_ON_ERROR': True, 12 | 'PRINT_RESULTS': True, 13 | 'PRINT_ONLY_COMBINED': False, 14 | 'PRINT_CONFIG': True, 15 | 'TIME_PROGRESS': True, 16 | 'OUTPUT_SUMMARY': True, 17 | 'OUTPUT_DETAILED': True, 18 | 'PLOT_CURVES': True, 19 | Dataset arguments: 20 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/kitti/kitti_mots'), # Location of GT data 21 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/kitti/kitti_mots_val'), # Location of all 22 | # trackers 23 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 24 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 25 | 'CLASSES_TO_EVAL': ['car', 'pedestrian'], # Valid: ['car', 'pedestrian'] 26 | 'SPLIT_TO_EVAL': 'val', # Valid: 'training', 'val' 27 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 28 | 'PRINT_CONFIG': True, # Whether to print current config 29 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 30 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 31 | 'SEQMAP_FOLDER': None, # Where seqmaps are found (if None, GT_FOLDER) 32 | 'SEQMAP_FILE': None, # Directly specify seqmap file (if none use seqmap_folder/split_to_eval.seqmap) 33 | 'SEQ_INFO': None, # If not None, directly specify sequences to eval and their number of timesteps 34 | 'GT_LOC_FORMAT': '{gt_folder}/instances_txt/{seq}.txt', # format of gt localization 35 | Metric arguments: 36 | 'METRICS': ['HOTA', 'CLEAR', 'Identity'] 37 | """ 38 | 39 | import sys 40 | import os 41 | import argparse 42 | from multiprocessing import freeze_support 43 | 44 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 45 | import trackeval # noqa: E402 46 | 47 | if __name__ == '__main__': 48 | freeze_support() 49 | 50 | # Command line interface: 51 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 52 | default_eval_config['DISPLAY_LESS_PROGRESS'] = False 53 | default_dataset_config = trackeval.datasets.KittiMOTS.get_default_dataset_config() 54 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity']} 55 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 56 | parser = argparse.ArgumentParser() 57 | for setting in config.keys(): 58 | if type(config[setting]) == list or type(config[setting]) == type(None): 59 | parser.add_argument("--" + setting, nargs='+') 60 | else: 61 | parser.add_argument("--" + setting) 62 | args = parser.parse_args().__dict__ 63 | for setting in args.keys(): 64 | if args[setting] is not None: 65 | if type(config[setting]) == type(True): 66 | if args[setting] == 'True': 67 | x = True 68 | elif args[setting] == 'False': 69 | x = False 70 | else: 71 | raise Exception('Command line parameter ' + setting + 'must be True or False') 72 | elif type(config[setting]) == type(1): 73 | x = int(args[setting]) 74 | elif type(args[setting]) == type(None): 75 | x = None 76 | else: 77 | x = args[setting] 78 | config[setting] = x 79 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 80 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 81 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 82 | 83 | # Run code 84 | evaluator = trackeval.Evaluator(eval_config) 85 | dataset_list = [trackeval.datasets.KittiMOTS(dataset_config)] 86 | metrics_list = [] 87 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.JAndF]: 88 | if metric.get_name() in metrics_config['METRICS']: 89 | metrics_list.append(metric()) 90 | if len(metrics_list) == 0: 91 | raise Exception('No metrics selected for evaluation') 92 | evaluator.evaluate(dataset_list, metrics_list) 93 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_mot_challenge.py: -------------------------------------------------------------------------------- 1 | 2 | """ run_mot_challenge.py 3 | 4 | Run example: 5 | run_mot_challenge.py --USE_PARALLEL False --METRICS Hota --TRACKERS_TO_EVAL Lif_T 6 | 7 | Command Line Arguments: Defaults, # Comments 8 | Eval arguments: 9 | 'USE_PARALLEL': False, 10 | 'NUM_PARALLEL_CORES': 8, 11 | 'BREAK_ON_ERROR': True, 12 | 'PRINT_RESULTS': True, 13 | 'PRINT_ONLY_COMBINED': False, 14 | 'PRINT_CONFIG': True, 15 | 'TIME_PROGRESS': True, 16 | 'OUTPUT_SUMMARY': True, 17 | 'OUTPUT_DETAILED': True, 18 | 'PLOT_CURVES': True, 19 | Dataset arguments: 20 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/mot_challenge/'), # Location of GT data 21 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/mot_challenge/'), # Trackers location 22 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 23 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 24 | 'CLASSES_TO_EVAL': ['pedestrian'], # Valid: ['pedestrian'] 25 | 'BENCHMARK': 'MOT17', # Valid: 'MOT17', 'MOT16', 'MOT20', 'MOT15' 26 | 'SPLIT_TO_EVAL': 'train', # Valid: 'train', 'test', 'all' 27 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 28 | 'PRINT_CONFIG': True, # Whether to print current config 29 | 'DO_PREPROC': True, # Whether to perform preprocessing (never done for 2D_MOT_2015) 30 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 31 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 32 | Metric arguments: 33 | 'METRICS': ['HOTA', 'CLEAR', 'Identity', 'VACE'] 34 | """ 35 | 36 | import sys 37 | import os 38 | import argparse 39 | from multiprocessing import freeze_support 40 | 41 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 42 | import trackeval # noqa: E402 43 | 44 | if __name__ == '__main__': 45 | freeze_support() 46 | 47 | # Command line interface: 48 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 49 | default_eval_config['DISPLAY_LESS_PROGRESS'] = False 50 | default_dataset_config = trackeval.datasets.MotChallenge2DBox.get_default_dataset_config() 51 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity'], 'THRESHOLD': 0.5} 52 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 53 | parser = argparse.ArgumentParser() 54 | for setting in config.keys(): 55 | if type(config[setting]) == list or type(config[setting]) == type(None): 56 | parser.add_argument("--" + setting, nargs='+') 57 | else: 58 | parser.add_argument("--" + setting) 59 | args = parser.parse_args().__dict__ 60 | for setting in args.keys(): 61 | if args[setting] is not None: 62 | if type(config[setting]) == type(True): 63 | if args[setting] == 'True': 64 | x = True 65 | elif args[setting] == 'False': 66 | x = False 67 | else: 68 | raise Exception('Command line parameter ' + setting + 'must be True or False') 69 | elif type(config[setting]) == type(1): 70 | x = int(args[setting]) 71 | elif type(args[setting]) == type(None): 72 | x = None 73 | elif setting == 'SEQ_INFO': 74 | x = dict(zip(args[setting], [None]*len(args[setting]))) 75 | else: 76 | x = args[setting] 77 | config[setting] = x 78 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 79 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 80 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 81 | 82 | # Run code 83 | evaluator = trackeval.Evaluator(eval_config) 84 | dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)] 85 | metrics_list = [] 86 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.VACE]: 87 | if metric.get_name() in metrics_config['METRICS']: 88 | metrics_list.append(metric(metrics_config)) 89 | if len(metrics_list) == 0: 90 | raise Exception('No metrics selected for evaluation') 91 | evaluator.evaluate(dataset_list, metrics_list) 92 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_mots_challenge.py: -------------------------------------------------------------------------------- 1 | """ run_mots.py 2 | 3 | Run example: 4 | run_mots.py --USE_PARALLEL False --METRICS Hota --TRACKERS_TO_EVAL TrackRCNN 5 | 6 | Command Line Arguments: Defaults, # Comments 7 | Eval arguments: 8 | 'USE_PARALLEL': False, 9 | 'NUM_PARALLEL_CORES': 8, 10 | 'BREAK_ON_ERROR': True, 11 | 'PRINT_RESULTS': True, 12 | 'PRINT_ONLY_COMBINED': False, 13 | 'PRINT_CONFIG': True, 14 | 'TIME_PROGRESS': True, 15 | 'OUTPUT_SUMMARY': True, 16 | 'OUTPUT_DETAILED': True, 17 | 'PLOT_CURVES': True, 18 | Dataset arguments: 19 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/mot_challenge/'), # Location of GT data 20 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/mot_challenge/'), # Trackers location 21 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 22 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 23 | 'CLASSES_TO_EVAL': ['pedestrian'], # Valid: ['pedestrian'] 24 | 'SPLIT_TO_EVAL': 'train', # Valid: 'train', 'test' 25 | 'INPUT_AS_ZIP': False, # Whether tracker input files are zipped 26 | 'PRINT_CONFIG': True, # Whether to print current config 27 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 28 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 29 | 'SEQMAP_FOLDER': None, # Where seqmaps are found (if None, GT_FOLDER/seqmaps) 30 | 'SEQMAP_FILE': None, # Directly specify seqmap file (if none use seqmap_folder/MOTS-split_to_eval) 31 | 'SEQ_INFO': None, # If not None, directly specify sequences to eval and their number of timesteps 32 | 'GT_LOC_FORMAT': '{gt_folder}/{seq}/gt/gt.txt', # '{gt_folder}/{seq}/gt/gt.txt' 33 | 'SKIP_SPLIT_FOL': False, # If False, data is in GT_FOLDER/MOTS-SPLIT_TO_EVAL/ and in 34 | # TRACKERS_FOLDER/MOTS-SPLIT_TO_EVAL/tracker/ 35 | # If True, then the middle 'MOTS-split' folder is skipped for both. 36 | Metric arguments: 37 | 'METRICS': ['HOTA','CLEAR', 'Identity', 'VACE', 'JAndF'] 38 | """ 39 | 40 | import sys 41 | import os 42 | import argparse 43 | from multiprocessing import freeze_support 44 | 45 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 46 | import trackeval # noqa: E402 47 | 48 | if __name__ == '__main__': 49 | freeze_support() 50 | 51 | # Command line interface: 52 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 53 | default_eval_config['DISPLAY_LESS_PROGRESS'] = False 54 | default_dataset_config = trackeval.datasets.MOTSChallenge.get_default_dataset_config() 55 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity']} 56 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 57 | parser = argparse.ArgumentParser() 58 | for setting in config.keys(): 59 | if type(config[setting]) == list or type(config[setting]) == type(None): 60 | parser.add_argument("--" + setting, nargs='+') 61 | else: 62 | parser.add_argument("--" + setting) 63 | args = parser.parse_args().__dict__ 64 | for setting in args.keys(): 65 | if args[setting] is not None: 66 | if type(config[setting]) == type(True): 67 | if args[setting] == 'True': 68 | x = True 69 | elif args[setting] == 'False': 70 | x = False 71 | else: 72 | raise Exception('Command line parameter ' + setting + 'must be True or False') 73 | elif type(config[setting]) == type(1): 74 | x = int(args[setting]) 75 | elif type(args[setting]) == type(None): 76 | x = None 77 | elif setting == 'SEQ_INFO': 78 | x = dict(zip(args[setting], [None]*len(args[setting]))) 79 | else: 80 | x = args[setting] 81 | config[setting] = x 82 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 83 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 84 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 85 | 86 | # Run code 87 | evaluator = trackeval.Evaluator(eval_config) 88 | dataset_list = [trackeval.datasets.MOTSChallenge(dataset_config)] 89 | metrics_list = [] 90 | for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.VACE, 91 | trackeval.metrics.JAndF]: 92 | if metric.get_name() in metrics_config['METRICS']: 93 | metrics_list.append(metric()) 94 | if len(metrics_list) == 0: 95 | raise Exception('No metrics selected for evaluation') 96 | evaluator.evaluate(dataset_list, metrics_list) 97 | -------------------------------------------------------------------------------- /TrackEval/scripts/run_tao.py: -------------------------------------------------------------------------------- 1 | """ run_tao.py 2 | 3 | Run example: 4 | run_tao.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL Tracktor++ 5 | 6 | Command Line Arguments: Defaults, # Comments 7 | Eval arguments: 8 | 'USE_PARALLEL': False, 9 | 'NUM_PARALLEL_CORES': 8, 10 | 'BREAK_ON_ERROR': True, 11 | 'PRINT_RESULTS': True, 12 | 'PRINT_ONLY_COMBINED': False, 13 | 'PRINT_CONFIG': True, 14 | 'TIME_PROGRESS': True, 15 | 'OUTPUT_SUMMARY': True, 16 | 'OUTPUT_DETAILED': True, 17 | 'PLOT_CURVES': True, 18 | Dataset arguments: 19 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/tao/tao_training'), # Location of GT data 20 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/tao/tao_training'), # Trackers location 21 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 22 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 23 | 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes) 24 | 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val' 25 | 'PRINT_CONFIG': True, # Whether to print current config 26 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 27 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 28 | 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL 29 | 'MAX_DETECTIONS': 300, # Number of maximal allowed detections per image (0 for unlimited) 30 | Metric arguments: 31 | 'METRICS': ['HOTA', 'CLEAR', 'Identity', 'TrackMAP'] 32 | """ 33 | 34 | import sys 35 | import os 36 | import argparse 37 | from multiprocessing import freeze_support 38 | 39 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 40 | import trackeval # noqa: E402 41 | 42 | if __name__ == '__main__': 43 | freeze_support() 44 | 45 | # Command line interface: 46 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 47 | # print only combined since TrackMAP is undefined for per sequence breakdowns 48 | default_eval_config['PRINT_ONLY_COMBINED'] = True 49 | default_eval_config['DISPLAY_LESS_PROGRESS'] = True 50 | default_dataset_config = trackeval.datasets.TAO.get_default_dataset_config() 51 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity', 'TrackMAP']} 52 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 53 | parser = argparse.ArgumentParser() 54 | for setting in config.keys(): 55 | if type(config[setting]) == list or type(config[setting]) == type(None): 56 | parser.add_argument("--" + setting, nargs='+') 57 | else: 58 | parser.add_argument("--" + setting) 59 | args = parser.parse_args().__dict__ 60 | for setting in args.keys(): 61 | if args[setting] is not None: 62 | if type(config[setting]) == type(True): 63 | if args[setting] == 'True': 64 | x = True 65 | elif args[setting] == 'False': 66 | x = False 67 | else: 68 | raise Exception('Command line parameter ' + setting + 'must be True or False') 69 | elif type(config[setting]) == type(1): 70 | x = int(args[setting]) 71 | elif type(args[setting]) == type(None): 72 | x = None 73 | else: 74 | x = args[setting] 75 | config[setting] = x 76 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 77 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 78 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 79 | 80 | # Run code 81 | evaluator = trackeval.Evaluator(eval_config) 82 | dataset_list = [trackeval.datasets.TAO(dataset_config)] 83 | metrics_list = [] 84 | for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.CLEAR, trackeval.metrics.Identity, 85 | trackeval.metrics.HOTA]: 86 | if metric.get_name() in metrics_config['METRICS']: 87 | metrics_list.append(metric()) 88 | if len(metrics_list) == 0: 89 | raise Exception('No metrics selected for evaluation') 90 | evaluator.evaluate(dataset_list, metrics_list) -------------------------------------------------------------------------------- /TrackEval/scripts/run_tao_ow.py: -------------------------------------------------------------------------------- 1 | """ run_tao.py 2 | 3 | Run example: 4 | run_tao_ow.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL Tracktor++ 5 | 6 | Command Line Arguments: Defaults, # Comments 7 | Eval arguments: 8 | 'USE_PARALLEL': False, 9 | 'NUM_PARALLEL_CORES': 8, 10 | 'BREAK_ON_ERROR': True, 11 | 'PRINT_RESULTS': True, 12 | 'PRINT_ONLY_COMBINED': False, 13 | 'PRINT_CONFIG': True, 14 | 'TIME_PROGRESS': True, 15 | 'OUTPUT_SUMMARY': True, 16 | 'OUTPUT_DETAILED': True, 17 | 'PLOT_CURVES': True, 18 | Dataset arguments: 19 | 'GT_FOLDER': os.path.join(code_path, 'data/gt/tao/tao_training'), # Location of GT data 20 | 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/tao/tao_training'), # Trackers location 21 | 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) 22 | 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) 23 | 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes) 24 | 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val' 25 | 'PRINT_CONFIG': True, # Whether to print current config 26 | 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER 27 | 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER 28 | 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL 29 | 'MAX_DETECTIONS': 300, # Number of maximal allowed detections per image (0 for unlimited) 30 | 'SUBSET': 'unknown', # Evaluate on the following subsets ['all', 'known', 'unknown', 'distractor'] 31 | Metric arguments: 32 | 'METRICS': ['HOTA', 'CLEAR', 'Identity', 'TrackMAP'] 33 | """ 34 | 35 | import sys 36 | import os 37 | import argparse 38 | from multiprocessing import freeze_support 39 | 40 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 41 | import trackeval # noqa: E402 42 | 43 | if __name__ == '__main__': 44 | freeze_support() 45 | 46 | # Command line interface: 47 | default_eval_config = trackeval.Evaluator.get_default_eval_config() 48 | # print only combined since TrackMAP is undefined for per sequence breakdowns 49 | default_eval_config['PRINT_ONLY_COMBINED'] = True 50 | default_eval_config['DISPLAY_LESS_PROGRESS'] = True 51 | default_dataset_config = trackeval.datasets.TAO_OW.get_default_dataset_config() 52 | default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity', 'TrackMAP']} 53 | config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs 54 | parser = argparse.ArgumentParser() 55 | for setting in config.keys(): 56 | if type(config[setting]) == list or type(config[setting]) == type(None): 57 | parser.add_argument("--" + setting, nargs='+') 58 | else: 59 | parser.add_argument("--" + setting) 60 | args = parser.parse_args().__dict__ 61 | for setting in args.keys(): 62 | if args[setting] is not None: 63 | if type(config[setting]) == type(True): 64 | if args[setting] == 'True': 65 | x = True 66 | elif args[setting] == 'False': 67 | x = False 68 | else: 69 | raise Exception('Command line parameter ' + setting + 'must be True or False') 70 | elif type(config[setting]) == type(1): 71 | x = int(args[setting]) 72 | elif type(args[setting]) == type(None): 73 | x = None 74 | else: 75 | x = args[setting] 76 | config[setting] = x 77 | eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} 78 | dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} 79 | metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} 80 | 81 | # Run code 82 | evaluator = trackeval.Evaluator(eval_config) 83 | dataset_list = [trackeval.datasets.TAO_OW(dataset_config)] 84 | metrics_list = [] 85 | # for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.CLEAR, trackeval.metrics.Identity, 86 | # trackeval.metrics.HOTA]: 87 | for metric in [trackeval.metrics.HOTA]: 88 | if metric.get_name() in metrics_config['METRICS']: 89 | metrics_list.append(metric()) 90 | if len(metrics_list) == 0: 91 | raise Exception('No metrics selected for evaluation') 92 | evaluator.evaluate(dataset_list, metrics_list) -------------------------------------------------------------------------------- /TrackEval/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = trackeval 3 | version = 1.0.dev1 4 | author = Jonathon Luiten, Arne Hoffhues 5 | author_email = jonoluiten@gmail.com 6 | description = Code for evaluating object tracking 7 | long_description = file: Readme.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/JonathonLuiten/TrackEval 10 | project_urls = 11 | Bug Tracker = https://github.com/JonathonLuiten/TrackEval/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | Programming Language :: Python :: 3 :: Only 15 | License :: OSI Approved :: MIT License 16 | Operating System :: OS Independent 17 | Topic :: Scientific/Engineering 18 | license_files = LICENSE 19 | 20 | [options] 21 | install_requires = 22 | numpy 23 | scipy 24 | packages = find: 25 | 26 | [options.packages.find] 27 | include = trackeval* 28 | -------------------------------------------------------------------------------- /TrackEval/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /TrackEval/tests/test_all_quick.py: -------------------------------------------------------------------------------- 1 | """ Test to ensure that the code is working correctly. 2 | Should test ALL metrics across all datasets and splits currently supported. 3 | Only tests one tracker per dataset/split to give a quick test result. 4 | """ 5 | 6 | import sys 7 | import os 8 | import numpy as np 9 | from multiprocessing import freeze_support 10 | 11 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 12 | import trackeval # noqa: E402 13 | 14 | # Fixes multiprocessing on windows, does nothing otherwise 15 | if __name__ == '__main__': 16 | freeze_support() 17 | 18 | eval_config = {'USE_PARALLEL': False, 19 | 'NUM_PARALLEL_CORES': 8, 20 | } 21 | evaluator = trackeval.Evaluator(eval_config) 22 | metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity()] 23 | 24 | tests = [ 25 | {'DATASET': 'Kitti2DBox', 'SPLIT_TO_EVAL': 'training', 'TRACKERS_TO_EVAL': ['CIWT']}, 26 | {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT15', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']}, 27 | {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT16', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']}, 28 | {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT17', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']}, 29 | {'DATASET': 'MotChallenge2DBox', 'BENCHMARK': 'MOT20', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['MPNTrack']}, 30 | ] 31 | 32 | for dataset_config in tests: 33 | 34 | dataset_name = dataset_config.pop('DATASET') 35 | if dataset_name == 'MotChallenge2DBox': 36 | dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)] 37 | file_loc = os.path.join('mot_challenge', dataset_config['BENCHMARK'] + '-' + dataset_config['SPLIT_TO_EVAL']) 38 | elif dataset_name == 'Kitti2DBox': 39 | dataset_list = [trackeval.datasets.Kitti2DBox(dataset_config)] 40 | file_loc = os.path.join('kitti', 'kitti_2d_box_train') 41 | else: 42 | raise Exception('Dataset %s does not exist.' % dataset_name) 43 | 44 | raw_results, messages = evaluator.evaluate(dataset_list, metrics_list) 45 | 46 | classes = dataset_list[0].config['CLASSES_TO_EVAL'] 47 | tracker = dataset_config['TRACKERS_TO_EVAL'][0] 48 | test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', file_loc) 49 | 50 | for cls in classes: 51 | results = {seq: raw_results[dataset_name][tracker][seq][cls] for seq in raw_results[dataset_name][tracker].keys()} 52 | current_metrics_list = metrics_list + [trackeval.metrics.Count()] 53 | metric_names = trackeval.utils.validate_metrics_list(current_metrics_list) 54 | 55 | # Load expected results: 56 | test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, cls + '_detailed.csv')) 57 | 58 | # Do checks 59 | for seq in test_data.keys(): 60 | assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys()) 61 | 62 | details = [] 63 | for metric, metric_name in zip(current_metrics_list, metric_names): 64 | table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()} 65 | details.append(metric.detailed_results(table_res)) 66 | res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], []) 67 | res_values = sum([list(s[seq].values()) for s in details], []) 68 | res_dict = dict(zip(res_fields, res_values)) 69 | 70 | for field in test_data[seq].keys(): 71 | assert np.isclose(res_dict[field], test_data[seq][field]), seq + ': ' + cls + ': ' + field 72 | 73 | print('Tracker %s tests passed' % tracker) 74 | print('All tests passed') 75 | 76 | -------------------------------------------------------------------------------- /TrackEval/tests/test_davis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | from multiprocessing import freeze_support 5 | 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | import trackeval # noqa: E402 8 | 9 | # Fixes multiprocessing on windows, does nothing otherwise 10 | if __name__ == '__main__': 11 | freeze_support() 12 | 13 | 14 | eval_config = {'USE_PARALLEL': False, 15 | 'NUM_PARALLEL_CORES': 8, 16 | 'PRINT_RESULTS': False, 17 | 'PRINT_CONFIG': True, 18 | 'TIME_PROGRESS': True, 19 | 'DISPLAY_LESS_PROGRESS': True, 20 | 'OUTPUT_SUMMARY': False, 21 | 'OUTPUT_EMPTY_CLASSES': False, 22 | 'OUTPUT_DETAILED': False, 23 | 'PLOT_CURVES': False, 24 | } 25 | evaluator = trackeval.Evaluator(eval_config) 26 | metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity(), 27 | trackeval.metrics.JAndF()] 28 | 29 | tests = [ 30 | {'SPLIT_TO_EVAL': 'val', 'TRACKERS_TO_EVAL': ['ags']}, 31 | ] 32 | 33 | for dataset_config in tests: 34 | 35 | dataset_list = [trackeval.datasets.DAVIS(dataset_config)] 36 | file_loc = os.path.join('davis', 'davis_unsupervised_' + dataset_config['SPLIT_TO_EVAL']) 37 | 38 | raw_results, messages = evaluator.evaluate(dataset_list, metrics_list) 39 | 40 | classes = dataset_list[0].config['CLASSES_TO_EVAL'] 41 | tracker = dataset_config['TRACKERS_TO_EVAL'][0] 42 | test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', file_loc) 43 | 44 | for cls in classes: 45 | results = {seq: raw_results['DAVIS'][tracker][seq][cls] for seq in raw_results['DAVIS'][tracker].keys()} 46 | current_metrics_list = metrics_list + [trackeval.metrics.Count()] 47 | metric_names = trackeval.utils.validate_metrics_list(current_metrics_list) 48 | 49 | # Load expected results: 50 | test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, cls + '_detailed.csv')) 51 | 52 | # Do checks 53 | for seq in test_data.keys(): 54 | assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys()) 55 | 56 | details = [] 57 | for metric, metric_name in zip(current_metrics_list, metric_names): 58 | table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()} 59 | details.append(metric.detailed_results(table_res)) 60 | res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], []) 61 | res_values = sum([list(s[seq].values()) for s in details], []) 62 | res_dict = dict(zip(res_fields, res_values)) 63 | 64 | for field in test_data[seq].keys(): 65 | assert np.isclose(res_dict[field], test_data[seq][field]), seq + ': ' + cls + ': ' + field 66 | 67 | print('Tracker %s tests passed' % tracker) 68 | print('All tests passed') -------------------------------------------------------------------------------- /TrackEval/tests/test_mot17.py: -------------------------------------------------------------------------------- 1 | """ Test to ensure that the code is working correctly. 2 | Runs all metrics on 14 trackers for the MOT Challenge MOT17 benchmark. 3 | """ 4 | 5 | 6 | import sys 7 | import os 8 | import numpy as np 9 | from multiprocessing import freeze_support 10 | 11 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 12 | import trackeval # noqa: E402 13 | 14 | # Fixes multiprocessing on windows, does nothing otherwise 15 | if __name__ == '__main__': 16 | freeze_support() 17 | 18 | eval_config = {'USE_PARALLEL': False, 19 | 'NUM_PARALLEL_CORES': 8, 20 | } 21 | evaluator = trackeval.Evaluator(eval_config) 22 | metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity()] 23 | test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', 'mot_challenge', 'MOT17-train') 24 | trackers = [ 25 | 'DPMOT', 26 | 'GNNMatch', 27 | 'IA', 28 | 'ISE_MOT17R', 29 | 'Lif_T', 30 | 'Lif_TsimInt', 31 | 'LPC_MOT', 32 | 'MAT', 33 | 'MIFTv2', 34 | 'MPNTrack', 35 | 'SSAT', 36 | 'TracktorCorr', 37 | 'Tracktorv2', 38 | 'UnsupTrack', 39 | ] 40 | 41 | for tracker in trackers: 42 | # Run code on tracker 43 | dataset_config = {'TRACKERS_TO_EVAL': [tracker], 44 | 'BENCHMARK': 'MOT17'} 45 | dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)] 46 | raw_results, messages = evaluator.evaluate(dataset_list, metrics_list) 47 | 48 | results = {seq: raw_results['MotChallenge2DBox'][tracker][seq]['pedestrian'] for seq in 49 | raw_results['MotChallenge2DBox'][tracker].keys()} 50 | current_metrics_list = metrics_list + [trackeval.metrics.Count()] 51 | metric_names = trackeval.utils.validate_metrics_list(current_metrics_list) 52 | 53 | # Load expected results: 54 | test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, 'pedestrian_detailed.csv')) 55 | assert len(test_data.keys()) == 22, len(test_data.keys()) 56 | 57 | # Do checks 58 | for seq in test_data.keys(): 59 | assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys()) 60 | 61 | details = [] 62 | for metric, metric_name in zip(current_metrics_list, metric_names): 63 | table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()} 64 | details.append(metric.detailed_results(table_res)) 65 | res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], []) 66 | res_values = sum([list(s[seq].values()) for s in details], []) 67 | res_dict = dict(zip(res_fields, res_values)) 68 | 69 | for field in test_data[seq].keys(): 70 | if not np.isclose(res_dict[field], test_data[seq][field]): 71 | print(tracker, seq, res_dict[field], test_data[seq][field], field) 72 | raise AssertionError 73 | 74 | print('Tracker %s tests passed' % tracker) 75 | print('All tests passed') 76 | 77 | -------------------------------------------------------------------------------- /TrackEval/tests/test_mots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | from multiprocessing import freeze_support 5 | 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | import trackeval # noqa: E402 8 | 9 | # Fixes multiprocessing on windows, does nothing otherwise 10 | if __name__ == '__main__': 11 | freeze_support() 12 | 13 | eval_config = {'USE_PARALLEL': False, 14 | 'NUM_PARALLEL_CORES': 8, 15 | } 16 | evaluator = trackeval.Evaluator(eval_config) 17 | metrics_list = [trackeval.metrics.HOTA(), trackeval.metrics.CLEAR(), trackeval.metrics.Identity()] 18 | 19 | tests = [ 20 | {'DATASET': 'KittiMOTS', 'SPLIT_TO_EVAL': 'val', 'TRACKERS_TO_EVAL': ['trackrcnn']}, 21 | {'DATASET': 'MOTSChallenge', 'SPLIT_TO_EVAL': 'train', 'TRACKERS_TO_EVAL': ['TrackRCNN']} 22 | ] 23 | 24 | for dataset_config in tests: 25 | 26 | dataset_name = dataset_config.pop('DATASET') 27 | if dataset_name == 'MOTSChallenge': 28 | dataset_list = [trackeval.datasets.MOTSChallenge(dataset_config)] 29 | file_loc = os.path.join('mot_challenge', 'MOTS-' + dataset_config['SPLIT_TO_EVAL']) 30 | elif dataset_name == 'KittiMOTS': 31 | dataset_list = [trackeval.datasets.KittiMOTS(dataset_config)] 32 | file_loc = os.path.join('kitti', 'kitti_mots_val') 33 | else: 34 | raise Exception('Dataset %s does not exist.' % dataset_name) 35 | 36 | raw_results, messages = evaluator.evaluate(dataset_list, metrics_list) 37 | 38 | classes = dataset_list[0].config['CLASSES_TO_EVAL'] 39 | tracker = dataset_config['TRACKERS_TO_EVAL'][0] 40 | test_data_loc = os.path.join(os.path.dirname(__file__), '..', 'data', 'tests', file_loc) 41 | 42 | for cls in classes: 43 | results = {seq: raw_results[dataset_name][tracker][seq][cls] for seq in raw_results[dataset_name][tracker].keys()} 44 | current_metrics_list = metrics_list + [trackeval.metrics.Count()] 45 | metric_names = trackeval.utils.validate_metrics_list(current_metrics_list) 46 | 47 | # Load expected results: 48 | test_data = trackeval.utils.load_detail(os.path.join(test_data_loc, tracker, cls + '_detailed.csv')) 49 | 50 | # Do checks 51 | for seq in test_data.keys(): 52 | assert len(test_data[seq].keys()) > 250, len(test_data[seq].keys()) 53 | 54 | details = [] 55 | for metric, metric_name in zip(current_metrics_list, metric_names): 56 | table_res = {seq_key: seq_value[metric_name] for seq_key, seq_value in results.items()} 57 | details.append(metric.detailed_results(table_res)) 58 | res_fields = sum([list(s['COMBINED_SEQ'].keys()) for s in details], []) 59 | res_values = sum([list(s[seq].values()) for s in details], []) 60 | res_dict = dict(zip(res_fields, res_values)) 61 | 62 | for field in test_data[seq].keys(): 63 | assert np.isclose(res_dict[field], test_data[seq][field]), seq + ': ' + cls + ': ' + field 64 | 65 | print('Tracker %s tests passed' % tracker) 66 | print('All tests passed') -------------------------------------------------------------------------------- /TrackEval/trackeval/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import Evaluator 2 | from . import datasets 3 | from . import metrics 4 | from . import plotting 5 | from . import utils 6 | -------------------------------------------------------------------------------- /TrackEval/trackeval/_timing.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from time import perf_counter 3 | import inspect 4 | 5 | DO_TIMING = False 6 | DISPLAY_LESS_PROGRESS = False 7 | timer_dict = {} 8 | counter = 0 9 | 10 | 11 | def time(f): 12 | @wraps(f) 13 | def wrap(*args, **kw): 14 | if DO_TIMING: 15 | # Run function with timing 16 | ts = perf_counter() 17 | result = f(*args, **kw) 18 | te = perf_counter() 19 | tt = te-ts 20 | 21 | # Get function name 22 | arg_names = inspect.getfullargspec(f)[0] 23 | if arg_names[0] == 'self' and DISPLAY_LESS_PROGRESS: 24 | return result 25 | elif arg_names[0] == 'self': 26 | method_name = type(args[0]).__name__ + '.' + f.__name__ 27 | else: 28 | method_name = f.__name__ 29 | 30 | # Record accumulative time in each function for analysis 31 | if method_name in timer_dict.keys(): 32 | timer_dict[method_name] += tt 33 | else: 34 | timer_dict[method_name] = tt 35 | 36 | # If code is finished, display timing summary 37 | if method_name == "Evaluator.evaluate": 38 | print("") 39 | print("Timing analysis:") 40 | for key, value in timer_dict.items(): 41 | print('%-70s %2.4f sec' % (key, value)) 42 | else: 43 | # Get function argument values for printing special arguments of interest 44 | arg_titles = ['tracker', 'seq', 'cls'] 45 | arg_vals = [] 46 | for i, a in enumerate(arg_names): 47 | if a in arg_titles: 48 | arg_vals.append(args[i]) 49 | arg_text = '(' + ', '.join(arg_vals) + ')' 50 | 51 | # Display methods and functions with different indentation. 52 | if arg_names[0] == 'self': 53 | print('%-74s %2.4f sec' % (' '*4 + method_name + arg_text, tt)) 54 | elif arg_names[0] == 'test': 55 | pass 56 | else: 57 | global counter 58 | counter += 1 59 | print('%i %-70s %2.4f sec' % (counter, method_name + arg_text, tt)) 60 | 61 | return result 62 | else: 63 | # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing. 64 | return f(*args, **kw) 65 | return wrap 66 | -------------------------------------------------------------------------------- /TrackEval/trackeval/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | import baseline_utils 2 | import stp 3 | import non_overlap 4 | import pascal_colormap 5 | import thresholder 6 | import vizualize -------------------------------------------------------------------------------- /TrackEval/trackeval/baselines/non_overlap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Non-Overlap: Code to take in a set of raw detections and produce a set of non-overlapping detections from it. 3 | 4 | Author: Jonathon Luiten 5 | """ 6 | 7 | import os 8 | import sys 9 | from multiprocessing.pool import Pool 10 | from multiprocessing import freeze_support 11 | 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 13 | from trackeval.baselines import baseline_utils as butils 14 | from trackeval.utils import get_code_path 15 | 16 | code_path = get_code_path() 17 | config = { 18 | 'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/raw_supplied/data/'), 19 | 'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'), 20 | 'SPLIT': 'train', # valid: 'train', 'val', 'test'. 21 | 'Benchmarks': None, # If None, all benchmarks in SPLIT. 22 | 23 | 'Num_Parallel_Cores': None, # If None, run without parallel. 24 | 25 | 'THRESHOLD_NMS_MASK_IOU': 0.5, 26 | } 27 | 28 | 29 | def do_sequence(seq_file): 30 | 31 | # Load input data from file (e.g. provided detections) 32 | # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'} 33 | data = butils.load_seq(seq_file) 34 | 35 | # Converts data from a class-separated to a class-combined format. 36 | # data[t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles', 'cls'} 37 | data = butils.combine_classes(data) 38 | 39 | # Where to accumulate output data for writing out 40 | output_data = [] 41 | 42 | # Run for each timestep. 43 | for timestep, t_data in enumerate(data): 44 | 45 | # Remove redundant masks by performing non-maximum suppression (NMS) 46 | t_data = butils.mask_NMS(t_data, nms_threshold=config['THRESHOLD_NMS_MASK_IOU']) 47 | 48 | # Perform non-overlap, to get non_overlapping masks. 49 | t_data = butils.non_overlap(t_data, already_sorted=True) 50 | 51 | # Save result in output format to write to file later. 52 | # Output Format = [timestep ID class score im_h im_w mask_RLE] 53 | for i in range(len(t_data['ids'])): 54 | row = [timestep, int(t_data['ids'][i]), t_data['cls'][i], t_data['scores'][i], t_data['im_hs'][i], 55 | t_data['im_ws'][i], t_data['mask_rles'][i]] 56 | output_data.append(row) 57 | 58 | # Write results to file 59 | out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']), 60 | config['OUTPUT_FOL'].format(split=config['SPLIT'])) 61 | butils.write_seq(output_data, out_file) 62 | 63 | print('DONE:', seq_file) 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | # Required to fix bug in multiprocessing on windows. 69 | freeze_support() 70 | 71 | # Obtain list of sequences to run tracker for. 72 | if config['Benchmarks']: 73 | benchmarks = config['Benchmarks'] 74 | else: 75 | benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao'] 76 | if config['SPLIT'] != 'train': 77 | benchmarks += ['waymo', 'mots_challenge'] 78 | seqs_todo = [] 79 | for bench in benchmarks: 80 | bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench) 81 | seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)] 82 | 83 | # Run in parallel 84 | if config['Num_Parallel_Cores']: 85 | with Pool(config['Num_Parallel_Cores']) as pool: 86 | results = pool.map(do_sequence, seqs_todo) 87 | 88 | # Run in series 89 | else: 90 | for seq_todo in seqs_todo: 91 | do_sequence(seq_todo) 92 | 93 | -------------------------------------------------------------------------------- /TrackEval/trackeval/baselines/thresholder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thresholder 3 | 4 | Author: Jonathon Luiten 5 | 6 | Simply reads in a set of detection, thresholds them at a certain score threshold, and writes them out again. 7 | """ 8 | 9 | import os 10 | import sys 11 | from multiprocessing.pool import Pool 12 | from multiprocessing import freeze_support 13 | 14 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 15 | from trackeval.baselines import baseline_utils as butils 16 | from trackeval.utils import get_code_path 17 | 18 | THRESHOLD = 0.2 19 | 20 | code_path = get_code_path() 21 | config = { 22 | 'INPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/non_overlap_supplied/data/'), 23 | 'OUTPUT_FOL': os.path.join(code_path, 'data/detections/rob_mots/{split}/threshold_' + str(100*THRESHOLD) + '/data/'), 24 | 'SPLIT': 'train', # valid: 'train', 'val', 'test'. 25 | 'Benchmarks': None, # If None, all benchmarks in SPLIT. 26 | 27 | 'Num_Parallel_Cores': None, # If None, run without parallel. 28 | 29 | 'DETECTION_THRESHOLD': THRESHOLD, 30 | } 31 | 32 | 33 | def do_sequence(seq_file): 34 | 35 | # Load input data from file (e.g. provided detections) 36 | # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'} 37 | data = butils.load_seq(seq_file) 38 | 39 | # Where to accumulate output data for writing out 40 | output_data = [] 41 | 42 | # Run for each class. 43 | for cls, cls_data in data.items(): 44 | 45 | # Run for each timestep. 46 | for timestep, t_data in enumerate(cls_data): 47 | 48 | # Threshold detections. 49 | t_data = butils.threshold(t_data, config['DETECTION_THRESHOLD']) 50 | 51 | # Save result in output format to write to file later. 52 | # Output Format = [timestep ID class score im_h im_w mask_RLE] 53 | for i in range(len(t_data['ids'])): 54 | row = [timestep, int(t_data['ids'][i]), cls, t_data['scores'][i], t_data['im_hs'][i], 55 | t_data['im_ws'][i], t_data['mask_rles'][i]] 56 | output_data.append(row) 57 | 58 | # Write results to file 59 | out_file = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT']), 60 | config['OUTPUT_FOL'].format(split=config['SPLIT'])) 61 | butils.write_seq(output_data, out_file) 62 | 63 | print('DONE:', seq_todo) 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | # Required to fix bug in multiprocessing on windows. 69 | freeze_support() 70 | 71 | # Obtain list of sequences to run tracker for. 72 | if config['Benchmarks']: 73 | benchmarks = config['Benchmarks'] 74 | else: 75 | benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao'] 76 | if config['SPLIT'] != 'train': 77 | benchmarks += ['waymo', 'mots_challenge'] 78 | seqs_todo = [] 79 | for bench in benchmarks: 80 | bench_fol = os.path.join(config['INPUT_FOL'].format(split=config['SPLIT']), bench) 81 | seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)] 82 | 83 | # Run in parallel 84 | if config['Num_Parallel_Cores']: 85 | with Pool(config['Num_Parallel_Cores']) as pool: 86 | results = pool.map(do_sequence, seqs_todo) 87 | 88 | # Run in series 89 | else: 90 | for seq_todo in seqs_todo: 91 | do_sequence(seq_todo) 92 | 93 | -------------------------------------------------------------------------------- /TrackEval/trackeval/baselines/vizualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vizualize: Code which converts .txt rle tracking results into a visual .png format. 3 | 4 | Author: Jonathon Luiten 5 | """ 6 | 7 | import os 8 | import sys 9 | from multiprocessing.pool import Pool 10 | from multiprocessing import freeze_support 11 | 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) 13 | from trackeval.baselines import baseline_utils as butils 14 | from trackeval.utils import get_code_path 15 | from trackeval.datasets.rob_mots_classmap import cls_id_to_name 16 | 17 | code_path = get_code_path() 18 | config = { 19 | # Tracker format: 20 | 'INPUT_FOL': os.path.join(code_path, 'data/trackers/rob_mots/{split}/STP/data/{bench}'), 21 | 'OUTPUT_FOL': os.path.join(code_path, 'data/viz/rob_mots/{split}/STP/data/{bench}'), 22 | # GT format: 23 | # 'INPUT_FOL': os.path.join(code_path, 'data/gt/rob_mots/{split}/{bench}/data/'), 24 | # 'OUTPUT_FOL': os.path.join(code_path, 'data/gt_viz/rob_mots/{split}/{bench}/'), 25 | 'SPLIT': 'train', # valid: 'train', 'val', 'test'. 26 | 'Benchmarks': None, # If None, all benchmarks in SPLIT. 27 | 'Num_Parallel_Cores': None, # If None, run without parallel. 28 | } 29 | 30 | 31 | def do_sequence(seq_file): 32 | # Folder to save resulting visualization in 33 | out_fol = seq_file.replace(config['INPUT_FOL'].format(split=config['SPLIT'], bench=bench), 34 | config['OUTPUT_FOL'].format(split=config['SPLIT'], bench=bench)).replace('.txt', '') 35 | 36 | # Load input data from file (e.g. provided detections) 37 | # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'} 38 | data = butils.load_seq(seq_file) 39 | 40 | # Get frame size for visualizing empty frames 41 | im_h, im_w = butils.get_frame_size(data) 42 | 43 | # First run for each class. 44 | for cls, cls_data in data.items(): 45 | 46 | if cls >= 100: 47 | continue 48 | 49 | # Run for each timestep. 50 | for timestep, t_data in enumerate(cls_data): 51 | # Save out visualization 52 | out_file = os.path.join(out_fol, cls_id_to_name[cls], str(timestep).zfill(5) + '.png') 53 | butils.save_as_png(t_data, out_file, im_h, im_w) 54 | 55 | 56 | # Then run for all classes combined 57 | # Converts data from a class-separated to a class-combined format. 58 | data = butils.combine_classes(data) 59 | 60 | # Run for each timestep. 61 | for timestep, t_data in enumerate(data): 62 | # Save out visualization 63 | out_file = os.path.join(out_fol, 'all_classes', str(timestep).zfill(5) + '.png') 64 | butils.save_as_png(t_data, out_file, im_h, im_w) 65 | 66 | print('DONE:', seq_file) 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | # Required to fix bug in multiprocessing on windows. 72 | freeze_support() 73 | 74 | # Obtain list of sequences to run tracker for. 75 | if config['Benchmarks']: 76 | benchmarks = config['Benchmarks'] 77 | else: 78 | benchmarks = ['davis_unsupervised', 'kitti_mots', 'youtube_vis', 'ovis', 'bdd_mots', 'tao'] 79 | if config['SPLIT'] != 'train': 80 | benchmarks += ['waymo', 'mots_challenge'] 81 | seqs_todo = [] 82 | for bench in benchmarks: 83 | bench_fol = config['INPUT_FOL'].format(split=config['SPLIT'], bench=bench) 84 | seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)] 85 | 86 | # Run in parallel 87 | if config['Num_Parallel_Cores']: 88 | with Pool(config['Num_Parallel_Cores']) as pool: 89 | results = pool.map(do_sequence, seqs_todo) 90 | 91 | # Run in series 92 | else: 93 | for seq_todo in seqs_todo: 94 | do_sequence(seq_todo) 95 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_2d_box import Kitti2DBox 2 | from .kitti_mots import KittiMOTS 3 | from .mot_challenge_2d_box import MotChallenge2DBox 4 | from .mots_challenge import MOTSChallenge 5 | from .bdd100k import BDD100K 6 | from .davis import DAVIS 7 | from .tao import TAO 8 | from .tao_ow import TAO_OW 9 | try: 10 | from .burst import BURST 11 | from .burst_ow import BURST_OW 12 | except ImportError as err: 13 | print(f"Error importing BURST due to missing underlying dependency: {err}") 14 | from .youtube_vis import YouTubeVIS 15 | from .head_tracking_challenge import HeadTrackingChallenge 16 | from .rob_mots import RobMOTS 17 | from .person_path_22 import PersonPath22 18 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/burst.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .burst_helpers.burst_base import BURSTBase 3 | from .burst_helpers.format_converter import GroundTruthBURSTFormatToTAOFormatConverter, PredictionBURSTFormatToTAOFormatConverter 4 | from .. import utils 5 | 6 | 7 | class BURST(BURSTBase): 8 | """Dataset class for TAO tracking""" 9 | 10 | @staticmethod 11 | def get_default_dataset_config(): 12 | tao_config = BURSTBase.get_default_dataset_config() 13 | code_path = utils.get_code_path() 14 | 15 | # e.g. 'data/gt/tsunami/exemplar_guided/' 16 | tao_config['GT_FOLDER'] = os.path.join( 17 | code_path, 'data/gt/burst/val/') # Location of GT data 18 | # e.g. 'data/trackers/tsunami/exemplar_guided/mask_guided/validation/' 19 | tao_config['TRACKERS_FOLDER'] = os.path.join( 20 | code_path, 'data/trackers/burst/class-guided/') # Trackers location 21 | # set to True or False 22 | tao_config['EXEMPLAR_GUIDED'] = False 23 | return tao_config 24 | 25 | def _iou_type(self): 26 | return 'mask' 27 | 28 | def _box_or_mask_from_det(self, det): 29 | return det['segmentation'] 30 | 31 | def _calculate_area_for_ann(self, ann): 32 | import pycocotools.mask as cocomask 33 | return cocomask.area(ann["segmentation"]) 34 | 35 | def _calculate_similarities(self, gt_dets_t, tracker_dets_t): 36 | similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False) 37 | return similarity_scores 38 | 39 | def _is_exemplar_guided(self): 40 | exemplar_guided = self.config['EXEMPLAR_GUIDED'] 41 | return exemplar_guided 42 | 43 | def _postproc_ground_truth_data(self, data): 44 | return GroundTruthBURSTFormatToTAOFormatConverter(data).convert() 45 | 46 | def _postproc_prediction_data(self, data): 47 | return PredictionBURSTFormatToTAOFormatConverter( 48 | self.gt_data, data, 49 | exemplar_guided=self._is_exemplar_guided()).convert() 50 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/burst_helpers/BURST_SPECIFIC_ISSUES.md: -------------------------------------------------------------------------------- 1 | The track ids in both ground truth and predictions are not globally unique, but 2 | start from 1 for each video. At the moment when converting from Ali format to 3 | TAO format, we remap the ids to be globally unique. It would be better to 4 | directly have this in the data though. 5 | 6 | 7 | Improve setting of EXEMPLAR_GUIDED flag, maybe this can be done automatically. 8 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/burst_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/TrackEval/trackeval/datasets/burst_helpers/__init__.py -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/burst_helpers/convert_burst_format_to_tao_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from .format_converter import GroundTruthBURSTFormatToTAOFormatConverter, PredictionBURSTFormatToTAOFormatConverter 4 | 5 | 6 | def main(args): 7 | with open(args.gt_input_file) as f: 8 | ali_format_gt = json.load(f) 9 | tao_format_gt = GroundTruthBURSTFormatToTAOFormatConverter( 10 | ali_format_gt, args.split).convert() 11 | with open(args.gt_output_file, 'w') as f: 12 | json.dump(tao_format_gt, f) 13 | 14 | if args.pred_input_file is None: 15 | return 16 | with open(args.pred_input_file) as f: 17 | ali_format_pred = json.load(f) 18 | tao_format_pred = PredictionBURSTFormatToTAOFormatConverter( 19 | tao_format_gt, ali_format_pred, args.split, 20 | args.exemplar_guided).convert() 21 | with open(args.pred_output_file, 'w') as f: 22 | json.dump(tao_format_pred, f) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | '--gt_input_file', type=str, 29 | default='../data/gt/tsunami/exemplar_guided/validation_all_annotations.json') 30 | parser.add_argument('--gt_output_file', type=str, 31 | default='/tmp/val_gt.json') 32 | parser.add_argument('--pred_input_file', type=str, 33 | default='../data/trackers/tsunami/exemplar_guided/STCN_off_the_shelf/data/results.json') 34 | parser.add_argument('--pred_output_file', type=str, 35 | default='/tmp/pred.json') 36 | parser.add_argument('--split', type=str, default='validation') 37 | parser.add_argument('--exemplar_guided', type=bool, default=True) 38 | args_ = parser.parse_args() 39 | main(args_) 40 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/burst_ow.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from .burst_helpers.burst_ow_base import BURST_OW_Base 4 | from .burst_helpers.format_converter import GroundTruthBURSTFormatToTAOFormatConverter, PredictionBURSTFormatToTAOFormatConverter 5 | from .. import utils 6 | 7 | 8 | class BURST_OW(BURST_OW_Base): 9 | """Dataset class for TAO tracking""" 10 | 11 | @staticmethod 12 | def get_default_dataset_config(): 13 | tao_config = BURST_OW_Base.get_default_dataset_config() 14 | code_path = utils.get_code_path() 15 | tao_config['GT_FOLDER'] = os.path.join( 16 | code_path, 'data/gt/burst/all_classes/val/') # Location of GT data 17 | tao_config['TRACKERS_FOLDER'] = os.path.join( 18 | code_path, 'data/trackers/burst/open-world/val/') # Trackers location 19 | return tao_config 20 | 21 | def _iou_type(self): 22 | return 'mask' 23 | 24 | def _box_or_mask_from_det(self, det): 25 | if "segmentation" in det: 26 | return det["segmentation"] 27 | else: 28 | return det["mask"] 29 | 30 | def _calculate_area_for_ann(self, ann): 31 | import pycocotools.mask as cocomask 32 | seg = self._box_or_mask_from_det(ann) 33 | return cocomask.area(seg) 34 | 35 | def _calculate_similarities(self, gt_dets_t, tracker_dets_t): 36 | similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False) 37 | return similarity_scores 38 | 39 | def _postproc_ground_truth_data(self, data): 40 | return GroundTruthBURSTFormatToTAOFormatConverter(data).convert() 41 | 42 | def _postproc_prediction_data(self, data): 43 | # if it's a list, it's already in TAO format and not in Ali format 44 | # however the image ids do not match and need to be remapped 45 | if isinstance(data, list): 46 | _remap_image_ids(data, self.gt_data) 47 | return data 48 | 49 | return PredictionBURSTFormatToTAOFormatConverter( 50 | self.gt_data, data, 51 | exemplar_guided=False).convert() 52 | 53 | 54 | def _remap_image_ids(pred_data, ali_gt_data): 55 | code_path = utils.get_code_path() 56 | if 'split' in ali_gt_data: 57 | split = ali_gt_data['split'] 58 | else: 59 | split = 'val' 60 | 61 | if split in ('val', 'validation'): 62 | tao_gt_path = os.path.join( 63 | code_path, 'data/gt/tao/tao_validation/gt.json') 64 | else: 65 | tao_gt_path = os.path.join( 66 | code_path, 'data/gt/tao/tao_test/test_without_annotations.json') 67 | 68 | with open(tao_gt_path) as f: 69 | tao_gt = json.load(f) 70 | 71 | tao_img_by_id = {} 72 | for img in tao_gt['images']: 73 | img_id = img['id'] 74 | tao_img_by_id[img_id] = img 75 | 76 | ali_img_id_by_filename = {} 77 | for ali_img in ali_gt_data['images']: 78 | ali_img_id = ali_img['id'] 79 | file_name = ali_img['file_name'].replace("validation", "val") 80 | ali_img_id_by_filename[file_name] = ali_img_id 81 | 82 | ali_img_id_by_tao_img_id = {} 83 | for tao_img_id, tao_img in tao_img_by_id.items(): 84 | file_name = tao_img['file_name'] 85 | ali_img_id = ali_img_id_by_filename[file_name] 86 | ali_img_id_by_tao_img_id[tao_img_id] = ali_img_id 87 | 88 | for det in pred_data: 89 | tao_img_id = det['image_id'] 90 | ali_img_id = ali_img_id_by_tao_img_id[tao_img_id] 91 | det['image_id'] = ali_img_id 92 | -------------------------------------------------------------------------------- /TrackEval/trackeval/datasets/rob_mots_classmap.py: -------------------------------------------------------------------------------- 1 | cls_id_to_name = { 2 | 1: 'person', 3 | 2: 'bicycle', 4 | 3: 'car', 5 | 4: 'motorcycle', 6 | 5: 'airplane', 7 | 6: 'bus', 8 | 7: 'train', 9 | 8: 'truck', 10 | 9: 'boat', 11 | 10: 'traffic light', 12 | 11: 'fire hydrant', 13 | 12: 'stop sign', 14 | 13: 'parking meter', 15 | 14: 'bench', 16 | 15: 'bird', 17 | 16: 'cat', 18 | 17: 'dog', 19 | 18: 'horse', 20 | 19: 'sheep', 21 | 20: 'cow', 22 | 21: 'elephant', 23 | 22: 'bear', 24 | 23: 'zebra', 25 | 24: 'giraffe', 26 | 25: 'backpack', 27 | 26: 'umbrella', 28 | 27: 'handbag', 29 | 28: 'tie', 30 | 29: 'suitcase', 31 | 30: 'frisbee', 32 | 31: 'skis', 33 | 32: 'snowboard', 34 | 33: 'sports ball', 35 | 34: 'kite', 36 | 35: 'baseball bat', 37 | 36: 'baseball glove', 38 | 37: 'skateboard', 39 | 38: 'surfboard', 40 | 39: 'tennis racket', 41 | 40: 'bottle', 42 | 41: 'wine glass', 43 | 42: 'cup', 44 | 43: 'fork', 45 | 44: 'knife', 46 | 45: 'spoon', 47 | 46: 'bowl', 48 | 47: 'banana', 49 | 48: 'apple', 50 | 49: 'sandwich', 51 | 50: 'orange', 52 | 51: 'broccoli', 53 | 52: 'carrot', 54 | 53: 'hot dog', 55 | 54: 'pizza', 56 | 55: 'donut', 57 | 56: 'cake', 58 | 57: 'chair', 59 | 58: 'couch', 60 | 59: 'potted plant', 61 | 60: 'bed', 62 | 61: 'dining table', 63 | 62: 'toilet', 64 | 63: 'tv', 65 | 64: 'laptop', 66 | 65: 'mouse', 67 | 66: 'remote', 68 | 67: 'keyboard', 69 | 68: 'cell phone', 70 | 69: 'microwave', 71 | 70: 'oven', 72 | 71: 'toaster', 73 | 72: 'sink', 74 | 73: 'refrigerator', 75 | 74: 'book', 76 | 75: 'clock', 77 | 76: 'vase', 78 | 77: 'scissors', 79 | 78: 'teddy bear', 80 | 79: 'hair drier', 81 | 80: 'toothbrush'} -------------------------------------------------------------------------------- /TrackEval/trackeval/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .hota import HOTA 2 | from .clear import CLEAR 3 | from .identity import Identity 4 | from .count import Count 5 | from .j_and_f import JAndF 6 | from .track_map import TrackMAP 7 | from .vace import VACE 8 | from .ideucl import IDEucl -------------------------------------------------------------------------------- /TrackEval/trackeval/metrics/count.py: -------------------------------------------------------------------------------- 1 | 2 | from ._base_metric import _BaseMetric 3 | from .. import _timing 4 | 5 | 6 | class Count(_BaseMetric): 7 | """Class which simply counts the number of tracker and gt detections and ids.""" 8 | def __init__(self, config=None): 9 | super().__init__() 10 | self.integer_fields = ['Dets', 'GT_Dets', 'IDs', 'GT_IDs'] 11 | self.fields = self.integer_fields 12 | self.summary_fields = self.fields 13 | 14 | @_timing.time 15 | def eval_sequence(self, data): 16 | """Returns counts for one sequence""" 17 | # Get results 18 | res = {'Dets': data['num_tracker_dets'], 19 | 'GT_Dets': data['num_gt_dets'], 20 | 'IDs': data['num_tracker_ids'], 21 | 'GT_IDs': data['num_gt_ids'], 22 | 'Frames': data['num_timesteps']} 23 | return res 24 | 25 | def combine_sequences(self, all_res): 26 | """Combines metrics across all sequences""" 27 | res = {} 28 | for field in self.integer_fields: 29 | res[field] = self._combine_sum(all_res, field) 30 | return res 31 | 32 | def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None): 33 | """Combines metrics across all classes by averaging over the class values""" 34 | res = {} 35 | for field in self.integer_fields: 36 | res[field] = self._combine_sum(all_res, field) 37 | return res 38 | 39 | def combine_classes_det_averaged(self, all_res): 40 | """Combines metrics across all classes by averaging over the detection values""" 41 | res = {} 42 | for field in self.integer_fields: 43 | res[field] = self._combine_sum(all_res, field) 44 | return res 45 | -------------------------------------------------------------------------------- /configs/nusc.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | root_dir = data/nuscenes 3 | 4 | [tracking] 5 | ; sim_metric = IoU 6 | ; sim_metric = CD 7 | sim_metric = NCD 8 | ang_vel = False 9 | vel_reinit = True 10 | t_miss = 5 11 | t_miss_new = 2 12 | t_hit = 0 13 | offline = True 14 | ; t_hit = 6 15 | ; offline = False 16 | 17 | match_algorithm = HA 18 | ; for IoU 19 | ; dis_thresh = 0.01 20 | ; for CD 21 | ; dis_thresh = -4.5 22 | ; for NCD 23 | dis_thresh = 0.5 24 | app_thresh = 0.5 25 | ang_thresh = 0 26 | ent_ex_score = 0.25 27 | app_m = 0.9 28 | p = 10 29 | q = 2 30 | 31 | [refinement] 32 | merge = True 33 | box_size_fusion = True 34 | interp = True 35 | smooth = True 36 | exponent = 45 37 | interp_max_interval = 4 38 | ignore_thresh = 0.35 39 | score_thresh = 0.15 40 | nms_thresh = 0 41 | pred_len = 0 42 | tau = 5.4 43 | ; for virconvtrack 44 | ; tau = 5.5 45 | 46 | [visualization] 47 | trajectory = False 48 | det_noise = False 49 | contradiction = False 50 | interpolation = False -------------------------------------------------------------------------------- /configs/virconv/default.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | root_dir = data/kitti/tracking 3 | 4 | [detection] 5 | det3d_name = virconv 6 | raw_score = True 7 | det2d_name = spatial_embeddings 8 | det2d_emb_name = BoT 9 | seg_name = spatial_embeddings 10 | seg_ckpt = segmentation/spatial_embeddings/spatial_embeddings.pth 11 | seg_emb_name = point_track 12 | seg_emb_ckpt = segmentation/point_track/point_track.pth 13 | 14 | score_thresh = 0.1 15 | use_pose = False 16 | use_lidar = True 17 | use_inst = True 18 | use_det2d = False 19 | use_embed = False 20 | min_corr_pts = 1 21 | min_corr_iou = 0.1 22 | recover_score_thresh = 0.85 23 | det3d_save_name = virconv_point_fusion 24 | 25 | [tracking] 26 | ; sim_metric = IoU 27 | ; sim_metric = CD 28 | sim_metric = NCD 29 | ang_vel = False 30 | vel_reinit = True 31 | t_miss = 28 32 | t_miss_new = 5 33 | t_hit = 6 34 | offline = True 35 | ; t_hit = 0 36 | ; offline = False 37 | 38 | match_algorithm = HA 39 | ; for IoU 40 | ; dis_thresh = 0.01 41 | ; for CD 42 | ; dis_thresh = -4.5 43 | ; for NCD 44 | dis_thresh = 0.5 45 | app_thresh = 0.5 46 | ang_thresh = 0 47 | ent_ex_score = 0.25 48 | app_m = 0.9 49 | p = 10 50 | q = 2 51 | 52 | [refinement] 53 | merge = True 54 | box_size_fusion = True 55 | interp = True 56 | smooth = True 57 | exponent = 45 58 | interp_max_interval = 4 59 | ignore_thresh = 0.35 60 | score_thresh = 0.15 61 | nms_thresh = 0 62 | pred_len = 0 63 | tau = 5.4 64 | ; for virconvtrack 65 | ; tau = 5.5 66 | 67 | [visualization] 68 | trajectory = False 69 | det_noise = False 70 | contradiction = False 71 | interpolation = False -------------------------------------------------------------------------------- /configs/voxel/default.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | root_dir = data/kitti/tracking 3 | 4 | [detection] 5 | det3d_name = voxel_rcnn 6 | raw_score = False 7 | det3d_ckpt = detection/voxel_rcnn/voxel_rcnn.pth 8 | batch_size = 8 9 | tta = False 10 | det2d_name = spatial_embeddings 11 | det2d_emb_name = BoT 12 | seg_name = spatial_embeddings 13 | seg_ckpt = segmentation/spatial_embeddings/spatial_embeddings.pth 14 | seg_emb_name = point_track 15 | seg_emb_ckpt = segmentation/point_track/point_track.pth 16 | 17 | score_thresh = 0.2 18 | use_pose = False 19 | use_lidar = True 20 | use_inst = True 21 | use_det2d = False 22 | use_embed = False 23 | min_corr_pts = 1 24 | min_corr_iou = 0 25 | recover_score_thresh = 0.85 26 | det3d_save_name = voxel_point_fusion 27 | 28 | [tracking] 29 | ; sim_metric = IoU 30 | ; sim_metric = CD 31 | sim_metric = NCD 32 | ang_vel = False 33 | vel_reinit = True 34 | t_miss = 28 35 | t_miss_new = 5 36 | t_hit = 6 37 | offline = True 38 | ; t_hit = 0 39 | ; offline = False 40 | 41 | match_algorithm = HA 42 | ; for IoU 43 | ; dis_thresh = 0.01 44 | ; for CD 45 | ; dis_thresh = -4.5 46 | ; for NCD 47 | dis_thresh = 0.5 48 | app_thresh = 0.5 49 | ang_thresh = 0 50 | ent_ex_score = 0.25 51 | app_m = 0.9 52 | p = 10 53 | q = 2 54 | 55 | [refinement] 56 | merge = True 57 | box_size_fusion = True 58 | interp = True 59 | smooth = True 60 | exponent = 45 61 | interp_max_interval = 4 62 | ignore_thresh = 0.35 63 | score_thresh = 0 64 | nms_thresh = 0 65 | pred_len = 0 66 | tau = 8 67 | 68 | [visualization] 69 | trajectory = False 70 | det_noise = False 71 | contradiction = False 72 | interpolation = False -------------------------------------------------------------------------------- /configs/voxel_tta/default.ini: -------------------------------------------------------------------------------- 1 | [data] 2 | root_dir = data/kitti/tracking 3 | 4 | [detection] 5 | det3d_name = voxel_rcnn_tta 6 | raw_score = False 7 | det3d_ckpt = detection/voxel_rcnn/voxel_rcnn.pth 8 | batch_size = 8 9 | tta = True 10 | det2d_name = spatial_embeddings 11 | det2d_emb_name = BoT 12 | seg_name = spatial_embeddings 13 | seg_ckpt = segmentation/spatial_embeddings/spatial_embeddings.pth 14 | seg_emb_name = point_track 15 | seg_emb_ckpt = segmentation/point_track/point_track.pth 16 | 17 | score_thresh = 0.2 18 | use_pose = False 19 | use_lidar = True 20 | use_inst = True 21 | use_det2d = False 22 | use_embed = False 23 | min_corr_pts = 1 24 | min_corr_iou = 0 25 | recover_score_thresh = 0.95 26 | det3d_save_name = voxel_tta_point_fusion 27 | 28 | [tracking] 29 | ; sim_metric = IoU 30 | ; sim_metric = CD 31 | sim_metric = NCD 32 | ang_vel = False 33 | vel_reinit = True 34 | t_miss = 28 35 | t_miss_new = 5 36 | t_hit = 6 37 | offline = True 38 | ; t_hit = 0 39 | ; offline = False 40 | 41 | match_algorithm = HA 42 | ; for IoU 43 | ; dis_thresh = 0.01 44 | ; for CD 45 | ; dis_thresh = -4.5 46 | ; for NCD 47 | dis_thresh = 0.5 48 | app_thresh = 0.5 49 | ang_thresh = 0 50 | ent_ex_score = 0.25 51 | app_m = 0.9 52 | p = 10 53 | q = 2 54 | 55 | [refinement] 56 | merge = True 57 | box_size_fusion = True 58 | interp = True 59 | smooth = True 60 | exponent = 45 61 | interp_max_interval = 4 62 | ignore_thresh = 0.35 63 | score_thresh = 0 64 | nms_thresh = 0 65 | pred_len = 0 66 | tau = 8 67 | 68 | [visualization] 69 | trajectory = False 70 | det_noise = False 71 | contradiction = False 72 | interpolation = False -------------------------------------------------------------------------------- /data_processing/check_det_num.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def main(): 5 | det_path = Path("data/tracking/virconv/") 6 | n = 0 7 | max_score = -100 8 | min_score = -100 9 | for seq in det_path.iterdir(): 10 | for file_path in seq.iterdir(): 11 | with open(file_path) as f: 12 | lines = f.readlines() 13 | n += len(lines) 14 | if len(lines) > 0: 15 | max_score = max( 16 | max_score, max([float(line.split()[-1]) for line in lines]) 17 | ) 18 | min_score = min( 19 | min_score, min([float(line.split()[-1]) for line in lines]) 20 | ) 21 | print(n, max_score, min_score) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /data_processing/crop_det_images.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | 7 | def crop_images(label_dir: Path, img_dir: Path, out_dir: Path): 8 | out_dir.mkdir(parents=True, exist_ok=True) 9 | for seq in range(21): 10 | seq = str(seq).zfill(4) 11 | seq_label_dir = label_dir / seq 12 | pbar = tqdm(list(seq_label_dir.iterdir())) 13 | pbar.set_description(seq) 14 | for label_file in pbar: 15 | with open(label_file) as f: 16 | lines = f.readlines() 17 | objs = [line.split() for line in lines] 18 | 19 | frame = label_file.stem 20 | image = Image.open(img_dir / seq / f"{frame}.png") 21 | cur_out_dir = out_dir / seq / frame 22 | for idx, obj in enumerate(objs): 23 | cur_out_dir.mkdir(exist_ok=True, parents=True) 24 | box2d = list(map(int, map(float, obj[1:5]))) 25 | cropped_img = image.crop(box2d) 26 | cropped_img.save(cur_out_dir / f"{idx}.png") 27 | 28 | 29 | def main(): 30 | img_dir = Path("data/kitti/tracking/training/image_02") 31 | out_dir = Path("data/kitti/tracking/training/det2d_emb_out/yolox") 32 | label_dir = Path("data/kitti/tracking/training/det2d_out/yolox") 33 | crop_images(label_dir, img_dir, out_dir) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /data_processing/crop_points.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from configparser import ConfigParser 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | 8 | from utils import (Calibration, crop_points_from_boxes, 9 | get_lidar_boxes_from_objs, get_objects_from_label, 10 | points_inside_boxes) 11 | 12 | 13 | def crop_points( 14 | label_dir: Path, lidar_dir: Path, calib: Calibration, out_dir: Path, is_gt=False 15 | ): 16 | out_dir.mkdir(parents=True, exist_ok=True) 17 | print(label_dir.name) 18 | for label_file in label_dir.iterdir(): 19 | objs = get_objects_from_label(label_file, track=False) 20 | if is_gt: 21 | objs = [ 22 | obj for obj in objs if obj.cls_type == "Car" or obj.cls_type == "Van" 23 | ] 24 | 25 | frame = label_file.stem 26 | cur_out_dir = out_dir / frame 27 | if len(objs) > 0: 28 | cur_out_dir.mkdir(exist_ok=True) 29 | boxes = get_lidar_boxes_from_objs(objs, calib) 30 | inside_points = crop_points_from_boxes( 31 | np.fromfile(lidar_dir / f"{frame}.bin", dtype=np.float32).reshape( 32 | -1, 4 33 | ), 34 | boxes, 35 | front_only=True, 36 | ) 37 | for idx, cur_points in enumerate(inside_points): 38 | cur_points.tofile(cur_out_dir / f"{idx}.bin") 39 | 40 | 41 | def test(label_dir: Path, lidar_dir: Path, calib: Calibration, is_gt=False): 42 | for label_file in label_dir.iterdir(): 43 | objs = get_objects_from_label(label_file, track=False) 44 | if is_gt: 45 | objs = [ 46 | obj for obj in objs if obj.cls_type == "Car" or obj.cls_type == "Van" 47 | ] 48 | if len(objs) > 0: 49 | boxes = get_lidar_boxes_from_objs(objs, calib) 50 | for i in range(len(boxes)): 51 | points = np.fromfile( 52 | lidar_dir / label_file.stem / f"{i}.bin", dtype=np.float32 53 | ).reshape(-1, 4) 54 | assert np.all(points_inside_boxes(boxes[i : i + 1], points[:, :3])) 55 | 56 | 57 | def mp_func(args): 58 | label_dir, lidar_dir, calib, out_dir, is_gt = args 59 | crop_points(label_dir, lidar_dir, calib, out_dir, is_gt) 60 | 61 | 62 | def main(): 63 | parser = ArgumentParser() 64 | parser.add_argument("config", type=str) 65 | parser.add_argument("split", type=str) 66 | parser.add_argument("--gt", action="store_true") 67 | args = parser.parse_args() 68 | config = ConfigParser() 69 | config.read(args.config) 70 | root_dir = Path(config["data"]["root_dir"]) / args.split 71 | det_name = config["detection"]["det3d_name"] 72 | 73 | lidar_dir = root_dir / "velodyne" 74 | calib_dir = root_dir / "calib" 75 | out_dir = root_dir / "cropped_points" 76 | detection_dir = root_dir / "det3d_out" / det_name 77 | 78 | is_gt = args.gt 79 | 80 | with Pool(8) as p: 81 | p.map( 82 | mp_func, 83 | [ 84 | ( 85 | detection_dir / seq.name, 86 | seq, 87 | Calibration(calib_dir / f"{seq.name}.txt"), 88 | out_dir / det_name / seq.name, 89 | is_gt, 90 | ) 91 | for seq in lidar_dir.iterdir() 92 | ], 93 | ) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /data_processing/crop_seg_images.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | def crop_images(label_dir: Path, img_dir: Path, out_dir: Path): 9 | out_dir.mkdir(parents=True, exist_ok=True) 10 | for seq in range(21): 11 | seq = str(seq).zfill(4) 12 | seq_label_dir = label_dir / seq 13 | pbar = tqdm(list(seq_label_dir.iterdir())) 14 | pbar.set_description(seq) 15 | for label_file in pbar: 16 | frame = label_file.stem 17 | image = Image.open(img_dir / seq / f"{frame}.png") 18 | cur_out_dir = out_dir / seq / frame 19 | 20 | seg_image = np.array(Image.open(label_file)) 21 | inst_ids = np.unique(seg_image)[1:] 22 | for idx, inst_id in enumerate(inst_ids): 23 | cur_out_dir.mkdir(exist_ok=True, parents=True) 24 | row_inds, col_inds = np.nonzero(seg_image == inst_id) 25 | box2d = [ 26 | np.min(col_inds), 27 | np.min(row_inds), 28 | np.max(col_inds), 29 | np.max(row_inds), 30 | ] 31 | cropped_img = image.crop(box2d) 32 | cropped_img.save(cur_out_dir / f"{idx}.png") 33 | 34 | 35 | def main(): 36 | img_dir = Path("data/kitti/tracking/training/image_02") 37 | out_dir = Path("data/kitti/tracking/training/seg_emb_out/spatial_embeddings_img") 38 | label_dir = Path("data/kitti/tracking/training/seg_out/spatial_embeddings") 39 | crop_images(label_dir, img_dir, out_dir) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /data_processing/masks2boxes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import tqdm 5 | from PIL import Image 6 | 7 | 8 | def convert_gt_instances_to_imgs(): 9 | instance_dir = Path("data/kitti/tracking/training/instances") 10 | seg_out_dir = Path("data/kitti/tracking/training/seg_out/gt") 11 | seg_out_dir.mkdir(exist_ok=True, parents=True) 12 | for seq in range(21): 13 | seq = str(seq).zfill(4) 14 | seq_dir = instance_dir / seq 15 | seq_out_dir = seg_out_dir / seq 16 | seq_out_dir.mkdir(exist_ok=True) 17 | bar = tqdm.tqdm(list(seq_dir.iterdir())) 18 | bar.set_description(seq) 19 | for img_file in bar: 20 | img = np.asarray(Image.open(img_file)) # uin16 21 | new_img = np.zeros(img.shape[:2], dtype=np.uint8) 22 | obj_ids = np.unique(img) 23 | for obj_id in obj_ids: 24 | class_id = obj_id // 1000 25 | if class_id == 1: 26 | inst_id = obj_id % 1000 + 1 27 | assert inst_id <= 255 28 | new_img[img == obj_id] = inst_id 29 | Image.fromarray(new_img).save(seq_out_dir / img_file.name) 30 | 31 | 32 | def convert_to_boxes(): 33 | seg_out_dir = Path("data/kitti/tracking/training/seg_out/spatial_embeddings") 34 | det_out_dir = Path("data/kitti/tracking/training/det2d_out/spatial_embeddings") 35 | det_out_dir.mkdir(exist_ok=True, parents=True) 36 | for seq in range(21): 37 | seq = str(seq).zfill(4) 38 | seq_dir = seg_out_dir / seq 39 | seq_out_dir = det_out_dir / seq 40 | seq_out_dir.mkdir(exist_ok=True) 41 | bar = tqdm.tqdm(list(seq_dir.iterdir())) 42 | bar.set_description(seq) 43 | for img_file in bar: 44 | img = np.asarray(Image.open(img_file)) 45 | obj_ids = np.unique(img)[1:] 46 | lines = [] 47 | for obj_id in obj_ids: 48 | row_inds, col_inds = np.nonzero(img == obj_id) 49 | box2d = [ 50 | np.min(col_inds), 51 | np.min(row_inds), 52 | np.max(col_inds), 53 | np.max(row_inds), 54 | ] 55 | lines.append(" ".join([str(x) for x in box2d]) + f" 1.0 {obj_id}\n") 56 | with open(seq_out_dir / f"{img_file.stem}.txt", "w") as f: 57 | f.writelines(lines) 58 | 59 | 60 | if __name__ == "__main__": 61 | convert_to_boxes() 62 | -------------------------------------------------------------------------------- /data_processing/object2tracking.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | from pathlib import Path 4 | 5 | 6 | def main(): 7 | dir_name = sys.argv[1] 8 | obj_dir = Path(f"data/kitti/detection/{dir_name}") 9 | trk_dir = Path(f"data/kitti/tracking/training/det2d_out/{dir_name}") 10 | trk_dir.mkdir(exist_ok=True, parents=True) 11 | 12 | with open("data/kitti/detection/sample2seq.txt") as f: 13 | lines = f.readlines() 14 | split_lines = [line.split() for line in lines] 15 | sample2seq = {s[0]: s[1:] for s in split_lines} 16 | 17 | for sample_file in obj_dir.iterdir(): 18 | seq, frame = sample2seq[sample_file.stem] 19 | cur_dir = trk_dir / seq 20 | cur_dir.mkdir(exist_ok=True) 21 | shutil.copyfile(sample_file, cur_dir / f"{frame}.txt") 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /data_processing/save_img_shapes.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from configparser import ConfigParser 4 | from pathlib import Path 5 | 6 | import tqdm 7 | from PIL import Image 8 | 9 | 10 | def save_img_hw_json(root_dir: Path): 11 | img_hw_dict = {} 12 | img_dir = root_dir / "image_02" 13 | for seq_dir in img_dir.iterdir(): 14 | cur_seq_img_hw_dict = {} 15 | pbar = tqdm.tqdm(list(seq_dir.iterdir())) 16 | pbar.set_description(seq_dir.name) 17 | for img_path in pbar: 18 | img = Image.open(img_path) 19 | cur_seq_img_hw_dict[img_path.stem] = (img.height, img.width) 20 | img_hw_dict[seq_dir.name] = cur_seq_img_hw_dict 21 | 22 | with open(root_dir / "img_hw.json", "w") as f: 23 | json.dump(img_hw_dict, f) 24 | 25 | 26 | def main(): 27 | parser = ArgumentParser() 28 | parser.add_argument("config", type=str) 29 | parser.add_argument("split", type=str) 30 | args = parser.parse_args() 31 | config = ConfigParser() 32 | config.read(args.config) 33 | root_dir = Path(config["data"]["root_dir"]) / args.split 34 | save_img_hw_json(root_dir) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /data_processing/tracking2object.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | import tqdm 6 | 7 | 8 | def main(): 9 | parser = ArgumentParser() 10 | parser.add_argument("trk_dir", type=str) 11 | parser.add_argument("obj_dir", type=str) 12 | args = parser.parse_args() 13 | trk_dir = Path(args.trk_dir) 14 | obj_dir = Path(args.obj_dir) 15 | obj_dir.mkdir(exist_ok=True, parents=True) 16 | 17 | with open("data/kitti/detection/training/seq2sample.txt") as f: 18 | lines = f.readlines() 19 | split_lines = [line.split() for line in lines] 20 | seq2samples = {s[0]: s[1:] for s in split_lines} 21 | 22 | for seq, samples in tqdm.tqdm(seq2samples.items()): 23 | for i, file in enumerate(sorted((trk_dir / seq).iterdir())): 24 | shutil.copyfile(file, obj_dir / f"{samples[i]}.txt") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /datasets/voxelization.py: -------------------------------------------------------------------------------- 1 | tv = None 2 | try: 3 | import cumm.tensorview as tv 4 | except: 5 | pass 6 | 7 | 8 | class VoxelGeneratorWrapper: 9 | def __init__( 10 | self, 11 | voxel_size, 12 | point_cloud_range, 13 | num_point_features, 14 | max_num_points_per_voxel, 15 | max_num_voxels, 16 | ): 17 | try: 18 | from spconv.utils import VoxelGeneratorV2 as VoxelGenerator 19 | 20 | self.spconv_ver = 1 21 | except: 22 | try: 23 | from spconv.utils import VoxelGenerator 24 | 25 | self.spconv_ver = 1 26 | except: 27 | from spconv.utils import Point2VoxelCPU3d as VoxelGenerator 28 | 29 | self.spconv_ver = 2 30 | 31 | if self.spconv_ver == 1: 32 | self._voxel_generator = VoxelGenerator( 33 | voxel_size=voxel_size, 34 | point_cloud_range=point_cloud_range, 35 | max_num_points=max_num_points_per_voxel, 36 | max_voxels=max_num_voxels, 37 | ) 38 | else: 39 | self._voxel_generator = VoxelGenerator( 40 | vsize_xyz=voxel_size, 41 | coors_range_xyz=point_cloud_range, 42 | num_point_features=num_point_features, 43 | max_num_points_per_voxel=max_num_points_per_voxel, 44 | max_num_voxels=max_num_voxels, 45 | ) 46 | 47 | def generate(self, points): 48 | if self.spconv_ver == 1: 49 | voxel_output = self._voxel_generator.generate(points) 50 | if isinstance(voxel_output, dict): 51 | voxels, coordinates, num_points = ( 52 | voxel_output["voxels"], 53 | voxel_output["coordinates"], 54 | voxel_output["num_points_per_voxel"], 55 | ) 56 | else: 57 | voxels, coordinates, num_points = voxel_output 58 | else: 59 | assert ( 60 | tv is not None 61 | ), f"Unexpected error, library: 'cumm' wasn't imported properly." 62 | voxel_output = self._voxel_generator.point_to_voxel(tv.from_numpy(points)) 63 | tv_voxels, tv_coordinates, tv_num_points = voxel_output 64 | # make copy with numpy(), since numpy_view() will disappear as soon as the generator is deleted 65 | voxels = tv_voxels.numpy() 66 | coordinates = tv_coordinates.numpy() 67 | num_points = tv_num_points.numpy() 68 | return voxels, coordinates, num_points 69 | -------------------------------------------------------------------------------- /delete_far_objects.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | from utils import Calibration, KittiTrack3d, get_objects_from_label 6 | 7 | 8 | def main(): 9 | max_distance = 50 10 | # original_label_dir = Path('data/kitti/tracking/training/label_02') 11 | original_label_dir = Path( 12 | "F://Github/3D-Multi-Object-Tracker/evaluation/results/virconv/data" 13 | ) 14 | calib_dir = Path("data/kitti/tracking/training/calib") 15 | # my_label_dir = Path(f'data/kitti/tracking/training/label_near_{max_distance}') 16 | my_label_dir = Path( 17 | f"F://Github/3D-Multi-Object-Tracker/evaluation/results/virconv_{max_distance}/data" 18 | ) 19 | my_label_dir.mkdir(exist_ok=True, parents=True) 20 | for seq in range(21): 21 | seq = str(seq).zfill(4) 22 | calib = Calibration(calib_dir / f"{seq}.txt") 23 | tracks = get_objects_from_label(f"{original_label_dir / seq}.txt", track=True) 24 | lines = [] 25 | for track in tracks: 26 | track: KittiTrack3d 27 | box = track.to_lidar_box(calib) 28 | if np.linalg.norm(box[:2]) <= max_distance: 29 | lines.append(track.serialize()) 30 | with open(my_label_dir / f"{seq}.txt", "w") as f: 31 | f.writelines(lines) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /detection/kitti_object_eval_python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/detection/kitti_object_eval_python/__init__.py -------------------------------------------------------------------------------- /detection/voxel_rcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/detection/voxel_rcnn/__init__.py -------------------------------------------------------------------------------- /detection/voxel_rcnn/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AnchorGenerator: 5 | def __init__(self, anchor_range, anchor_generator_config): 6 | super().__init__() 7 | self.anchor_generator_cfg = anchor_generator_config 8 | self.anchor_range = anchor_range 9 | self.anchor_sizes = [ 10 | config["anchor_sizes"] for config in anchor_generator_config 11 | ] 12 | self.anchor_rotations = [ 13 | config["anchor_rotations"] for config in anchor_generator_config 14 | ] 15 | self.anchor_heights = [ 16 | config["anchor_bottom_heights"] for config in anchor_generator_config 17 | ] 18 | self.align_center = [ 19 | config.get("align_center", False) for config in anchor_generator_config 20 | ] 21 | 22 | assert ( 23 | len(self.anchor_sizes) 24 | == len(self.anchor_rotations) 25 | == len(self.anchor_heights) 26 | ) 27 | self.num_of_anchor_sets = len(self.anchor_sizes) 28 | 29 | def generate_anchors(self, grid_sizes): 30 | assert len(grid_sizes) == self.num_of_anchor_sets 31 | all_anchors = [] 32 | num_anchors_per_location = [] 33 | for grid_size, anchor_size, anchor_rotation, anchor_height, align_center in zip( 34 | grid_sizes, 35 | self.anchor_sizes, 36 | self.anchor_rotations, 37 | self.anchor_heights, 38 | self.align_center, 39 | ): 40 | num_anchors_per_location.append( 41 | len(anchor_rotation) * len(anchor_size) * len(anchor_height) 42 | ) 43 | if align_center: 44 | x_stride = (self.anchor_range[3] - self.anchor_range[0]) / grid_size[0] 45 | y_stride = (self.anchor_range[4] - self.anchor_range[1]) / grid_size[1] 46 | x_offset, y_offset = x_stride / 2, y_stride / 2 47 | else: 48 | x_stride = (self.anchor_range[3] - self.anchor_range[0]) / ( 49 | grid_size[0] - 1 50 | ) 51 | y_stride = (self.anchor_range[4] - self.anchor_range[1]) / ( 52 | grid_size[1] - 1 53 | ) 54 | x_offset, y_offset = 0, 0 55 | 56 | x_shifts = torch.arange( 57 | self.anchor_range[0] + x_offset, 58 | self.anchor_range[3] + 1e-5, 59 | step=x_stride, 60 | dtype=torch.float32, 61 | ).cuda() 62 | y_shifts = torch.arange( 63 | self.anchor_range[1] + y_offset, 64 | self.anchor_range[4] + 1e-5, 65 | step=y_stride, 66 | dtype=torch.float32, 67 | ).cuda() 68 | z_shifts = x_shifts.new_tensor(anchor_height) 69 | 70 | num_anchor_size, num_anchor_rotation = ( 71 | anchor_size.__len__(), 72 | anchor_rotation.__len__(), 73 | ) 74 | anchor_rotation = x_shifts.new_tensor(anchor_rotation) 75 | anchor_size = x_shifts.new_tensor(anchor_size) 76 | x_shifts, y_shifts, z_shifts = torch.meshgrid( 77 | [x_shifts, y_shifts, z_shifts] 78 | ) # [x_grid, y_grid, z_grid] 79 | anchors = torch.stack( 80 | (x_shifts, y_shifts, z_shifts), dim=-1 81 | ) # [x, y, z, 3] 82 | anchors = anchors[:, :, :, None, :].repeat(1, 1, 1, anchor_size.shape[0], 1) 83 | anchor_size = anchor_size.view(1, 1, 1, -1, 3).repeat( 84 | [*anchors.shape[0:3], 1, 1] 85 | ) 86 | anchors = torch.cat((anchors, anchor_size), dim=-1) 87 | anchors = anchors[:, :, :, :, None, :].repeat( 88 | 1, 1, 1, 1, num_anchor_rotation, 1 89 | ) 90 | anchor_rotation = anchor_rotation.view(1, 1, 1, 1, -1, 1).repeat( 91 | [*anchors.shape[0:3], num_anchor_size, 1, 1] 92 | ) 93 | anchors = torch.cat( 94 | (anchors, anchor_rotation), dim=-1 95 | ) # [x, y, z, num_size, num_rot, 7] 96 | 97 | anchors = anchors.permute(2, 1, 0, 3, 4, 5).contiguous() 98 | # anchors = anchors.view(-1, anchors.shape[-1]) 99 | anchors[..., 2] += anchors[..., 5] / 2 # shift to box centers 100 | all_anchors.append(anchors) 101 | return all_anchors, num_anchors_per_location 102 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/box_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ResidualCoder: 5 | def __init__(self, code_size=7, encode_angle_by_sincos=False, **kwargs): 6 | super().__init__() 7 | self.code_size = code_size 8 | self.encode_angle_by_sincos = encode_angle_by_sincos 9 | if self.encode_angle_by_sincos: 10 | self.code_size += 1 11 | 12 | def encode_torch(self, boxes, anchors): 13 | """ 14 | Args: 15 | boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] 16 | anchors: (N, 7 + C) [x, y, z, dx, dy, dz, heading or *[cos, sin], ...] 17 | 18 | Returns: 19 | 20 | """ 21 | anchors[:, 3:6] = torch.clamp_min(anchors[:, 3:6], min=1e-5) 22 | boxes[:, 3:6] = torch.clamp_min(boxes[:, 3:6], min=1e-5) 23 | 24 | xa, ya, za, dxa, dya, dza, ra, *cas = torch.split(anchors, 1, dim=-1) 25 | xg, yg, zg, dxg, dyg, dzg, rg, *cgs = torch.split(boxes, 1, dim=-1) 26 | 27 | diagonal = torch.sqrt(dxa**2 + dya**2) 28 | xt = (xg - xa) / diagonal 29 | yt = (yg - ya) / diagonal 30 | zt = (zg - za) / dza 31 | dxt = torch.log(dxg / dxa) 32 | dyt = torch.log(dyg / dya) 33 | dzt = torch.log(dzg / dza) 34 | if self.encode_angle_by_sincos: 35 | rt_cos = torch.cos(rg) - torch.cos(ra) 36 | rt_sin = torch.sin(rg) - torch.sin(ra) 37 | rts = [rt_cos, rt_sin] 38 | else: 39 | rts = [rg - ra] 40 | 41 | cts = [g - a for g, a in zip(cgs, cas)] 42 | return torch.cat([xt, yt, zt, dxt, dyt, dzt, *rts, *cts], dim=-1) 43 | 44 | def decode_torch(self, box_encodings, anchors): 45 | """ 46 | Args: 47 | box_encodings: (B, N, 7 + C) or (N, 7 + C) [x, y, z, dx, dy, dz, heading or *[cos, sin], ...] 48 | anchors: (B, N, 7 + C) or (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] 49 | 50 | Returns: 51 | 52 | """ 53 | xa, ya, za, dxa, dya, dza, ra, *cas = torch.split(anchors, 1, dim=-1) 54 | if not self.encode_angle_by_sincos: 55 | xt, yt, zt, dxt, dyt, dzt, rt, *cts = torch.split(box_encodings, 1, dim=-1) 56 | else: 57 | xt, yt, zt, dxt, dyt, dzt, cost, sint, *cts = torch.split( 58 | box_encodings, 1, dim=-1 59 | ) 60 | 61 | diagonal = torch.sqrt(dxa**2 + dya**2) 62 | xg = xt * diagonal + xa 63 | yg = yt * diagonal + ya 64 | zg = zt * dza + za 65 | 66 | dxg = torch.exp(dxt) * dxa 67 | dyg = torch.exp(dyt) * dya 68 | dzg = torch.exp(dzt) * dza 69 | 70 | if self.encode_angle_by_sincos: 71 | rg_cos = cost + torch.cos(ra) 72 | rg_sin = sint + torch.sin(ra) 73 | rg = torch.atan2(rg_sin, rg_cos) 74 | else: 75 | rg = rt + ra 76 | 77 | cgs = [t + a for t, a in zip(cts, cas)] 78 | return torch.cat([xg, yg, zg, dxg, dyg, dzg, rg, *cgs], dim=-1) 79 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/height_compression.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | 4 | class HeightCompression(Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, batch_dict): 9 | encoded_spconv_tensor = batch_dict["encoded_spconv_tensor"] 10 | spatial_features = encoded_spconv_tensor.dense() 11 | N, C, D, H, W = spatial_features.shape 12 | spatial_features = spatial_features.view(N, C * D, H, W) 13 | batch_dict["spatial_features"] = spatial_features 14 | batch_dict["spatial_features_stride"] = batch_dict[ 15 | "encoded_spconv_tensor_stride" 16 | ] 17 | return batch_dict 18 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/iou3d_nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/detection/voxel_rcnn/iou3d_nms/__init__.py -------------------------------------------------------------------------------- /detection/voxel_rcnn/iou3d_nms/iou3d_nms_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D IoU Calculation and Rotated NMS 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | """ 6 | 7 | import torch 8 | 9 | from . import iou3d_nms_cuda 10 | 11 | 12 | def boxes_iou_bev(boxes_a, boxes_b): 13 | """ 14 | Args: 15 | boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] 16 | boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] 17 | 18 | Returns: 19 | ans_iou: (N, M) 20 | """ 21 | assert boxes_a.shape[1] == boxes_b.shape[1] == 7 22 | ans_iou = torch.zeros( 23 | ((boxes_a.shape[0], boxes_b.shape[0])), dtype=torch.float32, device="cuda" 24 | ) 25 | 26 | iou3d_nms_cuda.boxes_iou_bev_gpu( 27 | boxes_a.contiguous(), boxes_b.contiguous(), ans_iou 28 | ) 29 | 30 | return ans_iou 31 | 32 | 33 | def boxes_iou3d_gpu(boxes_a, boxes_b): 34 | """ 35 | Args: 36 | boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] 37 | boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading] 38 | 39 | Returns: 40 | ans_iou: (N, M) 41 | """ 42 | assert boxes_a.shape[1] == boxes_b.shape[1] == 7 43 | 44 | # height overlap 45 | boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5] / 2).view(-1, 1) 46 | boxes_a_height_min = (boxes_a[:, 2] - boxes_a[:, 5] / 2).view(-1, 1) 47 | boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5] / 2).view(1, -1) 48 | boxes_b_height_min = (boxes_b[:, 2] - boxes_b[:, 5] / 2).view(1, -1) 49 | 50 | # bev overlap 51 | overlaps_bev = torch.zeros( 52 | (boxes_a.shape[0], boxes_b.shape[0]), dtype=torch.float32, device="cuda" 53 | ) # (N, M) 54 | iou3d_nms_cuda.boxes_overlap_bev_gpu( 55 | boxes_a.contiguous(), boxes_b.contiguous(), overlaps_bev 56 | ) 57 | 58 | max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min) 59 | min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max) 60 | overlaps_h = torch.clamp(min_of_max - max_of_min, min=0) 61 | 62 | # 3d iou 63 | overlaps_3d = overlaps_bev * overlaps_h 64 | 65 | vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1) 66 | vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(1, -1) 67 | 68 | iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6) 69 | 70 | return iou3d 71 | 72 | 73 | def nms_gpu(boxes, scores, thresh, pre_maxsize=None, **kwargs): 74 | """ 75 | :param boxes: (N, 7) [x, y, z, dx, dy, dz, heading] 76 | :param scores: (N) 77 | :param thresh: 78 | :return: 79 | """ 80 | assert boxes.shape[1] == 7 81 | order = scores.sort(0, descending=True)[1] 82 | if pre_maxsize is not None: 83 | order = order[:pre_maxsize] 84 | 85 | boxes = boxes[order].contiguous() 86 | keep = torch.LongTensor(boxes.size(0)) 87 | num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh) 88 | return order[keep[:num_out].cuda()].contiguous() 89 | 90 | 91 | def nms_normal_gpu(boxes, scores, thresh, **kwargs): 92 | """ 93 | :param boxes: (N, 7) [x, y, z, dx, dy, dz, heading] 94 | :param scores: (N) 95 | :param thresh: 96 | :return: 97 | """ 98 | assert boxes.shape[1] == 7 99 | order = scores.sort(0, descending=True)[1] 100 | 101 | boxes = boxes[order].contiguous() 102 | 103 | keep = torch.LongTensor(boxes.size(0)) 104 | num_out = iou3d_nms_cuda.nms_normal_gpu(boxes, keep, thresh) 105 | return order[keep[:num_out].cuda()].contiguous() 106 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/iou3d_nms/src/iou3d_nms.h: -------------------------------------------------------------------------------- 1 | #ifndef IOU3D_NMS_H 2 | #define IOU3D_NMS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap); 10 | int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou); 11 | int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh); 12 | int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/iou3d_nms/src/iou3d_nms_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // #include "iou3d_cpu.h" 8 | #include "iou3d_nms.h" 9 | 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu, "oriented boxes overlap"); 13 | m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou"); 14 | m.def("nms_gpu", &nms_gpu, "oriented nms gpu"); 15 | m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu"); 16 | // m.def("boxes_iou_bev_cpu", &boxes_iou_bev_cpu, "oriented boxes iou"); 17 | } 18 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/mean_vfe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MeanVFE(torch.nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, batch_dict): 9 | voxel_features, voxel_num_points = ( 10 | batch_dict["voxels"], 11 | batch_dict["voxel_num_points"], 12 | ) 13 | points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) 14 | normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as( 15 | voxel_features 16 | ) 17 | points_mean = points_mean / normalizer # (num_voxels, C) 18 | batch_dict["voxel_features"] = points_mean.contiguous() 19 | return batch_dict 20 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/detection/voxel_rcnn/pointnet2_stack/__init__.py -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "ball_query_gpu.h" 13 | 14 | #define CHECK_CUDA(x) do { \ 15 | if (!x.device().is_cuda()) { \ 16 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 17 | exit(-1); \ 18 | } \ 19 | } while (0) 20 | #define CHECK_CONTIGUOUS(x) do { \ 21 | if (!x.is_contiguous()) { \ 22 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 23 | exit(-1); \ 24 | } \ 25 | } while (0) 26 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 27 | 28 | 29 | int ball_query_wrapper_stack(int B, int M, float radius, int nsample, 30 | at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, 31 | at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor) { 32 | CHECK_INPUT(new_xyz_tensor); 33 | CHECK_INPUT(xyz_tensor); 34 | CHECK_INPUT(new_xyz_batch_cnt_tensor); 35 | CHECK_INPUT(xyz_batch_cnt_tensor); 36 | 37 | const float *new_xyz = new_xyz_tensor.data_ptr(); 38 | const float *xyz = xyz_tensor.data_ptr(); 39 | const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data_ptr(); 40 | const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data_ptr(); 41 | int *idx = idx_tensor.data_ptr(); 42 | 43 | ball_query_kernel_launcher_stack(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx); 44 | return 1; 45 | } 46 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ball_query_gpu.h" 13 | #include "cuda_utils.h" 14 | 15 | 16 | __global__ void ball_query_kernel_stack(int B, int M, float radius, int nsample, \ 17 | const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx) { 18 | // :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features 19 | // :param xyz_batch_cnt: (batch_size), [N1, N2, ...] 20 | // :param new_xyz: (M1 + M2 ..., 3) centers of the ball query 21 | // :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] 22 | // output: 23 | // idx: (M, nsample) 24 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 25 | if (pt_idx >= M) return; 26 | 27 | int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0]; 28 | for (int k = 1; k < B; k++){ 29 | if (pt_idx < pt_cnt) break; 30 | pt_cnt += new_xyz_batch_cnt[k]; 31 | bs_idx = k; 32 | } 33 | 34 | int xyz_batch_start_idx = 0; 35 | for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k]; 36 | // for (int k = 0; k < bs_idx; k++) new_xyz_batch_start_idx += new_xyz_batch_cnt[k]; 37 | 38 | new_xyz += pt_idx * 3; 39 | xyz += xyz_batch_start_idx * 3; 40 | idx += pt_idx * nsample; 41 | 42 | float radius2 = radius * radius; 43 | float new_x = new_xyz[0]; 44 | float new_y = new_xyz[1]; 45 | float new_z = new_xyz[2]; 46 | int n = xyz_batch_cnt[bs_idx]; 47 | 48 | int cnt = 0; 49 | for (int k = 0; k < n; ++k) { 50 | float x = xyz[k * 3 + 0]; 51 | float y = xyz[k * 3 + 1]; 52 | float z = xyz[k * 3 + 2]; 53 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 54 | if (d2 < radius2){ 55 | if (cnt == 0){ 56 | for (int l = 0; l < nsample; ++l) { 57 | idx[l] = k; 58 | } 59 | } 60 | idx[cnt] = k; 61 | ++cnt; 62 | if (cnt >= nsample) break; 63 | } 64 | } 65 | if (cnt == 0) idx[0] = -1; 66 | } 67 | 68 | 69 | void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample, 70 | const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx){ 71 | // :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features 72 | // :param xyz_batch_cnt: (batch_size), [N1, N2, ...] 73 | // :param new_xyz: (M1 + M2 ..., 3) centers of the ball query 74 | // :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] 75 | // output: 76 | // idx: (M, nsample) 77 | 78 | cudaError_t err; 79 | 80 | dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) 81 | dim3 threads(THREADS_PER_BLOCK); 82 | 83 | ball_query_kernel_stack<<>>(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx); 84 | // cudaDeviceSynchronize(); // for using printf in kernel function 85 | err = cudaGetLastError(); 86 | if (cudaSuccess != err) { 87 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 88 | exit(-1); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of ball query, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #ifndef _STACK_BALL_QUERY_GPU_H 9 | #define _STACK_BALL_QUERY_GPU_H 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | int ball_query_wrapper_stack(int B, int M, float radius, int nsample, 17 | at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, 18 | at::Tensor xyz_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor); 19 | 20 | 21 | void ball_query_kernel_launcher_stack(int B, int M, float radius, int nsample, 22 | const float *new_xyz, const int *new_xyz_batch_cnt, const float *xyz, const int *xyz_batch_cnt, int *idx); 23 | 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _STACK_CUDA_UTILS_H 2 | #define _STACK_CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define THREADS_PER_BLOCK 256 7 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 8 | 9 | #endif 10 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of point grouping, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "group_points_gpu.h" 13 | 14 | #define CHECK_CUDA(x) do { \ 15 | if (!x.device().is_cuda()) { \ 16 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 17 | exit(-1); \ 18 | } \ 19 | } while (0) 20 | #define CHECK_CONTIGUOUS(x) do { \ 21 | if (!x.is_contiguous()) { \ 22 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 23 | exit(-1); \ 24 | } \ 25 | } while (0) 26 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 27 | 28 | 29 | int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample, 30 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, 31 | at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor) { 32 | 33 | CHECK_INPUT(grad_out_tensor); 34 | CHECK_INPUT(idx_tensor); 35 | CHECK_INPUT(idx_batch_cnt_tensor); 36 | CHECK_INPUT(features_batch_cnt_tensor); 37 | CHECK_INPUT(grad_features_tensor); 38 | 39 | const float *grad_out = grad_out_tensor.data_ptr(); 40 | const int *idx = idx_tensor.data_ptr(); 41 | const int *idx_batch_cnt = idx_batch_cnt_tensor.data_ptr(); 42 | const int *features_batch_cnt = features_batch_cnt_tensor.data_ptr(); 43 | float *grad_features = grad_features_tensor.data_ptr(); 44 | 45 | group_points_grad_kernel_launcher_stack(B, M, C, N, nsample, grad_out, idx, idx_batch_cnt, features_batch_cnt, grad_features); 46 | return 1; 47 | } 48 | 49 | 50 | int group_points_wrapper_stack(int B, int M, int C, int nsample, 51 | at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor, 52 | at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor) { 53 | 54 | CHECK_INPUT(features_tensor); 55 | CHECK_INPUT(features_batch_cnt_tensor); 56 | CHECK_INPUT(idx_tensor); 57 | CHECK_INPUT(idx_batch_cnt_tensor); 58 | CHECK_INPUT(out_tensor); 59 | 60 | const float *features = features_tensor.data_ptr(); 61 | const int *idx = idx_tensor.data_ptr(); 62 | const int *features_batch_cnt = features_batch_cnt_tensor.data_ptr(); 63 | const int *idx_batch_cnt = idx_batch_cnt_tensor.data_ptr(); 64 | float *out = out_tensor.data_ptr(); 65 | 66 | group_points_kernel_launcher_stack(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, out); 67 | return 1; 68 | } -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of point grouping, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #ifndef _STACK_GROUP_POINTS_GPU_H 9 | #define _STACK_GROUP_POINTS_GPU_H 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | int group_points_wrapper_stack(int B, int M, int C, int nsample, 18 | at::Tensor features_tensor, at::Tensor features_batch_cnt_tensor, 19 | at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, at::Tensor out_tensor); 20 | 21 | void group_points_kernel_launcher_stack(int B, int M, int C, int nsample, 22 | const float *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, float *out); 23 | 24 | int group_points_grad_wrapper_stack(int B, int M, int C, int N, int nsample, 25 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor idx_batch_cnt_tensor, 26 | at::Tensor features_batch_cnt_tensor, at::Tensor grad_features_tensor); 27 | 28 | void group_points_grad_kernel_launcher_stack(int B, int M, int C, int N, int nsample, 29 | const float *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, float *grad_features); 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Stacked-batch-data version of point interpolation, modified from the original implementation of official PointNet++ codes. 3 | Written by Shaoshuai Shi 4 | All Rights Reserved 2019-2020. 5 | */ 6 | 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "interpolate_gpu.h" 16 | 17 | #define CHECK_CUDA(x) do { \ 18 | if (!x.device().is_cuda()) { \ 19 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 20 | exit(-1); \ 21 | } \ 22 | } while (0) 23 | #define CHECK_CONTIGUOUS(x) do { \ 24 | if (!x.is_contiguous()) { \ 25 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 26 | exit(-1); \ 27 | } \ 28 | } while (0) 29 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 30 | 31 | 32 | void three_nn_wrapper_stack(at::Tensor unknown_tensor, 33 | at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor, 34 | at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor){ 35 | // unknown: (N1 + N2 ..., 3) 36 | // unknown_batch_cnt: (batch_size), [N1, N2, ...] 37 | // known: (M1 + M2 ..., 3) 38 | // known_batch_cnt: (batch_size), [M1, M2, ...] 39 | // Return: 40 | // dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors 41 | // idx: (N1 + N2 ..., 3) index of the three nearest neighbors 42 | CHECK_INPUT(unknown_tensor); 43 | CHECK_INPUT(unknown_batch_cnt_tensor); 44 | CHECK_INPUT(known_tensor); 45 | CHECK_INPUT(known_batch_cnt_tensor); 46 | CHECK_INPUT(dist2_tensor); 47 | CHECK_INPUT(idx_tensor); 48 | 49 | int batch_size = unknown_batch_cnt_tensor.size(0); 50 | int N = unknown_tensor.size(0); 51 | int M = known_tensor.size(0); 52 | const float *unknown = unknown_tensor.data_ptr(); 53 | const int *unknown_batch_cnt = unknown_batch_cnt_tensor.data_ptr(); 54 | const float *known = known_tensor.data_ptr(); 55 | const int *known_batch_cnt = known_batch_cnt_tensor.data_ptr(); 56 | float *dist2 = dist2_tensor.data_ptr(); 57 | int *idx = idx_tensor.data_ptr(); 58 | 59 | three_nn_kernel_launcher_stack(batch_size, N, M, unknown, unknown_batch_cnt, known, known_batch_cnt, dist2, idx); 60 | } 61 | 62 | 63 | void three_interpolate_wrapper_stack(at::Tensor features_tensor, 64 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) { 65 | // features_tensor: (M1 + M2 ..., C) 66 | // idx_tensor: [N1 + N2 ..., 3] 67 | // weight_tensor: [N1 + N2 ..., 3] 68 | // Return: 69 | // out_tensor: (N1 + N2 ..., C) 70 | CHECK_INPUT(features_tensor); 71 | CHECK_INPUT(idx_tensor); 72 | CHECK_INPUT(weight_tensor); 73 | CHECK_INPUT(out_tensor); 74 | 75 | int N = out_tensor.size(0); 76 | int channels = features_tensor.size(1); 77 | const float *features = features_tensor.data_ptr(); 78 | const float *weight = weight_tensor.data_ptr(); 79 | const int *idx = idx_tensor.data_ptr(); 80 | float *out = out_tensor.data_ptr(); 81 | 82 | three_interpolate_kernel_launcher_stack(N, channels, features, idx, weight, out); 83 | } 84 | 85 | 86 | void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor, 87 | at::Tensor weight_tensor, at::Tensor grad_features_tensor) { 88 | // grad_out_tensor: (N1 + N2 ..., C) 89 | // idx_tensor: [N1 + N2 ..., 3] 90 | // weight_tensor: [N1 + N2 ..., 3] 91 | // Return: 92 | // grad_features_tensor: (M1 + M2 ..., C) 93 | CHECK_INPUT(grad_out_tensor); 94 | CHECK_INPUT(idx_tensor); 95 | CHECK_INPUT(weight_tensor); 96 | CHECK_INPUT(grad_features_tensor); 97 | 98 | int N = grad_out_tensor.size(0); 99 | int channels = grad_out_tensor.size(1); 100 | const float *grad_out = grad_out_tensor.data_ptr(); 101 | const float *weight = weight_tensor.data_ptr(); 102 | const int *idx = idx_tensor.data_ptr(); 103 | float *grad_features = grad_features_tensor.data_ptr(); 104 | 105 | // printf("N=%d, channels=%d\n", N, channels); 106 | three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features); 107 | } -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_stack(at::Tensor unknown_tensor, 11 | at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor, 12 | at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 13 | 14 | 15 | void three_interpolate_wrapper_stack(at::Tensor features_tensor, 16 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 17 | 18 | 19 | 20 | void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor, 21 | at::Tensor weight_tensor, at::Tensor grad_features_tensor); 22 | 23 | 24 | void three_nn_kernel_launcher_stack(int batch_size, int N, int M, const float *unknown, 25 | const int *unknown_batch_cnt, const float *known, const int *known_batch_cnt, 26 | float *dist2, int *idx); 27 | 28 | 29 | void three_interpolate_kernel_launcher_stack(int N, int channels, 30 | const float *features, const int *idx, const float *weight, float *out); 31 | 32 | 33 | 34 | void three_interpolate_grad_kernel_launcher_stack(int N, int channels, const float *grad_out, 35 | const int *idx, const float *weight, float *grad_features); 36 | 37 | 38 | 39 | #endif -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | #include "voxel_query_gpu.h" 9 | #include "vector_pool_gpu.h" 10 | 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack"); 14 | m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack"); 15 | 16 | m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper"); 17 | m.def("stack_farthest_point_sampling_wrapper", &stack_farthest_point_sampling_wrapper, "stack_farthest_point_sampling_wrapper"); 18 | 19 | m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack"); 20 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack"); 21 | 22 | m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack"); 23 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack"); 24 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack"); 25 | 26 | m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack"); 27 | m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack"); 28 | 29 | m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack"); 30 | m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack"); 31 | } 32 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "sampling_gpu.h" 5 | 6 | #define CHECK_CUDA(x) do { \ 7 | if (!x.device().is_cuda()) { \ 8 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 9 | exit(-1); \ 10 | } \ 11 | } while (0) 12 | #define CHECK_CONTIGUOUS(x) do { \ 13 | if (!x.is_contiguous()) { \ 14 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 15 | exit(-1); \ 16 | } \ 17 | } while (0) 18 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 19 | 20 | 21 | int farthest_point_sampling_wrapper(int b, int n, int m, 22 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 23 | 24 | CHECK_INPUT(points_tensor); 25 | CHECK_INPUT(temp_tensor); 26 | CHECK_INPUT(idx_tensor); 27 | 28 | const float *points = points_tensor.data_ptr(); 29 | float *temp = temp_tensor.data_ptr(); 30 | int *idx = idx_tensor.data_ptr(); 31 | 32 | farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx); 33 | return 1; 34 | } 35 | 36 | 37 | int stack_farthest_point_sampling_wrapper(at::Tensor points_tensor, 38 | at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor, 39 | at::Tensor num_sampled_points_tensor) { 40 | 41 | CHECK_INPUT(points_tensor); 42 | CHECK_INPUT(temp_tensor); 43 | CHECK_INPUT(idx_tensor); 44 | CHECK_INPUT(xyz_batch_cnt_tensor); 45 | CHECK_INPUT(num_sampled_points_tensor); 46 | 47 | int batch_size = xyz_batch_cnt_tensor.size(0); 48 | int N = points_tensor.size(0); 49 | const float *points = points_tensor.data_ptr(); 50 | float *temp = temp_tensor.data_ptr(); 51 | int *xyz_batch_cnt = xyz_batch_cnt_tensor.data_ptr(); 52 | int *idx = idx_tensor.data_ptr(); 53 | int *num_sampled_points = num_sampled_points_tensor.data_ptr(); 54 | 55 | stack_farthest_point_sampling_kernel_launcher(N, batch_size, points, temp, xyz_batch_cnt, idx, num_sampled_points); 56 | return 1; 57 | } -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int farthest_point_sampling_wrapper(int b, int n, int m, 10 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 11 | 12 | void farthest_point_sampling_kernel_launcher(int b, int n, int m, 13 | const float *dataset, float *temp, int *idxs); 14 | 15 | int stack_farthest_point_sampling_wrapper( 16 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, 17 | at::Tensor idx_tensor, at::Tensor num_sampled_points_tensor); 18 | 19 | 20 | void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size, 21 | const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points); 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/vector_pool_gpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | Vector-pool aggregation based local feature aggregation for point cloud. 3 | PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection 4 | https://arxiv.org/abs/2102.00463 5 | 6 | Written by Shaoshuai Shi 7 | All Rights Reserved 2020. 8 | */ 9 | 10 | 11 | #ifndef _STACK_VECTOR_POOL_GPU_H 12 | #define _STACK_VECTOR_POOL_GPU_H 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | int query_stacked_local_neighbor_idxs_kernel_launcher_stack( 21 | const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt, 22 | int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs, 23 | float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type); 24 | 25 | int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, 26 | at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, 27 | at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor, 28 | int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type); 29 | 30 | 31 | int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack( 32 | const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers, 33 | int *new_xyz_grid_idxs, float *new_xyz_grid_dist2, 34 | const int *stack_neighbor_idxs, const int *start_len, 35 | int M, int num_total_grids); 36 | 37 | int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor, 38 | at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor, 39 | at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor, 40 | at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, 41 | int M, int num_total_grids); 42 | 43 | 44 | int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, 45 | at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, 46 | at::Tensor new_features_tensor, at::Tensor new_local_xyz, 47 | at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, 48 | int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz, 49 | int num_max_sum_points, int nsample, int neighbor_type, int pooling_type); 50 | 51 | 52 | int vector_pool_kernel_launcher_stack( 53 | const float *support_xyz, const float *support_features, const int *xyz_batch_cnt, 54 | const float *new_xyz, float *new_features, float * new_local_xyz, const int *new_xyz_batch_cnt, 55 | int *point_cnt_of_grid, int *grouped_idxs, 56 | int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, 57 | int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids, int use_xyz, 58 | int num_max_sum_points, int nsample, int neighbor_type, int pooling_type); 59 | 60 | 61 | int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor, 62 | at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, 63 | at::Tensor grad_support_features_tensor); 64 | 65 | 66 | void vector_pool_grad_kernel_launcher_stack( 67 | const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs, 68 | float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids, 69 | int num_max_sum_points); 70 | 71 | #endif 72 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/voxel_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "voxel_query_gpu.h" 9 | 10 | #define CHECK_CUDA(x) do { \ 11 | if (!x.device().is_cuda()) { \ 12 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 13 | exit(-1); \ 14 | } \ 15 | } while (0) 16 | #define CHECK_CONTIGUOUS(x) do { \ 17 | if (!x.is_contiguous()) { \ 18 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 19 | exit(-1); \ 20 | } \ 21 | } while (0) 22 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 23 | 24 | 25 | int voxel_query_wrapper_stack(int M, int R1, int R2, int R3, int nsample, float radius, 26 | int z_range, int y_range, int x_range, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, 27 | at::Tensor new_coords_tensor, at::Tensor point_indices_tensor, at::Tensor idx_tensor) { 28 | CHECK_INPUT(new_coords_tensor); 29 | CHECK_INPUT(point_indices_tensor); 30 | CHECK_INPUT(new_xyz_tensor); 31 | CHECK_INPUT(xyz_tensor); 32 | 33 | const float *new_xyz = new_xyz_tensor.data_ptr(); 34 | const float *xyz = xyz_tensor.data_ptr(); 35 | const int *new_coords = new_coords_tensor.data_ptr(); 36 | const int *point_indices = point_indices_tensor.data_ptr(); 37 | int *idx = idx_tensor.data_ptr(); 38 | 39 | voxel_query_kernel_launcher_stack(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx); 40 | return 1; 41 | } 42 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/voxel_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "voxel_query_gpu.h" 7 | #include "cuda_utils.h" 8 | 9 | 10 | __global__ void voxel_query_kernel_stack(int M, int R1, int R2, int R3, int nsample, 11 | float radius, int z_range, int y_range, int x_range, const float *new_xyz, 12 | const float *xyz, const int *new_coords, const int *point_indices, int *idx) { 13 | // :param new_coords: (M1 + M2 ..., 4) centers of the ball query 14 | // :param point_indices: (B, Z, Y, X) 15 | // output: 16 | // idx: (M1 + M2, nsample) 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (pt_idx >= M) return; 19 | 20 | new_xyz += pt_idx * 3; 21 | new_coords += pt_idx * 4; 22 | idx += pt_idx * nsample; 23 | 24 | curandState state; 25 | curand_init(pt_idx, 0, 0, &state); 26 | 27 | float radius2 = radius * radius; 28 | float new_x = new_xyz[0]; 29 | float new_y = new_xyz[1]; 30 | float new_z = new_xyz[2]; 31 | 32 | int batch_idx = new_coords[0]; 33 | int new_coords_z = new_coords[1]; 34 | int new_coords_y = new_coords[2]; 35 | int new_coords_x = new_coords[3]; 36 | 37 | int cnt = 0; 38 | int cnt2 = 0; 39 | // for (int dz = -1*z_range; dz <= z_range; ++dz) { 40 | for (int dz = -1*z_range; dz <= z_range; ++dz) { 41 | int z_coord = new_coords_z + dz; 42 | if (z_coord < 0 || z_coord >= R1) continue; 43 | 44 | for (int dy = -1*y_range; dy <= y_range; ++dy) { 45 | int y_coord = new_coords_y + dy; 46 | if (y_coord < 0 || y_coord >= R2) continue; 47 | 48 | for (int dx = -1*x_range; dx <= x_range; ++dx) { 49 | int x_coord = new_coords_x + dx; 50 | if (x_coord < 0 || x_coord >= R3) continue; 51 | 52 | int index = batch_idx * R1 * R2 * R3 + \ 53 | z_coord * R2 * R3 + \ 54 | y_coord * R3 + \ 55 | x_coord; 56 | int neighbor_idx = point_indices[index]; 57 | if (neighbor_idx < 0) continue; 58 | 59 | float x_per = xyz[neighbor_idx*3 + 0]; 60 | float y_per = xyz[neighbor_idx*3 + 1]; 61 | float z_per = xyz[neighbor_idx*3 + 2]; 62 | 63 | float dist2 = (x_per - new_x) * (x_per - new_x) + (y_per - new_y) * (y_per - new_y) + (z_per - new_z) * (z_per - new_z); 64 | 65 | if (dist2 > radius2) continue; 66 | 67 | ++cnt2; 68 | 69 | if (cnt < nsample) { 70 | if (cnt == 0) { 71 | for (int l = 0; l < nsample; ++l) { 72 | idx[l] = neighbor_idx; 73 | } 74 | } 75 | idx[cnt] = neighbor_idx; 76 | ++cnt; 77 | } 78 | // else { 79 | // float rnd = curand_uniform(&state); 80 | // if (rnd < (float(nsample) / cnt2)) { 81 | // int insertidx = ceilf(curand_uniform(&state) * nsample) - 1; 82 | // idx[insertidx] = neighbor_idx; 83 | // } 84 | // } 85 | } 86 | } 87 | } 88 | if (cnt == 0) idx[0] = -1; 89 | } 90 | 91 | 92 | void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample, 93 | float radius, int z_range, int y_range, int x_range, const float *new_xyz, 94 | const float *xyz, const int *new_coords, const int *point_indices, int *idx) { 95 | // :param new_coords: (M1 + M2 ..., 4) centers of the voxel query 96 | // :param point_indices: (B, Z, Y, X) 97 | // output: 98 | // idx: (M1 + M2, nsample) 99 | 100 | cudaError_t err; 101 | 102 | dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) 103 | dim3 threads(THREADS_PER_BLOCK); 104 | 105 | voxel_query_kernel_stack<<>>(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx); 106 | // cudaDeviceSynchronize(); // for using printf in kernel function 107 | 108 | err = cudaGetLastError(); 109 | if (cudaSuccess != err) { 110 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 111 | exit(-1); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /detection/voxel_rcnn/pointnet2_stack/src/voxel_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _STACK_VOXEL_QUERY_GPU_H 2 | #define _STACK_VOXEL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int voxel_query_wrapper_stack(int M, int R1, int R2, int R3, int nsample, float radius, 10 | int z_range, int y_range, int x_range, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, 11 | at::Tensor new_coords_tensor, at::Tensor point_indices_tensor, at::Tensor idx_tensor); 12 | 13 | 14 | void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample, 15 | float radius, int z_range, int y_range, int x_range, const float *new_xyz, 16 | const float *xyz, const int *new_coords, const int *point_indices, int *idx); 17 | 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /eval_kitti_detection.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | from detection.kitti_object_eval_python.eval import get_official_eval_result 5 | from utils import get_objects_from_label 6 | 7 | 8 | def get_label_annos(label_dir, sample_ids=None): 9 | label_dir = Path(label_dir) 10 | if sample_ids is None: 11 | label_files = label_dir.glob("*.txt") 12 | else: 13 | label_files = [ 14 | (label_dir / f"{str(sample_id).zfill(6)}.txt") for sample_id in sample_ids 15 | ] 16 | return [ 17 | get_objects_from_label(label_file, as_dict=True) for label_file in label_files 18 | ] 19 | 20 | 21 | def evaluate(label_path, result_path, label_split_file, classes=(0,)): 22 | dt_annos = get_label_annos(result_path) 23 | if label_split_file is None: 24 | gt_sample_ids = None 25 | else: 26 | with open(label_split_file, "r") as f: 27 | lines = f.readlines() 28 | gt_sample_ids = [int(line) for line in lines] 29 | gt_annos = get_label_annos(label_path, gt_sample_ids) 30 | res_str = get_official_eval_result(gt_annos, dt_annos, classes) 31 | return res_str 32 | 33 | 34 | def main(): 35 | parser = ArgumentParser(description="arg parser") 36 | parser.add_argument("label_path", type=str) 37 | parser.add_argument("result_path", type=str) 38 | parser.add_argument("--label_split", type=str) 39 | parser.add_argument("--classes", type=int, default=0) 40 | args = parser.parse_args() 41 | print(args.label_path, args.result_path) 42 | print(evaluate(args.label_path, args.result_path, args.label_split, args.classes)) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /eval_nusc_detection.py: -------------------------------------------------------------------------------- 1 | from nuscenes import NuScenes 2 | from nuscenes.eval.detection.config import config_factory 3 | from nuscenes.eval.detection.evaluate import DetectionEval 4 | 5 | 6 | def main(): 7 | nusc = NuScenes("v1.0-mini", dataroot="data/nuscenes_mini", verbose=False) 8 | # nusc = NuScenes("v1.0-trainval", dataroot="data/nuscenes", verbose=True) 9 | nusc_eval = DetectionEval( 10 | nusc, 11 | config=config_factory("detection_cvpr_2019"), 12 | # result_path="data/nuscenes/results_nusc_val.json", 13 | # result_path="data/nuscenes/mini_val_results.json", 14 | result_path="output/nuscenes/mini/mini/results.json", 15 | eval_set="mini_val", 16 | # eval_set="val", 17 | output_dir="output/nuscenes_det/det", 18 | # output_dir="output/nuscenes_det/val", 19 | verbose=True, 20 | ) 21 | nusc_eval.main(render_curves=False) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /eval_nusc_tracking.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | 4 | from nuscenes.eval.common.config import config_factory 5 | from nuscenes.eval.tracking.evaluate import TrackingEval 6 | 7 | TRK_CLASSES = [ 8 | "bicycle", 9 | "motorcycle", 10 | "pedestrian", 11 | "bus", 12 | "car", 13 | "trailer", 14 | "truck", 15 | ] 16 | 17 | 18 | def main(): 19 | parser = ArgumentParser() 20 | parser.add_argument("split", type=str) 21 | parser.add_argument("tag", type=str) 22 | args = parser.parse_args() 23 | 24 | with open(f"output/nuscenes/{args.split}/{args.tag}/results.json", "r") as f: 25 | data = json.load(f) 26 | results = data["results"] 27 | for sample_token, objs in results.items(): 28 | for obj in objs: 29 | if "detection_name" in obj: 30 | obj["tracking_name"] = obj["detection_name"] 31 | if "detection_score" in obj: 32 | obj["tracking_score"] = obj["detection_score"] 33 | results[sample_token] = [ 34 | obj for obj in objs if obj["tracking_name"] in TRK_CLASSES 35 | ] 36 | data["results"] = results 37 | with open( 38 | f"output/nuscenes/{args.split}/{args.tag}/tracking_results.json", "w" 39 | ) as f: 40 | json.dump(data, f) 41 | 42 | track_eval = TrackingEval( 43 | config=config_factory("tracking_nips_2019"), 44 | result_path=f"output/nuscenes/{args.split}/{args.tag}/tracking_results.json", 45 | eval_set=args.split, 46 | output_dir=f"output/nuscenes/{args.split}/{args.tag}", 47 | nusc_version="v1.0-trainval", 48 | nusc_dataroot="data/nuscenes", 49 | ) 50 | track_eval.main(render_curves=False) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /kitti_2d_3d_det_fusion.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from argparse import ArgumentParser 4 | from configparser import ConfigParser 5 | from pathlib import Path 6 | 7 | import tqdm 8 | 9 | from tracking.detections.kitti_detections import get_detection_data 10 | from utils import Calibration, get_poses_from_file, read_seqmap_file 11 | 12 | 13 | def main(): 14 | parser = ArgumentParser() 15 | parser.add_argument("config", type=str) 16 | parser.add_argument("split", type=str) 17 | args = parser.parse_args() 18 | 19 | config = ConfigParser() 20 | config.read(args.config) 21 | 22 | root_dir = Path(config["data"]["root_dir"]) / args.split 23 | 24 | calib_dir = root_dir / "calib" 25 | oxts_dir = root_dir / "oxts" 26 | img_hw_dict = json.load(open(root_dir / "img_hw.json")) 27 | 28 | detection_cfg = config["detection"] 29 | 30 | det3d_name = detection_cfg["det3d_name"] 31 | det3d_dir = root_dir / "det3d_out" / det3d_name 32 | crop_dir = root_dir / "cropped_points" / det3d_name 33 | det2d_dir = root_dir / "det2d_out" / detection_cfg["det2d_name"] 34 | det2d_emb_dir = root_dir / "det2d_emb_out" / detection_cfg["det2d_emb_name"] 35 | seg_out_dir = root_dir / "seg_out" / detection_cfg["seg_name"] 36 | seg_emb_dir = root_dir / "seg_emb_out" / detection_cfg["seg_emb_name"] 37 | det3d_save_name = detection_cfg["det3d_save_name"] 38 | det3d_save_dir = root_dir / "det3d_out" / det3d_save_name 39 | 40 | use_lidar = detection_cfg.getboolean("use_lidar") 41 | use_inst = detection_cfg.getboolean("use_inst") 42 | use_det2d = detection_cfg.getboolean("use_det2d") 43 | assert not (use_inst and use_det2d) 44 | use_embed = detection_cfg.getboolean("use_embed") 45 | 46 | if use_inst: 47 | emb_dir = seg_emb_dir 48 | elif use_det2d: 49 | emb_dir = det2d_emb_dir 50 | else: 51 | use_embed = False 52 | 53 | seqmap_file = root_dir / f"evaluate_tracking.seqmap.{args.split}" 54 | frame_num_dict = read_seqmap_file(seqmap_file) 55 | 56 | for seq in frame_num_dict: 57 | # (seq 0001: missing 177 178 179 180) 58 | seq_det3d_dir = det3d_dir / seq 59 | frames = [f.stem for f in seq_det3d_dir.iterdir()] 60 | num_frames = len(frames) 61 | seq_det3d_save_dir: Path = det3d_save_dir / seq 62 | seq_det3d_save_dir.mkdir(parents=True, exist_ok=True) 63 | 64 | if detection_cfg.getboolean("use_pose"): 65 | # Reads imu data and converts to poses 66 | poses = get_poses_from_file(oxts_dir / f"{seq}.txt") 67 | else: 68 | poses = [None for _ in range(num_frames)] 69 | 70 | # seq calibration data 71 | calib = Calibration(calib_dir / f"{seq}.txt") 72 | 73 | seq_img_hw_dict = img_hw_dict[seq] 74 | 75 | good_dets_3d = {} 76 | bad_dets_3d = {} 77 | 78 | pbar = tqdm.tqdm(range(num_frames)) 79 | pbar.set_description(seq) 80 | for idx in pbar: 81 | # Retrieves the current frame info 82 | frame = frames[idx] 83 | 84 | # Gets detections from files 85 | cur_good_dets, cur_bad_dets_3d, _ = get_detection_data( 86 | det3d_file=seq_det3d_dir / f"{frame}.txt", 87 | calib=calib, 88 | pose=poses[idx], 89 | img_hw=seq_img_hw_dict[frame], 90 | lidar_dir=crop_dir / seq / frame if use_lidar else None, 91 | det2d_file=det2d_dir / seq / f"{frame}.txt" if use_det2d else None, 92 | seg_file=seg_out_dir / seq / f"{frame}.png" if use_inst else None, 93 | embed_dir=emb_dir / seq / frame if use_embed else None, 94 | min_corr_pts=detection_cfg.getfloat("min_corr_pts"), 95 | min_corr_iou=detection_cfg.getfloat("min_corr_iou"), 96 | raw_score=detection_cfg.getboolean("raw_score"), 97 | score_thresh=detection_cfg.getfloat("score_thresh"), 98 | recover_score_thresh=detection_cfg.getfloat("recover_score_thresh"), 99 | ) 100 | 101 | with open(seq_det3d_save_dir / f"{frame}.txt", "w") as f: 102 | lines = [] 103 | if cur_good_dets is not None: 104 | lines = [trk.to_obj().serialize() for trk in cur_good_dets.objs] 105 | f.writelines(lines) 106 | 107 | good_dets_3d[frame] = cur_good_dets 108 | bad_dets_3d[frame] = cur_bad_dets_3d 109 | 110 | with open(seq_det3d_save_dir / "good_dets_3d.pkl", "wb") as f: 111 | pickle.dump(good_dets_3d, f) 112 | 113 | with open(seq_det3d_save_dir / "bad_dets_3d.pkl", "wb") as f: 114 | pickle.dump(bad_dets_3d, f) 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /kitti_trajectory_refinement.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | from argparse import ArgumentParser 4 | from configparser import ConfigParser 5 | from pathlib import Path 6 | 7 | import tqdm 8 | 9 | from tracking.trajectory_clustering_split_and_recombination import \ 10 | merge_forward_backward_trajectories 11 | from tracking.trajectory_completion import linear_interpolation 12 | from tracking.trajectory_refinement import (box_size_weighted_mean, 13 | gaussian_smoothing) 14 | from utils import (Calibration, read_kitti_trajectories_from_file, 15 | read_seqmap_file, write_kitti_trajectories_to_file) 16 | 17 | 18 | def main(): 19 | parser = ArgumentParser() 20 | parser.add_argument("config", type=str) 21 | parser.add_argument("tag", type=str) 22 | parser.add_argument("split", type=str) 23 | parser.add_argument("forward_tag", type=str) 24 | parser.add_argument("backward_tag", type=str) 25 | args = parser.parse_args() 26 | 27 | config = ConfigParser() 28 | config.read(args.config) 29 | refinement_cfg = config["refinement"] 30 | visualization_cfg = config["visualization"] 31 | 32 | out_dir = Path(f"output/kitti/{args.split}/{args.tag}") 33 | if out_dir.exists(): 34 | shutil.rmtree(out_dir) 35 | out_data_dir = out_dir / "data" 36 | out_data_dir.mkdir(parents=True) 37 | forward_dir = Path(f"output/kitti/{args.split}/{args.forward_tag}") 38 | backward_dir = Path(f"output/kitti/{args.split}/{args.backward_tag}") 39 | forward_data_dir = forward_dir / "data" 40 | backward_data_dir = backward_dir / "data" 41 | 42 | root_dir = Path(config["data"]["root_dir"]) 43 | split_dir = root_dir / ("testing" if "test" in args.split else "training") 44 | calib_dir = split_dir / "calib" 45 | img_hw_dict = json.load(open(split_dir / "img_hw.json")) 46 | 47 | seqmap_file = split_dir / f"evaluate_tracking.seqmap.{args.split}" 48 | frame_num_dict = read_seqmap_file(seqmap_file) 49 | raw_score = config["detection"].getboolean("raw_score") 50 | 51 | for seq in tqdm.tqdm(frame_num_dict): 52 | calib = Calibration(calib_dir / f"{seq}.txt") 53 | cur_img_hw_dict = img_hw_dict[seq] 54 | trajectories = read_kitti_trajectories_from_file( 55 | seq, forward_data_dir, calib, cur_img_hw_dict, raw_score=raw_score 56 | ) 57 | if refinement_cfg.getboolean("merge"): 58 | backward_trajectories = read_kitti_trajectories_from_file( 59 | seq, backward_data_dir, calib, cur_img_hw_dict, raw_score=raw_score 60 | ) 61 | trajectories = merge_forward_backward_trajectories( 62 | trajectories, 63 | backward_trajectories, 64 | visualize_contradictions=visualization_cfg.getboolean("contradiction"), 65 | ) 66 | if refinement_cfg.getboolean("box_size_fusion"): 67 | trajectories = box_size_weighted_mean( 68 | trajectories, 69 | calib, 70 | cur_img_hw_dict, 71 | exponent=refinement_cfg.getfloat("exponent"), 72 | ) 73 | if refinement_cfg.getboolean("interp"): 74 | trajectories = linear_interpolation( 75 | trajectories, 76 | calib, 77 | img_hw_dict=cur_img_hw_dict, 78 | interp_max_interval=refinement_cfg.getint("interp_max_interval"), 79 | score_thresh=refinement_cfg.getfloat("score_thresh"), 80 | ignore_thresh=refinement_cfg.getfloat("ignore_thresh"), 81 | nms_thresh=refinement_cfg.getfloat("nms_thresh"), 82 | visualize=visualization_cfg.getboolean("interpolation"), 83 | ) 84 | if refinement_cfg.getboolean("smooth"): 85 | trajectories = gaussian_smoothing( 86 | trajectories, calib, cur_img_hw_dict, refinement_cfg.getfloat("tau") 87 | ) 88 | write_kitti_trajectories_to_file(seq, trajectories, out_data_dir) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /mmdet_kitti_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from mmdet.apis import inference_detector, init_detector 10 | from mmdet.registry import VISUALIZERS 11 | from PIL import Image 12 | 13 | 14 | def parse_args(): 15 | parser = ArgumentParser() 16 | parser.add_argument("root_dir", help="Root dir") 17 | parser.add_argument("name", help="model name") 18 | parser.add_argument("config", help="Config file") 19 | parser.add_argument("checkpoint", help="Checkpoint file") 20 | parser.add_argument("--device", default="cuda:0", help="Device used for inference") 21 | parser.add_argument("--visualize", action="store_true") 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(args): 27 | # build the model from a config file and a checkpoint file 28 | model = init_detector(args.config, args.checkpoint, device=args.device) 29 | root_dir = Path(args.root_dir) 30 | img_dir = root_dir / "image_02" 31 | assert img_dir.exists() 32 | det_out_dir = root_dir / "det2d_out" / args.name 33 | seg_out_dir = root_dir / "seg_out" / args.name 34 | 35 | if args.visualize: 36 | visualizer = VISUALIZERS.build(model.cfg.visualizer) 37 | visualizer.dataset_meta = model.dataset_meta 38 | 39 | for seq in range(21): 40 | seq = str(seq).zfill(4) 41 | pbar = tqdm.tqdm(list((img_dir / seq).iterdir())) 42 | pbar.set_description(seq) 43 | for img_file in pbar: 44 | image = mmcv.imread(img_file, channel_order="rgb") 45 | result = inference_detector(model, image) 46 | if args.visualize: 47 | visualizer.add_datasample( 48 | "result", 49 | image, 50 | data_sample=result, 51 | draw_gt=None, 52 | wait_time=0, 53 | ) 54 | visualizer.show() 55 | boxes = result.pred_instances.bboxes.cpu() 56 | labels = result.pred_instances.labels.cpu() 57 | scores = result.pred_instances.scores.cpu() 58 | masks = result.pred_instances.masks.cpu() 59 | 60 | class_mask = torch.isin( 61 | labels, torch.tensor([2, 5, 7]) 62 | ) # cars, buses, trucks 63 | boxes = boxes[class_mask].tolist() 64 | scores = scores[class_mask].tolist() 65 | labels = labels[class_mask].tolist() 66 | masks = masks[class_mask].numpy() 67 | 68 | lines = [] 69 | seg_img = np.zeros(image.shape[:2], dtype=np.uint8) 70 | for i in range(len(boxes)): 71 | lines.append( 72 | " ".join([str(x) for x in boxes[i]]) + f" {scores[i]} {labels[i]}\n" 73 | ) 74 | seg_img[masks[i]] = i + 1 75 | 76 | seq_det_out_dir = det_out_dir / seq 77 | seq_det_out_dir.mkdir(exist_ok=True, parents=True) 78 | 79 | seq_seg_out_dir = seg_out_dir / seq 80 | seq_seg_out_dir.mkdir(exist_ok=True, parents=True) 81 | 82 | with open(seq_det_out_dir / f"{img_file.stem}.txt", "w") as f: 83 | f.writelines(lines) 84 | 85 | Image.fromarray(seg_img).save(seq_seg_out_dir / f"{img_file.stem}.png") 86 | 87 | 88 | if __name__ == "__main__": 89 | args = parse_args() 90 | main(args) 91 | -------------------------------------------------------------------------------- /mots_tools/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Visual Computing Institute 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mots_tools/README.md: -------------------------------------------------------------------------------- 1 | # mots_tools 2 | Tools for evaluating and visualizing results for the Multi Object Tracking and Segmentation (MOTS) task. 3 | 4 | For the TrackR-CNN code please visit https://github.com/VisualComputingInstitute/TrackR-CNN 5 | 6 | ## Project website (including annotations) 7 | https://www.vision.rwth-aachen.de/page/mots 8 | 9 | ## Paper 10 | https://www.vision.rwth-aachen.de/media/papers/mots-multi-object-tracking-and-segmentation/MOTS.pdf 11 | 12 | ## Using the mots_tools 13 | Please install the cocotools (https://github.com/cocodataset/cocoapi), which we use with run-length encoded binary masks. If you want to visualize your results using this script, please also install FFmpeg. 14 | 15 | In order to evaluate or visualize the results of your MOTS method, please export them in one of the two formats we use for the ground truth annotations: png or txt (see https://www.vision.rwth-aachen.de/page/mots). When using png, we expect the result images to be in subfolders corresponding to the sequences (e.g. tracking_results/0002/000000.png, tracking_results/0002/000001.png, ...). When using txt, we expect filenames corresponding to the sequences (e.g. tracking_results/0002.txt, tracking_results/0006.txt, ...). 16 | 17 | ### Evaluating a tracking result 18 | Clone this repository, navigate to the mots_tools directory and make sure it is in your Python path. 19 | Now suppose your tracking results are located in a folder "tracking_results". Suppose further the ground truth annotations are located in a folder "gt_folder". Then you can evaluate your results using the commands 20 | ``` 21 | python mots_eval/eval.py tracking_results gt_folder seqmap 22 | ``` 23 | where "seqmap" is a textfile containing the sequences which you want to evaluate on. Several seqmaps are already provided in the mots_eval repository: val.seqmap, train.seqmap, fulltrain.seqmap, val_MOTSchallenge.seqmap which correspond to the KITTI MOTS validation set, the KITTI MOTS training set, both KITTI MOTS sets combined and the four annotated MOTSChallenge sequences respectively. 24 | 25 | Parts of the evaluation logic are built upon the KITTI 2D tracking evaluation devkit from http://www.cvlibs.net/datasets/kitti/eval_tracking.php 26 | 27 | ### Visualizing a tracking result 28 | Similarly to evaluating tracking results, you can also create visualizations using 29 | ``` 30 | python mots_vis/visualize_mots.py tracking_results img_folder output_folder seqmap 31 | ``` 32 | where "img_folder" is a folder containing the original KITTI tracking images (http://www.cvlibs.net/download.php?file=data_tracking_image_2.zip) and "output_folder" is a folder where the resulting visualization will be created. 33 | ## Citation 34 | If you use this code, please cite: 35 | ``` 36 | @inproceedings{Voigtlaender19CVPR_MOTS, 37 | author = {Paul Voigtlaender and Michael Krause and Aljo\u{s}a O\u{s}ep and Jonathon Luiten and Berin Balachandar Gnana Sekar and Andreas Geiger and Bastian Leibe}, 38 | title = {{MOTS}: Multi-Object Tracking and Segmentation}, 39 | booktitle = {CVPR}, 40 | year = {2019}, 41 | } 42 | ``` 43 | 44 | ## License 45 | MIT License 46 | 47 | ## Contact 48 | If you find a problem in the code, please open an issue. 49 | 50 | For general questions, please contact Paul Voigtlaender (voigtlaender@vision.rwth-aachen.de) or Michael Krause (michael.krause@rwth-aachen.de) 51 | -------------------------------------------------------------------------------- /mots_tools/eval.py: -------------------------------------------------------------------------------- 1 | import pycocotools.mask as rletools 2 | import sys 3 | from mots_common.io import load_seqmap, load_sequences 4 | from mots_eval.MOTS_metrics import compute_MOTS_metrics 5 | 6 | 7 | IGNORE_CLASS = 10 8 | 9 | 10 | def mask_iou(a, b, criterion="union"): 11 | is_crowd = criterion != "union" 12 | return rletools.iou([a.mask], [b.mask], [is_crowd])[0][0] 13 | 14 | 15 | def evaluate_class(gt, results, max_frames, class_id): 16 | _, results_obj = compute_MOTS_metrics(gt, results, max_frames, class_id, IGNORE_CLASS, mask_iou) 17 | return results_obj 18 | 19 | 20 | def run_eval(results_folder, gt_folder, seqmap_filename): 21 | seqmap, max_frames = load_seqmap(seqmap_filename) 22 | print("Loading ground truth...") 23 | gt = load_sequences(gt_folder, seqmap) 24 | print("Loading results...") 25 | results = load_sequences(results_folder, seqmap) 26 | print("Compute KITTI tracking eval with simplified matching and MOTSA") 27 | print("Evaluate class: Cars") 28 | results_cars = evaluate_class(gt, results, max_frames, 1) 29 | print("Evaluate class: Pedestrians") 30 | results_ped = evaluate_class(gt, results, max_frames, 2) 31 | 32 | 33 | 34 | if __name__ == "__main__": 35 | if len(sys.argv) != 4: 36 | print("Usage: python eval.py results_folder gt_folder seqmap") 37 | sys.exit(1) 38 | 39 | results_folder = sys.argv[1] 40 | gt_folder = sys.argv[2] 41 | seqmap_filename = sys.argv[3] 42 | 43 | run_eval(results_folder, gt_folder, seqmap_filename) 44 | -------------------------------------------------------------------------------- /mots_tools/mots_common/images_to_txt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from mots_common.io import load_sequences, load_seqmap, write_sequences 3 | 4 | 5 | if __name__ == "__main__": 6 | if len(sys.argv) != 4: 7 | print("Usage: python images_to_txt.py gt_img_folder gt_txt_output_folder seqmap") 8 | sys.exit(1) 9 | 10 | gt_img_folder = sys.argv[1] 11 | gt_txt_output_folder = sys.argv[2] 12 | seqmap_filename = sys.argv[3] 13 | 14 | seqmap, _ = load_seqmap(seqmap_filename) 15 | print("Loading ground truth images...") 16 | gt = load_sequences(gt_img_folder, seqmap) 17 | print("Writing ground truth txts...") 18 | write_sequences(gt, gt_txt_output_folder) 19 | -------------------------------------------------------------------------------- /mots_tools/mots_common/io.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import PIL.Image as Image 6 | import pycocotools.mask as rletools 7 | 8 | 9 | class SegmentedObject: 10 | def __init__(self, mask, class_id, track_id): 11 | self.mask = mask 12 | self.class_id = class_id 13 | self.track_id = track_id 14 | 15 | 16 | def load_sequences(path, seqmap): 17 | objects_per_frame_per_sequence = {} 18 | for seq in seqmap: 19 | seq_path_folder = os.path.join(path, seq) 20 | seq_path_txt = os.path.join(path, seq + ".txt") 21 | if os.path.isdir(seq_path_folder): 22 | objects_per_frame_per_sequence[seq] = load_images_for_folder(seq_path_folder) 23 | elif os.path.exists(seq_path_txt): 24 | objects_per_frame_per_sequence[seq] = load_txt(seq_path_txt) 25 | else: 26 | assert False, "Can't find data in directory " + path 27 | 28 | return objects_per_frame_per_sequence 29 | 30 | 31 | def load_txt(path): 32 | objects_per_frame = {} 33 | track_ids_per_frame = {} # To check that no frame contains two objects with same id 34 | combined_mask_per_frame = {} # To check that no frame contains overlapping masks 35 | with open(path, "r") as f: 36 | for line in f: 37 | line = line.strip() 38 | fields = line.split(" ") 39 | 40 | frame = int(fields[0]) 41 | if frame not in objects_per_frame: 42 | objects_per_frame[frame] = [] 43 | if frame not in track_ids_per_frame: 44 | track_ids_per_frame[frame] = set() 45 | if int(fields[1]) in track_ids_per_frame[frame]: 46 | assert False, "Multiple objects with track id " + fields[1] + " in frame " + fields[0] 47 | else: 48 | track_ids_per_frame[frame].add(int(fields[1])) 49 | 50 | class_id = int(fields[2]) 51 | if not(class_id == 1 or class_id == 2 or class_id == 10): 52 | assert False, "Unknown object class " + fields[2] 53 | 54 | mask = {'size': [int(fields[3]), int(fields[4])], 'counts': fields[5].encode(encoding='UTF-8')} 55 | if frame not in combined_mask_per_frame: 56 | combined_mask_per_frame[frame] = mask 57 | elif rletools.area(rletools.merge([combined_mask_per_frame[frame], mask], intersect=True)) > 0.0: 58 | assert False, "Objects with overlapping masks in frame " + fields[0] 59 | else: 60 | combined_mask_per_frame[frame] = rletools.merge([combined_mask_per_frame[frame], mask], intersect=False) 61 | objects_per_frame[frame].append(SegmentedObject( 62 | mask, 63 | class_id, 64 | int(fields[1]) 65 | )) 66 | 67 | return objects_per_frame 68 | 69 | 70 | def load_images_for_folder(path): 71 | files = sorted(glob.glob(os.path.join(path, "*.png"))) 72 | 73 | objects_per_frame = {} 74 | for file in files: 75 | objects = load_image(file) 76 | frame = filename_to_frame_nr(os.path.basename(file)) 77 | objects_per_frame[frame] = objects 78 | 79 | return objects_per_frame 80 | 81 | 82 | def filename_to_frame_nr(filename): 83 | assert len(filename) == 10, "Expect filenames to have format 000000.png, 000001.png, ..." 84 | return int(filename.split('.')[0]) 85 | 86 | 87 | def load_image(filename, id_divisor=1000): 88 | img = np.array(Image.open(filename)) 89 | obj_ids = np.unique(img) 90 | 91 | objects = [] 92 | mask = np.zeros(img.shape, dtype=np.uint8, order="F") # Fortran order needed for pycocos RLE tools 93 | for idx, obj_id in enumerate(obj_ids): 94 | if obj_id == 0: # background 95 | continue 96 | mask.fill(0) 97 | pixels_of_elem = np.where(img == obj_id) 98 | mask[pixels_of_elem] = 1 99 | objects.append(SegmentedObject( 100 | rletools.encode(mask), 101 | obj_id // id_divisor, 102 | obj_id 103 | )) 104 | 105 | return objects 106 | 107 | 108 | def load_seqmap(seqmap_filename): 109 | print("Loading seqmap...") 110 | seqmap = [] 111 | max_frames = {} 112 | with open(seqmap_filename, "r") as fh: 113 | for i, l in enumerate(fh): 114 | fields = l.split(" ") 115 | seq = "%04d" % int(fields[0]) 116 | seqmap.append(seq) 117 | max_frames[seq] = int(fields[3]) 118 | return seqmap, max_frames 119 | 120 | 121 | def write_sequences(gt, output_folder): 122 | os.makedirs(output_folder, exist_ok=True) 123 | for seq, seq_frames in gt.items(): 124 | write_sequence(seq_frames, os.path.join(output_folder, seq + ".txt")) 125 | return 126 | 127 | 128 | def write_sequence(frames, path): 129 | with open(path, "w") as f: 130 | for t, objects in frames.items(): 131 | for obj in objects: 132 | print(t, obj.track_id, obj.class_id, obj.mask["size"][0], obj.mask["size"][1], 133 | obj.mask["counts"].decode(encoding='UTF-8'), file=f) 134 | -------------------------------------------------------------------------------- /mots_tools/test.seqmap: -------------------------------------------------------------------------------- 1 | 0000 empty 000000 000464 2 | 0001 empty 000000 000146 3 | 0002 empty 000000 000242 4 | 0003 empty 000000 000256 5 | 0004 empty 000000 000420 6 | 0005 empty 000000 000808 7 | 0006 empty 000000 000113 8 | 0007 empty 000000 000214 9 | 0008 empty 000000 000164 10 | 0009 empty 000000 000348 11 | 0010 empty 000000 001175 12 | 0011 empty 000000 000773 13 | 0012 empty 000000 000693 14 | 0013 empty 000000 000151 15 | 0014 empty 000000 000849 16 | 0015 empty 000000 000700 17 | 0016 empty 000000 000509 18 | 0017 empty 000000 000304 19 | 0018 empty 000000 000179 20 | 0019 empty 000000 000403 21 | 0020 empty 000000 000172 22 | 0021 empty 000000 000202 23 | 0022 empty 000000 000435 24 | 0023 empty 000000 000429 25 | 0024 empty 000000 000315 26 | 0025 empty 000000 000175 27 | 0026 empty 000000 000169 28 | 0027 empty 000000 000084 29 | 0028 empty 000000 000174 30 | -------------------------------------------------------------------------------- /mots_tools/training.seqmap: -------------------------------------------------------------------------------- 1 | 0000 empty 000000 000154 2 | 0001 empty 000000 000447 3 | 0002 empty 000000 000233 4 | 0003 empty 000000 000144 5 | 0004 empty 000000 000314 6 | 0005 empty 000000 000297 7 | 0006 empty 000000 000270 8 | 0007 empty 000000 000800 9 | 0008 empty 000000 000390 10 | 0009 empty 000000 000803 11 | 0010 empty 000000 000294 12 | 0011 empty 000000 000373 13 | 0012 empty 000000 000078 14 | 0013 empty 000000 000340 15 | 0014 empty 000000 000106 16 | 0015 empty 000000 000376 17 | 0016 empty 000000 000209 18 | 0017 empty 000000 000145 19 | 0018 empty 000000 000339 20 | 0019 empty 000000 001059 21 | 0020 empty 000000 000837 22 | -------------------------------------------------------------------------------- /mots_tools/val.seqmap: -------------------------------------------------------------------------------- 1 | 0002 empty 000000 000233 2 | 0006 empty 000000 000270 3 | 0007 empty 000000 000800 4 | 0008 empty 000000 000390 5 | 0010 empty 000000 000294 6 | 0013 empty 000000 000340 7 | 0014 empty 000000 000106 8 | 0016 empty 000000 000209 9 | 0018 empty 000000 000339 10 | -------------------------------------------------------------------------------- /scripts/eval_kitti_mots.ps1: -------------------------------------------------------------------------------- 1 | python ./TrackEval/scripts/run_kitti_mots.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER ./data/mots --TRACKERS_FOLDER ./output/kitti --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $args[0] --SPLIT_TO_EVAL $args[1]; 2 | python ./mots_tools/eval.py ./output/kitti/$($args[0])/data ./data/mots/label_02 ./mots_tools/$($args[1]).seqmap -------------------------------------------------------------------------------- /scripts/eval_kitti_tracking.ps1: -------------------------------------------------------------------------------- 1 | python ./TrackEval/scripts/run_kitti.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER ./data/kitti/tracking/training --TRACKERS_FOLDER ./output/kitti/$($args[1]) --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $args[0] --SPLIT_TO_EVAL $args[1] -------------------------------------------------------------------------------- /scripts/run_and_eval_kitti_backward.ps1: -------------------------------------------------------------------------------- 1 | python ./kitti_3d_tracking.py $args[0] $args[1] $args[2] --backward; 2 | python ./TrackEval/scripts/run_kitti.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER ./data/kitti/tracking/training --TRACKERS_FOLDER ./output/kitti/$($args[2]) --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $args[1] --SPLIT_TO_EVAL $args[2] -------------------------------------------------------------------------------- /scripts/run_and_eval_kitti_forward.ps1: -------------------------------------------------------------------------------- 1 | python ./kitti_3d_tracking.py $args[0] $args[1] $args[2]; 2 | python ./TrackEval/scripts/run_kitti.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER ./data/kitti/tracking/training --TRACKERS_FOLDER ./output/kitti/$($args[2]) --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $args[1] --SPLIT_TO_EVAL $args[2] -------------------------------------------------------------------------------- /scripts/run_and_eval_kitti_merge.ps1: -------------------------------------------------------------------------------- 1 | python ./kitti_trajectory_refinement.py $args[0] $args[1] $args[2] $args[3] $args[4]; 2 | python ./TrackEval/scripts/run_kitti.py --TIME_PROGRESS False --PRINT_CONFIG False --GT_FOLDER ./data/kitti/tracking/training --TRACKERS_FOLDER ./output/kitti/$($args[2]) --CLASSES_TO_EVAL car --TRACKERS_TO_EVAL $args[1] --SPLIT_TO_EVAL $args[2] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | 6 | 7 | def make_cuda_ext(name, module, sources): 8 | cuda_ext = CUDAExtension( 9 | name="%s.%s" % (module, name), 10 | sources=[os.path.join(*module.split("."), src) for src in sources], 11 | ) 12 | return cuda_ext 13 | 14 | 15 | if __name__ == "__main__": 16 | setup( 17 | name="BiTrack", 18 | cmdclass={ 19 | "build_ext": BuildExtension, 20 | }, 21 | ext_modules=[ 22 | make_cuda_ext( 23 | name="iou3d_nms_cuda", 24 | module="detection.voxel_rcnn.iou3d_nms", 25 | sources=[ 26 | "src/iou3d_nms_api.cpp", 27 | "src/iou3d_nms.cpp", 28 | "src/iou3d_nms_kernel.cu", 29 | ], 30 | ), 31 | make_cuda_ext( 32 | name="pointnet2_stack_cuda", 33 | module="detection.voxel_rcnn.pointnet2_stack", 34 | sources=[ 35 | "src/pointnet2_api.cpp", 36 | "src/ball_query.cpp", 37 | "src/ball_query_gpu.cu", 38 | "src/group_points.cpp", 39 | "src/group_points_gpu.cu", 40 | "src/sampling.cpp", 41 | "src/sampling_gpu.cu", 42 | "src/interpolate.cpp", 43 | "src/interpolate_gpu.cu", 44 | "src/voxel_query.cpp", 45 | "src/voxel_query_gpu.cu", 46 | "src/vector_pool.cpp", 47 | "src/vector_pool_gpu.cu", 48 | ], 49 | ), 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kemo-Huang/BiTrack/055cd6c1252adae1be4bcb9016a15d08031723e0/tracking/__init__.py -------------------------------------------------------------------------------- /tracking/association.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ortools.linear_solver import pywraplp 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | 6 | class Matcher: 7 | def __init__(self, algorithm=0): 8 | self.algorithm = algorithm 9 | if algorithm == "HA": 10 | # Hungarian algorithm 11 | self.solver = None 12 | elif algorithm == "MCF": 13 | # min-cost flow 14 | self.solver = pywraplp.Solver.CreateSolver("SCIP") 15 | else: 16 | raise NotImplementedError 17 | 18 | def match( 19 | self, 20 | aff_matrix, 21 | aff_thresh=0, 22 | det_scores=None, 23 | trk_scores=None, 24 | entry_scores=None, 25 | exit_scores=None, 26 | post_valid_mask=None, 27 | unused=None, 28 | ): 29 | num_dets, num_trks = aff_matrix.shape 30 | 31 | if self.algorithm == "HA": 32 | row_ind, col_ind = linear_sum_assignment(aff_matrix, maximize=True) 33 | valid_mask = aff_matrix[row_ind, col_ind] >= aff_thresh 34 | if post_valid_mask is not None: 35 | valid_mask &= post_valid_mask[row_ind, col_ind] 36 | row_ind = row_ind[valid_mask] 37 | col_ind = col_ind[valid_mask] 38 | match_matrix = np.zeros((num_dets, num_trks), dtype=bool) 39 | match_matrix[row_ind, col_ind] = True 40 | entry_vec = np.logical_not(np.any(match_matrix, axis=1)) 41 | exit_vec = np.logical_not(np.any(match_matrix, axis=0)) 42 | 43 | false_det_vec = np.zeros(num_dets, dtype=bool) 44 | false_trk_vec = np.zeros(num_trks, dtype=bool) 45 | 46 | elif self.algorithm == "MCF": 47 | assert len(entry_scores) == num_dets 48 | assert len(exit_scores) == num_trks 49 | 50 | self.solver.Clear() 51 | 52 | # Variables 53 | y_det_cls = [self.solver.BoolVar(f"y_det_cls_{i}") for i in range(num_dets)] 54 | y_trk_cls = [self.solver.BoolVar(f"y_trk_cls_{i}") for i in range(num_trks)] 55 | y_entry = [self.solver.BoolVar(f"y_entry_{i}") for i in range(num_dets)] 56 | y_exit = [self.solver.BoolVar(f"y_exit_{i}") for i in range(num_trks)] 57 | y_link = [ 58 | [self.solver.BoolVar(f"y_exit_{i}_{j}") for j in range(num_trks)] 59 | for i in range(num_dets) 60 | ] 61 | 62 | # Constraints 63 | # det = link + entry 64 | for i in range(num_dets): 65 | self.solver.Add( 66 | self.solver.Sum( 67 | [-y_det_cls[i]] 68 | + [y_link[i][j] for j in range(num_trks)] 69 | + [y_entry[i]] 70 | ) 71 | == 0 72 | ) 73 | # trk = link + exit 74 | for j in range(num_trks): 75 | self.solver.Add( 76 | self.solver.Sum( 77 | [-y_trk_cls[j]] 78 | + [y_link[i][j] for i in range(num_dets)] 79 | + [y_exit[j]] 80 | ) 81 | == 0 82 | ) 83 | 84 | # Objective 85 | w_det_cls = [(det_scores[i] - 1) * y_det_cls[i] for i in range(num_dets)] 86 | w_trk_cls = [(trk_scores[i] - 1) * y_trk_cls[i] for i in range(num_trks)] 87 | w_entry = [entry_scores[i] * y_entry[i] for i in range(num_dets)] 88 | w_exit = [exit_scores[i] * y_exit[i] for i in range(num_trks)] 89 | w_link = [ 90 | [aff_matrix[i, j] * y_link[i][j] for j in range(num_trks)] 91 | for i in range(num_dets) 92 | ] 93 | 94 | self.solver.Maximize( 95 | self.solver.Sum( 96 | w_det_cls + w_trk_cls + w_entry + w_exit + sum(w_link, []) 97 | ) 98 | ) 99 | 100 | self.solver.Solve() 101 | 102 | match_matrix = np.array( 103 | [ 104 | [y_link[i][j].solution_value() for j in range(num_trks)] 105 | for i in range(num_dets) 106 | ], 107 | dtype=bool, 108 | ) 109 | entry_vec = np.array([y.solution_value() for y in y_entry], dtype=bool) 110 | exit_vec = np.array([y.solution_value() for y in y_exit], dtype=bool) 111 | 112 | if aff_thresh > 0: 113 | invalid_mask = np.logical_or(aff_matrix < aff_thresh, ~post_valid_mask) 114 | match_matrix[invalid_mask] = False 115 | 116 | false_det_vec = np.logical_not( 117 | np.logical_or(entry_vec, np.any(match_matrix, axis=1)) 118 | ) 119 | false_trk_vec = np.logical_not( 120 | np.logical_or(exit_vec, np.any(match_matrix, axis=0)) 121 | ) 122 | 123 | return match_matrix, entry_vec, exit_vec, false_det_vec, false_trk_vec 124 | -------------------------------------------------------------------------------- /tracking/detections/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .detections import Detections 3 | -------------------------------------------------------------------------------- /tracking/detections/detections.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Detections: 5 | def __init__(self, boxes3d: np.ndarray, objs, similarity, embeds, coor_2d_inds): 6 | self.boxes3d = boxes3d 7 | self.objs = objs 8 | self.similarity = similarity 9 | self.embeds = embeds 10 | self.corr_2d_inds = coor_2d_inds 11 | if boxes3d is not None: 12 | assert len(boxes3d) == len(objs) 13 | if similarity is not None: 14 | assert len(similarity) == len(boxes3d) 15 | if embeds is not None: 16 | assert len(embeds) == len(coor_2d_inds) 17 | 18 | def __len__(self): 19 | if self.boxes3d is not None: 20 | return len(self.boxes3d) 21 | else: 22 | return len(self.corr_2d_inds) 23 | 24 | def append_3d(self, b): 25 | self.boxes3d = np.concatenate((self.boxes3d, b.boxes3d)) 26 | a_objs = self.objs if isinstance(self.objs, np.ndarray) else np.array(self.objs) 27 | b_objs = b.objs if isinstance(b.objs, np.ndarray) else np.array(b.objs) 28 | self.objs = np.concatenate((a_objs, b_objs)) 29 | 30 | assert len(self.boxes3d) == len(self.objs) 31 | 32 | if self.similarity is not None and b.similarity is not None: 33 | self.similarity = np.concatenate((self.similarity, b.similarity)) 34 | if self.embeds is not None and b.embeds is not None: 35 | self.embeds = np.concatenate((self.embeds, b.embeds)) 36 | if self.corr_2d_inds is not None and b.corr_2d_inds is not None: 37 | self.corr_2d_inds = np.concatenate((self.corr_2d_inds, b.corr_2d_inds)) 38 | 39 | def delete_2d(self, b): 40 | remain_inds = np.setdiff1d(self.corr_2d_inds, b.corr_2d_inds) 41 | remain_mask = np.isin(self.corr_2d_inds, remain_inds) 42 | self.corr_2d_inds = remain_inds 43 | if self.embeds is not None: 44 | self.embeds = self.embeds[remain_mask] 45 | -------------------------------------------------------------------------------- /tracking/motion_filters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from filterpy.kalman import KalmanFilter 3 | 4 | from utils import angle_in_range 5 | 6 | 7 | class CVKalmanFilter: 8 | def __init__(self, box, p=10, q=2, r=1, ang_vel=True, vel_reinit=True): 9 | """Constant volocity (CV)-based Kalman filter. 10 | Args: 11 | box (np.ndarray): [x, y, z, dx, dy, dz, heading] 12 | """ 13 | assert len(box) == 7 14 | box[6] = angle_in_range(box[6]) 15 | 16 | self.ang_vel = ang_vel 17 | if ang_vel: 18 | # dim_x: [x, y, z, dx, dy, dz, heading, vx, vy, vz, vr] 19 | self.kf = KalmanFilter(dim_x=11, dim_z=7) 20 | self.kf.F = np.array( 21 | [ # state transition matrix 22 | [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 23 | [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], 24 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], 25 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 26 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 27 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 28 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], 29 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 30 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 31 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 32 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 33 | ] 34 | ) 35 | 36 | else: 37 | # dim_x: [x, y, z, dx, dy, dz, heading, vx, vy, vz] 38 | self.kf = KalmanFilter(dim_x=10, dim_z=7) 39 | self.kf.F = np.array( 40 | [ # state transition matrix 41 | [1, 0, 0, 0, 0, 0, 0, 1, 0, 0], 42 | [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], 43 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 1], 44 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 45 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 46 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 47 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 48 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 49 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 50 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 51 | ] 52 | ) 53 | 54 | self.kf.x[:7] = box[:, np.newaxis] 55 | self.kf.H[:, :7] = np.eye(7) 56 | 57 | self.kf.P *= p # initial state uncertainty 58 | self.kf.Q *= q # process uncertainty 59 | self.kf.R *= r # measurement uncertainty 60 | 61 | self.vel_initialized = False 62 | self.vel_reinit = vel_reinit 63 | 64 | def update(self, box): 65 | """ 66 | Args: 67 | box (np.ndarray): [x, y, z, dx, dy, dz, heading] 68 | """ 69 | box[6] = angle_in_range(box[6]) 70 | if abs(box[6] - self.kf.x[6]) > np.pi: 71 | if box[6] > self.kf.x[6]: 72 | box[6] -= 2 * np.pi 73 | else: 74 | box[6] += 2 * np.pi 75 | if abs(box[6] - self.kf.x[6]) > np.pi / 2: 76 | if box[6] > self.kf.x[6]: 77 | box[6] -= np.pi 78 | else: 79 | box[6] += np.pi 80 | 81 | if self.vel_reinit and not self.vel_initialized: 82 | self.kf.x[7:10, 0] = box[:3] - self.kf.x[:3, 0] 83 | if self.ang_vel: 84 | self.kf.x[10, 0] = box[6] - self.kf.x[6, 0] 85 | self.kf.x[:7, 0] = box 86 | self.vel_initialized = True 87 | 88 | self.kf.update(box) 89 | self.kf.x[6] = angle_in_range(self.kf.x[6]) 90 | 91 | def predict(self, t=1) -> np.array: 92 | """ 93 | Advances the state vectors and returns the predicted bounding box estimate. 94 | """ 95 | assert t > 0 96 | for _ in range(t): 97 | self.kf.predict() 98 | self.kf.x[6] = angle_in_range(self.kf.x[6]) 99 | return self.kf.x[:7, 0] 100 | 101 | @property 102 | def x(self): 103 | return self.kf.x 104 | 105 | 106 | class MAFilter: 107 | def __init__(self, box: np.ndarray): 108 | """Moving Average Filter 109 | 110 | Args: 111 | box (np.ndarray): [x, y, z, dx, dy, dz, heading] 112 | """ 113 | assert len(box) == 7 114 | box[-1] = angle_in_range(box[-1]) 115 | self.x = box 116 | self.v = np.zeros(7) 117 | self.step = 0 118 | 119 | def update(self, box): 120 | box[-1] = angle_in_range(box[-1]) 121 | if self.step == 0: 122 | self.v = box - self.x 123 | else: 124 | self.v = (self.v + box - self.x) / 2 125 | self.v[-1] = angle_in_range(self.v[-1]) 126 | self.x = box 127 | self.step += 1 128 | 129 | def predict(self, t=1): 130 | box = self.x + t * self.v 131 | box[-1] = angle_in_range(box[-1]) 132 | return box 133 | -------------------------------------------------------------------------------- /tracking/track.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .motion_filters import CVKalmanFilter 4 | 5 | 6 | class Track: 7 | global_cur_id = 1 8 | 9 | def __init__( 10 | self, 11 | box: np.ndarray, 12 | obj, 13 | offline=True, 14 | embed=None, 15 | momentum=0.9, 16 | p=10, 17 | q=2, 18 | ang_vel=True, 19 | vel_reinit=True, 20 | ): 21 | self.id = Track.global_cur_id 22 | Track.global_cur_id += 1 23 | self.filter = CVKalmanFilter( 24 | box, p=p, q=q, ang_vel=ang_vel, vel_reinit=vel_reinit 25 | ) 26 | self.embed = embed 27 | self.misses = 0 28 | self.hits = 0 29 | self.offline = offline 30 | self.obj = obj 31 | self.new = True 32 | self.momentum = momentum 33 | if self.offline: 34 | self.max_hits = 0 35 | self.boxes = [box] 36 | self.objs = [obj] 37 | 38 | def update(self, box, obj, embed=None): 39 | self.filter.update(box) 40 | self.obj = obj 41 | if self.offline: 42 | self.boxes.append(box) 43 | self.objs.append(obj) 44 | 45 | if embed is not None: 46 | self.embed = self.momentum * self.embed + (1 - self.momentum) * embed 47 | self.embed /= np.linalg.norm(self.embed) 48 | 49 | def predict(self, t=1): 50 | return self.filter.predict(t) 51 | -------------------------------------------------------------------------------- /tracking/trajectory_refinement.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy as np 4 | from sklearn.gaussian_process import GaussianProcessRegressor 5 | from sklearn.gaussian_process.kernels import RBF 6 | 7 | from utils import KittiTrack3d, NuscenesObject 8 | 9 | 10 | def box_size_weighted_mean( 11 | trajectories: Dict[int, Tuple[List[np.ndarray], List[KittiTrack3d]]], 12 | calib, 13 | img_hw_dict, 14 | exponent=45, 15 | ): 16 | for trk_id, trajectory in trajectories.items(): 17 | boxes, objs = trajectory 18 | if isinstance(boxes, list): 19 | boxes = np.array(boxes) 20 | scores = np.array([x.tracking_score for x in objs]) 21 | scores = scores**exponent 22 | boxes[:, 3:6] = np.sum( 23 | np.tile(scores[:, np.newaxis], (1, 3)) * boxes[:, 3:6], axis=0 24 | ) / np.sum(scores) 25 | 26 | new_objs = [] 27 | for box, obj in zip(boxes, objs): 28 | new_objs.append( 29 | KittiTrack3d( 30 | sample_id=obj.sample_id, 31 | tracking_id=obj.tracking_id, 32 | img_hw=img_hw_dict[str(obj.sample_id).zfill(6)], 33 | ).from_lidar_box(box, calib, obj.cls_type, obj.tracking_score) 34 | ) 35 | trajectories[trk_id] = (boxes, new_objs) 36 | return trajectories 37 | 38 | 39 | def box_size_weighted_mean_nusc( 40 | trajectories: Dict[int, Tuple[List[np.ndarray], List[NuscenesObject]]], 41 | exponent=45, 42 | ): 43 | for trk_id, trajectory in trajectories.items(): 44 | boxes, objs = trajectory 45 | if isinstance(boxes, list): 46 | boxes = np.array(boxes) 47 | scores = np.array([x.tracking_score for x in objs]) 48 | scores = scores**exponent 49 | boxes[:, 3:6] = np.sum( 50 | np.tile(scores[:, np.newaxis], (1, 3)) * boxes[:, 3:6], axis=0 51 | ) / np.sum(scores) 52 | 53 | new_objs = [] 54 | for box, obj in zip(boxes, objs): 55 | new_objs.append( 56 | NuscenesObject(None).from_box( 57 | box, 58 | obj.data["sample_token"], 59 | obj.data["velocity"], 60 | obj.tracking_id, 61 | obj.data["tracking_name"], 62 | obj.tracking_score, 63 | ) 64 | ) 65 | trajectories[trk_id] = (boxes, new_objs) 66 | return trajectories 67 | 68 | 69 | def gaussian_smoothing( 70 | trajectories: Dict[int, Tuple[List[np.ndarray], List[KittiTrack3d]]], 71 | calib, 72 | img_hw_dict, 73 | tau, 74 | ): 75 | for trk_id, trajectory in trajectories.items(): 76 | boxes, objs = trajectory 77 | if isinstance(boxes, list): 78 | boxes = np.array(boxes) 79 | len_scale = np.clip(tau * np.log(tau**3 / len(objs)), tau**-1, tau**2) 80 | gpr = GaussianProcessRegressor(RBF(len_scale, "fixed")) 81 | t = np.array([obj.sample_id for obj in objs])[:, np.newaxis] 82 | boxes[:, 0] = gpr.fit(t, boxes[:, 0:1]).predict(t) 83 | boxes[:, 1] = gpr.fit(t, boxes[:, 1:2]).predict(t) 84 | boxes[:, 2] = gpr.fit(t, boxes[:, 2:3]).predict(t) 85 | # boxes[:, 6] = gpr.fit(t, boxes[:, 6:7]).predict(t) 86 | 87 | new_objs = [] 88 | for box, obj in zip(boxes, objs): 89 | new_objs.append( 90 | KittiTrack3d( 91 | sample_id=obj.sample_id, 92 | tracking_id=obj.tracking_id, 93 | img_hw=img_hw_dict[str(obj.sample_id).zfill(6)], 94 | ).from_lidar_box(box, calib, obj.cls_type, obj.tracking_score) 95 | ) 96 | trajectories[trk_id] = (boxes, new_objs) 97 | 98 | return trajectories 99 | 100 | 101 | def gaussian_smoothing_nusc( 102 | trajectories: Dict[int, Tuple[List[np.ndarray], List[NuscenesObject]]], 103 | tau, 104 | ): 105 | for trk_id, trajectory in trajectories.items(): 106 | boxes, objs = trajectory 107 | if isinstance(boxes, list): 108 | boxes = np.array(boxes) 109 | len_scale = np.clip(tau * np.log(tau**3 / len(objs)), tau**-1, tau**2) 110 | gpr = GaussianProcessRegressor(RBF(len_scale, "fixed")) 111 | t = np.array([obj.sample_id for obj in objs])[:, np.newaxis] 112 | boxes[:, 0] = gpr.fit(t, boxes[:, 0:1]).predict(t) 113 | boxes[:, 1] = gpr.fit(t, boxes[:, 1:2]).predict(t) 114 | boxes[:, 2] = gpr.fit(t, boxes[:, 2:3]).predict(t) 115 | # boxes[:, 6] = gpr.fit(t, boxes[:, 6:7]).predict(t) 116 | 117 | new_objs = [] 118 | for box, obj in zip(boxes, objs): 119 | new_objs.append( 120 | NuscenesObject().from_box(box, obj.sample_id, obj.data["velocity"], obj.tracking_id, obj.data["tracking_name"], obj.tracking_score) 121 | ) 122 | trajectories[trk_id] = (boxes, new_objs) 123 | 124 | return trajectories -------------------------------------------------------------------------------- /visualization/visualize_embeddings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from PIL import Image 6 | from sklearn.manifold import TSNE 7 | 8 | from utils import Calibration, get_instance_ids_of_boxes_hungarian 9 | 10 | 11 | def main(): 12 | min_n_pts = 3 13 | 14 | all_embeds = [] 15 | corres_gt_ids = [] 16 | 17 | for i in range(21): 18 | seq = str(i).zfill(4) 19 | embed_dir = Path(f"data/tracking/embeddings/{seq}") 20 | instance_dir = Path(f"data/tracking/seg_instances/{seq}") 21 | lidar_dir = Path(f"data/tracking/cropped_points/gt/{seq}") 22 | calib = Calibration(f"data/tracking/calib/{seq}.txt") 23 | 24 | for instance_file in sorted(instance_dir.iterdir()): 25 | embed_file = embed_dir / f"{instance_file.stem}.npy" 26 | if embed_file.exists(): 27 | cur_lidar_dir = lidar_dir / instance_file.stem 28 | if cur_lidar_dir.exists(): 29 | gt_ids = [] 30 | inside_points = [] 31 | for lidar_file in cur_lidar_dir.iterdir(): 32 | gt_ids.append(int(lidar_file.stem)) 33 | inside_points.append( 34 | np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)[ 35 | :, :3 36 | ] 37 | ) 38 | instance_img = np.array(Image.open(instance_file)) 39 | instance_ids, _ = get_instance_ids_of_boxes_hungarian( 40 | instance_img, inside_points, calib, min_n_pts 41 | ) 42 | embeds = np.load(embed_file) 43 | 44 | good_box_mask = instance_ids > 0 45 | all_embeds.append(embeds[instance_ids[good_box_mask] - 1]) 46 | corres_gt_ids.append( 47 | np.array(gt_ids, dtype=int)[good_box_mask] + 1000 * i 48 | ) 49 | 50 | all_embeds = np.concatenate(all_embeds) 51 | corres_gt_ids = np.concatenate(corres_gt_ids) 52 | print(all_embeds.shape, corres_gt_ids.shape) 53 | 54 | tsne = TSNE(n_components=2, init="pca", learning_rate="auto") 55 | X_new = tsne.fit_transform(all_embeds) 56 | plt.scatter(X_new[:, 0], X_new[:, 1], c=corres_gt_ids) 57 | plt.show() 58 | 59 | 60 | main() 61 | -------------------------------------------------------------------------------- /visualization/visualize_frame_bev.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | 7 | from utils import (Calibration, boxes_to_corners_bev, draw_boxes_bev, 8 | get_lidar_boxes_from_objs, get_objects_from_label, 9 | map_tracks_by_frames) 10 | 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('tag', type=str) 15 | parser.add_argument('split', type=str) 16 | parser.add_argument('seq', type=int) 17 | parser.add_argument('frame', type=int) 18 | args = parser.parse_args() 19 | trk_out_dir = Path(f'output/kitti/{args.split}/{args.tag}/data') 20 | gt_dir = Path(f'data/kitti/tracking/{args.split}/label_02') 21 | calib_dir = Path(f'data/kitti/tracking/{args.split}/calib') 22 | lidar_dir = Path(f'data/kitti/tracking/{args.split}/velodyne') 23 | assert trk_out_dir.exists() 24 | 25 | calib = Calibration(calib_dir / f'{str(args.seq).zfill(4)}.txt') 26 | seq_out_tracks = map_tracks_by_frames( 27 | get_objects_from_label(trk_out_dir / f'{str(args.seq).zfill(4)}.txt', track=True) 28 | ) 29 | out_tracks = seq_out_tracks.get(args.frame, []) 30 | out_boxes_bev = boxes_to_corners_bev(get_lidar_boxes_from_objs(out_tracks, calib)) 31 | out_ids = [x.tracking_id for x in out_tracks] 32 | 33 | seq_gt_tracks = map_tracks_by_frames([ 34 | x for x in get_objects_from_label(gt_dir / f'{str(args.seq).zfill(4)}.txt', track=True) if x.cls_type == 'Car' or x.cls_type == 'Van' 35 | ]) 36 | gt_tracks = seq_gt_tracks.get(args.frame, []) 37 | gt_boxes_bev = boxes_to_corners_bev(get_lidar_boxes_from_objs(gt_tracks, calib)) 38 | gt_ids = [x.tracking_id for x in gt_tracks] 39 | 40 | points = np.fromfile(lidar_dir / str(args.seq).zfill(4) / f'{str(args.frame).zfill(6)}.bin', dtype=np.float32).reshape(-1, 4)[:, :2] 41 | 42 | fig, ax = plt.subplots() 43 | ax.axis('equal') 44 | 45 | draw_boxes_bev(ax, gt_boxes_bev, 'r', gt_ids) 46 | draw_boxes_bev(ax, out_boxes_bev, 'b', out_ids) 47 | xlim = ax.get_xlim() 48 | ylim = ax.get_ylim() 49 | ax.scatter(points[:,0], points[:,1], s=1, c='gray') 50 | ax.set_xlim(xlim) 51 | ax.set_ylim(ylim) 52 | 53 | plt.show() 54 | 55 | if __name__ == '__main__': 56 | main() -------------------------------------------------------------------------------- /visualization/visualize_trajectories.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | from utils import (Calibration, get_global_boxes_from_lidar, 6 | get_poses_from_file, read_kitti_trajectories_from_file, 7 | visualize_trajectories) 8 | 9 | 10 | def main(): 11 | seq = "0006" 12 | root_dir = Path("data/kitti/tracking/training") 13 | calib = Calibration(root_dir / f"calib/{seq}.txt") 14 | poses = get_poses_from_file(root_dir / f"oxts/{seq}.txt") 15 | tracks = read_kitti_trajectories_from_file( 16 | seq, Path("output/kitti/training/bitrack/data"), calib 17 | ) 18 | all_boxes = [] 19 | for boxes, objs in tracks.values(): 20 | global_boxes = [] 21 | for box, obj in zip(boxes, objs): 22 | frame = int(obj.sample_id) 23 | if frame > 100: 24 | continue 25 | global_boxes.append( 26 | get_global_boxes_from_lidar(box[np.newaxis], poses[frame])[0] 27 | ) 28 | all_boxes.append(np.array(global_boxes).reshape(-1, 7)) 29 | visualize_trajectories(all_boxes) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | --------------------------------------------------------------------------------