├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── data
├── acc_curve.png
├── lsp
│ └── images
│ │ ├── lsp_dataset
│ │ └── lspet_dataset
├── mpii
│ ├── mean.pth.tar
│ └── mpii_annotations.json
└── mscoco
│ └── README.md
├── evaluation
├── data
│ ├── detections.mat
│ └── detections_our_format.mat
├── eval_PCKh.m
├── eval_PCKh.py
├── showskeletons_joints.m
└── utils.py
├── example
├── lsp.py
├── mpii.py
└── mscoco.py
├── miscs
├── cocoScale.m
├── gen_coco.m
├── gen_lsp.m
└── gen_mpii.m
├── pose
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── lsp.py
│ ├── mpii.py
│ └── mscoco.py
├── models
│ ├── __init__.py
│ ├── hourglass.py
│ └── preresnet.py
└── utils
│ ├── __init__.py
│ ├── evaluation.py
│ ├── imutils.py
│ ├── logger.py
│ ├── misc.py
│ ├── osutils.py
│ └── transforms.py
├── requirements.txt
└── tools
├── mpii_demo.py
└── mpii_export_to_onxx.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | checkpoint
3 | dev
4 | *.pth.tar
5 | data/mpii/images
6 | !data/mpii/mean.pth.tar
7 | *.json
8 | *debug*
9 | *.idea/*
10 | test_transforms.py
11 | experiments
12 | data/mscoco/coco
13 | data/mscoco/keypoint
14 | *.m~
15 | miscs/posetrack
16 | miscs/h36m
17 | data/h36m
18 |
19 | __pycache__/
20 | *.py[cod]
21 | *$py.class
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | env/
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *,cover
63 | .hypothesis/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # IPython Notebook
87 | .ipynb_checkpoints
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # celery beat schedule file
93 | celerybeat-schedule
94 |
95 | # dotenv
96 | .env
97 |
98 | # virtualenv
99 | venv/
100 | ENV/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pose/progress"]
2 | path = pose/progress
3 | url = https://github.com/verigak/progress.git
4 | [submodule "miscs/jsonlab"]
5 | path = miscs/jsonlab
6 | url = https://github.com/fangq/jsonlab
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 | {one line to give the program's name and a brief idea of what it does.}
635 | Copyright (C) {year} {name of author}
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | {project} Copyright (C) {year} {fullname}
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-Pose
2 |
3 | PyTorch-Pose is a PyTorch implementation of the general pipeline for 2D single human pose estimation. The aim is to provide the interface of the training/inference/evaluation, and the dataloader with various data augmentation options for the most popular human pose databases (e.g., [the MPII human pose](http://human-pose.mpi-inf.mpg.de), [LSP](http://www.comp.leeds.ac.uk/mat4saj/lsp.html) and [FLIC](http://bensapp.github.io/flic-dataset.html)).
4 |
5 | Some codes for data preparation and augmentation are brought from the [Stacked hourglass network](https://github.com/anewell/pose-hg-train). Thanks to the original author.
6 |
7 | ## Models
8 | | Model|in_res |featrues| # of Weights |Head|Shoulder| Elbow| Wrist| Hip |Knee| Ankle| Mean|Link|
9 | | --- |---| ----|----------- | ----| ----| ---| ---| ---| ---| ---| ---|----|
10 | | hg_s2_b1|256|128|6.73m| 95.74| 94.51| 87.68| 81.70| 87.81| 80.88 |76.83| 86.58|[GoogleDrive](https://drive.google.com/open?id=1c_YR0NKmRfRvLcNB5wFpm75VOkC9Y1n4)
11 | | hg_s2_b1_mobile|256|128|2.31m|95.80| 93.61| 85.50| 79.63| 86.13| 77.82| 73.62| 84.69|[GoogleDrive](https://drive.google.com/open?id=1FxTRhiw6_dS8X1jBBUw_bxHX6RoBJaJO)
12 | | hg_s2_b1_tiny|192|128|2.31m|94.95| 92.87|84.59| 78.19| 84.68| 77.70| 73.07| 83.88|[GoogleDrive](https://drive.google.com/open?id=1qrkaUDPbHwdSBozRbN150O4Mu9HMWIOG)
13 |
14 |
15 | ## Installation
16 | 1. Create a virtualenv
17 | ```
18 | virtualenv -p /usr/bin/python2.7 posevenv
19 | ```
20 | 2. Install all dependencies in virtualenv
21 | ```
22 | source posevenv/bin/activate
23 | pip install -r requirements.txt
24 | ```
25 | 3. Clone the repository with submodule
26 | ```
27 | git clone --recursive https://github.com/yuanyuanli85/pytorch-pose.git
28 | ```
29 |
30 | 4. Create a symbolic link to the `images` directory of the MPII dataset:
31 | ```
32 | ln -s PATH_TO_MPII_IMAGES_DIR data/mpii/images
33 | ```
34 |
35 | 5. Disable cudnn for batchnorm layer to solve bug in pytorch0.4.0
36 | ```
37 | sed -i "1194s/torch\.backends\.cudnn\.enabled/False/g" ./pose_venv/lib/python2.7/site-packages/torch/nn/functional.py
38 | ```
39 | ## Training
40 |
41 | * Normal network configuration, in_res 256, features 128
42 | ```sh
43 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1/ --in_res 256 --features 256
44 | ```
45 |
46 | * Mobile network configuration, in_res 256, features 128
47 | ```sh
48 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1_mobile/ --mobile True --in_res 256 --features 256
49 | ```
50 |
51 | * Tiny network configuration, in_res 192, features 128
52 | ```sh
53 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1_tiny/ --mobile True --in_res 192 --features 128
54 | ```
55 |
56 | ## Evaluation
57 |
58 | Run evaluation to generate mat file
59 | ```sh
60 | python example/mpii.py -a hg --stacks 2 --blocks 1 --checkpoint checkpoint/hg_s2_b1/ --resume checkpoint/hg_s2_b1/model_best.pth.tar -e
61 | ```
62 | * `--resume_checkpoint` is the checkpoint want to evaluate
63 |
64 | Run `evaluation/eval_PCKh.py` to get val score
65 |
66 | ## Export pytorch checkpoint to onnx
67 | ```sh
68 | python tools/mpii_export_to_onxx.py -a hg -s 2 -b 1 --num-classes 16 --mobile True --in_res 256 --checkpoint checkpoint/model_best.pth.tar
69 | --out_onnx checkpoint/model_best.onnx
70 | ```
71 | Here
72 | * `--checkpoint` is the checkpoint want to export
73 | * `--out_onnx` is the exported onnx file
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/data/acc_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/data/acc_curve.png
--------------------------------------------------------------------------------
/data/lsp/images/lsp_dataset:
--------------------------------------------------------------------------------
1 | /home/wyang/Data/dataset/LSP
--------------------------------------------------------------------------------
/data/lsp/images/lspet_dataset:
--------------------------------------------------------------------------------
1 | /home/wyang/Data/dataset/LSP_ext/
--------------------------------------------------------------------------------
/data/mpii/mean.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/data/mpii/mean.pth.tar
--------------------------------------------------------------------------------
/data/mscoco/README.md:
--------------------------------------------------------------------------------
1 | ## Directory structure
2 |
3 | - `coco`: [coco API](https://github.com/pdollar/coco)
4 | - `keypoint`: COCO keypoint dataset
5 | - `images`: tain and val datasets
6 | - `train2014`
7 | - `val2014`
8 | - `person_keypoints_train+val5k2014`: annotations (JSON files)
9 | - `coco_annotations.json`: reformatted annotation generated by `./miscs/gen_coco.m`
10 |
--------------------------------------------------------------------------------
/evaluation/data/detections.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/evaluation/data/detections.mat
--------------------------------------------------------------------------------
/evaluation/data/detections_our_format.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanyuanli85/pytorch-pose/e0651e77e7832fa603921022d5456b59e435cd91/evaluation/data/detections_our_format.mat
--------------------------------------------------------------------------------
/evaluation/eval_PCKh.m:
--------------------------------------------------------------------------------
1 | % wei
2 | addpath('./utils');
3 |
4 | % set `debug = true` if you want to visualize the skeletons
5 | % You also need to download the MPII dataset and specify the path of
6 | % annopath = `mpii_human_pose_v1_u12_1.mat`
7 | debug = false;
8 | annopath = 'path_to/mpii_human_pose_v1_u12_1.mat';
9 |
10 | load('data/detections.mat');
11 | tompson_i = RELEASE_img_index;
12 |
13 | threshold = 0.5;
14 | SC_BIAS = 0.6; % THIS IS DEFINED IN util_get_head_size.m
15 |
16 | pa = [2, 3, 7, 7, 4, 5, 8, 9, 10, 0, 12, 13, 8, 8, 14, 15];
17 |
18 | load('data/detections_our_format.mat', 'dataset_joints', 'jnt_missing', 'pos_pred_src', 'pos_gt_src', 'headboxes_src');
19 |
20 | % predictions
21 | predfile = '/home/wyang/code/pose/pytorch-pose/checkpoint/mpii/hg_s2_b1_mean/preds_valid.mat';
22 | preds = load(predfile,'preds');
23 | pos_pred_src = permute(preds.preds, [2, 3, 1]);
24 |
25 | % DEBUG
26 | if debug
27 | mat = load(annopath);
28 |
29 | for i = 1:length(tompson_i)
30 | imname = mat.RELEASE.annolist(tompson_i(i)).image.name;
31 | fprintf('%s\n', imname);
32 | im = imread(['/home/wyang/Data/dataset/mpii/images/' imname]);
33 | pred = pos_pred_src(:, :, i);
34 | showskeletons_joints(im, pred, pa);
35 | pause; clf;
36 | end
37 | end
38 |
39 | head = find(ismember(dataset_joints, 'head'));
40 | lsho = find(ismember(dataset_joints, 'lsho'));
41 | lelb = find(ismember(dataset_joints, 'lelb'));
42 | lwri = find(ismember(dataset_joints, 'lwri'));
43 | lhip = find(ismember(dataset_joints, 'lhip'));
44 | lkne = find(ismember(dataset_joints, 'lkne'));
45 | lank = find(ismember(dataset_joints, 'lank'));
46 |
47 | rsho = find(ismember(dataset_joints, 'rsho'));
48 | relb = find(ismember(dataset_joints, 'relb'));
49 | rwri = find(ismember(dataset_joints, 'rwri'));
50 | rhip = find(ismember(dataset_joints, 'rhip'));
51 | rkne = find(ismember(dataset_joints, 'rkne'));
52 | rank = find(ismember(dataset_joints, 'rank'));
53 |
54 | % Calculate PCKh again for a few joints just to make sure our evaluation
55 | % matches Leonid's...
56 | jnt_visible = 1 - jnt_missing;
57 | uv_err = pos_pred_src - pos_gt_src;
58 | uv_err = sqrt(sum(uv_err .* uv_err, 2));
59 | headsizes = headboxes_src(2,:,:) - headboxes_src(1,:,:);
60 | headsizes = sqrt(sum(headsizes .* headsizes, 2));
61 | headsizes = headsizes * SC_BIAS;
62 | scaled_uv_err = squeeze(uv_err ./ repmat(headsizes, size(uv_err, 1), 1, 1));
63 |
64 | % Zero the contribution of joints that are missing
65 | scaled_uv_err = scaled_uv_err .* jnt_visible;
66 | jnt_count = squeeze(sum(jnt_visible, 2));
67 | less_than_threshold = (scaled_uv_err < threshold) .* jnt_visible;
68 | PCKh = 100 * squeeze(sum(less_than_threshold, 2)) ./ jnt_count;
69 |
70 | % save PCK all
71 | range = (0:0.01:0.5);
72 | pckAll = zeros(length(range),16);
73 | for r = 1:length( range)
74 | threshold = range(r);
75 | less_than_threshold = (scaled_uv_err < threshold) .* jnt_visible;
76 | pckAll(r, :) = 100 * squeeze(sum(less_than_threshold, 2)) ./ jnt_count;
77 |
78 | end
79 |
80 | [~, name, ~] = fileparts(predfile);
81 |
82 | % Uncomment if you want to save the result
83 | % save(sprintf('pckAll-%s.mat', name), 'scaled_uv_err', 'pos_pred_src');
84 |
85 | clc;
86 | fprintf(' Head , Shoulder , Elbow , Wrist , Hip , Knee , Ankle , Mean , \n');
87 | fprintf('name , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f , %.2f% , \n',...
88 | PCKh(head), (PCKh(lsho)+PCKh(rsho))/2, (PCKh(lelb)+PCKh(relb))/2,...
89 | (PCKh(lwri)+PCKh(rwri))/2, (PCKh(lhip)+PCKh(rhip))/2, ...
90 | (PCKh(lkne)+PCKh(rkne))/2, (PCKh(lank)+PCKh(rank))/2, mean(PCKh([1:6, 9:16])));
91 | fprintf('\n');
92 |
93 |
--------------------------------------------------------------------------------
/evaluation/eval_PCKh.py:
--------------------------------------------------------------------------------
1 | from scipy.io import loadmat
2 | from numpy import transpose
3 | import skimage.io as sio
4 | from utils import visualize
5 | import numpy as np
6 | import os
7 |
8 |
9 | detection = loadmat('evaluation/data/detections.mat')
10 | det_idxs = detection['RELEASE_img_index']
11 | debug = 0
12 | threshold = 0.5
13 | SC_BIAS = 0.6
14 |
15 | pa = [2, 3, 7, 7, 4, 5, 8, 9, 10, 0, 12, 13, 8, 8, 14, 15]
16 |
17 | dict = loadmat('evaluation/data/detections_our_format.mat')
18 | dataset_joints = dict['dataset_joints']
19 | jnt_missing = dict['jnt_missing']
20 | pos_pred_src = dict['pos_pred_src']
21 | pos_gt_src = dict['pos_gt_src']
22 | headboxes_src = dict['headboxes_src']
23 |
24 |
25 |
26 | #predictions
27 | model_name = 'hg4'
28 | predfile = 'checkpoint/mpii/' + model_name + '/preds_valid.mat'
29 | preds = loadmat(predfile)['preds']
30 | pos_pred_src = transpose(preds, [1, 2, 0])
31 |
32 |
33 | if debug:
34 |
35 | for i in range(len(det_idxs[0])):
36 | anno = mat['RELEASE']['annolist'][0, 0][0][det_idxs[0][i] - 1]
37 | fn = anno['image']['name'][0, 0][0]
38 | imagePath = 'data/mpii/images/' + fn
39 | oriImg = sio.imread(imagePath)
40 | pred = pos_pred_src[:, :, i]
41 | visualize(oriImg, pred, pa)
42 |
43 |
44 | head = np.where(dataset_joints == 'head')[1][0]
45 | lsho = np.where(dataset_joints == 'lsho')[1][0]
46 | lelb = np.where(dataset_joints == 'lelb')[1][0]
47 | lwri = np.where(dataset_joints == 'lwri')[1][0]
48 | lhip = np.where(dataset_joints == 'lhip')[1][0]
49 | lkne = np.where(dataset_joints == 'lkne')[1][0]
50 | lank = np.where(dataset_joints == 'lank')[1][0]
51 |
52 | rsho = np.where(dataset_joints == 'rsho')[1][0]
53 | relb = np.where(dataset_joints == 'relb')[1][0]
54 | rwri = np.where(dataset_joints == 'rwri')[1][0]
55 | rkne = np.where(dataset_joints == 'rkne')[1][0]
56 | rank = np.where(dataset_joints == 'rank')[1][0]
57 | rhip = np.where(dataset_joints == 'rhip')[1][0]
58 |
59 | jnt_visible = 1 - jnt_missing
60 | uv_error = pos_pred_src - pos_gt_src
61 | uv_err = np.linalg.norm(uv_error, axis=1)
62 | headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
63 | headsizes = np.linalg.norm(headsizes, axis=0)
64 | headsizes *= SC_BIAS
65 | scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
66 | scaled_uv_err = np.divide(uv_err, scale)
67 | scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible)
68 | jnt_count = np.sum(jnt_visible, axis=1)
69 | less_than_threshold = np.multiply((scaled_uv_err < threshold), jnt_visible)
70 | PCKh = np.divide(100. * np.sum(less_than_threshold, axis=1), jnt_count)
71 |
72 |
73 | # save
74 | rng = np.arange(0, 0.5, 0.01)
75 | pckAll = np.zeros((len(rng), 16))
76 |
77 | for r in range(len(rng)):
78 | threshold = rng[r]
79 | less_than_threshold = np.multiply(scaled_uv_err < threshold, jnt_visible)
80 | pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count)
81 |
82 | name = predfile.split(os.sep)[-1]
83 | PCKh = np.ma.array(PCKh, mask=False)
84 | PCKh.mask[6:8] = True
85 | print("Model, Head, Shoulder, Elbow, Wrist, Hip , Knee , Ankle , Mean")
86 | print('{:s} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f}'.format(model_name, PCKh[head], 0.5 * (PCKh[lsho] + PCKh[rsho])\
87 | , 0.5 * (PCKh[lelb] + PCKh[relb]),0.5 * (PCKh[lwri] + PCKh[rwri]), 0.5 * (PCKh[lhip] + PCKh[rhip]), 0.5 * (PCKh[lkne] + PCKh[rkne]) \
88 | , 0.5 * (PCKh[lank] + PCKh[rank]), np.mean(PCKh)))
--------------------------------------------------------------------------------
/evaluation/showskeletons_joints.m:
--------------------------------------------------------------------------------
1 | function h = showskeletons_joints(im, points, pa, msize, torsobox)
2 | if nargin < 4
3 | msize = 4;
4 | end
5 | if nargin < 5
6 | torsobox = [];
7 | end
8 | p_no = numel(pa);
9 |
10 | switch p_no
11 | case 26
12 | partcolor = {'g','g','y','r','r','r','r','y','y','y','m','m','m','m','y','b','b','b','b','y','y','y','c','c','c','c'};
13 | case 14
14 | partcolor = {'g','g','y','r','r','y','m','m','y','b','b','y','c','c'};
15 | case 10
16 | partcolor = {'g','g','y','y','y','r','m','m','m','b','b','b','y','c','c'};
17 | case 18
18 | partcolor = {'g','g','y','r','r','r','r','y','y','y','y','b','b','b','b','y','y','y'};
19 | case 16
20 | partcolor = {'g','g','g','r','r','r','y','y','y','b','b','b','c','c','m','m'};
21 | otherwise
22 | error('showboxes: not supported');
23 | end
24 | h = imshow(im); hold on;
25 | if ~isempty(points)
26 | x = points(:,1);
27 | y = points(:,2);
28 | for n = 1:size(x,1)
29 | for child = 1:p_no
30 | if child == 0 || pa(child) == 0
31 | continue;
32 | end
33 | x1 = x(pa(child));
34 | y1 = y(pa(child));
35 | x2 = x(child);
36 | y2 = y(child);
37 |
38 | plot(x1, y1, 'o', 'color', partcolor{child}, ...
39 | 'MarkerSize',msize, 'MarkerFaceColor', partcolor{child});
40 | plot(x2, y2, 'o', 'color', partcolor{child}, ...
41 | 'MarkerSize',msize, 'MarkerFaceColor', partcolor{child});
42 | line([x1 x2],[y1 y2],'color',partcolor{child},'linewidth',round(msize/2));
43 | end
44 | end
45 | end
46 | if ~isempty(torsobox)
47 | plotbox(torsobox,'w--');
48 | end
49 | drawnow; hold off;
50 |
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def visualize(oriImg, points, pa):
3 | import matplotlib
4 | import cv2 as cv
5 | import matplotlib.pyplot as plt
6 | import math
7 |
8 | fig = matplotlib.pyplot.gcf()
9 | # fig.set_size_inches(12, 12)
10 |
11 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
12 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
13 | [170,0,255],[255,0,255]]
14 | canvas = oriImg
15 | stickwidth = 4
16 | x = points[:, 0]
17 | y = points[:, 1]
18 |
19 | for n in range(len(x)):
20 | for child in range(len(pa)):
21 | if pa[child] is 0:
22 | continue
23 |
24 | x1 = x[pa[child] - 1]
25 | y1 = y[pa[child] - 1]
26 | x2 = x[child]
27 | y2 = y[child]
28 |
29 | cv.line(canvas, (x1, y1), (x2, y2), colors[child], 8)
30 |
31 |
32 | plt.imshow(canvas[:, :, [2, 1, 0]])
33 | fig = matplotlib.pyplot.gcf()
34 | fig.set_size_inches(12, 12)
35 |
36 | from time import gmtime, strftime
37 | import os
38 | directory = 'data/mpii/result/test_images'
39 | if not os.path.exists(directory):
40 | os.makedirs(directory)
41 |
42 | fn = os.path.join(directory, strftime("%Y-%m-%d-%H_%M_%S", gmtime()) + '.jpg')
43 |
44 | plt.savefig(fn)
--------------------------------------------------------------------------------
/example/lsp.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import argparse
5 | import time
6 | import matplotlib.pyplot as plt
7 |
8 | import torch
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torchvision.datasets as datasets
13 |
14 | from pose import Bar
15 | from pose.utils.logger import Logger, savefig
16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds
17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join
19 | from pose.utils.imutils import batch_with_heatmap
20 | from pose.utils.transforms import fliplr, flip_back
21 | import pose.models as models
22 | import pose.datasets as datasets
23 |
24 | model_names = sorted(name for name in models.__dict__
25 | if name.islower() and not name.startswith("__")
26 | and callable(models.__dict__[name]))
27 |
28 | idx = [1,2,3,4,5,6,11,12,15,16]
29 |
30 | best_acc = 0
31 |
32 |
33 | def main(args):
34 | global best_acc
35 |
36 | # create checkpoint dir
37 | if not isdir(args.checkpoint):
38 | mkdir_p(args.checkpoint)
39 |
40 | # create model
41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes)
43 |
44 | model = torch.nn.DataParallel(model).cuda()
45 |
46 | # define loss function (criterion) and optimizer
47 | criterion = torch.nn.MSELoss(size_average=True).cuda()
48 |
49 | optimizer = torch.optim.RMSprop(model.parameters(),
50 | lr=args.lr,
51 | momentum=args.momentum,
52 | weight_decay=args.weight_decay)
53 |
54 | # optionally resume from a checkpoint
55 | title = 'LSP-' + args.arch
56 | if args.resume:
57 | if isfile(args.resume):
58 | print("=> loading checkpoint '{}'".format(args.resume))
59 | checkpoint = torch.load(args.resume)
60 | args.start_epoch = checkpoint['epoch']
61 | best_acc = checkpoint['best_acc']
62 | model.load_state_dict(checkpoint['state_dict'])
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 | print("=> loaded checkpoint '{}' (epoch {})"
65 | .format(args.resume, checkpoint['epoch']))
66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True)
67 | else:
68 | print("=> no checkpoint found at '{}'".format(args.resume))
69 | else:
70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
72 |
73 | cudnn.benchmark = True
74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
75 |
76 | # Data loading code
77 | train_loader = torch.utils.data.DataLoader(
78 | datasets.LSP('data/lsp/LEEDS_annotations.json', 'data/lsp/images',
79 | sigma=args.sigma, label_type=args.label_type),
80 | batch_size=args.train_batch, shuffle=True,
81 | num_workers=args.workers, pin_memory=True)
82 |
83 | val_loader = torch.utils.data.DataLoader(
84 | datasets.LSP('data/lsp/LEEDS_annotations.json', 'data/lsp/images',
85 | sigma=args.sigma, label_type=args.label_type, train=False),
86 | batch_size=args.test_batch, shuffle=False,
87 | num_workers=args.workers, pin_memory=True)
88 |
89 | if args.evaluate:
90 | print('\nEvaluation only')
91 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip)
92 | save_pred(predictions, checkpoint=args.checkpoint)
93 | return
94 |
95 | lr = args.lr
96 | for epoch in range(args.start_epoch, args.epochs):
97 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
98 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
99 |
100 | # decay sigma
101 | if args.sigma_decay > 0:
102 | train_loader.dataset.sigma *= args.sigma_decay
103 | val_loader.dataset.sigma *= args.sigma_decay
104 |
105 | # train for one epoch
106 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip)
107 |
108 | # evaluate on validation set
109 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes,
110 | args.debug, args.flip)
111 |
112 | # append logger file
113 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
114 |
115 | # remember best acc and save checkpoint
116 | is_best = valid_acc > best_acc
117 | best_acc = max(valid_acc, best_acc)
118 | save_checkpoint({
119 | 'epoch': epoch + 1,
120 | 'arch': args.arch,
121 | 'state_dict': model.state_dict(),
122 | 'best_acc': best_acc,
123 | 'optimizer' : optimizer.state_dict(),
124 | }, predictions, is_best, checkpoint=args.checkpoint)
125 |
126 | logger.close()
127 | logger.plot(['Train Acc', 'Val Acc'])
128 | savefig(os.path.join(args.checkpoint, 'log.eps'))
129 |
130 |
131 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True):
132 | batch_time = AverageMeter()
133 | data_time = AverageMeter()
134 | losses = AverageMeter()
135 | acces = AverageMeter()
136 |
137 | # switch to train mode
138 | model.train()
139 |
140 | end = time.time()
141 |
142 | gt_win, pred_win = None, None
143 | bar = Bar('Processing', max=len(train_loader))
144 | for i, (inputs, target, meta) in enumerate(train_loader):
145 | # measure data loading time
146 | data_time.update(time.time() - end)
147 |
148 | input_var = torch.autograd.Variable(inputs.cuda())
149 | target_var = torch.autograd.Variable(target.cuda(async=True))
150 |
151 | # compute output
152 | output = model(input_var)
153 | score_map = output[-1].data.cpu()
154 |
155 | loss = criterion(output[0], target_var)
156 | for j in range(1, len(output)):
157 | loss += criterion(output[j], target_var)
158 | acc = accuracy(score_map, target, idx)
159 |
160 | if debug: # visualize groundtruth and predictions
161 | gt_batch_img = batch_with_heatmap(inputs, target)
162 | pred_batch_img = batch_with_heatmap(inputs, score_map)
163 | if not gt_win or not pred_win:
164 | ax1 = plt.subplot(121)
165 | ax1.title.set_text('Groundtruth')
166 | gt_win = plt.imshow(gt_batch_img)
167 | ax2 = plt.subplot(122)
168 | ax2.title.set_text('Prediction')
169 | pred_win = plt.imshow(pred_batch_img)
170 | else:
171 | gt_win.set_data(gt_batch_img)
172 | pred_win.set_data(pred_batch_img)
173 | plt.pause(.05)
174 | plt.draw()
175 |
176 | # measure accuracy and record loss
177 | losses.update(loss.data[0], inputs.size(0))
178 | acces.update(acc[0], inputs.size(0))
179 |
180 | # compute gradient and do SGD step
181 | optimizer.zero_grad()
182 | loss.backward()
183 | optimizer.step()
184 |
185 | # measure elapsed time
186 | batch_time.update(time.time() - end)
187 | end = time.time()
188 |
189 | # plot progress
190 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
191 | batch=i + 1,
192 | size=len(train_loader),
193 | data=data_time.val,
194 | bt=batch_time.val,
195 | total=bar.elapsed_td,
196 | eta=bar.eta_td,
197 | loss=losses.avg,
198 | acc=acces.avg
199 | )
200 | bar.next()
201 |
202 | bar.finish()
203 | return losses.avg, acces.avg
204 |
205 |
206 | def validate(val_loader, model, criterion, num_classes, debug=False, flip=True):
207 | batch_time = AverageMeter()
208 | data_time = AverageMeter()
209 | losses = AverageMeter()
210 | acces = AverageMeter()
211 |
212 | # predictions
213 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)
214 |
215 | # switch to evaluate mode
216 | model.eval()
217 |
218 | gt_win, pred_win = None, None
219 | end = time.time()
220 | bar = Bar('Processing', max=len(val_loader))
221 | for i, (inputs, target, meta) in enumerate(val_loader):
222 | # measure data loading time
223 | data_time.update(time.time() - end)
224 |
225 | target = target.cuda(async=True)
226 |
227 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True)
228 | target_var = torch.autograd.Variable(target, volatile=True)
229 |
230 | # compute output
231 | output = model(input_var)
232 | score_map = output[-1].data.cpu()
233 | if flip:
234 | flip_input_var = torch.autograd.Variable(
235 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(),
236 | volatile=True
237 | )
238 | flip_output_var = model(flip_input_var)
239 | flip_output = flip_back(flip_output_var[-1].data.cpu())
240 | score_map += flip_output
241 |
242 |
243 |
244 | loss = 0
245 | for o in output:
246 | loss += criterion(o, target_var)
247 | acc = accuracy(score_map, target.cpu(), idx)
248 |
249 | # generate predictions
250 | preds = final_preds(score_map, meta['center'], meta['scale'], [64, 64])
251 | for n in range(score_map.size(0)):
252 | predictions[meta['index'][n], :, :] = preds[n, :, :]
253 |
254 |
255 | if debug:
256 | gt_batch_img = batch_with_heatmap(inputs, target)
257 | pred_batch_img = batch_with_heatmap(inputs, score_map)
258 | if not gt_win or not pred_win:
259 | plt.subplot(121)
260 | gt_win = plt.imshow(gt_batch_img)
261 | plt.subplot(122)
262 | pred_win = plt.imshow(pred_batch_img)
263 | else:
264 | gt_win.set_data(gt_batch_img)
265 | pred_win.set_data(pred_batch_img)
266 | plt.pause(.05)
267 | plt.draw()
268 |
269 | # measure accuracy and record loss
270 | losses.update(loss.data[0], inputs.size(0))
271 | acces.update(acc[0], inputs.size(0))
272 |
273 | # measure elapsed time
274 | batch_time.update(time.time() - end)
275 | end = time.time()
276 |
277 | # plot progress
278 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
279 | batch=i + 1,
280 | size=len(val_loader),
281 | data=data_time.val,
282 | bt=batch_time.avg,
283 | total=bar.elapsed_td,
284 | eta=bar.eta_td,
285 | loss=losses.avg,
286 | acc=acces.avg
287 | )
288 | bar.next()
289 |
290 | bar.finish()
291 | return losses.avg, acces.avg, predictions
292 |
293 | if __name__ == '__main__':
294 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
295 | # Model structure
296 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg',
297 | choices=model_names,
298 | help='model architecture: ' +
299 | ' | '.join(model_names) +
300 | ' (default: resnet18)')
301 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
302 | help='Number of hourglasses to stack')
303 | parser.add_argument('--features', default=256, type=int, metavar='N',
304 | help='Number of features in the hourglass')
305 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N',
306 | help='Number of residual modules at each location in the hourglass')
307 | parser.add_argument('--num-classes', default=16, type=int, metavar='N',
308 | help='Number of keypoints')
309 | # Training strategy
310 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
311 | help='number of data loading workers (default: 4)')
312 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
313 | help='number of total epochs to run')
314 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
315 | help='manual epoch number (useful on restarts)')
316 | parser.add_argument('--train-batch', default=6, type=int, metavar='N',
317 | help='train batchsize')
318 | parser.add_argument('--test-batch', default=6, type=int, metavar='N',
319 | help='test batchsize')
320 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float,
321 | metavar='LR', help='initial learning rate')
322 | parser.add_argument('--momentum', default=0, type=float, metavar='M',
323 | help='momentum')
324 | parser.add_argument('--weight-decay', '--wd', default=0, type=float,
325 | metavar='W', help='weight decay (default: 0)')
326 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90],
327 | help='Decrease learning rate at these epochs.')
328 | parser.add_argument('--gamma', type=float, default=0.1,
329 | help='LR is multiplied by gamma on schedule.')
330 | # Data processing
331 | parser.add_argument('-f', '--flip', dest='flip', action='store_true',
332 | help='flip the input during validation')
333 | parser.add_argument('--sigma', type=float, default=1,
334 | help='Groundtruth Gaussian sigma.')
335 | parser.add_argument('--sigma-decay', type=float, default=0,
336 | help='Sigma decay rate for each epoch.')
337 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian',
338 | choices=['Gaussian', 'Cauchy'],
339 | help='Labelmap dist type: (default=Gaussian)')
340 | # Miscs
341 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
342 | help='path to save checkpoint (default: checkpoint)')
343 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
344 | help='path to latest checkpoint (default: none)')
345 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
346 | help='evaluate model on validation set')
347 | parser.add_argument('-d', '--debug', dest='debug', action='store_true',
348 | help='show intermediate results')
349 |
350 |
351 | main(parser.parse_args())
--------------------------------------------------------------------------------
/example/mpii.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import argparse
5 | import time
6 | import matplotlib.pyplot as plt
7 |
8 | import torch
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torchvision.datasets as datasets
13 |
14 | from pose import Bar
15 | from pose.utils.logger import Logger, savefig
16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds
17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join
19 | from pose.utils.imutils import batch_with_heatmap
20 | from pose.utils.transforms import fliplr, flip_back
21 | import pose.models as models
22 | import pose.datasets as datasets
23 |
24 | model_names = sorted(name for name in models.__dict__
25 | if name.islower() and not name.startswith("__")
26 | and callable(models.__dict__[name]))
27 |
28 | idx = [1,2,3,4,5,6,11,12,15,16]
29 |
30 | best_acc = 0
31 |
32 |
33 | def main(args):
34 | global best_acc
35 |
36 | # create checkpoint dir
37 | if not isdir(args.checkpoint):
38 | mkdir_p(args.checkpoint)
39 |
40 | # create model
41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes, mobile=args.mobile)
43 |
44 | model = torch.nn.DataParallel(model).cuda()
45 |
46 | # define loss function (criterion) and optimizer
47 | criterion = torch.nn.MSELoss(size_average=True).cuda()
48 |
49 | optimizer = torch.optim.RMSprop(model.parameters(),
50 | lr=args.lr,
51 | momentum=args.momentum,
52 | weight_decay=args.weight_decay)
53 |
54 | # optionally resume from a checkpoint
55 | title = 'mpii-' + args.arch
56 | if args.resume:
57 | if isfile(args.resume):
58 | print("=> loading checkpoint '{}'".format(args.resume))
59 | checkpoint = torch.load(args.resume)
60 | args.start_epoch = checkpoint['epoch']
61 | best_acc = checkpoint['best_acc']
62 | model.load_state_dict(checkpoint['state_dict'])
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 | print("=> loaded checkpoint '{}' (epoch {})"
65 | .format(args.resume, checkpoint['epoch']))
66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True)
67 | else:
68 | print("=> no checkpoint found at '{}'".format(args.resume))
69 | else:
70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
72 |
73 | cudnn.benchmark = True
74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
75 |
76 | # Data loading code
77 | train_loader = torch.utils.data.DataLoader(
78 | datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
79 | sigma=args.sigma, label_type=args.label_type,
80 | inp_res=args.in_res, out_res=args.in_res//4),
81 | batch_size=args.train_batch, shuffle=True,
82 | num_workers=args.workers, pin_memory=True)
83 |
84 | val_loader = torch.utils.data.DataLoader(
85 | datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
86 | sigma=args.sigma, label_type=args.label_type, train=False,
87 | inp_res=args.in_res, out_res=args.in_res // 4),
88 | batch_size=args.test_batch, shuffle=False,
89 | num_workers=args.workers, pin_memory=True)
90 |
91 | if args.evaluate:
92 | print('\nEvaluation only')
93 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.in_res//4, args.debug, args.flip)
94 | save_pred(predictions, checkpoint=args.checkpoint)
95 | return
96 |
97 | lr = args.lr
98 | for epoch in range(args.start_epoch, args.epochs):
99 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
100 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
101 |
102 | # decay sigma
103 | if args.sigma_decay > 0:
104 | train_loader.dataset.sigma *= args.sigma_decay
105 | val_loader.dataset.sigma *= args.sigma_decay
106 |
107 | # train for one epoch
108 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip)
109 |
110 | # evaluate on validation set
111 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes,
112 | args.in_res//4, args.debug, args.flip)
113 |
114 | # append logger file
115 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
116 |
117 | # remember best acc and save checkpoint
118 | is_best = valid_acc > best_acc
119 | best_acc = max(valid_acc, best_acc)
120 | save_checkpoint({
121 | 'epoch': epoch + 1,
122 | 'arch': args.arch,
123 | 'state_dict': model.state_dict(),
124 | 'best_acc': best_acc,
125 | 'optimizer' : optimizer.state_dict(),
126 | }, predictions, is_best, checkpoint=args.checkpoint)
127 |
128 | logger.close()
129 | logger.plot(['Train Acc', 'Val Acc'])
130 | savefig(os.path.join(args.checkpoint, 'log.eps'))
131 |
132 |
133 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True):
134 | batch_time = AverageMeter()
135 | data_time = AverageMeter()
136 | losses = AverageMeter()
137 | acces = AverageMeter()
138 |
139 | # switch to train mode
140 | model.train()
141 |
142 | end = time.time()
143 |
144 | gt_win, pred_win = None, None
145 | bar = Bar('Processing', max=len(train_loader))
146 | for i, (inputs, target, meta) in enumerate(train_loader):
147 | # measure data loading time
148 | data_time.update(time.time() - end)
149 |
150 | input_var = torch.autograd.Variable(inputs.cuda())
151 | target_var = torch.autograd.Variable(target.cuda(async=True))
152 |
153 | # compute output
154 | output = model(input_var)
155 | score_map = output[-1].data.cpu()
156 |
157 | loss = criterion(output[0], target_var)
158 | for j in range(1, len(output)):
159 | loss += criterion(output[j], target_var)
160 | acc = accuracy(score_map, target, idx)
161 |
162 | if debug: # visualize groundtruth and predictions
163 | gt_batch_img = batch_with_heatmap(inputs, target)
164 | pred_batch_img = batch_with_heatmap(inputs, score_map)
165 | if not gt_win or not pred_win:
166 | ax1 = plt.subplot(121)
167 | ax1.title.set_text('Groundtruth')
168 | gt_win = plt.imshow(gt_batch_img)
169 | ax2 = plt.subplot(122)
170 | ax2.title.set_text('Prediction')
171 | pred_win = plt.imshow(pred_batch_img)
172 | else:
173 | gt_win.set_data(gt_batch_img)
174 | pred_win.set_data(pred_batch_img)
175 | plt.pause(.05)
176 | plt.draw()
177 |
178 | # measure accuracy and record loss
179 | losses.update(loss.item(), inputs.size(0))
180 | acces.update(acc[0], inputs.size(0))
181 |
182 | # compute gradient and do SGD step
183 | optimizer.zero_grad()
184 | loss.backward()
185 | optimizer.step()
186 |
187 | # measure elapsed time
188 | batch_time.update(time.time() - end)
189 | end = time.time()
190 |
191 | # plot progress
192 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
193 | batch=i + 1,
194 | size=len(train_loader),
195 | data=data_time.val,
196 | bt=batch_time.val,
197 | total=bar.elapsed_td,
198 | eta=bar.eta_td,
199 | loss=losses.avg,
200 | acc=acces.avg
201 | )
202 | bar.next()
203 |
204 | bar.finish()
205 | return losses.avg, acces.avg
206 |
207 |
208 | def validate(val_loader, model, criterion, num_classes, out_res, debug=False, flip=True):
209 | batch_time = AverageMeter()
210 | data_time = AverageMeter()
211 | losses = AverageMeter()
212 | acces = AverageMeter()
213 |
214 | # predictions
215 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)
216 |
217 | # switch to evaluate mode
218 | model.eval()
219 |
220 | gt_win, pred_win = None, None
221 | end = time.time()
222 | bar = Bar('Processing', max=len(val_loader))
223 | for i, (inputs, target, meta) in enumerate(val_loader):
224 | # measure data loading time
225 | data_time.update(time.time() - end)
226 |
227 | target = target.cuda(async=True)
228 |
229 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True)
230 | target_var = torch.autograd.Variable(target, volatile=True)
231 |
232 | # compute output
233 | output = model(input_var)
234 | score_map = output[-1].data.cpu()
235 | if flip:
236 | flip_input_var = torch.autograd.Variable(
237 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(),
238 | volatile=True
239 | )
240 | flip_output_var = model(flip_input_var)
241 | flip_output = flip_back(flip_output_var[-1].data.cpu())
242 | score_map += flip_output
243 |
244 |
245 |
246 | loss = 0
247 | for o in output:
248 | loss += criterion(o, target_var)
249 | acc = accuracy(score_map, target.cpu(), idx)
250 |
251 | # generate predictions
252 | preds = final_preds(score_map, meta['center'], meta['scale'], [out_res, out_res])
253 | for n in range(score_map.size(0)):
254 | predictions[meta['index'][n], :, :] = preds[n, :, :]
255 |
256 |
257 | if debug:
258 | gt_batch_img = batch_with_heatmap(inputs, target)
259 | pred_batch_img = batch_with_heatmap(inputs, score_map)
260 | if not gt_win or not pred_win:
261 | plt.subplot(121)
262 | gt_win = plt.imshow(gt_batch_img)
263 | plt.subplot(122)
264 | pred_win = plt.imshow(pred_batch_img)
265 | else:
266 | gt_win.set_data(gt_batch_img)
267 | pred_win.set_data(pred_batch_img)
268 | plt.pause(.05)
269 | plt.draw()
270 |
271 | # measure accuracy and record loss
272 | losses.update(loss.item(), inputs.size(0))
273 | acces.update(acc[0], inputs.size(0))
274 |
275 | # measure elapsed time
276 | batch_time.update(time.time() - end)
277 | end = time.time()
278 |
279 | # plot progress
280 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
281 | batch=i + 1,
282 | size=len(val_loader),
283 | data=data_time.val,
284 | bt=batch_time.avg,
285 | total=bar.elapsed_td,
286 | eta=bar.eta_td,
287 | loss=losses.avg,
288 | acc=acces.avg
289 | )
290 | bar.next()
291 |
292 | bar.finish()
293 | return losses.avg, acces.avg, predictions
294 |
295 | if __name__ == '__main__':
296 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
297 | # Model structure
298 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg',
299 | choices=model_names,
300 | help='model architecture: ' +
301 | ' | '.join(model_names) +
302 | ' (default: resnet18)')
303 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
304 | help='Number of hourglasses to stack')
305 | parser.add_argument('--features', default=256, type=int, metavar='N',
306 | help='Number of features in the hourglass')
307 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N',
308 | help='Number of residual modules at each location in the hourglass')
309 | parser.add_argument('--num-classes', default=16, type=int, metavar='N',
310 | help='Number of keypoints')
311 | parser.add_argument('--mobile', default=False, type=bool, metavar='N',
312 | help='use depthwise convolution in bottneck-block')
313 | # Training strategy
314 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
315 | help='number of data loading workers (default: 4)')
316 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
317 | help='number of total epochs to run')
318 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
319 | help='manual epoch number (useful on restarts)')
320 | parser.add_argument('--train-batch', default=6, type=int, metavar='N',
321 | help='train batchsize')
322 | parser.add_argument('--test-batch', default=6, type=int, metavar='N',
323 | help='test batchsize')
324 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float,
325 | metavar='LR', help='initial learning rate')
326 | parser.add_argument('--momentum', default=0, type=float, metavar='M',
327 | help='momentum')
328 | parser.add_argument('--weight-decay', '--wd', default=0, type=float,
329 | metavar='W', help='weight decay (default: 0)')
330 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90],
331 | help='Decrease learning rate at these epochs.')
332 | parser.add_argument('--gamma', type=float, default=0.1,
333 | help='LR is multiplied by gamma on schedule.')
334 | # Data processing
335 | parser.add_argument('-f', '--flip', dest='flip', action='store_true',
336 | help='flip the input during validation')
337 | parser.add_argument('--sigma', type=float, default=1,
338 | help='Groundtruth Gaussian sigma.')
339 | parser.add_argument('--sigma-decay', type=float, default=0,
340 | help='Sigma decay rate for each epoch.')
341 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian',
342 | choices=['Gaussian', 'Cauchy'],
343 | help='Labelmap dist type: (default=Gaussian)')
344 | parser.add_argument('--in_res', default=256, type=int,
345 | choices=[256, 192],
346 | help='input resolution for network')
347 | # Miscs
348 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
349 | help='path to save checkpoint (default: checkpoint)')
350 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
351 | help='path to latest checkpoint (default: none)')
352 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
353 | help='evaluate model on validation set')
354 | parser.add_argument('-d', '--debug', dest='debug', action='store_true',
355 | help='show intermediate results')
356 |
357 | main(parser.parse_args())
--------------------------------------------------------------------------------
/example/mscoco.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import argparse
5 | import time
6 | import matplotlib.pyplot as plt
7 |
8 | import torch
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torchvision.datasets as datasets
13 |
14 | from pose import Bar
15 | from pose.utils.logger import Logger, savefig
16 | from pose.utils.evaluation import accuracy, AverageMeter, final_preds
17 | from pose.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
18 | from pose.utils.osutils import mkdir_p, isfile, isdir, join
19 | from pose.utils.imutils import batch_with_heatmap
20 | from pose.utils.transforms import fliplr, flip_back
21 | import pose.models as models
22 | import pose.datasets as datasets
23 |
24 | model_names = sorted(name for name in models.__dict__
25 | if name.islower() and not name.startswith("__")
26 | and callable(models.__dict__[name]))
27 |
28 | idx = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
29 |
30 | best_acc = 0
31 |
32 |
33 | def main(args):
34 | global best_acc
35 |
36 | # create checkpoint dir
37 | if not isdir(args.checkpoint):
38 | mkdir_p(args.checkpoint)
39 |
40 | # create model
41 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
42 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes)
43 |
44 | model = torch.nn.DataParallel(model).cuda()
45 |
46 | # define loss function (criterion) and optimizer
47 | criterion = torch.nn.MSELoss(size_average=True).cuda()
48 |
49 | optimizer = torch.optim.RMSprop(model.parameters(),
50 | lr=args.lr,
51 | momentum=args.momentum,
52 | weight_decay=args.weight_decay)
53 |
54 | # optionally resume from a checkpoint
55 | title = 'MSCOCO-' + args.arch
56 | if args.resume:
57 | if isfile(args.resume):
58 | print("=> loading checkpoint '{}'".format(args.resume))
59 | checkpoint = torch.load(args.resume)
60 | args.start_epoch = checkpoint['epoch']
61 | best_acc = checkpoint['best_acc']
62 | model.load_state_dict(checkpoint['state_dict'])
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 | print("=> loaded checkpoint '{}' (epoch {})"
65 | .format(args.resume, checkpoint['epoch']))
66 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True)
67 | else:
68 | print("=> no checkpoint found at '{}'".format(args.resume))
69 | else:
70 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
71 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
72 |
73 | cudnn.benchmark = True
74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
75 |
76 | # Data loading code
77 | train_loader = torch.utils.data.DataLoader(
78 | datasets.Mscoco('data/mscoco/coco_annotations.json', 'data/mscoco/keypoint/images/train2014',
79 | sigma=args.sigma, label_type=args.label_type),
80 | batch_size=args.train_batch, shuffle=True,
81 | num_workers=args.workers, pin_memory=True)
82 |
83 | val_loader = torch.utils.data.DataLoader(
84 | datasets.Mscoco('data/mscoco/coco_annotations.json', 'data/mscoco/keypoint/images/val2014',
85 | sigma=args.sigma, label_type=args.label_type, train=False),
86 | batch_size=args.test_batch, shuffle=False,
87 | num_workers=args.workers, pin_memory=True)
88 |
89 | if args.evaluate:
90 | print('\nEvaluation only')
91 | loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip)
92 | save_pred(predictions, checkpoint=args.checkpoint)
93 | return
94 |
95 | lr = args.lr
96 | for epoch in range(args.start_epoch, args.epochs):
97 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
98 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
99 |
100 | # train for one epoch
101 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip)
102 |
103 | # evaluate on validation set
104 | valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip)
105 |
106 | # append logger file
107 | logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
108 |
109 | # remember best acc and save checkpoint
110 | is_best = valid_acc > best_acc
111 | best_acc = max(valid_acc, best_acc)
112 | save_checkpoint({
113 | 'epoch': epoch + 1,
114 | 'arch': args.arch,
115 | 'state_dict': model.state_dict(),
116 | 'best_acc': best_acc,
117 | 'optimizer' : optimizer.state_dict(),
118 | }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot)
119 |
120 | logger.close()
121 | logger.plot(['Train Acc', 'Val Acc'])
122 | savefig(os.path.join(args.checkpoint, 'log.eps'))
123 |
124 |
125 | def train(train_loader, model, criterion, optimizer, debug=False, flip=True):
126 | batch_time = AverageMeter()
127 | data_time = AverageMeter()
128 | losses = AverageMeter()
129 | acces = AverageMeter()
130 |
131 | # switch to train mode
132 | model.train()
133 |
134 | end = time.time()
135 |
136 | gt_win, pred_win = None, None
137 | bar = Bar('Processing', max=len(train_loader))
138 | for i, (inputs, target, meta) in enumerate(train_loader):
139 | # measure data loading time
140 | data_time.update(time.time() - end)
141 |
142 | input_var = torch.autograd.Variable(inputs.cuda())
143 | target_var = torch.autograd.Variable(target.cuda(async=True))
144 |
145 | # compute output
146 | output = model(input_var)
147 | score_map = output[-1].data.cpu()
148 |
149 | loss = criterion(output[0], target_var)
150 | for j in range(1, len(output)):
151 | loss += criterion(output[j], target_var)
152 | acc = accuracy(score_map, target, idx)
153 |
154 | if debug: # visualize groundtruth and predictions
155 | gt_batch_img = batch_with_heatmap(inputs, target)
156 | pred_batch_img = batch_with_heatmap(inputs, score_map)
157 | if not gt_win or not pred_win:
158 | ax1 = plt.subplot(121)
159 | ax1.title.set_text('Groundtruth')
160 | gt_win = plt.imshow(gt_batch_img)
161 | ax2 = plt.subplot(122)
162 | ax2.title.set_text('Prediction')
163 | pred_win = plt.imshow(pred_batch_img)
164 | else:
165 | gt_win.set_data(gt_batch_img)
166 | pred_win.set_data(pred_batch_img)
167 | plt.pause(.05)
168 | plt.draw()
169 |
170 | # measure accuracy and record loss
171 | losses.update(loss.data[0], inputs.size(0))
172 | acces.update(acc[0], inputs.size(0))
173 |
174 | # compute gradient and do SGD step
175 | optimizer.zero_grad()
176 | loss.backward()
177 | optimizer.step()
178 |
179 | # measure elapsed time
180 | batch_time.update(time.time() - end)
181 | end = time.time()
182 |
183 | # plot progress
184 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
185 | batch=i + 1,
186 | size=len(train_loader),
187 | data=data_time.val,
188 | bt=batch_time.val,
189 | total=bar.elapsed_td,
190 | eta=bar.eta_td,
191 | loss=losses.avg,
192 | acc=acces.avg
193 | )
194 | bar.next()
195 |
196 | bar.finish()
197 | return losses.avg, acces.avg
198 |
199 |
200 | def validate(val_loader, model, criterion, num_classes, debug=False, flip=True):
201 | batch_time = AverageMeter()
202 | data_time = AverageMeter()
203 | losses = AverageMeter()
204 | acces = AverageMeter()
205 |
206 | # predictions
207 | predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)
208 |
209 | # switch to evaluate mode
210 | model.eval()
211 |
212 | gt_win, pred_win = None, None
213 | end = time.time()
214 | bar = Bar('Processing', max=len(val_loader))
215 | for i, (inputs, target, meta) in enumerate(val_loader):
216 | # measure data loading time
217 | data_time.update(time.time() - end)
218 |
219 | target = target.cuda(async=True)
220 |
221 | input_var = torch.autograd.Variable(inputs.cuda(), volatile=True)
222 | target_var = torch.autograd.Variable(target, volatile=True)
223 |
224 | # compute output
225 | output = model(input_var)
226 | score_map = output[-1].data.cpu()
227 | if flip:
228 | flip_input_var = torch.autograd.Variable(
229 | torch.from_numpy(fliplr(inputs.clone().numpy())).float().cuda(),
230 | volatile=True
231 | )
232 | flip_output_var = model(flip_input_var)
233 | flip_output = flip_back(flip_output_var[-1].data.cpu())
234 | score_map += flip_output
235 |
236 |
237 |
238 | loss = 0
239 | for o in output:
240 | loss += criterion(o, target_var)
241 | acc = accuracy(score_map, target.cpu(), idx)
242 |
243 | # generate predictions
244 | preds = final_preds(score_map, meta['center'], meta['scale'], [64, 64])
245 | for n in range(score_map.size(0)):
246 | predictions[meta['index'][n], :, :] = preds[n, :, :]
247 |
248 |
249 | if debug:
250 | gt_batch_img = batch_with_heatmap(inputs, target)
251 | pred_batch_img = batch_with_heatmap(inputs, score_map)
252 | if not gt_win or not pred_win:
253 | plt.subplot(121)
254 | gt_win = plt.imshow(gt_batch_img)
255 | plt.subplot(122)
256 | pred_win = plt.imshow(pred_batch_img)
257 | else:
258 | gt_win.set_data(gt_batch_img)
259 | pred_win.set_data(pred_batch_img)
260 | plt.pause(.5)
261 | plt.draw()
262 |
263 | # measure accuracy and record loss
264 | losses.update(loss.data[0], inputs.size(0))
265 | acces.update(acc[0], inputs.size(0))
266 |
267 | # measure elapsed time
268 | batch_time.update(time.time() - end)
269 | end = time.time()
270 |
271 | # plot progress
272 | bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Acc: {acc: .4f}'.format(
273 | batch=i + 1,
274 | size=len(val_loader),
275 | data=data_time.val,
276 | bt=batch_time.avg,
277 | total=bar.elapsed_td,
278 | eta=bar.eta_td,
279 | loss=losses.avg,
280 | acc=acces.avg
281 | )
282 | bar.next()
283 |
284 | bar.finish()
285 | return losses.avg, acces.avg, predictions
286 |
287 | if __name__ == '__main__':
288 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
289 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg',
290 | choices=model_names,
291 | help='model architecture: ' +
292 | ' | '.join(model_names) +
293 | ' (default: resnet18)')
294 | parser.add_argument('--num-classes', default=17, type=int, metavar='N',
295 | help='Number of keypoints')
296 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
297 | help='number of data loading workers (default: 4)')
298 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
299 | help='number of total epochs to run')
300 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
301 | help='manual epoch number (useful on restarts)')
302 | parser.add_argument('--snapshot', default=0, type=int, metavar='N',
303 | help='How often to take a snapshot of the model (0 = never)')
304 | parser.add_argument('--train-batch', default=6, type=int, metavar='N',
305 | help='train batchsize')
306 | parser.add_argument('--test-batch', default=6, type=int, metavar='N',
307 | help='test batchsize')
308 | parser.add_argument('--lr', '--learning-rate', default=2.5e-4, type=float,
309 | metavar='LR', help='initial learning rate')
310 | parser.add_argument('--momentum', default=0, type=float, metavar='M',
311 | help='momentum')
312 | parser.add_argument('--weight-decay', '--wd', default=0, type=float,
313 | metavar='W', help='weight decay (default: 0)')
314 | parser.add_argument('--schedule', type=int, nargs='+', default=[60, 90],
315 | help='Decrease learning rate at these epochs.')
316 | parser.add_argument('--gamma', type=float, default=0.1,
317 | help='LR is multiplied by gamma on schedule.')
318 | parser.add_argument('--sigma', type=float, default=1,
319 | help='Sigma to generate Gaussian groundtruth map.')
320 | parser.add_argument('--label-type', metavar='LABELTYPE', default='Gaussian',
321 | choices=['Gaussian', 'Cauchy'],
322 | help='Labelmap dist type: (default=Gaussian)')
323 | parser.add_argument('--print-freq', '-p', default=10, type=int,
324 | metavar='N', help='print frequency (default: 10)')
325 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
326 | help='path to save checkpoint (default: checkpoint)')
327 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
328 | help='path to latest checkpoint (default: none)')
329 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
330 | help='evaluate model on validation set')
331 | parser.add_argument('-d', '--debug', dest='debug', action='store_true',
332 | help='show intermediate results')
333 | parser.add_argument('-f', '--flip', dest='flip', action='store_true',
334 | help='flip the input during validation')
335 | # Model structure
336 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
337 | help='Number of hourglasses to stack')
338 | parser.add_argument('--features', default=256, type=int, metavar='N',
339 | help='Number of features in the hourglass')
340 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N',
341 | help='Number of residual modules at each location in the hourglass')
342 |
343 |
344 | main(parser.parse_args())
--------------------------------------------------------------------------------
/miscs/cocoScale.m:
--------------------------------------------------------------------------------
1 | function scale = cocoScale(x, y, v)
2 | % Mean distance on MPII dataset
3 | % rtorso, ltorso, rlleg, ruleg, lulleg, llleg,
4 | % rlarm, ruarm, luarm, llarm, head
5 | meandist = [59.3535, 60.4532, 52.1800, 53.7957, 54.4153, 58.0402, ...
6 | 27.0043, 32.8498, 33.1757, 27.0978, 33.3005];
7 |
8 | sk = {[13, 7], [6, 12], [17, 15], [15, 13], [12, 14], [14, 16], ...
9 | [11, 9], [9, 7], [6, 8], [8, 10]};
10 |
11 | scale = -1;
12 | for i=1:length(sk),
13 | s=sk{i};
14 | if(all(v(s)>0)),
15 | scale = norm([x(s(1))-x(s(2)), y(s(1))-y(s(2))])/meandist(i);
16 | break;
17 | end;
18 | end
19 |
--------------------------------------------------------------------------------
/miscs/gen_coco.m:
--------------------------------------------------------------------------------
1 | %% Generate JSON file for MSCOCO keypoint data
2 | clear all; close all;
3 | addpath('jsonlab/')
4 | addpath('../data/mscoco/coco/MatlabAPI/');
5 | trainval = [1, 0];
6 | personCnt = 0;
7 | DEBUG = false;
8 |
9 | for isv = trainval
10 | isValidation = isv;
11 | %% initialize COCO api (please specify dataType/annType below)
12 | annTypes = {'person_keypoints' };
13 | if isValidation
14 | dataType='val5k2014'; annType=annTypes{1}; % specify dataType/annType
15 | else
16 | dataType='train2014'; annType=annTypes{1}; % specify dataType/annType
17 | end
18 |
19 |
20 | annFile=sprintf('../data/mscoco/keypoint/person_keypoints_train+val5k2014/%s_%s.json',annType,dataType);
21 | coco=CocoApi(annFile);
22 |
23 | %% display COCO categories and supercategories
24 | if( ~strcmp(annType,'captions') )
25 | cats = coco.loadCats(coco.getCatIds());
26 | sk = cats.skeleton; % get skeleton
27 | skc = {'m', 'm', 'g', 'g', 'y', 'r', 'b', 'y', ... % 1-8
28 | 'r', 'b', 'r', 'b', 'c', 'c', 'c', 'c', 'c', 'y', 'y'};
29 | nms={cats.name}; fprintf('COCO categories: ');
30 | fprintf('%s, ',nms{:}); fprintf('\n');
31 | nms=unique({cats.supercategory}); fprintf('COCO supercategories: ');
32 | fprintf('%s, ',nms{:}); fprintf('\n');
33 | end
34 |
35 | %% get all images containing given categories, select one at random
36 | catIds = coco.getCatIds('catNms',{'person'});
37 | imgIds = coco.getImgIds('catIds',catIds);
38 | imgId = imgIds(randi(length(imgIds)));
39 |
40 | imageCnt = 0;
41 | keypointsCnt = zeros(17, 1);
42 | fullBodyCnt = 0;
43 | meanArea = 0;
44 |
45 | for i = 1:length(imgIds)
46 | fprintf('%d | %d\n', i, length(imgIds));
47 | imgId = imgIds(i);
48 | img = coco.loadImgs(imgId);
49 |
50 | %% load and display annotations
51 | annIds = coco.getAnnIds('imgIds',imgId,'catIds',catIds,'iscrowd',[]);
52 | anns = coco.loadAnns(annIds);
53 | n=length(anns);
54 | hasKeypoints = false;
55 | for j=1:n
56 | a=anns(j); if(a.iscrowd), continue; end; hold on;
57 | if a.num_keypoints > 0
58 | hasKeypoints = true;
59 |
60 | kp=a.keypoints;
61 | x=kp(1:3:end)+1; y=kp(2:3:end)+1; v=kp(3:3:end);
62 | vi = find(v > 0);
63 | keypointsCnt(vi) = keypointsCnt(vi) + 1;
64 | meanArea = meanArea + a.area;
65 |
66 | scale = cocoScale(x, y, v);
67 | % if scale == -1 % connot compute scale
68 | if scale <= 0 % connot compute scale
69 | continue;
70 | end
71 |
72 | assert(scale ~= 0);
73 | personCnt = personCnt + 1;
74 |
75 | % write to json
76 | joint_all(personCnt).dataset = 'coco';
77 | joint_all(personCnt).isValidation = isValidation;
78 | joint_all(personCnt).isValidation = isValidation;
79 |
80 | joint_all(personCnt).img_paths = img.file_name;
81 | joint_all(personCnt).objpos = [mean(x(v>0)), mean(y(v>0))];
82 | joint_all(personCnt).joint_self = [x; y; v]';
83 | joint_all(personCnt).scale_provided = scale;
84 |
85 | if DEBUG
86 | if isValidation
87 | datadir ='val2014'; annType=annTypes{1}; % specify dataType/annType
88 | else
89 | datadir='train2014'; annType=annTypes{1}; % specify dataType/annType
90 | end
91 | I = imread(sprintf('../data/mscoco/keypoint/images/%s/%s',datadir,joint_all(personCnt).img_paths));
92 | I = imresize(I, 1/joint_all(personCnt).scale_provided);
93 | imshow(I); hold on;
94 | x1 = x/scale;
95 | y1 = y/scale;
96 | objpos = joint_all(personCnt).objpos/scale;
97 | show_skeleton(x1, y1, v, sk, skc);
98 | viscircles(objpos,5)
99 | pause;close;
100 | end
101 | end
102 |
103 | if a.num_keypoints == 17
104 | fullBodyCnt = fullBodyCnt + 1;
105 | end
106 | end
107 |
108 | if hasKeypoints
109 | imageCnt = imageCnt + 1;
110 | end
111 | end
112 | end
113 | fprintf('save %d person\n', personCnt);
114 |
115 | opt.FileName = '../data/mscoco/coco_annotations.json';
116 | opt.FloatFormat = '%.3f';
117 | opt.Compact = 1;
118 | savejson('', joint_all, opt);
119 |
120 |
121 | %
122 | % clc;
123 | %
124 | % fprintf('validation: images: %d | persons: %d\n', imageCnt, personCnt);
125 | %
126 | % fprintf('%s\n', strjoin(cats.keypoints,', '))
127 | % for i = 1:length(cats.keypoints)
128 | % fprintf('%d, ', keypointsCnt(i));
129 | % end
130 | %
131 | % fprintf('\nFull body cnt: %d\n', fullBodyCnt);
132 | % fprintf('mean area: %.4f\n', meanArea/personCnt);
--------------------------------------------------------------------------------
/miscs/gen_lsp.m:
--------------------------------------------------------------------------------
1 | % Dataset link
2 | % LSP: http://sam.johnson.io/research/lsp.html
3 | % LSP extend: http://sam.johnson.io/research/lspet.html
4 | function gen_lsp
5 | addpath('jsonlab/')
6 | % in cpp: real scale = param_.target_dist()/meta.scale_self = (41/35)/scale_input
7 | targetDist = 41/35; % in caffe cpp file 41/35
8 | oriTrTe = load('/home/wyang/Data/dataset/LSP/joints.mat');
9 | extTrain = load('/home/wyang/Data/dataset/lspet_dataset/joints.mat');
10 |
11 | % in LEEDS:
12 | % 1 Right ankle
13 | % 2 Right knee
14 | % 3 Right hip
15 | % 4 Left hip
16 | % 5 Left knee
17 | % 6 Left ankle
18 | % 7 Right wrist
19 | % 8 Right elbow
20 | % 9 Right shoulder
21 | % 10 Left shoulder
22 | % 11 Left elbow
23 | % 12 Left wrist
24 | % 13 Neck
25 | % 14 Head top
26 | % 15,16 DUMMY
27 | % We want to comply to MPII: (1 - r ankle, 2 - r knee, 3 - r hip, 4 - l hip, 5 - l knee, 6 - l ankle, ..
28 | % 7 - pelvis, 8 - thorax, 9 - upper neck, 10 - head top,
29 | % 11 - r wrist, 12 - r elbow, 13 - r shoulder, 14 - l shoulder, 15 - l elbow, 16 - l wrist)
30 | ordering = [1 2 3, 4 5 6, 15 16, 13 14, 7 8 9, 10 11 12]; % should follow MPI 16 parts..?
31 | oriTrTe.joints(:,[15 16],:) = 0;
32 | oriTrTe.joints = oriTrTe.joints(:,ordering,:);
33 | oriTrTe.joints(3,:,:) = 1 - oriTrTe.joints(3,:,:);
34 | oriTrTe.joints = permute(oriTrTe.joints, [2 1 3]);
35 |
36 | % pelvis
37 | oriTrTe.joints(7, 1:2, :) = mean(oriTrTe.joints(3:4,1:2,:));
38 | v1 = oriTrTe.joints(3,3,:) > 0;
39 | v2 = oriTrTe.joints(4,3,:) > 0;
40 | v = find(v1 .* v2 == 1);
41 | oriTrTe.joints(7, 3, v) = 1;
42 |
43 | % thorax
44 | oriTrTe.joints(8, 1:2, :) = mean(oriTrTe.joints(13:14,1:2,:));
45 | v1 = oriTrTe.joints(13,3,:) > 0;
46 | v2 = oriTrTe.joints(14,3,:) > 0;
47 | v = find(v1 .* v2 == 1);
48 | oriTrTe.joints(8, 3, v) = 1;
49 |
50 | extTrain.joints([15 16],:,:) = 0;
51 | extTrain.joints = extTrain.joints(ordering,:,:);
52 |
53 |
54 | % pelvis
55 | extTrain.joints(7, 1:2, :) = mean(extTrain.joints(3:4,1:2,:));
56 | extTrain.joints(7, 3, :) = 1;
57 |
58 | % thorax
59 | extTrain.joints(8, 1:2, :) = mean(extTrain.joints(13:14,1:2,:));
60 | extTrain.joints(8, 3, :) = 1;
61 |
62 | count = 1;
63 |
64 | path = {'lspet_dataset/images/im%05d.jpg', 'lsp_dataset/images/im%04d.jpg'};
65 | local_path = {'/home/wyang/Data/dataset/lspet_dataset/images/im%05d.jpg', '/home/wyang/Data/dataset/LSP/images/im%04d.jpg'};
66 | num_image = [10000, 1000]; %[10000, 2000];
67 |
68 | for dataset = 1:2
69 | for im = 1:num_image(dataset)
70 | % trivial stuff for LEEDS
71 | joint_all(count).dataset = 'LEEDS';
72 | joint_all(count).isValidation = 0;
73 | joint_all(count).img_paths = sprintf(path{dataset}, im);
74 | joint_all(count).numOtherPeople = 0;
75 | joint_all(count).annolist_index = count;
76 | joint_all(count).people_index = 1;
77 | % joints and w, h
78 | if(dataset == 1)
79 | joint_this = extTrain.joints(:,:,im);
80 | else
81 | joint_this = oriTrTe.joints(:,:,im);
82 | end
83 | path_this = sprintf(local_path{dataset}, im);
84 | [h,w,~] = size(imread(path_this));
85 |
86 | joint_all(count).img_width = w;
87 | joint_all(count).img_height = h;
88 | joint_all(count).joint_self = joint_this;
89 | % infer objpos
90 | invisible = (joint_all(count).joint_self(:,3) == 0);
91 | if(dataset == 1) %lspet is not tightly cropped
92 | joint_all(count).objpos(1) = (min(joint_all(count).joint_self(~invisible, 1)) + max(joint_all(count).joint_self(~invisible, 1))) / 2;
93 | joint_all(count).objpos(2) = (min(joint_all(count).joint_self(~invisible, 2)) + max(joint_all(count).joint_self(~invisible, 2))) / 2;
94 | else
95 | joint_all(count).objpos(1) = w/2;
96 | joint_all(count).objpos(2) = h/2;
97 | end
98 |
99 | count = count + 1;
100 | fprintf('processing %s\n', path_this);
101 | end
102 | end
103 |
104 | % ---- test data
105 | dataset = 2;
106 | for im = 1001:2000
107 | % trivial stuff for LEEDS
108 | joint_all(count).dataset = 'LEEDS';
109 | joint_all(count).isValidation = 1;
110 | joint_all(count).img_paths = sprintf(path{dataset}, im);
111 | joint_all(count).numOtherPeople = 0;
112 | joint_all(count).annolist_index = count;
113 | joint_all(count).people_index = 1;
114 | % joints and w, h
115 | if(dataset == 1)
116 | joint_this = extTrain.joints(:,:,im);
117 | else
118 | joint_this = oriTrTe.joints(:,:,im);
119 | end
120 | path_this = sprintf(local_path{dataset}, im);
121 | [h,w,~] = size(imread(path_this));
122 |
123 | joint_all(count).img_width = w;
124 | joint_all(count).img_height = h;
125 | joint_all(count).joint_self = joint_this;
126 | % infer objpos
127 | invisible = (joint_all(count).joint_self(:,3) == 0);
128 | if(dataset == 1) %lspet is not tightly cropped
129 | joint_all(count).objpos(1) = (min(joint_all(count).joint_self(~invisible, 1)) + max(joint_all(count).joint_self(~invisible, 1))) / 2;
130 | joint_all(count).objpos(2) = (min(joint_all(count).joint_self(~invisible, 2)) + max(joint_all(count).joint_self(~invisible, 2))) / 2;
131 | else
132 | joint_all(count).objpos(1) = w/2;
133 | joint_all(count).objpos(2) = h/2;
134 | end
135 |
136 | count = count + 1;
137 | fprintf('processing %s\n', path_this);
138 | end
139 |
140 |
141 |
142 | joint_all = insertMPILikeScale(joint_all, targetDist);
143 |
144 |
145 | opt.FileName = '../data/lsp/LEEDS_annotations.json';
146 | opt.FloatFormat = '%.3f';
147 | opt.Compact = 1;
148 | savejson('', joint_all, opt);
149 |
150 |
151 | function joint_all = insertMPILikeScale(joint_all, targetDist)
152 | % calculate scales for each image first
153 | joints = cat(3, joint_all.joint_self);
154 | joints([7 8],:,:) = [];
155 | pa = [2 3 7, 5 4 7, 8 0, 10 11 7, 13 12 7];
156 | x = permute(joints(:,1,:), [3 1 2]);
157 | y = permute(joints(:,2,:), [3 1 2]);
158 | vis = permute(joints(:,3,:), [3 1 2]);
159 | validLimb = 1:14-1;
160 |
161 | x_diff = x(:, [1:7,9:14]) - x(:, pa([1:7,9:14]));
162 | y_diff = y(:, [1:7,9:14]) - y(:, pa([1:7,9:14]));
163 | limb_vis = vis(:, [1:7,9:14]) .* vis(:, pa([1:7,9:14]));
164 | l = sqrt(x_diff.^2 + y_diff.^2);
165 |
166 | for p = 1:14-1 % for each limb. reference: 7th limb, which is 7 to pa(7) (neck to head)
167 | valid_compare = limb_vis(:,7) .* limb_vis(:,p);
168 | ratio = l(valid_compare==1, p) ./ l(valid_compare==1, 7);
169 | r(p) = median(ratio(~isnan(ratio), 1));
170 | end
171 |
172 | numFiles = size(x_diff, 1);
173 | all_scales = zeros(numFiles, 1);
174 |
175 | boxSize = 368;
176 | psize = 64;
177 | nSqueezed = 0;
178 |
179 | for file = 1:numFiles %numFiles
180 | l_update = l(file, validLimb) ./ r(validLimb);
181 | l_update = l_update(limb_vis(file,:)==1);
182 | distToObserve = quantile(l_update, 0.75);
183 | scale_in_lmdb = distToObserve/35; % can't get too small. 35 is a magic number to balance to MPI
184 | scale_in_cpp = targetDist/scale_in_lmdb; % can't get too large to be cropped
185 |
186 | visibleParts = joints(:, 3, file);
187 | visibleParts = joints(visibleParts==1, 1:2, file);
188 | x_range = max(visibleParts(:,1)) - min(visibleParts(:,1));
189 | y_range = max(visibleParts(:,2)) - min(visibleParts(:,2));
190 | scale_x_ub = (boxSize - psize)/x_range;
191 | scale_y_ub = (boxSize - psize)/y_range;
192 |
193 | scale_shrink = min(min(scale_x_ub, scale_y_ub), scale_in_cpp);
194 |
195 | if scale_shrink ~= scale_in_cpp
196 | nSqueezed = nSqueezed + 1;
197 | fprintf('img %d: scale = %f %f %f shrink %d\n', file, scale_in_cpp, scale_shrink, min(scale_x_ub, scale_y_ub), nSqueezed);
198 | else
199 | fprintf('img %d: scale = %f %f %f\n', file, scale_in_cpp, scale_shrink, min(scale_x_ub, scale_y_ub));
200 | end
201 |
202 | joint_all(file).scale_provided = targetDist/scale_shrink; % back to lmdb unit
203 | end
204 |
205 | fprintf('total %d squeezed!\n', nSqueezed);
206 |
--------------------------------------------------------------------------------
/miscs/gen_mpii.m:
--------------------------------------------------------------------------------
1 | % Generate MPII train/validation split (Tompson et al. CVPR 2015)
2 | % Code ported from
3 | % https://github.com/shihenw/convolutional-pose-machines-release/blob/master/training/genJSON.m
4 | %
5 | % in MPI: (0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee,
6 | % 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck, 9 - head top,
7 | % 10 - r wrist, 11 - r elbow, 12 - r shoulder, 13 - l shoulder,
8 | % 14 - l elbow, 15 - l wrist)"
9 |
10 |
11 | addpath('jsonlab/')
12 |
13 | % Download MPII http://human-pose.mpi-inf.mpg.de/#download
14 | MPIIROOT = '/home/wyang/Data/dataset/mpii';
15 |
16 | % Download Tompson split from
17 | % http://www.cims.nyu.edu/~tompson/data/mpii_valid_pred.zip
18 | TOMPSONROOT = '/home/wyang/Data/dataset/mpii/Tompson_valid';
19 |
20 | mat = load(fullfile(MPIIROOT, '/mpii_human_pose_v1_u12_1/mpii_human_pose_v1_u12_1.mat'));
21 | RELEASE = mat.RELEASE;
22 | trainIdx = find(RELEASE.img_train);
23 |
24 | tompson = load(fullfile(TOMPSONROOT, '/mpii_predictions/data/detections'));
25 | tompson_i_p = [tompson.RELEASE_img_index; tompson.RELEASE_person_index];
26 |
27 | count = 1;
28 | validationCount = 0;
29 | trainCount = 0;
30 |
31 | makeFigure = 0; % Set as 1 for visualizing annotations
32 |
33 | for i = trainIdx
34 | numPeople = length(RELEASE.annolist(i).annorect);
35 | fprintf('image: %d (numPeople: %d) last: %d\n', i, numPeople, trainIdx(end));
36 |
37 | for p = 1:numPeople
38 | loc = find(sum(~bsxfun(@minus, tompson_i_p, [i;p]))==2, 1);
39 | loc2 = find(tompson.RELEASE_img_index == i);
40 | if(~isempty(loc))
41 | validationCount = validationCount + 1;
42 | isValidation = 1;
43 | elseif (isempty(loc2))
44 | trainCount = trainCount + 1;
45 | isValidation = 0;
46 | else
47 | continue;
48 | end
49 | joint_all(count).dataset = 'MPI';
50 | joint_all(count).isValidation = isValidation;
51 |
52 | try % sometimes no annotation at all....
53 | anno = RELEASE.annolist(i).annorect(p).annopoints.point;
54 | catch
55 | continue;
56 | end
57 |
58 | % set image path
59 | joint_all(count).img_paths = RELEASE.annolist(i).image.name;
60 | [h,w,~] = size(imread(fullfile(MPIIROOT, '/images/', joint_all(count).img_paths)));
61 | joint_all(count).img_width = w;
62 | joint_all(count).img_height = h;
63 | joint_all(count).objpos = [RELEASE.annolist(i).annorect(p).objpos.x, RELEASE.annolist(i).annorect(p).objpos.y];
64 | % set part label: joint_all is (np-3-nTrain)
65 |
66 |
67 | % for this very center person
68 | for part = 1:length(anno)
69 | joint_all(count).joint_self(anno(part).id+1, 1) = anno(part).x;
70 | joint_all(count).joint_self(anno(part).id+1, 2) = anno(part).y;
71 | try % sometimes no is_visible...
72 | if(anno(part).is_visible == 0 || anno(part).is_visible == '0')
73 | joint_all(count).joint_self(anno(part).id+1, 3) = 0;
74 | else
75 | joint_all(count).joint_self(anno(part).id+1, 3) = 1;
76 | end
77 | catch
78 | joint_all(count).joint_self(anno(part).id+1, 3) = 1;
79 | end
80 | end
81 |
82 | % pad it into 16x3
83 | dim_1 = size(joint_all(count).joint_self, 1);
84 | dim_3 = size(joint_all(count).joint_self, 3);
85 | pad_dim = 16 - dim_1;
86 | joint_all(count).joint_self = [joint_all(count).joint_self; zeros(pad_dim, 3, dim_3)];
87 |
88 | % set scale
89 | joint_all(count).scale_provided = RELEASE.annolist(i).annorect(p).scale;
90 |
91 | % for other person on the same image
92 | count_other = 1;
93 | joint_others = cell(0,0);
94 | for op = 1:numPeople
95 | if(op == p), continue; end
96 | try % sometimes no annotation at all....
97 | anno = RELEASE.annolist(i).annorect(op).annopoints.point;
98 | catch
99 | continue;
100 | end
101 | joint_others{count_other} = zeros(16,3);
102 | for part = 1:length(anno)
103 | joint_all(count).joint_others{count_other}(anno(part).id+1, 1) = anno(part).x;
104 | joint_all(count).joint_others{count_other}(anno(part).id+1, 2) = anno(part).y;
105 | try % sometimes no is_visible...
106 | if(anno(part).is_visible == 0 || anno(part).is_visible == '0')
107 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 0;
108 | else
109 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 1;
110 | end
111 | catch
112 | joint_all(count).joint_others{count_other}(anno(part).id+1, 3) = 1;
113 | end
114 | % pad it into 16x3
115 | dim_1 = size(joint_all(count).joint_others{count_other}, 1);
116 | dim_3 = size(joint_all(count).joint_others{count_other}, 3);
117 | pad_dim = 16 - dim_1;
118 | joint_all(count).joint_others{count_other} = [joint_all(count).joint_others{count_other}; zeros(pad_dim, 3, dim_3)];
119 | end
120 |
121 | joint_all(count).scale_provided_other(count_other) = RELEASE.annolist(i).annorect(op).scale;
122 | joint_all(count).objpos_other{count_other} = [RELEASE.annolist(i).annorect(op).objpos.x RELEASE.annolist(i).annorect(op).objpos.y];
123 |
124 | count_other = count_other + 1;
125 | end
126 |
127 | if(makeFigure) % visualizing to debug
128 | imshow(imread(fullfile(MPIIROOT, '/images/', joint_all(count).img_paths)));
129 | hold on;
130 | visiblePart = joint_all(count).joint_self(:,3) == 1;
131 | invisiblePart = joint_all(count).joint_self(:,3) == 0;
132 | plot(joint_all(count).joint_self(visiblePart, 1), joint_all(count).joint_self(visiblePart,2), 'gx', 'MarkerSize', 10);
133 | plot(joint_all(count).joint_self(invisiblePart,1), joint_all(count).joint_self(invisiblePart,2), 'rx', 'MarkerSize', 10);
134 | plot(joint_all(count).objpos(1), joint_all(count).objpos(2), 'cs');
135 | if(~isempty(joint_all(count).joint_others))
136 | for op = 1:size(joint_all(count).joint_others, 3)
137 | visiblePart = joint_all(count).joint_others{op}(:,3) == 1;
138 | invisiblePart = joint_all(count).joint_others{op}(:,3) == 0;
139 | plot(joint_all(count).joint_others{op}(visiblePart,1), joint_all(count).joint_others{op}(visiblePart,2), 'mx', 'MarkerSize', 10);
140 | plot(joint_all(count).joint_others{op}(invisiblePart,1), joint_all(count).joint_others{op}(invisiblePart,2), 'cx', 'MarkerSize', 10);
141 | end
142 | end
143 | pause;
144 | close all;
145 | end
146 | joint_all(count).annolist_index = i;
147 | joint_all(count).people_index = p;
148 | joint_all(count).numOtherPeople = length(joint_all(count).joint_others);
149 | count = count + 1;
150 | end
151 | end
152 |
153 | opt.FileName = '../data/mpii/mpii_annotations.json';
154 | opt.FloatFormat = '%.3f';
155 | opt.Compact = 1;
156 | savejson('', joint_all, opt);
--------------------------------------------------------------------------------
/pose/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import models
5 | from . import utils
6 |
7 | import os, sys
8 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
9 | from progress.bar import Bar as Bar
10 |
11 | __version__ = '0.1.0'
--------------------------------------------------------------------------------
/pose/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .mpii import Mpii
2 | from .mscoco import Mscoco
3 | from .lsp import LSP
4 |
5 | __all__ = ('Mpii', 'Mscoco', 'LSP')
--------------------------------------------------------------------------------
/pose/datasets/lsp.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import numpy as np
5 | import json
6 | import random
7 | import math
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | from pose.utils.osutils import *
13 | from pose.utils.imutils import *
14 | from pose.utils.transforms import *
15 |
16 |
17 | class LSP(data.Dataset):
18 | """
19 | LSP extended dataset (11,000 train, 1000 test)
20 | Original datasets contain 14 keypoints. We interpolate mid-hip and mid-shoulder and change the indices to match
21 | the MPII dataset (16 keypoints).
22 |
23 | Wei Yang (bearpaw@GitHub)
24 | 2017-09-28
25 | """
26 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1,
27 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'):
28 | self.img_folder = img_folder # root image folders
29 | self.is_train = train # training set or test set
30 | self.inp_res = inp_res
31 | self.out_res = out_res
32 | self.sigma = sigma
33 | self.scale_factor = scale_factor
34 | self.rot_factor = rot_factor
35 | self.label_type = label_type
36 |
37 | # create train/val split
38 | with open(jsonfile) as anno_file:
39 | self.anno = json.load(anno_file)
40 |
41 | self.train, self.valid = [], []
42 | for idx, val in enumerate(self.anno):
43 | if val['isValidation'] == True:
44 | self.valid.append(idx)
45 | else:
46 | self.train.append(idx)
47 | self.mean, self.std = self._compute_mean()
48 |
49 | def _compute_mean(self):
50 | meanstd_file = './data/lsp/mean.pth.tar'
51 | if isfile(meanstd_file):
52 | meanstd = torch.load(meanstd_file)
53 | else:
54 | mean = torch.zeros(3)
55 | std = torch.zeros(3)
56 | for index in self.train:
57 | a = self.anno[index]
58 | img_path = os.path.join(self.img_folder, a['img_paths'])
59 | img = load_image(img_path) # CxHxW
60 | mean += img.view(img.size(0), -1).mean(1)
61 | std += img.view(img.size(0), -1).std(1)
62 | mean /= len(self.train)
63 | std /= len(self.train)
64 | meanstd = {
65 | 'mean': mean,
66 | 'std': std,
67 | }
68 | torch.save(meanstd, meanstd_file)
69 | if self.is_train:
70 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2]))
71 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2]))
72 |
73 | return meanstd['mean'], meanstd['std']
74 |
75 | def __getitem__(self, index):
76 | sf = self.scale_factor
77 | rf = self.rot_factor
78 | if self.is_train:
79 | a = self.anno[self.train[index]]
80 | else:
81 | a = self.anno[self.valid[index]]
82 |
83 | img_path = os.path.join(self.img_folder, a['img_paths'])
84 | pts = torch.Tensor(a['joint_self'])
85 | # pts[:, 0:2] -= 1 # Convert pts to zero based
86 |
87 | # c = torch.Tensor(a['objpos']) - 1
88 | c = torch.Tensor(a['objpos'])
89 | s = a['scale_provided']
90 |
91 | # Adjust center/scale slightly to avoid cropping limbs
92 | if c[0] != -1:
93 | # c[1] = c[1] + 15 * s
94 | s = s * 1.4375
95 |
96 | # For single-person pose estimation with a centered/scaled figure
97 | nparts = pts.size(0)
98 | img = load_image(img_path) # CxHxW
99 |
100 | r = 0
101 | # if self.is_train:
102 | # s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0]
103 | # r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0
104 | #
105 | # # # Flip
106 | # # if random.random() <= 0.5:
107 | # # img = torch.from_numpy(fliplr(img.numpy())).float()
108 | # # pts = shufflelr(pts, width=img.size(2), dataset='mpii')
109 | # # c[0] = img.size(2) - c[0]
110 | #
111 | # # Color
112 | # img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
113 | # img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
114 | # img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
115 |
116 | # Prepare image and groundtruth map
117 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r)
118 | inp = color_normalize(inp, self.mean, self.std)
119 |
120 | # Generate ground truth
121 | tpts = pts.clone()
122 | target = torch.zeros(nparts, self.out_res, self.out_res)
123 | for i in range(nparts):
124 | # if tpts[i, 2] > 0: # This is evil!!
125 | if tpts[i, 0] > 0:
126 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r))
127 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type)
128 |
129 | # Meta info
130 | meta = {'index' : index, 'center' : c, 'scale' : s,
131 | 'pts' : pts, 'tpts' : tpts}
132 |
133 | return inp, target, meta
134 |
135 | def __len__(self):
136 | if self.is_train:
137 | return len(self.train)
138 | else:
139 | return len(self.valid)
--------------------------------------------------------------------------------
/pose/datasets/mpii.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import numpy as np
5 | import json
6 | import random
7 | import math
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | from pose.utils.osutils import *
13 | from pose.utils.imutils import *
14 | from pose.utils.transforms import *
15 |
16 |
17 | class Mpii(data.Dataset):
18 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1,
19 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'):
20 | self.img_folder = img_folder # root image folders
21 | self.is_train = train # training set or test set
22 | self.inp_res = inp_res
23 | self.out_res = out_res
24 | self.sigma = sigma
25 | self.scale_factor = scale_factor
26 | self.rot_factor = rot_factor
27 | self.label_type = label_type
28 |
29 | # create train/val split
30 | with open(jsonfile) as anno_file:
31 | self.anno = json.load(anno_file)
32 |
33 | self.train, self.valid = [], []
34 | for idx, val in enumerate(self.anno):
35 | if val['isValidation'] == True:
36 | self.valid.append(idx)
37 | else:
38 | self.train.append(idx)
39 | self.mean, self.std = self._compute_mean()
40 |
41 | def _compute_mean(self):
42 | meanstd_file = './data/mpii/mean.pth.tar'
43 | if isfile(meanstd_file):
44 | meanstd = torch.load(meanstd_file)
45 | else:
46 | mean = torch.zeros(3)
47 | std = torch.zeros(3)
48 | for index in self.train:
49 | a = self.anno[index]
50 | img_path = os.path.join(self.img_folder, a['img_paths'])
51 | img = load_image(img_path) # CxHxW
52 | mean += img.view(img.size(0), -1).mean(1)
53 | std += img.view(img.size(0), -1).std(1)
54 | mean /= len(self.train)
55 | std /= len(self.train)
56 | meanstd = {
57 | 'mean': mean,
58 | 'std': std,
59 | }
60 | torch.save(meanstd, meanstd_file)
61 | if self.is_train:
62 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2]))
63 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2]))
64 |
65 | return meanstd['mean'], meanstd['std']
66 |
67 | def __getitem__(self, index):
68 | sf = self.scale_factor
69 | rf = self.rot_factor
70 | if self.is_train:
71 | a = self.anno[self.train[index]]
72 | else:
73 | a = self.anno[self.valid[index]]
74 |
75 | img_path = os.path.join(self.img_folder, a['img_paths'])
76 | pts = torch.Tensor(a['joint_self'])
77 | # pts[:, 0:2] -= 1 # Convert pts to zero based
78 |
79 | # c = torch.Tensor(a['objpos']) - 1
80 | c = torch.Tensor(a['objpos'])
81 | s = a['scale_provided']
82 |
83 | # Adjust center/scale slightly to avoid cropping limbs
84 | if c[0] != -1:
85 | c[1] = c[1] + 15 * s
86 | s = s * 1.25
87 |
88 | # For single-person pose estimation with a centered/scaled figure
89 | nparts = pts.size(0)
90 | img = load_image(img_path) # CxHxW
91 |
92 | r = 0
93 | if self.is_train:
94 | s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0]
95 | r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0
96 |
97 | # Flip
98 | if random.random() <= 0.5:
99 | img = torch.from_numpy(fliplr(img.numpy())).float()
100 | pts = shufflelr(pts, width=img.size(2), dataset='mpii')
101 | c[0] = img.size(2) - c[0]
102 |
103 | # Color
104 | img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
105 | img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
106 | img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
107 |
108 | # Prepare image and groundtruth map
109 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r)
110 | inp = color_normalize(inp, self.mean, self.std)
111 |
112 | # Generate ground truth
113 | tpts = pts.clone()
114 | target = torch.zeros(nparts, self.out_res, self.out_res)
115 | for i in range(nparts):
116 | # if tpts[i, 2] > 0: # This is evil!!
117 | if tpts[i, 1] > 0:
118 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r))
119 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type)
120 |
121 | # Meta info
122 | meta = {'index' : index, 'center' : c, 'scale' : s,
123 | 'pts' : pts, 'tpts' : tpts}
124 |
125 | return inp, target, meta
126 |
127 | def __len__(self):
128 | if self.is_train:
129 | return len(self.train)
130 | else:
131 | return len(self.valid)
132 |
--------------------------------------------------------------------------------
/pose/datasets/mscoco.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import numpy as np
5 | import json
6 | import random
7 | import math
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | from pose.utils.osutils import *
13 | from pose.utils.imutils import *
14 | from pose.utils.transforms import *
15 |
16 |
17 | class Mscoco(data.Dataset):
18 | def __init__(self, jsonfile, img_folder, inp_res=256, out_res=64, train=True, sigma=1,
19 | scale_factor=0.25, rot_factor=30, label_type='Gaussian'):
20 | self.img_folder = img_folder # root image folders
21 | self.is_train = train # training set or test set
22 | self.inp_res = inp_res
23 | self.out_res = out_res
24 | self.sigma = sigma
25 | self.scale_factor = scale_factor
26 | self.rot_factor = rot_factor
27 | self.label_type = label_type
28 |
29 | # create train/val split
30 | with open(jsonfile) as anno_file:
31 | self.anno = json.load(anno_file)
32 |
33 | self.train, self.valid = [], []
34 | for idx, val in enumerate(self.anno):
35 | if val['isValidation'] == True:
36 | self.valid.append(idx)
37 | else:
38 | self.train.append(idx)
39 | self.mean, self.std = self._compute_mean()
40 |
41 | def _compute_mean(self):
42 | meanstd_file = './data/mscoco/mean.pth.tar'
43 | if isfile(meanstd_file):
44 | meanstd = torch.load(meanstd_file)
45 | else:
46 | print('==> compute mean')
47 | mean = torch.zeros(3)
48 | std = torch.zeros(3)
49 | cnt = 0
50 | for index in self.train:
51 | cnt += 1
52 | print( '{} | {}'.format(cnt, len(self.train)))
53 | a = self.anno[index]
54 | img_path = os.path.join(self.img_folder, a['img_paths'])
55 | img = load_image(img_path) # CxHxW
56 | mean += img.view(img.size(0), -1).mean(1)
57 | std += img.view(img.size(0), -1).std(1)
58 | mean /= len(self.train)
59 | std /= len(self.train)
60 | meanstd = {
61 | 'mean': mean,
62 | 'std': std,
63 | }
64 | torch.save(meanstd, meanstd_file)
65 | if self.is_train:
66 | print(' Mean: %.4f, %.4f, %.4f' % (meanstd['mean'][0], meanstd['mean'][1], meanstd['mean'][2]))
67 | print(' Std: %.4f, %.4f, %.4f' % (meanstd['std'][0], meanstd['std'][1], meanstd['std'][2]))
68 |
69 | return meanstd['mean'], meanstd['std']
70 |
71 | def __getitem__(self, index):
72 | sf = self.scale_factor
73 | rf = self.rot_factor
74 | if self.is_train:
75 | a = self.anno[self.train[index]]
76 | else:
77 | a = self.anno[self.valid[index]]
78 |
79 | img_path = os.path.join(self.img_folder, a['img_paths'])
80 | pts = torch.Tensor(a['joint_self'])
81 | # pts[:, 0:2] -= 1 # Convert pts to zero based
82 |
83 | # c = torch.Tensor(a['objpos']) - 1
84 | c = torch.Tensor(a['objpos'])
85 | s = a['scale_provided']
86 |
87 | # Adjust center/scale slightly to avoid cropping limbs
88 | if c[0] != -1:
89 | c[1] = c[1] + 15 * s
90 | s = s * 1.25
91 |
92 | # For single-person pose estimation with a centered/scaled figure
93 | nparts = pts.size(0)
94 | img = load_image(img_path) # CxHxW
95 |
96 | r = 0
97 | if self.is_train:
98 | s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0]
99 | r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0
100 |
101 | # Flip
102 | if random.random() <= 0.5:
103 | img = torch.from_numpy(fliplr(img.numpy())).float()
104 | pts = shufflelr(pts, width=img.size(2), dataset='mpii')
105 | c[0] = img.size(2) - c[0]
106 |
107 | # Color
108 | img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
109 | img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
110 | img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
111 |
112 | # Prepare image and groundtruth map
113 | inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r)
114 | inp = color_normalize(inp, self.mean, self.std)
115 |
116 | # Generate ground truth
117 | tpts = pts.clone()
118 | target = torch.zeros(nparts, self.out_res, self.out_res)
119 | for i in range(nparts):
120 | if tpts[i, 2] > 0: # COCO visible: 0-no label, 1-label + invisible, 2-label + visible
121 | tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r))
122 | target[i] = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type)
123 |
124 | # Meta info
125 | meta = {'index' : index, 'center' : c, 'scale' : s,
126 | 'pts' : pts, 'tpts' : tpts}
127 |
128 | return inp, target, meta
129 |
130 | def __len__(self):
131 | if self.is_train:
132 | return len(self.train)
133 | else:
134 | return len(self.valid)
--------------------------------------------------------------------------------
/pose/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .hourglass import *
2 | from .preresnet import *
--------------------------------------------------------------------------------
/pose/models/hourglass.py:
--------------------------------------------------------------------------------
1 | '''
2 | Hourglass network inserted in the pre-activated Resnet
3 | Use lr=0.01 for current version
4 | (c) YANG, Wei
5 | '''
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | # from .preresnet import BasicBlock, Bottleneck
10 |
11 |
12 | __all__ = ['HourglassNet', 'hg']
13 |
14 | class Bottleneck(nn.Module):
15 | expansion = 2
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, mobile=False):
18 | super(Bottleneck, self).__init__()
19 |
20 | self.bn1 = nn.BatchNorm2d(inplanes)
21 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | if mobile:
25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
26 | padding=1, bias=True, groups=planes)
27 | else:
28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
29 | padding=1, bias=True)
30 | self.bn3 = nn.BatchNorm2d(planes)
31 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | residual = x
38 |
39 | out = self.bn1(x)
40 | out = self.relu(out)
41 | out = self.conv1(out)
42 |
43 | out = self.bn2(out)
44 | out = self.relu(out)
45 | out = self.conv2(out)
46 |
47 | out = self.bn3(out)
48 | out = self.relu(out)
49 | out = self.conv3(out)
50 |
51 | if self.downsample is not None:
52 | residual = self.downsample(x)
53 |
54 | out += residual
55 |
56 | return out
57 |
58 |
59 | class Hourglass(nn.Module):
60 | def __init__(self, block, num_blocks, planes, depth, mobile):
61 | super(Hourglass, self).__init__()
62 | self.mobile = mobile
63 | self.depth = depth
64 | self.block = block
65 | self.upsample = nn.Upsample(scale_factor=2)
66 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
67 |
68 | def _make_residual(self, block, num_blocks, planes):
69 | layers = []
70 | for i in range(0, num_blocks):
71 | layers.append(block(planes*block.expansion, planes, mobile=self.mobile))
72 | return nn.Sequential(*layers)
73 |
74 | def _make_hour_glass(self, block, num_blocks, planes, depth):
75 | hg = []
76 | for i in range(depth):
77 | res = []
78 | for j in range(3):
79 | res.append(self._make_residual(block, num_blocks, planes))
80 | if i == 0:
81 | res.append(self._make_residual(block, num_blocks, planes))
82 | hg.append(nn.ModuleList(res))
83 | return nn.ModuleList(hg)
84 |
85 | def _hour_glass_forward(self, n, x):
86 | up1 = self.hg[n-1][0](x)
87 | low1 = F.max_pool2d(x, 2, stride=2)
88 | low1 = self.hg[n-1][1](low1)
89 |
90 | if n > 1:
91 | low2 = self._hour_glass_forward(n-1, low1)
92 | else:
93 | low2 = self.hg[n-1][3](low1)
94 | low3 = self.hg[n-1][2](low2)
95 | up2 = self.upsample(low3)
96 | out = up1 + up2
97 | return out
98 |
99 | def forward(self, x):
100 | return self._hour_glass_forward(self.depth, x)
101 |
102 |
103 | class HourglassNet(nn.Module):
104 | '''Hourglass model from Newell et al ECCV 2016'''
105 | def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, mobile=False):
106 | super(HourglassNet, self).__init__()
107 |
108 | self.mobile = mobile
109 | self.inplanes = 64
110 | self.num_feats = 128
111 | self.num_stacks = num_stacks
112 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
113 | bias=True)
114 | self.bn1 = nn.BatchNorm2d(self.inplanes)
115 | self.relu = nn.ReLU(inplace=True)
116 | self.layer1 = self._make_residual(block, self.inplanes, 1)
117 | self.layer2 = self._make_residual(block, self.inplanes, 1)
118 | self.layer3 = self._make_residual(block, self.num_feats, 1)
119 | self.maxpool = nn.MaxPool2d(2, stride=2)
120 |
121 | # build hourglass modules
122 | ch = self.num_feats*block.expansion
123 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
124 | for i in range(num_stacks):
125 | hg.append(Hourglass(block, num_blocks, self.num_feats, 4, self.mobile))
126 | res.append(self._make_residual(block, self.num_feats, num_blocks))
127 | fc.append(self._make_fc(ch, ch))
128 | score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
129 | if i < num_stacks-1:
130 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
131 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
132 | self.hg = nn.ModuleList(hg)
133 | self.res = nn.ModuleList(res)
134 | self.fc = nn.ModuleList(fc)
135 | self.score = nn.ModuleList(score)
136 | self.fc_ = nn.ModuleList(fc_)
137 | self.score_ = nn.ModuleList(score_)
138 |
139 | def _make_residual(self, block, planes, blocks, stride=1):
140 | downsample = None
141 | if stride != 1 or self.inplanes != planes * block.expansion:
142 | downsample = nn.Sequential(
143 | nn.Conv2d(self.inplanes, planes * block.expansion,
144 | kernel_size=1, stride=stride, bias=True),
145 | )
146 |
147 | layers = []
148 | layers.append(block(self.inplanes, planes, stride, downsample, self.mobile))
149 | self.inplanes = planes * block.expansion
150 | for i in range(1, blocks):
151 | layers.append(block(self.inplanes, planes, mobile=self.mobile))
152 |
153 | return nn.Sequential(*layers)
154 |
155 | def _make_fc(self, inplanes, outplanes):
156 | bn = nn.BatchNorm2d(inplanes)
157 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
158 | return nn.Sequential(
159 | conv,
160 | bn,
161 | self.relu,
162 | )
163 |
164 | def forward(self, x):
165 | out = []
166 | x = self.conv1(x)
167 | x = self.bn1(x)
168 | x = self.relu(x)
169 |
170 | x = self.layer1(x)
171 | x = self.maxpool(x)
172 | x = self.layer2(x)
173 | x = self.layer3(x)
174 |
175 | for i in range(self.num_stacks):
176 | y = self.hg[i](x)
177 | y = self.res[i](y)
178 | y = self.fc[i](y)
179 | score = self.score[i](y)
180 | out.append(score)
181 | if i < self.num_stacks-1:
182 | fc_ = self.fc_[i](y)
183 | score_ = self.score_[i](score)
184 | x = x + fc_ + score_
185 |
186 | return out
187 |
188 |
189 | def hg(**kwargs):
190 | model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'],
191 | num_classes=kwargs['num_classes'], mobile=kwargs['mobile'])
192 | return model
193 |
--------------------------------------------------------------------------------
/pose/models/preresnet.py:
--------------------------------------------------------------------------------
1 | '''Pre-activated Resnet for cifar dataset.
2 | Ported form https://github.com/facebook/fb.resnet.torch/blob/master/models/preresnet.lua
3 | (c) YANG, Wei
4 | '''
5 | import torch.nn as nn
6 | import math
7 | import torch.utils.model_zoo as model_zoo
8 |
9 |
10 | __all__ = ['PreResNet', 'preresnet20', 'preresnet32', 'preresnet44', 'preresnet56',
11 | 'preresnet110', 'preresnet1202']
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | "3x3 convolution with padding"
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.bn1 = nn.BatchNorm2d(inplanes)
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.bn1(x)
36 | out = self.relu(out)
37 | out = self.conv1(out)
38 |
39 | out = self.bn2(out)
40 | out = self.relu(out)
41 | out = self.conv2(out)
42 |
43 | if self.downsample is not None:
44 | residual = self.downsample(x)
45 |
46 | out += residual
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.bn1 = nn.BatchNorm2d(inplanes)
57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
58 | self.bn2 = nn.BatchNorm2d(planes)
59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
60 | padding=1, bias=False)
61 | self.bn3 = nn.BatchNorm2d(planes)
62 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x):
68 | residual = x
69 |
70 | out = self.bn1(x)
71 | out = self.relu(out)
72 | out = self.conv1(out)
73 |
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 | out = self.conv2(out)
77 |
78 | out = self.bn3(out)
79 | out = self.relu(out)
80 | out = self.conv3(out)
81 |
82 | if self.downsample is not None:
83 | residual = self.downsample(x)
84 |
85 | out += residual
86 |
87 | return out
88 |
89 |
90 | class PreResNet(nn.Module):
91 |
92 | def __init__(self, block, layers, num_classes=1000):
93 | self.inplanes = 16
94 | super(PreResNet, self).__init__()
95 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
96 | bias=False)
97 | self.layer1 = self._make_layer(block, 16, layers[0])
98 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
99 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
100 | self.bn1 = nn.BatchNorm2d(64*block.expansion)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.fc1 = nn.Conv2d(64*block.expansion, 64*block.expansion, kernel_size=1, bias=False)
103 | self.bn2 = nn.BatchNorm2d(64*block.expansion)
104 | self.fc2 = nn.Conv2d(64*block.expansion, num_classes, kernel_size=1)
105 | # self.avgpool = nn.AvgPool2d(8)
106 | # self.fc = nn.Linear(64*block.expansion, num_classes)
107 |
108 | # for m in self.modules():
109 | # if isinstance(m, nn.Conv2d):
110 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111 | # m.weight.data.normal_(0, math.sqrt(2. / n))
112 | # elif isinstance(m, nn.BatchNorm2d):
113 | # m.weight.data.fill_(1)
114 | # m.bias.data.zero_()
115 |
116 | def _make_layer(self, block, planes, blocks, stride=1):
117 | downsample = None
118 | if stride != 1 or self.inplanes != planes * block.expansion:
119 | downsample = nn.Sequential(
120 | nn.Conv2d(self.inplanes, planes * block.expansion,
121 | kernel_size=1, stride=stride, bias=False),
122 | # nn.BatchNorm2d(planes * block.expansion),
123 | )
124 |
125 | layers = []
126 | layers.append(block(self.inplanes, planes, stride, downsample))
127 | self.inplanes = planes * block.expansion
128 | for i in range(1, blocks):
129 | layers.append(block(self.inplanes, planes))
130 |
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | x = self.conv1(x)
135 |
136 | x = self.layer1(x)
137 | x = self.layer2(x)
138 | x = self.layer3(x)
139 | x = self.fc1(self.relu(self.bn1(x)))
140 | x = self.fc2(self.relu(self.bn2(x)))
141 | # x = self.sigmoid(x)
142 | # x = self.avgpool(x)
143 | # x = x.view(x.size(0), -1)
144 |
145 | return [x]
146 |
147 |
148 | def preresnet20(**kwargs):
149 | """Constructs a PreResNet-20 model.
150 | """
151 | model = PreResNet(BasicBlock, [3, 3, 3], **kwargs)
152 | return model
153 |
154 |
155 | def preresnet32(**kwargs):
156 | """Constructs a PreResNet-32 model.
157 | """
158 | model = PreResNet(BasicBlock, [5, 5, 5], **kwargs)
159 | return model
160 |
161 |
162 | def preresnet44(**kwargs):
163 | """Constructs a PreResNet-44 model.
164 | """
165 | model = PreResNet(Bottleneck, [7, 7, 7], **kwargs)
166 | return model
167 |
168 |
169 | def preresnet56(**kwargs):
170 | """Constructs a PreResNet-56 model.
171 | """
172 | model = PreResNet(Bottleneck, [9, 9, 9], **kwargs)
173 | return model
174 |
175 |
176 | def preresnet110(**kwargs):
177 | """Constructs a PreResNet-110 model.
178 | """
179 | model = PreResNet(Bottleneck, [18, 18, 18], **kwargs)
180 | return model
181 |
182 | def preresnet1202(**kwargs):
183 | """Constructs a PreResNet-1202 model.
184 | """
185 | model = PreResNet(Bottleneck, [200, 200, 200], **kwargs)
186 | return model
--------------------------------------------------------------------------------
/pose/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .evaluation import *
4 | from .imutils import *
5 | from .logger import *
6 | from .misc import *
7 | from .osutils import *
8 | from .transforms import *
9 |
--------------------------------------------------------------------------------
/pose/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import math
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from random import randint
7 |
8 | from .misc import *
9 | from .transforms import transform, transform_preds
10 |
11 | __all__ = ['accuracy', 'AverageMeter']
12 |
13 | def get_preds(scores):
14 | ''' get predictions from score maps in torch Tensor
15 | return type: torch.LongTensor
16 | '''
17 | assert scores.dim() == 4, 'Score maps should be 4-dim'
18 | maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)
19 |
20 | maxval = maxval.view(scores.size(0), scores.size(1), 1)
21 | idx = idx.view(scores.size(0), scores.size(1), 1) + 1
22 |
23 | preds = idx.repeat(1, 1, 2).float()
24 |
25 | preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
26 | preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1
27 |
28 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
29 | preds *= pred_mask
30 | return preds
31 |
32 | def calc_dists(preds, target, normalize):
33 | preds = preds.float()
34 | target = target.float()
35 | dists = torch.zeros(preds.size(1), preds.size(0))
36 | for n in range(preds.size(0)):
37 | for c in range(preds.size(1)):
38 | if target[n,c,0] > 1 and target[n, c, 1] > 1:
39 | dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
40 | else:
41 | dists[c, n] = -1
42 | return dists
43 |
44 | def dist_acc(dists, thr=0.5):
45 | ''' Return percentage below threshold while ignoring values with a -1 '''
46 | if dists.ne(-1).sum() > 0:
47 | return float(dists.le(thr).eq(dists.ne(-1)).sum()) / float(dists.ne(-1).sum())
48 | else:
49 | return -1
50 |
51 | def accuracy(output, target, idxs, thr=0.5):
52 | ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
53 | First value to be returned is average accuracy across 'idxs', followed by individual accuracies
54 | '''
55 | preds = get_preds(output)
56 | gts = get_preds(target)
57 | norm = torch.ones(preds.size(0))*output.size(3)/10
58 | dists = calc_dists(preds, gts, norm)
59 |
60 | acc = torch.zeros(len(idxs)+1)
61 | avg_acc = 0
62 | cnt = 0
63 |
64 | for i in range(len(idxs)):
65 | acc[i+1] = dist_acc(dists[idxs[i]-1])
66 | if acc[i+1] >= 0:
67 | avg_acc = avg_acc + acc[i+1]
68 | cnt += 1
69 |
70 | if cnt != 0:
71 | acc[0] = avg_acc / cnt
72 | return acc
73 |
74 | def final_preds(output, center, scale, res):
75 | coords = get_preds(output) # float type
76 |
77 | # pose-processing
78 | for n in range(coords.size(0)):
79 | for p in range(coords.size(1)):
80 | hm = output[n][p]
81 | px = int(math.floor(coords[n][p][0]))
82 | py = int(math.floor(coords[n][p][1]))
83 | if px > 1 and px < res[0] and py > 1 and py < res[1]:
84 | diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
85 | coords[n][p] += diff.sign() * .25
86 | coords += 0.5
87 | preds = coords.clone()
88 |
89 | # Transform back
90 | for i in range(coords.size(0)):
91 | preds[i] = transform_preds(coords[i], center[i], scale[i], res)
92 |
93 | if preds.dim() < 3:
94 | preds = preds.view(1, preds.size())
95 |
96 | return preds
97 |
98 |
99 | class AverageMeter(object):
100 | """Computes and stores the average and current value"""
101 | def __init__(self):
102 | self.reset()
103 |
104 | def reset(self):
105 | self.val = 0
106 | self.avg = 0
107 | self.sum = 0
108 | self.count = 0
109 |
110 | def update(self, val, n=1):
111 | self.val = val
112 | self.sum += val * n
113 | self.count += n
114 | self.avg = self.sum / self.count
115 |
--------------------------------------------------------------------------------
/pose/utils/imutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import scipy.misc
7 |
8 | from .misc import *
9 |
10 | def im_to_numpy(img):
11 | img = to_numpy(img)
12 | img = np.transpose(img, (1, 2, 0)) # H*W*C
13 | return img
14 |
15 | def im_to_torch(img):
16 | img = np.transpose(img, (2, 0, 1)) # C*H*W
17 | img = to_torch(img).float()
18 | if img.max() > 1:
19 | img /= 255
20 | return img
21 |
22 | def load_image(img_path):
23 | # H x W x C => C x H x W
24 | return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))
25 |
26 | def resize(img, owidth, oheight):
27 | img = im_to_numpy(img)
28 | print('%f %f' % (img.min(), img.max()))
29 | img = scipy.misc.imresize(
30 | img,
31 | (oheight, owidth)
32 | )
33 | img = im_to_torch(img)
34 | print('%f %f' % (img.min(), img.max()))
35 | return img
36 |
37 | # =============================================================================
38 | # Helpful functions generating groundtruth labelmap
39 | # =============================================================================
40 |
41 | def gaussian(shape=(7,7),sigma=1):
42 | """
43 | 2D gaussian mask - should give the same result as MATLAB's
44 | fspecial('gaussian',[shape],[sigma])
45 | """
46 | m,n = [(ss-1.)/2. for ss in shape]
47 | y,x = np.ogrid[-m:m+1,-n:n+1]
48 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
49 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
50 | return to_torch(h).float()
51 |
52 | def draw_labelmap(img, pt, sigma, type='Gaussian'):
53 | # Draw a 2D gaussian
54 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
55 | img = to_numpy(img)
56 |
57 | # Check that any part of the gaussian is in-bounds
58 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
59 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
60 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
61 | br[0] < 0 or br[1] < 0):
62 | # If not, just return the image as is
63 | return to_torch(img)
64 |
65 | # Generate gaussian
66 | size = 6 * sigma + 1
67 | x = np.arange(0, size, 1, float)
68 | y = x[:, np.newaxis]
69 | x0 = y0 = size // 2
70 | # The gaussian is not normalized, we want the center value to equal 1
71 | if type == 'Gaussian':
72 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
73 | elif type == 'Cauchy':
74 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
75 |
76 |
77 | # Usable gaussian range
78 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
79 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
80 | # Image range
81 | img_x = max(0, ul[0]), min(br[0], img.shape[1])
82 | img_y = max(0, ul[1]), min(br[1], img.shape[0])
83 |
84 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
85 | return to_torch(img)
86 |
87 | # =============================================================================
88 | # Helpful display functions
89 | # =============================================================================
90 |
91 | def gauss(x, a, b, c, d=0):
92 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d
93 |
94 | def color_heatmap(x):
95 | x = to_numpy(x)
96 | color = np.zeros((x.shape[0],x.shape[1],3))
97 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
98 | color[:,:,1] = gauss(x, 1, .5, .3)
99 | color[:,:,2] = gauss(x, 1, .2, .3)
100 | color[color > 1] = 1
101 | color = (color * 255).astype(np.uint8)
102 | return color
103 |
104 | def imshow(img):
105 | npimg = im_to_numpy(img*255).astype(np.uint8)
106 | plt.imshow(npimg)
107 | plt.axis('off')
108 |
109 | def show_joints(img, pts):
110 | imshow(img)
111 |
112 | for i in range(pts.size(0)):
113 | if pts[i, 2] > 0:
114 | plt.plot(pts[i, 0], pts[i, 1], 'yo')
115 | plt.axis('off')
116 |
117 | def show_sample(inputs, target):
118 | num_sample = inputs.size(0)
119 | num_joints = target.size(1)
120 | height = target.size(2)
121 | width = target.size(3)
122 |
123 | for n in range(num_sample):
124 | inp = resize(inputs[n], width, height)
125 | out = inp
126 | for p in range(num_joints):
127 | tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5
128 | out = torch.cat((out, tgt), 2)
129 |
130 | imshow(out)
131 | plt.show()
132 |
133 | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
134 | inp = to_numpy(inp * 255)
135 | out = to_numpy(out)
136 |
137 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))
138 | for i in range(3):
139 | img[:, :, i] = inp[i, :, :]
140 |
141 | if parts_to_show is None:
142 | parts_to_show = np.arange(out.shape[0])
143 |
144 | # Generate a single image to display input/output pair
145 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
146 | size = img.shape[0] // num_rows
147 |
148 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
149 | full_img[:img.shape[0], :img.shape[1]] = img
150 |
151 | inp_small = scipy.misc.imresize(img, [size, size])
152 |
153 | # Set up heatmap display for each part
154 | for i, part in enumerate(parts_to_show):
155 | part_idx = part
156 | out_resized = scipy.misc.imresize(out[part_idx], [size, size])
157 | out_resized = out_resized.astype(float)/255
158 | out_img = inp_small.copy() * .3
159 | color_hm = color_heatmap(out_resized)
160 | out_img += color_hm * .7
161 |
162 | col_offset = (i % num_cols + num_rows) * size
163 | row_offset = (i // num_cols) * size
164 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img
165 |
166 | return full_img
167 |
168 | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):
169 | batch_img = []
170 | for n in range(min(inputs.size(0), 4)):
171 | inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])
172 | batch_img.append(
173 | sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)
174 | )
175 | return np.concatenate(batch_img)
176 |
--------------------------------------------------------------------------------
/pose/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
11 |
12 | def savefig(fname, dpi=None):
13 | dpi = 150 if dpi == None else dpi
14 | plt.savefig(fname, dpi=dpi)
15 |
16 | def plot_overlap(logger, names=None):
17 | names = logger.names if names == None else names
18 | numbers = logger.numbers
19 | for _, name in enumerate(names):
20 | x = np.arange(len(numbers[name]))
21 | plt.plot(x, np.asarray(numbers[name]))
22 | return [logger.title + '(' + name + ')' for name in names]
23 |
24 | class Logger(object):
25 | '''Save training process to log file with simple plot function.'''
26 | def __init__(self, fpath, title=None, resume=False):
27 | self.file = None
28 | self.resume = resume
29 | self.title = '' if title == None else title
30 | if fpath is not None:
31 | if resume:
32 | self.file = open(fpath, 'r')
33 | name = self.file.readline()
34 | self.names = name.rstrip().split('\t')
35 | self.numbers = {}
36 | for _, name in enumerate(self.names):
37 | self.numbers[name] = []
38 |
39 | for numbers in self.file:
40 | numbers = numbers.rstrip().split('\t')
41 | for i in range(0, len(numbers)):
42 | self.numbers[self.names[i]].append(numbers[i])
43 | self.file.close()
44 | self.file = open(fpath, 'a')
45 | else:
46 | self.file = open(fpath, 'w')
47 |
48 | def set_names(self, names):
49 | if self.resume:
50 | pass
51 | # initialize numbers as empty list
52 | self.numbers = {}
53 | self.names = names
54 | for _, name in enumerate(self.names):
55 | self.file.write(name)
56 | self.file.write('\t')
57 | self.numbers[name] = []
58 | self.file.write('\n')
59 | self.file.flush()
60 |
61 |
62 | def append(self, numbers):
63 | assert len(self.names) == len(numbers), 'Numbers do not match names'
64 | for index, num in enumerate(numbers):
65 | self.file.write("{0:.6f}".format(num))
66 | self.file.write('\t')
67 | self.numbers[self.names[index]].append(num)
68 | self.file.write('\n')
69 | self.file.flush()
70 |
71 | def plot(self, names=None):
72 | names = self.names if names == None else names
73 | numbers = self.numbers
74 | for _, name in enumerate(names):
75 | x = np.arange(len(numbers[name]))
76 | plt.plot(x, np.asarray(numbers[name]))
77 | plt.legend([self.title + '(' + name + ')' for name in names])
78 | plt.grid(True)
79 |
80 | def close(self):
81 | if self.file is not None:
82 | self.file.close()
83 |
84 | class LoggerMonitor(object):
85 | '''Load and visualize multiple logs.'''
86 | def __init__ (self, paths):
87 | '''paths is a distionary with {name:filepath} pair'''
88 | self.loggers = []
89 | for title, path in paths.items():
90 | logger = Logger(path, title=title, resume=True)
91 | self.loggers.append(logger)
92 |
93 | def plot(self, names=None):
94 | plt.figure()
95 | plt.subplot(121)
96 | legend_text = []
97 | for logger in self.loggers:
98 | legend_text += plot_overlap(logger, names)
99 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
100 | plt.grid(True)
101 |
102 | if __name__ == '__main__':
103 | # # Example
104 | # logger = Logger('test.txt')
105 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
106 |
107 | # length = 100
108 | # t = np.arange(length)
109 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
110 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
111 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
112 |
113 | # for i in range(0, length):
114 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
115 | # logger.plot()
116 |
117 | # Example: logger monitor
118 | paths = {
119 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
120 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
121 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
122 | }
123 |
124 | field = ['Valid Acc.']
125 |
126 | monitor = LoggerMonitor(paths)
127 | monitor.plot(names=field)
128 | savefig('test.eps')
--------------------------------------------------------------------------------
/pose/utils/misc.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import shutil
5 | import torch
6 | import math
7 | import numpy as np
8 | import scipy.io
9 | import matplotlib.pyplot as plt
10 |
11 | def to_numpy(tensor):
12 | if torch.is_tensor(tensor):
13 | return tensor.cpu().numpy()
14 | elif type(tensor).__module__ != 'numpy':
15 | raise ValueError("Cannot convert {} to numpy array"
16 | .format(type(tensor)))
17 | return tensor
18 |
19 |
20 | def to_torch(ndarray):
21 | if type(ndarray).__module__ == 'numpy':
22 | return torch.from_numpy(ndarray)
23 | elif not torch.is_tensor(ndarray):
24 | raise ValueError("Cannot convert {} to torch tensor"
25 | .format(type(ndarray)))
26 | return ndarray
27 |
28 |
29 | def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None):
30 | preds = to_numpy(preds)
31 | filepath = os.path.join(checkpoint, filename)
32 | torch.save(state, filepath)
33 | scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds})
34 |
35 | if snapshot and state.epoch % snapshot == 0:
36 | shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
37 |
38 | if is_best:
39 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
40 | scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds})
41 |
42 |
43 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
44 | preds = to_numpy(preds)
45 | filepath = os.path.join(checkpoint, filename)
46 | scipy.io.savemat(filepath, mdict={'preds' : preds})
47 |
48 |
49 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
50 | """Sets the learning rate to the initial LR decayed by schedule"""
51 | if epoch in schedule:
52 | lr *= gamma
53 | for param_group in optimizer.param_groups:
54 | param_group['lr'] = lr
55 | return lr
--------------------------------------------------------------------------------
/pose/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import errno
5 |
6 | def mkdir_p(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
13 | def isfile(fname):
14 | return os.path.isfile(fname)
15 |
16 | def isdir(dirname):
17 | return os.path.isdir(dirname)
18 |
19 | def join(path, *paths):
20 | return os.path.join(path, *paths)
21 |
--------------------------------------------------------------------------------
/pose/utils/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import numpy as np
5 | import scipy.misc
6 | import matplotlib.pyplot as plt
7 | import torch
8 |
9 | from .misc import *
10 | from .imutils import *
11 |
12 |
13 | def color_normalize(x, mean, std):
14 | if x.size(0) == 1:
15 | x = x.repeat(3, 1, 1)
16 |
17 | for t, m, s in zip(x, mean, std):
18 | t.sub_(m)
19 | return x
20 |
21 |
22 | def flip_back(flip_output, dataset='mpii'):
23 | """
24 | flip output map
25 | """
26 | if dataset == 'mpii':
27 | matchedParts = (
28 | [0,5], [1,4], [2,3],
29 | [10,15], [11,14], [12,13]
30 | )
31 | else:
32 | print('Not supported dataset: ' + dataset)
33 |
34 | # flip output horizontally
35 | flip_output = fliplr(flip_output.numpy())
36 |
37 | # Change left-right parts
38 | for pair in matchedParts:
39 | tmp = np.copy(flip_output[:, pair[0], :, :])
40 | flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]
41 | flip_output[:, pair[1], :, :] = tmp
42 |
43 | return torch.from_numpy(flip_output).float()
44 |
45 |
46 | def shufflelr(x, width, dataset='mpii'):
47 | """
48 | flip coords
49 | """
50 | if dataset == 'mpii':
51 | matchedParts = (
52 | [0,5], [1,4], [2,3],
53 | [10,15], [11,14], [12,13]
54 | )
55 | else:
56 | print('Not supported dataset: ' + dataset)
57 |
58 | # Flip horizontal
59 | x[:, 0] = width - x[:, 0]
60 |
61 | # Change left-right parts
62 | for pair in matchedParts:
63 | tmp = x[pair[0], :].clone()
64 | x[pair[0], :] = x[pair[1], :]
65 | x[pair[1], :] = tmp
66 |
67 | return x
68 |
69 |
70 | def fliplr(x):
71 | if x.ndim == 3:
72 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
73 | elif x.ndim == 4:
74 | for i in range(x.shape[0]):
75 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
76 | return x.astype(float)
77 |
78 |
79 | def get_transform(center, scale, res, rot=0):
80 | """
81 | General image processing functions
82 | """
83 | # Generate transformation matrix
84 | h = 200 * scale
85 | t = np.zeros((3, 3))
86 | t[0, 0] = float(res[1]) / h
87 | t[1, 1] = float(res[0]) / h
88 | t[0, 2] = res[1] * (-float(center[0]) / h + .5)
89 | t[1, 2] = res[0] * (-float(center[1]) / h + .5)
90 | t[2, 2] = 1
91 | if not rot == 0:
92 | rot = -rot # To match direction of rotation from cropping
93 | rot_mat = np.zeros((3,3))
94 | rot_rad = rot * np.pi / 180
95 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
96 | rot_mat[0,:2] = [cs, -sn]
97 | rot_mat[1,:2] = [sn, cs]
98 | rot_mat[2,2] = 1
99 | # Need to rotate around center
100 | t_mat = np.eye(3)
101 | t_mat[0,2] = -res[1]/2
102 | t_mat[1,2] = -res[0]/2
103 | t_inv = t_mat.copy()
104 | t_inv[:2,2] *= -1
105 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
106 | return t
107 |
108 |
109 | def transform(pt, center, scale, res, invert=0, rot=0):
110 | # Transform pixel location to different reference
111 | t = get_transform(center, scale, res, rot=rot)
112 | if invert:
113 | t = np.linalg.inv(t)
114 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
115 | new_pt = np.dot(t, new_pt)
116 | return new_pt[:2].astype(int) + 1
117 |
118 |
119 | def transform_preds(coords, center, scale, res):
120 | # size = coords.size()
121 | # coords = coords.view(-1, coords.size(-1))
122 | # print(coords.size())
123 | for p in range(coords.size(0)):
124 | coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
125 | return coords
126 |
127 |
128 | def crop(img, center, scale, res, rot=0):
129 | img = im_to_numpy(img)
130 |
131 | # Preprocessing for efficient cropping
132 | ht, wd = img.shape[0], img.shape[1]
133 | sf = scale * 200.0 / res[0]
134 | if sf < 2:
135 | sf = 1
136 | else:
137 | new_size = int(np.math.floor(max(ht, wd) / sf))
138 | new_ht = int(np.math.floor(ht / sf))
139 | new_wd = int(np.math.floor(wd / sf))
140 | if new_size < 2:
141 | return torch.zeros(res[0], res[1], img.shape[2]) \
142 | if len(img.shape) > 2 else torch.zeros(res[0], res[1])
143 | else:
144 | img = scipy.misc.imresize(img, [new_ht, new_wd])
145 | center = center * 1.0 / sf
146 | scale = scale / sf
147 |
148 | # Upper left point
149 | ul = np.array(transform([0, 0], center, scale, res, invert=1))
150 | # Bottom right point
151 | br = np.array(transform(res, center, scale, res, invert=1))
152 |
153 | # Padding so that when rotated proper amount of context is included
154 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
155 | if not rot == 0:
156 | ul -= pad
157 | br += pad
158 |
159 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
160 | if len(img.shape) > 2:
161 | new_shape += [img.shape[2]]
162 | new_img = np.zeros(new_shape)
163 |
164 | # Range to fill new array
165 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
166 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
167 | # Range to sample from original image
168 | old_x = max(0, ul[0]), min(len(img[0]), br[0])
169 | old_y = max(0, ul[1]), min(len(img), br[1])
170 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
171 |
172 | if not rot == 0:
173 | # Remove padding
174 | new_img = scipy.misc.imrotate(new_img, rot)
175 | new_img = new_img[pad:-pad, pad:-pad]
176 |
177 | new_img = im_to_torch(scipy.misc.imresize(new_img, res))
178 | return new_img
179 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cloudpickle==0.6.1
2 | cycler==0.10.0
3 | dask==1.0.0
4 | decorator==4.3.0
5 | functools32==3.2.3.post2
6 | matplotlib==2.0.2
7 | networkx==2.2
8 | numpy==1.15.4
9 | Pillow==5.3.0
10 | pyparsing==2.3.0
11 | python-dateutil==2.7.5
12 | pytz==2018.7
13 | PyWavelets==1.0.1
14 | scikit-image==0.14.1
15 | scipy==1.2.0
16 | six==1.12.0
17 | subprocess32==3.5.3
18 | toolz==0.9.0
19 | torch==0.4.0
20 | torchvision==0.2.1
21 |
--------------------------------------------------------------------------------
/tools/mpii_demo.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import torch.nn.parallel
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim
8 | from pose.utils.osutils import mkdir_p, isfile, isdir, join
9 | import pose.models as models
10 | from scipy.ndimage import gaussian_filter, maximum_filter
11 | import cv2
12 | import numpy as np
13 |
14 | def load_image(imgfile, w, h ):
15 | image = cv2.imread(imgfile)
16 | image = cv2.resize(image, (w, h))
17 | image = image[:, :, ::-1] # BGR -> RGB
18 | image = image / 255.0
19 | image = image - np.array([[[0.4404, 0.4440, 0.4327]]]) # Extract mean RGB
20 | image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
21 | image = image[np.newaxis, :, :, :]
22 | return image
23 |
24 | def load_model(arch='hg', stacks=2, blocks=1, num_classes=16, mobile=True,
25 | resume='checkpoint/pytorch-pose/mpii_hg_s2_b1_mobile/checkpoint.pth.tar'):
26 | # create model
27 | model = models.__dict__[arch](num_stacks=stacks, num_blocks=blocks, num_classes=num_classes, mobile=mobile)
28 | model = torch.nn.DataParallel(model).cuda()
29 |
30 | # optionally resume from a checkpoint
31 | if isfile(resume):
32 | print("=> loading checkpoint '{}'".format(resume))
33 | checkpoint = torch.load(resume)
34 | model.load_state_dict(checkpoint['state_dict'])
35 | print("=> loaded checkpoint '{}' (epoch {})"
36 | .format(resume, checkpoint['epoch']))
37 | else:
38 | print("=> no checkpoint found at '{}'".format(resume))
39 |
40 | cudnn.benchmark = True
41 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
42 | model.eval()
43 | return model
44 |
45 | def inference(model, image):
46 | model.eval()
47 | input_tensor = torch.from_numpy(image).float().cuda()
48 | output = model(input_tensor)
49 | output = output[-1]
50 | output = output.data.cpu()
51 | print(output.shape)
52 | kps = post_process_heatmap(output[0,:,:,:])
53 | return kps
54 |
55 |
56 | def post_process_heatmap(heatMap, kpConfidenceTh=0.2):
57 | kplst = list()
58 | for i in range(heatMap.shape[0]):
59 | _map = heatMap[i, :, :]
60 | _map = gaussian_filter(_map, sigma=1)
61 | _nmsPeaks = non_max_supression(_map, windowSize=3, threshold=1e-6)
62 |
63 | y, x = np.where(_nmsPeaks == _nmsPeaks.max())
64 | if len(x) > 0 and len(y) > 0:
65 | kplst.append((int(x[0]), int(y[0]), _nmsPeaks[y[0], x[0]]))
66 | else:
67 | kplst.append((0, 0, 0))
68 |
69 | kp = np.array(kplst)
70 | return kp
71 |
72 |
73 | def non_max_supression(plain, windowSize=3, threshold=1e-6):
74 | # clear value less than threshold
75 | under_th_indices = plain < threshold
76 | plain[under_th_indices] = 0
77 | return plain * (plain == maximum_filter(plain, footprint=np.ones((windowSize, windowSize))))
78 |
79 | def render_kps(cvmat, kps, scale_x, scale_y):
80 | for _kp in kps:
81 | _x, _y, _conf = _kp
82 | if _conf > 0.2:
83 | cv2.circle(cvmat, center=(int(_x*4*scale_x), int(_y*4*scale_y)), color=(0,0,255), radius=5)
84 |
85 | return cvmat
86 |
87 |
88 | def main():
89 | model = load_model()
90 | in_res_h , in_res_w = 192, 192
91 |
92 | imgfile = "/home/yli150/sample.jpg"
93 | image = load_image(imgfile, in_res_w, in_res_h)
94 | print(image.shape)
95 |
96 | kps = inference(model, image)
97 |
98 | cvmat = cv2.imread(imgfile)
99 | scale_x = cvmat.shape[1]*1.0/in_res_w
100 | scale_y = cvmat.shape[0]*1.0/in_res_h
101 | render_kps(cvmat, kps, scale_x, scale_y)
102 | print(kps)
103 | cv2.imshow('x', cvmat)
104 | cv2.waitKey(0)
105 |
106 | if __name__ == '__main__':
107 | main()
--------------------------------------------------------------------------------
/tools/mpii_export_to_onxx.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import argparse
4 | import torch
5 | import torch.nn.parallel
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim
8 |
9 | from pose.utils.logger import Logger, savefig
10 | from pose.utils.osutils import mkdir_p, isfile, isdir, join
11 | import pose.models as models
12 |
13 | model_names = sorted(name for name in models.__dict__
14 | if name.islower() and not name.startswith("__")
15 | and callable(models.__dict__[name]))
16 |
17 | def main(args):
18 |
19 | # create checkpoint dir
20 | if not isdir(args.checkpoint):
21 | mkdir_p(args.checkpoint)
22 |
23 | # create model
24 | print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
25 | model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes,
26 | mobile=args.mobile)
27 | model.eval()
28 |
29 | # optionally resume from a checkpoint
30 | title = 'mpii-' + args.arch
31 | if args.checkpoint:
32 | if isfile(args.checkpoint):
33 | print("=> loading checkpoint '{}'".format(args.checkpoint))
34 | checkpoint = torch.load(args.checkpoint)
35 | args.start_epoch = checkpoint['epoch']
36 |
37 | # create new OrderedDict that does not contain `module.`
38 | from collections import OrderedDict
39 | new_state_dict = OrderedDict()
40 | for k, v in checkpoint['state_dict'].items():
41 | name = k[7:] # remove `module.`
42 | new_state_dict[name] = v
43 | # load params
44 | model.load_state_dict(new_state_dict)
45 |
46 | print("=> loaded checkpoint '{}' (epoch {})"
47 | .format(args.checkpoint, checkpoint['epoch']))
48 | else:
49 | print("=> no checkpoint found at '{}'".format(args.checkpoint))
50 | else:
51 | logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
52 | logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
53 |
54 | cudnn.benchmark = True
55 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
56 |
57 | dummy_input = torch.randn(1, 3, args.in_res, args.in_res)
58 | torch.onnx.export(model, dummy_input, args.out_onnx)
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
62 | # Model structure
63 | parser.add_argument('--arch', '-a', metavar='ARCH', default='hg',
64 | choices=model_names,
65 | help='model architecture: ' +
66 | ' | '.join(model_names) +
67 | ' (default: resnet18)')
68 | parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
69 | help='Number of hourglasses to stack')
70 | parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N',
71 | help='Number of residual modules at each location in the hourglass')
72 | parser.add_argument('--num-classes', default=16, type=int, metavar='N',
73 | help='Number of keypoints')
74 | parser.add_argument('--mobile', default=False, type=bool, metavar='N',
75 | help='use depthwise convolution in bottneck-block')
76 | parser.add_argument('--out_onnx', required=True, type=str, metavar='N',
77 | help='exported onnx file')
78 | parser.add_argument('--checkpoint', required=True, type=str, metavar='N',
79 | help='pre-trained model checkpoint')
80 | parser.add_argument('--in_res', required=True, type=int, metavar='N',
81 | help='input shape 128 or 256')
82 | main(parser.parse_args())
--------------------------------------------------------------------------------