├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data ├── coco.names ├── dog.jpg ├── dog_416416.png ├── eagle.jpg ├── giraffe.jpg ├── horses.jpg ├── person.jpg ├── scream.jpg ├── voc.data └── voc.names ├── dataset_factory ├── VOCDataset.py ├── VOCDataset.pyc ├── __init__.py └── __init__.pyc ├── detect.py ├── eval.py ├── models ├── __init__.py ├── __init__.pyc ├── yolo_v2.py ├── yolo_v2.pyc ├── yolo_v2_loss.py ├── yolo_v2_loss.pyc ├── yolo_v2_resnet.py └── yolo_v2_resnet.pyc ├── predictions.jpg ├── scripts ├── demo_detect.sh ├── demo_eval.sh ├── demo_start_tensorboard.sh ├── demo_train.sh └── demo_valid.sh ├── train.py ├── utils ├── __init__.py ├── __init__.pyc ├── cfg_loader.py ├── cfg_loader.pyc ├── iou.py ├── iou.pyc ├── logger.py ├── logger.pyc ├── nms.py └── nms.pyc └── valid.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | weights/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) 2017 {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | pytorch_detection Copyright (C) 2017 liaoyuhua 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### pytorch_detection 2 | This is a pytorch implementaion of YOLO v2 whichi attempts to reproduce the results of [project](https://pjreddie.com/darknet/yolo) and the [paper](https://arxiv.org/abs/1612.08242): YOLO9000: Better,Faster,Stronger by Joseph Redmon and Ali Farhadi. 3 | 4 | This project is based on this project [pytorch-yolo2](https://github.com/marvis/pytorch-yolo2) 5 | 6 | This repository tries to achieve the following goals 7 | - [x] implement yolo v2 forward network using config yolo-voc.cfg 8 | - [x] implement load darknet's [yolo-voc.weights](http://pjreddie.com/media/files/yolo-voc.weights) 9 | - [x] implement detect.py 10 | - [x] implement valid.py. This script produces results of pasval evaluation format for evaluation. 11 | - [x] implement eval.py. 12 | - [x] implement darknet loss 13 | - [x] implement train.py. 14 | - [x] save as darknet weights 15 | - [x] support log to tensorboard 16 | - [x] support multi-gpu training 17 | - [x] add image preprocess step to boost model accuracy get 0.7303@20171106 18 | - [ ] optimize code in yolo-v2 loss to reduce training time 19 | 20 | **NOTE:** 21 | This is still an experimental project. Model trained on VOC0712 train+val 22 | 23 | VOC07 test mAP is 0.5630 @20171019
24 | 25 | VOC07 test mAp is 0.7303 @20171106
26 | AP for aeroplane = 0.784
27 | AP for bicycle = 0.783
28 | AP for bird = 0.754
29 | AP for boat = 0.648
30 | AP for bottle = 0.481
31 | AP for bus = 0.777
32 | AP for car = 0.824
33 | AP for cat = 0.841
34 | AP for chair = 0.56
35 | AP for cow = 0.772
36 | AP for diningtable = 0.719
37 | AP for dog = 0.79
38 | AP for horse = 0.807
39 | AP for motorbike = 0.784
40 | AP for person = 0.753
41 | AP for pottedplant = 0.53
42 | AP for sheep = 0.765
43 | AP for sofa = 0.708
44 | AP for train = 0.818
45 | AP for tvmonitor = 0.709
46 | 47 | ### Detection Using a Pretrained Model 48 | ``` 49 | mkdir weights && cd weights 50 | wget http://pjreddie.com/media/files/yolo-voc.weights 51 | cd .. 52 | ./scripts/demo_detect.sh 53 | ``` 54 | 55 | ### Training YOLOv2 56 | You can train YOLOv2 on any dataset. Here we train on VOC2007/2012 train+val 57 | 1. Get the PASCAL VOC Data(2007trainval+2012trainval+2007test) 58 | ``` 59 | mkdir dataSet && cd dataSet 60 | wget https://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar 61 | wget https://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar 62 | wget https://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar 63 | tar xf VOCtrainval_11-May-2012.tar 64 | tar xf VOCtrainval_06-Nov-2007.tar 65 | tar xf VOCtest_06-Nov-2007.tar 66 | cd .. 67 | ``` 68 | 2. Generate Labels for VOC 69 | ``` 70 | cd dataSet 71 | wget http://pjreddie.com/media/files/voc_label.py 72 | python voc_label.py 73 | cat 2007_train.txt 2007_val.txt 2012_*.txt > voc_train.txt 74 | ``` 75 | 3. Modify data/voc.data for Pascal Data 76 | ``` 77 | train = dataSet/train.txt 78 | valid = dataSet/2007_test.txt 79 | names = data/voc.names 80 | backup = backup 81 | ``` 82 | 4. Download Pretrained Convolutional Weights 83 | ``` 84 | cd weights 85 | wget http://pjreddie.com/media/files/darknet19_448.conv.23 86 | cd .. 87 | ``` 88 | 5. Train The Model 89 | ``` 90 | ./scripts/demo_train.sh 91 | ``` 92 | 6. Evaluate The Model 93 | if you want to eval the model, please modify the result directory in demo_eval.sh after running demo_valid 94 | ``` 95 | ./scripts/demo_valid.sh 96 | ./scripts/demo_eval.sh 97 | ``` 98 | -------------------------------------------------------------------------------- /data/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/dog.jpg -------------------------------------------------------------------------------- /data/dog_416416.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/dog_416416.png -------------------------------------------------------------------------------- /data/eagle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/eagle.jpg -------------------------------------------------------------------------------- /data/giraffe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/giraffe.jpg -------------------------------------------------------------------------------- /data/horses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/horses.jpg -------------------------------------------------------------------------------- /data/person.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/person.jpg -------------------------------------------------------------------------------- /data/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/data/scream.jpg -------------------------------------------------------------------------------- /data/voc.data: -------------------------------------------------------------------------------- 1 | train = ../dataSet/voc_train.txt 2 | valid = ../dataSet/2007_test.txt 3 | names = ./data/voc.names 4 | backup = backup 5 | gpus = 0 6 | -------------------------------------------------------------------------------- /data/voc.names: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor 21 | -------------------------------------------------------------------------------- /dataset_factory/VOCDataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import numpy as np 5 | import numpy.random as npr 6 | 7 | 8 | from torch.utils.data import Dataset 9 | from PIL import Image,ImageDraw 10 | def read_truth_args(lab_path, min_box_scale): 11 | if os.path.exists(lab_path): 12 | truths = np.loadtxt(lab_path) 13 | truths = truths.reshape(truths.size/5,5) 14 | #print "min_box_scale {} truths.size {}".format(min_box_scale,truths.size) 15 | new_truths = [] 16 | for i in range(truths.shape[0]): 17 | if truths[i][3]255: 37 | x-=255 38 | if x<=0: 39 | x+=255 40 | return x 41 | cs[0] =cs[0].point(change_hue) 42 | im = Image.merge(im.mode,tuple(cs)) 43 | im = im.convert('RGB') 44 | return im 45 | 46 | def random_distort_image(img,hue,saturation,exposure): 47 | fhue = npr.uniform(-hue,hue) 48 | fsat = rand_scale(saturation) 49 | fexp = rand_scale(exposure) 50 | res = distort_image(img,fhue,fsat,fexp) 51 | return res 52 | 53 | def data_augmentation(img,shape,jitter,hue,saturation,exposure): 54 | 55 | w = shape[0] 56 | h = shape[1] 57 | ow = img.width 58 | oh = img.height 59 | dw = int(jitter*ow) 60 | dh = int(jitter*oh) 61 | 62 | new_ar = (ow- npr.randint(-dw,dw))*1.0/(oh - npr.randint(-dh,dh)) 63 | scale = npr.random()*(2-0.25)+0.25 64 | 65 | if new_ar<1: 66 | nh = scale*h 67 | nw = nh*new_ar 68 | else: 69 | nw = scale*w 70 | nh = nw/new_ar 71 | nh = int(nh) 72 | nw = int(nw) 73 | nw_im = img.resize((nw,nh)) 74 | out_im=Image.new(img.mode,shape,(128,128,128)) 75 | nw_im_np = np.array(nw_im) 76 | out_im_np = np.array(out_im) 77 | 78 | dx = npr.random()*abs(w-nw) + min(0, w-nw) 79 | dy = npr.random()*abs(h-nh) + min(0, h-nh) 80 | 81 | dx = int(dx) 82 | dy = int(dy) 83 | # print "dx %d dy %d" %(dx,dy) 84 | # print nw_im_np.shape 85 | if dx <0: 86 | nw_im_start_x = abs(dx) 87 | nw_im_end_x = min(abs(dx)+w,nw) 88 | out_im_start_x = 0 89 | out_im_end_x = min(abs(dx)+w,nw)-abs(dx) 90 | else: 91 | nw_im_start_x = 0 92 | nw_im_end_x = min(nw,w-dx) 93 | out_im_start_x = dx 94 | out_im_end_x = min(w,dx+min(nw,w-dx)) 95 | if dy <0: 96 | nw_im_start_y = abs(dy) 97 | nw_im_end_y = min(abs(dy)+h,nh) 98 | out_im_start_y = 0 99 | out_im_end_y = min(abs(dy)+h,nh)-abs(dy) 100 | else: 101 | nw_im_start_y = 0 102 | nw_im_end_y = min(nh,h-dy) 103 | out_im_start_y = dy 104 | out_im_end_y = min(h,dy+min(nh,h-dy)) 105 | 106 | out_im_np[out_im_start_y:out_im_end_y,out_im_start_x:out_im_end_x,:] = nw_im_np[nw_im_start_y:nw_im_end_y,nw_im_start_x:nw_im_end_x,:] 107 | out_im = Image.fromarray(out_im_np) 108 | 109 | dx = -dx*1.0/w 110 | dy = -dy*1.0/h 111 | 112 | sx = nw*1.0/w 113 | sy = nh*1.0/h 114 | 115 | 116 | out_im = random_distort_image(out_im,hue,saturation,exposure) 117 | 118 | flip = npr.randint(1,10000)%2 119 | if flip: 120 | out_im = out_im.transpose(Image.FLIP_LEFT_RIGHT) 121 | return out_im,flip,dx,dy,sx,sy 122 | 123 | def fill_truth_detection(labpath,w,h,flip,dx,dy,sx,sy): 124 | 125 | max_boxes_per_img = 30 126 | label = np.zeros((max_boxes_per_img,5)) 127 | if os.path.exists(labpath): 128 | gts = np.loadtxt(labpath) 129 | if gts is not None: 130 | 131 | #print gts 132 | gts = np.reshape(gts,(-1,5)) 133 | npr.shuffle(gts) 134 | cc = 0 135 | for i in range(gts.shape[0]): 136 | x1 = gts[i][1] - gts[i][3]/2 137 | y1 = gts[i][2] - gts[i][4]/2 138 | x2 = gts[i][1] + gts[i][3]/2 139 | y2 = gts[i][2] + gts[i][4]/2 140 | 141 | x1=min(0.999,max(0,x1*sx-dx)) 142 | y1=min(0.999,max(0,y1*sy-dy)) 143 | x2=min(0.999,max(0,x2*sx-dx)) 144 | y2=min(0.999,max(0,y2*sy-dy)) 145 | 146 | gts[i][1] = (x1+x2)/2 147 | gts[i][2] = (y1+y2)/2 148 | gts[i][3] = x2-x1 149 | gts[i][4] = y2-y1 150 | 151 | if flip: 152 | gts[i][1] = 0.999-gts[i][1] 153 | if gts[i][3]<0.002 or gts[i][4]<0.002: 154 | continue 155 | label[cc] =gts[i] 156 | cc +=1 157 | if cc>=max_boxes_per_img: 158 | break 159 | else: 160 | print "label path not exist!" 161 | label = np.reshape(label,(-1)) 162 | return label 163 | 164 | 165 | class VOCDataset(Dataset): 166 | def __init__(self,image_files,shape=None,shuffle=True,batch_size=64,train_phase=False,transform=None,target_transform=None): 167 | super(VOCDataset,self).__init__() 168 | self.image_files = image_files 169 | if shuffle: 170 | npr.shuffle(self.image_files) 171 | 172 | self.image_num = len(self.image_files) 173 | self.transform = transform 174 | self.target_transform = target_transform 175 | self.batch_size = batch_size 176 | self.shape = shape 177 | self.train_phase= train_phase 178 | 179 | def __len__(self): 180 | return self.image_num 181 | 182 | def __getitem__(self,index): 183 | 184 | image_path = self.image_files[index].rstrip() 185 | labelpath = image_path.replace('JPEGImages','labels').replace('.jpg','.txt').replace('.png','.txt') 186 | img = Image.open(image_path).convert('RGB') 187 | 188 | if self.train_phase: 189 | jitter = 0.3 190 | saturation = 1.5 191 | exposure = 1.5 192 | hue = 0.1 193 | img,flip,dx,dy,sx,sy = data_augmentation(img,self.shape,jitter,hue,saturation,exposure) 194 | # print labelpath 195 | label = fill_truth_detection(labelpath,img.width,img.height,flip,dx,dy,sx,sy) 196 | # out_im_draw = ImageDraw.Draw(img) 197 | # label = np.reshape(label,(-1,5)) 198 | # print label.shape 199 | # print label 200 | # for i in range(label.shape[0]): 201 | # if label[i][1] ==0 : 202 | # continue 203 | # cx = label[i][1]*img.width 204 | # cy = label[i][2]*img.height 205 | # w = label[i][3]*img.width 206 | # h = label[i][4]*img.height 207 | # new_loc = [cx-w/2,cy-h/2,cx+w/2,cy+h/2] 208 | # out_im_draw.rectangle(new_loc,outline=(0,0,255)) 209 | # img.save('load_test_1.PNG','PNG') 210 | # label = np.reshape(label, (-1)) 211 | label = torch.from_numpy(label) 212 | 213 | 214 | 215 | else: 216 | if self.shape: 217 | img = img.resize(self.shape) 218 | label = torch.zeros(50*5) 219 | 220 | truths = read_truth_args(labelpath,8.0/img.width) 221 | 222 | 223 | #print "returned turthes {}".format(truths) 224 | tmp = torch.from_numpy(truths) 225 | 226 | tmp = tmp.view(-1) 227 | tsz = tmp.numel() 228 | 229 | if tsz >50*5: 230 | print ("warning labeled object morn than %d" %(50)) 231 | label = tmp[0:50*5] 232 | else: 233 | label[0:tsz] = tmp 234 | 235 | 236 | if self.transform is not None: 237 | img = self.transform(img) 238 | 239 | if self.target_transform is not None: 240 | label = self.target_transform(label) 241 | 242 | return (img,label) 243 | 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /dataset_factory/VOCDataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/dataset_factory/VOCDataset.pyc -------------------------------------------------------------------------------- /dataset_factory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/dataset_factory/__init__.py -------------------------------------------------------------------------------- /dataset_factory/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/dataset_factory/__init__.pyc -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | # detect.py 2 | import sys 3 | from PIL import Image,ImageDraw 4 | from models.yolo_v2 import yolo_v2 5 | import torch 6 | from torch.autograd import Variable 7 | import time 8 | 9 | from utils.cfg_loader import load_class_names 10 | from utils.iou import bbox_iou 11 | from utils.nms import nms 12 | 13 | def plot_boxes(img,boxes,savename=None,class_names=None): 14 | width = img.width 15 | height = img.height 16 | 17 | draw = ImageDraw.Draw(img) 18 | for i in range(len(boxes)): 19 | box = boxes[i] 20 | x1 = (box[0] - box[2]/2.0)*width 21 | y1 = (box[1] - box[3]/2.0)*height 22 | x2 = (box[0] + box[2]/2.0)*width 23 | y2 = (box[1] + box[3]/2.0)*height 24 | 25 | rgb = (255,0,0) 26 | if class_names: 27 | det_conf = box[4] 28 | cls_conf = box[5] 29 | cls_ind = box[6] 30 | thr = det_conf*cls_conf 31 | print ('%12s:cls_conf=%8.5f det_conf=%8.5f thr=%8.5f' %(class_names[cls_ind],cls_conf,det_conf,thr)) 32 | rgb_anno = (0,0,255) 33 | draw.text((x1,y1),class_names[cls_ind],fill=rgb_anno) 34 | #print("{} {} {} {} ".format(x1,y1,x2,y2)) 35 | draw.rectangle([x1,y1,x2,y2],outline=rgb) 36 | 37 | if savename: 38 | print("save plot results to {}".format(savename)) 39 | img.save(savename) 40 | 41 | 42 | def detect(namesfile, weightfile, imgfile): 43 | 44 | conf_thresh = 0.25 45 | nms_thresh = 0.45 46 | model = yolo_v2() 47 | model.load_weights(weightfile) 48 | #model.save_weights('weights/save_test.weights') 49 | if torch.cuda.is_available(): 50 | model.cuda() 51 | model.eval() 52 | 53 | img_orig = Image.open(imgfile).convert('RGB') 54 | siezd = img_orig.resize((model.width,model.height)) 55 | 56 | if isinstance(siezd,Image.Image): 57 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(siezd.tobytes())) 58 | img = img.view(model.height, model.width, 3).transpose(0,1).transpose(0,2).contiguous() 59 | img = img.view(1, 3, model.height, model.width) 60 | img = img.float().div(255.0) 61 | if torch.cuda.is_available(): 62 | img =img.cuda() 63 | img = Variable(img) 64 | 65 | start = time.time() 66 | output = model(img) 67 | output = output.data 68 | finish = time.time() 69 | boxes = model.get_region_boxes(output, conf_thresh)[0] 70 | 71 | #print("before nms") 72 | #print(boxes) 73 | boxes = nms(boxes, nms_thresh) 74 | #print("after nms") 75 | #print(boxes) 76 | print("{}: Predicted in {} seconds.".format(imgfile, (finish-start))) 77 | class_names = load_class_names(namesfile) 78 | plot_boxes(img_orig,boxes, 'predictions.jpg',class_names) 79 | 80 | 81 | if __name__=='__main__': 82 | if len(sys.argv) == 4: 83 | namesfile = sys.argv[1] 84 | weightfile = sys.argv[2] 85 | imgfile = sys.argv[3] 86 | detect(namesfile,weightfile,imgfile) 87 | else: 88 | print("Usage: ") 89 | print("python detect.py namesfile weightfile imgfile") 90 | print("Please use yolo-voc.weights") 91 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import xml.etree.ElementTree as ET 5 | from utils.cfg_loader import load_class_names 6 | from utils.cfg_loader import read_data_cfg 7 | from utils.iou import bbox_iou 8 | import cPickle 9 | import numpy as np 10 | import cPickle 11 | 12 | def parse_anno(anno_file): 13 | tree = ET.parse(anno_file) 14 | objects = [] 15 | for obj in tree.findall('object'): 16 | obj_struct = {} 17 | obj_struct['name'] = obj.find('name').text 18 | obj_struct['pose'] = obj.find('pose').text 19 | obj_struct['truncated'] = int(obj.find('truncated').text) 20 | obj_struct['difficult'] = int(obj.find('difficult').text) 21 | 22 | bbox = obj.find('bndbox') 23 | obj_struct['bbox'] =[ int(bbox.find('xmin').text), 24 | int(bbox.find('ymin').text), 25 | int(bbox.find('xmax').text), 26 | int(bbox.find('ymax').text)] 27 | objects.append(obj_struct) 28 | return objects 29 | 30 | def voc_eval(imageset_file,det_file,cls_name,cachedir,ovthresh=0.5): 31 | 32 | cachefile = os.path.join(cachedir,'voc_2007_test_annos.pkl') 33 | with open(imageset_file, 'r' ) as f: 34 | image_files = f.readlines() 35 | image_files = [file.rstrip() for file in image_files] 36 | 37 | if not os.path.isfile(cachefile): 38 | recs = {} 39 | for i,imgname in enumerate(image_files): 40 | anno_name = imgname.replace('.jpg','.xml').replace('JPEGImages','Annotations') 41 | recs[imgname] = parse_anno(anno_name) 42 | if i %100 ==0: 43 | print 'Reading annotation for {:d}/{:d}'.format(i+1,len(image_files)) 44 | with open(cachefile,'w') as f: 45 | cPickle.dump(recs,f) 46 | else: 47 | with open(cachefile,'r') as f: 48 | recs = cPickle.load(f) 49 | 50 | #load gt files 51 | gts = {} 52 | npos = 0 53 | for imgname in image_files: 54 | gts_per_img = [obj for obj in recs[imgname] if obj['name']==cls_name] 55 | bboxes = np.array([agt['bbox'] for agt in gts_per_img]) 56 | difficult = np.array([obj['difficult'] for obj in gts_per_img]).astype(np.bool) 57 | npos = npos + np.sum(~difficult) 58 | 59 | det = [False]*len(gts_per_img) 60 | gts[imgname] = {'bbs':bboxes, 61 | 'difficult':difficult, 62 | 'det':det} 63 | #print 'cls %s has %d gts in dataset'%(cls_name,npos) 64 | #read dets 65 | if os.path.isfile(det_file): 66 | with open(det_file,'r') as f: 67 | lines = f.readlines() 68 | 69 | img_names = [line.strip().split(' ')[0] for line in lines] 70 | confidence =np.array([float(line.strip().split(' ')[1]) for line in lines]) 71 | detBndBoxes = np.array([[np.float(loc) for loc in line.strip().split(' ')[2:]] for line in lines]) 72 | 73 | #sorted by confidence 74 | sorted_ind = np.argsort(-confidence) 75 | detBndBoxes = detBndBoxes[sorted_ind,:] 76 | img_names = [img_names[ind] for ind in sorted_ind] 77 | 78 | detnum = len(img_names) 79 | tp = np.zeros(detnum) 80 | fp = np.zeros(detnum) 81 | 82 | for detid in range(detnum): 83 | gts_for_img = gts[img_names[detid]] 84 | det_bb = detBndBoxes[detid] 85 | 86 | ovmax = -np.inf 87 | gt_bbs = gts_for_img['bbs'].astype(np.float) 88 | 89 | if gt_bbs.size>0: 90 | #print gt_bbs 91 | gt_bbs = np.transpose(gt_bbs) 92 | #print gt_bbs 93 | overlaps = bbox_iou(gt_bbs,det_bb,x1y1x2y2=True) 94 | ovmax = np.max(overlaps) 95 | jmax = np.argmax(overlaps) 96 | 97 | if ovmax > ovthresh: 98 | if not gts_for_img['difficult'][jmax]: 99 | if not gts_for_img['det'][jmax]: 100 | tp[detid] = 1 101 | gts_for_img['det'][jmax]= True 102 | else: 103 | fp[detid] = 1 104 | else: 105 | fp[detid] = 1 106 | fp = np.cumsum(fp) 107 | tp = np.cumsum(tp) 108 | rec = tp/float(npos) 109 | 110 | prec = tp/(np.maximum(tp+fp,np.finfo(np.float64).eps)) 111 | 112 | voc_2007_metric = True 113 | if voc_2007_metric: 114 | ap = 0.0 115 | for t in np.arange(0.0,1.1,0.1): 116 | if np.sum(rec >=t)==0: 117 | p = 0 118 | else: 119 | p=np.max(prec[rec >=t]) 120 | ap = ap + p/11.0 121 | else: 122 | print 'detfile %s not exist' %(det_file) 123 | rec = 0 124 | prec = 0 125 | ap = 0 126 | return rec,prec,ap 127 | 128 | 129 | if __name__=="__main__": 130 | if len(sys.argv) == 3: 131 | datacfg = sys.argv[1] 132 | result_dir = sys.argv[2] 133 | 134 | #read data cfg 135 | options = read_data_cfg(datacfg) 136 | valid_images_set_file = options['valid'] 137 | namesfile = options['names'] 138 | class_names = load_class_names(namesfile) 139 | resultfile_prefix = 'comp4_det_test' 140 | anno_cachedir = 'anno_cached' 141 | if not os.path.exists(anno_cachedir): 142 | os.mkdir(anno_cachedir) 143 | 144 | avrPres = [] 145 | for cls_name in class_names: 146 | det_file = '%s/%s_%s.txt'%(result_dir,resultfile_prefix,cls_name) 147 | rec,prec,ap = voc_eval(valid_images_set_file,det_file,cls_name,anno_cachedir,ovthresh=0.5) 148 | avrPres.append(ap) 149 | print 'AP for {:15s} = {:.4f}'.format(cls_name,ap) 150 | print 'Mean AP = {:.4f}'.format(np.mean(np.array(avrPres))) 151 | print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' 152 | else: 153 | print("Usage:") 154 | print("python eval.py datacfg resultdir") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/models/__init__.py -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/models/__init__.pyc -------------------------------------------------------------------------------- /models/yolo_v2.py: -------------------------------------------------------------------------------- 1 | # 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from easydict import EasyDict as edict 9 | import sys 10 | sys.path.insert(0, "..") 11 | from utils.iou import bbox_iou 12 | 13 | class EmptyModule(nn.Module): 14 | def __init__(self): 15 | super(EmptyModule,self).__init__() 16 | def forward(self,x): 17 | return x 18 | 19 | class yolo_v2_reorg(nn.Module): 20 | def __init__(self,stride=2): 21 | super(yolo_v2_reorg,self).__init__() 22 | self.stride = stride 23 | def forward(self,x): 24 | stride = self.stride 25 | assert(x.data.dim()==4) 26 | B = x.data.size(0) 27 | C = x.data.size(1) 28 | H = x.data.size(2) 29 | W = x.data.size(3) 30 | assert(H % stride == 0) 31 | assert(W % stride == 0) 32 | ws = stride 33 | hs = stride 34 | x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous() 35 | x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous() 36 | x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous() 37 | x = x.view(B, hs*ws*C, H/hs, W/ws) 38 | return x 39 | 40 | class yolo_v2(nn.Module): 41 | def __init__(self): 42 | super(yolo_v2,self).__init__() 43 | self.width = 416 44 | self.height = 416 45 | self.models,self.layerInd_has_no_weights = self.create_model() 46 | self.header = torch.IntTensor([0,0,0,0]) 47 | self.seen = 0 48 | self.anchors_str = "1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071" 49 | self.num_classes = 20 50 | self.anchor_step = 2 51 | self.anchors = [float(i) for i in self.anchors_str.lstrip().rstrip().split(',')] 52 | self.num_anchors = len(self.anchors)/self.anchor_step 53 | self.network_name = "yolo_v2" 54 | 55 | def create_model(self): 56 | models = nn.ModuleList() 57 | layerInd_has_no_weights = [] 58 | #32 59 | conv0 = nn.Sequential() 60 | conv0.add_module('conv0',nn.Conv2d(3,32,3,1,1,bias=False)) 61 | conv0.add_module('bn0',nn.BatchNorm2d(32)) 62 | conv0.add_module('leaky0',nn.LeakyReLU(0.1,inplace=True)) 63 | models.append(conv0) 64 | #max pool 0 ind=1 65 | models.append(nn.MaxPool2d(2,2)) 66 | layerInd_has_no_weights.append(1) 67 | #64 68 | conv1 = nn.Sequential() 69 | conv1.add_module('conv1',nn.Conv2d(32,64,3,1,1,bias=False)) 70 | conv1.add_module('bn1',nn.BatchNorm2d(64)) 71 | conv1.add_module('leaky1',nn.LeakyReLU(0.1,inplace=True)) 72 | models.append(conv1) 73 | #max pool 1 ind=3 74 | models.append(nn.MaxPool2d(2,2)) 75 | layerInd_has_no_weights.append(3) 76 | #128 77 | conv2 = nn.Sequential() 78 | conv2.add_module('conv2',nn.Conv2d(64,128,3,1,1,bias=False)) 79 | conv2.add_module('bn2',nn.BatchNorm2d(128)) 80 | conv2.add_module('leaky2',nn.LeakyReLU(0.1,inplace=True)) 81 | models.append(conv2) 82 | conv3 = nn.Sequential() 83 | conv3.add_module('conv3',nn.Conv2d(128,64,1,1,0,bias=False)) 84 | conv3.add_module('bn3',nn.BatchNorm2d(64)) 85 | conv3.add_module('leaky3',nn.LeakyReLU(0.1,inplace=True)) 86 | models.append(conv3) 87 | conv4 = nn.Sequential() 88 | conv4.add_module('conv4',nn.Conv2d(64,128,3,1,1,bias=False)) 89 | conv4.add_module('bn4',nn.BatchNorm2d(128)) 90 | conv4.add_module('leaky4',nn.LeakyReLU(0.1,inplace=True)) 91 | models.append(conv4) 92 | #max pool 2 ind =7 93 | models.append(nn.MaxPool2d(2,2)) 94 | layerInd_has_no_weights.append(7) 95 | #256 96 | conv5 = nn.Sequential() 97 | conv5.add_module('conv5',nn.Conv2d(128,256,3,1,1,bias=False)) 98 | conv5.add_module('bn5',nn.BatchNorm2d(256)) 99 | conv5.add_module('leaky5',nn.LeakyReLU(0.1,inplace=True)) 100 | models.append(conv5) 101 | conv6 = nn.Sequential() 102 | conv6.add_module('conv6',nn.Conv2d(256,128,1,1,0,bias=False)) 103 | conv6.add_module('bn6',nn.BatchNorm2d(128)) 104 | conv6.add_module('leaky6',nn.LeakyReLU(0.1,inplace=True)) 105 | models.append(conv6) 106 | conv7 = nn.Sequential() 107 | conv7.add_module('conv7',nn.Conv2d(128,256,3,1,1,bias=False)) 108 | conv7.add_module('bn7',nn.BatchNorm2d(256)) 109 | conv7.add_module('leaky7',nn.LeakyReLU(0.1,inplace=True)) 110 | models.append(conv7) 111 | #max pool 3 ind=11 112 | models.append(nn.MaxPool2d(2,2)) 113 | layerInd_has_no_weights.append(11) 114 | #512 115 | conv8 = nn.Sequential() 116 | conv8.add_module('conv8',nn.Conv2d(256,512,3,1,1,bias=False)) 117 | conv8.add_module('bn8',nn.BatchNorm2d(512)) 118 | conv8.add_module('leaky8',nn.LeakyReLU(0.1,inplace=True)) 119 | models.append(conv8) 120 | conv9 = nn.Sequential() 121 | conv9.add_module('conv9',nn.Conv2d(512,256,1,1,0,bias=False)) 122 | conv9.add_module('bn9',nn.BatchNorm2d(256)) 123 | conv9.add_module('leaky9',nn.LeakyReLU(0.1,inplace=True)) 124 | models.append(conv9) 125 | conv10 = nn.Sequential() 126 | conv10.add_module('conv10',nn.Conv2d(256,512,3,1,1,bias=False)) 127 | conv10.add_module('bn10',nn.BatchNorm2d(512)) 128 | conv10.add_module('leaky10',nn.LeakyReLU(0.1,inplace=True)) 129 | models.append(conv10) 130 | conv11 = nn.Sequential() 131 | conv11.add_module('conv11',nn.Conv2d(512,256,1,1,0,bias=False)) 132 | conv11.add_module('bn11',nn.BatchNorm2d(256)) 133 | conv11.add_module('leaky11',nn.LeakyReLU(0.1,inplace=True)) 134 | models.append(conv11) 135 | #keep result ind=16 136 | conv12 = nn.Sequential() 137 | conv12.add_module('conv12',nn.Conv2d(256,512,3,1,1,bias=False)) 138 | conv12.add_module('bn12',nn.BatchNorm2d(512)) 139 | conv12.add_module('leaky12',nn.LeakyReLU(0.1,inplace=True)) 140 | models.append(conv12) 141 | #max pool 4 ind=17 142 | models.append(nn.MaxPool2d(2,2)) 143 | layerInd_has_no_weights.append(17) 144 | #1024 145 | conv13 = nn.Sequential() 146 | conv13.add_module('conv13',nn.Conv2d(512,1024,3,1,1,bias=False)) 147 | conv13.add_module('bn13',nn.BatchNorm2d(1024)) 148 | conv13.add_module('leaky13',nn.LeakyReLU(0.1,inplace=True)) 149 | models.append(conv13) 150 | conv14 = nn.Sequential() 151 | conv14.add_module('conv14',nn.Conv2d(1024,512,1,1,0,bias=False)) 152 | conv14.add_module('bn14',nn.BatchNorm2d(512)) 153 | conv14.add_module('leaky14',nn.LeakyReLU(0.1,inplace=True)) 154 | models.append(conv14) 155 | conv15 = nn.Sequential() 156 | conv15.add_module('conv15',nn.Conv2d(512,1024,3,1,1,bias=False)) 157 | conv15.add_module('bn15',nn.BatchNorm2d(1024)) 158 | conv15.add_module('leaky15',nn.LeakyReLU(0.1,inplace=True)) 159 | models.append(conv15) 160 | conv16 = nn.Sequential() 161 | conv16.add_module('conv16',nn.Conv2d(1024,512,1,1,0,bias=False)) 162 | conv16.add_module('bn16',nn.BatchNorm2d(512)) 163 | conv16.add_module('leaky16',nn.LeakyReLU(0.1,inplace=True)) 164 | models.append(conv16) 165 | conv17 = nn.Sequential() 166 | conv17.add_module('conv17',nn.Conv2d(512,1024,3,1,1,bias=False)) 167 | conv17.add_module('bn17',nn.BatchNorm2d(1024)) 168 | conv17.add_module('leaky17',nn.LeakyReLU(0.1,inplace=True)) 169 | models.append(conv17) 170 | ################################## 171 | conv18 = nn.Sequential() 172 | conv18.add_module('conv18',nn.Conv2d(1024,1024,3,1,1,bias=False)) 173 | conv18.add_module('bn18',nn.BatchNorm2d(1024)) 174 | conv18.add_module('leaky18',nn.LeakyReLU(0.1,inplace=True)) 175 | models.append(conv18) 176 | 177 | #keep result id=24 178 | conv19 = nn.Sequential() 179 | conv19.add_module('conv19',nn.Conv2d(1024,1024,3,1,1,bias=False)) 180 | conv19.add_module('bn19',nn.BatchNorm2d(1024)) 181 | conv19.add_module('leaky19',nn.LeakyReLU(0.1,inplace=True)) 182 | models.append(conv19) 183 | 184 | #route -9 id=25 185 | models.append(EmptyModule()) 186 | layerInd_has_no_weights.append(25) 187 | #conv id=26 188 | conv20 = nn.Sequential() 189 | conv20.add_module('conv20',nn.Conv2d(512,64,1,1,0,bias=False)) 190 | conv20.add_module('bn20',nn.BatchNorm2d(64)) 191 | conv20.add_module('leaky20',nn.LeakyReLU(0.1,inplace=True)) 192 | models.append(conv20) 193 | #reorg id=27 194 | models.append(yolo_v2_reorg(2)) 195 | layerInd_has_no_weights.append(27) 196 | #route -1,-4 id=28 197 | models.append(EmptyModule()) 198 | layerInd_has_no_weights.append(28) 199 | 200 | #conv id =29 201 | conv21 = nn.Sequential() 202 | conv21.add_module('conv21',nn.Conv2d(1280,1024,3,1,1,bias=False)) 203 | conv21.add_module('bn21',nn.BatchNorm2d(1024)) 204 | conv21.add_module('leaky21',nn.LeakyReLU(0.1,inplace=True)) 205 | models.append(conv21) 206 | 207 | #conv id = 30 208 | conv22 = nn.Sequential() 209 | conv22.add_module('conv22',nn.Conv2d(1024,125,1,1,0)) 210 | models.append(conv22) 211 | 212 | return models,layerInd_has_no_weights 213 | 214 | 215 | def get_region_boxes(self, output,conf_thresh): 216 | anchor_step = self.anchor_step 217 | num_classes = self.num_classes 218 | num_anchors = self.num_anchors 219 | anchors = self.anchors 220 | if output.dim() ==3: 221 | output = output.unsequence(0) 222 | batch = output.size(0) 223 | assert(output.size(1) == (5+num_classes)*num_anchors) 224 | h = output.size(2) 225 | w = output.size(3) 226 | 227 | 228 | output = output.view(batch*num_anchors,5+num_classes,h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w) 229 | 230 | grid_x = torch.linspace(0,w-1,w).repeat(h,1).repeat(batch*num_anchors,1,1).view(batch*num_anchors*h*w).cuda() 231 | grid_y = torch.linspace(0,h-1,h).repeat(w,1).t().repeat(batch*num_anchors,1,1).view(batch*num_anchors*h*w).cuda() 232 | 233 | cx = torch.sigmoid(output[0]) + grid_x 234 | cy = torch.sigmoid(output[1]) + grid_y 235 | anchor_w = torch.Tensor(anchors).view(num_anchors,anchor_step).index_select(1,torch.LongTensor([0])) 236 | anchor_h = torch.Tensor(anchors).view(num_anchors,anchor_step).index_select(1,torch.LongTensor([1])) 237 | anchor_w = anchor_w.repeat(batch,1).repeat(1,1,h*w).view(batch*num_anchors*h*w).cuda() 238 | anchor_h = anchor_h.repeat(batch,1).repeat(1,1,h*w).view(batch*num_anchors*h*w).cuda() 239 | ws = torch.exp(output[2])*anchor_w 240 | hs = torch.exp(output[3])*anchor_h 241 | 242 | def_confs = torch.sigmoid(output[4]) 243 | 244 | nnSoftmax = torch.nn.Softmax() 245 | 246 | cls_confs = nnSoftmax(Variable(output[5:5+num_classes].transpose(0,1))).data 247 | cls_max_confs,cls_max_ids = torch.max(cls_confs,1) 248 | cls_max_confs = cls_max_confs.view(-1) 249 | cls_max_ids = cls_max_ids.view(-1) 250 | 251 | def_confs = self.convert2cpu(def_confs) 252 | cls_max_confs = self.convert2cpu(cls_max_confs) 253 | cls_max_ids = self.convert2cpu_long(cls_max_ids) 254 | cx = self.convert2cpu(cx) 255 | cy = self.convert2cpu(cy) 256 | ws = self.convert2cpu(ws) 257 | hs = self.convert2cpu(hs) 258 | 259 | all_boxes = [] 260 | for b in range(batch): 261 | boxes = [] 262 | for row in range(h): 263 | for col in range(w): 264 | for i in range(num_anchors): 265 | ind = b*h*w*num_anchors + i*h*w + row*w + col 266 | conf = def_confs[ind]*cls_max_confs[ind] 267 | if conf >conf_thresh: 268 | bcx = cx[ind] 269 | bcy = cy[ind] 270 | bw = ws[ind] 271 | bh = hs[ind] 272 | #print "bbox {} {} {} {}".format(bcx,bcy,bw,bh) 273 | box = [bcx/w,bcy/h,bw/w,bh/h,def_confs[ind],cls_max_confs[ind],cls_max_ids[ind]] 274 | boxes.append(box) 275 | all_boxes.append(boxes) 276 | return all_boxes 277 | 278 | def forward(self,x): 279 | outputs = dict() 280 | for ind,model in enumerate(self.models): 281 | #route 282 | if ind == 25: 283 | input = outputs[ind-9] 284 | x = model(input) 285 | #route 286 | elif ind == 28: 287 | input = torch.cat((outputs[ind-1],outputs[ind-4]),1) 288 | x = model(input) 289 | else: 290 | x = model(x) 291 | if ind==16 or ind==27 or ind==24: 292 | outputs[ind]=x 293 | return x 294 | 295 | def load_conv(self,buf,start,conv_model): 296 | num_w = conv_model.weight.numel() 297 | num_b = conv_model.bias.numel() 298 | conv_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])) 299 | start = start +num_b 300 | conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w]).view(conv_model.weight.size())) 301 | start = start + num_w 302 | return start 303 | 304 | def load_conv_bn(self,buf,start,conv_model,bn_model): 305 | num_w = conv_model.weight.numel() 306 | num_b = bn_model.bias.numel() 307 | 308 | bn_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])) 309 | start =start +num_b 310 | bn_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_b])) 311 | start =start +num_b 312 | #cannot call .data on a torch.Tensor 313 | bn_model.running_mean.copy_(torch.from_numpy(buf[start:start+num_b])) 314 | start =start +num_b 315 | #cannot call .data on a torch.Tensor 316 | bn_model.running_var.copy_(torch.from_numpy(buf[start:start+num_b])) 317 | start =start +num_b 318 | conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w]).view(conv_model.weight.size())) 319 | start = start + num_w 320 | return start 321 | 322 | 323 | def load_weights(self,weight_file): 324 | print "weight_file {}".format(weight_file) 325 | fp = open(weight_file,'rb') 326 | major = np.fromfile(fp,count=1,dtype = np.int32) 327 | minor = np.fromfile(fp,count=1,dtype = np.int32) 328 | revision = np.fromfile(fp,count=1,dtype = np.int32) 329 | print "weight_file major {} minor {}".format(major,minor) 330 | if (major[0]*10 + minor[0] )>=2: 331 | print "using version 2" 332 | self.seen = np.fromfile(fp,count=1,dtype = np.int64) 333 | else: 334 | print "using version 1" 335 | self.seen = np.fromfile(fp,count=1,dtype = np.int32) 336 | print "weight file revision {} seen {}".format(revision,self.seen) 337 | buf = np.fromfile(fp,dtype = np.float32) 338 | #print "len(buf) = {} ".format(len(buf)) 339 | fp.close() 340 | start = 0 341 | #print self.models 342 | for ind,model in enumerate(self.models): 343 | if ind not in self.layerInd_has_no_weights: 344 | if start>= len(buf): 345 | continue 346 | if ind !=30: 347 | #print model[0] 348 | #print model[1] 349 | start = self.load_conv_bn(buf, start, model[0], model[1]) 350 | else: 351 | start = self.load_conv(buf,start,model[0]) 352 | print "weight file loading finished" 353 | def save_weights(self,weight_file): 354 | fp = open(weight_file,'wb') 355 | print "save weight to file {}".format(weight_file) 356 | header = np.asarray([0,0,0,self.seen],dtype=np.int32) 357 | header.tofile(fp) 358 | 359 | #save weights 360 | for ind,model in enumerate(self.models): 361 | if ind not in self.layerInd_has_no_weights: 362 | if ind !=30: 363 | self.save_conv_bn(fp,model[0],model[1]) 364 | else: 365 | self.save_conv(fp,model[0]) 366 | print "save weights finished" 367 | def convert2cpu(self,gpu_matrix): 368 | return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix) 369 | 370 | def convert2cpu_long(self,gpu_matrix): 371 | return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix) 372 | 373 | def save_conv(self,fp,conv_model): 374 | if conv_model.bias.is_cuda: 375 | self.convert2cpu(conv_model.bias.data).numpy().tofile(fp) 376 | self.convert2cpu(conv_model.weight.data).numpy().tofile(fp) 377 | else: 378 | conv_model.bias.data.numpy().tofile(fp) 379 | conv_model.weight.data.numpy().tofile(fp) 380 | 381 | def save_conv_bn(self,fp,conv_model,bn_model): 382 | if bn_model.bias.is_cuda: 383 | self.convert2cpu(bn_model.bias.data).numpy().tofile(fp) 384 | self.convert2cpu(bn_model.weight.data).numpy().tofile(fp) 385 | self.convert2cpu(bn_model.running_mean).numpy().tofile(fp) 386 | self.convert2cpu(bn_model.running_var).numpy().tofile(fp) 387 | self.convert2cpu(conv_model.weight.data).numpy().tofile(fp) 388 | else: 389 | bn_model.bias.data.numpy().tofile(fp) 390 | bn_model.weight.data.numpy().tofile(fp) 391 | bn_model.running_mean.numpy().tofile(fp) 392 | bn_model.running_var.numpy().tofile(fp) 393 | conv_model.weight.data.numpy().tofile(fp) 394 | 395 | 396 | 397 | 398 | 399 | 400 | -------------------------------------------------------------------------------- /models/yolo_v2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/models/yolo_v2.pyc -------------------------------------------------------------------------------- /models/yolo_v2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import sys 7 | sys.path.insert(0, "..") 8 | from utils.iou import bbox_iou 9 | 10 | import math 11 | 12 | class yolo_v2_loss(nn.Module): 13 | def __init__(self,num_classes,anchors_str,anchor_step): 14 | super(yolo_v2_loss,self).__init__() 15 | self.anchors_str = anchors_str 16 | self.anchors = [float(i) for i in anchors_str.split(',')] 17 | self.anchor_step = anchor_step 18 | self.num_classes = num_classes 19 | self.num_anchors = len(self.anchors)/anchor_step 20 | 21 | self.object_scale = 5 22 | self.noobject_scale = 1 23 | self.class_scale = 1 24 | self.coord_scale = 1 25 | self.seen = 0 26 | self.epoch = 0 27 | self.lr = 0 28 | self.seenbatches = 0 29 | self.thresh = 0.6 30 | self.tf_logger = None 31 | self.mse_loss = nn.MSELoss(size_average=False) 32 | 33 | def convert2cpu(self,gpu_matrix): 34 | return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix) 35 | 36 | def forward( self, output, target): 37 | #output: 38 | nB = output.data.size(0) 39 | nA = self.num_anchors 40 | nC = self.num_classes 41 | nH = output.data.size(2) 42 | nW = output.data.size(3) 43 | target = target.data 44 | nAnchors = nA*nH*nW 45 | nPixels = nH*nW 46 | 47 | output = output.view(nB, nA, (5+nC), nH, nW) 48 | 49 | tx_pred = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW)) 50 | ty_pred = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW)) 51 | tw_pred = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW) 52 | th_pred = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW) 53 | 54 | conf_pred = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)) 55 | conf_pred_cpu = self.convert2cpu(conf_pred.data) 56 | 57 | cls_preds = output.index_select(2, Variable(torch.linspace(5,5+nC-1,nC).long().cuda())) 58 | cls_preds = cls_preds.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC) 59 | 60 | #generate pred_bboxes 61 | pred_boxes = torch.cuda.FloatTensor(4, nB*nA*nH*nW) 62 | grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda() 63 | grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda() 64 | anchor_w = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([0])).cuda() 65 | anchor_h = torch.Tensor(self.anchors).view(nA, self.anchor_step).index_select(1, torch.LongTensor([1])).cuda() 66 | anchor_w = anchor_w.repeat(nB, 1).repeat(1, 1, nH*nW).view(nB*nA*nH*nW) 67 | anchor_h = anchor_h.repeat(nB, 1).repeat(1, 1, nH*nW).view(nB*nA*nH*nW) 68 | pred_boxes[0] = tx_pred.data.view(nB*nA*nH*nW) + grid_x 69 | pred_boxes[1] = ty_pred.data.view(nB*nA*nH*nW) + grid_y 70 | pred_boxes[2] = torch.exp(tw_pred.data.view(nB*nA*nH*nW)) * anchor_w 71 | pred_boxes[3] = torch.exp(th_pred.data.view(nB*nA*nH*nW)) * anchor_h 72 | pred_boxes = pred_boxes.transpose(0,1).contiguous().view(-1,4) 73 | pred_boxes_cpu = self.convert2cpu(pred_boxes) 74 | 75 | tx_target = torch.zeros(nB, nA, nH, nW) 76 | ty_target = torch.zeros(nB, nA, nH, nW) 77 | tw_target = torch.zeros(nB, nA, nH, nW) 78 | th_target = torch.zeros(nB, nA, nH, nW) 79 | coord_mask = torch.zeros(nB, nA, nH, nW) 80 | 81 | tconf_target = torch.zeros(nB, nA, nH, nW) 82 | conf_mask = torch.ones(nB, nA, nH, nW)*self.noobject_scale 83 | tcls_target = torch.zeros(nB, nA, nH, nW) 84 | cls_mask = torch.zeros(nB, nA, nH, nW) 85 | 86 | avg_anyobj = 0 87 | for b in xrange(nB): 88 | for j in xrange(nH): 89 | for i in xrange(nW): 90 | for n in xrange(nA): 91 | cur_pred_box = pred_boxes_cpu[b*nAnchors+n*nPixels+j*nW+i] 92 | best_iou = 0 93 | for t in xrange(30): 94 | if target[b][t*5+1] == 0: 95 | break 96 | gx = target[b][t*5+1]*nW 97 | gy = target[b][t*5+2]*nH 98 | gw = target[b][t*5+3]*nW 99 | gh = target[b][t*5+4]*nH 100 | cur_gt_box = np.array([gx,gy,gw,gh]) 101 | iou = bbox_iou(cur_pred_box.numpy(), cur_gt_box, x1y1x2y2=False) 102 | if iou > best_iou: 103 | best_iou = iou 104 | 105 | if best_iou > self.thresh: 106 | conf_mask[b][n][j][i] = 0 107 | #avg_anyobj += conf_pred_cpu.data[b][n][j][i] 108 | 109 | if self.seen < 12800: 110 | tx_target.fill_(0.5) 111 | ty_target.fill_(0.5) 112 | tw_target.zero_() 113 | th_target.zero_() 114 | coord_mask.fill_(0.01) 115 | 116 | nGT = 0 117 | nCorrect = 0 118 | avg_iou= 0 119 | avg_obj= 0 120 | ncount = 0 121 | for b in xrange(nB): 122 | for t in xrange(30): 123 | if target[b][t*5+1] == 0: 124 | break 125 | nGT = nGT + 1 126 | best_iou = 0.0 127 | best_n = -1 128 | min_dist = 10000 129 | gx = target[b][t*5+1] * nW 130 | gy = target[b][t*5+2] * nH 131 | gi = int(gx) 132 | gj = int(gy) 133 | gw = target[b][t*5+3] * nW 134 | gh = target[b][t*5+4] * nH 135 | gt_box = [0, 0, gw, gh] 136 | for n in xrange(nA): 137 | aw = self.anchors[self.anchor_step*n] 138 | ah = self.anchors[self.anchor_step*n+1] 139 | anchor_box = [0, 0, aw, ah] 140 | iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False) 141 | if iou > best_iou: 142 | best_iou = iou 143 | best_n = n 144 | 145 | gt_box = [gx, gy, gw, gh] 146 | pred_box = pred_boxes_cpu[b*nAnchors+best_n*nPixels+gj*nW+gi] 147 | #print "pred_box",pred_box 148 | 149 | tx_target[b][best_n][gj][gi] = target[b][t*5+1] * nW - gi 150 | ty_target[b][best_n][gj][gi] = target[b][t*5+2] * nH - gj 151 | tw_target[b][best_n][gj][gi] = math.log(target[b][t*5+3]* nW /self.anchors[self.anchor_step*best_n]) 152 | th_target[b][best_n][gj][gi] = math.log(target[b][t*5+4]* nH /self.anchors[self.anchor_step*best_n+1]) 153 | coord_mask[b][best_n][gj][gi] = self.coord_scale*(2-target[b][t*5+3]*target[b][t*5+4]) 154 | 155 | iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) # best_iou 156 | tconf_target[b][best_n][gj][gi] = iou 157 | conf_mask[b][best_n][gj][gi] = self.object_scale 158 | 159 | cls_mask[b][best_n][gj][gi] = 1 160 | tcls_target[b][best_n][gj][gi] = target[b][t*5] 161 | 162 | #print "b {} t {} iou {}".format(b,t,iou) 163 | if iou > 0.5: 164 | nCorrect = nCorrect + 1 165 | avg_iou += iou 166 | ncount +=1 167 | avg_obj += conf_pred_cpu[b][best_n][gj][gi] 168 | 169 | 170 | coord_mask_gpu = Variable(coord_mask.cuda()) 171 | tx_target_gpu = Variable(tx_target.cuda()) 172 | ty_target_gpu = Variable(ty_target.cuda()) 173 | tw_target_gpu = Variable(tw_target.cuda()) 174 | th_target_gpu = Variable(th_target.cuda()) 175 | 176 | loss_x = self.mse_loss(tx_pred*coord_mask_gpu, tx_target_gpu*coord_mask_gpu)/2.0 177 | loss_y = self.mse_loss(ty_pred*coord_mask_gpu, ty_target_gpu*coord_mask_gpu)/2.0 178 | loss_w = self.mse_loss(tw_pred*coord_mask_gpu, tw_target_gpu*coord_mask_gpu)/2.0 179 | loss_h = self.mse_loss(th_pred*coord_mask_gpu, th_target_gpu*coord_mask_gpu)/2.0 180 | 181 | conf_mask_gpu = Variable(conf_mask.cuda()) 182 | tconf_target_gpu = Variable(tconf_target.cuda()) 183 | loss_conf = self.mse_loss(conf_pred*conf_mask_gpu, tconf_target_gpu*conf_mask_gpu)/2.0 184 | 185 | cls_mask = (cls_mask == 1) 186 | tcls_target_gpu = Variable(tcls_target.view(-1)[cls_mask].long().cuda()) 187 | cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda()) 188 | cls_preds = cls_preds[cls_mask].view(-1, nC) 189 | loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls_preds, tcls_target_gpu) 190 | loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls 191 | 192 | nProposals = int((conf_pred_cpu > 0.25).sum()) 193 | print 'epoch: %d,seenB: %d,seenS: %d, nProposal %d, GtAvgIOU: %f, AvgObj: %f, AvgRecall: %f, count: %d'%(self.epoch,self.seenbatches,self.seen,nProposals,avg_iou/ncount,avg_obj/ncount,nCorrect*1.0/ncount,ncount) 194 | print('---->lr: %f,loss: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f' % (self.lr,loss_x.data[0], loss_y.data[0], loss_w.data[0], loss_h.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0])) 195 | if self.tf_logger is not None: 196 | self.tf_logger.scalar_summary("loss_x", loss_x.data[0], self.seenbatches) 197 | self.tf_logger.scalar_summary("loss_y", loss_y.data[0], self.seenbatches) 198 | self.tf_logger.scalar_summary("loss_w", loss_w.data[0], self.seenbatches) 199 | self.tf_logger.scalar_summary("loss_h", loss_h.data[0], self.seenbatches) 200 | self.tf_logger.scalar_summary("loss_conf", loss_conf.data[0], self.seenbatches) 201 | self.tf_logger.scalar_summary("loss_cls", loss_cls.data[0], self.seenbatches) 202 | self.tf_logger.scalar_summary("loss_total", loss.data[0], self.seenbatches) 203 | 204 | 205 | self.tf_logger.scalar_summary("GtAvgIOU", avg_iou/ncount, self.seenbatches) 206 | self.tf_logger.scalar_summary("AvgObj", avg_obj/ncount, self.seenbatches) 207 | self.tf_logger.scalar_summary("AvgRecall", nCorrect*1.0/ncount, self.seenbatches) 208 | self.tf_logger.scalar_summary("count", ncount, self.seenbatches) 209 | 210 | 211 | return loss -------------------------------------------------------------------------------- /models/yolo_v2_loss.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/models/yolo_v2_loss.pyc -------------------------------------------------------------------------------- /models/yolo_v2_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | from torch.nn.utils.rnn import pack_padded_sequence 7 | from torch.autograd import Variable 8 | 9 | 10 | class yolo_v2_resnet(nn.Module): 11 | def __init__(self): 12 | """Load the pretrained ResNet-152 and replace top fc layer.""" 13 | super(yolo_v2_resnet, self).__init__() 14 | resnet = models.resnet50(pretrained=True) 15 | modules = list(resnet.children())[:-2] # delete the last fc layer and avrpool layer 16 | self.resnet = nn.Sequential(*modules) 17 | self.region_layer = nn.Conv2d(2048,125,1,1,0) 18 | self.width = 416 19 | self.height = 416 20 | self.header = torch.IntTensor([0,0,0,0]) 21 | self.seen = 0 22 | self.anchors_str = "1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071" 23 | self.num_classes = 20 24 | self.anchor_step = 2 25 | self.anchors = [float(i) for i in self.anchors_str.lstrip().rstrip().split(',')] 26 | self.num_anchors = len(self.anchors)/self.anchor_step 27 | self.network_name = "yolo_v2_resnet" 28 | 29 | def forward(self, images): 30 | """Extract the image feature vectors.""" 31 | features = self.resnet(images) 32 | features = Variable(features.data) 33 | features = self.region_layer(features) 34 | return features 35 | 36 | def load_weights(self,weight_file): 37 | if os.path.isfile(weight_file,): 38 | print "load weight file: {}".format(weight_file) 39 | self.load_state_dict(torch.load(weight_file)) 40 | else: 41 | print "weight file {} doesn't exist".format(weight_file) 42 | 43 | def save_weights(self,weight_dir): 44 | if os.path.exists(weight_dir): 45 | torch.save(self.state_dict(),weight_dir) 46 | 47 | def convert2cpu(self,gpu_matrix): 48 | return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix) 49 | 50 | def convert2cpu_long(self,gpu_matrix): 51 | return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix) 52 | 53 | 54 | def get_region_boxes(self, output,conf_thresh): 55 | anchor_step = self.anchor_step 56 | num_classes = self.num_classes 57 | num_anchors = self.num_anchors 58 | anchors = self.anchors 59 | if output.dim() ==3: 60 | output = output.unsequence(0) 61 | batch = output.size(0) 62 | assert(output.size(1) == (5+num_classes)*num_anchors) 63 | h = output.size(2) 64 | w = output.size(3) 65 | 66 | 67 | output = output.view(batch*num_anchors,5+num_classes,h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w) 68 | 69 | grid_x = torch.linspace(0,w-1,w).repeat(h,1).repeat(batch*num_anchors,1,1).view(batch*num_anchors*h*w).cuda() 70 | grid_y = torch.linspace(0,h-1,h).repeat(w,1).t().repeat(batch*num_anchors,1,1).view(batch*num_anchors*h*w).cuda() 71 | 72 | cx = torch.sigmoid(output[0]) + grid_x 73 | cy = torch.sigmoid(output[1]) + grid_y 74 | anchor_w = torch.Tensor(anchors).view(num_anchors,anchor_step).index_select(1,torch.LongTensor([0])) 75 | anchor_h = torch.Tensor(anchors).view(num_anchors,anchor_step).index_select(1,torch.LongTensor([1])) 76 | anchor_w = anchor_w.repeat(batch,1).repeat(1,1,h*w).view(batch*num_anchors*h*w).cuda() 77 | anchor_h = anchor_h.repeat(batch,1).repeat(1,1,h*w).view(batch*num_anchors*h*w).cuda() 78 | ws = torch.exp(output[2])*anchor_w 79 | hs = torch.exp(output[3])*anchor_h 80 | 81 | def_confs = torch.sigmoid(output[4]) 82 | 83 | nnSoftmax = torch.nn.Softmax() 84 | 85 | cls_confs = nnSoftmax(Variable(output[5:5+num_classes].transpose(0,1))).data 86 | cls_max_confs,cls_max_ids = torch.max(cls_confs,1) 87 | cls_max_confs = cls_max_confs.view(-1) 88 | cls_max_ids = cls_max_ids.view(-1) 89 | 90 | def_confs = self.convert2cpu(def_confs) 91 | cls_max_confs = self.convert2cpu(cls_max_confs) 92 | cls_max_ids = self.convert2cpu_long(cls_max_ids) 93 | cx = self.convert2cpu(cx) 94 | cy = self.convert2cpu(cy) 95 | ws = self.convert2cpu(ws) 96 | hs = self.convert2cpu(hs) 97 | 98 | all_boxes = [] 99 | for b in range(batch): 100 | boxes = [] 101 | for row in range(h): 102 | for col in range(w): 103 | for i in range(num_anchors): 104 | ind = b*h*w*num_anchors + i*h*w + row*w + col 105 | conf = def_confs[ind]*cls_max_confs[ind] 106 | if conf >conf_thresh: 107 | bcx = cx[ind] 108 | bcy = cy[ind] 109 | bw = ws[ind] 110 | bh = hs[ind] 111 | #print "bbox {} {} {} {}".format(bcx,bcy,bw,bh) 112 | box = [bcx/w,bcy/h,bw/w,bh/h,def_confs[ind],cls_max_confs[ind],cls_max_ids[ind]] 113 | boxes.append(box) 114 | all_boxes.append(boxes) 115 | return all_boxes -------------------------------------------------------------------------------- /models/yolo_v2_resnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/models/yolo_v2_resnet.pyc -------------------------------------------------------------------------------- /predictions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/predictions.jpg -------------------------------------------------------------------------------- /scripts/demo_detect.sh: -------------------------------------------------------------------------------- 1 | 2 | #python detect.py data/voc.names backup/yolo_v2_v2_000100.weights data/eagle.jpg 3 | #python detect.py data/voc.names weights/yolo-voc.weights data/eagle.jpg 4 | 5 | python detect.py data/voc.names weights/save_test.weights data/eagle.jpg -------------------------------------------------------------------------------- /scripts/demo_eval.sh: -------------------------------------------------------------------------------- 1 | #python eval.py data/voc.data results/voc_yolo-voc_20170928_133944 2 | python eval.py data/voc.data results/voc_yolo_v2_v2_000100_20171106_100352 3 | -------------------------------------------------------------------------------- /scripts/demo_start_tensorboard.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | tensorboard --logdir='./logs/log_20171104_112628' --port=6006 4 | -------------------------------------------------------------------------------- /scripts/demo_train.sh: -------------------------------------------------------------------------------- 1 | 2 | #tray yolo_v2 3 | #python train.py data/voc.data weights/darknet19_448.conv.23 4 | python train.py data/voc.data weights/000100.weights 5 | #test valid part in train.py 6 | #python train.py data/voc.data weights/yolo-voc.weights 7 | 8 | #train yolo_v2_resnet 9 | #python train.py data/voc.data weights/darknet19_448.conv.23.123 -------------------------------------------------------------------------------- /scripts/demo_valid.sh: -------------------------------------------------------------------------------- 1 | #python valid.py data/voc.data weights/yolo-voc.weights 2 | #python valid.py data/voc.data weights/000090.weights 3 | python valid.py data/voc.data backup/yolo_v2_v2_000100.weights 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | 6 | from models.yolo_v2 import yolo_v2 7 | from models.yolo_v2_resnet import yolo_v2_resnet 8 | from models.yolo_v2_loss import yolo_v2_loss 9 | 10 | from utils.cfg_loader import read_data_cfg 11 | from utils.iou import bbox_iou 12 | from utils.nms import nms 13 | from dataset_factory.VOCDataset import VOCDataset 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | import torch.optim as optim 19 | import torch.optim.lr_scheduler as lr_scheduler 20 | from torchvision import transforms 21 | 22 | from utils.logger import Logger 23 | 24 | def file_lines(filepath): 25 | with open(filepath) as lines: 26 | return sum(1 for line in lines) 27 | 28 | def logging(message): 29 | print('%s %s' % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), message)) 30 | 31 | if __name__=="__main__": 32 | if len(sys.argv) == 3: 33 | datacfg = sys.argv[1] 34 | weightfile = sys.argv[2] 35 | backupdir = 'backup' 36 | 37 | #training settings 38 | data_options = read_data_cfg(datacfg) 39 | train_set_file = data_options['train'] 40 | train_samples = file_lines(train_set_file) 41 | valid_set_file = data_options['valid'] 42 | 43 | #training parameters 44 | batch_size = 40 45 | init_epoch = 0 46 | max_epochs = 100 47 | learning_rate = 0.001 48 | momentum = 0.9 49 | decay = 0.0005 50 | seen_samples = 0 51 | processed_batches = 0 52 | save_interval = 20 53 | ngpus = 2 54 | steps = [int(i) for i in "0,100,25000,35000".split(",") ] 55 | scales = [float(i) for i in "0.1,10,0.1,0.1".split(",")] 56 | 57 | #test parameters 58 | conf_thresh = 0.25 59 | nms_thresh = 0.45 60 | iou_thresh = 0.5 61 | valid_interval = 20 62 | 63 | #logger for tensorboard 64 | log_dir = "./logs/log_%s" %(time.strftime("%Y%m%d_%H%M%S", time.localtime())) 65 | if not os.path.exists(log_dir): 66 | os.mkdir(log_dir) 67 | tf_logger = Logger(log_dir) 68 | 69 | if not os.path.exists(backupdir): 70 | os.mkdir(backupdir) 71 | 72 | # construct model 73 | #model = yolo_v2_resnet() 74 | old_model = yolo_v2() 75 | # to do load pretrained partial file 76 | old_model.load_weights(weightfile) 77 | seen_samples = old_model.seen 78 | if torch.cuda.is_available(): 79 | #os.environ['CUDA_VISIBLE_DEVICES']='0' 80 | torch.cuda.manual_seed(int(time.time())) 81 | old_model.cuda() 82 | model = old_model 83 | if ngpus >1: 84 | model = nn.DataParallel(model,device_ids=[0,2]) 85 | 86 | optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size) 87 | 88 | #scheduler = lr_scheduler.StepLR(optimizer,step_size = 30 ,gamma=0.1) 89 | def adjust_learning_rate(optimizer,batchid,learning_rate,batch_size): 90 | for i in range(len(steps)): 91 | if batchid == steps[i]: 92 | learning_rate = learning_rate*scales[i] 93 | for param_group in optimizer.param_groups: 94 | param_group['lr'] = learning_rate/batch_size 95 | return learning_rate 96 | 97 | 98 | #load train image set 99 | with open(train_set_file,'r') as fp: 100 | train_image_files = fp.readlines() 101 | train_image_files = [file.rstrip() for file in train_image_files] 102 | 103 | #load valid image set 104 | with open(valid_set_file,'r') as fp: 105 | valid_image_files = fp.readlines() 106 | valid_image_files = [file.rstrip() for file in valid_image_files] 107 | 108 | 109 | if old_model.network_name == 'yolo_v2': 110 | print 'img_trans has no transforms.Normalize' 111 | img_trans = transforms.Compose([transforms.ToTensor(),]) 112 | 113 | elif old_model.network_name == 'yolo_v2_resnet': 114 | print 'img_trans has transforms.Normalize' 115 | img_trans = transforms.Compose([transforms.ToTensor(), 116 | transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]) 117 | 118 | #construct valid data loader 119 | train_dataset = VOCDataset(train_image_files,shape=(old_model.width,old_model.height),shuffle=True,train_phase=True,transform=img_trans) 120 | train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=batch_size,pin_memory=True) 121 | #logging('training with %d samples' % (len(train_loader.dataset))) 122 | # 123 | valid_dataset = VOCDataset(valid_image_files,shape=(old_model.width,old_model.height),shuffle=False,transform=img_trans) 124 | valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,shuffle=False,num_workers=40,pin_memory=True) 125 | 126 | def truths_length(truths): 127 | for i in range(50): 128 | if truths[i][1] == 0: 129 | return i 130 | 131 | loss_func = yolo_v2_loss(old_model.num_classes,old_model.anchors_str,old_model.anchor_step) 132 | loss_func.tf_logger = tf_logger 133 | for epoch in range(init_epoch, max_epochs): 134 | #for epoch in range(0,1): 135 | if 1: 136 | model.train() 137 | for batch_idx, (images, target) in enumerate(train_loader): 138 | if processed_batches in steps: 139 | 140 | learning_rate = adjust_learning_rate(optimizer,processed_batches,learning_rate,batch_size) 141 | print "learning rate changed to {}".format(learning_rate) 142 | #scheduler.step() 143 | #logging('epoch %d,all_batches %d, batch_size %d, lr %f' % (epoch+1,processed_batches+1,batch_size, learning_rate)) 144 | if torch.cuda.is_available(): 145 | images = images.cuda() 146 | images_var = Variable(images) 147 | target_var = Variable(target) 148 | optimizer.zero_grad() 149 | output = model(images_var) 150 | seen_samples = seen_samples + images_var.data.size(0) 151 | loss_func.seen = seen_samples 152 | loss_func.epoch = epoch 153 | loss_func.lr = learning_rate 154 | loss_func.seenbatches = processed_batches 155 | loss = loss_func(output, target_var) 156 | 157 | loss.backward() 158 | optimizer.step() 159 | processed_batches = processed_batches + 1 160 | 161 | if (epoch+1) % save_interval == 0: 162 | extension = 'tmp' 163 | if old_model.network_name == 'yolo_v2': 164 | extension = 'weights' 165 | elif old_model.network_name == 'yolo_v2_resnet': 166 | extension = 'pth' 167 | logging('save weights to %s/%s_v2_%06d.%s' % (backupdir,old_model.network_name,epoch+1,extension)) 168 | model.module.seen = seen_samples #(epoch + 1) * len(train_loader.dataset) 169 | model.module.save_weights('%s/%s_v2_%06d.%s' % (backupdir,old_model.network_name,epoch+1,extension)) 170 | 171 | #valid process 172 | if 1: 173 | if epoch % valid_interval == 0 : 174 | total = 0.0 175 | proposals = 0.0 176 | correct = 0.0 177 | eps = 1e-5 178 | model.eval() 179 | for batch_idx, (data, target) in enumerate(valid_loader): 180 | if torch.cuda.is_available(): 181 | data = data.cuda() 182 | data = Variable(data, volatile=True) 183 | output = model(data).data 184 | all_boxes = old_model.get_region_boxes(output, conf_thresh) 185 | for i in range(output.size(0)): 186 | boxes = all_boxes[i] 187 | boxes = nms(boxes, nms_thresh) 188 | truths = target[i].view(-1, 5) 189 | num_gts = truths_length(truths) 190 | total = total + num_gts 191 | for i in range(len(boxes)): 192 | if boxes[i][4]*boxes[i][5] > conf_thresh: 193 | proposals = proposals+1 194 | for i in range(num_gts): 195 | box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]] 196 | for j in range(len(boxes)): 197 | iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False) 198 | if iou > iou_thresh and boxes[j][6] == box_gt[6]: 199 | correct = correct+1 200 | precision = 1.0*correct/(proposals+eps) 201 | recall = 1.0*correct/(total+eps) 202 | fscore = 2.0*precision*recall/(precision+recall+eps) 203 | logging("valid process precision: %f, recall: %f, fscore: %f" % (precision, recall, fscore)) 204 | tf_logger.scalar_summary("valid precision",precision,epoch) 205 | tf_logger.scalar_summary("valid recall",recall,epoch) 206 | else: 207 | print("Usage:") 208 | print("python train.py datacfg weightfile") -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/__init__.py -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/__init__.pyc -------------------------------------------------------------------------------- /utils/cfg_loader.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | def load_class_names(namesfile): 4 | class_names = [] 5 | with open(namesfile,'r') as fp: 6 | lines = fp.readlines() 7 | for line in lines: 8 | line = line.rstrip() 9 | class_names.append(line) 10 | return class_names 11 | 12 | 13 | def read_data_cfg(datacfg): 14 | options = dict() 15 | with open(datacfg,'r') as fp: 16 | lines = fp.readlines() 17 | 18 | for line in lines: 19 | line = line.strip() 20 | if len(line)==0: 21 | return options 22 | key,value = line.split('=') 23 | key = key.strip() 24 | value = value.strip() 25 | options[key] = value 26 | return options 27 | 28 | 29 | -------------------------------------------------------------------------------- /utils/cfg_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/cfg_loader.pyc -------------------------------------------------------------------------------- /utils/iou.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def bbox_iou(box1,box2, x1y1x2y2=True): 5 | if x1y1x2y2: 6 | mx= np.minimum(box1[0],box2[0]) 7 | Mx= np.maximum(box1[2],box2[2]) 8 | my= np.minimum(box1[1],box2[1]) 9 | My= np.maximum(box1[3],box2[3]) 10 | w1= box1[2] - box1[0] 11 | h1= box1[3] - box1[1] 12 | w2= box2[2] - box2[0] 13 | h2= box2[3] - box2[1] 14 | else: 15 | mx= np.minimum(box1[0]-box1[2]/2.0,box2[0]-box2[2]/2.0 ) 16 | Mx= np.maximum(box1[0]+box1[2]/2.0,box2[0]+box2[2]/2.0 ) 17 | my= np.minimum(box1[1]-box1[3]/2.0,box2[1]-box2[3]/2.0 ) 18 | My= np.maximum(box1[1]+box1[3]/2.0,box2[1]+box2[3]/2.0 ) 19 | w1= box1[2] 20 | h1= box1[3] 21 | w2= box2[2] 22 | h2= box2[3] 23 | uw = Mx-mx 24 | uh = My-my 25 | cw = w1 + w2 - uw 26 | ch = h1 + h2 - uh 27 | carea = 0 28 | if not isinstance(cw,np.ndarray): 29 | if cw<=0 or ch <=0: 30 | return 0.0 31 | area1 = w1 * h1 32 | area2 = w2 * h2 33 | carea = cw*ch 34 | if isinstance(cw,np.ndarray): 35 | carea[cw<=0] = 0 36 | uarea = area1 + area2 - carea 37 | 38 | return carea/uarea -------------------------------------------------------------------------------- /utils/iou.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/iou.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | -------------------------------------------------------------------------------- /utils/logger.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/logger.pyc -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .iou import bbox_iou 3 | 4 | def nms(boxes,nms_thresh): 5 | if len(boxes)==0: 6 | return boxes 7 | def_confs = torch.zeros(len(boxes)) 8 | for i in range(len(boxes)): 9 | def_confs[i] = 1- boxes[i][4]*boxes[i][5] 10 | 11 | _,sortIdx = torch.sort(def_confs) 12 | out_boxes = [] 13 | for i in range(len(boxes)): 14 | box_i = boxes[sortIdx[i]] 15 | if box_i[4] >0 : 16 | out_boxes.append(box_i) 17 | for j in range(i+1,len(boxes)): 18 | box_j = boxes[sortIdx[j]] 19 | if bbox_iou(box_i,box_j,x1y1x2y2=False) > nms_thresh: 20 | box_j[4] = 0 21 | return out_boxes -------------------------------------------------------------------------------- /utils/nms.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pluto16/pytorch_detection/4e8ebbba266fefaf78e66e6e9fef083ebd751a7b/utils/nms.pyc -------------------------------------------------------------------------------- /valid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | from torch.autograd import Variable 6 | from torchvision import datasets,transforms 7 | 8 | from PIL import Image 9 | from models.yolo_v2 import yolo_v2 10 | from utils.cfg_loader import load_class_names 11 | from utils.cfg_loader import read_data_cfg 12 | 13 | from dataset_factory.VOCDataset import VOCDataset 14 | from utils.nms import nms 15 | import time 16 | 17 | 18 | 19 | def valid(datacfg,weight_file,outfile_prefix): 20 | 21 | options = read_data_cfg(datacfg) 22 | valid_images_set_file = options['valid'] 23 | namesfile = options['names'] 24 | 25 | #load class names 26 | class_names = load_class_names(namesfile) 27 | #load valid image 28 | with open(valid_images_set_file,'r') as fp: 29 | image_files = fp.readlines() 30 | image_files = [file.rstrip() for file in image_files] 31 | 32 | 33 | model = yolo_v2() 34 | model.load_weights(weight_file) 35 | 36 | 37 | print("weights %s loaded"%(weight_file)) 38 | if torch.cuda.is_available(): 39 | model.cuda() 40 | model.eval() 41 | 42 | #result file 43 | fps = [0]*model.num_classes 44 | if not os.path.exists('results'): 45 | os.mkdir('results') 46 | dir_name = 'results/%s_%s_%s' %(namesfile.split('/')[-1].split('.')[0],weight_file.split('/')[-1].split('.')[0],time.strftime("%Y%m%d_%H%M%S",time.localtime())) 47 | print 'save results to %s'%(dir_name) 48 | if not os.path.exists(dir_name): 49 | os.mkdir(dir_name) 50 | for i in range(model.num_classes): 51 | buf ="%s/%s_%s.txt" % (dir_name,outfile_prefix,class_names[i]) 52 | fps[i] = open(buf,'w') 53 | 54 | #construct datalist 55 | valid_dataset = VOCDataset(image_files,shape=(model.width,model.height),shuffle=False,transform=transforms.Compose([transforms.ToTensor(),])) 56 | valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=4,shuffle=False,num_workers=4,pin_memory=True) 57 | 58 | conf_thresh = 0.005 59 | nms_thresh = 0.45 60 | LineId = -1 61 | for batch_index,(data,target) in enumerate(valid_loader): 62 | data = data.cuda() 63 | data = Variable( data, volatile= True) 64 | output = model(data).data 65 | batch_boxes = model.get_region_boxes(output,conf_thresh) 66 | for i in range(len(batch_boxes)): 67 | boxes = batch_boxes[i] 68 | boxes = nms(boxes,nms_thresh) 69 | 70 | LineId = LineId +1 71 | image_name = image_files[LineId] 72 | print "[Batch_index:%d] [%d/%d] file:%s "%(batch_index,LineId+1,len(image_files),image_name) 73 | 74 | img_orig = Image.open(image_name) 75 | #print img_orig 76 | height,width =img_orig.height,img_orig.width 77 | print " height %d, width %d, bbox num %d" % (height,width,len(boxes)) 78 | for box in boxes: 79 | x1 = (box[0] - box[2]/2.0)*width 80 | y1 = (box[1] - box[3]/2.0)*height 81 | x2 = (box[0] + box[2]/2.0)*width 82 | y2 = (box[1] + box[3]/2.0)*height 83 | det_conf = box[4] 84 | cls_conf = box[5] 85 | cls_id = box[6] 86 | fps[cls_id].write("%s %f %f %f %f %f\n"%(image_name,det_conf*cls_conf,x1,y1,x2,y2)) 87 | 88 | for i in range(model.num_classes): 89 | fps[i].close() 90 | 91 | #get average precision using voc standard 92 | 93 | if __name__=="__main__": 94 | if len(sys.argv) == 3: 95 | datacfg = sys.argv[1] 96 | weightfile = sys.argv[2] 97 | outfile = 'comp4_det_test' 98 | valid(datacfg,weightfile,outfile) 99 | else: 100 | print("Usage:") 101 | print("python valid.py datacfg weightfile") 102 | 103 | 104 | 105 | 106 | --------------------------------------------------------------------------------