├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_library.md ├── nu_configs │ ├── giou.yaml │ └── iou.yaml └── waymo_configs │ ├── pd_kf_giou.yaml │ └── vc_kf_giou.yaml ├── data_loader ├── __init__.py ├── nuscenes_loader.py └── waymo_loader.py ├── docs ├── config.md ├── data_preprocess.md ├── demo.png ├── nuScenes.md ├── output_format.md ├── simpletrack_gif.gif └── waymo.md ├── mot_3d ├── __init__.py ├── association.py ├── data_protos │ ├── __init__.py │ ├── bbox.py │ └── validity.py ├── frame_data.py ├── life │ ├── __init__.py │ └── hit_manager.py ├── mot.py ├── motion_model │ ├── __init__.py │ └── kalman_filter.py ├── preprocessing │ ├── __init__.py │ ├── bbox_coarse_hash.py │ └── nms.py ├── redundancy │ ├── __init__.py │ └── redundancy.py ├── tracklet │ ├── __init__.py │ └── tracklet.py ├── update_info_data.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ └── geometry.py └── visualization │ ├── __init__.py │ └── visualizer2d.py ├── preprocessing ├── nuscenes_data │ ├── detection.py │ ├── ego_pose.py │ ├── gt_info.py │ ├── nuscenes_preprocess.sh │ ├── raw_pc.py │ ├── sensor_calibration.py │ ├── time_stamp.py │ └── token_info.py └── waymo_data │ ├── detection.py │ ├── ego_info.py │ ├── gt_bin_decode.py │ ├── raw_pc.py │ ├── time_stamp.py │ └── waymo_preprocess.sh ├── requirements.txt ├── setup.py └── tools ├── demo.py ├── main_nuscenes.py ├── main_nuscenes_10hz.py ├── main_waymo.py ├── nuscenes_result_creation.py ├── nuscenes_result_creation_10hz.py ├── nuscenes_type_merge.py └── waymo_pred_bin.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.npy 7 | *.zip 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 tusen-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimpleTrack: Simple yet Effective 3D Multi-object Tracking 2 | 3 | This is the repository for our paper [SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking](https://arxiv.org/abs/2111.09621). We are still working on writing the documentations and cleaning up the code, but the following parts are sufficient for you to replicate the results in our paper. For more variants of the model, we have already moved all of our code onto the `dev` branch, so please feel free to check it out if you really need to delve deep recently. We will try our best to get everything ready as soon as possible. 4 | 5 | If you find our paper or code useful for you, please consider cite us by: 6 | ``` 7 | @article{pang2021simpletrack, 8 | title={SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking}, 9 | author={Pang, Ziqi and Li, Zhichao and Wang, Naiyan}, 10 | journal={arXiv preprint arXiv:2111.09621}, 11 | year={2021} 12 | } 13 | ``` 14 | 15 | 16 | 17 | - [ ] Accelerating the code, make the IoU/GIoU computation parallel. 18 | - [ ] Add documentation for codebase. 19 | 20 | ## Installation 21 | 22 | ### Environment Requirements 23 | 24 | `SimpleTrack` requires `python>=3.6` and the packages of `pip install -r requirements.txt`. For the experiments on Waymo Open Dataset, please install the devkit following the instructions at [waymo open dataset devkit](https://github.com/waymo-research/waymo-open-dataset). 25 | 26 | ### Installation 27 | 28 | We implement the `SimpleTrack` algorithm as a library `mot_3d`. Please run `pip install -e ./` to install it locally. 29 | 30 | ## Demo and API Example 31 | 32 | ### Demo 33 | 34 | We provide a demo based on the first sequence with ID `10203656353524179475_7625_000_7645_000` from the validation set of [Waymo Open Dataset](https://waymo.com/open/) and the detection from [CenterPoint](https://github.com/tianweiy/CenterPoint). (We are thankful and hope that we are not violating any terms here.) 35 | 36 | First, download the [demo_data](https://www.dropbox.com/s/m8vt7t7tqofaoq2/demo_data.zip?dl=0) and extract it locally. It contains the necessary information and is already preprocessed according to [our preprocessing programs](./docs/data_preprocess.md). To run the demo, please run the following command. It will provide interactive visualization with `matplotlib.pyplt`. Therefore, it is recommended to run this demo locally. 37 | 38 | ```bash 39 | python tools/demo.py \ 40 | --name demo \ 41 | --det_name cp \ 42 | --obj_type vehicle \ 43 | --config_path configs/waymo_configs/vc_kf_giou.yaml \ 44 | --data_folder ./demo_data/ \ 45 | --visualize 46 | ``` 47 | 48 | An example output for the visualization is the following figure. 49 | 50 | 51 | 52 | In the visualization, the red bounding boxes are the output tracking results with their IDs. The blue ones are the tracking results that are not output due to low confidence score. The green ones are the detection bounding boxes with scores. The black ones are the ground truth bounding boxes. 53 | 54 | ### API Example 55 | 56 | The most important function is `tracker.frame_mot()`. An object of `MOTModel` iteratively digests the information from each frame `FrameData` and infers the tracking result on each frame. 57 | 58 | ## Inference on Waymo Open Dataset and nuScenes 59 | 60 | Refer to the documentation of [Waymo Open Dataset Inference](./docs/waymo.md) and [nuScenes Inference](./docs/nuScenes.md). **Important: please rigorously follow the config file path and instructions in the documentation to reproduce the results.** 61 | 62 | The detailed metrics and files are at [Dropbox Link](https://www.dropbox.com/sh/2zcpf7wyho7x21q/AABRtGP75KHghr0wyoVqzFFXa?dl=0), Waymo Open Dataset in [Dropbox link](https://www.dropbox.com/sh/u6o8dcwmzya04uk/AAAUsNvTt7ubXul9dx5Xnp4xa?dl=0), nuScenes in [Dropbox Link](https://www.dropbox.com/sh/8906exnes0u5e89/AAD0xLwW1nq_QiuUBaYDrQVna?dl=0). 63 | 64 | For the metrics on test set, please refer to our paper or the leaderboard. 65 | 66 | ## Related Documentations 67 | 68 | To enable the better usages of our `mot_3d` library, we provide a list useful documentations, and will add more in the future. 69 | 70 | * [Read and Use the Configurations](./docs/config.md). We explain how to specify the behaviors of trackers in this documentation, such as two-stage association, the thresholds for association, etc. 71 | * [Format of Output](./docs/output_format.md). We explain the output format for the APIs in `SimpleTrack`, so that you may directly use the functions provided. (in progress) 72 | * Visualization with `mot_3d` (in progress) 73 | * Structure of `mot_3d` (in progress) 74 | -------------------------------------------------------------------------------- /configs/config_library.md: -------------------------------------------------------------------------------- 1 | # Configuration Library -------------------------------------------------------------------------------- /configs/nu_configs/giou.yaml: -------------------------------------------------------------------------------- 1 | running: 2 | covariance: default 3 | score_threshold: 0.01 4 | max_age_since_update: 2 5 | min_hits_to_birth: 0 6 | match_type: bipartite 7 | asso: giou 8 | has_velo: true 9 | nms_thres: 0.1 10 | motion_model: kf 11 | asso_thres: 12 | giou: 1.5 13 | iou: 0.9 14 | 15 | redundancy: 16 | mode: default 17 | det_score_threshold: 18 | iou: 0.01 19 | giou: 0.01 20 | det_dist_threshold: 21 | iou: 0.1 22 | giou: -0.5 23 | 24 | data_loader: 25 | pc: true 26 | nms: true 27 | nms_thres: 0.1 -------------------------------------------------------------------------------- /configs/nu_configs/iou.yaml: -------------------------------------------------------------------------------- 1 | running: 2 | covariance: default 3 | score_threshold: 0.01 4 | max_age_since_update: 2 5 | min_hits_to_birth: 1 6 | match_type: bipartite 7 | asso: iou 8 | has_velo: false 9 | motion_model: kf 10 | asso_thres: 11 | giou: 1.5 12 | iou: 0.9 13 | 14 | redundancy: 15 | mode: default 16 | det_score_threshold: 17 | iou: 0.01 18 | giou: 0.01 19 | m_dis: 0.01 20 | det_dist_threshold: 21 | iou: 0.1 22 | giou: -0.5 23 | 24 | data_loader: 25 | pc: false -------------------------------------------------------------------------------- /configs/waymo_configs/pd_kf_giou.yaml: -------------------------------------------------------------------------------- 1 | running: 2 | covariance: default 3 | score_threshold: 0.5 4 | max_age_since_update: 2 5 | min_hits_to_birth: 3 6 | match_type: bipartite 7 | asso: giou 8 | has_velo: false 9 | motion_model: kf 10 | asso_thres: 11 | iou: 0.9 12 | giou: 1.5 13 | 14 | redundancy: 15 | mode: mm 16 | det_score_threshold: 17 | iou: 0.1 18 | giou: 0.1 19 | det_dist_threshold: 20 | iou: 0.1 21 | giou: -0.5 22 | 23 | data_loader: 24 | pc: true 25 | nms: true 26 | nms_thres: 0.25 -------------------------------------------------------------------------------- /configs/waymo_configs/vc_kf_giou.yaml: -------------------------------------------------------------------------------- 1 | running: 2 | covariance: default 3 | score_threshold: 0.7 4 | max_age_since_update: 2 5 | min_hits_to_birth: 3 6 | match_type: bipartite 7 | asso: giou 8 | has_velo: false 9 | motion_model: kf 10 | asso_thres: 11 | iou: 0.9 12 | giou: 1.5 13 | 14 | redundancy: 15 | mode: mm 16 | det_score_threshold: 17 | iou: 0.1 18 | giou: 0.1 19 | det_dist_threshold: 20 | iou: 0.1 21 | giou: -0.5 22 | 23 | data_loader: 24 | pc: true 25 | nms: true 26 | nms_thres: 0.25 -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from data_loader.waymo_loader import WaymoLoader 2 | from data_loader.nuscenes_loader import NuScenesLoader, NuScenesLoader10Hz -------------------------------------------------------------------------------- /data_loader/nuscenes_loader.py: -------------------------------------------------------------------------------- 1 | """ Example of data loader: 2 | The data loader has to be an iterator: 3 | Return a dict of frame data 4 | Users may create the logic of your own data loader 5 | """ 6 | import os, numpy as np, json 7 | from pyquaternion import Quaternion 8 | from nuscenes.utils.data_classes import Box 9 | from mot_3d.data_protos import BBox 10 | import mot_3d.utils as utils 11 | from mot_3d.preprocessing import nms 12 | 13 | 14 | def transform_matrix(translation: np.ndarray = np.array([0, 0, 0]), 15 | rotation: np.ndarray = np.array([1, 0, 0, 0]), 16 | inverse: bool = False) -> np.ndarray: 17 | tm = np.eye(4) 18 | rotation = Quaternion(rotation) 19 | 20 | if inverse: 21 | rot_inv = rotation.rotation_matrix.T 22 | trans = np.transpose(-np.array(translation)) 23 | tm[:3, :3] = rot_inv 24 | tm[:3, 3] = rot_inv.dot(trans) 25 | else: 26 | tm[:3, :3] = rotation.rotation_matrix 27 | tm[:3, 3] = np.transpose(np.array(translation)) 28 | return tm 29 | 30 | 31 | def nu_array2mot_bbox(b): 32 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10])) 33 | mot_bbox = BBox( 34 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2], 35 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2], 36 | o=nu_box.orientation.yaw_pitch_roll[0] 37 | ) 38 | if len(b) == 11: 39 | mot_bbox.s = b[-1] 40 | return mot_bbox 41 | 42 | 43 | class NuScenesLoader: 44 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame): 45 | """ initialize with the path to data 46 | Args: 47 | data_folder (str): root path to your data 48 | """ 49 | self.configs = configs 50 | self.segment = segment_name 51 | self.data_loader = data_folder 52 | self.det_data_folder = det_data_folder 53 | self.type_token = type_token 54 | 55 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r')) 56 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)), 57 | allow_pickle=True) 58 | self.calib_info = np.load(os.path.join(data_folder, 'calib_info', '{:}.npz'.format(segment_name)), 59 | allow_pickle=True) 60 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)), 61 | allow_pickle=True) 62 | self.det_type_filter = True 63 | 64 | self.use_pc = configs['data_loader']['pc'] 65 | if self.use_pc: 66 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)), 67 | allow_pickle=True) 68 | 69 | self.max_frame = len(self.dets['bboxes']) 70 | self.cur_frame = start_frame 71 | 72 | def __iter__(self): 73 | return self 74 | 75 | def __next__(self): 76 | if self.cur_frame >= self.max_frame: 77 | raise StopIteration 78 | 79 | result = dict() 80 | result['time_stamp'] = self.ts_info[self.cur_frame] * 1e-6 81 | ego = self.ego_info[str(self.cur_frame)] 82 | ego_matrix = transform_matrix(ego[:3], ego[3:]) 83 | result['ego'] = ego_matrix 84 | 85 | bboxes = self.dets['bboxes'][self.cur_frame] 86 | inst_types = self.dets['types'][self.cur_frame] 87 | frame_bboxes = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token] 88 | result['det_types'] = [inst_types[i] for i in range(len(inst_types)) if inst_types[i] in self.type_token] 89 | result['dets'] = [nu_array2mot_bbox(b) for b in frame_bboxes] 90 | result['aux_info'] = dict() 91 | if 'velos' in list(self.dets.keys()): 92 | cur_velos = self.dets['velos'][self.cur_frame] 93 | result['aux_info']['velos'] = [cur_velos[i] for i in range(len(cur_velos)) 94 | if inst_types[i] in self.type_token] 95 | else: 96 | result['aux_info']['velos'] = None 97 | 98 | result['dets'], result['det_types'], result['aux_info']['velos'] = \ 99 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], 0.1) 100 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']] 101 | 102 | result['pc'] = None 103 | if self.use_pc: 104 | pc = self.pcs[str(self.cur_frame)][:, :3] 105 | calib = self.calib_info[str(self.cur_frame)] 106 | calib_trans, calib_rot = np.asarray(calib[:3]), Quaternion(np.asarray(calib[3:])) 107 | pc = np.dot(pc, calib_rot.rotation_matrix.T) 108 | pc += calib_trans 109 | result['pc'] = utils.pc2world(ego_matrix, pc) 110 | 111 | # if 'velos' in list(self.dets.keys()): 112 | # cur_frame_velos = self.dets['velos'][self.cur_frame] 113 | # result['aux_info']['velos'] = [cur_frame_velos[i] 114 | # for i in range(len(bboxes)) if inst_types[i] in self.type_token] 115 | result['aux_info']['is_key_frame'] = True 116 | 117 | self.cur_frame += 1 118 | return result 119 | 120 | def __len__(self): 121 | return self.max_frame 122 | 123 | def frame_nms(self, dets, det_types, velos, thres): 124 | frame_indexes, frame_types = nms(dets, det_types, thres) 125 | result_dets = [dets[i] for i in frame_indexes] 126 | result_velos = None 127 | if velos is not None: 128 | result_velos = [velos[i] for i in frame_indexes] 129 | return result_dets, frame_types, result_velos 130 | 131 | 132 | class NuScenesLoader10Hz: 133 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame): 134 | """ initialize with the path to data 135 | Args: 136 | data_folder (str): root path to your data 137 | """ 138 | self.configs = configs 139 | self.segment = segment_name 140 | self.data_loader = data_folder 141 | self.det_data_folder = det_data_folder 142 | self.type_token = type_token 143 | 144 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r')) 145 | self.time_stamps = [t[0] for t in self.ts_info] 146 | self.is_key_frames = [t[1] for t in self.ts_info] 147 | 148 | self.token_info = json.load(open(os.path.join(data_folder, 'token_info', '{:}.json'.format(segment_name)), 'r')) 149 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)), 150 | allow_pickle=True) 151 | self.calib_info = np.load(os.path.join(data_folder, 'calib_info', '{:}.npz'.format(segment_name)), 152 | allow_pickle=True) 153 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)), 154 | allow_pickle=True) 155 | self.det_type_filter = True 156 | 157 | self.use_pc = configs['data_loader']['pc'] 158 | if self.use_pc: 159 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)), 160 | allow_pickle=True) 161 | 162 | self.max_frame = len(self.dets['bboxes']) 163 | self.selected_frames = [i for i in range(self.max_frame) if self.token_info[i][3]] 164 | self.cur_selected_index = 0 165 | self.cur_frame = start_frame 166 | 167 | def __iter__(self): 168 | return self 169 | 170 | def __next__(self): 171 | if self.cur_selected_index >= len(self.selected_frames): 172 | raise StopIteration 173 | self.cur_frame = self.selected_frames[self.cur_selected_index] 174 | 175 | result = dict() 176 | result['time_stamp'] = self.time_stamps[self.cur_frame] * 1e-6 177 | ego = self.ego_info[str(self.cur_frame)] 178 | ego_matrix = transform_matrix(ego[:3], ego[3:]) 179 | result['ego'] = ego_matrix 180 | 181 | bboxes = self.dets['bboxes'][self.cur_frame] 182 | inst_types = self.dets['types'][self.cur_frame] 183 | frame_bboxes = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token] 184 | result['det_types'] = [inst_types[i] for i in range(len(inst_types)) if inst_types[i] in self.type_token] 185 | 186 | result['dets'] = [nu_array2mot_bbox(b) for b in frame_bboxes] 187 | result['aux_info'] = dict() 188 | if 'velos' in list(self.dets.keys()): 189 | cur_velos = self.dets['velos'][self.cur_frame] 190 | result['aux_info']['velos'] = [cur_velos[i] for i in range(len(cur_velos)) 191 | if inst_types[i] in self.type_token] 192 | else: 193 | result['aux_info']['velos'] = None 194 | result['dets'], result['det_types'], result['aux_info']['velos'] = \ 195 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], 0.1) 196 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']] 197 | 198 | result['pc'] = None 199 | if self.use_pc: 200 | pc = self.pcs[str(self.cur_frame)][:, :3] 201 | calib = self.calib_info[str(self.cur_frame)] 202 | calib_trans, calib_rot = np.asarray(calib[:3]), Quaternion(np.asarray(calib[3:])) 203 | pc = np.dot(pc, calib_rot.rotation_matrix.T) 204 | pc += calib_trans 205 | result['pc'] = utils.pc2world(ego_matrix, pc) 206 | 207 | # if 'velos' in list(self.dets.keys()): 208 | # cur_frame_velos = self.dets['velos'][self.cur_frame] 209 | # result['aux_info']['velos'] = [cur_frame_velos[i] 210 | # for i in range(len(bboxes)) if inst_types[i] in self.type_token] 211 | # print(result['aux_info']['velos']) 212 | result['aux_info']['is_key_frame'] = self.is_key_frames[self.cur_frame] 213 | 214 | self.cur_selected_index += 1 215 | return result 216 | 217 | def __len__(self): 218 | return len(self.selected_frames) 219 | 220 | def frame_nms(self, dets, det_types, velos, thres): 221 | frame_indexes, frame_types = nms(dets, det_types, thres) 222 | result_dets = [dets[i] for i in frame_indexes] 223 | result_velos = None 224 | if velos is not None: 225 | result_velos = [velos[i] for i in frame_indexes] 226 | return result_dets, frame_types, result_velos 227 | -------------------------------------------------------------------------------- /data_loader/waymo_loader.py: -------------------------------------------------------------------------------- 1 | """ Example of data loader: 2 | The data loader has to be an iterator: 3 | Return a dict of frame data 4 | Users may create the logic of your own data loader 5 | """ 6 | import os, numpy as np, json 7 | import mot_3d.utils as utils 8 | from mot_3d.data_protos import BBox 9 | from mot_3d.preprocessing import nms 10 | 11 | 12 | class WaymoLoader: 13 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame): 14 | """ initialize with the path to data 15 | Args: 16 | data_folder (str): root path to your data 17 | """ 18 | self.configs = configs 19 | self.segment = segment_name 20 | self.data_loader = data_folder 21 | self.det_data_folder = det_data_folder 22 | self.type_token = type_token 23 | 24 | self.nms = configs['data_loader']['nms'] 25 | self.nms_thres = configs['data_loader']['nms_thres'] 26 | 27 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r')) 28 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)), 29 | allow_pickle=True) 30 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)), 31 | allow_pickle=True) 32 | self.det_type_filter = True 33 | 34 | self.use_pc = configs['data_loader']['pc'] 35 | if self.use_pc: 36 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)), 37 | allow_pickle=True) 38 | 39 | self.max_frame = len(self.dets['bboxes']) 40 | self.cur_frame = start_frame 41 | 42 | def __iter__(self): 43 | return self 44 | 45 | def __next__(self): 46 | if self.cur_frame >= self.max_frame: 47 | raise StopIteration 48 | 49 | result = dict() 50 | result['time_stamp'] = self.ts_info[self.cur_frame] * 1e-6 51 | result['ego'] = self.ego_info[str(self.cur_frame)] 52 | 53 | bboxes = self.dets['bboxes'][self.cur_frame] 54 | inst_types = self.dets['types'][self.cur_frame] 55 | selected_dets = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token] 56 | result['det_types'] = [inst_types[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token] 57 | result['dets'] = [BBox.bbox2world(result['ego'], BBox.array2bbox(b)) 58 | for b in selected_dets] 59 | 60 | result['pc'] = None 61 | if self.use_pc: 62 | pc = self.pcs[str(self.cur_frame)] 63 | result['pc'] = utils.pc2world(result['ego'], pc) 64 | 65 | result['aux_info'] = {'is_key_frame': True} 66 | if 'velos' in self.dets.keys(): 67 | cur_frame_velos = self.dets['velos'][self.cur_frame] 68 | result['aux_info']['velos'] = [np.array(cur_frame_velos[i]) 69 | for i in range(len(bboxes)) if inst_types[i] in self.type_token] 70 | result['aux_info']['velos'] = [utils.velo2world(result['ego'], v) 71 | for v in result['aux_info']['velos']] 72 | else: 73 | result['aux_info']['velos'] = None 74 | 75 | if self.nms: 76 | result['dets'], result['det_types'], result['aux_info']['velos'] = \ 77 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], self.nms_thres) 78 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']] 79 | 80 | self.cur_frame += 1 81 | return result 82 | 83 | def __len__(self): 84 | return self.max_frame 85 | 86 | def frame_nms(self, dets, det_types, velos, thres): 87 | frame_indexes, frame_types = nms(dets, det_types, thres) 88 | result_dets = [dets[i] for i in frame_indexes] 89 | result_velos = None 90 | if velos is not None: 91 | result_velos = [velos[i] for i in frame_indexes] 92 | return result_dets, frame_types, result_velos 93 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | 3 | This documentation provides a brief guide for writing and reading the configuration files. We specify the tracker behaviors with a single `.yaml` file as in the folder of `./configs/`. 4 | 5 | ## Explanation with the Example from Waymo Open Dataset 6 | 7 | Take the configs for Waymo Open Dataset (vehicles) as example, we explain with `configs/waymo_configs/vc_kf_giou.yaml`. 8 | 9 | We provide a thorough annotation below, but it may be confusing if you delve deep directly. So the following are the key points for specifying the model. 10 | 11 | * **Data Preprocessing** is controlled by the `nms` and `nms_thres` arguments in `data_loader` field. They specify whether to use NMS and what is the IoU threshold for NMS. 12 | 13 | * **Association Metric** is specified in `running-asso`, where you can choose from IoU and GIoU for now. To specify the threshold for successful association, we treat `1-IoU` or `1-GIoU` as the distance, and it must be smaller than the corresponding numbers in `asso_thres`. If you are interested in L2 distance or Mahalanobis distance, please refer to the `dev` branch. We are still cleaning up and adding documentations. 14 | * **Motion Model** is by default Kalman filter. We are adding more options, such as velocity model. Please refer to the `dev` branch if you wish to explore on your own. 15 | * **Two-stage Association** is specified by the `redundancy` module. `mm` denotes two-stage and `default` denotes the conventional single-stage association. 16 | * NOTE on **nuScenes 10Hz two-stage association**. We are trying to incorporate this part into the config file instead of hard coding everything. Please temporarily go to the function `non_key_frame_mot` in `mot_3d/mot.py` for how we deal with the tracking on no-key frames. 17 | 18 | 19 | ```yaml 20 | running: 21 | covariance: default # not used 22 | score_threshold: 0.7 # detection score threshold for first-stage association and output 23 | max_age_since_update: 2 # count for death, same as AB3DMOT 24 | min_hits_to_birth: 3 # count for birth, same as AB3DMOT 25 | match_type: bipartite # use hungarian (biparitite) or greedy algorithm (greedy) for association 26 | asso: giou # association metric, we support GIoU (giou) and IoU (iou) for now 27 | has_velo: false 28 | motion_model: kf # Kalman filter (kf) as motion model 29 | asso_thres: 30 | iou: 0.9 # association threshold, 1 - IoU has to be smaller than it. 31 | giou: 1.5 # association threshold, 1 - GIoU has to be smaller than it. 32 | 33 | redundancy: 34 | mode: mm # (mm) denotes two-stage association, (default) denotes one stage association (see nuScenes configs) 35 | det_score_threshold: # detection score threshold for two-stage association 36 | iou: 0.1 37 | giou: 0.1 38 | det_dist_threshold: # association threshold for two-stage association 39 | iou: 0.1 # IoU has to be greater than this 40 | giou: -0.5 # GIoU has to be greater than this 41 | 42 | data_loader: 43 | pc: true # load point clouds for visualization 44 | nms: true # apply NMS for data preprocessing, True or False 45 | nms_thres: 0.25 # IoU-3D threshold for NMS 46 | ``` -------------------------------------------------------------------------------- /docs/data_preprocess.md: -------------------------------------------------------------------------------- 1 | # Data Preprocessing 2 | 3 | ## Waymo Open Dataset 4 | 5 | We decompose the preprocessing of Waymo Open Dataset into the following steps. 6 | 7 | ### 1. Raw Data 8 | 9 | This step converts the information needed from `tf_records` into more handy forms. Suppose the folder storing `tf_records` is `raw_data_folder`, the target location if `data_folder`, run the following command: (you can specify `process_num` to be an integer greater than 1 for faster preprocessing.) 10 | 11 | ```bash 12 | cd preprocessing/waymo_data 13 | bash waymo_preprocess.sh ${raw_data_folder} ${data_folder} ${process_num} 14 | ``` 15 | 16 | ### 2. Ground Truth Information 17 | 18 | The ground truth for the 3D MOT and 3D Detection are the same. **You have to download a .bin file from Waymo Open Dataset for the ground truth, which we have no right to share according to the license.** 19 | 20 | To decode the ground truth information, suppose `bin_path` is the path to the ground truth file, `data_folder` is the target location of data preprocess. Eventually, we store the ground truth information in `${data_folder}/detection/gt/dets/`. 21 | 22 | ```bash 23 | cd preprocessing/waymo_data 24 | python gt_bin_decode.py --data_folder ${data_dir} --file_path ${bin_path} 25 | ``` 26 | 27 | ### 3. Detection 28 | 29 | To infer 3D MOT on your detection file, we still need the `bin_path` indicating the path to the detection results, then name your detection as `name` for future convenience. The preprocessing of the detection follows the below scripts. (Only use `metadata` if you want to save the velocity / acceleration contained in the detection file.) 30 | 31 | ```bash 32 | cd preprocessing/waymo_data 33 | python detection.py --name ${name} --data_folder ${data_dir} --file_path ${bin_path} --metadata --id 34 | ``` 35 | 36 | ## nuScenes 37 | 38 | ### 1. Preprocessing 39 | 40 | To preprocessing the raw data from nuScenes, suppose you have put the raw data of nuScenes at `raw_data_dir`. We provide two modes of proprocessing: 41 | * Only the data on the key frames (2Hz) is extracted, the target location is `data_dir_2hz`. 42 | * All the data (20Hz) is extracted to the location of `data_dir_20hz`. 43 | 44 | Run the following commands. 45 | 46 | ```bash 47 | cd preprocessing/nuscenes_data 48 | bash nuscenes_preprocess.sh ${raw_data_dir} ${data_dir_2hz} ${data_dir_20hz} 49 | ``` 50 | 51 | ### 2. Detection 52 | 53 | To infer 3D MOT on your detection file, we convert the json format detection files at `file_path` into the .npz files similar to our approach on Waymo Open Dataset. Please name your detection as `name` for future convenience. The preprocessing of the detection follows the below scripts. (Only use `velo` if you want to save the velocity contained in the detection file.) 54 | 55 | ```bash 56 | cd preprocessing/nuscenes_data 57 | 58 | # for 2Hz detection file 59 | python detection.py --raw_data_folder ${raw_data_dir} --data_folder ${data_dir_2hz} --det_name ${name} --file_path ${file_path} --mode 2hz --velo 60 | 61 | # for 20Hz detection file 62 | python detection.py --raw_data_folder ${raw_data_dir} --data_folder ${data_dir_20hz} --det_name ${name} --file_path ${file_path} --mode 20hz --velo 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/docs/demo.png -------------------------------------------------------------------------------- /docs/nuScenes.md: -------------------------------------------------------------------------------- 1 | # nuScenes Inference 2 | 3 | ## Data Preprocessing 4 | 5 | Please follow the command at [Data Preprocessing Documentations](./data_preprocess.md). After data preprocessing, the directories to the preprocessed data are `nuscenes2hz_data_dir` and `nuscenes20hz_data_dir` for the key-frame and all-frame data of nuScenes. 6 | 7 | ## Inference with SimpleTrack 8 | 9 | **Important: Please strictly follow the config file we use.** 10 | 11 | For the setting of inference only on key frames (2Hz), run the following command. The per-sequence results are then saved in the `${nuscenes_result_dir}/SimpleTrack2Hz/summary`, with subfolders of the types of the objects containing the results for each type. 12 | 13 | ```bash 14 | python tools/main_nuscenes.py \ 15 | --name SimpleTrack2Hz \ 16 | --det_name ${det_name} \ 17 | --config_path configs/nu_configs/giou.yaml \ 18 | --result_folder ${nuscenes_result_dir} \ 19 | --data_folder ${nuscenes2hz_data_dir} \ 20 | --process ${proc_num} 21 | ``` 22 | 23 | For the 10Hz settings proposed in the paper, run the following commands. The per-sequence results are then saved in the `${nuscenes_result_dir}/SimpleTrack20Hz/summary`. 24 | 25 | ```bash 26 | python tools/main_nuscenes_10hz.py \ 27 | --name SimpleTrack10Hz \ 28 | --det_name ${det_name} \ 29 | --config_path configs/nu_configs/giou.yaml \ 30 | --result_folder ${nuscenes_result_dir} \ 31 | --data_folder ${nuscenes20hz_data_dir} \ 32 | --process ${proc_num} 33 | ``` 34 | 35 | I use the process number of 150 in my experiments, which is the same as the number of sequences in nuScenes validation set. 36 | 37 | ## Output Format 38 | 39 | In the folder of `${nuscenes_result_dir}/SimpleTrack/summary/`, there are sub-folders corresponding to each object type in nuScenes. Inside each sub-folder, there are 150 `.npz` files, matching the 150 sequences in the nuScenes validation set. For the format in each `.npz` file, please refer to [Output Format](output_format.md). 40 | 41 | ## Converting to nuScenes Format 42 | 43 | Use the following command to convert the output results in the SimpleTrack format into the `.json` format specified by the nuScenes officials. After running the following commands, there will the tracking results in `.json` formats in `${nuscenes_result_dir}/SimpleTrack2Hz/results` and `${nuscenes_result_dir}/SimpleTrack10Hz/results`. 44 | 45 | For the setting of 2Hz, which only inferences on the key frames, run the following commands. 46 | 47 | ```bash 48 | python tools/nuscenes_result_creation.py \ 49 | --name SimpleTrack2Hz \ 50 | --result_folder ${nuscenes_result_dir} \ 51 | --data_folder ${nuscenes2hz_data_dir} 52 | 53 | python tools/nuscenes_type_merge.py \ 54 | --name SimpleTrack2Hz \ 55 | --result_folder ${nuscenes_result_dir} 56 | ``` 57 | 58 | For the setting of 10Hz, run the following commands. 59 | 60 | ```bash 61 | python tools/nuscenes_result_creation_10hz.py \ 62 | --name SimpleTrack10Hz \ 63 | --result_folder ${nuscenes_result_dir} \ 64 | --data_folder ${nuscenes20hz_data_dir} 65 | 66 | python tools/nuscenes_type_merge.py \ 67 | --name SimpleTrack10Hz \ 68 | --result_folder ${nuscenes_result_dir} 69 | ``` 70 | 71 | ## Files and Detailed Metrics 72 | 73 | Please see [Dropbox Link](https://www.dropbox.com/sh/8906exnes0u5e89/AAD0xLwW1nq_QiuUBaYDrQVna?dl=0). -------------------------------------------------------------------------------- /docs/output_format.md: -------------------------------------------------------------------------------- 1 | # Output Format -------------------------------------------------------------------------------- /docs/simpletrack_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/docs/simpletrack_gif.gif -------------------------------------------------------------------------------- /docs/waymo.md: -------------------------------------------------------------------------------- 1 | # Waymo Open Dataset Inference 2 | 3 | This documentation describes the steps to infer on the validation set of Waymo Open Dataset. 4 | 5 | ## Data Preprocessing 6 | 7 | We provide instructions for Waymo Open Dataset and nuScenes. Please follow the command at [Data Preprocessing Documentations](data_preprocess.md). After data preprocessing, the directories to the preprocessed data are `waymo_data_dir`, the name for the detection is `det_name`. 8 | 9 | ## Inference with SimpleTrack 10 | 11 | To run for a standard SimpleTrack setting experiment, run the following commands, the per-sequence results are then saved in the `${waymo_result_dir}/SimpleTrack/summary`, with subfolder taking the name `vehicle/`, `pedestrian/`, and `cyclist/`. 12 | 13 | In the experiments, 14 | * `${det_name}` denotes the name of the detection preprocessed; 15 | * `${waymo_result_dir}` is for saving tracking results; 16 | * `${waymo_data_dir}` is the folder to the preprocessed waymo data; 17 | * `${proc_num}` uses multiprocessing to inference different sequences (for your information, I generally use 202 processes, so that each one of the process runs a sequence of the validation set). 18 | * **Please look out for the different config files we use for each type of object.** 19 | 20 | ```bash 21 | # for vehicle 22 | python tools/main_waymo.py \ 23 | --name SimpleTrack \ 24 | --det_name ${det_name} \ 25 | --obj_type vehicle \ 26 | --config_path configs/waymo_configs/vc_kf_giou.yaml \ 27 | --data_folder ${waymo_data_dir} \ 28 | --result_folder ${waymo_result_dir} \ 29 | --gt_folder ${gt_dets_dir} \ 30 | --process ${proc_num} 31 | 32 | # for pedestrian 33 | python tools/main_waymo.py \ 34 | --name SimpleTrack \ 35 | --det_name ${det_name} \ 36 | --obj_type vehicle \ 37 | --config_path configs/waymo_configs/pd_kf_giou.yaml \ 38 | --data_folder ${waymo_data_dir} \ 39 | --result_folder ${waymo_result_dir} \ 40 | --process ${proc_num} 41 | 42 | # for cyclist 43 | python tools/main_waymo.py \ 44 | --name SimpleTrack \ 45 | --det_name ${det_name} \ 46 | --obj_type vehicle \ 47 | --config_path configs/waymo_configs/vc_kf_giou.yaml \ 48 | --data_folder ${waymo_data_dir} \ 49 | --result_folder ${waymo_result_dir} 50 | --process ${proc_num} 51 | ``` 52 | 53 | ## Output Format 54 | 55 | In the folder of `${waymo_result_dir}/SimpleTrack/summary/`, there are three sub-folders `vehicle`, `pedestrian`, and `cyclist`. Each of them contains 202 `.npz` files, corresponding to the sequences in the validation set of Waymo Open Dataset. For the format in each `.npz` file, please refer to [Output Format](output_format.md). 56 | 57 | ## Converting to Waymo Open Dataset Format 58 | 59 | To work with the official formats of Waymo Open Dataset, the following commands convert the results in SimpleTrack format to the `.bin` format. 60 | 61 | After running the command, we will have four `pred.bin` files in `${waymo_result_dir}/SimpleTrack/bin/` as `prd.bin` (all objects), `vehicle/pred.bin` (vehicles only), `pedestrain/pred.bin` (pedestrian only), and `cyclist/pred.bin` (cyclist only). 62 | 63 | ```bash 64 | python tools/waymo_pred_bin.py \ 65 | --name SimpleTrack \ 66 | --result_folder ${waymo_result_dir} \ 67 | --data_folder ${waymo_data_dir} 68 | ``` 69 | 70 | Eventually, use the evaluation provided by Waymo officials by following [Quick Guide to Waymo Open Dataset](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md). 71 | 72 | ## Files and Detailed Metrics 73 | 74 | Please see [Dropbox link](https://www.dropbox.com/sh/u6o8dcwmzya04uk/AAAUsNvTt7ubXul9dx5Xnp4xa?dl=0). 75 | -------------------------------------------------------------------------------- /mot_3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/mot_3d/__init__.py -------------------------------------------------------------------------------- /mot_3d/association.py: -------------------------------------------------------------------------------- 1 | import numpy as np, mot_3d.tracklet as tracklet 2 | from . import utils 3 | from scipy.optimize import linear_sum_assignment 4 | from .frame_data import FrameData 5 | from .update_info_data import UpdateInfoData 6 | from .data_protos import BBox, Validity 7 | 8 | 9 | def associate_dets_to_tracks(dets, tracks, mode, asso, 10 | dist_threshold=0.9, trk_innovation_matrix=None): 11 | """ associate the tracks with detections 12 | """ 13 | if mode == 'bipartite': 14 | matched_indices, dist_matrix = \ 15 | bipartite_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix) 16 | elif mode == 'greedy': 17 | matched_indices, dist_matrix = \ 18 | greedy_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix) 19 | unmatched_dets = list() 20 | for d, det in enumerate(dets): 21 | if d not in matched_indices[:, 0]: 22 | unmatched_dets.append(d) 23 | 24 | unmatched_tracks = list() 25 | for t, trk in enumerate(tracks): 26 | if t not in matched_indices[:, 1]: 27 | unmatched_tracks.append(t) 28 | 29 | matches = list() 30 | for m in matched_indices: 31 | if dist_matrix[m[0], m[1]] > dist_threshold: 32 | unmatched_dets.append(m[0]) 33 | unmatched_tracks.append(m[1]) 34 | else: 35 | matches.append(m.reshape(2)) 36 | return matches, np.array(unmatched_dets), np.array(unmatched_tracks) 37 | 38 | 39 | def bipartite_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix): 40 | if asso == 'iou': 41 | dist_matrix = compute_iou_distance(dets, tracks, asso) 42 | elif asso == 'giou': 43 | dist_matrix = compute_iou_distance(dets, tracks, asso) 44 | elif asso == 'm_dis': 45 | dist_matrix = compute_m_distance(dets, tracks, trk_innovation_matrix) 46 | elif asso == 'euler': 47 | dist_matrix = compute_m_distance(dets, tracks, None) 48 | row_ind, col_ind = linear_sum_assignment(dist_matrix) 49 | matched_indices = np.stack([row_ind, col_ind], axis=1) 50 | return matched_indices, dist_matrix 51 | 52 | 53 | def greedy_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix): 54 | """ it's ok to use iou in bipartite 55 | but greedy is only for m_distance 56 | """ 57 | matched_indices = list() 58 | 59 | # compute the distance matrix 60 | if asso == 'm_dis': 61 | distance_matrix = compute_m_distance(dets, tracks, trk_innovation_matrix) 62 | elif asso == 'euler': 63 | distance_matrix = compute_m_distance(dets, tracks, None) 64 | elif asso == 'iou': 65 | distance_matrix = compute_iou_distance(dets, tracks, asso) 66 | elif asso == 'giou': 67 | distance_matrix = compute_iou_distance(dets, tracks, asso) 68 | num_dets, num_trks = distance_matrix.shape 69 | 70 | # association in the greedy manner 71 | # refer to https://github.com/eddyhkchiu/mahalanobis_3d_multi_object_tracking/blob/master/main.py 72 | distance_1d = distance_matrix.reshape(-1) 73 | index_1d = np.argsort(distance_1d) 74 | index_2d = np.stack([index_1d // num_trks, index_1d % num_trks], axis=1) 75 | detection_id_matches_to_tracking_id = [-1] * num_dets 76 | tracking_id_matches_to_detection_id = [-1] * num_trks 77 | for sort_i in range(index_2d.shape[0]): 78 | detection_id = int(index_2d[sort_i][0]) 79 | tracking_id = int(index_2d[sort_i][1]) 80 | if tracking_id_matches_to_detection_id[tracking_id] == -1 and detection_id_matches_to_tracking_id[detection_id] == -1: 81 | tracking_id_matches_to_detection_id[tracking_id] = detection_id 82 | detection_id_matches_to_tracking_id[detection_id] = tracking_id 83 | matched_indices.append([detection_id, tracking_id]) 84 | if len(matched_indices) == 0: 85 | matched_indices = np.empty((0, 2)) 86 | else: 87 | matched_indices = np.asarray(matched_indices) 88 | return matched_indices, distance_matrix 89 | 90 | 91 | def compute_m_distance(dets, tracks, trk_innovation_matrix): 92 | """ compute l2 or mahalanobis distance 93 | when the input trk_innovation_matrix is None, compute L2 distance (euler) 94 | else compute mahalanobis distance 95 | return dist_matrix: numpy array [len(dets), len(tracks)] 96 | """ 97 | euler_dis = (trk_innovation_matrix is None) # is use euler distance 98 | if not euler_dis: 99 | trk_inv_inn_matrices = [np.linalg.inv(m) for m in trk_innovation_matrix] 100 | dist_matrix = np.empty((len(dets), len(tracks))) 101 | 102 | for i, det in enumerate(dets): 103 | for j, trk in enumerate(tracks): 104 | if euler_dis: 105 | dist_matrix[i, j] = utils.m_distance(det, trk) 106 | else: 107 | dist_matrix[i, j] = utils.m_distance(det, trk, trk_inv_inn_matrices[j]) 108 | return dist_matrix 109 | 110 | 111 | def compute_iou_distance(dets, tracks, asso='iou'): 112 | iou_matrix = np.zeros((len(dets), len(tracks))) 113 | for d, det in enumerate(dets): 114 | for t, trk in enumerate(tracks): 115 | if asso == 'iou': 116 | iou_matrix[d, t] = utils.iou3d(det, trk)[1] 117 | elif asso == 'giou': 118 | iou_matrix[d, t] = utils.giou3d(det, trk) 119 | dist_matrix = 1 - iou_matrix 120 | return dist_matrix 121 | -------------------------------------------------------------------------------- /mot_3d/data_protos/__init__.py: -------------------------------------------------------------------------------- 1 | from mot_3d.data_protos.bbox import BBox 2 | from mot_3d.data_protos.validity import Validity -------------------------------------------------------------------------------- /mot_3d/data_protos/bbox.py: -------------------------------------------------------------------------------- 1 | """ The defination and basic methods of bbox 2 | """ 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class BBox: 8 | def __init__(self, x=None, y=None, z=None, h=None, w=None, l=None, o=None): 9 | self.x = x # center x 10 | self.y = y # center y 11 | self.z = z # center z 12 | self.h = h # height 13 | self.w = w # width 14 | self.l = l # length 15 | self.o = o # orientation 16 | self.s = None # detection score 17 | 18 | def __str__(self): 19 | return 'x: {}, y: {}, z: {}, heading: {}, length: {}, width: {}, height: {}, score: {}'.format( 20 | self.x, self.y, self.z, self.o, self.l, self.w, self.h, self.s) 21 | 22 | @classmethod 23 | def bbox2dict(cls, bbox): 24 | return { 25 | 'center_x': bbox.x, 'center_y': bbox.y, 'center_z': bbox.z, 26 | 'height': bbox.h, 'width': bbox.w, 'length': bbox.l, 'heading': bbox.o} 27 | 28 | @classmethod 29 | def bbox2array(cls, bbox): 30 | if bbox.s is None: 31 | return np.array([bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h]) 32 | else: 33 | return np.array([bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h, bbox.s]) 34 | 35 | @classmethod 36 | def array2bbox(cls, data): 37 | bbox = BBox() 38 | bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h = data[:7] 39 | if len(data) == 8: 40 | bbox.s = data[-1] 41 | return bbox 42 | 43 | @classmethod 44 | def dict2bbox(cls, data): 45 | bbox = BBox() 46 | bbox.x = data['center_x'] 47 | bbox.y = data['center_y'] 48 | bbox.z = data['center_z'] 49 | bbox.h = data['height'] 50 | bbox.w = data['width'] 51 | bbox.l = data['length'] 52 | bbox.o = data['heading'] 53 | if 'score' in data.keys(): 54 | bbox.s = data['score'] 55 | return bbox 56 | 57 | @classmethod 58 | def copy_bbox(cls, bboxa, bboxb): 59 | bboxa.x = bboxb.x 60 | bboxa.y = bboxb.y 61 | bboxa.z = bboxb.z 62 | bboxa.l = bboxb.l 63 | bboxa.w = bboxb.w 64 | bboxa.h = bboxb.h 65 | bboxa.o = bboxb.o 66 | bboxa.s = bboxb.s 67 | return 68 | 69 | @classmethod 70 | def box2corners2d(cls, bbox): 71 | """ the coordinates for bottom corners 72 | """ 73 | bottom_center = np.array([bbox.x, bbox.y, bbox.z - bbox.h / 2]) 74 | cos, sin = np.cos(bbox.o), np.sin(bbox.o) 75 | pc0 = np.array([bbox.x + cos * bbox.l / 2 + sin * bbox.w / 2, 76 | bbox.y + sin * bbox.l / 2 - cos * bbox.w / 2, 77 | bbox.z - bbox.h / 2]) 78 | pc1 = np.array([bbox.x + cos * bbox.l / 2 - sin * bbox.w / 2, 79 | bbox.y + sin * bbox.l / 2 + cos * bbox.w / 2, 80 | bbox.z - bbox.h / 2]) 81 | pc2 = 2 * bottom_center - pc0 82 | pc3 = 2 * bottom_center - pc1 83 | 84 | return [pc0.tolist(), pc1.tolist(), pc2.tolist(), pc3.tolist()] 85 | 86 | @classmethod 87 | def box2corners3d(cls, bbox): 88 | """ the coordinates for bottom corners 89 | """ 90 | center = np.array([bbox.x, bbox.y, bbox.z]) 91 | bottom_corners = np.array(BBox.box2corners2d(bbox)) 92 | up_corners = 2 * center - bottom_corners 93 | corners = np.concatenate([up_corners, bottom_corners], axis=0) 94 | return corners.tolist() 95 | 96 | @classmethod 97 | def motion2bbox(cls, bbox, motion): 98 | result = deepcopy(bbox) 99 | result.x += motion[0] 100 | result.y += motion[1] 101 | result.z += motion[2] 102 | result.o += motion[3] 103 | return result 104 | 105 | @classmethod 106 | def set_bbox_size(cls, bbox, size_array): 107 | result = deepcopy(bbox) 108 | result.l, result.w, result.h = size_array 109 | return result 110 | 111 | @classmethod 112 | def set_bbox_with_states(cls, prev_bbox, state_array): 113 | prev_array = BBox.bbox2array(prev_bbox) 114 | prev_array[:4] += state_array[:4] 115 | prev_array[4:] = state_array[4:] 116 | bbox = BBox.array2bbox(prev_array) 117 | return bbox 118 | 119 | @classmethod 120 | def box_pts2world(cls, ego_matrix, pcs): 121 | new_pcs = np.concatenate((pcs, 122 | np.ones(pcs.shape[0])[:, np.newaxis]), 123 | axis=1) 124 | new_pcs = ego_matrix @ new_pcs.T 125 | new_pcs = new_pcs.T[:, :3] 126 | return new_pcs 127 | 128 | @classmethod 129 | def edge2yaw(cls, center, edge): 130 | vec = edge - center 131 | yaw = np.arccos(vec[0] / np.linalg.norm(vec)) 132 | if vec[1] < 0: 133 | yaw = -yaw 134 | return yaw 135 | 136 | @classmethod 137 | def bbox2world(cls, ego_matrix, box): 138 | # center and corners 139 | corners = np.array(BBox.box2corners2d(box)) 140 | center = BBox.bbox2array(box)[:3][np.newaxis, :] 141 | center = BBox.box_pts2world(ego_matrix, center)[0] 142 | corners = BBox.box_pts2world(ego_matrix, corners) 143 | # heading 144 | edge_mid_point = (corners[0] + corners[1]) / 2 145 | yaw = BBox.edge2yaw(center[:2], edge_mid_point[:2]) 146 | 147 | result = deepcopy(box) 148 | result.x, result.y, result.z = center 149 | result.o = yaw 150 | return result -------------------------------------------------------------------------------- /mot_3d/data_protos/validity.py: -------------------------------------------------------------------------------- 1 | class Validity: 2 | TYPES = ['birth', 'alive', 'death'] 3 | def __init__(self): 4 | return 5 | 6 | @classmethod 7 | def valid(cls, state_string): 8 | tokens = state_string.split('_') 9 | if tokens[0] == 'birth': 10 | return True 11 | if len(tokens) < 3: 12 | return False 13 | if tokens[0] == 'alive' and int(tokens[1]) == 1: 14 | return True 15 | return False 16 | 17 | @classmethod 18 | def notoutput(cls, state_string): 19 | tokens = state_string.split('_') 20 | if len(tokens) < 3: 21 | return False 22 | if tokens[0] == 'alive' and int(tokens[1]) != 1: 23 | return True 24 | return False 25 | 26 | @classmethod 27 | def predicted(cls, state_string): 28 | state, token = state_string.split('_') 29 | if state not in Validity.TYPES: 30 | raise ValueError('type name not existed') 31 | 32 | if state == 'alive' and int(token) != 0: 33 | return True 34 | return False 35 | 36 | @classmethod 37 | def modify_string(cls, state_string, mode): 38 | tokens = state_string.split('_') 39 | tokens[1] = str(mode) 40 | return '{:}_{:}_{:}'.format(tokens[0], tokens[1], tokens[2]) -------------------------------------------------------------------------------- /mot_3d/frame_data.py: -------------------------------------------------------------------------------- 1 | """ input form of the data in each frame 2 | """ 3 | from .data_protos import BBox 4 | import numpy as np, mot_3d.utils as utils 5 | 6 | 7 | class FrameData: 8 | def __init__(self, dets, ego, time_stamp=None, pc=None, det_types=None, aux_info=None): 9 | self.dets = dets # detections for each frame 10 | self.ego = ego # ego matrix information 11 | self.pc = pc 12 | self.det_types = det_types 13 | self.time_stamp = time_stamp 14 | self.aux_info = aux_info 15 | 16 | for i, det in enumerate(self.dets): 17 | self.dets[i] = BBox.array2bbox(det) 18 | 19 | # if not aux_info['is_key_frame']: 20 | # self.dets = [d for d in self.dets if d.s >= 0.5] -------------------------------------------------------------------------------- /mot_3d/life/__init__.py: -------------------------------------------------------------------------------- 1 | from ..life.hit_manager import HitManager -------------------------------------------------------------------------------- /mot_3d/life/hit_manager.py: -------------------------------------------------------------------------------- 1 | """ a finite state machine to manage the life cycle 2 | states: 3 | - birth: first founded 4 | - alive: alive 5 | - no_asso: without high score association, about to die 6 | - dead: may it eternal peace 7 | """ 8 | import numpy as np 9 | from ..data_protos import Validity 10 | from ..update_info_data import UpdateInfoData 11 | from .. import utils 12 | 13 | 14 | class HitManager: 15 | def __init__(self, configs, frame_index): 16 | self.time_since_update = 0 17 | self.hits = 1 # number of total hits including the first detection 18 | self.hit_streak = 1 # number of continuing hit considering the first detection 19 | self.first_continuing_hit = 1 20 | self.still_first = True 21 | self.age = 0 22 | self.recent_state = None 23 | 24 | self.max_age = configs['running']['max_age_since_update'] 25 | self.min_hits = configs['running']['min_hits_to_birth'] 26 | 27 | self.state = 'birth' 28 | self.recent_state = 1 29 | self.no_asso = False 30 | if frame_index <= self.min_hits or self.min_hits == 0: 31 | self.state = 'alive' 32 | self.recent_state = 1 33 | 34 | def predict(self, is_key_frame): 35 | # only on key frame 36 | # unassociated prediction affects the life-cycle management 37 | if not is_key_frame: 38 | return 39 | 40 | self.age += 1 41 | if self.time_since_update > 0: 42 | self.hit_streak = 0 43 | self.still_first = False 44 | self.time_since_update += 1 45 | self.fall = True 46 | return 47 | 48 | def if_valid(self, update_info): 49 | self.recent_state = update_info.mode 50 | return update_info.mode 51 | 52 | def update(self, update_info: UpdateInfoData, is_key_frame=True): 53 | # the update happening during the non-key-frame 54 | # can extend the life of tracklet 55 | association = self.if_valid(update_info) 56 | self.recent_state = association 57 | if association != 0: 58 | self.fall = False 59 | self.time_since_update = 0 60 | self.history = [] 61 | self.hits += 1 62 | self.hit_streak += 1 # number of continuing hit 63 | if self.still_first: 64 | self.first_continuing_hit += 1 # number of continuing hit in the fist time 65 | if is_key_frame: 66 | self.state_transition(association, update_info.frame_index) 67 | 68 | def state_transition(self, mode, frame_index): 69 | # if just founded 70 | if self.state == 'birth': 71 | if (self.hits >= self.min_hits) or (frame_index <= self.min_hits): 72 | self.state = 'alive' 73 | self.recent_state = mode 74 | elif self.time_since_update >= self.max_age: 75 | self.state = 'dead' 76 | # already alive 77 | elif self.state == 'alive': 78 | if self.time_since_update >= self.max_age: 79 | self.state = 'dead' 80 | 81 | def alive(self, frame_index): 82 | return self.state == 'alive' 83 | 84 | def death(self, frame_index): 85 | return self.state == 'dead' 86 | 87 | def valid_output(self, frame_index): 88 | return (self.state == 'alive') and (self.no_asso == False) 89 | 90 | def state_string(self, frame_index): 91 | """ Each tracklet use a state strong to represent its state 92 | This string is used for determining output, etc. 93 | """ 94 | if self.state == 'birth': 95 | return '{:}_{:}'.format(self.state, self.hits) 96 | elif self.state == 'alive': 97 | return '{:}_{:}_{:}'.format(self.state, self.recent_state, self.time_since_update) 98 | elif self.state == 'dead': 99 | return '{:}_{:}'.format(self.state, self.time_since_update) -------------------------------------------------------------------------------- /mot_3d/mot.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np, mot_3d.tracklet as tracklet, mot_3d.utils as utils 3 | from .redundancy import RedundancyModule 4 | from scipy.optimize import linear_sum_assignment 5 | from .frame_data import FrameData 6 | from .update_info_data import UpdateInfoData 7 | from .data_protos import BBox, Validity 8 | from .association import associate_dets_to_tracks 9 | from . import visualization 10 | from mot_3d import redundancy 11 | import pdb, os 12 | 13 | 14 | class MOTModel: 15 | def __init__(self, configs): 16 | self.trackers = list() # tracker for each single tracklet 17 | self.frame_count = 0 # record for the frames 18 | self.count = 0 # record the obj number to assign ids 19 | self.time_stamp = None # the previous time stamp 20 | self.redundancy = RedundancyModule(configs) # module for no detection cases 21 | 22 | non_key_redundancy_config = deepcopy(configs) 23 | non_key_redundancy_config['redundancy'] = { 24 | 'mode': 'mm', 25 | 'det_score_threshold': {'giou': 0.1, 'iou': 0.1, 'euler': 0.1}, 26 | 'det_dist_threshold': {'giou': -0.5, 'iou': 0.1, 'euler': 4} 27 | } 28 | self.non_key_redundancy = RedundancyModule(non_key_redundancy_config) 29 | 30 | self.configs = configs 31 | self.match_type = configs['running']['match_type'] 32 | self.score_threshold = configs['running']['score_threshold'] 33 | self.asso = configs['running']['asso'] 34 | self.asso_thres = configs['running']['asso_thres'][self.asso] 35 | self.motion_model = configs['running']['motion_model'] 36 | 37 | self.max_age = configs['running']['max_age_since_update'] 38 | self.min_hits = configs['running']['min_hits_to_birth'] 39 | 40 | @property 41 | def has_velo(self): 42 | return not (self.motion_model == 'kf' or self.motion_model == 'fbkf' or self.motion_model == 'ma') 43 | 44 | def frame_mot(self, input_data: FrameData): 45 | """ For each frame input, generate the latest mot results 46 | Args: 47 | input_data (FrameData): input data, including detection bboxes and ego information 48 | Returns: 49 | tracks on this frame: [(bbox0, id0), (bbox1, id1), ...] 50 | """ 51 | self.frame_count += 1 52 | 53 | # initialize the time stamp on frame 0 54 | if self.time_stamp is None: 55 | self.time_stamp = input_data.time_stamp 56 | 57 | if not input_data.aux_info['is_key_frame']: 58 | result = self.non_key_frame_mot(input_data) 59 | return result 60 | 61 | if 'kf' in self.motion_model: 62 | matched, unmatched_dets, unmatched_trks = self.forward_step_trk(input_data) 63 | 64 | time_lag = input_data.time_stamp - self.time_stamp 65 | # update the matched tracks 66 | for t, trk in enumerate(self.trackers): 67 | if t not in unmatched_trks: 68 | for k in range(len(matched)): 69 | if matched[k][1] == t: 70 | d = matched[k][0] 71 | break 72 | if self.has_velo: 73 | aux_info = { 74 | 'velo': list(input_data.aux_info['velos'][d]), 75 | 'is_key_frame': input_data.aux_info['is_key_frame']} 76 | else: 77 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']} 78 | update_info = UpdateInfoData(mode=1, bbox=input_data.dets[d], ego=input_data.ego, 79 | frame_index=self.frame_count, pc=input_data.pc, 80 | dets=input_data.dets, aux_info=aux_info) 81 | trk.update(update_info) 82 | else: 83 | result_bbox, update_mode, aux_info = self.redundancy.infer(trk, input_data, time_lag) 84 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']} 85 | update_info = UpdateInfoData(mode=update_mode, bbox=result_bbox, 86 | ego=input_data.ego, frame_index=self.frame_count, 87 | pc=input_data.pc, dets=input_data.dets, aux_info=aux_info) 88 | trk.update(update_info) 89 | 90 | # create new tracks for unmatched detections 91 | for index in unmatched_dets: 92 | if self.has_velo: 93 | aux_info = { 94 | 'velo': list(input_data.aux_info['velos'][index]), 95 | 'is_key_frame': input_data.aux_info['is_key_frame']} 96 | else: 97 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']} 98 | 99 | track = tracklet.Tracklet(self.configs, self.count, input_data.dets[index], input_data.det_types[index], 100 | self.frame_count, aux_info=aux_info, time_stamp=input_data.time_stamp) 101 | self.trackers.append(track) 102 | self.count += 1 103 | 104 | # remove dead tracks 105 | track_num = len(self.trackers) 106 | for index, trk in enumerate(reversed(self.trackers)): 107 | if trk.death(self.frame_count): 108 | self.trackers.pop(track_num - 1 - index) 109 | 110 | # output the results 111 | result = list() 112 | for trk in self.trackers: 113 | state_string = trk.state_string(self.frame_count) 114 | result.append((trk.get_state(), trk.id, state_string, trk.det_type)) 115 | 116 | # wrap up and update the information about the mot trackers 117 | self.time_stamp = input_data.time_stamp 118 | for trk in self.trackers: 119 | trk.sync_time_stamp(self.time_stamp) 120 | 121 | return result 122 | 123 | def forward_step_trk(self, input_data: FrameData): 124 | dets = input_data.dets 125 | det_indexes = [i for i, det in enumerate(dets) if det.s >= self.score_threshold] 126 | dets = [dets[i] for i in det_indexes] 127 | 128 | # prediction and association 129 | trk_preds = list() 130 | for trk in self.trackers: 131 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame'])) 132 | 133 | # for m-distance association 134 | trk_innovation_matrix = None 135 | if self.asso == 'm_dis': 136 | trk_innovation_matrix = [trk.compute_innovation_matrix() for trk in self.trackers] 137 | 138 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds, 139 | self.match_type, self.asso, self.asso_thres, trk_innovation_matrix) 140 | 141 | for k in range(len(matched)): 142 | matched[k][0] = det_indexes[matched[k][0]] 143 | for k in range(len(unmatched_dets)): 144 | unmatched_dets[k] = det_indexes[unmatched_dets[k]] 145 | return matched, unmatched_dets, unmatched_trks 146 | 147 | def non_key_forward_step_trk(self, input_data: FrameData): 148 | """ tracking on non-key frames (for nuScenes) 149 | """ 150 | dets = input_data.dets 151 | det_indexes = [i for i, det in enumerate(dets) if det.s >= 0.5] 152 | dets = [dets[i] for i in det_indexes] 153 | 154 | # prediction and association 155 | trk_preds = list() 156 | for trk in self.trackers: 157 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame'])) 158 | 159 | # for m-distance association 160 | trk_innovation_matrix = None 161 | if self.asso == 'm_dis': 162 | trk_innovation_matrix = [trk.compute_innovation_matrix() for trk in self.trackers] 163 | 164 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds, 165 | self.match_type, self.asso, self.asso_thres, trk_innovation_matrix) 166 | 167 | for k in range(len(matched)): 168 | matched[k][0] = det_indexes[matched[k][0]] 169 | for k in range(len(unmatched_dets)): 170 | unmatched_dets[k] = det_indexes[unmatched_dets[k]] 171 | return matched, unmatched_dets, unmatched_trks 172 | 173 | def non_key_frame_mot(self, input_data: FrameData): 174 | """ tracking on non-key frames (for nuScenes) 175 | """ 176 | self.frame_count += 1 177 | # initialize the time stamp on frame 0 178 | if self.time_stamp is None: 179 | self.time_stamp = input_data.time_stamp 180 | 181 | if 'kf' in self.motion_model: 182 | matched, unmatched_dets, unmatched_trks = self.non_key_forward_step_trk(input_data) 183 | time_lag = input_data.time_stamp - self.time_stamp 184 | 185 | redundancy_bboxes, update_modes = self.non_key_redundancy.bipartite_infer(input_data, self.trackers) 186 | # update the matched tracks 187 | for t, trk in enumerate(self.trackers): 188 | if t not in unmatched_trks: 189 | for k in range(len(matched)): 190 | if matched[k][1] == t: 191 | d = matched[k][0] 192 | break 193 | if self.has_velo: 194 | aux_info = { 195 | 'velo': list(input_data.aux_info['velos'][d]), 196 | 'is_key_frame': input_data.aux_info['is_key_frame']} 197 | else: 198 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']} 199 | update_info = UpdateInfoData(mode=1, bbox=input_data.dets[d], ego=input_data.ego, 200 | frame_index=self.frame_count, pc=input_data.pc, 201 | dets=input_data.dets, aux_info=aux_info) 202 | trk.update(update_info) 203 | else: 204 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']} 205 | update_info = UpdateInfoData(mode=update_modes[t], bbox=redundancy_bboxes[t], 206 | ego=input_data.ego, frame_index=self.frame_count, 207 | pc=input_data.pc, dets=input_data.dets, aux_info=aux_info) 208 | trk.update(update_info) 209 | 210 | # output the results 211 | result = list() 212 | for trk in self.trackers: 213 | state_string = trk.state_string(self.frame_count) 214 | result.append((trk.get_state(), trk.id, state_string, trk.det_type)) 215 | 216 | # wrap up and update the information about the mot trackers 217 | self.time_stamp = input_data.time_stamp 218 | for trk in self.trackers: 219 | trk.sync_time_stamp(self.time_stamp) 220 | 221 | return result -------------------------------------------------------------------------------- /mot_3d/motion_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kalman_filter import KalmanFilterMotionModel 2 | -------------------------------------------------------------------------------- /mot_3d/motion_model/kalman_filter.py: -------------------------------------------------------------------------------- 1 | """ Many parts are borrowed from https://github.com/xinshuoweng/AB3DMOT 2 | """ 3 | 4 | import numpy as np 5 | from ..data_protos import BBox 6 | from filterpy.kalman import KalmanFilter 7 | 8 | 9 | class KalmanFilterMotionModel: 10 | def __init__(self, bbox: BBox, inst_type, time_stamp, covariance='default'): 11 | # the time stamp of last observation 12 | self.prev_time_stamp = time_stamp 13 | self.latest_time_stamp = time_stamp 14 | # define constant velocity model 15 | self.score = bbox.s 16 | self.inst_type = inst_type 17 | 18 | self.kf = KalmanFilter(dim_x=10, dim_z=7) 19 | self.kf.x[:7] = BBox.bbox2array(bbox)[:7].reshape((7, 1)) 20 | self.kf.F = np.array([[1,0,0,0,0,0,0,1,0,0], # state transition matrix 21 | [0,1,0,0,0,0,0,0,1,0], 22 | [0,0,1,0,0,0,0,0,0,1], 23 | [0,0,0,1,0,0,0,0,0,0], 24 | [0,0,0,0,1,0,0,0,0,0], 25 | [0,0,0,0,0,1,0,0,0,0], 26 | [0,0,0,0,0,0,1,0,0,0], 27 | [0,0,0,0,0,0,0,1,0,0], 28 | [0,0,0,0,0,0,0,0,1,0], 29 | [0,0,0,0,0,0,0,0,0,1]]) 30 | 31 | self.kf.H = np.array([[1,0,0,0,0,0,0,0,0,0], # measurement function, 32 | [0,1,0,0,0,0,0,0,0,0], 33 | [0,0,1,0,0,0,0,0,0,0], 34 | [0,0,0,1,0,0,0,0,0,0], 35 | [0,0,0,0,1,0,0,0,0,0], 36 | [0,0,0,0,0,1,0,0,0,0], 37 | [0,0,0,0,0,0,1,0,0,0]]) 38 | 39 | self.kf.B = np.zeros((10, 1)) # dummy control transition matrix 40 | 41 | # # with angular velocity 42 | # self.kf = KalmanFilter(dim_x=11, dim_z=7) 43 | # self.kf.F = np.array([[1,0,0,0,0,0,0,1,0,0,0], # state transition matrix 44 | # [0,1,0,0,0,0,0,0,1,0,0], 45 | # [0,0,1,0,0,0,0,0,0,1,0], 46 | # [0,0,0,1,0,0,0,0,0,0,1], 47 | # [0,0,0,0,1,0,0,0,0,0,0], 48 | # [0,0,0,0,0,1,0,0,0,0,0], 49 | # [0,0,0,0,0,0,1,0,0,0,0], 50 | # [0,0,0,0,0,0,0,1,0,0,0], 51 | # [0,0,0,0,0,0,0,0,1,0,0], 52 | # [0,0,0,0,0,0,0,0,0,1,0], 53 | # [0,0,0,0,0,0,0,0,0,0,1]]) 54 | 55 | # self.kf.H = np.array([[1,0,0,0,0,0,0,0,0,0,0], # measurement function, 56 | # [0,1,0,0,0,0,0,0,0,0,0], 57 | # [0,0,1,0,0,0,0,0,0,0,0], 58 | # [0,0,0,1,0,0,0,0,0,0,0], 59 | # [0,0,0,0,1,0,0,0,0,0,0], 60 | # [0,0,0,0,0,1,0,0,0,0,0], 61 | # [0,0,0,0,0,0,1,0,0,0,0]]) 62 | 63 | self.covariance_type = covariance 64 | # self.kf.R[0:,0:] *= 10. # measurement uncertainty 65 | self.kf.P[7:, 7:] *= 1000. # state uncertainty, give high uncertainty to the unobservable initial velocities, covariance matrix 66 | self.kf.P *= 10. 67 | 68 | # self.kf.Q[-1,-1] *= 0.01 # process uncertainty 69 | # self.kf.Q[7:, 7:] *= 0.01 70 | 71 | self.history = [bbox] 72 | 73 | def predict(self, time_stamp=None): 74 | """ For the motion prediction, use the get_prediction function. 75 | """ 76 | self.kf.predict() 77 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2 78 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2 79 | return 80 | 81 | def update(self, det_bbox: BBox, aux_info=None): 82 | """ 83 | Updates the state vector with observed bbox. 84 | """ 85 | bbox = BBox.bbox2array(det_bbox)[:7] 86 | 87 | # full pipeline of kf, first predict, then update 88 | self.predict() 89 | 90 | ######################### orientation correction 91 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the range 92 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2 93 | 94 | new_theta = bbox[3] 95 | if new_theta >= np.pi: new_theta -= np.pi * 2 # make the theta still in the range 96 | if new_theta < -np.pi: new_theta += np.pi * 2 97 | bbox[3] = new_theta 98 | 99 | predicted_theta = self.kf.x[3] 100 | if np.abs(new_theta - predicted_theta) > np.pi / 2.0 and np.abs(new_theta - predicted_theta) < np.pi * 3 / 2.0: # if the angle of two theta is not acute angle 101 | self.kf.x[3] += np.pi 102 | if self.kf.x[3] > np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the range 103 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2 104 | 105 | # now the angle is acute: < 90 or > 270, convert the case of > 270 to < 90 106 | if np.abs(new_theta - self.kf.x[3]) >= np.pi * 3 / 2.0: 107 | if new_theta > 0: self.kf.x[3] += np.pi * 2 108 | else: self.kf.x[3] -= np.pi * 2 109 | 110 | ######################### # flip 111 | 112 | self.kf.update(bbox) 113 | self.prev_time_stamp = self.latest_time_stamp 114 | 115 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the rage 116 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2 117 | 118 | if det_bbox.s is None: 119 | self.score = self.score * 0.01 120 | else: 121 | self.score = det_bbox.s 122 | 123 | cur_bbox = self.kf.x[:7].reshape(-1).tolist() 124 | cur_bbox = BBox.array2bbox(cur_bbox + [self.score]) 125 | self.history[-1] = cur_bbox 126 | return 127 | 128 | def get_prediction(self, time_stamp=None): 129 | """ 130 | Advances the state vector and returns the predicted bounding box estimate. 131 | """ 132 | time_lag = time_stamp - self.prev_time_stamp 133 | self.latest_time_stamp = time_stamp 134 | self.kf.F = np.array([[1,0,0,0,0,0,0,time_lag,0,0], # state transition matrix 135 | [0,1,0,0,0,0,0,0,time_lag,0], 136 | [0,0,1,0,0,0,0,0,0,time_lag], 137 | [0,0,0,1,0,0,0,0,0,0], 138 | [0,0,0,0,1,0,0,0,0,0], 139 | [0,0,0,0,0,1,0,0,0,0], 140 | [0,0,0,0,0,0,1,0,0,0], 141 | [0,0,0,0,0,0,0,1,0,0], 142 | [0,0,0,0,0,0,0,0,1,0], 143 | [0,0,0,0,0,0,0,0,0,1]]) 144 | pred_x = self.kf.get_prediction()[0] 145 | if pred_x[3] >= np.pi: pred_x[3] -= np.pi * 2 146 | if pred_x[3] < -np.pi: pred_x[3] += np.pi * 2 147 | pred_bbox = BBox.array2bbox(pred_x[:7].reshape(-1)) 148 | 149 | self.history.append(pred_bbox) 150 | return pred_bbox 151 | 152 | def get_state(self): 153 | """ 154 | Returns the current bounding box estimate. 155 | """ 156 | return self.history[-1] 157 | 158 | def compute_innovation_matrix(self): 159 | """ compute the innovation matrix for association with mahalonobis distance 160 | """ 161 | return np.matmul(np.matmul(self.kf.H, self.kf.P), self.kf.H.T) + self.kf.R 162 | 163 | def sync_time_stamp(self, time_stamp): 164 | self.time_stamp = time_stamp 165 | return 166 | -------------------------------------------------------------------------------- /mot_3d/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms import nms -------------------------------------------------------------------------------- /mot_3d/preprocessing/bbox_coarse_hash.py: -------------------------------------------------------------------------------- 1 | """ Split the area into grid boxes 2 | BBoxes in different grid boxes without overlap cannot have overlap 3 | """ 4 | import numpy as np 5 | from ..data_protos import BBox 6 | 7 | 8 | class BBoxCoarseFilter: 9 | def __init__(self, grid_size, scaler=100): 10 | self.gsize = grid_size 11 | self.scaler = 100 12 | self.bbox_dict = dict() 13 | 14 | def bboxes2dict(self, bboxes): 15 | for i, bbox in enumerate(bboxes): 16 | grid_keys = self.compute_bbox_key(bbox) 17 | for key in grid_keys: 18 | if key not in self.bbox_dict.keys(): 19 | self.bbox_dict[key] = set([i]) 20 | else: 21 | self.bbox_dict[key].add(i) 22 | return 23 | 24 | def compute_bbox_key(self, bbox): 25 | corners = np.asarray(BBox.box2corners2d(bbox)) 26 | min_keys = np.floor(np.min(corners, axis=0) / self.gsize).astype(np.int) 27 | max_keys = np.floor(np.max(corners, axis=0) / self.gsize).astype(np.int) 28 | 29 | # enumerate all the corners 30 | grid_keys = [ 31 | self.scaler * min_keys[0] + min_keys[1], 32 | self.scaler * min_keys[0] + max_keys[1], 33 | self.scaler * max_keys[0] + min_keys[1], 34 | self.scaler * max_keys[0] + max_keys[1] 35 | ] 36 | return grid_keys 37 | 38 | def related_bboxes(self, bbox): 39 | """ return the list of related bboxes 40 | """ 41 | result = set() 42 | grid_keys = self.compute_bbox_key(bbox) 43 | for key in grid_keys: 44 | if key in self.bbox_dict.keys(): 45 | result.update(self.bbox_dict[key]) 46 | return list(result) 47 | 48 | def clear(self): 49 | self.bbox_dict = dict() -------------------------------------------------------------------------------- /mot_3d/preprocessing/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .. import utils 3 | from ..data_protos import BBox 4 | from .bbox_coarse_hash import BBoxCoarseFilter 5 | 6 | 7 | def weird_bbox(bbox): 8 | if bbox.l <= 0 or bbox.w <= 0 or bbox.h <= 0: 9 | return True 10 | else: 11 | return False 12 | 13 | 14 | def nms(dets, inst_types, threshold_low=0.1, threshold_high=1.0, threshold_yaw=0.3): 15 | """ keep the bboxes with overlap <= threshold 16 | """ 17 | dets_coarse_filter = BBoxCoarseFilter(grid_size=100, scaler=100) 18 | dets_coarse_filter.bboxes2dict(dets) 19 | scores = np.asarray([det.s for det in dets]) 20 | yaws = np.asarray([det.o for det in dets]) 21 | order = np.argsort(scores)[::-1] 22 | 23 | result_indexes = list() 24 | result_types = list() 25 | while order.size > 0: 26 | index = order[0] 27 | 28 | if weird_bbox(dets[index]): 29 | order = order[1:] 30 | continue 31 | 32 | # locate the related bboxes 33 | filter_indexes = dets_coarse_filter.related_bboxes(dets[index]) 34 | in_mask = np.isin(order, filter_indexes) 35 | related_idxes = order[in_mask] 36 | related_idxes = np.asarray([i for i in related_idxes if inst_types[i] == inst_types[index]]) 37 | 38 | # compute the ious 39 | bbox_num = len(related_idxes) 40 | ious = np.zeros(bbox_num) 41 | for i, idx in enumerate(related_idxes): 42 | ious[i] = utils.iou3d(dets[index], dets[idx])[1] 43 | related_inds = np.where(ious > threshold_low) 44 | related_inds_vote = np.where(ious > threshold_high) 45 | order_vote = related_idxes[related_inds_vote] 46 | 47 | if len(order_vote) >= 2: 48 | # keep the bboxes with similar yaw 49 | if order_vote.shape[0] <= 2: 50 | score_index = np.argmax(scores[order_vote]) 51 | median_yaw = yaws[order_vote][score_index] 52 | elif order_vote.shape[0] % 2 == 0: 53 | tmp_yaw = yaws[order_vote].copy() 54 | tmp_yaw = np.append(tmp_yaw, yaws[order_vote][0]) 55 | median_yaw = np.median(tmp_yaw) 56 | else: 57 | median_yaw = np.median(yaws[order_vote]) 58 | yaw_vote = np.where(np.abs(yaws[order_vote] - median_yaw) % (2 * np.pi) < threshold_yaw)[0] 59 | order_vote = order_vote[yaw_vote] 60 | 61 | # start weighted voting 62 | vote_score_sum = np.sum(scores[order_vote]) 63 | det_arrays = list() 64 | for idx in order_vote: 65 | det_arrays.append(BBox.bbox2array(dets[idx])[np.newaxis, :]) 66 | det_arrays = np.vstack(det_arrays) 67 | avg_bbox_array = np.sum(scores[order_vote][:, np.newaxis] * det_arrays, axis=0) / vote_score_sum 68 | bbox = BBox.array2bbox(avg_bbox_array) 69 | bbox.s = scores[index] 70 | result_indexes.append(index) 71 | result_types.append(inst_types[index]) 72 | else: 73 | result_indexes.append(index) 74 | result_types.append(inst_types[index]) 75 | 76 | # delete the overlapped bboxes 77 | delete_idxes = related_idxes[related_inds] 78 | in_mask = np.isin(order, delete_idxes, invert=True) 79 | order = order[in_mask] 80 | 81 | return result_indexes, result_types 82 | -------------------------------------------------------------------------------- /mot_3d/redundancy/__init__.py: -------------------------------------------------------------------------------- 1 | """ the redundancy system interface, in case of no high score detections 2 | this is the two-stage association. 3 | """ 4 | from .redundancy import RedundancyModule -------------------------------------------------------------------------------- /mot_3d/redundancy/redundancy.py: -------------------------------------------------------------------------------- 1 | from ..frame_data import FrameData 2 | from ..update_info_data import UpdateInfoData 3 | from ..data_protos import BBox, Validity 4 | import numpy as np, mot_3d.utils as utils 5 | from ..tracklet import Tracklet 6 | from ..association import associate_dets_to_tracks 7 | 8 | 9 | class RedundancyModule: 10 | def __init__(self, configs): 11 | self.configs = configs 12 | self.mode = configs['redundancy']['mode'] 13 | self.asso = configs['running']['asso'] 14 | self.det_score = configs['redundancy']['det_score_threshold'][self.asso] 15 | self.det_threshold = configs['redundancy']['det_dist_threshold'][self.asso] 16 | self.motion_model_type = configs['running']['motion_model'] 17 | 18 | def infer(self, trk: Tracklet, input_data: FrameData, time_lag=None): 19 | if self.mode == 'bbox': 20 | return self.bbox_redundancy(trk, input_data) 21 | elif self.mode == 'mm': 22 | return self.motion_model_redundancy(trk, input_data, time_lag) 23 | else: 24 | return self.default_redundancy(trk, input_data) 25 | 26 | def default_redundancy(self, trk: Tracklet, input_data: FrameData): 27 | """ Return the supposed state, association string, and auxiliary information 28 | """ 29 | pred_bbox = trk.get_state() 30 | return pred_bbox, 0, None 31 | 32 | def motion_model_redundancy(self, trk: Tracklet, input_data: FrameData, time_lag): 33 | # get the motion model prediction / current state 34 | pred_bbox = trk.get_state() 35 | 36 | # associate to low-score detections 37 | dists = list() 38 | dets = input_data.dets 39 | related_indexes = [i for i, det in enumerate(dets) if det.s > self.det_score] 40 | candidate_dets = [dets[i] for i in related_indexes] 41 | 42 | # association 43 | for i, det in enumerate(candidate_dets): 44 | pd_det = det 45 | 46 | if self.asso == 'iou': 47 | dists.append(utils.iou3d(pd_det, pred_bbox)[1]) 48 | elif self.asso == 'giou': 49 | dists.append(utils.giou3d(pd_det, pred_bbox)) 50 | elif self.asso == 'm_dis': 51 | trk_innovation_matrix = trk.compute_innovation_matrix() 52 | inv_innovation_matrix = np.linalg.inv(trk_innovation_matrix) 53 | dists.append(utils.m_distance(pd_det, pred_bbox, inv_innovation_matrix)) 54 | elif self.asso == 'euler': 55 | dists.append(utils.m_distance(pd_det, pred_bbox)) 56 | 57 | if self.asso in ['iou', 'giou'] and (len(dists) == 0 or np.max(dists) < self.det_threshold): 58 | result_bbox = pred_bbox 59 | update_mode = 0 # two-stage not assiciated 60 | elif self.asso in ['m_dis', 'euler'] and (len(dists) == 0 or np.min(dists) > self.det_threshold): 61 | result_bbox = pred_bbox 62 | update_mode = 0 # two-stage not assiciated 63 | else: 64 | result_bbox = pred_bbox 65 | update_mode = 3 # associated 66 | return result_bbox, update_mode, {'velo': np.zeros(2)} 67 | 68 | def bipartite_infer(self, input_data: FrameData, tracklets): 69 | dets = input_data.dets 70 | det_indexes = [i for i, det in enumerate(dets) if det.s >= self.det_score] 71 | dets = [dets[i] for i in det_indexes] 72 | 73 | # prediction and association 74 | trk_preds = list() 75 | for trk in tracklets: 76 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame'])) 77 | 78 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds, 79 | 'bipartite', 'giou', 1 - self.det_threshold, None) 80 | for k in range(len(matched)): 81 | matched[k][0] = det_indexes[matched[k][0]] 82 | for k in range(len(unmatched_dets)): 83 | unmatched_dets[k] = det_indexes[unmatched_dets[k]] 84 | 85 | result_bboxes, update_modes = [], [] 86 | for t, trk in enumerate(tracklets): 87 | if t not in unmatched_trks: 88 | for k in range(len(matched)): 89 | if matched[k][1] == t: 90 | d = matched[k][0] 91 | break 92 | result_bboxes.append(trk_preds[t]) 93 | update_modes.append(4) # associated 94 | else: 95 | result_bboxes.append(trk_preds[t]) 96 | update_modes.append(0) # not associated 97 | return result_bboxes, update_modes 98 | -------------------------------------------------------------------------------- /mot_3d/tracklet/__init__.py: -------------------------------------------------------------------------------- 1 | from ..tracklet.tracklet import Tracklet -------------------------------------------------------------------------------- /mot_3d/tracklet/tracklet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .. import motion_model 3 | from .. import life as life_manager 4 | from ..update_info_data import UpdateInfoData 5 | from ..frame_data import FrameData 6 | from ..data_protos import BBox 7 | 8 | 9 | class Tracklet: 10 | def __init__(self, configs, id, bbox: BBox, det_type, frame_index, time_stamp=None, aux_info=None): 11 | self.id = id 12 | self.time_stamp = time_stamp 13 | self.asso = configs['running']['asso'] 14 | 15 | self.configs = configs 16 | self.det_type = det_type 17 | self.aux_info = aux_info 18 | 19 | # initialize different types of motion model 20 | self.motion_model_type = configs['running']['motion_model'] 21 | # simple kalman filter 22 | if self.motion_model_type == 'kf': 23 | self.motion_model = motion_model.KalmanFilterMotionModel( 24 | bbox=bbox, inst_type=self.det_type, time_stamp=time_stamp, covariance=configs['running']['covariance']) 25 | 26 | # life and death management 27 | self.life_manager = life_manager.HitManager(configs, frame_index) 28 | # store the score for the latest bbox 29 | self.latest_score = bbox.s 30 | 31 | def predict(self, time_stamp=None, is_key_frame=True): 32 | """ in the prediction step, the motion model predicts the state of bbox 33 | the other components have to be synced 34 | the result is a BBox 35 | 36 | the ussage of time_stamp is optional, only if you use velocities 37 | """ 38 | result = self.motion_model.get_prediction(time_stamp=time_stamp) 39 | self.life_manager.predict(is_key_frame=is_key_frame) 40 | self.latest_score = self.latest_score * 0.01 41 | result.s = self.latest_score 42 | return result 43 | 44 | def update(self, update_info: UpdateInfoData): 45 | """ update the state of the tracklet 46 | """ 47 | self.latest_score = update_info.bbox.s 48 | is_key_frame = update_info.aux_info['is_key_frame'] 49 | 50 | # only the direct association update the motion model 51 | if update_info.mode == 1 or update_info.mode == 3: 52 | self.motion_model.update(update_info.bbox, update_info.aux_info) 53 | else: 54 | pass 55 | self.life_manager.update(update_info, is_key_frame) 56 | return 57 | 58 | def get_state(self): 59 | """ current state of the tracklet 60 | """ 61 | result = self.motion_model.get_state() 62 | result.s = self.latest_score 63 | return result 64 | 65 | def valid_output(self, frame_index): 66 | return self.life_manager.valid_output(frame_index) 67 | 68 | def death(self, frame_index): 69 | return self.life_manager.death(frame_index) 70 | 71 | def state_string(self, frame_index): 72 | """ the string describes how we get the bbox (e.g. by detection or motion model prediction) 73 | """ 74 | return self.life_manager.state_string(frame_index) 75 | 76 | def compute_innovation_matrix(self): 77 | """ compute the innovation matrix for association with mahalonobis distance 78 | """ 79 | return self.motion_model.compute_innovation_matrix() 80 | 81 | def sync_time_stamp(self, time_stamp): 82 | """ sync the time stamp for motion model 83 | """ 84 | self.motion_model.sync_time_stamp(time_stamp) 85 | return 86 | -------------------------------------------------------------------------------- /mot_3d/update_info_data.py: -------------------------------------------------------------------------------- 1 | """ a general interface for aranging the things inside a single tracklet 2 | data structure for updating the life cycles and states of a tracklet 3 | """ 4 | from .data_protos import BBox 5 | from . import utils 6 | import numpy as np 7 | 8 | 9 | class UpdateInfoData: 10 | def __init__(self, mode, bbox: BBox, frame_index, ego, dets=None, pc=None, aux_info=None): 11 | self.mode = mode # association state 12 | self.bbox = bbox 13 | self.ego = ego 14 | self.frame_index = frame_index 15 | self.pc = pc 16 | self.dets = dets 17 | self.aux_info = aux_info 18 | -------------------------------------------------------------------------------- /mot_3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .geometry import * 2 | from .data_utils import * -------------------------------------------------------------------------------- /mot_3d/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Selecting the sequences according to types 2 | # Transfer the ID from string into int if needed 3 | from ..data_protos import BBox 4 | import numpy as np 5 | 6 | 7 | __all__ = ['inst_filter', 'str2int', 'box_wrapper', 'type_filter', 'id_transform'] 8 | 9 | 10 | def str2int(strs): 11 | result = [int(s) for s in strs] 12 | return result 13 | 14 | 15 | def box_wrapper(bboxes, ids): 16 | frame_num = len(ids) 17 | result = list() 18 | for _i in range(frame_num): 19 | frame_result = list() 20 | num = len(ids[_i]) 21 | for _j in range(num): 22 | frame_result.append((ids[_i][_j], bboxes[_i][_j])) 23 | result.append(frame_result) 24 | return result 25 | 26 | 27 | def id_transform(ids): 28 | frame_num = len(ids) 29 | 30 | id_list = list() 31 | for _i in range(frame_num): 32 | id_list += ids[_i] 33 | id_list = sorted(list(set(id_list))) 34 | 35 | id_mapping = dict() 36 | for _i, id in enumerate(id_list): 37 | id_mapping[id] = _i 38 | 39 | result = list() 40 | for _i in range(frame_num): 41 | frame_ids = list() 42 | frame_id_num = len(ids[_i]) 43 | for _j in range(frame_id_num): 44 | frame_ids.append(id_mapping[ids[_i][_j]]) 45 | result.append(frame_ids) 46 | return result 47 | 48 | 49 | def inst_filter(ids, bboxes, types, type_field=[1], id_trans=False): 50 | """ filter the bboxes according to types 51 | """ 52 | frame_num = len(ids) 53 | if id_trans: 54 | ids = id_transform(ids) 55 | id_result, bbox_result = [], [] 56 | for _i in range(frame_num): 57 | frame_ids = list() 58 | frame_bboxes = list() 59 | frame_id_num = len(ids[_i]) 60 | for _j in range(frame_id_num): 61 | obj_type = types[_i][_j] 62 | matched = False 63 | for type_name in type_field: 64 | if str(type_name) in str(obj_type): 65 | matched = True 66 | if matched: 67 | frame_ids.append(ids[_i][_j]) 68 | frame_bboxes.append(BBox.array2bbox(bboxes[_i][_j])) 69 | id_result.append(frame_ids) 70 | bbox_result.append(frame_bboxes) 71 | return id_result, bbox_result 72 | 73 | 74 | def type_filter(contents, types, type_field=[1]): 75 | frame_num = len(types) 76 | content_result = [list() for i in range(len(type_field))] 77 | for _k, inst_type in enumerate(type_field): 78 | for _i in range(frame_num): 79 | frame_contents = list() 80 | frame_id_num = len(contents[_i]) 81 | for _j in range(frame_id_num): 82 | if types[_i][_j] != inst_type: 83 | continue 84 | frame_contents.append(contents[_i][_j]) 85 | content_result[_k].append(frame_contents) 86 | return content_result -------------------------------------------------------------------------------- /mot_3d/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from shapely.geometry import Polygon 4 | from scipy.spatial import ConvexHull, convex_hull_plot_2d 5 | import numba 6 | from ..data_protos import BBox 7 | 8 | 9 | __all__ = ['pc_in_box', 'downsample', 'pc_in_box_2D', 10 | 'apply_motion_to_points', 'make_transformation_matrix', 11 | 'iou2d', 'iou3d', 'pc2world', 'giou2d', 'giou3d', 12 | 'back_step_det', 'm_distance', 'velo2world', 'score_rectification'] 13 | 14 | 15 | def velo2world(ego_matrix, velo): 16 | """ transform local velocity [x, y] to global 17 | """ 18 | new_velo = velo[:, np.newaxis] 19 | new_velo = ego_matrix[:2, :2] @ new_velo 20 | return new_velo[:, 0] 21 | 22 | 23 | def apply_motion_to_points(points, motion, pre_move=0): 24 | transformation_matrix = make_transformation_matrix(motion) 25 | points = deepcopy(points) 26 | points = points + pre_move 27 | new_points = np.concatenate((points, 28 | np.ones(points.shape[0])[:, np.newaxis]), 29 | axis=1) 30 | 31 | new_points = transformation_matrix @ new_points.T 32 | new_points = new_points.T[:, :3] 33 | new_points -= pre_move 34 | return new_points 35 | 36 | 37 | @numba.njit 38 | def downsample(points, voxel_size=0.05): 39 | sample_dict = dict() 40 | for i in range(points.shape[0]): 41 | point_coord = np.floor(points[i] / voxel_size) 42 | sample_dict[(int(point_coord[0]), int(point_coord[1]), int(point_coord[2]))] = True 43 | res = np.zeros((len(sample_dict), 3), dtype=np.float32) 44 | idx = 0 45 | for k, v in sample_dict.items(): 46 | res[idx, 0] = k[0] * voxel_size + voxel_size / 2 47 | res[idx, 1] = k[1] * voxel_size + voxel_size / 2 48 | res[idx, 2] = k[2] * voxel_size + voxel_size / 2 49 | idx += 1 50 | return res 51 | 52 | 53 | # def pc_in_box(box, pc, box_scaling=1.5): 54 | # center_x, center_y, length, width = \ 55 | # box['center_x'], box['center_y'], box['length'], box['width'] 56 | # center_z, height = box['center_z'], box['height'] 57 | # yaw = box['heading'] 58 | 59 | # rx = np.abs((pc[:, 0] - center_x) * np.cos(yaw) + (pc[:, 1] - center_y) * np.sin(yaw)) 60 | # ry = np.abs((pc[:, 0] - center_x) * -(np.sin(yaw)) + (pc[:, 1] - center_y) * np.cos(yaw)) 61 | # rz = np.abs(pc[:, 2] - center_z) 62 | 63 | # mask_x = (rx < (length * box_scaling / 2)) 64 | # mask_y = (ry < (width * box_scaling / 2)) 65 | # mask_z = (rz < (height / 2)) 66 | 67 | # mask = mask_x * mask_y * mask_z 68 | # indices = np.argwhere(mask == 1).reshape(-1) 69 | # return pc[indices, :] 70 | 71 | 72 | # def pc_in_box_2D(box, pc, box_scaling=1.0): 73 | # center_x, center_y, length, width = \ 74 | # box['center_x'], box['center_y'], box['length'], box['width'] 75 | # yaw = box['heading'] 76 | 77 | # cos = np.cos(yaw) 78 | # sin = np.sin(yaw) 79 | # rx = np.abs((pc[:, 0] - center_x) * cos + (pc[:, 1] - center_y) * sin) 80 | # ry = np.abs((pc[:, 0] - center_x) * -(sin) + (pc[:, 1] - center_y) * cos) 81 | 82 | # mask_x = (rx < (length * box_scaling / 2)) 83 | # mask_y = (ry < (width * box_scaling / 2)) 84 | 85 | # mask = mask_x * mask_y 86 | # indices = np.argwhere(mask == 1).reshape(-1) 87 | # return pc[indices, :] 88 | 89 | 90 | def pc_in_box(box, pc, box_scaling=1.5): 91 | center_x, center_y, length, width = \ 92 | box.x, box.y, box.l, box.w 93 | center_z, height = box.z, box.h 94 | yaw = box.o 95 | return pc_in_box_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling) 96 | 97 | 98 | @numba.njit 99 | def pc_in_box_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling=1.5): 100 | mask = np.zeros(pc.shape[0], dtype=np.int32) 101 | yaw_cos, yaw_sin = np.cos(yaw), np.sin(yaw) 102 | for i in range(pc.shape[0]): 103 | rx = np.abs((pc[i, 0] - center_x) * yaw_cos + (pc[i, 1] - center_y) * yaw_sin) 104 | ry = np.abs((pc[i, 0] - center_x) * -yaw_sin + (pc[i, 1] - center_y) * yaw_cos) 105 | rz = np.abs(pc[i, 2] - center_z) 106 | 107 | if rx < (length * box_scaling / 2) and ry < (width * box_scaling / 2) and rz < (height * box_scaling / 2): 108 | mask[i] = 1 109 | indices = np.argwhere(mask == 1) 110 | result = np.zeros((indices.shape[0], 3), dtype=np.float64) 111 | for i in range(indices.shape[0]): 112 | result[i, :] = pc[indices[i], :] 113 | return result 114 | 115 | 116 | def pc_in_box_2D(box, pc, box_scaling=1.0): 117 | center_x, center_y, length, width = \ 118 | box.x, box.y, box.l, box.w 119 | center_z, height = box.z, box.h 120 | yaw = box.o 121 | return pc_in_box_2D_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling) 122 | 123 | 124 | @numba.njit 125 | def pc_in_box_2D_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling=1.0): 126 | mask = np.zeros(pc.shape[0], dtype=np.int32) 127 | yaw_cos, yaw_sin = np.cos(yaw), np.sin(yaw) 128 | for i in range(pc.shape[0]): 129 | rx = np.abs((pc[i, 0] - center_x) * yaw_cos + (pc[i, 1] - center_y) * yaw_sin) 130 | ry = np.abs((pc[i, 0] - center_x) * -yaw_sin + (pc[i, 1] - center_y) * yaw_cos) 131 | 132 | if rx < (length * box_scaling / 2) and ry < (width * box_scaling / 2): 133 | mask[i] = 1 134 | indices = np.argwhere(mask == 1) 135 | result = np.zeros((indices.shape[0], 3), dtype=np.float64) 136 | for i in range(indices.shape[0]): 137 | result[i, :] = pc[indices[i], :] 138 | return result 139 | 140 | 141 | def make_transformation_matrix(motion): 142 | x, y, z, theta = motion 143 | transformation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0, x], 144 | [np.sin(theta), np.cos(theta), 0, y], 145 | [0 , 0 , 1, z], 146 | [0 , 0 , 0, 1]]) 147 | return transformation_matrix 148 | 149 | 150 | def iou2d(box_a, box_b): 151 | boxa_corners = np.array(BBox.box2corners2d(box_a))[:, :2] 152 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2] 153 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners) 154 | overlap = reca.intersection(recb).area 155 | area_a = reca.area 156 | area_b = recb.area 157 | iou = overlap / (area_a + area_b - overlap + 1e-10) 158 | return iou 159 | 160 | 161 | def iou3d(box_a, box_b): 162 | boxa_corners = np.array(BBox.box2corners2d(box_a)) 163 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2] 164 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners) 165 | overlap_area = reca.intersection(recb).area 166 | iou_2d = overlap_area / (reca.area + recb.area - overlap_area) 167 | 168 | ha, hb = box_a.h, box_b.h 169 | za, zb = box_a.z, box_b.z 170 | overlap_height = max(0, min((za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2))) 171 | overlap_volume = overlap_area * overlap_height 172 | union_volume = box_a.w * box_a.l * ha + box_b.w * box_b.l * hb - overlap_volume 173 | iou_3d = overlap_volume / (union_volume + 1e-5) 174 | 175 | return iou_2d, iou_3d 176 | 177 | 178 | def pc2world(ego_matrix, pcs): 179 | new_pcs = np.concatenate((pcs, 180 | np.ones(pcs.shape[0])[:, np.newaxis]), 181 | axis=1) 182 | new_pcs = ego_matrix @ new_pcs.T 183 | new_pcs = new_pcs.T[:, :3] 184 | return new_pcs 185 | 186 | 187 | def giou2d(box_a: BBox, box_b: BBox): 188 | boxa_corners = np.array(BBox.box2corners2d(box_a)) 189 | boxb_corners = np.array(BBox.box2corners2d(box_b)) 190 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners) 191 | 192 | # compute intersection and union 193 | I = reca.intersection(recb).area 194 | U = box_a.w * box_a.l + box_b.w * box_b.l - I 195 | 196 | # compute the convex area 197 | all_corners = np.vstack((boxa_corners, boxb_corners)) 198 | C = ConvexHull(all_corners) 199 | convex_corners = all_corners[C.vertices] 200 | convex_area = PolyArea2D(convex_corners) 201 | C = convex_area 202 | 203 | # compute giou 204 | return I / U - (C - U) / C 205 | 206 | 207 | def giou3d(box_a: BBox, box_b: BBox): 208 | boxa_corners = np.array(BBox.box2corners2d(box_a))[:, :2] 209 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2] 210 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners) 211 | ha, hb = box_a.h, box_b.h 212 | za, zb = box_a.z, box_b.z 213 | overlap_height = max(0, min([(za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2), ha, hb])) 214 | union_height = max([(za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2), ha, hb]) 215 | 216 | # compute intersection and union 217 | I = reca.intersection(recb).area * overlap_height 218 | U = box_a.w * box_a.l * ha + box_b.w * box_b.l * hb - I 219 | 220 | # compute the convex area 221 | all_corners = np.vstack((boxa_corners, boxb_corners)) 222 | C = ConvexHull(all_corners) 223 | convex_corners = all_corners[C.vertices] 224 | convex_area = PolyArea2D(convex_corners) 225 | C = convex_area * union_height 226 | 227 | # compute giou 228 | giou = I / U - (C - U) / C 229 | return giou 230 | 231 | 232 | def PolyArea2D(pts): 233 | roll_pts = np.roll(pts, -1, axis=0) 234 | area = np.abs(np.sum((pts[:, 0] * roll_pts[:, 1] - pts[:, 1] * roll_pts[:, 0]))) * 0.5 235 | return area 236 | 237 | 238 | def back_step_det(det: BBox, velo, time_lag): 239 | result = BBox() 240 | BBox.copy_bbox(result, det) 241 | result.x -= (time_lag * velo[0]) 242 | result.y -= (time_lag * velo[1]) 243 | return result 244 | 245 | 246 | def diff_orientation_correction(diff): 247 | """ 248 | return the angle diff = det - trk 249 | if angle diff > 90 or < -90, rotate trk and update the angle diff 250 | """ 251 | if diff > np.pi / 2: 252 | diff -= np.pi 253 | if diff < -np.pi / 2: 254 | diff += np.pi 255 | return diff 256 | 257 | 258 | def m_distance(det, trk, trk_inv_innovation_matrix=None): 259 | det_array = BBox.bbox2array(det)[:7] 260 | trk_array = BBox.bbox2array(trk)[:7] 261 | 262 | diff = np.expand_dims(det_array - trk_array, axis=1) 263 | corrected_yaw_diff = diff_orientation_correction(diff[3]) 264 | diff[3] = corrected_yaw_diff 265 | 266 | if trk_inv_innovation_matrix is not None: 267 | result = \ 268 | np.sqrt(np.matmul(np.matmul(diff.T, trk_inv_innovation_matrix), diff)[0][0]) 269 | else: 270 | result = np.sqrt(np.dot(diff.T, diff)) 271 | return result 272 | 273 | 274 | def score_rectification(dets, gts): 275 | """ rectify the scores of detections according to their 3d iou with gts 276 | """ 277 | result = deepcopy(dets) 278 | 279 | if len(gts) == 0: 280 | for i, _ in enumerate(dets): 281 | result[i].s = 0.0 282 | return result 283 | 284 | if len(dets) == 0: 285 | return result 286 | 287 | iou_matrix = np.zeros((len(dets), len(gts))) 288 | for i, d in enumerate(dets): 289 | for j, g in enumerate(gts): 290 | iou_matrix[i, j] = iou3d(d, g)[1] 291 | max_index = np.argmax(iou_matrix, axis=1) 292 | max_iou = np.max(iou_matrix, axis=1) 293 | index = list(reversed(sorted(range(len(dets)), key=lambda k:max_iou[k]))) 294 | 295 | matched_gt = [] 296 | for i in index: 297 | if max_iou[i] >= 0.1 and max_index[i] not in matched_gt: 298 | result[i].s = max_iou[i] 299 | matched_gt.append(max_index[i]) 300 | elif max_iou[i] >= 0.1 and max_index[i] in matched_gt: 301 | result[i].s = 0.2 302 | else: 303 | result[i].s = 0.05 304 | 305 | return result 306 | -------------------------------------------------------------------------------- /mot_3d/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualizer2d import Visualizer2D -------------------------------------------------------------------------------- /mot_3d/visualization/visualizer2d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt, numpy as np 2 | from ..data_protos import BBox 3 | 4 | 5 | class Visualizer2D: 6 | def __init__(self, name='', figsize=(8, 8)): 7 | self.figure = plt.figure(name, figsize=figsize) 8 | plt.axis('equal') 9 | self.COLOR_MAP = { 10 | 'gray': np.array([140, 140, 136]) / 256, 11 | 'light_blue': np.array([4, 157, 217]) / 256, 12 | 'red': np.array([191, 4, 54]) / 256, 13 | 'black': np.array([0, 0, 0]) / 256, 14 | 'purple': np.array([224, 133, 250]) / 256, 15 | 'dark_green': np.array([32, 64, 40]) / 256, 16 | 'green': np.array([77, 115, 67]) / 256 17 | } 18 | 19 | def show(self): 20 | plt.show() 21 | 22 | def close(self): 23 | plt.close() 24 | 25 | def save(self, path): 26 | plt.savefig(path) 27 | 28 | def handler_pc(self, pc, color='gray'): 29 | vis_pc = np.asarray(pc) 30 | plt.scatter(vis_pc[:, 0], vis_pc[:, 1], marker='o', color=self.COLOR_MAP[color], s=0.01) 31 | 32 | def handler_box(self, box: BBox, message: str='', color='red', linestyle='solid'): 33 | corners = np.array(BBox.box2corners2d(box))[:, :2] 34 | corners = np.concatenate([corners, corners[0:1, :2]]) 35 | plt.plot(corners[:, 0], corners[:, 1], color=self.COLOR_MAP[color], linestyle=linestyle) 36 | corner_index = np.random.randint(0, 4, 1) 37 | plt.text(corners[corner_index, 0] - 1, corners[corner_index, 1] - 1, message, color=self.COLOR_MAP[color]) -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/detection.py: -------------------------------------------------------------------------------- 1 | import os, argparse, numpy as np, json 2 | from tqdm import tqdm 3 | 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 7 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 8 | parser.add_argument('--det_name', type=str, default='cp') 9 | parser.add_argument('--file_path', type=str, default='val.json') 10 | parser.add_argument('--velo', action='store_true', default=False) 11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 12 | args = parser.parse_args() 13 | 14 | 15 | def get_sample_tokens(data_folder, mode): 16 | token_folder = os.path.join(data_folder, 'token_info') 17 | file_names = sorted(os.listdir(token_folder)) 18 | result = dict() 19 | for i, file_name in enumerate(file_names): 20 | file_path = os.path.join(token_folder, file_name) 21 | scene_name = file_name.split('.')[0] 22 | tokens = json.load(open(file_path, 'r')) 23 | 24 | if mode == '2hz': 25 | result[scene_name] = tokens 26 | elif mode == '20hz': 27 | result[scene_name] = [t[0] for t in tokens] 28 | return result 29 | 30 | 31 | def sample_result2bbox_array(sample): 32 | trans, size, rot, score = \ 33 | sample['translation'], sample['size'],sample['rotation'], sample['detection_score'] 34 | return trans + size + rot + [score] 35 | 36 | 37 | def main(det_name, file_path, detection_folder, data_folder, mode): 38 | # dealing with the paths 39 | detection_folder = os.path.join(detection_folder, det_name) 40 | output_folder = os.path.join(detection_folder, 'dets') 41 | os.makedirs(output_folder, exist_ok=True) 42 | 43 | # load the detection file 44 | print('LOADING RAW FILE') 45 | f = open(file_path, 'r') 46 | det_data = json.load(f)['results'] 47 | f.close() 48 | 49 | # prepare the scene names and all the related tokens 50 | tokens = get_sample_tokens(data_folder, mode) 51 | scene_names = sorted(list(tokens.keys())) 52 | bboxes, inst_types, velos = dict(), dict(), dict() 53 | for scene_name in scene_names: 54 | frame_num = len(tokens[scene_name]) 55 | bboxes[scene_name], inst_types[scene_name] = \ 56 | [[] for i in range(frame_num)], [[] for i in range(frame_num)] 57 | if args.velo: 58 | velos[scene_name] = [[] for i in range(frame_num)] 59 | 60 | # enumerate through all the frames 61 | sample_keys = list(det_data.keys()) 62 | print('PROCESSING...') 63 | pbar = tqdm(total=len(sample_keys)) 64 | for sample_index, sample_key in enumerate(sample_keys): 65 | # find the corresponding scene and frame index 66 | scene_name, frame_index = None, None 67 | for scene_name in scene_names: 68 | if sample_key in tokens[scene_name]: 69 | frame_index = tokens[scene_name].index(sample_key) 70 | break 71 | 72 | # extract the bboxes and types 73 | sample_results = det_data[sample_key] 74 | for sample in sample_results: 75 | bbox, inst_type = sample_result2bbox_array(sample), sample['detection_name'] 76 | inst_velo = sample['velocity'] 77 | bboxes[scene_name][frame_index] += [bbox] 78 | inst_types[scene_name][frame_index] += [inst_type] 79 | 80 | if args.velo: 81 | velos[scene_name][frame_index] += [inst_velo] 82 | pbar.update(1) 83 | pbar.close() 84 | 85 | # save the results 86 | print('SAVING...') 87 | pbar = tqdm(total=len(scene_names)) 88 | for scene_name in scene_names: 89 | if args.velo: 90 | np.savez_compressed(os.path.join(output_folder, '{:}.npz'.format(scene_name)), 91 | bboxes=bboxes[scene_name], types=inst_types[scene_name], velos=velos[scene_name]) 92 | else: 93 | np.savez_compressed(os.path.join(output_folder, '{:}.npz'.format(scene_name)), 94 | bboxes=bboxes[scene_name], types=inst_types[scene_name]) 95 | pbar.update(1) 96 | pbar.close() 97 | return 98 | 99 | 100 | if __name__ == '__main__': 101 | detection_folder = os.path.join(args.data_folder, 'detection') 102 | os.makedirs(detection_folder, exist_ok=True) 103 | 104 | main(args.det_name, args.file_path, detection_folder, args.data_folder, args.mode) -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/ego_pose.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 12 | args = parser.parse_args() 13 | 14 | 15 | def main(nusc, scene_names, root_path, ego_folder, mode): 16 | pbar = tqdm(total=len(scene_names)) 17 | for scene_index, scene_info in enumerate(nusc.scene): 18 | scene_name = scene_info['name'] 19 | if scene_name not in scene_names: 20 | continue 21 | first_sample_token = scene_info['first_sample_token'] 22 | last_sample_token = scene_info['last_sample_token'] 23 | frame_data = nusc.get('sample', first_sample_token) 24 | if args.mode == '20hz': 25 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 26 | elif args.mode == '2hz': 27 | cur_sample_token = deepcopy(first_sample_token) 28 | frame_index = 0 29 | ego_data = dict() 30 | while True: 31 | if mode == '2hz': 32 | frame_data = nusc.get('sample', cur_sample_token) 33 | lidar_token = frame_data['data']['LIDAR_TOP'] 34 | lidar_data = nusc.get('sample_data', lidar_token) 35 | ego_token = lidar_data['ego_pose_token'] 36 | ego_pose = nusc.get('ego_pose', ego_token) 37 | elif mode == '20hz': 38 | frame_data = nusc.get('sample_data', cur_sample_token) 39 | ego_token = frame_data['ego_pose_token'] 40 | ego_pose = nusc.get('ego_pose', ego_token) 41 | 42 | # translation + rotation 43 | ego_data[str(frame_index)] = ego_pose['translation'] + ego_pose['rotation'] 44 | 45 | # clean up and prepare for the next 46 | cur_sample_token = frame_data['next'] 47 | if cur_sample_token == '': 48 | break 49 | frame_index += 1 50 | 51 | np.savez_compressed(os.path.join(ego_folder, '{:}.npz'.format(scene_name)), **ego_data) 52 | pbar.update(1) 53 | pbar.close() 54 | return 55 | 56 | 57 | if __name__ == '__main__': 58 | print('ego info') 59 | ego_folder = os.path.join(args.data_folder, 'ego_info') 60 | os.makedirs(ego_folder, exist_ok=True) 61 | 62 | val_scene_names = splits.create_splits_scenes()['val'] 63 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 64 | main(nusc, val_scene_names, args.raw_data_folder, ego_folder, args.mode) 65 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/gt_info.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | import pdb 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 11 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 12 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 13 | args = parser.parse_args() 14 | 15 | 16 | def instance_info2bbox_array(info): 17 | translation = info.center.tolist() 18 | size = info.wlh.tolist() 19 | rotation = info.orientation.q.tolist() 20 | return translation + size + rotation 21 | 22 | 23 | def main(nusc, scene_names, root_path, gt_folder): 24 | pbar = tqdm(total=len(scene_names)) 25 | for scene_index, scene_info in enumerate(nusc.scene): 26 | scene_name = scene_info['name'] 27 | if scene_name not in scene_names: 28 | continue 29 | 30 | first_sample_token = scene_info['first_sample_token'] 31 | last_sample_token = scene_info['last_sample_token'] 32 | frame_data = nusc.get('sample', first_sample_token) 33 | if args.mode == '20hz': 34 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 35 | elif args.mode == '2hz': 36 | cur_sample_token = deepcopy(first_sample_token) 37 | 38 | frame_index = 0 39 | IDS, inst_types, bboxes = list(), list(), list() 40 | while True: 41 | frame_ids, frame_types, frame_bboxes = list(), list(), list() 42 | if args.mode == '2hz': 43 | frame_data = nusc.get('sample', cur_sample_token) 44 | lidar_token = frame_data['data']['LIDAR_TOP'] 45 | instances = nusc.get_boxes(lidar_token) 46 | for inst in instances: 47 | frame_ids.append(inst.token) 48 | frame_types.append(inst.name) 49 | frame_bboxes.append(instance_info2bbox_array(inst)) 50 | 51 | elif args.mode == '20hz': 52 | frame_data = nusc.get('sample_data', cur_sample_token) 53 | lidar_data = nusc.get('sample_data', cur_sample_token) 54 | instances = nusc.get_boxes(lidar_data['token']) 55 | for inst in instances: 56 | frame_ids.append(inst.token) 57 | frame_types.append(inst.name) 58 | frame_bboxes.append(instance_info2bbox_array(inst)) 59 | 60 | IDS.append(frame_ids) 61 | inst_types.append(frame_types) 62 | bboxes.append(frame_bboxes) 63 | 64 | # clean up and prepare for the next 65 | cur_sample_token = frame_data['next'] 66 | if cur_sample_token == '': 67 | break 68 | 69 | np.savez_compressed(os.path.join(gt_folder, '{:}.npz'.format(scene_name)), 70 | ids=IDS, types=inst_types, bboxes=bboxes) 71 | pbar.update(1) 72 | pbar.close() 73 | return 74 | 75 | 76 | if __name__ == '__main__': 77 | print('gt info') 78 | gt_folder = os.path.join(args.data_folder, 'gt_info') 79 | os.makedirs(gt_folder, exist_ok=True) 80 | 81 | val_scene_names = splits.create_splits_scenes()['val'] 82 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 83 | main(nusc, val_scene_names, args.raw_data_folder, gt_folder) 84 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/nuscenes_preprocess.sh: -------------------------------------------------------------------------------- 1 | raw_data_dir=$1 2 | data_dir_2hz=$2 3 | data_dir_20hz=$3 4 | 5 | # token information 6 | python token_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 7 | python token_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 8 | 9 | # time stamp information 10 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 11 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 12 | 13 | # sensor calibration information 14 | python sensor_calibration.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 15 | python sensor_calibration.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 16 | 17 | # ego pose 18 | python ego_pose.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 19 | python ego_pose.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 20 | 21 | # gt information 22 | python gt_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 23 | python gt_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 24 | 25 | # point cloud, useful for visualization 26 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz 27 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz 28 | 29 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/raw_pc.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | import matplotlib.pyplot as plt 6 | import multiprocessing 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 11 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 12 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 13 | parser.add_argument('--process', type=int, default=1) 14 | args = parser.parse_args() 15 | 16 | 17 | def load_pc(path): 18 | pc = np.fromfile(path, dtype=np.float32) 19 | pc = pc.reshape((-1, 5))[:, :4] 20 | return pc 21 | 22 | 23 | def main(nusc, scene_names, root_path, pc_folder, mode, pid=0, process=1): 24 | for scene_index, scene_info in enumerate(nusc.scene): 25 | if scene_index % process != pid: 26 | continue 27 | scene_name = scene_info['name'] 28 | if scene_name not in scene_names: 29 | continue 30 | print('PROCESSING {:} / {:}'.format(scene_index + 1, len(nusc.scene))) 31 | 32 | first_sample_token = scene_info['first_sample_token'] 33 | frame_data = nusc.get('sample', first_sample_token) 34 | if mode == '20hz': 35 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 36 | elif mode == '2hz': 37 | cur_sample_token = deepcopy(first_sample_token) 38 | frame_index = 0 39 | pc_data = dict() 40 | while True: 41 | # find the path to lidar data 42 | if mode == '2hz': 43 | lidar_data = nusc.get('sample', cur_sample_token) 44 | lidar_path = nusc.get_sample_data_path(lidar_data['data']['LIDAR_TOP']) 45 | elif args.mode == '20hz': 46 | lidar_data = nusc.get('sample_data', cur_sample_token) 47 | lidar_path = lidar_data['filename'] 48 | 49 | # load and store the data 50 | point_cloud = np.fromfile(os.path.join(root_path, lidar_path), dtype=np.float32) 51 | point_cloud = np.reshape(point_cloud, (-1, 5))[:, :4] 52 | pc_data[str(frame_index)] = point_cloud 53 | 54 | # clean up and prepare for the next 55 | cur_sample_token = lidar_data['next'] 56 | if cur_sample_token == '': 57 | break 58 | frame_index += 1 59 | 60 | if frame_index % 10 == 0: 61 | print('PROCESSING ', scene_index, ' , ', frame_index) 62 | 63 | np.savez_compressed(os.path.join(pc_folder, '{:}.npz'.format(scene_name)), **pc_data) 64 | return 65 | 66 | 67 | if __name__ == '__main__': 68 | print('point cloud') 69 | pc_folder = os.path.join(args.data_folder, 'pc', 'raw_pc') 70 | os.makedirs(pc_folder, exist_ok=True) 71 | 72 | val_scene_names = splits.create_splits_scenes()['val'] 73 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 74 | main(nusc, val_scene_names, args.raw_data_folder, pc_folder, args.mode, 0, 1) 75 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/sensor_calibration.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 12 | args = parser.parse_args() 13 | 14 | 15 | def main(nusc, scene_names, root_path, calib_folder, mode): 16 | pbar = tqdm(total=len(scene_names)) 17 | for scene_index, scene_info in enumerate(nusc.scene): 18 | scene_name = scene_info['name'] 19 | if scene_name not in scene_names: 20 | continue 21 | 22 | first_sample_token = scene_info['first_sample_token'] 23 | last_sample_token = scene_info['last_sample_token'] 24 | frame_data = nusc.get('sample', first_sample_token) 25 | if mode == '20hz': 26 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 27 | elif mode == '2hz': 28 | cur_sample_token = deepcopy(first_sample_token) 29 | frame_index = 0 30 | calib_data = dict() 31 | while True: 32 | if mode == '2hz': 33 | frame_data = nusc.get('sample', cur_sample_token) 34 | lidar_token = frame_data['data']['LIDAR_TOP'] 35 | lidar_data = nusc.get('sample_data', lidar_token) 36 | calib_token = lidar_data['calibrated_sensor_token'] 37 | calib_pose = nusc.get('calibrated_sensor', calib_token) 38 | elif mode == '20hz': 39 | frame_data = nusc.get('sample_data', cur_sample_token) 40 | calib_token = frame_data['calibrated_sensor_token'] 41 | calib_pose = nusc.get('calibrated_sensor', calib_token) 42 | 43 | # translation + rotation 44 | calib_data[str(frame_index)] = calib_pose['translation'] + calib_pose['rotation'] 45 | 46 | # clean up and prepare for the next 47 | cur_sample_token = frame_data['next'] 48 | if cur_sample_token == '': 49 | break 50 | frame_index += 1 51 | 52 | np.savez_compressed(os.path.join(calib_folder, '{:}.npz'.format(scene_name)), **calib_data) 53 | pbar.update(1) 54 | pbar.close() 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | print('sensor calib') 60 | calib_folder = os.path.join(args.data_folder, 'calib_info') 61 | os.makedirs(calib_folder, exist_ok=True) 62 | 63 | val_scene_names = splits.create_splits_scenes()['val'] 64 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 65 | main(nusc, val_scene_names, args.raw_data_folder, calib_folder, args.mode) 66 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/time_stamp.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse, json 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 12 | args = parser.parse_args() 13 | 14 | 15 | def main(nusc, scene_names, root_path, ts_folder, mode): 16 | pbar = tqdm(total=len(scene_names)) 17 | for scene_index, scene_info in enumerate(nusc.scene): 18 | scene_name = scene_info['name'] 19 | if scene_name not in scene_names: 20 | continue 21 | 22 | first_sample_token = scene_info['first_sample_token'] 23 | last_sample_token = scene_info['last_sample_token'] 24 | frame_data = nusc.get('sample', first_sample_token) 25 | if mode == '20hz': 26 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 27 | elif mode == '2hz': 28 | cur_sample_token = deepcopy(first_sample_token) 29 | time_stamps = list() 30 | 31 | while True: 32 | if mode == '2hz': 33 | frame_data = nusc.get('sample', cur_sample_token) 34 | time_stamps.append(frame_data['timestamp']) 35 | elif mode == '20hz': 36 | frame_data = nusc.get('sample_data', cur_sample_token) 37 | # time stamp and if key frame 38 | time_stamps.append((frame_data['timestamp'], frame_data['is_key_frame'])) 39 | 40 | # clean up and prepare for the next 41 | cur_sample_token = frame_data['next'] 42 | if cur_sample_token == '': 43 | break 44 | f = open(os.path.join(ts_folder, '{:}.json'.format(scene_name)), 'w') 45 | json.dump(time_stamps, f) 46 | f.close() 47 | pbar.update(1) 48 | pbar.close() 49 | return 50 | 51 | 52 | if __name__ == '__main__': 53 | print('time stamp') 54 | ts_folder = os.path.join(args.data_folder, 'ts_info') 55 | os.makedirs(ts_folder, exist_ok=True) 56 | 57 | val_scene_names = splits.create_splits_scenes()['val'] 58 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 59 | main(nusc, val_scene_names, args.raw_data_folder, ts_folder, args.mode) 60 | -------------------------------------------------------------------------------- /preprocessing/nuscenes_data/token_info.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, nuscenes, argparse, json 2 | from nuscenes.nuscenes import NuScenes 3 | from nuscenes.utils import splits 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/') 10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/') 11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz']) 12 | args = parser.parse_args() 13 | 14 | 15 | def set_selected_or_not(frame_tokens): 16 | """ under the 20hz setting, 17 | we have to set whether to use a certain frame 18 | 1. select at the interval of 1 frames 19 | 2. if meet key frame, reset the counter 20 | """ 21 | counter = -1 22 | selected = list() 23 | frame_num = len(frame_tokens) 24 | for _, tokens in enumerate(frame_tokens): 25 | is_key_frame = tokens[1] 26 | counter += 1 27 | if is_key_frame: 28 | selected.append(True) 29 | counter = 0 30 | continue 31 | else: 32 | if counter % 2 == 0: 33 | selected.append(True) 34 | else: 35 | selected.append(False) 36 | result_tokens = [(list(frame_tokens[i]) + [selected[i]]) for i in range(frame_num)] 37 | return result_tokens 38 | 39 | 40 | def main(nusc, scene_names, root_path, token_folder, mode): 41 | pbar = tqdm(total=len(scene_names)) 42 | for scene_index, scene_info in enumerate(nusc.scene): 43 | scene_name = scene_info['name'] 44 | if scene_name not in scene_names: 45 | continue 46 | 47 | first_sample_token = scene_info['first_sample_token'] 48 | last_sample_token = scene_info['last_sample_token'] 49 | frame_data = nusc.get('sample', first_sample_token) 50 | if mode == '20hz': 51 | cur_sample_token = frame_data['data']['LIDAR_TOP'] 52 | elif mode == '2hz': 53 | cur_sample_token = deepcopy(first_sample_token) 54 | frame_tokens = list() 55 | 56 | while True: 57 | # find the path to lidar data 58 | if mode == '2hz': 59 | frame_data = nusc.get('sample', cur_sample_token) 60 | frame_tokens.append(cur_sample_token) 61 | elif mode == '20hz': 62 | frame_data = nusc.get('sample_data', cur_sample_token) 63 | frame_tokens.append((cur_sample_token, frame_data['is_key_frame'], frame_data['sample_token'])) 64 | 65 | # clean up and prepare for the next 66 | cur_sample_token = frame_data['next'] 67 | if cur_sample_token == '': 68 | break 69 | 70 | if mode == '20hz': 71 | frame_tokens = set_selected_or_not(frame_tokens) 72 | f = open(os.path.join(token_folder, '{:}.json'.format(scene_name)), 'w') 73 | json.dump(frame_tokens, f) 74 | f.close() 75 | 76 | pbar.update(1) 77 | pbar.close() 78 | return 79 | 80 | 81 | if __name__ == '__main__': 82 | print('token info') 83 | os.makedirs(args.data_folder, exist_ok=True) 84 | 85 | token_folder = os.path.join(args.data_folder, 'token_info') 86 | os.makedirs(token_folder, exist_ok=True) 87 | 88 | val_scene_names = splits.create_splits_scenes()['val'] 89 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True) 90 | main(nusc, val_scene_names, args.raw_data_folder, token_folder, args.mode) 91 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/detection.py: -------------------------------------------------------------------------------- 1 | """ Extract the detections from .bin files 2 | Each sequence is a .npz file containing three fields: bboxes, types, ids. 3 | bboxes, types, and ids follow the same format: 4 | [[bboxes in frame 0], 5 | [bboxes in frame 1], 6 | ... 7 | [bboxes in the last frame]] 8 | """ 9 | import os, numpy as np, argparse, json 10 | import tensorflow.compat.v1 as tf 11 | from tqdm import tqdm 12 | from google.protobuf.descriptor import FieldDescriptor as FD 13 | tf.enable_eager_execution() 14 | from waymo_open_dataset import label_pb2 15 | from waymo_open_dataset.protos import metrics_pb2 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--name', type=str, default='public') 20 | parser.add_argument('--file_path', type=str, default='validation.bin') 21 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/') 22 | parser.add_argument('--metadata', action='store_true', default=False) 23 | parser.add_argument('--id', action='store_true', default=False) 24 | args = parser.parse_args() 25 | 26 | 27 | def bbox_dict2array(box_dict): 28 | """transform box dict in waymo_open_format to array 29 | Args: 30 | box_dict ([dict]): waymo_open_dataset formatted bbox 31 | """ 32 | result = np.array([ 33 | box_dict['center_x'], 34 | box_dict['center_y'], 35 | box_dict['center_z'], 36 | box_dict['heading'], 37 | box_dict['length'], 38 | box_dict['width'], 39 | box_dict['height'], 40 | box_dict['score'] 41 | ]) 42 | return result 43 | 44 | 45 | def str_list_to_int(lst): 46 | result = [] 47 | for t in lst: 48 | try: 49 | t = int(t) 50 | result.append(t) 51 | except: 52 | continue 53 | return result 54 | 55 | 56 | def main(name, data_folder, det_folder, file_path, out_folder): 57 | # load timestamp and segment names 58 | ts_info_folder = os.path.join(data_folder, 'ts_info') 59 | ts_files = os.listdir(ts_info_folder) 60 | ts_info = dict() 61 | segment_name_list = list() 62 | for ts_file_name in ts_files: 63 | ts = json.load(open(os.path.join(ts_info_folder, ts_file_name), 'r')) 64 | segment_name = ts_file_name.split('.')[0] 65 | ts_info[segment_name] = ts 66 | segment_name_list.append(segment_name) 67 | 68 | # load detection file 69 | det_folder = os.path.join(det_folder, name) 70 | f = open(os.path.join(det_folder, file_path), 'rb') 71 | objects = metrics_pb2.Objects() 72 | objects.ParseFromString(f.read()) 73 | f.close() 74 | 75 | # parse and aggregate detections 76 | objects = objects.objects 77 | object_num = len(objects) 78 | 79 | result_bbox, result_type, result_velo, result_accel, result_ids = dict(), dict(), dict(), dict(), dict() 80 | for seg_name in ts_info.keys(): 81 | result_bbox[seg_name] = dict() 82 | result_type[seg_name] = dict() 83 | result_velo[seg_name] = dict() 84 | result_accel[seg_name] = dict() 85 | result_ids[seg_name] = dict() 86 | 87 | print('Converting') 88 | pbar = tqdm(total=object_num) 89 | for _i in range(object_num): 90 | instance = objects[_i] 91 | segment = instance.context_name 92 | time_stamp = instance.frame_timestamp_micros 93 | 94 | box = instance.object.box 95 | bbox_dict = { 96 | 'center_x': box.center_x, 97 | 'center_y': box.center_y, 98 | 'center_z': box.center_z, 99 | 'width': box.width, 100 | 'length': box.length, 101 | 'height': box.height, 102 | 'heading': box.heading, 103 | 'score': instance.score 104 | } 105 | bbox_array = bbox_dict2array(bbox_dict) 106 | obj_type = instance.object.type 107 | 108 | if args.metadata: 109 | meta_data = instance.object.metadata 110 | velo = (meta_data.speed_x, meta_data.speed_y) 111 | accel = (meta_data.accel_x, meta_data.accel_y) 112 | 113 | if args.id: 114 | id = instance.object.id 115 | 116 | val_index = None 117 | for _j in range(len(segment_name_list)): 118 | if segment in segment_name_list[_j]: 119 | val_index = _j 120 | break 121 | segment_name = segment_name_list[val_index] 122 | 123 | frame_number = None 124 | for _j in range(len(ts_info[segment_name])): 125 | if ts_info[segment_name][_j] == time_stamp: 126 | frame_number = _j 127 | break 128 | 129 | if str(frame_number) not in result_bbox[segment_name].keys(): 130 | result_bbox[segment_name][str(frame_number)] = list() 131 | result_type[segment_name][str(frame_number)] = list() 132 | if args.metadata: 133 | result_velo[segment_name][str(frame_number)] = list() 134 | result_accel[segment_name][str(frame_number)] = list() 135 | if args.id: 136 | result_ids[segment_name][str(frame_number)] = list() 137 | 138 | result_bbox[segment_name][str(frame_number)].append(bbox_array) 139 | result_type[segment_name][str(frame_number)].append(obj_type) 140 | if args.metadata: 141 | result_velo[segment_name][str(frame_number)].append(velo) 142 | result_accel[segment_name][str(frame_number)].append(accel) 143 | if args.id: 144 | result_ids[segment_name][str(frame_number)].append(id) 145 | 146 | pbar.update(1) 147 | pbar.close() 148 | 149 | print('Saving') 150 | pbar = tqdm(total=len(segment_name_list)) 151 | # store in files 152 | for _i, segment_name in enumerate(segment_name_list): 153 | dets = result_bbox[segment_name] 154 | types = result_type[segment_name] 155 | if args.metadata: 156 | velos = result_velo[segment_name] 157 | accels = result_accel[segment_name] 158 | if args.id: 159 | ids = result_ids[segment_name] 160 | 161 | frame_keys = sorted(str_list_to_int(dets.keys())) 162 | max_frame = max(frame_keys) 163 | bboxes = list() 164 | obj_types = list() 165 | velocities, accelerations, id_names = list(), list(), list() 166 | for key in range(max_frame + 1): 167 | if str(key) in dets.keys(): 168 | bboxes.append(dets[str(key)]) 169 | obj_types.append(types[str(key)]) 170 | if args.metadata: 171 | velocities.append(velos[str(key)]) 172 | accelerations.append(accels[str(key)]) 173 | if args.id: 174 | id_names.append(ids[str(key)]) 175 | else: 176 | bboxes.append([]) 177 | obj_types.append([]) 178 | if args.metadata: 179 | velocities.append([]) 180 | accelerations.append([]) 181 | if args.id: 182 | id_names.append([]) 183 | result = {'bboxes': bboxes, 'types': obj_types} 184 | if args.metadata: 185 | result['velos'] = velocities 186 | result['accels'] = accelerations 187 | if args.id: 188 | result['ids'] = id_names 189 | 190 | np.savez_compressed(os.path.join(out_folder, "{:}.npz".format(segment_name)), **result) 191 | pbar.update(1) 192 | pbar.close() 193 | 194 | 195 | if __name__ == '__main__': 196 | det_folder = os.path.join(args.data_folder, 'detection') 197 | os.makedirs(det_folder, exist_ok=True) 198 | output_folder = os.path.join(det_folder, args.name, 'dets') 199 | os.makedirs(output_folder, exist_ok=True) 200 | 201 | main(args.name, args.data_folder, det_folder, args.file_path, output_folder) 202 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/ego_info.py: -------------------------------------------------------------------------------- 1 | """ Extract the ego location information from tfrecords 2 | Output file format: dict compressed in .npz files 3 | { 4 | st(frame_num): ego_info (4 * 4 matrix) 5 | } 6 | """ 7 | import argparse 8 | import numpy as np 9 | import os 10 | from google.protobuf.descriptor import FieldDescriptor as FD 11 | import tensorflow.compat.v1 as tf 12 | tf.enable_eager_execution() 13 | import multiprocessing 14 | from waymo_open_dataset import dataset_pb2 as open_dataset 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/', 19 | help='location of tfrecords') 20 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/', 21 | help='output folder') 22 | parser.add_argument('--process', type=int, default=1, help='use multiprocessing for acceleration') 23 | args = parser.parse_args() 24 | 25 | 26 | def pb2dict(obj): 27 | """ 28 | Takes a ProtoBuf Message obj and convertes it to a dict. 29 | """ 30 | adict = {} 31 | # if not obj.IsInitialized(): 32 | # return None 33 | for field in obj.DESCRIPTOR.fields: 34 | if not getattr(obj, field.name): 35 | continue 36 | if not field.label == FD.LABEL_REPEATED: 37 | if not field.type == FD.TYPE_MESSAGE: 38 | adict[field.name] = getattr(obj, field.name) 39 | else: 40 | value = pb2dict(getattr(obj, field.name)) 41 | if value: 42 | adict[field.name] = value 43 | else: 44 | if field.type == FD.TYPE_MESSAGE: 45 | adict[field.name] = [pb2dict(v) for v in getattr(obj, field.name)] 46 | else: 47 | adict[field.name] = [v for v in getattr(obj, field.name)] 48 | return adict 49 | 50 | 51 | def main(raw_data_folder, data_folder, process_num=1, token=0): 52 | """ The process with index "token" process the ego pose information. 53 | """ 54 | tf_records = os.listdir(raw_data_folder) 55 | tf_records = [x for x in tf_records if 'tfrecord' in x] 56 | tf_records = sorted(tf_records) 57 | for record_index, tf_record_name in enumerate(tf_records): 58 | if record_index % process_num != token: 59 | continue 60 | print('starting for ego info ', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name) 61 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name) 62 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='') 63 | segment_name = tf_record_name.split('.')[0] 64 | 65 | frame_num = 0 66 | ego_infos = dict() 67 | 68 | for data in dataset: 69 | frame = open_dataset.Frame() 70 | frame.ParseFromString(bytearray(data.numpy())) 71 | 72 | ego_info = np.reshape(np.array(frame.pose.transform), [4, 4]) 73 | ego_infos[str(frame_num)] = ego_info 74 | 75 | frame_num += 1 76 | if frame_num % 10 == 0: 77 | print('ego record {:} / {:} frame number {:}'.format(record_index + 1, len(tf_records), frame_num)) 78 | print('{:} frames in total'.format(frame_num)) 79 | 80 | np.savez_compressed(os.path.join(data_folder, "{}.npz".format(segment_name)), **ego_infos) 81 | 82 | 83 | if __name__ == '__main__': 84 | args.data_folder = os.path.join(args.data_folder, 'ego_info') 85 | os.makedirs(args.data_folder, exist_ok=True) 86 | 87 | if args.process > 1: 88 | pool = multiprocessing.Pool(args.process) 89 | for token in range(args.process): 90 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token)) 91 | pool.close() 92 | pool.join() 93 | else: 94 | main(args.raw_data_folder, args.data_folder) 95 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/gt_bin_decode.py: -------------------------------------------------------------------------------- 1 | """ Process the .bin file of ground truth, and save it in our detection format. 2 | Each sequence is a .npz file containing three fields: bboxes, types, ids. 3 | bboxes, types, and ids follow the same format: 4 | [[bboxes in frame 0], 5 | [bboxes in frame 1], 6 | ... 7 | [bboxes in the last frame]] 8 | """ 9 | import os, numpy as np, argparse, json 10 | from mot_3d.data_protos import BBox 11 | import mot_3d.utils as utils 12 | import tensorflow.compat.v1 as tf 13 | from google.protobuf.descriptor import FieldDescriptor as FD 14 | tf.enable_eager_execution() 15 | from waymo_open_dataset.protos import metrics_pb2 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data_folder', type=str, default='/../../../datasets/waymo/mot/') 20 | parser.add_argument('--file_path', type=str, default='../datasets/waymo/mot/info/validation_gt.bin') 21 | args = parser.parse_args() 22 | 23 | 24 | def main(file_path, out_folder, data_folder): 25 | # load time stamp 26 | ts_info_folder = os.path.join(data_folder, 'ts_info') 27 | time_stamp_info = dict() 28 | ts_files = os.listdir(ts_info_folder) 29 | for ts_file_name in ts_files: 30 | segment_name = ts_file_name.split('.')[0] 31 | ts_path = os.path.join(ts_info_folder, '{}.json'.format(segment_name)) 32 | f = open(ts_path, 'r') 33 | ts = json.load(f) 34 | f.close() 35 | time_stamp_info[segment_name] = ts 36 | 37 | # segment name list 38 | segment_name_list = list() 39 | for ts_file_name in ts_files: 40 | segment_name = ts_file_name.split('.')[0] 41 | segment_name_list.append(segment_name) 42 | 43 | # load gt.bin file 44 | f = open(file_path, 'rb') 45 | objects = metrics_pb2.Objects() 46 | objects.ParseFromString(f.read()) 47 | f.close() 48 | 49 | # parse and aggregate detections 50 | objects = objects.objects 51 | object_num = len(objects) 52 | 53 | result_bbox, result_type, result_id = dict(), dict(), dict() 54 | for seg_name in time_stamp_info.keys(): 55 | result_bbox[seg_name] = dict() 56 | result_type[seg_name] = dict() 57 | result_id[seg_name] = dict() 58 | 59 | for i in range(object_num): 60 | instance = objects[i] 61 | segment = instance.context_name 62 | time_stamp = instance.frame_timestamp_micros 63 | 64 | box = instance.object.box 65 | box_dict = { 66 | 'center_x': box.center_x, 67 | 'center_y': box.center_y, 68 | 'center_z': box.center_z, 69 | 'width': box.width, 70 | 'length': box.length, 71 | 'height': box.height, 72 | 'heading': box.heading, 73 | 'score': instance.score 74 | } 75 | 76 | val_index = None 77 | for _j in range(len(segment_name_list)): 78 | if segment in segment_name_list[_j]: 79 | val_index = _j 80 | break 81 | segment_name = segment_name_list[val_index] 82 | 83 | frame_number = None 84 | for _j in range(len(time_stamp_info[segment_name])): 85 | if time_stamp_info[segment_name][_j] == time_stamp: 86 | frame_number = _j 87 | break 88 | 89 | if str(frame_number) not in result_bbox[segment_name].keys(): 90 | result_bbox[segment_name][str(frame_number)] = list() 91 | result_type[segment_name][str(frame_number)] = list() 92 | result_id[segment_name][str(frame_number)] = list() 93 | 94 | result_bbox[segment_name][str(frame_number)].append(BBox.bbox2array(BBox.dict2bbox(box_dict))) 95 | result_type[segment_name][str(frame_number)].append(instance.object.type) 96 | result_id[segment_name][str(frame_number)].append(instance.object.id) 97 | 98 | if (i + 1) % 10000 == 0: 99 | print(i + 1, ' / ', object_num) 100 | 101 | # store in files 102 | for _i, segment_name in enumerate(time_stamp_info.keys()): 103 | dets = result_bbox[segment_name] 104 | types = result_type[segment_name] 105 | ids = result_id[segment_name] 106 | print('{} / {}'.format(_i + 1, len(time_stamp_info.keys()))) 107 | 108 | frame_keys = sorted(utils.str2int(dets.keys())) 109 | max_frame = max(frame_keys) 110 | obj_ids, bboxes, obj_types = list(), list(), list() 111 | 112 | for key in range(max_frame + 1): 113 | if str(key) in dets.keys(): 114 | bboxes.append(dets[str(key)]) 115 | obj_types.append(types[str(key)]) 116 | obj_ids.append(ids[str(key)]) 117 | else: 118 | bboxes.append([]) 119 | obj_types.append([]) 120 | obj_ids.append([]) 121 | 122 | np.savez_compressed(os.path.join(out_folder, "{}.npz".format(segment_name)), 123 | bboxes=bboxes, types=obj_types, ids=obj_ids) 124 | 125 | 126 | if __name__ == '__main__': 127 | out_folder = os.path.join(args.data_folder, 'detection', 'gt', 'dets') 128 | os.makedirs(out_folder) 129 | main(args.file_path, out_folder, args.data_folder) 130 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/raw_pc.py: -------------------------------------------------------------------------------- 1 | """ Extract the point cloud sequences from the tfrecords 2 | output format: a compressed dict stored in an npz file 3 | { 4 | str(frame_number): pc (N * 3 numpy array) 5 | } 6 | """ 7 | import argparse 8 | import numpy as np 9 | import os 10 | import multiprocessing 11 | from google.protobuf.descriptor import FieldDescriptor as FD 12 | import tensorflow.compat.v1 as tf 13 | tf.enable_eager_execution() 14 | from waymo_open_dataset.utils import frame_utils 15 | from waymo_open_dataset import dataset_pb2 as open_dataset 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/', 20 | help='the location of tfrecords') 21 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/', 22 | help='the location of raw pcs') 23 | parser.add_argument('--process', type=int, default=1) 24 | args = parser.parse_args() 25 | 26 | 27 | def pb2dict(obj): 28 | """ 29 | Takes a ProtoBuf Message obj and convertes it to a dict. 30 | """ 31 | adict = {} 32 | # if not obj.IsInitialized(): 33 | # return None 34 | for field in obj.DESCRIPTOR.fields: 35 | if not getattr(obj, field.name): 36 | continue 37 | if not field.label == FD.LABEL_REPEATED: 38 | if not field.type == FD.TYPE_MESSAGE: 39 | adict[field.name] = getattr(obj, field.name) 40 | else: 41 | value = pb2dict(getattr(obj, field.name)) 42 | if value: 43 | adict[field.name] = value 44 | else: 45 | if field.type == FD.TYPE_MESSAGE: 46 | adict[field.name] = [pb2dict(v) for v in getattr(obj, field.name)] 47 | else: 48 | adict[field.name] = [v for v in getattr(obj, field.name)] 49 | return adict 50 | 51 | 52 | def main(raw_data_folder, data_folder, process=1, token=0): 53 | tf_records = os.listdir(raw_data_folder) 54 | tf_records = [x for x in tf_records if 'tfrecord' in x] 55 | tf_records = sorted(tf_records) 56 | for record_index, tf_record_name in enumerate(tf_records): 57 | if record_index % process != token: 58 | continue 59 | print('starting for raw pc', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name) 60 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name) 61 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='') 62 | segment_name = tf_record_name.split('.')[0] 63 | 64 | frame_num = 0 65 | pcs = dict() 66 | 67 | for data in dataset: 68 | frame = open_dataset.Frame() 69 | frame.ParseFromString(bytearray(data.numpy())) 70 | 71 | # extract the points 72 | (range_images, camera_projections, range_image_top_pose) = \ 73 | frame_utils.parse_range_image_and_camera_projection(frame) 74 | points, cp_points = frame_utils.convert_range_image_to_point_cloud( 75 | frame, range_images, camera_projections, range_image_top_pose, ri_index=0) 76 | points_all = np.concatenate(points, axis=0) 77 | pcs[str(frame_num)] = points_all 78 | 79 | frame_num += 1 80 | if frame_num % 10 == 0: 81 | print('Point Cloud Record {} / {} FNumber {:}'.format(record_index + 1, len(tf_records), frame_num)) 82 | print('{:} frames in total'.format(frame_num)) 83 | 84 | np.savez_compressed(os.path.join(data_folder, "{}.npz".format(segment_name)), **pcs) 85 | 86 | 87 | if __name__ == '__main__': 88 | args.data_folder = os.path.join(args.data_folder, 'pc', 'raw_pc') 89 | os.makedirs(args.data_folder, exist_ok=True) 90 | 91 | if args.process > 1: 92 | # multiprocessing accelerate the speed 93 | pool = multiprocessing.Pool(args.process) 94 | for token in range(args.process): 95 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token)) 96 | pool.close() 97 | pool.join() 98 | else: 99 | main(args.raw_data_folder, args.data_folder) 100 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/time_stamp.py: -------------------------------------------------------------------------------- 1 | """ Extract the time stamp information about each frame. 2 | Each sequence has a json file containing a list of timestamps. 3 | """ 4 | import os 5 | import tensorflow.compat.v1 as tf 6 | import numpy as np 7 | import argparse 8 | import json 9 | from google.protobuf.descriptor import FieldDescriptor as FD 10 | tf.enable_eager_execution() 11 | import multiprocessing 12 | from waymo_open_dataset import dataset_pb2 as open_dataset 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/', 17 | help='location of tfrecords') 18 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/', 19 | help='output folder') 20 | parser.add_argument('--process', type=int, default=1, help='use multiprocessing for acceleration') 21 | args = parser.parse_args() 22 | 23 | 24 | def main(raw_data_folder, data_folder, process, token): 25 | tf_records = os.listdir(raw_data_folder) 26 | tf_records = [x for x in tf_records if 'tfrecord' in x] 27 | tf_records = sorted(tf_records) 28 | 29 | for record_index, tf_record_name in enumerate(tf_records): 30 | if record_index % process != token: 31 | continue 32 | print('starting for time stamp: ', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name) 33 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name) 34 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='') 35 | 36 | time_stamps = list() 37 | for data in dataset: 38 | frame = open_dataset.Frame() 39 | frame.ParseFromString(bytearray(data.numpy())) 40 | time_stamps.append(frame.timestamp_micros) 41 | 42 | file_name = tf_record_name.split('.')[0] 43 | print(file_name) 44 | f = open(os.path.join(data_folder, '{:}.json'.format(file_name)), 'w') 45 | json.dump(time_stamps, f) 46 | f.close() 47 | 48 | 49 | if __name__ == '__main__': 50 | args.data_folder = os.path.join(args.data_folder, 'ts_info') 51 | os.makedirs(args.data_folder, exist_ok=True) 52 | 53 | if args.process > 1: 54 | pool = multiprocessing.Pool(args.process) 55 | for token in range(args.process): 56 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token)) 57 | pool.close() 58 | pool.join() 59 | else: 60 | main(args.raw_data_folder, args.data_folder) 61 | -------------------------------------------------------------------------------- /preprocessing/waymo_data/waymo_preprocess.sh: -------------------------------------------------------------------------------- 1 | raw_data_dir=$1 2 | data_dir=$2 3 | proc_num=$3 4 | 5 | # ego pose 6 | python ego_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num 7 | 8 | # # time stamp 9 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num 10 | 11 | # point cloud, useful for visualization 12 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | filterpy 3 | shapely 4 | numba 5 | pyquaternion 6 | nuscenes-devkit -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup( 5 | name='mot_3d', 6 | version='0.1', 7 | packages=['mot_3d'] 8 | ) -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil 2 | import mot_3d.visualization as visualization, mot_3d.utils as utils 3 | from mot_3d.data_protos import BBox, Validity 4 | from mot_3d.mot import MOTModel 5 | from mot_3d.frame_data import FrameData 6 | from data_loader import WaymoLoader 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | # running configurations 11 | parser.add_argument('--name', type=str, default='demo') 12 | parser.add_argument('--det_name', type=str, default='cp') 13 | parser.add_argument('--process', type=int, default=1) 14 | parser.add_argument('--visualize', action='store_true', default=False) 15 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug') 16 | parser.add_argument('--obj_type', type=str, default='vehicle', choices=['vehicle', 'pedestrian', 'cyclist']) 17 | # paths 18 | parser.add_argument('--config_path', type=str, default='configs/waymo_configs/vc_kf_giou.yaml') 19 | parser.add_argument('--result_folder', type=str, default='./mot_results/') 20 | parser.add_argument('--data_folder', type=str, default='./demo_data/') 21 | parser.add_argument('--gt_folder', type=str, default='./demo_data/detection/gt/dets/') 22 | args = parser.parse_args() 23 | 24 | 25 | def load_gt_bboxes(gt_folder, data_folder, segment_name, type_token): 26 | gt_info = np.load(os.path.join(gt_folder, '{:}.npz'.format(segment_name)), 27 | allow_pickle=True) 28 | ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)), 29 | allow_pickle=True) 30 | bboxes, ids, inst_types = gt_info['bboxes'], gt_info['ids'], gt_info['types'] 31 | gt_ids, gt_bboxes = utils.inst_filter(ids, bboxes, inst_types, type_field=[type_token], id_trans=True) 32 | 33 | ego_keys = sorted(utils.str2int(ego_info.keys())) 34 | egos = [ego_info[str(key)] for key in ego_keys] 35 | gt_bboxes = gt_bbox2world(gt_bboxes, egos) 36 | return gt_bboxes, gt_ids 37 | 38 | 39 | def gt_bbox2world(bboxes, egos): 40 | frame_num = len(egos) 41 | for i in range(frame_num): 42 | ego = egos[i] 43 | bbox_num = len(bboxes[i]) 44 | for j in range(bbox_num): 45 | bboxes[i][j] = BBox.bbox2world(ego, bboxes[i][j]) 46 | return bboxes 47 | 48 | 49 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''): 50 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12)) 51 | if pc is not None: 52 | visualizer.handler_pc(pc) 53 | for _, bbox in enumerate(gt_bboxes): 54 | visualizer.handler_box(bbox, message='', color='black') 55 | dets = [d for d in dets if d.s >= 0.1] 56 | for det in dets: 57 | visualizer.handler_box(det, message='%.2f' % det.s, color='green', linestyle='dashed') 58 | for _, (bbox, id, state) in enumerate(zip(bboxes, ids, states)): 59 | if Validity.valid(state): 60 | visualizer.handler_box(bbox, message=str(id), color='red') 61 | else: 62 | visualizer.handler_box(bbox, message=str(id), color='light_blue') 63 | visualizer.show() 64 | visualizer.close() 65 | 66 | 67 | def sequence_mot(configs, data_loader: WaymoLoader, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False): 68 | tracker = MOTModel(configs) 69 | frame_num = len(data_loader) 70 | IDs, bboxes, states, types = list(), list(), list(), list() 71 | for frame_index in range(data_loader.cur_frame, frame_num): 72 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(data_loader.type_token, sequence_id + 1, frame_index + 1, frame_num)) 73 | 74 | # input data 75 | frame_data = next(data_loader) 76 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'], 77 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp']) 78 | 79 | # mot 80 | results = tracker.frame_mot(frame_data) 81 | result_pred_bboxes = [trk[0] for trk in results] 82 | result_pred_ids = [trk[1] for trk in results] 83 | result_pred_states = [trk[2] for trk in results] 84 | result_types = [trk[3] for trk in results] 85 | 86 | # visualization 87 | if visualize: 88 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states, 89 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}_{:}'.format(args.name, sequence_id, frame_index)) 90 | 91 | # wrap for output 92 | IDs.append(result_pred_ids) 93 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes] 94 | bboxes.append(result_pred_bboxes) 95 | states.append(result_pred_states) 96 | types.append(result_types) 97 | return IDs, bboxes, states, types 98 | 99 | 100 | def main(name, obj_type, config_path, data_folder, det_data_folder, result_folder, gt_folder, start_frame=0, token=0, process=1): 101 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 102 | # simply knowing about all the segments 103 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 104 | print(file_names[0]) 105 | 106 | # load model configs 107 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader) 108 | 109 | if obj_type == 'vehicle': 110 | type_token = 1 111 | elif obj_type == 'pedestrian': 112 | type_token = 2 113 | elif obj_type == 'cyclist': 114 | type_token = 4 115 | for file_index, file_name in enumerate(file_names[:]): 116 | if file_index % process != token: 117 | continue 118 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names))) 119 | segment_name = file_name.split('.')[0] 120 | data_loader = WaymoLoader(configs, [type_token], segment_name, data_folder, det_data_folder, start_frame) 121 | gt_bboxes, gt_ids = load_gt_bboxes(gt_folder, data_folder, segment_name, type_token) 122 | 123 | # real mot happens here 124 | ids, bboxes, states, types = sequence_mot(configs, data_loader, file_index, gt_bboxes, gt_ids, args.visualize) 125 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)), 126 | ids=ids, bboxes=bboxes, states=states) 127 | 128 | 129 | if __name__ == '__main__': 130 | result_folder = os.path.join(args.result_folder, args.name) 131 | os.makedirs(result_folder, exist_ok=True) 132 | 133 | summary_folder = os.path.join(result_folder, 'summary') 134 | os.makedirs(summary_folder, exist_ok=True) 135 | 136 | summary_folder = os.path.join(summary_folder, args.obj_type) 137 | os.makedirs(summary_folder, exist_ok=True) 138 | 139 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name) 140 | 141 | if args.process > 1: 142 | pool = multiprocessing.Pool(args.process) 143 | for token in range(args.process): 144 | result = pool.apply_async(main, args=(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, 145 | result_folder, args.gt_folder, 0, token, args.process)) 146 | pool.close() 147 | pool.join() 148 | else: 149 | main(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, result_folder, 150 | args.gt_folder, args.start_frame, 0, 1) 151 | 152 | -------------------------------------------------------------------------------- /tools/main_nuscenes.py: -------------------------------------------------------------------------------- 1 | """ inference on the nuscenes dataset 2 | """ 3 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil 4 | import mot_3d.visualization as visualization, mot_3d.utils as utils 5 | from mot_3d.data_protos import BBox, Validity 6 | from mot_3d.mot import MOTModel 7 | from mot_3d.frame_data import FrameData 8 | from data_loader import NuScenesLoader 9 | from pyquaternion import Quaternion 10 | from nuscenes.utils.data_classes import Box 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | # running configurations 15 | parser.add_argument('--name', type=str, default='debug') 16 | parser.add_argument('--det_name', type=str, default='cp') 17 | parser.add_argument('--process', type=int, default=1) 18 | parser.add_argument('--visualize', action='store_true', default=False) 19 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug') 20 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle') 21 | # paths 22 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation') 23 | parser.add_argument('--result_folder', type=str, default='../nu_mot_results/') 24 | parser.add_argument('--data_folder', type=str, default='../datasets/nuscenes/') 25 | args = parser.parse_args() 26 | 27 | 28 | def nu_array2mot_bbox(b): 29 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10])) 30 | mot_bbox = BBox( 31 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2], 32 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2], 33 | o=nu_box.orientation.yaw_pitch_roll[0] 34 | ) 35 | if len(b) == 11: 36 | mot_bbox.s = b[-1] 37 | return mot_bbox 38 | 39 | 40 | def load_gt_bboxes(data_folder, type_token, segment_name): 41 | gt_info = np.load(os.path.join(data_folder, 'gt_info', '{:}.npz'.format(segment_name)), allow_pickle=True) 42 | ids, inst_types, bboxes = gt_info['ids'], gt_info['types'], gt_info['bboxes'] 43 | 44 | mot_bboxes = list() 45 | for _, frame_bboxes in enumerate(bboxes): 46 | mot_bboxes.append([]) 47 | for _, b in enumerate(frame_bboxes): 48 | mot_bboxes[-1].append(BBox.bbox2array(nu_array2mot_bbox(b))) 49 | gt_ids, gt_bboxes = utils.inst_filter(ids, mot_bboxes, inst_types, 50 | type_field=type_token, id_trans=True) 51 | return gt_bboxes, gt_ids 52 | 53 | 54 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''): 55 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12)) 56 | if pc is not None: 57 | visualizer.handler_pc(pc) 58 | 59 | if gt_bboxes is not None: 60 | for _, bbox in enumerate(gt_bboxes): 61 | visualizer.handler_box(bbox, message='', color='black') 62 | dets = [d for d in dets if d.s >= 0.01] 63 | for det in dets: 64 | visualizer.handler_box(det, message='%.2f' % det.s, color='gray', linestyle='dashed') 65 | for _, (bbox, id, state_string) in enumerate(zip(bboxes, ids, states)): 66 | if Validity.valid(state_string): 67 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='red') 68 | else: 69 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='light_blue') 70 | visualizer.show() 71 | visualizer.close() 72 | 73 | 74 | def sequence_mot(configs, data_loader, obj_type, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False): 75 | tracker = MOTModel(configs) 76 | frame_num = len(data_loader) 77 | IDs, bboxes, states, types = list(), list(), list(), list() 78 | 79 | for frame_index in range(data_loader.cur_frame, frame_num): 80 | if frame_index % 10 == 0: 81 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(obj_type, sequence_id, frame_index + 1, frame_num)) 82 | 83 | # input data 84 | frame_data = next(data_loader) 85 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'], 86 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp']) 87 | 88 | # mot 89 | results = tracker.frame_mot(frame_data) 90 | result_pred_bboxes = [trk[0] for trk in results] 91 | result_pred_ids = [trk[1] for trk in results] 92 | result_pred_states = [trk[2] for trk in results] 93 | result_types = [trk[3] for trk in results] 94 | 95 | # visualization 96 | if visualize: 97 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states, 98 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}'.format(args.name, frame_index)) 99 | 100 | # wrap for output 101 | IDs.append(result_pred_ids) 102 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes] 103 | bboxes.append(result_pred_bboxes) 104 | states.append(result_pred_states) 105 | types.append(result_types) 106 | 107 | return IDs, bboxes, states, types 108 | 109 | 110 | def main(name, obj_types, config_path, data_folder, det_data_folder, result_folder, start_frame=0, token=0, process=1): 111 | for obj_type in obj_types: 112 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 113 | # simply knowing about all the segments 114 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 115 | 116 | # load model configs 117 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader) 118 | 119 | for file_index, file_name in enumerate(file_names[:]): 120 | if file_index % process != token: 121 | continue 122 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names))) 123 | segment_name = file_name.split('.')[0] 124 | 125 | data_loader = NuScenesLoader(configs, [obj_type], segment_name, data_folder, det_data_folder, start_frame) 126 | 127 | gt_bboxes, gt_ids = load_gt_bboxes(data_folder, [obj_type], segment_name) 128 | ids, bboxes, states, types = sequence_mot(configs, data_loader, obj_type, file_index, gt_bboxes, gt_ids, args.visualize) 129 | 130 | frame_num = len(ids) 131 | for frame_index in range(frame_num): 132 | id_num = len(ids[frame_index]) 133 | for i in range(id_num): 134 | ids[frame_index][i] = '{:}_{:}'.format(file_index, ids[frame_index][i]) 135 | 136 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)), 137 | ids=ids, bboxes=bboxes, states=states, types=types) 138 | 139 | 140 | if __name__ == '__main__': 141 | result_folder = os.path.join(args.result_folder, args.name) 142 | os.makedirs(result_folder, exist_ok=True) 143 | summary_folder = os.path.join(result_folder, 'summary') 144 | os.makedirs(summary_folder, exist_ok=True) 145 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name) 146 | 147 | obj_types = args.obj_types.split(',') 148 | for obj_type in obj_types: 149 | tmp_summary_folder = os.path.join(summary_folder, obj_type) 150 | os.makedirs(tmp_summary_folder, exist_ok=True) 151 | 152 | if args.process > 1: 153 | pool = multiprocessing.Pool(args.process) 154 | for token in range(args.process): 155 | result = pool.apply_async(main, args=(args.name, obj_types, args.config_path, args.data_folder, det_data_folder, 156 | result_folder, 0, token, args.process)) 157 | pool.close() 158 | pool.join() 159 | else: 160 | main(args.name, obj_types, args.config_path, args.data_folder, det_data_folder, 161 | result_folder, args.start_frame, 0, 1) -------------------------------------------------------------------------------- /tools/main_nuscenes_10hz.py: -------------------------------------------------------------------------------- 1 | """ inference on the nuscenes dataset 2 | """ 3 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil 4 | import mot_3d.visualization as visualization, mot_3d.utils as utils 5 | from mot_3d.data_protos import BBox, Validity 6 | from mot_3d.mot import MOTModel 7 | from mot_3d.frame_data import FrameData 8 | from data_loader import NuScenesLoader10Hz 9 | from pyquaternion import Quaternion 10 | from nuscenes.utils.data_classes import Box 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | # running configurations 15 | parser.add_argument('--name', type=str, default='debug') 16 | parser.add_argument('--det_name', type=str, default='cp') 17 | parser.add_argument('--process', type=int, default=1) 18 | parser.add_argument('--visualize', action='store_true', default=False) 19 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug') 20 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle') 21 | # paths 22 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation') 23 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/10hz_exps/') 24 | parser.add_argument('--data_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/datasets/nuscenes/') 25 | args = parser.parse_args() 26 | 27 | 28 | def nu_array2mot_bbox(b): 29 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10])) 30 | mot_bbox = BBox( 31 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2], 32 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2], 33 | o=nu_box.orientation.yaw_pitch_roll[0] 34 | ) 35 | if len(b) == 11: 36 | mot_bbox.s = b[-1] 37 | return mot_bbox 38 | 39 | 40 | def load_gt_bboxes(data_folder, type_token, segment_name): 41 | gt_info = np.load(os.path.join(data_folder, 'gt_info', '{:}.npz'.format(segment_name)), allow_pickle=True) 42 | ids, inst_types, bboxes = gt_info['ids'], gt_info['types'], gt_info['bboxes'] 43 | 44 | mot_bboxes = list() 45 | for _, frame_bboxes in enumerate(bboxes): 46 | mot_bboxes.append([]) 47 | for _, b in enumerate(frame_bboxes): 48 | mot_bboxes[-1].append(BBox.bbox2array(nu_array2mot_bbox(b))) 49 | gt_ids, gt_bboxes = utils.inst_filter(ids, mot_bboxes, inst_types, 50 | type_field=type_token, id_trans=True) 51 | return gt_bboxes, gt_ids 52 | 53 | 54 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, is_key_frame=True, dets=None, name=''): 55 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12)) 56 | if pc is not None: 57 | visualizer.handler_pc(pc) 58 | for _, bbox in enumerate(gt_bboxes): 59 | visualizer.handler_box(bbox, message='', color='black') 60 | dets = [d for d in dets if d.s >= 0.01] 61 | 62 | if is_key_frame: 63 | line_style = 'solid' 64 | else: 65 | line_style = 'dashed' 66 | 67 | for det in dets: 68 | visualizer.handler_box(det, message='%.2f' % det.s, color='gray', linestyle='dashed') 69 | for _, (bbox, id, state_string) in enumerate(zip(bboxes, ids, states)): 70 | if Validity.valid(state_string): 71 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='red', linestyle=line_style) 72 | else: 73 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='light_blue', linestyle=line_style) 74 | visualizer.show() 75 | visualizer.close() 76 | 77 | 78 | def sequence_mot(configs, data_loader, obj_type, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False): 79 | tracker = MOTModel(configs) 80 | frame_num = len(data_loader) 81 | IDs, bboxes, states, types = list(), list(), list(), list() 82 | 83 | if gt_bboxes is None: 84 | gt_bboxes = [[] for i in range(data_loader.cur_frame, frame_num)] 85 | gt_ids = [[] for i in range(data_loader.cur_frame, frame_num)] 86 | 87 | for frame_index in range(data_loader.cur_frame, frame_num): 88 | if frame_index % 10 == 0: 89 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(obj_type, sequence_id, frame_index + 1, frame_num)) 90 | 91 | # input data 92 | frame_data = next(data_loader) 93 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'], 94 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp']) 95 | 96 | # mot 97 | results = tracker.frame_mot(frame_data) 98 | result_pred_bboxes = [trk[0] for trk in results] 99 | result_pred_ids = [trk[1] for trk in results] 100 | result_pred_states = [trk[2] for trk in results] 101 | result_types = [trk[3] for trk in results] 102 | 103 | # visualization 104 | if visualize: 105 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states, 106 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, 107 | is_key_frame=frame_data.aux_info['is_key_frame'], dets=frame_data.dets, name='{:}_{:}'.format(args.name, frame_index)) 108 | 109 | # wrap for output 110 | IDs.append(result_pred_ids) 111 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes] 112 | bboxes.append(result_pred_bboxes) 113 | states.append(result_pred_states) 114 | types.append(result_types) 115 | 116 | return IDs, bboxes, states, types 117 | 118 | 119 | def main(name, obj_types, config_path, data_folder, det_data_folder, result_folder, start_frame=0, token=0, process=1): 120 | for obj_type in obj_types: 121 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 122 | # simply knowing about all the segments 123 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 124 | 125 | # load model configs 126 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader) 127 | 128 | for file_index, file_name in enumerate(file_names[:]): 129 | if file_index % process != token: 130 | continue 131 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names))) 132 | segment_name = file_name.split('.')[0] 133 | 134 | data_loader = NuScenesLoader10Hz(configs, [obj_type], segment_name, data_folder, det_data_folder, start_frame) 135 | 136 | gt_bboxes, gt_ids = load_gt_bboxes(data_folder, [obj_type], segment_name) 137 | ids, bboxes, states, types = sequence_mot(configs, data_loader, obj_type, file_index, gt_bboxes, gt_ids, args.visualize) 138 | 139 | frame_num = len(ids) 140 | for frame_index in range(frame_num): 141 | id_num = len(ids[frame_index]) 142 | for i in range(id_num): 143 | ids[frame_index][i] = '{:}_{:}'.format(file_index, ids[frame_index][i]) 144 | 145 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)), 146 | ids=ids, bboxes=bboxes, states=states, types=types) 147 | 148 | 149 | if __name__ == '__main__': 150 | result_folder = os.path.join(args.result_folder, args.name) 151 | os.makedirs(result_folder, exist_ok=True) 152 | summary_folder = os.path.join(result_folder, 'summary') 153 | os.makedirs(summary_folder, exist_ok=True) 154 | 155 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name) 156 | 157 | obj_types = args.obj_types.split(',') 158 | for obj_type in obj_types: 159 | tmp_summary_folder = os.path.join(summary_folder, obj_type) 160 | os.makedirs(tmp_summary_folder, exist_ok=True) 161 | 162 | if args.process > 1: 163 | pool = multiprocessing.Pool(args.process) 164 | for token in range(args.process): 165 | result = pool.apply_async(main, args=(args.name, obj_types, args.config_path, args.data_folder, det_data_folder, 166 | result_folder, 0, token, args.process)) 167 | pool.close() 168 | pool.join() 169 | else: 170 | main(args.name, obj_types, args.config_path, args.data_folder, det_data_folder, 171 | result_folder, args.start_frame, 0, 1) 172 | -------------------------------------------------------------------------------- /tools/main_waymo.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil 2 | import mot_3d.visualization as visualization, mot_3d.utils as utils 3 | from mot_3d.data_protos import BBox, Validity 4 | from mot_3d.mot import MOTModel 5 | from mot_3d.frame_data import FrameData 6 | from data_loader import WaymoLoader 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | # running configurations 11 | parser.add_argument('--name', type=str, default='debug') 12 | parser.add_argument('--det_name', type=str, default='public') 13 | parser.add_argument('--process', type=int, default=1) 14 | parser.add_argument('--visualize', action='store_true', default=False) 15 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug') 16 | parser.add_argument('--obj_type', type=str, default='vehicle', choices=['vehicle', 'pedestrian', 'cyclist']) 17 | # paths 18 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation') 19 | parser.add_argument('--result_folder', type=str, default='../mot_results/') 20 | parser.add_argument('--data_folder', type=str, default='../datasets/waymo/mot/') 21 | parser.add_argument('--gt_folder', type=str, default='../datasets/waymo/mot/detection/gt/dets/') 22 | args = parser.parse_args() 23 | 24 | 25 | def load_gt_bboxes(gt_folder, data_folder, segment_name, type_token): 26 | gt_info = np.load(os.path.join(gt_folder, '{:}.npz'.format(segment_name)), 27 | allow_pickle=True) 28 | ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)), 29 | allow_pickle=True) 30 | bboxes, ids, inst_types = gt_info['bboxes'], gt_info['ids'], gt_info['types'] 31 | gt_ids, gt_bboxes = utils.inst_filter(ids, bboxes, inst_types, type_field=[type_token], id_trans=True) 32 | 33 | ego_keys = sorted(utils.str2int(ego_info.keys())) 34 | egos = [ego_info[str(key)] for key in ego_keys] 35 | gt_bboxes = gt_bbox2world(gt_bboxes, egos) 36 | return gt_bboxes, gt_ids 37 | 38 | 39 | def gt_bbox2world(bboxes, egos): 40 | frame_num = len(egos) 41 | for i in range(frame_num): 42 | ego = egos[i] 43 | bbox_num = len(bboxes[i]) 44 | for j in range(bbox_num): 45 | bboxes[i][j] = BBox.bbox2world(ego, bboxes[i][j]) 46 | return bboxes 47 | 48 | 49 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''): 50 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12)) 51 | if pc is not None: 52 | visualizer.handler_pc(pc) 53 | for _, bbox in enumerate(gt_bboxes): 54 | visualizer.handler_box(bbox, message='', color='black') 55 | dets = [d for d in dets if d.s >= 0.1] 56 | for det in dets: 57 | visualizer.handler_box(det, message='%.2f' % det.s, color='purple', linestyle='dashed') 58 | for _, (bbox, id, state) in enumerate(zip(bboxes, ids, states)): 59 | if Validity.valid(state): 60 | visualizer.handler_box(bbox, message=str(id), color='red') 61 | else: 62 | visualizer.handler_box(bbox, message=str(id), color='light_blue') 63 | # visualizer.show() 64 | visualizer.save('imgs/{:}.png'.format(name)) 65 | visualizer.close() 66 | 67 | 68 | def sequence_mot(configs, data_loader: WaymoLoader, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False): 69 | tracker = MOTModel(configs) 70 | frame_num = len(data_loader) 71 | IDs, bboxes, states, types = list(), list(), list(), list() 72 | for frame_index in range(data_loader.cur_frame, frame_num): 73 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(data_loader.type_token, sequence_id + 1, frame_index + 1, frame_num)) 74 | 75 | # input data 76 | frame_data = next(data_loader) 77 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'], 78 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp']) 79 | 80 | # mot 81 | results = tracker.frame_mot(frame_data) 82 | result_pred_bboxes = [trk[0] for trk in results] 83 | result_pred_ids = [trk[1] for trk in results] 84 | result_pred_states = [trk[2] for trk in results] 85 | result_types = [trk[3] for trk in results] 86 | 87 | # visualization 88 | if visualize: 89 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states, 90 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}_{:}'.format(args.name, sequence_id, frame_index)) 91 | 92 | # wrap for output 93 | IDs.append(result_pred_ids) 94 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes] 95 | bboxes.append(result_pred_bboxes) 96 | states.append(result_pred_states) 97 | types.append(result_types) 98 | return IDs, bboxes, states, types 99 | 100 | 101 | def main(name, obj_type, config_path, data_folder, det_data_folder, result_folder, gt_folder, start_frame=0, token=0, process=1): 102 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 103 | # simply knowing about all the segments 104 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 105 | print(file_names[0]) 106 | 107 | # load model configs 108 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader) 109 | 110 | if obj_type == 'vehicle': 111 | type_token = 1 112 | elif obj_type == 'pedestrian': 113 | type_token = 2 114 | elif obj_type == 'cyclist': 115 | type_token = 4 116 | for file_index, file_name in enumerate(file_names[:]): 117 | if file_index % process != token: 118 | continue 119 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names))) 120 | segment_name = file_name.split('.')[0] 121 | data_loader = WaymoLoader(configs, [type_token], segment_name, data_folder, det_data_folder, start_frame) 122 | gt_bboxes, gt_ids = load_gt_bboxes(gt_folder, data_folder, segment_name, type_token) 123 | 124 | # real mot happens here 125 | ids, bboxes, states, types = sequence_mot(configs, data_loader, file_index, gt_bboxes, gt_ids, args.visualize) 126 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)), 127 | ids=ids, bboxes=bboxes, states=states) 128 | 129 | 130 | if __name__ == '__main__': 131 | result_folder = os.path.join(args.result_folder, args.name) 132 | os.makedirs(result_folder, exist_ok=True) 133 | 134 | summary_folder = os.path.join(result_folder, 'summary') 135 | os.makedirs(summary_folder, exist_ok=True) 136 | 137 | summary_folder = os.path.join(summary_folder, args.obj_type) 138 | os.makedirs(summary_folder, exist_ok=True) 139 | 140 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name) 141 | 142 | if args.process > 1: 143 | pool = multiprocessing.Pool(args.process) 144 | for token in range(args.process): 145 | result = pool.apply_async(main, args=(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, 146 | result_folder, args.gt_folder, 0, token, args.process)) 147 | pool.close() 148 | pool.join() 149 | else: 150 | main(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, result_folder, 151 | args.gt_folder, args.start_frame, 0, 1) 152 | 153 | -------------------------------------------------------------------------------- /tools/nuscenes_result_creation.py: -------------------------------------------------------------------------------- 1 | import os, argparse, json, numpy as np 2 | from pyquaternion import Quaternion 3 | from mot_3d.data_protos import BBox, Validity 4 | from tqdm import tqdm 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--name', type=str, default='debug') 9 | parser.add_argument('--obj_types', type=str, default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle') 10 | parser.add_argument('--result_folder', type=str, default='../nu_mot_results/') 11 | parser.add_argument('--data_folder', type=str, default='../datasets/nuscenes/') 12 | args = parser.parse_args() 13 | 14 | 15 | def bbox_array2nuscenes_format(bbox_array): 16 | translation = bbox_array[:3].tolist() 17 | size = bbox_array[4:7].tolist() 18 | size = [size[1], size[0], size[2]] 19 | velocity = [0.0, 0.0] 20 | score = bbox_array[-1] 21 | 22 | yaw = bbox_array[3] 23 | rot_matrix = np.array([[np.cos(yaw), -np.sin(yaw), 0, 0], 24 | [np.sin(yaw), np.cos(yaw), 0, 0], 25 | [0, 0, 1, 0], 26 | [0, 1, 0, 1]]) 27 | q = Quaternion(matrix=rot_matrix) 28 | rotation = q.q.tolist() 29 | 30 | sample_result = { 31 | 'translation': translation, 32 | 'size': size, 33 | 'velocity': velocity, 34 | 'rotation': rotation, 35 | 'tracking_score': score 36 | } 37 | return sample_result 38 | 39 | 40 | def main(name, obj_types, data_folder, result_folder, output_folder): 41 | for obj_type in obj_types: 42 | print('CONVERTING {:}'.format(obj_type)) 43 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 44 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 45 | token_info_folder = os.path.join(data_folder, 'token_info') 46 | 47 | results = dict() 48 | pbar = tqdm(total=len(file_names)) 49 | for file_index, file_name in enumerate(file_names): 50 | segment_name = file_name.split('.')[0] 51 | token_info = json.load(open(os.path.join(token_info_folder, '{:}.json'.format(segment_name)), 'r')) 52 | mot_results = np.load(os.path.join(summary_folder, '{:}.npz'.format(segment_name)), allow_pickle=True) 53 | 54 | ids, bboxes, states, types = \ 55 | mot_results['ids'], mot_results['bboxes'], mot_results['states'], mot_results['types'] 56 | frame_num = len(ids) 57 | for frame_index in range(frame_num): 58 | sample_token = token_info[frame_index] 59 | results[sample_token] = list() 60 | frame_bboxes, frame_ids, frame_types, frame_states = \ 61 | bboxes[frame_index], ids[frame_index], types[frame_index], states[frame_index] 62 | 63 | frame_obj_num = len(frame_ids) 64 | for i in range(frame_obj_num): 65 | sample_result = bbox_array2nuscenes_format(frame_bboxes[i]) 66 | sample_result['sample_token'] = sample_token 67 | sample_result['tracking_id'] = frame_types[i] + '_' + str(frame_ids[i]) 68 | sample_result['tracking_name'] = frame_types[i] 69 | results[sample_token].append(sample_result) 70 | pbar.update(1) 71 | pbar.close() 72 | submission_file = { 73 | 'meta': { 74 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False 75 | }, 76 | 'results': results 77 | } 78 | 79 | f = open(os.path.join(output_folder, obj_type, 'results.json'), 'w') 80 | json.dump(submission_file, f) 81 | f.close() 82 | return 83 | 84 | 85 | if __name__ == '__main__': 86 | result_folder = os.path.join(args.result_folder, args.name) 87 | obj_types = args.obj_types.split(',') 88 | output_folder = os.path.join(result_folder, 'results') 89 | for obj_type in obj_types: 90 | tmp_output_folder = os.path.join(result_folder, 'results', obj_type) 91 | os.makedirs(tmp_output_folder, exist_ok=True) 92 | 93 | main(args.name, obj_types, args.data_folder, result_folder, output_folder) 94 | -------------------------------------------------------------------------------- /tools/nuscenes_result_creation_10hz.py: -------------------------------------------------------------------------------- 1 | import os, argparse, json, numpy as np 2 | from pyquaternion import Quaternion 3 | from mot_3d.data_protos import BBox, Validity 4 | from tqdm import tqdm 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--name', type=str, default='debug') 9 | parser.add_argument('--obj_types', type=str, default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle') 10 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/10hz_exps/') 11 | parser.add_argument('--data_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/datasets/nuscenes/validation_20hz/') 12 | args = parser.parse_args() 13 | 14 | 15 | def bbox_array2nuscenes_format(bbox_array): 16 | translation = bbox_array[:3].tolist() 17 | size = bbox_array[4:7].tolist() 18 | size = [size[1], size[0], size[2]] 19 | velocity = [0.0, 0.0] 20 | score = bbox_array[-1] 21 | 22 | yaw = bbox_array[3] 23 | rot_matrix = np.array([[np.cos(yaw), -np.sin(yaw), 0, 0], 24 | [np.sin(yaw), np.cos(yaw), 0, 0], 25 | [0, 0, 1, 0], 26 | [0, 1, 0, 1]]) 27 | q = Quaternion(matrix=rot_matrix) 28 | rotation = q.q.tolist() 29 | 30 | sample_result = { 31 | 'translation': translation, 32 | 'size': size, 33 | 'velocity': velocity, 34 | 'rotation': rotation, 35 | 'tracking_score': score 36 | } 37 | return sample_result 38 | 39 | 40 | def main(name, obj_types, data_folder, result_folder, output_folder): 41 | for obj_type in obj_types: 42 | print('CONVERTING {:}'.format(obj_type)) 43 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 44 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info'))) 45 | token_info_folder = os.path.join(data_folder, 'token_info') 46 | 47 | results = dict() 48 | pbar = tqdm(total=len(file_names)) 49 | for file_index, file_name in enumerate(file_names): 50 | segment_name = file_name.split('.')[0] 51 | token_info = json.load(open(os.path.join(token_info_folder, '{:}.json'.format(segment_name)), 'r')) 52 | mot_results = np.load(os.path.join(summary_folder, '{:}.npz'.format(segment_name)), allow_pickle=True) 53 | 54 | ids, bboxes, states, types = \ 55 | mot_results['ids'], mot_results['bboxes'], mot_results['states'], mot_results['types'] 56 | frame_num = len(ids) 57 | token_info = [t for t in token_info if t[3]] 58 | for frame_index in range(frame_num): 59 | frame_token = token_info[frame_index] 60 | is_key_frame = frame_token[1] 61 | if not is_key_frame: 62 | continue 63 | 64 | sample_token = frame_token[2] 65 | results[sample_token] = list() 66 | frame_bboxes, frame_ids, frame_types, frame_states = \ 67 | bboxes[frame_index], ids[frame_index], types[frame_index], states[frame_index] 68 | 69 | frame_obj_num = len(frame_ids) 70 | for i in range(frame_obj_num): 71 | sample_result = bbox_array2nuscenes_format(frame_bboxes[i]) 72 | sample_result['sample_token'] = sample_token 73 | sample_result['tracking_id'] = frame_types[i] + '_' + str(frame_ids[i]) 74 | sample_result['tracking_name'] = frame_types[i] 75 | results[sample_token].append(sample_result) 76 | pbar.update(1) 77 | pbar.close() 78 | submission_file = { 79 | 'meta': { 80 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False 81 | }, 82 | 'results': results 83 | } 84 | 85 | f = open(os.path.join(output_folder, obj_type, 'results.json'), 'w') 86 | json.dump(submission_file, f) 87 | f.close() 88 | return 89 | 90 | 91 | if __name__ == '__main__': 92 | result_folder = os.path.join(args.result_folder, args.name) 93 | obj_types = args.obj_types.split(',') 94 | output_folder = os.path.join(result_folder, 'results') 95 | for obj_type in obj_types: 96 | tmp_output_folder = os.path.join(result_folder, 'results', obj_type) 97 | if not os.path.exists(tmp_output_folder): 98 | os.makedirs(tmp_output_folder) 99 | 100 | main(args.name, obj_types, args.data_folder, result_folder, output_folder) 101 | -------------------------------------------------------------------------------- /tools/nuscenes_type_merge.py: -------------------------------------------------------------------------------- 1 | import os, argparse, json, numpy as np 2 | from pyquaternion import Quaternion 3 | from mot_3d.data_protos import BBox, Validity 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--name', type=str, default='debug') 7 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle') 8 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/nu_mot_results/') 9 | args = parser.parse_args() 10 | 11 | 12 | def main(name, obj_types, result_folder): 13 | raw_results = list() 14 | for type_name in obj_types: 15 | path = os.path.join(result_folder, type_name, 'results.json') 16 | f = open(path, 'r') 17 | raw_results.append(json.load(f)['results']) 18 | f.close() 19 | 20 | results = raw_results[0] 21 | sample_tokens = list(results.keys()) 22 | for token in sample_tokens: 23 | for i in range(1, len(obj_types)): 24 | results[token] += raw_results[i][token] 25 | 26 | submission_file = { 27 | 'meta': { 28 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False 29 | }, 30 | 'results': results 31 | } 32 | 33 | f = open(os.path.join(result_folder, 'results.json'), 'w') 34 | json.dump(submission_file, f) 35 | f.close() 36 | return 37 | 38 | 39 | if __name__ == '__main__': 40 | result_folder = os.path.join(args.result_folder, args.name, 'results') 41 | obj_types = args.obj_types.split(',') 42 | main(args.name, obj_types, result_folder) -------------------------------------------------------------------------------- /tools/waymo_pred_bin.py: -------------------------------------------------------------------------------- 1 | from waymo_open_dataset import dataset_pb2 2 | from waymo_open_dataset import label_pb2 3 | from waymo_open_dataset.protos import metrics_pb2 4 | import os, time, numpy as np, sys, pickle as pkl 5 | import argparse, json 6 | from copy import deepcopy 7 | from mot_3d.data_protos import BBox, Validity 8 | import mot_3d.utils as utils 9 | from tqdm import tqdm 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--name', type=str, default='baseline') 14 | parser.add_argument('--obj_types', type=str, default='vehicle,pedestrian,cyclist') 15 | parser.add_argument('--result_folder', type=str, default='../mot_results/') 16 | parser.add_argument('--data_folder', type=str, default='../datasets/waymo/mot/') 17 | args = parser.parse_args() 18 | 19 | 20 | def get_context_name(file_name: str): 21 | context = file_name.split('.')[0] # file name 22 | context = context.split('-')[1] # after segment 23 | context = context.split('w')[0] # before with 24 | context = context[:-1] 25 | return context 26 | 27 | def pred_content_filter(pred_contents, pred_states): 28 | result_contents = list() 29 | for contents, states in zip(pred_contents, pred_states): 30 | indices = [i for i in range(len(states)) if Validity.valid(states[i])] 31 | frame_contents = [contents[i] for i in indices] 32 | result_contents.append(frame_contents) 33 | return result_contents 34 | 35 | def main(name, obj_type, result_folder, raw_data_folder, output_folder, output_file_name): 36 | summary_folder = os.path.join(result_folder, 'summary', obj_type) 37 | file_names = sorted(os.listdir(summary_folder))[:] 38 | 39 | if obj_type == 'vehicle': 40 | type_token = 1 41 | elif obj_type == 'pedestrian': 42 | type_token = 2 43 | elif obj_type == 'cyclist': 44 | type_token = 4 45 | 46 | ts_info_folder = os.path.join(raw_data_folder, 'ts_info') 47 | ego_info_folder = os.path.join(raw_data_folder, 'ego_info') 48 | obj_list = list() 49 | 50 | print('Converting TYPE {:} into WAYMO Format'.format(obj_type)) 51 | pbar = tqdm(total=len(file_names)) 52 | for file_index, file_name in enumerate(file_names[:]): 53 | file_name_prefix = file_name.split('.')[0] 54 | context_name = get_context_name(file_name) 55 | 56 | ts_path = os.path.join(ts_info_folder, '{}.json'.format(file_name_prefix)) 57 | ts_data = json.load(open(ts_path, 'r')) # list of time stamps by order of frame 58 | 59 | # load ego motions 60 | ego_motions = np.load(os.path.join(ego_info_folder, '{:}.npz'.format(file_name_prefix)), allow_pickle=True) 61 | 62 | pred_result = np.load(os.path.join(summary_folder, file_name), allow_pickle=True) 63 | pred_ids, pred_bboxes, pred_states = pred_result['ids'], pred_result['bboxes'], pred_result['states'] 64 | pred_bboxes = pred_content_filter(pred_bboxes, pred_states) 65 | pred_ids = pred_content_filter(pred_ids, pred_states) 66 | pred_velos, pred_accels = None, None 67 | 68 | obj_list += create_sequence(pred_ids, pred_bboxes, type_token, context_name, 69 | ts_data, ego_motions, pred_velos, pred_accels) 70 | pbar.update(1) 71 | pbar.close() 72 | objects = metrics_pb2.Objects() 73 | for obj in obj_list: 74 | objects.objects.append(obj) 75 | 76 | output_folder = os.path.join(output_folder, obj_type) 77 | if not os.path.exists(output_folder): 78 | os.makedirs(output_folder) 79 | output_path = os.path.join(output_folder, '{:}.bin'.format(output_file_name)) 80 | f = open(output_path, 'wb') 81 | f.write(objects.SerializeToString()) 82 | f.close() 83 | return 84 | 85 | 86 | def create_single_pred_bbox(id, bbox, type_token, time_stamp, context_name, inv_ego_motion, velo, accel): 87 | o = metrics_pb2.Object() 88 | o.context_name = context_name 89 | o.frame_timestamp_micros = time_stamp 90 | box = label_pb2.Label.Box() 91 | 92 | proto_box = BBox.array2bbox(bbox) 93 | proto_box = BBox.bbox2world(inv_ego_motion, proto_box) 94 | bbox = BBox.bbox2array(proto_box) 95 | 96 | box.center_x, box.center_y, box.center_z, box.heading = bbox[:4] 97 | box.length, box.width, box.height = bbox[4:7] 98 | o.object.box.CopyFrom(box) 99 | o.score = bbox[-1] 100 | 101 | meta_data = label_pb2.Label.Metadata() 102 | o.object.metadata.CopyFrom(meta_data) 103 | 104 | o.object.id = '{:}_{:}'.format(type_token, id) 105 | o.object.type = type_token 106 | return o 107 | 108 | 109 | def create_sequence(pred_ids, pred_bboxes, type_token, context_name, time_stamps, ego_motions, pred_velos, pred_accels): 110 | frame_num = len(pred_ids) 111 | sequence_objects = list() 112 | for frame_index in range(frame_num): 113 | time_stamp = time_stamps[frame_index] 114 | frame_obj_num = len(pred_ids[frame_index]) 115 | ego_motion = ego_motions[str(frame_index)] 116 | inv_ego_motion = np.linalg.inv(ego_motion) 117 | for obj_index in range(frame_obj_num): 118 | pred_id = pred_ids[frame_index][obj_index] 119 | pred_bbox = pred_bboxes[frame_index][obj_index] 120 | pred_velo, pred_accel = None, None 121 | sequence_objects.append(create_single_pred_bbox( 122 | pred_id, pred_bbox, type_token, time_stamp, context_name, inv_ego_motion, pred_velo, pred_accel)) 123 | return sequence_objects 124 | 125 | 126 | def merge_results(output_folder, obj_types, output_file_name): 127 | print('Merging different object types') 128 | result_objs = list() 129 | for obj_type in obj_types: 130 | bin_path = os.path.join(output_folder, obj_type, '{:}.bin'.format(output_file_name)) 131 | f = open(bin_path, 'rb') 132 | objects = metrics_pb2.Objects() 133 | objects.ParseFromString(f.read()) 134 | f.close() 135 | objects = objects.objects 136 | result_objs += objects 137 | 138 | output_objs = metrics_pb2.Objects() 139 | for obj in result_objs: 140 | output_objs.objects.append(obj) 141 | 142 | output_path = os.path.join(output_folder, '{:}.bin'.format(output_file_name)) 143 | f = open(output_path, 'wb') 144 | f.write(output_objs.SerializeToString()) 145 | f.close() 146 | 147 | 148 | if __name__ == '__main__': 149 | result_folder = os.path.join(args.result_folder, args.name) 150 | output_folder = os.path.join(result_folder, 'bin') 151 | os.makedirs(output_folder, exist_ok=True) 152 | 153 | obj_types = args.obj_types.split(',') 154 | for obj_type in obj_types: 155 | main(args.name, obj_type, result_folder, args.data_folder, output_folder, 'pred') 156 | 157 | merge_results(output_folder, obj_types, 'pred') 158 | --------------------------------------------------------------------------------