├── .gitignore ├── LICENSE ├── README.md ├── application_util ├── __init__.py ├── image_viewer.py ├── preprocessing.py └── visualization.py ├── deep_sort ├── __init__.py ├── detection.py ├── iou_matching.py ├── kalman_filter.py ├── linear_assignment.py ├── nn_matching.py ├── track.py └── tracker.py ├── deep_sort_app.py ├── evaluate_motchallenge.py ├── generate_videos.py ├── requirements-gpu.txt ├── requirements.txt ├── show_results.py └── tools ├── freeze_model.py └── generate_detections.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep SORT 2 | 3 | ## Introduction 4 | 5 | This repository contains code for *Simple Online and Realtime Tracking with a Deep Association Metric* (Deep SORT). 6 | We extend the original [SORT](https://github.com/abewley/sort) algorithm to 7 | integrate appearance information based on a deep appearance descriptor. 8 | See the [arXiv preprint](https://arxiv.org/abs/1703.07402) for more information. 9 | 10 | ## Installation 11 | 12 | First, clone the repository and install dependencies: 13 | ``` 14 | git clone https://github.com/nwojke/deep_sort.git 15 | cd deep_sort 16 | 17 | # The following command installs all the dependencies required to run the 18 | # tracker and regenerate detections. If you only need to run the tracker with 19 | # existing detections, you can use pip install -r requirements.txt instead. 20 | pip install -r requirements-gpu.txt 21 | ``` 22 | Then, download pre-generated detections and the CNN checkpoint file from 23 | [here](https://drive.google.com/open?id=18fKzfqnqhqW3s9zwsCbnVJ5XF2JFeqMp). 24 | 25 | *NOTE:* The candidate object locations of our pre-generated detections are 26 | taken from the following paper: 27 | ``` 28 | F. Yu, W. Li, Q. Li, Y. Liu, X. Shi, J. Yan. POI: Multiple Object Tracking with 29 | High Performance Detection and Appearance Feature. In BMTT, SenseTime Group 30 | Limited, 2016. 31 | ``` 32 | We have replaced the appearance descriptor with a custom deep convolutional 33 | neural network (see below). 34 | 35 | ## Running the tracker 36 | 37 | The following example starts the tracker on one of the 38 | [MOT16 benchmark](https://motchallenge.net/data/MOT16/) 39 | sequences. 40 | We assume resources have been extracted to the repository root directory and 41 | the MOT16 benchmark data is in `./MOT16`: 42 | ``` 43 | python deep_sort_app.py \ 44 | --sequence_dir=./MOT16/test/MOT16-06 \ 45 | --detection_file=./resources/detections/MOT16_POI_test/MOT16-06.npy \ 46 | --min_confidence=0.3 \ 47 | --nn_budget=100 \ 48 | --display=True 49 | ``` 50 | Check `python deep_sort_app.py -h` for an overview of available options. 51 | There are also scripts in the repository to visualize results, generate videos, 52 | and evaluate the MOT challenge benchmark. 53 | 54 | ## Generating detections 55 | 56 | Beside the main tracking application, this repository contains a script to 57 | generate features for person re-identification, suitable to compare the visual 58 | appearance of pedestrian bounding boxes using cosine similarity. 59 | The following example generates these features from standard MOT challenge 60 | detections. Again, we assume resources have been extracted to the repository 61 | root directory and MOT16 data is in `./MOT16`: 62 | ``` 63 | python tools/generate_detections.py \ 64 | --model=resources/networks/mars-small128.pb \ 65 | --mot_dir=./MOT16/train \ 66 | --output_dir=./resources/detections/MOT16_train 67 | ``` 68 | The model has been generated with TensorFlow 1.5. If you run into 69 | incompatibility, re-export the frozen inference graph to obtain a new 70 | `mars-small128.pb` that is compatible with your version: 71 | ``` 72 | python tools/freeze_model.py 73 | ``` 74 | The ``generate_detections.py`` stores for each sequence of the MOT16 dataset 75 | a separate binary file in NumPy native format. Each file contains an array of 76 | shape `Nx138`, where N is the number of detections in the corresponding MOT 77 | sequence. The first 10 columns of this array contain the raw MOT detection 78 | copied over from the input file. The remaining 128 columns store the appearance 79 | descriptor. The files generated by this command can be used as input for the 80 | `deep_sort_app.py`. 81 | 82 | **NOTE**: If ``python tools/generate_detections.py`` raises a TensorFlow error, 83 | try passing an absolute path to the ``--model`` argument. This might help in 84 | some cases. 85 | 86 | ## Training the model 87 | 88 | To train the deep association metric model we used a novel [cosine metric learning](https://github.com/nwojke/cosine_metric_learning) approach which is provided as a separate repository. 89 | 90 | ## Highlevel overview of source files 91 | 92 | In the top-level directory are executable scripts to execute, evaluate, and 93 | visualize the tracker. The main entry point is in `deep_sort_app.py`. 94 | This file runs the tracker on a MOTChallenge sequence. 95 | 96 | In package `deep_sort` is the main tracking code: 97 | 98 | * `detection.py`: Detection base class. 99 | * `kalman_filter.py`: A Kalman filter implementation and concrete 100 | parametrization for image space filtering. 101 | * `linear_assignment.py`: This module contains code for min cost matching and 102 | the matching cascade. 103 | * `iou_matching.py`: This module contains the IOU matching metric. 104 | * `nn_matching.py`: A module for a nearest neighbor matching metric. 105 | * `track.py`: The track class contains single-target track data such as Kalman 106 | state, number of hits, misses, hit streak, associated feature vectors, etc. 107 | * `tracker.py`: This is the multi-target tracker class. 108 | 109 | The `deep_sort_app.py` expects detections in a custom format, stored in .npy 110 | files. These can be computed from MOTChallenge detections using 111 | `generate_detections.py`. We also provide 112 | [pre-generated detections](https://drive.google.com/open?id=1VVqtL0klSUvLnmBKS89il1EKC3IxUBVK). 113 | 114 | ## Citing DeepSORT 115 | 116 | If you find this repo useful in your research, please consider citing the following papers: 117 | 118 | @inproceedings{Wojke2017simple, 119 | title={Simple Online and Realtime Tracking with a Deep Association Metric}, 120 | author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich}, 121 | booktitle={2017 IEEE International Conference on Image Processing (ICIP)}, 122 | year={2017}, 123 | pages={3645--3649}, 124 | organization={IEEE}, 125 | doi={10.1109/ICIP.2017.8296962} 126 | } 127 | 128 | @inproceedings{Wojke2018deep, 129 | title={Deep Cosine Metric Learning for Person Re-identification}, 130 | author={Wojke, Nicolai and Bewley, Alex}, 131 | booktitle={2018 IEEE Winter Conference on Applications of Computer Vision (WACV)}, 132 | year={2018}, 133 | pages={748--756}, 134 | organization={IEEE}, 135 | doi={10.1109/WACV.2018.00087} 136 | } 137 | -------------------------------------------------------------------------------- /application_util/__init__.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 -------------------------------------------------------------------------------- /application_util/image_viewer.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | """ 3 | This module contains an image viewer and drawing routines based on OpenCV. 4 | """ 5 | import numpy as np 6 | import cv2 7 | import time 8 | 9 | 10 | def is_in_bounds(mat, roi): 11 | """Check if ROI is fully contained in the image. 12 | 13 | Parameters 14 | ---------- 15 | mat : ndarray 16 | An ndarray of ndim>=2. 17 | roi : (int, int, int, int) 18 | Region of interest (x, y, width, height) where (x, y) is the top-left 19 | corner. 20 | 21 | Returns 22 | ------- 23 | bool 24 | Returns true if the ROI is contain in mat. 25 | 26 | """ 27 | if roi[0] < 0 or roi[0] + roi[2] >= mat.shape[1]: 28 | return False 29 | if roi[1] < 0 or roi[1] + roi[3] >= mat.shape[0]: 30 | return False 31 | return True 32 | 33 | 34 | def view_roi(mat, roi): 35 | """Get sub-array. 36 | 37 | The ROI must be valid, i.e., fully contained in the image. 38 | 39 | Parameters 40 | ---------- 41 | mat : ndarray 42 | An ndarray of ndim=2 or ndim=3. 43 | roi : (int, int, int, int) 44 | Region of interest (x, y, width, height) where (x, y) is the top-left 45 | corner. 46 | 47 | Returns 48 | ------- 49 | ndarray 50 | A view of the roi. 51 | 52 | """ 53 | sx, ex = roi[0], roi[0] + roi[2] 54 | sy, ey = roi[1], roi[1] + roi[3] 55 | if mat.ndim == 2: 56 | return mat[sy:ey, sx:ex] 57 | else: 58 | return mat[sy:ey, sx:ex, :] 59 | 60 | 61 | class ImageViewer(object): 62 | """An image viewer with drawing routines and video capture capabilities. 63 | 64 | Key Bindings: 65 | 66 | * 'SPACE' : pause 67 | * 'ESC' : quit 68 | 69 | Parameters 70 | ---------- 71 | update_ms : int 72 | Number of milliseconds between frames (1000 / frames per second). 73 | window_shape : (int, int) 74 | Shape of the window (width, height). 75 | caption : Optional[str] 76 | Title of the window. 77 | 78 | Attributes 79 | ---------- 80 | image : ndarray 81 | Color image of shape (height, width, 3). You may directly manipulate 82 | this image to change the view. Otherwise, you may call any of the 83 | drawing routines of this class. Internally, the image is treated as 84 | beeing in BGR color space. 85 | 86 | Note that the image is resized to the the image viewers window_shape 87 | just prior to visualization. Therefore, you may pass differently sized 88 | images and call drawing routines with the appropriate, original point 89 | coordinates. 90 | color : (int, int, int) 91 | Current BGR color code that applies to all drawing routines. 92 | Values are in range [0-255]. 93 | text_color : (int, int, int) 94 | Current BGR text color code that applies to all text rendering 95 | routines. Values are in range [0-255]. 96 | thickness : int 97 | Stroke width in pixels that applies to all drawing routines. 98 | 99 | """ 100 | 101 | def __init__(self, update_ms, window_shape=(640, 480), caption="Figure 1"): 102 | self._window_shape = window_shape 103 | self._caption = caption 104 | self._update_ms = update_ms 105 | self._video_writer = None 106 | self._user_fun = lambda: None 107 | self._terminate = False 108 | 109 | self.image = np.zeros(self._window_shape + (3, ), dtype=np.uint8) 110 | self._color = (0, 0, 0) 111 | self.text_color = (255, 255, 255) 112 | self.thickness = 1 113 | 114 | @property 115 | def color(self): 116 | return self._color 117 | 118 | @color.setter 119 | def color(self, value): 120 | if len(value) != 3: 121 | raise ValueError("color must be tuple of 3") 122 | self._color = tuple(int(c) for c in value) 123 | 124 | def rectangle(self, x, y, w, h, label=None): 125 | """Draw a rectangle. 126 | 127 | Parameters 128 | ---------- 129 | x : float | int 130 | Top left corner of the rectangle (x-axis). 131 | y : float | int 132 | Top let corner of the rectangle (y-axis). 133 | w : float | int 134 | Width of the rectangle. 135 | h : float | int 136 | Height of the rectangle. 137 | label : Optional[str] 138 | A text label that is placed at the top left corner of the 139 | rectangle. 140 | 141 | """ 142 | pt1 = int(x), int(y) 143 | pt2 = int(x + w), int(y + h) 144 | cv2.rectangle(self.image, pt1, pt2, self._color, self.thickness) 145 | if label is not None: 146 | text_size = cv2.getTextSize( 147 | label, cv2.FONT_HERSHEY_PLAIN, 1, self.thickness) 148 | 149 | center = pt1[0] + 5, pt1[1] + 5 + text_size[0][1] 150 | pt2 = pt1[0] + 10 + text_size[0][0], pt1[1] + 10 + \ 151 | text_size[0][1] 152 | cv2.rectangle(self.image, pt1, pt2, self._color, -1) 153 | cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN, 154 | 1, (255, 255, 255), self.thickness) 155 | 156 | def circle(self, x, y, radius, label=None): 157 | """Draw a circle. 158 | 159 | Parameters 160 | ---------- 161 | x : float | int 162 | Center of the circle (x-axis). 163 | y : float | int 164 | Center of the circle (y-axis). 165 | radius : float | int 166 | Radius of the circle in pixels. 167 | label : Optional[str] 168 | A text label that is placed at the center of the circle. 169 | 170 | """ 171 | image_size = int(radius + self.thickness + 1.5) # actually half size 172 | roi = int(x - image_size), int(y - image_size), \ 173 | int(2 * image_size), int(2 * image_size) 174 | if not is_in_bounds(self.image, roi): 175 | return 176 | 177 | image = view_roi(self.image, roi) 178 | center = image.shape[1] // 2, image.shape[0] // 2 179 | cv2.circle( 180 | image, center, int(radius + .5), self._color, self.thickness) 181 | if label is not None: 182 | cv2.putText( 183 | self.image, label, center, cv2.FONT_HERSHEY_PLAIN, 184 | 2, self.text_color, 2) 185 | 186 | def gaussian(self, mean, covariance, label=None): 187 | """Draw 95% confidence ellipse of a 2-D Gaussian distribution. 188 | 189 | Parameters 190 | ---------- 191 | mean : array_like 192 | The mean vector of the Gaussian distribution (ndim=1). 193 | covariance : array_like 194 | The 2x2 covariance matrix of the Gaussian distribution. 195 | label : Optional[str] 196 | A text label that is placed at the center of the ellipse. 197 | 198 | """ 199 | # chi2inv(0.95, 2) = 5.9915 200 | vals, vecs = np.linalg.eigh(5.9915 * covariance) 201 | indices = vals.argsort()[::-1] 202 | vals, vecs = np.sqrt(vals[indices]), vecs[:, indices] 203 | 204 | center = int(mean[0] + .5), int(mean[1] + .5) 205 | axes = int(vals[0] + .5), int(vals[1] + .5) 206 | angle = int(180. * np.arctan2(vecs[1, 0], vecs[0, 0]) / np.pi) 207 | cv2.ellipse( 208 | self.image, center, axes, angle, 0, 360, self._color, 2) 209 | if label is not None: 210 | cv2.putText(self.image, label, center, cv2.FONT_HERSHEY_PLAIN, 211 | 2, self.text_color, 2) 212 | 213 | def annotate(self, x, y, text): 214 | """Draws a text string at a given location. 215 | 216 | Parameters 217 | ---------- 218 | x : int | float 219 | Bottom-left corner of the text in the image (x-axis). 220 | y : int | float 221 | Bottom-left corner of the text in the image (y-axis). 222 | text : str 223 | The text to be drawn. 224 | 225 | """ 226 | cv2.putText(self.image, text, (int(x), int(y)), cv2.FONT_HERSHEY_PLAIN, 227 | 2, self.text_color, 2) 228 | 229 | def colored_points(self, points, colors=None, skip_index_check=False): 230 | """Draw a collection of points. 231 | 232 | The point size is fixed to 1. 233 | 234 | Parameters 235 | ---------- 236 | points : ndarray 237 | The Nx2 array of image locations, where the first dimension is 238 | the x-coordinate and the second dimension is the y-coordinate. 239 | colors : Optional[ndarray] 240 | The Nx3 array of colors (dtype=np.uint8). If None, the current 241 | color attribute is used. 242 | skip_index_check : Optional[bool] 243 | If True, index range checks are skipped. This is faster, but 244 | requires all points to lie within the image dimensions. 245 | 246 | """ 247 | if not skip_index_check: 248 | cond1, cond2 = points[:, 0] >= 0, points[:, 0] < 480 249 | cond3, cond4 = points[:, 1] >= 0, points[:, 1] < 640 250 | indices = np.logical_and.reduce((cond1, cond2, cond3, cond4)) 251 | points = points[indices, :] 252 | if colors is None: 253 | colors = np.repeat( 254 | self._color, len(points)).reshape(3, len(points)).T 255 | indices = (points + .5).astype(np.int64) 256 | self.image[indices[:, 1], indices[:, 0], :] = colors 257 | 258 | def enable_videowriter(self, output_filename, fourcc_string="MJPG", 259 | fps=None): 260 | """ Write images to video file. 261 | 262 | Parameters 263 | ---------- 264 | output_filename : str 265 | Output filename. 266 | fourcc_string : str 267 | The OpenCV FOURCC code that defines the video codec (check OpenCV 268 | documentation for more information). 269 | fps : Optional[float] 270 | Frames per second. If None, configured according to current 271 | parameters. 272 | 273 | """ 274 | fourcc = cv2.VideoWriter_fourcc(*fourcc_string) 275 | if fps is None: 276 | fps = int(1000. / self._update_ms) 277 | self._video_writer = cv2.VideoWriter( 278 | output_filename, fourcc, fps, self._window_shape) 279 | 280 | def disable_videowriter(self): 281 | """ Disable writing videos. 282 | """ 283 | self._video_writer = None 284 | 285 | def run(self, update_fun=None): 286 | """Start the image viewer. 287 | 288 | This method blocks until the user requests to close the window. 289 | 290 | Parameters 291 | ---------- 292 | update_fun : Optional[Callable[] -> None] 293 | An optional callable that is invoked at each frame. May be used 294 | to play an animation/a video sequence. 295 | 296 | """ 297 | if update_fun is not None: 298 | self._user_fun = update_fun 299 | 300 | self._terminate, is_paused = False, False 301 | # print("ImageViewer is paused, press space to start.") 302 | while not self._terminate: 303 | t0 = time.time() 304 | if not is_paused: 305 | self._terminate = not self._user_fun() 306 | if self._video_writer is not None: 307 | self._video_writer.write( 308 | cv2.resize(self.image, self._window_shape)) 309 | t1 = time.time() 310 | remaining_time = max(1, int(self._update_ms - 1e3*(t1-t0))) 311 | cv2.imshow( 312 | self._caption, cv2.resize(self.image, self._window_shape[:2])) 313 | key = cv2.waitKey(remaining_time) 314 | if key & 255 == 27: # ESC 315 | print("terminating") 316 | self._terminate = True 317 | elif key & 255 == 32: # ' ' 318 | print("toggeling pause: " + str(not is_paused)) 319 | is_paused = not is_paused 320 | elif key & 255 == 115: # 's' 321 | print("stepping") 322 | self._terminate = not self._user_fun() 323 | is_paused = True 324 | 325 | # Due to a bug in OpenCV we must call imshow after destroying the 326 | # window. This will make the window appear again as soon as waitKey 327 | # is called. 328 | # 329 | # see https://github.com/Itseez/opencv/issues/4535 330 | self.image[:] = 0 331 | cv2.destroyWindow(self._caption) 332 | cv2.waitKey(1) 333 | cv2.imshow(self._caption, self.image) 334 | 335 | def stop(self): 336 | """Stop the control loop. 337 | 338 | After calling this method, the viewer will stop execution before the 339 | next frame and hand over control flow to the user. 340 | 341 | Parameters 342 | ---------- 343 | 344 | """ 345 | self._terminate = True 346 | -------------------------------------------------------------------------------- /application_util/preprocessing.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def non_max_suppression(boxes, max_bbox_overlap, scores=None): 7 | """Suppress overlapping detections. 8 | 9 | Original code from [1]_ has been adapted to include confidence score. 10 | 11 | .. [1] http://www.pyimagesearch.com/2015/02/16/ 12 | faster-non-maximum-suppression-python/ 13 | 14 | Examples 15 | -------- 16 | 17 | >>> boxes = [d.roi for d in detections] 18 | >>> scores = [d.confidence for d in detections] 19 | >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores) 20 | >>> detections = [detections[i] for i in indices] 21 | 22 | Parameters 23 | ---------- 24 | boxes : ndarray 25 | Array of ROIs (x, y, width, height). 26 | max_bbox_overlap : float 27 | ROIs that overlap more than this values are suppressed. 28 | scores : Optional[array_like] 29 | Detector confidence score. 30 | 31 | Returns 32 | ------- 33 | List[int] 34 | Returns indices of detections that have survived non-maxima suppression. 35 | 36 | """ 37 | if len(boxes) == 0: 38 | return [] 39 | 40 | boxes = boxes.astype(np.float64) 41 | pick = [] 42 | 43 | x1 = boxes[:, 0] 44 | y1 = boxes[:, 1] 45 | x2 = boxes[:, 2] + boxes[:, 0] 46 | y2 = boxes[:, 3] + boxes[:, 1] 47 | 48 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | if scores is not None: 50 | idxs = np.argsort(scores) 51 | else: 52 | idxs = np.argsort(y2) 53 | 54 | while len(idxs) > 0: 55 | last = len(idxs) - 1 56 | i = idxs[last] 57 | pick.append(i) 58 | 59 | xx1 = np.maximum(x1[i], x1[idxs[:last]]) 60 | yy1 = np.maximum(y1[i], y1[idxs[:last]]) 61 | xx2 = np.minimum(x2[i], x2[idxs[:last]]) 62 | yy2 = np.minimum(y2[i], y2[idxs[:last]]) 63 | 64 | w = np.maximum(0, xx2 - xx1 + 1) 65 | h = np.maximum(0, yy2 - yy1 + 1) 66 | 67 | overlap = (w * h) / area[idxs[:last]] 68 | 69 | idxs = np.delete( 70 | idxs, np.concatenate( 71 | ([last], np.where(overlap > max_bbox_overlap)[0]))) 72 | 73 | return pick 74 | -------------------------------------------------------------------------------- /application_util/visualization.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import colorsys 4 | from .image_viewer import ImageViewer 5 | 6 | 7 | def create_unique_color_float(tag, hue_step=0.41): 8 | """Create a unique RGB color code for a given track id (tag). 9 | 10 | The color code is generated in HSV color space by moving along the 11 | hue angle and gradually changing the saturation. 12 | 13 | Parameters 14 | ---------- 15 | tag : int 16 | The unique target identifying tag. 17 | hue_step : float 18 | Difference between two neighboring color codes in HSV space (more 19 | specifically, the distance in hue channel). 20 | 21 | Returns 22 | ------- 23 | (float, float, float) 24 | RGB color code in range [0, 1] 25 | 26 | """ 27 | h, v = (tag * hue_step) % 1, 1. - (int(tag * hue_step) % 4) / 5. 28 | r, g, b = colorsys.hsv_to_rgb(h, 1., v) 29 | return r, g, b 30 | 31 | 32 | def create_unique_color_uchar(tag, hue_step=0.41): 33 | """Create a unique RGB color code for a given track id (tag). 34 | 35 | The color code is generated in HSV color space by moving along the 36 | hue angle and gradually changing the saturation. 37 | 38 | Parameters 39 | ---------- 40 | tag : int 41 | The unique target identifying tag. 42 | hue_step : float 43 | Difference between two neighboring color codes in HSV space (more 44 | specifically, the distance in hue channel). 45 | 46 | Returns 47 | ------- 48 | (int, int, int) 49 | RGB color code in range [0, 255] 50 | 51 | """ 52 | r, g, b = create_unique_color_float(tag, hue_step) 53 | return int(255*r), int(255*g), int(255*b) 54 | 55 | 56 | class NoVisualization(object): 57 | """ 58 | A dummy visualization object that loops through all frames in a given 59 | sequence to update the tracker without performing any visualization. 60 | """ 61 | 62 | def __init__(self, seq_info): 63 | self.frame_idx = seq_info["min_frame_idx"] 64 | self.last_idx = seq_info["max_frame_idx"] 65 | 66 | def set_image(self, image): 67 | pass 68 | 69 | def draw_groundtruth(self, track_ids, boxes): 70 | pass 71 | 72 | def draw_detections(self, detections): 73 | pass 74 | 75 | def draw_trackers(self, trackers): 76 | pass 77 | 78 | def run(self, frame_callback): 79 | while self.frame_idx <= self.last_idx: 80 | frame_callback(self, self.frame_idx) 81 | self.frame_idx += 1 82 | 83 | 84 | class Visualization(object): 85 | """ 86 | This class shows tracking output in an OpenCV image viewer. 87 | """ 88 | 89 | def __init__(self, seq_info, update_ms): 90 | image_shape = seq_info["image_size"][::-1] 91 | aspect_ratio = float(image_shape[1]) / image_shape[0] 92 | image_shape = 1024, int(aspect_ratio * 1024) 93 | self.viewer = ImageViewer( 94 | update_ms, image_shape, "Figure %s" % seq_info["sequence_name"]) 95 | self.viewer.thickness = 2 96 | self.frame_idx = seq_info["min_frame_idx"] 97 | self.last_idx = seq_info["max_frame_idx"] 98 | 99 | def run(self, frame_callback): 100 | self.viewer.run(lambda: self._update_fun(frame_callback)) 101 | 102 | def _update_fun(self, frame_callback): 103 | if self.frame_idx > self.last_idx: 104 | return False # Terminate 105 | frame_callback(self, self.frame_idx) 106 | self.frame_idx += 1 107 | return True 108 | 109 | def set_image(self, image): 110 | self.viewer.image = image 111 | 112 | def draw_groundtruth(self, track_ids, boxes): 113 | self.viewer.thickness = 2 114 | for track_id, box in zip(track_ids, boxes): 115 | self.viewer.color = create_unique_color_uchar(track_id) 116 | self.viewer.rectangle(*box.astype(np.int64), label=str(track_id)) 117 | 118 | def draw_detections(self, detections): 119 | self.viewer.thickness = 2 120 | self.viewer.color = 0, 0, 255 121 | for i, detection in enumerate(detections): 122 | self.viewer.rectangle(*detection.tlwh) 123 | 124 | def draw_trackers(self, tracks): 125 | self.viewer.thickness = 2 126 | for track in tracks: 127 | if not track.is_confirmed() or track.time_since_update > 0: 128 | continue 129 | self.viewer.color = create_unique_color_uchar(track.track_id) 130 | self.viewer.rectangle( 131 | *track.to_tlwh().astype(np.int64), label=str(track.track_id)) 132 | # self.viewer.gaussian(track.mean[:2], track.covariance[:2, :2], 133 | # label="%d" % track.track_id) 134 | # 135 | -------------------------------------------------------------------------------- /deep_sort/__init__.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | -------------------------------------------------------------------------------- /deep_sort/detection.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | class Detection(object): 6 | """ 7 | This class represents a bounding box detection in a single image. 8 | 9 | Parameters 10 | ---------- 11 | tlwh : array_like 12 | Bounding box in format `(x, y, w, h)`. 13 | confidence : float 14 | Detector confidence score. 15 | feature : array_like 16 | A feature vector that describes the object contained in this image. 17 | 18 | Attributes 19 | ---------- 20 | tlwh : ndarray 21 | Bounding box in format `(top left x, top left y, width, height)`. 22 | confidence : ndarray 23 | Detector confidence score. 24 | feature : ndarray | NoneType 25 | A feature vector that describes the object contained in this image. 26 | 27 | """ 28 | 29 | def __init__(self, tlwh, confidence, feature): 30 | self.tlwh = np.asarray(tlwh, dtype=np.float64) 31 | self.confidence = float(confidence) 32 | self.feature = np.asarray(feature, dtype=np.float32) 33 | 34 | def to_tlbr(self): 35 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 36 | `(top left, bottom right)`. 37 | """ 38 | ret = self.tlwh.copy() 39 | ret[2:] += ret[:2] 40 | return ret 41 | 42 | def to_xyah(self): 43 | """Convert bounding box to format `(center x, center y, aspect ratio, 44 | height)`, where the aspect ratio is `width / height`. 45 | """ 46 | ret = self.tlwh.copy() 47 | ret[:2] += ret[2:] / 2 48 | ret[2] /= ret[3] 49 | return ret 50 | -------------------------------------------------------------------------------- /deep_sort/iou_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import linear_assignment 5 | 6 | 7 | def iou(bbox, candidates): 8 | """Computer intersection over union. 9 | 10 | Parameters 11 | ---------- 12 | bbox : ndarray 13 | A bounding box in format `(top left x, top left y, width, height)`. 14 | candidates : ndarray 15 | A matrix of candidate bounding boxes (one per row) in the same format 16 | as `bbox`. 17 | 18 | Returns 19 | ------- 20 | ndarray 21 | The intersection over union in [0, 1] between the `bbox` and each 22 | candidate. A higher score means a larger fraction of the `bbox` is 23 | occluded by the candidate. 24 | 25 | """ 26 | bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] 27 | candidates_tl = candidates[:, :2] 28 | candidates_br = candidates[:, :2] + candidates[:, 2:] 29 | 30 | tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], 31 | np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] 32 | br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], 33 | np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] 34 | wh = np.maximum(0., br - tl) 35 | 36 | area_intersection = wh.prod(axis=1) 37 | area_bbox = bbox[2:].prod() 38 | area_candidates = candidates[:, 2:].prod(axis=1) 39 | return area_intersection / (area_bbox + area_candidates - area_intersection) 40 | 41 | 42 | def iou_cost(tracks, detections, track_indices=None, 43 | detection_indices=None): 44 | """An intersection over union distance metric. 45 | 46 | Parameters 47 | ---------- 48 | tracks : List[deep_sort.track.Track] 49 | A list of tracks. 50 | detections : List[deep_sort.detection.Detection] 51 | A list of detections. 52 | track_indices : Optional[List[int]] 53 | A list of indices to tracks that should be matched. Defaults to 54 | all `tracks`. 55 | detection_indices : Optional[List[int]] 56 | A list of indices to detections that should be matched. Defaults 57 | to all `detections`. 58 | 59 | Returns 60 | ------- 61 | ndarray 62 | Returns a cost matrix of shape 63 | len(track_indices), len(detection_indices) where entry (i, j) is 64 | `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. 65 | 66 | """ 67 | if track_indices is None: 68 | track_indices = np.arange(len(tracks)) 69 | if detection_indices is None: 70 | detection_indices = np.arange(len(detections)) 71 | 72 | cost_matrix = np.zeros((len(track_indices), len(detection_indices))) 73 | for row, track_idx in enumerate(track_indices): 74 | if tracks[track_idx].time_since_update > 1: 75 | cost_matrix[row, :] = linear_assignment.INFTY_COST 76 | continue 77 | 78 | bbox = tracks[track_idx].to_tlwh() 79 | candidates = np.asarray([detections[i].tlwh for i in detection_indices]) 80 | cost_matrix[row, :] = 1. - iou(bbox, candidates) 81 | return cost_matrix 82 | -------------------------------------------------------------------------------- /deep_sort/kalman_filter.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | 6 | """ 7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 9 | function and used as Mahalanobis gating threshold. 10 | """ 11 | chi2inv95 = { 12 | 1: 3.8415, 13 | 2: 5.9915, 14 | 3: 7.8147, 15 | 4: 9.4877, 16 | 5: 11.070, 17 | 6: 12.592, 18 | 7: 14.067, 19 | 8: 15.507, 20 | 9: 16.919} 21 | 22 | 23 | class KalmanFilter(object): 24 | """ 25 | A simple Kalman filter for tracking bounding boxes in image space. 26 | 27 | The 8-dimensional state space 28 | 29 | x, y, a, h, vx, vy, va, vh 30 | 31 | contains the bounding box center position (x, y), aspect ratio a, height h, 32 | and their respective velocities. 33 | 34 | Object motion follows a constant velocity model. The bounding box location 35 | (x, y, a, h) is taken as direct observation of the state space (linear 36 | observation model). 37 | 38 | """ 39 | 40 | def __init__(self): 41 | ndim, dt = 4, 1. 42 | 43 | # Create Kalman filter model matrices. 44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 45 | for i in range(ndim): 46 | self._motion_mat[i, ndim + i] = dt 47 | self._update_mat = np.eye(ndim, 2 * ndim) 48 | 49 | # Motion and observation uncertainty are chosen relative to the current 50 | # state estimate. These weights control the amount of uncertainty in 51 | # the model. This is a bit hacky. 52 | self._std_weight_position = 1. / 20 53 | self._std_weight_velocity = 1. / 160 54 | 55 | def initiate(self, measurement): 56 | """Create track from unassociated measurement. 57 | 58 | Parameters 59 | ---------- 60 | measurement : ndarray 61 | Bounding box coordinates (x, y, a, h) with center position (x, y), 62 | aspect ratio a, and height h. 63 | 64 | Returns 65 | ------- 66 | (ndarray, ndarray) 67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 68 | dimensional) of the new track. Unobserved velocities are initialized 69 | to 0 mean. 70 | 71 | """ 72 | mean_pos = measurement 73 | mean_vel = np.zeros_like(mean_pos) 74 | mean = np.r_[mean_pos, mean_vel] 75 | 76 | std = [ 77 | 2 * self._std_weight_position * measurement[3], 78 | 2 * self._std_weight_position * measurement[3], 79 | 1e-2, 80 | 2 * self._std_weight_position * measurement[3], 81 | 10 * self._std_weight_velocity * measurement[3], 82 | 10 * self._std_weight_velocity * measurement[3], 83 | 1e-5, 84 | 10 * self._std_weight_velocity * measurement[3]] 85 | covariance = np.diag(np.square(std)) 86 | return mean, covariance 87 | 88 | def predict(self, mean, covariance): 89 | """Run Kalman filter prediction step. 90 | 91 | Parameters 92 | ---------- 93 | mean : ndarray 94 | The 8 dimensional mean vector of the object state at the previous 95 | time step. 96 | covariance : ndarray 97 | The 8x8 dimensional covariance matrix of the object state at the 98 | previous time step. 99 | 100 | Returns 101 | ------- 102 | (ndarray, ndarray) 103 | Returns the mean vector and covariance matrix of the predicted 104 | state. Unobserved velocities are initialized to 0 mean. 105 | 106 | """ 107 | std_pos = [ 108 | self._std_weight_position * mean[3], 109 | self._std_weight_position * mean[3], 110 | 1e-2, 111 | self._std_weight_position * mean[3]] 112 | std_vel = [ 113 | self._std_weight_velocity * mean[3], 114 | self._std_weight_velocity * mean[3], 115 | 1e-5, 116 | self._std_weight_velocity * mean[3]] 117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 118 | 119 | mean = np.dot(self._motion_mat, mean) 120 | covariance = np.linalg.multi_dot(( 121 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 122 | 123 | return mean, covariance 124 | 125 | def project(self, mean, covariance): 126 | """Project state distribution to measurement space. 127 | 128 | Parameters 129 | ---------- 130 | mean : ndarray 131 | The state's mean vector (8 dimensional array). 132 | covariance : ndarray 133 | The state's covariance matrix (8x8 dimensional). 134 | 135 | Returns 136 | ------- 137 | (ndarray, ndarray) 138 | Returns the projected mean and covariance matrix of the given state 139 | estimate. 140 | 141 | """ 142 | std = [ 143 | self._std_weight_position * mean[3], 144 | self._std_weight_position * mean[3], 145 | 1e-1, 146 | self._std_weight_position * mean[3]] 147 | innovation_cov = np.diag(np.square(std)) 148 | 149 | mean = np.dot(self._update_mat, mean) 150 | covariance = np.linalg.multi_dot(( 151 | self._update_mat, covariance, self._update_mat.T)) 152 | return mean, covariance + innovation_cov 153 | 154 | def update(self, mean, covariance, measurement): 155 | """Run Kalman filter correction step. 156 | 157 | Parameters 158 | ---------- 159 | mean : ndarray 160 | The predicted state's mean vector (8 dimensional). 161 | covariance : ndarray 162 | The state's covariance matrix (8x8 dimensional). 163 | measurement : ndarray 164 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 165 | is the center position, a the aspect ratio, and h the height of the 166 | bounding box. 167 | 168 | Returns 169 | ------- 170 | (ndarray, ndarray) 171 | Returns the measurement-corrected state distribution. 172 | 173 | """ 174 | projected_mean, projected_cov = self.project(mean, covariance) 175 | 176 | chol_factor, lower = scipy.linalg.cho_factor( 177 | projected_cov, lower=True, check_finite=False) 178 | kalman_gain = scipy.linalg.cho_solve( 179 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 180 | check_finite=False).T 181 | innovation = measurement - projected_mean 182 | 183 | new_mean = mean + np.dot(innovation, kalman_gain.T) 184 | new_covariance = covariance - np.linalg.multi_dot(( 185 | kalman_gain, projected_cov, kalman_gain.T)) 186 | return new_mean, new_covariance 187 | 188 | def gating_distance(self, mean, covariance, measurements, 189 | only_position=False): 190 | """Compute gating distance between state distribution and measurements. 191 | 192 | A suitable distance threshold can be obtained from `chi2inv95`. If 193 | `only_position` is False, the chi-square distribution has 4 degrees of 194 | freedom, otherwise 2. 195 | 196 | Parameters 197 | ---------- 198 | mean : ndarray 199 | Mean vector over the state distribution (8 dimensional). 200 | covariance : ndarray 201 | Covariance of the state distribution (8x8 dimensional). 202 | measurements : ndarray 203 | An Nx4 dimensional matrix of N measurements, each in 204 | format (x, y, a, h) where (x, y) is the bounding box center 205 | position, a the aspect ratio, and h the height. 206 | only_position : Optional[bool] 207 | If True, distance computation is done with respect to the bounding 208 | box center position only. 209 | 210 | Returns 211 | ------- 212 | ndarray 213 | Returns an array of length N, where the i-th element contains the 214 | squared Mahalanobis distance between (mean, covariance) and 215 | `measurements[i]`. 216 | 217 | """ 218 | mean, covariance = self.project(mean, covariance) 219 | if only_position: 220 | mean, covariance = mean[:2], covariance[:2, :2] 221 | measurements = measurements[:, :2] 222 | 223 | cholesky_factor = np.linalg.cholesky(covariance) 224 | d = measurements - mean 225 | z = scipy.linalg.solve_triangular( 226 | cholesky_factor, d.T, lower=True, check_finite=False, 227 | overwrite_b=True) 228 | squared_maha = np.sum(z * z, axis=0) 229 | return squared_maha 230 | -------------------------------------------------------------------------------- /deep_sort/linear_assignment.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | from . import kalman_filter 6 | 7 | 8 | INFTY_COST = 1e+5 9 | 10 | 11 | def min_cost_matching( 12 | distance_metric, max_distance, tracks, detections, track_indices=None, 13 | detection_indices=None): 14 | """Solve linear assignment problem. 15 | 16 | Parameters 17 | ---------- 18 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 19 | The distance metric is given a list of tracks and detections as well as 20 | a list of N track indices and M detection indices. The metric should 21 | return the NxM dimensional cost matrix, where element (i, j) is the 22 | association cost between the i-th track in the given track indices and 23 | the j-th detection in the given detection_indices. 24 | max_distance : float 25 | Gating threshold. Associations with cost larger than this value are 26 | disregarded. 27 | tracks : List[track.Track] 28 | A list of predicted tracks at the current time step. 29 | detections : List[detection.Detection] 30 | A list of detections at the current time step. 31 | track_indices : List[int] 32 | List of track indices that maps rows in `cost_matrix` to tracks in 33 | `tracks` (see description above). 34 | detection_indices : List[int] 35 | List of detection indices that maps columns in `cost_matrix` to 36 | detections in `detections` (see description above). 37 | 38 | Returns 39 | ------- 40 | (List[(int, int)], List[int], List[int]) 41 | Returns a tuple with the following three entries: 42 | * A list of matched track and detection indices. 43 | * A list of unmatched track indices. 44 | * A list of unmatched detection indices. 45 | 46 | """ 47 | if track_indices is None: 48 | track_indices = np.arange(len(tracks)) 49 | if detection_indices is None: 50 | detection_indices = np.arange(len(detections)) 51 | 52 | if len(detection_indices) == 0 or len(track_indices) == 0: 53 | return [], track_indices, detection_indices # Nothing to match. 54 | 55 | cost_matrix = distance_metric( 56 | tracks, detections, track_indices, detection_indices) 57 | cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 58 | indices = np.asarray(linear_sum_assignment(cost_matrix)).T 59 | 60 | matches, unmatched_tracks, unmatched_detections = [], [], [] 61 | for col, detection_idx in enumerate(detection_indices): 62 | if col not in indices[:, 1]: 63 | unmatched_detections.append(detection_idx) 64 | for row, track_idx in enumerate(track_indices): 65 | if row not in indices[:, 0]: 66 | unmatched_tracks.append(track_idx) 67 | for row, col in indices: 68 | track_idx = track_indices[row] 69 | detection_idx = detection_indices[col] 70 | if cost_matrix[row, col] > max_distance: 71 | unmatched_tracks.append(track_idx) 72 | unmatched_detections.append(detection_idx) 73 | else: 74 | matches.append((track_idx, detection_idx)) 75 | return matches, unmatched_tracks, unmatched_detections 76 | 77 | 78 | def matching_cascade( 79 | distance_metric, max_distance, cascade_depth, tracks, detections, 80 | track_indices=None, detection_indices=None): 81 | """Run matching cascade. 82 | 83 | Parameters 84 | ---------- 85 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 86 | The distance metric is given a list of tracks and detections as well as 87 | a list of N track indices and M detection indices. The metric should 88 | return the NxM dimensional cost matrix, where element (i, j) is the 89 | association cost between the i-th track in the given track indices and 90 | the j-th detection in the given detection indices. 91 | max_distance : float 92 | Gating threshold. Associations with cost larger than this value are 93 | disregarded. 94 | cascade_depth: int 95 | The cascade depth, should be se to the maximum track age. 96 | tracks : List[track.Track] 97 | A list of predicted tracks at the current time step. 98 | detections : List[detection.Detection] 99 | A list of detections at the current time step. 100 | track_indices : Optional[List[int]] 101 | List of track indices that maps rows in `cost_matrix` to tracks in 102 | `tracks` (see description above). Defaults to all tracks. 103 | detection_indices : Optional[List[int]] 104 | List of detection indices that maps columns in `cost_matrix` to 105 | detections in `detections` (see description above). Defaults to all 106 | detections. 107 | 108 | Returns 109 | ------- 110 | (List[(int, int)], List[int], List[int]) 111 | Returns a tuple with the following three entries: 112 | * A list of matched track and detection indices. 113 | * A list of unmatched track indices. 114 | * A list of unmatched detection indices. 115 | 116 | """ 117 | if track_indices is None: 118 | track_indices = list(range(len(tracks))) 119 | if detection_indices is None: 120 | detection_indices = list(range(len(detections))) 121 | 122 | unmatched_detections = detection_indices 123 | matches = [] 124 | for level in range(cascade_depth): 125 | if len(unmatched_detections) == 0: # No detections left 126 | break 127 | 128 | track_indices_l = [ 129 | k for k in track_indices 130 | if tracks[k].time_since_update == 1 + level 131 | ] 132 | if len(track_indices_l) == 0: # Nothing to match at this level 133 | continue 134 | 135 | matches_l, _, unmatched_detections = \ 136 | min_cost_matching( 137 | distance_metric, max_distance, tracks, detections, 138 | track_indices_l, unmatched_detections) 139 | matches += matches_l 140 | unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) 141 | return matches, unmatched_tracks, unmatched_detections 142 | 143 | 144 | def gate_cost_matrix( 145 | kf, cost_matrix, tracks, detections, track_indices, detection_indices, 146 | gated_cost=INFTY_COST, only_position=False): 147 | """Invalidate infeasible entries in cost matrix based on the state 148 | distributions obtained by Kalman filtering. 149 | 150 | Parameters 151 | ---------- 152 | kf : The Kalman filter. 153 | cost_matrix : ndarray 154 | The NxM dimensional cost matrix, where N is the number of track indices 155 | and M is the number of detection indices, such that entry (i, j) is the 156 | association cost between `tracks[track_indices[i]]` and 157 | `detections[detection_indices[j]]`. 158 | tracks : List[track.Track] 159 | A list of predicted tracks at the current time step. 160 | detections : List[detection.Detection] 161 | A list of detections at the current time step. 162 | track_indices : List[int] 163 | List of track indices that maps rows in `cost_matrix` to tracks in 164 | `tracks` (see description above). 165 | detection_indices : List[int] 166 | List of detection indices that maps columns in `cost_matrix` to 167 | detections in `detections` (see description above). 168 | gated_cost : Optional[float] 169 | Entries in the cost matrix corresponding to infeasible associations are 170 | set this value. Defaults to a very large value. 171 | only_position : Optional[bool] 172 | If True, only the x, y position of the state distribution is considered 173 | during gating. Defaults to False. 174 | 175 | Returns 176 | ------- 177 | ndarray 178 | Returns the modified cost matrix. 179 | 180 | """ 181 | gating_dim = 2 if only_position else 4 182 | gating_threshold = kalman_filter.chi2inv95[gating_dim] 183 | measurements = np.asarray( 184 | [detections[i].to_xyah() for i in detection_indices]) 185 | for row, track_idx in enumerate(track_indices): 186 | track = tracks[track_idx] 187 | gating_distance = kf.gating_distance( 188 | track.mean, track.covariance, measurements, only_position) 189 | cost_matrix[row, gating_distance > gating_threshold] = gated_cost 190 | return cost_matrix 191 | -------------------------------------------------------------------------------- /deep_sort/nn_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | def _pdist(a, b): 6 | """Compute pair-wise squared distance between points in `a` and `b`. 7 | 8 | Parameters 9 | ---------- 10 | a : array_like 11 | An NxM matrix of N samples of dimensionality M. 12 | b : array_like 13 | An LxM matrix of L samples of dimensionality M. 14 | 15 | Returns 16 | ------- 17 | ndarray 18 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 19 | contains the squared distance between `a[i]` and `b[j]`. 20 | 21 | """ 22 | a, b = np.asarray(a), np.asarray(b) 23 | if len(a) == 0 or len(b) == 0: 24 | return np.zeros((len(a), len(b))) 25 | a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) 26 | r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] 27 | r2 = np.clip(r2, 0., float(np.inf)) 28 | return r2 29 | 30 | 31 | def _cosine_distance(a, b, data_is_normalized=False): 32 | """Compute pair-wise cosine distance between points in `a` and `b`. 33 | 34 | Parameters 35 | ---------- 36 | a : array_like 37 | An NxM matrix of N samples of dimensionality M. 38 | b : array_like 39 | An LxM matrix of L samples of dimensionality M. 40 | data_is_normalized : Optional[bool] 41 | If True, assumes rows in a and b are unit length vectors. 42 | Otherwise, a and b are explicitly normalized to lenght 1. 43 | 44 | Returns 45 | ------- 46 | ndarray 47 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 48 | contains the squared distance between `a[i]` and `b[j]`. 49 | 50 | """ 51 | if not data_is_normalized: 52 | a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) 53 | b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) 54 | return 1. - np.dot(a, b.T) 55 | 56 | 57 | def _nn_euclidean_distance(x, y): 58 | """ Helper function for nearest neighbor distance metric (Euclidean). 59 | 60 | Parameters 61 | ---------- 62 | x : ndarray 63 | A matrix of N row-vectors (sample points). 64 | y : ndarray 65 | A matrix of M row-vectors (query points). 66 | 67 | Returns 68 | ------- 69 | ndarray 70 | A vector of length M that contains for each entry in `y` the 71 | smallest Euclidean distance to a sample in `x`. 72 | 73 | """ 74 | distances = _pdist(x, y) 75 | return np.maximum(0.0, distances.min(axis=0)) 76 | 77 | 78 | def _nn_cosine_distance(x, y): 79 | """ Helper function for nearest neighbor distance metric (cosine). 80 | 81 | Parameters 82 | ---------- 83 | x : ndarray 84 | A matrix of N row-vectors (sample points). 85 | y : ndarray 86 | A matrix of M row-vectors (query points). 87 | 88 | Returns 89 | ------- 90 | ndarray 91 | A vector of length M that contains for each entry in `y` the 92 | smallest cosine distance to a sample in `x`. 93 | 94 | """ 95 | distances = _cosine_distance(x, y) 96 | return distances.min(axis=0) 97 | 98 | 99 | class NearestNeighborDistanceMetric(object): 100 | """ 101 | A nearest neighbor distance metric that, for each target, returns 102 | the closest distance to any sample that has been observed so far. 103 | 104 | Parameters 105 | ---------- 106 | metric : str 107 | Either "euclidean" or "cosine". 108 | matching_threshold: float 109 | The matching threshold. Samples with larger distance are considered an 110 | invalid match. 111 | budget : Optional[int] 112 | If not None, fix samples per class to at most this number. Removes 113 | the oldest samples when the budget is reached. 114 | 115 | Attributes 116 | ---------- 117 | samples : Dict[int -> List[ndarray]] 118 | A dictionary that maps from target identities to the list of samples 119 | that have been observed so far. 120 | 121 | """ 122 | 123 | def __init__(self, metric, matching_threshold, budget=None): 124 | 125 | 126 | if metric == "euclidean": 127 | self._metric = _nn_euclidean_distance 128 | elif metric == "cosine": 129 | self._metric = _nn_cosine_distance 130 | else: 131 | raise ValueError( 132 | "Invalid metric; must be either 'euclidean' or 'cosine'") 133 | self.matching_threshold = matching_threshold 134 | self.budget = budget 135 | self.samples = {} 136 | 137 | def partial_fit(self, features, targets, active_targets): 138 | """Update the distance metric with new data. 139 | 140 | Parameters 141 | ---------- 142 | features : ndarray 143 | An NxM matrix of N features of dimensionality M. 144 | targets : ndarray 145 | An integer array of associated target identities. 146 | active_targets : List[int] 147 | A list of targets that are currently present in the scene. 148 | 149 | """ 150 | for feature, target in zip(features, targets): 151 | self.samples.setdefault(target, []).append(feature) 152 | if self.budget is not None: 153 | self.samples[target] = self.samples[target][-self.budget:] 154 | self.samples = {k: self.samples[k] for k in active_targets} 155 | 156 | def distance(self, features, targets): 157 | """Compute distance between features and targets. 158 | 159 | Parameters 160 | ---------- 161 | features : ndarray 162 | An NxM matrix of N features of dimensionality M. 163 | targets : List[int] 164 | A list of targets to match the given `features` against. 165 | 166 | Returns 167 | ------- 168 | ndarray 169 | Returns a cost matrix of shape len(targets), len(features), where 170 | element (i, j) contains the closest squared distance between 171 | `targets[i]` and `features[j]`. 172 | 173 | """ 174 | cost_matrix = np.zeros((len(targets), len(features))) 175 | for i, target in enumerate(targets): 176 | cost_matrix[i, :] = self._metric(self.samples[target], features) 177 | return cost_matrix 178 | -------------------------------------------------------------------------------- /deep_sort/track.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | 3 | 4 | class TrackState: 5 | """ 6 | Enumeration type for the single target track state. Newly created tracks are 7 | classified as `tentative` until enough evidence has been collected. Then, 8 | the track state is changed to `confirmed`. Tracks that are no longer alive 9 | are classified as `deleted` to mark them for removal from the set of active 10 | tracks. 11 | 12 | """ 13 | 14 | Tentative = 1 15 | Confirmed = 2 16 | Deleted = 3 17 | 18 | 19 | class Track: 20 | """ 21 | A single target track with state space `(x, y, a, h)` and associated 22 | velocities, where `(x, y)` is the center of the bounding box, `a` is the 23 | aspect ratio and `h` is the height. 24 | 25 | Parameters 26 | ---------- 27 | mean : ndarray 28 | Mean vector of the initial state distribution. 29 | covariance : ndarray 30 | Covariance matrix of the initial state distribution. 31 | track_id : int 32 | A unique track identifier. 33 | n_init : int 34 | Number of consecutive detections before the track is confirmed. The 35 | track state is set to `Deleted` if a miss occurs within the first 36 | `n_init` frames. 37 | max_age : int 38 | The maximum number of consecutive misses before the track state is 39 | set to `Deleted`. 40 | feature : Optional[ndarray] 41 | Feature vector of the detection this track originates from. If not None, 42 | this feature is added to the `features` cache. 43 | 44 | Attributes 45 | ---------- 46 | mean : ndarray 47 | Mean vector of the initial state distribution. 48 | covariance : ndarray 49 | Covariance matrix of the initial state distribution. 50 | track_id : int 51 | A unique track identifier. 52 | hits : int 53 | Total number of measurement updates. 54 | age : int 55 | Total number of frames since first occurance. 56 | time_since_update : int 57 | Total number of frames since last measurement update. 58 | state : TrackState 59 | The current track state. 60 | features : List[ndarray] 61 | A cache of features. On each measurement update, the associated feature 62 | vector is added to this list. 63 | 64 | """ 65 | 66 | def __init__(self, mean, covariance, track_id, n_init, max_age, 67 | feature=None): 68 | self.mean = mean 69 | self.covariance = covariance 70 | self.track_id = track_id 71 | self.hits = 1 72 | self.age = 1 73 | self.time_since_update = 0 74 | 75 | self.state = TrackState.Tentative 76 | self.features = [] 77 | if feature is not None: 78 | self.features.append(feature) 79 | 80 | self._n_init = n_init 81 | self._max_age = max_age 82 | 83 | def to_tlwh(self): 84 | """Get current position in bounding box format `(top left x, top left y, 85 | width, height)`. 86 | 87 | Returns 88 | ------- 89 | ndarray 90 | The bounding box. 91 | 92 | """ 93 | ret = self.mean[:4].copy() 94 | ret[2] *= ret[3] 95 | ret[:2] -= ret[2:] / 2 96 | return ret 97 | 98 | def to_tlbr(self): 99 | """Get current position in bounding box format `(min x, miny, max x, 100 | max y)`. 101 | 102 | Returns 103 | ------- 104 | ndarray 105 | The bounding box. 106 | 107 | """ 108 | ret = self.to_tlwh() 109 | ret[2:] = ret[:2] + ret[2:] 110 | return ret 111 | 112 | def predict(self, kf): 113 | """Propagate the state distribution to the current time step using a 114 | Kalman filter prediction step. 115 | 116 | Parameters 117 | ---------- 118 | kf : kalman_filter.KalmanFilter 119 | The Kalman filter. 120 | 121 | """ 122 | self.mean, self.covariance = kf.predict(self.mean, self.covariance) 123 | self.age += 1 124 | self.time_since_update += 1 125 | 126 | def update(self, kf, detection): 127 | """Perform Kalman filter measurement update step and update the feature 128 | cache. 129 | 130 | Parameters 131 | ---------- 132 | kf : kalman_filter.KalmanFilter 133 | The Kalman filter. 134 | detection : Detection 135 | The associated detection. 136 | 137 | """ 138 | self.mean, self.covariance = kf.update( 139 | self.mean, self.covariance, detection.to_xyah()) 140 | self.features.append(detection.feature) 141 | 142 | self.hits += 1 143 | self.time_since_update = 0 144 | if self.state == TrackState.Tentative and self.hits >= self._n_init: 145 | self.state = TrackState.Confirmed 146 | 147 | def mark_missed(self): 148 | """Mark this track as missed (no association at the current time step). 149 | """ 150 | if self.state == TrackState.Tentative: 151 | self.state = TrackState.Deleted 152 | elif self.time_since_update > self._max_age: 153 | self.state = TrackState.Deleted 154 | 155 | def is_tentative(self): 156 | """Returns True if this track is tentative (unconfirmed). 157 | """ 158 | return self.state == TrackState.Tentative 159 | 160 | def is_confirmed(self): 161 | """Returns True if this track is confirmed.""" 162 | return self.state == TrackState.Confirmed 163 | 164 | def is_deleted(self): 165 | """Returns True if this track is dead and should be deleted.""" 166 | return self.state == TrackState.Deleted 167 | -------------------------------------------------------------------------------- /deep_sort/tracker.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import kalman_filter 5 | from . import linear_assignment 6 | from . import iou_matching 7 | from .track import Track 8 | 9 | 10 | class Tracker: 11 | """ 12 | This is the multi-target tracker. 13 | 14 | Parameters 15 | ---------- 16 | metric : nn_matching.NearestNeighborDistanceMetric 17 | A distance metric for measurement-to-track association. 18 | max_age : int 19 | Maximum number of missed misses before a track is deleted. 20 | n_init : int 21 | Number of consecutive detections before the track is confirmed. The 22 | track state is set to `Deleted` if a miss occurs within the first 23 | `n_init` frames. 24 | 25 | Attributes 26 | ---------- 27 | metric : nn_matching.NearestNeighborDistanceMetric 28 | The distance metric used for measurement to track association. 29 | max_age : int 30 | Maximum number of missed misses before a track is deleted. 31 | n_init : int 32 | Number of frames that a track remains in initialization phase. 33 | kf : kalman_filter.KalmanFilter 34 | A Kalman filter to filter target trajectories in image space. 35 | tracks : List[Track] 36 | The list of active tracks at the current time step. 37 | 38 | """ 39 | 40 | def __init__(self, metric, max_iou_distance=0.7, max_age=30, n_init=3): 41 | self.metric = metric 42 | self.max_iou_distance = max_iou_distance 43 | self.max_age = max_age 44 | self.n_init = n_init 45 | 46 | self.kf = kalman_filter.KalmanFilter() 47 | self.tracks = [] 48 | self._next_id = 1 49 | 50 | def predict(self): 51 | """Propagate track state distributions one time step forward. 52 | 53 | This function should be called once every time step, before `update`. 54 | """ 55 | for track in self.tracks: 56 | track.predict(self.kf) 57 | 58 | def update(self, detections): 59 | """Perform measurement update and track management. 60 | 61 | Parameters 62 | ---------- 63 | detections : List[deep_sort.detection.Detection] 64 | A list of detections at the current time step. 65 | 66 | """ 67 | # Run matching cascade. 68 | matches, unmatched_tracks, unmatched_detections = \ 69 | self._match(detections) 70 | 71 | # Update track set. 72 | for track_idx, detection_idx in matches: 73 | self.tracks[track_idx].update( 74 | self.kf, detections[detection_idx]) 75 | for track_idx in unmatched_tracks: 76 | self.tracks[track_idx].mark_missed() 77 | for detection_idx in unmatched_detections: 78 | self._initiate_track(detections[detection_idx]) 79 | self.tracks = [t for t in self.tracks if not t.is_deleted()] 80 | 81 | # Update distance metric. 82 | active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] 83 | features, targets = [], [] 84 | for track in self.tracks: 85 | if not track.is_confirmed(): 86 | continue 87 | features += track.features 88 | targets += [track.track_id for _ in track.features] 89 | track.features = [] 90 | self.metric.partial_fit( 91 | np.asarray(features), np.asarray(targets), active_targets) 92 | 93 | def _match(self, detections): 94 | 95 | def gated_metric(tracks, dets, track_indices, detection_indices): 96 | features = np.array([dets[i].feature for i in detection_indices]) 97 | targets = np.array([tracks[i].track_id for i in track_indices]) 98 | cost_matrix = self.metric.distance(features, targets) 99 | cost_matrix = linear_assignment.gate_cost_matrix( 100 | self.kf, cost_matrix, tracks, dets, track_indices, 101 | detection_indices) 102 | 103 | return cost_matrix 104 | 105 | # Split track set into confirmed and unconfirmed tracks. 106 | confirmed_tracks = [ 107 | i for i, t in enumerate(self.tracks) if t.is_confirmed()] 108 | unconfirmed_tracks = [ 109 | i for i, t in enumerate(self.tracks) if not t.is_confirmed()] 110 | 111 | # Associate confirmed tracks using appearance features. 112 | matches_a, unmatched_tracks_a, unmatched_detections = \ 113 | linear_assignment.matching_cascade( 114 | gated_metric, self.metric.matching_threshold, self.max_age, 115 | self.tracks, detections, confirmed_tracks) 116 | 117 | # Associate remaining tracks together with unconfirmed tracks using IOU. 118 | iou_track_candidates = unconfirmed_tracks + [ 119 | k for k in unmatched_tracks_a if 120 | self.tracks[k].time_since_update == 1] 121 | unmatched_tracks_a = [ 122 | k for k in unmatched_tracks_a if 123 | self.tracks[k].time_since_update != 1] 124 | matches_b, unmatched_tracks_b, unmatched_detections = \ 125 | linear_assignment.min_cost_matching( 126 | iou_matching.iou_cost, self.max_iou_distance, self.tracks, 127 | detections, iou_track_candidates, unmatched_detections) 128 | 129 | matches = matches_a + matches_b 130 | unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) 131 | return matches, unmatched_tracks, unmatched_detections 132 | 133 | def _initiate_track(self, detection): 134 | mean, covariance = self.kf.initiate(detection.to_xyah()) 135 | self.tracks.append(Track( 136 | mean, covariance, self._next_id, self.n_init, self.max_age, 137 | detection.feature)) 138 | self._next_id += 1 139 | -------------------------------------------------------------------------------- /deep_sort_app.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import division, print_function, absolute_import 3 | 4 | import argparse 5 | import os 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | from application_util import preprocessing 11 | from application_util import visualization 12 | from deep_sort import nn_matching 13 | from deep_sort.detection import Detection 14 | from deep_sort.tracker import Tracker 15 | 16 | 17 | def gather_sequence_info(sequence_dir, detection_file): 18 | """Gather sequence information, such as image filenames, detections, 19 | groundtruth (if available). 20 | 21 | Parameters 22 | ---------- 23 | sequence_dir : str 24 | Path to the MOTChallenge sequence directory. 25 | detection_file : str 26 | Path to the detection file. 27 | 28 | Returns 29 | ------- 30 | Dict 31 | A dictionary of the following sequence information: 32 | 33 | * sequence_name: Name of the sequence 34 | * image_filenames: A dictionary that maps frame indices to image 35 | filenames. 36 | * detections: A numpy array of detections in MOTChallenge format. 37 | * groundtruth: A numpy array of ground truth in MOTChallenge format. 38 | * image_size: Image size (height, width). 39 | * min_frame_idx: Index of the first frame. 40 | * max_frame_idx: Index of the last frame. 41 | 42 | """ 43 | image_dir = os.path.join(sequence_dir, "img1") 44 | image_filenames = { 45 | int(os.path.splitext(f)[0]): os.path.join(image_dir, f) 46 | for f in os.listdir(image_dir)} 47 | groundtruth_file = os.path.join(sequence_dir, "gt/gt.txt") 48 | 49 | detections = None 50 | if detection_file is not None: 51 | detections = np.load(detection_file) 52 | groundtruth = None 53 | if os.path.exists(groundtruth_file): 54 | groundtruth = np.loadtxt(groundtruth_file, delimiter=',') 55 | 56 | if len(image_filenames) > 0: 57 | image = cv2.imread(next(iter(image_filenames.values())), 58 | cv2.IMREAD_GRAYSCALE) 59 | image_size = image.shape 60 | else: 61 | image_size = None 62 | 63 | if len(image_filenames) > 0: 64 | min_frame_idx = min(image_filenames.keys()) 65 | max_frame_idx = max(image_filenames.keys()) 66 | else: 67 | min_frame_idx = int(detections[:, 0].min()) 68 | max_frame_idx = int(detections[:, 0].max()) 69 | 70 | info_filename = os.path.join(sequence_dir, "seqinfo.ini") 71 | if os.path.exists(info_filename): 72 | with open(info_filename, "r") as f: 73 | line_splits = [l.split('=') for l in f.read().splitlines()[1:]] 74 | info_dict = dict( 75 | s for s in line_splits if isinstance(s, list) and len(s) == 2) 76 | 77 | update_ms = 1000 / int(info_dict["frameRate"]) 78 | else: 79 | update_ms = None 80 | 81 | feature_dim = detections.shape[1] - 10 if detections is not None else 0 82 | seq_info = { 83 | "sequence_name": os.path.basename(sequence_dir), 84 | "image_filenames": image_filenames, 85 | "detections": detections, 86 | "groundtruth": groundtruth, 87 | "image_size": image_size, 88 | "min_frame_idx": min_frame_idx, 89 | "max_frame_idx": max_frame_idx, 90 | "feature_dim": feature_dim, 91 | "update_ms": update_ms 92 | } 93 | return seq_info 94 | 95 | 96 | def create_detections(detection_mat, frame_idx, min_height=0): 97 | """Create detections for given frame index from the raw detection matrix. 98 | 99 | Parameters 100 | ---------- 101 | detection_mat : ndarray 102 | Matrix of detections. The first 10 columns of the detection matrix are 103 | in the standard MOTChallenge detection format. In the remaining columns 104 | store the feature vector associated with each detection. 105 | frame_idx : int 106 | The frame index. 107 | min_height : Optional[int] 108 | A minimum detection bounding box height. Detections that are smaller 109 | than this value are disregarded. 110 | 111 | Returns 112 | ------- 113 | List[tracker.Detection] 114 | Returns detection responses at given frame index. 115 | 116 | """ 117 | frame_indices = detection_mat[:, 0].astype(np.int64) 118 | mask = frame_indices == frame_idx 119 | 120 | detection_list = [] 121 | for row in detection_mat[mask]: 122 | bbox, confidence, feature = row[2:6], row[6], row[10:] 123 | if bbox[3] < min_height: 124 | continue 125 | detection_list.append(Detection(bbox, confidence, feature)) 126 | return detection_list 127 | 128 | 129 | def run(sequence_dir, detection_file, output_file, min_confidence, 130 | nms_max_overlap, min_detection_height, max_cosine_distance, 131 | nn_budget, display): 132 | """Run multi-target tracker on a particular sequence. 133 | 134 | Parameters 135 | ---------- 136 | sequence_dir : str 137 | Path to the MOTChallenge sequence directory. 138 | detection_file : str 139 | Path to the detections file. 140 | output_file : str 141 | Path to the tracking output file. This file will contain the tracking 142 | results on completion. 143 | min_confidence : float 144 | Detection confidence threshold. Disregard all detections that have 145 | a confidence lower than this value. 146 | nms_max_overlap: float 147 | Maximum detection overlap (non-maximum suppression threshold). 148 | min_detection_height : int 149 | Detection height threshold. Disregard all detections that have 150 | a height lower than this value. 151 | max_cosine_distance : float 152 | Gating threshold for cosine distance metric (object appearance). 153 | nn_budget : Optional[int] 154 | Maximum size of the appearance descriptor gallery. If None, no budget 155 | is enforced. 156 | display : bool 157 | If True, show visualization of intermediate tracking results. 158 | 159 | """ 160 | seq_info = gather_sequence_info(sequence_dir, detection_file) 161 | metric = nn_matching.NearestNeighborDistanceMetric( 162 | "cosine", max_cosine_distance, nn_budget) 163 | tracker = Tracker(metric) 164 | results = [] 165 | 166 | def frame_callback(vis, frame_idx): 167 | print("Processing frame %05d" % frame_idx) 168 | 169 | # Load image and generate detections. 170 | detections = create_detections( 171 | seq_info["detections"], frame_idx, min_detection_height) 172 | detections = [d for d in detections if d.confidence >= min_confidence] 173 | 174 | # Run non-maximum suppression. 175 | boxes = np.array([d.tlwh for d in detections]) 176 | scores = np.array([d.confidence for d in detections]) 177 | indices = preprocessing.non_max_suppression( 178 | boxes, nms_max_overlap, scores) 179 | detections = [detections[i] for i in indices] 180 | 181 | # Update tracker. 182 | tracker.predict() 183 | tracker.update(detections) 184 | 185 | # Update visualization. 186 | if display: 187 | image = cv2.imread( 188 | seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR) 189 | vis.set_image(image.copy()) 190 | vis.draw_detections(detections) 191 | vis.draw_trackers(tracker.tracks) 192 | 193 | # Store results. 194 | for track in tracker.tracks: 195 | if not track.is_confirmed() or track.time_since_update > 1: 196 | continue 197 | bbox = track.to_tlwh() 198 | results.append([ 199 | frame_idx, track.track_id, bbox[0], bbox[1], bbox[2], bbox[3]]) 200 | 201 | # Run tracker. 202 | if display: 203 | visualizer = visualization.Visualization(seq_info, update_ms=5) 204 | else: 205 | visualizer = visualization.NoVisualization(seq_info) 206 | visualizer.run(frame_callback) 207 | 208 | # Store results. 209 | f = open(output_file, 'w') 210 | for row in results: 211 | print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1' % ( 212 | row[0], row[1], row[2], row[3], row[4], row[5]),file=f) 213 | 214 | 215 | def bool_string(input_string): 216 | if input_string not in {"True","False"}: 217 | raise ValueError("Please Enter a valid Ture/False choice") 218 | else: 219 | return (input_string == "True") 220 | 221 | def parse_args(): 222 | """ Parse command line arguments. 223 | """ 224 | parser = argparse.ArgumentParser(description="Deep SORT") 225 | parser.add_argument( 226 | "--sequence_dir", help="Path to MOTChallenge sequence directory", 227 | default=None, required=True) 228 | parser.add_argument( 229 | "--detection_file", help="Path to custom detections.", default=None, 230 | required=True) 231 | parser.add_argument( 232 | "--output_file", help="Path to the tracking output file. This file will" 233 | " contain the tracking results on completion.", 234 | default="/tmp/hypotheses.txt") 235 | parser.add_argument( 236 | "--min_confidence", help="Detection confidence threshold. Disregard " 237 | "all detections that have a confidence lower than this value.", 238 | default=0.8, type=float) 239 | parser.add_argument( 240 | "--min_detection_height", help="Threshold on the detection bounding " 241 | "box height. Detections with height smaller than this value are " 242 | "disregarded", default=0, type=int) 243 | parser.add_argument( 244 | "--nms_max_overlap", help="Non-maximum suppression threshold: Maximum " 245 | "detection overlap.", default=1.0, type=float) 246 | parser.add_argument( 247 | "--max_cosine_distance", help="Gating threshold for cosine distance " 248 | "metric (object appearance).", type=float, default=0.2) 249 | parser.add_argument( 250 | "--nn_budget", help="Maximum size of the appearance descriptors " 251 | "gallery. If None, no budget is enforced.", type=int, default=None) 252 | parser.add_argument( 253 | "--display", help="Show intermediate tracking results", 254 | default=True, type=bool_string) 255 | return parser.parse_args() 256 | 257 | 258 | if __name__ == "__main__": 259 | args = parse_args() 260 | run( 261 | args.sequence_dir, args.detection_file, args.output_file, 262 | args.min_confidence, args.nms_max_overlap, args.min_detection_height, 263 | args.max_cosine_distance, args.nn_budget, args.display) 264 | -------------------------------------------------------------------------------- /evaluate_motchallenge.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import argparse 3 | import os 4 | import deep_sort_app 5 | 6 | 7 | def parse_args(): 8 | """ Parse command line arguments. 9 | """ 10 | parser = argparse.ArgumentParser(description="MOTChallenge evaluation") 11 | parser.add_argument( 12 | "--mot_dir", help="Path to MOTChallenge directory (train or test)", 13 | required=True) 14 | parser.add_argument( 15 | "--detection_dir", help="Path to detections.", default="detections", 16 | required=True) 17 | parser.add_argument( 18 | "--output_dir", help="Folder in which the results will be stored. Will " 19 | "be created if it does not exist.", default="results") 20 | parser.add_argument( 21 | "--min_confidence", help="Detection confidence threshold. Disregard " 22 | "all detections that have a confidence lower than this value. Set to " 23 | "0.3 to reproduce results in the paper.", 24 | default=0.3, type=float) 25 | parser.add_argument( 26 | "--min_detection_height", help="Threshold on the detection bounding " 27 | "box height. Detections with height smaller than this value are " 28 | "disregarded", default=0, type=int) 29 | parser.add_argument( 30 | "--nms_max_overlap", help="Non-maximum suppression threshold: Maximum " 31 | "detection overlap.", default=1.0, type=float) 32 | parser.add_argument( 33 | "--max_cosine_distance", help="Gating threshold for cosine distance " 34 | "metric (object appearance).", type=float, default=0.2) 35 | parser.add_argument( 36 | "--nn_budget", help="Maximum size of the appearance descriptors " 37 | "gallery. If None, no budget is enforced.", type=int, default=100) 38 | return parser.parse_args() 39 | 40 | 41 | if __name__ == "__main__": 42 | args = parse_args() 43 | 44 | os.makedirs(args.output_dir, exist_ok=True) 45 | sequences = os.listdir(args.mot_dir) 46 | for sequence in sequences: 47 | print("Running sequence %s" % sequence) 48 | sequence_dir = os.path.join(args.mot_dir, sequence) 49 | detection_file = os.path.join(args.detection_dir, "%s.npy" % sequence) 50 | output_file = os.path.join(args.output_dir, "%s.txt" % sequence) 51 | deep_sort_app.run( 52 | sequence_dir, detection_file, output_file, args.min_confidence, 53 | args.nms_max_overlap, args.min_detection_height, 54 | args.max_cosine_distance, args.nn_budget, display=False) 55 | -------------------------------------------------------------------------------- /generate_videos.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import os 3 | import argparse 4 | import show_results 5 | 6 | 7 | def convert(filename_in, filename_out, ffmpeg_executable="ffmpeg"): 8 | import subprocess 9 | command = [ffmpeg_executable, "-i", filename_in, "-c:v", "libx264", 10 | "-preset", "slow", "-crf", "21", filename_out] 11 | subprocess.call(command) 12 | 13 | 14 | def parse_args(): 15 | """ Parse command line arguments. 16 | """ 17 | parser = argparse.ArgumentParser(description="Siamese Tracking") 18 | parser.add_argument( 19 | "--mot_dir", help="Path to MOTChallenge directory (train or test)", 20 | required=True) 21 | parser.add_argument( 22 | "--result_dir", help="Path to the folder with tracking output.", 23 | required=True) 24 | parser.add_argument( 25 | "--output_dir", help="Folder to store the videos in. Will be created " 26 | "if it does not exist.", 27 | required=True) 28 | parser.add_argument( 29 | "--convert_h264", help="If true, convert videos to libx264 (requires " 30 | "FFMPEG", default=False) 31 | parser.add_argument( 32 | "--update_ms", help="Time between consecutive frames in milliseconds. " 33 | "Defaults to the frame_rate specified in seqinfo.ini, if available.", 34 | default=None) 35 | return parser.parse_args() 36 | 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | 41 | os.makedirs(args.output_dir, exist_ok=True) 42 | for sequence_txt in os.listdir(args.result_dir): 43 | sequence = os.path.splitext(sequence_txt)[0] 44 | sequence_dir = os.path.join(args.mot_dir, sequence) 45 | if not os.path.exists(sequence_dir): 46 | continue 47 | result_file = os.path.join(args.result_dir, sequence_txt) 48 | update_ms = args.update_ms 49 | video_filename = os.path.join(args.output_dir, "%s.avi" % sequence) 50 | 51 | print("Saving %s to %s." % (sequence_txt, video_filename)) 52 | show_results.run( 53 | sequence_dir, result_file, False, None, update_ms, video_filename) 54 | 55 | if not args.convert_h264: 56 | import sys 57 | sys.exit() 58 | for sequence_txt in os.listdir(args.result_dir): 59 | sequence = os.path.splitext(sequence_txt)[0] 60 | sequence_dir = os.path.join(args.mot_dir, sequence) 61 | if not os.path.exists(sequence_dir): 62 | continue 63 | filename_in = os.path.join(args.output_dir, "%s.avi" % sequence) 64 | filename_out = os.path.join(args.output_dir, "%s.mp4" % sequence) 65 | convert(filename_in, filename_out) 66 | -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | numpy<2.0.0 2 | opencv-python 3 | scipy 4 | tensorflow[and-cuda]==2.10.0 5 | tf-slim 6 | tf-keras 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | scipy 4 | -------------------------------------------------------------------------------- /show_results.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import argparse 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | import deep_sort_app 8 | from deep_sort.iou_matching import iou 9 | from application_util import visualization 10 | 11 | 12 | DEFAULT_UPDATE_MS = 20 13 | 14 | 15 | def run(sequence_dir, result_file, show_false_alarms=False, detection_file=None, 16 | update_ms=None, video_filename=None): 17 | """Run tracking result visualization. 18 | 19 | Parameters 20 | ---------- 21 | sequence_dir : str 22 | Path to the MOTChallenge sequence directory. 23 | result_file : str 24 | Path to the tracking output file in MOTChallenge ground truth format. 25 | show_false_alarms : Optional[bool] 26 | If True, false alarms are highlighted as red boxes. 27 | detection_file : Optional[str] 28 | Path to the detection file. 29 | update_ms : Optional[int] 30 | Number of milliseconds between cosecutive frames. Defaults to (a) the 31 | frame rate specifid in the seqinfo.ini file or DEFAULT_UDPATE_MS ms if 32 | seqinfo.ini is not available. 33 | video_filename : Optional[Str] 34 | If not None, a video of the tracking results is written to this file. 35 | 36 | """ 37 | seq_info = deep_sort_app.gather_sequence_info(sequence_dir, detection_file) 38 | results = np.loadtxt(result_file, delimiter=',') 39 | 40 | if show_false_alarms and seq_info["groundtruth"] is None: 41 | raise ValueError("No groundtruth available. Cannot show false alarms.") 42 | 43 | def frame_callback(vis, frame_idx): 44 | print("Frame idx", frame_idx) 45 | image = cv2.imread( 46 | seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR) 47 | 48 | vis.set_image(image.copy()) 49 | 50 | if seq_info["detections"] is not None: 51 | detections = deep_sort_app.create_detections( 52 | seq_info["detections"], frame_idx) 53 | vis.draw_detections(detections) 54 | 55 | mask = results[:, 0].astype(np.int) == frame_idx 56 | track_ids = results[mask, 1].astype(np.int) 57 | boxes = results[mask, 2:6] 58 | vis.draw_groundtruth(track_ids, boxes) 59 | 60 | if show_false_alarms: 61 | groundtruth = seq_info["groundtruth"] 62 | mask = groundtruth[:, 0].astype(np.int) == frame_idx 63 | gt_boxes = groundtruth[mask, 2:6] 64 | for box in boxes: 65 | # NOTE(nwojke): This is not strictly correct, because we don't 66 | # solve the assignment problem here. 67 | min_iou_overlap = 0.5 68 | if iou(box, gt_boxes).max() < min_iou_overlap: 69 | vis.viewer.color = 0, 0, 255 70 | vis.viewer.thickness = 4 71 | vis.viewer.rectangle(*box.astype(np.int)) 72 | 73 | if update_ms is None: 74 | update_ms = seq_info["update_ms"] 75 | if update_ms is None: 76 | update_ms = DEFAULT_UPDATE_MS 77 | visualizer = visualization.Visualization(seq_info, update_ms) 78 | if video_filename is not None: 79 | visualizer.viewer.enable_videowriter(video_filename) 80 | visualizer.run(frame_callback) 81 | 82 | 83 | def parse_args(): 84 | """ Parse command line arguments. 85 | """ 86 | parser = argparse.ArgumentParser(description="Siamese Tracking") 87 | parser.add_argument( 88 | "--sequence_dir", help="Path to the MOTChallenge sequence directory.", 89 | default=None, required=True) 90 | parser.add_argument( 91 | "--result_file", help="Tracking output in MOTChallenge file format.", 92 | default=None, required=True) 93 | parser.add_argument( 94 | "--detection_file", help="Path to custom detections (optional).", 95 | default=None) 96 | parser.add_argument( 97 | "--update_ms", help="Time between consecutive frames in milliseconds. " 98 | "Defaults to the frame_rate specified in seqinfo.ini, if available.", 99 | default=None) 100 | parser.add_argument( 101 | "--output_file", help="Filename of the (optional) output video.", 102 | default=None) 103 | parser.add_argument( 104 | "--show_false_alarms", help="Show false alarms as red bounding boxes.", 105 | type=bool, default=False) 106 | return parser.parse_args() 107 | 108 | 109 | if __name__ == "__main__": 110 | args = parse_args() 111 | run( 112 | args.sequence_dir, args.result_file, args.show_false_alarms, 113 | args.detection_file, args.update_ms, args.output_file) 114 | -------------------------------------------------------------------------------- /tools/freeze_model.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import argparse 3 | import tensorflow as tf 4 | import tf_slim 5 | 6 | 7 | def _batch_norm_fn(x, scope=None): 8 | if scope is None: 9 | scope = tf.compat.v1.get_variable_scope().name + "/bn" 10 | return tf_slim.batch_norm(x, scope=scope) 11 | 12 | 13 | def create_link( 14 | incoming, network_builder, scope, nonlinearity=tf.nn.elu, 15 | weights_initializer=tf.compat.v1.truncated_normal_initializer(stddev=1e-3), 16 | regularizer=None, is_first=False, summarize_activations=True): 17 | if is_first: 18 | network = incoming 19 | else: 20 | network = _batch_norm_fn(incoming, scope=scope + "/bn") 21 | network = nonlinearity(network) 22 | if summarize_activations: 23 | tf.summary.histogram(scope+"/activations", network) 24 | 25 | pre_block_network = network 26 | post_block_network = network_builder(pre_block_network, scope) 27 | 28 | incoming_dim = pre_block_network.get_shape().as_list()[-1] 29 | outgoing_dim = post_block_network.get_shape().as_list()[-1] 30 | if incoming_dim != outgoing_dim: 31 | assert outgoing_dim == 2 * incoming_dim, \ 32 | "%d != %d" % (outgoing_dim, 2 * incoming) 33 | projection = tf_slim.conv2d( 34 | incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None, 35 | scope=scope+"/projection", weights_initializer=weights_initializer, 36 | biases_initializer=None, weights_regularizer=regularizer) 37 | network = projection + post_block_network 38 | else: 39 | network = incoming + post_block_network 40 | return network 41 | 42 | 43 | def create_inner_block( 44 | incoming, scope, nonlinearity=tf.nn.elu, 45 | weights_initializer=tf.compat.v1.truncated_normal_initializer(1e-3), 46 | bias_initializer=tf.zeros_initializer(), regularizer=None, 47 | increase_dim=False, summarize_activations=True): 48 | n = incoming.get_shape().as_list()[-1] 49 | stride = 1 50 | if increase_dim: 51 | n *= 2 52 | stride = 2 53 | 54 | incoming = tf_slim.conv2d( 55 | incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME", 56 | normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer, 57 | biases_initializer=bias_initializer, weights_regularizer=regularizer, 58 | scope=scope + "/1") 59 | if summarize_activations: 60 | tf.summary.histogram(incoming.name + "/activations", incoming) 61 | 62 | incoming = tf_slim.dropout(incoming, keep_prob=0.6) 63 | 64 | incoming = tf_slim.conv2d( 65 | incoming, n, [3, 3], 1, activation_fn=None, padding="SAME", 66 | normalizer_fn=None, weights_initializer=weights_initializer, 67 | biases_initializer=bias_initializer, weights_regularizer=regularizer, 68 | scope=scope + "/2") 69 | return incoming 70 | 71 | 72 | def residual_block(incoming, scope, nonlinearity=tf.nn.elu, 73 | weights_initializer=tf.compat.v1.truncated_normal_initializer(1e3), 74 | bias_initializer=tf.zeros_initializer(), regularizer=None, 75 | increase_dim=False, is_first=False, 76 | summarize_activations=True): 77 | 78 | def network_builder(x, s): 79 | return create_inner_block( 80 | x, s, nonlinearity, weights_initializer, bias_initializer, 81 | regularizer, increase_dim, summarize_activations) 82 | 83 | return create_link( 84 | incoming, network_builder, scope, nonlinearity, weights_initializer, 85 | regularizer, is_first, summarize_activations) 86 | 87 | 88 | def _create_network(incoming, reuse=None, weight_decay=1e-8): 89 | nonlinearity = tf.nn.elu 90 | conv_weight_init = tf.compat.v1.truncated_normal_initializer(stddev=1e-3) 91 | conv_bias_init = tf.zeros_initializer() 92 | conv_regularizer = tf_slim.l2_regularizer(weight_decay) 93 | fc_weight_init = tf.compat.v1.truncated_normal_initializer(stddev=1e-3) 94 | fc_bias_init = tf.zeros_initializer() 95 | fc_regularizer = tf_slim.l2_regularizer(weight_decay) 96 | 97 | def batch_norm_fn(x): 98 | return tf_slim.batch_norm(x, scope=tf.compat.v1.get_variable_scope().name + "/bn") 99 | 100 | network = incoming 101 | network = tf_slim.conv2d( 102 | network, 32, [3, 3], stride=1, activation_fn=nonlinearity, 103 | padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1", 104 | weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, 105 | weights_regularizer=conv_regularizer) 106 | network = tf_slim.conv2d( 107 | network, 32, [3, 3], stride=1, activation_fn=nonlinearity, 108 | padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2", 109 | weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, 110 | weights_regularizer=conv_regularizer) 111 | 112 | # NOTE(nwojke): This is missing a padding="SAME" to match the CNN 113 | # architecture in Table 1 of the paper. Information on how this affects 114 | # performance on MOT 16 training sequences can be found in 115 | # issue 10 https://github.com/nwojke/deep_sort/issues/10 116 | network = tf_slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1") 117 | 118 | network = residual_block( 119 | network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init, 120 | conv_regularizer, increase_dim=False, is_first=True) 121 | network = residual_block( 122 | network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init, 123 | conv_regularizer, increase_dim=False) 124 | 125 | network = residual_block( 126 | network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init, 127 | conv_regularizer, increase_dim=True) 128 | network = residual_block( 129 | network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init, 130 | conv_regularizer, increase_dim=False) 131 | 132 | network = residual_block( 133 | network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init, 134 | conv_regularizer, increase_dim=True) 135 | network = residual_block( 136 | network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init, 137 | conv_regularizer, increase_dim=False) 138 | 139 | feature_dim = network.get_shape().as_list()[-1] 140 | network = tf_slim.flatten(network) 141 | 142 | network = tf_slim.dropout(network, keep_prob=0.6) 143 | network = tf_slim.fully_connected( 144 | network, feature_dim, activation_fn=nonlinearity, 145 | normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer, 146 | scope="fc1", weights_initializer=fc_weight_init, 147 | biases_initializer=fc_bias_init) 148 | 149 | features = network 150 | 151 | # Features in rows, normalize axis 1. 152 | features = tf_slim.batch_norm(features, scope="ball", reuse=reuse) 153 | feature_norm = tf.sqrt( 154 | tf.constant(1e-8, tf.float32) + 155 | tf.reduce_sum(tf.square(features), [1], keepdims=True)) 156 | features = features / feature_norm 157 | return features, None 158 | 159 | 160 | def _network_factory(weight_decay=1e-8): 161 | 162 | def factory_fn(image, reuse): 163 | with tf_slim.arg_scope([tf_slim.batch_norm, tf_slim.dropout], 164 | is_training=False): 165 | with tf_slim.arg_scope([tf_slim.conv2d, tf_slim.fully_connected, 166 | tf_slim.batch_norm, tf_slim.layer_norm], 167 | reuse=reuse): 168 | features, logits = _create_network( 169 | image, reuse=reuse, weight_decay=weight_decay) 170 | return features, logits 171 | 172 | return factory_fn 173 | 174 | 175 | def _preprocess(image): 176 | image = image[:, :, ::-1] # BGR to RGB 177 | return image 178 | 179 | 180 | def parse_args(): 181 | """Parse command line arguments. 182 | """ 183 | parser = argparse.ArgumentParser(description="Freeze old model") 184 | parser.add_argument( 185 | "--checkpoint_in", 186 | default="resources/networks/mars-small128.ckpt-68577", 187 | help="Path to checkpoint file") 188 | parser.add_argument( 189 | "--graphdef_out", 190 | default="resources/networks/mars-small128.pb") 191 | return parser.parse_args() 192 | 193 | 194 | def main(): 195 | args = parse_args() 196 | 197 | with tf.compat.v1.Session(graph=tf.Graph()) as session: 198 | input_var = tf.compat.v1.placeholder( 199 | tf.uint8, (None, 128, 64, 3), name="images") 200 | image_var = tf.map_fn( 201 | lambda x: _preprocess(x), tf.cast(input_var, tf.float32)) 202 | 203 | factory_fn = _network_factory() 204 | features, _ = factory_fn(image_var, reuse=None) 205 | features = tf.identity(features, name="features") 206 | 207 | saver = tf.compat.v1.train.Saver(tf_slim.get_variables_to_restore()) 208 | saver.restore(session, args.checkpoint_in) 209 | 210 | output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( 211 | session, tf.compat.v1.get_default_graph().as_graph_def(), 212 | [features.name.split(":")[0]]) 213 | with tf.compat.v1.gfile.GFile(args.graphdef_out, "wb") as file_handle: 214 | file_handle.write(output_graph_def.SerializeToString()) 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /tools/generate_detections.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import os 3 | import errno 4 | import argparse 5 | import numpy as np 6 | import cv2 7 | import tensorflow as tf 8 | 9 | 10 | def _run_in_batches(f, data_dict, out, batch_size): 11 | data_len = len(out) 12 | num_batches = int(data_len / batch_size) 13 | 14 | s, e = 0, 0 15 | for i in range(num_batches): 16 | s, e = i * batch_size, (i + 1) * batch_size 17 | batch_data_dict = {k: v[s:e] for k, v in data_dict.items()} 18 | out[s:e] = f(batch_data_dict) 19 | if e < len(out): 20 | batch_data_dict = {k: v[e:] for k, v in data_dict.items()} 21 | out[e:] = f(batch_data_dict) 22 | 23 | 24 | def extract_image_patch(image, bbox, patch_shape): 25 | """Extract image patch from bounding box. 26 | 27 | Parameters 28 | ---------- 29 | image : ndarray 30 | The full image. 31 | bbox : array_like 32 | The bounding box in format (x, y, width, height). 33 | patch_shape : Optional[array_like] 34 | This parameter can be used to enforce a desired patch shape 35 | (height, width). First, the `bbox` is adapted to the aspect ratio 36 | of the patch shape, then it is clipped at the image boundaries. 37 | If None, the shape is computed from :arg:`bbox`. 38 | 39 | Returns 40 | ------- 41 | ndarray | NoneType 42 | An image patch showing the :arg:`bbox`, optionally reshaped to 43 | :arg:`patch_shape`. 44 | Returns None if the bounding box is empty or fully outside of the image 45 | boundaries. 46 | 47 | """ 48 | bbox = np.array(bbox) 49 | if patch_shape is not None: 50 | # correct aspect ratio to patch shape 51 | target_aspect = float(patch_shape[1]) / patch_shape[0] 52 | new_width = target_aspect * bbox[3] 53 | bbox[0] -= (new_width - bbox[2]) / 2 54 | bbox[2] = new_width 55 | 56 | # convert to top left, bottom right 57 | bbox[2:] += bbox[:2] 58 | bbox = bbox.astype(np.int64) 59 | 60 | # clip at image boundaries 61 | bbox[:2] = np.maximum(0, bbox[:2]) 62 | bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:]) 63 | if np.any(bbox[:2] >= bbox[2:]): 64 | return None 65 | sx, sy, ex, ey = bbox 66 | image = image[sy:ey, sx:ex] 67 | image = cv2.resize(image, tuple(patch_shape[::-1])) 68 | return image 69 | 70 | 71 | class ImageEncoder(object): 72 | 73 | def __init__(self, checkpoint_filename, input_name="images", 74 | output_name="features"): 75 | self.session = tf.compat.v1.Session() 76 | with tf.compat.v1.gfile.GFile(checkpoint_filename, "rb") as file_handle: 77 | graph_def = tf.compat.v1.GraphDef() 78 | graph_def.ParseFromString(file_handle.read()) 79 | tf.import_graph_def(graph_def, name="net") 80 | 81 | self.input_var = tf.compat.v1.get_default_graph().get_tensor_by_name( 82 | "%s:0" % input_name) 83 | self.output_var = tf.compat.v1.get_default_graph().get_tensor_by_name( 84 | "%s:0" % output_name) 85 | 86 | assert len(self.output_var.get_shape()) == 2 87 | assert len(self.input_var.get_shape()) == 4 88 | self.feature_dim = self.output_var.get_shape().as_list()[-1] 89 | self.image_shape = self.input_var.get_shape().as_list()[1:] 90 | 91 | def __call__(self, data_x, batch_size=32): 92 | out = np.zeros((len(data_x), self.feature_dim), np.float32) 93 | _run_in_batches( 94 | lambda x: self.session.run(self.output_var, feed_dict=x), 95 | {self.input_var: data_x}, out, batch_size) 96 | return out 97 | 98 | 99 | def create_box_encoder(model_filename, input_name="images", 100 | output_name="features", batch_size=32): 101 | image_encoder = ImageEncoder(model_filename, input_name, output_name) 102 | image_shape = image_encoder.image_shape 103 | 104 | def encoder(image, boxes): 105 | image_patches = [] 106 | for box in boxes: 107 | patch = extract_image_patch(image, box, image_shape[:2]) 108 | if patch is None: 109 | print("WARNING: Failed to extract image patch: %s." % str(box)) 110 | patch = np.random.uniform( 111 | 0., 255., image_shape).astype(np.uint8) 112 | image_patches.append(patch) 113 | image_patches = np.asarray(image_patches) 114 | return image_encoder(image_patches, batch_size) 115 | 116 | return encoder 117 | 118 | 119 | def generate_detections(encoder, mot_dir, output_dir, detection_dir=None): 120 | """Generate detections with features. 121 | 122 | Parameters 123 | ---------- 124 | encoder : Callable[image, ndarray] -> ndarray 125 | The encoder function takes as input a BGR color image and a matrix of 126 | bounding boxes in format `(x, y, w, h)` and returns a matrix of 127 | corresponding feature vectors. 128 | mot_dir : str 129 | Path to the MOTChallenge directory (can be either train or test). 130 | output_dir 131 | Path to the output directory. Will be created if it does not exist. 132 | detection_dir 133 | Path to custom detections. The directory structure should be the default 134 | MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the 135 | standard MOTChallenge detections. 136 | 137 | """ 138 | if detection_dir is None: 139 | detection_dir = mot_dir 140 | try: 141 | os.makedirs(output_dir) 142 | except OSError as exception: 143 | if exception.errno == errno.EEXIST and os.path.isdir(output_dir): 144 | pass 145 | else: 146 | raise ValueError( 147 | "Failed to created output directory '%s'" % output_dir) 148 | 149 | for sequence in os.listdir(mot_dir): 150 | print("Processing %s" % sequence) 151 | sequence_dir = os.path.join(mot_dir, sequence) 152 | 153 | image_dir = os.path.join(sequence_dir, "img1") 154 | image_filenames = { 155 | int(os.path.splitext(f)[0]): os.path.join(image_dir, f) 156 | for f in os.listdir(image_dir)} 157 | 158 | detection_file = os.path.join( 159 | detection_dir, sequence, "det/det.txt") 160 | detections_in = np.loadtxt(detection_file, delimiter=',') 161 | detections_out = [] 162 | 163 | frame_indices = detections_in[:, 0].astype(np.int64) 164 | min_frame_idx = frame_indices.astype(np.int64).min() 165 | max_frame_idx = frame_indices.astype(np.int64).max() 166 | for frame_idx in range(min_frame_idx, max_frame_idx + 1): 167 | print("Frame %05d/%05d" % (frame_idx, max_frame_idx)) 168 | mask = frame_indices == frame_idx 169 | rows = detections_in[mask] 170 | 171 | if frame_idx not in image_filenames: 172 | print("WARNING could not find image for frame %d" % frame_idx) 173 | continue 174 | bgr_image = cv2.imread( 175 | image_filenames[frame_idx], cv2.IMREAD_COLOR) 176 | features = encoder(bgr_image, rows[:, 2:6].copy()) 177 | detections_out += [np.r_[(row, feature)] for row, feature 178 | in zip(rows, features)] 179 | 180 | output_filename = os.path.join(output_dir, "%s.npy" % sequence) 181 | np.save( 182 | output_filename, np.asarray(detections_out), allow_pickle=False) 183 | 184 | 185 | def parse_args(): 186 | """Parse command line arguments. 187 | """ 188 | parser = argparse.ArgumentParser(description="Re-ID feature extractor") 189 | parser.add_argument( 190 | "--model", 191 | default="resources/networks/mars-small128.pb", 192 | help="Path to freezed inference graph protobuf.") 193 | parser.add_argument( 194 | "--mot_dir", help="Path to MOTChallenge directory (train or test)", 195 | required=True) 196 | parser.add_argument( 197 | "--detection_dir", help="Path to custom detections. Defaults to " 198 | "standard MOT detections Directory structure should be the default " 199 | "MOTChallenge structure: [sequence]/det/det.txt", default=None) 200 | parser.add_argument( 201 | "--output_dir", help="Output directory. Will be created if it does not" 202 | " exist.", default="detections") 203 | return parser.parse_args() 204 | 205 | 206 | def main(): 207 | args = parse_args() 208 | encoder = create_box_encoder(args.model, batch_size=32) 209 | generate_detections(encoder, args.mot_dir, args.output_dir, 210 | args.detection_dir) 211 | 212 | 213 | if __name__ == "__main__": 214 | main() 215 | --------------------------------------------------------------------------------