├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── bvh.py ├── finetune.sh ├── inlib ├── __init__.py ├── models.py └── ops.py ├── m2m_config.py ├── model ├── __init__.py ├── base_model.py ├── cl_model.py ├── discriminator_model.py └── gan_model.py ├── pretrain.sh ├── train_gan.py └── utils ├── __init__.py ├── capg_exp_skel.pkl ├── exp_loss.py ├── gan_loss.py ├── plot_loss.py ├── reader.py └── tf_expsdk.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | result_tmp 4 | dataset/* 5 | output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning 2 | This reop contains training code of paper on Music2Dance generation: "[DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning](https://ieeexplore.ieee.org/abstract/document/9042236/)". [Project Page](http://zju-capg.org/research_en_music_deepdance.html) 3 | 4 | ## Requirements 5 | - A CUDA compatible GPU 6 | - Ubuntu >= 14.04 7 | 8 | ## Usage 9 | 10 | Download this repo on your computer and create a new enviroment using commands as follows: 11 | ``` 12 | git clone https://github.com/computer-animation-perception-group/DeepDance_train.git 13 | conda create -n music_dance python==3.5 14 | pip install -r requirement.txt 15 | ``` 16 | Download the processed training data ([fold_json](https://drive.google.com/file/d/18YhFlqkwU6akfjSBgcmywJu_BtfjmAZz/view?usp=sharing), [motion_feature](https://drive.google.com/file/d/18Hk5jEW8DV_AXzWZcvdLkUvlTiVdZ0Sp/view?usp=sharing) and [music_feature](https://drive.google.com/file/d/1VMt_fhG2livx1keh9o9Vu6zwwZPgB3ZZ/view?usp=sharing)), extract and put them under "./dataset", and run the following scripts: 17 | ``` 18 | bash pretrain.sh 19 | bash finetune.sh 20 | ``` 21 | Once the training completed, you can generate novel dances with the training models using our [demo code](https://github.com/computer-animation-perception-group/DeepDance) 22 | 23 | ## License 24 | Licensed under an GPL v3.0 License and only for research purpose. 25 | 26 | ## Bibtex 27 | ``` 28 | @article{sun2020deepdance, 29 | author={G. {Sun} and Y. {Wong} and Z. {Cheng} and M. S. {Kankanhalli} and W. {Geng} and X. {Li}}, 30 | journal={IEEE Transactions on Multimedia}, 31 | title={DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning}, 32 | year={2021}, 33 | volume={23}, 34 | number={}, 35 | pages={497-509},} 36 | ``` 37 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/__init__.py -------------------------------------------------------------------------------- /bvh.py: -------------------------------------------------------------------------------- 1 | class Node: 2 | def __init__(self, root=False): 3 | self.name = None 4 | self.channels = [] 5 | self.offset = (0, 0, 0) 6 | self.children = [] 7 | self._is_root = root 8 | self.order = "" 9 | self.pos_idx = [] 10 | self.exp_idx = [] 11 | self.rot_idx = [] 12 | self.quat_idx = [] 13 | self.parent = [] 14 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | dis_type='DisSegGraph' 3 | loss_mode='gan' 4 | fold_idx=0 5 | seg_len=90 6 | loss_type=2 7 | if [ $loss_type == 1 ]; then 8 | loss_arr=(1.0 0.05 0.0) 9 | elif [ $loss_type == 2 ]; then 10 | loss_arr=(1.0 0.1 0.1) 11 | else 12 | loss_arr=(1.0 0.0 0.0) 13 | fi 14 | mus_ebd_dim=72 15 | dis_name='time_cond_cnn' 16 | kernel_size=(1 3) 17 | stride=(1 2) 18 | cond_axis=1 19 | model_path=./output/pretrain/all-f4/model/cnn-erd_19_model.ckpt.meta 20 | CUDA_VISIBLE_DEVICES=$gpu \ 21 | python3 train_gan.py --learning_rate 1e-4 \ 22 | --dis_learning_rate 2e-5 \ 23 | --mse_rate 1 \ 24 | --dis_rate 0.01 \ 25 | --loss_mode $loss_mode \ 26 | --is_load_model True \ 27 | --is_reg True \ 28 | --reg_scale 5e-5 \ 29 | --rnn_keep_list 0.95 0.9 1.0\ 30 | --dis_type $dis_type \ 31 | --dis_name $dis_name \ 32 | --loss_rate_list ${loss_arr[0]} ${loss_arr[1]} ${loss_arr[2]}\ 33 | --kernel_size ${kernel_size[0]} ${kernel_size[1]} \ 34 | --stride ${stride[0]} ${stride[1]}\ 35 | --act_type lrelu \ 36 | --optimizer Adam \ 37 | --cond_axis $cond_axis \ 38 | --seg_list $seg_len \ 39 | --seq_shift 1 \ 40 | --gen_hop $seg_len \ 41 | --fold_list $fold_idx \ 42 | --type_list gudianwu \ 43 | --model_path ${model_path%.*} \ 44 | --max_max_epoch 15 \ 45 | --save_data_epoch 5 \ 46 | --save_model_epoch 5 \ 47 | --is_save_train False \ 48 | --mot_scale 100. \ 49 | --norm_way zscore \ 50 | --teacher_forcing_ratio 0. \ 51 | --tf_decay 1. \ 52 | --batch_size 128 \ 53 | --mus_ebd_dim $mus_ebd_dim \ 54 | --has_random_seed False \ 55 | --is_all_norm True \ 56 | --add_info ./output/finetune -------------------------------------------------------------------------------- /inlib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/inlib/__init__.py -------------------------------------------------------------------------------- /inlib/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .ops import * 3 | import tensorflow.contrib as tf_contrib 4 | 5 | 6 | def convolution(value, output_num, kernel_size=[3, 3], strides=[1, 1], name='conv', padding='SAME', 7 | activate_type='relu'): 8 | x = conv2d(value, output_num, kernel_size[0], kernel_size[1], [1, strides[0], strides[1], 1], name, padding) 9 | if activate_type == 'relu': 10 | x = xelu(x, 'relu_' + name, activate_type) 11 | else: 12 | x = xelu(x, 'lrelu_' + name, activate_type) 13 | return x 14 | 15 | 16 | class vgg19: 17 | def __init__(self, model_path): 18 | self.data_dict = np.load(model_path).item() 19 | print('load vgg19 weight complete.') 20 | 21 | def get_feature(self, x, reuse=False): 22 | self.channels = x.get_shape()[-1] 23 | with tf.variable_scope('vgg19') as scope: 24 | if reuse: 25 | scope.reuse_variables() 26 | # conv1 27 | self.conv1_1 = convolution(x, 64, [3, 3], [1, 1], name='conv1_1') 28 | self.conv1_2 = convolution(self.conv1_1, 64, [3, 3], [1, 1], name='conv1_2') 29 | self.conv1_2 = pool2D(self.conv1_2, 2, 2, name='max_pool1') 30 | 31 | # conv2 32 | self.conv2_1 = convolution(self.conv1_2, 128, [3, 3], [1, 1], name='conv2_1') 33 | self.conv2_2 = convolution(self.conv2_1, 128, [3, 3], [1, 1], name='conv2_2') 34 | self.conv2_2 = pool2D(self.conv2_2, 2, 2, name='max_pool2') 35 | 36 | # conv3 37 | self.conv3_1 = convolution(self.conv2_2, 256, [3, 3], [1, 1], name='conv3_1') 38 | self.conv3_2 = convolution(self.conv3_1, 256, [3, 3], [1, 1], name='conv3_2') 39 | self.conv3_3 = convolution(self.conv3_2, 256, [3, 3], [1, 1], name='conv3_3') 40 | self.conv3_4 = convolution(self.conv3_3, 256, [3, 3], [1, 1], name='conv3_4') 41 | self.conv3_4 = pool2D(self.conv3_4, 2, 2, name='max_pool3') 42 | 43 | # conv4 44 | self.conv4_1 = convolution(self.conv3_4, 512, [3, 3], [1, 1], name='conv4_1') 45 | self.conv4_2 = convolution(self.conv4_1, 512, [3, 3], [1, 1], name='conv4_2') 46 | self.conv4_3 = convolution(self.conv4_2, 512, [3, 3], [1, 1], name='conv4_3') 47 | self.conv4_4 = convolution(self.conv4_3, 512, [3, 3], [1, 1], name='conv4_4') 48 | self.conv4_4 = pool2D(self.conv3_4, 2, 2, name='max_pool4') 49 | 50 | # conv5 51 | self.conv5_1 = convolution(self.conv4_4, 512, [3, 3], [1, 1], name='conv5_1') 52 | self.conv5_2 = convolution(self.conv5_1, 512, [3, 3], [1, 1], name='conv5_2') 53 | self.conv5_3 = convolution(self.conv5_2, 512, [3, 3], [1, 1], name='conv5_3') 54 | self.conv5_4 = convolution(self.conv5_3, 512, [3, 3], [1, 1], name='conv5_4') 55 | 56 | # flatten 57 | self.feature5_4 = tf.reshape(self.conv5_4, [x.get_shape()[0], -1], name='feature5_4') 58 | 59 | return self.feature5_4 60 | 61 | def load_weights(self, sess): 62 | vars = tf.trainable_variables(scope='vgg19') 63 | loaded_vars = [var for var in vars if 'conv' in var.name] 64 | keys = sorted(self.data_dict) 65 | for i in range(len(keys)): 66 | print(loaded_vars[i * 2], keys[i] + '_weight') 67 | sess.run(loaded_vars[i * 2].assign(self.data_dict[keys[i]][0])) 68 | print(loaded_vars[i * 2 + 1], keys[i] + '_bias') 69 | sess.run(loaded_vars[i * 2 + 1].assign(self.data_dict[keys[i]][1])) 70 | 71 | 72 | def mlp(x, dim_list, name='mlp', reuse=False): 73 | # dim_list=[[output_num1, activation1],...,[output_numk, activationk]] 74 | with tf.variable_scope(name) as scope: 75 | if reuse: 76 | scope.reuse_variables() 77 | outputs = [x] 78 | for i in range(len(dim_list)): 79 | out = fc(outputs[-1], dim_list[i][0], name='fc' + str(i + 1)) 80 | if i < len(dim_list) - 1: 81 | out = xelu(out, name=dim_list[i][1] + str(i + 1), activate_type=dim_list[i][1]) 82 | outputs.append(out) 83 | return outputs[-1] 84 | 85 | 86 | def cnn(x, conv_list, fc_list, name='cnn', is_training=True, reuse=False): 87 | # conv_list=[[output_num1, kernels1, strides1, padding1, activation1], 88 | # ...[output_numk, kernelsk, stridesk, paddingk, activationk]] 89 | # fc_list=[[output_num1, activation1],...,[output_numk,activationk] 90 | with tf.variable_scope(name) as scope: 91 | if reuse: 92 | scope.reuse_variables() 93 | outputs = [x] 94 | for i in range(len(conv_list)): 95 | out = convolution(outputs[-1], conv_list[i][0], conv_list[i][1], conv_list[i][2], name='conv' + str(i + 1), 96 | padding=conv_list[i][3], activate_type=conv_list[i][4]) 97 | if 'bn' in conv_list[-1]: 98 | out = bn(out, is_training, scope='bn'+str(i+1)) 99 | outputs.append(out) 100 | # print('outputs: ', outputs[-1]) 101 | outputs.append(tf.reshape(outputs[-1], [int(x.get_shape()[0]), -1])) 102 | for i in range(len(fc_list)): 103 | out = fc(outputs[-1], fc_list[i][0], name='fc' + str(i + 1)) 104 | if i < len(fc_list) - 1: 105 | out = xelu(out, name=fc_list[i][1] + str(i + 1), activate_type=fc_list[i][1]) 106 | outputs.append(out) 107 | return outputs[-1] 108 | -------------------------------------------------------------------------------- /inlib/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.layers import pooling 3 | import tensorflow.contrib as tf_contrib 4 | 5 | 6 | def weight(shape, stddev=0.02, name='weight', trainable=True): 7 | dtype = tf.float32 8 | var = tf.get_variable(name, shape, dtype, initializer=tf.random_normal_initializer(0.0, stddev, dtype=dtype), 9 | trainable=trainable) 10 | return var 11 | 12 | 13 | def bias(dim, bias_start=0.0, name='bias', trainable=True): 14 | dtype = tf.float32 15 | var = tf.get_variable(name, dim, dtype, initializer=tf.constant_initializer(value=bias_start, dtype=dtype), 16 | trainable=trainable) 17 | 18 | return var 19 | 20 | 21 | def xelu(value, name='relu', activate_type='relu', para=0.2): 22 | with tf.variable_scope(name): 23 | if activate_type == 'relu': 24 | # relu 25 | return tf.nn.relu(value) 26 | elif activate_type == 'lrelu': 27 | # leaky relu 28 | return tf.maximum(value, value * para) 29 | else: 30 | return value 31 | 32 | 33 | def pool2D(value, k_h=3, k_w=3, strides=[1, 2, 2, 1], name='max_pool', padding='VALID'): 34 | kernel_size = [1, k_h, k_w, 1] 35 | with tf.variable_scope(name + '_2d'): 36 | if name == 'max_pool': 37 | # max pooling 38 | return tf.nn.max_pool(value, kernel_size, strides, padding) 39 | elif name == 'avg_pool': 40 | # average pooling 41 | return tf.nn.avg_pool(value, kernel_size, strides, padding) 42 | else: 43 | # default: max pooling 44 | return tf.nn.max_pool(value, kernel_size, strides, padding) 45 | 46 | 47 | def pool1D(value, ksize=3, strides=[1, 2, 1], name='max_pool', padding='VALID'): 48 | kernel_size = [1, ksize, 1] 49 | with tf.variable_scope(name + '_1d'): 50 | if name == 'max_pool': 51 | # max pooling 52 | return pooling.max_pooling1d(value, kernel_size, strides, padding) 53 | elif name == 'avg_pool': 54 | # average pooling 55 | return pooling.average_pooling1d(value, kernel_size, strides, padding) 56 | else: 57 | # default: max pooling 58 | return pooling.max_pooling1d(value, kernel_size, strides, padding) 59 | 60 | 61 | def fc(value, output_num, name='fc', with_weight=False, with_bias=True): 62 | input_shape = value.get_shape().as_list() 63 | with tf.variable_scope(name): 64 | weights = weight([input_shape[1], output_num]) 65 | output = tf.matmul(value, weights) 66 | if with_bias: 67 | biases = bias(output_num) 68 | output = output + biases 69 | if with_weight: 70 | if with_bias: 71 | return output, weights, biases 72 | else: 73 | return output, weights 74 | else: 75 | return output 76 | 77 | 78 | def conv1d(value, output_num, ksize=3, strides=[1, 1, 1], name='conv', padding='SAME', with_weight=False, 79 | with_bias=True): 80 | with tf.variable_scope(name + '_1d'): 81 | weights = weight([ksize, value.get_shape[-1], output_num]) 82 | conv = tf.nn.conv1d(value, weights, strides, padding, use_cudnn_on_gpu=True) 83 | if with_bias: 84 | biases = bias(output_num) 85 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 86 | if with_weight: 87 | if with_bias: 88 | return conv, weights, biases 89 | else: 90 | return conv, weights 91 | else: 92 | return conv 93 | 94 | 95 | def conv2d(value, output_num, k_h=3, k_w=3, strides=[1, 1, 1, 1], name='conv', padding='SAME', with_weight=False, 96 | with_bias=True): 97 | with tf.variable_scope(name + '_2d'): 98 | weights = weight([k_h, k_w, value.get_shape()[-1], output_num]) 99 | conv = tf.nn.conv2d(value, weights, strides, padding, use_cudnn_on_gpu=True) 100 | if with_bias: 101 | biases = bias(output_num) 102 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 103 | if with_weight: 104 | if with_bias: 105 | return conv, weights, biases 106 | else: 107 | return conv, weights 108 | else: 109 | return conv 110 | 111 | 112 | def bn(x, is_training, scope='bn'): 113 | return tf.layers.batch_normalization(x, 114 | axis=-1, 115 | training=is_training, 116 | name=scope) 117 | -------------------------------------------------------------------------------- /m2m_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | 6 | class CAPGConfig(object): 7 | """capg config""" 8 | def __init__(self, m_type, fold_num, seg_len): 9 | self.mot_data_dir = './dataset/motion_feature/exp/' 10 | self.mus_data_dir = './dataset/music_feature/librosa/' 11 | self.json_dir = './dataset/fold_json/' 12 | self.all_json_path = os.path.join(self.json_dir, 'all-f4', fold_num, 'train_list.json') 13 | self.train_json_path = os.path.join(self.json_dir, m_type, fold_num, 'train_list.json') 14 | self.test_json_path = os.path.join(self.json_dir, m_type, fold_num, 'test_list.json') 15 | self.hidden_size = 512 16 | self.mot_hidden_size = 1024 17 | 18 | self.is_save_model = True 19 | self.is_load_model = False 20 | self.save_epoch = 0 21 | self.test_epoch = 5 22 | self.gen_hop = 10 23 | self.seq_shift = 1 # 15 for beat 24 | self.use_mus_rnn = True 25 | self.mus_rnn_layers = 1 26 | self.max_max_epoch = 15 27 | self.is_reg = False 28 | self.reg_scale = 5e-4 29 | self.rnn_keep_prob = 1 30 | 31 | self.is_shuffle = True 32 | self.has_random_seed = False 33 | 34 | self.is_align = True 35 | self.mus_delay = 0 # 1 for beat 36 | self.mot_ignore_dims = [18, 19, 20, 33, 34, 35, 48, 49, 50, 60, 61, 62, 72, 73, 74] 37 | self.mot_dim = 60 38 | 39 | self.is_z_score = True 40 | self.is_all_norm = False 41 | self.mus_dim = 201 42 | self.mus_kernel_size = 51 43 | self.batch_size = 32 44 | self.num_steps = int(seg_len) 45 | self.test_num_steps = int(seg_len) 46 | self.max_epoch = 20 47 | self.lr_decay = 1 48 | self.max_grad_norm = 25 49 | self.val_data_len = 150 50 | 51 | self.rnn_layers = 3 52 | self.mot_rnn_layers = 2 53 | 54 | self.info = "gan_gt" 55 | self.val_batch_size = 1 56 | self.test_batch_size = 1 57 | self.is_use_pre_mot = False 58 | 59 | self.use_noise = False 60 | self.noise_schedule = ['2:0.05', '6:0.1', '12:0.2', '16:0.3', '22:0.5', '30:0.8', '36:1.0'] 61 | self.start_idx = 1 62 | 63 | def save_config(self, path): 64 | config_dict = dict() 65 | for name, value in vars(self).items(): 66 | if isinstance(value, list): 67 | value = np.asarray(value).tolist() 68 | config_dict[name] = value 69 | json.dump(config_dict, open(path, 'w'), indent=4, sort_keys=True) 70 | 71 | 72 | def get_config(m_type, fold_num, seg_len): 73 | config = CAPGConfig(m_type, fold_num, seg_len) 74 | return config 75 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/model/__init__.py -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | 5 | def data_type(): 6 | return tf.float32 7 | 8 | 9 | class BaseModel(object): 10 | """The Base model.""" 11 | def __init__(self, train_type, config): 12 | if train_type == 0: 13 | print("---init training graph---") 14 | is_training = True 15 | elif train_type == 1: 16 | print('---init validate graph---') 17 | is_training = False 18 | else: 19 | print("---init test graph---") 20 | is_training = False 21 | 22 | self.is_training = is_training 23 | config.is_training = is_training 24 | self.mot_dim = config.mot_dim 25 | self.mus_ebd_dim = config.mus_ebd_dim 26 | self.batch_size = config.batch_size 27 | self.num_steps = config.num_steps 28 | self.input_x = tf.placeholder(shape=[self.batch_size, self.num_steps, None], 29 | dtype=data_type(), name="input_x") 30 | # mot_input 31 | self.input_y = tf.placeholder(shape=[self.batch_size, self.num_steps, self.mot_dim], 32 | dtype=data_type(), name="input_y") 33 | 34 | self.init_step_mot = tf.placeholder(data_type(), [self.batch_size, self.mot_dim], name="init_step_mot") 35 | self.tf_mask = tf.placeholder(shape=[self.num_steps], dtype=tf.bool, name='tf_mask') 36 | 37 | mot_predictions, mus_ebd_outputs, mot_state, mus_state = \ 38 | self._build_mot_rnn_graph(mus_inputs=self.input_x, 39 | config=config, 40 | train_type=train_type) 41 | 42 | self.mot_final_state = mot_state 43 | self.mus_final_state = mus_state 44 | self.mot_predictions = mot_predictions 45 | self.mot_truth = self.input_y 46 | self.mus_ebd_outputs = mus_ebd_outputs 47 | 48 | def _build_mus_graph(self, time_step, mus_cell, mus_state, inputs, config, is_training): 49 | print("mus_graph") 50 | # outputs = [] 51 | with tf.variable_scope("mus_rnn"): 52 | fc_weights = tf.get_variable('fc', [config.hidden_size, self.mus_ebd_dim], 53 | initializer=tf.truncated_normal_initializer()) 54 | fc_biases = tf.get_variable('bias', [self.mus_ebd_dim], 55 | initializer=tf.zeros_initializer()) 56 | if time_step > 0: 57 | tf.get_variable_scope().reuse_variables() 58 | 59 | mus_input = self._build_mus_conv_graph(inputs, config, is_training) 60 | (cell_output, mus_state) = mus_cell(mus_input, mus_state) 61 | # outputs.append(cell_output) 62 | output = tf.reshape(cell_output, [-1, config.hidden_size]) 63 | fc_output = tf.nn.xw_plus_b(output, fc_weights, fc_biases) 64 | 65 | return fc_output, mus_state 66 | 67 | @staticmethod 68 | def _get_lstm_cell(rnn_layer_idx, hidden_size, config, is_training): 69 | lstm_cell = tf_contrib.rnn.BasicLSTMCell( 70 | hidden_size, forget_bias=0.0, state_is_tuple=True, 71 | reuse=tf.get_variable_scope().reuse) 72 | print('rnn_layer: ', rnn_layer_idx, config.rnn_keep_list[rnn_layer_idx]) 73 | if is_training and config.rnn_keep_list[rnn_layer_idx] < 1: 74 | lstm_cell = tf_contrib.rnn.DropoutWrapper(lstm_cell, 75 | output_keep_prob=config.rnn_keep_list[rnn_layer_idx]) 76 | return lstm_cell 77 | 78 | def _build_mot_rnn_graph(self, mus_inputs, config, train_type): 79 | if train_type == 0: 80 | is_training = True 81 | else: 82 | is_training = False 83 | 84 | rnn_layer_idx = 0 85 | mus_cell = tf_contrib.rnn.MultiRNNCell( 86 | [self._get_lstm_cell(i, config.hidden_size, config, is_training) 87 | for i in range(rnn_layer_idx, rnn_layer_idx + config.mus_rnn_layers)], state_is_tuple=True) 88 | 89 | rnn_layer_idx += config.mus_rnn_layers 90 | mot_cell = tf_contrib.rnn.MultiRNNCell( 91 | [self._get_lstm_cell(i, config.mot_hidden_size, config, is_training) 92 | for i in range(rnn_layer_idx, rnn_layer_idx + config.mot_rnn_layers)], state_is_tuple=True) 93 | 94 | self.mot_initial_state = mot_cell.zero_state(config.batch_size, data_type()) 95 | mot_state = self.mot_initial_state 96 | 97 | self.mus_initial_state = mus_cell.zero_state(config.batch_size, data_type()) 98 | mus_state = self.mus_initial_state 99 | 100 | last_step_mot = self.init_step_mot 101 | outputs = [] 102 | mus_ebd_outputs = [] 103 | 104 | with tf.variable_scope("generator/mot_rnn"): 105 | 106 | for time_step in range(self.num_steps): 107 | if time_step > 0: 108 | tf.get_variable_scope().reuse_variables() 109 | last_step_mot = tf.cond(tf.equal(self.tf_mask[time_step], tf.constant(True)), 110 | lambda: self.input_y[:, time_step-1, :], 111 | lambda: self.last_step_mot) 112 | 113 | mus_input = mus_inputs[:, time_step, :] 114 | print("mot_rnn: ", time_step) 115 | if not config.use_mus_rnn: 116 | with tf.variable_scope("mus_rnn"): 117 | mus_fea = self._build_mus_conv_graph(mus_input, config, is_training) 118 | else: 119 | mus_fea, mus_state = self._build_mus_graph(time_step, mus_cell, 120 | mus_state, mus_input, config, is_training) 121 | mot_input = last_step_mot 122 | mot_input = tf.reshape(mot_input, [-1, self.mot_dim]) 123 | # mus_fea = tf.zeros(tf.shape(mus_fea)) 124 | all_input = tf.concat([mus_fea, mot_input], 1, name='mus_mot_input') 125 | 126 | # fc1 127 | fc1_weights = tf.get_variable('fc1', [self.mus_ebd_dim + self.mot_dim, 500], dtype=data_type()) 128 | fc1_biases = tf.get_variable('bias1', [500], dtype=data_type()) 129 | fc1_linear = tf.nn.xw_plus_b(all_input, fc1_weights, fc1_biases, name='fc1_linear') 130 | fc1_relu = tf.nn.relu(fc1_linear, name='fc1_relu') 131 | 132 | # fc2 133 | fc2_weights = tf.get_variable('fc2', [500, 500], dtype=data_type()) 134 | fc2_biases = tf.get_variable('bias2', [500], dtype=data_type()) 135 | fc2_linear = tf.nn.xw_plus_b(fc1_relu, fc2_weights, fc2_biases, name='fc2_linear') 136 | 137 | (cell_output, mot_state) = mot_cell(fc2_linear, mot_state) 138 | output = tf.reshape(cell_output, [-1, config.mot_hidden_size]) 139 | 140 | # fc3 141 | fc3_weights = tf.get_variable('fc3', [config.mot_hidden_size, 500], dtype=data_type()) 142 | fc3_biases = tf.get_variable('bias3', [500], dtype=data_type()) 143 | fc3_linear = tf.nn.xw_plus_b(output, fc3_weights, fc3_biases, name='fc3_linear') 144 | fc3_relu = tf.nn.relu(fc3_linear, name='fc3_relu') 145 | 146 | fc4_weights = tf.get_variable('fc4', [500, 100], dtype=data_type()) 147 | fc4_biases = tf.get_variable('bias4', [100], dtype=data_type()) 148 | fc4_linear = tf.nn.xw_plus_b(fc3_relu, fc4_weights, fc4_biases, name='fc4_linear') 149 | fc4_relu = tf.nn.relu(fc4_linear, name='fc4_relu') 150 | 151 | fc5_weights = tf.get_variable('fc5', [100, self.mot_dim], dtype=data_type()) 152 | fc5_biases = tf.get_variable('bias5', [self.mot_dim], dtype=data_type()) 153 | fc5_linear = tf.nn.xw_plus_b(fc4_relu, fc5_weights, fc5_biases, name='fc5_linear') 154 | self.last_step_mot = fc5_linear 155 | 156 | outputs.append(fc5_linear) 157 | mus_ebd_outputs.append(mus_fea) 158 | 159 | outputs = tf.reshape(tf.concat(outputs, 1), [self.batch_size, self.num_steps, self.mot_dim]) 160 | mus_ebd_outputs = tf.reshape(tf.concat(mus_ebd_outputs, 1), [self.batch_size, self.num_steps, self.mus_ebd_dim]) 161 | 162 | return outputs, mus_ebd_outputs, mot_state, mus_state 163 | 164 | @staticmethod 165 | def _mus_conv(inputs, kernel_shape, bias_shape, is_training): 166 | conv_weights = tf.get_variable('conv', kernel_shape, 167 | initializer=tf.truncated_normal_initializer()) 168 | # tf.summary.histogram("conv weights", conv_weights) 169 | conv_biases = tf.get_variable('bias', bias_shape, 170 | initializer=tf.zeros_initializer()) 171 | conv = tf.nn.conv2d(inputs, 172 | conv_weights, 173 | strides=[1, 1, 1, 1], 174 | padding='VALID') 175 | bias = tf.nn.bias_add(conv, conv_biases) 176 | norm = tf.layers.batch_normalization(bias, axis=3, 177 | training=is_training) 178 | 179 | elu = tf.nn.elu(norm) 180 | return elu 181 | 182 | def _build_mus_conv_graph(self, inputs, config, is_training): 183 | """Build music graph""" 184 | 185 | print("mus_conv_graph") 186 | mus_dim = config.mus_dim 187 | mus_input = tf.reshape(inputs, [-1, mus_dim, 5, 1]) 188 | 189 | with tf.variable_scope('conv1'): 190 | elu1 = self._mus_conv(mus_input, 191 | kernel_shape=[mus_dim, 2, 1, 64], 192 | bias_shape=[64], 193 | is_training=is_training) 194 | with tf.variable_scope('conv2'): 195 | elu2 = self._mus_conv(elu1, 196 | kernel_shape=[1, 2, 64, 128], 197 | bias_shape=[128], 198 | is_training=is_training) 199 | 200 | with tf.variable_scope('conv3'): 201 | elu3 = self._mus_conv(elu2, 202 | kernel_shape=[1, 2, 128, 256], 203 | bias_shape=[256], 204 | is_training=is_training) 205 | 206 | with tf.variable_scope('conv4'): 207 | elu4 = self._mus_conv(elu3, 208 | kernel_shape=[1, 2, 256, 512], 209 | bias_shape=[512], 210 | is_training=is_training) 211 | mus_conv_output = tf.reshape(elu4, [-1, 512]) 212 | return mus_conv_output 213 | -------------------------------------------------------------------------------- /model/cl_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import * 2 | from utils import exp_loss as es 3 | 4 | 5 | class CLModel(BaseModel): 6 | def __init__(self, train_type, config): 7 | super(CLModel, self).__init__(train_type, config) 8 | 9 | mot_predictions = self.mot_predictions 10 | mot_truth = self.mot_truth 11 | 12 | tru_pos, pre_pos = es.get_pos_chls(mot_predictions, mot_truth, config) 13 | 14 | # generator loss 15 | g_loss, loss_list = es.loss_impl(mot_predictions, mot_truth, pre_pos, tru_pos, config) 16 | self.g_loss = loss_list 17 | 18 | # if test, return 19 | if not self.is_training: 20 | return 21 | 22 | tvars = tf.trainable_variables() 23 | g_vars = [v for v in tvars if 'generator' in v.name] 24 | 25 | # add reg 26 | if config.is_reg: 27 | reg_cost = tf.reduce_sum([tf.nn.l2_loss(v) for v in g_vars 28 | if 'bias' not in v.name]) * config.reg_scale 29 | g_loss = g_loss + reg_cost 30 | 31 | gen_learning_rate = config.learning_rate 32 | 33 | if config.optimizer.lower() == 'adam': 34 | print('Adam optimizer') 35 | g_optimizer = tf.train.AdamOptimizer(learning_rate=gen_learning_rate) 36 | else: 37 | print('Rmsprop optimizer') 38 | g_optimizer = tf.train.RMSPropOptimizer(learning_rate=gen_learning_rate) 39 | 40 | # for batch_norm op 41 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 42 | g_grads = tf.gradients(g_loss, g_vars, aggregation_method=2) 43 | with tf.control_dependencies(update_ops): 44 | self.train_g_op = g_optimizer.apply_gradients(zip(g_grads, g_vars)) 45 | print('train_g_op') 46 | -------------------------------------------------------------------------------- /model/discriminator_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import inlib.models as md 3 | import numpy as np 4 | 5 | 6 | class DisGraph(object): 7 | def __init__(self, inputs, cond_inputs, config, name, is_reuse): 8 | self.inputs = inputs 9 | self.cond_inputs = cond_inputs 10 | self.name = name 11 | self.is_reuse = is_reuse 12 | self.act_type = config.act_type 13 | self.kernel_size = config.kernel_size 14 | self.cond_axis = config.cond_axis 15 | self.stride = config.stride 16 | self.mus_ebd_dim = config.mus_ebd_dim 17 | self.batch_size = config.batch_size 18 | self.num_steps = config.num_steps 19 | self.is_training = config.is_training 20 | self.is_shuffle = config.is_shuffle 21 | 22 | def build_dis_graph(self): 23 | if self.name == 'mlp': 24 | outputs = self._build_dis_mlp_graph() 25 | elif self.name == 'cnn': 26 | outputs = self._build_dis_cnn_graph() 27 | elif self.name == 'sig_cnn': 28 | outputs = self._build_dis_sig_cnn_graph() 29 | elif self.name == 'cond_cnn': 30 | outputs = self._build_dis_cond_cnn_graph() 31 | elif self.name == 'time_cond_cnn': 32 | outputs = self._build_dis_time_cond_cnn_graph() 33 | elif self.name == 'tgan_cond_cnn': 34 | outputs = self._build_dis_tgan_cond_cnn_graph() 35 | elif self.name == 'time_tgan_cond_cnn': 36 | outputs = self._build_dis_time_tgan_cond_cnn_graph() 37 | else: 38 | raise ValueError('Not valid discriminator name') 39 | 40 | return outputs 41 | 42 | def _build_dis_mlp_graph(self): 43 | return [] 44 | 45 | def _build_dis_cnn_graph(self): 46 | return [] 47 | 48 | def _build_dis_cond_cnn_graph(self): 49 | return [] 50 | 51 | def _build_dis_time_cond_cnn_graph(self): 52 | return [] 53 | 54 | def _build_dis_tgan_cond_cnn_graph(self): 55 | return [] 56 | 57 | def _build_dis_time_tgan_cond_cnn_graph(self): 58 | return [] 59 | 60 | def _build_dis_sig_cnn_graph(self): 61 | return [] 62 | 63 | 64 | class DisFrameGraph(DisGraph): 65 | def __init__(self, inputs, cond_inputs, config, name='cnn', is_reuse=False): 66 | super(DisFrameGraph, self).__init__(inputs, cond_inputs, config, name, is_reuse) 67 | 68 | def _build_dis_mlp_graph(self): 69 | fc_list_d = [[100, self.act_type], [256, self.act_type], [500, self.act_type], [1, '']] 70 | # [batch_size*num_steps, mus_ebd_dim] 71 | mot_input = tf.reshape(self.inputs, [-1, 60]) 72 | outputs = md.mlp(mot_input, fc_list_d, 'discriminator', reuse=self.is_reuse) 73 | return outputs 74 | 75 | def _build_dis_cnn_graph(self): 76 | print('frame_cnn_graph') 77 | mot_input = tf.reshape(self.inputs, [-1, 20, 1, 3]) 78 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type], 79 | [128, self.kernel_size, self.stride, 'SAME', self.act_type]] 80 | fc_list_d = [[1, '']] 81 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 82 | return outputs 83 | 84 | def _build_dis_cond_cnn_graph(self): 85 | print('frame_cond_cnn_graph') 86 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 87 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 88 | # bs * mus_ebd_dim * num_steps * 1 89 | mot_input = tf.transpose(mot_input, [0, 2, 1, 3]) 90 | cond_input = tf.transpose(cond_input, [0, 2, 1, 3]) 91 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond') 92 | 93 | [batch_size, m_dim, num_steps, chl] = all_input.get_shape() 94 | all_input = tf.transpose(all_input, [0, 2, 1, 3]) 95 | all_input = tf.reshape(all_input, [int(batch_size)*int(num_steps), int(m_dim), 1, int(chl)]) 96 | print('all_input: ', all_input) 97 | 98 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type], 99 | [128, self.kernel_size, self.stride, 'SAME', self.act_type]] 100 | fc_list_d = [[1, '']] 101 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 102 | return outputs 103 | 104 | def _build_dis_sig_cnn_graph(self): 105 | print('frame_sig_cnn_graph') 106 | inputs = tf.reshape(self.inputs, [-1, 20, 1, 3]) 107 | idx_lists = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 108 | 18, 19, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 2, 4, 6, 8, 10, 109 | 12, 14, 16, 18, 1, 4, 7, 10, 13, 16, 19, 3, 6, 9, 12, 15, 18, 110 | 2, 5, 8, 11, 14, 17, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3, 111 | 7, 11, 15, 19, 4, 8, 12, 16, 1, 6, 11, 16, 2, 7, 12, 17, 3, 112 | 8, 13, 18, 4, 9, 14, 19, 5, 10, 15, 1, 7, 13, 19, 6, 12, 18, 113 | 5, 11, 17, 4, 10, 16, 3, 9, 15, 2, 8, 14, 1, 8, 15, 3, 10, 114 | 17, 5, 12, 19, 7, 14, 2, 9, 16, 4, 11, 18, 6, 13, 1, 9, 17, 115 | 6, 14, 3, 11, 19, 8, 16, 5, 13, 2, 10, 18, 7, 15, 4, 12, 1, 116 | 10, 19, 9, 18, 8, 17, 7, 16, 6, 15, 5, 14, 4, 13, 3, 12, 2, 117 | 11, 1] 118 | 119 | # TODO: need to check 120 | mot_input = [] 121 | for i, idx in enumerate(idx_lists): 122 | mot_input.append(inputs[:, idx, :, :]) 123 | mot_input = tf.reshape(tf.concat(mot_input, axis=1), [-1, 173, 1, 3]) 124 | # [3, 1], [2, 1] 125 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type], 126 | [128, self.kernel_size, self.stride, 'SAME', self.act_type], 127 | [256, self.kernel_size, self.stride, 'SAME', self.act_type], 128 | [512, self.kernel_size, self.stride, 'SAME', self.act_type]] 129 | 130 | fc_list_d = [[1, '']] 131 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 132 | return outputs 133 | 134 | 135 | class DisSegGraph(DisGraph): 136 | def __init__(self, inputs, cond_inputs, config, name='mlp', is_reuse=False): 137 | super(DisSegGraph, self).__init__(inputs, cond_inputs, config, name, is_reuse) 138 | 139 | def _build_dis_mlp_graph(self): 140 | outputs = [] 141 | return outputs 142 | 143 | def _build_dis_cnn_graph(self): 144 | print('seg_cnn_graph') 145 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, 20, 3]) 146 | # bs * 20 * num_steps * 3 147 | tf.transpose(mot_input, [0, 2, 1, 3]) 148 | # [3, 3] [2, 2] 149 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type], 150 | [128, self.kernel_size, self.stride, 'SAME', self.act_type], 151 | [256, self.kernel_size, self.stride, 'SAME', self.act_type], 152 | [512, self.kernel_size, self.stride, 'SAME', self.act_type]] 153 | fc_list_d = [[1, '']] 154 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 155 | return outputs 156 | 157 | def _build_dis_cond_cnn_graph(self): 158 | print('seg_cond_cnn_graph') 159 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 160 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 161 | # bs * mus_ebd_dim * num_steps * 1 162 | # cond_input = tf.transpose(cond_input, [0, 2, 1, 3]) 163 | # mot_input = tf.transpose(mot_input, [0, 2, 1, 3]) 164 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond') 165 | if self.is_shuffle: 166 | original_shape = all_input.get_shape().as_list() 167 | np.random.seed(1234567890) 168 | shuffle_list = list(np.random.permutation(original_shape[0])) 169 | all_inputs = [] 170 | for i, idx in enumerate(shuffle_list): 171 | all_inputs.append(all_input[idx:idx+1, :, :, :]) 172 | all_input = tf.concat(all_inputs, axis=0) 173 | print('all_input: ', all_input) 174 | # [3, 3] [2, 2] 175 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type], 176 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 177 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 178 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']] 179 | fc_list_d = [[1, '']] 180 | outputs = md.cnn(all_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 181 | return outputs 182 | 183 | def _build_dis_time_cond_cnn_graph(self): 184 | print('seg_time_cond_cnn_graph') 185 | # bs * 1 * num_steps * 72 186 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim]) 187 | mot_input = tf.reshape(self.inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim]) 188 | 189 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond') 190 | if self.is_shuffle: 191 | original_shape = all_input.get_shape().as_list() 192 | np.random.seed(1234567890) 193 | shuffle_list = list(np.random.permutation(original_shape[0])) 194 | all_inputs = [] 195 | for i, idx in enumerate(shuffle_list): 196 | all_inputs.append(all_input[idx:idx+1, :, :, :]) 197 | all_input = tf.concat(all_inputs, axis=0) 198 | print('all_input: ', all_input) 199 | # [1, 3] [1, 2] 200 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type], 201 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 202 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 203 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']] 204 | fc_list_d = [[1, '']] 205 | outputs = md.cnn(all_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse) 206 | return outputs 207 | 208 | def _build_dis_tgan_cond_cnn_graph(self): 209 | print('tgan_cond_cnn_graph') 210 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 211 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1]) 212 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond') 213 | if self.is_shuffle: 214 | print('shuffle') 215 | original_shape = all_input.get_shape().as_list() 216 | np.random.seed(1234567890) 217 | shuffle_list = list(np.random.permutation(original_shape[0])) 218 | all_inputs = [] 219 | for i, idx in enumerate(shuffle_list): 220 | all_inputs.append(all_input[idx:idx+1, :, :, :]) 221 | all_input = tf.concat(all_inputs, axis=0) 222 | print('all_input: ', all_input) 223 | # [3, 3] [2, 2] 224 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type], 225 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 226 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 227 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']] 228 | outputs = md.cnn(all_input, conv_list_d, [], name='discriminator', 229 | is_training=self.is_training, reuse=self.is_reuse) 230 | return outputs 231 | 232 | def _build_dis_time_tgan_cond_cnn_graph(self): 233 | print('time_tgan_cond_cnn_graph') 234 | # bs * 1 * num_steps * 72 235 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim]) 236 | mot_input = tf.reshape(self.inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim]) 237 | 238 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond') 239 | if self.is_shuffle: 240 | original_shape = all_input.get_shape().as_list() 241 | np.random.seed(1234567890) 242 | shuffle_list = list(np.random.permutation(original_shape[0])) 243 | all_inputs = [] 244 | for i, idx in enumerate(shuffle_list): 245 | all_inputs.append(all_input[idx:idx+1, :, :, :]) 246 | all_input = tf.concat(all_inputs, axis=0) 247 | print('all_input: ', all_input) 248 | # [1, 3] [1, 2] 249 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type], 250 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 251 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'], 252 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']] 253 | outputs = md.cnn(all_input, conv_list_d, [], name='discriminator', reuse=self.is_reuse) 254 | return outputs 255 | 256 | 257 | -------------------------------------------------------------------------------- /model/gan_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import * 2 | from model import discriminator_model as dm 3 | from utils import exp_loss as es, gan_loss as gls 4 | 5 | 6 | class GanModel(BaseModel): 7 | """The Generative adversarial model""" 8 | def __init__(self, train_type, config): 9 | super(GanModel, self).__init__(train_type, config) 10 | 11 | mot_predictions = self.mot_predictions 12 | mot_truth = self.mot_truth 13 | mus_ebd_outputs = self.mus_ebd_outputs 14 | 15 | tru_pos, pre_pos = es.get_pos_chls(mot_predictions, mot_truth, config) 16 | 17 | dis_name = config.dis_name 18 | dis_graph = getattr(dm, config.dis_type) 19 | 20 | if self.mus_ebd_dim == 60: 21 | real_data = mot_truth 22 | fake_data = mot_predictions 23 | elif self.mus_ebd_dim == 72: 24 | real_data = tru_pos 25 | fake_data = pre_pos 26 | else: 27 | real_data = tf.concat([mot_truth, tru_pos], axis=-1) 28 | fake_data = tf.concat([mot_predictions, pre_pos], axis=-1) 29 | 30 | print('real_data:', real_data) 31 | print('fake_data:', fake_data) 32 | 33 | g_sig_loss, d_loss, clip_d_weights = \ 34 | gls.gan_loss(dis_graph, dis_name, real_data=real_data, fake_data=fake_data, 35 | cond_inputs=mus_ebd_outputs, config=config) 36 | 37 | # generator loss 38 | g_loss, loss_list = es.loss_impl(mot_predictions, mot_truth, pre_pos, tru_pos, config) 39 | # g_mse_loss = tf.reduce_mean(tf.squared_difference(mot_predictions, mot_truth), 40 | # name='mean_square_loss') 41 | g_loss = config.mse_rate * g_loss + config.dis_rate * g_sig_loss 42 | self.g_loss = [loss_list, g_sig_loss] 43 | self.d_loss = d_loss 44 | 45 | # if test, return 46 | if not self.is_training: 47 | return 48 | 49 | tvars = tf.trainable_variables() 50 | d_vars = [v for v in tvars if 'discriminator' in v.name] 51 | g_vars = [v for v in tvars if 'generator' in v.name] 52 | 53 | # add reg 54 | if config.is_reg: 55 | reg_cost = tf.reduce_sum([tf.nn.l2_loss(v) for v in g_vars 56 | if 'bias' not in v.name]) * config.reg_scale 57 | g_loss = g_loss + reg_cost 58 | 59 | gen_learning_rate = config.learning_rate 60 | dis_learning_rate = config.dis_learning_rate 61 | 62 | if config.optimizer.lower() == 'adam': 63 | print('Adam optimizer') 64 | g_optimizer = tf.train.AdamOptimizer(learning_rate=gen_learning_rate) 65 | d_optimizer = tf.train.AdamOptimizer(learning_rate=dis_learning_rate) 66 | else: 67 | print('Rmsprop optimizer') 68 | g_optimizer = tf.train.RMSPropOptimizer(learning_rate=gen_learning_rate) 69 | d_optimizer = tf.train.RMSPropOptimizer(learning_rate=dis_learning_rate) 70 | 71 | # for batch_norm op 72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 73 | g_grads = tf.gradients(g_loss, g_vars, aggregation_method=2) 74 | d_grads = tf.gradients(d_loss, d_vars, aggregation_method=2) 75 | with tf.control_dependencies(update_ops): 76 | self.train_g_op = g_optimizer.apply_gradients(zip(g_grads, g_vars)) 77 | print('train_g_op') 78 | 79 | if clip_d_weights: 80 | with tf.control_dependencies([clip_d_weights, update_ops]): 81 | self.train_d_op = d_optimizer.apply_gradients(zip(d_grads, d_vars)) 82 | # self._train_d_op = optimizer.minimize(d_loss, var_list=d_vars) 83 | else: 84 | with tf.control_dependencies(update_ops): 85 | self.train_d_op = d_optimizer.apply_gradients(zip(d_grads, d_vars)) 86 | print('train_d_op') -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | dis_type='DisSegGraph' 3 | loss_mode='gan' 4 | seg_len=90 5 | loss_type=2 6 | if [ $loss_type == 1 ]; then 7 | loss_arr=(1.0 0.1 0.0) 8 | elif [ $loss_type == 2 ]; then 9 | loss_arr=(1.0 0.1 0.1) 10 | else 11 | loss_arr=(1.0 0.0 0.0) 12 | fi 13 | mus_ebd_dim=72 14 | dis_name='time_cond_cnn' 15 | kernel_size=(1 3) 16 | stride=(1 2) 17 | cond_axis=1 18 | CUDA_VISIBLE_DEVICES=$gpu \ 19 | python3 train_gan.py --learning_rate 1e-4 \ 20 | --dis_learning_rate 2e-5 \ 21 | --mse_rate 1 \ 22 | --dis_rate 0.01 \ 23 | --loss_mode $loss_mode \ 24 | --is_load_model False \ 25 | --is_reg False \ 26 | --reg_scale 5e-5 \ 27 | --rnn_keep_list 1.0 1.0 1.0\ 28 | --dis_type $dis_type \ 29 | --dis_name $dis_name \ 30 | --loss_rate_list ${loss_arr[0]} ${loss_arr[1]} ${loss_arr[2]}\ 31 | --kernel_size ${kernel_size[0]} ${kernel_size[1]} \ 32 | --stride ${stride[0]} ${stride[1]}\ 33 | --act_type lrelu \ 34 | --optimizer Adam \ 35 | --cond_axis $cond_axis \ 36 | --seg_list $seg_len \ 37 | --seq_shift 1 \ 38 | --gen_hop $seg_len \ 39 | --fold_list 0 \ 40 | --type_list all-f4 \ 41 | --model_path '' \ 42 | --max_max_epoch 20 \ 43 | --save_data_epoch 5 \ 44 | --save_model_epoch 5 \ 45 | --is_save_train False \ 46 | --mot_scale 100. \ 47 | --norm_way zscore \ 48 | --teacher_forcing_ratio 0. \ 49 | --tf_decay 1. \ 50 | --batch_size 128 \ 51 | --mus_ebd_dim $mus_ebd_dim \ 52 | --has_random_seed False \ 53 | --is_all_norm False \ 54 | --add_info ./output/pretrain -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import os 5 | import time 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | 9 | import numpy as np 10 | from model.gan_model import * 11 | 12 | import m2m_config as cfg 13 | from utils import reader as rd 14 | 15 | 16 | def run_epoch(session, model, data_info, config, teacher_forcing_ratio, 17 | path=None, train_type=0, verbose=False, epoch=0): 18 | """Runs the model on the given data""" 19 | d_losses = 0.0 20 | g_exp_losses = 0.0 21 | g_losses = 0.0 22 | 23 | mot_state = session.run(model.mot_initial_state) 24 | mus_state = session.run(model.mus_initial_state) 25 | 26 | step = 0 27 | np.random.seed(123456789) 28 | 29 | for batch_x, batch_y, batch_f in rd.capg_seq_generator(epoch, train_type, data_info, config): 30 | feed_dict = dict() 31 | if config.is_use_pre_mot: 32 | for i, (c, h) in enumerate(model.mot_initial_state): 33 | feed_dict[c] = mot_state[i].c 34 | feed_dict[h] = mot_state[i].h 35 | 36 | for i, (c, h) in enumerate(model.mus_initial_state): 37 | feed_dict[c] = mus_state[i].c 38 | feed_dict[h] = mus_state[i].h 39 | 40 | tf_mask = np.random.uniform(size=config.num_steps) < teacher_forcing_ratio 41 | # print(tf_mask) 42 | feed_dict[model.tf_mask] = tf_mask 43 | 44 | last_step_mot = copy.deepcopy(batch_f) 45 | last_step_mot[:, :6] = 0 46 | 47 | feed_dict[model.init_step_mot] = last_step_mot 48 | feed_dict[model.input_x] = batch_x 49 | feed_dict[model.input_y] = batch_y 50 | 51 | g_fetches = { 52 | "last_step_mot": model.last_step_mot, 53 | "g_loss": model.g_loss, 54 | "eval_op": model.train_g_op 55 | } 56 | 57 | d_fetches = { 58 | "d_loss": model.d_loss, 59 | "eval_op": model.train_d_op 60 | } 61 | 62 | d_vals = session.run(d_fetches, feed_dict) 63 | g_vals = session.run(g_fetches, feed_dict) 64 | 65 | d_loss = d_vals["d_loss"] 66 | g_loss = g_vals["g_loss"] 67 | 68 | d_losses += d_loss 69 | g_exp_losses += g_loss[0][-1] 70 | g_losses += g_loss[1] 71 | step += 1 72 | 73 | if verbose: 74 | info = "Epoch {0}: {1} d_loss: {2} g_loss: {3}, exp_loss: {4}\n".format( 75 | epoch, step, d_loss, g_loss[1], g_loss[0]) 76 | print(info) 77 | with open(path, 'a') as fh: 78 | fh.write(info) 79 | 80 | return [d_losses/step, g_losses/step, g_exp_losses/step] 81 | 82 | 83 | def generate_motion(session, model, data_info, gen_str, test_config, hop, epoch=0, 84 | time_dir=None, use_pre_mot=True, prefix='test', is_save=True): 85 | """Runs the model on the given data""" 86 | g_exp_losses = 0.0 87 | g_losses = 0.0 88 | d_losses = 0.0 89 | 90 | fetches = { 91 | "prediction": model.mot_predictions, 92 | "last_step_mot": model.last_step_mot, 93 | "g_loss": model.g_loss, 94 | "d_loss": model.d_loss, 95 | "mot_final_state": model.mot_final_state, 96 | "mus_final_state": model.mus_final_state, 97 | } 98 | 99 | step = 0 100 | num_steps = test_config.num_steps 101 | pre_mot = [] 102 | mus_data = data_info[gen_str][0] 103 | mot_data = copy.deepcopy(data_info[gen_str][1]) 104 | 105 | seq_keys = list(mus_data.keys()) 106 | seq_keys.sort() 107 | mus_delay = test_config.mus_delay 108 | 109 | for file_name in seq_keys: 110 | predictions = [] 111 | mus_file_data = mus_data[file_name] 112 | mot_file_data = mot_data[file_name] 113 | test_len = min(mus_file_data.shape[1]+mus_delay, mot_file_data.shape[1]) 114 | test_num = int((test_len - 1 - num_steps) / hop + 1) 115 | 116 | mot_state = session.run(model.mot_initial_state) 117 | mus_state = session.run(model.mus_initial_state) 118 | 119 | for t in range(test_num): 120 | batch_x = mus_file_data[:, t * hop + 1 - mus_delay: t * hop + num_steps + 1 - mus_delay, :] 121 | batch_y = mot_file_data[:, t * hop + 1: t * hop + num_steps + 1, :] 122 | batch_f = mot_file_data[:, t * hop, :] # first frame 123 | 124 | feed_dict = dict() 125 | if use_pre_mot: 126 | for i, (c, h) in enumerate(model.mot_initial_state): 127 | feed_dict[c] = mot_state[i].c 128 | feed_dict[h] = mot_state[i].h 129 | 130 | for i, (c, h) in enumerate(model.mus_initial_state): 131 | feed_dict[c] = mus_state[i].c 132 | feed_dict[h] = mus_state[i].h 133 | 134 | if t > 0 and use_pre_mot: 135 | last_step_mot = copy.deepcopy(pre_mot) 136 | else: 137 | last_step_mot = copy.deepcopy(batch_f) 138 | last_step_mot[:, :6] = 0 139 | 140 | feed_dict[model.init_step_mot] = last_step_mot 141 | feed_dict[model.input_x] = batch_x 142 | feed_dict[model.input_y] = batch_y 143 | feed_dict[model.tf_mask] = [False] * test_config.num_steps 144 | 145 | vals = session.run(fetches, feed_dict) 146 | 147 | prediction = vals["prediction"] 148 | g_loss = vals["g_loss"] 149 | d_loss = vals["d_loss"] 150 | mot_state = vals["mot_final_state"] 151 | mus_state = vals["mus_final_state"] 152 | pre_mot = vals["last_step_mot"] 153 | 154 | d_losses += d_loss 155 | g_exp_losses += g_loss[0][-1] 156 | g_losses += g_loss[1] 157 | 158 | step += 1 159 | prediction = np.reshape(prediction, [test_config.num_steps, test_config.mot_dim]) 160 | predictions.append(prediction) 161 | 162 | if is_save and ((epoch+1) % test_config.save_data_epoch == 0 or epoch == 0): 163 | test_pred_path = os.path.join(time_dir, prefix, str(epoch+1), file_name + ".csv") 164 | if len(predictions): 165 | predictions = np.concatenate(predictions, 0) 166 | rd.save_predict_data(predictions, test_pred_path, data_info, 167 | test_config.norm_way, test_config.mot_ignore_dims, 168 | test_config.mot_scale) 169 | 170 | return [d_losses/step, g_losses/step, g_exp_losses/step] 171 | 172 | 173 | def save_arg(config, path): 174 | config_dict = dict() 175 | for name, value in vars(config).items(): 176 | config_dict[name] = value 177 | json.dump(config_dict, open(path, 'w'), indent=4, sort_keys=True) 178 | 179 | 180 | def run_main(config, test_config, data_info): 181 | 182 | with tf.Graph().as_default(): 183 | with tf.name_scope("Train"): 184 | with tf.variable_scope("Model", reuse=None): 185 | train_model = GanModel(config=config, 186 | train_type=0) 187 | 188 | with tf.name_scope("Test"): 189 | with tf.variable_scope("Model", reuse=True): 190 | test_model = GanModel(config=test_config, 191 | train_type=2) 192 | 193 | # allowing gpu memory growth 194 | gpu_config = tf.ConfigProto() 195 | saver = tf.train.Saver(max_to_keep=20) 196 | gpu_config.gpu_options.allow_growth = True 197 | 198 | with tf.Session(config=gpu_config) as session: 199 | 200 | # initialize all variables 201 | if config.is_load_model: 202 | saver.restore(session, config.model_path) 203 | else: 204 | session.run(tf.global_variables_initializer()) 205 | 206 | # start queue 207 | coord = tf.train.Coordinator() 208 | tf.train.start_queue_runners(sess=session, coord=coord) 209 | 210 | time_str = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 211 | save_dir = config.save_dir 212 | if not os.path.exists(save_dir): 213 | os.makedirs(save_dir) 214 | model_save_dir = os.path.join(save_dir, 'model') 215 | if not os.path.exists(save_dir): 216 | os.makedirs(model_save_dir) 217 | train_loss_dict = OrderedDict() 218 | test_loss_dict = OrderedDict() 219 | start_time = time.time() 220 | train_loss_path = os.path.join(save_dir, "train_loss.txt") 221 | train_step_loss_path = os.path.join(save_dir, "train_step_loss.txt") 222 | config_path = os.path.join(save_dir, "config.txt") 223 | time_path = os.path.join(save_dir, "time.txt") 224 | config.save_config(config_path) 225 | arg_path = os.path.join(save_dir, "args.txt") 226 | save_arg(args, arg_path) 227 | 228 | teacher_forcing_ratio = config.teacher_forcing_ratio 229 | 230 | for i in range(config.max_max_epoch): 231 | train_loss = \ 232 | run_epoch(session, train_model, 233 | data_info, config, 234 | teacher_forcing_ratio, 235 | path=train_step_loss_path, 236 | train_type=0, 237 | epoch=i, 238 | verbose=True) 239 | 240 | print("---Epoch {0} train_loss: {1}\n".format(i, train_loss)) 241 | train_loss_dict[str(i+1)] = train_loss 242 | json.dump(train_loss_dict, open(train_loss_path, 'w'), indent=4) 243 | 244 | if (i + 1) % test_config.save_data_epoch == 0: 245 | _ = \ 246 | generate_motion(session, test_model, 247 | data_info, 'test', test_config, hop=test_config.num_steps, 248 | epoch=i, time_dir=save_dir, 249 | use_pre_mot=True, prefix='seq') 250 | 251 | if test_config.is_save_train: 252 | _ = \ 253 | generate_motion(session, test_model, 254 | data_info, 'train', test_config, hop=test_config.num_steps, 255 | epoch=i, time_dir=save_dir, 256 | use_pre_mot=True, prefix='seq_train') 257 | 258 | if (i == 0 or (i + 1) % config.save_model_epoch == 0) and config.is_save_model: 259 | model_save_path = os.path.join(model_save_dir, 'cnn-erd_'+str(i)+'_model.ckpt') 260 | saver.save(session, model_save_path) 261 | 262 | time_info = "Epoch: {0} Elapsed Time : {1}\n".format(i + 1, time.time()-start_time) 263 | print(time_info) 264 | with open(time_path, 'a') as fh: 265 | fh.write(time_info) 266 | 267 | teacher_forcing_ratio *= config.tf_decay 268 | 269 | coord.request_stop() 270 | coord.join() 271 | 272 | 273 | def main(_): 274 | type_list = args.type_list 275 | fold_list = args.fold_list 276 | seg_list = args.seg_list 277 | 278 | for seg_len in seg_list: 279 | for fold_idx in fold_list: 280 | for i, m_type in enumerate(type_list): 281 | seg_str = str(seg_len) 282 | fold_str = 'fold_' + str(fold_idx) 283 | print(m_type, seg_str, fold_str) 284 | if fold_idx != 0 and m_type in ['hiphop', 'salsa']: 285 | continue 286 | if fold_idx == 3 and m_type == 'groovenet': 287 | continue 288 | config = cfg.get_config(m_type, fold_str, seg_str) 289 | cfg_list = [] 290 | care_list = ['add_info', 'mse_rate', 'dis_rate', 'dis_learning_rate', 291 | 'reg_scale', 'rnn_keep_list', 'is_reg', 'cond_axis'] 292 | for k, v in sorted(vars(args).items()): 293 | print(k, v) 294 | setattr(config, k, v) 295 | if k in care_list: 296 | v_str = str(v) 297 | if isinstance(v, bool): 298 | v_str = v_str[0] 299 | cfg_list.append(v_str) 300 | config.save_dir = os.path.join(args.add_info, m_type) 301 | 302 | args.care_list = care_list 303 | test_config = copy.deepcopy(config) 304 | test_config.batch_size = config.test_batch_size 305 | test_config.num_steps = config.test_num_steps 306 | 307 | print(config.save_dir) 308 | data_info = rd.run_all(config) 309 | config.mot_data_info = data_info['mot'] 310 | test_config.mot_data_info = data_info['mot'] 311 | run_main(config, test_config, data_info) 312 | 313 | 314 | if __name__ == "__main__": 315 | 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument('--add_info', type=str, default='') 318 | parser.add_argument('--learning_rate', type=float, default=1e-4) 319 | parser.add_argument('--is_load_model', type=lambda x: (str(x).lower() == 'true')) 320 | parser.add_argument('--optimizer', type=str, default='Adam') 321 | parser.add_argument('--fold_list', nargs='+', type=int, help='0, 1, 2, 3') 322 | parser.add_argument('--seg_list', nargs='+', type=int, help='150, 90') 323 | parser.add_argument('--type_list', nargs='+', type=str, help='music type') 324 | parser.add_argument('--model_path', type=str, help='model_path') 325 | parser.add_argument('--max_max_epoch', type=int, help='training epoch number') 326 | parser.add_argument('--save_model_epoch', type=int, help='save_model_epoch_number') 327 | parser.add_argument('--save_data_epoch', type=int, help='save_data_epoch_number') 328 | parser.add_argument('--is_reg', type=lambda x: (str(x).lower() == 'true'), help='if add regularization') 329 | parser.add_argument('--reg_scale', type=float, help='5e-4') 330 | parser.add_argument('--rnn_keep_list', nargs='+', type=float, help='rnn_keep_probability list, [1.0, 1.0, 1.0]') 331 | parser.add_argument('--batch_size', type=int, help='32 or 64') 332 | parser.add_argument('--has_random_seed', type=lambda x: (str(x).lower() == 'true'), help='') 333 | parser.add_argument('--teacher_forcing_ratio', type=float, help='') 334 | parser.add_argument('--tf_decay', type=float, help='') 335 | parser.add_argument('--norm_way', type=str, help='zscore, maxmin, no') 336 | parser.add_argument('--seq_shift', type=int, help='seq_shift') 337 | parser.add_argument('--gen_hop', type=int, help='gen_hop') 338 | parser.add_argument('--mot_scale', type=float, help='motion scale') 339 | parser.add_argument('--is_save_train', type=lambda x: (str(x).lower() == 'true')) 340 | parser.add_argument('--cond_axis', type=int, help='1: height, 3: channel', default=3) 341 | parser.add_argument('--act_type', type=str, default='lrelu') 342 | parser.add_argument('--kernel_size', nargs='+', type=int) 343 | parser.add_argument('--stride', nargs='+', type=int) 344 | parser.add_argument('--dis_learning_rate', type=float, default=1e-4) 345 | parser.add_argument('--dis_type', type=str, help='DisFrameGraph or DisSegGraph') 346 | parser.add_argument('--dis_name', type=str, default='cond_cnn') 347 | parser.add_argument('--mse_rate', type=float, default=0.99) 348 | parser.add_argument('--dis_rate', type=float, default=0.01) 349 | parser.add_argument('--loss_mode', type=str, default='gan') 350 | parser.add_argument('--clip_value', type=float, default=0.01) 351 | parser.add_argument('--pen_lambda', type=float, default=10) 352 | parser.add_argument('--mus_ebd_dim', type=int) 353 | parser.add_argument('--is_all_norm', type=lambda x: (str(x).lower() == 'true'), default=False) 354 | parser.add_argument('--loss_rate_list', nargs='+', type=float, default=[1., 0., 0.]) 355 | 356 | args = parser.parse_args() 357 | 358 | tf.app.run() 359 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/utils/__init__.py -------------------------------------------------------------------------------- /utils/capg_exp_skel.pkl: -------------------------------------------------------------------------------- 1 | (lp0 2 | ccopy_reg 3 | _reconstructor 4 | p1 5 | (cbvh 6 | Node 7 | p2 8 | c__builtin__ 9 | object 10 | p3 11 | Ntp4 12 | Rp5 13 | (dp6 14 | Vparent 15 | p7 16 | L-1L 17 | sV_is_root 18 | p8 19 | I01 20 | sVoffset 21 | p9 22 | (F38.2704 23 | F101.351 24 | F-83.7549 25 | tp10 26 | sVchannels 27 | p11 28 | (lp12 29 | VXposition 30 | p13 31 | aVYposition 32 | p14 33 | aVZposition 34 | p15 35 | aVZrotation 36 | p16 37 | aVXrotation 38 | p17 39 | aVYrotation 40 | p18 41 | asVorder 42 | p19 43 | VZXY 44 | p20 45 | sVquat_idx 46 | p21 47 | L0L 48 | sVname 49 | p22 50 | VHips 51 | p23 52 | sVchildren 53 | p24 54 | (lp25 55 | g1 56 | (g2 57 | g3 58 | Ntp26 59 | Rp27 60 | (dp28 61 | g7 62 | L0L 63 | sg8 64 | I00 65 | sg9 66 | (F0.0 67 | F11.6235 68 | F0.0 69 | tp29 70 | sg11 71 | (lp30 72 | VZrotation 73 | p31 74 | aVXrotation 75 | p32 76 | aVYrotation 77 | p33 78 | asg19 79 | VZXY 80 | p34 81 | sg21 82 | L1L 83 | sg22 84 | VChest 85 | p35 86 | sg24 87 | (lp36 88 | g1 89 | (g2 90 | g3 91 | Ntp37 92 | Rp38 93 | (dp39 94 | g7 95 | L1L 96 | sg8 97 | I00 98 | sg9 99 | (F0.0 100 | F16.7104 101 | F0.0 102 | tp40 103 | sg11 104 | (lp41 105 | VZrotation 106 | p42 107 | aVXrotation 108 | p43 109 | aVYrotation 110 | p44 111 | asg19 112 | VZXY 113 | p45 114 | sg21 115 | L2L 116 | sg22 117 | VChest2 118 | p46 119 | sg24 120 | (lp47 121 | g1 122 | (g2 123 | g3 124 | Ntp48 125 | Rp49 126 | (dp50 127 | g7 128 | L2L 129 | sg8 130 | I00 131 | sg9 132 | (F0.0 133 | F10.3009 134 | F0.0 135 | tp51 136 | sg11 137 | (lp52 138 | VZrotation 139 | p53 140 | aVXrotation 141 | p54 142 | aVYrotation 143 | p55 144 | asg19 145 | VZXY 146 | p56 147 | sg21 148 | L3L 149 | sg22 150 | VNeck 151 | p57 152 | sg24 153 | (lp58 154 | g1 155 | (g2 156 | g3 157 | Ntp59 158 | Rp60 159 | (dp61 160 | g7 161 | L3L 162 | sg8 163 | I00 164 | sg9 165 | (F0.0 166 | F13.2005 167 | F0.0 168 | tp62 169 | sg11 170 | (lp63 171 | VZrotation 172 | p64 173 | aVXrotation 174 | p65 175 | aVYrotation 176 | p66 177 | asg19 178 | VZXY 179 | p67 180 | sg21 181 | L4L 182 | sg22 183 | VHead 184 | p68 185 | sg24 186 | (lp69 187 | g1 188 | (g2 189 | g3 190 | Ntp70 191 | Rp71 192 | (dp72 193 | g7 194 | L4L 195 | sg8 196 | I00 197 | sg9 198 | (F0.0 199 | F10.5299 200 | F0.0 201 | tp73 202 | sg11 203 | (lp74 204 | sg19 205 | V 206 | p75 207 | sg21 208 | Nsg22 209 | VEnd Site 210 | p76 211 | sg24 212 | (lp77 213 | sVrot_idx 214 | p78 215 | (lp79 216 | sVpos_idx 217 | p80 218 | (lp81 219 | sVexp_idx 220 | p82 221 | (lp83 222 | L18L 223 | aL19L 224 | aL20L 225 | asbasg78 226 | cnumpy.core.multiarray 227 | _reconstruct 228 | p84 229 | (cnumpy 230 | ndarray 231 | p85 232 | (L0L 233 | tp86 234 | c_codecs 235 | encode 236 | p87 237 | (Vb 238 | p88 239 | Vlatin1 240 | p89 241 | tp90 242 | Rp91 243 | tp92 244 | Rp93 245 | (L1L 246 | (L3L 247 | tp94 248 | cnumpy 249 | dtype 250 | p95 251 | (Vi4 252 | p96 253 | L0L 254 | L1L 255 | tp97 256 | Rp98 257 | (L3L 258 | V< 259 | p99 260 | NNNL-1L 261 | L-1L 262 | L0L 263 | tp100 264 | bI00 265 | g87 266 | (V 267 | p101 268 | g89 269 | tp102 270 | Rp103 271 | tp104 272 | bsg80 273 | (lp105 274 | sg82 275 | (lp106 276 | L15L 277 | aL16L 278 | aL17L 279 | asbasg78 280 | g84 281 | (g85 282 | (L0L 283 | tp107 284 | g91 285 | tp108 286 | Rp109 287 | (L1L 288 | (L3L 289 | tp110 290 | g98 291 | I00 292 | g87 293 | (V  294 | p111 295 | g89 296 | tp112 297 | Rp113 298 | tp114 299 | bsg80 300 | (lp115 301 | sg82 302 | (lp116 303 | L12L 304 | aL13L 305 | aL14L 306 | asbag1 307 | (g2 308 | g3 309 | Ntp117 310 | Rp118 311 | (dp119 312 | g7 313 | L2L 314 | sg8 315 | I00 316 | sg9 317 | (F2.69605 318 | F10.657 319 | F4.47645 320 | tp120 321 | sg11 322 | (lp121 323 | VZrotation 324 | p122 325 | aVXrotation 326 | p123 327 | aVYrotation 328 | p124 329 | asg19 330 | VZXY 331 | p125 332 | sg21 333 | L5L 334 | sg22 335 | VLeftCollar 336 | p126 337 | sg24 338 | (lp127 339 | g1 340 | (g2 341 | g3 342 | Ntp128 343 | Rp129 344 | (dp130 345 | g7 346 | L6L 347 | sg8 348 | I00 349 | sg9 350 | (F14.7774 351 | F0.0 352 | F0.0 353 | tp131 354 | sg11 355 | (lp132 356 | VZrotation 357 | p133 358 | aVXrotation 359 | p134 360 | aVYrotation 361 | p135 362 | asg19 363 | VZXY 364 | p136 365 | sg21 366 | L6L 367 | sg22 368 | VLeftShoulder 369 | p137 370 | sg24 371 | (lp138 372 | g1 373 | (g2 374 | g3 375 | Ntp139 376 | Rp140 377 | (dp141 378 | g7 379 | L7L 380 | sg8 381 | I00 382 | sg9 383 | (F0.0 384 | F-30.7247 385 | F0.0 386 | tp142 387 | sg11 388 | (lp143 389 | VZrotation 390 | p144 391 | aVXrotation 392 | p145 393 | aVYrotation 394 | p146 395 | asg19 396 | VZXY 397 | p147 398 | sg21 399 | L7L 400 | sg22 401 | VLeftElbow 402 | p148 403 | sg24 404 | (lp149 405 | g1 406 | (g2 407 | g3 408 | Ntp150 409 | Rp151 410 | (dp152 411 | g7 412 | L8L 413 | sg8 414 | I00 415 | sg9 416 | (F0.0 417 | F-24.9766 418 | F0.0 419 | tp153 420 | sg11 421 | (lp154 422 | VZrotation 423 | p155 424 | aVXrotation 425 | p156 426 | aVYrotation 427 | p157 428 | asg19 429 | VZXY 430 | p158 431 | sg21 432 | L8L 433 | sg22 434 | VLeftWrist 435 | p159 436 | sg24 437 | (lp160 438 | g1 439 | (g2 440 | g3 441 | Ntp161 442 | Rp162 443 | (dp163 444 | g7 445 | L9L 446 | sg8 447 | I00 448 | sg9 449 | (F0.0 450 | F-18.7451 451 | F0.0 452 | tp164 453 | sg11 454 | (lp165 455 | sg19 456 | g75 457 | sg21 458 | Nsg22 459 | g76 460 | sg24 461 | (lp166 462 | sg78 463 | (lp167 464 | sg80 465 | (lp168 466 | sg82 467 | (lp169 468 | L33L 469 | aL34L 470 | aL35L 471 | asbasg78 472 | g84 473 | (g85 474 | (L0L 475 | tp170 476 | g91 477 | tp171 478 | Rp172 479 | (L1L 480 | (L3L 481 | tp173 482 | g98 483 | I00 484 | g87 485 | (V 486 | p174 487 | g89 488 | tp175 489 | Rp176 490 | tp177 491 | bsg80 492 | (lp178 493 | sg82 494 | (lp179 495 | L30L 496 | aL31L 497 | aL32L 498 | asbasg78 499 | g84 500 | (g85 501 | (L0L 502 | tp180 503 | g91 504 | tp181 505 | Rp182 506 | (L1L 507 | (L3L 508 | tp183 509 | g98 510 | I00 511 | g87 512 | (V 513 | p184 514 | g89 515 | tp185 516 | Rp186 517 | tp187 518 | bsg80 519 | (lp188 520 | sg82 521 | (lp189 522 | L27L 523 | aL28L 524 | aL29L 525 | asbasg78 526 | g84 527 | (g85 528 | (L0L 529 | tp190 530 | g91 531 | tp191 532 | Rp192 533 | (L1L 534 | (L3L 535 | tp193 536 | g98 537 | I00 538 | g87 539 | (V 540 | p194 541 | g89 542 | tp195 543 | Rp196 544 | tp197 545 | bsg80 546 | (lp198 547 | sg82 548 | (lp199 549 | L24L 550 | aL25L 551 | aL26L 552 | asbasg78 553 | g84 554 | (g85 555 | (L0L 556 | tp200 557 | g91 558 | tp201 559 | Rp202 560 | (L1L 561 | (L3L 562 | tp203 563 | g98 564 | I00 565 | g87 566 | (V 567 | p204 568 | g89 569 | tp205 570 | Rp206 571 | tp207 572 | bsg80 573 | (lp208 574 | sg82 575 | (lp209 576 | L21L 577 | aL22L 578 | aL23L 579 | asbag1 580 | (g2 581 | g3 582 | Ntp210 583 | Rp211 584 | (dp212 585 | g7 586 | L2L 587 | sg8 588 | I00 589 | sg9 590 | (F-2.69605 591 | F10.657 592 | F4.47645 593 | tp213 594 | sg11 595 | (lp214 596 | VZrotation 597 | p215 598 | aVXrotation 599 | p216 600 | aVYrotation 601 | p217 602 | asg19 603 | VZXY 604 | p218 605 | sg21 606 | L9L 607 | sg22 608 | VRightCollar 609 | p219 610 | sg24 611 | (lp220 612 | g1 613 | (g2 614 | g3 615 | Ntp221 616 | Rp222 617 | (dp223 618 | g7 619 | L11L 620 | sg8 621 | I00 622 | sg9 623 | (F-15.4132 624 | F0.0 625 | F0.0 626 | tp224 627 | sg11 628 | (lp225 629 | VZrotation 630 | p226 631 | aVXrotation 632 | p227 633 | aVYrotation 634 | p228 635 | asg19 636 | VZXY 637 | p229 638 | sg21 639 | L10L 640 | sg22 641 | VRightShoulder 642 | p230 643 | sg24 644 | (lp231 645 | g1 646 | (g2 647 | g3 648 | Ntp232 649 | Rp233 650 | (dp234 651 | g7 652 | L12L 653 | sg8 654 | I00 655 | sg9 656 | (F0.0 657 | F-28.1813 658 | F0.0 659 | tp235 660 | sg11 661 | (lp236 662 | VZrotation 663 | p237 664 | aVXrotation 665 | p238 666 | aVYrotation 667 | p239 668 | asg19 669 | VZXY 670 | p240 671 | sg21 672 | L11L 673 | sg22 674 | VRightElbow 675 | p241 676 | sg24 677 | (lp242 678 | g1 679 | (g2 680 | g3 681 | Ntp243 682 | Rp244 683 | (dp245 684 | g7 685 | L13L 686 | sg8 687 | I00 688 | sg9 689 | (F0.0 690 | F-24.9766 691 | F0.0 692 | tp246 693 | sg11 694 | (lp247 695 | VZrotation 696 | p248 697 | aVXrotation 698 | p249 699 | aVYrotation 700 | p250 701 | asg19 702 | VZXY 703 | p251 704 | sg21 705 | L12L 706 | sg22 707 | VRightWrist 708 | p252 709 | sg24 710 | (lp253 711 | g1 712 | (g2 713 | g3 714 | Ntp254 715 | Rp255 716 | (dp256 717 | g7 718 | L14L 719 | sg8 720 | I00 721 | sg9 722 | (F0.0 723 | F-18.1602 724 | F0.0 725 | tp257 726 | sg11 727 | (lp258 728 | sg19 729 | g75 730 | sg21 731 | Nsg22 732 | g76 733 | sg24 734 | (lp259 735 | sg78 736 | (lp260 737 | sg80 738 | (lp261 739 | sg82 740 | (lp262 741 | L48L 742 | aL49L 743 | aL50L 744 | asbasg78 745 | g84 746 | (g85 747 | (L0L 748 | tp263 749 | g91 750 | tp264 751 | Rp265 752 | (L1L 753 | (L3L 754 | tp266 755 | g98 756 | I00 757 | g87 758 | (V()' 759 | p267 760 | g89 761 | tp268 762 | Rp269 763 | tp270 764 | bsg80 765 | (lp271 766 | sg82 767 | (lp272 768 | L45L 769 | aL46L 770 | aL47L 771 | asbasg78 772 | g84 773 | (g85 774 | (L0L 775 | tp273 776 | g91 777 | tp274 778 | Rp275 779 | (L1L 780 | (L3L 781 | tp276 782 | g98 783 | I00 784 | g87 785 | (V%&$ 786 | p277 787 | g89 788 | tp278 789 | Rp279 790 | tp280 791 | bsg80 792 | (lp281 793 | sg82 794 | (lp282 795 | L42L 796 | aL43L 797 | aL44L 798 | asbasg78 799 | g84 800 | (g85 801 | (L0L 802 | tp283 803 | g91 804 | tp284 805 | Rp285 806 | (L1L 807 | (L3L 808 | tp286 809 | g98 810 | I00 811 | g87 812 | (V"#! 813 | p287 814 | g89 815 | tp288 816 | Rp289 817 | tp290 818 | bsg80 819 | (lp291 820 | sg82 821 | (lp292 822 | L39L 823 | aL40L 824 | aL41L 825 | asbasg78 826 | g84 827 | (g85 828 | (L0L 829 | tp293 830 | g91 831 | tp294 832 | Rp295 833 | (L1L 834 | (L3L 835 | tp296 836 | g98 837 | I00 838 | g87 839 | (V  840 | p297 841 | g89 842 | tp298 843 | Rp299 844 | tp300 845 | bsg80 846 | (lp301 847 | sg82 848 | (lp302 849 | L36L 850 | aL37L 851 | aL38L 852 | asbasg78 853 | g84 854 | (g85 855 | (L0L 856 | tp303 857 | g91 858 | tp304 859 | Rp305 860 | (L1L 861 | (L3L 862 | tp306 863 | g98 864 | I00 865 | g87 866 | (V\u000a 867 | p307 868 | g89 869 | tp308 870 | Rp309 871 | tp310 872 | bsg80 873 | (lp311 874 | sg82 875 | (lp312 876 | L9L 877 | aL10L 878 | aL11L 879 | asbasg78 880 | g84 881 | (g85 882 | (L0L 883 | tp313 884 | g91 885 | tp314 886 | Rp315 887 | (L1L 888 | (L3L 889 | tp316 890 | g98 891 | I00 892 | g87 893 | (V 894 | p317 895 | g89 896 | tp318 897 | Rp319 898 | tp320 899 | bsg80 900 | (lp321 901 | sg82 902 | (lp322 903 | L6L 904 | aL7L 905 | aL8L 906 | asbag1 907 | (g2 908 | g3 909 | Ntp323 910 | Rp324 911 | (dp325 912 | g7 913 | L0L 914 | sg8 915 | I00 916 | sg9 917 | (F8.724 918 | F0.0 919 | F0.0 920 | tp326 921 | sg11 922 | (lp327 923 | VZrotation 924 | p328 925 | aVXrotation 926 | p329 927 | aVYrotation 928 | p330 929 | asg19 930 | VZXY 931 | p331 932 | sg21 933 | L13L 934 | sg22 935 | VLeftHip 936 | p332 937 | sg24 938 | (lp333 939 | g1 940 | (g2 941 | g3 942 | Ntp334 943 | Rp335 944 | (dp336 945 | g7 946 | L16L 947 | sg8 948 | I00 949 | sg9 950 | (F0.0 951 | F-46.9773 952 | F0.0 953 | tp337 954 | sg11 955 | (lp338 956 | VZrotation 957 | p339 958 | aVXrotation 959 | p340 960 | aVYrotation 961 | p341 962 | asg19 963 | VZXY 964 | p342 965 | sg21 966 | L14L 967 | sg22 968 | VLeftKnee 969 | p343 970 | sg24 971 | (lp344 972 | g1 973 | (g2 974 | g3 975 | Ntp345 976 | Rp346 977 | (dp347 978 | g7 979 | L17L 980 | sg8 981 | I00 982 | sg9 983 | (F0.0 984 | F-45.6547 985 | F0.0 986 | tp348 987 | sg11 988 | (lp349 989 | VZrotation 990 | p350 991 | aVXrotation 992 | p351 993 | aVYrotation 994 | p352 995 | asg19 996 | VZXY 997 | p353 998 | sg21 999 | L15L 1000 | sg22 1001 | VLeftAnkle 1002 | p354 1003 | sg24 1004 | (lp355 1005 | g1 1006 | (g2 1007 | g3 1008 | Ntp356 1009 | Rp357 1010 | (dp358 1011 | g7 1012 | L18L 1013 | sg8 1014 | I00 1015 | sg9 1016 | (F0.0 1017 | F-10.2547 1018 | F12.3862 1019 | tp359 1020 | sg11 1021 | (lp360 1022 | sg19 1023 | g75 1024 | sg21 1025 | Nsg22 1026 | g76 1027 | sg24 1028 | (lp361 1029 | sg78 1030 | (lp362 1031 | sg80 1032 | (lp363 1033 | sg82 1034 | (lp364 1035 | L60L 1036 | aL61L 1037 | aL62L 1038 | asbasg78 1039 | g84 1040 | (g85 1041 | (L0L 1042 | tp365 1043 | g91 1044 | tp366 1045 | Rp367 1046 | (L1L 1047 | (L3L 1048 | tp368 1049 | g98 1050 | I00 1051 | g87 1052 | (V120 1053 | p369 1054 | g89 1055 | tp370 1056 | Rp371 1057 | tp372 1058 | bsg80 1059 | (lp373 1060 | sg82 1061 | (lp374 1062 | L57L 1063 | aL58L 1064 | aL59L 1065 | asbasg78 1066 | g84 1067 | (g85 1068 | (L0L 1069 | tp375 1070 | g91 1071 | tp376 1072 | Rp377 1073 | (L1L 1074 | (L3L 1075 | tp378 1076 | g98 1077 | I00 1078 | g87 1079 | (V./- 1080 | p379 1081 | g89 1082 | tp380 1083 | Rp381 1084 | tp382 1085 | bsg80 1086 | (lp383 1087 | sg82 1088 | (lp384 1089 | L54L 1090 | aL55L 1091 | aL56L 1092 | asbasg78 1093 | g84 1094 | (g85 1095 | (L0L 1096 | tp385 1097 | g91 1098 | tp386 1099 | Rp387 1100 | (L1L 1101 | (L3L 1102 | tp388 1103 | g98 1104 | I00 1105 | g87 1106 | (V+,* 1107 | p389 1108 | g89 1109 | tp390 1110 | Rp391 1111 | tp392 1112 | bsg80 1113 | (lp393 1114 | sg82 1115 | (lp394 1116 | L51L 1117 | aL52L 1118 | aL53L 1119 | asbag1 1120 | (g2 1121 | g3 1122 | Ntp395 1123 | Rp396 1124 | (dp397 1125 | g7 1126 | L0L 1127 | sg8 1128 | I00 1129 | sg9 1130 | (F-8.724 1131 | F0.0 1132 | F0.0 1133 | tp398 1134 | sg11 1135 | (lp399 1136 | VZrotation 1137 | p400 1138 | aVXrotation 1139 | p401 1140 | aVYrotation 1141 | p402 1142 | asg19 1143 | VZXY 1144 | p403 1145 | sg21 1146 | L16L 1147 | sg22 1148 | VRightHip 1149 | p404 1150 | sg24 1151 | (lp405 1152 | g1 1153 | (g2 1154 | g3 1155 | Ntp406 1156 | Rp407 1157 | (dp408 1158 | g7 1159 | L20L 1160 | sg8 1161 | I00 1162 | sg9 1163 | (F0.0 1164 | F-46.9773 1165 | F0.0 1166 | tp409 1167 | sg11 1168 | (lp410 1169 | VZrotation 1170 | p411 1171 | aVXrotation 1172 | p412 1173 | aVYrotation 1174 | p413 1175 | asg19 1176 | VZXY 1177 | p414 1178 | sg21 1179 | L17L 1180 | sg22 1181 | VRightKnee 1182 | p415 1183 | sg24 1184 | (lp416 1185 | g1 1186 | (g2 1187 | g3 1188 | Ntp417 1189 | Rp418 1190 | (dp419 1191 | g7 1192 | L21L 1193 | sg8 1194 | I00 1195 | sg9 1196 | (F0.0 1197 | F-45.6547 1198 | F0.0 1199 | tp420 1200 | sg11 1201 | (lp421 1202 | VZrotation 1203 | p422 1204 | aVXrotation 1205 | p423 1206 | aVYrotation 1207 | p424 1208 | asg19 1209 | VZXY 1210 | p425 1211 | sg21 1212 | L18L 1213 | sg22 1214 | VRightAnkle 1215 | p426 1216 | sg24 1217 | (lp427 1218 | g1 1219 | (g2 1220 | g3 1221 | Ntp428 1222 | Rp429 1223 | (dp430 1224 | g7 1225 | L22L 1226 | sg8 1227 | I00 1228 | sg9 1229 | (F0.0 1230 | F-10.2547 1231 | F12.3862 1232 | tp431 1233 | sg11 1234 | (lp432 1235 | sg19 1236 | g75 1237 | sg21 1238 | Nsg22 1239 | g76 1240 | sg24 1241 | (lp433 1242 | sg78 1243 | (lp434 1244 | sg80 1245 | (lp435 1246 | sg82 1247 | (lp436 1248 | L72L 1249 | aL73L 1250 | aL74L 1251 | asbasg78 1252 | g84 1253 | (g85 1254 | (L0L 1255 | tp437 1256 | g91 1257 | tp438 1258 | Rp439 1259 | (L1L 1260 | (L3L 1261 | tp440 1262 | g98 1263 | I00 1264 | g87 1265 | (V:;9 1266 | p441 1267 | g89 1268 | tp442 1269 | Rp443 1270 | tp444 1271 | bsg80 1272 | (lp445 1273 | sg82 1274 | (lp446 1275 | L69L 1276 | aL70L 1277 | aL71L 1278 | asbasg78 1279 | g84 1280 | (g85 1281 | (L0L 1282 | tp447 1283 | g91 1284 | tp448 1285 | Rp449 1286 | (L1L 1287 | (L3L 1288 | tp450 1289 | g98 1290 | I00 1291 | g87 1292 | (V786 1293 | p451 1294 | g89 1295 | tp452 1296 | Rp453 1297 | tp454 1298 | bsg80 1299 | (lp455 1300 | sg82 1301 | (lp456 1302 | L66L 1303 | aL67L 1304 | aL68L 1305 | asbasg78 1306 | g84 1307 | (g85 1308 | (L0L 1309 | tp457 1310 | g91 1311 | tp458 1312 | Rp459 1313 | (L1L 1314 | (L3L 1315 | tp460 1316 | g98 1317 | I00 1318 | g87 1319 | (V453 1320 | p461 1321 | g89 1322 | tp462 1323 | Rp463 1324 | tp464 1325 | bsg80 1326 | (lp465 1327 | sg82 1328 | (lp466 1329 | L63L 1330 | aL64L 1331 | aL65L 1332 | asbasg78 1333 | g84 1334 | (g85 1335 | (L0L 1336 | tp467 1337 | g91 1338 | tp468 1339 | Rp469 1340 | (L1L 1341 | (L3L 1342 | tp470 1343 | g98 1344 | I00 1345 | g87 1346 | (V 1347 | p471 1348 | g89 1349 | tp472 1350 | Rp473 1351 | tp474 1352 | bsg80 1353 | (lp475 1354 | L0L 1355 | aL1L 1356 | aL2L 1357 | asg82 1358 | (lp476 1359 | L3L 1360 | aL4L 1361 | aL5L 1362 | asbag27 1363 | ag38 1364 | ag49 1365 | ag60 1366 | ag71 1367 | ag118 1368 | ag129 1369 | ag140 1370 | ag151 1371 | ag162 1372 | ag211 1373 | ag222 1374 | ag233 1375 | ag244 1376 | ag255 1377 | ag324 1378 | ag335 1379 | ag346 1380 | ag357 1381 | ag396 1382 | ag407 1383 | ag418 1384 | ag429 1385 | a. -------------------------------------------------------------------------------- /utils/exp_loss.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import tensorflow as tf 4 | 5 | from utils import tf_expsdk as tk 6 | 7 | 8 | def unnorm_chls(chls, config): 9 | mot_data_info = config.mot_data_info 10 | norm_way = config.norm_way 11 | mot_ignore_dims = config.mot_ignore_dims 12 | nd = mot_data_info[0].shape[0] 13 | all_dims = list(range(nd)) 14 | useful_dims = [d for d in all_dims if d not in mot_ignore_dims] 15 | original_shape = chls.get_shape().as_list() 16 | 17 | if norm_way == 'no': 18 | print('unnorm no') 19 | unnorm_data = chls 20 | else: 21 | print('unnorm zscore') 22 | chls = tf.reshape(chls, [-1, original_shape[-1]]) 23 | data_mean = tf.constant(mot_data_info[2][useful_dims], dtype=tf.float32, shape=[1, len(useful_dims)]) 24 | data_std = tf.constant(mot_data_info[3][useful_dims], dtype=tf.float32, shape=[1, len(useful_dims)]) 25 | unnorm_data = tf.multiply(chls, data_std) + data_mean 26 | unnorm_data = tf.reshape(unnorm_data, original_shape) 27 | 28 | return unnorm_data 29 | 30 | 31 | def norm_chls(chls, config, eps=1e-6): 32 | mot_data_info = config.mot_data_info 33 | norm_way = config.norm_way 34 | mot_ignore_dims = config.mot_ignore_dims 35 | nd = mot_data_info[0].shape[0] 36 | all_dims = list(range(nd)) 37 | useful_dims = [d for d in all_dims if d not in mot_ignore_dims] 38 | original_shape = chls.get_shape().as_list() 39 | 40 | if norm_way == 'no': 41 | print('norm no') 42 | norm_data = chls 43 | else: 44 | print('norm zscore') 45 | chls = tf.reshape(chls, [-1, original_shape[-1]]) 46 | data_mean = tf.constant(mot_data_info[2][useful_dims], dtype=tf.float32, shape=[1, len(useful_dims)]) 47 | data_std = tf.constant(mot_data_info[3][useful_dims], dtype=tf.float32, shape=[1, len(useful_dims)]) 48 | norm_data = (chls - data_mean) / (data_std + eps) 49 | norm_data = tf.reshape(norm_data, original_shape) 50 | 51 | return norm_data 52 | 53 | 54 | def _normalize(chls, axis, eps=1e-6): 55 | unit_chls = tf.nn.l2_normalize(chls, dim=axis, epsilon=eps) 56 | return unit_chls 57 | 58 | 59 | def _dot2ang(dot_product, eps=1e-6): 60 | dot_product = tf.clip_by_value(dot_product, -1.+eps, 1.-eps) 61 | angle = tf.acos(dot_product) 62 | # angle = 1 - dot_product 63 | return angle 64 | 65 | 66 | def mse_loss_impl(pre_chls, tru_chls): 67 | return tf.reduce_mean(tf.squared_difference(pre_chls, tru_chls)) 68 | 69 | 70 | def mse_trans_loss_impl(pre_chls, tru_chls): 71 | pre_trans_chls = pre_chls[:, :, :3] 72 | tru_trans_chls = tru_chls[:, :, :3] 73 | 74 | return tf.reduce_mean(tf.squared_difference(pre_trans_chls, tru_trans_chls)) 75 | 76 | 77 | def mse_exp_loss_impl(pre_chls, tru_chls): 78 | pre_exp_chls = pre_chls[:, :, 3:] 79 | tru_exp_chls = tru_chls[:, :, 3:] 80 | return tf.reduce_mean(tf.squared_difference(pre_exp_chls, tru_exp_chls)) 81 | 82 | 83 | def path_loss_impl(pre_chls, tru_chls): 84 | """ 85 | path loss 86 | :param pre_chls: batch_size * num_steps * 79 87 | :param tru_chls: batch_size * num_steps * 79 88 | :return: loss value 89 | """ 90 | print('path_loss') 91 | pre_root_positions = tf.stack([pre_chls[:, :, 0], pre_chls[:, :, 2]], axis=-1) 92 | tru_root_positions = tf.stack([tru_chls[:, :, 0], tru_chls[:, :, 2]], axis=-1) 93 | pre_path = pre_root_positions[:, 1:, :] - pre_root_positions[:, :-1, :] 94 | tru_path = tru_root_positions[:, 1:, :] - tru_root_positions[:, :-1, :] 95 | 96 | # pre_path = pre_root_positions - pre_root_positions[:, 0:1, :] 97 | # tru_path = tru_root_positions - tru_root_positions[:, 0:1, :] 98 | 99 | # pre_path = tf.norm(pre_path, axis=-1) 100 | # tru_path = tf.norm(tru_path, axis=-1) 101 | 102 | # path_loss = tf.squared_difference(pre_path, tru_path) 103 | path_loss = tf.norm(pre_path-tru_path, axis=-1) 104 | path_loss = tf.reduce_mean(path_loss) 105 | 106 | return path_loss 107 | 108 | 109 | def height_loss_impl(pre_chls, tru_chls): 110 | pre_height = pre_chls[:, :, 1] 111 | tru_height = tru_chls[:, :, 1] 112 | height_loss = tf.reduce_mean(tf.abs(pre_height-tru_height)) 113 | 114 | return height_loss 115 | 116 | 117 | def pos_loss_impl(pre_chls, tru_chls): 118 | 119 | chls_shape = pre_chls.get_shape().as_list() 120 | num_joints = int((chls_shape[2] - 3) / 3) 121 | pre_pos_chls = tf.reshape(pre_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 122 | tru_pos_chls = tf.reshape(tru_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 123 | pos_loss = tf.norm(pre_pos_chls-tru_pos_chls, axis=-1) 124 | pos_loss = tf.reduce_mean(pos_loss) 125 | 126 | pre_delta_chls = pre_pos_chls[:, 1:, :] - pre_pos_chls[:, :-1, :] 127 | tru_delta_chls = tru_pos_chls[:, 1:, :] - tru_pos_chls[:, :-1, :] 128 | pos_delta_loss = tf.norm(pre_delta_chls-tru_delta_chls, axis=-1) 129 | pos_delta_loss = tf.reduce_mean(pos_delta_loss) 130 | 131 | return pos_loss, pos_delta_loss 132 | 133 | 134 | def pos_delta_loss_impl(pre_chls, tru_chls): 135 | chls_shape = pre_chls.get_shape().as_list() 136 | num_joints = int((chls_shape[2] - 3) / 3) 137 | pre_pos_chls = tf.reshape(pre_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 138 | tru_pos_chls = tf.reshape(tru_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 139 | 140 | pre_delta_chls = pre_pos_chls[:, 1:, :] - pre_pos_chls[:, :-1, :] 141 | tru_delta_chls = tru_pos_chls[:, 1:, :] - tru_pos_chls[:, :-1, :] 142 | 143 | pos_loss = tf.norm(pre_delta_chls-tru_delta_chls, axis=-1) 144 | pos_loss = tf.reduce_mean(pos_loss) 145 | 146 | return pos_loss 147 | 148 | 149 | def _get_angle(root_positions): 150 | 151 | path_former = root_positions[:, 1:-1, :] - root_positions[:, :-2, :] 152 | unit_former = _normalize(path_former, axis=-1) 153 | 154 | path_latter = root_positions[:, 2:, :] - root_positions[:, 1:-1, :] 155 | unit_latter = _normalize(path_latter, axis=-1) 156 | 157 | dot_product = tf.reduce_sum(tf.multiply(unit_former, unit_latter), axis=-1) 158 | angle = _dot2ang(dot_product) 159 | 160 | return angle 161 | 162 | 163 | def dir_loss_impl(pre_chls, tru_chls): 164 | """ 165 | direction loss 166 | :param pre_chls: batch_size * num_steps * 79 167 | :param tru_chls: batch_size * num_steps * 79 168 | :return: loss value 169 | """ 170 | pre_root_positions = tf.stack([pre_chls[:, :, 0], pre_chls[:, :, 2]], axis=-1) 171 | tru_root_positions = tf.stack([tru_chls[:, :, 0], tru_chls[:, :, 2]], axis=-1) 172 | 173 | pre_angle = _get_angle(pre_root_positions) 174 | tru_angle = _get_angle(tru_root_positions) 175 | angle_dist = pre_angle - tru_angle 176 | dir_loss = tf.abs(angle_dist) 177 | dir_loss = tf.reduce_mean(dir_loss) 178 | return dir_loss 179 | 180 | 181 | def _get_root_ori(chls): 182 | """ 183 | CHip(Chest) 184 | /\ 185 | / \ 186 | Rhip /____\ Lhip 187 | :param chls: 188 | :return: 189 | """ 190 | c_hip_chl = chls[:, :, 1, :] 191 | l_hip_chl = chls[:, :, 16, :] 192 | r_hip_chl = chls[:, :, 20, :] 193 | 194 | c_ori = _normalize(tf.cross(r_hip_chl-c_hip_chl, l_hip_chl-c_hip_chl), axis=-1) 195 | # r_ori = tf.nn.l2_normalize(tf.cross(l_hip_chl-r_hip_chl, c_hip_chl-r_hip_chl), dim=-1) 196 | # l_ori = tf.nn.l2_normalize(tf.cross(c_hip_chl-l_hip_chl, r_hip_chl-l_hip_chl), dim=-1) 197 | # root_ori = tf.nn.l2_normalize((c_ori + r_ori + l_ori) / 3, dim=-1) 198 | 199 | return c_ori 200 | 201 | 202 | def ori_loss_impl(pre_chls, tru_chls): 203 | """ 204 | orientation loss 205 | :param pre_chls: batch_size * num_steps * 72 206 | :param tru_chls: batch_size * num_steps * 72 207 | :return: loss value 208 | """ 209 | chls_shape = pre_chls.get_shape().as_list() 210 | num_joints = int((chls_shape[2]) / 3) 211 | pre_chls = tf.reshape(pre_chls, [chls_shape[0], chls_shape[1], num_joints, 3]) 212 | tru_chls = tf.reshape(tru_chls, [chls_shape[0], chls_shape[1], num_joints, 3]) 213 | 214 | pre_ori = _get_root_ori(pre_chls) 215 | tru_ori = _get_root_ori(tru_chls) 216 | 217 | dot_product = tf.reduce_sum(tf.multiply(pre_ori, tru_ori), axis=-1) 218 | angle_dist = _dot2ang(dot_product) 219 | ori_loss = tf.reduce_mean(tf.abs(angle_dist)) 220 | 221 | pre_dot_product = tf.reduce_sum(tf.multiply(pre_ori[:, 1:, :], pre_ori[:, :-1, :]), axis=-1) 222 | pre_angle = _dot2ang(pre_dot_product) 223 | tru_dot_product = tf.reduce_sum(tf.multiply(tru_ori[:, 1:, :], tru_ori[:, :-1, :]), axis=-1) 224 | tru_angle = _dot2ang(tru_dot_product) 225 | ori_delta_loss = tf.reduce_mean(tf.abs(pre_angle - tru_angle)) 226 | 227 | return ori_loss, ori_delta_loss 228 | 229 | 230 | def get_pos_chls(pre_chls, tru_chls, config): 231 | mot_scale = config.mot_scale 232 | 233 | pre_chls = unnorm_chls(pre_chls, config) 234 | tru_chls = unnorm_chls(tru_chls, config) 235 | 236 | init_t = tf.constant([0, 0, 0], dtype=tf.float32) 237 | init_r = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32) 238 | pre_chls = tk.revert_coordinate_space(pre_chls, init_t, init_r) 239 | tru_chls = tk.revert_coordinate_space(tru_chls, init_t, init_r) 240 | 241 | with open('./utils/capg_exp_skel.pkl', 'rb') as fh: 242 | skel = pickle.load(fh) 243 | 244 | chls_shape = pre_chls.get_shape().as_list() 245 | num_joints = int((chls_shape[2] - 3) / 3) 246 | 247 | pre_exp_chls = tf.reshape(pre_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 248 | # if add height in trans, will result in nan loss 249 | pre_trans_chls = pre_chls[:, :, :3] 250 | 251 | tru_exp_chls = tf.reshape(tru_chls[:, :, 3:], [chls_shape[0], chls_shape[1], num_joints, 3]) 252 | tru_trans_chls = tru_chls[:, :, :3] 253 | 254 | pre_positions = tk.exp2xyz(skel, pre_exp_chls, pre_trans_chls, mot_scale) 255 | tru_positions = tk.exp2xyz(skel, tru_exp_chls, tru_trans_chls, mot_scale) 256 | pre_positions = tf.reshape(pre_positions, [chls_shape[0], chls_shape[1], -1]) 257 | tru_positions = tf.reshape(tru_positions, [chls_shape[0], chls_shape[1], -1]) 258 | 259 | return pre_positions, tru_positions 260 | 261 | 262 | def loss_impl(pre_chls, tru_chls, pre_pos_chls, tru_pos_chls, config): 263 | rate = config.loss_rate_list 264 | 265 | mse_loss = mse_loss_impl(pre_chls, tru_chls) 266 | # pos_loss = pos_loss_impl(pre_chls, tru_chls, mot_scale) 267 | pos_loss, pos_delta_loss = pos_loss_impl(pre_pos_chls, tru_pos_chls) 268 | path_loss = path_loss_impl(pre_pos_chls, tru_pos_chls) 269 | height_loss = height_loss_impl(pre_pos_chls, tru_pos_chls) 270 | dir_loss = dir_loss_impl(pre_pos_chls, tru_pos_chls) 271 | ori_loss, ori_delta_loss = ori_loss_impl(pre_pos_chls, tru_pos_chls) 272 | 273 | path_loss = path_loss + height_loss 274 | loss_list = [mse_loss, pos_loss, path_loss, dir_loss, ori_loss, pos_delta_loss, ori_delta_loss] 275 | loss_res_list = [] 276 | loss = tf.constant(0, dtype=tf.float32, name='loss') 277 | for i in range(len(loss_list)): 278 | if i == 0: 279 | rate_idx = 0 280 | elif 1 <= i <= 4: 281 | rate_idx = 1 282 | else: 283 | rate_idx = 2 284 | loss_res_list.append(loss_list[i]) 285 | if rate[rate_idx] != 0: 286 | loss += rate[rate_idx] * loss_list[i] 287 | loss_res_list.append(loss) 288 | 289 | return loss, loss_res_list 290 | 291 | 292 | -------------------------------------------------------------------------------- /utils/gan_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import inlib.models as md 3 | 4 | 5 | def gan_loss(dis_graph, dis_name, real_data, fake_data, cond_inputs, config): 6 | config.is_shuffle = False 7 | d_real_model = dis_graph(real_data, cond_inputs, name=dis_name, is_reuse=False, config=config) 8 | config.is_shuffle = True 9 | d_fake_model = dis_graph(fake_data, cond_inputs, name=dis_name, is_reuse=True, config=config) 10 | real_logits = d_real_model.build_dis_graph() 11 | fake_logits = d_fake_model.build_dis_graph() 12 | 13 | mode = config.loss_mode 14 | clip_disc_weights = [] 15 | 16 | if mode == 'wgan': 17 | gen_loss = -tf.reduce_mean(fake_logits) 18 | disc_loss = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) 19 | 20 | clip_ops = [] 21 | disc_vars = [var for var in tf.trainable_variables() if 'discriminator' in var.name] 22 | for var in disc_vars: 23 | if not hasattr(config, 'clip_value'): 24 | raise ValueError('wgan must set the clip_value argument!') 25 | clip_value = config.clip_value 26 | clip_bounds = [-clip_value, clip_value] 27 | clip_ops.append( 28 | tf.assign(var, tf.clip_by_value( 29 | var, clip_bounds[0], clip_bounds[1])) 30 | ) 31 | clip_disc_weights = tf.group(*clip_ops) 32 | 33 | elif mode == 'wgan-gp': 34 | gen_loss = -tf.reduce_mean(fake_logits) 35 | disc_loss = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) 36 | 37 | alpha = tf.random_uniform( 38 | shape=[real_data.get_shape()[0].value, 1, 1], minval=0., maxval=1.) 39 | differences = fake_data - real_data 40 | interpolates = real_data + (alpha*differences) 41 | gradients = tf.gradients(dis_graph(interpolates, cond_inputs, config=config, 42 | name=dis_name, is_reuse=True).build_dis_graph(), 43 | [interpolates])[0] 44 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2])) 45 | gradient_penalty = tf.reduce_mean((slopes-1.)**2) 46 | if not hasattr(config, 'pen_lambda'): 47 | raise ValueError('wgan-gp must have lambda argument') 48 | disc_loss += config.pen_lambda * gradient_penalty 49 | 50 | elif mode == 'gan': 51 | gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 52 | logits=fake_logits, labels=tf.ones_like(fake_logits))) 53 | disc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 54 | logits=fake_logits, labels=tf.zeros_like(fake_logits))) 55 | disc_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 56 | logits=real_logits, labels=tf.ones_like(real_logits))) 57 | disc_loss /= 2.0 58 | 59 | elif mode == 'rsgan': 60 | gen_loss = tf.reduce_mean(-tf.log(tf.sigmoid(fake_logits - real_logits) + 1e-9)) 61 | disc_loss = tf.reduce_mean(-tf.log(tf.sigmoid(real_logits - fake_logits) + 1e-9)) 62 | 63 | elif 'tgan' in mode: 64 | x_real_fake = tf.subtract(real_logits, fake_logits) 65 | x_fake_real = tf.subtract(fake_logits, real_logits) 66 | fc_list = [[1, '']] 67 | x_real_fake_score = md.mlp(x_real_fake, fc_list, 'discriminator', reuse=False) 68 | x_fake_real_score = md.mlp(x_fake_real, fc_list, 'discriminator', reuse=True) 69 | loss_type = mode.split('-')[1] 70 | gen_loss = tgan_gen_loss(loss_type, real=x_real_fake_score, fake=x_fake_real_score) 71 | disc_loss = tgan_diss_loss(loss_type, real=x_real_fake_score, fake=x_fake_real_score) 72 | else: 73 | raise ValueError('Not implemented loss mode.') 74 | 75 | return gen_loss, disc_loss, clip_disc_weights 76 | 77 | 78 | def tgan_diss_loss(loss_type, real, fake): 79 | real_loss = 0 80 | fake_loss = 0 81 | 82 | if loss_type == 'wgan': 83 | real_loss = -tf.reduce_mean(real) 84 | fake_loss = tf.reduce_mean(fake) 85 | 86 | if loss_type == 'lsgan': 87 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) 88 | fake_loss = tf.reduce_mean(tf.square(fake)) 89 | 90 | if loss_type == 'sgan' or loss_type == 'dragan': 91 | print('sgan') 92 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)+1e-9) 93 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)+1e-9) 94 | 95 | if loss_type == 'hinge': 96 | real_loss = tf.reduce_mean(tf.nn.relu(1.0 - real)) 97 | fake_loss = tf.reduce_mean(tf.nn.relu(1.0 + fake)) 98 | 99 | loss = real_loss + fake_loss 100 | 101 | return loss 102 | 103 | 104 | def tgan_gen_loss(loss_type, real, fake): 105 | real_loss = 0 106 | fake_loss = 0 107 | 108 | if loss_type == 'wgan': 109 | real_loss = tf.reduce_mean(real) 110 | fake_loss = -tf.reduce_mean(fake) 111 | 112 | if loss_type == 'lsgan': 113 | real_loss = tf.reduce_mean(tf.square(real)) 114 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) 115 | 116 | if loss_type == 'sgan' or loss_type == 'dragan': 117 | print('sgan') 118 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real), logits=real)+1e-9) 119 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)+1e-9) 120 | 121 | if loss_type == 'hinge': 122 | fake_loss = -tf.reduce_mean(fake) 123 | 124 | loss = real_loss + fake_loss 125 | 126 | return loss 127 | 128 | -------------------------------------------------------------------------------- /utils/plot_loss.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') 3 | import matplotlib.pyplot as plt 4 | from collections import OrderedDict 5 | import os 6 | 7 | 8 | 9 | def run(): 10 | a = OrderedDict() 11 | a[0] = 0.9 12 | a[1] = 0.8 13 | a[2] = 0.6 14 | save_plot_loss('', a, a, a) 15 | plt.show() 16 | 17 | 18 | def save_plot_loss(save_path, d_loss_dict, g_mse_loss_dict, g_loss_dict): 19 | d_loss = [k for k in d_loss_dict.values()] 20 | g_mse_loss = [k for k in g_mse_loss_dict.values()] 21 | g_loss = [k for k in g_loss_dict.values()] 22 | x = list(range(len(d_loss))) 23 | 24 | plt.plot(x, d_loss, 'g-', label='d_loss') 25 | plt.plot(x, g_mse_loss, 'r-', label='g_mse_loss') 26 | plt.plot(x, g_loss, 'b-', label='g_loss') 27 | plt.legend() 28 | plt.xlabel('epochs') 29 | plt.ylabel('loss value') 30 | plt.savefig(save_path) 31 | plt.close() 32 | -------------------------------------------------------------------------------- /utils/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import pandas as pd 5 | import numpy as np 6 | import m2m_config as cfg 7 | 8 | 9 | # config = cfg.get_config() 10 | 11 | 12 | def get_normalize_info(input_data): 13 | data_mean = np.mean(input_data, axis=0) 14 | data_std = np.std(input_data, axis=0) 15 | return data_mean, data_std 16 | 17 | 18 | def load_data(data_dir, json_path, mot_scale): 19 | all_data = [] 20 | seq_data = dict() 21 | with open(json_path, 'r') as f: 22 | list_csv = json.load(f) 23 | for i, csv_name in enumerate(list_csv): 24 | csv_no_ext_name = os.path.splitext(os.path.basename(csv_name))[0] 25 | print(i, csv_no_ext_name) 26 | csv_path = os.path.join(data_dir, csv_name) 27 | csv_data = pd.read_csv(csv_path, header=None, dtype=np.float32).values 28 | csv_data[:, :3] = csv_data[:, :3] / mot_scale 29 | seq_data[csv_no_ext_name] = csv_data 30 | if 0 == len(all_data): 31 | all_data = copy.deepcopy(csv_data) 32 | else: 33 | all_data = np.append(all_data, csv_data, axis=0) 34 | 35 | return seq_data, all_data 36 | 37 | 38 | def load_tv_data(data_dir, train_json_path, val_json_path, num_steps=150): 39 | all_data = [] 40 | val_all_data = [] 41 | seq_data = dict() 42 | val_seq_data = dict() 43 | 44 | with open(train_json_path, 'r') as trf: 45 | train_list = json.load(trf) 46 | 47 | with open(val_json_path, 'r') as vf: 48 | val_list = json.load(vf) 49 | 50 | for i, file_name in enumerate(train_list): 51 | base_name = os.path.splitext(os.path.basename(file_name))[0] 52 | print(i, base_name) 53 | [mus_type, mus_idx, mus_sub_idx] = base_name.split('.') 54 | file_path = os.path.join(data_dir, file_name) 55 | file_data = pd.read_csv(file_path, header=None, dtype=np.float32).as_matrix() 56 | train_data = file_data 57 | val_data = [] 58 | for j, val_sub_idx in enumerate(val_list[mus_type][mus_idx]): 59 | if mus_sub_idx == val_sub_idx and j == 0: 60 | val_data = file_data[-num_steps:, :] 61 | train_data = file_data[:-num_steps, :] 62 | elif mus_sub_idx == val_sub_idx and j == 1: 63 | val_data = file_data[:num_steps, :] 64 | train_data = file_data[num_steps:, :] 65 | 66 | seq_data[base_name] = train_data 67 | if len(val_data) > 0: 68 | val_seq_data[file_name] = val_data 69 | 70 | if 0 == len(all_data): 71 | all_data = copy.deepcopy(train_data) 72 | else: 73 | all_data = np.append(all_data, train_data, axis=0) 74 | 75 | if 0 == len(val_all_data): 76 | val_all_data = copy.deepcopy(val_data) 77 | elif len(val_data) > 0: 78 | val_all_data = np.append(val_all_data, val_data, axis=0) 79 | 80 | return seq_data, all_data, val_seq_data, val_all_data 81 | 82 | 83 | def load_trip_data(data_dir, train_json_path, val_json_path, val_rate): 84 | all_data = [] 85 | val_all_data = [] 86 | seq_data = dict() 87 | val_seq_data = dict() 88 | 89 | with open(train_json_path, 'r') as trf: 90 | train_list = json.load(trf) 91 | 92 | with open(val_json_path, 'r') as vf: 93 | val_list = json.load(vf) 94 | 95 | for i, file_name in enumerate(train_list): 96 | base_name = os.path.splitext(os.path.basename(file_name))[0] 97 | print(i, base_name) 98 | file_path = os.path.join(data_dir, file_name) 99 | file_data = pd.read_csv(file_path, header=None, dtype=np.float32).as_matrix() 100 | train_data = file_data 101 | val_data = [] 102 | val_len = val_rate * file_data.shape[0] 103 | 104 | if file_name in val_list: 105 | val_data = file_data[-val_len:, :] 106 | train_data = file_data[:-val_len, :] 107 | 108 | seq_data[base_name] = train_data 109 | if len(val_data) > 0: 110 | val_seq_data[file_name] = val_data 111 | 112 | if 0 == len(all_data): 113 | all_data = copy.deepcopy(train_data) 114 | else: 115 | all_data = np.append(all_data, train_data, axis=0) 116 | 117 | if 0 == len(val_all_data): 118 | val_all_data = copy.deepcopy(val_data) 119 | elif len(val_data) > 0: 120 | val_all_data = np.append(val_all_data, val_data, axis=0) 121 | 122 | return seq_data, all_data, val_seq_data, val_all_data 123 | 124 | 125 | def sample_seq(mus_seq, mus_data_max, mus_data_min, mot_seq, mot_data_max, mot_data_min, 126 | mot_data_mean, mot_data_std, config): 127 | 128 | wlen = config.num_steps 129 | hop = config.seq_shift 130 | is_z_score = config.is_z_score 131 | start_idx = config.start_idx 132 | mus_dim = config.mus_dim 133 | ignore_dims = config.mot_ignore_dims 134 | mus_delay = config.mus_delay 135 | 136 | mus_x = [] 137 | mot_x = [] 138 | mot_y = [] 139 | ns = 0 # N sequence 140 | 141 | seq_keys = list(mus_seq.keys()) 142 | seq_keys.sort() 143 | # loop each action sequence, i.e. each file 144 | for k in seq_keys: 145 | print("sample_seq: ", k) 146 | mus_seq_data = mus_seq[k] 147 | mot_seq_data = mot_seq[k] 148 | start = start_idx 149 | end = start + wlen 150 | seq_len = min(mot_seq_data.shape[0], mus_seq_data.shape[0]+mus_delay) 151 | while end <= seq_len: 152 | mus_x.append(mus_seq_data[start-mus_delay:end-mus_delay, :]) 153 | mot_x.append(mot_seq_data[start-1:end-1, :]) 154 | mot_y.append(mot_seq_data[start:end, :]) 155 | ns += 1 156 | start += hop 157 | end += hop 158 | 159 | # initialize tensors 160 | mus_nd = mus_x[0].shape[1] 161 | mot_nd = mot_x[0].shape[1] 162 | mus_x_tensor = np.zeros((wlen, ns, mus_nd), dtype=np.float32) 163 | mot_x_tensor = np.zeros((wlen, ns, mot_nd), dtype=np.float32) 164 | mot_y_tensor = np.zeros((wlen, ns, mot_nd), dtype=np.float32) 165 | 166 | count = 0 167 | for _mus_x, _mot_x, _mot_y in zip(mus_x, mot_x, mot_y): 168 | mus_x_tensor[:, count, :] = _mus_x 169 | mot_x_tensor[:, count, :] = _mot_x 170 | mot_y_tensor[:, count, :] = _mot_y 171 | count += 1 172 | 173 | mus_x_tensor = norm_mus(mus_x_tensor, mus_data_max, mus_data_min, mus_dim) 174 | if is_z_score: 175 | mot_x_tensor = norm_mot_std(mot_x_tensor, mot_data_mean, mot_data_std, ignore_dims) 176 | mot_y_tensor = norm_mot_std(mot_y_tensor, mot_data_mean, mot_data_std, ignore_dims) 177 | else: 178 | mot_x_tensor = norm_mot(mot_x_tensor, mot_data_max, mot_data_min, ignore_dims) 179 | mot_y_tensor = norm_mot(mot_y_tensor, mot_data_max, mot_data_min, ignore_dims) 180 | 181 | return mus_x_tensor, mot_x_tensor, mot_y_tensor 182 | 183 | 184 | def sample_seq_idx(mus_seq, mot_seq, config): 185 | wlen = config.num_steps 186 | hop = config.seq_shift 187 | start_idx = config.start_idx 188 | mus_delay = config.mus_delay 189 | 190 | mus_x = [] 191 | mot_x = [] 192 | mot_y = [] 193 | ns = 0 # N sequence 194 | 195 | seq_keys = list(mus_seq.keys()) 196 | seq_keys.sort() 197 | # loop each action sequence, i.e. each file 198 | for k in seq_keys: 199 | print("sample_seq: ", k) 200 | mus_seq_data = mus_seq[k] 201 | mot_seq_data = mot_seq[k] 202 | start = start_idx 203 | end = start + wlen 204 | seq_len = min(mot_seq_data.shape[0], mus_seq_data.shape[0]+mus_delay) 205 | while end <= seq_len: 206 | mus_x.append([k, start - mus_delay, end - mus_delay]) 207 | mot_x.append([k, start-1, end-1]) 208 | mot_y.append([k, start, end]) 209 | ns += 1 210 | start += hop 211 | end += hop 212 | return mus_x, mot_x, mot_y 213 | 214 | 215 | def get_seq_tensor(mus_seq, mus_data_max, mus_data_min, mot_seq, mot_data_max, mot_data_min, 216 | mot_data_mean, mot_data_std, config): 217 | 218 | norm_way = config.norm_way 219 | mus_dim = config.mus_dim 220 | ignore_dims = config.mot_ignore_dims 221 | # mot_scale = config.mot_scale 222 | 223 | mus_seq_tensor = dict() 224 | mot_seq_tensor = dict() 225 | 226 | seq_keys = list(mus_seq.keys()) 227 | seq_keys.sort() 228 | # loop each action sequence, i.e. each file 229 | for k in seq_keys: 230 | print("sample_seq: ", k) 231 | mus_seq_data = mus_seq[k] 232 | mus_seq_data = np.reshape(mus_seq_data, [1, mus_seq_data.shape[0], mus_seq_data.shape[1]]) 233 | mot_seq_data = mot_seq[k] 234 | # mot_seq_data[:, :3] = mot_seq_data[:, :3] / mot_scale 235 | mot_seq_data = np.reshape(mot_seq_data, [1, mot_seq_data.shape[0], mot_seq_data.shape[1]]) 236 | 237 | mus_seq_tensor[k] = norm_mus(mus_seq_data, mus_data_max, mus_data_min, mus_dim) 238 | if norm_way == 'zscore': 239 | mot_seq_tensor[k] = norm_mot_std(mot_seq_data, mot_data_mean, mot_data_std, ignore_dims) 240 | elif norm_way == 'maxmin': 241 | mot_seq_tensor[k] = norm_mot(mot_seq_data, mot_data_max, mot_data_min, ignore_dims) 242 | else: 243 | mot_seq_tensor[k] = get_useful_dim(mot_seq_data, ignore_dims) 244 | 245 | return mus_seq_tensor, mot_seq_tensor 246 | 247 | 248 | def get_init_frame(mot_data, batch_size, num_steps): 249 | data_len = int(mot_data.shape[0]) 250 | batch_len = data_len // batch_size 251 | mot_data = mot_data[0: batch_size * batch_len, :] 252 | mot_data = np.reshape(mot_data, [batch_size, batch_len, -1]) 253 | init_mot = mot_data[0: batch_size, 0 * num_steps, :] 254 | print("init_mot shape: ", init_mot.shape) 255 | return init_mot 256 | 257 | 258 | def save_data_info(data_mean, data_std, save_dir): 259 | data_mean_path = os.path.join(save_dir, 'data_mean.csv') 260 | pd.DataFrame(data_mean).to_csv(data_mean_path, index=False, header=False) 261 | 262 | data_std_path = os.path.join(save_dir, 'data_std.csv') 263 | pd.DataFrame(data_std).to_csv(data_std_path, index=False, header=False) 264 | 265 | 266 | def get_useful_dim(input_data, ignore_dims): 267 | [_, _, nd] = input_data.shape 268 | useful_list = [] 269 | for i in range(nd): 270 | if i not in ignore_dims: 271 | useful_list.append(i) 272 | 273 | return input_data[:, :, useful_list] 274 | 275 | 276 | def unnorm_data(norm_data, data_mean, data_std, ignore_dims): 277 | sl = norm_data.shape[0] 278 | nd = data_mean.shape[0] 279 | 280 | org_data = np.zeros((sl, nd), dtype=np.float32) 281 | use_dimensions = [] 282 | for i in range(nd): 283 | if i in ignore_dims: 284 | continue 285 | use_dimensions.append(i) 286 | use_dimensions = np.array(use_dimensions) 287 | org_data[:, use_dimensions] = norm_data 288 | 289 | std_mat = data_std.reshape((1, nd)) 290 | std_mat = np.repeat(std_mat, sl, axis=0) 291 | mean_mat = data_mean.reshape((1, nd)) 292 | mean_mat = np.repeat(mean_mat, sl, axis=0) 293 | org_data = np.multiply(org_data, std_mat) + mean_mat 294 | return org_data 295 | 296 | 297 | def add_ignore_dims(input_data, ignore_dims): 298 | sl = input_data.shape[0] 299 | nd = input_data.shape[1] + len(ignore_dims) 300 | org_data = np.zeros((sl, nd), dtype=np.float32) 301 | use_dimensions = [] 302 | for i in range(nd): 303 | if i in ignore_dims: 304 | continue 305 | use_dimensions.append(i) 306 | use_dimensions = np.array(use_dimensions) 307 | org_data[:, use_dimensions] = input_data 308 | 309 | return org_data 310 | 311 | 312 | def save_predict_data(input_data, save_path, data_info, norm_way, ignore_dims, mot_scale=1.0): 313 | save_dir = os.path.dirname(save_path) 314 | if not os.path.exists(save_dir): 315 | os.makedirs(save_dir) 316 | 317 | data_max = data_info['mot'][0] 318 | data_min = data_info['mot'][1] 319 | data_mean = data_info['mot'][2] 320 | data_std = data_info['mot'][3] 321 | 322 | if norm_way == 'zscore': 323 | input_data = unnorm_data(input_data, data_mean, data_std, ignore_dims) 324 | elif norm_way == 'maxmin': 325 | input_data = unnorm_mot(input_data, data_max, data_min, ignore_dims) 326 | else: 327 | input_data = add_ignore_dims(input_data, ignore_dims) 328 | 329 | input_data[:, :3] = input_data[:, :3] * mot_scale 330 | 331 | input_data_frame = pd.DataFrame(input_data) 332 | input_data_frame.to_csv(save_path, index=False, header=False) 333 | 334 | 335 | def save_mus_data(input_data, data_info, mus_dim, save_path): 336 | data_max = data_info['mus'][0] 337 | data_min = data_info['mus'][1] 338 | input_data = unnorm_mus(input_data, data_max, data_min, mus_dim) 339 | input_data_frame = pd.DataFrame(input_data) 340 | input_data_frame.to_csv(save_path, index=False, header=False) 341 | 342 | 343 | def norm_mot(seq_data, data_max, data_min, ignore_dims): 344 | [sl, nb, nd] = seq_data.shape 345 | data_max = data_max.reshape((1, 1, nd)) 346 | data_max = np.repeat(data_max, sl, axis=0) 347 | data_max = np.repeat(data_max, nb, axis=1) 348 | 349 | data_min = data_min.reshape((1, 1, nd)) 350 | data_min = np.repeat(data_min, sl, axis=0) 351 | data_min = np.repeat(data_min, nb, axis=1) 352 | 353 | eps = 1e-12 354 | norm_data = np.divide((seq_data - data_min), (data_max - data_min) + eps) 355 | norm_data = np.multiply(norm_data, 1.8) 356 | norm_data = np.subtract(norm_data, 0.9) 357 | 358 | return get_useful_dim(norm_data, ignore_dims) 359 | 360 | 361 | def norm_mot_std(input_tensor, data_mean, data_std, ignore_dims, eps=1e-12): 362 | mean_tensor = data_mean.reshape((1, 1, input_tensor.shape[2])) 363 | mean_tensor = np.repeat(mean_tensor, input_tensor.shape[0], axis=0) 364 | mean_tensor = np.repeat(mean_tensor, input_tensor.shape[1], axis=1) 365 | data_std = np.add(data_std, eps) 366 | std_tensor = np.reshape(data_std, [1, 1, input_tensor.shape[2]]) 367 | std_tensor = np.repeat(std_tensor, input_tensor.shape[0], axis=0) 368 | std_tensor = np.repeat(std_tensor, input_tensor.shape[1], axis=1) 369 | norm_tensor = np.divide((input_tensor - mean_tensor), std_tensor) 370 | 371 | return get_useful_dim(norm_tensor, ignore_dims) 372 | 373 | 374 | def norm_mus(seq_data, data_max, data_min, mus_dim): 375 | [sl, nb, nd] = seq_data.shape 376 | seq_data = np.reshape(seq_data, [sl, nb, mus_dim, 5]) 377 | 378 | data_max = data_max.reshape((1, 1, mus_dim, 1)) 379 | data_max = np.repeat(data_max, sl, axis=0) 380 | data_max = np.repeat(data_max, nb, axis=1) 381 | data_max = np.repeat(data_max, 5, axis=3) 382 | 383 | data_min = data_min.reshape((1, 1, mus_dim, 1)) 384 | data_min = np.repeat(data_min, sl, axis=0) 385 | data_min = np.repeat(data_min, nb, axis=1) 386 | data_min = np.repeat(data_min, 5, axis=3) 387 | 388 | eps = 1e-12 389 | norm_data = np.divide((seq_data - data_min), (data_max - data_min) + eps) 390 | norm_data = np.multiply(norm_data, 0.8) 391 | norm_data = np.add(norm_data, 0.1) 392 | norm_data = np.reshape(norm_data, [sl, nb, nd]) 393 | 394 | return norm_data 395 | 396 | 397 | def unnorm_mus(norm_tensor, data_max, data_min, mus_dim): 398 | [sl, nd] = norm_tensor.shape 399 | norm_tensor = np.reshape(norm_tensor, [sl, mus_dim, 5]) 400 | data_max = data_max.reshape((1, mus_dim, 1)) 401 | data_max = np.repeat(data_max, sl, axis=0) 402 | data_max = np.repeat(data_max, 5, axis=2) 403 | 404 | data_min = data_min.reshape((1, mus_dim, 1)) 405 | data_min = np.repeat(data_min, sl, axis=0) 406 | data_min = np.repeat(data_min, 5, axis=2) 407 | 408 | unnorm_tensor = np.subtract(norm_tensor, 0.1) 409 | unnorm_tensor = np.divide(unnorm_tensor, 0.8) 410 | unnorm_tensor = np.multiply(unnorm_tensor, (data_max - data_min)) + data_min 411 | 412 | unnorm_tensor = np.reshape(unnorm_tensor, [sl, nd]) 413 | 414 | return unnorm_tensor 415 | 416 | 417 | def unnorm_mot(norm_tensor, data_max, data_min, ignore_dims): 418 | sl = norm_tensor.shape[0] 419 | nd = data_max.shape[0] 420 | 421 | data_max = data_max.reshape((1, nd)) 422 | data_max = np.repeat(data_max, sl, axis=0) 423 | data_min = data_min.reshape((1, nd)) 424 | data_min = np.repeat(data_min, sl, axis=0) 425 | 426 | org_tensor = np.zeros((sl, nd), dtype=np.float32) 427 | use_dimensions = [] 428 | for i in range(nd): 429 | if i in ignore_dims: 430 | continue 431 | use_dimensions.append(i) 432 | use_dimensions = np.array(use_dimensions) 433 | org_tensor[:, use_dimensions] = norm_tensor 434 | 435 | unnorm_tensor = np.add(org_tensor, 0.9) 436 | unnorm_tensor = np.divide(unnorm_tensor, 1.8) 437 | unnorm_tensor = np.multiply(unnorm_tensor, (data_max - data_min)) + data_min 438 | 439 | return unnorm_tensor 440 | 441 | 442 | def add_noise(x, noise=1e-5): 443 | """ 444 | :param x: 445 | :param noise: np.random.normal, mean = 0, sigma = noise 446 | :return: 447 | """ 448 | rng = np.random.RandomState(1234567890) 449 | [sl, ns, nd] = x.shape 450 | 451 | # sl * ns samples are drawn 452 | binomial_prob = rng.binomial(1, 0.5, size=(sl, ns, 1)) 453 | noise_to_add = rng.normal(scale=noise, size=x.shape) 454 | noise_sample = np.repeat(binomial_prob, nd, axis=2) * noise_to_add 455 | x += noise_sample 456 | 457 | return x 458 | 459 | 460 | def normalize_info(input_data, full_skel=True, eps=1e-4): 461 | """ 462 | :param input_data: 463 | :param full_skel: 464 | :param eps: 1e-4 default 465 | :return: data_mean, data_std, ignore_dimensions, new_idx 466 | """ 467 | data_mean = np.mean(input_data, axis=0) 468 | data_std = np.std(input_data, axis=0) 469 | ignore_dimensions = [] 470 | if not full_skel: 471 | ignore_dimensions = [0, 1, 2, 3, 4, 5] 472 | 473 | ignore_dimensions.extend(list(np.where(data_std < eps)[0])) 474 | # not_ignore_dims = cfg.get_config().not_ignore_dim # left_arm 475 | # for not_ignore_dim in not_ignore_dims: 476 | # ignore_dimensions.remove(not_ignore_dim) 477 | # print('ignore_dimensions: ', ignore_dimensions) 478 | 479 | new_idx = [] 480 | count = 0 481 | for i in range(input_data.shape[1]): 482 | if i in ignore_dimensions: 483 | new_idx.append(-1) 484 | else: 485 | new_idx.append(count) 486 | count += 1 487 | 488 | return data_mean, data_std, ignore_dimensions, np.array(new_idx) 489 | 490 | 491 | def run_all(config): 492 | mot_scale = config.mot_scale 493 | 494 | print('--Loading music data..') 495 | [mus_seq, mus_all] = load_data(config.mus_data_dir, config.train_json_path, mot_scale) 496 | 497 | [test_mus_seq, test_mus_all] = load_data(config.mus_data_dir, config.test_json_path, mot_scale) 498 | 499 | print('--Loading motion data..') 500 | 501 | [mot_seq, mot_all] = load_data(config.mot_data_dir, config.train_json_path, mot_scale) 502 | [test_mot_seq, test_mot_all] = load_data(config.mot_data_dir, config.test_json_path, mot_scale) 503 | 504 | if config.is_all_norm: 505 | print('--Loading all norm motion info data..') 506 | [_, mot_all_data] = load_data(config.mot_data_dir, config.all_json_path, mot_scale) 507 | [_, mus_all_data] = load_data(config.mus_data_dir, config.all_json_path, mot_scale) 508 | else: 509 | mot_all_data = mot_all 510 | mus_all_data = mus_all 511 | 512 | mus_all_len = mus_all_data.shape[0] 513 | mus_all_data = np.reshape(mus_all_data, [mus_all_len, config.mus_dim, 5]) 514 | mus_all_data = mus_all_data.swapaxes(1, 2) 515 | mus_all_data = np.reshape(mus_all_data, [mus_all_len*5, config.mus_dim]) 516 | mus_data_max = np.max(mus_all_data, axis=0) 517 | mus_data_min = np.min(mus_all_data, axis=0) 518 | 519 | mot_data_max = np.max(mot_all_data, axis=0) 520 | mot_data_min = np.min(mot_all_data, axis=0) 521 | 522 | [mot_data_mean, mot_data_std, ignore_dimensions, new_idx] = normalize_info(mot_all_data) 523 | 524 | mus_seq_tensor, mot_seq_tensor = \ 525 | get_seq_tensor(mus_seq, mus_data_max, mus_data_min, 526 | mot_seq, mot_data_max, mot_data_min, 527 | mot_data_mean, mot_data_std, config) 528 | 529 | [mus_x_tensor, mot_x_tensor, mot_y_tensor] = \ 530 | sample_seq_idx(mus_seq, mot_seq, config) 531 | 532 | test_mus_seq_tensor, test_mot_seq_tensor = \ 533 | get_seq_tensor(test_mus_seq, mus_data_max, mus_data_min, 534 | test_mot_seq, mot_data_max, mot_data_min, 535 | mot_data_mean, mot_data_std, config) 536 | 537 | data_info = dict() 538 | data_info['ignore_dimensions'] = ignore_dimensions 539 | data_info['new_idx'] = new_idx 540 | data_info['mus'] = [mus_data_max, mus_data_min] 541 | data_info['mot'] = [mot_data_max, mot_data_min, mot_data_mean, mot_data_std] 542 | data_info['train'] = [mus_seq_tensor, mot_seq_tensor, mus_x_tensor, mot_x_tensor, mot_y_tensor] 543 | data_info['test'] = [test_mus_seq_tensor, test_mot_seq_tensor] 544 | 545 | return data_info 546 | 547 | 548 | def capg_seq_generator(epoch, train_type, data_info, config): 549 | if train_type == 0: 550 | mus_data = data_info['train'][0] 551 | mot_data = data_info['train'][1] 552 | mus_idx = data_info['train'][2] 553 | mot_x_idx = data_info['train'][3] 554 | mot_y_idx = data_info['train'][4] 555 | batch_size = config.batch_size 556 | else: 557 | mus_data = data_info['val'][0] 558 | mot_data = data_info['val'][1] 559 | mus_idx = data_info['val'][2] 560 | mot_x_idx = data_info['val'][3] 561 | mot_y_idx = data_info['val'][4] 562 | batch_size = config.batch_size 563 | 564 | std = 1e-12 565 | 566 | if train_type == 0 and config.use_noise: 567 | noise_schedule = config.noise_schedule 568 | for j in range(len(noise_schedule)): 569 | if epoch >= int(noise_schedule[j].split(':')[0]): 570 | std = float(noise_schedule[j].split(':')[1]) 571 | 572 | epoch_size = int(np.floor(len(mot_x_idx) / batch_size)) 573 | 574 | if config.is_shuffle and train_type == 0: 575 | if config.has_random_seed: 576 | np.random.seed(1234567890) 577 | shuffle_list = list(np.random.permutation(len(mot_x_idx))) 578 | mus_idx = [mus_idx[i] for i in shuffle_list] 579 | mot_x_idx = [mot_x_idx[i] for i in shuffle_list] 580 | mot_y_idx = [mot_y_idx[i] for i in shuffle_list] 581 | 582 | for i in range(epoch_size): 583 | x = [] 584 | y = [] 585 | f = [] 586 | for j in range(batch_size): 587 | file_name = mus_idx[i*batch_size+j][0] 588 | x_start = mus_idx[i*batch_size+j][1] 589 | x_end = mus_idx[i*batch_size+j][2] 590 | 591 | y_start = mot_y_idx[i*batch_size+j][1] 592 | y_end = mot_y_idx[i*batch_size+j][2] 593 | 594 | f_start = mot_x_idx[i*batch_size+j][1] 595 | 596 | x_seq = mus_data[file_name][0, x_start:x_end, :] 597 | y_seq = copy.deepcopy(mot_data[file_name][0, y_start:y_end, :]) 598 | f_seq = copy.deepcopy(mot_data[file_name][0, f_start, :]) 599 | x.append(x_seq) 600 | y.append(y_seq) 601 | f.append(f_seq) 602 | 603 | x = np.stack(x, axis=0) 604 | y = np.stack(y, axis=0) 605 | f = np.stack(f, axis=0) 606 | 607 | yield x, y, f 608 | -------------------------------------------------------------------------------- /utils/tf_expsdk.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def exp_mul(e0, e1): 6 | original_shape = e0.get_shape().as_list() 7 | e0 = tf.reshape(e0, [-1, 3]) 8 | e1 = tf.reshape(e1, [-1, 3]) 9 | r0 = exp2rot(e0) 10 | r1 = exp2rot(e1) 11 | rm = tf.matmul(r0, r1) 12 | em = rot2exp(rm) 13 | em = tf.reshape(em, original_shape) 14 | 15 | return em 16 | 17 | 18 | def rot2exp(r): 19 | return quat2exp(rot2quat(r)) 20 | 21 | 22 | def rot2quat(r, eps=1e-8): 23 | original_shape = r.get_shape().as_list() 24 | out_shape = original_shape[:-1] 25 | out_shape[-1] = 4 26 | r = tf.reshape(r, [-1, 3, 3]) 27 | d = r - tf.transpose(r, [0, 2, 1]) 28 | r_ = tf.stack([-d[:, 1, 2], d[:, 0, 2], -d[:, 0, 1]], axis=-1) 29 | sin_theta = tf.norm(r_, axis=-1) / 2 30 | r0 = tf.divide(r_, tf.norm(r_, axis=-1, keep_dims=True) + eps) 31 | cos_theta = (tf.trace(r) - 1) / 2 32 | theta = tf.atan2(sin_theta, cos_theta) 33 | theta = tf.reshape(theta, [-1, 1]) 34 | w = tf.cos(theta / 2) 35 | v = r0 * tf.sin(theta / 2) 36 | q = tf.concat([w, v], axis=-1) 37 | q = tf.reshape(q, out_shape) 38 | return q 39 | 40 | 41 | def revert_coordinate_space(chls, init_t, init_r): 42 | org_shape = chls.get_shape().as_list() 43 | 44 | init_t = tf.reshape(init_t, [1, 3]) 45 | init_r = tf.reshape(init_r, [1, 3, 3]) 46 | init_r = tf.tile(init_r, [org_shape[0], 1, 1]) 47 | r_prev = init_r 48 | t_prev = init_t 49 | rec_chls = [] 50 | for i in range(org_shape[1]): 51 | print(i) 52 | r_diff = exp2rot(chls[:, i, 3:6]) 53 | r = tf.matmul(r_diff, r_prev) 54 | rec_r = rot2exp(r) 55 | rec_t = t_prev + tf.squeeze(tf.matmul(tf.transpose(r_prev, [0, 2, 1]), 56 | tf.transpose(chls[:, i:i+1, :3], [0, 2, 1]))) 57 | rec_frame = tf.concat([rec_t, rec_r, chls[:, i, 6:]], axis=-1) 58 | rec_chls.append(rec_frame) 59 | 60 | t_prev = rec_t 61 | r_prev = r 62 | 63 | rec_chls = tf.stack(rec_chls, axis=1) 64 | print('revert_coordinate_space') 65 | return rec_chls 66 | 67 | 68 | def quat2exp(q, eps=1e-8): 69 | original_shape = q.get_shape().as_list() 70 | out_shape = original_shape.copy() 71 | out_shape[-1] = int(out_shape[-1] / 4 * 3) 72 | 73 | q = tf.reshape(q, [-1, 4]) 74 | sin_half_theta = tf.norm(q[:, 1:], axis=-1) 75 | cos_half_theta = q[:, 0] 76 | r0 = q[:, 1:] / (tf.norm(q[:, 1:], axis=-1, keep_dims=True) + eps) 77 | theta = 2 * tf.atan2(sin_half_theta, cos_half_theta) 78 | theta = tf.mod(theta + 2 * np.pi, 2 * np.pi) 79 | pi = tf.constant(np.pi, dtype=tf.float32, shape=theta.get_shape().as_list()) 80 | theta = tf.where(theta > pi, 2 * np.pi - theta, theta) 81 | r0 = tf.where(theta > pi, -r0, r0) 82 | 83 | e = r0 * tf.reshape(theta, [-1, 1]) 84 | e = tf.reshape(e, out_shape) 85 | return e 86 | 87 | 88 | def exp2rot(e, eps=1e-32): 89 | original_shape = e.get_shape().as_list() 90 | out_shape = original_shape.copy() 91 | out_shape.extend([3]) 92 | 93 | e = tf.reshape(e, [-1, 3]) 94 | theta = tf.norm(e, axis=-1, keep_dims=True) 95 | r0 = tf.divide(e, theta + eps) 96 | 97 | c_0 = tf.constant(0, dtype=tf.float32, shape=[e.get_shape().as_list()[0]]) 98 | # row0 = tf.stack([c_0, -r0[:, 2], r0[:, 1]], axis=1) 99 | # row1 = tf.stack([c_0, c_0, -r0[:, 0]], axis=1) 100 | # row2 = tf.stack([c_0, c_0, c_0], axis=1) 101 | # r0x = tf.stack([row0, row1, row2], axis=1) 102 | 103 | c0 = tf.stack([c_0, c_0, c_0], axis=1) 104 | c1 = tf.stack([-r0[:, 2], c_0, c_0], axis=1) 105 | c2 = tf.stack([r0[:, 1], -r0[:, 0], c_0], axis=1) 106 | r0x = tf.stack([c0, c1, c2], axis=2) 107 | 108 | r0x = r0x - tf.transpose(r0x, perm=[0, 2, 1]) 109 | 110 | eye_matrix = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32) 111 | rot_0 = tf.reshape(eye_matrix, [1, 3, 3]) 112 | rot_1 = tf.reshape(tf.sin(theta), [-1, 1, 1]) * r0x 113 | rot_2 = tf.matmul(tf.reshape(1-tf.cos(theta), [-1, 1, 1]) * r0x, r0x) 114 | rot = rot_0 + rot_1 + rot_2 115 | 116 | rot = tf.reshape(rot, out_shape) 117 | return rot 118 | 119 | 120 | def rotation_matrix(x_angle, y_angle, z_angle, order='zxy'): 121 | # TODO: only 'zxy' implementation for now 122 | # order = order.lower() 123 | original_shape = x_angle.get_shape().as_list() 124 | out_shape = original_shape.copy() 125 | out_shape.extend([3, 3]) 126 | 127 | x_angle = tf.reshape(x_angle, [-1]) 128 | y_angle = tf.reshape(y_angle, [-1]) 129 | z_angle = tf.reshape(z_angle, [-1]) 130 | 131 | c1 = tf.cos(x_angle) 132 | c2 = tf.cos(y_angle) 133 | c3 = tf.cos(z_angle) 134 | s1 = tf.sin(x_angle) 135 | s2 = tf.sin(y_angle) 136 | s3 = tf.sin(z_angle) 137 | 138 | r0 = tf.stack([c2*c3-s1*s2*s3, c2*s3+s1*s2*c3, -s2*c1], axis=1) 139 | r1 = tf.stack([-c1*s3, c1*c3, s1], axis=1) 140 | r2 = tf.stack([s2*c3+c2*s1*s3, s2*s3-c2*s1*c3, c2*c1], axis=1) 141 | rm = tf.stack([r0, r1, r2], axis=1) 142 | 143 | rm = tf.reshape(rm, original_shape) 144 | return rm 145 | 146 | 147 | def rot_vector(r, v): 148 | original_shape = v.get_shape().as_list() 149 | r = tf.reshape(r, [-1, 3, 3]) 150 | v = tf.reshape(v, [-1, 1, 3]) 151 | rv = tf.matmul(v, r) 152 | return tf.reshape(rv, original_shape) 153 | 154 | 155 | def rot_mul(r0, r1): 156 | original_shape = r0.get_shape().as_list() 157 | r0 = tf.reshape(r0, [-1, 3, 3]) 158 | r1 = tf.reshape(r1, [-1, 3, 3]) 159 | 160 | r = tf.matmul(r0, r1) 161 | return tf.reshape(r, original_shape) 162 | 163 | 164 | def exp2xyz(skel, rotations, root_positions, scale): 165 | """ 166 | :param skel: capg skel 167 | :param rotations: batch_size * num_steps * num_joints * 3 168 | :param root_positions: batch_size * num_steps * 3 169 | :param scale: meter, scale = 100.0 170 | :return: positions: batch_size * num_steps * num_joints * 3 171 | """ 172 | positions_world = [] 173 | rotations_world = [] 174 | rot_shape = rotations.get_shape().as_list() 175 | 176 | for i in range(len(skel)): 177 | if i == 0: 178 | this_pos = root_positions 179 | this_rot = exp2rot(rotations[:, :, 0, :]) 180 | else: 181 | parent = skel[i].parent 182 | offset = tf.constant(np.asarray(skel[i].offset) / scale, dtype=tf.float32) 183 | offset = tf.expand_dims(offset, 0) 184 | offset = tf.expand_dims(offset, 0) 185 | offset = tf.tile(offset, [rot_shape[0], rot_shape[1], 1]) 186 | this_pos = rot_vector(rotations_world[parent], offset) + positions_world[parent] 187 | if skel[i].quat_idx: 188 | # print(skel[i].quat_idx) 189 | this_rot = exp2rot(rotations[:, :, skel[i].quat_idx, :]) 190 | this_rot = rot_mul(this_rot, rotations_world[parent]) 191 | else: 192 | this_rot = None 193 | 194 | positions_world.append(this_pos) 195 | rotations_world.append(this_rot) 196 | 197 | points = tf.stack(positions_world, axis=3) 198 | points = tf.transpose(points, [0, 1, 3, 2]) 199 | 200 | return points 201 | --------------------------------------------------------------------------------