├── .gitignore ├── ByteTrack_Convert2ONNX.ipynb ├── LICENSE ├── README.md ├── byte_tracker ├── byte_tracker_onnx.py ├── model │ ├── bytetrack_s.onnx │ └── bytetrack_s_mot17.pth.tar ├── tracker │ ├── basetrack.py │ ├── byte_tracker.py │ ├── kalman_filter.py │ └── matching.py └── utils │ └── yolox_utils.py ├── demo_video_onnx.py ├── demo_webcam_onnx.py ├── requirements.txt └── sample.mp4 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # bat 132 | *.bat -------------------------------------------------------------------------------- /ByteTrack_Convert2ONNX.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ByteTrack-Convert2ONNX.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "h7ZxXiORz5xD" 23 | }, 24 | "source": [ 25 | "# ByteTrack リポジトリクローン" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "colab": { 32 | "base_uri": "https://localhost:8080/" 33 | }, 34 | "id": "ldhQVvDGz2Aa", 35 | "outputId": "65d61468-fd80-40fe-d212-c36c9e332e62" 36 | }, 37 | "source": [ 38 | "!git clone https://github.com/ifzhang/ByteTrack" 39 | ], 40 | "execution_count": 1, 41 | "outputs": [ 42 | { 43 | "output_type": "stream", 44 | "name": "stdout", 45 | "text": [ 46 | "Cloning into 'ByteTrack'...\n", 47 | "remote: Enumerating objects: 1857, done.\u001b[K\n", 48 | "remote: Counting objects: 100% (1171/1171), done.\u001b[K\n", 49 | "remote: Compressing objects: 100% (530/530), done.\u001b[K\n", 50 | "remote: Total 1857 (delta 666), reused 1080 (delta 613), pack-reused 686\u001b[K\n", 51 | "Receiving objects: 100% (1857/1857), 78.15 MiB | 26.27 MiB/s, done.\n", 52 | "Resolving deltas: 100% (1040/1040), done.\n" 53 | ] 54 | } 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "bZWU7zk00Aix", 64 | "outputId": "d7ef096d-d4be-442b-c4b3-6e466b93ab53" 65 | }, 66 | "source": [ 67 | "%cd ByteTrack" 68 | ], 69 | "execution_count": 2, 70 | "outputs": [ 71 | { 72 | "output_type": "stream", 73 | "name": "stdout", 74 | "text": [ 75 | "/content/ByteTrack\n" 76 | ] 77 | } 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "L-eeDkfQ05jM" 84 | }, 85 | "source": [ 86 | "!mkdir pretrained" 87 | ], 88 | "execution_count": 3, 89 | "outputs": [] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": { 94 | "id": "XWCWW-rF4l-X" 95 | }, 96 | "source": [ 97 | "# 訓練済みモデルダウンロード" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "colab": { 104 | "base_uri": "https://localhost:8080/", 105 | "height": 375 106 | }, 107 | "id": "JDw91Szr083y", 108 | "outputId": "dc228724-78c0-4c11-8148-f1b1859029c1" 109 | }, 110 | "source": [ 111 | "import gdown\n", 112 | "\n", 113 | "# bytetrack_s_mot17\n", 114 | "gdown.download('https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj', 'bytetrack_s_mot17.pth.tar', quiet=False)\n", 115 | "# bytetrack_m_mot17 \n", 116 | "gdown.download('https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun', 'bytetrack_m_mot17.pth.tar', quiet=False)\n", 117 | "# bytetrack_l_mot17\n", 118 | "gdown.download('https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz', 'bytetrack_l_mot17.pth.tar', quiet=False)\n", 119 | "# bytetrack_x_mot17\n", 120 | "gdown.download('https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5', 'bytetrack_x_mot17.pth.tar', quiet=False)\n", 121 | "\n", 122 | "# bytetrack_x_mot20\n", 123 | "gdown.download('https://drive.google.com/uc?id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U', 'bytetrack_x_mot20.pth.tar', quiet=False)" 124 | ], 125 | "execution_count": 4, 126 | "outputs": [ 127 | { 128 | "output_type": "stream", 129 | "name": "stderr", 130 | "text": [ 131 | "Downloading...\n", 132 | "From: https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj\n", 133 | "To: /content/ByteTrack/bytetrack_s_mot17.pth.tar\n", 134 | "100%|██████████| 71.8M/71.8M [00:00<00:00, 126MB/s]\n", 135 | "Downloading...\n", 136 | "From: https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun\n", 137 | "To: /content/ByteTrack/bytetrack_m_mot17.pth.tar\n", 138 | "100%|██████████| 203M/203M [00:01<00:00, 163MB/s]\n", 139 | "Downloading...\n", 140 | "From: https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz\n", 141 | "To: /content/ByteTrack/bytetrack_l_mot17.pth.tar\n", 142 | "100%|██████████| 434M/434M [00:02<00:00, 154MB/s]\n", 143 | "Downloading...\n", 144 | "From: https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5\n", 145 | "To: /content/ByteTrack/bytetrack_x_mot17.pth.tar\n", 146 | "100%|██████████| 793M/793M [00:06<00:00, 127MB/s]\n", 147 | "Downloading...\n", 148 | "From: https://drive.google.com/uc?id=1HX2_JpMOjOIj1Z9rJjoet9XNy_cCAs5U\n", 149 | "To: /content/ByteTrack/bytetrack_x_mot20.pth.tar\n", 150 | "100%|██████████| 793M/793M [00:12<00:00, 62.2MB/s]\n" 151 | ] 152 | }, 153 | { 154 | "output_type": "execute_result", 155 | "data": { 156 | "application/vnd.google.colaboratory.intrinsic+json": { 157 | "type": "string" 158 | }, 159 | "text/plain": [ 160 | "'bytetrack_x_mot20.pth.tar'" 161 | ] 162 | }, 163 | "metadata": {}, 164 | "execution_count": 4 165 | } 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "metadata": { 171 | "id": "HNInyvrk1X68" 172 | }, 173 | "source": [ 174 | "!mv bytetrack_s_mot17.pth.tar pretrained\n", 175 | "!mv bytetrack_m_mot17.pth.tar pretrained\n", 176 | "!mv bytetrack_l_mot17.pth.tar pretrained\n", 177 | "!mv bytetrack_x_mot17.pth.tar pretrained\n", 178 | "!mv bytetrack_x_mot20.pth.tar pretrained" 179 | ], 180 | "execution_count": 5, 181 | "outputs": [] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "id": "_Eqotubg4te8" 187 | }, 188 | "source": [ 189 | "# 必要パッケージインストール" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "metadata": { 195 | "colab": { 196 | "base_uri": "https://localhost:8080/" 197 | }, 198 | "id": "V8sgVre62vSp", 199 | "outputId": "1ea02318-668e-494e-e43a-c98312e0e884" 200 | }, 201 | "source": [ 202 | "!pip install loguru\n", 203 | "!pip install thop\n", 204 | "!pip install onnx\n", 205 | "!pip install onnx-simplifier" 206 | ], 207 | "execution_count": 6, 208 | "outputs": [ 209 | { 210 | "output_type": "stream", 211 | "name": "stdout", 212 | "text": [ 213 | "Collecting loguru\n", 214 | " Downloading loguru-0.5.3-py3-none-any.whl (57 kB)\n", 215 | "\u001b[?25l\r\u001b[K |█████▊ | 10 kB 23.5 MB/s eta 0:00:01\r\u001b[K |███████████▌ | 20 kB 10.0 MB/s eta 0:00:01\r\u001b[K |█████████████████▏ | 30 kB 8.1 MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 40 kB 7.5 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 51 kB 4.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 57 kB 2.5 MB/s \n", 216 | "\u001b[?25hInstalling collected packages: loguru\n", 217 | "Successfully installed loguru-0.5.3\n", 218 | "Collecting thop\n", 219 | " Downloading thop-0.0.31.post2005241907-py3-none-any.whl (8.7 kB)\n", 220 | "Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from thop) (1.10.0+cu111)\n", 221 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.0.0->thop) (3.10.0.2)\n", 222 | "Installing collected packages: thop\n", 223 | "Successfully installed thop-0.0.31.post2005241907\n", 224 | "Collecting onnx\n", 225 | " Downloading onnx-1.10.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (12.7 MB)\n", 226 | "\u001b[K |████████████████████████████████| 12.7 MB 4.3 MB/s \n", 227 | "\u001b[?25hRequirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from onnx) (1.19.5)\n", 228 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from onnx) (1.15.0)\n", 229 | "Requirement already satisfied: protobuf in /usr/local/lib/python3.7/dist-packages (from onnx) (3.17.3)\n", 230 | "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.7/dist-packages (from onnx) (3.10.0.2)\n", 231 | "Installing collected packages: onnx\n", 232 | "Successfully installed onnx-1.10.2\n", 233 | "Collecting onnx-simplifier\n", 234 | " Downloading onnx-simplifier-0.3.6.tar.gz (13 kB)\n", 235 | "Requirement already satisfied: onnx in /usr/local/lib/python3.7/dist-packages (from onnx-simplifier) (1.10.2)\n", 236 | "Collecting onnxoptimizer>=0.2.5\n", 237 | " Downloading onnxoptimizer-0.2.6-cp37-cp37m-manylinux2014_x86_64.whl (466 kB)\n", 238 | "\u001b[K |████████████████████████████████| 466 kB 4.2 MB/s \n", 239 | "\u001b[?25hCollecting onnxruntime>=1.6.0\n", 240 | " Downloading onnxruntime-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)\n", 241 | "\u001b[K |████████████████████████████████| 4.8 MB 32.7 MB/s \n", 242 | "\u001b[?25hRequirement already satisfied: protobuf>=3.7.0 in /usr/local/lib/python3.7/dist-packages (from onnx-simplifier) (3.17.3)\n", 243 | "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from onnxruntime>=1.6.0->onnx-simplifier) (1.19.5)\n", 244 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from onnxruntime>=1.6.0->onnx-simplifier) (2.0)\n", 245 | "Requirement already satisfied: six>=1.9 in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.7.0->onnx-simplifier) (1.15.0)\n", 246 | "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.7/dist-packages (from onnx->onnx-simplifier) (3.10.0.2)\n", 247 | "Building wheels for collected packages: onnx-simplifier\n", 248 | " Building wheel for onnx-simplifier (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 249 | " Created wheel for onnx-simplifier: filename=onnx_simplifier-0.3.6-py3-none-any.whl size=12873 sha256=486587eb3cf7d3fbf4358732c9239be2ed04fe9e6d58de51c5f93a96bd45bfb2\n", 250 | " Stored in directory: /root/.cache/pip/wheels/0c/47/80/8eb21098e22c19d60b1c14021ee67442b4ad2d7991fdad46ba\n", 251 | "Successfully built onnx-simplifier\n", 252 | "Installing collected packages: onnxruntime, onnxoptimizer, onnx-simplifier\n", 253 | "Successfully installed onnx-simplifier-0.3.6 onnxoptimizer-0.2.6 onnxruntime-1.9.0\n" 254 | ] 255 | } 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": { 261 | "id": "c_VaJ7x_4xRf" 262 | }, 263 | "source": [ 264 | "# ONNXエクスポート" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "metadata": { 270 | "id": "IoyLhN-v3wtN" 271 | }, 272 | "source": [ 273 | "!cp tools/export_onnx.py ./" 274 | ], 275 | "execution_count": 7, 276 | "outputs": [] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "metadata": { 281 | "colab": { 282 | "base_uri": "https://localhost:8080/" 283 | }, 284 | "id": "AoWy7DGg0La7", 285 | "outputId": "d78f03bb-7e19-4d6c-f2db-8eeb981d31f7" 286 | }, 287 | "source": [ 288 | "!python export_onnx.py \\\n", 289 | " --output-name bytetrack_s.onnx \\\n", 290 | " -f exps/example/mot/yolox_s_mix_det.py \\\n", 291 | " -c pretrained/bytetrack_s_mot17.pth.tar" 292 | ], 293 | "execution_count": 8, 294 | "outputs": [ 295 | { 296 | "output_type": "stream", 297 | "name": "stdout", 298 | "text": [ 299 | "\u001b[32m2021-11-18 02:34:16.311\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m52\u001b[0m - \u001b[1margs value: Namespace(ckpt='pretrained/bytetrack_s_mot17.pth.tar', exp_file='exps/example/mot/yolox_s_mix_det.py', experiment_name=None, input='images', name=None, no_onnxsim=False, opset=11, opts=[], output='output', output_name='bytetrack_s.onnx')\u001b[0m\n", 300 | "\u001b[32m2021-11-18 02:34:17.162\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m76\u001b[0m - \u001b[1mloading checkpoint done.\u001b[0m\n", 301 | "\u001b[32m2021-11-18 02:34:26.008\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m86\u001b[0m - \u001b[1mgenerated onnx model named bytetrack_s.onnx\u001b[0m\n", 302 | "\u001b[32m2021-11-18 02:34:29.160\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mmain\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1mgenerated simplified onnx model named bytetrack_s.onnx\u001b[0m\n" 303 | ] 304 | } 305 | ] 306 | } 307 | ] 308 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 KazuhitoTakahashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > **Warning** 2 | >
このリポジトリは [ifzhang/ByteTrack](https://github.com/ifzhang/ByteTrack) のモデルファイルを
単純にONNXに変換しソースコードを整理したリポジトリです。
3 | > 特別な理由がない限り、より汎用的な [yolox-bytetrack-sample](https://github.com/Kazuhito00/yolox-bytetrack-sample) や
[yolox-bytetrack-mcmot-sample](https://github.com/Kazuhito00/yolox-bytetrack-mcmot-sample) の使用をおすすめします。 4 | 5 | # ByteTrack-ONNX-Sample 6 | [ByteTrack(Multi-Object Tracking by Associating Every Detection Box)](https://github.com/ifzhang/ByteTrack)のPythonでのONNX推論サンプルです。
7 | ONNXに変換したモデルも同梱しています。
8 | 変換自体を試したい方は[ByteTrack_Convert2ONNX.ipynb](ByteTrack_Convert2ONNX.ipynb)を使用ください。
9 | [ByteTrack_Convert2ONNX.ipynb](ByteTrack_Convert2ONNX.ipynb)はColaboratory上での実行を想定しています。
10 | 以下の動画はWindowsでの実行例です。 11 | 12 | https://user-images.githubusercontent.com/37477845/142617492-7fef3f6e-5725-480c-b059-0f2dee1606bc.mp4 13 | 14 | # Requirement 15 | opencv-python 4.5.3.56 or later
16 | onnx 1.9.0 or later
17 | onnxruntime-gpu 1.9.0 or later
18 | Cython 0.29.24 or later
19 | torch 1.8.1 or later
20 | torchvision 0.9.1 or later
21 | pycocotools 2.0.2 or later
22 | scipy 1.6.3 or later
23 | loguru 0.5.3 or later
24 | thop 0.0.31.post2005241907 or later
25 | lap 0.4.0 or later
26 | cython_bbox 0.1.3 or later
27 |
28 | ※onnxruntime-gpuはonnxruntimeでも動作しますが、推論時間がかかるためGPUを推奨します
29 | ※Windowsでcython_bbox のインストールが失敗する場合は、GitHubからのインストールをお試しください(2021/11/19時点)
30 | `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox` 31 | 32 | # Demo 33 | デモの実行方法は以下です。 34 | #### 動画:動画に対しByteTrackで追跡した結果を動画出力します 35 | ```bash 36 | python demo_video_onnx.py 37 | ``` 38 |
39 | 実行時オプション 40 | 41 | * --use_debug_window
42 | 動画書き込み時に書き込みフレームをGUI表示するか否か
43 | デフォルト:指定なし 44 | * --model
45 | ByteTrackのONNXモデル格納パス
46 | デフォルト:byte_tracker/model/bytetrack_s.onnx 47 | * --video
48 | 入力動画の格納パス
49 | デフォルト:sample.mp4 50 | * --output_dir
51 | 動画出力パス
52 | デフォルト:output 53 | * --score_th
54 | 人検出のスコア閾値
55 | デフォルト:0.1 56 | * --nms_th
57 | 人検出のNMS閾値
58 | デフォルト:0.7 59 | * --input_shape
60 | 推論時入力サイズ
61 | デフォルト:608,1088 62 | * --with_p6
63 | YOLOXモデルのFPN/PANでp6を含むか否か
64 | デフォルト:指定なし 65 | * --track_thresh
66 | 追跡時のスコア閾値
67 | デフォルト:0.5 68 | * --track_buffer
69 | 見失い時に何フレームの間、追跡対象を保持するか
70 | デフォルト:30 71 | * --match_thresh
72 | 追跡時のマッチングスコア閾値
73 | デフォルト:0.8 74 | * --min-box-area
75 | 最小のバウンディングボックスのサイズ閾値
76 | デフォルト:10 77 | * --mot20
78 | MOT20を使用しているか否か
79 | デフォルト:指定なし 80 |
81 | 82 | #### Webカメラ:Webカメラ画像に対しByteTrackで追跡した結果をGUI表示します 83 | ```bash 84 | python demo_webcam_onnx.py 85 | ``` 86 |
87 | 実行時オプション 88 | 89 | * --model
90 | ByteTrackのONNXモデル格納パス
91 | デフォルト:byte_tracker/model/bytetrack_s.onnx 92 | * --device
93 | カメラデバイス番号の指定
94 | デフォルト:0 95 | * --width
96 | カメラキャプチャ時の横幅
97 | デフォルト:960 98 | * --height
99 | カメラキャプチャ時の縦幅
100 | デフォルト:540 101 | * --score_th
102 | 人検出のスコア閾値
103 | デフォルト:0.1 104 | * --nms_th
105 | 人検出のNMS閾値
106 | デフォルト:0.7 107 | * --input_shape
108 | 推論時入力サイズ
109 | デフォルト:608,1088 110 | * --with_p6
111 | YOLOXモデルのFPN/PANでp6を含むか否か
112 | デフォルト:指定なし 113 | * --track_thresh
114 | 追跡時のスコア閾値
115 | デフォルト:0.5 116 | * --track_buffer
117 | 見失い時に何フレームの間、追跡対象を保持するか
118 | デフォルト:30 119 | * --match_thresh
120 | 追跡時のマッチングスコア閾値
121 | デフォルト:0.8 122 | * --min-box-area
123 | 最小のバウンディングボックスのサイズ閾値
124 | デフォルト:10 125 | * --mot20
126 | MOT20を使用しているか否か
127 | デフォルト:指定なし 128 |
129 | 130 | # Reference 131 | * [ifzhang/ByteTrack](https://github.com/ifzhang/ByteTrack) 132 | 133 | # Author 134 | 高橋かずひと(https://twitter.com/KzhtTkhs) 135 | 136 | # License 137 | ByteTrack-ONNX-Sample is under [MIT License](LICENSE). 138 | 139 | # License(Movie) 140 | サンプル動画は[NHKクリエイティブ・ライブラリー](https://www.nhk.or.jp/archives/creative/)の[イギリス ウースターのエルガー像](https://www2.nhk.or.jp/archives/creative/material/view.cgi?m=D0002011239_00000)を使用しています。 141 | -------------------------------------------------------------------------------- /byte_tracker/byte_tracker_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import copy 4 | 5 | import numpy as np 6 | import onnxruntime 7 | 8 | from byte_tracker.utils.yolox_utils import ( 9 | pre_process, 10 | post_process, 11 | multiclass_nms, 12 | ) 13 | from byte_tracker.tracker.byte_tracker import BYTETracker 14 | 15 | 16 | class ByteTrackerONNX(object): 17 | def __init__(self, args): 18 | self.args = args 19 | 20 | self.rgb_means = (0.485, 0.456, 0.406) 21 | self.std = (0.229, 0.224, 0.225) 22 | 23 | # ByteTracker ONNX読み込み 24 | self.session = onnxruntime.InferenceSession(args.model) 25 | self.input_shape = tuple(map(int, args.input_shape.split(','))) 26 | 27 | # ByteTrackerインスタンス生成 28 | self.tracker = BYTETracker(args, frame_rate=30) 29 | 30 | def _pre_process(self, image): 31 | image_info = {'id': 0} 32 | 33 | image_info['image'] = copy.deepcopy(image) 34 | image_info['width'] = image.shape[1] 35 | image_info['height'] = image.shape[0] 36 | 37 | preprocessed_image, ratio = pre_process( 38 | image, 39 | self.input_shape, 40 | self.rgb_means, 41 | self.std, 42 | ) 43 | image_info['ratio'] = ratio 44 | 45 | return preprocessed_image, image_info 46 | 47 | def inference(self, image): 48 | # 前処理 49 | image, image_info = self._pre_process(image) 50 | 51 | # 推論 52 | input_name = self.session.get_inputs()[0].name 53 | result = self.session.run(None, {input_name: image[None, :, :, :]}) 54 | 55 | # 後処理 56 | dets = self._post_process(result, image_info) 57 | 58 | # トラッカー更新 59 | bboxes, ids, scores = self._tracker_update( 60 | dets, 61 | image_info, 62 | ) 63 | 64 | return image_info, bboxes, ids, scores 65 | 66 | def _post_process(self, result, image_info): 67 | # バウンディングボックス、スコア算出 68 | predictions = post_process( 69 | result[0], 70 | self.input_shape, 71 | p6=self.args.with_p6, 72 | ) 73 | predictions = predictions[0] 74 | boxes = predictions[:, :4] 75 | scores = predictions[:, 4:5] * predictions[:, 5:] 76 | 77 | # 相対座標を絶対座標へ変換 78 | boxes_xyxy = np.ones_like(boxes) 79 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. 80 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. 81 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. 82 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. 83 | boxes_xyxy /= image_info['ratio'] 84 | 85 | # NMS実施 86 | dets = multiclass_nms( 87 | boxes_xyxy, 88 | scores, 89 | nms_thr=self.args.nms_th, 90 | score_thr=self.args.score_th, 91 | ) 92 | 93 | return dets 94 | 95 | def _tracker_update(self, dets, image_info): 96 | # トラッカー更新 97 | online_targets = [] 98 | if dets is not None: 99 | online_targets = self.tracker.update( 100 | dets[:, :-1], 101 | [image_info['height'], image_info['width']], 102 | [image_info['height'], image_info['width']], 103 | ) 104 | 105 | online_tlwhs = [] 106 | online_ids = [] 107 | online_scores = [] 108 | for online_target in online_targets: 109 | tlwh = online_target.tlwh 110 | track_id = online_target.track_id 111 | vertical = tlwh[2] / tlwh[3] > 1.6 112 | if tlwh[2] * tlwh[3] > self.args.min_box_area and not vertical: 113 | online_tlwhs.append(tlwh) 114 | online_ids.append(track_id) 115 | online_scores.append(online_target.score) 116 | 117 | return online_tlwhs, online_ids, online_scores 118 | -------------------------------------------------------------------------------- /byte_tracker/model/bytetrack_s.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/byte_tracker/model/bytetrack_s.onnx -------------------------------------------------------------------------------- /byte_tracker/model/bytetrack_s_mot17.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/byte_tracker/model/bytetrack_s_mot17.pth.tar -------------------------------------------------------------------------------- /byte_tracker/tracker/basetrack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | 5 | class TrackState(object): 6 | New = 0 7 | Tracked = 1 8 | Lost = 2 9 | Removed = 3 10 | 11 | 12 | class BaseTrack(object): 13 | _count = 0 14 | 15 | track_id = 0 16 | is_activated = False 17 | state = TrackState.New 18 | 19 | history = OrderedDict() 20 | features = [] 21 | curr_feature = None 22 | score = 0 23 | start_frame = 0 24 | frame_id = 0 25 | time_since_update = 0 26 | 27 | # multi-camera 28 | location = (np.inf, np.inf) 29 | 30 | @property 31 | def end_frame(self): 32 | return self.frame_id 33 | 34 | @staticmethod 35 | def next_id(): 36 | BaseTrack._count += 1 37 | return BaseTrack._count 38 | 39 | def activate(self, *args): 40 | raise NotImplementedError 41 | 42 | def predict(self): 43 | raise NotImplementedError 44 | 45 | def update(self, *args, **kwargs): 46 | raise NotImplementedError 47 | 48 | def mark_lost(self): 49 | self.state = TrackState.Lost 50 | 51 | def mark_removed(self): 52 | self.state = TrackState.Removed -------------------------------------------------------------------------------- /byte_tracker/tracker/byte_tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import os 4 | import os.path as osp 5 | import copy 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from .kalman_filter import KalmanFilter 10 | from byte_tracker.tracker import matching 11 | from .basetrack import BaseTrack, TrackState 12 | 13 | class STrack(BaseTrack): 14 | shared_kalman = KalmanFilter() 15 | def __init__(self, tlwh, score): 16 | 17 | # wait activate 18 | self._tlwh = np.asarray(tlwh, dtype=np.float) 19 | self.kalman_filter = None 20 | self.mean, self.covariance = None, None 21 | self.is_activated = False 22 | 23 | self.score = score 24 | self.tracklet_len = 0 25 | 26 | def predict(self): 27 | mean_state = self.mean.copy() 28 | if self.state != TrackState.Tracked: 29 | mean_state[7] = 0 30 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) 31 | 32 | @staticmethod 33 | def multi_predict(stracks): 34 | if len(stracks) > 0: 35 | multi_mean = np.asarray([st.mean.copy() for st in stracks]) 36 | multi_covariance = np.asarray([st.covariance for st in stracks]) 37 | for i, st in enumerate(stracks): 38 | if st.state != TrackState.Tracked: 39 | multi_mean[i][7] = 0 40 | multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) 41 | for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): 42 | stracks[i].mean = mean 43 | stracks[i].covariance = cov 44 | 45 | def activate(self, kalman_filter, frame_id): 46 | """Start a new tracklet""" 47 | self.kalman_filter = kalman_filter 48 | self.track_id = self.next_id() 49 | self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) 50 | 51 | self.tracklet_len = 0 52 | self.state = TrackState.Tracked 53 | if frame_id == 1: 54 | self.is_activated = True 55 | # self.is_activated = True 56 | self.frame_id = frame_id 57 | self.start_frame = frame_id 58 | 59 | def re_activate(self, new_track, frame_id, new_id=False): 60 | self.mean, self.covariance = self.kalman_filter.update( 61 | self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) 62 | ) 63 | self.tracklet_len = 0 64 | self.state = TrackState.Tracked 65 | self.is_activated = True 66 | self.frame_id = frame_id 67 | if new_id: 68 | self.track_id = self.next_id() 69 | self.score = new_track.score 70 | 71 | def update(self, new_track, frame_id): 72 | """ 73 | Update a matched track 74 | :type new_track: STrack 75 | :type frame_id: int 76 | :type update_feature: bool 77 | :return: 78 | """ 79 | self.frame_id = frame_id 80 | self.tracklet_len += 1 81 | 82 | new_tlwh = new_track.tlwh 83 | self.mean, self.covariance = self.kalman_filter.update( 84 | self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) 85 | self.state = TrackState.Tracked 86 | self.is_activated = True 87 | 88 | self.score = new_track.score 89 | 90 | @property 91 | # @jit(nopython=True) 92 | def tlwh(self): 93 | """Get current position in bounding box format `(top left x, top left y, 94 | width, height)`. 95 | """ 96 | if self.mean is None: 97 | return self._tlwh.copy() 98 | ret = self.mean[:4].copy() 99 | ret[2] *= ret[3] 100 | ret[:2] -= ret[2:] / 2 101 | return ret 102 | 103 | @property 104 | # @jit(nopython=True) 105 | def tlbr(self): 106 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 107 | `(top left, bottom right)`. 108 | """ 109 | ret = self.tlwh.copy() 110 | ret[2:] += ret[:2] 111 | return ret 112 | 113 | @staticmethod 114 | # @jit(nopython=True) 115 | def tlwh_to_xyah(tlwh): 116 | """Convert bounding box to format `(center x, center y, aspect ratio, 117 | height)`, where the aspect ratio is `width / height`. 118 | """ 119 | ret = np.asarray(tlwh).copy() 120 | ret[:2] += ret[2:] / 2 121 | ret[2] /= ret[3] 122 | return ret 123 | 124 | def to_xyah(self): 125 | return self.tlwh_to_xyah(self.tlwh) 126 | 127 | @staticmethod 128 | # @jit(nopython=True) 129 | def tlbr_to_tlwh(tlbr): 130 | ret = np.asarray(tlbr).copy() 131 | ret[2:] -= ret[:2] 132 | return ret 133 | 134 | @staticmethod 135 | # @jit(nopython=True) 136 | def tlwh_to_tlbr(tlwh): 137 | ret = np.asarray(tlwh).copy() 138 | ret[2:] += ret[:2] 139 | return ret 140 | 141 | def __repr__(self): 142 | return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) 143 | 144 | 145 | class BYTETracker(object): 146 | def __init__(self, args, frame_rate=30): 147 | self.tracked_stracks = [] # type: list[STrack] 148 | self.lost_stracks = [] # type: list[STrack] 149 | self.removed_stracks = [] # type: list[STrack] 150 | 151 | self.frame_id = 0 152 | self.args = args 153 | #self.det_thresh = args.track_thresh 154 | self.det_thresh = args.track_thresh + 0.1 155 | self.buffer_size = int(frame_rate / 30.0 * args.track_buffer) 156 | self.max_time_lost = self.buffer_size 157 | self.kalman_filter = KalmanFilter() 158 | 159 | def update(self, output_results, img_info, img_size): 160 | self.frame_id += 1 161 | activated_starcks = [] 162 | refind_stracks = [] 163 | lost_stracks = [] 164 | removed_stracks = [] 165 | 166 | if output_results.shape[1] == 5: 167 | scores = output_results[:, 4] 168 | bboxes = output_results[:, :4] 169 | else: 170 | output_results = output_results.cpu().numpy() 171 | scores = output_results[:, 4] * output_results[:, 5] 172 | bboxes = output_results[:, :4] # x1y1x2y2 173 | img_h, img_w = img_info[0], img_info[1] 174 | scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w)) 175 | bboxes /= scale 176 | 177 | remain_inds = scores > self.args.track_thresh 178 | inds_low = scores > 0.1 179 | inds_high = scores < self.args.track_thresh 180 | 181 | inds_second = np.logical_and(inds_low, inds_high) 182 | dets_second = bboxes[inds_second] 183 | dets = bboxes[remain_inds] 184 | scores_keep = scores[remain_inds] 185 | scores_second = scores[inds_second] 186 | 187 | if len(dets) > 0: 188 | '''Detections''' 189 | detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for 190 | (tlbr, s) in zip(dets, scores_keep)] 191 | else: 192 | detections = [] 193 | 194 | ''' Add newly detected tracklets to tracked_stracks''' 195 | unconfirmed = [] 196 | tracked_stracks = [] # type: list[STrack] 197 | for track in self.tracked_stracks: 198 | if not track.is_activated: 199 | unconfirmed.append(track) 200 | else: 201 | tracked_stracks.append(track) 202 | 203 | ''' Step 2: First association, with high score detection boxes''' 204 | strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) 205 | # Predict the current location with KF 206 | STrack.multi_predict(strack_pool) 207 | dists = matching.iou_distance(strack_pool, detections) 208 | if not self.args.mot20: 209 | dists = matching.fuse_score(dists, detections) 210 | matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) 211 | 212 | for itracked, idet in matches: 213 | track = strack_pool[itracked] 214 | det = detections[idet] 215 | if track.state == TrackState.Tracked: 216 | track.update(detections[idet], self.frame_id) 217 | activated_starcks.append(track) 218 | else: 219 | track.re_activate(det, self.frame_id, new_id=False) 220 | refind_stracks.append(track) 221 | 222 | ''' Step 3: Second association, with low score detection boxes''' 223 | # association the untrack to the low score detections 224 | if len(dets_second) > 0: 225 | '''Detections''' 226 | detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for 227 | (tlbr, s) in zip(dets_second, scores_second)] 228 | else: 229 | detections_second = [] 230 | r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] 231 | dists = matching.iou_distance(r_tracked_stracks, detections_second) 232 | matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) 233 | for itracked, idet in matches: 234 | track = r_tracked_stracks[itracked] 235 | det = detections_second[idet] 236 | if track.state == TrackState.Tracked: 237 | track.update(det, self.frame_id) 238 | activated_starcks.append(track) 239 | else: 240 | track.re_activate(det, self.frame_id, new_id=False) 241 | refind_stracks.append(track) 242 | 243 | for it in u_track: 244 | track = r_tracked_stracks[it] 245 | if not track.state == TrackState.Lost: 246 | track.mark_lost() 247 | lost_stracks.append(track) 248 | 249 | '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' 250 | detections = [detections[i] for i in u_detection] 251 | dists = matching.iou_distance(unconfirmed, detections) 252 | if not self.args.mot20: 253 | dists = matching.fuse_score(dists, detections) 254 | matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) 255 | for itracked, idet in matches: 256 | unconfirmed[itracked].update(detections[idet], self.frame_id) 257 | activated_starcks.append(unconfirmed[itracked]) 258 | for it in u_unconfirmed: 259 | track = unconfirmed[it] 260 | track.mark_removed() 261 | removed_stracks.append(track) 262 | 263 | """ Step 4: Init new stracks""" 264 | for inew in u_detection: 265 | track = detections[inew] 266 | if track.score < self.det_thresh: 267 | continue 268 | track.activate(self.kalman_filter, self.frame_id) 269 | activated_starcks.append(track) 270 | """ Step 5: Update state""" 271 | for track in self.lost_stracks: 272 | if self.frame_id - track.end_frame > self.max_time_lost: 273 | track.mark_removed() 274 | removed_stracks.append(track) 275 | 276 | # print('Ramained match {} s'.format(t4-t3)) 277 | 278 | self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] 279 | self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) 280 | self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) 281 | self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) 282 | self.lost_stracks.extend(lost_stracks) 283 | self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) 284 | self.removed_stracks.extend(removed_stracks) 285 | self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) 286 | # get scores of lost tracks 287 | output_stracks = [track for track in self.tracked_stracks if track.is_activated] 288 | 289 | return output_stracks 290 | 291 | 292 | def joint_stracks(tlista, tlistb): 293 | exists = {} 294 | res = [] 295 | for t in tlista: 296 | exists[t.track_id] = 1 297 | res.append(t) 298 | for t in tlistb: 299 | tid = t.track_id 300 | if not exists.get(tid, 0): 301 | exists[tid] = 1 302 | res.append(t) 303 | return res 304 | 305 | 306 | def sub_stracks(tlista, tlistb): 307 | stracks = {} 308 | for t in tlista: 309 | stracks[t.track_id] = t 310 | for t in tlistb: 311 | tid = t.track_id 312 | if stracks.get(tid, 0): 313 | del stracks[tid] 314 | return list(stracks.values()) 315 | 316 | 317 | def remove_duplicate_stracks(stracksa, stracksb): 318 | pdist = matching.iou_distance(stracksa, stracksb) 319 | pairs = np.where(pdist < 0.15) 320 | dupa, dupb = list(), list() 321 | for p, q in zip(*pairs): 322 | timep = stracksa[p].frame_id - stracksa[p].start_frame 323 | timeq = stracksb[q].frame_id - stracksb[q].start_frame 324 | if timep > timeq: 325 | dupb.append(q) 326 | else: 327 | dupa.append(p) 328 | resa = [t for i, t in enumerate(stracksa) if not i in dupa] 329 | resb = [t for i, t in enumerate(stracksb) if not i in dupb] 330 | return resa, resb 331 | -------------------------------------------------------------------------------- /byte_tracker/tracker/kalman_filter.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | 6 | """ 7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 9 | function and used as Mahalanobis gating threshold. 10 | """ 11 | chi2inv95 = { 12 | 1: 3.8415, 13 | 2: 5.9915, 14 | 3: 7.8147, 15 | 4: 9.4877, 16 | 5: 11.070, 17 | 6: 12.592, 18 | 7: 14.067, 19 | 8: 15.507, 20 | 9: 16.919} 21 | 22 | 23 | class KalmanFilter(object): 24 | """ 25 | A simple Kalman filter for tracking bounding boxes in image space. 26 | 27 | The 8-dimensional state space 28 | 29 | x, y, a, h, vx, vy, va, vh 30 | 31 | contains the bounding box center position (x, y), aspect ratio a, height h, 32 | and their respective velocities. 33 | 34 | Object motion follows a constant velocity model. The bounding box location 35 | (x, y, a, h) is taken as direct observation of the state space (linear 36 | observation model). 37 | 38 | """ 39 | 40 | def __init__(self): 41 | ndim, dt = 4, 1. 42 | 43 | # Create Kalman filter model matrices. 44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 45 | for i in range(ndim): 46 | self._motion_mat[i, ndim + i] = dt 47 | self._update_mat = np.eye(ndim, 2 * ndim) 48 | 49 | # Motion and observation uncertainty are chosen relative to the current 50 | # state estimate. These weights control the amount of uncertainty in 51 | # the model. This is a bit hacky. 52 | self._std_weight_position = 1. / 20 53 | self._std_weight_velocity = 1. / 160 54 | 55 | def initiate(self, measurement): 56 | """Create track from unassociated measurement. 57 | 58 | Parameters 59 | ---------- 60 | measurement : ndarray 61 | Bounding box coordinates (x, y, a, h) with center position (x, y), 62 | aspect ratio a, and height h. 63 | 64 | Returns 65 | ------- 66 | (ndarray, ndarray) 67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 68 | dimensional) of the new track. Unobserved velocities are initialized 69 | to 0 mean. 70 | 71 | """ 72 | mean_pos = measurement 73 | mean_vel = np.zeros_like(mean_pos) 74 | mean = np.r_[mean_pos, mean_vel] 75 | 76 | std = [ 77 | 2 * self._std_weight_position * measurement[3], 78 | 2 * self._std_weight_position * measurement[3], 79 | 1e-2, 80 | 2 * self._std_weight_position * measurement[3], 81 | 10 * self._std_weight_velocity * measurement[3], 82 | 10 * self._std_weight_velocity * measurement[3], 83 | 1e-5, 84 | 10 * self._std_weight_velocity * measurement[3]] 85 | covariance = np.diag(np.square(std)) 86 | return mean, covariance 87 | 88 | def predict(self, mean, covariance): 89 | """Run Kalman filter prediction step. 90 | 91 | Parameters 92 | ---------- 93 | mean : ndarray 94 | The 8 dimensional mean vector of the object state at the previous 95 | time step. 96 | covariance : ndarray 97 | The 8x8 dimensional covariance matrix of the object state at the 98 | previous time step. 99 | 100 | Returns 101 | ------- 102 | (ndarray, ndarray) 103 | Returns the mean vector and covariance matrix of the predicted 104 | state. Unobserved velocities are initialized to 0 mean. 105 | 106 | """ 107 | std_pos = [ 108 | self._std_weight_position * mean[3], 109 | self._std_weight_position * mean[3], 110 | 1e-2, 111 | self._std_weight_position * mean[3]] 112 | std_vel = [ 113 | self._std_weight_velocity * mean[3], 114 | self._std_weight_velocity * mean[3], 115 | 1e-5, 116 | self._std_weight_velocity * mean[3]] 117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 118 | 119 | #mean = np.dot(self._motion_mat, mean) 120 | mean = np.dot(mean, self._motion_mat.T) 121 | covariance = np.linalg.multi_dot(( 122 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 123 | 124 | return mean, covariance 125 | 126 | def project(self, mean, covariance): 127 | """Project state distribution to measurement space. 128 | 129 | Parameters 130 | ---------- 131 | mean : ndarray 132 | The state's mean vector (8 dimensional array). 133 | covariance : ndarray 134 | The state's covariance matrix (8x8 dimensional). 135 | 136 | Returns 137 | ------- 138 | (ndarray, ndarray) 139 | Returns the projected mean and covariance matrix of the given state 140 | estimate. 141 | 142 | """ 143 | std = [ 144 | self._std_weight_position * mean[3], 145 | self._std_weight_position * mean[3], 146 | 1e-1, 147 | self._std_weight_position * mean[3]] 148 | innovation_cov = np.diag(np.square(std)) 149 | 150 | mean = np.dot(self._update_mat, mean) 151 | covariance = np.linalg.multi_dot(( 152 | self._update_mat, covariance, self._update_mat.T)) 153 | return mean, covariance + innovation_cov 154 | 155 | def multi_predict(self, mean, covariance): 156 | """Run Kalman filter prediction step (Vectorized version). 157 | Parameters 158 | ---------- 159 | mean : ndarray 160 | The Nx8 dimensional mean matrix of the object states at the previous 161 | time step. 162 | covariance : ndarray 163 | The Nx8x8 dimensional covariance matrics of the object states at the 164 | previous time step. 165 | Returns 166 | ------- 167 | (ndarray, ndarray) 168 | Returns the mean vector and covariance matrix of the predicted 169 | state. Unobserved velocities are initialized to 0 mean. 170 | """ 171 | std_pos = [ 172 | self._std_weight_position * mean[:, 3], 173 | self._std_weight_position * mean[:, 3], 174 | 1e-2 * np.ones_like(mean[:, 3]), 175 | self._std_weight_position * mean[:, 3]] 176 | std_vel = [ 177 | self._std_weight_velocity * mean[:, 3], 178 | self._std_weight_velocity * mean[:, 3], 179 | 1e-5 * np.ones_like(mean[:, 3]), 180 | self._std_weight_velocity * mean[:, 3]] 181 | sqr = np.square(np.r_[std_pos, std_vel]).T 182 | 183 | motion_cov = [] 184 | for i in range(len(mean)): 185 | motion_cov.append(np.diag(sqr[i])) 186 | motion_cov = np.asarray(motion_cov) 187 | 188 | mean = np.dot(mean, self._motion_mat.T) 189 | left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) 190 | covariance = np.dot(left, self._motion_mat.T) + motion_cov 191 | 192 | return mean, covariance 193 | 194 | def update(self, mean, covariance, measurement): 195 | """Run Kalman filter correction step. 196 | 197 | Parameters 198 | ---------- 199 | mean : ndarray 200 | The predicted state's mean vector (8 dimensional). 201 | covariance : ndarray 202 | The state's covariance matrix (8x8 dimensional). 203 | measurement : ndarray 204 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 205 | is the center position, a the aspect ratio, and h the height of the 206 | bounding box. 207 | 208 | Returns 209 | ------- 210 | (ndarray, ndarray) 211 | Returns the measurement-corrected state distribution. 212 | 213 | """ 214 | projected_mean, projected_cov = self.project(mean, covariance) 215 | 216 | chol_factor, lower = scipy.linalg.cho_factor( 217 | projected_cov, lower=True, check_finite=False) 218 | kalman_gain = scipy.linalg.cho_solve( 219 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 220 | check_finite=False).T 221 | innovation = measurement - projected_mean 222 | 223 | new_mean = mean + np.dot(innovation, kalman_gain.T) 224 | new_covariance = covariance - np.linalg.multi_dot(( 225 | kalman_gain, projected_cov, kalman_gain.T)) 226 | return new_mean, new_covariance 227 | 228 | def gating_distance(self, mean, covariance, measurements, 229 | only_position=False, metric='maha'): 230 | """Compute gating distance between state distribution and measurements. 231 | A suitable distance threshold can be obtained from `chi2inv95`. If 232 | `only_position` is False, the chi-square distribution has 4 degrees of 233 | freedom, otherwise 2. 234 | Parameters 235 | ---------- 236 | mean : ndarray 237 | Mean vector over the state distribution (8 dimensional). 238 | covariance : ndarray 239 | Covariance of the state distribution (8x8 dimensional). 240 | measurements : ndarray 241 | An Nx4 dimensional matrix of N measurements, each in 242 | format (x, y, a, h) where (x, y) is the bounding box center 243 | position, a the aspect ratio, and h the height. 244 | only_position : Optional[bool] 245 | If True, distance computation is done with respect to the bounding 246 | box center position only. 247 | Returns 248 | ------- 249 | ndarray 250 | Returns an array of length N, where the i-th element contains the 251 | squared Mahalanobis distance between (mean, covariance) and 252 | `measurements[i]`. 253 | """ 254 | mean, covariance = self.project(mean, covariance) 255 | if only_position: 256 | mean, covariance = mean[:2], covariance[:2, :2] 257 | measurements = measurements[:, :2] 258 | 259 | d = measurements - mean 260 | if metric == 'gaussian': 261 | return np.sum(d * d, axis=1) 262 | elif metric == 'maha': 263 | cholesky_factor = np.linalg.cholesky(covariance) 264 | z = scipy.linalg.solve_triangular( 265 | cholesky_factor, d.T, lower=True, check_finite=False, 266 | overwrite_b=True) 267 | squared_maha = np.sum(z * z, axis=0) 268 | return squared_maha 269 | else: 270 | raise ValueError('invalid distance metric') -------------------------------------------------------------------------------- /byte_tracker/tracker/matching.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy 4 | import lap 5 | from scipy.spatial.distance import cdist 6 | 7 | from cython_bbox import bbox_overlaps as bbox_ious 8 | from byte_tracker.tracker import kalman_filter 9 | import time 10 | 11 | def merge_matches(m1, m2, shape): 12 | O,P,Q = shape 13 | m1 = np.asarray(m1) 14 | m2 = np.asarray(m2) 15 | 16 | M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) 17 | M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) 18 | 19 | mask = M1*M2 20 | match = mask.nonzero() 21 | match = list(zip(match[0], match[1])) 22 | unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) 23 | unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) 24 | 25 | return match, unmatched_O, unmatched_Q 26 | 27 | 28 | def _indices_to_matches(cost_matrix, indices, thresh): 29 | matched_cost = cost_matrix[tuple(zip(*indices))] 30 | matched_mask = (matched_cost <= thresh) 31 | 32 | matches = indices[matched_mask] 33 | unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) 34 | unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) 35 | 36 | return matches, unmatched_a, unmatched_b 37 | 38 | 39 | def linear_assignment(cost_matrix, thresh): 40 | if cost_matrix.size == 0: 41 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) 42 | matches, unmatched_a, unmatched_b = [], [], [] 43 | cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) 44 | for ix, mx in enumerate(x): 45 | if mx >= 0: 46 | matches.append([ix, mx]) 47 | unmatched_a = np.where(x < 0)[0] 48 | unmatched_b = np.where(y < 0)[0] 49 | matches = np.asarray(matches) 50 | return matches, unmatched_a, unmatched_b 51 | 52 | 53 | def ious(atlbrs, btlbrs): 54 | """ 55 | Compute cost based on IoU 56 | :type atlbrs: list[tlbr] | np.ndarray 57 | :type atlbrs: list[tlbr] | np.ndarray 58 | 59 | :rtype ious np.ndarray 60 | """ 61 | ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float) 62 | if ious.size == 0: 63 | return ious 64 | 65 | ious = bbox_ious( 66 | np.ascontiguousarray(atlbrs, dtype=np.float), 67 | np.ascontiguousarray(btlbrs, dtype=np.float) 68 | ) 69 | 70 | return ious 71 | 72 | 73 | def iou_distance(atracks, btracks): 74 | """ 75 | Compute cost based on IoU 76 | :type atracks: list[STrack] 77 | :type btracks: list[STrack] 78 | 79 | :rtype cost_matrix np.ndarray 80 | """ 81 | 82 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): 83 | atlbrs = atracks 84 | btlbrs = btracks 85 | else: 86 | atlbrs = [track.tlbr for track in atracks] 87 | btlbrs = [track.tlbr for track in btracks] 88 | _ious = ious(atlbrs, btlbrs) 89 | cost_matrix = 1 - _ious 90 | 91 | return cost_matrix 92 | 93 | def v_iou_distance(atracks, btracks): 94 | """ 95 | Compute cost based on IoU 96 | :type atracks: list[STrack] 97 | :type btracks: list[STrack] 98 | 99 | :rtype cost_matrix np.ndarray 100 | """ 101 | 102 | if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): 103 | atlbrs = atracks 104 | btlbrs = btracks 105 | else: 106 | atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] 107 | btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] 108 | _ious = ious(atlbrs, btlbrs) 109 | cost_matrix = 1 - _ious 110 | 111 | return cost_matrix 112 | 113 | def embedding_distance(tracks, detections, metric='cosine'): 114 | """ 115 | :param tracks: list[STrack] 116 | :param detections: list[BaseTrack] 117 | :param metric: 118 | :return: cost_matrix np.ndarray 119 | """ 120 | 121 | cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float) 122 | if cost_matrix.size == 0: 123 | return cost_matrix 124 | det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float) 125 | #for i, track in enumerate(tracks): 126 | #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) 127 | track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float) 128 | cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features 129 | return cost_matrix 130 | 131 | 132 | def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): 133 | if cost_matrix.size == 0: 134 | return cost_matrix 135 | gating_dim = 2 if only_position else 4 136 | gating_threshold = kalman_filter.chi2inv95[gating_dim] 137 | measurements = np.asarray([det.to_xyah() for det in detections]) 138 | for row, track in enumerate(tracks): 139 | gating_distance = kf.gating_distance( 140 | track.mean, track.covariance, measurements, only_position) 141 | cost_matrix[row, gating_distance > gating_threshold] = np.inf 142 | return cost_matrix 143 | 144 | 145 | def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): 146 | if cost_matrix.size == 0: 147 | return cost_matrix 148 | gating_dim = 2 if only_position else 4 149 | gating_threshold = kalman_filter.chi2inv95[gating_dim] 150 | measurements = np.asarray([det.to_xyah() for det in detections]) 151 | for row, track in enumerate(tracks): 152 | gating_distance = kf.gating_distance( 153 | track.mean, track.covariance, measurements, only_position, metric='maha') 154 | cost_matrix[row, gating_distance > gating_threshold] = np.inf 155 | cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance 156 | return cost_matrix 157 | 158 | 159 | def fuse_iou(cost_matrix, tracks, detections): 160 | if cost_matrix.size == 0: 161 | return cost_matrix 162 | reid_sim = 1 - cost_matrix 163 | iou_dist = iou_distance(tracks, detections) 164 | iou_sim = 1 - iou_dist 165 | fuse_sim = reid_sim * (1 + iou_sim) / 2 166 | det_scores = np.array([det.score for det in detections]) 167 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) 168 | #fuse_sim = fuse_sim * (1 + det_scores) / 2 169 | fuse_cost = 1 - fuse_sim 170 | return fuse_cost 171 | 172 | 173 | def fuse_score(cost_matrix, detections): 174 | if cost_matrix.size == 0: 175 | return cost_matrix 176 | iou_sim = 1 - cost_matrix 177 | det_scores = np.array([det.score for det in detections]) 178 | det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) 179 | fuse_sim = iou_sim * det_scores 180 | fuse_cost = 1 - fuse_sim 181 | return fuse_cost -------------------------------------------------------------------------------- /byte_tracker/utils/yolox_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def nms(boxes, scores, nms_thr): 8 | """Single class NMS implemented in Numpy.""" 9 | x1 = boxes[:, 0] 10 | y1 = boxes[:, 1] 11 | x2 = boxes[:, 2] 12 | y2 = boxes[:, 3] 13 | 14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 15 | order = scores.argsort()[::-1] 16 | 17 | keep = [] 18 | while order.size > 0: 19 | i = order[0] 20 | keep.append(i) 21 | xx1 = np.maximum(x1[i], x1[order[1:]]) 22 | yy1 = np.maximum(y1[i], y1[order[1:]]) 23 | xx2 = np.minimum(x2[i], x2[order[1:]]) 24 | yy2 = np.minimum(y2[i], y2[order[1:]]) 25 | 26 | w = np.maximum(0.0, xx2 - xx1 + 1) 27 | h = np.maximum(0.0, yy2 - yy1 + 1) 28 | inter = w * h 29 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 30 | 31 | inds = np.where(ovr <= nms_thr)[0] 32 | order = order[inds + 1] 33 | 34 | return keep 35 | 36 | 37 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 38 | """Multiclass NMS implemented in Numpy""" 39 | final_dets = [] 40 | num_classes = scores.shape[1] 41 | for cls_ind in range(num_classes): 42 | cls_scores = scores[:, cls_ind] 43 | valid_score_mask = cls_scores > score_thr 44 | if valid_score_mask.sum() == 0: 45 | continue 46 | else: 47 | valid_scores = cls_scores[valid_score_mask] 48 | valid_boxes = boxes[valid_score_mask] 49 | keep = nms(valid_boxes, valid_scores, nms_thr) 50 | if len(keep) > 0: 51 | cls_inds = np.ones((len(keep), 1)) * cls_ind 52 | dets = np.concatenate( 53 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1) 54 | final_dets.append(dets) 55 | if len(final_dets) == 0: 56 | return None 57 | return np.concatenate(final_dets, 0) 58 | 59 | 60 | def pre_process(image, input_size, mean, std, swap=(2, 0, 1)): 61 | if len(image.shape) == 3: 62 | padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 63 | else: 64 | padded_img = np.ones(input_size) * 114.0 65 | img = np.array(image) 66 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 67 | resized_img = cv2.resize( 68 | img, 69 | (int(img.shape[1] * r), int(img.shape[0] * r)), 70 | interpolation=cv2.INTER_LINEAR, 71 | ).astype(np.float32) 72 | padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img 73 | 74 | padded_img = padded_img[:, :, ::-1] 75 | padded_img /= 255.0 76 | if mean is not None: 77 | padded_img -= mean 78 | if std is not None: 79 | padded_img /= std 80 | padded_img = padded_img.transpose(swap) 81 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 82 | return padded_img, r 83 | 84 | 85 | def post_process(outputs, img_size, p6=False): 86 | grids = [] 87 | expanded_strides = [] 88 | 89 | if not p6: 90 | strides = [8, 16, 32] 91 | else: 92 | strides = [8, 16, 32, 64] 93 | 94 | hsizes = [img_size[0] // stride for stride in strides] 95 | wsizes = [img_size[1] // stride for stride in strides] 96 | 97 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 98 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 99 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 100 | grids.append(grid) 101 | shape = grid.shape[:2] 102 | expanded_strides.append(np.full((*shape, 1), stride)) 103 | 104 | grids = np.concatenate(grids, 1) 105 | expanded_strides = np.concatenate(expanded_strides, 1) 106 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 107 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 108 | 109 | return outputs 110 | -------------------------------------------------------------------------------- /demo_video_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import argparse 5 | 6 | import cv2 7 | from loguru import logger 8 | 9 | from byte_tracker.byte_tracker_onnx import ByteTrackerONNX 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | '--use_debug_window', 17 | action='store_true', 18 | ) 19 | 20 | parser.add_argument( 21 | '--model', 22 | type=str, 23 | default='byte_tracker/model/bytetrack_s.onnx', 24 | ) 25 | parser.add_argument( 26 | '--video', 27 | type=str, 28 | default='sample.mp4', 29 | ) 30 | parser.add_argument( 31 | '--output_dir', 32 | type=str, 33 | default='output', 34 | ) 35 | parser.add_argument( 36 | '--score_th', 37 | type=float, 38 | default=0.1, 39 | ) 40 | parser.add_argument( 41 | '--nms_th', 42 | type=float, 43 | default=0.7, 44 | ) 45 | parser.add_argument( 46 | '--input_shape', 47 | type=str, 48 | default='608,1088', 49 | ) 50 | parser.add_argument( 51 | '--with_p6', 52 | action='store_true', 53 | help='Whether your model uses p6 in FPN/PAN.', 54 | ) 55 | 56 | # tracking args 57 | parser.add_argument( 58 | '--track_thresh', 59 | type=float, 60 | default=0.5, 61 | help='tracking confidence threshold', 62 | ) 63 | parser.add_argument( 64 | '--track_buffer', 65 | type=int, 66 | default=30, 67 | help='the frames for keep lost tracks', 68 | ) 69 | parser.add_argument( 70 | '--match_thresh', 71 | type=float, 72 | default=0.8, 73 | help='matching threshold for tracking', 74 | ) 75 | parser.add_argument( 76 | '--min-box-area', 77 | type=float, 78 | default=10, 79 | help='filter out tiny boxes', 80 | ) 81 | parser.add_argument( 82 | '--mot20', 83 | dest='mot20', 84 | default=False, 85 | action='store_true', 86 | help='test mot20.', 87 | ) 88 | 89 | args = parser.parse_args() 90 | 91 | return args 92 | 93 | 94 | def main(): 95 | # 引数取得 96 | args = get_args() 97 | 98 | use_debug_window = args.use_debug_window 99 | 100 | video_path = args.video 101 | output_dir = args.output_dir 102 | 103 | # ByteTrackerインスタンス生成 104 | byte_tracker = ByteTrackerONNX(args) 105 | 106 | # 動画読み込み 107 | cap = cv2.VideoCapture(video_path) 108 | width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) 109 | height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) 110 | fps = cap.get(cv2.CAP_PROP_FPS) 111 | frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) 112 | 113 | # 動画出力ディレクトリ作成 114 | os.makedirs(output_dir, exist_ok=True) 115 | save_path = os.path.join(output_dir, video_path.split("/")[-1]) 116 | logger.info(f"video save path is {save_path}") 117 | 118 | # ビデオライター生成 119 | video_writer = cv2.VideoWriter( 120 | save_path, 121 | cv2.VideoWriter_fourcc(*"mp4v"), 122 | fps, 123 | (int(width), int(height)), 124 | ) 125 | 126 | frame_id = 1 127 | 128 | while True: 129 | start_time = time.time() 130 | 131 | # フレーム読み出し 132 | ret, frame = cap.read() 133 | if not ret: 134 | break 135 | debug_image = copy.deepcopy(frame) 136 | 137 | # Byte Tracker推論 138 | _, bboxes, ids, scores = byte_tracker.inference(frame) 139 | 140 | elapsed_time = time.time() - start_time 141 | 142 | # 検出情報描画 143 | debug_image = draw_tracking_info( 144 | debug_image, 145 | bboxes, 146 | ids, 147 | scores, 148 | frame_id, 149 | elapsed_time, 150 | ) 151 | 152 | if use_debug_window: 153 | # キー処理(ESC:終了) 154 | key = cv2.waitKey(1) 155 | if key == 27: # ESC 156 | break 157 | 158 | # 画面反映 159 | cv2.imshow('ByteTrack ONNX Sample', debug_image) 160 | 161 | # 動画書き込み 162 | video_writer.write(debug_image) 163 | 164 | logger.info( 165 | 'frame {}/{} ({:.2f} ms)'.format(frame_id, int(frame_count), 166 | elapsed_time * 1000), ) 167 | frame_id += 1 168 | 169 | if use_debug_window: 170 | cap.release() 171 | cv2.destroyAllWindows() 172 | 173 | 174 | def get_id_color(index): 175 | temp_index = abs(int(index)) * 3 176 | color = ((37 * temp_index) % 255, (17 * temp_index) % 255, 177 | (29 * temp_index) % 255) 178 | return color 179 | 180 | 181 | def draw_tracking_info( 182 | image, 183 | tlwhs, 184 | ids, 185 | scores, 186 | frame_id=0, 187 | elapsed_time=0., 188 | ): 189 | text_scale = 1.5 190 | text_thickness = 2 191 | line_thickness = 2 192 | 193 | # フレーム数、処理時間、推論時間 194 | text = 'frame: %d ' % (frame_id) 195 | text += 'elapsed time: %.0fms ' % (elapsed_time * 1000) 196 | text += 'num: %d' % (len(tlwhs)) 197 | cv2.putText( 198 | image, 199 | text, 200 | (0, int(15 * text_scale)), 201 | cv2.FONT_HERSHEY_PLAIN, 202 | 2, 203 | (0, 255, 0), 204 | thickness=text_thickness, 205 | ) 206 | 207 | for index, tlwh in enumerate(tlwhs): 208 | x1, y1 = int(tlwh[0]), int(tlwh[1]) 209 | x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3]) 210 | 211 | # バウンディングボックス 212 | color = get_id_color(ids[index]) 213 | cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness) 214 | 215 | # ID、スコア 216 | # text = str(ids[index]) + ':%.2f' % (scores[index]) 217 | text = str(ids[index]) 218 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, 219 | text_scale, (0, 0, 0), text_thickness + 3) 220 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, 221 | text_scale, (255, 255, 255), text_thickness) 222 | return image 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /demo_webcam_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import argparse 5 | 6 | import cv2 7 | 8 | from byte_tracker.byte_tracker_onnx import ByteTrackerONNX 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument( 15 | '--model', 16 | type=str, 17 | default='byte_tracker/model/bytetrack_s.onnx', 18 | ) 19 | parser.add_argument('--device', type=int, default=0) 20 | parser.add_argument("--width", help='cap width', type=int, default=960) 21 | parser.add_argument("--height", help='cap height', type=int, default=540) 22 | parser.add_argument( 23 | '--score_th', 24 | type=float, 25 | default=0.1, 26 | ) 27 | parser.add_argument( 28 | '--nms_th', 29 | type=float, 30 | default=0.7, 31 | ) 32 | parser.add_argument( 33 | '--input_shape', 34 | type=str, 35 | default='608,1088', 36 | ) 37 | parser.add_argument( 38 | '--with_p6', 39 | action='store_true', 40 | help='Whether your model uses p6 in FPN/PAN.', 41 | ) 42 | 43 | # tracking args 44 | parser.add_argument( 45 | '--track_thresh', 46 | type=float, 47 | default=0.5, 48 | help='tracking confidence threshold', 49 | ) 50 | parser.add_argument( 51 | '--track_buffer', 52 | type=int, 53 | default=30, 54 | help='the frames for keep lost tracks', 55 | ) 56 | parser.add_argument( 57 | '--match_thresh', 58 | type=float, 59 | default=0.8, 60 | help='matching threshold for tracking', 61 | ) 62 | parser.add_argument( 63 | '--min-box-area', 64 | type=float, 65 | default=10, 66 | help='filter out tiny boxes', 67 | ) 68 | parser.add_argument( 69 | '--mot20', 70 | dest='mot20', 71 | default=False, 72 | action='store_true', 73 | help='test mot20.', 74 | ) 75 | 76 | args = parser.parse_args() 77 | 78 | return args 79 | 80 | 81 | def main(): 82 | # 引数取得 83 | args = get_args() 84 | 85 | cap_device = args.device 86 | cap_width = args.width 87 | cap_height = args.height 88 | 89 | # ByteTrackerインスタンス生成 90 | byte_tracker = ByteTrackerONNX(args) 91 | 92 | # カメラ準備 93 | cap = cv2.VideoCapture(cap_device) 94 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, cap_width) 95 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, cap_height) 96 | 97 | frame_id = 1 98 | 99 | while True: 100 | start_time = time.time() 101 | 102 | # フレーム読み出し 103 | ret, frame = cap.read() 104 | if not ret: 105 | break 106 | debug_image = copy.deepcopy(frame) 107 | 108 | # Byte Tracker推論 109 | _, bboxes, ids, scores = byte_tracker.inference(frame) 110 | 111 | elapsed_time = time.time() - start_time 112 | 113 | # 検出情報描画 114 | debug_image = draw_tracking_info( 115 | debug_image, 116 | bboxes, 117 | ids, 118 | scores, 119 | frame_id, 120 | elapsed_time, 121 | ) 122 | 123 | # キー処理(ESC:終了) 124 | key = cv2.waitKey(1) 125 | if key == 27: # ESC 126 | break 127 | 128 | # 画面反映 129 | cv2.imshow('ByteTrack ONNX Sample', debug_image) 130 | 131 | frame_id += 1 132 | 133 | cap.release() 134 | cv2.destroyAllWindows() 135 | 136 | 137 | def get_id_color(index): 138 | temp_index = abs(int(index)) * 3 139 | color = ((37 * temp_index) % 255, (17 * temp_index) % 255, 140 | (29 * temp_index) % 255) 141 | return color 142 | 143 | 144 | def draw_tracking_info( 145 | image, 146 | tlwhs, 147 | ids, 148 | scores, 149 | frame_id=0, 150 | elapsed_time=0., 151 | ): 152 | text_scale = 1.5 153 | text_thickness = 2 154 | line_thickness = 2 155 | 156 | # フレーム数、処理時間、推論時間 157 | text = 'frame: %d ' % (frame_id) 158 | text += 'elapsed time: %.0fms ' % (elapsed_time * 1000) 159 | text += 'num: %d' % (len(tlwhs)) 160 | cv2.putText( 161 | image, 162 | text, 163 | (0, int(15 * text_scale)), 164 | cv2.FONT_HERSHEY_PLAIN, 165 | 2, 166 | (0, 255, 0), 167 | thickness=text_thickness, 168 | ) 169 | 170 | for index, tlwh in enumerate(tlwhs): 171 | x1, y1 = int(tlwh[0]), int(tlwh[1]) 172 | x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3]) 173 | 174 | # バウンディングボックス 175 | color = get_id_color(ids[index]) 176 | cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness) 177 | 178 | # ID、スコア 179 | # text = str(ids[index]) + ':%.2f' % (scores[index]) 180 | text = str(ids[index]) 181 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, 182 | text_scale, (0, 0, 0), text_thickness + 3) 183 | cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, 184 | text_scale, (255, 255, 255), text_thickness) 185 | return image 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | onnx 3 | onnxruntime-gpu 4 | Cython 5 | torch 6 | torchvision 7 | pycocotools 8 | scipy 9 | loguru 10 | thop 11 | lap 12 | cython_bbox 13 | -------------------------------------------------------------------------------- /sample.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kazuhito00/ByteTrack-ONNX-Sample/70fd73903decc2071213c181398eb8de23cc2b8e/sample.mp4 --------------------------------------------------------------------------------