├── .gitignore
├── ByteTrack_Convert2ONNX.ipynb
├── LICENSE
├── README.md
├── byte_tracker
├── byte_tracker_onnx.py
├── model
│ ├── bytetrack_s.onnx
│ └── bytetrack_s_mot17.pth.tar
├── tracker
│ ├── basetrack.py
│ ├── byte_tracker.py
│ ├── kalman_filter.py
│ └── matching.py
└── utils
│ └── yolox_utils.py
├── demo_video_onnx.py
├── demo_webcam_onnx.py
├── requirements.txt
└── sample.mp4
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # bat
132 | *.bat
--------------------------------------------------------------------------------
/ByteTrack_Convert2ONNX.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "ByteTrack-Convert2ONNX.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "h7ZxXiORz5xD"
23 | },
24 | "source": [
25 | "# ByteTrack リポジトリクローン"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "colab": {
32 | "base_uri": "https://localhost:8080/"
33 | },
34 | "id": "ldhQVvDGz2Aa",
35 | "outputId": "65d61468-fd80-40fe-d212-c36c9e332e62"
36 | },
37 | "source": [
38 | "!git clone https://github.com/ifzhang/ByteTrack"
39 | ],
40 | "execution_count": 1,
41 | "outputs": [
42 | {
43 | "output_type": "stream",
44 | "name": "stdout",
45 | "text": [
46 | "Cloning into 'ByteTrack'...\n",
47 | "remote: Enumerating objects: 1857, done.\u001b[K\n",
48 | "remote: Counting objects: 100% (1171/1171), done.\u001b[K\n",
49 | "remote: Compressing objects: 100% (530/530), done.\u001b[K\n",
50 | "remote: Total 1857 (delta 666), reused 1080 (delta 613), pack-reused 686\u001b[K\n",
51 | "Receiving objects: 100% (1857/1857), 78.15 MiB | 26.27 MiB/s, done.\n",
52 | "Resolving deltas: 100% (1040/1040), done.\n"
53 | ]
54 | }
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "id": "bZWU7zk00Aix",
64 | "outputId": "d7ef096d-d4be-442b-c4b3-6e466b93ab53"
65 | },
66 | "source": [
67 | "%cd ByteTrack"
68 | ],
69 | "execution_count": 2,
70 | "outputs": [
71 | {
72 | "output_type": "stream",
73 | "name": "stdout",
74 | "text": [
75 | "/content/ByteTrack\n"
76 | ]
77 | }
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "metadata": {
83 | "id": "L-eeDkfQ05jM"
84 | },
85 | "source": [
86 | "!mkdir pretrained"
87 | ],
88 | "execution_count": 3,
89 | "outputs": []
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {
94 | "id": "XWCWW-rF4l-X"
95 | },
96 | "source": [
97 | "# 訓練済みモデルダウンロード"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "metadata": {
103 | "colab": {
104 | "base_uri": "https://localhost:8080/",
105 | "height": 375
106 | },
107 | "id": "JDw91Szr083y",
108 | "outputId": "dc228724-78c0-4c11-8148-f1b1859029c1"
109 | },
110 | "source": [
111 | "import gdown\n",
112 | "\n",
113 | "# bytetrack_s_mot17\n",
114 | "gdown.download('https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj', 'bytetrack_s_mot17.pth.tar', quiet=False)\n",
115 | "# bytetrack_m_mot17 \n",
116 | "gdown.download('https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun', 'bytetrack_m_mot17.pth.tar', quiet=False)\n",
117 | "# bytetrack_l_mot17\n",
118 | "gdown.download('https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz', 'bytetrack_l_mot17.pth.tar', quiet=False)\n",
119 | "# bytetrack_x_mot17\n",
120 | "gdown.download('https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5', 'bytetrack_x_mot17.pth.tar', quiet=False)\n",
121 | "\n",
122 | "# bytetrack_x_mot20\n",
123 | "gdown.download('https://drive.google.com/uc?id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U', 'bytetrack_x_mot20.pth.tar', quiet=False)"
124 | ],
125 | "execution_count": 4,
126 | "outputs": [
127 | {
128 | "output_type": "stream",
129 | "name": "stderr",
130 | "text": [
131 | "Downloading...\n",
132 | "From: https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj\n",
133 | "To: /content/ByteTrack/bytetrack_s_mot17.pth.tar\n",
134 | "100%|██████████| 71.8M/71.8M [00:00<00:00, 126MB/s]\n",
135 | "Downloading...\n",
136 | "From: https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun\n",
137 | "To: /content/ByteTrack/bytetrack_m_mot17.pth.tar\n",
138 | "100%|██████████| 203M/203M [00:01<00:00, 163MB/s]\n",
139 | "Downloading...\n",
140 | "From: https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz\n",
141 | "To: /content/ByteTrack/bytetrack_l_mot17.pth.tar\n",
142 | "100%|██████████| 434M/434M [00:02<00:00, 154MB/s]\n",
143 | "Downloading...\n",
144 | "From: https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5\n",
145 | "To: /content/ByteTrack/bytetrack_x_mot17.pth.tar\n",
146 | "100%|██████████| 793M/793M [00:06<00:00, 127MB/s]\n",
147 | "Downloading...\n",
148 | "From: https://drive.google.com/uc?id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U\n",
149 | "To: /content/ByteTrack/bytetrack_x_mot20.pth.tar\n",
150 | "100%|██████████| 793M/793M [00:12<00:00, 62.2MB/s]\n"
151 | ]
152 | },
153 | {
154 | "output_type": "execute_result",
155 | "data": {
156 | "application/vnd.google.colaboratory.intrinsic+json": {
157 | "type": "string"
158 | },
159 | "text/plain": [
160 | "'bytetrack_x_mot20.pth.tar'"
161 | ]
162 | },
163 | "metadata": {},
164 | "execution_count": 4
165 | }
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "metadata": {
171 | "id": "HNInyvrk1X68"
172 | },
173 | "source": [
174 | "!mv bytetrack_s_mot17.pth.tar pretrained\n",
175 | "!mv bytetrack_m_mot17.pth.tar pretrained\n",
176 | "!mv bytetrack_l_mot17.pth.tar pretrained\n",
177 | "!mv bytetrack_x_mot17.pth.tar pretrained\n",
178 | "!mv bytetrack_x_mot20.pth.tar pretrained"
179 | ],
180 | "execution_count": 5,
181 | "outputs": []
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {
186 | "id": "_Eqotubg4te8"
187 | },
188 | "source": [
189 | "# 必要パッケージインストール"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "metadata": {
195 | "colab": {
196 | "base_uri": "https://localhost:8080/"
197 | },
198 | "id": "V8sgVre62vSp",
199 | "outputId": "1ea02318-668e-494e-e43a-c98312e0e884"
200 | },
201 | "source": [
202 | "!pip install loguru\n",
203 | "!pip install thop\n",
204 | "!pip install onnx\n",
205 | "!pip install onnx-simplifier"
206 | ],
207 | "execution_count": 6,
208 | "outputs": [
209 | {
210 | "output_type": "stream",
211 | "name": "stdout",
212 | "text": [
213 | "Collecting loguru\n",
214 | " Downloading loguru-0.5.3-py3-none-any.whl (57 kB)\n",
215 | "\u001b[?25l\r\u001b[K |█████▊ | 10 kB 23.5 MB/s eta 0:00:01\r\u001b[K |███████████▌ | 20 kB 10.0 MB/s eta 0:00:01\r\u001b[K |█████████████████▏ | 30 kB 8.1 MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 40 kB 7.5 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 51 kB 4.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 57 kB 2.5 MB/s \n",
216 | "\u001b[?25hInstalling collected packages: loguru\n",
217 | "Successfully installed loguru-0.5.3\n",
218 | "Collecting thop\n",
219 | " Downloading thop-0.0.31.post2005241907-py3-none-any.whl (8.7 kB)\n",
220 | "Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from thop) (1.10.0+cu111)\n",
221 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.0.0->thop) (3.10.0.2)\n",
222 | "Installing collected packages: thop\n",
223 | "Successfully installed thop-0.0.31.post2005241907\n",
224 | "Collecting onnx\n",
225 | " Downloading onnx-1.10.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (12.7 MB)\n",
226 | "\u001b[K |████████████████████████████████| 12.7 MB 4.3 MB/s \n",
227 | "\u001b[?25hRequirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from onnx) (1.19.5)\n",
228 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from onnx) (1.15.0)\n",
229 | "Requirement already satisfied: protobuf in /usr/local/lib/python3.7/dist-packages (from onnx) (3.17.3)\n",
230 | "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.7/dist-packages (from onnx) (3.10.0.2)\n",
231 | "Installing collected packages: onnx\n",
232 | "Successfully installed onnx-1.10.2\n",
233 | "Collecting onnx-simplifier\n",
234 | " Downloading onnx-simplifier-0.3.6.tar.gz (13 kB)\n",
235 | "Requirement already satisfied: onnx in /usr/local/lib/python3.7/dist-packages (from onnx-simplifier) (1.10.2)\n",
236 | "Collecting onnxoptimizer>=0.2.5\n",
237 | " Downloading onnxoptimizer-0.2.6-cp37-cp37m-manylinux2014_x86_64.whl (466 kB)\n",
238 | "\u001b[K |████████████████████████████████| 466 kB 4.2 MB/s \n",
239 | "\u001b[?25hCollecting onnxruntime>=1.6.0\n",
240 | " Downloading onnxruntime-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)\n",
241 | "\u001b[K |████████████████████████████████| 4.8 MB 32.7 MB/s \n",
242 | "\u001b[?25hRequirement already satisfied: protobuf>=3.7.0 in /usr/local/lib/python3.7/dist-packages (from onnx-simplifier) (3.17.3)\n",
243 | "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from onnxruntime>=1.6.0->onnx-simplifier) (1.19.5)\n",
244 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from onnxruntime>=1.6.0->onnx-simplifier) (2.0)\n",
245 | "Requirement already satisfied: six>=1.9 in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.7.0->onnx-simplifier) (1.15.0)\n",
246 | "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.7/dist-packages (from onnx->onnx-simplifier) (3.10.0.2)\n",
247 | "Building wheels for collected packages: onnx-simplifier\n",
248 | " Building wheel for onnx-simplifier (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
249 | " Created wheel for onnx-simplifier: filename=onnx_simplifier-0.3.6-py3-none-any.whl size=12873 sha256=486587eb3cf7d3fbf4358732c9239be2ed04fe9e6d58de51c5f93a96bd45bfb2\n",
250 | " Stored in directory: /root/.cache/pip/wheels/0c/47/80/8eb21098e22c19d60b1c14021ee67442b4ad2d7991fdad46ba\n",
251 | "Successfully built onnx-simplifier\n",
252 | "Installing collected packages: onnxruntime, onnxoptimizer, onnx-simplifier\n",
253 | "Successfully installed onnx-simplifier-0.3.6 onnxoptimizer-0.2.6 onnxruntime-1.9.0\n"
254 | ]
255 | }
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {
261 | "id": "c_VaJ7x_4xRf"
262 | },
263 | "source": [
264 | "# ONNXエクスポート"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "metadata": {
270 | "id": "IoyLhN-v3wtN"
271 | },
272 | "source": [
273 | "!cp tools/export_onnx.py ./"
274 | ],
275 | "execution_count": 7,
276 | "outputs": []
277 | },
278 | {
279 | "cell_type": "code",
280 | "metadata": {
281 | "colab": {
282 | "base_uri": "https://localhost:8080/"
283 | },
284 | "id": "AoWy7DGg0La7",
285 | "outputId": "d78f03bb-7e19-4d6c-f2db-8eeb981d31f7"
286 | },
287 | "source": [
288 | "!python export_onnx.py \\\n",
289 | " --output-name bytetrack_s.onnx \\\n",
290 | " -f exps/example/mot/yolox_s_mix_det.py \\\n",
291 | " -c pretrained/bytetrack_s_mot17.pth.tar"
292 | ],
293 | "execution_count": 8,
294 | "outputs": [
295 | {
296 | "output_type": "stream",
297 | "name": "stdout",
298 | "text": [
299 | "\u001b[32m2021-11-18 02:34:16.311\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m52\u001b[0m - \u001b[1margs value: Namespace(ckpt='pretrained/bytetrack_s_mot17.pth.tar', exp_file='exps/example/mot/yolox_s_mix_det.py', experiment_name=None, input='images', name=None, no_onnxsim=False, opset=11, opts=[], output='output', output_name='bytetrack_s.onnx')\u001b[0m\n",
300 | "\u001b[32m2021-11-18 02:34:17.162\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m76\u001b[0m - \u001b[1mloading checkpoint done.\u001b[0m\n",
301 | "\u001b[32m2021-11-18 02:34:26.008\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m86\u001b[0m - \u001b[1mgenerated onnx model named bytetrack_s.onnx\u001b[0m\n",
302 | "\u001b[32m2021-11-18 02:34:29.160\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1mgenerated simplified onnx model named bytetrack_s.onnx\u001b[0m\n"
303 | ]
304 | }
305 | ]
306 | }
307 | ]
308 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 KazuhitoTakahashi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | > **Warning**
2 | >
このリポジトリは [ifzhang/ByteTrack](https://github.com/ifzhang/ByteTrack) のモデルファイルを
単純にONNXに変換しソースコードを整理したリポジトリです。
3 | > 特別な理由がない限り、より汎用的な [yolox-bytetrack-sample](https://github.com/Kazuhito00/yolox-bytetrack-sample) や
[yolox-bytetrack-mcmot-sample](https://github.com/Kazuhito00/yolox-bytetrack-mcmot-sample) の使用をおすすめします。
4 |
5 | # ByteTrack-ONNX-Sample
6 | [ByteTrack(Multi-Object Tracking by Associating Every Detection Box)](https://github.com/ifzhang/ByteTrack)のPythonでのONNX推論サンプルです。
7 | ONNXに変換したモデルも同梱しています。
8 | 変換自体を試したい方は[ByteTrack_Convert2ONNX.ipynb](ByteTrack_Convert2ONNX.ipynb)を使用ください。
9 | [ByteTrack_Convert2ONNX.ipynb](ByteTrack_Convert2ONNX.ipynb)はColaboratory上での実行を想定しています。
10 | 以下の動画はWindowsでの実行例です。
11 |
12 | https://user-images.githubusercontent.com/37477845/142617492-7fef3f6e-5725-480c-b059-0f2dee1606bc.mp4
13 |
14 | # Requirement
15 | opencv-python 4.5.3.56 or later
16 | onnx 1.9.0 or later
17 | onnxruntime-gpu 1.9.0 or later
18 | Cython 0.29.24 or later
19 | torch 1.8.1 or later
20 | torchvision 0.9.1 or later
21 | pycocotools 2.0.2 or later
22 | scipy 1.6.3 or later
23 | loguru 0.5.3 or later
24 | thop 0.0.31.post2005241907 or later
25 | lap 0.4.0 or later
26 | cython_bbox 0.1.3 or later
27 |
28 | ※onnxruntime-gpuはonnxruntimeでも動作しますが、推論時間がかかるためGPUを推奨します
29 | ※Windowsでcython_bbox のインストールが失敗する場合は、GitHubからのインストールをお試しください(2021/11/19時点)
30 | `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`
31 |
32 | # Demo
33 | デモの実行方法は以下です。
34 | #### 動画:動画に対しByteTrackで追跡した結果を動画出力します
35 | ```bash
36 | python demo_video_onnx.py
37 | ```
38 |
39 | 実行時オプション
40 |
41 | * --use_debug_window
42 | 動画書き込み時に書き込みフレームをGUI表示するか否か
43 | デフォルト:指定なし
44 | * --model
45 | ByteTrackのONNXモデル格納パス
46 | デフォルト:byte_tracker/model/bytetrack_s.onnx
47 | * --video
48 | 入力動画の格納パス
49 | デフォルト:sample.mp4
50 | * --output_dir
51 | 動画出力パス
52 | デフォルト:output
53 | * --score_th
54 | 人検出のスコア閾値
55 | デフォルト:0.1
56 | * --nms_th
57 | 人検出のNMS閾値
58 | デフォルト:0.7
59 | * --input_shape
60 | 推論時入力サイズ
61 | デフォルト:608,1088
62 | * --with_p6
63 | YOLOXモデルのFPN/PANでp6を含むか否か
64 | デフォルト:指定なし
65 | * --track_thresh
66 | 追跡時のスコア閾値
67 | デフォルト:0.5
68 | * --track_buffer
69 | 見失い時に何フレームの間、追跡対象を保持するか
70 | デフォルト:30
71 | * --match_thresh
72 | 追跡時のマッチングスコア閾値
73 | デフォルト:0.8
74 | * --min-box-area
75 | 最小のバウンディングボックスのサイズ閾値
76 | デフォルト:10
77 | * --mot20
78 | MOT20を使用しているか否か
79 | デフォルト:指定なし
80 |
81 |
82 | #### Webカメラ:Webカメラ画像に対しByteTrackで追跡した結果をGUI表示します
83 | ```bash
84 | python demo_webcam_onnx.py
85 | ```
86 |
87 | 実行時オプション
88 |
89 | * --model
90 | ByteTrackのONNXモデル格納パス
91 | デフォルト:byte_tracker/model/bytetrack_s.onnx
92 | * --device
93 | カメラデバイス番号の指定
94 | デフォルト:0
95 | * --width
96 | カメラキャプチャ時の横幅
97 | デフォルト:960
98 | * --height
99 | カメラキャプチャ時の縦幅
100 | デフォルト:540
101 | * --score_th
102 | 人検出のスコア閾値
103 | デフォルト:0.1
104 | * --nms_th
105 | 人検出のNMS閾値
106 | デフォルト:0.7
107 | * --input_shape
108 | 推論時入力サイズ
109 | デフォルト:608,1088
110 | * --with_p6
111 | YOLOXモデルのFPN/PANでp6を含むか否か
112 | デフォルト:指定なし
113 | * --track_thresh
114 | 追跡時のスコア閾値
115 | デフォルト:0.5
116 | * --track_buffer
117 | 見失い時に何フレームの間、追跡対象を保持するか
118 | デフォルト:30
119 | * --match_thresh
120 | 追跡時のマッチングスコア閾値
121 | デフォルト:0.8
122 | * --min-box-area
123 | 最小のバウンディングボックスのサイズ閾値
124 | デフォルト:10
125 | * --mot20
126 | MOT20を使用しているか否か
127 | デフォルト:指定なし
128 |
129 |
130 | # Reference
131 | * [ifzhang/ByteTrack](https://github.com/ifzhang/ByteTrack)
132 |
133 | # Author
134 | 高橋かずひと(https://twitter.com/KzhtTkhs)
135 |
136 | # License
137 | ByteTrack-ONNX-Sample is under [MIT License](LICENSE).
138 |
139 | # License(Movie)
140 | サンプル動画は[NHKクリエイティブ・ライブラリー](https://www.nhk.or.jp/archives/creative/)の[イギリス ウースターのエルガー像](https://www2.nhk.or.jp/archives/creative/material/view.cgi?m=D0002011239_00000)を使用しています。
141 |
--------------------------------------------------------------------------------
/byte_tracker/byte_tracker_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import copy
4 |
5 | import numpy as np
6 | import onnxruntime
7 |
8 | from byte_tracker.utils.yolox_utils import (
9 | pre_process,
10 | post_process,
11 | multiclass_nms,
12 | )
13 | from byte_tracker.tracker.byte_tracker import BYTETracker
14 |
15 |
16 | class ByteTrackerONNX(object):
17 | def __init__(self, args):
18 | self.args = args
19 |
20 | self.rgb_means = (0.485, 0.456, 0.406)
21 | self.std = (0.229, 0.224, 0.225)
22 |
23 | # ByteTracker ONNX読み込み
24 | self.session = onnxruntime.InferenceSession(args.model)
25 | self.input_shape = tuple(map(int, args.input_shape.split(',')))
26 |
27 | # ByteTrackerインスタンス生成
28 | self.tracker = BYTETracker(args, frame_rate=30)
29 |
30 | def _pre_process(self, image):
31 | image_info = {'id': 0}
32 |
33 | image_info['image'] = copy.deepcopy(image)
34 | image_info['width'] = image.shape[1]
35 | image_info['height'] = image.shape[0]
36 |
37 | preprocessed_image, ratio = pre_process(
38 | image,
39 | self.input_shape,
40 | self.rgb_means,
41 | self.std,
42 | )
43 | image_info['ratio'] = ratio
44 |
45 | return preprocessed_image, image_info
46 |
47 | def inference(self, image):
48 | # 前処理
49 | image, image_info = self._pre_process(image)
50 |
51 | # 推論
52 | input_name = self.session.get_inputs()[0].name
53 | result = self.session.run(None, {input_name: image[None, :, :, :]})
54 |
55 | # 後処理
56 | dets = self._post_process(result, image_info)
57 |
58 | # トラッカー更新
59 | bboxes, ids, scores = self._tracker_update(
60 | dets,
61 | image_info,
62 | )
63 |
64 | return image_info, bboxes, ids, scores
65 |
66 | def _post_process(self, result, image_info):
67 | # バウンディングボックス、スコア算出
68 | predictions = post_process(
69 | result[0],
70 | self.input_shape,
71 | p6=self.args.with_p6,
72 | )
73 | predictions = predictions[0]
74 | boxes = predictions[:, :4]
75 | scores = predictions[:, 4:5] * predictions[:, 5:]
76 |
77 | # 相対座標を絶対座標へ変換
78 | boxes_xyxy = np.ones_like(boxes)
79 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
80 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
81 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
82 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
83 | boxes_xyxy /= image_info['ratio']
84 |
85 | # NMS実施
86 | dets = multiclass_nms(
87 | boxes_xyxy,
88 | scores,
89 | nms_thr=self.args.nms_th,
90 | score_thr=self.args.score_th,
91 | )
92 |
93 | return dets
94 |
95 | def _tracker_update(self, dets, image_info):
96 | # トラッカー更新
97 | online_targets = []
98 | if dets is not None:
99 | online_targets = self.tracker.update(
100 | dets[:, :-1],
101 | [image_info['height'], image_info['width']],
102 | [image_info['height'], image_info['width']],
103 | )
104 |
105 | online_tlwhs = []
106 | online_ids = []
107 | online_scores = []
108 | for online_target in online_targets:
109 | tlwh = online_target.tlwh
110 | track_id = online_target.track_id
111 | vertical = tlwh[2] / tlwh[3] > 1.6
112 | if tlwh[2] * tlwh[3] > self.args.min_box_area and not vertical:
113 | online_tlwhs.append(tlwh)
114 | online_ids.append(track_id)
115 | online_scores.append(online_target.score)
116 |
117 | return online_tlwhs, online_ids, online_scores
118 |
--------------------------------------------------------------------------------
/byte_tracker/model/bytetrack_s.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/byte_tracker/model/bytetrack_s.onnx
--------------------------------------------------------------------------------
/byte_tracker/model/bytetrack_s_mot17.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/byte_tracker/model/bytetrack_s_mot17.pth.tar
--------------------------------------------------------------------------------
/byte_tracker/tracker/basetrack.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import OrderedDict
3 |
4 |
5 | class TrackState(object):
6 | New = 0
7 | Tracked = 1
8 | Lost = 2
9 | Removed = 3
10 |
11 |
12 | class BaseTrack(object):
13 | _count = 0
14 |
15 | track_id = 0
16 | is_activated = False
17 | state = TrackState.New
18 |
19 | history = OrderedDict()
20 | features = []
21 | curr_feature = None
22 | score = 0
23 | start_frame = 0
24 | frame_id = 0
25 | time_since_update = 0
26 |
27 | # multi-camera
28 | location = (np.inf, np.inf)
29 |
30 | @property
31 | def end_frame(self):
32 | return self.frame_id
33 |
34 | @staticmethod
35 | def next_id():
36 | BaseTrack._count += 1
37 | return BaseTrack._count
38 |
39 | def activate(self, *args):
40 | raise NotImplementedError
41 |
42 | def predict(self):
43 | raise NotImplementedError
44 |
45 | def update(self, *args, **kwargs):
46 | raise NotImplementedError
47 |
48 | def mark_lost(self):
49 | self.state = TrackState.Lost
50 |
51 | def mark_removed(self):
52 | self.state = TrackState.Removed
--------------------------------------------------------------------------------
/byte_tracker/tracker/byte_tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import deque
3 | import os
4 | import os.path as osp
5 | import copy
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | from .kalman_filter import KalmanFilter
10 | from byte_tracker.tracker import matching
11 | from .basetrack import BaseTrack, TrackState
12 |
13 | class STrack(BaseTrack):
14 | shared_kalman = KalmanFilter()
15 | def __init__(self, tlwh, score):
16 |
17 | # wait activate
18 | self._tlwh = np.asarray(tlwh, dtype=np.float)
19 | self.kalman_filter = None
20 | self.mean, self.covariance = None, None
21 | self.is_activated = False
22 |
23 | self.score = score
24 | self.tracklet_len = 0
25 |
26 | def predict(self):
27 | mean_state = self.mean.copy()
28 | if self.state != TrackState.Tracked:
29 | mean_state[7] = 0
30 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
31 |
32 | @staticmethod
33 | def multi_predict(stracks):
34 | if len(stracks) > 0:
35 | multi_mean = np.asarray([st.mean.copy() for st in stracks])
36 | multi_covariance = np.asarray([st.covariance for st in stracks])
37 | for i, st in enumerate(stracks):
38 | if st.state != TrackState.Tracked:
39 | multi_mean[i][7] = 0
40 | multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
41 | for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
42 | stracks[i].mean = mean
43 | stracks[i].covariance = cov
44 |
45 | def activate(self, kalman_filter, frame_id):
46 | """Start a new tracklet"""
47 | self.kalman_filter = kalman_filter
48 | self.track_id = self.next_id()
49 | self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
50 |
51 | self.tracklet_len = 0
52 | self.state = TrackState.Tracked
53 | if frame_id == 1:
54 | self.is_activated = True
55 | # self.is_activated = True
56 | self.frame_id = frame_id
57 | self.start_frame = frame_id
58 |
59 | def re_activate(self, new_track, frame_id, new_id=False):
60 | self.mean, self.covariance = self.kalman_filter.update(
61 | self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
62 | )
63 | self.tracklet_len = 0
64 | self.state = TrackState.Tracked
65 | self.is_activated = True
66 | self.frame_id = frame_id
67 | if new_id:
68 | self.track_id = self.next_id()
69 | self.score = new_track.score
70 |
71 | def update(self, new_track, frame_id):
72 | """
73 | Update a matched track
74 | :type new_track: STrack
75 | :type frame_id: int
76 | :type update_feature: bool
77 | :return:
78 | """
79 | self.frame_id = frame_id
80 | self.tracklet_len += 1
81 |
82 | new_tlwh = new_track.tlwh
83 | self.mean, self.covariance = self.kalman_filter.update(
84 | self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
85 | self.state = TrackState.Tracked
86 | self.is_activated = True
87 |
88 | self.score = new_track.score
89 |
90 | @property
91 | # @jit(nopython=True)
92 | def tlwh(self):
93 | """Get current position in bounding box format `(top left x, top left y,
94 | width, height)`.
95 | """
96 | if self.mean is None:
97 | return self._tlwh.copy()
98 | ret = self.mean[:4].copy()
99 | ret[2] *= ret[3]
100 | ret[:2] -= ret[2:] / 2
101 | return ret
102 |
103 | @property
104 | # @jit(nopython=True)
105 | def tlbr(self):
106 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
107 | `(top left, bottom right)`.
108 | """
109 | ret = self.tlwh.copy()
110 | ret[2:] += ret[:2]
111 | return ret
112 |
113 | @staticmethod
114 | # @jit(nopython=True)
115 | def tlwh_to_xyah(tlwh):
116 | """Convert bounding box to format `(center x, center y, aspect ratio,
117 | height)`, where the aspect ratio is `width / height`.
118 | """
119 | ret = np.asarray(tlwh).copy()
120 | ret[:2] += ret[2:] / 2
121 | ret[2] /= ret[3]
122 | return ret
123 |
124 | def to_xyah(self):
125 | return self.tlwh_to_xyah(self.tlwh)
126 |
127 | @staticmethod
128 | # @jit(nopython=True)
129 | def tlbr_to_tlwh(tlbr):
130 | ret = np.asarray(tlbr).copy()
131 | ret[2:] -= ret[:2]
132 | return ret
133 |
134 | @staticmethod
135 | # @jit(nopython=True)
136 | def tlwh_to_tlbr(tlwh):
137 | ret = np.asarray(tlwh).copy()
138 | ret[2:] += ret[:2]
139 | return ret
140 |
141 | def __repr__(self):
142 | return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
143 |
144 |
145 | class BYTETracker(object):
146 | def __init__(self, args, frame_rate=30):
147 | self.tracked_stracks = [] # type: list[STrack]
148 | self.lost_stracks = [] # type: list[STrack]
149 | self.removed_stracks = [] # type: list[STrack]
150 |
151 | self.frame_id = 0
152 | self.args = args
153 | #self.det_thresh = args.track_thresh
154 | self.det_thresh = args.track_thresh + 0.1
155 | self.buffer_size = int(frame_rate / 30.0 * args.track_buffer)
156 | self.max_time_lost = self.buffer_size
157 | self.kalman_filter = KalmanFilter()
158 |
159 | def update(self, output_results, img_info, img_size):
160 | self.frame_id += 1
161 | activated_starcks = []
162 | refind_stracks = []
163 | lost_stracks = []
164 | removed_stracks = []
165 |
166 | if output_results.shape[1] == 5:
167 | scores = output_results[:, 4]
168 | bboxes = output_results[:, :4]
169 | else:
170 | output_results = output_results.cpu().numpy()
171 | scores = output_results[:, 4] * output_results[:, 5]
172 | bboxes = output_results[:, :4] # x1y1x2y2
173 | img_h, img_w = img_info[0], img_info[1]
174 | scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w))
175 | bboxes /= scale
176 |
177 | remain_inds = scores > self.args.track_thresh
178 | inds_low = scores > 0.1
179 | inds_high = scores < self.args.track_thresh
180 |
181 | inds_second = np.logical_and(inds_low, inds_high)
182 | dets_second = bboxes[inds_second]
183 | dets = bboxes[remain_inds]
184 | scores_keep = scores[remain_inds]
185 | scores_second = scores[inds_second]
186 |
187 | if len(dets) > 0:
188 | '''Detections'''
189 | detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
190 | (tlbr, s) in zip(dets, scores_keep)]
191 | else:
192 | detections = []
193 |
194 | ''' Add newly detected tracklets to tracked_stracks'''
195 | unconfirmed = []
196 | tracked_stracks = [] # type: list[STrack]
197 | for track in self.tracked_stracks:
198 | if not track.is_activated:
199 | unconfirmed.append(track)
200 | else:
201 | tracked_stracks.append(track)
202 |
203 | ''' Step 2: First association, with high score detection boxes'''
204 | strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
205 | # Predict the current location with KF
206 | STrack.multi_predict(strack_pool)
207 | dists = matching.iou_distance(strack_pool, detections)
208 | if not self.args.mot20:
209 | dists = matching.fuse_score(dists, detections)
210 | matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
211 |
212 | for itracked, idet in matches:
213 | track = strack_pool[itracked]
214 | det = detections[idet]
215 | if track.state == TrackState.Tracked:
216 | track.update(detections[idet], self.frame_id)
217 | activated_starcks.append(track)
218 | else:
219 | track.re_activate(det, self.frame_id, new_id=False)
220 | refind_stracks.append(track)
221 |
222 | ''' Step 3: Second association, with low score detection boxes'''
223 | # association the untrack to the low score detections
224 | if len(dets_second) > 0:
225 | '''Detections'''
226 | detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
227 | (tlbr, s) in zip(dets_second, scores_second)]
228 | else:
229 | detections_second = []
230 | r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
231 | dists = matching.iou_distance(r_tracked_stracks, detections_second)
232 | matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
233 | for itracked, idet in matches:
234 | track = r_tracked_stracks[itracked]
235 | det = detections_second[idet]
236 | if track.state == TrackState.Tracked:
237 | track.update(det, self.frame_id)
238 | activated_starcks.append(track)
239 | else:
240 | track.re_activate(det, self.frame_id, new_id=False)
241 | refind_stracks.append(track)
242 |
243 | for it in u_track:
244 | track = r_tracked_stracks[it]
245 | if not track.state == TrackState.Lost:
246 | track.mark_lost()
247 | lost_stracks.append(track)
248 |
249 | '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
250 | detections = [detections[i] for i in u_detection]
251 | dists = matching.iou_distance(unconfirmed, detections)
252 | if not self.args.mot20:
253 | dists = matching.fuse_score(dists, detections)
254 | matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
255 | for itracked, idet in matches:
256 | unconfirmed[itracked].update(detections[idet], self.frame_id)
257 | activated_starcks.append(unconfirmed[itracked])
258 | for it in u_unconfirmed:
259 | track = unconfirmed[it]
260 | track.mark_removed()
261 | removed_stracks.append(track)
262 |
263 | """ Step 4: Init new stracks"""
264 | for inew in u_detection:
265 | track = detections[inew]
266 | if track.score < self.det_thresh:
267 | continue
268 | track.activate(self.kalman_filter, self.frame_id)
269 | activated_starcks.append(track)
270 | """ Step 5: Update state"""
271 | for track in self.lost_stracks:
272 | if self.frame_id - track.end_frame > self.max_time_lost:
273 | track.mark_removed()
274 | removed_stracks.append(track)
275 |
276 | # print('Ramained match {} s'.format(t4-t3))
277 |
278 | self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
279 | self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
280 | self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
281 | self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
282 | self.lost_stracks.extend(lost_stracks)
283 | self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
284 | self.removed_stracks.extend(removed_stracks)
285 | self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
286 | # get scores of lost tracks
287 | output_stracks = [track for track in self.tracked_stracks if track.is_activated]
288 |
289 | return output_stracks
290 |
291 |
292 | def joint_stracks(tlista, tlistb):
293 | exists = {}
294 | res = []
295 | for t in tlista:
296 | exists[t.track_id] = 1
297 | res.append(t)
298 | for t in tlistb:
299 | tid = t.track_id
300 | if not exists.get(tid, 0):
301 | exists[tid] = 1
302 | res.append(t)
303 | return res
304 |
305 |
306 | def sub_stracks(tlista, tlistb):
307 | stracks = {}
308 | for t in tlista:
309 | stracks[t.track_id] = t
310 | for t in tlistb:
311 | tid = t.track_id
312 | if stracks.get(tid, 0):
313 | del stracks[tid]
314 | return list(stracks.values())
315 |
316 |
317 | def remove_duplicate_stracks(stracksa, stracksb):
318 | pdist = matching.iou_distance(stracksa, stracksb)
319 | pairs = np.where(pdist < 0.15)
320 | dupa, dupb = list(), list()
321 | for p, q in zip(*pairs):
322 | timep = stracksa[p].frame_id - stracksa[p].start_frame
323 | timeq = stracksb[q].frame_id - stracksb[q].start_frame
324 | if timep > timeq:
325 | dupb.append(q)
326 | else:
327 | dupa.append(p)
328 | resa = [t for i, t in enumerate(stracksa) if not i in dupa]
329 | resb = [t for i, t in enumerate(stracksb) if not i in dupb]
330 | return resa, resb
331 |
--------------------------------------------------------------------------------
/byte_tracker/tracker/kalman_filter.py:
--------------------------------------------------------------------------------
1 | # vim: expandtab:ts=4:sw=4
2 | import numpy as np
3 | import scipy.linalg
4 |
5 |
6 | """
7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of
8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
9 | function and used as Mahalanobis gating threshold.
10 | """
11 | chi2inv95 = {
12 | 1: 3.8415,
13 | 2: 5.9915,
14 | 3: 7.8147,
15 | 4: 9.4877,
16 | 5: 11.070,
17 | 6: 12.592,
18 | 7: 14.067,
19 | 8: 15.507,
20 | 9: 16.919}
21 |
22 |
23 | class KalmanFilter(object):
24 | """
25 | A simple Kalman filter for tracking bounding boxes in image space.
26 |
27 | The 8-dimensional state space
28 |
29 | x, y, a, h, vx, vy, va, vh
30 |
31 | contains the bounding box center position (x, y), aspect ratio a, height h,
32 | and their respective velocities.
33 |
34 | Object motion follows a constant velocity model. The bounding box location
35 | (x, y, a, h) is taken as direct observation of the state space (linear
36 | observation model).
37 |
38 | """
39 |
40 | def __init__(self):
41 | ndim, dt = 4, 1.
42 |
43 | # Create Kalman filter model matrices.
44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim)
45 | for i in range(ndim):
46 | self._motion_mat[i, ndim + i] = dt
47 | self._update_mat = np.eye(ndim, 2 * ndim)
48 |
49 | # Motion and observation uncertainty are chosen relative to the current
50 | # state estimate. These weights control the amount of uncertainty in
51 | # the model. This is a bit hacky.
52 | self._std_weight_position = 1. / 20
53 | self._std_weight_velocity = 1. / 160
54 |
55 | def initiate(self, measurement):
56 | """Create track from unassociated measurement.
57 |
58 | Parameters
59 | ----------
60 | measurement : ndarray
61 | Bounding box coordinates (x, y, a, h) with center position (x, y),
62 | aspect ratio a, and height h.
63 |
64 | Returns
65 | -------
66 | (ndarray, ndarray)
67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8
68 | dimensional) of the new track. Unobserved velocities are initialized
69 | to 0 mean.
70 |
71 | """
72 | mean_pos = measurement
73 | mean_vel = np.zeros_like(mean_pos)
74 | mean = np.r_[mean_pos, mean_vel]
75 |
76 | std = [
77 | 2 * self._std_weight_position * measurement[3],
78 | 2 * self._std_weight_position * measurement[3],
79 | 1e-2,
80 | 2 * self._std_weight_position * measurement[3],
81 | 10 * self._std_weight_velocity * measurement[3],
82 | 10 * self._std_weight_velocity * measurement[3],
83 | 1e-5,
84 | 10 * self._std_weight_velocity * measurement[3]]
85 | covariance = np.diag(np.square(std))
86 | return mean, covariance
87 |
88 | def predict(self, mean, covariance):
89 | """Run Kalman filter prediction step.
90 |
91 | Parameters
92 | ----------
93 | mean : ndarray
94 | The 8 dimensional mean vector of the object state at the previous
95 | time step.
96 | covariance : ndarray
97 | The 8x8 dimensional covariance matrix of the object state at the
98 | previous time step.
99 |
100 | Returns
101 | -------
102 | (ndarray, ndarray)
103 | Returns the mean vector and covariance matrix of the predicted
104 | state. Unobserved velocities are initialized to 0 mean.
105 |
106 | """
107 | std_pos = [
108 | self._std_weight_position * mean[3],
109 | self._std_weight_position * mean[3],
110 | 1e-2,
111 | self._std_weight_position * mean[3]]
112 | std_vel = [
113 | self._std_weight_velocity * mean[3],
114 | self._std_weight_velocity * mean[3],
115 | 1e-5,
116 | self._std_weight_velocity * mean[3]]
117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
118 |
119 | #mean = np.dot(self._motion_mat, mean)
120 | mean = np.dot(mean, self._motion_mat.T)
121 | covariance = np.linalg.multi_dot((
122 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
123 |
124 | return mean, covariance
125 |
126 | def project(self, mean, covariance):
127 | """Project state distribution to measurement space.
128 |
129 | Parameters
130 | ----------
131 | mean : ndarray
132 | The state's mean vector (8 dimensional array).
133 | covariance : ndarray
134 | The state's covariance matrix (8x8 dimensional).
135 |
136 | Returns
137 | -------
138 | (ndarray, ndarray)
139 | Returns the projected mean and covariance matrix of the given state
140 | estimate.
141 |
142 | """
143 | std = [
144 | self._std_weight_position * mean[3],
145 | self._std_weight_position * mean[3],
146 | 1e-1,
147 | self._std_weight_position * mean[3]]
148 | innovation_cov = np.diag(np.square(std))
149 |
150 | mean = np.dot(self._update_mat, mean)
151 | covariance = np.linalg.multi_dot((
152 | self._update_mat, covariance, self._update_mat.T))
153 | return mean, covariance + innovation_cov
154 |
155 | def multi_predict(self, mean, covariance):
156 | """Run Kalman filter prediction step (Vectorized version).
157 | Parameters
158 | ----------
159 | mean : ndarray
160 | The Nx8 dimensional mean matrix of the object states at the previous
161 | time step.
162 | covariance : ndarray
163 | The Nx8x8 dimensional covariance matrics of the object states at the
164 | previous time step.
165 | Returns
166 | -------
167 | (ndarray, ndarray)
168 | Returns the mean vector and covariance matrix of the predicted
169 | state. Unobserved velocities are initialized to 0 mean.
170 | """
171 | std_pos = [
172 | self._std_weight_position * mean[:, 3],
173 | self._std_weight_position * mean[:, 3],
174 | 1e-2 * np.ones_like(mean[:, 3]),
175 | self._std_weight_position * mean[:, 3]]
176 | std_vel = [
177 | self._std_weight_velocity * mean[:, 3],
178 | self._std_weight_velocity * mean[:, 3],
179 | 1e-5 * np.ones_like(mean[:, 3]),
180 | self._std_weight_velocity * mean[:, 3]]
181 | sqr = np.square(np.r_[std_pos, std_vel]).T
182 |
183 | motion_cov = []
184 | for i in range(len(mean)):
185 | motion_cov.append(np.diag(sqr[i]))
186 | motion_cov = np.asarray(motion_cov)
187 |
188 | mean = np.dot(mean, self._motion_mat.T)
189 | left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
190 | covariance = np.dot(left, self._motion_mat.T) + motion_cov
191 |
192 | return mean, covariance
193 |
194 | def update(self, mean, covariance, measurement):
195 | """Run Kalman filter correction step.
196 |
197 | Parameters
198 | ----------
199 | mean : ndarray
200 | The predicted state's mean vector (8 dimensional).
201 | covariance : ndarray
202 | The state's covariance matrix (8x8 dimensional).
203 | measurement : ndarray
204 | The 4 dimensional measurement vector (x, y, a, h), where (x, y)
205 | is the center position, a the aspect ratio, and h the height of the
206 | bounding box.
207 |
208 | Returns
209 | -------
210 | (ndarray, ndarray)
211 | Returns the measurement-corrected state distribution.
212 |
213 | """
214 | projected_mean, projected_cov = self.project(mean, covariance)
215 |
216 | chol_factor, lower = scipy.linalg.cho_factor(
217 | projected_cov, lower=True, check_finite=False)
218 | kalman_gain = scipy.linalg.cho_solve(
219 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
220 | check_finite=False).T
221 | innovation = measurement - projected_mean
222 |
223 | new_mean = mean + np.dot(innovation, kalman_gain.T)
224 | new_covariance = covariance - np.linalg.multi_dot((
225 | kalman_gain, projected_cov, kalman_gain.T))
226 | return new_mean, new_covariance
227 |
228 | def gating_distance(self, mean, covariance, measurements,
229 | only_position=False, metric='maha'):
230 | """Compute gating distance between state distribution and measurements.
231 | A suitable distance threshold can be obtained from `chi2inv95`. If
232 | `only_position` is False, the chi-square distribution has 4 degrees of
233 | freedom, otherwise 2.
234 | Parameters
235 | ----------
236 | mean : ndarray
237 | Mean vector over the state distribution (8 dimensional).
238 | covariance : ndarray
239 | Covariance of the state distribution (8x8 dimensional).
240 | measurements : ndarray
241 | An Nx4 dimensional matrix of N measurements, each in
242 | format (x, y, a, h) where (x, y) is the bounding box center
243 | position, a the aspect ratio, and h the height.
244 | only_position : Optional[bool]
245 | If True, distance computation is done with respect to the bounding
246 | box center position only.
247 | Returns
248 | -------
249 | ndarray
250 | Returns an array of length N, where the i-th element contains the
251 | squared Mahalanobis distance between (mean, covariance) and
252 | `measurements[i]`.
253 | """
254 | mean, covariance = self.project(mean, covariance)
255 | if only_position:
256 | mean, covariance = mean[:2], covariance[:2, :2]
257 | measurements = measurements[:, :2]
258 |
259 | d = measurements - mean
260 | if metric == 'gaussian':
261 | return np.sum(d * d, axis=1)
262 | elif metric == 'maha':
263 | cholesky_factor = np.linalg.cholesky(covariance)
264 | z = scipy.linalg.solve_triangular(
265 | cholesky_factor, d.T, lower=True, check_finite=False,
266 | overwrite_b=True)
267 | squared_maha = np.sum(z * z, axis=0)
268 | return squared_maha
269 | else:
270 | raise ValueError('invalid distance metric')
--------------------------------------------------------------------------------
/byte_tracker/tracker/matching.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import scipy
4 | import lap
5 | from scipy.spatial.distance import cdist
6 |
7 | from cython_bbox import bbox_overlaps as bbox_ious
8 | from byte_tracker.tracker import kalman_filter
9 | import time
10 |
11 | def merge_matches(m1, m2, shape):
12 | O,P,Q = shape
13 | m1 = np.asarray(m1)
14 | m2 = np.asarray(m2)
15 |
16 | M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
17 | M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
18 |
19 | mask = M1*M2
20 | match = mask.nonzero()
21 | match = list(zip(match[0], match[1]))
22 | unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
23 | unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
24 |
25 | return match, unmatched_O, unmatched_Q
26 |
27 |
28 | def _indices_to_matches(cost_matrix, indices, thresh):
29 | matched_cost = cost_matrix[tuple(zip(*indices))]
30 | matched_mask = (matched_cost <= thresh)
31 |
32 | matches = indices[matched_mask]
33 | unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
34 | unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
35 |
36 | return matches, unmatched_a, unmatched_b
37 |
38 |
39 | def linear_assignment(cost_matrix, thresh):
40 | if cost_matrix.size == 0:
41 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
42 | matches, unmatched_a, unmatched_b = [], [], []
43 | cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
44 | for ix, mx in enumerate(x):
45 | if mx >= 0:
46 | matches.append([ix, mx])
47 | unmatched_a = np.where(x < 0)[0]
48 | unmatched_b = np.where(y < 0)[0]
49 | matches = np.asarray(matches)
50 | return matches, unmatched_a, unmatched_b
51 |
52 |
53 | def ious(atlbrs, btlbrs):
54 | """
55 | Compute cost based on IoU
56 | :type atlbrs: list[tlbr] | np.ndarray
57 | :type atlbrs: list[tlbr] | np.ndarray
58 |
59 | :rtype ious np.ndarray
60 | """
61 | ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
62 | if ious.size == 0:
63 | return ious
64 |
65 | ious = bbox_ious(
66 | np.ascontiguousarray(atlbrs, dtype=np.float),
67 | np.ascontiguousarray(btlbrs, dtype=np.float)
68 | )
69 |
70 | return ious
71 |
72 |
73 | def iou_distance(atracks, btracks):
74 | """
75 | Compute cost based on IoU
76 | :type atracks: list[STrack]
77 | :type btracks: list[STrack]
78 |
79 | :rtype cost_matrix np.ndarray
80 | """
81 |
82 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
83 | atlbrs = atracks
84 | btlbrs = btracks
85 | else:
86 | atlbrs = [track.tlbr for track in atracks]
87 | btlbrs = [track.tlbr for track in btracks]
88 | _ious = ious(atlbrs, btlbrs)
89 | cost_matrix = 1 - _ious
90 |
91 | return cost_matrix
92 |
93 | def v_iou_distance(atracks, btracks):
94 | """
95 | Compute cost based on IoU
96 | :type atracks: list[STrack]
97 | :type btracks: list[STrack]
98 |
99 | :rtype cost_matrix np.ndarray
100 | """
101 |
102 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
103 | atlbrs = atracks
104 | btlbrs = btracks
105 | else:
106 | atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
107 | btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
108 | _ious = ious(atlbrs, btlbrs)
109 | cost_matrix = 1 - _ious
110 |
111 | return cost_matrix
112 |
113 | def embedding_distance(tracks, detections, metric='cosine'):
114 | """
115 | :param tracks: list[STrack]
116 | :param detections: list[BaseTrack]
117 | :param metric:
118 | :return: cost_matrix np.ndarray
119 | """
120 |
121 | cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
122 | if cost_matrix.size == 0:
123 | return cost_matrix
124 | det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
125 | #for i, track in enumerate(tracks):
126 | #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
127 | track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)
128 | cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
129 | return cost_matrix
130 |
131 |
132 | def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
133 | if cost_matrix.size == 0:
134 | return cost_matrix
135 | gating_dim = 2 if only_position else 4
136 | gating_threshold = kalman_filter.chi2inv95[gating_dim]
137 | measurements = np.asarray([det.to_xyah() for det in detections])
138 | for row, track in enumerate(tracks):
139 | gating_distance = kf.gating_distance(
140 | track.mean, track.covariance, measurements, only_position)
141 | cost_matrix[row, gating_distance > gating_threshold] = np.inf
142 | return cost_matrix
143 |
144 |
145 | def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
146 | if cost_matrix.size == 0:
147 | return cost_matrix
148 | gating_dim = 2 if only_position else 4
149 | gating_threshold = kalman_filter.chi2inv95[gating_dim]
150 | measurements = np.asarray([det.to_xyah() for det in detections])
151 | for row, track in enumerate(tracks):
152 | gating_distance = kf.gating_distance(
153 | track.mean, track.covariance, measurements, only_position, metric='maha')
154 | cost_matrix[row, gating_distance > gating_threshold] = np.inf
155 | cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
156 | return cost_matrix
157 |
158 |
159 | def fuse_iou(cost_matrix, tracks, detections):
160 | if cost_matrix.size == 0:
161 | return cost_matrix
162 | reid_sim = 1 - cost_matrix
163 | iou_dist = iou_distance(tracks, detections)
164 | iou_sim = 1 - iou_dist
165 | fuse_sim = reid_sim * (1 + iou_sim) / 2
166 | det_scores = np.array([det.score for det in detections])
167 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
168 | #fuse_sim = fuse_sim * (1 + det_scores) / 2
169 | fuse_cost = 1 - fuse_sim
170 | return fuse_cost
171 |
172 |
173 | def fuse_score(cost_matrix, detections):
174 | if cost_matrix.size == 0:
175 | return cost_matrix
176 | iou_sim = 1 - cost_matrix
177 | det_scores = np.array([det.score for det in detections])
178 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
179 | fuse_sim = iou_sim * det_scores
180 | fuse_cost = 1 - fuse_sim
181 | return fuse_cost
--------------------------------------------------------------------------------
/byte_tracker/utils/yolox_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | def nms(boxes, scores, nms_thr):
8 | """Single class NMS implemented in Numpy."""
9 | x1 = boxes[:, 0]
10 | y1 = boxes[:, 1]
11 | x2 = boxes[:, 2]
12 | y2 = boxes[:, 3]
13 |
14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15 | order = scores.argsort()[::-1]
16 |
17 | keep = []
18 | while order.size > 0:
19 | i = order[0]
20 | keep.append(i)
21 | xx1 = np.maximum(x1[i], x1[order[1:]])
22 | yy1 = np.maximum(y1[i], y1[order[1:]])
23 | xx2 = np.minimum(x2[i], x2[order[1:]])
24 | yy2 = np.minimum(y2[i], y2[order[1:]])
25 |
26 | w = np.maximum(0.0, xx2 - xx1 + 1)
27 | h = np.maximum(0.0, yy2 - yy1 + 1)
28 | inter = w * h
29 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
30 |
31 | inds = np.where(ovr <= nms_thr)[0]
32 | order = order[inds + 1]
33 |
34 | return keep
35 |
36 |
37 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
38 | """Multiclass NMS implemented in Numpy"""
39 | final_dets = []
40 | num_classes = scores.shape[1]
41 | for cls_ind in range(num_classes):
42 | cls_scores = scores[:, cls_ind]
43 | valid_score_mask = cls_scores > score_thr
44 | if valid_score_mask.sum() == 0:
45 | continue
46 | else:
47 | valid_scores = cls_scores[valid_score_mask]
48 | valid_boxes = boxes[valid_score_mask]
49 | keep = nms(valid_boxes, valid_scores, nms_thr)
50 | if len(keep) > 0:
51 | cls_inds = np.ones((len(keep), 1)) * cls_ind
52 | dets = np.concatenate(
53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
54 | final_dets.append(dets)
55 | if len(final_dets) == 0:
56 | return None
57 | return np.concatenate(final_dets, 0)
58 |
59 |
60 | def pre_process(image, input_size, mean, std, swap=(2, 0, 1)):
61 | if len(image.shape) == 3:
62 | padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
63 | else:
64 | padded_img = np.ones(input_size) * 114.0
65 | img = np.array(image)
66 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
67 | resized_img = cv2.resize(
68 | img,
69 | (int(img.shape[1] * r), int(img.shape[0] * r)),
70 | interpolation=cv2.INTER_LINEAR,
71 | ).astype(np.float32)
72 | padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img
73 |
74 | padded_img = padded_img[:, :, ::-1]
75 | padded_img /= 255.0
76 | if mean is not None:
77 | padded_img -= mean
78 | if std is not None:
79 | padded_img /= std
80 | padded_img = padded_img.transpose(swap)
81 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
82 | return padded_img, r
83 |
84 |
85 | def post_process(outputs, img_size, p6=False):
86 | grids = []
87 | expanded_strides = []
88 |
89 | if not p6:
90 | strides = [8, 16, 32]
91 | else:
92 | strides = [8, 16, 32, 64]
93 |
94 | hsizes = [img_size[0] // stride for stride in strides]
95 | wsizes = [img_size[1] // stride for stride in strides]
96 |
97 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
98 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
99 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
100 | grids.append(grid)
101 | shape = grid.shape[:2]
102 | expanded_strides.append(np.full((*shape, 1), stride))
103 |
104 | grids = np.concatenate(grids, 1)
105 | expanded_strides = np.concatenate(expanded_strides, 1)
106 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
107 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
108 |
109 | return outputs
110 |
--------------------------------------------------------------------------------
/demo_video_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | import argparse
5 |
6 | import cv2
7 | from loguru import logger
8 |
9 | from byte_tracker.byte_tracker_onnx import ByteTrackerONNX
10 |
11 |
12 | def get_args():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument(
16 | '--use_debug_window',
17 | action='store_true',
18 | )
19 |
20 | parser.add_argument(
21 | '--model',
22 | type=str,
23 | default='byte_tracker/model/bytetrack_s.onnx',
24 | )
25 | parser.add_argument(
26 | '--video',
27 | type=str,
28 | default='sample.mp4',
29 | )
30 | parser.add_argument(
31 | '--output_dir',
32 | type=str,
33 | default='output',
34 | )
35 | parser.add_argument(
36 | '--score_th',
37 | type=float,
38 | default=0.1,
39 | )
40 | parser.add_argument(
41 | '--nms_th',
42 | type=float,
43 | default=0.7,
44 | )
45 | parser.add_argument(
46 | '--input_shape',
47 | type=str,
48 | default='608,1088',
49 | )
50 | parser.add_argument(
51 | '--with_p6',
52 | action='store_true',
53 | help='Whether your model uses p6 in FPN/PAN.',
54 | )
55 |
56 | # tracking args
57 | parser.add_argument(
58 | '--track_thresh',
59 | type=float,
60 | default=0.5,
61 | help='tracking confidence threshold',
62 | )
63 | parser.add_argument(
64 | '--track_buffer',
65 | type=int,
66 | default=30,
67 | help='the frames for keep lost tracks',
68 | )
69 | parser.add_argument(
70 | '--match_thresh',
71 | type=float,
72 | default=0.8,
73 | help='matching threshold for tracking',
74 | )
75 | parser.add_argument(
76 | '--min-box-area',
77 | type=float,
78 | default=10,
79 | help='filter out tiny boxes',
80 | )
81 | parser.add_argument(
82 | '--mot20',
83 | dest='mot20',
84 | default=False,
85 | action='store_true',
86 | help='test mot20.',
87 | )
88 |
89 | args = parser.parse_args()
90 |
91 | return args
92 |
93 |
94 | def main():
95 | # 引数取得
96 | args = get_args()
97 |
98 | use_debug_window = args.use_debug_window
99 |
100 | video_path = args.video
101 | output_dir = args.output_dir
102 |
103 | # ByteTrackerインスタンス生成
104 | byte_tracker = ByteTrackerONNX(args)
105 |
106 | # 動画読み込み
107 | cap = cv2.VideoCapture(video_path)
108 | width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
109 | height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
110 | fps = cap.get(cv2.CAP_PROP_FPS)
111 | frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
112 |
113 | # 動画出力ディレクトリ作成
114 | os.makedirs(output_dir, exist_ok=True)
115 | save_path = os.path.join(output_dir, video_path.split("/")[-1])
116 | logger.info(f"video save path is {save_path}")
117 |
118 | # ビデオライター生成
119 | video_writer = cv2.VideoWriter(
120 | save_path,
121 | cv2.VideoWriter_fourcc(*"mp4v"),
122 | fps,
123 | (int(width), int(height)),
124 | )
125 |
126 | frame_id = 1
127 |
128 | while True:
129 | start_time = time.time()
130 |
131 | # フレーム読み出し
132 | ret, frame = cap.read()
133 | if not ret:
134 | break
135 | debug_image = copy.deepcopy(frame)
136 |
137 | # Byte Tracker推論
138 | _, bboxes, ids, scores = byte_tracker.inference(frame)
139 |
140 | elapsed_time = time.time() - start_time
141 |
142 | # 検出情報描画
143 | debug_image = draw_tracking_info(
144 | debug_image,
145 | bboxes,
146 | ids,
147 | scores,
148 | frame_id,
149 | elapsed_time,
150 | )
151 |
152 | if use_debug_window:
153 | # キー処理(ESC:終了)
154 | key = cv2.waitKey(1)
155 | if key == 27: # ESC
156 | break
157 |
158 | # 画面反映
159 | cv2.imshow('ByteTrack ONNX Sample', debug_image)
160 |
161 | # 動画書き込み
162 | video_writer.write(debug_image)
163 |
164 | logger.info(
165 | 'frame {}/{} ({:.2f} ms)'.format(frame_id, int(frame_count),
166 | elapsed_time * 1000), )
167 | frame_id += 1
168 |
169 | if use_debug_window:
170 | cap.release()
171 | cv2.destroyAllWindows()
172 |
173 |
174 | def get_id_color(index):
175 | temp_index = abs(int(index)) * 3
176 | color = ((37 * temp_index) % 255, (17 * temp_index) % 255,
177 | (29 * temp_index) % 255)
178 | return color
179 |
180 |
181 | def draw_tracking_info(
182 | image,
183 | tlwhs,
184 | ids,
185 | scores,
186 | frame_id=0,
187 | elapsed_time=0.,
188 | ):
189 | text_scale = 1.5
190 | text_thickness = 2
191 | line_thickness = 2
192 |
193 | # フレーム数、処理時間、推論時間
194 | text = 'frame: %d ' % (frame_id)
195 | text += 'elapsed time: %.0fms ' % (elapsed_time * 1000)
196 | text += 'num: %d' % (len(tlwhs))
197 | cv2.putText(
198 | image,
199 | text,
200 | (0, int(15 * text_scale)),
201 | cv2.FONT_HERSHEY_PLAIN,
202 | 2,
203 | (0, 255, 0),
204 | thickness=text_thickness,
205 | )
206 |
207 | for index, tlwh in enumerate(tlwhs):
208 | x1, y1 = int(tlwh[0]), int(tlwh[1])
209 | x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3])
210 |
211 | # バウンディングボックス
212 | color = get_id_color(ids[index])
213 | cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness)
214 |
215 | # ID、スコア
216 | # text = str(ids[index]) + ':%.2f' % (scores[index])
217 | text = str(ids[index])
218 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
219 | text_scale, (0, 0, 0), text_thickness + 3)
220 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
221 | text_scale, (255, 255, 255), text_thickness)
222 | return image
223 |
224 |
225 | if __name__ == '__main__':
226 | main()
227 |
--------------------------------------------------------------------------------
/demo_webcam_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | import argparse
5 |
6 | import cv2
7 |
8 | from byte_tracker.byte_tracker_onnx import ByteTrackerONNX
9 |
10 |
11 | def get_args():
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument(
15 | '--model',
16 | type=str,
17 | default='byte_tracker/model/bytetrack_s.onnx',
18 | )
19 | parser.add_argument('--device', type=int, default=0)
20 | parser.add_argument("--width", help='cap width', type=int, default=960)
21 | parser.add_argument("--height", help='cap height', type=int, default=540)
22 | parser.add_argument(
23 | '--score_th',
24 | type=float,
25 | default=0.1,
26 | )
27 | parser.add_argument(
28 | '--nms_th',
29 | type=float,
30 | default=0.7,
31 | )
32 | parser.add_argument(
33 | '--input_shape',
34 | type=str,
35 | default='608,1088',
36 | )
37 | parser.add_argument(
38 | '--with_p6',
39 | action='store_true',
40 | help='Whether your model uses p6 in FPN/PAN.',
41 | )
42 |
43 | # tracking args
44 | parser.add_argument(
45 | '--track_thresh',
46 | type=float,
47 | default=0.5,
48 | help='tracking confidence threshold',
49 | )
50 | parser.add_argument(
51 | '--track_buffer',
52 | type=int,
53 | default=30,
54 | help='the frames for keep lost tracks',
55 | )
56 | parser.add_argument(
57 | '--match_thresh',
58 | type=float,
59 | default=0.8,
60 | help='matching threshold for tracking',
61 | )
62 | parser.add_argument(
63 | '--min-box-area',
64 | type=float,
65 | default=10,
66 | help='filter out tiny boxes',
67 | )
68 | parser.add_argument(
69 | '--mot20',
70 | dest='mot20',
71 | default=False,
72 | action='store_true',
73 | help='test mot20.',
74 | )
75 |
76 | args = parser.parse_args()
77 |
78 | return args
79 |
80 |
81 | def main():
82 | # 引数取得
83 | args = get_args()
84 |
85 | cap_device = args.device
86 | cap_width = args.width
87 | cap_height = args.height
88 |
89 | # ByteTrackerインスタンス生成
90 | byte_tracker = ByteTrackerONNX(args)
91 |
92 | # カメラ準備
93 | cap = cv2.VideoCapture(cap_device)
94 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, cap_width)
95 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, cap_height)
96 |
97 | frame_id = 1
98 |
99 | while True:
100 | start_time = time.time()
101 |
102 | # フレーム読み出し
103 | ret, frame = cap.read()
104 | if not ret:
105 | break
106 | debug_image = copy.deepcopy(frame)
107 |
108 | # Byte Tracker推論
109 | _, bboxes, ids, scores = byte_tracker.inference(frame)
110 |
111 | elapsed_time = time.time() - start_time
112 |
113 | # 検出情報描画
114 | debug_image = draw_tracking_info(
115 | debug_image,
116 | bboxes,
117 | ids,
118 | scores,
119 | frame_id,
120 | elapsed_time,
121 | )
122 |
123 | # キー処理(ESC:終了)
124 | key = cv2.waitKey(1)
125 | if key == 27: # ESC
126 | break
127 |
128 | # 画面反映
129 | cv2.imshow('ByteTrack ONNX Sample', debug_image)
130 |
131 | frame_id += 1
132 |
133 | cap.release()
134 | cv2.destroyAllWindows()
135 |
136 |
137 | def get_id_color(index):
138 | temp_index = abs(int(index)) * 3
139 | color = ((37 * temp_index) % 255, (17 * temp_index) % 255,
140 | (29 * temp_index) % 255)
141 | return color
142 |
143 |
144 | def draw_tracking_info(
145 | image,
146 | tlwhs,
147 | ids,
148 | scores,
149 | frame_id=0,
150 | elapsed_time=0.,
151 | ):
152 | text_scale = 1.5
153 | text_thickness = 2
154 | line_thickness = 2
155 |
156 | # フレーム数、処理時間、推論時間
157 | text = 'frame: %d ' % (frame_id)
158 | text += 'elapsed time: %.0fms ' % (elapsed_time * 1000)
159 | text += 'num: %d' % (len(tlwhs))
160 | cv2.putText(
161 | image,
162 | text,
163 | (0, int(15 * text_scale)),
164 | cv2.FONT_HERSHEY_PLAIN,
165 | 2,
166 | (0, 255, 0),
167 | thickness=text_thickness,
168 | )
169 |
170 | for index, tlwh in enumerate(tlwhs):
171 | x1, y1 = int(tlwh[0]), int(tlwh[1])
172 | x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3])
173 |
174 | # バウンディングボックス
175 | color = get_id_color(ids[index])
176 | cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness)
177 |
178 | # ID、スコア
179 | # text = str(ids[index]) + ':%.2f' % (scores[index])
180 | text = str(ids[index])
181 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
182 | text_scale, (0, 0, 0), text_thickness + 3)
183 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
184 | text_scale, (255, 255, 255), text_thickness)
185 | return image
186 |
187 |
188 | if __name__ == '__main__':
189 | main()
190 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | onnx
3 | onnxruntime-gpu
4 | Cython
5 | torch
6 | torchvision
7 | pycocotools
8 | scipy
9 | loguru
10 | thop
11 | lap
12 | cython_bbox
13 |
--------------------------------------------------------------------------------
/sample.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/sample.mp4
--------------------------------------------------------------------------------