├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 000032.jpg └── 000232.jpg ├── average_precision.py ├── data_queue.py ├── detect.py ├── export_model.py ├── infer.py ├── pascal-voc └── download-data.sh ├── pascal_summary.py ├── process_dataset.py ├── source_pascal_voc.py ├── ssdutils.py ├── ssdvgg.py ├── train.py ├── training_data.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *~ 3 | tb-* 4 | *test* 5 | *.zip 6 | pascal-voc* 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | SSD-TensorFlow 3 | ============== 4 | 5 | Overview 6 | -------- 7 | 8 | The programs in this repository train and use a Single Shot MultiBox Detector 9 | to take an image and draw bounding boxes around objects of certain classes 10 | contained in this image. The network is based on the VGG-16 model and uses 11 | the approach described in [this paper][1] by Wei Liu et al. The software is 12 | generic and easily extendable to any dataset, although I only tried it with 13 | [Pascal VOC][2] so far. All you need to do to introduce a new dataset is to 14 | create a new `source_xxxxxx.py` file defining it. 15 | 16 | Go [here][4] for more info. 17 | 18 | Pascal VOC Results 19 | ------------------ 20 | 21 | Images and numbers speak louder than a thousand words, so here they are: 22 | 23 | ![Example #1][img1] 24 | ![Example #2][img2] 25 | 26 | | Model | Training data | mAP Train | mAP VOC12 test |Reference| 27 | |:------:|:--------------------------------:|:---------:|:--------------:|:-------:| 28 | | vgg300 | VOC07+12 trainval and VOC07 Test | 79.5% | [72.3%][3] | 72.4% | 29 | | vgg512 | VOC07+12 trainval and VOC07 Test | 82.3% | [75.0%][5] | 74.9% | 30 | 31 | Usage 32 | ----- 33 | 34 | To train the model on the Pascal VOC data, go to the `pascal-voc` directory 35 | and download the dataset: 36 | 37 | cd pascal-voc 38 | ./download-data.sh 39 | cd .. 40 | 41 | You then need to preprocess the dataset before you can train the model on it. 42 | It's OK to use the default settings, but if you want something more, it's always 43 | good to try the `--help` parameter. 44 | 45 | ./process_dataset.py 46 | 47 | You can then train the whole thing. It will take around 150 to 200 epochs to get 48 | good results. Again, you can try `--help` if you want to do something custom. 49 | 50 | ./train.py 51 | 52 | You can annotate images, dump raw predictions, print the AP stats, or export the 53 | results in the Pascal VOC compatible format using the inference script. 54 | 55 | ./infer.py --help 56 | 57 | To export the model to an inference optimize graph run (use `result/result` 58 | as the name of the output tensor): 59 | 60 | ./export_model.py 61 | 62 | If you want to make detection basing on the inference model, check out: 63 | 64 | ./detect.py 65 | 66 | 67 | Have Fun! 68 | 69 | [1]: https://arxiv.org/pdf/1512.02325.pdf 70 | [2]: http://host.robots.ox.ac.uk/pascal/VOC/ 71 | [3]: http://host.robots.ox.ac.uk:8080/anonymous/NEIZIN.html 72 | [4]: http://jany.st/post/2017-11-05-single-shot-detector-ssd-from-scratch-in-tensorflow.html 73 | [5]: http://host.robots.ox.ac.uk:8080/anonymous/FYP60C.html 74 | 75 | [img1]: assets/000232.jpg 76 | [img2]: assets/000032.jpg 77 | -------------------------------------------------------------------------------- /assets/000032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljanyst/ssd-tensorflow/e9c1ee5ccb90130c34cbc64bed62b6bc8b3880d2/assets/000032.jpg -------------------------------------------------------------------------------- /assets/000232.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljanyst/ssd-tensorflow/e9c1ee5ccb90130c34cbc64bed62b6bc8b3880d2/assets/000232.jpg -------------------------------------------------------------------------------- /average_precision.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 13.09.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import numpy as np 22 | 23 | from collections import defaultdict 24 | from ssdutils import jaccard_overlap 25 | from utils import Size, prop2abs 26 | 27 | IMG_SIZE = Size(1000, 1000) 28 | 29 | #------------------------------------------------------------------------------- 30 | def APs2mAP(aps): 31 | """ 32 | Take a mean of APs over all classes to compute mAP 33 | """ 34 | num_classes = 0. 35 | sum_ap = 0. 36 | for _, v in aps.items(): 37 | sum_ap += v 38 | num_classes += 1 39 | 40 | if num_classes == 0: 41 | return 0 42 | return sum_ap/num_classes 43 | 44 | #------------------------------------------------------------------------------- 45 | class APCalculator: 46 | """ 47 | Compute average precision of object detection as used in PASCAL VOC 48 | Challenges. It is a peculiar measure because of the way it calculates the 49 | precision-recall curve. It's highly sensitive to the sorting order of the 50 | predictions in different images. Ie. the exact same resulting bounding 51 | boxes in all images may get different AP score depending on the way 52 | the boxes are sorted globally by confidence. 53 | Reference: http://homepages.inf.ed.ac.uk/ckiw/postscript/ijcv_voc09.pdf 54 | Reference: http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar 55 | """ 56 | #--------------------------------------------------------------------------- 57 | def __init__(self, minoverlap=0.5): 58 | """ 59 | Initialize the calculator. 60 | """ 61 | self.minoverlap = minoverlap 62 | self.clear() 63 | 64 | #--------------------------------------------------------------------------- 65 | def add_detections(self, gt_boxes, boxes): 66 | """ 67 | Add new detections to the calculator. 68 | :param gt_sample: ground truth sample 69 | :param boxes: a list of (float, Box) tuples representing 70 | detections and their confidences, the detections 71 | must have a correctly set label 72 | """ 73 | 74 | sample_id = len(self.gt_boxes) 75 | self.gt_boxes.append(gt_boxes) 76 | 77 | for conf, box in boxes: 78 | arr = np.array(prop2abs(box.center, box.size, IMG_SIZE)) 79 | self.det_params[box.label].append(arr) 80 | self.det_confidence[box.label].append(conf) 81 | self.det_sample_ids[box.label].append(sample_id) 82 | 83 | #--------------------------------------------------------------------------- 84 | def compute_aps(self): 85 | """ 86 | Compute the average precision per class as well as mAP. 87 | """ 88 | 89 | #----------------------------------------------------------------------- 90 | # Split the ground truth samples by class and sample 91 | #----------------------------------------------------------------------- 92 | counts = defaultdict(lambda: 0) 93 | gt_map = defaultdict(dict) 94 | 95 | for sample_id, boxes in enumerate(self.gt_boxes): 96 | boxes_by_class = defaultdict(list) 97 | for box in boxes: 98 | counts[box.label] += 1 99 | boxes_by_class[box.label].append(box) 100 | 101 | for k, v in boxes_by_class.items(): 102 | arr = np.zeros((len(v), 4)) 103 | match = np.zeros((len(v)), dtype=np.bool) 104 | for i, box in enumerate(v): 105 | arr[i] = np.array(prop2abs(box.center, box.size, IMG_SIZE)) 106 | gt_map[k][sample_id] = (arr, match) 107 | 108 | #----------------------------------------------------------------------- 109 | # Compare predictions to ground truth 110 | #----------------------------------------------------------------------- 111 | aps = {} 112 | for k in gt_map: 113 | #------------------------------------------------------------------- 114 | # Create numpy arrays of detection parameters and sort them 115 | # in descending order 116 | #------------------------------------------------------------------- 117 | params = np.array(self.det_params[k], dtype=np.float32) 118 | confs = np.array(self.det_confidence[k], dtype=np.float32) 119 | sample_ids = np.array(self.det_sample_ids[k], dtype=np.int) 120 | idxs_max = np.argsort(-confs) 121 | params = params[idxs_max] 122 | confs = confs[idxs_max] 123 | sample_ids = sample_ids[idxs_max] 124 | 125 | #------------------------------------------------------------------- 126 | # Loop over the detections and count true and false positives 127 | #------------------------------------------------------------------- 128 | tps = np.zeros((params.shape[0])) # true positives 129 | fps = np.zeros((params.shape[0])) # false positives 130 | for i in range(params.shape[0]): 131 | sample_id = sample_ids[i] 132 | box = params[i] 133 | 134 | #--------------------------------------------------------------- 135 | # The image this detection comes from contains no objects of 136 | # of this class 137 | #--------------------------------------------------------------- 138 | if not sample_id in gt_map[k]: 139 | fps[i] = 1 140 | continue 141 | 142 | #--------------------------------------------------------------- 143 | # Compute the jaccard overlap and see if it's over the threshold 144 | #--------------------------------------------------------------- 145 | gt = gt_map[k][sample_id][0] 146 | matched = gt_map[k][sample_id][1] 147 | 148 | iou = jaccard_overlap(box, gt) 149 | max_idx = np.argmax(iou) 150 | 151 | if iou[max_idx] < self.minoverlap: 152 | fps[i] = 1 153 | continue 154 | 155 | #--------------------------------------------------------------- 156 | # Check if the max overlap ground truth box is already matched 157 | #--------------------------------------------------------------- 158 | if matched[max_idx]: 159 | fps[i] = 1 160 | continue 161 | 162 | tps[i] = 1 163 | matched[max_idx] = True 164 | 165 | #------------------------------------------------------------------- 166 | # Compute the precision, recall 167 | #------------------------------------------------------------------- 168 | fps = np.cumsum(fps) 169 | tps = np.cumsum(tps) 170 | recall = tps/counts[k] 171 | prec = tps/(tps+fps) 172 | ap = 0 173 | for r_tilde in np.arange(0, 1.1, 0.1): 174 | prec_rec = prec[recall>=r_tilde] 175 | if len(prec_rec) > 0: 176 | ap += np.amax(prec_rec) 177 | 178 | ap /= 11. 179 | aps[k] = ap 180 | 181 | return aps 182 | 183 | #--------------------------------------------------------------------------- 184 | def clear(self): 185 | """ 186 | Clear the current detection cache. Useful for restarting the calculation 187 | for a new batch of data. 188 | """ 189 | self.det_params = defaultdict(list) 190 | self.det_confidence = defaultdict(list) 191 | self.det_sample_ids = defaultdict(list) 192 | self.gt_boxes = [] 193 | -------------------------------------------------------------------------------- /data_queue.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 17.09.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import queue as q 22 | import numpy as np 23 | import multiprocessing as mp 24 | 25 | #------------------------------------------------------------------------------- 26 | class DataQueue: 27 | #--------------------------------------------------------------------------- 28 | def __init__(self, img_template, label_template, maxsize): 29 | #----------------------------------------------------------------------- 30 | # Figure out the data tupes, sizes and shapes of both arrays 31 | #----------------------------------------------------------------------- 32 | self.img_dtype = img_template.dtype 33 | self.img_shape = img_template.shape 34 | self.img_bc = len(img_template.tobytes()) 35 | self.label_dtype = label_template.dtype 36 | self.label_shape = label_template.shape 37 | self.label_bc = len(label_template.tobytes()) 38 | 39 | #----------------------------------------------------------------------- 40 | # Make an array pool and queue 41 | #----------------------------------------------------------------------- 42 | self.array_pool = [] 43 | self.array_queue = mp.Queue(maxsize) 44 | for i in range(maxsize): 45 | img_buff = mp.Array('c', self.img_bc, lock=False) 46 | img_arr = np.frombuffer(img_buff, dtype=self.img_dtype) 47 | img_arr = img_arr.reshape(self.img_shape) 48 | 49 | label_buff = mp.Array('c', self.label_bc, lock=False) 50 | label_arr = np.frombuffer(label_buff, dtype=self.label_dtype) 51 | label_arr = label_arr.reshape(self.label_shape) 52 | 53 | self.array_pool.append((img_arr, label_arr)) 54 | self.array_queue.put(i) 55 | 56 | self.queue = mp.Queue(maxsize) 57 | 58 | #--------------------------------------------------------------------------- 59 | def put(self, img, label, boxes, *args, **kwargs): 60 | #----------------------------------------------------------------------- 61 | # Check whether the params are consistent with the data we can store 62 | #----------------------------------------------------------------------- 63 | def check_consistency(name, arr, dtype, shape, byte_count): 64 | if type(arr) is not np.ndarray: 65 | raise ValueError(name + ' needs to be a numpy array') 66 | if arr.dtype != dtype: 67 | raise ValueError('{}\'s elements need to be of type {} but is {}' \ 68 | .format(name, str(dtype), str(arr.dtype))) 69 | if arr.shape != shape: 70 | raise ValueError('{}\'s shape needs to be {} but is {}' \ 71 | .format(name, shape, arr.shape)) 72 | if len(arr.tobytes()) != byte_count: 73 | raise ValueError('{}\'s byte count needs to be {} but is {}' \ 74 | .format(name, byte_count, len(arr.data))) 75 | 76 | check_consistency('img', img, self.img_dtype, self.img_shape, 77 | self.img_bc) 78 | check_consistency('label', label, self.label_dtype, self.label_shape, 79 | self.label_bc) 80 | 81 | #----------------------------------------------------------------------- 82 | # If we can not get the slot within timeout we are actually full, not 83 | # empty 84 | #----------------------------------------------------------------------- 85 | try: 86 | arr_id = self.array_queue.get(*args, **kwargs) 87 | except q.Empty: 88 | raise q.Full() 89 | 90 | #----------------------------------------------------------------------- 91 | # Copy the arrays into the shared pool 92 | #----------------------------------------------------------------------- 93 | self.array_pool[arr_id][0][:] = img 94 | self.array_pool[arr_id][1][:] = label 95 | self.queue.put((arr_id, boxes), *args, **kwargs) 96 | 97 | #--------------------------------------------------------------------------- 98 | def get(self, *args, **kwargs): 99 | item = self.queue.get(*args, **kwargs) 100 | arr_id = item[0] 101 | boxes = item[1] 102 | 103 | img = np.copy(self.array_pool[arr_id][0]) 104 | label = np.copy(self.array_pool[arr_id][1]) 105 | 106 | self.array_queue.put(arr_id) 107 | 108 | return img, label, boxes 109 | 110 | #--------------------------------------------------------------------------- 111 | def empty(self): 112 | return self.queue.empty() 113 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #------------------------------------------------------------------------------- 3 | # Author: Lukasz Janyst 4 | # Date: 05.02.2018 5 | #------------------------------------------------------------------------------- 6 | # This file is part of SSD-TensorFlow. 7 | # 8 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # SSD-TensorFlow is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with SSD-Tensorflow. If not, see . 20 | #------------------------------------------------------------------------------- 21 | 22 | import tensorflow as tf 23 | import argparse 24 | import pickle 25 | import numpy as np 26 | import sys 27 | import cv2 28 | import os 29 | 30 | from ssdutils import get_anchors_for_preset, decode_boxes, suppress_overlaps 31 | from utils import draw_box 32 | from tqdm import tqdm 33 | 34 | if sys.version_info[0] < 3: 35 | print("This is a Python 3 program. Use Python 3 or higher.") 36 | sys.exit(1) 37 | 38 | #------------------------------------------------------------------------------- 39 | # Start the show 40 | #------------------------------------------------------------------------------- 41 | def main(): 42 | #--------------------------------------------------------------------------- 43 | # Parse the commandline 44 | #--------------------------------------------------------------------------- 45 | parser = argparse.ArgumentParser(description='SSD inference') 46 | parser.add_argument("files", nargs="*") 47 | parser.add_argument('--model', default='model300.pb', 48 | help='model file') 49 | parser.add_argument('--training-data', default='training-data-300.pkl', 50 | help='training data') 51 | parser.add_argument('--output-dir', default='test-out', 52 | help='output directory') 53 | parser.add_argument('--batch-size', type=int, default=32, 54 | help='batch size') 55 | args = parser.parse_args() 56 | 57 | #--------------------------------------------------------------------------- 58 | # Print parameters 59 | #--------------------------------------------------------------------------- 60 | print('[i] Model: ', args.model) 61 | print('[i] Training data: ', args.training_data) 62 | print('[i] Output dir: ', args.output_dir) 63 | print('[i] Batch size: ', args.batch_size) 64 | 65 | #--------------------------------------------------------------------------- 66 | # Load the graph and the training data 67 | #--------------------------------------------------------------------------- 68 | graph_def = tf.GraphDef() 69 | with open(args.model, 'rb') as f: 70 | serialized = f.read() 71 | graph_def.ParseFromString(serialized) 72 | 73 | with open(args.training_data, 'rb') as f: 74 | data = pickle.load(f) 75 | preset = data['preset'] 76 | colors = data['colors'] 77 | lid2name = data['lid2name'] 78 | anchors = get_anchors_for_preset(preset) 79 | 80 | #--------------------------------------------------------------------------- 81 | # Create the output directory 82 | #--------------------------------------------------------------------------- 83 | if not os.path.exists(args.output_dir): 84 | os.makedirs(args.output_dir) 85 | 86 | #--------------------------------------------------------------------------- 87 | # Run the detections in batches 88 | #--------------------------------------------------------------------------- 89 | with tf.Session() as sess: 90 | tf.import_graph_def(graph_def, name='detector') 91 | img_input = sess.graph.get_tensor_by_name('detector/image_input:0') 92 | result = sess.graph.get_tensor_by_name('detector/result/result:0') 93 | 94 | files = sys.argv[1:] 95 | 96 | for i in tqdm(range(0, len(files), args.batch_size)): 97 | batch_names = files[i:i+args.batch_size] 98 | batch_imgs = [] 99 | batch = [] 100 | for f in batch_names: 101 | img = cv2.imread(f) 102 | batch_imgs.append(img) 103 | img = cv2.resize(img, (300, 300)) 104 | batch.append(img) 105 | 106 | batch = np.array(batch) 107 | feed = {img_input: batch} 108 | enc_boxes = sess.run(result, feed_dict=feed) 109 | 110 | for i in range(len(batch_names)): 111 | boxes = decode_boxes(enc_boxes[i], anchors, 0.5, lid2name, None) 112 | boxes = suppress_overlaps(boxes)[:200] 113 | name = os.path.basename(batch_names[i]) 114 | 115 | with open(os.path.join(args.output_dir, name+'.txt'), 'w') as f: 116 | for box in boxes: 117 | draw_box(batch_imgs[i], box[1], colors[box[1].label]) 118 | 119 | box_data = '{} {} {} {} {} {}\n'.format(box[1].label, 120 | box[1].labelid, box[1].center.x, box[1].center.y, 121 | box[1].size.w, box[1].size.h) 122 | f.write(box_data) 123 | 124 | cv2.imwrite(os.path.join(args.output_dir, name), 125 | batch_imgs[i]) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /export_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #------------------------------------------------------------------------------- 3 | # Author: Lukasz Janyst 4 | # Date: 27.09.2017 5 | #------------------------------------------------------------------------------- 6 | # This file is part of SSD-TensorFlow. 7 | # 8 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # SSD-TensorFlow is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with SSD-Tensorflow. If not, see . 20 | #------------------------------------------------------------------------------- 21 | 22 | import argparse 23 | import sys 24 | import os 25 | 26 | import tensorflow as tf 27 | 28 | from tensorflow.python.framework import graph_util 29 | 30 | if sys.version_info[0] < 3: 31 | print("This is a Python 3 program. Use Python 3 or higher.") 32 | sys.exit(1) 33 | 34 | #--------------------------------------------------------------------------- 35 | # Parse the commandline 36 | #--------------------------------------------------------------------------- 37 | parser = argparse.ArgumentParser(description='Export a tensorflow model') 38 | parser.add_argument('--metagraph-file', default='final.ckpt.meta', 39 | help='name of the metagraph file') 40 | parser.add_argument('--checkpoint-file', default='final.ckpt', 41 | help='name of the checkpoint file') 42 | parser.add_argument('--output-file', default='model.pb', 43 | help='name of the output file') 44 | parser.add_argument('--output-tensors', nargs='+', 45 | required=True, 46 | help='names of the output tensors') 47 | args = parser.parse_args() 48 | 49 | print('[i] Matagraph file: ', args.metagraph_file) 50 | print('[i] Checkpoint file: ', args.checkpoint_file) 51 | print('[i] Output file: ', args.output_file) 52 | print('[i] Output tensors: ', args.output_tensors) 53 | 54 | for f in [args.checkpoint_file+'.index', args.metagraph_file]: 55 | if not os.path.exists(f): 56 | print('[!] Cannot find file:', f) 57 | sys.exit(1) 58 | 59 | #------------------------------------------------------------------------------- 60 | # Export the graph 61 | #------------------------------------------------------------------------------- 62 | with tf.Session() as sess: 63 | saver = tf.train.import_meta_graph(args.metagraph_file) 64 | saver.restore(sess, args.checkpoint_file) 65 | 66 | graph = tf.get_default_graph() 67 | input_graph_def = graph.as_graph_def() 68 | output_graph_def = graph_util.convert_variables_to_constants( 69 | sess, input_graph_def, args.output_tensors) 70 | 71 | with open(args.output_file, "wb") as f: 72 | f.write(output_graph_def.SerializeToString()) 73 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #------------------------------------------------------------------------------- 3 | # Author: Lukasz Janyst 4 | # Date: 09.09.2017 5 | #------------------------------------------------------------------------------- 6 | # This file is part of SSD-TensorFlow. 7 | # 8 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # SSD-TensorFlow is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with SSD-Tensorflow. If not, see . 20 | #------------------------------------------------------------------------------- 21 | 22 | import argparse 23 | import pickle 24 | import math 25 | import sys 26 | import cv2 27 | import os 28 | 29 | import tensorflow as tf 30 | import numpy as np 31 | 32 | from average_precision import APCalculator, APs2mAP 33 | from pascal_summary import PascalSummary 34 | from ssdutils import get_anchors_for_preset, decode_boxes, suppress_overlaps 35 | from ssdvgg import SSDVGG 36 | from utils import str2bool, load_data_source, draw_box 37 | from tqdm import tqdm 38 | 39 | if sys.version_info[0] < 3: 40 | print("This is a Python 3 program. Use Python 3 or higher.") 41 | sys.exit(1) 42 | 43 | #------------------------------------------------------------------------------- 44 | def sample_generator(samples, image_size, batch_size): 45 | image_size = (image_size.w, image_size.h) 46 | for offset in range(0, len(samples), batch_size): 47 | files = samples[offset:offset+batch_size] 48 | images = [] 49 | idxs = [] 50 | for i, image_file in enumerate(files): 51 | image = cv2.resize(cv2.imread(image_file), image_size) 52 | images.append(image.astype(np.float32)) 53 | idxs.append(offset+i) 54 | yield np.array(images), idxs 55 | 56 | #------------------------------------------------------------------------------- 57 | def main(): 58 | #--------------------------------------------------------------------------- 59 | # Parse commandline 60 | #--------------------------------------------------------------------------- 61 | parser = argparse.ArgumentParser(description='SSD inference') 62 | parser.add_argument("files", nargs="*") 63 | parser.add_argument('--name', default='test', 64 | help='project name') 65 | parser.add_argument('--checkpoint', type=int, default=-1, 66 | help='checkpoint to restore; -1 is the most recent') 67 | parser.add_argument('--training-data', 68 | default='pascal-voc/training-data.pkl', 69 | help='Information about parameters used for training') 70 | parser.add_argument('--output-dir', default='test-output', 71 | help='directory for the resulting images') 72 | parser.add_argument('--annotate', type=str2bool, default='False', 73 | help="Annotate the data samples") 74 | parser.add_argument('--dump-predictions', type=str2bool, default='False', 75 | help="Dump raw predictions") 76 | parser.add_argument('--compute-stats', type=str2bool, default='True', 77 | help="Compute the mAP stats") 78 | parser.add_argument('--data-source', default=None, 79 | help='Use test files from the data source') 80 | parser.add_argument('--data-dir', default='pascal-voc', 81 | help='Use test files from the data source') 82 | parser.add_argument('--batch-size', type=int, default=32, 83 | help='batch size') 84 | parser.add_argument('--sample', default='test', 85 | choices=['test', 'trainval'], help='sample to run on') 86 | parser.add_argument('--threshold', type=float, default=0.5, 87 | help='confidence threshold') 88 | parser.add_argument('--pascal-summary', type=str2bool, default='False', 89 | help='dump the detections in Pascal VOC format') 90 | 91 | args = parser.parse_args() 92 | 93 | #--------------------------------------------------------------------------- 94 | # Print parameters 95 | #--------------------------------------------------------------------------- 96 | print('[i] Project name: ', args.name) 97 | print('[i] Training data: ', args.training_data) 98 | print('[i] Batch size: ', args.batch_size) 99 | print('[i] Data source: ', args.data_source) 100 | print('[i] Data directory: ', args.data_dir) 101 | print('[i] Output directory: ', args.output_dir) 102 | print('[i] Annotate: ', args.annotate) 103 | print('[i] Dump predictions: ', args.dump_predictions) 104 | print('[i] Sample: ', args.sample) 105 | print('[i] Threshold: ', args.threshold) 106 | print('[i] Pascal summary: ', args.pascal_summary) 107 | 108 | #--------------------------------------------------------------------------- 109 | # Check if we can get the checkpoint 110 | #--------------------------------------------------------------------------- 111 | state = tf.train.get_checkpoint_state(args.name) 112 | if state is None: 113 | print('[!] No network state found in ' + args.name) 114 | return 1 115 | 116 | try: 117 | checkpoint_file = state.all_model_checkpoint_paths[args.checkpoint] 118 | except IndexError: 119 | print('[!] Cannot find checkpoint ' + str(args.checkpoint_file)) 120 | return 1 121 | 122 | metagraph_file = checkpoint_file + '.meta' 123 | 124 | if not os.path.exists(metagraph_file): 125 | print('[!] Cannot find metagraph ' + metagraph_file) 126 | return 1 127 | 128 | #--------------------------------------------------------------------------- 129 | # Load the training data 130 | #--------------------------------------------------------------------------- 131 | try: 132 | with open(args.training_data, 'rb') as f: 133 | data = pickle.load(f) 134 | preset = data['preset'] 135 | colors = data['colors'] 136 | lid2name = data['lid2name'] 137 | num_classes = data['num-classes'] 138 | image_size = preset.image_size 139 | anchors = get_anchors_for_preset(preset) 140 | except (FileNotFoundError, IOError, KeyError) as e: 141 | print('[!] Unable to load training data:', str(e)) 142 | return 1 143 | 144 | #--------------------------------------------------------------------------- 145 | # Load the data source if defined 146 | #--------------------------------------------------------------------------- 147 | compute_stats = False 148 | source = None 149 | if args.data_source: 150 | print('[i] Configuring the data source...') 151 | try: 152 | source = load_data_source(args.data_source) 153 | if args.sample == 'test': 154 | source.load_test_data(args.data_dir) 155 | num_samples = source.num_test 156 | samples = source.test_samples 157 | else: 158 | source.load_trainval_data(args.data_dir, 0) 159 | num_samples = source.num_train 160 | samples = source.train_samples 161 | print('[i] # samples: ', num_samples) 162 | print('[i] # classes: ', source.num_classes) 163 | except (ImportError, AttributeError, RuntimeError) as e: 164 | print('[!] Unable to load data source:', str(e)) 165 | return 1 166 | 167 | if args.compute_stats: 168 | compute_stats = True 169 | 170 | #--------------------------------------------------------------------------- 171 | # Create a list of files to analyse and make sure that the output directory 172 | # exists 173 | #--------------------------------------------------------------------------- 174 | files = [] 175 | 176 | if source: 177 | for sample in samples: 178 | files.append(sample.filename) 179 | 180 | if not source: 181 | if args.files: 182 | files = args.files 183 | 184 | if not files: 185 | print('[!] No files specified') 186 | return 1 187 | 188 | files = list(filter(lambda x: os.path.exists(x), files)) 189 | if files: 190 | if not os.path.exists(args.output_dir): 191 | os.makedirs(args.output_dir) 192 | 193 | #--------------------------------------------------------------------------- 194 | # Print model and dataset stats 195 | #--------------------------------------------------------------------------- 196 | print('[i] Compute stats: ', compute_stats) 197 | print('[i] Network checkpoint:', checkpoint_file) 198 | print('[i] Metagraph file: ', metagraph_file) 199 | print('[i] Image size: ', image_size) 200 | print('[i] Number of files: ', len(files)) 201 | 202 | #--------------------------------------------------------------------------- 203 | # Create the network 204 | #--------------------------------------------------------------------------- 205 | if compute_stats: 206 | ap_calc = APCalculator() 207 | 208 | if args.pascal_summary: 209 | pascal_summary = PascalSummary() 210 | 211 | with tf.Session() as sess: 212 | print('[i] Creating the model...') 213 | net = SSDVGG(sess, preset) 214 | net.build_from_metagraph(metagraph_file, checkpoint_file) 215 | 216 | #----------------------------------------------------------------------- 217 | # Process the images 218 | #----------------------------------------------------------------------- 219 | generator = sample_generator(files, image_size, args.batch_size) 220 | n_sample_batches = int(math.ceil(len(files)/args.batch_size)) 221 | description = '[i] Processing samples' 222 | 223 | for x, idxs in tqdm(generator, total=n_sample_batches, 224 | desc=description, unit='batches'): 225 | feed = {net.image_input: x, 226 | net.keep_prob: 1} 227 | enc_boxes = sess.run(net.result, feed_dict=feed) 228 | 229 | #------------------------------------------------------------------- 230 | # Process the predictions 231 | #------------------------------------------------------------------- 232 | for i in range(enc_boxes.shape[0]): 233 | boxes = decode_boxes(enc_boxes[i], anchors, args.threshold, 234 | lid2name, None) 235 | boxes = suppress_overlaps(boxes)[:200] 236 | filename = files[idxs[i]] 237 | basename = os.path.basename(filename) 238 | 239 | #--------------------------------------------------------------- 240 | # Annotate samples 241 | #--------------------------------------------------------------- 242 | if args.annotate: 243 | img = cv2.imread(filename) 244 | for box in boxes: 245 | draw_box(img, box[1], colors[box[1].label]) 246 | fn = args.output_dir+'/'+basename 247 | cv2.imwrite(fn, img) 248 | 249 | #--------------------------------------------------------------- 250 | # Dump the predictions 251 | #--------------------------------------------------------------- 252 | if args.dump_predictions: 253 | raw_fn = args.output_dir+'/'+basename+'.npy' 254 | np.save(raw_fn, enc_boxes[i]) 255 | 256 | #--------------------------------------------------------------- 257 | # Add predictions to the stats calculator and to the Pascal 258 | # summary 259 | #--------------------------------------------------------------- 260 | if compute_stats: 261 | ap_calc.add_detections(samples[idxs[i]].boxes, boxes) 262 | 263 | if args.pascal_summary: 264 | pascal_summary.add_detections(filename, boxes) 265 | 266 | #--------------------------------------------------------------------------- 267 | # Compute and print the stats 268 | #--------------------------------------------------------------------------- 269 | if compute_stats: 270 | aps = ap_calc.compute_aps() 271 | for k, v in aps.items(): 272 | print('[i] AP [{0}]: {1:.3f}'.format(k, v)) 273 | print('[i] mAP: {0:.3f}'.format(APs2mAP(aps))) 274 | 275 | #--------------------------------------------------------------------------- 276 | # Write the pascal summary files 277 | #--------------------------------------------------------------------------- 278 | if args.pascal_summary: 279 | pascal_summary.write_summary(args.output_dir) 280 | 281 | print('[i] All done.') 282 | return 0 283 | 284 | if __name__ == '__main__': 285 | sys.exit(main()) 286 | -------------------------------------------------------------------------------- /pascal-voc/download-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -c http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 4 | wget -c http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 5 | wget -c http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 6 | 7 | mkdir -p trainval 8 | mkdir -p test 9 | 10 | (cd trainval && tar xf ../VOCtrainval_06-Nov-2007.tar) 11 | (cd trainval && tar xf ../VOCtrainval_11-May-2012.tar) 12 | (cd test && tar xf ../VOCtest_06-Nov-2007.tar) 13 | -------------------------------------------------------------------------------- /pascal_summary.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 31.10.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import cv2 22 | import os 23 | 24 | from collections import defaultdict, namedtuple 25 | from utils import prop2abs, Size 26 | 27 | #------------------------------------------------------------------------------- 28 | Detection = namedtuple('Detection', ['fileid', 'confidence', 'left', 'top', 29 | 'right', 'bottom']) 30 | 31 | #------------------------------------------------------------------------------- 32 | class PascalSummary: 33 | #--------------------------------------------------------------------------- 34 | def __init__(self): 35 | self.boxes = defaultdict(list) 36 | 37 | #--------------------------------------------------------------------------- 38 | def add_detections(self, filename, boxes): 39 | fileid = os.path.basename(filename) 40 | fileid = ''.join(fileid.split('.')[:-1]) 41 | img = cv2.imread(filename) 42 | img_size = Size(img.shape[1], img.shape[0]) 43 | for conf, box in boxes: 44 | xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, img_size) 45 | if xmin < 0: xmin = 0 46 | if xmin >= img_size.w: xmin = img_size.w-1 47 | if xmax < 0: xmax = 0 48 | if xmax >= img_size.w: xmax = img_size.w-1 49 | if ymin < 0: ymin = 0 50 | if ymin >= img_size.h: ymin = img_size.h-1 51 | if ymax < 0: ymax = 0 52 | if ymax >= img_size.h: ymax = img_size.h-1 53 | det = Detection(fileid, conf, float(xmin+1), float(ymin+1), float(xmax+1), float(ymax+1)) 54 | self.boxes[box.label].append(det) 55 | 56 | #--------------------------------------------------------------------------- 57 | def write_summary(self, target_dir): 58 | for k, v in self.boxes.items(): 59 | filename = target_dir+'/comp4_det_test_'+k+'.txt' 60 | with open(filename, 'w') as f: 61 | for det in v: 62 | line = "{} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}\n" \ 63 | .format(det.fileid, det.confidence, det.left, det.top, 64 | det.right, det.bottom) 65 | f.write(line) 66 | -------------------------------------------------------------------------------- /process_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #------------------------------------------------------------------------------- 3 | # Author: Lukasz Janyst 4 | # Date: 29.08.2017 5 | #------------------------------------------------------------------------------- 6 | # This file is part of SSD-TensorFlow. 7 | # 8 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # SSD-TensorFlow is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with SSD-Tensorflow. If not, see . 20 | #------------------------------------------------------------------------------- 21 | 22 | import argparse 23 | import pickle 24 | import sys 25 | import cv2 26 | import os 27 | 28 | import numpy as np 29 | 30 | from transforms import * 31 | from ssdutils import get_preset_by_name 32 | from utils import load_data_source, str2bool, draw_box 33 | from tqdm import tqdm 34 | 35 | if sys.version_info[0] < 3: 36 | print("This is a Python 3 program. Use Python 3 or higher.") 37 | sys.exit(1) 38 | 39 | #------------------------------------------------------------------------------- 40 | def annotate(data_dir, samples, colors, sample_name): 41 | """ 42 | Draw the bounding boxes on the sample images 43 | :param data_dir: the directory where the dataset's files are stored 44 | :param samples: samples to be processed 45 | :param colors: a dictionary mapping class name to a BGR color tuple 46 | :param colors: name of the sample 47 | """ 48 | result_dir = data_dir+'/annotated/'+sample_name.strip()+'/' 49 | if not os.path.exists(result_dir): 50 | os.makedirs(result_dir) 51 | 52 | for sample in tqdm(samples, desc=sample_name, unit='samples'): 53 | img = cv2.imread(sample.filename) 54 | basefn = os.path.basename(sample.filename) 55 | for box in sample.boxes: 56 | draw_box(img, box, colors[box.label]) 57 | cv2.imwrite(result_dir+basefn, img) 58 | 59 | #------------------------------------------------------------------------------- 60 | def build_sampler(overlap, trials): 61 | return SamplerTransform(sample=True, min_scale=0.3, max_scale=1.0, 62 | min_aspect_ratio=0.5, max_aspect_ratio=2.0, 63 | min_jaccard_overlap=overlap, max_trials=trials) 64 | 65 | #------------------------------------------------------------------------------- 66 | def build_train_transforms(preset, num_classes, sampler_trials, expand_prob): 67 | #--------------------------------------------------------------------------- 68 | # Resizing 69 | #--------------------------------------------------------------------------- 70 | tf_resize = ResizeTransform(width=preset.image_size.w, 71 | height=preset.image_size.h, 72 | algorithms=[cv2.INTER_LINEAR, 73 | cv2.INTER_AREA, 74 | cv2.INTER_NEAREST, 75 | cv2.INTER_CUBIC, 76 | cv2.INTER_LANCZOS4]) 77 | 78 | #--------------------------------------------------------------------------- 79 | # Image distortions 80 | #--------------------------------------------------------------------------- 81 | tf_brightness = BrightnessTransform(delta=32) 82 | tf_rnd_brightness = RandomTransform(prob=0.5, transform=tf_brightness) 83 | 84 | tf_contrast = ContrastTransform(lower=0.5, upper=1.5) 85 | tf_rnd_contrast = RandomTransform(prob=0.5, transform=tf_contrast) 86 | 87 | tf_hue = HueTransform(delta=18) 88 | tf_rnd_hue = RandomTransform(prob=0.5, transform=tf_hue) 89 | 90 | tf_saturation = SaturationTransform(lower=0.5, upper=1.5) 91 | tf_rnd_saturation = RandomTransform(prob=0.5, transform=tf_saturation) 92 | 93 | tf_reorder_channels = ReorderChannelsTransform() 94 | tf_rnd_reorder_channels = RandomTransform(prob=0.5, 95 | transform=tf_reorder_channels) 96 | 97 | #--------------------------------------------------------------------------- 98 | # Compositions of image distortions 99 | #--------------------------------------------------------------------------- 100 | tf_distort_lst = [ 101 | tf_rnd_contrast, 102 | tf_rnd_saturation, 103 | tf_rnd_hue, 104 | tf_rnd_contrast 105 | ] 106 | tf_distort_1 = ComposeTransform(transforms=tf_distort_lst[:-1]) 107 | tf_distort_2 = ComposeTransform(transforms=tf_distort_lst[1:]) 108 | tf_distort_comp = [tf_distort_1, tf_distort_2] 109 | tf_distort = TransformPickerTransform(transforms=tf_distort_comp) 110 | 111 | #--------------------------------------------------------------------------- 112 | # Expand sample 113 | #--------------------------------------------------------------------------- 114 | tf_expand = ExpandTransform(max_ratio=4.0, mean_value=[104, 117, 123]) 115 | tf_rnd_expand = RandomTransform(prob=expand_prob, transform=tf_expand) 116 | 117 | #--------------------------------------------------------------------------- 118 | # Samplers 119 | #--------------------------------------------------------------------------- 120 | samplers = [ 121 | SamplerTransform(sample=False), 122 | build_sampler(0.1, sampler_trials), 123 | build_sampler(0.3, sampler_trials), 124 | build_sampler(0.5, sampler_trials), 125 | build_sampler(0.7, sampler_trials), 126 | build_sampler(0.9, sampler_trials), 127 | build_sampler(1.0, sampler_trials) 128 | ] 129 | tf_sample_picker = SamplePickerTransform(samplers=samplers) 130 | 131 | #--------------------------------------------------------------------------- 132 | # Horizontal flip 133 | #--------------------------------------------------------------------------- 134 | tf_flip = HorizontalFlipTransform() 135 | tf_rnd_flip = RandomTransform(prob=0.5, transform=tf_flip) 136 | 137 | #--------------------------------------------------------------------------- 138 | # Transform list 139 | #--------------------------------------------------------------------------- 140 | transforms = [ 141 | ImageLoaderTransform(), 142 | tf_rnd_brightness, 143 | tf_distort, 144 | tf_rnd_reorder_channels, 145 | tf_rnd_expand, 146 | tf_sample_picker, 147 | tf_rnd_flip, 148 | LabelCreatorTransform(preset=preset, num_classes=num_classes), 149 | tf_resize 150 | ] 151 | return transforms 152 | 153 | #------------------------------------------------------------------------------- 154 | def build_valid_transforms(preset, num_classes): 155 | tf_resize = ResizeTransform(width=preset.image_size.w, 156 | height=preset.image_size.h, 157 | algorithms=[cv2.INTER_LINEAR]) 158 | transforms = [ 159 | ImageLoaderTransform(), 160 | LabelCreatorTransform(preset=preset, num_classes=num_classes), 161 | tf_resize 162 | ] 163 | return transforms 164 | 165 | #------------------------------------------------------------------------------- 166 | def main(): 167 | #--------------------------------------------------------------------------- 168 | # Parse the commandline 169 | #--------------------------------------------------------------------------- 170 | parser = argparse.ArgumentParser(description='Process a dataset for SSD') 171 | parser.add_argument('--data-source', default='pascal_voc', 172 | help='data source') 173 | parser.add_argument('--data-dir', default='pascal-voc', 174 | help='data directory') 175 | parser.add_argument('--validation-fraction', type=float, default=0.025, 176 | help='fraction of the data to be used for validation') 177 | parser.add_argument('--expand-probability', type=float, default=0.5, 178 | help='probability of running sample expander') 179 | parser.add_argument('--sampler-trials', type=int, default=50, 180 | help='number of time a sampler tries to find a sample') 181 | parser.add_argument('--annotate', type=str2bool, default='False', 182 | help="Annotate the data samples") 183 | parser.add_argument('--compute-td', type=str2bool, default='True', 184 | help="Compute training data") 185 | parser.add_argument('--preset', default='vgg300', 186 | choices=['vgg300', 'vgg512'], 187 | help="The neural network preset") 188 | parser.add_argument('--process-test', type=str2bool, default='False', 189 | help="process the test dataset") 190 | args = parser.parse_args() 191 | 192 | print('[i] Data source: ', args.data_source) 193 | print('[i] Data directory: ', args.data_dir) 194 | print('[i] Validation fraction: ', args.validation_fraction) 195 | print('[i] Expand probability: ', args.expand_probability) 196 | print('[i] Sampler trials: ', args.sampler_trials) 197 | print('[i] Annotate: ', args.annotate) 198 | print('[i] Compute training data:', args.compute_td) 199 | print('[i] Preset: ', args.preset) 200 | print('[i] Process test dataset: ', args.process_test) 201 | 202 | #--------------------------------------------------------------------------- 203 | # Load the data source 204 | #--------------------------------------------------------------------------- 205 | print('[i] Configuring the data source...') 206 | try: 207 | source = load_data_source(args.data_source) 208 | source.load_trainval_data(args.data_dir, args.validation_fraction) 209 | if args.process_test: 210 | source.load_test_data(args.data_dir) 211 | print('[i] # training samples: ', source.num_train) 212 | print('[i] # validation samples: ', source.num_valid) 213 | print('[i] # testing samples: ', source.num_test) 214 | print('[i] # classes: ', source.num_classes) 215 | except (ImportError, AttributeError, RuntimeError) as e: 216 | print('[!] Unable to load data source:', str(e)) 217 | return 1 218 | 219 | #--------------------------------------------------------------------------- 220 | # Annotate samples 221 | #--------------------------------------------------------------------------- 222 | if args.annotate: 223 | print('[i] Annotating samples...') 224 | annotate(args.data_dir, source.train_samples, source.colors, 'train') 225 | annotate(args.data_dir, source.valid_samples, source.colors, 'valid') 226 | if args.process_test: 227 | annotate(args.data_dir, source.test_samples, source.colors, 'test ') 228 | 229 | #--------------------------------------------------------------------------- 230 | # Compute the training data 231 | #--------------------------------------------------------------------------- 232 | if args.compute_td: 233 | preset = get_preset_by_name(args.preset) 234 | with open(args.data_dir+'/train-samples.pkl', 'wb') as f: 235 | pickle.dump(source.train_samples, f) 236 | with open(args.data_dir+'/valid-samples.pkl', 'wb') as f: 237 | pickle.dump(source.valid_samples, f) 238 | 239 | with open(args.data_dir+'/training-data.pkl', 'wb') as f: 240 | data = { 241 | 'preset': preset, 242 | 'num-classes': source.num_classes, 243 | 'colors': source.colors, 244 | 'lid2name': source.lid2name, 245 | 'lname2id': source.lname2id, 246 | 'train-transforms': build_train_transforms(preset, 247 | source.num_classes, args.sampler_trials, 248 | args.expand_probability ), 249 | 'valid-transforms': build_valid_transforms(preset, 250 | source.num_classes) 251 | } 252 | pickle.dump(data, f) 253 | 254 | return 0 255 | 256 | if __name__ == '__main__': 257 | sys.exit(main()) 258 | -------------------------------------------------------------------------------- /source_pascal_voc.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 30.08.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import lxml.etree 22 | import random 23 | import math 24 | import cv2 25 | import os 26 | 27 | import numpy as np 28 | 29 | from utils import Label, Box, Sample, Size 30 | from utils import rgb2bgr, abs2prop 31 | from glob import glob 32 | from tqdm import tqdm 33 | 34 | #------------------------------------------------------------------------------- 35 | # Labels 36 | #------------------------------------------------------------------------------- 37 | label_defs = [ 38 | Label('aeroplane', rgb2bgr((0, 0, 0))), 39 | Label('bicycle', rgb2bgr((111, 74, 0))), 40 | Label('bird', rgb2bgr(( 81, 0, 81))), 41 | Label('boat', rgb2bgr((128, 64, 128))), 42 | Label('bottle', rgb2bgr((244, 35, 232))), 43 | Label('bus', rgb2bgr((230, 150, 140))), 44 | Label('car', rgb2bgr(( 70, 70, 70))), 45 | Label('cat', rgb2bgr((102, 102, 156))), 46 | Label('chair', rgb2bgr((190, 153, 153))), 47 | Label('cow', rgb2bgr((150, 120, 90))), 48 | Label('diningtable', rgb2bgr((153, 153, 153))), 49 | Label('dog', rgb2bgr((250, 170, 30))), 50 | Label('horse', rgb2bgr((220, 220, 0))), 51 | Label('motorbike', rgb2bgr((107, 142, 35))), 52 | Label('person', rgb2bgr(( 52, 151, 52))), 53 | Label('pottedplant', rgb2bgr(( 70, 130, 180))), 54 | Label('sheep', rgb2bgr((220, 20, 60))), 55 | Label('sofa', rgb2bgr(( 0, 0, 142))), 56 | Label('train', rgb2bgr(( 0, 0, 230))), 57 | Label('tvmonitor', rgb2bgr((119, 11, 32)))] 58 | 59 | #------------------------------------------------------------------------------- 60 | class PascalVOCSource: 61 | #--------------------------------------------------------------------------- 62 | def __init__(self): 63 | self.num_classes = len(label_defs) 64 | self.colors = {l.name: l.color for l in label_defs} 65 | self.lid2name = {i: l.name for i, l in enumerate(label_defs)} 66 | self.lname2id = {l.name: i for i, l in enumerate(label_defs)} 67 | self.num_train = 0 68 | self.num_valid = 0 69 | self.num_test = 0 70 | self.train_samples = [] 71 | self.valid_samples = [] 72 | self.test_samples = [] 73 | 74 | #--------------------------------------------------------------------------- 75 | def __build_annotation_list(self, root, dataset_type): 76 | """ 77 | Build a list of samples for the VOC dataset (either trainval or test) 78 | """ 79 | annot_root = root + '/Annotations/' 80 | annot_files = [] 81 | with open(root + '/ImageSets/Main/' + dataset_type + '.txt') as f: 82 | for line in f: 83 | annot_file = annot_root + line.strip() + '.xml' 84 | if os.path.exists(annot_file): 85 | annot_files.append(annot_file) 86 | return annot_files 87 | 88 | #--------------------------------------------------------------------------- 89 | def __build_sample_list(self, root, annot_files, dataset_name): 90 | """ 91 | Build a list of samples for the VOC dataset (either trainval or test) 92 | """ 93 | image_root = root + '/JPEGImages/' 94 | samples = [] 95 | 96 | #----------------------------------------------------------------------- 97 | # Process each annotated sample 98 | #----------------------------------------------------------------------- 99 | for fn in tqdm(annot_files, desc=dataset_name, unit='samples'): 100 | with open(fn, 'r') as f: 101 | doc = lxml.etree.parse(f) 102 | filename = image_root+doc.xpath('/annotation/filename')[0].text 103 | 104 | #--------------------------------------------------------------- 105 | # Get the file dimensions 106 | #--------------------------------------------------------------- 107 | if not os.path.exists(filename): 108 | continue 109 | 110 | img = cv2.imread(filename) 111 | imgsize = Size(img.shape[1], img.shape[0]) 112 | 113 | #--------------------------------------------------------------- 114 | # Get boxes for all the objects 115 | #--------------------------------------------------------------- 116 | boxes = [] 117 | objects = doc.xpath('/annotation/object') 118 | for obj in objects: 119 | #----------------------------------------------------------- 120 | # Get the properties of the box and convert them to the 121 | # proportional terms 122 | #----------------------------------------------------------- 123 | label = obj.xpath('name')[0].text 124 | xmin = int(float(obj.xpath('bndbox/xmin')[0].text)) 125 | xmax = int(float(obj.xpath('bndbox/xmax')[0].text)) 126 | ymin = int(float(obj.xpath('bndbox/ymin')[0].text)) 127 | ymax = int(float(obj.xpath('bndbox/ymax')[0].text)) 128 | center, size = abs2prop(xmin, xmax, ymin, ymax, imgsize) 129 | box = Box(label, self.lname2id[label], center, size) 130 | boxes.append(box) 131 | if not boxes: 132 | continue 133 | sample = Sample(filename, boxes, imgsize) 134 | samples.append(sample) 135 | 136 | return samples 137 | 138 | #--------------------------------------------------------------------------- 139 | def load_trainval_data(self, data_dir, valid_fraction): 140 | """ 141 | Load the training and validation data 142 | :param data_dir: the directory where the dataset's file are stored 143 | :param valid_fraction: what franction of the dataset should be used 144 | as a validation sample 145 | """ 146 | 147 | #----------------------------------------------------------------------- 148 | # Process the samples defined in the relevant file lists 149 | #----------------------------------------------------------------------- 150 | train_annot = [] 151 | train_samples = [] 152 | for vocid in ['VOC2007', 'VOC2012']: 153 | root = data_dir + '/trainval/VOCdevkit/'+vocid 154 | name = 'trainval_'+vocid 155 | annot = self.__build_annotation_list(root, 'trainval') 156 | train_annot += annot 157 | train_samples += self.__build_sample_list(root, annot, name) 158 | 159 | root = data_dir + '/test/VOCdevkit/VOC2007' 160 | annot = self.__build_annotation_list(root, 'test') 161 | train_samples += self.__build_sample_list(root, annot, 'test_VOC2007') 162 | 163 | #----------------------------------------------------------------------- 164 | # We have some 5.5k annotated samples that are not on these lists, so 165 | # we can use them for validation 166 | #----------------------------------------------------------------------- 167 | root = data_dir + '/trainval/VOCdevkit/VOC2012' 168 | all_annot = set(glob(root + '/Annotations/*.xml')) 169 | valid_annot = all_annot - set(train_annot) 170 | valid_samples = self.__build_sample_list(root, valid_annot, 171 | 'valid_VOC2012') 172 | 173 | #----------------------------------------------------------------------- 174 | # Final set up and sanity check 175 | #----------------------------------------------------------------------- 176 | self.valid_samples = valid_samples 177 | self.train_samples = train_samples 178 | 179 | if len(self.train_samples) == 0: 180 | raise RuntimeError('No training samples found in ' + data_dir) 181 | 182 | if valid_fraction > 0: 183 | if len(self.valid_samples) == 0: 184 | raise RuntimeError('No validation samples found in ' + data_dir) 185 | 186 | self.num_train = len(self.train_samples) 187 | self.num_valid = len(self.valid_samples) 188 | 189 | #--------------------------------------------------------------------------- 190 | def load_test_data(self, data_dir): 191 | """ 192 | Load the test data 193 | :param data_dir: the directory where the dataset's file are stored 194 | """ 195 | root = data_dir + '/test/VOCdevkit/VOC2012' 196 | annot = self.__build_annotation_list(root, 'test') 197 | self.test_samples = self.__build_sample_list(root, annot, 198 | 'test_VOC2012') 199 | 200 | if len(self.test_samples) == 0: 201 | raise RuntimeError('No testing samples found in ' + data_dir) 202 | 203 | self.num_test = len(self.test_samples) 204 | 205 | #------------------------------------------------------------------------------- 206 | def get_source(): 207 | return PascalVOCSource() 208 | -------------------------------------------------------------------------------- /ssdutils.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 29.08.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import numpy as np 22 | 23 | from utils import Size, Point, Overlap, Score, Box, prop2abs, normalize_box 24 | from collections import namedtuple, defaultdict 25 | from math import sqrt, log, exp 26 | 27 | #------------------------------------------------------------------------------- 28 | # Define the flavors of SSD that we're going to use and it's various properties. 29 | # It's done so that we don't have to build the whole network in memory in order 30 | # to pre-process the datasets. 31 | #------------------------------------------------------------------------------- 32 | SSDMap = namedtuple('SSDMap', ['size', 'scale', 'aspect_ratios']) 33 | SSDPreset = namedtuple('SSDPreset', ['name', 'image_size', 'maps', 34 | 'extra_scale', 'num_anchors']) 35 | 36 | SSD_PRESETS = { 37 | 'vgg300': SSDPreset(name = 'vgg300', 38 | image_size = Size(300, 300), 39 | maps = [ 40 | SSDMap(Size(38, 38), 0.1, [2, 0.5]), 41 | SSDMap(Size(19, 19), 0.2, [2, 3, 0.5, 1./3.]), 42 | SSDMap(Size(10, 10), 0.375, [2, 3, 0.5, 1./3.]), 43 | SSDMap(Size( 5, 5), 0.55, [2, 3, 0.5, 1./3.]), 44 | SSDMap(Size( 3, 3), 0.725, [2, 0.5]), 45 | SSDMap(Size( 1, 1), 0.9, [2, 0.5]) 46 | ], 47 | extra_scale = 1.075, 48 | num_anchors = 8732), 49 | 'vgg512': SSDPreset(name = 'vgg512', 50 | image_size = Size(512, 512), 51 | maps = [ 52 | SSDMap(Size(64, 64), 0.07, [2, 0.5]), 53 | SSDMap(Size(32, 32), 0.15, [2, 3, 0.5, 1./3.]), 54 | SSDMap(Size(16, 16), 0.3, [2, 3, 0.5, 1./3.]), 55 | SSDMap(Size( 8, 8), 0.45, [2, 3, 0.5, 1./3.]), 56 | SSDMap(Size( 4, 4), 0.6, [2, 3, 0.5, 1./3.]), 57 | SSDMap(Size( 2, 2), 0.75, [2, 0.5]), 58 | SSDMap(Size( 1, 1), 0.9, [2, 0.5]) 59 | ], 60 | extra_scale = 1.05, 61 | num_anchors = 24564) 62 | } 63 | 64 | #------------------------------------------------------------------------------- 65 | # Default box parameters both in terms proportional to image dimensions 66 | #------------------------------------------------------------------------------- 67 | Anchor = namedtuple('Anchor', ['center', 'size', 'x', 'y', 'scale', 'map']) 68 | 69 | #------------------------------------------------------------------------------- 70 | def get_preset_by_name(pname): 71 | if not pname in SSD_PRESETS: 72 | raise RuntimeError('No such preset: '+pname) 73 | return SSD_PRESETS[pname] 74 | 75 | #------------------------------------------------------------------------------- 76 | def get_anchors_for_preset(preset): 77 | """ 78 | Compute the default (anchor) boxes for the given SSD preset 79 | """ 80 | #--------------------------------------------------------------------------- 81 | # Compute the width and heights of the anchor boxes for every scale 82 | #--------------------------------------------------------------------------- 83 | box_sizes = [] 84 | for i in range(len(preset.maps)): 85 | map_params = preset.maps[i] 86 | s = map_params.scale 87 | aspect_ratios = [1] + map_params.aspect_ratios 88 | aspect_ratios = list(map(lambda x: sqrt(x), aspect_ratios)) 89 | 90 | sizes = [] 91 | for ratio in aspect_ratios: 92 | w = s * ratio 93 | h = s / ratio 94 | sizes.append((w, h)) 95 | if i < len(preset.maps)-1: 96 | s_prime = sqrt(s*preset.maps[i+1].scale) 97 | else: 98 | s_prime = sqrt(s*preset.extra_scale) 99 | sizes.append((s_prime, s_prime)) 100 | box_sizes.append(sizes) 101 | 102 | #--------------------------------------------------------------------------- 103 | # Compute the actual boxes for every scale and feature map 104 | #--------------------------------------------------------------------------- 105 | anchors = [] 106 | for k in range(len(preset.maps)): 107 | fk = preset.maps[k].size[0] 108 | s = preset.maps[k].scale 109 | for size in box_sizes[k]: 110 | for j in range(fk): 111 | y = (j+0.5)/float(fk) 112 | for i in range(fk): 113 | x = (i+0.5)/float(fk) 114 | box = Anchor(Point(x, y), Size(size[0], size[1]), 115 | i, j, s, k) 116 | anchors.append(box) 117 | return anchors 118 | 119 | #------------------------------------------------------------------------------- 120 | def anchors2array(anchors, img_size): 121 | """ 122 | Computes a numpy array out of absolute anchor params (img_size is needed 123 | as a reference) 124 | """ 125 | arr = np.zeros((len(anchors), 4)) 126 | for i in range(len(anchors)): 127 | anchor = anchors[i] 128 | xmin, xmax, ymin, ymax = prop2abs(anchor.center, anchor.size, img_size) 129 | arr[i] = np.array([xmin, xmax, ymin, ymax]) 130 | return arr 131 | 132 | #------------------------------------------------------------------------------- 133 | def box2array(box, img_size): 134 | xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, img_size) 135 | return np.array([xmin, xmax, ymin, ymax]) 136 | 137 | #------------------------------------------------------------------------------- 138 | def jaccard_overlap(box_arr, anchors_arr): 139 | areaa = (anchors_arr[:, 1]-anchors_arr[:, 0]+1) * \ 140 | (anchors_arr[:, 3]-anchors_arr[:, 2]+1) 141 | areab = (box_arr[1]-box_arr[0]+1) * (box_arr[3]-box_arr[2]+1) 142 | 143 | xxmin = np.maximum(box_arr[0], anchors_arr[:, 0]) 144 | xxmax = np.minimum(box_arr[1], anchors_arr[:, 1]) 145 | yymin = np.maximum(box_arr[2], anchors_arr[:, 2]) 146 | yymax = np.minimum(box_arr[3], anchors_arr[:, 3]) 147 | 148 | w = np.maximum(0, xxmax-xxmin+1) 149 | h = np.maximum(0, yymax-yymin+1) 150 | intersection = w*h 151 | union = areab+areaa-intersection 152 | return intersection/union 153 | 154 | #------------------------------------------------------------------------------- 155 | def compute_overlap(box_arr, anchors_arr, threshold): 156 | iou = jaccard_overlap(box_arr, anchors_arr) 157 | overlap = iou > threshold 158 | 159 | good_idxs = np.nonzero(overlap)[0] 160 | best_idx = np.argmax(iou) 161 | best = None 162 | good = [] 163 | 164 | if iou[best_idx] > threshold: 165 | best = Score(best_idx, iou[best_idx]) 166 | 167 | for idx in good_idxs: 168 | good.append(Score(idx, iou[idx])) 169 | 170 | return Overlap(best, good) 171 | 172 | #------------------------------------------------------------------------------- 173 | def compute_location(box, anchor): 174 | arr = np.zeros((4)) 175 | arr[0] = (box.center.x-anchor.center.x)/anchor.size.w*10 176 | arr[1] = (box.center.y-anchor.center.y)/anchor.size.h*10 177 | arr[2] = log(box.size.w/anchor.size.w)*5 178 | arr[3] = log(box.size.h/anchor.size.h)*5 179 | return arr 180 | 181 | #------------------------------------------------------------------------------- 182 | def decode_location(box, anchor): 183 | box[box > 100] = 100 # only happens early training 184 | 185 | x = box[0]/10 * anchor.size.w + anchor.center.x 186 | y = box[1]/10 * anchor.size.h + anchor.center.y 187 | w = exp(box[2]/5) * anchor.size.w 188 | h = exp(box[3]/5) * anchor.size.h 189 | return Point(x, y), Size(w, h) 190 | 191 | #------------------------------------------------------------------------------- 192 | def decode_boxes(pred, anchors, confidence_threshold = 0.01, lid2name = {}, 193 | detections_cap=200): 194 | """ 195 | Decode boxes from the neural net predictions. 196 | Label names are decoded using the lid2name dictionary - the id to name 197 | translation is not done if the corresponding key does not exist. 198 | """ 199 | 200 | #--------------------------------------------------------------------------- 201 | # Find the detections 202 | #--------------------------------------------------------------------------- 203 | num_classes = pred.shape[1]-4 204 | bg_class = num_classes-1 205 | box_class = np.argmax(pred[:, :num_classes-1], axis=1) 206 | confidence = pred[np.arange(len(pred)), box_class] 207 | if detections_cap is not None: 208 | detections = np.argsort(confidence)[::-1][:detections_cap] 209 | else: 210 | detections = np.argsort(confidence)[::-1] 211 | 212 | #--------------------------------------------------------------------------- 213 | # Decode coordinates of each box with confidence over a threshold 214 | #--------------------------------------------------------------------------- 215 | boxes = [] 216 | for idx in detections: 217 | confidence = pred[idx, box_class[idx]] 218 | if confidence < confidence_threshold: 219 | break 220 | 221 | center, size = decode_location(pred[idx, num_classes:], anchors[idx]) 222 | cid = box_class[idx] 223 | cname = None 224 | if cid in lid2name: 225 | cname = lid2name[cid] 226 | det = (confidence, normalize_box(Box(cname, cid, center, size))) 227 | boxes.append(det) 228 | 229 | return boxes 230 | 231 | #------------------------------------------------------------------------------- 232 | def non_maximum_suppression(boxes, overlap_threshold): 233 | #--------------------------------------------------------------------------- 234 | # Convert to absolute coordinates and to a more convenient format 235 | #--------------------------------------------------------------------------- 236 | xmin = [] 237 | xmax = [] 238 | ymin = [] 239 | ymax = [] 240 | conf = [] 241 | img_size = Size(1000, 1000) 242 | 243 | for box in boxes: 244 | params = prop2abs(box[1].center, box[1].size, img_size) 245 | xmin.append(params[0]) 246 | xmax.append(params[1]) 247 | ymin.append(params[2]) 248 | ymax.append(params[3]) 249 | conf.append(box[0]) 250 | 251 | xmin = np.array(xmin) 252 | xmax = np.array(xmax) 253 | ymin = np.array(ymin) 254 | ymax = np.array(ymax) 255 | conf = np.array(conf) 256 | 257 | #--------------------------------------------------------------------------- 258 | # Compute the area of each box and sort the indices by confidence level 259 | # (lowest confidence first first). 260 | #--------------------------------------------------------------------------- 261 | area = (xmax-xmin+1) * (ymax-ymin+1) 262 | idxs = np.argsort(conf) 263 | pick = [] 264 | 265 | #--------------------------------------------------------------------------- 266 | # Loop until we still have indices to process 267 | #--------------------------------------------------------------------------- 268 | while len(idxs) > 0: 269 | #----------------------------------------------------------------------- 270 | # Grab the last index (ie. the most confident detection), remove it from 271 | # the list of indices to process, and put it on the list of picks 272 | #----------------------------------------------------------------------- 273 | last = idxs.shape[0]-1 274 | i = idxs[last] 275 | idxs = np.delete(idxs, last) 276 | pick.append(i) 277 | suppress = [] 278 | 279 | #----------------------------------------------------------------------- 280 | # Figure out the intersection with the remaining windows 281 | #----------------------------------------------------------------------- 282 | xxmin = np.maximum(xmin[i], xmin[idxs]) 283 | xxmax = np.minimum(xmax[i], xmax[idxs]) 284 | yymin = np.maximum(ymin[i], ymin[idxs]) 285 | yymax = np.minimum(ymax[i], ymax[idxs]) 286 | 287 | w = np.maximum(0, xxmax-xxmin+1) 288 | h = np.maximum(0, yymax-yymin+1) 289 | intersection = w*h 290 | 291 | #----------------------------------------------------------------------- 292 | # Compute IOU and suppress indices with IOU higher than a threshold 293 | #----------------------------------------------------------------------- 294 | union = area[i]+area[idxs]-intersection 295 | iou = intersection/union 296 | overlap = iou > overlap_threshold 297 | suppress = np.nonzero(overlap)[0] 298 | idxs = np.delete(idxs, suppress) 299 | 300 | #--------------------------------------------------------------------------- 301 | # Return the selected boxes 302 | #--------------------------------------------------------------------------- 303 | selected = [] 304 | for i in pick: 305 | selected.append(boxes[i]) 306 | 307 | return selected 308 | 309 | #------------------------------------------------------------------------------- 310 | def suppress_overlaps(boxes): 311 | class_boxes = defaultdict(list) 312 | selected_boxes = [] 313 | for box in boxes: 314 | class_boxes[box[1].labelid].append(box) 315 | 316 | for k, v in class_boxes.items(): 317 | selected_boxes += non_maximum_suppression(v, 0.45) 318 | return selected_boxes 319 | -------------------------------------------------------------------------------- /ssdvgg.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 27.08.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import zipfile 22 | import shutil 23 | import os 24 | 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | from urllib.request import urlretrieve 29 | from tqdm import tqdm 30 | 31 | #------------------------------------------------------------------------------- 32 | class DLProgress(tqdm): 33 | last_block = 0 34 | 35 | #--------------------------------------------------------------------------- 36 | def hook(self, block_num=1, block_size=1, total_size=None): 37 | self.total = total_size 38 | self.update((block_num - self.last_block) * block_size) 39 | self.last_block = block_num 40 | 41 | #------------------------------------------------------------------------------- 42 | def conv_map(x, size, shape, stride, name, padding='SAME'): 43 | with tf.variable_scope(name): 44 | w = tf.get_variable("filter", 45 | shape=[shape, shape, x.get_shape()[3], size], 46 | initializer=tf.contrib.layers.xavier_initializer()) 47 | b = tf.Variable(tf.zeros(size), name='biases') 48 | x = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=padding) 49 | x = tf.nn.bias_add(x, b) 50 | x = tf.nn.relu(x) 51 | l2 = tf.nn.l2_loss(w) 52 | return x, l2 53 | 54 | #------------------------------------------------------------------------------- 55 | def classifier(x, size, mapsize, name): 56 | with tf.variable_scope(name): 57 | w = tf.get_variable("filter", 58 | shape=[3, 3, x.get_shape()[3], size], 59 | initializer=tf.contrib.layers.xavier_initializer()) 60 | b = tf.Variable(tf.zeros(size), name='biases') 61 | x = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') 62 | x = tf.nn.bias_add(x, b) 63 | x = tf.reshape(x, [-1, mapsize.w*mapsize.h, size]) 64 | l2 = tf.nn.l2_loss(w) 65 | return x, l2 66 | 67 | #------------------------------------------------------------------------------- 68 | def smooth_l1_loss(x): 69 | square_loss = 0.5*x**2 70 | absolute_loss = tf.abs(x) 71 | return tf.where(tf.less(absolute_loss, 1.), square_loss, absolute_loss-0.5) 72 | 73 | #------------------------------------------------------------------------------- 74 | def array2tensor(x, name): 75 | init = tf.constant_initializer(value=x, dtype=tf.float32) 76 | tensor = tf.get_variable(name=name, initializer=init, shape=x.shape) 77 | return tensor 78 | 79 | #------------------------------------------------------------------------------- 80 | def l2_normalization(x, initial_scale, channels, name): 81 | with tf.variable_scope(name): 82 | scale = array2tensor(initial_scale*np.ones(channels), 'scale') 83 | x = scale*tf.nn.l2_normalize(x, axis=-1) 84 | return x 85 | 86 | #------------------------------------------------------------------------------- 87 | class SSDVGG: 88 | #--------------------------------------------------------------------------- 89 | def __init__(self, session, preset): 90 | self.preset = preset 91 | self.session = session 92 | self.__built = False 93 | self.__build_names() 94 | 95 | #--------------------------------------------------------------------------- 96 | def build_from_vgg(self, vgg_dir, num_classes, a_trous=True, 97 | progress_hook='tqdm'): 98 | """ 99 | Build the model for training based on a pre-define vgg16 model. 100 | :param vgg_dir: directory where the vgg model should be stored 101 | :param num_classes: number of classes 102 | :param progress_hook: a hook to show download progress of vgg16; 103 | the value may be a callable for urlretrieve 104 | or string "tqdm" 105 | """ 106 | self.num_classes = num_classes+1 107 | self.num_vars = num_classes+5 108 | self.l2_loss = 0 109 | self.__download_vgg(vgg_dir, progress_hook) 110 | self.__load_vgg(vgg_dir) 111 | if a_trous: self.__build_vgg_mods_a_trous() 112 | else: self.__build_vgg_mods() 113 | self.__build_ssd_layers() 114 | self.__build_norms() 115 | self.__select_feature_maps() 116 | self.__build_classifiers() 117 | self.__built = True 118 | 119 | #--------------------------------------------------------------------------- 120 | def build_from_metagraph(self, metagraph_file, checkpoint_file): 121 | """ 122 | Build the model for inference from a metagraph shapshot and weights 123 | checkpoint. 124 | """ 125 | sess = self.session 126 | saver = tf.train.import_meta_graph(metagraph_file) 127 | saver.restore(sess, checkpoint_file) 128 | self.image_input = sess.graph.get_tensor_by_name('image_input:0') 129 | self.keep_prob = sess.graph.get_tensor_by_name('keep_prob:0') 130 | self.result = sess.graph.get_tensor_by_name('result/result:0') 131 | 132 | #--------------------------------------------------------------------------- 133 | def build_optimizer_from_metagraph(self): 134 | """ 135 | Get the optimizer and the loss from metagraph 136 | """ 137 | sess = self.session 138 | self.loss = sess.graph.get_tensor_by_name('total_loss/loss:0') 139 | self.localization_loss = sess.graph.get_tensor_by_name('localization_loss/localization_loss:0') 140 | self.confidence_loss = sess.graph.get_tensor_by_name('confidence_loss/confidence_loss:0') 141 | self.l2_loss = sess.graph.get_tensor_by_name('total_loss/l2_loss:0') 142 | self.optimizer = sess.graph.get_operation_by_name('optimizer/optimizer') 143 | self.labels = sess.graph.get_tensor_by_name('labels:0') 144 | 145 | self.losses = { 146 | 'total': self.loss, 147 | 'localization': self.localization_loss, 148 | 'confidence': self.confidence_loss, 149 | 'l2': self.l2_loss 150 | } 151 | 152 | #--------------------------------------------------------------------------- 153 | def __download_vgg(self, vgg_dir, progress_hook): 154 | #----------------------------------------------------------------------- 155 | # Check if the model needs to be downloaded 156 | #----------------------------------------------------------------------- 157 | vgg_archive = 'vgg.zip' 158 | vgg_files = [ 159 | vgg_dir + '/variables/variables.data-00000-of-00001', 160 | vgg_dir + '/variables/variables.index', 161 | vgg_dir + '/saved_model.pb'] 162 | 163 | missing_vgg_files = [vgg_file for vgg_file in vgg_files \ 164 | if not os.path.exists(vgg_file)] 165 | 166 | if missing_vgg_files: 167 | if os.path.exists(vgg_dir): 168 | shutil.rmtree(vgg_dir) 169 | os.makedirs(vgg_dir) 170 | 171 | #------------------------------------------------------------------- 172 | # Download vgg 173 | #------------------------------------------------------------------- 174 | url = 'https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/vgg.zip' 175 | if not os.path.exists(vgg_archive): 176 | if callable(progress_hook): 177 | urlretrieve(url, vgg_archive, progress_hook) 178 | else: 179 | with DLProgress(unit='B', unit_scale=True, miniters=1) as pbar: 180 | urlretrieve(url, vgg_archive, pbar.hook) 181 | 182 | #------------------------------------------------------------------- 183 | # Extract vgg 184 | #------------------------------------------------------------------- 185 | zip_archive = zipfile.ZipFile(vgg_archive, 'r') 186 | zip_archive.extractall(vgg_dir) 187 | zip_archive.close() 188 | 189 | #--------------------------------------------------------------------------- 190 | def __load_vgg(self, vgg_dir): 191 | sess = self.session 192 | graph = tf.saved_model.loader.load(sess, ['vgg16'], vgg_dir+'/vgg') 193 | self.image_input = sess.graph.get_tensor_by_name('image_input:0') 194 | self.keep_prob = sess.graph.get_tensor_by_name('keep_prob:0') 195 | self.vgg_conv4_3 = sess.graph.get_tensor_by_name('conv4_3/Relu:0') 196 | self.vgg_conv5_3 = sess.graph.get_tensor_by_name('conv5_3/Relu:0') 197 | self.vgg_fc6_w = sess.graph.get_tensor_by_name('fc6/weights:0') 198 | self.vgg_fc6_b = sess.graph.get_tensor_by_name('fc6/biases:0') 199 | self.vgg_fc7_w = sess.graph.get_tensor_by_name('fc7/weights:0') 200 | self.vgg_fc7_b = sess.graph.get_tensor_by_name('fc7/biases:0') 201 | 202 | layers = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 203 | 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 204 | 'conv5_1', 'conv5_2', 'conv5_3'] 205 | 206 | for l in layers: 207 | self.l2_loss += sess.graph.get_tensor_by_name(l+'/L2Loss:0') 208 | 209 | #--------------------------------------------------------------------------- 210 | def __build_vgg_mods(self): 211 | self.mod_pool5 = tf.nn.max_pool(self.vgg_conv5_3, ksize=[1, 3, 3, 1], 212 | strides=[1, 1, 1, 1], padding='SAME', 213 | name='mod_pool5') 214 | 215 | with tf.variable_scope('mod_conv6'): 216 | x = tf.nn.conv2d(self.mod_pool5, self.vgg_fc6_w, 217 | strides=[1, 1, 1, 1], padding='SAME') 218 | x = tf.nn.bias_add(x, self.vgg_fc6_b) 219 | self.mod_conv6 = tf.nn.relu(x) 220 | self.l2_loss += tf.nn.l2_loss(self.vgg_fc6_w) 221 | 222 | with tf.variable_scope('mod_conv7'): 223 | x = tf.nn.conv2d(self.mod_conv6, self.vgg_fc7_w, 224 | strides=[1, 1, 1, 1], padding='SAME') 225 | x = tf.nn.bias_add(x, self.vgg_fc7_b) 226 | x = tf.nn.relu(x) 227 | self.mod_conv7 = x 228 | self.l2_loss += tf.nn.l2_loss(self.vgg_fc7_w) 229 | 230 | #--------------------------------------------------------------------------- 231 | def __build_vgg_mods_a_trous(self): 232 | sess = self.session 233 | 234 | self.mod_pool5 = tf.nn.max_pool(self.vgg_conv5_3, ksize=[1, 3, 3, 1], 235 | strides=[1, 1, 1, 1], padding='SAME', 236 | name='mod_pool5') 237 | 238 | #----------------------------------------------------------------------- 239 | # Modified conv6 240 | #----------------------------------------------------------------------- 241 | with tf.variable_scope('mod_conv6'): 242 | #------------------------------------------------------------------- 243 | # Decimate the weights 244 | #------------------------------------------------------------------- 245 | orig_w, orig_b = sess.run([self.vgg_fc6_w, self.vgg_fc6_b]) 246 | mod_w = np.zeros((3, 3, 512, 1024)) 247 | mod_b = np.zeros(1024) 248 | 249 | for i in range(1024): 250 | mod_b[i] = orig_b[4*i] 251 | for h in range(3): 252 | for w in range(3): 253 | mod_w[h, w, :, i] = orig_w[3*h, 3*w, :, 4*i] 254 | 255 | #------------------------------------------------------------------- 256 | # Build the feature map 257 | #------------------------------------------------------------------- 258 | w = array2tensor(mod_w, 'filter') 259 | b = array2tensor(mod_b, 'biases') 260 | x = tf.nn.atrous_conv2d(self.mod_pool5, w, rate=6, padding='SAME') 261 | x = tf.nn.bias_add(x, b) 262 | x = tf.nn.relu(x) 263 | self.mod_conv6 = x 264 | self.l2_loss += tf.nn.l2_loss(w) 265 | 266 | #----------------------------------------------------------------------- 267 | # Modified conv7 268 | #----------------------------------------------------------------------- 269 | with tf.variable_scope('mod_conv7'): 270 | #------------------------------------------------------------------- 271 | # Decimate the weights 272 | #------------------------------------------------------------------- 273 | orig_w, orig_b = sess.run([self.vgg_fc7_w, self.vgg_fc7_b]) 274 | mod_w = np.zeros((1, 1, 1024, 1024)) 275 | mod_b = np.zeros(1024) 276 | 277 | for i in range(1024): 278 | mod_b[i] = orig_b[4*i] 279 | for j in range(1024): 280 | mod_w[:, :, j, i] = orig_w[:, :, 4*j, 4*i] 281 | 282 | #------------------------------------------------------------------- 283 | # Build the feature map 284 | #------------------------------------------------------------------- 285 | w = array2tensor(mod_w, 'filter') 286 | b = array2tensor(mod_b, 'biases') 287 | x = tf.nn.conv2d(self.mod_conv6, w, strides=[1, 1, 1, 1], 288 | padding='SAME') 289 | x = tf.nn.bias_add(x, b) 290 | x = tf.nn.relu(x) 291 | self.mod_conv7 = x 292 | self.l2_loss += tf.nn.l2_loss(w) 293 | 294 | #--------------------------------------------------------------------------- 295 | def __with_loss(self, x, l2_loss): 296 | self.l2_loss += l2_loss 297 | return x 298 | 299 | #--------------------------------------------------------------------------- 300 | def __build_ssd_layers(self): 301 | stride10 = 1 302 | padding10 = 'VALID' 303 | if len(self.preset.maps) >= 7: 304 | stride10 = 2 305 | padding10 = 'SAME' 306 | 307 | x, l2 = conv_map(self.mod_conv7, 256, 1, 1, 'conv8_1') 308 | self.ssd_conv8_1 = self.__with_loss(x, l2) 309 | x, l2 = conv_map(self.ssd_conv8_1, 512, 3, 2, 'conv8_2') 310 | self.ssd_conv8_2 = self.__with_loss(x, l2) 311 | x, l2 = conv_map(self.ssd_conv8_2, 128, 1, 1, 'conv9_1') 312 | self.ssd_conv9_1 = self.__with_loss(x, l2) 313 | x, l2 = conv_map(self.ssd_conv9_1, 256, 3, 2, 'conv9_2') 314 | self.ssd_conv9_2 = self.__with_loss(x, l2) 315 | x, l2 = conv_map(self.ssd_conv9_2, 128, 1, 1, 'conv10_1') 316 | self.ssd_conv10_1 = self.__with_loss(x, l2) 317 | x, l2 = conv_map(self.ssd_conv10_1, 256, 3, stride10, 'conv10_2', padding10) 318 | self.ssd_conv10_2 = self.__with_loss(x, l2) 319 | x, l2 = conv_map(self.ssd_conv10_2, 128, 1, 1, 'conv11_1') 320 | self.ssd_conv11_1 = self.__with_loss(x, l2) 321 | x, l2 = conv_map(self.ssd_conv11_1, 256, 3, 1, 'conv11_2', 'VALID') 322 | self.ssd_conv11_2 = self.__with_loss(x, l2) 323 | 324 | if len(self.preset.maps) < 7: 325 | return 326 | 327 | x, l2 = conv_map(self.ssd_conv11_2, 128, 1, 1, 'conv12_1') 328 | paddings = [[0, 0], [0, 1], [0, 1], [0, 0]] 329 | x = tf.pad(x, paddings, "CONSTANT") 330 | self.ssd_conv12_1 = self.__with_loss(x, l2) 331 | x, l2 = conv_map(self.ssd_conv12_1, 256, 3, 1, 'conv12_2', 'VALID') 332 | self.ssd_conv12_2 = self.__with_loss(x, l2) 333 | 334 | #--------------------------------------------------------------------------- 335 | def __build_norms(self): 336 | x = l2_normalization(self.vgg_conv4_3, 20, 512, 'l2_norm_conv4_3') 337 | self.norm_conv4_3 = x 338 | 339 | #--------------------------------------------------------------------------- 340 | def __select_feature_maps(self): 341 | self.__maps = [ 342 | self.norm_conv4_3, 343 | self.mod_conv7, 344 | self.ssd_conv8_2, 345 | self.ssd_conv9_2, 346 | self.ssd_conv10_2, 347 | self.ssd_conv11_2] 348 | 349 | if len(self.preset.maps) == 7: 350 | self.__maps.append(self.ssd_conv12_2) 351 | 352 | #--------------------------------------------------------------------------- 353 | def __build_classifiers(self): 354 | with tf.variable_scope('classifiers'): 355 | self.__classifiers = [] 356 | for i in range(len(self.__maps)): 357 | fmap = self.__maps[i] 358 | map_size = self.preset.maps[i].size 359 | for j in range(2+len(self.preset.maps[i].aspect_ratios)): 360 | name = 'classifier{}_{}'.format(i, j) 361 | clsfier, l2 = classifier(fmap, self.num_vars, map_size, name) 362 | self.__classifiers.append(self.__with_loss(clsfier, l2)) 363 | 364 | with tf.variable_scope('output'): 365 | output = tf.concat(self.__classifiers, axis=1, name='output') 366 | self.logits = output[:,:,:self.num_classes] 367 | 368 | with tf.variable_scope('result'): 369 | self.classifier = tf.nn.softmax(self.logits) 370 | self.locator = output[:,:,self.num_classes:] 371 | self.result = tf.concat([self.classifier, self.locator], 372 | axis=-1, name='result') 373 | 374 | #--------------------------------------------------------------------------- 375 | def build_optimizer(self, learning_rate=0.001, weight_decay=0.0005, 376 | momentum=0.9, global_step=None): 377 | 378 | self.labels = tf.placeholder(tf.float32, name='labels', 379 | shape=[None, None, self.num_vars]) 380 | 381 | with tf.variable_scope('ground_truth'): 382 | #------------------------------------------------------------------- 383 | # Split the ground truth tensor 384 | #------------------------------------------------------------------- 385 | # Classification ground truth tensor 386 | # Shape: (batch_size, num_anchors, num_classes) 387 | gt_cl = self.labels[:,:,:self.num_classes] 388 | 389 | # Localization ground truth tensor 390 | # Shape: (batch_size, num_anchors, 4) 391 | gt_loc = self.labels[:,:,self.num_classes:] 392 | 393 | # Batch size 394 | # Shape: scalar 395 | batch_size = tf.shape(gt_cl)[0] 396 | 397 | #----------------------------------------------------------------------- 398 | # Compute match counters 399 | #----------------------------------------------------------------------- 400 | with tf.variable_scope('match_counters'): 401 | # Number of anchors per sample 402 | # Shape: (batch_size) 403 | total_num = tf.ones([batch_size], dtype=tf.int64) * \ 404 | tf.to_int64(self.preset.num_anchors) 405 | 406 | # Number of negative (not-matched) anchors per sample, computed 407 | # by counting boxes of the background class in each sample. 408 | # Shape: (batch_size) 409 | negatives_num = tf.count_nonzero(gt_cl[:,:,-1], axis=1) 410 | 411 | # Number of positive (matched) anchors per sample 412 | # Shape: (batch_size) 413 | positives_num = total_num-negatives_num 414 | 415 | # Number of positives per sample that is division-safe 416 | # Shape: (batch_size) 417 | positives_num_safe = tf.where(tf.equal(positives_num, 0), 418 | tf.ones([batch_size])*10e-15, 419 | tf.to_float(positives_num)) 420 | 421 | #----------------------------------------------------------------------- 422 | # Compute masks 423 | #----------------------------------------------------------------------- 424 | with tf.variable_scope('match_masks'): 425 | # Boolean tensor determining whether an anchor is a positive 426 | # Shape: (batch_size, num_anchors) 427 | positives_mask = tf.equal(gt_cl[:,:,-1], 0) 428 | 429 | # Boolean tensor determining whether an anchor is a negative 430 | # Shape: (batch_size, num_anchors) 431 | negatives_mask = tf.logical_not(positives_mask) 432 | 433 | #----------------------------------------------------------------------- 434 | # Compute the confidence loss 435 | #----------------------------------------------------------------------- 436 | with tf.variable_scope('confidence_loss'): 437 | # Cross-entropy tensor - all of the values are non-negative 438 | # Shape: (batch_size, num_anchors) 439 | ce = tf.nn.softmax_cross_entropy_with_logits_v2(labels=gt_cl, 440 | logits=self.logits) 441 | 442 | #------------------------------------------------------------------- 443 | # Sum up the loss of all the positive anchors 444 | #------------------------------------------------------------------- 445 | # Positives - the loss of negative anchors is zeroed out 446 | # Shape: (batch_size, num_anchors) 447 | positives = tf.where(positives_mask, ce, tf.zeros_like(ce)) 448 | 449 | # Total loss of positive anchors 450 | # Shape: (batch_size) 451 | positives_sum = tf.reduce_sum(positives, axis=-1) 452 | 453 | #------------------------------------------------------------------- 454 | # Figure out what the negative anchors with highest confidence loss 455 | # are 456 | #------------------------------------------------------------------- 457 | # Negatives - the loss of positive anchors is zeroed out 458 | # Shape: (batch_size, num_anchors) 459 | negatives = tf.where(negatives_mask, ce, tf.zeros_like(ce)) 460 | 461 | # Top negatives - sorted confience loss with the highest one first 462 | # Shape: (batch_size, num_anchors) 463 | negatives_top = tf.nn.top_k(negatives, self.preset.num_anchors)[0] 464 | 465 | #------------------------------------------------------------------- 466 | # Fugure out what the number of negatives we want to keep is 467 | #------------------------------------------------------------------- 468 | # Maximum number of negatives to keep per sample - we keep at most 469 | # 3 times as many as we have positive anchors in the sample 470 | # Shape: (batch_size) 471 | negatives_num_max = tf.minimum(negatives_num, 3*positives_num) 472 | 473 | #------------------------------------------------------------------- 474 | # Mask out superfluous negatives and compute the sum of the loss 475 | #------------------------------------------------------------------- 476 | # Transposed vector of maximum negatives per sample 477 | # Shape (batch_size, 1) 478 | negatives_num_max_t = tf.expand_dims(negatives_num_max, 1) 479 | 480 | # Range tensor: [0, 1, 2, ..., num_anchors-1] 481 | # Shape: (num_anchors) 482 | rng = tf.range(0, self.preset.num_anchors, 1) 483 | 484 | # Row range, the same as above, but int64 and a row of a matrix 485 | # Shape: (1, num_anchors) 486 | range_row = tf.to_int64(tf.expand_dims(rng, 0)) 487 | 488 | # Mask of maximum negatives - first `negative_num_max` elements 489 | # in corresponding row are `True`, the rest is false 490 | # Shape: (batch_size, num_anchors) 491 | negatives_max_mask = tf.less(range_row, negatives_num_max_t) 492 | 493 | # Max negatives - all the positives and superfluous negatives are 494 | # zeroed out. 495 | # Shape: (batch_size, num_anchors) 496 | negatives_max = tf.where(negatives_max_mask, negatives_top, 497 | tf.zeros_like(negatives_top)) 498 | 499 | # Sum of max negatives for each sample 500 | # Shape: (batch_size) 501 | negatives_max_sum = tf.reduce_sum(negatives_max, axis=-1) 502 | 503 | #------------------------------------------------------------------- 504 | # Compute the confidence loss for each element 505 | #------------------------------------------------------------------- 506 | # Total confidence loss for each sample 507 | # Shape: (batch_size) 508 | confidence_loss = tf.add(positives_sum, negatives_max_sum) 509 | 510 | # Total confidence loss normalized by the number of positives 511 | # per sample 512 | # Shape: (batch_size) 513 | confidence_loss = tf.where(tf.equal(positives_num, 0), 514 | tf.zeros([batch_size]), 515 | tf.div(confidence_loss, 516 | positives_num_safe)) 517 | 518 | # Mean confidence loss for the batch 519 | # Shape: scalar 520 | self.confidence_loss = tf.reduce_mean(confidence_loss, 521 | name='confidence_loss') 522 | 523 | #----------------------------------------------------------------------- 524 | # Compute the localization loss 525 | #----------------------------------------------------------------------- 526 | with tf.variable_scope('localization_loss'): 527 | # Element-wise difference between the predicted localization loss 528 | # and the ground truth 529 | # Shape: (batch_size, num_anchors, 4) 530 | loc_diff = tf.subtract(self.locator, gt_loc) 531 | 532 | # Smooth L1 loss 533 | # Shape: (batch_size, num_anchors, 4) 534 | loc_loss = smooth_l1_loss(loc_diff) 535 | 536 | # Sum of localization losses for each anchor 537 | # Shape: (batch_size, num_anchors) 538 | loc_loss_sum = tf.reduce_sum(loc_loss, axis=-1) 539 | 540 | # Positive locs - the loss of negative anchors is zeroed out 541 | # Shape: (batch_size, num_anchors) 542 | positive_locs = tf.where(positives_mask, loc_loss_sum, 543 | tf.zeros_like(loc_loss_sum)) 544 | 545 | # Total loss of positive anchors 546 | # Shape: (batch_size) 547 | localization_loss = tf.reduce_sum(positive_locs, axis=-1) 548 | 549 | # Total localization loss normalized by the number of positives 550 | # per sample 551 | # Shape: (batch_size) 552 | localization_loss = tf.where(tf.equal(positives_num, 0), 553 | tf.zeros([batch_size]), 554 | tf.div(localization_loss, 555 | positives_num_safe)) 556 | 557 | # Mean localization loss for the batch 558 | # Shape: scalar 559 | self.localization_loss = tf.reduce_mean(localization_loss, 560 | name='localization_loss') 561 | 562 | #----------------------------------------------------------------------- 563 | # Compute total loss 564 | #----------------------------------------------------------------------- 565 | with tf.variable_scope('total_loss'): 566 | # Sum of the localization and confidence loss 567 | # Shape: (batch_size) 568 | self.conf_and_loc_loss = tf.add(self.confidence_loss, 569 | self.localization_loss, 570 | name='sum_losses') 571 | 572 | # L2 loss 573 | # Shape: scalar 574 | self.l2_loss = tf.multiply(weight_decay, self.l2_loss, 575 | name='l2_loss') 576 | 577 | # Final loss 578 | # Shape: scalar 579 | self.loss = tf.add(self.conf_and_loc_loss, self.l2_loss, 580 | name='loss') 581 | 582 | #----------------------------------------------------------------------- 583 | # Build the optimizer 584 | #----------------------------------------------------------------------- 585 | with tf.variable_scope('optimizer'): 586 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) 587 | optimizer = optimizer.minimize(self.loss, global_step=global_step, 588 | name='optimizer') 589 | 590 | #----------------------------------------------------------------------- 591 | # Store the tensors 592 | #----------------------------------------------------------------------- 593 | self.optimizer = optimizer 594 | self.losses = { 595 | 'total': self.loss, 596 | 'localization': self.localization_loss, 597 | 'confidence': self.confidence_loss, 598 | 'l2': self.l2_loss 599 | } 600 | 601 | #--------------------------------------------------------------------------- 602 | def __build_names(self): 603 | #----------------------------------------------------------------------- 604 | # Names of the original and new scopes 605 | #----------------------------------------------------------------------- 606 | self.original_scopes = [ 607 | 'conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 608 | 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 609 | 'conv5_3', 'mod_conv6', 'mod_conv7' 610 | ] 611 | 612 | self.new_scopes = [ 613 | 'conv8_1', 'conv8_2', 'conv9_1', 'conv9_2', 'conv10_1', 'conv10_2', 614 | 'conv11_1', 'conv11_2' 615 | ] 616 | 617 | if len(self.preset.maps) == 7: 618 | self.new_scopes += ['conv12_1', 'conv12_2'] 619 | 620 | for i in range(len(self.preset.maps)): 621 | for j in range(2+len(self.preset.maps[i].aspect_ratios)): 622 | self.new_scopes.append('classifiers/classifier{}_{}'.format(i, j)) 623 | 624 | #--------------------------------------------------------------------------- 625 | def build_summaries(self, restore): 626 | if restore: 627 | return self.session.graph.get_tensor_by_name('net_summaries/net_summaries:0') 628 | 629 | #----------------------------------------------------------------------- 630 | # Build the filter summaries 631 | #----------------------------------------------------------------------- 632 | names = self.original_scopes + self.new_scopes 633 | sess = self.session 634 | with tf.variable_scope('filter_summaries'): 635 | summaries = [] 636 | for name in names: 637 | tensor = sess.graph.get_tensor_by_name(name+'/filter:0') 638 | summary = tf.summary.histogram(name, tensor) 639 | summaries.append(summary) 640 | 641 | #----------------------------------------------------------------------- 642 | # Scale summary 643 | #----------------------------------------------------------------------- 644 | with tf.variable_scope('scale_summary'): 645 | tensor = sess.graph.get_tensor_by_name('l2_norm_conv4_3/scale:0') 646 | summary = tf.summary.histogram('l2_norm_conv4_3', tensor) 647 | summaries.append(summary) 648 | 649 | return tf.summary.merge(summaries, name='net_summaries') 650 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #------------------------------------------------------------------------------- 3 | # Author: Lukasz Janyst 4 | # Date: 07.09.2017 5 | #------------------------------------------------------------------------------- 6 | # This file is part of SSD-TensorFlow. 7 | # 8 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation, either version 3 of the License, or 11 | # (at your option) any later version. 12 | # 13 | # SSD-TensorFlow is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with SSD-Tensorflow. If not, see . 20 | #------------------------------------------------------------------------------- 21 | 22 | import argparse 23 | import math 24 | import sys 25 | import os 26 | 27 | import multiprocessing as mp 28 | import tensorflow as tf 29 | import numpy as np 30 | 31 | from average_precision import APCalculator, APs2mAP 32 | from training_data import TrainingData 33 | from ssdutils import get_anchors_for_preset, decode_boxes, suppress_overlaps 34 | from ssdvgg import SSDVGG 35 | from utils import * 36 | from tqdm import tqdm 37 | 38 | if sys.version_info[0] < 3: 39 | print("This is a Python 3 program. Use Python 3 or higher.") 40 | sys.exit(1) 41 | 42 | #------------------------------------------------------------------------------- 43 | def compute_lr(lr_values, lr_boundaries): 44 | with tf.variable_scope('learning_rate'): 45 | global_step = tf.Variable(0, trainable=False, name='global_step') 46 | lr = tf.train.piecewise_constant(global_step, lr_boundaries, lr_values) 47 | return lr, global_step 48 | 49 | #------------------------------------------------------------------------------- 50 | def main(): 51 | #--------------------------------------------------------------------------- 52 | # Parse the commandline 53 | #--------------------------------------------------------------------------- 54 | parser = argparse.ArgumentParser(description='Train the SSD') 55 | parser.add_argument('--name', default='test', 56 | help='project name') 57 | parser.add_argument('--data-dir', default='pascal-voc', 58 | help='data directory') 59 | parser.add_argument('--vgg-dir', default='vgg_graph', 60 | help='directory for the VGG-16 model') 61 | parser.add_argument('--epochs', type=int, default=200, 62 | help='number of training epochs') 63 | parser.add_argument('--batch-size', type=int, default=8, 64 | help='batch size') 65 | parser.add_argument('--tensorboard-dir', default="tb", 66 | help='name of the tensorboard data directory') 67 | parser.add_argument('--checkpoint-interval', type=int, default=5, 68 | help='checkpoint interval') 69 | parser.add_argument('--lr-values', type=str, default='0.00075;0.0001;0.00001', 70 | help='learning rate values') 71 | parser.add_argument('--lr-boundaries', type=str, default='320000;400000', 72 | help='learning rate chage boundaries (in batches)') 73 | parser.add_argument('--momentum', type=float, default=0.9, 74 | help='momentum for the optimizer') 75 | parser.add_argument('--weight-decay', type=float, default=0.0005, 76 | help='L2 normalization factor') 77 | parser.add_argument('--continue-training', type=str2bool, default='False', 78 | help='continue training from the latest checkpoint') 79 | parser.add_argument('--num-workers', type=int, default=mp.cpu_count(), 80 | help='number of parallel generators') 81 | 82 | args = parser.parse_args() 83 | 84 | print('[i] Project name: ', args.name) 85 | print('[i] Data directory: ', args.data_dir) 86 | print('[i] VGG directory: ', args.vgg_dir) 87 | print('[i] # epochs: ', args.epochs) 88 | print('[i] Batch size: ', args.batch_size) 89 | print('[i] Tensorboard directory:', args.tensorboard_dir) 90 | print('[i] Checkpoint interval: ', args.checkpoint_interval) 91 | print('[i] Learning rate values: ', args.lr_values) 92 | print('[i] Learning rate boundaries: ', args.lr_boundaries) 93 | print('[i] Momentum: ', args.momentum) 94 | print('[i] Weight decay: ', args.weight_decay) 95 | print('[i] Continue: ', args.continue_training) 96 | print('[i] Number of workers: ', args.num_workers) 97 | 98 | #--------------------------------------------------------------------------- 99 | # Find an existing checkpoint 100 | #--------------------------------------------------------------------------- 101 | start_epoch = 0 102 | if args.continue_training: 103 | state = tf.train.get_checkpoint_state(args.name) 104 | if state is None: 105 | print('[!] No network state found in ' + args.name) 106 | return 1 107 | 108 | ckpt_paths = state.all_model_checkpoint_paths 109 | if not ckpt_paths: 110 | print('[!] No network state found in ' + args.name) 111 | return 1 112 | 113 | last_epoch = None 114 | checkpoint_file = None 115 | for ckpt in ckpt_paths: 116 | ckpt_num = os.path.basename(ckpt).split('.')[0][1:] 117 | try: 118 | ckpt_num = int(ckpt_num) 119 | except ValueError: 120 | continue 121 | if last_epoch is None or last_epoch < ckpt_num: 122 | last_epoch = ckpt_num 123 | checkpoint_file = ckpt 124 | 125 | if checkpoint_file is None: 126 | print('[!] No checkpoints found, cannot continue!') 127 | return 1 128 | 129 | metagraph_file = checkpoint_file + '.meta' 130 | 131 | if not os.path.exists(metagraph_file): 132 | print('[!] Cannot find metagraph', metagraph_file) 133 | return 1 134 | start_epoch = last_epoch 135 | 136 | #--------------------------------------------------------------------------- 137 | # Create a project directory 138 | #--------------------------------------------------------------------------- 139 | else: 140 | try: 141 | print('[i] Creating directory {}...'.format(args.name)) 142 | os.makedirs(args.name) 143 | except (IOError) as e: 144 | print('[!]', str(e)) 145 | return 1 146 | 147 | print('[i] Starting at epoch: ', start_epoch+1) 148 | 149 | #--------------------------------------------------------------------------- 150 | # Configure the training data 151 | #--------------------------------------------------------------------------- 152 | print('[i] Configuring the training data...') 153 | try: 154 | td = TrainingData(args.data_dir) 155 | print('[i] # training samples: ', td.num_train) 156 | print('[i] # validation samples: ', td.num_valid) 157 | print('[i] # classes: ', td.num_classes) 158 | print('[i] Image size: ', td.preset.image_size) 159 | except (AttributeError, RuntimeError) as e: 160 | print('[!] Unable to load training data:', str(e)) 161 | return 1 162 | 163 | #--------------------------------------------------------------------------- 164 | # Create the network 165 | #--------------------------------------------------------------------------- 166 | with tf.Session() as sess: 167 | print('[i] Creating the model...') 168 | n_train_batches = int(math.ceil(td.num_train/args.batch_size)) 169 | n_valid_batches = int(math.ceil(td.num_valid/args.batch_size)) 170 | 171 | global_step = None 172 | if start_epoch == 0: 173 | lr_values = args.lr_values.split(';') 174 | try: 175 | lr_values = [float(x) for x in lr_values] 176 | except ValueError: 177 | print('[!] Learning rate values must be floats') 178 | sys.exit(1) 179 | 180 | lr_boundaries = args.lr_boundaries.split(';') 181 | try: 182 | lr_boundaries = [int(x) for x in lr_boundaries] 183 | except ValueError: 184 | print('[!] Learning rate boundaries must be ints') 185 | sys.exit(1) 186 | 187 | ret = compute_lr(lr_values, lr_boundaries) 188 | learning_rate, global_step = ret 189 | 190 | net = SSDVGG(sess, td.preset) 191 | if start_epoch != 0: 192 | net.build_from_metagraph(metagraph_file, checkpoint_file) 193 | net.build_optimizer_from_metagraph() 194 | else: 195 | net.build_from_vgg(args.vgg_dir, td.num_classes) 196 | net.build_optimizer(learning_rate=learning_rate, 197 | global_step=global_step, 198 | weight_decay=args.weight_decay, 199 | momentum=args.momentum) 200 | 201 | initialize_uninitialized_variables(sess) 202 | 203 | #----------------------------------------------------------------------- 204 | # Create various helpers 205 | #----------------------------------------------------------------------- 206 | summary_writer = tf.summary.FileWriter(args.tensorboard_dir, 207 | sess.graph) 208 | saver = tf.train.Saver(max_to_keep=20) 209 | 210 | anchors = get_anchors_for_preset(td.preset) 211 | training_ap_calc = APCalculator() 212 | validation_ap_calc = APCalculator() 213 | 214 | #----------------------------------------------------------------------- 215 | # Summaries 216 | #----------------------------------------------------------------------- 217 | restore = start_epoch != 0 218 | 219 | training_ap = PrecisionSummary(sess, summary_writer, 'training', 220 | td.lname2id.keys(), restore) 221 | validation_ap = PrecisionSummary(sess, summary_writer, 'validation', 222 | td.lname2id.keys(), restore) 223 | 224 | training_imgs = ImageSummary(sess, summary_writer, 'training', 225 | td.label_colors, restore) 226 | validation_imgs = ImageSummary(sess, summary_writer, 'validation', 227 | td.label_colors, restore) 228 | 229 | training_loss = LossSummary(sess, summary_writer, 'training', 230 | td.num_train, restore) 231 | validation_loss = LossSummary(sess, summary_writer, 'validation', 232 | td.num_valid, restore) 233 | 234 | #----------------------------------------------------------------------- 235 | # Get the initial snapshot of the network 236 | #----------------------------------------------------------------------- 237 | net_summary_ops = net.build_summaries(restore) 238 | if start_epoch == 0: 239 | net_summary = sess.run(net_summary_ops) 240 | summary_writer.add_summary(net_summary, 0) 241 | summary_writer.flush() 242 | 243 | #----------------------------------------------------------------------- 244 | # Cycle through the epoch 245 | #----------------------------------------------------------------------- 246 | print('[i] Training...') 247 | for e in range(start_epoch, args.epochs): 248 | training_imgs_samples = [] 249 | validation_imgs_samples = [] 250 | 251 | #------------------------------------------------------------------- 252 | # Train 253 | #------------------------------------------------------------------- 254 | generator = td.train_generator(args.batch_size, args.num_workers) 255 | description = '[i] Train {:>2}/{}'.format(e+1, args.epochs) 256 | for x, y, gt_boxes in tqdm(generator, total=n_train_batches, 257 | desc=description, unit='batches'): 258 | 259 | if len(training_imgs_samples) < 3: 260 | saved_images = np.copy(x[:3]) 261 | 262 | feed = {net.image_input: x, 263 | net.labels: y} 264 | result, loss_batch, _ = sess.run([net.result, net.losses, 265 | net.optimizer], 266 | feed_dict=feed) 267 | 268 | if math.isnan(loss_batch['confidence']): 269 | print('[!] Confidence loss is NaN.') 270 | 271 | training_loss.add(loss_batch, x.shape[0]) 272 | 273 | if e == 0: continue 274 | 275 | for i in range(result.shape[0]): 276 | boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) 277 | boxes = suppress_overlaps(boxes) 278 | training_ap_calc.add_detections(gt_boxes[i], boxes) 279 | 280 | if len(training_imgs_samples) < 3: 281 | training_imgs_samples.append((saved_images[i], boxes)) 282 | 283 | #------------------------------------------------------------------- 284 | # Validate 285 | #------------------------------------------------------------------- 286 | generator = td.valid_generator(args.batch_size, args.num_workers) 287 | description = '[i] Valid {:>2}/{}'.format(e+1, args.epochs) 288 | 289 | for x, y, gt_boxes in tqdm(generator, total=n_valid_batches, 290 | desc=description, unit='batches'): 291 | feed = {net.image_input: x, 292 | net.labels: y} 293 | result, loss_batch = sess.run([net.result, net.losses], 294 | feed_dict=feed) 295 | 296 | validation_loss.add(loss_batch, x.shape[0]) 297 | 298 | if e == 0: continue 299 | 300 | for i in range(result.shape[0]): 301 | boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) 302 | boxes = suppress_overlaps(boxes) 303 | validation_ap_calc.add_detections(gt_boxes[i], boxes) 304 | 305 | if len(validation_imgs_samples) < 3: 306 | validation_imgs_samples.append((np.copy(x[i]), boxes)) 307 | 308 | #------------------------------------------------------------------- 309 | # Write summaries 310 | #------------------------------------------------------------------- 311 | training_loss.push(e+1) 312 | validation_loss.push(e+1) 313 | 314 | net_summary = sess.run(net_summary_ops) 315 | summary_writer.add_summary(net_summary, e+1) 316 | 317 | APs = training_ap_calc.compute_aps() 318 | mAP = APs2mAP(APs) 319 | training_ap.push(e+1, mAP, APs) 320 | 321 | APs = validation_ap_calc.compute_aps() 322 | mAP = APs2mAP(APs) 323 | validation_ap.push(e+1, mAP, APs) 324 | 325 | training_ap_calc.clear() 326 | validation_ap_calc.clear() 327 | 328 | training_imgs.push(e+1, training_imgs_samples) 329 | validation_imgs.push(e+1, validation_imgs_samples) 330 | 331 | summary_writer.flush() 332 | 333 | #------------------------------------------------------------------- 334 | # Save a checktpoint 335 | #------------------------------------------------------------------- 336 | if (e+1) % args.checkpoint_interval == 0: 337 | checkpoint = '{}/e{}.ckpt'.format(args.name, e+1) 338 | saver.save(sess, checkpoint) 339 | print('[i] Checkpoint saved:', checkpoint) 340 | 341 | checkpoint = '{}/final.ckpt'.format(args.name) 342 | saver.save(sess, checkpoint) 343 | print('[i] Checkpoint saved:', checkpoint) 344 | 345 | return 0 346 | 347 | if __name__ == '__main__': 348 | sys.exit(main()) 349 | -------------------------------------------------------------------------------- /training_data.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 09.09.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import pickle 22 | import random 23 | import math 24 | import cv2 25 | import os 26 | 27 | import multiprocessing as mp 28 | import numpy as np 29 | import queue as q 30 | 31 | from data_queue import DataQueue 32 | from copy import copy 33 | 34 | #------------------------------------------------------------------------------- 35 | class TrainingData: 36 | #--------------------------------------------------------------------------- 37 | def __init__(self, data_dir): 38 | #----------------------------------------------------------------------- 39 | # Read the dataset info 40 | #----------------------------------------------------------------------- 41 | try: 42 | with open(data_dir+'/training-data.pkl', 'rb') as f: 43 | data = pickle.load(f) 44 | with open(data_dir+'/train-samples.pkl', 'rb') as f: 45 | train_samples = pickle.load(f) 46 | with open(data_dir+'/valid-samples.pkl', 'rb') as f: 47 | valid_samples = pickle.load(f) 48 | except (FileNotFoundError, IOError) as e: 49 | raise RuntimeError(str(e)) 50 | 51 | nones = [None] * len(train_samples) 52 | train_samples = list(zip(nones, nones, train_samples)) 53 | nones = [None] * len(valid_samples) 54 | valid_samples = list(zip(nones, nones, valid_samples)) 55 | 56 | #----------------------------------------------------------------------- 57 | # Set the attributes up 58 | #----------------------------------------------------------------------- 59 | self.preset = data['preset'] 60 | self.num_classes = data['num-classes'] 61 | self.label_colors = data['colors'] 62 | self.lid2name = data['lid2name'] 63 | self.lname2id = data['lname2id'] 64 | self.train_tfs = data['train-transforms'] 65 | self.valid_tfs = data['valid-transforms'] 66 | self.train_generator = self.__batch_generator(train_samples, 67 | self.train_tfs) 68 | self.valid_generator = self.__batch_generator(valid_samples, 69 | self.valid_tfs) 70 | self.num_train = len(train_samples) 71 | self.num_valid = len(valid_samples) 72 | self.train_samples = list(map(lambda x: x[2], train_samples)) 73 | self.valid_samples = list(map(lambda x: x[2], valid_samples)) 74 | 75 | #--------------------------------------------------------------------------- 76 | def __batch_generator(self, sample_list_, transforms): 77 | image_size = (self.preset.image_size.w, self.preset.image_size.h) 78 | 79 | #----------------------------------------------------------------------- 80 | def run_transforms(sample): 81 | args = sample 82 | for t in transforms: 83 | args = t(*args) 84 | return args 85 | 86 | #----------------------------------------------------------------------- 87 | def process_samples(samples): 88 | images = [] 89 | labels = [] 90 | gt_boxes = [] 91 | for s in samples: 92 | done = False 93 | counter = 0 94 | while not done and counter < 50: 95 | image, label, gt = run_transforms(s) 96 | num_bg = np.count_nonzero(label[:, self.num_classes]) 97 | done = num_bg < label.shape[0] 98 | counter += 1 99 | 100 | images.append(image.astype(np.float32)) 101 | labels.append(label.astype(np.float32)) 102 | gt_boxes.append(gt.boxes) 103 | 104 | images = np.array(images, dtype=np.float32) 105 | labels = np.array(labels, dtype=np.float32) 106 | return images, labels, gt_boxes 107 | 108 | #----------------------------------------------------------------------- 109 | def batch_producer(sample_queue, batch_queue): 110 | while True: 111 | #--------------------------------------------------------------- 112 | # Process the sample 113 | #--------------------------------------------------------------- 114 | try: 115 | samples = sample_queue.get(timeout=1) 116 | except q.Empty: 117 | break 118 | 119 | images, labels, gt_boxes = process_samples(samples) 120 | 121 | #--------------------------------------------------------------- 122 | # Pad the result in the case where we don't have enough samples 123 | # to fill the entire batch 124 | #--------------------------------------------------------------- 125 | if images.shape[0] < batch_queue.img_shape[0]: 126 | images_norm = np.zeros(batch_queue.img_shape, 127 | dtype=np.float32) 128 | labels_norm = np.zeros(batch_queue.label_shape, 129 | dtype=np.float32) 130 | images_norm[:images.shape[0]] = images 131 | labels_norm[:images.shape[0]] = labels 132 | batch_queue.put(images_norm, labels_norm, gt_boxes) 133 | else: 134 | batch_queue.put(images, labels, gt_boxes) 135 | 136 | #----------------------------------------------------------------------- 137 | def gen_batch(batch_size, num_workers=0): 138 | sample_list = copy(sample_list_) 139 | random.shuffle(sample_list) 140 | 141 | #------------------------------------------------------------------- 142 | # Set up the parallel generator 143 | #------------------------------------------------------------------- 144 | if num_workers > 0: 145 | #--------------------------------------------------------------- 146 | # Set up the queues 147 | #--------------------------------------------------------------- 148 | img_template = np.zeros((batch_size, self.preset.image_size.h, 149 | self.preset.image_size.w, 3), 150 | dtype=np.float32) 151 | label_template = np.zeros((batch_size, self.preset.num_anchors, 152 | self.num_classes+5), 153 | dtype=np.float32) 154 | max_size = num_workers*5 155 | n_batches = int(math.ceil(len(sample_list_)/batch_size)) 156 | sample_queue = mp.Queue(n_batches) 157 | batch_queue = DataQueue(img_template, label_template, max_size) 158 | 159 | #--------------------------------------------------------------- 160 | # Set up the workers. Make sure we can fork safely even if 161 | # OpenCV has been compiled with CUDA and multi-threading 162 | # support. 163 | #--------------------------------------------------------------- 164 | workers = [] 165 | os.environ['CUDA_VISIBLE_DEVICES'] = "" 166 | cv2_num_threads = cv2.getNumThreads() 167 | cv2.setNumThreads(1) 168 | for i in range(num_workers): 169 | args = (sample_queue, batch_queue) 170 | w = mp.Process(target=batch_producer, args=args) 171 | workers.append(w) 172 | w.start() 173 | del os.environ['CUDA_VISIBLE_DEVICES'] 174 | cv2.setNumThreads(cv2_num_threads) 175 | 176 | #--------------------------------------------------------------- 177 | # Fill the sample queue with data 178 | #--------------------------------------------------------------- 179 | for offset in range(0, len(sample_list), batch_size): 180 | samples = sample_list[offset:offset+batch_size] 181 | sample_queue.put(samples) 182 | 183 | #--------------------------------------------------------------- 184 | # Return the data 185 | #--------------------------------------------------------------- 186 | for offset in range(0, len(sample_list), batch_size): 187 | images, labels, gt_boxes = batch_queue.get() 188 | num_items = len(gt_boxes) 189 | yield images[:num_items], labels[:num_items], gt_boxes 190 | 191 | #--------------------------------------------------------------- 192 | # Join the workers 193 | #--------------------------------------------------------------- 194 | for w in workers: 195 | w.join() 196 | 197 | #------------------------------------------------------------------- 198 | # Return a serial generator 199 | #------------------------------------------------------------------- 200 | else: 201 | for offset in range(0, len(sample_list), batch_size): 202 | samples = sample_list[offset:offset+batch_size] 203 | images, labels, gt_boxes = process_samples(samples) 204 | yield images, labels, gt_boxes 205 | 206 | return gen_batch 207 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 18.09.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import cv2 22 | import random 23 | 24 | import numpy as np 25 | 26 | from ssdutils import get_anchors_for_preset, get_preset_by_name, anchors2array 27 | from ssdutils import box2array, compute_overlap, compute_location, anchors2array 28 | from utils import Size, Sample, Point, Box, abs2prop, prop2abs 29 | from math import sqrt 30 | 31 | #------------------------------------------------------------------------------- 32 | class Transform: 33 | def __init__(self, **kwargs): 34 | for arg, val in kwargs.items(): 35 | setattr(self, arg, val) 36 | self.initialized = False 37 | 38 | #------------------------------------------------------------------------------- 39 | class ImageLoaderTransform(Transform): 40 | """ 41 | Load and image from the file specified in the Sample object 42 | """ 43 | def __call__(self, data, label, gt): 44 | return cv2.imread(gt.filename), label, gt 45 | 46 | #------------------------------------------------------------------------------- 47 | def process_overlap(overlap, box, anchor, matches, num_classes, vec): 48 | if overlap.idx in matches and matches[overlap.idx] >= overlap.score: 49 | return 50 | 51 | matches[overlap.idx] = overlap.score 52 | vec[overlap.idx, 0:num_classes+1] = 0 53 | vec[overlap.idx, box.labelid] = 1 54 | vec[overlap.idx, num_classes+1:] = compute_location(box, anchor) 55 | 56 | #------------------------------------------------------------------------------- 57 | class LabelCreatorTransform(Transform): 58 | """ 59 | Create a label vector out of a ground trut sample 60 | Parameters: preset, num_classes 61 | """ 62 | #--------------------------------------------------------------------------- 63 | def initialize(self): 64 | self.anchors = get_anchors_for_preset(self.preset) 65 | self.vheight = len(self.anchors) 66 | self.vwidth = self.num_classes+5 # background class + location offsets 67 | self.img_size = Size(1000, 1000) 68 | self.anchors_arr = anchors2array(self.anchors, self.img_size) 69 | self.initialized = True 70 | 71 | #--------------------------------------------------------------------------- 72 | def __call__(self, data, label, gt): 73 | #----------------------------------------------------------------------- 74 | # Initialize the data vector and other variables 75 | #----------------------------------------------------------------------- 76 | if not self.initialized: 77 | self.initialize() 78 | 79 | vec = np.zeros((self.vheight, self.vwidth), dtype=np.float32) 80 | 81 | #----------------------------------------------------------------------- 82 | # For every box compute the best match and all the matches above 0.5 83 | # Jaccard overlap 84 | #----------------------------------------------------------------------- 85 | overlaps = {} 86 | for box in gt.boxes: 87 | box_arr = box2array(box, self.img_size) 88 | overlaps[box] = compute_overlap(box_arr, self.anchors_arr, 0.5) 89 | 90 | #----------------------------------------------------------------------- 91 | # Set up the training vector resolving conflicts in favor of a better 92 | # match 93 | #----------------------------------------------------------------------- 94 | vec[:, self.num_classes] = 1 # background class 95 | vec[:, self.num_classes+1] = 0 # x offset 96 | vec[:, self.num_classes+2] = 0 # y offset 97 | vec[:, self.num_classes+3] = 0 # log width scale 98 | vec[:, self.num_classes+4] = 0 # log height scale 99 | 100 | matches = {} 101 | for box in gt.boxes: 102 | for overlap in overlaps[box].good: 103 | anchor = self.anchors[overlap.idx] 104 | process_overlap(overlap, box, anchor, matches, self.num_classes, vec) 105 | 106 | matches = {} 107 | for box in gt.boxes: 108 | overlap = overlaps[box].best 109 | if not overlap: 110 | continue 111 | anchor = self.anchors[overlap.idx] 112 | process_overlap(overlap, box, anchor, matches, self.num_classes, vec) 113 | 114 | return data, vec, gt 115 | 116 | #------------------------------------------------------------------------------- 117 | class ResizeTransform(Transform): 118 | """ 119 | Resize an image 120 | Parameters: width, height, algorithms 121 | """ 122 | def __call__(self, data, label, gt): 123 | alg = random.choice(self.algorithms) 124 | resized = cv2.resize(data, (self.width, self.height), interpolation=alg) 125 | return resized, label, gt 126 | 127 | #------------------------------------------------------------------------------- 128 | class RandomTransform(Transform): 129 | """ 130 | Call another transform with a given probability 131 | Parameters: prob, transform 132 | """ 133 | def __call__(self, data, label, gt): 134 | p = random.uniform(0, 1) 135 | if p < self.prob: 136 | return self.transform(data, label, gt) 137 | return data, label, gt 138 | 139 | #------------------------------------------------------------------------------- 140 | class ComposeTransform(Transform): 141 | """ 142 | Call a bunch of transforms serially 143 | Parameters: transforms 144 | """ 145 | def __call__(self, data, label, gt): 146 | args = (data, label, gt) 147 | for t in self.transforms: 148 | args = t(*args) 149 | return args 150 | 151 | #------------------------------------------------------------------------------- 152 | class TransformPickerTransform(Transform): 153 | """ 154 | Call a randomly chosen transform from the list 155 | Parameters: transforms 156 | """ 157 | def __call__(self, data, label, gt): 158 | pick = random.randint(0, len(self.transforms)-1) 159 | return self.transforms[pick](data, label, gt) 160 | 161 | #------------------------------------------------------------------------------- 162 | class BrightnessTransform(Transform): 163 | """ 164 | Transform brightness 165 | Parameters: delta 166 | """ 167 | def __call__(self, data, label, gt): 168 | data = data.astype(np.float32) 169 | delta = random.randint(-self.delta, self.delta) 170 | data += delta 171 | data[data>255] = 255 172 | data[data<0] = 0 173 | data = data.astype(np.uint8) 174 | return data, label, gt 175 | 176 | #------------------------------------------------------------------------------- 177 | class ContrastTransform(Transform): 178 | """ 179 | Transform contrast 180 | Parameters: lower, upper 181 | """ 182 | def __call__(self, data, label, gt): 183 | data = data.astype(np.float32) 184 | delta = random.uniform(self.lower, self.upper) 185 | data *= delta 186 | data[data>255] = 255 187 | data[data<0] = 0 188 | data = data.astype(np.uint8) 189 | return data, label, gt 190 | 191 | #------------------------------------------------------------------------------- 192 | class HueTransform(Transform): 193 | """ 194 | Transform hue 195 | Parameters: delta 196 | """ 197 | def __call__(self, data, label, gt): 198 | data = cv2.cvtColor(data, cv2.COLOR_BGR2HSV) 199 | data = data.astype(np.float32) 200 | delta = random.randint(-self.delta, self.delta) 201 | data[0] += delta 202 | data[0][data[0]>180] -= 180 203 | data[0][data[0]<0] +=180 204 | data = data.astype(np.uint8) 205 | data = cv2.cvtColor(data, cv2.COLOR_HSV2BGR) 206 | return data, label, gt 207 | 208 | #------------------------------------------------------------------------------- 209 | class SaturationTransform(Transform): 210 | """ 211 | Transform hue 212 | Parameters: lower, upper 213 | """ 214 | def __call__(self, data, label, gt): 215 | data = cv2.cvtColor(data, cv2.COLOR_BGR2HSV) 216 | data = data.astype(np.float32) 217 | delta = random.uniform(self.lower, self.upper) 218 | data[1] *= delta 219 | data[1][data[1]>255] = 255 220 | data[1][data[1]<0] = 0 221 | data = data.astype(np.uint8) 222 | data = cv2.cvtColor(data, cv2.COLOR_HSV2BGR) 223 | return data, label, gt 224 | 225 | #------------------------------------------------------------------------------- 226 | class ReorderChannelsTransform(Transform): 227 | """ 228 | Reorder Image Channels 229 | """ 230 | def __call__(self, data, label, gt): 231 | channels = [0, 1, 2] 232 | random.shuffle(channels) 233 | return data[:, :,channels], label, gt 234 | 235 | #------------------------------------------------------------------------------- 236 | def transform_box(box, orig_size, new_size, h_off, w_off): 237 | #--------------------------------------------------------------------------- 238 | # Compute the new coordinates of the box 239 | #--------------------------------------------------------------------------- 240 | xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, orig_size) 241 | xmin += w_off 242 | xmax += w_off 243 | ymin += h_off 244 | ymax += h_off 245 | 246 | #--------------------------------------------------------------------------- 247 | # Check if the center falls within the image 248 | #--------------------------------------------------------------------------- 249 | width = xmax - xmin 250 | height = ymax - ymin 251 | new_cx = xmin + int(width/2) 252 | new_cy = ymin + int(height/2) 253 | if new_cx < 0 or new_cx >= new_size.w: 254 | return None 255 | if new_cy < 0 or new_cy >= new_size.h: 256 | return None 257 | 258 | center, size = abs2prop(xmin, xmax, ymin, ymax, new_size) 259 | return Box(box.label, box.labelid, center, size) 260 | 261 | #------------------------------------------------------------------------------- 262 | def transform_gt(gt, new_size, h_off, w_off): 263 | boxes = [] 264 | for box in gt.boxes: 265 | box = transform_box(box, gt.imgsize, new_size, h_off, w_off) 266 | if box is None: 267 | continue 268 | boxes.append(box) 269 | return Sample(gt.filename, boxes, new_size) 270 | 271 | #------------------------------------------------------------------------------- 272 | class ExpandTransform(Transform): 273 | """ 274 | Expand the image and fill the empty space with the mean value 275 | Parameters: max_ratio, mean_value 276 | """ 277 | def __call__(self, data, label, gt): 278 | #----------------------------------------------------------------------- 279 | # Calculate sizes and offsets 280 | #----------------------------------------------------------------------- 281 | ratio = random.uniform(1, self.max_ratio) 282 | orig_size = gt.imgsize 283 | new_size = Size(int(orig_size.w*ratio), int(orig_size.h*ratio)) 284 | h_off = random.randint(0, new_size.h-orig_size.h) 285 | w_off = random.randint(0, new_size.w-orig_size.w) 286 | 287 | #----------------------------------------------------------------------- 288 | # Create the new image and place the input image in it 289 | #----------------------------------------------------------------------- 290 | img = np.zeros((new_size.h, new_size.w, 3)) 291 | img[:, :] = np.array(self.mean_value) 292 | img[h_off:h_off+orig_size.h, w_off:w_off+orig_size.w, :] = data 293 | 294 | #----------------------------------------------------------------------- 295 | # Transform the ground truth 296 | #----------------------------------------------------------------------- 297 | gt = transform_gt(gt, new_size, h_off, w_off) 298 | 299 | return img, label, gt 300 | 301 | #------------------------------------------------------------------------------- 302 | class SamplerTransform(Transform): 303 | """ 304 | Sample a fraction of the image according to given parameters 305 | Params: min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, 306 | min_jaccard_overlap 307 | """ 308 | def __call__(self, data, label, gt): 309 | #----------------------------------------------------------------------- 310 | # Check whether to sample or not 311 | #----------------------------------------------------------------------- 312 | if not self.sample: 313 | return data, label, gt 314 | 315 | #----------------------------------------------------------------------- 316 | # Retry sampling a couple of times 317 | #----------------------------------------------------------------------- 318 | source_boxes = anchors2array(gt.boxes, gt.imgsize) 319 | box = None 320 | box_arr = None 321 | for _ in range(self.max_trials): 322 | #------------------------------------------------------------------- 323 | # Sample a bounding box 324 | #------------------------------------------------------------------- 325 | scale = random.uniform(self.min_scale, self.max_scale) 326 | aspect_ratio = random.uniform(self.min_aspect_ratio, 327 | self.max_aspect_ratio) 328 | 329 | # make sure width and height will not be larger than 1 330 | aspect_ratio = max(aspect_ratio, scale**2) 331 | aspect_ratio = min(aspect_ratio, 1/(scale**2)) 332 | 333 | width = scale*sqrt(aspect_ratio) 334 | height = scale/sqrt(aspect_ratio) 335 | cx = 0.5*width + random.uniform(0, 1-width) 336 | cy = 0.5*height + random.uniform(0, 1-height) 337 | center = Point(cx, cy) 338 | size = Size(width, height) 339 | 340 | #------------------------------------------------------------------- 341 | # Check if the box satisfies the jaccard overlap constraint 342 | #------------------------------------------------------------------- 343 | box_arr = np.array(prop2abs(center, size, gt.imgsize)) 344 | overlap = compute_overlap(box_arr, source_boxes, 0) 345 | if overlap.best and overlap.best.score >= self.min_jaccard_overlap: 346 | box = Box(None, None, center, size) 347 | break 348 | 349 | if box is None: 350 | return None 351 | 352 | #----------------------------------------------------------------------- 353 | # Crop the box and adjust the ground truth 354 | #----------------------------------------------------------------------- 355 | new_size = Size(box_arr[1]-box_arr[0], box_arr[3]-box_arr[2]) 356 | w_off = -box_arr[0] 357 | h_off = -box_arr[2] 358 | data = data[box_arr[2]:box_arr[3], box_arr[0]:box_arr[1]] 359 | gt = transform_gt(gt, new_size, h_off, w_off) 360 | 361 | return data, label, gt 362 | 363 | #------------------------------------------------------------------------------- 364 | class SamplePickerTransform(Transform): 365 | """ 366 | Run a bunch of sample transforms and return one of the produced samples 367 | Parameters: samplers 368 | """ 369 | def __call__(self, data, label, gt): 370 | samples = [] 371 | for sampler in self.samplers: 372 | sample = sampler(data, label, gt) 373 | if sample is not None: 374 | samples.append(sample) 375 | return random.choice(samples) 376 | 377 | #------------------------------------------------------------------------------- 378 | class HorizontalFlipTransform(Transform): 379 | """ 380 | Horizontally flip the image 381 | """ 382 | def __call__(self, data, label, gt): 383 | data = cv2.flip(data, 1) 384 | boxes = [] 385 | for box in gt.boxes: 386 | center = Point(1-box.center.x, box.center.y) 387 | box = Box(box.label, box.labelid, center, box.size) 388 | boxes.append(box) 389 | gt = Sample(gt.filename, boxes, gt.imgsize) 390 | 391 | return data, label, gt 392 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Author: Lukasz Janyst 3 | # Date: 29.08.2017 4 | #------------------------------------------------------------------------------- 5 | # This file is part of SSD-TensorFlow. 6 | # 7 | # SSD-TensorFlow is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # SSD-TensorFlow is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with SSD-Tensorflow. If not, see . 19 | #------------------------------------------------------------------------------- 20 | 21 | import argparse 22 | import math 23 | import cv2 24 | 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | from collections import namedtuple 29 | 30 | #------------------------------------------------------------------------------- 31 | def initialize_uninitialized_variables(sess): 32 | """ 33 | Only initialize the weights that have not yet been initialized by other 34 | means, such as importing a metagraph and a checkpoint. It's useful when 35 | extending an existing model. 36 | """ 37 | uninit_vars = [] 38 | uninit_tensors = [] 39 | for var in tf.global_variables(): 40 | uninit_vars.append(var) 41 | uninit_tensors.append(tf.is_variable_initialized(var)) 42 | uninit_bools = sess.run(uninit_tensors) 43 | uninit = zip(uninit_bools, uninit_vars) 44 | uninit = [var for init, var in uninit if not init] 45 | sess.run(tf.variables_initializer(uninit)) 46 | 47 | #------------------------------------------------------------------------------- 48 | def load_data_source(data_source): 49 | """ 50 | Load a data source given it's name 51 | """ 52 | source_module = __import__('source_'+data_source) 53 | get_source = getattr(source_module, 'get_source') 54 | return get_source() 55 | 56 | #------------------------------------------------------------------------------- 57 | def rgb2bgr(tpl): 58 | """ 59 | Convert RGB color tuple to BGR 60 | """ 61 | return (tpl[2], tpl[1], tpl[0]) 62 | 63 | #------------------------------------------------------------------------------- 64 | Label = namedtuple('Label', ['name', 'color']) 65 | Size = namedtuple('Size', ['w', 'h']) 66 | Point = namedtuple('Point', ['x', 'y']) 67 | Sample = namedtuple('Sample', ['filename', 'boxes', 'imgsize']) 68 | Box = namedtuple('Box', ['label', 'labelid', 'center', 'size']) 69 | Score = namedtuple('Score', ['idx', 'score']) 70 | Overlap = namedtuple('Overlap', ['best', 'good']) 71 | 72 | #------------------------------------------------------------------------------- 73 | def str2bool(v): 74 | """ 75 | Convert a string to a boolean 76 | """ 77 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 78 | return True 79 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 80 | return False 81 | else: 82 | raise argparse.ArgumentTypeError('Boolean value expected.') 83 | 84 | #------------------------------------------------------------------------------- 85 | def abs2prop(xmin, xmax, ymin, ymax, imgsize): 86 | """ 87 | Convert the absolute min-max box bound to proportional center-width bounds 88 | """ 89 | width = float(xmax-xmin) 90 | height = float(ymax-ymin) 91 | cx = float(xmin)+width/2 92 | cy = float(ymin)+height/2 93 | width /= imgsize.w 94 | height /= imgsize.h 95 | cx /= imgsize.w 96 | cy /= imgsize.h 97 | return Point(cx, cy), Size(width, height) 98 | 99 | #------------------------------------------------------------------------------- 100 | def prop2abs(center, size, imgsize): 101 | """ 102 | Convert proportional center-width bounds to absolute min-max bounds 103 | """ 104 | width2 = size.w*imgsize.w/2 105 | height2 = size.h*imgsize.h/2 106 | cx = center.x*imgsize.w 107 | cy = center.y*imgsize.h 108 | return int(cx-width2), int(cx+width2), int(cy-height2), int(cy+height2) 109 | 110 | #------------------------------------------------------------------------------- 111 | def box_is_valid(box): 112 | for x in [box.center.x, box.center.y, box.size.w, box.size.h]: 113 | if math.isnan(x) or math.isinf(x): 114 | return False 115 | return True 116 | 117 | #------------------------------------------------------------------------------- 118 | def normalize_box(box): 119 | if not box_is_valid(box): 120 | return box 121 | 122 | img_size = Size(1000, 1000) 123 | xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, img_size) 124 | xmin = max(xmin, 0) 125 | xmax = min(xmax, img_size.w-1) 126 | ymin = max(ymin, 0) 127 | ymax = min(ymax, img_size.h-1) 128 | 129 | # this happens early in the training when box min and max are outside 130 | # of the image 131 | xmin = min(xmin, xmax) 132 | ymin = min(ymin, ymax) 133 | 134 | center, size = abs2prop(xmin, xmax, ymin, ymax, img_size) 135 | return Box(box.label, box.labelid, center, size) 136 | 137 | #------------------------------------------------------------------------------- 138 | def draw_box(img, box, color): 139 | img_size = Size(img.shape[1], img.shape[0]) 140 | xmin, xmax, ymin, ymax = prop2abs(box.center, box.size, img_size) 141 | img_box = np.copy(img) 142 | cv2.rectangle(img_box, (xmin, ymin), (xmax, ymax), color, 2) 143 | cv2.rectangle(img_box, (xmin-1, ymin), (xmax+1, ymin-20), color, cv2.FILLED) 144 | font = cv2.FONT_HERSHEY_SIMPLEX 145 | cv2.putText(img_box, box.label, (xmin+5, ymin-5), font, 0.5, 146 | (255, 255, 255), 1, cv2.LINE_AA) 147 | alpha = 0.8 148 | cv2.addWeighted(img_box, alpha, img, 1.-alpha, 0, img) 149 | 150 | #------------------------------------------------------------------------------- 151 | class PrecisionSummary: 152 | #--------------------------------------------------------------------------- 153 | def __init__(self, session, writer, sample_name, labels, restore=False): 154 | self.session = session 155 | self.writer = writer 156 | self.labels = labels 157 | 158 | sess = session 159 | ph_name = sample_name+'_mAP_ph' 160 | sum_name = sample_name+'_mAP' 161 | 162 | if restore: 163 | self.mAP_placeholder = sess.graph.get_tensor_by_name(ph_name+':0') 164 | self.mAP_summary_op = sess.graph.get_tensor_by_name(sum_name+':0') 165 | else: 166 | self.mAP_placeholder = tf.placeholder(tf.float32, name=ph_name) 167 | self.mAP_summary_op = tf.summary.scalar(sum_name, 168 | self.mAP_placeholder) 169 | 170 | self.placeholders = {} 171 | self.summary_ops = {} 172 | 173 | for label in labels: 174 | sum_name = sample_name+'_AP_'+label 175 | ph_name = sample_name+'_AP_ph_'+label 176 | if restore: 177 | placeholder = sess.graph.get_tensor_by_name(ph_name+':0') 178 | summary_op = sess.graph.get_tensor_by_name(sum_name+':0') 179 | else: 180 | placeholder = tf.placeholder(tf.float32, name=ph_name) 181 | summary_op = tf.summary.scalar(sum_name, placeholder) 182 | self.placeholders[label] = placeholder 183 | self.summary_ops[label] = summary_op 184 | 185 | #--------------------------------------------------------------------------- 186 | def push(self, epoch, mAP, APs): 187 | if not APs: return 188 | 189 | feed = {self.mAP_placeholder: mAP} 190 | tensors = [self.mAP_summary_op] 191 | for label in self.labels: 192 | feed[self.placeholders[label]] = APs[label] 193 | tensors.append(self.summary_ops[label]) 194 | 195 | summaries = self.session.run(tensors, feed_dict=feed) 196 | 197 | for summary in summaries: 198 | self.writer.add_summary(summary, epoch) 199 | 200 | #------------------------------------------------------------------------------- 201 | class ImageSummary: 202 | #--------------------------------------------------------------------------- 203 | def __init__(self, session, writer, sample_name, colors, restore=False): 204 | self.session = session 205 | self.writer = writer 206 | self.colors = colors 207 | 208 | sess = session 209 | sum_name = sample_name+'_img' 210 | ph_name = sample_name+'_img_ph' 211 | if restore: 212 | self.img_placeholder = sess.graph.get_tensor_by_name(ph_name+':0') 213 | self.img_summary_op = sess.graph.get_tensor_by_name(sum_name+':0') 214 | else: 215 | self.img_placeholder = tf.placeholder(tf.float32, name=ph_name, 216 | shape=[None, None, None, 3]) 217 | self.img_summary_op = tf.summary.image(sum_name, 218 | self.img_placeholder) 219 | 220 | #--------------------------------------------------------------------------- 221 | def push(self, epoch, samples): 222 | imgs = np.zeros((3, 512, 512, 3)) 223 | for i, sample in enumerate(samples): 224 | img = cv2.resize(sample[0], (512, 512)) 225 | for _, box in sample[1]: 226 | draw_box(img, box, self.colors[box.label]) 227 | img[img>255] = 255 228 | img[img<0] = 0 229 | imgs[i] = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) 230 | 231 | feed = {self.img_placeholder: imgs} 232 | summary = self.session.run(self.img_summary_op, feed_dict=feed) 233 | self.writer.add_summary(summary, epoch) 234 | 235 | #------------------------------------------------------------------------------- 236 | class LossSummary: 237 | #--------------------------------------------------------------------------- 238 | def __init__(self, session, writer, sample_name, num_samples, 239 | restore=False): 240 | self.session = session 241 | self.writer = writer 242 | self.num_samples = num_samples 243 | self.loss_names = ['total', 'localization', 'confidence', 'l2'] 244 | self.loss_values = {} 245 | self.placeholders = {} 246 | 247 | sess = session 248 | 249 | summary_ops = [] 250 | for loss in self.loss_names: 251 | sum_name = sample_name+'_'+loss+'_loss' 252 | ph_name = sample_name+'_'+loss+'_loss_ph' 253 | 254 | if restore: 255 | placeholder = sess.graph.get_tensor_by_name(ph_name+':0') 256 | summary_op = sess.graph.get_tensor_by_name(sum_name+':0') 257 | else: 258 | placeholder = tf.placeholder(tf.float32, name=ph_name) 259 | summary_op = tf.summary.scalar(sum_name, placeholder) 260 | 261 | self.loss_values[loss] = float(0) 262 | self.placeholders[loss] = placeholder 263 | summary_ops.append(summary_op) 264 | 265 | self.summary_ops = tf.summary.merge(summary_ops) 266 | 267 | #--------------------------------------------------------------------------- 268 | def add(self, values, num_samples): 269 | for loss in self.loss_names: 270 | self.loss_values[loss] += values[loss]*num_samples 271 | 272 | #--------------------------------------------------------------------------- 273 | def push(self, epoch): 274 | feed = {} 275 | for loss in self.loss_names: 276 | feed[self.placeholders[loss]] = \ 277 | self.loss_values[loss]/self.num_samples 278 | 279 | summary = self.session.run(self.summary_ops, feed_dict=feed) 280 | self.writer.add_summary(summary, epoch) 281 | 282 | for loss in self.loss_names: 283 | self.loss_values[loss] = float(0) 284 | --------------------------------------------------------------------------------