├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data ├── acc_curve.png ├── lsp │ └── images │ │ ├── lsp_dataset │ │ └── lspet_dataset ├── mpii │ ├── mean.pth.tar │ └── mpii_annotations.json └── mscoco │ └── README.md ├── evaluation ├── data │ ├── detections.mat │ └── detections_our_format.mat ├── eval_PCKh.m ├── eval_PCKh.py ├── showskeletons_joints.m └── utils.py ├── example ├── lsp.py ├── mpii.py └── mscoco.py ├── miscs ├── cocoScale.m ├── gen_coco.m ├── gen_lsp.m └── gen_mpii.m ├── pose ├── __init__.py ├── datasets │ ├── __init__.py │ ├── lsp.py │ ├── mpii.py │ └── mscoco.py ├── models │ ├── __init__.py │ ├── hourglass.py │ └── preresnet.py └── utils │ ├── __init__.py │ ├── evaluation.py │ ├── imutils.py │ ├── logger.py │ ├── misc.py │ ├── osutils.py │ └── transforms.py ├── requirements.txt └── tools ├── mpii_demo.py └── mpii_export_to_onxx.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | checkpoint 3 | dev 4 | *.pth.tar 5 | data/mpii/images 6 | !data/mpii/mean.pth.tar 7 | *.json 8 | *debug* 9 | *.idea/* 10 | test_transforms.py 11 | experiments 12 | data/mscoco/coco 13 | data/mscoco/keypoint 14 | *.m~ 15 | miscs/posetrack 16 | miscs/h36m 17 | data/h36m 18 | 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *,cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # IPython Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pose/progress"] 2 | path = pose/progress 3 | url = https://github.com/verigak/progress.git 4 | [submodule "miscs/jsonlab"] 5 | path = miscs/jsonlab 6 | url = https://github.com/fangq/jsonlab 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-Pose 2 | 3 | PyTorch-Pose is a PyTorch implementation of the general pipeline for 2D single human pose estimation. The aim is to provide the interface of the training/inference/evaluation, and the dataloader with various data augmentation options for the most popular human pose databases (e.g., [the MPII human pose](http://human-pose.mpi-inf.mpg.de), [LSP](http://www.comp.leeds.ac.uk/mat4saj/lsp.html) and [FLIC](http://bensapp.github.io/flic-dataset.html)). 4 | 5 | Some codes for data preparation and augmentation are brought from the [Stacked hourglass network](https://github.com/anewell/pose-hg-train). Thanks to the original author. 6 | 7 | ## Models 8 | | Model|in_res |featrues| # of Weights |Head|Shoulder| Elbow| Wrist| Hip |Knee| Ankle| Mean|Link| 9 | | --- |---| ----|----------- | ----| ----| ---| ---| ---| ---| ---| ---|----| 10 | | hg_s2_b1|256|128|6.73m| 95.74| 94.51| 87.68| 81.70| 87.81| 80.88 |76.83| 86.58|[GoogleDrive](https://drive.google.com/open?id=1c_YR0NKmRfRvLcNB5wFpm75VOkC9Y1n4) 11 | | hg_s2_b1_mobile|256|128|2.31m|95.80| 93.61| 85.50| 79.63| 86.13| 77.82| 73.62| 84.69|[GoogleDrive](https://drive.google.com/open?id=1FxTRhiw6_dS8X1jBBUw_bxHX6RoBJaJO) 12 | | hg_s2_b1_tiny|192|128|2.31m|94.95| 92.87|84.59| 78.19| 84.68| 77.70| 73.07| 83.88|[GoogleDrive](https://drive.google.com/open?id=1qrkaUDPbHwdSBozRbN150O4Mu9HMWIOG) 13 | 14 | 15 | ## Installation 16 | 1. Create a virtualenv 17 | ``` 18 | virtualenv -p /usr/bin/python2.7 posevenv 19 | ``` 20 | 2. Install all dependencies in virtualenv 21 | ``` 22 | source posevenv/bin/activate 23 | pip install -r requirements.txt 24 | ``` 25 | 3. Clone the repository with submodule 26 | ``` 27 | git clone --recursive https://github.com/yuanyuanli85/pytorch-pose.git 28 | ``` 29 | 30 | 4. Create a symbolic link to the `images` directory of the MPII dataset: 31 | ``` 32 | ln -s PATH_TO_MPII_IMAGES_DIR data/mpii/images 33 | ``` 34 | 35 | 5. Disable cudnn for batchnorm layer to solve bug in pytorch0.4.0 36 | ``` 37 | sed -i "1194s/torch\.backends\.cudnn\.enabled/False/g" ./pose_venv/lib/python2.7/site-packages/torch/nn/functional.py 38 | ``` 39 | ## Training 40 | 41 | * Normal network configuration, in_res 256, features 128 42 | ```sh 43 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1/ --in_res 256 --features 256 44 | ``` 45 | 46 | * Mobile network configuration, in_res 256, features 128 47 | ```sh 48 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1_mobile/ --mobile True --in_res 256 --features 256 49 | ``` 50 | 51 | * Tiny network configuration, in_res 192, features 128 52 | ```sh 53 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1_tiny/ --mobile True --in_res 192 --features 128 54 | ``` 55 | 56 | ## Evaluation 57 | 58 | Run evaluation to generate mat file 59 | ```sh 60 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1/ --resume checkpoint/hg_s2_b1/model_best.pth.tar -e 61 | ``` 62 | * `--resume_checkpoint` is the checkpoint want to evaluate 63 | 64 | Run `evaluation/eval_PCKh.py` to get val score 65 | 66 | ## Export pytorch checkpoint to onnx 67 | ```sh 68 | python tools/mpii_export_to_onxx.py -a hg -s 2 -b 1 --num-classes 16 --mobile True --in_res 256 --checkpoint checkpoint/model_best.pth.tar 69 | --out_onnx checkpoint/model_best.onnx 70 | ``` 71 | Here 72 | * `--checkpoint` is the checkpoint want to export 73 | * `--out_onnx` is the exported onnx file 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /data/acc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/data/acc_curve.png -------------------------------------------------------------------------------- /data/lsp/images/lsp_dataset: -------------------------------------------------------------------------------- 1 | /home/wyang/Data/dataset/LSP -------------------------------------------------------------------------------- /data/lsp/images/lspet_dataset: -------------------------------------------------------------------------------- 1 | /home/wyang/Data/dataset/LSP_ext/ -------------------------------------------------------------------------------- /data/mpii/mean.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/data/mpii/mean.pth.tar -------------------------------------------------------------------------------- /data/mscoco/README.md: -------------------------------------------------------------------------------- 1 | ## Directory structure 2 | 3 | - `coco`: [coco API](https://github.com/pdollar/coco) 4 | - `keypoint`: COCO keypoint dataset 5 | - `images`: tain and val datasets 6 | - `train2014` 7 | - `val2014` 8 | - `person_keypoints_train+val5k2014`: annotations (JSON files) 9 | - `coco_annotations.json`: reformatted annotation generated by `./miscs/gen_coco.m` 10 | -------------------------------------------------------------------------------- /evaluation/data/detections.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/evaluation/data/detections.mat -------------------------------------------------------------------------------- /evaluation/data/detections_our_format.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/evaluation/data/detections_our_format.mat -------------------------------------------------------------------------------- /evaluation/eval_PCKh.m: -------------------------------------------------------------------------------- 1 | % wei 2 | addpath('./utils'); 3 | 4 | % set `debug = true` if you want to visualize the skeletons 5 | % You also need to download the MPII dataset and specify the path of 6 | % annopath = `mpii_human_pose_v1_u12_1.mat` 7 | debug = false; 8 | annopath = 'path_to/mpii_human_pose_v1_u12_1.mat'; 9 | 10 | load('data/detections.mat'); 11 | tompson_i = RELEASE_img_index; 12 | 13 | threshold = 0.5; 14 | SC_BIAS = 0.6; % THIS IS DEFINED IN util_get_head_size.m 15 | 16 | pa = [2, 3, 7, 7, 4, 5, 8, 9, 10, 0, 12, 13, 8, 8, 14, 15]; 17 | 18 | load('data/detections_our_format.mat', 'dataset_joints', 'jnt_missing', 'pos_pred_src', 'pos_gt_src', 'headboxes_src'); 19 | 20 | % predictions 21 | predfile = '/home/wyang/code/pose/pytorch-pose/checkpoint/mpii/hg_s2_b1_mean/preds_valid.mat'; 22 | preds = load(predfile,'preds'); 23 | pos_pred_src = permute(preds.preds, [2, 3, 1]); 24 | 25 | % DEBUG 26 | if debug 27 | mat = load(annopath); 28 | 29 | for i = 1:length(tompson_i) 30 | imname = mat.RELEASE.annolist(tompson_i(i)).image.name; 31 | fprintf('%s\n', imname); 32 | im = imread(['/home/wyang/Data/dataset/mpii/images/' imname]); 33 | pred = pos_pred_src(:, :, i); 34 | showskeletons_joints(im, pred, pa); 35 | pause; clf; 36 | end 37 | end 38 | 39 | head = find(ismember(dataset_joints, 'head')); 40 | lsho = find(ismember(dataset_joints, 'lsho')); 41 | lelb = find(ismember(dataset_joints, 'lelb')); 42 | lwri = find(ismember(dataset_joints, 'lwri')); 43 | lhip = find(ismember(dataset_joints, 'lhip')); 44 | lkne = find(ismember(dataset_joints, 'lkne')); 45 | lank = find(ismember(dataset_joints, 'lank')); 46 | 47 | rsho = find(ismember(dataset_joints, 'rsho')); 48 | relb = find(ismember(dataset_joints, 'relb')); 49 | rwri = find(ismember(dataset_joints, 'rwri')); 50 | rhip = find(ismember(dataset_joints, 'rhip')); 51 | rkne = find(ismember(dataset_joints, 'rkne')); 52 | rank = find(ismember(dataset_joints, 'rank')); 53 | 54 | % Calculate PCKh again for a few joints just to make sure our evaluation 55 | % matches Leonid's... 56 | jnt_visible = 1 - jnt_missing; 57 | uv_err = pos_pred_src - pos_gt_src; 58 | uv_err = sqrt(sum(uv_err .* uv_err, 2)); 59 | headsizes = headboxes_src(2,:,:) - headboxes_src(1,:,:); 60 | headsizes = sqrt(sum(headsizes .* headsizes, 2)); 61 | headsizes = headsizes * SC_BIAS; 62 | scaled_uv_err = squeeze(uv_err ./ repmat(headsizes, size(uv_err, 1), 1, 1)); 63 | 64 | % Zero the contribution of joints that are missing 65 | scaled_uv_err = scaled_uv_err .* jnt_visible; 66 | jnt_count = squeeze(sum(jnt_visible, 2)); 67 | less_than_threshold = (scaled_uv_err < threshold) .* jnt_visible; 68 | PCKh = 100 * squeeze(sum(less_than_threshold, 2)) ./ jnt_count; 69 | 70 | % save PCK all 71 | range = (0:0.01:0.5); 72 | pckAll = zeros(length(range),16); 73 | for r = 1:length( range) 74 | threshold = range(r); 75 | less_than_threshold = (scaled_uv_err < threshold) .* jnt_visible; 76 | pckAll(r, :) = 100 * squeeze(sum(less_than_threshold, 2)) ./ jnt_count; 77 | 78 | end 79 | 80 | [~, name, ~] = fileparts(predfile); 81 | 82 | % Uncomment if you want to save the result 83 | % save(sprintf('pckAll-%s.mat', name), 'scaled_uv_err', 'pos_pred_src'); 84 | 85 | clc; 86 | fprintf(' Head , Shoulder , Elbow , Wrist , Hip , Knee , Ankle , Mean , \n'); 87 | fprintf('name , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f% , \n',... 88 | PCKh(head), (PCKh(lsho)+PCKh(rsho))/2, (PCKh(lelb)+PCKh(relb))/2,... 89 | (PCKh(lwri)+PCKh(rwri))/2, (PCKh(lhip)+PCKh(rhip))/2, ... 90 | (PCKh(lkne)+PCKh(rkne))/2, (PCKh(lank)+PCKh(rank))/2, mean(PCKh([1:6, 9:16]))); 91 | fprintf('\n'); 92 | 93 | -------------------------------------------------------------------------------- /evaluation/eval_PCKh.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from numpy import transpose 3 | import skimage.io as sio 4 | from utils import visualize 5 | import numpy as np 6 | import os 7 | 8 | 9 | detection = loadmat('evaluation/data/detections.mat') 10 | det_idxs = detection['RELEASE_img_index'] 11 | debug = 0 12 | threshold = 0.5 13 | SC_BIAS = 0.6 14 | 15 | pa = [2, 3, 7, 7, 4, 5, 8, 9, 10, 0, 12, 13, 8, 8, 14, 15] 16 | 17 | dict = loadmat('evaluation/data/detections_our_format.mat') 18 | dataset_joints = dict['dataset_joints'] 19 | jnt_missing = dict['jnt_missing'] 20 | pos_pred_src = dict['pos_pred_src'] 21 | pos_gt_src = dict['pos_gt_src'] 22 | headboxes_src = dict['headboxes_src'] 23 | 24 | 25 | 26 | #predictions 27 | model_name = 'hg4' 28 | predfile = 'checkpoint/mpii/' + model_name + '/preds_valid.mat' 29 | preds = loadmat(predfile)['preds'] 30 | pos_pred_src = transpose(preds, [1, 2, 0]) 31 | 32 | 33 | if debug: 34 | 35 | for i in range(len(det_idxs[0])): 36 | anno = mat['RELEASE']['annolist'][0, 0][0][det_idxs[0][i] - 1] 37 | fn = anno['image']['name'][0, 0][0] 38 | imagePath = 'data/mpii/images/' + fn 39 | oriImg = sio.imread(imagePath) 40 | pred = pos_pred_src[:, :, i] 41 | visualize(oriImg, pred, pa) 42 | 43 | 44 | head = np.where(dataset_joints == 'head')[1][0] 45 | lsho = np.where(dataset_joints == 'lsho')[1][0] 46 | lelb = np.where(dataset_joints == 'lelb')[1][0] 47 | lwri = np.where(dataset_joints == 'lwri')[1][0] 48 | lhip = np.where(dataset_joints == 'lhip')[1][0] 49 | lkne = np.where(dataset_joints == 'lkne')[1][0] 50 | lank = np.where(dataset_joints == 'lank')[1][0] 51 | 52 | rsho = np.where(dataset_joints == 'rsho')[1][0] 53 | relb = np.where(dataset_joints == 'relb')[1][0] 54 | rwri = np.where(dataset_joints == 'rwri')[1][0] 55 | rkne = np.where(dataset_joints == 'rkne')[1][0] 56 | rank = np.where(dataset_joints == 'rank')[1][0] 57 | rhip = np.where(dataset_joints == 'rhip')[1][0] 58 | 59 | jnt_visible = 1 - jnt_missing 60 | uv_error = pos_pred_src - pos_gt_src 61 | uv_err = np.linalg.norm(uv_error, axis=1) 62 | headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :] 63 | headsizes = np.linalg.norm(headsizes, axis=0) 64 | headsizes *= SC_BIAS 65 | scale = np.multiply(headsizes, np.ones((len(uv_err), 1))) 66 | scaled_uv_err = np.divide(uv_err, scale) 67 | scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible) 68 | jnt_count = np.sum(jnt_visible, axis=1) 69 | less_than_threshold = np.multiply((scaled_uv_err < threshold), jnt_visible) 70 | PCKh = np.divide(100. * np.sum(less_than_threshold, axis=1), jnt_count) 71 | 72 | 73 | # save 74 | rng = np.arange(0, 0.5, 0.01) 75 | pckAll = np.zeros((len(rng), 16)) 76 | 77 | for r in range(len(rng)): 78 | threshold = rng[r] 79 | less_than_threshold = np.multiply(scaled_uv_err < threshold, jnt_visible) 80 | pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count) 81 | 82 | name = predfile.split(os.sep)[-1] 83 | PCKh = np.ma.array(PCKh, mask=False) 84 | PCKh.mask[6:8] = True 85 | print("Model, Head, Shoulder, Elbow, Wrist, Hip , Knee , Ankle , Mean") 86 | print('{:s} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f}'.format(model_name, PCKh[head], 0.5 * (PCKh[lsho] + PCKh[rsho])\ 87 | , 0.5 * (PCKh[lelb] + PCKh[relb]),0.5 * (PCKh[lwri] + PCKh[rwri]), 0.5 * (PCKh[lhip] + PCKh[rhip]), 0.5 * (PCKh[lkne] + PCKh[rkne]) \ 88 | , 0.5 * (PCKh[lank] + PCKh[rank]), np.mean(PCKh))) -------------------------------------------------------------------------------- /evaluation/showskeletons_joints.m: -------------------------------------------------------------------------------- 1 | function h = showskeletons_joints(im, points, pa, msize, torsobox) 2 | if nargin < 4 3 | msize = 4; 4 | end 5 | if nargin < 5 6 | torsobox = []; 7 | end 8 | p_no = numel(pa); 9 | 10 | switch p_no 11 | case 26 12 | partcolor = {'g','g','y','r','r','r','r','y','y','y','m','m','m','m','y','b','b','b','b','y','y','y','c','c','c','c'}; 13 | case 14 14 | partcolor = {'g','g','y','r','r','y','m','m','y','b','b','y','c','c'}; 15 | case 10 16 | partcolor = {'g','g','y','y','y','r','m','m','m','b','b','b','y','c','c'}; 17 | case 18 18 | partcolor = {'g','g','y','r','r','r','r','y','y','y','y','b','b','b','b','y','y','y'}; 19 | case 16 20 | partcolor = {'g','g','g','r','r','r','y','y','y','b','b','b','c','c','m','m'}; 21 | otherwise 22 | error('showboxes: not supported'); 23 | end 24 | h = imshow(im); hold on; 25 | if ~isempty(points) 26 | x = points(:,1); 27 | y = points(:,2); 28 | for n = 1:size(x,1) 29 | for child = 1:p_no 30 | if child == 0 || pa(child) == 0 31 | continue; 32 | end 33 | x1 = x(pa(child)); 34 | y1 = y(pa(child)); 35 | x2 = x(child); 36 | y2 = y(child); 37 | 38 | plot(x1, y1, 'o', 'color', partcolor{child}, ... 39 | 'MarkerSize',msize, 'MarkerFaceColor', partcolor{child}); 40 | plot(x2, y2, 'o', 'color', partcolor{child}, ... 41 | 'MarkerSize',msize, 'MarkerFaceColor', partcolor{child}); 42 | line([x1 x2],[y1 y2],'color',partcolor{child},'linewidth',round(msize/2)); 43 | end 44 | end 45 | end 46 | if ~isempty(torsobox) 47 | plotbox(torsobox,'w--'); 48 | end 49 | drawnow; hold off; 50 | -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def visualize(oriImg, points, pa): 3 | import matplotlib 4 | import cv2 as cv 5 | import matplotlib.pyplot as plt 6 | import math 7 | 8 | fig = matplotlib.pyplot.gcf() 9 | # fig.set_size_inches(12, 12) 10 | 11 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 12 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 13 | [170,0,255],[255,0,255]] 14 | canvas = oriImg 15 | stickwidth = 4 16 | x = points[:, 0] 17 | y = points[:, 1] 18 | 19 | for n in range(len(x)): 20 | for child in range(len(pa)): 21 | if pa[child] is 0: 22 | continue 23 | 24 | x1 = x[pa[child] - 1] 25 | y1 = y[pa[child] - 1] 26 | x2 = x[child] 27 | y2 = y[child] 28 | 29 | cv.line(canvas, (x1, y1), (x2, y2), colors[child], 8) 30 | 31 | 32 | plt.imshow(canvas[:, :, [2, 1, 0]]) 33 | fig = matplotlib.pyplot.gcf() 34 | fig.set_size_inches(12, 12) 35 | 36 | from time import gmtime, strftime 37 | import os 38 | directory = 'data/mpii/result/test_images' 39 | if not os.path.exists(directory): 40 | os.makedirs(directory) 41 | 42 | fn = os.path.join(directory, strftime("%Y-%m-%d-%H_%M_%S", gmtime()) + '.jpg') 43 | 44 | plt.savefig(fn) -------------------------------------------------------------------------------- /example/lsp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import argparse 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torchvision.datasets as datasets 13 | 14 | from pose import Bar 15 | from pose.utils.logger import Logger, savefig 16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds 17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate 18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join 19 | from pose.utils.imutils import batch_with_heatmap 20 | from pose.utils.transforms import fliplr, flip_back 21 | import pose.models as models 22 | import pose.datasets as datasets 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") 26 | and callable(models.__dict__[name])) 27 | 28 | idx = [1,2,3,4,5,6,11,12,15,16] 29 | 30 | best_acc = 0 31 | 32 | 33 | def main(args): 34 | global best_acc 35 | 36 | # create checkpoint dir 37 | if not isdir(args.checkpoint): 38 | mkdir_p(args.checkpoint) 39 | 40 | # create model 41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks)) 42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes) 43 | 44 | model = torch.nn.DataParallel(model).cuda() 45 | 46 | # define loss function (criterion) and optimizer 47 | criterion = torch.nn.MSELoss(size_average=True).cuda() 48 | 49 | optimizer = torch.optim.RMSprop(model.parameters(), 50 | lr=args.lr, 51 | momentum=args.momentum, 52 | weight_decay=args.weight_decay) 53 | 54 | # optionally resume from a checkpoint 55 | title = 'LSP-' + args.arch 56 | if args.resume: 57 | if isfile(args.resume): 58 | print("=> loading checkpoint '{}'".format(args.resume)) 59 | checkpoint = torch.load(args.resume) 60 | args.start_epoch = checkpoint['epoch'] 61 | best_acc = checkpoint['best_acc'] 62 | model.load_state_dict(checkpoint['state_dict']) 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | print("=> loaded checkpoint '{}' (epoch {})" 65 | .format(args.resume, checkpoint['epoch'])) 66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) 67 | else: 68 | print("=> no checkpoint found at '{}'".format(args.resume)) 69 | else: 70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title) 71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) 72 | 73 | cudnn.benchmark = True 74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 75 | 76 | # Data loading code 77 | train_loader = torch.utils.data.DataLoader( 78 | datasets.LSP('data/lsp/LEEDS_annotations.json', 'data/lsp/images', 79 | sigma=args.sigma, label_type=args.label_type), 80 | batch_size=args.train_batch, shuffle=True, 81 | num_workers=args.workers, pin_memory=True) 82 | 83 | val_loader = torch.utils.data.DataLoader( 84 | datasets.LSP('data/lsp/LEEDS_annotations.json', 'data/lsp/images', 85 | sigma=args.sigma, label_type=args.label_type, train=False), 86 | batch_size=args.test_batch, shuffle=False, 87 | num_workers=args.workers, pin_memory=True) 88 | 89 | if args.evaluate: 90 | print('\nEvaluation only') 91 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip) 92 | save_pred(predictions, checkpoint=args.checkpoint) 93 | return 94 | 95 | lr = args.lr 96 | for epoch in range(args.start_epoch, args.epochs): 97 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) 98 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 99 | 100 | # decay sigma 101 | if args.sigma_decay > 0: 102 | train_loader.dataset.sigma *= args.sigma_decay 103 | val_loader.dataset.sigma *= args.sigma_decay 104 | 105 | # train for one epoch 106 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip) 107 | 108 | # evaluate on validation set 109 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes, 110 | args.debug, args.flip) 111 | 112 | # append logger file 113 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) 114 | 115 | # remember best acc and save checkpoint 116 | is_best = valid_acc > best_acc 117 | best_acc = max(valid_acc, best_acc) 118 | save_checkpoint({ 119 | 'epoch': epoch + 1, 120 | 'arch': args.arch, 121 | 'state_dict': model.state_dict(), 122 | 'best_acc': best_acc, 123 | 'optimizer' : optimizer.state_dict(), 124 | }, predictions, is_best, checkpoint=args.checkpoint) 125 | 126 | logger.close() 127 | logger.plot(['Train Acc', 'Val Acc']) 128 | savefig(os.path.join(args.checkpoint, 'log.eps')) 129 | 130 | 131 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True): 132 | batch_time = AverageMeter() 133 | data_time = AverageMeter() 134 | losses = AverageMeter() 135 | acces = AverageMeter() 136 | 137 | # switch to train mode 138 | model.train() 139 | 140 | end = time.time() 141 | 142 | gt_win, pred_win = None, None 143 | bar = Bar('Processing', max=len(train_loader)) 144 | for i, (inputs, target, meta) in enumerate(train_loader): 145 | # measure data loading time 146 | data_time.update(time.time() - end) 147 | 148 | input_var = torch.autograd.Variable(inputs.cuda()) 149 | target_var = torch.autograd.Variable(target.cuda(async=True)) 150 | 151 | # compute output 152 | output = model(input_var) 153 | score_map = output[-1].data.cpu() 154 | 155 | loss = criterion(output[0], target_var) 156 | for j in range(1, len(output)): 157 | loss += criterion(output[j], target_var) 158 | acc = accuracy(score_map, target, idx) 159 | 160 | if debug: # visualize groundtruth and predictions 161 | gt_batch_img = batch_with_heatmap(inputs, target) 162 | pred_batch_img = batch_with_heatmap(inputs, score_map) 163 | if not gt_win or not pred_win: 164 | ax1 = plt.subplot(121) 165 | ax1.title.set_text('Groundtruth') 166 | gt_win = plt.imshow(gt_batch_img) 167 | ax2 = plt.subplot(122) 168 | ax2.title.set_text('Prediction') 169 | pred_win = plt.imshow(pred_batch_img) 170 | else: 171 | gt_win.set_data(gt_batch_img) 172 | pred_win.set_data(pred_batch_img) 173 | plt.pause(.05) 174 | plt.draw() 175 | 176 | # measure accuracy and record loss 177 | losses.update(loss.data[0], inputs.size(0)) 178 | acces.update(acc[0], inputs.size(0)) 179 | 180 | # compute gradient and do SGD step 181 | optimizer.zero_grad() 182 | loss.backward() 183 | optimizer.step() 184 | 185 | # measure elapsed time 186 | batch_time.update(time.time() - end) 187 | end = time.time() 188 | 189 | # plot progress 190 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 191 | batch=i + 1, 192 | size=len(train_loader), 193 | data=data_time.val, 194 | bt=batch_time.val, 195 | total=bar.elapsed_td, 196 | eta=bar.eta_td, 197 | loss=losses.avg, 198 | acc=acces.avg 199 | ) 200 | bar.next() 201 | 202 | bar.finish() 203 | return losses.avg, acces.avg 204 | 205 | 206 | def validate(val_loader, model, criterion, num_classes, debug=False, flip=True): 207 | batch_time = AverageMeter() 208 | data_time = AverageMeter() 209 | losses = AverageMeter() 210 | acces = AverageMeter() 211 | 212 | # predictions 213 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2) 214 | 215 | # switch to evaluate mode 216 | model.eval() 217 | 218 | gt_win, pred_win = None, None 219 | end = time.time() 220 | bar = Bar('Processing', max=len(val_loader)) 221 | for i, (inputs, target, meta) in enumerate(val_loader): 222 | # measure data loading time 223 | data_time.update(time.time() - end) 224 | 225 | target = target.cuda(async=True) 226 | 227 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True) 228 | target_var = torch.autograd.Variable(target, volatile=True) 229 | 230 | # compute output 231 | output = model(input_var) 232 | score_map = output[-1].data.cpu() 233 | if flip: 234 | flip_input_var = torch.autograd.Variable( 235 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(), 236 | volatile=True 237 | ) 238 | flip_output_var = model(flip_input_var) 239 | flip_output = flip_back(flip_output_var[-1].data.cpu()) 240 | score_map += flip_output 241 | 242 | 243 | 244 | loss = 0 245 | for o in output: 246 | loss += criterion(o, target_var) 247 | acc = accuracy(score_map, target.cpu(), idx) 248 | 249 | # generate predictions 250 | preds = final_preds(score_map, meta['center'], meta['scale'], [64, 64]) 251 | for n in range(score_map.size(0)): 252 | predictions[meta['index'][n], :, :] = preds[n, :, :] 253 | 254 | 255 | if debug: 256 | gt_batch_img = batch_with_heatmap(inputs, target) 257 | pred_batch_img = batch_with_heatmap(inputs, score_map) 258 | if not gt_win or not pred_win: 259 | plt.subplot(121) 260 | gt_win = plt.imshow(gt_batch_img) 261 | plt.subplot(122) 262 | pred_win = plt.imshow(pred_batch_img) 263 | else: 264 | gt_win.set_data(gt_batch_img) 265 | pred_win.set_data(pred_batch_img) 266 | plt.pause(.05) 267 | plt.draw() 268 | 269 | # measure accuracy and record loss 270 | losses.update(loss.data[0], inputs.size(0)) 271 | acces.update(acc[0], inputs.size(0)) 272 | 273 | # measure elapsed time 274 | batch_time.update(time.time() - end) 275 | end = time.time() 276 | 277 | # plot progress 278 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 279 | batch=i + 1, 280 | size=len(val_loader), 281 | data=data_time.val, 282 | bt=batch_time.avg, 283 | total=bar.elapsed_td, 284 | eta=bar.eta_td, 285 | loss=losses.avg, 286 | acc=acces.avg 287 | ) 288 | bar.next() 289 | 290 | bar.finish() 291 | return losses.avg, acces.avg, predictions 292 | 293 | if __name__ == '__main__': 294 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 295 | # Model structure 296 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg', 297 | choices=model_names, 298 | help='model architecture: ' + 299 | ' | '.join(model_names) + 300 | ' (default: resnet18)') 301 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N', 302 | help='Number of hourglasses to stack') 303 | parser.add_argument('--features', default=256, type=int, metavar='N', 304 | help='Number of features in the hourglass') 305 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N', 306 | help='Number of residual modules at each location in the hourglass') 307 | parser.add_argument('--num-classes', default=16, type=int, metavar='N', 308 | help='Number of keypoints') 309 | # Training strategy 310 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 311 | help='number of data loading workers (default: 4)') 312 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 313 | help='number of total epochs to run') 314 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 315 | help='manual epoch number (useful on restarts)') 316 | parser.add_argument('--train-batch', default=6, type=int, metavar='N', 317 | help='train batchsize') 318 | parser.add_argument('--test-batch', default=6, type=int, metavar='N', 319 | help='test batchsize') 320 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float, 321 | metavar='LR', help='initial learning rate') 322 | parser.add_argument('--momentum', default=0, type=float, metavar='M', 323 | help='momentum') 324 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 325 | metavar='W', help='weight decay (default: 0)') 326 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90], 327 | help='Decrease learning rate at these epochs.') 328 | parser.add_argument('--gamma', type=float, default=0.1, 329 | help='LR is multiplied by gamma on schedule.') 330 | # Data processing 331 | parser.add_argument('-f', '--flip', dest='flip', action='store_true', 332 | help='flip the input during validation') 333 | parser.add_argument('--sigma', type=float, default=1, 334 | help='Groundtruth Gaussian sigma.') 335 | parser.add_argument('--sigma-decay', type=float, default=0, 336 | help='Sigma decay rate for each epoch.') 337 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian', 338 | choices=['Gaussian', 'Cauchy'], 339 | help='Labelmap dist type: (default=Gaussian)') 340 | # Miscs 341 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 342 | help='path to save checkpoint (default: checkpoint)') 343 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 344 | help='path to latest checkpoint (default: none)') 345 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 346 | help='evaluate model on validation set') 347 | parser.add_argument('-d', '--debug', dest='debug', action='store_true', 348 | help='show intermediate results') 349 | 350 | 351 | main(parser.parse_args()) -------------------------------------------------------------------------------- /example/mpii.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import argparse 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torchvision.datasets as datasets 13 | 14 | from pose import Bar 15 | from pose.utils.logger import Logger, savefig 16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds 17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate 18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join 19 | from pose.utils.imutils import batch_with_heatmap 20 | from pose.utils.transforms import fliplr, flip_back 21 | import pose.models as models 22 | import pose.datasets as datasets 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") 26 | and callable(models.__dict__[name])) 27 | 28 | idx = [1,2,3,4,5,6,11,12,15,16] 29 | 30 | best_acc = 0 31 | 32 | 33 | def main(args): 34 | global best_acc 35 | 36 | # create checkpoint dir 37 | if not isdir(args.checkpoint): 38 | mkdir_p(args.checkpoint) 39 | 40 | # create model 41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks)) 42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes, mobile=args.mobile) 43 | 44 | model = torch.nn.DataParallel(model).cuda() 45 | 46 | # define loss function (criterion) and optimizer 47 | criterion = torch.nn.MSELoss(size_average=True).cuda() 48 | 49 | optimizer = torch.optim.RMSprop(model.parameters(), 50 | lr=args.lr, 51 | momentum=args.momentum, 52 | weight_decay=args.weight_decay) 53 | 54 | # optionally resume from a checkpoint 55 | title = 'mpii-' + args.arch 56 | if args.resume: 57 | if isfile(args.resume): 58 | print("=> loading checkpoint '{}'".format(args.resume)) 59 | checkpoint = torch.load(args.resume) 60 | args.start_epoch = checkpoint['epoch'] 61 | best_acc = checkpoint['best_acc'] 62 | model.load_state_dict(checkpoint['state_dict']) 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | print("=> loaded checkpoint '{}' (epoch {})" 65 | .format(args.resume, checkpoint['epoch'])) 66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) 67 | else: 68 | print("=> no checkpoint found at '{}'".format(args.resume)) 69 | else: 70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title) 71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) 72 | 73 | cudnn.benchmark = True 74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 75 | 76 | # Data loading code 77 | train_loader = torch.utils.data.DataLoader( 78 | datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', 79 | sigma=args.sigma, label_type=args.label_type, 80 | inp_res=args.in_res, out_res=args.in_res//4), 81 | batch_size=args.train_batch, shuffle=True, 82 | num_workers=args.workers, pin_memory=True) 83 | 84 | val_loader = torch.utils.data.DataLoader( 85 | datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images', 86 | sigma=args.sigma, label_type=args.label_type, train=False, 87 | inp_res=args.in_res, out_res=args.in_res // 4), 88 | batch_size=args.test_batch, shuffle=False, 89 | num_workers=args.workers, pin_memory=True) 90 | 91 | if args.evaluate: 92 | print('\nEvaluation only') 93 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.in_res//4, args.debug, args.flip) 94 | save_pred(predictions, checkpoint=args.checkpoint) 95 | return 96 | 97 | lr = args.lr 98 | for epoch in range(args.start_epoch, args.epochs): 99 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) 100 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 101 | 102 | # decay sigma 103 | if args.sigma_decay > 0: 104 | train_loader.dataset.sigma *= args.sigma_decay 105 | val_loader.dataset.sigma *= args.sigma_decay 106 | 107 | # train for one epoch 108 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip) 109 | 110 | # evaluate on validation set 111 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes, 112 | args.in_res//4, args.debug, args.flip) 113 | 114 | # append logger file 115 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) 116 | 117 | # remember best acc and save checkpoint 118 | is_best = valid_acc > best_acc 119 | best_acc = max(valid_acc, best_acc) 120 | save_checkpoint({ 121 | 'epoch': epoch + 1, 122 | 'arch': args.arch, 123 | 'state_dict': model.state_dict(), 124 | 'best_acc': best_acc, 125 | 'optimizer' : optimizer.state_dict(), 126 | }, predictions, is_best, checkpoint=args.checkpoint) 127 | 128 | logger.close() 129 | logger.plot(['Train Acc', 'Val Acc']) 130 | savefig(os.path.join(args.checkpoint, 'log.eps')) 131 | 132 | 133 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True): 134 | batch_time = AverageMeter() 135 | data_time = AverageMeter() 136 | losses = AverageMeter() 137 | acces = AverageMeter() 138 | 139 | # switch to train mode 140 | model.train() 141 | 142 | end = time.time() 143 | 144 | gt_win, pred_win = None, None 145 | bar = Bar('Processing', max=len(train_loader)) 146 | for i, (inputs, target, meta) in enumerate(train_loader): 147 | # measure data loading time 148 | data_time.update(time.time() - end) 149 | 150 | input_var = torch.autograd.Variable(inputs.cuda()) 151 | target_var = torch.autograd.Variable(target.cuda(async=True)) 152 | 153 | # compute output 154 | output = model(input_var) 155 | score_map = output[-1].data.cpu() 156 | 157 | loss = criterion(output[0], target_var) 158 | for j in range(1, len(output)): 159 | loss += criterion(output[j], target_var) 160 | acc = accuracy(score_map, target, idx) 161 | 162 | if debug: # visualize groundtruth and predictions 163 | gt_batch_img = batch_with_heatmap(inputs, target) 164 | pred_batch_img = batch_with_heatmap(inputs, score_map) 165 | if not gt_win or not pred_win: 166 | ax1 = plt.subplot(121) 167 | ax1.title.set_text('Groundtruth') 168 | gt_win = plt.imshow(gt_batch_img) 169 | ax2 = plt.subplot(122) 170 | ax2.title.set_text('Prediction') 171 | pred_win = plt.imshow(pred_batch_img) 172 | else: 173 | gt_win.set_data(gt_batch_img) 174 | pred_win.set_data(pred_batch_img) 175 | plt.pause(.05) 176 | plt.draw() 177 | 178 | # measure accuracy and record loss 179 | losses.update(loss.item(), inputs.size(0)) 180 | acces.update(acc[0], inputs.size(0)) 181 | 182 | # compute gradient and do SGD step 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | 187 | # measure elapsed time 188 | batch_time.update(time.time() - end) 189 | end = time.time() 190 | 191 | # plot progress 192 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 193 | batch=i + 1, 194 | size=len(train_loader), 195 | data=data_time.val, 196 | bt=batch_time.val, 197 | total=bar.elapsed_td, 198 | eta=bar.eta_td, 199 | loss=losses.avg, 200 | acc=acces.avg 201 | ) 202 | bar.next() 203 | 204 | bar.finish() 205 | return losses.avg, acces.avg 206 | 207 | 208 | def validate(val_loader, model, criterion, num_classes, out_res, debug=False, flip=True): 209 | batch_time = AverageMeter() 210 | data_time = AverageMeter() 211 | losses = AverageMeter() 212 | acces = AverageMeter() 213 | 214 | # predictions 215 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2) 216 | 217 | # switch to evaluate mode 218 | model.eval() 219 | 220 | gt_win, pred_win = None, None 221 | end = time.time() 222 | bar = Bar('Processing', max=len(val_loader)) 223 | for i, (inputs, target, meta) in enumerate(val_loader): 224 | # measure data loading time 225 | data_time.update(time.time() - end) 226 | 227 | target = target.cuda(async=True) 228 | 229 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True) 230 | target_var = torch.autograd.Variable(target, volatile=True) 231 | 232 | # compute output 233 | output = model(input_var) 234 | score_map = output[-1].data.cpu() 235 | if flip: 236 | flip_input_var = torch.autograd.Variable( 237 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(), 238 | volatile=True 239 | ) 240 | flip_output_var = model(flip_input_var) 241 | flip_output = flip_back(flip_output_var[-1].data.cpu()) 242 | score_map += flip_output 243 | 244 | 245 | 246 | loss = 0 247 | for o in output: 248 | loss += criterion(o, target_var) 249 | acc = accuracy(score_map, target.cpu(), idx) 250 | 251 | # generate predictions 252 | preds = final_preds(score_map, meta['center'], meta['scale'], [out_res, out_res]) 253 | for n in range(score_map.size(0)): 254 | predictions[meta['index'][n], :, :] = preds[n, :, :] 255 | 256 | 257 | if debug: 258 | gt_batch_img = batch_with_heatmap(inputs, target) 259 | pred_batch_img = batch_with_heatmap(inputs, score_map) 260 | if not gt_win or not pred_win: 261 | plt.subplot(121) 262 | gt_win = plt.imshow(gt_batch_img) 263 | plt.subplot(122) 264 | pred_win = plt.imshow(pred_batch_img) 265 | else: 266 | gt_win.set_data(gt_batch_img) 267 | pred_win.set_data(pred_batch_img) 268 | plt.pause(.05) 269 | plt.draw() 270 | 271 | # measure accuracy and record loss 272 | losses.update(loss.item(), inputs.size(0)) 273 | acces.update(acc[0], inputs.size(0)) 274 | 275 | # measure elapsed time 276 | batch_time.update(time.time() - end) 277 | end = time.time() 278 | 279 | # plot progress 280 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 281 | batch=i + 1, 282 | size=len(val_loader), 283 | data=data_time.val, 284 | bt=batch_time.avg, 285 | total=bar.elapsed_td, 286 | eta=bar.eta_td, 287 | loss=losses.avg, 288 | acc=acces.avg 289 | ) 290 | bar.next() 291 | 292 | bar.finish() 293 | return losses.avg, acces.avg, predictions 294 | 295 | if __name__ == '__main__': 296 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 297 | # Model structure 298 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg', 299 | choices=model_names, 300 | help='model architecture: ' + 301 | ' | '.join(model_names) + 302 | ' (default: resnet18)') 303 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N', 304 | help='Number of hourglasses to stack') 305 | parser.add_argument('--features', default=256, type=int, metavar='N', 306 | help='Number of features in the hourglass') 307 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N', 308 | help='Number of residual modules at each location in the hourglass') 309 | parser.add_argument('--num-classes', default=16, type=int, metavar='N', 310 | help='Number of keypoints') 311 | parser.add_argument('--mobile', default=False, type=bool, metavar='N', 312 | help='use depthwise convolution in bottneck-block') 313 | # Training strategy 314 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 315 | help='number of data loading workers (default: 4)') 316 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 317 | help='number of total epochs to run') 318 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 319 | help='manual epoch number (useful on restarts)') 320 | parser.add_argument('--train-batch', default=6, type=int, metavar='N', 321 | help='train batchsize') 322 | parser.add_argument('--test-batch', default=6, type=int, metavar='N', 323 | help='test batchsize') 324 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float, 325 | metavar='LR', help='initial learning rate') 326 | parser.add_argument('--momentum', default=0, type=float, metavar='M', 327 | help='momentum') 328 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 329 | metavar='W', help='weight decay (default: 0)') 330 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90], 331 | help='Decrease learning rate at these epochs.') 332 | parser.add_argument('--gamma', type=float, default=0.1, 333 | help='LR is multiplied by gamma on schedule.') 334 | # Data processing 335 | parser.add_argument('-f', '--flip', dest='flip', action='store_true', 336 | help='flip the input during validation') 337 | parser.add_argument('--sigma', type=float, default=1, 338 | help='Groundtruth Gaussian sigma.') 339 | parser.add_argument('--sigma-decay', type=float, default=0, 340 | help='Sigma decay rate for each epoch.') 341 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian', 342 | choices=['Gaussian', 'Cauchy'], 343 | help='Labelmap dist type: (default=Gaussian)') 344 | parser.add_argument('--in_res', default=256, type=int, 345 | choices=[256, 192], 346 | help='input resolution for network') 347 | # Miscs 348 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 349 | help='path to save checkpoint (default: checkpoint)') 350 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 351 | help='path to latest checkpoint (default: none)') 352 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 353 | help='evaluate model on validation set') 354 | parser.add_argument('-d', '--debug', dest='debug', action='store_true', 355 | help='show intermediate results') 356 | 357 | main(parser.parse_args()) -------------------------------------------------------------------------------- /example/mscoco.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import argparse 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torchvision.datasets as datasets 13 | 14 | from pose import Bar 15 | from pose.utils.logger import Logger, savefig 16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds 17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate 18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join 19 | from pose.utils.imutils import batch_with_heatmap 20 | from pose.utils.transforms import fliplr, flip_back 21 | import pose.models as models 22 | import pose.datasets as datasets 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") 26 | and callable(models.__dict__[name])) 27 | 28 | idx = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17] 29 | 30 | best_acc = 0 31 | 32 | 33 | def main(args): 34 | global best_acc 35 | 36 | # create checkpoint dir 37 | if not isdir(args.checkpoint): 38 | mkdir_p(args.checkpoint) 39 | 40 | # create model 41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks)) 42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes) 43 | 44 | model = torch.nn.DataParallel(model).cuda() 45 | 46 | # define loss function (criterion) and optimizer 47 | criterion = torch.nn.MSELoss(size_average=True).cuda() 48 | 49 | optimizer = torch.optim.RMSprop(model.parameters(), 50 | lr=args.lr, 51 | momentum=args.momentum, 52 | weight_decay=args.weight_decay) 53 | 54 | # optionally resume from a checkpoint 55 | title = 'MSCOCO-' + args.arch 56 | if args.resume: 57 | if isfile(args.resume): 58 | print("=> loading checkpoint '{}'".format(args.resume)) 59 | checkpoint = torch.load(args.resume) 60 | args.start_epoch = checkpoint['epoch'] 61 | best_acc = checkpoint['best_acc'] 62 | model.load_state_dict(checkpoint['state_dict']) 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | print("=> loaded checkpoint '{}' (epoch {})" 65 | .format(args.resume, checkpoint['epoch'])) 66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) 67 | else: 68 | print("=> no checkpoint found at '{}'".format(args.resume)) 69 | else: 70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title) 71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) 72 | 73 | cudnn.benchmark = True 74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 75 | 76 | # Data loading code 77 | train_loader = torch.utils.data.DataLoader( 78 | datasets.Mscoco('data/mscoco/coco_annotations.json', 'data/mscoco/keypoint/images/train2014', 79 | sigma=args.sigma, label_type=args.label_type), 80 | batch_size=args.train_batch, shuffle=True, 81 | num_workers=args.workers, pin_memory=True) 82 | 83 | val_loader = torch.utils.data.DataLoader( 84 | datasets.Mscoco('data/mscoco/coco_annotations.json', 'data/mscoco/keypoint/images/val2014', 85 | sigma=args.sigma, label_type=args.label_type, train=False), 86 | batch_size=args.test_batch, shuffle=False, 87 | num_workers=args.workers, pin_memory=True) 88 | 89 | if args.evaluate: 90 | print('\nEvaluation only') 91 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip) 92 | save_pred(predictions, checkpoint=args.checkpoint) 93 | return 94 | 95 | lr = args.lr 96 | for epoch in range(args.start_epoch, args.epochs): 97 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) 98 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 99 | 100 | # train for one epoch 101 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip) 102 | 103 | # evaluate on validation set 104 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip) 105 | 106 | # append logger file 107 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc]) 108 | 109 | # remember best acc and save checkpoint 110 | is_best = valid_acc > best_acc 111 | best_acc = max(valid_acc, best_acc) 112 | save_checkpoint({ 113 | 'epoch': epoch + 1, 114 | 'arch': args.arch, 115 | 'state_dict': model.state_dict(), 116 | 'best_acc': best_acc, 117 | 'optimizer' : optimizer.state_dict(), 118 | }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot) 119 | 120 | logger.close() 121 | logger.plot(['Train Acc', 'Val Acc']) 122 | savefig(os.path.join(args.checkpoint, 'log.eps')) 123 | 124 | 125 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True): 126 | batch_time = AverageMeter() 127 | data_time = AverageMeter() 128 | losses = AverageMeter() 129 | acces = AverageMeter() 130 | 131 | # switch to train mode 132 | model.train() 133 | 134 | end = time.time() 135 | 136 | gt_win, pred_win = None, None 137 | bar = Bar('Processing', max=len(train_loader)) 138 | for i, (inputs, target, meta) in enumerate(train_loader): 139 | # measure data loading time 140 | data_time.update(time.time() - end) 141 | 142 | input_var = torch.autograd.Variable(inputs.cuda()) 143 | target_var = torch.autograd.Variable(target.cuda(async=True)) 144 | 145 | # compute output 146 | output = model(input_var) 147 | score_map = output[-1].data.cpu() 148 | 149 | loss = criterion(output[0], target_var) 150 | for j in range(1, len(output)): 151 | loss += criterion(output[j], target_var) 152 | acc = accuracy(score_map, target, idx) 153 | 154 | if debug: # visualize groundtruth and predictions 155 | gt_batch_img = batch_with_heatmap(inputs, target) 156 | pred_batch_img = batch_with_heatmap(inputs, score_map) 157 | if not gt_win or not pred_win: 158 | ax1 = plt.subplot(121) 159 | ax1.title.set_text('Groundtruth') 160 | gt_win = plt.imshow(gt_batch_img) 161 | ax2 = plt.subplot(122) 162 | ax2.title.set_text('Prediction') 163 | pred_win = plt.imshow(pred_batch_img) 164 | else: 165 | gt_win.set_data(gt_batch_img) 166 | pred_win.set_data(pred_batch_img) 167 | plt.pause(.05) 168 | plt.draw() 169 | 170 | # measure accuracy and record loss 171 | losses.update(loss.data[0], inputs.size(0)) 172 | acces.update(acc[0], inputs.size(0)) 173 | 174 | # compute gradient and do SGD step 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | 179 | # measure elapsed time 180 | batch_time.update(time.time() - end) 181 | end = time.time() 182 | 183 | # plot progress 184 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 185 | batch=i + 1, 186 | size=len(train_loader), 187 | data=data_time.val, 188 | bt=batch_time.val, 189 | total=bar.elapsed_td, 190 | eta=bar.eta_td, 191 | loss=losses.avg, 192 | acc=acces.avg 193 | ) 194 | bar.next() 195 | 196 | bar.finish() 197 | return losses.avg, acces.avg 198 | 199 | 200 | def validate(val_loader, model, criterion, num_classes, debug=False, flip=True): 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | acces = AverageMeter() 205 | 206 | # predictions 207 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2) 208 | 209 | # switch to evaluate mode 210 | model.eval() 211 | 212 | gt_win, pred_win = None, None 213 | end = time.time() 214 | bar = Bar('Processing', max=len(val_loader)) 215 | for i, (inputs, target, meta) in enumerate(val_loader): 216 | # measure data loading time 217 | data_time.update(time.time() - end) 218 | 219 | target = target.cuda(async=True) 220 | 221 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True) 222 | target_var = torch.autograd.Variable(target, volatile=True) 223 | 224 | # compute output 225 | output = model(input_var) 226 | score_map = output[-1].data.cpu() 227 | if flip: 228 | flip_input_var = torch.autograd.Variable( 229 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(), 230 | volatile=True 231 | ) 232 | flip_output_var = model(flip_input_var) 233 | flip_output = flip_back(flip_output_var[-1].data.cpu()) 234 | score_map += flip_output 235 | 236 | 237 | 238 | loss = 0 239 | for o in output: 240 | loss += criterion(o, target_var) 241 | acc = accuracy(score_map, target.cpu(), idx) 242 | 243 | # generate predictions 244 | preds = final_preds(score_map, meta['center'], meta['scale'], [64, 64]) 245 | for n in range(score_map.size(0)): 246 | predictions[meta['index'][n], :, :] = preds[n, :, :] 247 | 248 | 249 | if debug: 250 | gt_batch_img = batch_with_heatmap(inputs, target) 251 | pred_batch_img = batch_with_heatmap(inputs, score_map) 252 | if not gt_win or not pred_win: 253 | plt.subplot(121) 254 | gt_win = plt.imshow(gt_batch_img) 255 | plt.subplot(122) 256 | pred_win = plt.imshow(pred_batch_img) 257 | else: 258 | gt_win.set_data(gt_batch_img) 259 | pred_win.set_data(pred_batch_img) 260 | plt.pause(.5) 261 | plt.draw() 262 | 263 | # measure accuracy and record loss 264 | losses.update(loss.data[0], inputs.size(0)) 265 | acces.update(acc[0], inputs.size(0)) 266 | 267 | # measure elapsed time 268 | batch_time.update(time.time() - end) 269 | end = time.time() 270 | 271 | # plot progress 272 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format( 273 | batch=i + 1, 274 | size=len(val_loader), 275 | data=data_time.val, 276 | bt=batch_time.avg, 277 | total=bar.elapsed_td, 278 | eta=bar.eta_td, 279 | loss=losses.avg, 280 | acc=acces.avg 281 | ) 282 | bar.next() 283 | 284 | bar.finish() 285 | return losses.avg, acces.avg, predictions 286 | 287 | if __name__ == '__main__': 288 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 289 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg', 290 | choices=model_names, 291 | help='model architecture: ' + 292 | ' | '.join(model_names) + 293 | ' (default: resnet18)') 294 | parser.add_argument('--num-classes', default=17, type=int, metavar='N', 295 | help='Number of keypoints') 296 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 297 | help='number of data loading workers (default: 4)') 298 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 299 | help='number of total epochs to run') 300 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 301 | help='manual epoch number (useful on restarts)') 302 | parser.add_argument('--snapshot', default=0, type=int, metavar='N', 303 | help='How often to take a snapshot of the model (0 = never)') 304 | parser.add_argument('--train-batch', default=6, type=int, metavar='N', 305 | help='train batchsize') 306 | parser.add_argument('--test-batch', default=6, type=int, metavar='N', 307 | help='test batchsize') 308 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float, 309 | metavar='LR', help='initial learning rate') 310 | parser.add_argument('--momentum', default=0, type=float, metavar='M', 311 | help='momentum') 312 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 313 | metavar='W', help='weight decay (default: 0)') 314 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90], 315 | help='Decrease learning rate at these epochs.') 316 | parser.add_argument('--gamma', type=float, default=0.1, 317 | help='LR is multiplied by gamma on schedule.') 318 | parser.add_argument('--sigma', type=float, default=1, 319 | help='Sigma to generate Gaussian groundtruth map.') 320 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian', 321 | choices=['Gaussian', 'Cauchy'], 322 | help='Labelmap dist type: (default=Gaussian)') 323 | parser.add_argument('--print-freq', '-p', default=10, type=int, 324 | metavar='N', help='print frequency (default: 10)') 325 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 326 | help='path to save checkpoint (default: checkpoint)') 327 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 328 | help='path to latest checkpoint (default: none)') 329 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 330 | help='evaluate model on validation set') 331 | parser.add_argument('-d', '--debug', dest='debug', action='store_true', 332 | help='show intermediate results') 333 | parser.add_argument('-f', '--flip', dest='flip', action='store_true', 334 | help='flip the input during validation') 335 | # Model structure 336 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N', 337 | help='Number of hourglasses to stack') 338 | parser.add_argument('--features', default=256, type=int, metavar='N', 339 | help='Number of features in the hourglass') 340 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N', 341 | help='Number of residual modules at each location in the hourglass') 342 | 343 | 344 | main(parser.parse_args()) -------------------------------------------------------------------------------- /miscs/cocoScale.m: -------------------------------------------------------------------------------- 1 | function scale = cocoScale(x, y, v) 2 | % Mean distance on MPII dataset 3 | % rtorso, ltorso, rlleg, ruleg, lulleg, llleg, 4 | % rlarm, ruarm, luarm, llarm, head 5 | meandist = [59.3535, 60.4532, 52.1800, 53.7957, 54.4153, 58.0402, ... 6 | 27.0043, 32.8498, 33.1757, 27.0978, 33.3005]; 7 | 8 | sk = {[13, 7], [6, 12], [17, 15], [15, 13], [12, 14], [14, 16], ... 9 | [11, 9], [9, 7], [6, 8], [8, 10]}; 10 | 11 | scale = -1; 12 | for i=1:length(sk), 13 | s=sk{i}; 14 | if(all(v(s)>0)), 15 | scale = norm([x(s(1))-x(s(2)), y(s(1))-y(s(2))])/meandist(i); 16 | break; 17 | end; 18 | end 19 | -------------------------------------------------------------------------------- /miscs/gen_coco.m: -------------------------------------------------------------------------------- 1 | %% Generate JSON file for MSCOCO keypoint data 2 | clear all; close all; 3 | addpath('jsonlab/') 4 | addpath('../data/mscoco/coco/MatlabAPI/'); 5 | trainval = [1, 0]; 6 | personCnt = 0; 7 | DEBUG = false; 8 | 9 | for isv = trainval 10 | isValidation = isv; 11 | %% initialize COCO api (please specify dataType/annType below) 12 | annTypes = {'person_keypoints' }; 13 | if isValidation 14 | dataType='val5k2014'; annType=annTypes{1}; % specify dataType/annType 15 | else 16 | dataType='train2014'; annType=annTypes{1}; % specify dataType/annType 17 | end 18 | 19 | 20 | annFile=sprintf('../data/mscoco/keypoint/person_keypoints_train+val5k2014/%s_%s.json',annType,dataType); 21 | coco=CocoApi(annFile); 22 | 23 | %% display COCO categories and supercategories 24 | if( ~strcmp(annType,'captions') ) 25 | cats = coco.loadCats(coco.getCatIds()); 26 | sk = cats.skeleton; % get skeleton 27 | skc = {'m', 'm', 'g', 'g', 'y', 'r', 'b', 'y', ... % 1-8 28 | 'r', 'b', 'r', 'b', 'c', 'c', 'c', 'c', 'c', 'y', 'y'}; 29 | nms={cats.name}; fprintf('COCO categories: '); 30 | fprintf('%s, ',nms{:}); fprintf('\n'); 31 | nms=unique({cats.supercategory}); fprintf('COCO supercategories: '); 32 | fprintf('%s, ',nms{:}); fprintf('\n'); 33 | end 34 | 35 | %% get all images containing given categories, select one at random 36 | catIds = coco.getCatIds('catNms',{'person'}); 37 | imgIds = coco.getImgIds('catIds',catIds); 38 | imgId = imgIds(randi(length(imgIds))); 39 | 40 | imageCnt = 0; 41 | keypointsCnt = zeros(17, 1); 42 | fullBodyCnt = 0; 43 | meanArea = 0; 44 | 45 | for i = 1:length(imgIds) 46 | fprintf('%d | %d\n', i, length(imgIds)); 47 | imgId = imgIds(i); 48 | img = coco.loadImgs(imgId); 49 | 50 | %% load and display annotations 51 | annIds = coco.getAnnIds('imgIds',imgId,'catIds',catIds,'iscrowd',[]); 52 | anns = coco.loadAnns(annIds); 53 | n=length(anns); 54 | hasKeypoints = false; 55 | for j=1:n 56 | a=anns(j); if(a.iscrowd), continue; end; hold on; 57 | if a.num_keypoints > 0 58 | hasKeypoints = true; 59 | 60 | kp=a.keypoints; 61 | x=kp(1:3:end)+1; y=kp(2:3:end)+1; v=kp(3:3:end); 62 | vi = find(v > 0); 63 | keypointsCnt(vi) = keypointsCnt(vi) + 1; 64 | meanArea = meanArea + a.area; 65 | 66 | scale = cocoScale(x, y, v); 67 | % if scale == -1 % connot compute scale 68 | if scale <= 0 % connot compute scale 69 | continue; 70 | end 71 | 72 | assert(scale ~= 0); 73 | personCnt = personCnt + 1; 74 | 75 | % write to json 76 | joint_all(personCnt).dataset = 'coco'; 77 | joint_all(personCnt).isValidation = isValidation; 78 | joint_all(personCnt).isValidation = isValidation; 79 | 80 | joint_all(personCnt).img_paths = img.file_name; 81 | joint_all(personCnt).objpos = [mean(x(v>0)), mean(y(v>0))]; 82 | joint_all(personCnt).joint_self = [x; y; v]'; 83 | joint_all(personCnt).scale_provided = scale; 84 | 85 | if DEBUG 86 | if isValidation 87 | datadir ='val2014'; annType=annTypes{1}; % specify dataType/annType 88 | else 89 | datadir='train2014'; annType=annTypes{1}; % specify dataType/annType 90 | end 91 | I = imread(sprintf('../data/mscoco/keypoint/images/%s/%s',datadir,joint_all(personCnt).img_paths)); 92 | I = imresize(I, 1/joint_all(personCnt).scale_provided); 93 | imshow(I); hold on; 94 | x1 = x/scale; 95 | y1 = y/scale; 96 | objpos = joint_all(personCnt).objpos/scale; 97 | show_skeleton(x1, y1, v, sk, skc); 98 | viscircles(objpos,5) 99 | pause;close; 100 | end 101 | end 102 | 103 | if a.num_keypoints == 17 104 | fullBodyCnt = fullBodyCnt + 1; 105 | end 106 | end 107 | 108 | if hasKeypoints 109 | imageCnt = imageCnt + 1; 110 | end 111 | end 112 | end 113 | fprintf('save %d person\n', personCnt); 114 | 115 | opt.FileName = '../data/mscoco/coco_annotations.json'; 116 | opt.FloatFormat = '%.3f'; 117 | opt.Compact = 1; 118 | savejson('', joint_all, opt); 119 | 120 | 121 | % 122 | % clc; 123 | % 124 | % fprintf('validation: images: %d | persons: %d\n', imageCnt, personCnt); 125 | % 126 | % fprintf('%s\n', strjoin(cats.keypoints,', ')) 127 | % for i = 1:length(cats.keypoints) 128 | % fprintf('%d, ', keypointsCnt(i)); 129 | % end 130 | % 131 | % fprintf('\nFull body cnt: %d\n', fullBodyCnt); 132 | % fprintf('mean area: %.4f\n', meanArea/personCnt); -------------------------------------------------------------------------------- /miscs/gen_lsp.m: -------------------------------------------------------------------------------- 1 | % Dataset link 2 | % LSP: http://sam.johnson.io/research/lsp.html 3 | % LSP extend: http://sam.johnson.io/research/lspet.html 4 | function gen_lsp 5 | addpath('jsonlab/') 6 | % in cpp: real scale = param_.target_dist()/meta.scale_self = (41/35)/scale_input 7 | targetDist = 41/35; % in caffe cpp file 41/35 8 | oriTrTe = load('/home/wyang/Data/dataset/LSP/joints.mat'); 9 | extTrain = load('/home/wyang/Data/dataset/lspet_dataset/joints.mat'); 10 | 11 | % in LEEDS: 12 | % 1 Right ankle 13 | % 2 Right knee 14 | % 3 Right hip 15 | % 4 Left hip 16 | % 5 Left knee 17 | % 6 Left ankle 18 | % 7 Right wrist 19 | % 8 Right elbow 20 | % 9 Right shoulder 21 | % 10 Left shoulder 22 | % 11 Left elbow 23 | % 12 Left wrist 24 | % 13 Neck 25 | % 14 Head top 26 | % 15,16 DUMMY 27 | % We want to comply to MPII: (1 - r ankle, 2 - r knee, 3 - r hip, 4 - l hip, 5 - l knee, 6 - l ankle, .. 28 | % 7 - pelvis, 8 - thorax, 9 - upper neck, 10 - head top, 29 | % 11 - r wrist, 12 - r elbow, 13 - r shoulder, 14 - l shoulder, 15 - l elbow, 16 - l wrist) 30 | ordering = [1 2 3, 4 5 6, 15 16, 13 14, 7 8 9, 10 11 12]; % should follow MPI 16 parts..? 31 | oriTrTe.joints(:,[15 16],:) = 0; 32 | oriTrTe.joints = oriTrTe.joints(:,ordering,:); 33 | oriTrTe.joints(3,:,:) = 1 - oriTrTe.joints(3,:,:); 34 | oriTrTe.joints = permute(oriTrTe.joints, [2 1 3]); 35 | 36 | % pelvis 37 | oriTrTe.joints(7, 1:2, :) = mean(oriTrTe.joints(3:4,1:2,:)); 38 | v1 = oriTrTe.joints(3,3,:) > 0; 39 | v2 = oriTrTe.joints(4,3,:) > 0; 40 | v = find(v1 .* v2 == 1); 41 | oriTrTe.joints(7, 3, v) = 1; 42 | 43 | % thorax 44 | oriTrTe.joints(8, 1:2, :) = mean(oriTrTe.joints(13:14,1:2,:)); 45 | v1 = oriTrTe.joints(13,3,:) > 0; 46 | v2 = oriTrTe.joints(14,3,:) > 0; 47 | v = find(v1 .* v2 == 1); 48 | oriTrTe.joints(8, 3, v) = 1; 49 | 50 | extTrain.joints([15 16],:,:) = 0; 51 | extTrain.joints = extTrain.joints(ordering,:,:); 52 | 53 | 54 | % pelvis 55 | extTrain.joints(7, 1:2, :) = mean(extTrain.joints(3:4,1:2,:)); 56 | extTrain.joints(7, 3, :) = 1; 57 | 58 | % thorax 59 | extTrain.joints(8, 1:2, :) = mean(extTrain.joints(13:14,1:2,:)); 60 | extTrain.joints(8, 3, :) = 1; 61 | 62 | count = 1; 63 | 64 | path = {'lspet_dataset/images/im%05d.jpg', 'lsp_dataset/images/im%04d.jpg'}; 65 | local_path = {'/home/wyang/Data/dataset/lspet_dataset/images/im%05d.jpg', '/home/wyang/Data/dataset/LSP/images/im%04d.jpg'}; 66 | num_image = [10000, 1000]; %[10000, 2000]; 67 | 68 | for dataset = 1:2 69 | for im = 1:num_image(dataset) 70 | % trivial stuff for LEEDS 71 | joint_all(count).dataset = 'LEEDS'; 72 | joint_all(count).isValidation = 0; 73 | joint_all(count).img_paths = sprintf(path{dataset}, im); 74 | joint_all(count).numOtherPeople = 0; 75 | joint_all(count).annolist_index = count; 76 | joint_all(count).people_index = 1; 77 | % joints and w, h 78 | if(dataset == 1) 79 | joint_this = extTrain.joints(:,:,im); 80 | else 81 | joint_this = oriTrTe.joints(:,:,im); 82 | end 83 | path_this = sprintf(local_path{dataset}, im); 84 | [h,w,~] = size(imread(path_this)); 85 | 86 | joint_all(count).img_width = w; 87 | joint_all(count).img_height = h; 88 | joint_all(count).joint_self = joint_this; 89 | % infer objpos 90 | invisible = (joint_all(count).joint_self(:,3) == 0); 91 | if(dataset == 1) %lspet is not tightly cropped 92 | joint_all(count).objpos(1) = (min(joint_all(count).joint_self(~invisible, 1)) + max(joint_all(count).joint_self(~invisible, 1))) / 2; 93 | joint_all(count).objpos(2) = (min(joint_all(count).joint_self(~invisible, 2)) + max(joint_all(count).joint_self(~invisible, 2))) / 2; 94 | else 95 | joint_all(count).objpos(1) = w/2; 96 | joint_all(count).objpos(2) = h/2; 97 | end 98 | 99 | count = count + 1; 100 | fprintf('processing %s\n', path_this); 101 | end 102 | end 103 | 104 | % ---- test data 105 | dataset = 2; 106 | for im = 1001:2000 107 | % trivial stuff for LEEDS 108 | joint_all(count).dataset = 'LEEDS'; 109 | joint_all(count).isValidation = 1; 110 | joint_all(count).img_paths = sprintf(path{dataset}, im); 111 | joint_all(count).numOtherPeople = 0; 112 | joint_all(count).annolist_index = count; 113 | joint_all(count).people_index = 1; 114 | % joints and w, h 115 | if(dataset == 1) 116 | joint_this = extTrain.joints(:,:,im); 117 | else 118 | joint_this = oriTrTe.joints(:,:,im); 119 | end 120 | path_this = sprintf(local_path{dataset}, im); 121 | [h,w,~] = size(imread(path_this)); 122 | 123 | joint_all(count).img_width = w; 124 | joint_all(count).img_height = h; 125 | joint_all(count).joint_self = joint_this; 126 | % infer objpos 127 | invisible = (joint_all(count).joint_self(:,3) == 0); 128 | if(dataset == 1) %lspet is not tightly cropped 129 | joint_all(count).objpos(1) = (min(joint_all(count).joint_self(~invisible, 1)) + max(joint_all(count).joint_self(~invisible, 1))) / 2; 130 | joint_all(count).objpos(2) = (min(joint_all(count).joint_self(~invisible, 2)) + max(joint_all(count).joint_self(~invisible, 2))) / 2; 131 | else 132 | joint_all(count).objpos(1) = w/2; 133 | joint_all(count).objpos(2) = h/2; 134 | end 135 | 136 | count = count + 1; 137 | fprintf('processing %s\n', path_this); 138 | end 139 | 140 | 141 | 142 | joint_all = insertMPILikeScale(joint_all, targetDist); 143 | 144 | 145 | opt.FileName = '../data/lsp/LEEDS_annotations.json'; 146 | opt.FloatFormat = '%.3f'; 147 | opt.Compact = 1; 148 | savejson('', joint_all, opt); 149 | 150 | 151 | function joint_all = insertMPILikeScale(joint_all, targetDist) 152 | % calculate scales for each image first 153 | joints = cat(3, joint_all.joint_self); 154 | joints([7 8],:,:) = []; 155 | pa = [2 3 7, 5 4 7, 8 0, 10 11 7, 13 12 7]; 156 | x = permute(joints(:,1,:), [3 1 2]); 157 | y = permute(joints(:,2,:), [3 1 2]); 158 | vis = permute(joints(:,3,:), [3 1 2]); 159 | validLimb = 1:14-1; 160 | 161 | x_diff = x(:, [1:7,9:14]) - x(:, pa([1:7,9:14])); 162 | y_diff = y(:, [1:7,9:14]) - y(:, pa([1:7,9:14])); 163 | limb_vis = vis(:, [1:7,9:14]) .* vis(:, pa([1:7,9:14])); 164 | l = sqrt(x_diff.^2 + y_diff.^2); 165 | 166 | for p = 1:14-1 % for each limb. reference: 7th limb, which is 7 to pa(7) (neck to head) 167 | valid_compare = limb_vis(:,7) .* limb_vis(:,p); 168 | ratio = l(valid_compare==1, p) ./ l(valid_compare==1, 7); 169 | r(p) = median(ratio(~isnan(ratio), 1)); 170 | end 171 | 172 | numFiles = size(x_diff, 1); 173 | all_scales = zeros(numFiles, 1); 174 | 175 | boxSize = 368; 176 | psize = 64; 177 | nSqueezed = 0; 178 | 179 | for file = 1:numFiles %numFiles 180 | l_update = l(file, validLimb) ./ r(validLimb); 181 | l_update = l_update(limb_vis(file,:)==1); 182 | distToObserve = quantile(l_update, 0.75); 183 | scale_in_lmdb = distToObserve/35; % can't get too small. 35 is a magic number to balance to MPI 184 | scale_in_cpp = targetDist/scale_in_lmdb; % can't get too large to be cropped 185 | 186 | visibleParts = joints(:, 3, file); 187 | visibleParts = joints(visibleParts==1, 1:2, file); 188 | x_range = max(visibleParts(:,1)) - min(visibleParts(:,1)); 189 | y_range = max(visibleParts(:,2)) - min(visibleParts(:,2)); 190 | scale_x_ub = (boxSize - psize)/x_range; 191 | scale_y_ub = (boxSize - psize)/y_range; 192 | 193 | scale_shrink = min(min(scale_x_ub, scale_y_ub), scale_in_cpp); 194 | 195 | if scale_shrink ~= scale_in_cpp 196 | nSqueezed = nSqueezed + 1; 197 | fprintf('img %d: scale = %f %f %f shrink %d\n', file, scale_in_cpp, scale_shrink, min(scale_x_ub, scale_y_ub), nSqueezed); 198 | else 199 | fprintf('img %d: scale = %f %f %f\n', file, scale_in_cpp, scale_shrink, min(scale_x_ub, scale_y_ub)); 200 | end 201 | 202 | joint_all(file).scale_provided = targetDist/scale_shrink; % back to lmdb unit 203 | end 204 | 205 | fprintf('total %d squeezed!\n', nSqueezed); 206 | -------------------------------------------------------------------------------- /miscs/gen_mpii.m: -------------------------------------------------------------------------------- 1 | % Generate MPII train/validation split (Tompson et al. CVPR 2015) 2 | % Code ported from 3 | % https://github.com/shihenw/convolutional-pose-machines-release/blob/master/training/genJSON.m 4 | % 5 | % in MPI: (0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee, 6 | % 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck, 9 - head top, 7 | % 10 - r wrist, 11 - r elbow, 12 - r shoulder, 13 - l shoulder, 8 | % 14 - l elbow, 15 - l wrist)" 9 | 10 | 11 | addpath('jsonlab/') 12 | 13 | % Download MPII http://human-pose.mpi-inf.mpg.de/#download 14 | MPIIROOT = '/home/wyang/Data/dataset/mpii'; 15 | 16 | % Download Tompson split from 17 | % http://www.cims.nyu.edu/~tompson/data/mpii_valid_pred.zip 18 | TOMPSONROOT = '/home/wyang/Data/dataset/mpii/Tompson_valid'; 19 | 20 | mat = load(fullfile(MPIIROOT, '/mpii_human_pose_v1_u12_1/mpii_human_pose_v1_u12_1.mat')); 21 | RELEASE = mat.RELEASE; 22 | trainIdx = find(RELEASE.img_train); 23 | 24 | tompson = load(fullfile(TOMPSONROOT, '/mpii_predictions/data/detections')); 25 | tompson_i_p = [tompson.RELEASE_img_index; tompson.RELEASE_person_index]; 26 | 27 | count = 1; 28 | validationCount = 0; 29 | trainCount = 0; 30 | 31 | makeFigure = 0; % Set as 1 for visualizing annotations 32 | 33 | for i = trainIdx 34 | numPeople = length(RELEASE.annolist(i).annorect); 35 | fprintf('image: %d (numPeople: %d) last: %d\n', i, numPeople, trainIdx(end)); 36 | 37 | for p = 1:numPeople 38 | loc = find(sum(~bsxfun(@minus, tompson_i_p, [i;p]))==2, 1); 39 | loc2 = find(tompson.RELEASE_img_index == i); 40 | if(~isempty(loc)) 41 | validationCount = validationCount + 1; 42 | isValidation = 1; 43 | elseif (isempty(loc2)) 44 | trainCount = trainCount + 1; 45 | isValidation = 0; 46 | else 47 | continue; 48 | end 49 | joint_all(count).dataset = 'MPI'; 50 | joint_all(count).isValidation = isValidation; 51 | 52 | try % sometimes no annotation at all.... 53 | anno = RELEASE.annolist(i).annorect(p).annopoints.point; 54 | catch 55 | continue; 56 | end 57 | 58 | % set image path 59 | joint_all(count).img_paths = RELEASE.annolist(i).image.name; 60 | [h,w,~] = size(imread(fullfile(MPIIROOT, '/images/', joint_all(count).img_paths))); 61 | joint_all(count).img_width = w; 62 | joint_all(count).img_height = h; 63 | joint_all(count).objpos = [RELEASE.annolist(i).annorect(p).objpos.x, RELEASE.annolist(i).annorect(p).objpos.y]; 64 | % set part label: joint_all is (np-3-nTrain) 65 | 66 | 67 | % for this very center person 68 | for part = 1:length(anno) 69 | joint_all(count).joint_self(anno(part).id+1, 1) = anno(part).x; 70 | joint_all(count).joint_self(anno(part).id+1, 2) = anno(part).y; 71 | try % sometimes no is_visible... 72 | if(anno(part).is_visible == 0 || anno(part).is_visible == '0') 73 | joint_all(count).joint_self(anno(part).id+1, 3) = 0; 74 | else 75 | joint_all(count).joint_self(anno(part).id+1, 3) = 1; 76 | end 77 | catch 78 | joint_all(count).joint_self(anno(part).id+1, 3) = 1; 79 | end 80 | end 81 | 82 | % pad it into 16x3 83 | dim_1 = size(joint_all(count).joint_self, 1); 84 | dim_3 = size(joint_all(count).joint_self, 3); 85 | pad_dim = 16 - dim_1; 86 | joint_all(count).joint_self = [joint_all(count).joint_self; zeros(pad_dim, 3, dim_3)]; 87 | 88 | % set scale 89 | joint_all(count).scale_provided = RELEASE.annolist(i).annorect(p).scale; 90 | 91 | % for other person on the same image 92 | count_other = 1; 93 | joint_others = cell(0,0); 94 | for op = 1:numPeople 95 | if(op == p), continue; end 96 | try % sometimes no annotation at all.... 97 | anno = RELEASE.annolist(i).annorect(op).annopoints.point; 98 | catch 99 | continue; 100 | end 101 | joint_others{count_other} = zeros(16,3); 102 | for part = 1:length(anno) 103 | joint_all(count).joint_others{count_other}(anno(part).id+1, 1) = anno(part).x; 104 | joint_all(count).joint_others{count_other}(anno(part).id+1, 2) = anno(part).y; 105 | try % sometimes no is_visible... 106 | if(anno(part).is_visible == 0 || anno(part).is_visible == '0') 107 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 0; 108 | else 109 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 1; 110 | end 111 | catch 112 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 1; 113 | end 114 | % pad it into 16x3 115 | dim_1 = size(joint_all(count).joint_others{count_other}, 1); 116 | dim_3 = size(joint_all(count).joint_others{count_other}, 3); 117 | pad_dim = 16 - dim_1; 118 | joint_all(count).joint_others{count_other} = [joint_all(count).joint_others{count_other}; zeros(pad_dim, 3, dim_3)]; 119 | end 120 | 121 | joint_all(count).scale_provided_other(count_other) = RELEASE.annolist(i).annorect(op).scale; 122 | joint_all(count).objpos_other{count_other} = [RELEASE.annolist(i).annorect(op).objpos.x RELEASE.annolist(i).annorect(op).objpos.y]; 123 | 124 | count_other = count_other + 1; 125 | end 126 | 127 | if(makeFigure) % visualizing to debug 128 | imshow(imread(fullfile(MPIIROOT, '/images/', joint_all(count).img_paths))); 129 | hold on; 130 | visiblePart = joint_all(count).joint_self(:,3) == 1; 131 | invisiblePart = joint_all(count).joint_self(:,3) == 0; 132 | plot(joint_all(count).joint_self(visiblePart, 1), joint_all(count).joint_self(visiblePart,2), 'gx', 'MarkerSize', 10); 133 | plot(joint_all(count).joint_self(invisiblePart,1), joint_all(count).joint_self(invisiblePart,2), 'rx', 'MarkerSize', 10); 134 | plot(joint_all(count).objpos(1), joint_all(count).objpos(2), 'cs'); 135 | if(~isempty(joint_all(count).joint_others)) 136 | for op = 1:size(joint_all(count).joint_others, 3) 137 | visiblePart = joint_all(count).joint_others{op}(:,3) == 1; 138 | invisiblePart = joint_all(count).joint_others{op}(:,3) == 0; 139 | plot(joint_all(count).joint_others{op}(visiblePart,1), joint_all(count).joint_others{op}(visiblePart,2), 'mx', 'MarkerSize', 10); 140 | plot(joint_all(count).joint_others{op}(invisiblePart,1), joint_all(count).joint_others{op}(invisiblePart,2), 'cx', 'MarkerSize', 10); 141 | end 142 | end 143 | pause; 144 | close all; 145 | end 146 | joint_all(count).annolist_index = i; 147 | joint_all(count).people_index = p; 148 | joint_all(count).numOtherPeople = length(joint_all(count).joint_others); 149 | count = count + 1; 150 | end 151 | end 152 | 153 | opt.FileName = '../data/mpii/mpii_annotations.json'; 154 | opt.FloatFormat = '%.3f'; 155 | opt.Compact = 1; 156 | savejson('', joint_all, opt); -------------------------------------------------------------------------------- /pose/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import models 5 | from . import utils 6 | 7 | import os, sys 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 9 | from progress.bar import Bar as Bar 10 | 11 | __version__ = '0.1.0' -------------------------------------------------------------------------------- /pose/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mpii import Mpii 2 | from .mscoco import Mscoco 3 | from .lsp import LSP 4 | 5 | __all__ = ('Mpii', 'Mscoco', 'LSP') -------------------------------------------------------------------------------- /pose/datasets/lsp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import json 6 | import random 7 | import math 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | from pose.utils.osutils import * 13 | from pose.utils.imutils import * 14 | from pose.utils.transforms import * 15 | 16 | 17 | class LSP(data.Dataset): 18 | """ 19 | LSP extended dataset (11,000 train, 1000 test) 20 | Original datasets contain 14 keypoints. We interpolate mid-hip and mid-shoulder and change the indices to match 21 | the MPII dataset (16 keypoints). 22 | 23 | Wei Yang (bearpaw@GitHub) 24 | 2017-09-28 25 | """ 26 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1, 27 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'): 28 | self.img_folder = img_folder # root image folders 29 | self.is_train = train # training set or test set 30 | self.inp_res = inp_res 31 | self.out_res = out_res 32 | self.sigma = sigma 33 | self.scale_factor = scale_factor 34 | self.rot_factor = rot_factor 35 | self.label_type = label_type 36 | 37 | # create train/val split 38 | with open(jsonfile) as anno_file: 39 | self.anno = json.load(anno_file) 40 | 41 | self.train, self.valid = [], [] 42 | for idx, val in enumerate(self.anno): 43 | if val['isValidation'] == True: 44 | self.valid.append(idx) 45 | else: 46 | self.train.append(idx) 47 | self.mean, self.std = self._compute_mean() 48 | 49 | def _compute_mean(self): 50 | meanstd_file = './data/lsp/mean.pth.tar' 51 | if isfile(meanstd_file): 52 | meanstd = torch.load(meanstd_file) 53 | else: 54 | mean = torch.zeros(3) 55 | std = torch.zeros(3) 56 | for index in self.train: 57 | a = self.anno[index] 58 | img_path = os.path.join(self.img_folder, a['img_paths']) 59 | img = load_image(img_path) # CxHxW 60 | mean += img.view(img.size(0), -1).mean(1) 61 | std += img.view(img.size(0), -1).std(1) 62 | mean /= len(self.train) 63 | std /= len(self.train) 64 | meanstd = { 65 | 'mean': mean, 66 | 'std': std, 67 | } 68 | torch.save(meanstd, meanstd_file) 69 | if self.is_train: 70 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2])) 71 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2])) 72 | 73 | return meanstd['mean'], meanstd['std'] 74 | 75 | def __getitem__(self, index): 76 | sf = self.scale_factor 77 | rf = self.rot_factor 78 | if self.is_train: 79 | a = self.anno[self.train[index]] 80 | else: 81 | a = self.anno[self.valid[index]] 82 | 83 | img_path = os.path.join(self.img_folder, a['img_paths']) 84 | pts = torch.Tensor(a['joint_self']) 85 | # pts[:, 0:2] -= 1 # Convert pts to zero based 86 | 87 | # c = torch.Tensor(a['objpos']) - 1 88 | c = torch.Tensor(a['objpos']) 89 | s = a['scale_provided'] 90 | 91 | # Adjust center/scale slightly to avoid cropping limbs 92 | if c[0] != -1: 93 | # c[1] = c[1] + 15 * s 94 | s = s * 1.4375 95 | 96 | # For single-person pose estimation with a centered/scaled figure 97 | nparts = pts.size(0) 98 | img = load_image(img_path) # CxHxW 99 | 100 | r = 0 101 | # if self.is_train: 102 | # s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] 103 | # r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 104 | # 105 | # # # Flip 106 | # # if random.random() <= 0.5: 107 | # # img = torch.from_numpy(fliplr(img.numpy())).float() 108 | # # pts = shufflelr(pts, width=img.size(2), dataset='mpii') 109 | # # c[0] = img.size(2) - c[0] 110 | # 111 | # # Color 112 | # img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 113 | # img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 114 | # img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 115 | 116 | # Prepare image and groundtruth map 117 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) 118 | inp = color_normalize(inp, self.mean, self.std) 119 | 120 | # Generate ground truth 121 | tpts = pts.clone() 122 | target = torch.zeros(nparts, self.out_res, self.out_res) 123 | for i in range(nparts): 124 | # if tpts[i, 2] > 0: # This is evil!! 125 | if tpts[i, 0] > 0: 126 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r)) 127 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) 128 | 129 | # Meta info 130 | meta = {'index' : index, 'center' : c, 'scale' : s, 131 | 'pts' : pts, 'tpts' : tpts} 132 | 133 | return inp, target, meta 134 | 135 | def __len__(self): 136 | if self.is_train: 137 | return len(self.train) 138 | else: 139 | return len(self.valid) -------------------------------------------------------------------------------- /pose/datasets/mpii.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import json 6 | import random 7 | import math 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | from pose.utils.osutils import * 13 | from pose.utils.imutils import * 14 | from pose.utils.transforms import * 15 | 16 | 17 | class Mpii(data.Dataset): 18 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1, 19 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'): 20 | self.img_folder = img_folder # root image folders 21 | self.is_train = train # training set or test set 22 | self.inp_res = inp_res 23 | self.out_res = out_res 24 | self.sigma = sigma 25 | self.scale_factor = scale_factor 26 | self.rot_factor = rot_factor 27 | self.label_type = label_type 28 | 29 | # create train/val split 30 | with open(jsonfile) as anno_file: 31 | self.anno = json.load(anno_file) 32 | 33 | self.train, self.valid = [], [] 34 | for idx, val in enumerate(self.anno): 35 | if val['isValidation'] == True: 36 | self.valid.append(idx) 37 | else: 38 | self.train.append(idx) 39 | self.mean, self.std = self._compute_mean() 40 | 41 | def _compute_mean(self): 42 | meanstd_file = './data/mpii/mean.pth.tar' 43 | if isfile(meanstd_file): 44 | meanstd = torch.load(meanstd_file) 45 | else: 46 | mean = torch.zeros(3) 47 | std = torch.zeros(3) 48 | for index in self.train: 49 | a = self.anno[index] 50 | img_path = os.path.join(self.img_folder, a['img_paths']) 51 | img = load_image(img_path) # CxHxW 52 | mean += img.view(img.size(0), -1).mean(1) 53 | std += img.view(img.size(0), -1).std(1) 54 | mean /= len(self.train) 55 | std /= len(self.train) 56 | meanstd = { 57 | 'mean': mean, 58 | 'std': std, 59 | } 60 | torch.save(meanstd, meanstd_file) 61 | if self.is_train: 62 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2])) 63 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2])) 64 | 65 | return meanstd['mean'], meanstd['std'] 66 | 67 | def __getitem__(self, index): 68 | sf = self.scale_factor 69 | rf = self.rot_factor 70 | if self.is_train: 71 | a = self.anno[self.train[index]] 72 | else: 73 | a = self.anno[self.valid[index]] 74 | 75 | img_path = os.path.join(self.img_folder, a['img_paths']) 76 | pts = torch.Tensor(a['joint_self']) 77 | # pts[:, 0:2] -= 1 # Convert pts to zero based 78 | 79 | # c = torch.Tensor(a['objpos']) - 1 80 | c = torch.Tensor(a['objpos']) 81 | s = a['scale_provided'] 82 | 83 | # Adjust center/scale slightly to avoid cropping limbs 84 | if c[0] != -1: 85 | c[1] = c[1] + 15 * s 86 | s = s * 1.25 87 | 88 | # For single-person pose estimation with a centered/scaled figure 89 | nparts = pts.size(0) 90 | img = load_image(img_path) # CxHxW 91 | 92 | r = 0 93 | if self.is_train: 94 | s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] 95 | r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 96 | 97 | # Flip 98 | if random.random() <= 0.5: 99 | img = torch.from_numpy(fliplr(img.numpy())).float() 100 | pts = shufflelr(pts, width=img.size(2), dataset='mpii') 101 | c[0] = img.size(2) - c[0] 102 | 103 | # Color 104 | img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 105 | img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 106 | img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 107 | 108 | # Prepare image and groundtruth map 109 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) 110 | inp = color_normalize(inp, self.mean, self.std) 111 | 112 | # Generate ground truth 113 | tpts = pts.clone() 114 | target = torch.zeros(nparts, self.out_res, self.out_res) 115 | for i in range(nparts): 116 | # if tpts[i, 2] > 0: # This is evil!! 117 | if tpts[i, 1] > 0: 118 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r)) 119 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) 120 | 121 | # Meta info 122 | meta = {'index' : index, 'center' : c, 'scale' : s, 123 | 'pts' : pts, 'tpts' : tpts} 124 | 125 | return inp, target, meta 126 | 127 | def __len__(self): 128 | if self.is_train: 129 | return len(self.train) 130 | else: 131 | return len(self.valid) 132 | -------------------------------------------------------------------------------- /pose/datasets/mscoco.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import json 6 | import random 7 | import math 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | from pose.utils.osutils import * 13 | from pose.utils.imutils import * 14 | from pose.utils.transforms import * 15 | 16 | 17 | class Mscoco(data.Dataset): 18 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1, 19 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'): 20 | self.img_folder = img_folder # root image folders 21 | self.is_train = train # training set or test set 22 | self.inp_res = inp_res 23 | self.out_res = out_res 24 | self.sigma = sigma 25 | self.scale_factor = scale_factor 26 | self.rot_factor = rot_factor 27 | self.label_type = label_type 28 | 29 | # create train/val split 30 | with open(jsonfile) as anno_file: 31 | self.anno = json.load(anno_file) 32 | 33 | self.train, self.valid = [], [] 34 | for idx, val in enumerate(self.anno): 35 | if val['isValidation'] == True: 36 | self.valid.append(idx) 37 | else: 38 | self.train.append(idx) 39 | self.mean, self.std = self._compute_mean() 40 | 41 | def _compute_mean(self): 42 | meanstd_file = './data/mscoco/mean.pth.tar' 43 | if isfile(meanstd_file): 44 | meanstd = torch.load(meanstd_file) 45 | else: 46 | print('==> compute mean') 47 | mean = torch.zeros(3) 48 | std = torch.zeros(3) 49 | cnt = 0 50 | for index in self.train: 51 | cnt += 1 52 | print( '{} | {}'.format(cnt, len(self.train))) 53 | a = self.anno[index] 54 | img_path = os.path.join(self.img_folder, a['img_paths']) 55 | img = load_image(img_path) # CxHxW 56 | mean += img.view(img.size(0), -1).mean(1) 57 | std += img.view(img.size(0), -1).std(1) 58 | mean /= len(self.train) 59 | std /= len(self.train) 60 | meanstd = { 61 | 'mean': mean, 62 | 'std': std, 63 | } 64 | torch.save(meanstd, meanstd_file) 65 | if self.is_train: 66 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2])) 67 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2])) 68 | 69 | return meanstd['mean'], meanstd['std'] 70 | 71 | def __getitem__(self, index): 72 | sf = self.scale_factor 73 | rf = self.rot_factor 74 | if self.is_train: 75 | a = self.anno[self.train[index]] 76 | else: 77 | a = self.anno[self.valid[index]] 78 | 79 | img_path = os.path.join(self.img_folder, a['img_paths']) 80 | pts = torch.Tensor(a['joint_self']) 81 | # pts[:, 0:2] -= 1 # Convert pts to zero based 82 | 83 | # c = torch.Tensor(a['objpos']) - 1 84 | c = torch.Tensor(a['objpos']) 85 | s = a['scale_provided'] 86 | 87 | # Adjust center/scale slightly to avoid cropping limbs 88 | if c[0] != -1: 89 | c[1] = c[1] + 15 * s 90 | s = s * 1.25 91 | 92 | # For single-person pose estimation with a centered/scaled figure 93 | nparts = pts.size(0) 94 | img = load_image(img_path) # CxHxW 95 | 96 | r = 0 97 | if self.is_train: 98 | s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0] 99 | r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0 100 | 101 | # Flip 102 | if random.random() <= 0.5: 103 | img = torch.from_numpy(fliplr(img.numpy())).float() 104 | pts = shufflelr(pts, width=img.size(2), dataset='mpii') 105 | c[0] = img.size(2) - c[0] 106 | 107 | # Color 108 | img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 109 | img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 110 | img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1) 111 | 112 | # Prepare image and groundtruth map 113 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r) 114 | inp = color_normalize(inp, self.mean, self.std) 115 | 116 | # Generate ground truth 117 | tpts = pts.clone() 118 | target = torch.zeros(nparts, self.out_res, self.out_res) 119 | for i in range(nparts): 120 | if tpts[i, 2] > 0: # COCO visible: 0-no label, 1-label + invisible, 2-label + visible 121 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r)) 122 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type) 123 | 124 | # Meta info 125 | meta = {'index' : index, 'center' : c, 'scale' : s, 126 | 'pts' : pts, 'tpts' : tpts} 127 | 128 | return inp, target, meta 129 | 130 | def __len__(self): 131 | if self.is_train: 132 | return len(self.train) 133 | else: 134 | return len(self.valid) -------------------------------------------------------------------------------- /pose/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .hourglass import * 2 | from .preresnet import * -------------------------------------------------------------------------------- /pose/models/hourglass.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Hourglass network inserted in the pre-activated Resnet 3 | Use lr=0.01 for current version 4 | (c) YANG, Wei 5 | ''' 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | # from .preresnet import BasicBlock, Bottleneck 10 | 11 | 12 | __all__ = ['HourglassNet', 'hg'] 13 | 14 | class Bottleneck(nn.Module): 15 | expansion = 2 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, mobile=False): 18 | super(Bottleneck, self).__init__() 19 | 20 | self.bn1 = nn.BatchNorm2d(inplanes) 21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | if mobile: 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=1, bias=True, groups=planes) 27 | else: 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 29 | padding=1, bias=True) 30 | self.bn3 = nn.BatchNorm2d(planes) 31 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.bn1(x) 40 | out = self.relu(out) 41 | out = self.conv1(out) 42 | 43 | out = self.bn2(out) 44 | out = self.relu(out) 45 | out = self.conv2(out) 46 | 47 | out = self.bn3(out) 48 | out = self.relu(out) 49 | out = self.conv3(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | 56 | return out 57 | 58 | 59 | class Hourglass(nn.Module): 60 | def __init__(self, block, num_blocks, planes, depth, mobile): 61 | super(Hourglass, self).__init__() 62 | self.mobile = mobile 63 | self.depth = depth 64 | self.block = block 65 | self.upsample = nn.Upsample(scale_factor=2) 66 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth) 67 | 68 | def _make_residual(self, block, num_blocks, planes): 69 | layers = [] 70 | for i in range(0, num_blocks): 71 | layers.append(block(planes*block.expansion, planes, mobile=self.mobile)) 72 | return nn.Sequential(*layers) 73 | 74 | def _make_hour_glass(self, block, num_blocks, planes, depth): 75 | hg = [] 76 | for i in range(depth): 77 | res = [] 78 | for j in range(3): 79 | res.append(self._make_residual(block, num_blocks, planes)) 80 | if i == 0: 81 | res.append(self._make_residual(block, num_blocks, planes)) 82 | hg.append(nn.ModuleList(res)) 83 | return nn.ModuleList(hg) 84 | 85 | def _hour_glass_forward(self, n, x): 86 | up1 = self.hg[n-1][0](x) 87 | low1 = F.max_pool2d(x, 2, stride=2) 88 | low1 = self.hg[n-1][1](low1) 89 | 90 | if n > 1: 91 | low2 = self._hour_glass_forward(n-1, low1) 92 | else: 93 | low2 = self.hg[n-1][3](low1) 94 | low3 = self.hg[n-1][2](low2) 95 | up2 = self.upsample(low3) 96 | out = up1 + up2 97 | return out 98 | 99 | def forward(self, x): 100 | return self._hour_glass_forward(self.depth, x) 101 | 102 | 103 | class HourglassNet(nn.Module): 104 | '''Hourglass model from Newell et al ECCV 2016''' 105 | def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, mobile=False): 106 | super(HourglassNet, self).__init__() 107 | 108 | self.mobile = mobile 109 | self.inplanes = 64 110 | self.num_feats = 128 111 | self.num_stacks = num_stacks 112 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 113 | bias=True) 114 | self.bn1 = nn.BatchNorm2d(self.inplanes) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.layer1 = self._make_residual(block, self.inplanes, 1) 117 | self.layer2 = self._make_residual(block, self.inplanes, 1) 118 | self.layer3 = self._make_residual(block, self.num_feats, 1) 119 | self.maxpool = nn.MaxPool2d(2, stride=2) 120 | 121 | # build hourglass modules 122 | ch = self.num_feats*block.expansion 123 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] 124 | for i in range(num_stacks): 125 | hg.append(Hourglass(block, num_blocks, self.num_feats, 4, self.mobile)) 126 | res.append(self._make_residual(block, self.num_feats, num_blocks)) 127 | fc.append(self._make_fc(ch, ch)) 128 | score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True)) 129 | if i < num_stacks-1: 130 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) 131 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True)) 132 | self.hg = nn.ModuleList(hg) 133 | self.res = nn.ModuleList(res) 134 | self.fc = nn.ModuleList(fc) 135 | self.score = nn.ModuleList(score) 136 | self.fc_ = nn.ModuleList(fc_) 137 | self.score_ = nn.ModuleList(score_) 138 | 139 | def _make_residual(self, block, planes, blocks, stride=1): 140 | downsample = None 141 | if stride != 1 or self.inplanes != planes * block.expansion: 142 | downsample = nn.Sequential( 143 | nn.Conv2d(self.inplanes, planes * block.expansion, 144 | kernel_size=1, stride=stride, bias=True), 145 | ) 146 | 147 | layers = [] 148 | layers.append(block(self.inplanes, planes, stride, downsample, self.mobile)) 149 | self.inplanes = planes * block.expansion 150 | for i in range(1, blocks): 151 | layers.append(block(self.inplanes, planes, mobile=self.mobile)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def _make_fc(self, inplanes, outplanes): 156 | bn = nn.BatchNorm2d(inplanes) 157 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) 158 | return nn.Sequential( 159 | conv, 160 | bn, 161 | self.relu, 162 | ) 163 | 164 | def forward(self, x): 165 | out = [] 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) 169 | 170 | x = self.layer1(x) 171 | x = self.maxpool(x) 172 | x = self.layer2(x) 173 | x = self.layer3(x) 174 | 175 | for i in range(self.num_stacks): 176 | y = self.hg[i](x) 177 | y = self.res[i](y) 178 | y = self.fc[i](y) 179 | score = self.score[i](y) 180 | out.append(score) 181 | if i < self.num_stacks-1: 182 | fc_ = self.fc_[i](y) 183 | score_ = self.score_[i](score) 184 | x = x + fc_ + score_ 185 | 186 | return out 187 | 188 | 189 | def hg(**kwargs): 190 | model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'], 191 | num_classes=kwargs['num_classes'], mobile=kwargs['mobile']) 192 | return model 193 | -------------------------------------------------------------------------------- /pose/models/preresnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activated Resnet for cifar dataset. 2 | Ported form https://github.com/facebook/fb.resnet.torch/blob/master/models/preresnet.lua 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | import math 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | __all__ = ['PreResNet', 'preresnet20', 'preresnet32', 'preresnet44', 'preresnet56', 11 | 'preresnet110', 'preresnet1202'] 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(inplanes) 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.bn1(x) 36 | out = self.relu(out) 37 | out = self.conv1(out) 38 | 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | out = self.conv2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.bn1 = nn.BatchNorm2d(inplanes) 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 60 | padding=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | out = self.conv2(out) 77 | 78 | out = self.bn3(out) 79 | out = self.relu(out) 80 | out = self.conv3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | out += residual 86 | 87 | return out 88 | 89 | 90 | class PreResNet(nn.Module): 91 | 92 | def __init__(self, block, layers, num_classes=1000): 93 | self.inplanes = 16 94 | super(PreResNet, self).__init__() 95 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 96 | bias=False) 97 | self.layer1 = self._make_layer(block, 16, layers[0]) 98 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 99 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 100 | self.bn1 = nn.BatchNorm2d(64*block.expansion) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.fc1 = nn.Conv2d(64*block.expansion, 64*block.expansion, kernel_size=1, bias=False) 103 | self.bn2 = nn.BatchNorm2d(64*block.expansion) 104 | self.fc2 = nn.Conv2d(64*block.expansion, num_classes, kernel_size=1) 105 | # self.avgpool = nn.AvgPool2d(8) 106 | # self.fc = nn.Linear(64*block.expansion, num_classes) 107 | 108 | # for m in self.modules(): 109 | # if isinstance(m, nn.Conv2d): 110 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | # elif isinstance(m, nn.BatchNorm2d): 113 | # m.weight.data.fill_(1) 114 | # m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | # nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | 136 | x = self.layer1(x) 137 | x = self.layer2(x) 138 | x = self.layer3(x) 139 | x = self.fc1(self.relu(self.bn1(x))) 140 | x = self.fc2(self.relu(self.bn2(x))) 141 | # x = self.sigmoid(x) 142 | # x = self.avgpool(x) 143 | # x = x.view(x.size(0), -1) 144 | 145 | return [x] 146 | 147 | 148 | def preresnet20(**kwargs): 149 | """Constructs a PreResNet-20 model. 150 | """ 151 | model = PreResNet(BasicBlock, [3, 3, 3], **kwargs) 152 | return model 153 | 154 | 155 | def preresnet32(**kwargs): 156 | """Constructs a PreResNet-32 model. 157 | """ 158 | model = PreResNet(BasicBlock, [5, 5, 5], **kwargs) 159 | return model 160 | 161 | 162 | def preresnet44(**kwargs): 163 | """Constructs a PreResNet-44 model. 164 | """ 165 | model = PreResNet(Bottleneck, [7, 7, 7], **kwargs) 166 | return model 167 | 168 | 169 | def preresnet56(**kwargs): 170 | """Constructs a PreResNet-56 model. 171 | """ 172 | model = PreResNet(Bottleneck, [9, 9, 9], **kwargs) 173 | return model 174 | 175 | 176 | def preresnet110(**kwargs): 177 | """Constructs a PreResNet-110 model. 178 | """ 179 | model = PreResNet(Bottleneck, [18, 18, 18], **kwargs) 180 | return model 181 | 182 | def preresnet1202(**kwargs): 183 | """Constructs a PreResNet-1202 model. 184 | """ 185 | model = PreResNet(Bottleneck, [200, 200, 200], **kwargs) 186 | return model -------------------------------------------------------------------------------- /pose/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .evaluation import * 4 | from .imutils import * 5 | from .logger import * 6 | from .misc import * 7 | from .osutils import * 8 | from .transforms import * 9 | -------------------------------------------------------------------------------- /pose/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from random import randint 7 | 8 | from .misc import * 9 | from .transforms import transform, transform_preds 10 | 11 | __all__ = ['accuracy', 'AverageMeter'] 12 | 13 | def get_preds(scores): 14 | ''' get predictions from score maps in torch Tensor 15 | return type: torch.LongTensor 16 | ''' 17 | assert scores.dim() == 4, 'Score maps should be 4-dim' 18 | maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) 19 | 20 | maxval = maxval.view(scores.size(0), scores.size(1), 1) 21 | idx = idx.view(scores.size(0), scores.size(1), 1) + 1 22 | 23 | preds = idx.repeat(1, 1, 2).float() 24 | 25 | preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 26 | preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1 27 | 28 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float() 29 | preds *= pred_mask 30 | return preds 31 | 32 | def calc_dists(preds, target, normalize): 33 | preds = preds.float() 34 | target = target.float() 35 | dists = torch.zeros(preds.size(1), preds.size(0)) 36 | for n in range(preds.size(0)): 37 | for c in range(preds.size(1)): 38 | if target[n,c,0] > 1 and target[n, c, 1] > 1: 39 | dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] 40 | else: 41 | dists[c, n] = -1 42 | return dists 43 | 44 | def dist_acc(dists, thr=0.5): 45 | ''' Return percentage below threshold while ignoring values with a -1 ''' 46 | if dists.ne(-1).sum() > 0: 47 | return float(dists.le(thr).eq(dists.ne(-1)).sum()) / float(dists.ne(-1).sum()) 48 | else: 49 | return -1 50 | 51 | def accuracy(output, target, idxs, thr=0.5): 52 | ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations 53 | First value to be returned is average accuracy across 'idxs', followed by individual accuracies 54 | ''' 55 | preds = get_preds(output) 56 | gts = get_preds(target) 57 | norm = torch.ones(preds.size(0))*output.size(3)/10 58 | dists = calc_dists(preds, gts, norm) 59 | 60 | acc = torch.zeros(len(idxs)+1) 61 | avg_acc = 0 62 | cnt = 0 63 | 64 | for i in range(len(idxs)): 65 | acc[i+1] = dist_acc(dists[idxs[i]-1]) 66 | if acc[i+1] >= 0: 67 | avg_acc = avg_acc + acc[i+1] 68 | cnt += 1 69 | 70 | if cnt != 0: 71 | acc[0] = avg_acc / cnt 72 | return acc 73 | 74 | def final_preds(output, center, scale, res): 75 | coords = get_preds(output) # float type 76 | 77 | # pose-processing 78 | for n in range(coords.size(0)): 79 | for p in range(coords.size(1)): 80 | hm = output[n][p] 81 | px = int(math.floor(coords[n][p][0])) 82 | py = int(math.floor(coords[n][p][1])) 83 | if px > 1 and px < res[0] and py > 1 and py < res[1]: 84 | diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) 85 | coords[n][p] += diff.sign() * .25 86 | coords += 0.5 87 | preds = coords.clone() 88 | 89 | # Transform back 90 | for i in range(coords.size(0)): 91 | preds[i] = transform_preds(coords[i], center[i], scale[i], res) 92 | 93 | if preds.dim() < 3: 94 | preds = preds.view(1, preds.size()) 95 | 96 | return preds 97 | 98 | 99 | class AverageMeter(object): 100 | """Computes and stores the average and current value""" 101 | def __init__(self): 102 | self.reset() 103 | 104 | def reset(self): 105 | self.val = 0 106 | self.avg = 0 107 | self.sum = 0 108 | self.count = 0 109 | 110 | def update(self, val, n=1): 111 | self.val = val 112 | self.sum += val * n 113 | self.count += n 114 | self.avg = self.sum / self.count 115 | -------------------------------------------------------------------------------- /pose/utils/imutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import scipy.misc 7 | 8 | from .misc import * 9 | 10 | def im_to_numpy(img): 11 | img = to_numpy(img) 12 | img = np.transpose(img, (1, 2, 0)) # H*W*C 13 | return img 14 | 15 | def im_to_torch(img): 16 | img = np.transpose(img, (2, 0, 1)) # C*H*W 17 | img = to_torch(img).float() 18 | if img.max() > 1: 19 | img /= 255 20 | return img 21 | 22 | def load_image(img_path): 23 | # H x W x C => C x H x W 24 | return im_to_torch(scipy.misc.imread(img_path, mode='RGB')) 25 | 26 | def resize(img, owidth, oheight): 27 | img = im_to_numpy(img) 28 | print('%f %f' % (img.min(), img.max())) 29 | img = scipy.misc.imresize( 30 | img, 31 | (oheight, owidth) 32 | ) 33 | img = im_to_torch(img) 34 | print('%f %f' % (img.min(), img.max())) 35 | return img 36 | 37 | # ============================================================================= 38 | # Helpful functions generating groundtruth labelmap 39 | # ============================================================================= 40 | 41 | def gaussian(shape=(7,7),sigma=1): 42 | """ 43 | 2D gaussian mask - should give the same result as MATLAB's 44 | fspecial('gaussian',[shape],[sigma]) 45 | """ 46 | m,n = [(ss-1.)/2. for ss in shape] 47 | y,x = np.ogrid[-m:m+1,-n:n+1] 48 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 49 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 50 | return to_torch(h).float() 51 | 52 | def draw_labelmap(img, pt, sigma, type='Gaussian'): 53 | # Draw a 2D gaussian 54 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py 55 | img = to_numpy(img) 56 | 57 | # Check that any part of the gaussian is in-bounds 58 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 59 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 60 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or 61 | br[0] < 0 or br[1] < 0): 62 | # If not, just return the image as is 63 | return to_torch(img) 64 | 65 | # Generate gaussian 66 | size = 6 * sigma + 1 67 | x = np.arange(0, size, 1, float) 68 | y = x[:, np.newaxis] 69 | x0 = y0 = size // 2 70 | # The gaussian is not normalized, we want the center value to equal 1 71 | if type == 'Gaussian': 72 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 73 | elif type == 'Cauchy': 74 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) 75 | 76 | 77 | # Usable gaussian range 78 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 79 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 80 | # Image range 81 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 82 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 83 | 84 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 85 | return to_torch(img) 86 | 87 | # ============================================================================= 88 | # Helpful display functions 89 | # ============================================================================= 90 | 91 | def gauss(x, a, b, c, d=0): 92 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d 93 | 94 | def color_heatmap(x): 95 | x = to_numpy(x) 96 | color = np.zeros((x.shape[0],x.shape[1],3)) 97 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 98 | color[:,:,1] = gauss(x, 1, .5, .3) 99 | color[:,:,2] = gauss(x, 1, .2, .3) 100 | color[color > 1] = 1 101 | color = (color * 255).astype(np.uint8) 102 | return color 103 | 104 | def imshow(img): 105 | npimg = im_to_numpy(img*255).astype(np.uint8) 106 | plt.imshow(npimg) 107 | plt.axis('off') 108 | 109 | def show_joints(img, pts): 110 | imshow(img) 111 | 112 | for i in range(pts.size(0)): 113 | if pts[i, 2] > 0: 114 | plt.plot(pts[i, 0], pts[i, 1], 'yo') 115 | plt.axis('off') 116 | 117 | def show_sample(inputs, target): 118 | num_sample = inputs.size(0) 119 | num_joints = target.size(1) 120 | height = target.size(2) 121 | width = target.size(3) 122 | 123 | for n in range(num_sample): 124 | inp = resize(inputs[n], width, height) 125 | out = inp 126 | for p in range(num_joints): 127 | tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5 128 | out = torch.cat((out, tgt), 2) 129 | 130 | imshow(out) 131 | plt.show() 132 | 133 | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None): 134 | inp = to_numpy(inp * 255) 135 | out = to_numpy(out) 136 | 137 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) 138 | for i in range(3): 139 | img[:, :, i] = inp[i, :, :] 140 | 141 | if parts_to_show is None: 142 | parts_to_show = np.arange(out.shape[0]) 143 | 144 | # Generate a single image to display input/output pair 145 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) 146 | size = img.shape[0] // num_rows 147 | 148 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) 149 | full_img[:img.shape[0], :img.shape[1]] = img 150 | 151 | inp_small = scipy.misc.imresize(img, [size, size]) 152 | 153 | # Set up heatmap display for each part 154 | for i, part in enumerate(parts_to_show): 155 | part_idx = part 156 | out_resized = scipy.misc.imresize(out[part_idx], [size, size]) 157 | out_resized = out_resized.astype(float)/255 158 | out_img = inp_small.copy() * .3 159 | color_hm = color_heatmap(out_resized) 160 | out_img += color_hm * .7 161 | 162 | col_offset = (i % num_cols + num_rows) * size 163 | row_offset = (i // num_cols) * size 164 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img 165 | 166 | return full_img 167 | 168 | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None): 169 | batch_img = [] 170 | for n in range(min(inputs.size(0), 4)): 171 | inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n]) 172 | batch_img.append( 173 | sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show) 174 | ) 175 | return np.concatenate(batch_img) 176 | -------------------------------------------------------------------------------- /pose/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 11 | 12 | def savefig(fname, dpi=None): 13 | dpi = 150 if dpi == None else dpi 14 | plt.savefig(fname, dpi=dpi) 15 | 16 | def plot_overlap(logger, names=None): 17 | names = logger.names if names == None else names 18 | numbers = logger.numbers 19 | for _, name in enumerate(names): 20 | x = np.arange(len(numbers[name])) 21 | plt.plot(x, np.asarray(numbers[name])) 22 | return [logger.title + '(' + name + ')' for name in names] 23 | 24 | class Logger(object): 25 | '''Save training process to log file with simple plot function.''' 26 | def __init__(self, fpath, title=None, resume=False): 27 | self.file = None 28 | self.resume = resume 29 | self.title = '' if title == None else title 30 | if fpath is not None: 31 | if resume: 32 | self.file = open(fpath, 'r') 33 | name = self.file.readline() 34 | self.names = name.rstrip().split('\t') 35 | self.numbers = {} 36 | for _, name in enumerate(self.names): 37 | self.numbers[name] = [] 38 | 39 | for numbers in self.file: 40 | numbers = numbers.rstrip().split('\t') 41 | for i in range(0, len(numbers)): 42 | self.numbers[self.names[i]].append(numbers[i]) 43 | self.file.close() 44 | self.file = open(fpath, 'a') 45 | else: 46 | self.file = open(fpath, 'w') 47 | 48 | def set_names(self, names): 49 | if self.resume: 50 | pass 51 | # initialize numbers as empty list 52 | self.numbers = {} 53 | self.names = names 54 | for _, name in enumerate(self.names): 55 | self.file.write(name) 56 | self.file.write('\t') 57 | self.numbers[name] = [] 58 | self.file.write('\n') 59 | self.file.flush() 60 | 61 | 62 | def append(self, numbers): 63 | assert len(self.names) == len(numbers), 'Numbers do not match names' 64 | for index, num in enumerate(numbers): 65 | self.file.write("{0:.6f}".format(num)) 66 | self.file.write('\t') 67 | self.numbers[self.names[index]].append(num) 68 | self.file.write('\n') 69 | self.file.flush() 70 | 71 | def plot(self, names=None): 72 | names = self.names if names == None else names 73 | numbers = self.numbers 74 | for _, name in enumerate(names): 75 | x = np.arange(len(numbers[name])) 76 | plt.plot(x, np.asarray(numbers[name])) 77 | plt.legend([self.title + '(' + name + ')' for name in names]) 78 | plt.grid(True) 79 | 80 | def close(self): 81 | if self.file is not None: 82 | self.file.close() 83 | 84 | class LoggerMonitor(object): 85 | '''Load and visualize multiple logs.''' 86 | def __init__ (self, paths): 87 | '''paths is a distionary with {name:filepath} pair''' 88 | self.loggers = [] 89 | for title, path in paths.items(): 90 | logger = Logger(path, title=title, resume=True) 91 | self.loggers.append(logger) 92 | 93 | def plot(self, names=None): 94 | plt.figure() 95 | plt.subplot(121) 96 | legend_text = [] 97 | for logger in self.loggers: 98 | legend_text += plot_overlap(logger, names) 99 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 100 | plt.grid(True) 101 | 102 | if __name__ == '__main__': 103 | # # Example 104 | # logger = Logger('test.txt') 105 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 106 | 107 | # length = 100 108 | # t = np.arange(length) 109 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 112 | 113 | # for i in range(0, length): 114 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 115 | # logger.plot() 116 | 117 | # Example: logger monitor 118 | paths = { 119 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 120 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 121 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 122 | } 123 | 124 | field = ['Valid Acc.'] 125 | 126 | monitor = LoggerMonitor(paths) 127 | monitor.plot(names=field) 128 | savefig('test.eps') -------------------------------------------------------------------------------- /pose/utils/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import shutil 5 | import torch 6 | import math 7 | import numpy as np 8 | import scipy.io 9 | import matplotlib.pyplot as plt 10 | 11 | def to_numpy(tensor): 12 | if torch.is_tensor(tensor): 13 | return tensor.cpu().numpy() 14 | elif type(tensor).__module__ != 'numpy': 15 | raise ValueError("Cannot convert {} to numpy array" 16 | .format(type(tensor))) 17 | return tensor 18 | 19 | 20 | def to_torch(ndarray): 21 | if type(ndarray).__module__ == 'numpy': 22 | return torch.from_numpy(ndarray) 23 | elif not torch.is_tensor(ndarray): 24 | raise ValueError("Cannot convert {} to torch tensor" 25 | .format(type(ndarray))) 26 | return ndarray 27 | 28 | 29 | def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): 30 | preds = to_numpy(preds) 31 | filepath = os.path.join(checkpoint, filename) 32 | torch.save(state, filepath) 33 | scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds}) 34 | 35 | if snapshot and state.epoch % snapshot == 0: 36 | shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 37 | 38 | if is_best: 39 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 40 | scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds}) 41 | 42 | 43 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): 44 | preds = to_numpy(preds) 45 | filepath = os.path.join(checkpoint, filename) 46 | scipy.io.savemat(filepath, mdict={'preds' : preds}) 47 | 48 | 49 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): 50 | """Sets the learning rate to the initial LR decayed by schedule""" 51 | if epoch in schedule: 52 | lr *= gamma 53 | for param_group in optimizer.param_groups: 54 | param_group['lr'] = lr 55 | return lr -------------------------------------------------------------------------------- /pose/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import errno 5 | 6 | def mkdir_p(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | 13 | def isfile(fname): 14 | return os.path.isfile(fname) 15 | 16 | def isdir(dirname): 17 | return os.path.isdir(dirname) 18 | 19 | def join(path, *paths): 20 | return os.path.join(path, *paths) 21 | -------------------------------------------------------------------------------- /pose/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import scipy.misc 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | from .misc import * 10 | from .imutils import * 11 | 12 | 13 | def color_normalize(x, mean, std): 14 | if x.size(0) == 1: 15 | x = x.repeat(3, 1, 1) 16 | 17 | for t, m, s in zip(x, mean, std): 18 | t.sub_(m) 19 | return x 20 | 21 | 22 | def flip_back(flip_output, dataset='mpii'): 23 | """ 24 | flip output map 25 | """ 26 | if dataset == 'mpii': 27 | matchedParts = ( 28 | [0,5], [1,4], [2,3], 29 | [10,15], [11,14], [12,13] 30 | ) 31 | else: 32 | print('Not supported dataset: ' + dataset) 33 | 34 | # flip output horizontally 35 | flip_output = fliplr(flip_output.numpy()) 36 | 37 | # Change left-right parts 38 | for pair in matchedParts: 39 | tmp = np.copy(flip_output[:, pair[0], :, :]) 40 | flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :] 41 | flip_output[:, pair[1], :, :] = tmp 42 | 43 | return torch.from_numpy(flip_output).float() 44 | 45 | 46 | def shufflelr(x, width, dataset='mpii'): 47 | """ 48 | flip coords 49 | """ 50 | if dataset == 'mpii': 51 | matchedParts = ( 52 | [0,5], [1,4], [2,3], 53 | [10,15], [11,14], [12,13] 54 | ) 55 | else: 56 | print('Not supported dataset: ' + dataset) 57 | 58 | # Flip horizontal 59 | x[:, 0] = width - x[:, 0] 60 | 61 | # Change left-right parts 62 | for pair in matchedParts: 63 | tmp = x[pair[0], :].clone() 64 | x[pair[0], :] = x[pair[1], :] 65 | x[pair[1], :] = tmp 66 | 67 | return x 68 | 69 | 70 | def fliplr(x): 71 | if x.ndim == 3: 72 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1)) 73 | elif x.ndim == 4: 74 | for i in range(x.shape[0]): 75 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1)) 76 | return x.astype(float) 77 | 78 | 79 | def get_transform(center, scale, res, rot=0): 80 | """ 81 | General image processing functions 82 | """ 83 | # Generate transformation matrix 84 | h = 200 * scale 85 | t = np.zeros((3, 3)) 86 | t[0, 0] = float(res[1]) / h 87 | t[1, 1] = float(res[0]) / h 88 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 89 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 90 | t[2, 2] = 1 91 | if not rot == 0: 92 | rot = -rot # To match direction of rotation from cropping 93 | rot_mat = np.zeros((3,3)) 94 | rot_rad = rot * np.pi / 180 95 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 96 | rot_mat[0,:2] = [cs, -sn] 97 | rot_mat[1,:2] = [sn, cs] 98 | rot_mat[2,2] = 1 99 | # Need to rotate around center 100 | t_mat = np.eye(3) 101 | t_mat[0,2] = -res[1]/2 102 | t_mat[1,2] = -res[0]/2 103 | t_inv = t_mat.copy() 104 | t_inv[:2,2] *= -1 105 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 106 | return t 107 | 108 | 109 | def transform(pt, center, scale, res, invert=0, rot=0): 110 | # Transform pixel location to different reference 111 | t = get_transform(center, scale, res, rot=rot) 112 | if invert: 113 | t = np.linalg.inv(t) 114 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T 115 | new_pt = np.dot(t, new_pt) 116 | return new_pt[:2].astype(int) + 1 117 | 118 | 119 | def transform_preds(coords, center, scale, res): 120 | # size = coords.size() 121 | # coords = coords.view(-1, coords.size(-1)) 122 | # print(coords.size()) 123 | for p in range(coords.size(0)): 124 | coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) 125 | return coords 126 | 127 | 128 | def crop(img, center, scale, res, rot=0): 129 | img = im_to_numpy(img) 130 | 131 | # Preprocessing for efficient cropping 132 | ht, wd = img.shape[0], img.shape[1] 133 | sf = scale * 200.0 / res[0] 134 | if sf < 2: 135 | sf = 1 136 | else: 137 | new_size = int(np.math.floor(max(ht, wd) / sf)) 138 | new_ht = int(np.math.floor(ht / sf)) 139 | new_wd = int(np.math.floor(wd / sf)) 140 | if new_size < 2: 141 | return torch.zeros(res[0], res[1], img.shape[2]) \ 142 | if len(img.shape) > 2 else torch.zeros(res[0], res[1]) 143 | else: 144 | img = scipy.misc.imresize(img, [new_ht, new_wd]) 145 | center = center * 1.0 / sf 146 | scale = scale / sf 147 | 148 | # Upper left point 149 | ul = np.array(transform([0, 0], center, scale, res, invert=1)) 150 | # Bottom right point 151 | br = np.array(transform(res, center, scale, res, invert=1)) 152 | 153 | # Padding so that when rotated proper amount of context is included 154 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 155 | if not rot == 0: 156 | ul -= pad 157 | br += pad 158 | 159 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 160 | if len(img.shape) > 2: 161 | new_shape += [img.shape[2]] 162 | new_img = np.zeros(new_shape) 163 | 164 | # Range to fill new array 165 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 166 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 167 | # Range to sample from original image 168 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 169 | old_y = max(0, ul[1]), min(len(img), br[1]) 170 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 171 | 172 | if not rot == 0: 173 | # Remove padding 174 | new_img = scipy.misc.imrotate(new_img, rot) 175 | new_img = new_img[pad:-pad, pad:-pad] 176 | 177 | new_img = im_to_torch(scipy.misc.imresize(new_img, res)) 178 | return new_img 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle==0.6.1 2 | cycler==0.10.0 3 | dask==1.0.0 4 | decorator==4.3.0 5 | functools32==3.2.3.post2 6 | matplotlib==2.0.2 7 | networkx==2.2 8 | numpy==1.15.4 9 | Pillow==5.3.0 10 | pyparsing==2.3.0 11 | python-dateutil==2.7.5 12 | pytz==2018.7 13 | PyWavelets==1.0.1 14 | scikit-image==0.14.1 15 | scipy==1.2.0 16 | six==1.12.0 17 | subprocess32==3.5.3 18 | toolz==0.9.0 19 | torch==0.4.0 20 | torchvision==0.2.1 21 | -------------------------------------------------------------------------------- /tools/mpii_demo.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim 8 | from pose.utils.osutils import mkdir_p, isfile, isdir, join 9 | import pose.models as models 10 | from scipy.ndimage import gaussian_filter, maximum_filter 11 | import cv2 12 | import numpy as np 13 | 14 | def load_image(imgfile, w, h ): 15 | image = cv2.imread(imgfile) 16 | image = cv2.resize(image, (w, h)) 17 | image = image[:, :, ::-1] # BGR -> RGB 18 | image = image / 255.0 19 | image = image - np.array([[[0.4404, 0.4440, 0.4327]]]) # Extract mean RGB 20 | image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW 21 | image = image[np.newaxis, :, :, :] 22 | return image 23 | 24 | def load_model(arch='hg', stacks=2, blocks=1, num_classes=16, mobile=True, 25 | resume='checkpoint/pytorch-pose/mpii_hg_s2_b1_mobile/checkpoint.pth.tar'): 26 | # create model 27 | model = models.__dict__[arch](num_stacks=stacks, num_blocks=blocks, num_classes=num_classes, mobile=mobile) 28 | model = torch.nn.DataParallel(model).cuda() 29 | 30 | # optionally resume from a checkpoint 31 | if isfile(resume): 32 | print("=> loading checkpoint '{}'".format(resume)) 33 | checkpoint = torch.load(resume) 34 | model.load_state_dict(checkpoint['state_dict']) 35 | print("=> loaded checkpoint '{}' (epoch {})" 36 | .format(resume, checkpoint['epoch'])) 37 | else: 38 | print("=> no checkpoint found at '{}'".format(resume)) 39 | 40 | cudnn.benchmark = True 41 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 42 | model.eval() 43 | return model 44 | 45 | def inference(model, image): 46 | model.eval() 47 | input_tensor = torch.from_numpy(image).float().cuda() 48 | output = model(input_tensor) 49 | output = output[-1] 50 | output = output.data.cpu() 51 | print(output.shape) 52 | kps = post_process_heatmap(output[0,:,:,:]) 53 | return kps 54 | 55 | 56 | def post_process_heatmap(heatMap, kpConfidenceTh=0.2): 57 | kplst = list() 58 | for i in range(heatMap.shape[0]): 59 | _map = heatMap[i, :, :] 60 | _map = gaussian_filter(_map, sigma=1) 61 | _nmsPeaks = non_max_supression(_map, windowSize=3, threshold=1e-6) 62 | 63 | y, x = np.where(_nmsPeaks == _nmsPeaks.max()) 64 | if len(x) > 0 and len(y) > 0: 65 | kplst.append((int(x[0]), int(y[0]), _nmsPeaks[y[0], x[0]])) 66 | else: 67 | kplst.append((0, 0, 0)) 68 | 69 | kp = np.array(kplst) 70 | return kp 71 | 72 | 73 | def non_max_supression(plain, windowSize=3, threshold=1e-6): 74 | # clear value less than threshold 75 | under_th_indices = plain < threshold 76 | plain[under_th_indices] = 0 77 | return plain * (plain == maximum_filter(plain, footprint=np.ones((windowSize, windowSize)))) 78 | 79 | def render_kps(cvmat, kps, scale_x, scale_y): 80 | for _kp in kps: 81 | _x, _y, _conf = _kp 82 | if _conf > 0.2: 83 | cv2.circle(cvmat, center=(int(_x*4*scale_x), int(_y*4*scale_y)), color=(0,0,255), radius=5) 84 | 85 | return cvmat 86 | 87 | 88 | def main(): 89 | model = load_model() 90 | in_res_h , in_res_w = 192, 192 91 | 92 | imgfile = "/home/yli150/sample.jpg" 93 | image = load_image(imgfile, in_res_w, in_res_h) 94 | print(image.shape) 95 | 96 | kps = inference(model, image) 97 | 98 | cvmat = cv2.imread(imgfile) 99 | scale_x = cvmat.shape[1]*1.0/in_res_w 100 | scale_y = cvmat.shape[0]*1.0/in_res_h 101 | render_kps(cvmat, kps, scale_x, scale_y) 102 | print(kps) 103 | cv2.imshow('x', cvmat) 104 | cv2.waitKey(0) 105 | 106 | if __name__ == '__main__': 107 | main() -------------------------------------------------------------------------------- /tools/mpii_export_to_onxx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import torch 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim 8 | 9 | from pose.utils.logger import Logger, savefig 10 | from pose.utils.osutils import mkdir_p, isfile, isdir, join 11 | import pose.models as models 12 | 13 | model_names = sorted(name for name in models.__dict__ 14 | if name.islower() and not name.startswith("__") 15 | and callable(models.__dict__[name])) 16 | 17 | def main(args): 18 | 19 | # create checkpoint dir 20 | if not isdir(args.checkpoint): 21 | mkdir_p(args.checkpoint) 22 | 23 | # create model 24 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks)) 25 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes, 26 | mobile=args.mobile) 27 | model.eval() 28 | 29 | # optionally resume from a checkpoint 30 | title = 'mpii-' + args.arch 31 | if args.checkpoint: 32 | if isfile(args.checkpoint): 33 | print("=> loading checkpoint '{}'".format(args.checkpoint)) 34 | checkpoint = torch.load(args.checkpoint) 35 | args.start_epoch = checkpoint['epoch'] 36 | 37 | # create new OrderedDict that does not contain `module.` 38 | from collections import OrderedDict 39 | new_state_dict = OrderedDict() 40 | for k, v in checkpoint['state_dict'].items(): 41 | name = k[7:] # remove `module.` 42 | new_state_dict[name] = v 43 | # load params 44 | model.load_state_dict(new_state_dict) 45 | 46 | print("=> loaded checkpoint '{}' (epoch {})" 47 | .format(args.checkpoint, checkpoint['epoch'])) 48 | else: 49 | print("=> no checkpoint found at '{}'".format(args.checkpoint)) 50 | else: 51 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title) 52 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc']) 53 | 54 | cudnn.benchmark = True 55 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 56 | 57 | dummy_input = torch.randn(1, 3, args.in_res, args.in_res) 58 | torch.onnx.export(model, dummy_input, args.out_onnx) 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 62 | # Model structure 63 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg', 64 | choices=model_names, 65 | help='model architecture: ' + 66 | ' | '.join(model_names) + 67 | ' (default: resnet18)') 68 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N', 69 | help='Number of hourglasses to stack') 70 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N', 71 | help='Number of residual modules at each location in the hourglass') 72 | parser.add_argument('--num-classes', default=16, type=int, metavar='N', 73 | help='Number of keypoints') 74 | parser.add_argument('--mobile', default=False, type=bool, metavar='N', 75 | help='use depthwise convolution in bottneck-block') 76 | parser.add_argument('--out_onnx', required=True, type=str, metavar='N', 77 | help='exported onnx file') 78 | parser.add_argument('--checkpoint', required=True, type=str, metavar='N', 79 | help='pre-trained model checkpoint') 80 | parser.add_argument('--in_res', required=True, type=int, metavar='N', 81 | help='input shape 128 or 256') 82 | main(parser.parse_args()) --------------------------------------------------------------------------------