├── .gitignore
├── LICENSE
├── README.md
├── configs
├── config_library.md
├── nu_configs
│ ├── giou.yaml
│ └── iou.yaml
└── waymo_configs
│ ├── pd_kf_giou.yaml
│ └── vc_kf_giou.yaml
├── data_loader
├── __init__.py
├── nuscenes_loader.py
└── waymo_loader.py
├── docs
├── config.md
├── data_preprocess.md
├── demo.png
├── nuScenes.md
├── output_format.md
├── simpletrack_gif.gif
└── waymo.md
├── mot_3d
├── __init__.py
├── association.py
├── data_protos
│ ├── __init__.py
│ ├── bbox.py
│ └── validity.py
├── frame_data.py
├── life
│ ├── __init__.py
│ └── hit_manager.py
├── mot.py
├── motion_model
│ ├── __init__.py
│ └── kalman_filter.py
├── preprocessing
│ ├── __init__.py
│ ├── bbox_coarse_hash.py
│ └── nms.py
├── redundancy
│ ├── __init__.py
│ └── redundancy.py
├── tracklet
│ ├── __init__.py
│ └── tracklet.py
├── update_info_data.py
├── utils
│ ├── __init__.py
│ ├── data_utils.py
│ └── geometry.py
└── visualization
│ ├── __init__.py
│ └── visualizer2d.py
├── preprocessing
├── nuscenes_data
│ ├── detection.py
│ ├── ego_pose.py
│ ├── gt_info.py
│ ├── nuscenes_preprocess.sh
│ ├── raw_pc.py
│ ├── sensor_calibration.py
│ ├── time_stamp.py
│ └── token_info.py
└── waymo_data
│ ├── detection.py
│ ├── ego_info.py
│ ├── gt_bin_decode.py
│ ├── raw_pc.py
│ ├── time_stamp.py
│ └── waymo_preprocess.sh
├── requirements.txt
├── setup.py
└── tools
├── demo.py
├── main_nuscenes.py
├── main_nuscenes_10hz.py
├── main_waymo.py
├── nuscenes_result_creation.py
├── nuscenes_result_creation_10hz.py
├── nuscenes_type_merge.py
└── waymo_pred_bin.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 | *.npy
7 | *.zip
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 tusen-ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimpleTrack: Simple yet Effective 3D Multi-object Tracking
2 |
3 | This is the repository for our paper [SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking](https://arxiv.org/abs/2111.09621). We are still working on writing the documentations and cleaning up the code, but the following parts are sufficient for you to replicate the results in our paper. For more variants of the model, we have already moved all of our code onto the `dev` branch, so please feel free to check it out if you really need to delve deep recently. We will try our best to get everything ready as soon as possible.
4 |
5 | If you find our paper or code useful for you, please consider cite us by:
6 | ```
7 | @article{pang2021simpletrack,
8 | title={SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking},
9 | author={Pang, Ziqi and Li, Zhichao and Wang, Naiyan},
10 | journal={arXiv preprint arXiv:2111.09621},
11 | year={2021}
12 | }
13 | ```
14 |
15 |
16 |
17 | - [ ] Accelerating the code, make the IoU/GIoU computation parallel.
18 | - [ ] Add documentation for codebase.
19 |
20 | ## Installation
21 |
22 | ### Environment Requirements
23 |
24 | `SimpleTrack` requires `python>=3.6` and the packages of `pip install -r requirements.txt`. For the experiments on Waymo Open Dataset, please install the devkit following the instructions at [waymo open dataset devkit](https://github.com/waymo-research/waymo-open-dataset).
25 |
26 | ### Installation
27 |
28 | We implement the `SimpleTrack` algorithm as a library `mot_3d`. Please run `pip install -e ./` to install it locally.
29 |
30 | ## Demo and API Example
31 |
32 | ### Demo
33 |
34 | We provide a demo based on the first sequence with ID `10203656353524179475_7625_000_7645_000` from the validation set of [Waymo Open Dataset](https://waymo.com/open/) and the detection from [CenterPoint](https://github.com/tianweiy/CenterPoint). (We are thankful and hope that we are not violating any terms here.)
35 |
36 | First, download the [demo_data](https://www.dropbox.com/s/m8vt7t7tqofaoq2/demo_data.zip?dl=0) and extract it locally. It contains the necessary information and is already preprocessed according to [our preprocessing programs](./docs/data_preprocess.md). To run the demo, please run the following command. It will provide interactive visualization with `matplotlib.pyplt`. Therefore, it is recommended to run this demo locally.
37 |
38 | ```bash
39 | python tools/demo.py \
40 | --name demo \
41 | --det_name cp \
42 | --obj_type vehicle \
43 | --config_path configs/waymo_configs/vc_kf_giou.yaml \
44 | --data_folder ./demo_data/ \
45 | --visualize
46 | ```
47 |
48 | An example output for the visualization is the following figure.
49 |
50 |
51 |
52 | In the visualization, the red bounding boxes are the output tracking results with their IDs. The blue ones are the tracking results that are not output due to low confidence score. The green ones are the detection bounding boxes with scores. The black ones are the ground truth bounding boxes.
53 |
54 | ### API Example
55 |
56 | The most important function is `tracker.frame_mot()`. An object of `MOTModel` iteratively digests the information from each frame `FrameData` and infers the tracking result on each frame.
57 |
58 | ## Inference on Waymo Open Dataset and nuScenes
59 |
60 | Refer to the documentation of [Waymo Open Dataset Inference](./docs/waymo.md) and [nuScenes Inference](./docs/nuScenes.md). **Important: please rigorously follow the config file path and instructions in the documentation to reproduce the results.**
61 |
62 | The detailed metrics and files are at [Dropbox Link](https://www.dropbox.com/sh/2zcpf7wyho7x21q/AABRtGP75KHghr0wyoVqzFFXa?dl=0), Waymo Open Dataset in [Dropbox link](https://www.dropbox.com/sh/u6o8dcwmzya04uk/AAAUsNvTt7ubXul9dx5Xnp4xa?dl=0), nuScenes in [Dropbox Link](https://www.dropbox.com/sh/8906exnes0u5e89/AAD0xLwW1nq_QiuUBaYDrQVna?dl=0).
63 |
64 | For the metrics on test set, please refer to our paper or the leaderboard.
65 |
66 | ## Related Documentations
67 |
68 | To enable the better usages of our `mot_3d` library, we provide a list useful documentations, and will add more in the future.
69 |
70 | * [Read and Use the Configurations](./docs/config.md). We explain how to specify the behaviors of trackers in this documentation, such as two-stage association, the thresholds for association, etc.
71 | * [Format of Output](./docs/output_format.md). We explain the output format for the APIs in `SimpleTrack`, so that you may directly use the functions provided. (in progress)
72 | * Visualization with `mot_3d` (in progress)
73 | * Structure of `mot_3d` (in progress)
74 |
--------------------------------------------------------------------------------
/configs/config_library.md:
--------------------------------------------------------------------------------
1 | # Configuration Library
--------------------------------------------------------------------------------
/configs/nu_configs/giou.yaml:
--------------------------------------------------------------------------------
1 | running:
2 | covariance: default
3 | score_threshold: 0.01
4 | max_age_since_update: 2
5 | min_hits_to_birth: 0
6 | match_type: bipartite
7 | asso: giou
8 | has_velo: true
9 | nms_thres: 0.1
10 | motion_model: kf
11 | asso_thres:
12 | giou: 1.5
13 | iou: 0.9
14 |
15 | redundancy:
16 | mode: default
17 | det_score_threshold:
18 | iou: 0.01
19 | giou: 0.01
20 | det_dist_threshold:
21 | iou: 0.1
22 | giou: -0.5
23 |
24 | data_loader:
25 | pc: true
26 | nms: true
27 | nms_thres: 0.1
--------------------------------------------------------------------------------
/configs/nu_configs/iou.yaml:
--------------------------------------------------------------------------------
1 | running:
2 | covariance: default
3 | score_threshold: 0.01
4 | max_age_since_update: 2
5 | min_hits_to_birth: 1
6 | match_type: bipartite
7 | asso: iou
8 | has_velo: false
9 | motion_model: kf
10 | asso_thres:
11 | giou: 1.5
12 | iou: 0.9
13 |
14 | redundancy:
15 | mode: default
16 | det_score_threshold:
17 | iou: 0.01
18 | giou: 0.01
19 | m_dis: 0.01
20 | det_dist_threshold:
21 | iou: 0.1
22 | giou: -0.5
23 |
24 | data_loader:
25 | pc: false
--------------------------------------------------------------------------------
/configs/waymo_configs/pd_kf_giou.yaml:
--------------------------------------------------------------------------------
1 | running:
2 | covariance: default
3 | score_threshold: 0.5
4 | max_age_since_update: 2
5 | min_hits_to_birth: 3
6 | match_type: bipartite
7 | asso: giou
8 | has_velo: false
9 | motion_model: kf
10 | asso_thres:
11 | iou: 0.9
12 | giou: 1.5
13 |
14 | redundancy:
15 | mode: mm
16 | det_score_threshold:
17 | iou: 0.1
18 | giou: 0.1
19 | det_dist_threshold:
20 | iou: 0.1
21 | giou: -0.5
22 |
23 | data_loader:
24 | pc: true
25 | nms: true
26 | nms_thres: 0.25
--------------------------------------------------------------------------------
/configs/waymo_configs/vc_kf_giou.yaml:
--------------------------------------------------------------------------------
1 | running:
2 | covariance: default
3 | score_threshold: 0.7
4 | max_age_since_update: 2
5 | min_hits_to_birth: 3
6 | match_type: bipartite
7 | asso: giou
8 | has_velo: false
9 | motion_model: kf
10 | asso_thres:
11 | iou: 0.9
12 | giou: 1.5
13 |
14 | redundancy:
15 | mode: mm
16 | det_score_threshold:
17 | iou: 0.1
18 | giou: 0.1
19 | det_dist_threshold:
20 | iou: 0.1
21 | giou: -0.5
22 |
23 | data_loader:
24 | pc: true
25 | nms: true
26 | nms_thres: 0.25
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | from data_loader.waymo_loader import WaymoLoader
2 | from data_loader.nuscenes_loader import NuScenesLoader, NuScenesLoader10Hz
--------------------------------------------------------------------------------
/data_loader/nuscenes_loader.py:
--------------------------------------------------------------------------------
1 | """ Example of data loader:
2 | The data loader has to be an iterator:
3 | Return a dict of frame data
4 | Users may create the logic of your own data loader
5 | """
6 | import os, numpy as np, json
7 | from pyquaternion import Quaternion
8 | from nuscenes.utils.data_classes import Box
9 | from mot_3d.data_protos import BBox
10 | import mot_3d.utils as utils
11 | from mot_3d.preprocessing import nms
12 |
13 |
14 | def transform_matrix(translation: np.ndarray = np.array([0, 0, 0]),
15 | rotation: np.ndarray = np.array([1, 0, 0, 0]),
16 | inverse: bool = False) -> np.ndarray:
17 | tm = np.eye(4)
18 | rotation = Quaternion(rotation)
19 |
20 | if inverse:
21 | rot_inv = rotation.rotation_matrix.T
22 | trans = np.transpose(-np.array(translation))
23 | tm[:3, :3] = rot_inv
24 | tm[:3, 3] = rot_inv.dot(trans)
25 | else:
26 | tm[:3, :3] = rotation.rotation_matrix
27 | tm[:3, 3] = np.transpose(np.array(translation))
28 | return tm
29 |
30 |
31 | def nu_array2mot_bbox(b):
32 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10]))
33 | mot_bbox = BBox(
34 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2],
35 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2],
36 | o=nu_box.orientation.yaw_pitch_roll[0]
37 | )
38 | if len(b) == 11:
39 | mot_bbox.s = b[-1]
40 | return mot_bbox
41 |
42 |
43 | class NuScenesLoader:
44 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame):
45 | """ initialize with the path to data
46 | Args:
47 | data_folder (str): root path to your data
48 | """
49 | self.configs = configs
50 | self.segment = segment_name
51 | self.data_loader = data_folder
52 | self.det_data_folder = det_data_folder
53 | self.type_token = type_token
54 |
55 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r'))
56 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)),
57 | allow_pickle=True)
58 | self.calib_info = np.load(os.path.join(data_folder, 'calib_info', '{:}.npz'.format(segment_name)),
59 | allow_pickle=True)
60 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)),
61 | allow_pickle=True)
62 | self.det_type_filter = True
63 |
64 | self.use_pc = configs['data_loader']['pc']
65 | if self.use_pc:
66 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)),
67 | allow_pickle=True)
68 |
69 | self.max_frame = len(self.dets['bboxes'])
70 | self.cur_frame = start_frame
71 |
72 | def __iter__(self):
73 | return self
74 |
75 | def __next__(self):
76 | if self.cur_frame >= self.max_frame:
77 | raise StopIteration
78 |
79 | result = dict()
80 | result['time_stamp'] = self.ts_info[self.cur_frame] * 1e-6
81 | ego = self.ego_info[str(self.cur_frame)]
82 | ego_matrix = transform_matrix(ego[:3], ego[3:])
83 | result['ego'] = ego_matrix
84 |
85 | bboxes = self.dets['bboxes'][self.cur_frame]
86 | inst_types = self.dets['types'][self.cur_frame]
87 | frame_bboxes = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token]
88 | result['det_types'] = [inst_types[i] for i in range(len(inst_types)) if inst_types[i] in self.type_token]
89 | result['dets'] = [nu_array2mot_bbox(b) for b in frame_bboxes]
90 | result['aux_info'] = dict()
91 | if 'velos' in list(self.dets.keys()):
92 | cur_velos = self.dets['velos'][self.cur_frame]
93 | result['aux_info']['velos'] = [cur_velos[i] for i in range(len(cur_velos))
94 | if inst_types[i] in self.type_token]
95 | else:
96 | result['aux_info']['velos'] = None
97 |
98 | result['dets'], result['det_types'], result['aux_info']['velos'] = \
99 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], 0.1)
100 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']]
101 |
102 | result['pc'] = None
103 | if self.use_pc:
104 | pc = self.pcs[str(self.cur_frame)][:, :3]
105 | calib = self.calib_info[str(self.cur_frame)]
106 | calib_trans, calib_rot = np.asarray(calib[:3]), Quaternion(np.asarray(calib[3:]))
107 | pc = np.dot(pc, calib_rot.rotation_matrix.T)
108 | pc += calib_trans
109 | result['pc'] = utils.pc2world(ego_matrix, pc)
110 |
111 | # if 'velos' in list(self.dets.keys()):
112 | # cur_frame_velos = self.dets['velos'][self.cur_frame]
113 | # result['aux_info']['velos'] = [cur_frame_velos[i]
114 | # for i in range(len(bboxes)) if inst_types[i] in self.type_token]
115 | result['aux_info']['is_key_frame'] = True
116 |
117 | self.cur_frame += 1
118 | return result
119 |
120 | def __len__(self):
121 | return self.max_frame
122 |
123 | def frame_nms(self, dets, det_types, velos, thres):
124 | frame_indexes, frame_types = nms(dets, det_types, thres)
125 | result_dets = [dets[i] for i in frame_indexes]
126 | result_velos = None
127 | if velos is not None:
128 | result_velos = [velos[i] for i in frame_indexes]
129 | return result_dets, frame_types, result_velos
130 |
131 |
132 | class NuScenesLoader10Hz:
133 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame):
134 | """ initialize with the path to data
135 | Args:
136 | data_folder (str): root path to your data
137 | """
138 | self.configs = configs
139 | self.segment = segment_name
140 | self.data_loader = data_folder
141 | self.det_data_folder = det_data_folder
142 | self.type_token = type_token
143 |
144 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r'))
145 | self.time_stamps = [t[0] for t in self.ts_info]
146 | self.is_key_frames = [t[1] for t in self.ts_info]
147 |
148 | self.token_info = json.load(open(os.path.join(data_folder, 'token_info', '{:}.json'.format(segment_name)), 'r'))
149 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)),
150 | allow_pickle=True)
151 | self.calib_info = np.load(os.path.join(data_folder, 'calib_info', '{:}.npz'.format(segment_name)),
152 | allow_pickle=True)
153 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)),
154 | allow_pickle=True)
155 | self.det_type_filter = True
156 |
157 | self.use_pc = configs['data_loader']['pc']
158 | if self.use_pc:
159 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)),
160 | allow_pickle=True)
161 |
162 | self.max_frame = len(self.dets['bboxes'])
163 | self.selected_frames = [i for i in range(self.max_frame) if self.token_info[i][3]]
164 | self.cur_selected_index = 0
165 | self.cur_frame = start_frame
166 |
167 | def __iter__(self):
168 | return self
169 |
170 | def __next__(self):
171 | if self.cur_selected_index >= len(self.selected_frames):
172 | raise StopIteration
173 | self.cur_frame = self.selected_frames[self.cur_selected_index]
174 |
175 | result = dict()
176 | result['time_stamp'] = self.time_stamps[self.cur_frame] * 1e-6
177 | ego = self.ego_info[str(self.cur_frame)]
178 | ego_matrix = transform_matrix(ego[:3], ego[3:])
179 | result['ego'] = ego_matrix
180 |
181 | bboxes = self.dets['bboxes'][self.cur_frame]
182 | inst_types = self.dets['types'][self.cur_frame]
183 | frame_bboxes = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token]
184 | result['det_types'] = [inst_types[i] for i in range(len(inst_types)) if inst_types[i] in self.type_token]
185 |
186 | result['dets'] = [nu_array2mot_bbox(b) for b in frame_bboxes]
187 | result['aux_info'] = dict()
188 | if 'velos' in list(self.dets.keys()):
189 | cur_velos = self.dets['velos'][self.cur_frame]
190 | result['aux_info']['velos'] = [cur_velos[i] for i in range(len(cur_velos))
191 | if inst_types[i] in self.type_token]
192 | else:
193 | result['aux_info']['velos'] = None
194 | result['dets'], result['det_types'], result['aux_info']['velos'] = \
195 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], 0.1)
196 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']]
197 |
198 | result['pc'] = None
199 | if self.use_pc:
200 | pc = self.pcs[str(self.cur_frame)][:, :3]
201 | calib = self.calib_info[str(self.cur_frame)]
202 | calib_trans, calib_rot = np.asarray(calib[:3]), Quaternion(np.asarray(calib[3:]))
203 | pc = np.dot(pc, calib_rot.rotation_matrix.T)
204 | pc += calib_trans
205 | result['pc'] = utils.pc2world(ego_matrix, pc)
206 |
207 | # if 'velos' in list(self.dets.keys()):
208 | # cur_frame_velos = self.dets['velos'][self.cur_frame]
209 | # result['aux_info']['velos'] = [cur_frame_velos[i]
210 | # for i in range(len(bboxes)) if inst_types[i] in self.type_token]
211 | # print(result['aux_info']['velos'])
212 | result['aux_info']['is_key_frame'] = self.is_key_frames[self.cur_frame]
213 |
214 | self.cur_selected_index += 1
215 | return result
216 |
217 | def __len__(self):
218 | return len(self.selected_frames)
219 |
220 | def frame_nms(self, dets, det_types, velos, thres):
221 | frame_indexes, frame_types = nms(dets, det_types, thres)
222 | result_dets = [dets[i] for i in frame_indexes]
223 | result_velos = None
224 | if velos is not None:
225 | result_velos = [velos[i] for i in frame_indexes]
226 | return result_dets, frame_types, result_velos
227 |
--------------------------------------------------------------------------------
/data_loader/waymo_loader.py:
--------------------------------------------------------------------------------
1 | """ Example of data loader:
2 | The data loader has to be an iterator:
3 | Return a dict of frame data
4 | Users may create the logic of your own data loader
5 | """
6 | import os, numpy as np, json
7 | import mot_3d.utils as utils
8 | from mot_3d.data_protos import BBox
9 | from mot_3d.preprocessing import nms
10 |
11 |
12 | class WaymoLoader:
13 | def __init__(self, configs, type_token, segment_name, data_folder, det_data_folder, start_frame):
14 | """ initialize with the path to data
15 | Args:
16 | data_folder (str): root path to your data
17 | """
18 | self.configs = configs
19 | self.segment = segment_name
20 | self.data_loader = data_folder
21 | self.det_data_folder = det_data_folder
22 | self.type_token = type_token
23 |
24 | self.nms = configs['data_loader']['nms']
25 | self.nms_thres = configs['data_loader']['nms_thres']
26 |
27 | self.ts_info = json.load(open(os.path.join(data_folder, 'ts_info', '{:}.json'.format(segment_name)), 'r'))
28 | self.ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)),
29 | allow_pickle=True)
30 | self.dets = np.load(os.path.join(det_data_folder, 'dets', '{:}.npz'.format(segment_name)),
31 | allow_pickle=True)
32 | self.det_type_filter = True
33 |
34 | self.use_pc = configs['data_loader']['pc']
35 | if self.use_pc:
36 | self.pcs = np.load(os.path.join(data_folder, 'pc', 'raw_pc', '{:}.npz'.format(segment_name)),
37 | allow_pickle=True)
38 |
39 | self.max_frame = len(self.dets['bboxes'])
40 | self.cur_frame = start_frame
41 |
42 | def __iter__(self):
43 | return self
44 |
45 | def __next__(self):
46 | if self.cur_frame >= self.max_frame:
47 | raise StopIteration
48 |
49 | result = dict()
50 | result['time_stamp'] = self.ts_info[self.cur_frame] * 1e-6
51 | result['ego'] = self.ego_info[str(self.cur_frame)]
52 |
53 | bboxes = self.dets['bboxes'][self.cur_frame]
54 | inst_types = self.dets['types'][self.cur_frame]
55 | selected_dets = [bboxes[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token]
56 | result['det_types'] = [inst_types[i] for i in range(len(bboxes)) if inst_types[i] in self.type_token]
57 | result['dets'] = [BBox.bbox2world(result['ego'], BBox.array2bbox(b))
58 | for b in selected_dets]
59 |
60 | result['pc'] = None
61 | if self.use_pc:
62 | pc = self.pcs[str(self.cur_frame)]
63 | result['pc'] = utils.pc2world(result['ego'], pc)
64 |
65 | result['aux_info'] = {'is_key_frame': True}
66 | if 'velos' in self.dets.keys():
67 | cur_frame_velos = self.dets['velos'][self.cur_frame]
68 | result['aux_info']['velos'] = [np.array(cur_frame_velos[i])
69 | for i in range(len(bboxes)) if inst_types[i] in self.type_token]
70 | result['aux_info']['velos'] = [utils.velo2world(result['ego'], v)
71 | for v in result['aux_info']['velos']]
72 | else:
73 | result['aux_info']['velos'] = None
74 |
75 | if self.nms:
76 | result['dets'], result['det_types'], result['aux_info']['velos'] = \
77 | self.frame_nms(result['dets'], result['det_types'], result['aux_info']['velos'], self.nms_thres)
78 | result['dets'] = [BBox.bbox2array(d) for d in result['dets']]
79 |
80 | self.cur_frame += 1
81 | return result
82 |
83 | def __len__(self):
84 | return self.max_frame
85 |
86 | def frame_nms(self, dets, det_types, velos, thres):
87 | frame_indexes, frame_types = nms(dets, det_types, thres)
88 | result_dets = [dets[i] for i in frame_indexes]
89 | result_velos = None
90 | if velos is not None:
91 | result_velos = [velos[i] for i in frame_indexes]
92 | return result_dets, frame_types, result_velos
93 |
--------------------------------------------------------------------------------
/docs/config.md:
--------------------------------------------------------------------------------
1 | # Configurations
2 |
3 | This documentation provides a brief guide for writing and reading the configuration files. We specify the tracker behaviors with a single `.yaml` file as in the folder of `./configs/`.
4 |
5 | ## Explanation with the Example from Waymo Open Dataset
6 |
7 | Take the configs for Waymo Open Dataset (vehicles) as example, we explain with `configs/waymo_configs/vc_kf_giou.yaml`.
8 |
9 | We provide a thorough annotation below, but it may be confusing if you delve deep directly. So the following are the key points for specifying the model.
10 |
11 | * **Data Preprocessing** is controlled by the `nms` and `nms_thres` arguments in `data_loader` field. They specify whether to use NMS and what is the IoU threshold for NMS.
12 |
13 | * **Association Metric** is specified in `running-asso`, where you can choose from IoU and GIoU for now. To specify the threshold for successful association, we treat `1-IoU` or `1-GIoU` as the distance, and it must be smaller than the corresponding numbers in `asso_thres`. If you are interested in L2 distance or Mahalanobis distance, please refer to the `dev` branch. We are still cleaning up and adding documentations.
14 | * **Motion Model** is by default Kalman filter. We are adding more options, such as velocity model. Please refer to the `dev` branch if you wish to explore on your own.
15 | * **Two-stage Association** is specified by the `redundancy` module. `mm` denotes two-stage and `default` denotes the conventional single-stage association.
16 | * NOTE on **nuScenes 10Hz two-stage association**. We are trying to incorporate this part into the config file instead of hard coding everything. Please temporarily go to the function `non_key_frame_mot` in `mot_3d/mot.py` for how we deal with the tracking on no-key frames.
17 |
18 |
19 | ```yaml
20 | running:
21 | covariance: default # not used
22 | score_threshold: 0.7 # detection score threshold for first-stage association and output
23 | max_age_since_update: 2 # count for death, same as AB3DMOT
24 | min_hits_to_birth: 3 # count for birth, same as AB3DMOT
25 | match_type: bipartite # use hungarian (biparitite) or greedy algorithm (greedy) for association
26 | asso: giou # association metric, we support GIoU (giou) and IoU (iou) for now
27 | has_velo: false
28 | motion_model: kf # Kalman filter (kf) as motion model
29 | asso_thres:
30 | iou: 0.9 # association threshold, 1 - IoU has to be smaller than it.
31 | giou: 1.5 # association threshold, 1 - GIoU has to be smaller than it.
32 |
33 | redundancy:
34 | mode: mm # (mm) denotes two-stage association, (default) denotes one stage association (see nuScenes configs)
35 | det_score_threshold: # detection score threshold for two-stage association
36 | iou: 0.1
37 | giou: 0.1
38 | det_dist_threshold: # association threshold for two-stage association
39 | iou: 0.1 # IoU has to be greater than this
40 | giou: -0.5 # GIoU has to be greater than this
41 |
42 | data_loader:
43 | pc: true # load point clouds for visualization
44 | nms: true # apply NMS for data preprocessing, True or False
45 | nms_thres: 0.25 # IoU-3D threshold for NMS
46 | ```
--------------------------------------------------------------------------------
/docs/data_preprocess.md:
--------------------------------------------------------------------------------
1 | # Data Preprocessing
2 |
3 | ## Waymo Open Dataset
4 |
5 | We decompose the preprocessing of Waymo Open Dataset into the following steps.
6 |
7 | ### 1. Raw Data
8 |
9 | This step converts the information needed from `tf_records` into more handy forms. Suppose the folder storing `tf_records` is `raw_data_folder`, the target location if `data_folder`, run the following command: (you can specify `process_num` to be an integer greater than 1 for faster preprocessing.)
10 |
11 | ```bash
12 | cd preprocessing/waymo_data
13 | bash waymo_preprocess.sh ${raw_data_folder} ${data_folder} ${process_num}
14 | ```
15 |
16 | ### 2. Ground Truth Information
17 |
18 | The ground truth for the 3D MOT and 3D Detection are the same. **You have to download a .bin file from Waymo Open Dataset for the ground truth, which we have no right to share according to the license.**
19 |
20 | To decode the ground truth information, suppose `bin_path` is the path to the ground truth file, `data_folder` is the target location of data preprocess. Eventually, we store the ground truth information in `${data_folder}/detection/gt/dets/`.
21 |
22 | ```bash
23 | cd preprocessing/waymo_data
24 | python gt_bin_decode.py --data_folder ${data_dir} --file_path ${bin_path}
25 | ```
26 |
27 | ### 3. Detection
28 |
29 | To infer 3D MOT on your detection file, we still need the `bin_path` indicating the path to the detection results, then name your detection as `name` for future convenience. The preprocessing of the detection follows the below scripts. (Only use `metadata` if you want to save the velocity / acceleration contained in the detection file.)
30 |
31 | ```bash
32 | cd preprocessing/waymo_data
33 | python detection.py --name ${name} --data_folder ${data_dir} --file_path ${bin_path} --metadata --id
34 | ```
35 |
36 | ## nuScenes
37 |
38 | ### 1. Preprocessing
39 |
40 | To preprocessing the raw data from nuScenes, suppose you have put the raw data of nuScenes at `raw_data_dir`. We provide two modes of proprocessing:
41 | * Only the data on the key frames (2Hz) is extracted, the target location is `data_dir_2hz`.
42 | * All the data (20Hz) is extracted to the location of `data_dir_20hz`.
43 |
44 | Run the following commands.
45 |
46 | ```bash
47 | cd preprocessing/nuscenes_data
48 | bash nuscenes_preprocess.sh ${raw_data_dir} ${data_dir_2hz} ${data_dir_20hz}
49 | ```
50 |
51 | ### 2. Detection
52 |
53 | To infer 3D MOT on your detection file, we convert the json format detection files at `file_path` into the .npz files similar to our approach on Waymo Open Dataset. Please name your detection as `name` for future convenience. The preprocessing of the detection follows the below scripts. (Only use `velo` if you want to save the velocity contained in the detection file.)
54 |
55 | ```bash
56 | cd preprocessing/nuscenes_data
57 |
58 | # for 2Hz detection file
59 | python detection.py --raw_data_folder ${raw_data_dir} --data_folder ${data_dir_2hz} --det_name ${name} --file_path ${file_path} --mode 2hz --velo
60 |
61 | # for 20Hz detection file
62 | python detection.py --raw_data_folder ${raw_data_dir} --data_folder ${data_dir_20hz} --det_name ${name} --file_path ${file_path} --mode 20hz --velo
63 | ```
64 |
--------------------------------------------------------------------------------
/docs/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/docs/demo.png
--------------------------------------------------------------------------------
/docs/nuScenes.md:
--------------------------------------------------------------------------------
1 | # nuScenes Inference
2 |
3 | ## Data Preprocessing
4 |
5 | Please follow the command at [Data Preprocessing Documentations](./data_preprocess.md). After data preprocessing, the directories to the preprocessed data are `nuscenes2hz_data_dir` and `nuscenes20hz_data_dir` for the key-frame and all-frame data of nuScenes.
6 |
7 | ## Inference with SimpleTrack
8 |
9 | **Important: Please strictly follow the config file we use.**
10 |
11 | For the setting of inference only on key frames (2Hz), run the following command. The per-sequence results are then saved in the `${nuscenes_result_dir}/SimpleTrack2Hz/summary`, with subfolders of the types of the objects containing the results for each type.
12 |
13 | ```bash
14 | python tools/main_nuscenes.py \
15 | --name SimpleTrack2Hz \
16 | --det_name ${det_name} \
17 | --config_path configs/nu_configs/giou.yaml \
18 | --result_folder ${nuscenes_result_dir} \
19 | --data_folder ${nuscenes2hz_data_dir} \
20 | --process ${proc_num}
21 | ```
22 |
23 | For the 10Hz settings proposed in the paper, run the following commands. The per-sequence results are then saved in the `${nuscenes_result_dir}/SimpleTrack20Hz/summary`.
24 |
25 | ```bash
26 | python tools/main_nuscenes_10hz.py \
27 | --name SimpleTrack10Hz \
28 | --det_name ${det_name} \
29 | --config_path configs/nu_configs/giou.yaml \
30 | --result_folder ${nuscenes_result_dir} \
31 | --data_folder ${nuscenes20hz_data_dir} \
32 | --process ${proc_num}
33 | ```
34 |
35 | I use the process number of 150 in my experiments, which is the same as the number of sequences in nuScenes validation set.
36 |
37 | ## Output Format
38 |
39 | In the folder of `${nuscenes_result_dir}/SimpleTrack/summary/`, there are sub-folders corresponding to each object type in nuScenes. Inside each sub-folder, there are 150 `.npz` files, matching the 150 sequences in the nuScenes validation set. For the format in each `.npz` file, please refer to [Output Format](output_format.md).
40 |
41 | ## Converting to nuScenes Format
42 |
43 | Use the following command to convert the output results in the SimpleTrack format into the `.json` format specified by the nuScenes officials. After running the following commands, there will the tracking results in `.json` formats in `${nuscenes_result_dir}/SimpleTrack2Hz/results` and `${nuscenes_result_dir}/SimpleTrack10Hz/results`.
44 |
45 | For the setting of 2Hz, which only inferences on the key frames, run the following commands.
46 |
47 | ```bash
48 | python tools/nuscenes_result_creation.py \
49 | --name SimpleTrack2Hz \
50 | --result_folder ${nuscenes_result_dir} \
51 | --data_folder ${nuscenes2hz_data_dir}
52 |
53 | python tools/nuscenes_type_merge.py \
54 | --name SimpleTrack2Hz \
55 | --result_folder ${nuscenes_result_dir}
56 | ```
57 |
58 | For the setting of 10Hz, run the following commands.
59 |
60 | ```bash
61 | python tools/nuscenes_result_creation_10hz.py \
62 | --name SimpleTrack10Hz \
63 | --result_folder ${nuscenes_result_dir} \
64 | --data_folder ${nuscenes20hz_data_dir}
65 |
66 | python tools/nuscenes_type_merge.py \
67 | --name SimpleTrack10Hz \
68 | --result_folder ${nuscenes_result_dir}
69 | ```
70 |
71 | ## Files and Detailed Metrics
72 |
73 | Please see [Dropbox Link](https://www.dropbox.com/sh/8906exnes0u5e89/AAD0xLwW1nq_QiuUBaYDrQVna?dl=0).
--------------------------------------------------------------------------------
/docs/output_format.md:
--------------------------------------------------------------------------------
1 | # Output Format
--------------------------------------------------------------------------------
/docs/simpletrack_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/docs/simpletrack_gif.gif
--------------------------------------------------------------------------------
/docs/waymo.md:
--------------------------------------------------------------------------------
1 | # Waymo Open Dataset Inference
2 |
3 | This documentation describes the steps to infer on the validation set of Waymo Open Dataset.
4 |
5 | ## Data Preprocessing
6 |
7 | We provide instructions for Waymo Open Dataset and nuScenes. Please follow the command at [Data Preprocessing Documentations](data_preprocess.md). After data preprocessing, the directories to the preprocessed data are `waymo_data_dir`, the name for the detection is `det_name`.
8 |
9 | ## Inference with SimpleTrack
10 |
11 | To run for a standard SimpleTrack setting experiment, run the following commands, the per-sequence results are then saved in the `${waymo_result_dir}/SimpleTrack/summary`, with subfolder taking the name `vehicle/`, `pedestrian/`, and `cyclist/`.
12 |
13 | In the experiments,
14 | * `${det_name}` denotes the name of the detection preprocessed;
15 | * `${waymo_result_dir}` is for saving tracking results;
16 | * `${waymo_data_dir}` is the folder to the preprocessed waymo data;
17 | * `${proc_num}` uses multiprocessing to inference different sequences (for your information, I generally use 202 processes, so that each one of the process runs a sequence of the validation set).
18 | * **Please look out for the different config files we use for each type of object.**
19 |
20 | ```bash
21 | # for vehicle
22 | python tools/main_waymo.py \
23 | --name SimpleTrack \
24 | --det_name ${det_name} \
25 | --obj_type vehicle \
26 | --config_path configs/waymo_configs/vc_kf_giou.yaml \
27 | --data_folder ${waymo_data_dir} \
28 | --result_folder ${waymo_result_dir} \
29 | --gt_folder ${gt_dets_dir} \
30 | --process ${proc_num}
31 |
32 | # for pedestrian
33 | python tools/main_waymo.py \
34 | --name SimpleTrack \
35 | --det_name ${det_name} \
36 | --obj_type vehicle \
37 | --config_path configs/waymo_configs/pd_kf_giou.yaml \
38 | --data_folder ${waymo_data_dir} \
39 | --result_folder ${waymo_result_dir} \
40 | --process ${proc_num}
41 |
42 | # for cyclist
43 | python tools/main_waymo.py \
44 | --name SimpleTrack \
45 | --det_name ${det_name} \
46 | --obj_type vehicle \
47 | --config_path configs/waymo_configs/vc_kf_giou.yaml \
48 | --data_folder ${waymo_data_dir} \
49 | --result_folder ${waymo_result_dir}
50 | --process ${proc_num}
51 | ```
52 |
53 | ## Output Format
54 |
55 | In the folder of `${waymo_result_dir}/SimpleTrack/summary/`, there are three sub-folders `vehicle`, `pedestrian`, and `cyclist`. Each of them contains 202 `.npz` files, corresponding to the sequences in the validation set of Waymo Open Dataset. For the format in each `.npz` file, please refer to [Output Format](output_format.md).
56 |
57 | ## Converting to Waymo Open Dataset Format
58 |
59 | To work with the official formats of Waymo Open Dataset, the following commands convert the results in SimpleTrack format to the `.bin` format.
60 |
61 | After running the command, we will have four `pred.bin` files in `${waymo_result_dir}/SimpleTrack/bin/` as `prd.bin` (all objects), `vehicle/pred.bin` (vehicles only), `pedestrain/pred.bin` (pedestrian only), and `cyclist/pred.bin` (cyclist only).
62 |
63 | ```bash
64 | python tools/waymo_pred_bin.py \
65 | --name SimpleTrack \
66 | --result_folder ${waymo_result_dir} \
67 | --data_folder ${waymo_data_dir}
68 | ```
69 |
70 | Eventually, use the evaluation provided by Waymo officials by following [Quick Guide to Waymo Open Dataset](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md).
71 |
72 | ## Files and Detailed Metrics
73 |
74 | Please see [Dropbox link](https://www.dropbox.com/sh/u6o8dcwmzya04uk/AAAUsNvTt7ubXul9dx5Xnp4xa?dl=0).
75 |
--------------------------------------------------------------------------------
/mot_3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tusen-ai/SimpleTrack/05c96bb7ed98fc179856f327544612a66c839b5e/mot_3d/__init__.py
--------------------------------------------------------------------------------
/mot_3d/association.py:
--------------------------------------------------------------------------------
1 | import numpy as np, mot_3d.tracklet as tracklet
2 | from . import utils
3 | from scipy.optimize import linear_sum_assignment
4 | from .frame_data import FrameData
5 | from .update_info_data import UpdateInfoData
6 | from .data_protos import BBox, Validity
7 |
8 |
9 | def associate_dets_to_tracks(dets, tracks, mode, asso,
10 | dist_threshold=0.9, trk_innovation_matrix=None):
11 | """ associate the tracks with detections
12 | """
13 | if mode == 'bipartite':
14 | matched_indices, dist_matrix = \
15 | bipartite_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix)
16 | elif mode == 'greedy':
17 | matched_indices, dist_matrix = \
18 | greedy_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix)
19 | unmatched_dets = list()
20 | for d, det in enumerate(dets):
21 | if d not in matched_indices[:, 0]:
22 | unmatched_dets.append(d)
23 |
24 | unmatched_tracks = list()
25 | for t, trk in enumerate(tracks):
26 | if t not in matched_indices[:, 1]:
27 | unmatched_tracks.append(t)
28 |
29 | matches = list()
30 | for m in matched_indices:
31 | if dist_matrix[m[0], m[1]] > dist_threshold:
32 | unmatched_dets.append(m[0])
33 | unmatched_tracks.append(m[1])
34 | else:
35 | matches.append(m.reshape(2))
36 | return matches, np.array(unmatched_dets), np.array(unmatched_tracks)
37 |
38 |
39 | def bipartite_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix):
40 | if asso == 'iou':
41 | dist_matrix = compute_iou_distance(dets, tracks, asso)
42 | elif asso == 'giou':
43 | dist_matrix = compute_iou_distance(dets, tracks, asso)
44 | elif asso == 'm_dis':
45 | dist_matrix = compute_m_distance(dets, tracks, trk_innovation_matrix)
46 | elif asso == 'euler':
47 | dist_matrix = compute_m_distance(dets, tracks, None)
48 | row_ind, col_ind = linear_sum_assignment(dist_matrix)
49 | matched_indices = np.stack([row_ind, col_ind], axis=1)
50 | return matched_indices, dist_matrix
51 |
52 |
53 | def greedy_matcher(dets, tracks, asso, dist_threshold, trk_innovation_matrix):
54 | """ it's ok to use iou in bipartite
55 | but greedy is only for m_distance
56 | """
57 | matched_indices = list()
58 |
59 | # compute the distance matrix
60 | if asso == 'm_dis':
61 | distance_matrix = compute_m_distance(dets, tracks, trk_innovation_matrix)
62 | elif asso == 'euler':
63 | distance_matrix = compute_m_distance(dets, tracks, None)
64 | elif asso == 'iou':
65 | distance_matrix = compute_iou_distance(dets, tracks, asso)
66 | elif asso == 'giou':
67 | distance_matrix = compute_iou_distance(dets, tracks, asso)
68 | num_dets, num_trks = distance_matrix.shape
69 |
70 | # association in the greedy manner
71 | # refer to https://github.com/eddyhkchiu/mahalanobis_3d_multi_object_tracking/blob/master/main.py
72 | distance_1d = distance_matrix.reshape(-1)
73 | index_1d = np.argsort(distance_1d)
74 | index_2d = np.stack([index_1d // num_trks, index_1d % num_trks], axis=1)
75 | detection_id_matches_to_tracking_id = [-1] * num_dets
76 | tracking_id_matches_to_detection_id = [-1] * num_trks
77 | for sort_i in range(index_2d.shape[0]):
78 | detection_id = int(index_2d[sort_i][0])
79 | tracking_id = int(index_2d[sort_i][1])
80 | if tracking_id_matches_to_detection_id[tracking_id] == -1 and detection_id_matches_to_tracking_id[detection_id] == -1:
81 | tracking_id_matches_to_detection_id[tracking_id] = detection_id
82 | detection_id_matches_to_tracking_id[detection_id] = tracking_id
83 | matched_indices.append([detection_id, tracking_id])
84 | if len(matched_indices) == 0:
85 | matched_indices = np.empty((0, 2))
86 | else:
87 | matched_indices = np.asarray(matched_indices)
88 | return matched_indices, distance_matrix
89 |
90 |
91 | def compute_m_distance(dets, tracks, trk_innovation_matrix):
92 | """ compute l2 or mahalanobis distance
93 | when the input trk_innovation_matrix is None, compute L2 distance (euler)
94 | else compute mahalanobis distance
95 | return dist_matrix: numpy array [len(dets), len(tracks)]
96 | """
97 | euler_dis = (trk_innovation_matrix is None) # is use euler distance
98 | if not euler_dis:
99 | trk_inv_inn_matrices = [np.linalg.inv(m) for m in trk_innovation_matrix]
100 | dist_matrix = np.empty((len(dets), len(tracks)))
101 |
102 | for i, det in enumerate(dets):
103 | for j, trk in enumerate(tracks):
104 | if euler_dis:
105 | dist_matrix[i, j] = utils.m_distance(det, trk)
106 | else:
107 | dist_matrix[i, j] = utils.m_distance(det, trk, trk_inv_inn_matrices[j])
108 | return dist_matrix
109 |
110 |
111 | def compute_iou_distance(dets, tracks, asso='iou'):
112 | iou_matrix = np.zeros((len(dets), len(tracks)))
113 | for d, det in enumerate(dets):
114 | for t, trk in enumerate(tracks):
115 | if asso == 'iou':
116 | iou_matrix[d, t] = utils.iou3d(det, trk)[1]
117 | elif asso == 'giou':
118 | iou_matrix[d, t] = utils.giou3d(det, trk)
119 | dist_matrix = 1 - iou_matrix
120 | return dist_matrix
121 |
--------------------------------------------------------------------------------
/mot_3d/data_protos/__init__.py:
--------------------------------------------------------------------------------
1 | from mot_3d.data_protos.bbox import BBox
2 | from mot_3d.data_protos.validity import Validity
--------------------------------------------------------------------------------
/mot_3d/data_protos/bbox.py:
--------------------------------------------------------------------------------
1 | """ The defination and basic methods of bbox
2 | """
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 |
7 | class BBox:
8 | def __init__(self, x=None, y=None, z=None, h=None, w=None, l=None, o=None):
9 | self.x = x # center x
10 | self.y = y # center y
11 | self.z = z # center z
12 | self.h = h # height
13 | self.w = w # width
14 | self.l = l # length
15 | self.o = o # orientation
16 | self.s = None # detection score
17 |
18 | def __str__(self):
19 | return 'x: {}, y: {}, z: {}, heading: {}, length: {}, width: {}, height: {}, score: {}'.format(
20 | self.x, self.y, self.z, self.o, self.l, self.w, self.h, self.s)
21 |
22 | @classmethod
23 | def bbox2dict(cls, bbox):
24 | return {
25 | 'center_x': bbox.x, 'center_y': bbox.y, 'center_z': bbox.z,
26 | 'height': bbox.h, 'width': bbox.w, 'length': bbox.l, 'heading': bbox.o}
27 |
28 | @classmethod
29 | def bbox2array(cls, bbox):
30 | if bbox.s is None:
31 | return np.array([bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h])
32 | else:
33 | return np.array([bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h, bbox.s])
34 |
35 | @classmethod
36 | def array2bbox(cls, data):
37 | bbox = BBox()
38 | bbox.x, bbox.y, bbox.z, bbox.o, bbox.l, bbox.w, bbox.h = data[:7]
39 | if len(data) == 8:
40 | bbox.s = data[-1]
41 | return bbox
42 |
43 | @classmethod
44 | def dict2bbox(cls, data):
45 | bbox = BBox()
46 | bbox.x = data['center_x']
47 | bbox.y = data['center_y']
48 | bbox.z = data['center_z']
49 | bbox.h = data['height']
50 | bbox.w = data['width']
51 | bbox.l = data['length']
52 | bbox.o = data['heading']
53 | if 'score' in data.keys():
54 | bbox.s = data['score']
55 | return bbox
56 |
57 | @classmethod
58 | def copy_bbox(cls, bboxa, bboxb):
59 | bboxa.x = bboxb.x
60 | bboxa.y = bboxb.y
61 | bboxa.z = bboxb.z
62 | bboxa.l = bboxb.l
63 | bboxa.w = bboxb.w
64 | bboxa.h = bboxb.h
65 | bboxa.o = bboxb.o
66 | bboxa.s = bboxb.s
67 | return
68 |
69 | @classmethod
70 | def box2corners2d(cls, bbox):
71 | """ the coordinates for bottom corners
72 | """
73 | bottom_center = np.array([bbox.x, bbox.y, bbox.z - bbox.h / 2])
74 | cos, sin = np.cos(bbox.o), np.sin(bbox.o)
75 | pc0 = np.array([bbox.x + cos * bbox.l / 2 + sin * bbox.w / 2,
76 | bbox.y + sin * bbox.l / 2 - cos * bbox.w / 2,
77 | bbox.z - bbox.h / 2])
78 | pc1 = np.array([bbox.x + cos * bbox.l / 2 - sin * bbox.w / 2,
79 | bbox.y + sin * bbox.l / 2 + cos * bbox.w / 2,
80 | bbox.z - bbox.h / 2])
81 | pc2 = 2 * bottom_center - pc0
82 | pc3 = 2 * bottom_center - pc1
83 |
84 | return [pc0.tolist(), pc1.tolist(), pc2.tolist(), pc3.tolist()]
85 |
86 | @classmethod
87 | def box2corners3d(cls, bbox):
88 | """ the coordinates for bottom corners
89 | """
90 | center = np.array([bbox.x, bbox.y, bbox.z])
91 | bottom_corners = np.array(BBox.box2corners2d(bbox))
92 | up_corners = 2 * center - bottom_corners
93 | corners = np.concatenate([up_corners, bottom_corners], axis=0)
94 | return corners.tolist()
95 |
96 | @classmethod
97 | def motion2bbox(cls, bbox, motion):
98 | result = deepcopy(bbox)
99 | result.x += motion[0]
100 | result.y += motion[1]
101 | result.z += motion[2]
102 | result.o += motion[3]
103 | return result
104 |
105 | @classmethod
106 | def set_bbox_size(cls, bbox, size_array):
107 | result = deepcopy(bbox)
108 | result.l, result.w, result.h = size_array
109 | return result
110 |
111 | @classmethod
112 | def set_bbox_with_states(cls, prev_bbox, state_array):
113 | prev_array = BBox.bbox2array(prev_bbox)
114 | prev_array[:4] += state_array[:4]
115 | prev_array[4:] = state_array[4:]
116 | bbox = BBox.array2bbox(prev_array)
117 | return bbox
118 |
119 | @classmethod
120 | def box_pts2world(cls, ego_matrix, pcs):
121 | new_pcs = np.concatenate((pcs,
122 | np.ones(pcs.shape[0])[:, np.newaxis]),
123 | axis=1)
124 | new_pcs = ego_matrix @ new_pcs.T
125 | new_pcs = new_pcs.T[:, :3]
126 | return new_pcs
127 |
128 | @classmethod
129 | def edge2yaw(cls, center, edge):
130 | vec = edge - center
131 | yaw = np.arccos(vec[0] / np.linalg.norm(vec))
132 | if vec[1] < 0:
133 | yaw = -yaw
134 | return yaw
135 |
136 | @classmethod
137 | def bbox2world(cls, ego_matrix, box):
138 | # center and corners
139 | corners = np.array(BBox.box2corners2d(box))
140 | center = BBox.bbox2array(box)[:3][np.newaxis, :]
141 | center = BBox.box_pts2world(ego_matrix, center)[0]
142 | corners = BBox.box_pts2world(ego_matrix, corners)
143 | # heading
144 | edge_mid_point = (corners[0] + corners[1]) / 2
145 | yaw = BBox.edge2yaw(center[:2], edge_mid_point[:2])
146 |
147 | result = deepcopy(box)
148 | result.x, result.y, result.z = center
149 | result.o = yaw
150 | return result
--------------------------------------------------------------------------------
/mot_3d/data_protos/validity.py:
--------------------------------------------------------------------------------
1 | class Validity:
2 | TYPES = ['birth', 'alive', 'death']
3 | def __init__(self):
4 | return
5 |
6 | @classmethod
7 | def valid(cls, state_string):
8 | tokens = state_string.split('_')
9 | if tokens[0] == 'birth':
10 | return True
11 | if len(tokens) < 3:
12 | return False
13 | if tokens[0] == 'alive' and int(tokens[1]) == 1:
14 | return True
15 | return False
16 |
17 | @classmethod
18 | def notoutput(cls, state_string):
19 | tokens = state_string.split('_')
20 | if len(tokens) < 3:
21 | return False
22 | if tokens[0] == 'alive' and int(tokens[1]) != 1:
23 | return True
24 | return False
25 |
26 | @classmethod
27 | def predicted(cls, state_string):
28 | state, token = state_string.split('_')
29 | if state not in Validity.TYPES:
30 | raise ValueError('type name not existed')
31 |
32 | if state == 'alive' and int(token) != 0:
33 | return True
34 | return False
35 |
36 | @classmethod
37 | def modify_string(cls, state_string, mode):
38 | tokens = state_string.split('_')
39 | tokens[1] = str(mode)
40 | return '{:}_{:}_{:}'.format(tokens[0], tokens[1], tokens[2])
--------------------------------------------------------------------------------
/mot_3d/frame_data.py:
--------------------------------------------------------------------------------
1 | """ input form of the data in each frame
2 | """
3 | from .data_protos import BBox
4 | import numpy as np, mot_3d.utils as utils
5 |
6 |
7 | class FrameData:
8 | def __init__(self, dets, ego, time_stamp=None, pc=None, det_types=None, aux_info=None):
9 | self.dets = dets # detections for each frame
10 | self.ego = ego # ego matrix information
11 | self.pc = pc
12 | self.det_types = det_types
13 | self.time_stamp = time_stamp
14 | self.aux_info = aux_info
15 |
16 | for i, det in enumerate(self.dets):
17 | self.dets[i] = BBox.array2bbox(det)
18 |
19 | # if not aux_info['is_key_frame']:
20 | # self.dets = [d for d in self.dets if d.s >= 0.5]
--------------------------------------------------------------------------------
/mot_3d/life/__init__.py:
--------------------------------------------------------------------------------
1 | from ..life.hit_manager import HitManager
--------------------------------------------------------------------------------
/mot_3d/life/hit_manager.py:
--------------------------------------------------------------------------------
1 | """ a finite state machine to manage the life cycle
2 | states:
3 | - birth: first founded
4 | - alive: alive
5 | - no_asso: without high score association, about to die
6 | - dead: may it eternal peace
7 | """
8 | import numpy as np
9 | from ..data_protos import Validity
10 | from ..update_info_data import UpdateInfoData
11 | from .. import utils
12 |
13 |
14 | class HitManager:
15 | def __init__(self, configs, frame_index):
16 | self.time_since_update = 0
17 | self.hits = 1 # number of total hits including the first detection
18 | self.hit_streak = 1 # number of continuing hit considering the first detection
19 | self.first_continuing_hit = 1
20 | self.still_first = True
21 | self.age = 0
22 | self.recent_state = None
23 |
24 | self.max_age = configs['running']['max_age_since_update']
25 | self.min_hits = configs['running']['min_hits_to_birth']
26 |
27 | self.state = 'birth'
28 | self.recent_state = 1
29 | self.no_asso = False
30 | if frame_index <= self.min_hits or self.min_hits == 0:
31 | self.state = 'alive'
32 | self.recent_state = 1
33 |
34 | def predict(self, is_key_frame):
35 | # only on key frame
36 | # unassociated prediction affects the life-cycle management
37 | if not is_key_frame:
38 | return
39 |
40 | self.age += 1
41 | if self.time_since_update > 0:
42 | self.hit_streak = 0
43 | self.still_first = False
44 | self.time_since_update += 1
45 | self.fall = True
46 | return
47 |
48 | def if_valid(self, update_info):
49 | self.recent_state = update_info.mode
50 | return update_info.mode
51 |
52 | def update(self, update_info: UpdateInfoData, is_key_frame=True):
53 | # the update happening during the non-key-frame
54 | # can extend the life of tracklet
55 | association = self.if_valid(update_info)
56 | self.recent_state = association
57 | if association != 0:
58 | self.fall = False
59 | self.time_since_update = 0
60 | self.history = []
61 | self.hits += 1
62 | self.hit_streak += 1 # number of continuing hit
63 | if self.still_first:
64 | self.first_continuing_hit += 1 # number of continuing hit in the fist time
65 | if is_key_frame:
66 | self.state_transition(association, update_info.frame_index)
67 |
68 | def state_transition(self, mode, frame_index):
69 | # if just founded
70 | if self.state == 'birth':
71 | if (self.hits >= self.min_hits) or (frame_index <= self.min_hits):
72 | self.state = 'alive'
73 | self.recent_state = mode
74 | elif self.time_since_update >= self.max_age:
75 | self.state = 'dead'
76 | # already alive
77 | elif self.state == 'alive':
78 | if self.time_since_update >= self.max_age:
79 | self.state = 'dead'
80 |
81 | def alive(self, frame_index):
82 | return self.state == 'alive'
83 |
84 | def death(self, frame_index):
85 | return self.state == 'dead'
86 |
87 | def valid_output(self, frame_index):
88 | return (self.state == 'alive') and (self.no_asso == False)
89 |
90 | def state_string(self, frame_index):
91 | """ Each tracklet use a state strong to represent its state
92 | This string is used for determining output, etc.
93 | """
94 | if self.state == 'birth':
95 | return '{:}_{:}'.format(self.state, self.hits)
96 | elif self.state == 'alive':
97 | return '{:}_{:}_{:}'.format(self.state, self.recent_state, self.time_since_update)
98 | elif self.state == 'dead':
99 | return '{:}_{:}'.format(self.state, self.time_since_update)
--------------------------------------------------------------------------------
/mot_3d/mot.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import numpy as np, mot_3d.tracklet as tracklet, mot_3d.utils as utils
3 | from .redundancy import RedundancyModule
4 | from scipy.optimize import linear_sum_assignment
5 | from .frame_data import FrameData
6 | from .update_info_data import UpdateInfoData
7 | from .data_protos import BBox, Validity
8 | from .association import associate_dets_to_tracks
9 | from . import visualization
10 | from mot_3d import redundancy
11 | import pdb, os
12 |
13 |
14 | class MOTModel:
15 | def __init__(self, configs):
16 | self.trackers = list() # tracker for each single tracklet
17 | self.frame_count = 0 # record for the frames
18 | self.count = 0 # record the obj number to assign ids
19 | self.time_stamp = None # the previous time stamp
20 | self.redundancy = RedundancyModule(configs) # module for no detection cases
21 |
22 | non_key_redundancy_config = deepcopy(configs)
23 | non_key_redundancy_config['redundancy'] = {
24 | 'mode': 'mm',
25 | 'det_score_threshold': {'giou': 0.1, 'iou': 0.1, 'euler': 0.1},
26 | 'det_dist_threshold': {'giou': -0.5, 'iou': 0.1, 'euler': 4}
27 | }
28 | self.non_key_redundancy = RedundancyModule(non_key_redundancy_config)
29 |
30 | self.configs = configs
31 | self.match_type = configs['running']['match_type']
32 | self.score_threshold = configs['running']['score_threshold']
33 | self.asso = configs['running']['asso']
34 | self.asso_thres = configs['running']['asso_thres'][self.asso]
35 | self.motion_model = configs['running']['motion_model']
36 |
37 | self.max_age = configs['running']['max_age_since_update']
38 | self.min_hits = configs['running']['min_hits_to_birth']
39 |
40 | @property
41 | def has_velo(self):
42 | return not (self.motion_model == 'kf' or self.motion_model == 'fbkf' or self.motion_model == 'ma')
43 |
44 | def frame_mot(self, input_data: FrameData):
45 | """ For each frame input, generate the latest mot results
46 | Args:
47 | input_data (FrameData): input data, including detection bboxes and ego information
48 | Returns:
49 | tracks on this frame: [(bbox0, id0), (bbox1, id1), ...]
50 | """
51 | self.frame_count += 1
52 |
53 | # initialize the time stamp on frame 0
54 | if self.time_stamp is None:
55 | self.time_stamp = input_data.time_stamp
56 |
57 | if not input_data.aux_info['is_key_frame']:
58 | result = self.non_key_frame_mot(input_data)
59 | return result
60 |
61 | if 'kf' in self.motion_model:
62 | matched, unmatched_dets, unmatched_trks = self.forward_step_trk(input_data)
63 |
64 | time_lag = input_data.time_stamp - self.time_stamp
65 | # update the matched tracks
66 | for t, trk in enumerate(self.trackers):
67 | if t not in unmatched_trks:
68 | for k in range(len(matched)):
69 | if matched[k][1] == t:
70 | d = matched[k][0]
71 | break
72 | if self.has_velo:
73 | aux_info = {
74 | 'velo': list(input_data.aux_info['velos'][d]),
75 | 'is_key_frame': input_data.aux_info['is_key_frame']}
76 | else:
77 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']}
78 | update_info = UpdateInfoData(mode=1, bbox=input_data.dets[d], ego=input_data.ego,
79 | frame_index=self.frame_count, pc=input_data.pc,
80 | dets=input_data.dets, aux_info=aux_info)
81 | trk.update(update_info)
82 | else:
83 | result_bbox, update_mode, aux_info = self.redundancy.infer(trk, input_data, time_lag)
84 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']}
85 | update_info = UpdateInfoData(mode=update_mode, bbox=result_bbox,
86 | ego=input_data.ego, frame_index=self.frame_count,
87 | pc=input_data.pc, dets=input_data.dets, aux_info=aux_info)
88 | trk.update(update_info)
89 |
90 | # create new tracks for unmatched detections
91 | for index in unmatched_dets:
92 | if self.has_velo:
93 | aux_info = {
94 | 'velo': list(input_data.aux_info['velos'][index]),
95 | 'is_key_frame': input_data.aux_info['is_key_frame']}
96 | else:
97 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']}
98 |
99 | track = tracklet.Tracklet(self.configs, self.count, input_data.dets[index], input_data.det_types[index],
100 | self.frame_count, aux_info=aux_info, time_stamp=input_data.time_stamp)
101 | self.trackers.append(track)
102 | self.count += 1
103 |
104 | # remove dead tracks
105 | track_num = len(self.trackers)
106 | for index, trk in enumerate(reversed(self.trackers)):
107 | if trk.death(self.frame_count):
108 | self.trackers.pop(track_num - 1 - index)
109 |
110 | # output the results
111 | result = list()
112 | for trk in self.trackers:
113 | state_string = trk.state_string(self.frame_count)
114 | result.append((trk.get_state(), trk.id, state_string, trk.det_type))
115 |
116 | # wrap up and update the information about the mot trackers
117 | self.time_stamp = input_data.time_stamp
118 | for trk in self.trackers:
119 | trk.sync_time_stamp(self.time_stamp)
120 |
121 | return result
122 |
123 | def forward_step_trk(self, input_data: FrameData):
124 | dets = input_data.dets
125 | det_indexes = [i for i, det in enumerate(dets) if det.s >= self.score_threshold]
126 | dets = [dets[i] for i in det_indexes]
127 |
128 | # prediction and association
129 | trk_preds = list()
130 | for trk in self.trackers:
131 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame']))
132 |
133 | # for m-distance association
134 | trk_innovation_matrix = None
135 | if self.asso == 'm_dis':
136 | trk_innovation_matrix = [trk.compute_innovation_matrix() for trk in self.trackers]
137 |
138 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds,
139 | self.match_type, self.asso, self.asso_thres, trk_innovation_matrix)
140 |
141 | for k in range(len(matched)):
142 | matched[k][0] = det_indexes[matched[k][0]]
143 | for k in range(len(unmatched_dets)):
144 | unmatched_dets[k] = det_indexes[unmatched_dets[k]]
145 | return matched, unmatched_dets, unmatched_trks
146 |
147 | def non_key_forward_step_trk(self, input_data: FrameData):
148 | """ tracking on non-key frames (for nuScenes)
149 | """
150 | dets = input_data.dets
151 | det_indexes = [i for i, det in enumerate(dets) if det.s >= 0.5]
152 | dets = [dets[i] for i in det_indexes]
153 |
154 | # prediction and association
155 | trk_preds = list()
156 | for trk in self.trackers:
157 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame']))
158 |
159 | # for m-distance association
160 | trk_innovation_matrix = None
161 | if self.asso == 'm_dis':
162 | trk_innovation_matrix = [trk.compute_innovation_matrix() for trk in self.trackers]
163 |
164 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds,
165 | self.match_type, self.asso, self.asso_thres, trk_innovation_matrix)
166 |
167 | for k in range(len(matched)):
168 | matched[k][0] = det_indexes[matched[k][0]]
169 | for k in range(len(unmatched_dets)):
170 | unmatched_dets[k] = det_indexes[unmatched_dets[k]]
171 | return matched, unmatched_dets, unmatched_trks
172 |
173 | def non_key_frame_mot(self, input_data: FrameData):
174 | """ tracking on non-key frames (for nuScenes)
175 | """
176 | self.frame_count += 1
177 | # initialize the time stamp on frame 0
178 | if self.time_stamp is None:
179 | self.time_stamp = input_data.time_stamp
180 |
181 | if 'kf' in self.motion_model:
182 | matched, unmatched_dets, unmatched_trks = self.non_key_forward_step_trk(input_data)
183 | time_lag = input_data.time_stamp - self.time_stamp
184 |
185 | redundancy_bboxes, update_modes = self.non_key_redundancy.bipartite_infer(input_data, self.trackers)
186 | # update the matched tracks
187 | for t, trk in enumerate(self.trackers):
188 | if t not in unmatched_trks:
189 | for k in range(len(matched)):
190 | if matched[k][1] == t:
191 | d = matched[k][0]
192 | break
193 | if self.has_velo:
194 | aux_info = {
195 | 'velo': list(input_data.aux_info['velos'][d]),
196 | 'is_key_frame': input_data.aux_info['is_key_frame']}
197 | else:
198 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']}
199 | update_info = UpdateInfoData(mode=1, bbox=input_data.dets[d], ego=input_data.ego,
200 | frame_index=self.frame_count, pc=input_data.pc,
201 | dets=input_data.dets, aux_info=aux_info)
202 | trk.update(update_info)
203 | else:
204 | aux_info = {'is_key_frame': input_data.aux_info['is_key_frame']}
205 | update_info = UpdateInfoData(mode=update_modes[t], bbox=redundancy_bboxes[t],
206 | ego=input_data.ego, frame_index=self.frame_count,
207 | pc=input_data.pc, dets=input_data.dets, aux_info=aux_info)
208 | trk.update(update_info)
209 |
210 | # output the results
211 | result = list()
212 | for trk in self.trackers:
213 | state_string = trk.state_string(self.frame_count)
214 | result.append((trk.get_state(), trk.id, state_string, trk.det_type))
215 |
216 | # wrap up and update the information about the mot trackers
217 | self.time_stamp = input_data.time_stamp
218 | for trk in self.trackers:
219 | trk.sync_time_stamp(self.time_stamp)
220 |
221 | return result
--------------------------------------------------------------------------------
/mot_3d/motion_model/__init__.py:
--------------------------------------------------------------------------------
1 | from .kalman_filter import KalmanFilterMotionModel
2 |
--------------------------------------------------------------------------------
/mot_3d/motion_model/kalman_filter.py:
--------------------------------------------------------------------------------
1 | """ Many parts are borrowed from https://github.com/xinshuoweng/AB3DMOT
2 | """
3 |
4 | import numpy as np
5 | from ..data_protos import BBox
6 | from filterpy.kalman import KalmanFilter
7 |
8 |
9 | class KalmanFilterMotionModel:
10 | def __init__(self, bbox: BBox, inst_type, time_stamp, covariance='default'):
11 | # the time stamp of last observation
12 | self.prev_time_stamp = time_stamp
13 | self.latest_time_stamp = time_stamp
14 | # define constant velocity model
15 | self.score = bbox.s
16 | self.inst_type = inst_type
17 |
18 | self.kf = KalmanFilter(dim_x=10, dim_z=7)
19 | self.kf.x[:7] = BBox.bbox2array(bbox)[:7].reshape((7, 1))
20 | self.kf.F = np.array([[1,0,0,0,0,0,0,1,0,0], # state transition matrix
21 | [0,1,0,0,0,0,0,0,1,0],
22 | [0,0,1,0,0,0,0,0,0,1],
23 | [0,0,0,1,0,0,0,0,0,0],
24 | [0,0,0,0,1,0,0,0,0,0],
25 | [0,0,0,0,0,1,0,0,0,0],
26 | [0,0,0,0,0,0,1,0,0,0],
27 | [0,0,0,0,0,0,0,1,0,0],
28 | [0,0,0,0,0,0,0,0,1,0],
29 | [0,0,0,0,0,0,0,0,0,1]])
30 |
31 | self.kf.H = np.array([[1,0,0,0,0,0,0,0,0,0], # measurement function,
32 | [0,1,0,0,0,0,0,0,0,0],
33 | [0,0,1,0,0,0,0,0,0,0],
34 | [0,0,0,1,0,0,0,0,0,0],
35 | [0,0,0,0,1,0,0,0,0,0],
36 | [0,0,0,0,0,1,0,0,0,0],
37 | [0,0,0,0,0,0,1,0,0,0]])
38 |
39 | self.kf.B = np.zeros((10, 1)) # dummy control transition matrix
40 |
41 | # # with angular velocity
42 | # self.kf = KalmanFilter(dim_x=11, dim_z=7)
43 | # self.kf.F = np.array([[1,0,0,0,0,0,0,1,0,0,0], # state transition matrix
44 | # [0,1,0,0,0,0,0,0,1,0,0],
45 | # [0,0,1,0,0,0,0,0,0,1,0],
46 | # [0,0,0,1,0,0,0,0,0,0,1],
47 | # [0,0,0,0,1,0,0,0,0,0,0],
48 | # [0,0,0,0,0,1,0,0,0,0,0],
49 | # [0,0,0,0,0,0,1,0,0,0,0],
50 | # [0,0,0,0,0,0,0,1,0,0,0],
51 | # [0,0,0,0,0,0,0,0,1,0,0],
52 | # [0,0,0,0,0,0,0,0,0,1,0],
53 | # [0,0,0,0,0,0,0,0,0,0,1]])
54 |
55 | # self.kf.H = np.array([[1,0,0,0,0,0,0,0,0,0,0], # measurement function,
56 | # [0,1,0,0,0,0,0,0,0,0,0],
57 | # [0,0,1,0,0,0,0,0,0,0,0],
58 | # [0,0,0,1,0,0,0,0,0,0,0],
59 | # [0,0,0,0,1,0,0,0,0,0,0],
60 | # [0,0,0,0,0,1,0,0,0,0,0],
61 | # [0,0,0,0,0,0,1,0,0,0,0]])
62 |
63 | self.covariance_type = covariance
64 | # self.kf.R[0:,0:] *= 10. # measurement uncertainty
65 | self.kf.P[7:, 7:] *= 1000. # state uncertainty, give high uncertainty to the unobservable initial velocities, covariance matrix
66 | self.kf.P *= 10.
67 |
68 | # self.kf.Q[-1,-1] *= 0.01 # process uncertainty
69 | # self.kf.Q[7:, 7:] *= 0.01
70 |
71 | self.history = [bbox]
72 |
73 | def predict(self, time_stamp=None):
74 | """ For the motion prediction, use the get_prediction function.
75 | """
76 | self.kf.predict()
77 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2
78 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2
79 | return
80 |
81 | def update(self, det_bbox: BBox, aux_info=None):
82 | """
83 | Updates the state vector with observed bbox.
84 | """
85 | bbox = BBox.bbox2array(det_bbox)[:7]
86 |
87 | # full pipeline of kf, first predict, then update
88 | self.predict()
89 |
90 | ######################### orientation correction
91 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the range
92 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2
93 |
94 | new_theta = bbox[3]
95 | if new_theta >= np.pi: new_theta -= np.pi * 2 # make the theta still in the range
96 | if new_theta < -np.pi: new_theta += np.pi * 2
97 | bbox[3] = new_theta
98 |
99 | predicted_theta = self.kf.x[3]
100 | if np.abs(new_theta - predicted_theta) > np.pi / 2.0 and np.abs(new_theta - predicted_theta) < np.pi * 3 / 2.0: # if the angle of two theta is not acute angle
101 | self.kf.x[3] += np.pi
102 | if self.kf.x[3] > np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the range
103 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2
104 |
105 | # now the angle is acute: < 90 or > 270, convert the case of > 270 to < 90
106 | if np.abs(new_theta - self.kf.x[3]) >= np.pi * 3 / 2.0:
107 | if new_theta > 0: self.kf.x[3] += np.pi * 2
108 | else: self.kf.x[3] -= np.pi * 2
109 |
110 | ######################### # flip
111 |
112 | self.kf.update(bbox)
113 | self.prev_time_stamp = self.latest_time_stamp
114 |
115 | if self.kf.x[3] >= np.pi: self.kf.x[3] -= np.pi * 2 # make the theta still in the rage
116 | if self.kf.x[3] < -np.pi: self.kf.x[3] += np.pi * 2
117 |
118 | if det_bbox.s is None:
119 | self.score = self.score * 0.01
120 | else:
121 | self.score = det_bbox.s
122 |
123 | cur_bbox = self.kf.x[:7].reshape(-1).tolist()
124 | cur_bbox = BBox.array2bbox(cur_bbox + [self.score])
125 | self.history[-1] = cur_bbox
126 | return
127 |
128 | def get_prediction(self, time_stamp=None):
129 | """
130 | Advances the state vector and returns the predicted bounding box estimate.
131 | """
132 | time_lag = time_stamp - self.prev_time_stamp
133 | self.latest_time_stamp = time_stamp
134 | self.kf.F = np.array([[1,0,0,0,0,0,0,time_lag,0,0], # state transition matrix
135 | [0,1,0,0,0,0,0,0,time_lag,0],
136 | [0,0,1,0,0,0,0,0,0,time_lag],
137 | [0,0,0,1,0,0,0,0,0,0],
138 | [0,0,0,0,1,0,0,0,0,0],
139 | [0,0,0,0,0,1,0,0,0,0],
140 | [0,0,0,0,0,0,1,0,0,0],
141 | [0,0,0,0,0,0,0,1,0,0],
142 | [0,0,0,0,0,0,0,0,1,0],
143 | [0,0,0,0,0,0,0,0,0,1]])
144 | pred_x = self.kf.get_prediction()[0]
145 | if pred_x[3] >= np.pi: pred_x[3] -= np.pi * 2
146 | if pred_x[3] < -np.pi: pred_x[3] += np.pi * 2
147 | pred_bbox = BBox.array2bbox(pred_x[:7].reshape(-1))
148 |
149 | self.history.append(pred_bbox)
150 | return pred_bbox
151 |
152 | def get_state(self):
153 | """
154 | Returns the current bounding box estimate.
155 | """
156 | return self.history[-1]
157 |
158 | def compute_innovation_matrix(self):
159 | """ compute the innovation matrix for association with mahalonobis distance
160 | """
161 | return np.matmul(np.matmul(self.kf.H, self.kf.P), self.kf.H.T) + self.kf.R
162 |
163 | def sync_time_stamp(self, time_stamp):
164 | self.time_stamp = time_stamp
165 | return
166 |
--------------------------------------------------------------------------------
/mot_3d/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .nms import nms
--------------------------------------------------------------------------------
/mot_3d/preprocessing/bbox_coarse_hash.py:
--------------------------------------------------------------------------------
1 | """ Split the area into grid boxes
2 | BBoxes in different grid boxes without overlap cannot have overlap
3 | """
4 | import numpy as np
5 | from ..data_protos import BBox
6 |
7 |
8 | class BBoxCoarseFilter:
9 | def __init__(self, grid_size, scaler=100):
10 | self.gsize = grid_size
11 | self.scaler = 100
12 | self.bbox_dict = dict()
13 |
14 | def bboxes2dict(self, bboxes):
15 | for i, bbox in enumerate(bboxes):
16 | grid_keys = self.compute_bbox_key(bbox)
17 | for key in grid_keys:
18 | if key not in self.bbox_dict.keys():
19 | self.bbox_dict[key] = set([i])
20 | else:
21 | self.bbox_dict[key].add(i)
22 | return
23 |
24 | def compute_bbox_key(self, bbox):
25 | corners = np.asarray(BBox.box2corners2d(bbox))
26 | min_keys = np.floor(np.min(corners, axis=0) / self.gsize).astype(np.int)
27 | max_keys = np.floor(np.max(corners, axis=0) / self.gsize).astype(np.int)
28 |
29 | # enumerate all the corners
30 | grid_keys = [
31 | self.scaler * min_keys[0] + min_keys[1],
32 | self.scaler * min_keys[0] + max_keys[1],
33 | self.scaler * max_keys[0] + min_keys[1],
34 | self.scaler * max_keys[0] + max_keys[1]
35 | ]
36 | return grid_keys
37 |
38 | def related_bboxes(self, bbox):
39 | """ return the list of related bboxes
40 | """
41 | result = set()
42 | grid_keys = self.compute_bbox_key(bbox)
43 | for key in grid_keys:
44 | if key in self.bbox_dict.keys():
45 | result.update(self.bbox_dict[key])
46 | return list(result)
47 |
48 | def clear(self):
49 | self.bbox_dict = dict()
--------------------------------------------------------------------------------
/mot_3d/preprocessing/nms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .. import utils
3 | from ..data_protos import BBox
4 | from .bbox_coarse_hash import BBoxCoarseFilter
5 |
6 |
7 | def weird_bbox(bbox):
8 | if bbox.l <= 0 or bbox.w <= 0 or bbox.h <= 0:
9 | return True
10 | else:
11 | return False
12 |
13 |
14 | def nms(dets, inst_types, threshold_low=0.1, threshold_high=1.0, threshold_yaw=0.3):
15 | """ keep the bboxes with overlap <= threshold
16 | """
17 | dets_coarse_filter = BBoxCoarseFilter(grid_size=100, scaler=100)
18 | dets_coarse_filter.bboxes2dict(dets)
19 | scores = np.asarray([det.s for det in dets])
20 | yaws = np.asarray([det.o for det in dets])
21 | order = np.argsort(scores)[::-1]
22 |
23 | result_indexes = list()
24 | result_types = list()
25 | while order.size > 0:
26 | index = order[0]
27 |
28 | if weird_bbox(dets[index]):
29 | order = order[1:]
30 | continue
31 |
32 | # locate the related bboxes
33 | filter_indexes = dets_coarse_filter.related_bboxes(dets[index])
34 | in_mask = np.isin(order, filter_indexes)
35 | related_idxes = order[in_mask]
36 | related_idxes = np.asarray([i for i in related_idxes if inst_types[i] == inst_types[index]])
37 |
38 | # compute the ious
39 | bbox_num = len(related_idxes)
40 | ious = np.zeros(bbox_num)
41 | for i, idx in enumerate(related_idxes):
42 | ious[i] = utils.iou3d(dets[index], dets[idx])[1]
43 | related_inds = np.where(ious > threshold_low)
44 | related_inds_vote = np.where(ious > threshold_high)
45 | order_vote = related_idxes[related_inds_vote]
46 |
47 | if len(order_vote) >= 2:
48 | # keep the bboxes with similar yaw
49 | if order_vote.shape[0] <= 2:
50 | score_index = np.argmax(scores[order_vote])
51 | median_yaw = yaws[order_vote][score_index]
52 | elif order_vote.shape[0] % 2 == 0:
53 | tmp_yaw = yaws[order_vote].copy()
54 | tmp_yaw = np.append(tmp_yaw, yaws[order_vote][0])
55 | median_yaw = np.median(tmp_yaw)
56 | else:
57 | median_yaw = np.median(yaws[order_vote])
58 | yaw_vote = np.where(np.abs(yaws[order_vote] - median_yaw) % (2 * np.pi) < threshold_yaw)[0]
59 | order_vote = order_vote[yaw_vote]
60 |
61 | # start weighted voting
62 | vote_score_sum = np.sum(scores[order_vote])
63 | det_arrays = list()
64 | for idx in order_vote:
65 | det_arrays.append(BBox.bbox2array(dets[idx])[np.newaxis, :])
66 | det_arrays = np.vstack(det_arrays)
67 | avg_bbox_array = np.sum(scores[order_vote][:, np.newaxis] * det_arrays, axis=0) / vote_score_sum
68 | bbox = BBox.array2bbox(avg_bbox_array)
69 | bbox.s = scores[index]
70 | result_indexes.append(index)
71 | result_types.append(inst_types[index])
72 | else:
73 | result_indexes.append(index)
74 | result_types.append(inst_types[index])
75 |
76 | # delete the overlapped bboxes
77 | delete_idxes = related_idxes[related_inds]
78 | in_mask = np.isin(order, delete_idxes, invert=True)
79 | order = order[in_mask]
80 |
81 | return result_indexes, result_types
82 |
--------------------------------------------------------------------------------
/mot_3d/redundancy/__init__.py:
--------------------------------------------------------------------------------
1 | """ the redundancy system interface, in case of no high score detections
2 | this is the two-stage association.
3 | """
4 | from .redundancy import RedundancyModule
--------------------------------------------------------------------------------
/mot_3d/redundancy/redundancy.py:
--------------------------------------------------------------------------------
1 | from ..frame_data import FrameData
2 | from ..update_info_data import UpdateInfoData
3 | from ..data_protos import BBox, Validity
4 | import numpy as np, mot_3d.utils as utils
5 | from ..tracklet import Tracklet
6 | from ..association import associate_dets_to_tracks
7 |
8 |
9 | class RedundancyModule:
10 | def __init__(self, configs):
11 | self.configs = configs
12 | self.mode = configs['redundancy']['mode']
13 | self.asso = configs['running']['asso']
14 | self.det_score = configs['redundancy']['det_score_threshold'][self.asso]
15 | self.det_threshold = configs['redundancy']['det_dist_threshold'][self.asso]
16 | self.motion_model_type = configs['running']['motion_model']
17 |
18 | def infer(self, trk: Tracklet, input_data: FrameData, time_lag=None):
19 | if self.mode == 'bbox':
20 | return self.bbox_redundancy(trk, input_data)
21 | elif self.mode == 'mm':
22 | return self.motion_model_redundancy(trk, input_data, time_lag)
23 | else:
24 | return self.default_redundancy(trk, input_data)
25 |
26 | def default_redundancy(self, trk: Tracklet, input_data: FrameData):
27 | """ Return the supposed state, association string, and auxiliary information
28 | """
29 | pred_bbox = trk.get_state()
30 | return pred_bbox, 0, None
31 |
32 | def motion_model_redundancy(self, trk: Tracklet, input_data: FrameData, time_lag):
33 | # get the motion model prediction / current state
34 | pred_bbox = trk.get_state()
35 |
36 | # associate to low-score detections
37 | dists = list()
38 | dets = input_data.dets
39 | related_indexes = [i for i, det in enumerate(dets) if det.s > self.det_score]
40 | candidate_dets = [dets[i] for i in related_indexes]
41 |
42 | # association
43 | for i, det in enumerate(candidate_dets):
44 | pd_det = det
45 |
46 | if self.asso == 'iou':
47 | dists.append(utils.iou3d(pd_det, pred_bbox)[1])
48 | elif self.asso == 'giou':
49 | dists.append(utils.giou3d(pd_det, pred_bbox))
50 | elif self.asso == 'm_dis':
51 | trk_innovation_matrix = trk.compute_innovation_matrix()
52 | inv_innovation_matrix = np.linalg.inv(trk_innovation_matrix)
53 | dists.append(utils.m_distance(pd_det, pred_bbox, inv_innovation_matrix))
54 | elif self.asso == 'euler':
55 | dists.append(utils.m_distance(pd_det, pred_bbox))
56 |
57 | if self.asso in ['iou', 'giou'] and (len(dists) == 0 or np.max(dists) < self.det_threshold):
58 | result_bbox = pred_bbox
59 | update_mode = 0 # two-stage not assiciated
60 | elif self.asso in ['m_dis', 'euler'] and (len(dists) == 0 or np.min(dists) > self.det_threshold):
61 | result_bbox = pred_bbox
62 | update_mode = 0 # two-stage not assiciated
63 | else:
64 | result_bbox = pred_bbox
65 | update_mode = 3 # associated
66 | return result_bbox, update_mode, {'velo': np.zeros(2)}
67 |
68 | def bipartite_infer(self, input_data: FrameData, tracklets):
69 | dets = input_data.dets
70 | det_indexes = [i for i, det in enumerate(dets) if det.s >= self.det_score]
71 | dets = [dets[i] for i in det_indexes]
72 |
73 | # prediction and association
74 | trk_preds = list()
75 | for trk in tracklets:
76 | trk_preds.append(trk.predict(input_data.time_stamp, input_data.aux_info['is_key_frame']))
77 |
78 | matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds,
79 | 'bipartite', 'giou', 1 - self.det_threshold, None)
80 | for k in range(len(matched)):
81 | matched[k][0] = det_indexes[matched[k][0]]
82 | for k in range(len(unmatched_dets)):
83 | unmatched_dets[k] = det_indexes[unmatched_dets[k]]
84 |
85 | result_bboxes, update_modes = [], []
86 | for t, trk in enumerate(tracklets):
87 | if t not in unmatched_trks:
88 | for k in range(len(matched)):
89 | if matched[k][1] == t:
90 | d = matched[k][0]
91 | break
92 | result_bboxes.append(trk_preds[t])
93 | update_modes.append(4) # associated
94 | else:
95 | result_bboxes.append(trk_preds[t])
96 | update_modes.append(0) # not associated
97 | return result_bboxes, update_modes
98 |
--------------------------------------------------------------------------------
/mot_3d/tracklet/__init__.py:
--------------------------------------------------------------------------------
1 | from ..tracklet.tracklet import Tracklet
--------------------------------------------------------------------------------
/mot_3d/tracklet/tracklet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .. import motion_model
3 | from .. import life as life_manager
4 | from ..update_info_data import UpdateInfoData
5 | from ..frame_data import FrameData
6 | from ..data_protos import BBox
7 |
8 |
9 | class Tracklet:
10 | def __init__(self, configs, id, bbox: BBox, det_type, frame_index, time_stamp=None, aux_info=None):
11 | self.id = id
12 | self.time_stamp = time_stamp
13 | self.asso = configs['running']['asso']
14 |
15 | self.configs = configs
16 | self.det_type = det_type
17 | self.aux_info = aux_info
18 |
19 | # initialize different types of motion model
20 | self.motion_model_type = configs['running']['motion_model']
21 | # simple kalman filter
22 | if self.motion_model_type == 'kf':
23 | self.motion_model = motion_model.KalmanFilterMotionModel(
24 | bbox=bbox, inst_type=self.det_type, time_stamp=time_stamp, covariance=configs['running']['covariance'])
25 |
26 | # life and death management
27 | self.life_manager = life_manager.HitManager(configs, frame_index)
28 | # store the score for the latest bbox
29 | self.latest_score = bbox.s
30 |
31 | def predict(self, time_stamp=None, is_key_frame=True):
32 | """ in the prediction step, the motion model predicts the state of bbox
33 | the other components have to be synced
34 | the result is a BBox
35 |
36 | the ussage of time_stamp is optional, only if you use velocities
37 | """
38 | result = self.motion_model.get_prediction(time_stamp=time_stamp)
39 | self.life_manager.predict(is_key_frame=is_key_frame)
40 | self.latest_score = self.latest_score * 0.01
41 | result.s = self.latest_score
42 | return result
43 |
44 | def update(self, update_info: UpdateInfoData):
45 | """ update the state of the tracklet
46 | """
47 | self.latest_score = update_info.bbox.s
48 | is_key_frame = update_info.aux_info['is_key_frame']
49 |
50 | # only the direct association update the motion model
51 | if update_info.mode == 1 or update_info.mode == 3:
52 | self.motion_model.update(update_info.bbox, update_info.aux_info)
53 | else:
54 | pass
55 | self.life_manager.update(update_info, is_key_frame)
56 | return
57 |
58 | def get_state(self):
59 | """ current state of the tracklet
60 | """
61 | result = self.motion_model.get_state()
62 | result.s = self.latest_score
63 | return result
64 |
65 | def valid_output(self, frame_index):
66 | return self.life_manager.valid_output(frame_index)
67 |
68 | def death(self, frame_index):
69 | return self.life_manager.death(frame_index)
70 |
71 | def state_string(self, frame_index):
72 | """ the string describes how we get the bbox (e.g. by detection or motion model prediction)
73 | """
74 | return self.life_manager.state_string(frame_index)
75 |
76 | def compute_innovation_matrix(self):
77 | """ compute the innovation matrix for association with mahalonobis distance
78 | """
79 | return self.motion_model.compute_innovation_matrix()
80 |
81 | def sync_time_stamp(self, time_stamp):
82 | """ sync the time stamp for motion model
83 | """
84 | self.motion_model.sync_time_stamp(time_stamp)
85 | return
86 |
--------------------------------------------------------------------------------
/mot_3d/update_info_data.py:
--------------------------------------------------------------------------------
1 | """ a general interface for aranging the things inside a single tracklet
2 | data structure for updating the life cycles and states of a tracklet
3 | """
4 | from .data_protos import BBox
5 | from . import utils
6 | import numpy as np
7 |
8 |
9 | class UpdateInfoData:
10 | def __init__(self, mode, bbox: BBox, frame_index, ego, dets=None, pc=None, aux_info=None):
11 | self.mode = mode # association state
12 | self.bbox = bbox
13 | self.ego = ego
14 | self.frame_index = frame_index
15 | self.pc = pc
16 | self.dets = dets
17 | self.aux_info = aux_info
18 |
--------------------------------------------------------------------------------
/mot_3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .geometry import *
2 | from .data_utils import *
--------------------------------------------------------------------------------
/mot_3d/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | # Selecting the sequences according to types
2 | # Transfer the ID from string into int if needed
3 | from ..data_protos import BBox
4 | import numpy as np
5 |
6 |
7 | __all__ = ['inst_filter', 'str2int', 'box_wrapper', 'type_filter', 'id_transform']
8 |
9 |
10 | def str2int(strs):
11 | result = [int(s) for s in strs]
12 | return result
13 |
14 |
15 | def box_wrapper(bboxes, ids):
16 | frame_num = len(ids)
17 | result = list()
18 | for _i in range(frame_num):
19 | frame_result = list()
20 | num = len(ids[_i])
21 | for _j in range(num):
22 | frame_result.append((ids[_i][_j], bboxes[_i][_j]))
23 | result.append(frame_result)
24 | return result
25 |
26 |
27 | def id_transform(ids):
28 | frame_num = len(ids)
29 |
30 | id_list = list()
31 | for _i in range(frame_num):
32 | id_list += ids[_i]
33 | id_list = sorted(list(set(id_list)))
34 |
35 | id_mapping = dict()
36 | for _i, id in enumerate(id_list):
37 | id_mapping[id] = _i
38 |
39 | result = list()
40 | for _i in range(frame_num):
41 | frame_ids = list()
42 | frame_id_num = len(ids[_i])
43 | for _j in range(frame_id_num):
44 | frame_ids.append(id_mapping[ids[_i][_j]])
45 | result.append(frame_ids)
46 | return result
47 |
48 |
49 | def inst_filter(ids, bboxes, types, type_field=[1], id_trans=False):
50 | """ filter the bboxes according to types
51 | """
52 | frame_num = len(ids)
53 | if id_trans:
54 | ids = id_transform(ids)
55 | id_result, bbox_result = [], []
56 | for _i in range(frame_num):
57 | frame_ids = list()
58 | frame_bboxes = list()
59 | frame_id_num = len(ids[_i])
60 | for _j in range(frame_id_num):
61 | obj_type = types[_i][_j]
62 | matched = False
63 | for type_name in type_field:
64 | if str(type_name) in str(obj_type):
65 | matched = True
66 | if matched:
67 | frame_ids.append(ids[_i][_j])
68 | frame_bboxes.append(BBox.array2bbox(bboxes[_i][_j]))
69 | id_result.append(frame_ids)
70 | bbox_result.append(frame_bboxes)
71 | return id_result, bbox_result
72 |
73 |
74 | def type_filter(contents, types, type_field=[1]):
75 | frame_num = len(types)
76 | content_result = [list() for i in range(len(type_field))]
77 | for _k, inst_type in enumerate(type_field):
78 | for _i in range(frame_num):
79 | frame_contents = list()
80 | frame_id_num = len(contents[_i])
81 | for _j in range(frame_id_num):
82 | if types[_i][_j] != inst_type:
83 | continue
84 | frame_contents.append(contents[_i][_j])
85 | content_result[_k].append(frame_contents)
86 | return content_result
--------------------------------------------------------------------------------
/mot_3d/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | from shapely.geometry import Polygon
4 | from scipy.spatial import ConvexHull, convex_hull_plot_2d
5 | import numba
6 | from ..data_protos import BBox
7 |
8 |
9 | __all__ = ['pc_in_box', 'downsample', 'pc_in_box_2D',
10 | 'apply_motion_to_points', 'make_transformation_matrix',
11 | 'iou2d', 'iou3d', 'pc2world', 'giou2d', 'giou3d',
12 | 'back_step_det', 'm_distance', 'velo2world', 'score_rectification']
13 |
14 |
15 | def velo2world(ego_matrix, velo):
16 | """ transform local velocity [x, y] to global
17 | """
18 | new_velo = velo[:, np.newaxis]
19 | new_velo = ego_matrix[:2, :2] @ new_velo
20 | return new_velo[:, 0]
21 |
22 |
23 | def apply_motion_to_points(points, motion, pre_move=0):
24 | transformation_matrix = make_transformation_matrix(motion)
25 | points = deepcopy(points)
26 | points = points + pre_move
27 | new_points = np.concatenate((points,
28 | np.ones(points.shape[0])[:, np.newaxis]),
29 | axis=1)
30 |
31 | new_points = transformation_matrix @ new_points.T
32 | new_points = new_points.T[:, :3]
33 | new_points -= pre_move
34 | return new_points
35 |
36 |
37 | @numba.njit
38 | def downsample(points, voxel_size=0.05):
39 | sample_dict = dict()
40 | for i in range(points.shape[0]):
41 | point_coord = np.floor(points[i] / voxel_size)
42 | sample_dict[(int(point_coord[0]), int(point_coord[1]), int(point_coord[2]))] = True
43 | res = np.zeros((len(sample_dict), 3), dtype=np.float32)
44 | idx = 0
45 | for k, v in sample_dict.items():
46 | res[idx, 0] = k[0] * voxel_size + voxel_size / 2
47 | res[idx, 1] = k[1] * voxel_size + voxel_size / 2
48 | res[idx, 2] = k[2] * voxel_size + voxel_size / 2
49 | idx += 1
50 | return res
51 |
52 |
53 | # def pc_in_box(box, pc, box_scaling=1.5):
54 | # center_x, center_y, length, width = \
55 | # box['center_x'], box['center_y'], box['length'], box['width']
56 | # center_z, height = box['center_z'], box['height']
57 | # yaw = box['heading']
58 |
59 | # rx = np.abs((pc[:, 0] - center_x) * np.cos(yaw) + (pc[:, 1] - center_y) * np.sin(yaw))
60 | # ry = np.abs((pc[:, 0] - center_x) * -(np.sin(yaw)) + (pc[:, 1] - center_y) * np.cos(yaw))
61 | # rz = np.abs(pc[:, 2] - center_z)
62 |
63 | # mask_x = (rx < (length * box_scaling / 2))
64 | # mask_y = (ry < (width * box_scaling / 2))
65 | # mask_z = (rz < (height / 2))
66 |
67 | # mask = mask_x * mask_y * mask_z
68 | # indices = np.argwhere(mask == 1).reshape(-1)
69 | # return pc[indices, :]
70 |
71 |
72 | # def pc_in_box_2D(box, pc, box_scaling=1.0):
73 | # center_x, center_y, length, width = \
74 | # box['center_x'], box['center_y'], box['length'], box['width']
75 | # yaw = box['heading']
76 |
77 | # cos = np.cos(yaw)
78 | # sin = np.sin(yaw)
79 | # rx = np.abs((pc[:, 0] - center_x) * cos + (pc[:, 1] - center_y) * sin)
80 | # ry = np.abs((pc[:, 0] - center_x) * -(sin) + (pc[:, 1] - center_y) * cos)
81 |
82 | # mask_x = (rx < (length * box_scaling / 2))
83 | # mask_y = (ry < (width * box_scaling / 2))
84 |
85 | # mask = mask_x * mask_y
86 | # indices = np.argwhere(mask == 1).reshape(-1)
87 | # return pc[indices, :]
88 |
89 |
90 | def pc_in_box(box, pc, box_scaling=1.5):
91 | center_x, center_y, length, width = \
92 | box.x, box.y, box.l, box.w
93 | center_z, height = box.z, box.h
94 | yaw = box.o
95 | return pc_in_box_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling)
96 |
97 |
98 | @numba.njit
99 | def pc_in_box_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling=1.5):
100 | mask = np.zeros(pc.shape[0], dtype=np.int32)
101 | yaw_cos, yaw_sin = np.cos(yaw), np.sin(yaw)
102 | for i in range(pc.shape[0]):
103 | rx = np.abs((pc[i, 0] - center_x) * yaw_cos + (pc[i, 1] - center_y) * yaw_sin)
104 | ry = np.abs((pc[i, 0] - center_x) * -yaw_sin + (pc[i, 1] - center_y) * yaw_cos)
105 | rz = np.abs(pc[i, 2] - center_z)
106 |
107 | if rx < (length * box_scaling / 2) and ry < (width * box_scaling / 2) and rz < (height * box_scaling / 2):
108 | mask[i] = 1
109 | indices = np.argwhere(mask == 1)
110 | result = np.zeros((indices.shape[0], 3), dtype=np.float64)
111 | for i in range(indices.shape[0]):
112 | result[i, :] = pc[indices[i], :]
113 | return result
114 |
115 |
116 | def pc_in_box_2D(box, pc, box_scaling=1.0):
117 | center_x, center_y, length, width = \
118 | box.x, box.y, box.l, box.w
119 | center_z, height = box.z, box.h
120 | yaw = box.o
121 | return pc_in_box_2D_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling)
122 |
123 |
124 | @numba.njit
125 | def pc_in_box_2D_inner(center_x, center_y, center_z, length, width, height, yaw, pc, box_scaling=1.0):
126 | mask = np.zeros(pc.shape[0], dtype=np.int32)
127 | yaw_cos, yaw_sin = np.cos(yaw), np.sin(yaw)
128 | for i in range(pc.shape[0]):
129 | rx = np.abs((pc[i, 0] - center_x) * yaw_cos + (pc[i, 1] - center_y) * yaw_sin)
130 | ry = np.abs((pc[i, 0] - center_x) * -yaw_sin + (pc[i, 1] - center_y) * yaw_cos)
131 |
132 | if rx < (length * box_scaling / 2) and ry < (width * box_scaling / 2):
133 | mask[i] = 1
134 | indices = np.argwhere(mask == 1)
135 | result = np.zeros((indices.shape[0], 3), dtype=np.float64)
136 | for i in range(indices.shape[0]):
137 | result[i, :] = pc[indices[i], :]
138 | return result
139 |
140 |
141 | def make_transformation_matrix(motion):
142 | x, y, z, theta = motion
143 | transformation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0, x],
144 | [np.sin(theta), np.cos(theta), 0, y],
145 | [0 , 0 , 1, z],
146 | [0 , 0 , 0, 1]])
147 | return transformation_matrix
148 |
149 |
150 | def iou2d(box_a, box_b):
151 | boxa_corners = np.array(BBox.box2corners2d(box_a))[:, :2]
152 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2]
153 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners)
154 | overlap = reca.intersection(recb).area
155 | area_a = reca.area
156 | area_b = recb.area
157 | iou = overlap / (area_a + area_b - overlap + 1e-10)
158 | return iou
159 |
160 |
161 | def iou3d(box_a, box_b):
162 | boxa_corners = np.array(BBox.box2corners2d(box_a))
163 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2]
164 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners)
165 | overlap_area = reca.intersection(recb).area
166 | iou_2d = overlap_area / (reca.area + recb.area - overlap_area)
167 |
168 | ha, hb = box_a.h, box_b.h
169 | za, zb = box_a.z, box_b.z
170 | overlap_height = max(0, min((za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2)))
171 | overlap_volume = overlap_area * overlap_height
172 | union_volume = box_a.w * box_a.l * ha + box_b.w * box_b.l * hb - overlap_volume
173 | iou_3d = overlap_volume / (union_volume + 1e-5)
174 |
175 | return iou_2d, iou_3d
176 |
177 |
178 | def pc2world(ego_matrix, pcs):
179 | new_pcs = np.concatenate((pcs,
180 | np.ones(pcs.shape[0])[:, np.newaxis]),
181 | axis=1)
182 | new_pcs = ego_matrix @ new_pcs.T
183 | new_pcs = new_pcs.T[:, :3]
184 | return new_pcs
185 |
186 |
187 | def giou2d(box_a: BBox, box_b: BBox):
188 | boxa_corners = np.array(BBox.box2corners2d(box_a))
189 | boxb_corners = np.array(BBox.box2corners2d(box_b))
190 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners)
191 |
192 | # compute intersection and union
193 | I = reca.intersection(recb).area
194 | U = box_a.w * box_a.l + box_b.w * box_b.l - I
195 |
196 | # compute the convex area
197 | all_corners = np.vstack((boxa_corners, boxb_corners))
198 | C = ConvexHull(all_corners)
199 | convex_corners = all_corners[C.vertices]
200 | convex_area = PolyArea2D(convex_corners)
201 | C = convex_area
202 |
203 | # compute giou
204 | return I / U - (C - U) / C
205 |
206 |
207 | def giou3d(box_a: BBox, box_b: BBox):
208 | boxa_corners = np.array(BBox.box2corners2d(box_a))[:, :2]
209 | boxb_corners = np.array(BBox.box2corners2d(box_b))[:, :2]
210 | reca, recb = Polygon(boxa_corners), Polygon(boxb_corners)
211 | ha, hb = box_a.h, box_b.h
212 | za, zb = box_a.z, box_b.z
213 | overlap_height = max(0, min([(za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2), ha, hb]))
214 | union_height = max([(za + ha / 2) - (zb - hb / 2), (zb + hb / 2) - (za - ha / 2), ha, hb])
215 |
216 | # compute intersection and union
217 | I = reca.intersection(recb).area * overlap_height
218 | U = box_a.w * box_a.l * ha + box_b.w * box_b.l * hb - I
219 |
220 | # compute the convex area
221 | all_corners = np.vstack((boxa_corners, boxb_corners))
222 | C = ConvexHull(all_corners)
223 | convex_corners = all_corners[C.vertices]
224 | convex_area = PolyArea2D(convex_corners)
225 | C = convex_area * union_height
226 |
227 | # compute giou
228 | giou = I / U - (C - U) / C
229 | return giou
230 |
231 |
232 | def PolyArea2D(pts):
233 | roll_pts = np.roll(pts, -1, axis=0)
234 | area = np.abs(np.sum((pts[:, 0] * roll_pts[:, 1] - pts[:, 1] * roll_pts[:, 0]))) * 0.5
235 | return area
236 |
237 |
238 | def back_step_det(det: BBox, velo, time_lag):
239 | result = BBox()
240 | BBox.copy_bbox(result, det)
241 | result.x -= (time_lag * velo[0])
242 | result.y -= (time_lag * velo[1])
243 | return result
244 |
245 |
246 | def diff_orientation_correction(diff):
247 | """
248 | return the angle diff = det - trk
249 | if angle diff > 90 or < -90, rotate trk and update the angle diff
250 | """
251 | if diff > np.pi / 2:
252 | diff -= np.pi
253 | if diff < -np.pi / 2:
254 | diff += np.pi
255 | return diff
256 |
257 |
258 | def m_distance(det, trk, trk_inv_innovation_matrix=None):
259 | det_array = BBox.bbox2array(det)[:7]
260 | trk_array = BBox.bbox2array(trk)[:7]
261 |
262 | diff = np.expand_dims(det_array - trk_array, axis=1)
263 | corrected_yaw_diff = diff_orientation_correction(diff[3])
264 | diff[3] = corrected_yaw_diff
265 |
266 | if trk_inv_innovation_matrix is not None:
267 | result = \
268 | np.sqrt(np.matmul(np.matmul(diff.T, trk_inv_innovation_matrix), diff)[0][0])
269 | else:
270 | result = np.sqrt(np.dot(diff.T, diff))
271 | return result
272 |
273 |
274 | def score_rectification(dets, gts):
275 | """ rectify the scores of detections according to their 3d iou with gts
276 | """
277 | result = deepcopy(dets)
278 |
279 | if len(gts) == 0:
280 | for i, _ in enumerate(dets):
281 | result[i].s = 0.0
282 | return result
283 |
284 | if len(dets) == 0:
285 | return result
286 |
287 | iou_matrix = np.zeros((len(dets), len(gts)))
288 | for i, d in enumerate(dets):
289 | for j, g in enumerate(gts):
290 | iou_matrix[i, j] = iou3d(d, g)[1]
291 | max_index = np.argmax(iou_matrix, axis=1)
292 | max_iou = np.max(iou_matrix, axis=1)
293 | index = list(reversed(sorted(range(len(dets)), key=lambda k:max_iou[k])))
294 |
295 | matched_gt = []
296 | for i in index:
297 | if max_iou[i] >= 0.1 and max_index[i] not in matched_gt:
298 | result[i].s = max_iou[i]
299 | matched_gt.append(max_index[i])
300 | elif max_iou[i] >= 0.1 and max_index[i] in matched_gt:
301 | result[i].s = 0.2
302 | else:
303 | result[i].s = 0.05
304 |
305 | return result
306 |
--------------------------------------------------------------------------------
/mot_3d/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualizer2d import Visualizer2D
--------------------------------------------------------------------------------
/mot_3d/visualization/visualizer2d.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt, numpy as np
2 | from ..data_protos import BBox
3 |
4 |
5 | class Visualizer2D:
6 | def __init__(self, name='', figsize=(8, 8)):
7 | self.figure = plt.figure(name, figsize=figsize)
8 | plt.axis('equal')
9 | self.COLOR_MAP = {
10 | 'gray': np.array([140, 140, 136]) / 256,
11 | 'light_blue': np.array([4, 157, 217]) / 256,
12 | 'red': np.array([191, 4, 54]) / 256,
13 | 'black': np.array([0, 0, 0]) / 256,
14 | 'purple': np.array([224, 133, 250]) / 256,
15 | 'dark_green': np.array([32, 64, 40]) / 256,
16 | 'green': np.array([77, 115, 67]) / 256
17 | }
18 |
19 | def show(self):
20 | plt.show()
21 |
22 | def close(self):
23 | plt.close()
24 |
25 | def save(self, path):
26 | plt.savefig(path)
27 |
28 | def handler_pc(self, pc, color='gray'):
29 | vis_pc = np.asarray(pc)
30 | plt.scatter(vis_pc[:, 0], vis_pc[:, 1], marker='o', color=self.COLOR_MAP[color], s=0.01)
31 |
32 | def handler_box(self, box: BBox, message: str='', color='red', linestyle='solid'):
33 | corners = np.array(BBox.box2corners2d(box))[:, :2]
34 | corners = np.concatenate([corners, corners[0:1, :2]])
35 | plt.plot(corners[:, 0], corners[:, 1], color=self.COLOR_MAP[color], linestyle=linestyle)
36 | corner_index = np.random.randint(0, 4, 1)
37 | plt.text(corners[corner_index, 0] - 1, corners[corner_index, 1] - 1, message, color=self.COLOR_MAP[color])
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/detection.py:
--------------------------------------------------------------------------------
1 | import os, argparse, numpy as np, json
2 | from tqdm import tqdm
3 |
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
7 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
8 | parser.add_argument('--det_name', type=str, default='cp')
9 | parser.add_argument('--file_path', type=str, default='val.json')
10 | parser.add_argument('--velo', action='store_true', default=False)
11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
12 | args = parser.parse_args()
13 |
14 |
15 | def get_sample_tokens(data_folder, mode):
16 | token_folder = os.path.join(data_folder, 'token_info')
17 | file_names = sorted(os.listdir(token_folder))
18 | result = dict()
19 | for i, file_name in enumerate(file_names):
20 | file_path = os.path.join(token_folder, file_name)
21 | scene_name = file_name.split('.')[0]
22 | tokens = json.load(open(file_path, 'r'))
23 |
24 | if mode == '2hz':
25 | result[scene_name] = tokens
26 | elif mode == '20hz':
27 | result[scene_name] = [t[0] for t in tokens]
28 | return result
29 |
30 |
31 | def sample_result2bbox_array(sample):
32 | trans, size, rot, score = \
33 | sample['translation'], sample['size'],sample['rotation'], sample['detection_score']
34 | return trans + size + rot + [score]
35 |
36 |
37 | def main(det_name, file_path, detection_folder, data_folder, mode):
38 | # dealing with the paths
39 | detection_folder = os.path.join(detection_folder, det_name)
40 | output_folder = os.path.join(detection_folder, 'dets')
41 | os.makedirs(output_folder, exist_ok=True)
42 |
43 | # load the detection file
44 | print('LOADING RAW FILE')
45 | f = open(file_path, 'r')
46 | det_data = json.load(f)['results']
47 | f.close()
48 |
49 | # prepare the scene names and all the related tokens
50 | tokens = get_sample_tokens(data_folder, mode)
51 | scene_names = sorted(list(tokens.keys()))
52 | bboxes, inst_types, velos = dict(), dict(), dict()
53 | for scene_name in scene_names:
54 | frame_num = len(tokens[scene_name])
55 | bboxes[scene_name], inst_types[scene_name] = \
56 | [[] for i in range(frame_num)], [[] for i in range(frame_num)]
57 | if args.velo:
58 | velos[scene_name] = [[] for i in range(frame_num)]
59 |
60 | # enumerate through all the frames
61 | sample_keys = list(det_data.keys())
62 | print('PROCESSING...')
63 | pbar = tqdm(total=len(sample_keys))
64 | for sample_index, sample_key in enumerate(sample_keys):
65 | # find the corresponding scene and frame index
66 | scene_name, frame_index = None, None
67 | for scene_name in scene_names:
68 | if sample_key in tokens[scene_name]:
69 | frame_index = tokens[scene_name].index(sample_key)
70 | break
71 |
72 | # extract the bboxes and types
73 | sample_results = det_data[sample_key]
74 | for sample in sample_results:
75 | bbox, inst_type = sample_result2bbox_array(sample), sample['detection_name']
76 | inst_velo = sample['velocity']
77 | bboxes[scene_name][frame_index] += [bbox]
78 | inst_types[scene_name][frame_index] += [inst_type]
79 |
80 | if args.velo:
81 | velos[scene_name][frame_index] += [inst_velo]
82 | pbar.update(1)
83 | pbar.close()
84 |
85 | # save the results
86 | print('SAVING...')
87 | pbar = tqdm(total=len(scene_names))
88 | for scene_name in scene_names:
89 | if args.velo:
90 | np.savez_compressed(os.path.join(output_folder, '{:}.npz'.format(scene_name)),
91 | bboxes=bboxes[scene_name], types=inst_types[scene_name], velos=velos[scene_name])
92 | else:
93 | np.savez_compressed(os.path.join(output_folder, '{:}.npz'.format(scene_name)),
94 | bboxes=bboxes[scene_name], types=inst_types[scene_name])
95 | pbar.update(1)
96 | pbar.close()
97 | return
98 |
99 |
100 | if __name__ == '__main__':
101 | detection_folder = os.path.join(args.data_folder, 'detection')
102 | os.makedirs(detection_folder, exist_ok=True)
103 |
104 | main(args.det_name, args.file_path, detection_folder, args.data_folder, args.mode)
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/ego_pose.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | from tqdm import tqdm
6 |
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
12 | args = parser.parse_args()
13 |
14 |
15 | def main(nusc, scene_names, root_path, ego_folder, mode):
16 | pbar = tqdm(total=len(scene_names))
17 | for scene_index, scene_info in enumerate(nusc.scene):
18 | scene_name = scene_info['name']
19 | if scene_name not in scene_names:
20 | continue
21 | first_sample_token = scene_info['first_sample_token']
22 | last_sample_token = scene_info['last_sample_token']
23 | frame_data = nusc.get('sample', first_sample_token)
24 | if args.mode == '20hz':
25 | cur_sample_token = frame_data['data']['LIDAR_TOP']
26 | elif args.mode == '2hz':
27 | cur_sample_token = deepcopy(first_sample_token)
28 | frame_index = 0
29 | ego_data = dict()
30 | while True:
31 | if mode == '2hz':
32 | frame_data = nusc.get('sample', cur_sample_token)
33 | lidar_token = frame_data['data']['LIDAR_TOP']
34 | lidar_data = nusc.get('sample_data', lidar_token)
35 | ego_token = lidar_data['ego_pose_token']
36 | ego_pose = nusc.get('ego_pose', ego_token)
37 | elif mode == '20hz':
38 | frame_data = nusc.get('sample_data', cur_sample_token)
39 | ego_token = frame_data['ego_pose_token']
40 | ego_pose = nusc.get('ego_pose', ego_token)
41 |
42 | # translation + rotation
43 | ego_data[str(frame_index)] = ego_pose['translation'] + ego_pose['rotation']
44 |
45 | # clean up and prepare for the next
46 | cur_sample_token = frame_data['next']
47 | if cur_sample_token == '':
48 | break
49 | frame_index += 1
50 |
51 | np.savez_compressed(os.path.join(ego_folder, '{:}.npz'.format(scene_name)), **ego_data)
52 | pbar.update(1)
53 | pbar.close()
54 | return
55 |
56 |
57 | if __name__ == '__main__':
58 | print('ego info')
59 | ego_folder = os.path.join(args.data_folder, 'ego_info')
60 | os.makedirs(ego_folder, exist_ok=True)
61 |
62 | val_scene_names = splits.create_splits_scenes()['val']
63 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
64 | main(nusc, val_scene_names, args.raw_data_folder, ego_folder, args.mode)
65 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/gt_info.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | from tqdm import tqdm
6 | import pdb
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
11 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
12 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
13 | args = parser.parse_args()
14 |
15 |
16 | def instance_info2bbox_array(info):
17 | translation = info.center.tolist()
18 | size = info.wlh.tolist()
19 | rotation = info.orientation.q.tolist()
20 | return translation + size + rotation
21 |
22 |
23 | def main(nusc, scene_names, root_path, gt_folder):
24 | pbar = tqdm(total=len(scene_names))
25 | for scene_index, scene_info in enumerate(nusc.scene):
26 | scene_name = scene_info['name']
27 | if scene_name not in scene_names:
28 | continue
29 |
30 | first_sample_token = scene_info['first_sample_token']
31 | last_sample_token = scene_info['last_sample_token']
32 | frame_data = nusc.get('sample', first_sample_token)
33 | if args.mode == '20hz':
34 | cur_sample_token = frame_data['data']['LIDAR_TOP']
35 | elif args.mode == '2hz':
36 | cur_sample_token = deepcopy(first_sample_token)
37 |
38 | frame_index = 0
39 | IDS, inst_types, bboxes = list(), list(), list()
40 | while True:
41 | frame_ids, frame_types, frame_bboxes = list(), list(), list()
42 | if args.mode == '2hz':
43 | frame_data = nusc.get('sample', cur_sample_token)
44 | lidar_token = frame_data['data']['LIDAR_TOP']
45 | instances = nusc.get_boxes(lidar_token)
46 | for inst in instances:
47 | frame_ids.append(inst.token)
48 | frame_types.append(inst.name)
49 | frame_bboxes.append(instance_info2bbox_array(inst))
50 |
51 | elif args.mode == '20hz':
52 | frame_data = nusc.get('sample_data', cur_sample_token)
53 | lidar_data = nusc.get('sample_data', cur_sample_token)
54 | instances = nusc.get_boxes(lidar_data['token'])
55 | for inst in instances:
56 | frame_ids.append(inst.token)
57 | frame_types.append(inst.name)
58 | frame_bboxes.append(instance_info2bbox_array(inst))
59 |
60 | IDS.append(frame_ids)
61 | inst_types.append(frame_types)
62 | bboxes.append(frame_bboxes)
63 |
64 | # clean up and prepare for the next
65 | cur_sample_token = frame_data['next']
66 | if cur_sample_token == '':
67 | break
68 |
69 | np.savez_compressed(os.path.join(gt_folder, '{:}.npz'.format(scene_name)),
70 | ids=IDS, types=inst_types, bboxes=bboxes)
71 | pbar.update(1)
72 | pbar.close()
73 | return
74 |
75 |
76 | if __name__ == '__main__':
77 | print('gt info')
78 | gt_folder = os.path.join(args.data_folder, 'gt_info')
79 | os.makedirs(gt_folder, exist_ok=True)
80 |
81 | val_scene_names = splits.create_splits_scenes()['val']
82 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
83 | main(nusc, val_scene_names, args.raw_data_folder, gt_folder)
84 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/nuscenes_preprocess.sh:
--------------------------------------------------------------------------------
1 | raw_data_dir=$1
2 | data_dir_2hz=$2
3 | data_dir_20hz=$3
4 |
5 | # token information
6 | python token_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
7 | python token_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
8 |
9 | # time stamp information
10 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
11 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
12 |
13 | # sensor calibration information
14 | python sensor_calibration.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
15 | python sensor_calibration.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
16 |
17 | # ego pose
18 | python ego_pose.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
19 | python ego_pose.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
20 |
21 | # gt information
22 | python gt_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
23 | python gt_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
24 |
25 | # point cloud, useful for visualization
26 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir_2hz --mode 2hz
27 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir_20hz --mode 20hz
28 |
29 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/raw_pc.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | import matplotlib.pyplot as plt
6 | import multiprocessing
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
11 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
12 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
13 | parser.add_argument('--process', type=int, default=1)
14 | args = parser.parse_args()
15 |
16 |
17 | def load_pc(path):
18 | pc = np.fromfile(path, dtype=np.float32)
19 | pc = pc.reshape((-1, 5))[:, :4]
20 | return pc
21 |
22 |
23 | def main(nusc, scene_names, root_path, pc_folder, mode, pid=0, process=1):
24 | for scene_index, scene_info in enumerate(nusc.scene):
25 | if scene_index % process != pid:
26 | continue
27 | scene_name = scene_info['name']
28 | if scene_name not in scene_names:
29 | continue
30 | print('PROCESSING {:} / {:}'.format(scene_index + 1, len(nusc.scene)))
31 |
32 | first_sample_token = scene_info['first_sample_token']
33 | frame_data = nusc.get('sample', first_sample_token)
34 | if mode == '20hz':
35 | cur_sample_token = frame_data['data']['LIDAR_TOP']
36 | elif mode == '2hz':
37 | cur_sample_token = deepcopy(first_sample_token)
38 | frame_index = 0
39 | pc_data = dict()
40 | while True:
41 | # find the path to lidar data
42 | if mode == '2hz':
43 | lidar_data = nusc.get('sample', cur_sample_token)
44 | lidar_path = nusc.get_sample_data_path(lidar_data['data']['LIDAR_TOP'])
45 | elif args.mode == '20hz':
46 | lidar_data = nusc.get('sample_data', cur_sample_token)
47 | lidar_path = lidar_data['filename']
48 |
49 | # load and store the data
50 | point_cloud = np.fromfile(os.path.join(root_path, lidar_path), dtype=np.float32)
51 | point_cloud = np.reshape(point_cloud, (-1, 5))[:, :4]
52 | pc_data[str(frame_index)] = point_cloud
53 |
54 | # clean up and prepare for the next
55 | cur_sample_token = lidar_data['next']
56 | if cur_sample_token == '':
57 | break
58 | frame_index += 1
59 |
60 | if frame_index % 10 == 0:
61 | print('PROCESSING ', scene_index, ' , ', frame_index)
62 |
63 | np.savez_compressed(os.path.join(pc_folder, '{:}.npz'.format(scene_name)), **pc_data)
64 | return
65 |
66 |
67 | if __name__ == '__main__':
68 | print('point cloud')
69 | pc_folder = os.path.join(args.data_folder, 'pc', 'raw_pc')
70 | os.makedirs(pc_folder, exist_ok=True)
71 |
72 | val_scene_names = splits.create_splits_scenes()['val']
73 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
74 | main(nusc, val_scene_names, args.raw_data_folder, pc_folder, args.mode, 0, 1)
75 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/sensor_calibration.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | from tqdm import tqdm
6 |
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
12 | args = parser.parse_args()
13 |
14 |
15 | def main(nusc, scene_names, root_path, calib_folder, mode):
16 | pbar = tqdm(total=len(scene_names))
17 | for scene_index, scene_info in enumerate(nusc.scene):
18 | scene_name = scene_info['name']
19 | if scene_name not in scene_names:
20 | continue
21 |
22 | first_sample_token = scene_info['first_sample_token']
23 | last_sample_token = scene_info['last_sample_token']
24 | frame_data = nusc.get('sample', first_sample_token)
25 | if mode == '20hz':
26 | cur_sample_token = frame_data['data']['LIDAR_TOP']
27 | elif mode == '2hz':
28 | cur_sample_token = deepcopy(first_sample_token)
29 | frame_index = 0
30 | calib_data = dict()
31 | while True:
32 | if mode == '2hz':
33 | frame_data = nusc.get('sample', cur_sample_token)
34 | lidar_token = frame_data['data']['LIDAR_TOP']
35 | lidar_data = nusc.get('sample_data', lidar_token)
36 | calib_token = lidar_data['calibrated_sensor_token']
37 | calib_pose = nusc.get('calibrated_sensor', calib_token)
38 | elif mode == '20hz':
39 | frame_data = nusc.get('sample_data', cur_sample_token)
40 | calib_token = frame_data['calibrated_sensor_token']
41 | calib_pose = nusc.get('calibrated_sensor', calib_token)
42 |
43 | # translation + rotation
44 | calib_data[str(frame_index)] = calib_pose['translation'] + calib_pose['rotation']
45 |
46 | # clean up and prepare for the next
47 | cur_sample_token = frame_data['next']
48 | if cur_sample_token == '':
49 | break
50 | frame_index += 1
51 |
52 | np.savez_compressed(os.path.join(calib_folder, '{:}.npz'.format(scene_name)), **calib_data)
53 | pbar.update(1)
54 | pbar.close()
55 | return
56 |
57 |
58 | if __name__ == '__main__':
59 | print('sensor calib')
60 | calib_folder = os.path.join(args.data_folder, 'calib_info')
61 | os.makedirs(calib_folder, exist_ok=True)
62 |
63 | val_scene_names = splits.create_splits_scenes()['val']
64 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
65 | main(nusc, val_scene_names, args.raw_data_folder, calib_folder, args.mode)
66 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/time_stamp.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse, json
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | from tqdm import tqdm
6 |
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
12 | args = parser.parse_args()
13 |
14 |
15 | def main(nusc, scene_names, root_path, ts_folder, mode):
16 | pbar = tqdm(total=len(scene_names))
17 | for scene_index, scene_info in enumerate(nusc.scene):
18 | scene_name = scene_info['name']
19 | if scene_name not in scene_names:
20 | continue
21 |
22 | first_sample_token = scene_info['first_sample_token']
23 | last_sample_token = scene_info['last_sample_token']
24 | frame_data = nusc.get('sample', first_sample_token)
25 | if mode == '20hz':
26 | cur_sample_token = frame_data['data']['LIDAR_TOP']
27 | elif mode == '2hz':
28 | cur_sample_token = deepcopy(first_sample_token)
29 | time_stamps = list()
30 |
31 | while True:
32 | if mode == '2hz':
33 | frame_data = nusc.get('sample', cur_sample_token)
34 | time_stamps.append(frame_data['timestamp'])
35 | elif mode == '20hz':
36 | frame_data = nusc.get('sample_data', cur_sample_token)
37 | # time stamp and if key frame
38 | time_stamps.append((frame_data['timestamp'], frame_data['is_key_frame']))
39 |
40 | # clean up and prepare for the next
41 | cur_sample_token = frame_data['next']
42 | if cur_sample_token == '':
43 | break
44 | f = open(os.path.join(ts_folder, '{:}.json'.format(scene_name)), 'w')
45 | json.dump(time_stamps, f)
46 | f.close()
47 | pbar.update(1)
48 | pbar.close()
49 | return
50 |
51 |
52 | if __name__ == '__main__':
53 | print('time stamp')
54 | ts_folder = os.path.join(args.data_folder, 'ts_info')
55 | os.makedirs(ts_folder, exist_ok=True)
56 |
57 | val_scene_names = splits.create_splits_scenes()['val']
58 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
59 | main(nusc, val_scene_names, args.raw_data_folder, ts_folder, args.mode)
60 |
--------------------------------------------------------------------------------
/preprocessing/nuscenes_data/token_info.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, nuscenes, argparse, json
2 | from nuscenes.nuscenes import NuScenes
3 | from nuscenes.utils import splits
4 | from copy import deepcopy
5 | from tqdm import tqdm
6 |
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--raw_data_folder', type=str, default='../../../raw/nuscenes/data/sets/nuscenes/')
10 | parser.add_argument('--data_folder', type=str, default='../../../datasets/nuscenes/')
11 | parser.add_argument('--mode', type=str, default='2hz', choices=['20hz', '2hz'])
12 | args = parser.parse_args()
13 |
14 |
15 | def set_selected_or_not(frame_tokens):
16 | """ under the 20hz setting,
17 | we have to set whether to use a certain frame
18 | 1. select at the interval of 1 frames
19 | 2. if meet key frame, reset the counter
20 | """
21 | counter = -1
22 | selected = list()
23 | frame_num = len(frame_tokens)
24 | for _, tokens in enumerate(frame_tokens):
25 | is_key_frame = tokens[1]
26 | counter += 1
27 | if is_key_frame:
28 | selected.append(True)
29 | counter = 0
30 | continue
31 | else:
32 | if counter % 2 == 0:
33 | selected.append(True)
34 | else:
35 | selected.append(False)
36 | result_tokens = [(list(frame_tokens[i]) + [selected[i]]) for i in range(frame_num)]
37 | return result_tokens
38 |
39 |
40 | def main(nusc, scene_names, root_path, token_folder, mode):
41 | pbar = tqdm(total=len(scene_names))
42 | for scene_index, scene_info in enumerate(nusc.scene):
43 | scene_name = scene_info['name']
44 | if scene_name not in scene_names:
45 | continue
46 |
47 | first_sample_token = scene_info['first_sample_token']
48 | last_sample_token = scene_info['last_sample_token']
49 | frame_data = nusc.get('sample', first_sample_token)
50 | if mode == '20hz':
51 | cur_sample_token = frame_data['data']['LIDAR_TOP']
52 | elif mode == '2hz':
53 | cur_sample_token = deepcopy(first_sample_token)
54 | frame_tokens = list()
55 |
56 | while True:
57 | # find the path to lidar data
58 | if mode == '2hz':
59 | frame_data = nusc.get('sample', cur_sample_token)
60 | frame_tokens.append(cur_sample_token)
61 | elif mode == '20hz':
62 | frame_data = nusc.get('sample_data', cur_sample_token)
63 | frame_tokens.append((cur_sample_token, frame_data['is_key_frame'], frame_data['sample_token']))
64 |
65 | # clean up and prepare for the next
66 | cur_sample_token = frame_data['next']
67 | if cur_sample_token == '':
68 | break
69 |
70 | if mode == '20hz':
71 | frame_tokens = set_selected_or_not(frame_tokens)
72 | f = open(os.path.join(token_folder, '{:}.json'.format(scene_name)), 'w')
73 | json.dump(frame_tokens, f)
74 | f.close()
75 |
76 | pbar.update(1)
77 | pbar.close()
78 | return
79 |
80 |
81 | if __name__ == '__main__':
82 | print('token info')
83 | os.makedirs(args.data_folder, exist_ok=True)
84 |
85 | token_folder = os.path.join(args.data_folder, 'token_info')
86 | os.makedirs(token_folder, exist_ok=True)
87 |
88 | val_scene_names = splits.create_splits_scenes()['val']
89 | nusc = NuScenes(version='v1.0-trainval', dataroot=args.raw_data_folder, verbose=True)
90 | main(nusc, val_scene_names, args.raw_data_folder, token_folder, args.mode)
91 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/detection.py:
--------------------------------------------------------------------------------
1 | """ Extract the detections from .bin files
2 | Each sequence is a .npz file containing three fields: bboxes, types, ids.
3 | bboxes, types, and ids follow the same format:
4 | [[bboxes in frame 0],
5 | [bboxes in frame 1],
6 | ...
7 | [bboxes in the last frame]]
8 | """
9 | import os, numpy as np, argparse, json
10 | import tensorflow.compat.v1 as tf
11 | from tqdm import tqdm
12 | from google.protobuf.descriptor import FieldDescriptor as FD
13 | tf.enable_eager_execution()
14 | from waymo_open_dataset import label_pb2
15 | from waymo_open_dataset.protos import metrics_pb2
16 |
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--name', type=str, default='public')
20 | parser.add_argument('--file_path', type=str, default='validation.bin')
21 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/')
22 | parser.add_argument('--metadata', action='store_true', default=False)
23 | parser.add_argument('--id', action='store_true', default=False)
24 | args = parser.parse_args()
25 |
26 |
27 | def bbox_dict2array(box_dict):
28 | """transform box dict in waymo_open_format to array
29 | Args:
30 | box_dict ([dict]): waymo_open_dataset formatted bbox
31 | """
32 | result = np.array([
33 | box_dict['center_x'],
34 | box_dict['center_y'],
35 | box_dict['center_z'],
36 | box_dict['heading'],
37 | box_dict['length'],
38 | box_dict['width'],
39 | box_dict['height'],
40 | box_dict['score']
41 | ])
42 | return result
43 |
44 |
45 | def str_list_to_int(lst):
46 | result = []
47 | for t in lst:
48 | try:
49 | t = int(t)
50 | result.append(t)
51 | except:
52 | continue
53 | return result
54 |
55 |
56 | def main(name, data_folder, det_folder, file_path, out_folder):
57 | # load timestamp and segment names
58 | ts_info_folder = os.path.join(data_folder, 'ts_info')
59 | ts_files = os.listdir(ts_info_folder)
60 | ts_info = dict()
61 | segment_name_list = list()
62 | for ts_file_name in ts_files:
63 | ts = json.load(open(os.path.join(ts_info_folder, ts_file_name), 'r'))
64 | segment_name = ts_file_name.split('.')[0]
65 | ts_info[segment_name] = ts
66 | segment_name_list.append(segment_name)
67 |
68 | # load detection file
69 | det_folder = os.path.join(det_folder, name)
70 | f = open(os.path.join(det_folder, file_path), 'rb')
71 | objects = metrics_pb2.Objects()
72 | objects.ParseFromString(f.read())
73 | f.close()
74 |
75 | # parse and aggregate detections
76 | objects = objects.objects
77 | object_num = len(objects)
78 |
79 | result_bbox, result_type, result_velo, result_accel, result_ids = dict(), dict(), dict(), dict(), dict()
80 | for seg_name in ts_info.keys():
81 | result_bbox[seg_name] = dict()
82 | result_type[seg_name] = dict()
83 | result_velo[seg_name] = dict()
84 | result_accel[seg_name] = dict()
85 | result_ids[seg_name] = dict()
86 |
87 | print('Converting')
88 | pbar = tqdm(total=object_num)
89 | for _i in range(object_num):
90 | instance = objects[_i]
91 | segment = instance.context_name
92 | time_stamp = instance.frame_timestamp_micros
93 |
94 | box = instance.object.box
95 | bbox_dict = {
96 | 'center_x': box.center_x,
97 | 'center_y': box.center_y,
98 | 'center_z': box.center_z,
99 | 'width': box.width,
100 | 'length': box.length,
101 | 'height': box.height,
102 | 'heading': box.heading,
103 | 'score': instance.score
104 | }
105 | bbox_array = bbox_dict2array(bbox_dict)
106 | obj_type = instance.object.type
107 |
108 | if args.metadata:
109 | meta_data = instance.object.metadata
110 | velo = (meta_data.speed_x, meta_data.speed_y)
111 | accel = (meta_data.accel_x, meta_data.accel_y)
112 |
113 | if args.id:
114 | id = instance.object.id
115 |
116 | val_index = None
117 | for _j in range(len(segment_name_list)):
118 | if segment in segment_name_list[_j]:
119 | val_index = _j
120 | break
121 | segment_name = segment_name_list[val_index]
122 |
123 | frame_number = None
124 | for _j in range(len(ts_info[segment_name])):
125 | if ts_info[segment_name][_j] == time_stamp:
126 | frame_number = _j
127 | break
128 |
129 | if str(frame_number) not in result_bbox[segment_name].keys():
130 | result_bbox[segment_name][str(frame_number)] = list()
131 | result_type[segment_name][str(frame_number)] = list()
132 | if args.metadata:
133 | result_velo[segment_name][str(frame_number)] = list()
134 | result_accel[segment_name][str(frame_number)] = list()
135 | if args.id:
136 | result_ids[segment_name][str(frame_number)] = list()
137 |
138 | result_bbox[segment_name][str(frame_number)].append(bbox_array)
139 | result_type[segment_name][str(frame_number)].append(obj_type)
140 | if args.metadata:
141 | result_velo[segment_name][str(frame_number)].append(velo)
142 | result_accel[segment_name][str(frame_number)].append(accel)
143 | if args.id:
144 | result_ids[segment_name][str(frame_number)].append(id)
145 |
146 | pbar.update(1)
147 | pbar.close()
148 |
149 | print('Saving')
150 | pbar = tqdm(total=len(segment_name_list))
151 | # store in files
152 | for _i, segment_name in enumerate(segment_name_list):
153 | dets = result_bbox[segment_name]
154 | types = result_type[segment_name]
155 | if args.metadata:
156 | velos = result_velo[segment_name]
157 | accels = result_accel[segment_name]
158 | if args.id:
159 | ids = result_ids[segment_name]
160 |
161 | frame_keys = sorted(str_list_to_int(dets.keys()))
162 | max_frame = max(frame_keys)
163 | bboxes = list()
164 | obj_types = list()
165 | velocities, accelerations, id_names = list(), list(), list()
166 | for key in range(max_frame + 1):
167 | if str(key) in dets.keys():
168 | bboxes.append(dets[str(key)])
169 | obj_types.append(types[str(key)])
170 | if args.metadata:
171 | velocities.append(velos[str(key)])
172 | accelerations.append(accels[str(key)])
173 | if args.id:
174 | id_names.append(ids[str(key)])
175 | else:
176 | bboxes.append([])
177 | obj_types.append([])
178 | if args.metadata:
179 | velocities.append([])
180 | accelerations.append([])
181 | if args.id:
182 | id_names.append([])
183 | result = {'bboxes': bboxes, 'types': obj_types}
184 | if args.metadata:
185 | result['velos'] = velocities
186 | result['accels'] = accelerations
187 | if args.id:
188 | result['ids'] = id_names
189 |
190 | np.savez_compressed(os.path.join(out_folder, "{:}.npz".format(segment_name)), **result)
191 | pbar.update(1)
192 | pbar.close()
193 |
194 |
195 | if __name__ == '__main__':
196 | det_folder = os.path.join(args.data_folder, 'detection')
197 | os.makedirs(det_folder, exist_ok=True)
198 | output_folder = os.path.join(det_folder, args.name, 'dets')
199 | os.makedirs(output_folder, exist_ok=True)
200 |
201 | main(args.name, args.data_folder, det_folder, args.file_path, output_folder)
202 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/ego_info.py:
--------------------------------------------------------------------------------
1 | """ Extract the ego location information from tfrecords
2 | Output file format: dict compressed in .npz files
3 | {
4 | st(frame_num): ego_info (4 * 4 matrix)
5 | }
6 | """
7 | import argparse
8 | import numpy as np
9 | import os
10 | from google.protobuf.descriptor import FieldDescriptor as FD
11 | import tensorflow.compat.v1 as tf
12 | tf.enable_eager_execution()
13 | import multiprocessing
14 | from waymo_open_dataset import dataset_pb2 as open_dataset
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/',
19 | help='location of tfrecords')
20 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/',
21 | help='output folder')
22 | parser.add_argument('--process', type=int, default=1, help='use multiprocessing for acceleration')
23 | args = parser.parse_args()
24 |
25 |
26 | def pb2dict(obj):
27 | """
28 | Takes a ProtoBuf Message obj and convertes it to a dict.
29 | """
30 | adict = {}
31 | # if not obj.IsInitialized():
32 | # return None
33 | for field in obj.DESCRIPTOR.fields:
34 | if not getattr(obj, field.name):
35 | continue
36 | if not field.label == FD.LABEL_REPEATED:
37 | if not field.type == FD.TYPE_MESSAGE:
38 | adict[field.name] = getattr(obj, field.name)
39 | else:
40 | value = pb2dict(getattr(obj, field.name))
41 | if value:
42 | adict[field.name] = value
43 | else:
44 | if field.type == FD.TYPE_MESSAGE:
45 | adict[field.name] = [pb2dict(v) for v in getattr(obj, field.name)]
46 | else:
47 | adict[field.name] = [v for v in getattr(obj, field.name)]
48 | return adict
49 |
50 |
51 | def main(raw_data_folder, data_folder, process_num=1, token=0):
52 | """ The process with index "token" process the ego pose information.
53 | """
54 | tf_records = os.listdir(raw_data_folder)
55 | tf_records = [x for x in tf_records if 'tfrecord' in x]
56 | tf_records = sorted(tf_records)
57 | for record_index, tf_record_name in enumerate(tf_records):
58 | if record_index % process_num != token:
59 | continue
60 | print('starting for ego info ', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name)
61 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name)
62 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='')
63 | segment_name = tf_record_name.split('.')[0]
64 |
65 | frame_num = 0
66 | ego_infos = dict()
67 |
68 | for data in dataset:
69 | frame = open_dataset.Frame()
70 | frame.ParseFromString(bytearray(data.numpy()))
71 |
72 | ego_info = np.reshape(np.array(frame.pose.transform), [4, 4])
73 | ego_infos[str(frame_num)] = ego_info
74 |
75 | frame_num += 1
76 | if frame_num % 10 == 0:
77 | print('ego record {:} / {:} frame number {:}'.format(record_index + 1, len(tf_records), frame_num))
78 | print('{:} frames in total'.format(frame_num))
79 |
80 | np.savez_compressed(os.path.join(data_folder, "{}.npz".format(segment_name)), **ego_infos)
81 |
82 |
83 | if __name__ == '__main__':
84 | args.data_folder = os.path.join(args.data_folder, 'ego_info')
85 | os.makedirs(args.data_folder, exist_ok=True)
86 |
87 | if args.process > 1:
88 | pool = multiprocessing.Pool(args.process)
89 | for token in range(args.process):
90 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token))
91 | pool.close()
92 | pool.join()
93 | else:
94 | main(args.raw_data_folder, args.data_folder)
95 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/gt_bin_decode.py:
--------------------------------------------------------------------------------
1 | """ Process the .bin file of ground truth, and save it in our detection format.
2 | Each sequence is a .npz file containing three fields: bboxes, types, ids.
3 | bboxes, types, and ids follow the same format:
4 | [[bboxes in frame 0],
5 | [bboxes in frame 1],
6 | ...
7 | [bboxes in the last frame]]
8 | """
9 | import os, numpy as np, argparse, json
10 | from mot_3d.data_protos import BBox
11 | import mot_3d.utils as utils
12 | import tensorflow.compat.v1 as tf
13 | from google.protobuf.descriptor import FieldDescriptor as FD
14 | tf.enable_eager_execution()
15 | from waymo_open_dataset.protos import metrics_pb2
16 |
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--data_folder', type=str, default='/../../../datasets/waymo/mot/')
20 | parser.add_argument('--file_path', type=str, default='../datasets/waymo/mot/info/validation_gt.bin')
21 | args = parser.parse_args()
22 |
23 |
24 | def main(file_path, out_folder, data_folder):
25 | # load time stamp
26 | ts_info_folder = os.path.join(data_folder, 'ts_info')
27 | time_stamp_info = dict()
28 | ts_files = os.listdir(ts_info_folder)
29 | for ts_file_name in ts_files:
30 | segment_name = ts_file_name.split('.')[0]
31 | ts_path = os.path.join(ts_info_folder, '{}.json'.format(segment_name))
32 | f = open(ts_path, 'r')
33 | ts = json.load(f)
34 | f.close()
35 | time_stamp_info[segment_name] = ts
36 |
37 | # segment name list
38 | segment_name_list = list()
39 | for ts_file_name in ts_files:
40 | segment_name = ts_file_name.split('.')[0]
41 | segment_name_list.append(segment_name)
42 |
43 | # load gt.bin file
44 | f = open(file_path, 'rb')
45 | objects = metrics_pb2.Objects()
46 | objects.ParseFromString(f.read())
47 | f.close()
48 |
49 | # parse and aggregate detections
50 | objects = objects.objects
51 | object_num = len(objects)
52 |
53 | result_bbox, result_type, result_id = dict(), dict(), dict()
54 | for seg_name in time_stamp_info.keys():
55 | result_bbox[seg_name] = dict()
56 | result_type[seg_name] = dict()
57 | result_id[seg_name] = dict()
58 |
59 | for i in range(object_num):
60 | instance = objects[i]
61 | segment = instance.context_name
62 | time_stamp = instance.frame_timestamp_micros
63 |
64 | box = instance.object.box
65 | box_dict = {
66 | 'center_x': box.center_x,
67 | 'center_y': box.center_y,
68 | 'center_z': box.center_z,
69 | 'width': box.width,
70 | 'length': box.length,
71 | 'height': box.height,
72 | 'heading': box.heading,
73 | 'score': instance.score
74 | }
75 |
76 | val_index = None
77 | for _j in range(len(segment_name_list)):
78 | if segment in segment_name_list[_j]:
79 | val_index = _j
80 | break
81 | segment_name = segment_name_list[val_index]
82 |
83 | frame_number = None
84 | for _j in range(len(time_stamp_info[segment_name])):
85 | if time_stamp_info[segment_name][_j] == time_stamp:
86 | frame_number = _j
87 | break
88 |
89 | if str(frame_number) not in result_bbox[segment_name].keys():
90 | result_bbox[segment_name][str(frame_number)] = list()
91 | result_type[segment_name][str(frame_number)] = list()
92 | result_id[segment_name][str(frame_number)] = list()
93 |
94 | result_bbox[segment_name][str(frame_number)].append(BBox.bbox2array(BBox.dict2bbox(box_dict)))
95 | result_type[segment_name][str(frame_number)].append(instance.object.type)
96 | result_id[segment_name][str(frame_number)].append(instance.object.id)
97 |
98 | if (i + 1) % 10000 == 0:
99 | print(i + 1, ' / ', object_num)
100 |
101 | # store in files
102 | for _i, segment_name in enumerate(time_stamp_info.keys()):
103 | dets = result_bbox[segment_name]
104 | types = result_type[segment_name]
105 | ids = result_id[segment_name]
106 | print('{} / {}'.format(_i + 1, len(time_stamp_info.keys())))
107 |
108 | frame_keys = sorted(utils.str2int(dets.keys()))
109 | max_frame = max(frame_keys)
110 | obj_ids, bboxes, obj_types = list(), list(), list()
111 |
112 | for key in range(max_frame + 1):
113 | if str(key) in dets.keys():
114 | bboxes.append(dets[str(key)])
115 | obj_types.append(types[str(key)])
116 | obj_ids.append(ids[str(key)])
117 | else:
118 | bboxes.append([])
119 | obj_types.append([])
120 | obj_ids.append([])
121 |
122 | np.savez_compressed(os.path.join(out_folder, "{}.npz".format(segment_name)),
123 | bboxes=bboxes, types=obj_types, ids=obj_ids)
124 |
125 |
126 | if __name__ == '__main__':
127 | out_folder = os.path.join(args.data_folder, 'detection', 'gt', 'dets')
128 | os.makedirs(out_folder)
129 | main(args.file_path, out_folder, args.data_folder)
130 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/raw_pc.py:
--------------------------------------------------------------------------------
1 | """ Extract the point cloud sequences from the tfrecords
2 | output format: a compressed dict stored in an npz file
3 | {
4 | str(frame_number): pc (N * 3 numpy array)
5 | }
6 | """
7 | import argparse
8 | import numpy as np
9 | import os
10 | import multiprocessing
11 | from google.protobuf.descriptor import FieldDescriptor as FD
12 | import tensorflow.compat.v1 as tf
13 | tf.enable_eager_execution()
14 | from waymo_open_dataset.utils import frame_utils
15 | from waymo_open_dataset import dataset_pb2 as open_dataset
16 |
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/',
20 | help='the location of tfrecords')
21 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/',
22 | help='the location of raw pcs')
23 | parser.add_argument('--process', type=int, default=1)
24 | args = parser.parse_args()
25 |
26 |
27 | def pb2dict(obj):
28 | """
29 | Takes a ProtoBuf Message obj and convertes it to a dict.
30 | """
31 | adict = {}
32 | # if not obj.IsInitialized():
33 | # return None
34 | for field in obj.DESCRIPTOR.fields:
35 | if not getattr(obj, field.name):
36 | continue
37 | if not field.label == FD.LABEL_REPEATED:
38 | if not field.type == FD.TYPE_MESSAGE:
39 | adict[field.name] = getattr(obj, field.name)
40 | else:
41 | value = pb2dict(getattr(obj, field.name))
42 | if value:
43 | adict[field.name] = value
44 | else:
45 | if field.type == FD.TYPE_MESSAGE:
46 | adict[field.name] = [pb2dict(v) for v in getattr(obj, field.name)]
47 | else:
48 | adict[field.name] = [v for v in getattr(obj, field.name)]
49 | return adict
50 |
51 |
52 | def main(raw_data_folder, data_folder, process=1, token=0):
53 | tf_records = os.listdir(raw_data_folder)
54 | tf_records = [x for x in tf_records if 'tfrecord' in x]
55 | tf_records = sorted(tf_records)
56 | for record_index, tf_record_name in enumerate(tf_records):
57 | if record_index % process != token:
58 | continue
59 | print('starting for raw pc', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name)
60 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name)
61 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='')
62 | segment_name = tf_record_name.split('.')[0]
63 |
64 | frame_num = 0
65 | pcs = dict()
66 |
67 | for data in dataset:
68 | frame = open_dataset.Frame()
69 | frame.ParseFromString(bytearray(data.numpy()))
70 |
71 | # extract the points
72 | (range_images, camera_projections, range_image_top_pose) = \
73 | frame_utils.parse_range_image_and_camera_projection(frame)
74 | points, cp_points = frame_utils.convert_range_image_to_point_cloud(
75 | frame, range_images, camera_projections, range_image_top_pose, ri_index=0)
76 | points_all = np.concatenate(points, axis=0)
77 | pcs[str(frame_num)] = points_all
78 |
79 | frame_num += 1
80 | if frame_num % 10 == 0:
81 | print('Point Cloud Record {} / {} FNumber {:}'.format(record_index + 1, len(tf_records), frame_num))
82 | print('{:} frames in total'.format(frame_num))
83 |
84 | np.savez_compressed(os.path.join(data_folder, "{}.npz".format(segment_name)), **pcs)
85 |
86 |
87 | if __name__ == '__main__':
88 | args.data_folder = os.path.join(args.data_folder, 'pc', 'raw_pc')
89 | os.makedirs(args.data_folder, exist_ok=True)
90 |
91 | if args.process > 1:
92 | # multiprocessing accelerate the speed
93 | pool = multiprocessing.Pool(args.process)
94 | for token in range(args.process):
95 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token))
96 | pool.close()
97 | pool.join()
98 | else:
99 | main(args.raw_data_folder, args.data_folder)
100 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/time_stamp.py:
--------------------------------------------------------------------------------
1 | """ Extract the time stamp information about each frame.
2 | Each sequence has a json file containing a list of timestamps.
3 | """
4 | import os
5 | import tensorflow.compat.v1 as tf
6 | import numpy as np
7 | import argparse
8 | import json
9 | from google.protobuf.descriptor import FieldDescriptor as FD
10 | tf.enable_eager_execution()
11 | import multiprocessing
12 | from waymo_open_dataset import dataset_pb2 as open_dataset
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--raw_data_folder', type=str, default='../../../datasets/waymo/validation/',
17 | help='location of tfrecords')
18 | parser.add_argument('--data_folder', type=str, default='../../../datasets/waymo/mot/',
19 | help='output folder')
20 | parser.add_argument('--process', type=int, default=1, help='use multiprocessing for acceleration')
21 | args = parser.parse_args()
22 |
23 |
24 | def main(raw_data_folder, data_folder, process, token):
25 | tf_records = os.listdir(raw_data_folder)
26 | tf_records = [x for x in tf_records if 'tfrecord' in x]
27 | tf_records = sorted(tf_records)
28 |
29 | for record_index, tf_record_name in enumerate(tf_records):
30 | if record_index % process != token:
31 | continue
32 | print('starting for time stamp: ', record_index + 1, ' / ', len(tf_records), ' ', tf_record_name)
33 | FILE_NAME = os.path.join(raw_data_folder, tf_record_name)
34 | dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='')
35 |
36 | time_stamps = list()
37 | for data in dataset:
38 | frame = open_dataset.Frame()
39 | frame.ParseFromString(bytearray(data.numpy()))
40 | time_stamps.append(frame.timestamp_micros)
41 |
42 | file_name = tf_record_name.split('.')[0]
43 | print(file_name)
44 | f = open(os.path.join(data_folder, '{:}.json'.format(file_name)), 'w')
45 | json.dump(time_stamps, f)
46 | f.close()
47 |
48 |
49 | if __name__ == '__main__':
50 | args.data_folder = os.path.join(args.data_folder, 'ts_info')
51 | os.makedirs(args.data_folder, exist_ok=True)
52 |
53 | if args.process > 1:
54 | pool = multiprocessing.Pool(args.process)
55 | for token in range(args.process):
56 | result = pool.apply_async(main, args=(args.raw_data_folder, args.data_folder, args.process, token))
57 | pool.close()
58 | pool.join()
59 | else:
60 | main(args.raw_data_folder, args.data_folder)
61 |
--------------------------------------------------------------------------------
/preprocessing/waymo_data/waymo_preprocess.sh:
--------------------------------------------------------------------------------
1 | raw_data_dir=$1
2 | data_dir=$2
3 | proc_num=$3
4 |
5 | # ego pose
6 | python ego_info.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num
7 |
8 | # # time stamp
9 | python time_stamp.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num
10 |
11 | # point cloud, useful for visualization
12 | python raw_pc.py --raw_data_folder $raw_data_dir --data_folder $data_dir --process $proc_num
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | filterpy
3 | shapely
4 | numba
5 | pyquaternion
6 | nuscenes-devkit
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 |
4 | setup(
5 | name='mot_3d',
6 | version='0.1',
7 | packages=['mot_3d']
8 | )
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil
2 | import mot_3d.visualization as visualization, mot_3d.utils as utils
3 | from mot_3d.data_protos import BBox, Validity
4 | from mot_3d.mot import MOTModel
5 | from mot_3d.frame_data import FrameData
6 | from data_loader import WaymoLoader
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | # running configurations
11 | parser.add_argument('--name', type=str, default='demo')
12 | parser.add_argument('--det_name', type=str, default='cp')
13 | parser.add_argument('--process', type=int, default=1)
14 | parser.add_argument('--visualize', action='store_true', default=False)
15 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug')
16 | parser.add_argument('--obj_type', type=str, default='vehicle', choices=['vehicle', 'pedestrian', 'cyclist'])
17 | # paths
18 | parser.add_argument('--config_path', type=str, default='configs/waymo_configs/vc_kf_giou.yaml')
19 | parser.add_argument('--result_folder', type=str, default='./mot_results/')
20 | parser.add_argument('--data_folder', type=str, default='./demo_data/')
21 | parser.add_argument('--gt_folder', type=str, default='./demo_data/detection/gt/dets/')
22 | args = parser.parse_args()
23 |
24 |
25 | def load_gt_bboxes(gt_folder, data_folder, segment_name, type_token):
26 | gt_info = np.load(os.path.join(gt_folder, '{:}.npz'.format(segment_name)),
27 | allow_pickle=True)
28 | ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)),
29 | allow_pickle=True)
30 | bboxes, ids, inst_types = gt_info['bboxes'], gt_info['ids'], gt_info['types']
31 | gt_ids, gt_bboxes = utils.inst_filter(ids, bboxes, inst_types, type_field=[type_token], id_trans=True)
32 |
33 | ego_keys = sorted(utils.str2int(ego_info.keys()))
34 | egos = [ego_info[str(key)] for key in ego_keys]
35 | gt_bboxes = gt_bbox2world(gt_bboxes, egos)
36 | return gt_bboxes, gt_ids
37 |
38 |
39 | def gt_bbox2world(bboxes, egos):
40 | frame_num = len(egos)
41 | for i in range(frame_num):
42 | ego = egos[i]
43 | bbox_num = len(bboxes[i])
44 | for j in range(bbox_num):
45 | bboxes[i][j] = BBox.bbox2world(ego, bboxes[i][j])
46 | return bboxes
47 |
48 |
49 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''):
50 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12))
51 | if pc is not None:
52 | visualizer.handler_pc(pc)
53 | for _, bbox in enumerate(gt_bboxes):
54 | visualizer.handler_box(bbox, message='', color='black')
55 | dets = [d for d in dets if d.s >= 0.1]
56 | for det in dets:
57 | visualizer.handler_box(det, message='%.2f' % det.s, color='green', linestyle='dashed')
58 | for _, (bbox, id, state) in enumerate(zip(bboxes, ids, states)):
59 | if Validity.valid(state):
60 | visualizer.handler_box(bbox, message=str(id), color='red')
61 | else:
62 | visualizer.handler_box(bbox, message=str(id), color='light_blue')
63 | visualizer.show()
64 | visualizer.close()
65 |
66 |
67 | def sequence_mot(configs, data_loader: WaymoLoader, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False):
68 | tracker = MOTModel(configs)
69 | frame_num = len(data_loader)
70 | IDs, bboxes, states, types = list(), list(), list(), list()
71 | for frame_index in range(data_loader.cur_frame, frame_num):
72 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(data_loader.type_token, sequence_id + 1, frame_index + 1, frame_num))
73 |
74 | # input data
75 | frame_data = next(data_loader)
76 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'],
77 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp'])
78 |
79 | # mot
80 | results = tracker.frame_mot(frame_data)
81 | result_pred_bboxes = [trk[0] for trk in results]
82 | result_pred_ids = [trk[1] for trk in results]
83 | result_pred_states = [trk[2] for trk in results]
84 | result_types = [trk[3] for trk in results]
85 |
86 | # visualization
87 | if visualize:
88 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states,
89 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}_{:}'.format(args.name, sequence_id, frame_index))
90 |
91 | # wrap for output
92 | IDs.append(result_pred_ids)
93 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
94 | bboxes.append(result_pred_bboxes)
95 | states.append(result_pred_states)
96 | types.append(result_types)
97 | return IDs, bboxes, states, types
98 |
99 |
100 | def main(name, obj_type, config_path, data_folder, det_data_folder, result_folder, gt_folder, start_frame=0, token=0, process=1):
101 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
102 | # simply knowing about all the segments
103 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
104 | print(file_names[0])
105 |
106 | # load model configs
107 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
108 |
109 | if obj_type == 'vehicle':
110 | type_token = 1
111 | elif obj_type == 'pedestrian':
112 | type_token = 2
113 | elif obj_type == 'cyclist':
114 | type_token = 4
115 | for file_index, file_name in enumerate(file_names[:]):
116 | if file_index % process != token:
117 | continue
118 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names)))
119 | segment_name = file_name.split('.')[0]
120 | data_loader = WaymoLoader(configs, [type_token], segment_name, data_folder, det_data_folder, start_frame)
121 | gt_bboxes, gt_ids = load_gt_bboxes(gt_folder, data_folder, segment_name, type_token)
122 |
123 | # real mot happens here
124 | ids, bboxes, states, types = sequence_mot(configs, data_loader, file_index, gt_bboxes, gt_ids, args.visualize)
125 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)),
126 | ids=ids, bboxes=bboxes, states=states)
127 |
128 |
129 | if __name__ == '__main__':
130 | result_folder = os.path.join(args.result_folder, args.name)
131 | os.makedirs(result_folder, exist_ok=True)
132 |
133 | summary_folder = os.path.join(result_folder, 'summary')
134 | os.makedirs(summary_folder, exist_ok=True)
135 |
136 | summary_folder = os.path.join(summary_folder, args.obj_type)
137 | os.makedirs(summary_folder, exist_ok=True)
138 |
139 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name)
140 |
141 | if args.process > 1:
142 | pool = multiprocessing.Pool(args.process)
143 | for token in range(args.process):
144 | result = pool.apply_async(main, args=(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder,
145 | result_folder, args.gt_folder, 0, token, args.process))
146 | pool.close()
147 | pool.join()
148 | else:
149 | main(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, result_folder,
150 | args.gt_folder, args.start_frame, 0, 1)
151 |
152 |
--------------------------------------------------------------------------------
/tools/main_nuscenes.py:
--------------------------------------------------------------------------------
1 | """ inference on the nuscenes dataset
2 | """
3 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil
4 | import mot_3d.visualization as visualization, mot_3d.utils as utils
5 | from mot_3d.data_protos import BBox, Validity
6 | from mot_3d.mot import MOTModel
7 | from mot_3d.frame_data import FrameData
8 | from data_loader import NuScenesLoader
9 | from pyquaternion import Quaternion
10 | from nuscenes.utils.data_classes import Box
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | # running configurations
15 | parser.add_argument('--name', type=str, default='debug')
16 | parser.add_argument('--det_name', type=str, default='cp')
17 | parser.add_argument('--process', type=int, default=1)
18 | parser.add_argument('--visualize', action='store_true', default=False)
19 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug')
20 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle')
21 | # paths
22 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation')
23 | parser.add_argument('--result_folder', type=str, default='../nu_mot_results/')
24 | parser.add_argument('--data_folder', type=str, default='../datasets/nuscenes/')
25 | args = parser.parse_args()
26 |
27 |
28 | def nu_array2mot_bbox(b):
29 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10]))
30 | mot_bbox = BBox(
31 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2],
32 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2],
33 | o=nu_box.orientation.yaw_pitch_roll[0]
34 | )
35 | if len(b) == 11:
36 | mot_bbox.s = b[-1]
37 | return mot_bbox
38 |
39 |
40 | def load_gt_bboxes(data_folder, type_token, segment_name):
41 | gt_info = np.load(os.path.join(data_folder, 'gt_info', '{:}.npz'.format(segment_name)), allow_pickle=True)
42 | ids, inst_types, bboxes = gt_info['ids'], gt_info['types'], gt_info['bboxes']
43 |
44 | mot_bboxes = list()
45 | for _, frame_bboxes in enumerate(bboxes):
46 | mot_bboxes.append([])
47 | for _, b in enumerate(frame_bboxes):
48 | mot_bboxes[-1].append(BBox.bbox2array(nu_array2mot_bbox(b)))
49 | gt_ids, gt_bboxes = utils.inst_filter(ids, mot_bboxes, inst_types,
50 | type_field=type_token, id_trans=True)
51 | return gt_bboxes, gt_ids
52 |
53 |
54 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''):
55 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12))
56 | if pc is not None:
57 | visualizer.handler_pc(pc)
58 |
59 | if gt_bboxes is not None:
60 | for _, bbox in enumerate(gt_bboxes):
61 | visualizer.handler_box(bbox, message='', color='black')
62 | dets = [d for d in dets if d.s >= 0.01]
63 | for det in dets:
64 | visualizer.handler_box(det, message='%.2f' % det.s, color='gray', linestyle='dashed')
65 | for _, (bbox, id, state_string) in enumerate(zip(bboxes, ids, states)):
66 | if Validity.valid(state_string):
67 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='red')
68 | else:
69 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='light_blue')
70 | visualizer.show()
71 | visualizer.close()
72 |
73 |
74 | def sequence_mot(configs, data_loader, obj_type, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False):
75 | tracker = MOTModel(configs)
76 | frame_num = len(data_loader)
77 | IDs, bboxes, states, types = list(), list(), list(), list()
78 |
79 | for frame_index in range(data_loader.cur_frame, frame_num):
80 | if frame_index % 10 == 0:
81 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(obj_type, sequence_id, frame_index + 1, frame_num))
82 |
83 | # input data
84 | frame_data = next(data_loader)
85 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'],
86 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp'])
87 |
88 | # mot
89 | results = tracker.frame_mot(frame_data)
90 | result_pred_bboxes = [trk[0] for trk in results]
91 | result_pred_ids = [trk[1] for trk in results]
92 | result_pred_states = [trk[2] for trk in results]
93 | result_types = [trk[3] for trk in results]
94 |
95 | # visualization
96 | if visualize:
97 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states,
98 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}'.format(args.name, frame_index))
99 |
100 | # wrap for output
101 | IDs.append(result_pred_ids)
102 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
103 | bboxes.append(result_pred_bboxes)
104 | states.append(result_pred_states)
105 | types.append(result_types)
106 |
107 | return IDs, bboxes, states, types
108 |
109 |
110 | def main(name, obj_types, config_path, data_folder, det_data_folder, result_folder, start_frame=0, token=0, process=1):
111 | for obj_type in obj_types:
112 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
113 | # simply knowing about all the segments
114 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
115 |
116 | # load model configs
117 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
118 |
119 | for file_index, file_name in enumerate(file_names[:]):
120 | if file_index % process != token:
121 | continue
122 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names)))
123 | segment_name = file_name.split('.')[0]
124 |
125 | data_loader = NuScenesLoader(configs, [obj_type], segment_name, data_folder, det_data_folder, start_frame)
126 |
127 | gt_bboxes, gt_ids = load_gt_bboxes(data_folder, [obj_type], segment_name)
128 | ids, bboxes, states, types = sequence_mot(configs, data_loader, obj_type, file_index, gt_bboxes, gt_ids, args.visualize)
129 |
130 | frame_num = len(ids)
131 | for frame_index in range(frame_num):
132 | id_num = len(ids[frame_index])
133 | for i in range(id_num):
134 | ids[frame_index][i] = '{:}_{:}'.format(file_index, ids[frame_index][i])
135 |
136 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)),
137 | ids=ids, bboxes=bboxes, states=states, types=types)
138 |
139 |
140 | if __name__ == '__main__':
141 | result_folder = os.path.join(args.result_folder, args.name)
142 | os.makedirs(result_folder, exist_ok=True)
143 | summary_folder = os.path.join(result_folder, 'summary')
144 | os.makedirs(summary_folder, exist_ok=True)
145 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name)
146 |
147 | obj_types = args.obj_types.split(',')
148 | for obj_type in obj_types:
149 | tmp_summary_folder = os.path.join(summary_folder, obj_type)
150 | os.makedirs(tmp_summary_folder, exist_ok=True)
151 |
152 | if args.process > 1:
153 | pool = multiprocessing.Pool(args.process)
154 | for token in range(args.process):
155 | result = pool.apply_async(main, args=(args.name, obj_types, args.config_path, args.data_folder, det_data_folder,
156 | result_folder, 0, token, args.process))
157 | pool.close()
158 | pool.join()
159 | else:
160 | main(args.name, obj_types, args.config_path, args.data_folder, det_data_folder,
161 | result_folder, args.start_frame, 0, 1)
--------------------------------------------------------------------------------
/tools/main_nuscenes_10hz.py:
--------------------------------------------------------------------------------
1 | """ inference on the nuscenes dataset
2 | """
3 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil
4 | import mot_3d.visualization as visualization, mot_3d.utils as utils
5 | from mot_3d.data_protos import BBox, Validity
6 | from mot_3d.mot import MOTModel
7 | from mot_3d.frame_data import FrameData
8 | from data_loader import NuScenesLoader10Hz
9 | from pyquaternion import Quaternion
10 | from nuscenes.utils.data_classes import Box
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | # running configurations
15 | parser.add_argument('--name', type=str, default='debug')
16 | parser.add_argument('--det_name', type=str, default='cp')
17 | parser.add_argument('--process', type=int, default=1)
18 | parser.add_argument('--visualize', action='store_true', default=False)
19 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug')
20 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle')
21 | # paths
22 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation')
23 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/10hz_exps/')
24 | parser.add_argument('--data_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/datasets/nuscenes/')
25 | args = parser.parse_args()
26 |
27 |
28 | def nu_array2mot_bbox(b):
29 | nu_box = Box(b[:3], b[3:6], Quaternion(b[6:10]))
30 | mot_bbox = BBox(
31 | x=nu_box.center[0], y=nu_box.center[1], z=nu_box.center[2],
32 | w=nu_box.wlh[0], l=nu_box.wlh[1], h=nu_box.wlh[2],
33 | o=nu_box.orientation.yaw_pitch_roll[0]
34 | )
35 | if len(b) == 11:
36 | mot_bbox.s = b[-1]
37 | return mot_bbox
38 |
39 |
40 | def load_gt_bboxes(data_folder, type_token, segment_name):
41 | gt_info = np.load(os.path.join(data_folder, 'gt_info', '{:}.npz'.format(segment_name)), allow_pickle=True)
42 | ids, inst_types, bboxes = gt_info['ids'], gt_info['types'], gt_info['bboxes']
43 |
44 | mot_bboxes = list()
45 | for _, frame_bboxes in enumerate(bboxes):
46 | mot_bboxes.append([])
47 | for _, b in enumerate(frame_bboxes):
48 | mot_bboxes[-1].append(BBox.bbox2array(nu_array2mot_bbox(b)))
49 | gt_ids, gt_bboxes = utils.inst_filter(ids, mot_bboxes, inst_types,
50 | type_field=type_token, id_trans=True)
51 | return gt_bboxes, gt_ids
52 |
53 |
54 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, is_key_frame=True, dets=None, name=''):
55 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12))
56 | if pc is not None:
57 | visualizer.handler_pc(pc)
58 | for _, bbox in enumerate(gt_bboxes):
59 | visualizer.handler_box(bbox, message='', color='black')
60 | dets = [d for d in dets if d.s >= 0.01]
61 |
62 | if is_key_frame:
63 | line_style = 'solid'
64 | else:
65 | line_style = 'dashed'
66 |
67 | for det in dets:
68 | visualizer.handler_box(det, message='%.2f' % det.s, color='gray', linestyle='dashed')
69 | for _, (bbox, id, state_string) in enumerate(zip(bboxes, ids, states)):
70 | if Validity.valid(state_string):
71 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='red', linestyle=line_style)
72 | else:
73 | visualizer.handler_box(bbox, message='%.2f %s'%(bbox.s, id), color='light_blue', linestyle=line_style)
74 | visualizer.show()
75 | visualizer.close()
76 |
77 |
78 | def sequence_mot(configs, data_loader, obj_type, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False):
79 | tracker = MOTModel(configs)
80 | frame_num = len(data_loader)
81 | IDs, bboxes, states, types = list(), list(), list(), list()
82 |
83 | if gt_bboxes is None:
84 | gt_bboxes = [[] for i in range(data_loader.cur_frame, frame_num)]
85 | gt_ids = [[] for i in range(data_loader.cur_frame, frame_num)]
86 |
87 | for frame_index in range(data_loader.cur_frame, frame_num):
88 | if frame_index % 10 == 0:
89 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(obj_type, sequence_id, frame_index + 1, frame_num))
90 |
91 | # input data
92 | frame_data = next(data_loader)
93 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'],
94 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp'])
95 |
96 | # mot
97 | results = tracker.frame_mot(frame_data)
98 | result_pred_bboxes = [trk[0] for trk in results]
99 | result_pred_ids = [trk[1] for trk in results]
100 | result_pred_states = [trk[2] for trk in results]
101 | result_types = [trk[3] for trk in results]
102 |
103 | # visualization
104 | if visualize:
105 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states,
106 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc,
107 | is_key_frame=frame_data.aux_info['is_key_frame'], dets=frame_data.dets, name='{:}_{:}'.format(args.name, frame_index))
108 |
109 | # wrap for output
110 | IDs.append(result_pred_ids)
111 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
112 | bboxes.append(result_pred_bboxes)
113 | states.append(result_pred_states)
114 | types.append(result_types)
115 |
116 | return IDs, bboxes, states, types
117 |
118 |
119 | def main(name, obj_types, config_path, data_folder, det_data_folder, result_folder, start_frame=0, token=0, process=1):
120 | for obj_type in obj_types:
121 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
122 | # simply knowing about all the segments
123 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
124 |
125 | # load model configs
126 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
127 |
128 | for file_index, file_name in enumerate(file_names[:]):
129 | if file_index % process != token:
130 | continue
131 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names)))
132 | segment_name = file_name.split('.')[0]
133 |
134 | data_loader = NuScenesLoader10Hz(configs, [obj_type], segment_name, data_folder, det_data_folder, start_frame)
135 |
136 | gt_bboxes, gt_ids = load_gt_bboxes(data_folder, [obj_type], segment_name)
137 | ids, bboxes, states, types = sequence_mot(configs, data_loader, obj_type, file_index, gt_bboxes, gt_ids, args.visualize)
138 |
139 | frame_num = len(ids)
140 | for frame_index in range(frame_num):
141 | id_num = len(ids[frame_index])
142 | for i in range(id_num):
143 | ids[frame_index][i] = '{:}_{:}'.format(file_index, ids[frame_index][i])
144 |
145 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)),
146 | ids=ids, bboxes=bboxes, states=states, types=types)
147 |
148 |
149 | if __name__ == '__main__':
150 | result_folder = os.path.join(args.result_folder, args.name)
151 | os.makedirs(result_folder, exist_ok=True)
152 | summary_folder = os.path.join(result_folder, 'summary')
153 | os.makedirs(summary_folder, exist_ok=True)
154 |
155 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name)
156 |
157 | obj_types = args.obj_types.split(',')
158 | for obj_type in obj_types:
159 | tmp_summary_folder = os.path.join(summary_folder, obj_type)
160 | os.makedirs(tmp_summary_folder, exist_ok=True)
161 |
162 | if args.process > 1:
163 | pool = multiprocessing.Pool(args.process)
164 | for token in range(args.process):
165 | result = pool.apply_async(main, args=(args.name, obj_types, args.config_path, args.data_folder, det_data_folder,
166 | result_folder, 0, token, args.process))
167 | pool.close()
168 | pool.join()
169 | else:
170 | main(args.name, obj_types, args.config_path, args.data_folder, det_data_folder,
171 | result_folder, args.start_frame, 0, 1)
172 |
--------------------------------------------------------------------------------
/tools/main_waymo.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, argparse, json, sys, numba, yaml, multiprocessing, shutil
2 | import mot_3d.visualization as visualization, mot_3d.utils as utils
3 | from mot_3d.data_protos import BBox, Validity
4 | from mot_3d.mot import MOTModel
5 | from mot_3d.frame_data import FrameData
6 | from data_loader import WaymoLoader
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | # running configurations
11 | parser.add_argument('--name', type=str, default='debug')
12 | parser.add_argument('--det_name', type=str, default='public')
13 | parser.add_argument('--process', type=int, default=1)
14 | parser.add_argument('--visualize', action='store_true', default=False)
15 | parser.add_argument('--start_frame', type=int, default=0, help='start at a middle frame for debug')
16 | parser.add_argument('--obj_type', type=str, default='vehicle', choices=['vehicle', 'pedestrian', 'cyclist'])
17 | # paths
18 | parser.add_argument('--config_path', type=str, default='configs/config.yaml', help='config file path, follow the path in the documentation')
19 | parser.add_argument('--result_folder', type=str, default='../mot_results/')
20 | parser.add_argument('--data_folder', type=str, default='../datasets/waymo/mot/')
21 | parser.add_argument('--gt_folder', type=str, default='../datasets/waymo/mot/detection/gt/dets/')
22 | args = parser.parse_args()
23 |
24 |
25 | def load_gt_bboxes(gt_folder, data_folder, segment_name, type_token):
26 | gt_info = np.load(os.path.join(gt_folder, '{:}.npz'.format(segment_name)),
27 | allow_pickle=True)
28 | ego_info = np.load(os.path.join(data_folder, 'ego_info', '{:}.npz'.format(segment_name)),
29 | allow_pickle=True)
30 | bboxes, ids, inst_types = gt_info['bboxes'], gt_info['ids'], gt_info['types']
31 | gt_ids, gt_bboxes = utils.inst_filter(ids, bboxes, inst_types, type_field=[type_token], id_trans=True)
32 |
33 | ego_keys = sorted(utils.str2int(ego_info.keys()))
34 | egos = [ego_info[str(key)] for key in ego_keys]
35 | gt_bboxes = gt_bbox2world(gt_bboxes, egos)
36 | return gt_bboxes, gt_ids
37 |
38 |
39 | def gt_bbox2world(bboxes, egos):
40 | frame_num = len(egos)
41 | for i in range(frame_num):
42 | ego = egos[i]
43 | bbox_num = len(bboxes[i])
44 | for j in range(bbox_num):
45 | bboxes[i][j] = BBox.bbox2world(ego, bboxes[i][j])
46 | return bboxes
47 |
48 |
49 | def frame_visualization(bboxes, ids, states, gt_bboxes=None, gt_ids=None, pc=None, dets=None, name=''):
50 | visualizer = visualization.Visualizer2D(name=name, figsize=(12, 12))
51 | if pc is not None:
52 | visualizer.handler_pc(pc)
53 | for _, bbox in enumerate(gt_bboxes):
54 | visualizer.handler_box(bbox, message='', color='black')
55 | dets = [d for d in dets if d.s >= 0.1]
56 | for det in dets:
57 | visualizer.handler_box(det, message='%.2f' % det.s, color='purple', linestyle='dashed')
58 | for _, (bbox, id, state) in enumerate(zip(bboxes, ids, states)):
59 | if Validity.valid(state):
60 | visualizer.handler_box(bbox, message=str(id), color='red')
61 | else:
62 | visualizer.handler_box(bbox, message=str(id), color='light_blue')
63 | # visualizer.show()
64 | visualizer.save('imgs/{:}.png'.format(name))
65 | visualizer.close()
66 |
67 |
68 | def sequence_mot(configs, data_loader: WaymoLoader, sequence_id, gt_bboxes=None, gt_ids=None, visualize=False):
69 | tracker = MOTModel(configs)
70 | frame_num = len(data_loader)
71 | IDs, bboxes, states, types = list(), list(), list(), list()
72 | for frame_index in range(data_loader.cur_frame, frame_num):
73 | print('TYPE {:} SEQ {:} Frame {:} / {:}'.format(data_loader.type_token, sequence_id + 1, frame_index + 1, frame_num))
74 |
75 | # input data
76 | frame_data = next(data_loader)
77 | frame_data = FrameData(dets=frame_data['dets'], ego=frame_data['ego'], pc=frame_data['pc'],
78 | det_types=frame_data['det_types'], aux_info=frame_data['aux_info'], time_stamp=frame_data['time_stamp'])
79 |
80 | # mot
81 | results = tracker.frame_mot(frame_data)
82 | result_pred_bboxes = [trk[0] for trk in results]
83 | result_pred_ids = [trk[1] for trk in results]
84 | result_pred_states = [trk[2] for trk in results]
85 | result_types = [trk[3] for trk in results]
86 |
87 | # visualization
88 | if visualize:
89 | frame_visualization(result_pred_bboxes, result_pred_ids, result_pred_states,
90 | gt_bboxes[frame_index], gt_ids[frame_index], frame_data.pc, dets=frame_data.dets, name='{:}_{:}_{:}'.format(args.name, sequence_id, frame_index))
91 |
92 | # wrap for output
93 | IDs.append(result_pred_ids)
94 | result_pred_bboxes = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
95 | bboxes.append(result_pred_bboxes)
96 | states.append(result_pred_states)
97 | types.append(result_types)
98 | return IDs, bboxes, states, types
99 |
100 |
101 | def main(name, obj_type, config_path, data_folder, det_data_folder, result_folder, gt_folder, start_frame=0, token=0, process=1):
102 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
103 | # simply knowing about all the segments
104 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
105 | print(file_names[0])
106 |
107 | # load model configs
108 | configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
109 |
110 | if obj_type == 'vehicle':
111 | type_token = 1
112 | elif obj_type == 'pedestrian':
113 | type_token = 2
114 | elif obj_type == 'cyclist':
115 | type_token = 4
116 | for file_index, file_name in enumerate(file_names[:]):
117 | if file_index % process != token:
118 | continue
119 | print('START TYPE {:} SEQ {:} / {:}'.format(obj_type, file_index + 1, len(file_names)))
120 | segment_name = file_name.split('.')[0]
121 | data_loader = WaymoLoader(configs, [type_token], segment_name, data_folder, det_data_folder, start_frame)
122 | gt_bboxes, gt_ids = load_gt_bboxes(gt_folder, data_folder, segment_name, type_token)
123 |
124 | # real mot happens here
125 | ids, bboxes, states, types = sequence_mot(configs, data_loader, file_index, gt_bboxes, gt_ids, args.visualize)
126 | np.savez_compressed(os.path.join(summary_folder, '{}.npz'.format(segment_name)),
127 | ids=ids, bboxes=bboxes, states=states)
128 |
129 |
130 | if __name__ == '__main__':
131 | result_folder = os.path.join(args.result_folder, args.name)
132 | os.makedirs(result_folder, exist_ok=True)
133 |
134 | summary_folder = os.path.join(result_folder, 'summary')
135 | os.makedirs(summary_folder, exist_ok=True)
136 |
137 | summary_folder = os.path.join(summary_folder, args.obj_type)
138 | os.makedirs(summary_folder, exist_ok=True)
139 |
140 | det_data_folder = os.path.join(args.data_folder, 'detection', args.det_name)
141 |
142 | if args.process > 1:
143 | pool = multiprocessing.Pool(args.process)
144 | for token in range(args.process):
145 | result = pool.apply_async(main, args=(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder,
146 | result_folder, args.gt_folder, 0, token, args.process))
147 | pool.close()
148 | pool.join()
149 | else:
150 | main(args.name, args.obj_type, args.config_path, args.data_folder, det_data_folder, result_folder,
151 | args.gt_folder, args.start_frame, 0, 1)
152 |
153 |
--------------------------------------------------------------------------------
/tools/nuscenes_result_creation.py:
--------------------------------------------------------------------------------
1 | import os, argparse, json, numpy as np
2 | from pyquaternion import Quaternion
3 | from mot_3d.data_protos import BBox, Validity
4 | from tqdm import tqdm
5 |
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--name', type=str, default='debug')
9 | parser.add_argument('--obj_types', type=str, default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle')
10 | parser.add_argument('--result_folder', type=str, default='../nu_mot_results/')
11 | parser.add_argument('--data_folder', type=str, default='../datasets/nuscenes/')
12 | args = parser.parse_args()
13 |
14 |
15 | def bbox_array2nuscenes_format(bbox_array):
16 | translation = bbox_array[:3].tolist()
17 | size = bbox_array[4:7].tolist()
18 | size = [size[1], size[0], size[2]]
19 | velocity = [0.0, 0.0]
20 | score = bbox_array[-1]
21 |
22 | yaw = bbox_array[3]
23 | rot_matrix = np.array([[np.cos(yaw), -np.sin(yaw), 0, 0],
24 | [np.sin(yaw), np.cos(yaw), 0, 0],
25 | [0, 0, 1, 0],
26 | [0, 1, 0, 1]])
27 | q = Quaternion(matrix=rot_matrix)
28 | rotation = q.q.tolist()
29 |
30 | sample_result = {
31 | 'translation': translation,
32 | 'size': size,
33 | 'velocity': velocity,
34 | 'rotation': rotation,
35 | 'tracking_score': score
36 | }
37 | return sample_result
38 |
39 |
40 | def main(name, obj_types, data_folder, result_folder, output_folder):
41 | for obj_type in obj_types:
42 | print('CONVERTING {:}'.format(obj_type))
43 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
44 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
45 | token_info_folder = os.path.join(data_folder, 'token_info')
46 |
47 | results = dict()
48 | pbar = tqdm(total=len(file_names))
49 | for file_index, file_name in enumerate(file_names):
50 | segment_name = file_name.split('.')[0]
51 | token_info = json.load(open(os.path.join(token_info_folder, '{:}.json'.format(segment_name)), 'r'))
52 | mot_results = np.load(os.path.join(summary_folder, '{:}.npz'.format(segment_name)), allow_pickle=True)
53 |
54 | ids, bboxes, states, types = \
55 | mot_results['ids'], mot_results['bboxes'], mot_results['states'], mot_results['types']
56 | frame_num = len(ids)
57 | for frame_index in range(frame_num):
58 | sample_token = token_info[frame_index]
59 | results[sample_token] = list()
60 | frame_bboxes, frame_ids, frame_types, frame_states = \
61 | bboxes[frame_index], ids[frame_index], types[frame_index], states[frame_index]
62 |
63 | frame_obj_num = len(frame_ids)
64 | for i in range(frame_obj_num):
65 | sample_result = bbox_array2nuscenes_format(frame_bboxes[i])
66 | sample_result['sample_token'] = sample_token
67 | sample_result['tracking_id'] = frame_types[i] + '_' + str(frame_ids[i])
68 | sample_result['tracking_name'] = frame_types[i]
69 | results[sample_token].append(sample_result)
70 | pbar.update(1)
71 | pbar.close()
72 | submission_file = {
73 | 'meta': {
74 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False
75 | },
76 | 'results': results
77 | }
78 |
79 | f = open(os.path.join(output_folder, obj_type, 'results.json'), 'w')
80 | json.dump(submission_file, f)
81 | f.close()
82 | return
83 |
84 |
85 | if __name__ == '__main__':
86 | result_folder = os.path.join(args.result_folder, args.name)
87 | obj_types = args.obj_types.split(',')
88 | output_folder = os.path.join(result_folder, 'results')
89 | for obj_type in obj_types:
90 | tmp_output_folder = os.path.join(result_folder, 'results', obj_type)
91 | os.makedirs(tmp_output_folder, exist_ok=True)
92 |
93 | main(args.name, obj_types, args.data_folder, result_folder, output_folder)
94 |
--------------------------------------------------------------------------------
/tools/nuscenes_result_creation_10hz.py:
--------------------------------------------------------------------------------
1 | import os, argparse, json, numpy as np
2 | from pyquaternion import Quaternion
3 | from mot_3d.data_protos import BBox, Validity
4 | from tqdm import tqdm
5 |
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--name', type=str, default='debug')
9 | parser.add_argument('--obj_types', type=str, default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle')
10 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/10hz_exps/')
11 | parser.add_argument('--data_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/datasets/nuscenes/validation_20hz/')
12 | args = parser.parse_args()
13 |
14 |
15 | def bbox_array2nuscenes_format(bbox_array):
16 | translation = bbox_array[:3].tolist()
17 | size = bbox_array[4:7].tolist()
18 | size = [size[1], size[0], size[2]]
19 | velocity = [0.0, 0.0]
20 | score = bbox_array[-1]
21 |
22 | yaw = bbox_array[3]
23 | rot_matrix = np.array([[np.cos(yaw), -np.sin(yaw), 0, 0],
24 | [np.sin(yaw), np.cos(yaw), 0, 0],
25 | [0, 0, 1, 0],
26 | [0, 1, 0, 1]])
27 | q = Quaternion(matrix=rot_matrix)
28 | rotation = q.q.tolist()
29 |
30 | sample_result = {
31 | 'translation': translation,
32 | 'size': size,
33 | 'velocity': velocity,
34 | 'rotation': rotation,
35 | 'tracking_score': score
36 | }
37 | return sample_result
38 |
39 |
40 | def main(name, obj_types, data_folder, result_folder, output_folder):
41 | for obj_type in obj_types:
42 | print('CONVERTING {:}'.format(obj_type))
43 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
44 | file_names = sorted(os.listdir(os.path.join(data_folder, 'ego_info')))
45 | token_info_folder = os.path.join(data_folder, 'token_info')
46 |
47 | results = dict()
48 | pbar = tqdm(total=len(file_names))
49 | for file_index, file_name in enumerate(file_names):
50 | segment_name = file_name.split('.')[0]
51 | token_info = json.load(open(os.path.join(token_info_folder, '{:}.json'.format(segment_name)), 'r'))
52 | mot_results = np.load(os.path.join(summary_folder, '{:}.npz'.format(segment_name)), allow_pickle=True)
53 |
54 | ids, bboxes, states, types = \
55 | mot_results['ids'], mot_results['bboxes'], mot_results['states'], mot_results['types']
56 | frame_num = len(ids)
57 | token_info = [t for t in token_info if t[3]]
58 | for frame_index in range(frame_num):
59 | frame_token = token_info[frame_index]
60 | is_key_frame = frame_token[1]
61 | if not is_key_frame:
62 | continue
63 |
64 | sample_token = frame_token[2]
65 | results[sample_token] = list()
66 | frame_bboxes, frame_ids, frame_types, frame_states = \
67 | bboxes[frame_index], ids[frame_index], types[frame_index], states[frame_index]
68 |
69 | frame_obj_num = len(frame_ids)
70 | for i in range(frame_obj_num):
71 | sample_result = bbox_array2nuscenes_format(frame_bboxes[i])
72 | sample_result['sample_token'] = sample_token
73 | sample_result['tracking_id'] = frame_types[i] + '_' + str(frame_ids[i])
74 | sample_result['tracking_name'] = frame_types[i]
75 | results[sample_token].append(sample_result)
76 | pbar.update(1)
77 | pbar.close()
78 | submission_file = {
79 | 'meta': {
80 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False
81 | },
82 | 'results': results
83 | }
84 |
85 | f = open(os.path.join(output_folder, obj_type, 'results.json'), 'w')
86 | json.dump(submission_file, f)
87 | f.close()
88 | return
89 |
90 |
91 | if __name__ == '__main__':
92 | result_folder = os.path.join(args.result_folder, args.name)
93 | obj_types = args.obj_types.split(',')
94 | output_folder = os.path.join(result_folder, 'results')
95 | for obj_type in obj_types:
96 | tmp_output_folder = os.path.join(result_folder, 'results', obj_type)
97 | if not os.path.exists(tmp_output_folder):
98 | os.makedirs(tmp_output_folder)
99 |
100 | main(args.name, obj_types, args.data_folder, result_folder, output_folder)
101 |
--------------------------------------------------------------------------------
/tools/nuscenes_type_merge.py:
--------------------------------------------------------------------------------
1 | import os, argparse, json, numpy as np
2 | from pyquaternion import Quaternion
3 | from mot_3d.data_protos import BBox, Validity
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--name', type=str, default='debug')
7 | parser.add_argument('--obj_types', default='car,bus,trailer,truck,pedestrian,bicycle,motorcycle')
8 | parser.add_argument('--result_folder', type=str, default='/mnt/truenas/scratch/ziqi.pang/nu_mot_results/')
9 | args = parser.parse_args()
10 |
11 |
12 | def main(name, obj_types, result_folder):
13 | raw_results = list()
14 | for type_name in obj_types:
15 | path = os.path.join(result_folder, type_name, 'results.json')
16 | f = open(path, 'r')
17 | raw_results.append(json.load(f)['results'])
18 | f.close()
19 |
20 | results = raw_results[0]
21 | sample_tokens = list(results.keys())
22 | for token in sample_tokens:
23 | for i in range(1, len(obj_types)):
24 | results[token] += raw_results[i][token]
25 |
26 | submission_file = {
27 | 'meta': {
28 | 'use_camera': False, 'use_lidar': True, 'use_radar': False, 'use_map': False, 'use_external': False
29 | },
30 | 'results': results
31 | }
32 |
33 | f = open(os.path.join(result_folder, 'results.json'), 'w')
34 | json.dump(submission_file, f)
35 | f.close()
36 | return
37 |
38 |
39 | if __name__ == '__main__':
40 | result_folder = os.path.join(args.result_folder, args.name, 'results')
41 | obj_types = args.obj_types.split(',')
42 | main(args.name, obj_types, result_folder)
--------------------------------------------------------------------------------
/tools/waymo_pred_bin.py:
--------------------------------------------------------------------------------
1 | from waymo_open_dataset import dataset_pb2
2 | from waymo_open_dataset import label_pb2
3 | from waymo_open_dataset.protos import metrics_pb2
4 | import os, time, numpy as np, sys, pickle as pkl
5 | import argparse, json
6 | from copy import deepcopy
7 | from mot_3d.data_protos import BBox, Validity
8 | import mot_3d.utils as utils
9 | from tqdm import tqdm
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--name', type=str, default='baseline')
14 | parser.add_argument('--obj_types', type=str, default='vehicle,pedestrian,cyclist')
15 | parser.add_argument('--result_folder', type=str, default='../mot_results/')
16 | parser.add_argument('--data_folder', type=str, default='../datasets/waymo/mot/')
17 | args = parser.parse_args()
18 |
19 |
20 | def get_context_name(file_name: str):
21 | context = file_name.split('.')[0] # file name
22 | context = context.split('-')[1] # after segment
23 | context = context.split('w')[0] # before with
24 | context = context[:-1]
25 | return context
26 |
27 | def pred_content_filter(pred_contents, pred_states):
28 | result_contents = list()
29 | for contents, states in zip(pred_contents, pred_states):
30 | indices = [i for i in range(len(states)) if Validity.valid(states[i])]
31 | frame_contents = [contents[i] for i in indices]
32 | result_contents.append(frame_contents)
33 | return result_contents
34 |
35 | def main(name, obj_type, result_folder, raw_data_folder, output_folder, output_file_name):
36 | summary_folder = os.path.join(result_folder, 'summary', obj_type)
37 | file_names = sorted(os.listdir(summary_folder))[:]
38 |
39 | if obj_type == 'vehicle':
40 | type_token = 1
41 | elif obj_type == 'pedestrian':
42 | type_token = 2
43 | elif obj_type == 'cyclist':
44 | type_token = 4
45 |
46 | ts_info_folder = os.path.join(raw_data_folder, 'ts_info')
47 | ego_info_folder = os.path.join(raw_data_folder, 'ego_info')
48 | obj_list = list()
49 |
50 | print('Converting TYPE {:} into WAYMO Format'.format(obj_type))
51 | pbar = tqdm(total=len(file_names))
52 | for file_index, file_name in enumerate(file_names[:]):
53 | file_name_prefix = file_name.split('.')[0]
54 | context_name = get_context_name(file_name)
55 |
56 | ts_path = os.path.join(ts_info_folder, '{}.json'.format(file_name_prefix))
57 | ts_data = json.load(open(ts_path, 'r')) # list of time stamps by order of frame
58 |
59 | # load ego motions
60 | ego_motions = np.load(os.path.join(ego_info_folder, '{:}.npz'.format(file_name_prefix)), allow_pickle=True)
61 |
62 | pred_result = np.load(os.path.join(summary_folder, file_name), allow_pickle=True)
63 | pred_ids, pred_bboxes, pred_states = pred_result['ids'], pred_result['bboxes'], pred_result['states']
64 | pred_bboxes = pred_content_filter(pred_bboxes, pred_states)
65 | pred_ids = pred_content_filter(pred_ids, pred_states)
66 | pred_velos, pred_accels = None, None
67 |
68 | obj_list += create_sequence(pred_ids, pred_bboxes, type_token, context_name,
69 | ts_data, ego_motions, pred_velos, pred_accels)
70 | pbar.update(1)
71 | pbar.close()
72 | objects = metrics_pb2.Objects()
73 | for obj in obj_list:
74 | objects.objects.append(obj)
75 |
76 | output_folder = os.path.join(output_folder, obj_type)
77 | if not os.path.exists(output_folder):
78 | os.makedirs(output_folder)
79 | output_path = os.path.join(output_folder, '{:}.bin'.format(output_file_name))
80 | f = open(output_path, 'wb')
81 | f.write(objects.SerializeToString())
82 | f.close()
83 | return
84 |
85 |
86 | def create_single_pred_bbox(id, bbox, type_token, time_stamp, context_name, inv_ego_motion, velo, accel):
87 | o = metrics_pb2.Object()
88 | o.context_name = context_name
89 | o.frame_timestamp_micros = time_stamp
90 | box = label_pb2.Label.Box()
91 |
92 | proto_box = BBox.array2bbox(bbox)
93 | proto_box = BBox.bbox2world(inv_ego_motion, proto_box)
94 | bbox = BBox.bbox2array(proto_box)
95 |
96 | box.center_x, box.center_y, box.center_z, box.heading = bbox[:4]
97 | box.length, box.width, box.height = bbox[4:7]
98 | o.object.box.CopyFrom(box)
99 | o.score = bbox[-1]
100 |
101 | meta_data = label_pb2.Label.Metadata()
102 | o.object.metadata.CopyFrom(meta_data)
103 |
104 | o.object.id = '{:}_{:}'.format(type_token, id)
105 | o.object.type = type_token
106 | return o
107 |
108 |
109 | def create_sequence(pred_ids, pred_bboxes, type_token, context_name, time_stamps, ego_motions, pred_velos, pred_accels):
110 | frame_num = len(pred_ids)
111 | sequence_objects = list()
112 | for frame_index in range(frame_num):
113 | time_stamp = time_stamps[frame_index]
114 | frame_obj_num = len(pred_ids[frame_index])
115 | ego_motion = ego_motions[str(frame_index)]
116 | inv_ego_motion = np.linalg.inv(ego_motion)
117 | for obj_index in range(frame_obj_num):
118 | pred_id = pred_ids[frame_index][obj_index]
119 | pred_bbox = pred_bboxes[frame_index][obj_index]
120 | pred_velo, pred_accel = None, None
121 | sequence_objects.append(create_single_pred_bbox(
122 | pred_id, pred_bbox, type_token, time_stamp, context_name, inv_ego_motion, pred_velo, pred_accel))
123 | return sequence_objects
124 |
125 |
126 | def merge_results(output_folder, obj_types, output_file_name):
127 | print('Merging different object types')
128 | result_objs = list()
129 | for obj_type in obj_types:
130 | bin_path = os.path.join(output_folder, obj_type, '{:}.bin'.format(output_file_name))
131 | f = open(bin_path, 'rb')
132 | objects = metrics_pb2.Objects()
133 | objects.ParseFromString(f.read())
134 | f.close()
135 | objects = objects.objects
136 | result_objs += objects
137 |
138 | output_objs = metrics_pb2.Objects()
139 | for obj in result_objs:
140 | output_objs.objects.append(obj)
141 |
142 | output_path = os.path.join(output_folder, '{:}.bin'.format(output_file_name))
143 | f = open(output_path, 'wb')
144 | f.write(output_objs.SerializeToString())
145 | f.close()
146 |
147 |
148 | if __name__ == '__main__':
149 | result_folder = os.path.join(args.result_folder, args.name)
150 | output_folder = os.path.join(result_folder, 'bin')
151 | os.makedirs(output_folder, exist_ok=True)
152 |
153 | obj_types = args.obj_types.split(',')
154 | for obj_type in obj_types:
155 | main(args.name, obj_type, result_folder, args.data_folder, output_folder, 'pred')
156 |
157 | merge_results(output_folder, obj_types, 'pred')
158 |
--------------------------------------------------------------------------------