├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── bvh.py
├── finetune.sh
├── inlib
├── __init__.py
├── models.py
└── ops.py
├── m2m_config.py
├── model
├── __init__.py
├── base_model.py
├── cl_model.py
├── discriminator_model.py
└── gan_model.py
├── pretrain.sh
├── train_gan.py
└── utils
├── __init__.py
├── capg_exp_skel.pkl
├── exp_loss.py
├── gan_loss.py
├── plot_loss.py
├── reader.py
└── tf_expsdk.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | result_tmp
4 | dataset/*
5 | output
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning
2 | This reop contains training code of paper on Music2Dance generation: "[DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning](https://ieeexplore.ieee.org/abstract/document/9042236/)". [Project Page](http://zju-capg.org/research_en_music_deepdance.html)
3 |
4 | ## Requirements
5 | - A CUDA compatible GPU
6 | - Ubuntu >= 14.04
7 |
8 | ## Usage
9 |
10 | Download this repo on your computer and create a new enviroment using commands as follows:
11 | ```
12 | git clone https://github.com/computer-animation-perception-group/DeepDance_train.git
13 | conda create -n music_dance python==3.5
14 | pip install -r requirement.txt
15 | ```
16 | Download the processed training data ([fold_json](https://drive.google.com/file/d/18YhFlqkwU6akfjSBgcmywJu_BtfjmAZz/view?usp=sharing), [motion_feature](https://drive.google.com/file/d/18Hk5jEW8DV_AXzWZcvdLkUvlTiVdZ0Sp/view?usp=sharing) and [music_feature](https://drive.google.com/file/d/1VMt_fhG2livx1keh9o9Vu6zwwZPgB3ZZ/view?usp=sharing)), extract and put them under "./dataset", and run the following scripts:
17 | ```
18 | bash pretrain.sh
19 | bash finetune.sh
20 | ```
21 | Once the training completed, you can generate novel dances with the training models using our [demo code](https://github.com/computer-animation-perception-group/DeepDance)
22 |
23 | ## License
24 | Licensed under an GPL v3.0 License and only for research purpose.
25 |
26 | ## Bibtex
27 | ```
28 | @article{sun2020deepdance,
29 | author={G. {Sun} and Y. {Wong} and Z. {Cheng} and M. S. {Kankanhalli} and W. {Geng} and X. {Li}},
30 | journal={IEEE Transactions on Multimedia},
31 | title={DeepDance: Music-to-Dance Motion Choreography with Adversarial Learning},
32 | year={2021},
33 | volume={23},
34 | number={},
35 | pages={497-509},}
36 | ```
37 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/__init__.py
--------------------------------------------------------------------------------
/bvh.py:
--------------------------------------------------------------------------------
1 | class Node:
2 | def __init__(self, root=False):
3 | self.name = None
4 | self.channels = []
5 | self.offset = (0, 0, 0)
6 | self.children = []
7 | self._is_root = root
8 | self.order = ""
9 | self.pos_idx = []
10 | self.exp_idx = []
11 | self.rot_idx = []
12 | self.quat_idx = []
13 | self.parent = []
14 |
--------------------------------------------------------------------------------
/finetune.sh:
--------------------------------------------------------------------------------
1 | gpu=0
2 | dis_type='DisSegGraph'
3 | loss_mode='gan'
4 | fold_idx=0
5 | seg_len=90
6 | loss_type=2
7 | if [ $loss_type == 1 ]; then
8 | loss_arr=(1.0 0.05 0.0)
9 | elif [ $loss_type == 2 ]; then
10 | loss_arr=(1.0 0.1 0.1)
11 | else
12 | loss_arr=(1.0 0.0 0.0)
13 | fi
14 | mus_ebd_dim=72
15 | dis_name='time_cond_cnn'
16 | kernel_size=(1 3)
17 | stride=(1 2)
18 | cond_axis=1
19 | model_path=./output/pretrain/all-f4/model/cnn-erd_19_model.ckpt.meta
20 | CUDA_VISIBLE_DEVICES=$gpu \
21 | python3 train_gan.py --learning_rate 1e-4 \
22 | --dis_learning_rate 2e-5 \
23 | --mse_rate 1 \
24 | --dis_rate 0.01 \
25 | --loss_mode $loss_mode \
26 | --is_load_model True \
27 | --is_reg True \
28 | --reg_scale 5e-5 \
29 | --rnn_keep_list 0.95 0.9 1.0\
30 | --dis_type $dis_type \
31 | --dis_name $dis_name \
32 | --loss_rate_list ${loss_arr[0]} ${loss_arr[1]} ${loss_arr[2]}\
33 | --kernel_size ${kernel_size[0]} ${kernel_size[1]} \
34 | --stride ${stride[0]} ${stride[1]}\
35 | --act_type lrelu \
36 | --optimizer Adam \
37 | --cond_axis $cond_axis \
38 | --seg_list $seg_len \
39 | --seq_shift 1 \
40 | --gen_hop $seg_len \
41 | --fold_list $fold_idx \
42 | --type_list gudianwu \
43 | --model_path ${model_path%.*} \
44 | --max_max_epoch 15 \
45 | --save_data_epoch 5 \
46 | --save_model_epoch 5 \
47 | --is_save_train False \
48 | --mot_scale 100. \
49 | --norm_way zscore \
50 | --teacher_forcing_ratio 0. \
51 | --tf_decay 1. \
52 | --batch_size 128 \
53 | --mus_ebd_dim $mus_ebd_dim \
54 | --has_random_seed False \
55 | --is_all_norm True \
56 | --add_info ./output/finetune
--------------------------------------------------------------------------------
/inlib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/inlib/__init__.py
--------------------------------------------------------------------------------
/inlib/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .ops import *
3 | import tensorflow.contrib as tf_contrib
4 |
5 |
6 | def convolution(value, output_num, kernel_size=[3, 3], strides=[1, 1], name='conv', padding='SAME',
7 | activate_type='relu'):
8 | x = conv2d(value, output_num, kernel_size[0], kernel_size[1], [1, strides[0], strides[1], 1], name, padding)
9 | if activate_type == 'relu':
10 | x = xelu(x, 'relu_' + name, activate_type)
11 | else:
12 | x = xelu(x, 'lrelu_' + name, activate_type)
13 | return x
14 |
15 |
16 | class vgg19:
17 | def __init__(self, model_path):
18 | self.data_dict = np.load(model_path).item()
19 | print('load vgg19 weight complete.')
20 |
21 | def get_feature(self, x, reuse=False):
22 | self.channels = x.get_shape()[-1]
23 | with tf.variable_scope('vgg19') as scope:
24 | if reuse:
25 | scope.reuse_variables()
26 | # conv1
27 | self.conv1_1 = convolution(x, 64, [3, 3], [1, 1], name='conv1_1')
28 | self.conv1_2 = convolution(self.conv1_1, 64, [3, 3], [1, 1], name='conv1_2')
29 | self.conv1_2 = pool2D(self.conv1_2, 2, 2, name='max_pool1')
30 |
31 | # conv2
32 | self.conv2_1 = convolution(self.conv1_2, 128, [3, 3], [1, 1], name='conv2_1')
33 | self.conv2_2 = convolution(self.conv2_1, 128, [3, 3], [1, 1], name='conv2_2')
34 | self.conv2_2 = pool2D(self.conv2_2, 2, 2, name='max_pool2')
35 |
36 | # conv3
37 | self.conv3_1 = convolution(self.conv2_2, 256, [3, 3], [1, 1], name='conv3_1')
38 | self.conv3_2 = convolution(self.conv3_1, 256, [3, 3], [1, 1], name='conv3_2')
39 | self.conv3_3 = convolution(self.conv3_2, 256, [3, 3], [1, 1], name='conv3_3')
40 | self.conv3_4 = convolution(self.conv3_3, 256, [3, 3], [1, 1], name='conv3_4')
41 | self.conv3_4 = pool2D(self.conv3_4, 2, 2, name='max_pool3')
42 |
43 | # conv4
44 | self.conv4_1 = convolution(self.conv3_4, 512, [3, 3], [1, 1], name='conv4_1')
45 | self.conv4_2 = convolution(self.conv4_1, 512, [3, 3], [1, 1], name='conv4_2')
46 | self.conv4_3 = convolution(self.conv4_2, 512, [3, 3], [1, 1], name='conv4_3')
47 | self.conv4_4 = convolution(self.conv4_3, 512, [3, 3], [1, 1], name='conv4_4')
48 | self.conv4_4 = pool2D(self.conv3_4, 2, 2, name='max_pool4')
49 |
50 | # conv5
51 | self.conv5_1 = convolution(self.conv4_4, 512, [3, 3], [1, 1], name='conv5_1')
52 | self.conv5_2 = convolution(self.conv5_1, 512, [3, 3], [1, 1], name='conv5_2')
53 | self.conv5_3 = convolution(self.conv5_2, 512, [3, 3], [1, 1], name='conv5_3')
54 | self.conv5_4 = convolution(self.conv5_3, 512, [3, 3], [1, 1], name='conv5_4')
55 |
56 | # flatten
57 | self.feature5_4 = tf.reshape(self.conv5_4, [x.get_shape()[0], -1], name='feature5_4')
58 |
59 | return self.feature5_4
60 |
61 | def load_weights(self, sess):
62 | vars = tf.trainable_variables(scope='vgg19')
63 | loaded_vars = [var for var in vars if 'conv' in var.name]
64 | keys = sorted(self.data_dict)
65 | for i in range(len(keys)):
66 | print(loaded_vars[i * 2], keys[i] + '_weight')
67 | sess.run(loaded_vars[i * 2].assign(self.data_dict[keys[i]][0]))
68 | print(loaded_vars[i * 2 + 1], keys[i] + '_bias')
69 | sess.run(loaded_vars[i * 2 + 1].assign(self.data_dict[keys[i]][1]))
70 |
71 |
72 | def mlp(x, dim_list, name='mlp', reuse=False):
73 | # dim_list=[[output_num1, activation1],...,[output_numk, activationk]]
74 | with tf.variable_scope(name) as scope:
75 | if reuse:
76 | scope.reuse_variables()
77 | outputs = [x]
78 | for i in range(len(dim_list)):
79 | out = fc(outputs[-1], dim_list[i][0], name='fc' + str(i + 1))
80 | if i < len(dim_list) - 1:
81 | out = xelu(out, name=dim_list[i][1] + str(i + 1), activate_type=dim_list[i][1])
82 | outputs.append(out)
83 | return outputs[-1]
84 |
85 |
86 | def cnn(x, conv_list, fc_list, name='cnn', is_training=True, reuse=False):
87 | # conv_list=[[output_num1, kernels1, strides1, padding1, activation1],
88 | # ...[output_numk, kernelsk, stridesk, paddingk, activationk]]
89 | # fc_list=[[output_num1, activation1],...,[output_numk,activationk]
90 | with tf.variable_scope(name) as scope:
91 | if reuse:
92 | scope.reuse_variables()
93 | outputs = [x]
94 | for i in range(len(conv_list)):
95 | out = convolution(outputs[-1], conv_list[i][0], conv_list[i][1], conv_list[i][2], name='conv' + str(i + 1),
96 | padding=conv_list[i][3], activate_type=conv_list[i][4])
97 | if 'bn' in conv_list[-1]:
98 | out = bn(out, is_training, scope='bn'+str(i+1))
99 | outputs.append(out)
100 | # print('outputs: ', outputs[-1])
101 | outputs.append(tf.reshape(outputs[-1], [int(x.get_shape()[0]), -1]))
102 | for i in range(len(fc_list)):
103 | out = fc(outputs[-1], fc_list[i][0], name='fc' + str(i + 1))
104 | if i < len(fc_list) - 1:
105 | out = xelu(out, name=fc_list[i][1] + str(i + 1), activate_type=fc_list[i][1])
106 | outputs.append(out)
107 | return outputs[-1]
108 |
--------------------------------------------------------------------------------
/inlib/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.layers import pooling
3 | import tensorflow.contrib as tf_contrib
4 |
5 |
6 | def weight(shape, stddev=0.02, name='weight', trainable=True):
7 | dtype = tf.float32
8 | var = tf.get_variable(name, shape, dtype, initializer=tf.random_normal_initializer(0.0, stddev, dtype=dtype),
9 | trainable=trainable)
10 | return var
11 |
12 |
13 | def bias(dim, bias_start=0.0, name='bias', trainable=True):
14 | dtype = tf.float32
15 | var = tf.get_variable(name, dim, dtype, initializer=tf.constant_initializer(value=bias_start, dtype=dtype),
16 | trainable=trainable)
17 |
18 | return var
19 |
20 |
21 | def xelu(value, name='relu', activate_type='relu', para=0.2):
22 | with tf.variable_scope(name):
23 | if activate_type == 'relu':
24 | # relu
25 | return tf.nn.relu(value)
26 | elif activate_type == 'lrelu':
27 | # leaky relu
28 | return tf.maximum(value, value * para)
29 | else:
30 | return value
31 |
32 |
33 | def pool2D(value, k_h=3, k_w=3, strides=[1, 2, 2, 1], name='max_pool', padding='VALID'):
34 | kernel_size = [1, k_h, k_w, 1]
35 | with tf.variable_scope(name + '_2d'):
36 | if name == 'max_pool':
37 | # max pooling
38 | return tf.nn.max_pool(value, kernel_size, strides, padding)
39 | elif name == 'avg_pool':
40 | # average pooling
41 | return tf.nn.avg_pool(value, kernel_size, strides, padding)
42 | else:
43 | # default: max pooling
44 | return tf.nn.max_pool(value, kernel_size, strides, padding)
45 |
46 |
47 | def pool1D(value, ksize=3, strides=[1, 2, 1], name='max_pool', padding='VALID'):
48 | kernel_size = [1, ksize, 1]
49 | with tf.variable_scope(name + '_1d'):
50 | if name == 'max_pool':
51 | # max pooling
52 | return pooling.max_pooling1d(value, kernel_size, strides, padding)
53 | elif name == 'avg_pool':
54 | # average pooling
55 | return pooling.average_pooling1d(value, kernel_size, strides, padding)
56 | else:
57 | # default: max pooling
58 | return pooling.max_pooling1d(value, kernel_size, strides, padding)
59 |
60 |
61 | def fc(value, output_num, name='fc', with_weight=False, with_bias=True):
62 | input_shape = value.get_shape().as_list()
63 | with tf.variable_scope(name):
64 | weights = weight([input_shape[1], output_num])
65 | output = tf.matmul(value, weights)
66 | if with_bias:
67 | biases = bias(output_num)
68 | output = output + biases
69 | if with_weight:
70 | if with_bias:
71 | return output, weights, biases
72 | else:
73 | return output, weights
74 | else:
75 | return output
76 |
77 |
78 | def conv1d(value, output_num, ksize=3, strides=[1, 1, 1], name='conv', padding='SAME', with_weight=False,
79 | with_bias=True):
80 | with tf.variable_scope(name + '_1d'):
81 | weights = weight([ksize, value.get_shape[-1], output_num])
82 | conv = tf.nn.conv1d(value, weights, strides, padding, use_cudnn_on_gpu=True)
83 | if with_bias:
84 | biases = bias(output_num)
85 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
86 | if with_weight:
87 | if with_bias:
88 | return conv, weights, biases
89 | else:
90 | return conv, weights
91 | else:
92 | return conv
93 |
94 |
95 | def conv2d(value, output_num, k_h=3, k_w=3, strides=[1, 1, 1, 1], name='conv', padding='SAME', with_weight=False,
96 | with_bias=True):
97 | with tf.variable_scope(name + '_2d'):
98 | weights = weight([k_h, k_w, value.get_shape()[-1], output_num])
99 | conv = tf.nn.conv2d(value, weights, strides, padding, use_cudnn_on_gpu=True)
100 | if with_bias:
101 | biases = bias(output_num)
102 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
103 | if with_weight:
104 | if with_bias:
105 | return conv, weights, biases
106 | else:
107 | return conv, weights
108 | else:
109 | return conv
110 |
111 |
112 | def bn(x, is_training, scope='bn'):
113 | return tf.layers.batch_normalization(x,
114 | axis=-1,
115 | training=is_training,
116 | name=scope)
117 |
--------------------------------------------------------------------------------
/m2m_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 |
6 | class CAPGConfig(object):
7 | """capg config"""
8 | def __init__(self, m_type, fold_num, seg_len):
9 | self.mot_data_dir = './dataset/motion_feature/exp/'
10 | self.mus_data_dir = './dataset/music_feature/librosa/'
11 | self.json_dir = './dataset/fold_json/'
12 | self.all_json_path = os.path.join(self.json_dir, 'all-f4', fold_num, 'train_list.json')
13 | self.train_json_path = os.path.join(self.json_dir, m_type, fold_num, 'train_list.json')
14 | self.test_json_path = os.path.join(self.json_dir, m_type, fold_num, 'test_list.json')
15 | self.hidden_size = 512
16 | self.mot_hidden_size = 1024
17 |
18 | self.is_save_model = True
19 | self.is_load_model = False
20 | self.save_epoch = 0
21 | self.test_epoch = 5
22 | self.gen_hop = 10
23 | self.seq_shift = 1 # 15 for beat
24 | self.use_mus_rnn = True
25 | self.mus_rnn_layers = 1
26 | self.max_max_epoch = 15
27 | self.is_reg = False
28 | self.reg_scale = 5e-4
29 | self.rnn_keep_prob = 1
30 |
31 | self.is_shuffle = True
32 | self.has_random_seed = False
33 |
34 | self.is_align = True
35 | self.mus_delay = 0 # 1 for beat
36 | self.mot_ignore_dims = [18, 19, 20, 33, 34, 35, 48, 49, 50, 60, 61, 62, 72, 73, 74]
37 | self.mot_dim = 60
38 |
39 | self.is_z_score = True
40 | self.is_all_norm = False
41 | self.mus_dim = 201
42 | self.mus_kernel_size = 51
43 | self.batch_size = 32
44 | self.num_steps = int(seg_len)
45 | self.test_num_steps = int(seg_len)
46 | self.max_epoch = 20
47 | self.lr_decay = 1
48 | self.max_grad_norm = 25
49 | self.val_data_len = 150
50 |
51 | self.rnn_layers = 3
52 | self.mot_rnn_layers = 2
53 |
54 | self.info = "gan_gt"
55 | self.val_batch_size = 1
56 | self.test_batch_size = 1
57 | self.is_use_pre_mot = False
58 |
59 | self.use_noise = False
60 | self.noise_schedule = ['2:0.05', '6:0.1', '12:0.2', '16:0.3', '22:0.5', '30:0.8', '36:1.0']
61 | self.start_idx = 1
62 |
63 | def save_config(self, path):
64 | config_dict = dict()
65 | for name, value in vars(self).items():
66 | if isinstance(value, list):
67 | value = np.asarray(value).tolist()
68 | config_dict[name] = value
69 | json.dump(config_dict, open(path, 'w'), indent=4, sort_keys=True)
70 |
71 |
72 | def get_config(m_type, fold_num, seg_len):
73 | config = CAPGConfig(m_type, fold_num, seg_len)
74 | return config
75 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/model/__init__.py
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 |
5 | def data_type():
6 | return tf.float32
7 |
8 |
9 | class BaseModel(object):
10 | """The Base model."""
11 | def __init__(self, train_type, config):
12 | if train_type == 0:
13 | print("---init training graph---")
14 | is_training = True
15 | elif train_type == 1:
16 | print('---init validate graph---')
17 | is_training = False
18 | else:
19 | print("---init test graph---")
20 | is_training = False
21 |
22 | self.is_training = is_training
23 | config.is_training = is_training
24 | self.mot_dim = config.mot_dim
25 | self.mus_ebd_dim = config.mus_ebd_dim
26 | self.batch_size = config.batch_size
27 | self.num_steps = config.num_steps
28 | self.input_x = tf.placeholder(shape=[self.batch_size, self.num_steps, None],
29 | dtype=data_type(), name="input_x")
30 | # mot_input
31 | self.input_y = tf.placeholder(shape=[self.batch_size, self.num_steps, self.mot_dim],
32 | dtype=data_type(), name="input_y")
33 |
34 | self.init_step_mot = tf.placeholder(data_type(), [self.batch_size, self.mot_dim], name="init_step_mot")
35 | self.tf_mask = tf.placeholder(shape=[self.num_steps], dtype=tf.bool, name='tf_mask')
36 |
37 | mot_predictions, mus_ebd_outputs, mot_state, mus_state = \
38 | self._build_mot_rnn_graph(mus_inputs=self.input_x,
39 | config=config,
40 | train_type=train_type)
41 |
42 | self.mot_final_state = mot_state
43 | self.mus_final_state = mus_state
44 | self.mot_predictions = mot_predictions
45 | self.mot_truth = self.input_y
46 | self.mus_ebd_outputs = mus_ebd_outputs
47 |
48 | def _build_mus_graph(self, time_step, mus_cell, mus_state, inputs, config, is_training):
49 | print("mus_graph")
50 | # outputs = []
51 | with tf.variable_scope("mus_rnn"):
52 | fc_weights = tf.get_variable('fc', [config.hidden_size, self.mus_ebd_dim],
53 | initializer=tf.truncated_normal_initializer())
54 | fc_biases = tf.get_variable('bias', [self.mus_ebd_dim],
55 | initializer=tf.zeros_initializer())
56 | if time_step > 0:
57 | tf.get_variable_scope().reuse_variables()
58 |
59 | mus_input = self._build_mus_conv_graph(inputs, config, is_training)
60 | (cell_output, mus_state) = mus_cell(mus_input, mus_state)
61 | # outputs.append(cell_output)
62 | output = tf.reshape(cell_output, [-1, config.hidden_size])
63 | fc_output = tf.nn.xw_plus_b(output, fc_weights, fc_biases)
64 |
65 | return fc_output, mus_state
66 |
67 | @staticmethod
68 | def _get_lstm_cell(rnn_layer_idx, hidden_size, config, is_training):
69 | lstm_cell = tf_contrib.rnn.BasicLSTMCell(
70 | hidden_size, forget_bias=0.0, state_is_tuple=True,
71 | reuse=tf.get_variable_scope().reuse)
72 | print('rnn_layer: ', rnn_layer_idx, config.rnn_keep_list[rnn_layer_idx])
73 | if is_training and config.rnn_keep_list[rnn_layer_idx] < 1:
74 | lstm_cell = tf_contrib.rnn.DropoutWrapper(lstm_cell,
75 | output_keep_prob=config.rnn_keep_list[rnn_layer_idx])
76 | return lstm_cell
77 |
78 | def _build_mot_rnn_graph(self, mus_inputs, config, train_type):
79 | if train_type == 0:
80 | is_training = True
81 | else:
82 | is_training = False
83 |
84 | rnn_layer_idx = 0
85 | mus_cell = tf_contrib.rnn.MultiRNNCell(
86 | [self._get_lstm_cell(i, config.hidden_size, config, is_training)
87 | for i in range(rnn_layer_idx, rnn_layer_idx + config.mus_rnn_layers)], state_is_tuple=True)
88 |
89 | rnn_layer_idx += config.mus_rnn_layers
90 | mot_cell = tf_contrib.rnn.MultiRNNCell(
91 | [self._get_lstm_cell(i, config.mot_hidden_size, config, is_training)
92 | for i in range(rnn_layer_idx, rnn_layer_idx + config.mot_rnn_layers)], state_is_tuple=True)
93 |
94 | self.mot_initial_state = mot_cell.zero_state(config.batch_size, data_type())
95 | mot_state = self.mot_initial_state
96 |
97 | self.mus_initial_state = mus_cell.zero_state(config.batch_size, data_type())
98 | mus_state = self.mus_initial_state
99 |
100 | last_step_mot = self.init_step_mot
101 | outputs = []
102 | mus_ebd_outputs = []
103 |
104 | with tf.variable_scope("generator/mot_rnn"):
105 |
106 | for time_step in range(self.num_steps):
107 | if time_step > 0:
108 | tf.get_variable_scope().reuse_variables()
109 | last_step_mot = tf.cond(tf.equal(self.tf_mask[time_step], tf.constant(True)),
110 | lambda: self.input_y[:, time_step-1, :],
111 | lambda: self.last_step_mot)
112 |
113 | mus_input = mus_inputs[:, time_step, :]
114 | print("mot_rnn: ", time_step)
115 | if not config.use_mus_rnn:
116 | with tf.variable_scope("mus_rnn"):
117 | mus_fea = self._build_mus_conv_graph(mus_input, config, is_training)
118 | else:
119 | mus_fea, mus_state = self._build_mus_graph(time_step, mus_cell,
120 | mus_state, mus_input, config, is_training)
121 | mot_input = last_step_mot
122 | mot_input = tf.reshape(mot_input, [-1, self.mot_dim])
123 | # mus_fea = tf.zeros(tf.shape(mus_fea))
124 | all_input = tf.concat([mus_fea, mot_input], 1, name='mus_mot_input')
125 |
126 | # fc1
127 | fc1_weights = tf.get_variable('fc1', [self.mus_ebd_dim + self.mot_dim, 500], dtype=data_type())
128 | fc1_biases = tf.get_variable('bias1', [500], dtype=data_type())
129 | fc1_linear = tf.nn.xw_plus_b(all_input, fc1_weights, fc1_biases, name='fc1_linear')
130 | fc1_relu = tf.nn.relu(fc1_linear, name='fc1_relu')
131 |
132 | # fc2
133 | fc2_weights = tf.get_variable('fc2', [500, 500], dtype=data_type())
134 | fc2_biases = tf.get_variable('bias2', [500], dtype=data_type())
135 | fc2_linear = tf.nn.xw_plus_b(fc1_relu, fc2_weights, fc2_biases, name='fc2_linear')
136 |
137 | (cell_output, mot_state) = mot_cell(fc2_linear, mot_state)
138 | output = tf.reshape(cell_output, [-1, config.mot_hidden_size])
139 |
140 | # fc3
141 | fc3_weights = tf.get_variable('fc3', [config.mot_hidden_size, 500], dtype=data_type())
142 | fc3_biases = tf.get_variable('bias3', [500], dtype=data_type())
143 | fc3_linear = tf.nn.xw_plus_b(output, fc3_weights, fc3_biases, name='fc3_linear')
144 | fc3_relu = tf.nn.relu(fc3_linear, name='fc3_relu')
145 |
146 | fc4_weights = tf.get_variable('fc4', [500, 100], dtype=data_type())
147 | fc4_biases = tf.get_variable('bias4', [100], dtype=data_type())
148 | fc4_linear = tf.nn.xw_plus_b(fc3_relu, fc4_weights, fc4_biases, name='fc4_linear')
149 | fc4_relu = tf.nn.relu(fc4_linear, name='fc4_relu')
150 |
151 | fc5_weights = tf.get_variable('fc5', [100, self.mot_dim], dtype=data_type())
152 | fc5_biases = tf.get_variable('bias5', [self.mot_dim], dtype=data_type())
153 | fc5_linear = tf.nn.xw_plus_b(fc4_relu, fc5_weights, fc5_biases, name='fc5_linear')
154 | self.last_step_mot = fc5_linear
155 |
156 | outputs.append(fc5_linear)
157 | mus_ebd_outputs.append(mus_fea)
158 |
159 | outputs = tf.reshape(tf.concat(outputs, 1), [self.batch_size, self.num_steps, self.mot_dim])
160 | mus_ebd_outputs = tf.reshape(tf.concat(mus_ebd_outputs, 1), [self.batch_size, self.num_steps, self.mus_ebd_dim])
161 |
162 | return outputs, mus_ebd_outputs, mot_state, mus_state
163 |
164 | @staticmethod
165 | def _mus_conv(inputs, kernel_shape, bias_shape, is_training):
166 | conv_weights = tf.get_variable('conv', kernel_shape,
167 | initializer=tf.truncated_normal_initializer())
168 | # tf.summary.histogram("conv weights", conv_weights)
169 | conv_biases = tf.get_variable('bias', bias_shape,
170 | initializer=tf.zeros_initializer())
171 | conv = tf.nn.conv2d(inputs,
172 | conv_weights,
173 | strides=[1, 1, 1, 1],
174 | padding='VALID')
175 | bias = tf.nn.bias_add(conv, conv_biases)
176 | norm = tf.layers.batch_normalization(bias, axis=3,
177 | training=is_training)
178 |
179 | elu = tf.nn.elu(norm)
180 | return elu
181 |
182 | def _build_mus_conv_graph(self, inputs, config, is_training):
183 | """Build music graph"""
184 |
185 | print("mus_conv_graph")
186 | mus_dim = config.mus_dim
187 | mus_input = tf.reshape(inputs, [-1, mus_dim, 5, 1])
188 |
189 | with tf.variable_scope('conv1'):
190 | elu1 = self._mus_conv(mus_input,
191 | kernel_shape=[mus_dim, 2, 1, 64],
192 | bias_shape=[64],
193 | is_training=is_training)
194 | with tf.variable_scope('conv2'):
195 | elu2 = self._mus_conv(elu1,
196 | kernel_shape=[1, 2, 64, 128],
197 | bias_shape=[128],
198 | is_training=is_training)
199 |
200 | with tf.variable_scope('conv3'):
201 | elu3 = self._mus_conv(elu2,
202 | kernel_shape=[1, 2, 128, 256],
203 | bias_shape=[256],
204 | is_training=is_training)
205 |
206 | with tf.variable_scope('conv4'):
207 | elu4 = self._mus_conv(elu3,
208 | kernel_shape=[1, 2, 256, 512],
209 | bias_shape=[512],
210 | is_training=is_training)
211 | mus_conv_output = tf.reshape(elu4, [-1, 512])
212 | return mus_conv_output
213 |
--------------------------------------------------------------------------------
/model/cl_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import *
2 | from utils import exp_loss as es
3 |
4 |
5 | class CLModel(BaseModel):
6 | def __init__(self, train_type, config):
7 | super(CLModel, self).__init__(train_type, config)
8 |
9 | mot_predictions = self.mot_predictions
10 | mot_truth = self.mot_truth
11 |
12 | tru_pos, pre_pos = es.get_pos_chls(mot_predictions, mot_truth, config)
13 |
14 | # generator loss
15 | g_loss, loss_list = es.loss_impl(mot_predictions, mot_truth, pre_pos, tru_pos, config)
16 | self.g_loss = loss_list
17 |
18 | # if test, return
19 | if not self.is_training:
20 | return
21 |
22 | tvars = tf.trainable_variables()
23 | g_vars = [v for v in tvars if 'generator' in v.name]
24 |
25 | # add reg
26 | if config.is_reg:
27 | reg_cost = tf.reduce_sum([tf.nn.l2_loss(v) for v in g_vars
28 | if 'bias' not in v.name]) * config.reg_scale
29 | g_loss = g_loss + reg_cost
30 |
31 | gen_learning_rate = config.learning_rate
32 |
33 | if config.optimizer.lower() == 'adam':
34 | print('Adam optimizer')
35 | g_optimizer = tf.train.AdamOptimizer(learning_rate=gen_learning_rate)
36 | else:
37 | print('Rmsprop optimizer')
38 | g_optimizer = tf.train.RMSPropOptimizer(learning_rate=gen_learning_rate)
39 |
40 | # for batch_norm op
41 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
42 | g_grads = tf.gradients(g_loss, g_vars, aggregation_method=2)
43 | with tf.control_dependencies(update_ops):
44 | self.train_g_op = g_optimizer.apply_gradients(zip(g_grads, g_vars))
45 | print('train_g_op')
46 |
--------------------------------------------------------------------------------
/model/discriminator_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import inlib.models as md
3 | import numpy as np
4 |
5 |
6 | class DisGraph(object):
7 | def __init__(self, inputs, cond_inputs, config, name, is_reuse):
8 | self.inputs = inputs
9 | self.cond_inputs = cond_inputs
10 | self.name = name
11 | self.is_reuse = is_reuse
12 | self.act_type = config.act_type
13 | self.kernel_size = config.kernel_size
14 | self.cond_axis = config.cond_axis
15 | self.stride = config.stride
16 | self.mus_ebd_dim = config.mus_ebd_dim
17 | self.batch_size = config.batch_size
18 | self.num_steps = config.num_steps
19 | self.is_training = config.is_training
20 | self.is_shuffle = config.is_shuffle
21 |
22 | def build_dis_graph(self):
23 | if self.name == 'mlp':
24 | outputs = self._build_dis_mlp_graph()
25 | elif self.name == 'cnn':
26 | outputs = self._build_dis_cnn_graph()
27 | elif self.name == 'sig_cnn':
28 | outputs = self._build_dis_sig_cnn_graph()
29 | elif self.name == 'cond_cnn':
30 | outputs = self._build_dis_cond_cnn_graph()
31 | elif self.name == 'time_cond_cnn':
32 | outputs = self._build_dis_time_cond_cnn_graph()
33 | elif self.name == 'tgan_cond_cnn':
34 | outputs = self._build_dis_tgan_cond_cnn_graph()
35 | elif self.name == 'time_tgan_cond_cnn':
36 | outputs = self._build_dis_time_tgan_cond_cnn_graph()
37 | else:
38 | raise ValueError('Not valid discriminator name')
39 |
40 | return outputs
41 |
42 | def _build_dis_mlp_graph(self):
43 | return []
44 |
45 | def _build_dis_cnn_graph(self):
46 | return []
47 |
48 | def _build_dis_cond_cnn_graph(self):
49 | return []
50 |
51 | def _build_dis_time_cond_cnn_graph(self):
52 | return []
53 |
54 | def _build_dis_tgan_cond_cnn_graph(self):
55 | return []
56 |
57 | def _build_dis_time_tgan_cond_cnn_graph(self):
58 | return []
59 |
60 | def _build_dis_sig_cnn_graph(self):
61 | return []
62 |
63 |
64 | class DisFrameGraph(DisGraph):
65 | def __init__(self, inputs, cond_inputs, config, name='cnn', is_reuse=False):
66 | super(DisFrameGraph, self).__init__(inputs, cond_inputs, config, name, is_reuse)
67 |
68 | def _build_dis_mlp_graph(self):
69 | fc_list_d = [[100, self.act_type], [256, self.act_type], [500, self.act_type], [1, '']]
70 | # [batch_size*num_steps, mus_ebd_dim]
71 | mot_input = tf.reshape(self.inputs, [-1, 60])
72 | outputs = md.mlp(mot_input, fc_list_d, 'discriminator', reuse=self.is_reuse)
73 | return outputs
74 |
75 | def _build_dis_cnn_graph(self):
76 | print('frame_cnn_graph')
77 | mot_input = tf.reshape(self.inputs, [-1, 20, 1, 3])
78 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type],
79 | [128, self.kernel_size, self.stride, 'SAME', self.act_type]]
80 | fc_list_d = [[1, '']]
81 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
82 | return outputs
83 |
84 | def _build_dis_cond_cnn_graph(self):
85 | print('frame_cond_cnn_graph')
86 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
87 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
88 | # bs * mus_ebd_dim * num_steps * 1
89 | mot_input = tf.transpose(mot_input, [0, 2, 1, 3])
90 | cond_input = tf.transpose(cond_input, [0, 2, 1, 3])
91 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond')
92 |
93 | [batch_size, m_dim, num_steps, chl] = all_input.get_shape()
94 | all_input = tf.transpose(all_input, [0, 2, 1, 3])
95 | all_input = tf.reshape(all_input, [int(batch_size)*int(num_steps), int(m_dim), 1, int(chl)])
96 | print('all_input: ', all_input)
97 |
98 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type],
99 | [128, self.kernel_size, self.stride, 'SAME', self.act_type]]
100 | fc_list_d = [[1, '']]
101 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
102 | return outputs
103 |
104 | def _build_dis_sig_cnn_graph(self):
105 | print('frame_sig_cnn_graph')
106 | inputs = tf.reshape(self.inputs, [-1, 20, 1, 3])
107 | idx_lists = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
108 | 18, 19, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 2, 4, 6, 8, 10,
109 | 12, 14, 16, 18, 1, 4, 7, 10, 13, 16, 19, 3, 6, 9, 12, 15, 18,
110 | 2, 5, 8, 11, 14, 17, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3,
111 | 7, 11, 15, 19, 4, 8, 12, 16, 1, 6, 11, 16, 2, 7, 12, 17, 3,
112 | 8, 13, 18, 4, 9, 14, 19, 5, 10, 15, 1, 7, 13, 19, 6, 12, 18,
113 | 5, 11, 17, 4, 10, 16, 3, 9, 15, 2, 8, 14, 1, 8, 15, 3, 10,
114 | 17, 5, 12, 19, 7, 14, 2, 9, 16, 4, 11, 18, 6, 13, 1, 9, 17,
115 | 6, 14, 3, 11, 19, 8, 16, 5, 13, 2, 10, 18, 7, 15, 4, 12, 1,
116 | 10, 19, 9, 18, 8, 17, 7, 16, 6, 15, 5, 14, 4, 13, 3, 12, 2,
117 | 11, 1]
118 |
119 | # TODO: need to check
120 | mot_input = []
121 | for i, idx in enumerate(idx_lists):
122 | mot_input.append(inputs[:, idx, :, :])
123 | mot_input = tf.reshape(tf.concat(mot_input, axis=1), [-1, 173, 1, 3])
124 | # [3, 1], [2, 1]
125 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type],
126 | [128, self.kernel_size, self.stride, 'SAME', self.act_type],
127 | [256, self.kernel_size, self.stride, 'SAME', self.act_type],
128 | [512, self.kernel_size, self.stride, 'SAME', self.act_type]]
129 |
130 | fc_list_d = [[1, '']]
131 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
132 | return outputs
133 |
134 |
135 | class DisSegGraph(DisGraph):
136 | def __init__(self, inputs, cond_inputs, config, name='mlp', is_reuse=False):
137 | super(DisSegGraph, self).__init__(inputs, cond_inputs, config, name, is_reuse)
138 |
139 | def _build_dis_mlp_graph(self):
140 | outputs = []
141 | return outputs
142 |
143 | def _build_dis_cnn_graph(self):
144 | print('seg_cnn_graph')
145 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, 20, 3])
146 | # bs * 20 * num_steps * 3
147 | tf.transpose(mot_input, [0, 2, 1, 3])
148 | # [3, 3] [2, 2]
149 | conv_list_d = [[64, self.kernel_size, self.stride, 'SAME', self.act_type],
150 | [128, self.kernel_size, self.stride, 'SAME', self.act_type],
151 | [256, self.kernel_size, self.stride, 'SAME', self.act_type],
152 | [512, self.kernel_size, self.stride, 'SAME', self.act_type]]
153 | fc_list_d = [[1, '']]
154 | outputs = md.cnn(mot_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
155 | return outputs
156 |
157 | def _build_dis_cond_cnn_graph(self):
158 | print('seg_cond_cnn_graph')
159 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
160 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
161 | # bs * mus_ebd_dim * num_steps * 1
162 | # cond_input = tf.transpose(cond_input, [0, 2, 1, 3])
163 | # mot_input = tf.transpose(mot_input, [0, 2, 1, 3])
164 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond')
165 | if self.is_shuffle:
166 | original_shape = all_input.get_shape().as_list()
167 | np.random.seed(1234567890)
168 | shuffle_list = list(np.random.permutation(original_shape[0]))
169 | all_inputs = []
170 | for i, idx in enumerate(shuffle_list):
171 | all_inputs.append(all_input[idx:idx+1, :, :, :])
172 | all_input = tf.concat(all_inputs, axis=0)
173 | print('all_input: ', all_input)
174 | # [3, 3] [2, 2]
175 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type],
176 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
177 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
178 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']]
179 | fc_list_d = [[1, '']]
180 | outputs = md.cnn(all_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
181 | return outputs
182 |
183 | def _build_dis_time_cond_cnn_graph(self):
184 | print('seg_time_cond_cnn_graph')
185 | # bs * 1 * num_steps * 72
186 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim])
187 | mot_input = tf.reshape(self.inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim])
188 |
189 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond')
190 | if self.is_shuffle:
191 | original_shape = all_input.get_shape().as_list()
192 | np.random.seed(1234567890)
193 | shuffle_list = list(np.random.permutation(original_shape[0]))
194 | all_inputs = []
195 | for i, idx in enumerate(shuffle_list):
196 | all_inputs.append(all_input[idx:idx+1, :, :, :])
197 | all_input = tf.concat(all_inputs, axis=0)
198 | print('all_input: ', all_input)
199 | # [1, 3] [1, 2]
200 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type],
201 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
202 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
203 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']]
204 | fc_list_d = [[1, '']]
205 | outputs = md.cnn(all_input, conv_list_d, fc_list_d, name='discriminator', reuse=self.is_reuse)
206 | return outputs
207 |
208 | def _build_dis_tgan_cond_cnn_graph(self):
209 | print('tgan_cond_cnn_graph')
210 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
211 | mot_input = tf.reshape(self.inputs, [self.batch_size, self.num_steps, self.mus_ebd_dim, 1])
212 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond')
213 | if self.is_shuffle:
214 | print('shuffle')
215 | original_shape = all_input.get_shape().as_list()
216 | np.random.seed(1234567890)
217 | shuffle_list = list(np.random.permutation(original_shape[0]))
218 | all_inputs = []
219 | for i, idx in enumerate(shuffle_list):
220 | all_inputs.append(all_input[idx:idx+1, :, :, :])
221 | all_input = tf.concat(all_inputs, axis=0)
222 | print('all_input: ', all_input)
223 | # [3, 3] [2, 2]
224 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type],
225 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
226 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
227 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']]
228 | outputs = md.cnn(all_input, conv_list_d, [], name='discriminator',
229 | is_training=self.is_training, reuse=self.is_reuse)
230 | return outputs
231 |
232 | def _build_dis_time_tgan_cond_cnn_graph(self):
233 | print('time_tgan_cond_cnn_graph')
234 | # bs * 1 * num_steps * 72
235 | cond_input = tf.reshape(self.cond_inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim])
236 | mot_input = tf.reshape(self.inputs, [self.batch_size, 1, self.num_steps, self.mus_ebd_dim])
237 |
238 | all_input = tf.concat([mot_input, cond_input], axis=self.cond_axis, name='concat_cond')
239 | if self.is_shuffle:
240 | original_shape = all_input.get_shape().as_list()
241 | np.random.seed(1234567890)
242 | shuffle_list = list(np.random.permutation(original_shape[0]))
243 | all_inputs = []
244 | for i, idx in enumerate(shuffle_list):
245 | all_inputs.append(all_input[idx:idx+1, :, :, :])
246 | all_input = tf.concat(all_inputs, axis=0)
247 | print('all_input: ', all_input)
248 | # [1, 3] [1, 2]
249 | conv_list_d = [[32, self.kernel_size, self.stride, 'SAME', self.act_type],
250 | [64, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
251 | [128, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn'],
252 | [256, self.kernel_size, self.stride, 'SAME', self.act_type, 'bn']]
253 | outputs = md.cnn(all_input, conv_list_d, [], name='discriminator', reuse=self.is_reuse)
254 | return outputs
255 |
256 |
257 |
--------------------------------------------------------------------------------
/model/gan_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import *
2 | from model import discriminator_model as dm
3 | from utils import exp_loss as es, gan_loss as gls
4 |
5 |
6 | class GanModel(BaseModel):
7 | """The Generative adversarial model"""
8 | def __init__(self, train_type, config):
9 | super(GanModel, self).__init__(train_type, config)
10 |
11 | mot_predictions = self.mot_predictions
12 | mot_truth = self.mot_truth
13 | mus_ebd_outputs = self.mus_ebd_outputs
14 |
15 | tru_pos, pre_pos = es.get_pos_chls(mot_predictions, mot_truth, config)
16 |
17 | dis_name = config.dis_name
18 | dis_graph = getattr(dm, config.dis_type)
19 |
20 | if self.mus_ebd_dim == 60:
21 | real_data = mot_truth
22 | fake_data = mot_predictions
23 | elif self.mus_ebd_dim == 72:
24 | real_data = tru_pos
25 | fake_data = pre_pos
26 | else:
27 | real_data = tf.concat([mot_truth, tru_pos], axis=-1)
28 | fake_data = tf.concat([mot_predictions, pre_pos], axis=-1)
29 |
30 | print('real_data:', real_data)
31 | print('fake_data:', fake_data)
32 |
33 | g_sig_loss, d_loss, clip_d_weights = \
34 | gls.gan_loss(dis_graph, dis_name, real_data=real_data, fake_data=fake_data,
35 | cond_inputs=mus_ebd_outputs, config=config)
36 |
37 | # generator loss
38 | g_loss, loss_list = es.loss_impl(mot_predictions, mot_truth, pre_pos, tru_pos, config)
39 | # g_mse_loss = tf.reduce_mean(tf.squared_difference(mot_predictions, mot_truth),
40 | # name='mean_square_loss')
41 | g_loss = config.mse_rate * g_loss + config.dis_rate * g_sig_loss
42 | self.g_loss = [loss_list, g_sig_loss]
43 | self.d_loss = d_loss
44 |
45 | # if test, return
46 | if not self.is_training:
47 | return
48 |
49 | tvars = tf.trainable_variables()
50 | d_vars = [v for v in tvars if 'discriminator' in v.name]
51 | g_vars = [v for v in tvars if 'generator' in v.name]
52 |
53 | # add reg
54 | if config.is_reg:
55 | reg_cost = tf.reduce_sum([tf.nn.l2_loss(v) for v in g_vars
56 | if 'bias' not in v.name]) * config.reg_scale
57 | g_loss = g_loss + reg_cost
58 |
59 | gen_learning_rate = config.learning_rate
60 | dis_learning_rate = config.dis_learning_rate
61 |
62 | if config.optimizer.lower() == 'adam':
63 | print('Adam optimizer')
64 | g_optimizer = tf.train.AdamOptimizer(learning_rate=gen_learning_rate)
65 | d_optimizer = tf.train.AdamOptimizer(learning_rate=dis_learning_rate)
66 | else:
67 | print('Rmsprop optimizer')
68 | g_optimizer = tf.train.RMSPropOptimizer(learning_rate=gen_learning_rate)
69 | d_optimizer = tf.train.RMSPropOptimizer(learning_rate=dis_learning_rate)
70 |
71 | # for batch_norm op
72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
73 | g_grads = tf.gradients(g_loss, g_vars, aggregation_method=2)
74 | d_grads = tf.gradients(d_loss, d_vars, aggregation_method=2)
75 | with tf.control_dependencies(update_ops):
76 | self.train_g_op = g_optimizer.apply_gradients(zip(g_grads, g_vars))
77 | print('train_g_op')
78 |
79 | if clip_d_weights:
80 | with tf.control_dependencies([clip_d_weights, update_ops]):
81 | self.train_d_op = d_optimizer.apply_gradients(zip(d_grads, d_vars))
82 | # self._train_d_op = optimizer.minimize(d_loss, var_list=d_vars)
83 | else:
84 | with tf.control_dependencies(update_ops):
85 | self.train_d_op = d_optimizer.apply_gradients(zip(d_grads, d_vars))
86 | print('train_d_op')
--------------------------------------------------------------------------------
/pretrain.sh:
--------------------------------------------------------------------------------
1 | gpu=0
2 | dis_type='DisSegGraph'
3 | loss_mode='gan'
4 | seg_len=90
5 | loss_type=2
6 | if [ $loss_type == 1 ]; then
7 | loss_arr=(1.0 0.1 0.0)
8 | elif [ $loss_type == 2 ]; then
9 | loss_arr=(1.0 0.1 0.1)
10 | else
11 | loss_arr=(1.0 0.0 0.0)
12 | fi
13 | mus_ebd_dim=72
14 | dis_name='time_cond_cnn'
15 | kernel_size=(1 3)
16 | stride=(1 2)
17 | cond_axis=1
18 | CUDA_VISIBLE_DEVICES=$gpu \
19 | python3 train_gan.py --learning_rate 1e-4 \
20 | --dis_learning_rate 2e-5 \
21 | --mse_rate 1 \
22 | --dis_rate 0.01 \
23 | --loss_mode $loss_mode \
24 | --is_load_model False \
25 | --is_reg False \
26 | --reg_scale 5e-5 \
27 | --rnn_keep_list 1.0 1.0 1.0\
28 | --dis_type $dis_type \
29 | --dis_name $dis_name \
30 | --loss_rate_list ${loss_arr[0]} ${loss_arr[1]} ${loss_arr[2]}\
31 | --kernel_size ${kernel_size[0]} ${kernel_size[1]} \
32 | --stride ${stride[0]} ${stride[1]}\
33 | --act_type lrelu \
34 | --optimizer Adam \
35 | --cond_axis $cond_axis \
36 | --seg_list $seg_len \
37 | --seq_shift 1 \
38 | --gen_hop $seg_len \
39 | --fold_list 0 \
40 | --type_list all-f4 \
41 | --model_path '' \
42 | --max_max_epoch 20 \
43 | --save_data_epoch 5 \
44 | --save_model_epoch 5 \
45 | --is_save_train False \
46 | --mot_scale 100. \
47 | --norm_way zscore \
48 | --teacher_forcing_ratio 0. \
49 | --tf_decay 1. \
50 | --batch_size 128 \
51 | --mus_ebd_dim $mus_ebd_dim \
52 | --has_random_seed False \
53 | --is_all_norm False \
54 | --add_info ./output/pretrain
--------------------------------------------------------------------------------
/train_gan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import json
4 | import os
5 | import time
6 | from collections import OrderedDict
7 | from datetime import datetime
8 |
9 | import numpy as np
10 | from model.gan_model import *
11 |
12 | import m2m_config as cfg
13 | from utils import reader as rd
14 |
15 |
16 | def run_epoch(session, model, data_info, config, teacher_forcing_ratio,
17 | path=None, train_type=0, verbose=False, epoch=0):
18 | """Runs the model on the given data"""
19 | d_losses = 0.0
20 | g_exp_losses = 0.0
21 | g_losses = 0.0
22 |
23 | mot_state = session.run(model.mot_initial_state)
24 | mus_state = session.run(model.mus_initial_state)
25 |
26 | step = 0
27 | np.random.seed(123456789)
28 |
29 | for batch_x, batch_y, batch_f in rd.capg_seq_generator(epoch, train_type, data_info, config):
30 | feed_dict = dict()
31 | if config.is_use_pre_mot:
32 | for i, (c, h) in enumerate(model.mot_initial_state):
33 | feed_dict[c] = mot_state[i].c
34 | feed_dict[h] = mot_state[i].h
35 |
36 | for i, (c, h) in enumerate(model.mus_initial_state):
37 | feed_dict[c] = mus_state[i].c
38 | feed_dict[h] = mus_state[i].h
39 |
40 | tf_mask = np.random.uniform(size=config.num_steps) < teacher_forcing_ratio
41 | # print(tf_mask)
42 | feed_dict[model.tf_mask] = tf_mask
43 |
44 | last_step_mot = copy.deepcopy(batch_f)
45 | last_step_mot[:, :6] = 0
46 |
47 | feed_dict[model.init_step_mot] = last_step_mot
48 | feed_dict[model.input_x] = batch_x
49 | feed_dict[model.input_y] = batch_y
50 |
51 | g_fetches = {
52 | "last_step_mot": model.last_step_mot,
53 | "g_loss": model.g_loss,
54 | "eval_op": model.train_g_op
55 | }
56 |
57 | d_fetches = {
58 | "d_loss": model.d_loss,
59 | "eval_op": model.train_d_op
60 | }
61 |
62 | d_vals = session.run(d_fetches, feed_dict)
63 | g_vals = session.run(g_fetches, feed_dict)
64 |
65 | d_loss = d_vals["d_loss"]
66 | g_loss = g_vals["g_loss"]
67 |
68 | d_losses += d_loss
69 | g_exp_losses += g_loss[0][-1]
70 | g_losses += g_loss[1]
71 | step += 1
72 |
73 | if verbose:
74 | info = "Epoch {0}: {1} d_loss: {2} g_loss: {3}, exp_loss: {4}\n".format(
75 | epoch, step, d_loss, g_loss[1], g_loss[0])
76 | print(info)
77 | with open(path, 'a') as fh:
78 | fh.write(info)
79 |
80 | return [d_losses/step, g_losses/step, g_exp_losses/step]
81 |
82 |
83 | def generate_motion(session, model, data_info, gen_str, test_config, hop, epoch=0,
84 | time_dir=None, use_pre_mot=True, prefix='test', is_save=True):
85 | """Runs the model on the given data"""
86 | g_exp_losses = 0.0
87 | g_losses = 0.0
88 | d_losses = 0.0
89 |
90 | fetches = {
91 | "prediction": model.mot_predictions,
92 | "last_step_mot": model.last_step_mot,
93 | "g_loss": model.g_loss,
94 | "d_loss": model.d_loss,
95 | "mot_final_state": model.mot_final_state,
96 | "mus_final_state": model.mus_final_state,
97 | }
98 |
99 | step = 0
100 | num_steps = test_config.num_steps
101 | pre_mot = []
102 | mus_data = data_info[gen_str][0]
103 | mot_data = copy.deepcopy(data_info[gen_str][1])
104 |
105 | seq_keys = list(mus_data.keys())
106 | seq_keys.sort()
107 | mus_delay = test_config.mus_delay
108 |
109 | for file_name in seq_keys:
110 | predictions = []
111 | mus_file_data = mus_data[file_name]
112 | mot_file_data = mot_data[file_name]
113 | test_len = min(mus_file_data.shape[1]+mus_delay, mot_file_data.shape[1])
114 | test_num = int((test_len - 1 - num_steps) / hop + 1)
115 |
116 | mot_state = session.run(model.mot_initial_state)
117 | mus_state = session.run(model.mus_initial_state)
118 |
119 | for t in range(test_num):
120 | batch_x = mus_file_data[:, t * hop + 1 - mus_delay: t * hop + num_steps + 1 - mus_delay, :]
121 | batch_y = mot_file_data[:, t * hop + 1: t * hop + num_steps + 1, :]
122 | batch_f = mot_file_data[:, t * hop, :] # first frame
123 |
124 | feed_dict = dict()
125 | if use_pre_mot:
126 | for i, (c, h) in enumerate(model.mot_initial_state):
127 | feed_dict[c] = mot_state[i].c
128 | feed_dict[h] = mot_state[i].h
129 |
130 | for i, (c, h) in enumerate(model.mus_initial_state):
131 | feed_dict[c] = mus_state[i].c
132 | feed_dict[h] = mus_state[i].h
133 |
134 | if t > 0 and use_pre_mot:
135 | last_step_mot = copy.deepcopy(pre_mot)
136 | else:
137 | last_step_mot = copy.deepcopy(batch_f)
138 | last_step_mot[:, :6] = 0
139 |
140 | feed_dict[model.init_step_mot] = last_step_mot
141 | feed_dict[model.input_x] = batch_x
142 | feed_dict[model.input_y] = batch_y
143 | feed_dict[model.tf_mask] = [False] * test_config.num_steps
144 |
145 | vals = session.run(fetches, feed_dict)
146 |
147 | prediction = vals["prediction"]
148 | g_loss = vals["g_loss"]
149 | d_loss = vals["d_loss"]
150 | mot_state = vals["mot_final_state"]
151 | mus_state = vals["mus_final_state"]
152 | pre_mot = vals["last_step_mot"]
153 |
154 | d_losses += d_loss
155 | g_exp_losses += g_loss[0][-1]
156 | g_losses += g_loss[1]
157 |
158 | step += 1
159 | prediction = np.reshape(prediction, [test_config.num_steps, test_config.mot_dim])
160 | predictions.append(prediction)
161 |
162 | if is_save and ((epoch+1) % test_config.save_data_epoch == 0 or epoch == 0):
163 | test_pred_path = os.path.join(time_dir, prefix, str(epoch+1), file_name + ".csv")
164 | if len(predictions):
165 | predictions = np.concatenate(predictions, 0)
166 | rd.save_predict_data(predictions, test_pred_path, data_info,
167 | test_config.norm_way, test_config.mot_ignore_dims,
168 | test_config.mot_scale)
169 |
170 | return [d_losses/step, g_losses/step, g_exp_losses/step]
171 |
172 |
173 | def save_arg(config, path):
174 | config_dict = dict()
175 | for name, value in vars(config).items():
176 | config_dict[name] = value
177 | json.dump(config_dict, open(path, 'w'), indent=4, sort_keys=True)
178 |
179 |
180 | def run_main(config, test_config, data_info):
181 |
182 | with tf.Graph().as_default():
183 | with tf.name_scope("Train"):
184 | with tf.variable_scope("Model", reuse=None):
185 | train_model = GanModel(config=config,
186 | train_type=0)
187 |
188 | with tf.name_scope("Test"):
189 | with tf.variable_scope("Model", reuse=True):
190 | test_model = GanModel(config=test_config,
191 | train_type=2)
192 |
193 | # allowing gpu memory growth
194 | gpu_config = tf.ConfigProto()
195 | saver = tf.train.Saver(max_to_keep=20)
196 | gpu_config.gpu_options.allow_growth = True
197 |
198 | with tf.Session(config=gpu_config) as session:
199 |
200 | # initialize all variables
201 | if config.is_load_model:
202 | saver.restore(session, config.model_path)
203 | else:
204 | session.run(tf.global_variables_initializer())
205 |
206 | # start queue
207 | coord = tf.train.Coordinator()
208 | tf.train.start_queue_runners(sess=session, coord=coord)
209 |
210 | time_str = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
211 | save_dir = config.save_dir
212 | if not os.path.exists(save_dir):
213 | os.makedirs(save_dir)
214 | model_save_dir = os.path.join(save_dir, 'model')
215 | if not os.path.exists(save_dir):
216 | os.makedirs(model_save_dir)
217 | train_loss_dict = OrderedDict()
218 | test_loss_dict = OrderedDict()
219 | start_time = time.time()
220 | train_loss_path = os.path.join(save_dir, "train_loss.txt")
221 | train_step_loss_path = os.path.join(save_dir, "train_step_loss.txt")
222 | config_path = os.path.join(save_dir, "config.txt")
223 | time_path = os.path.join(save_dir, "time.txt")
224 | config.save_config(config_path)
225 | arg_path = os.path.join(save_dir, "args.txt")
226 | save_arg(args, arg_path)
227 |
228 | teacher_forcing_ratio = config.teacher_forcing_ratio
229 |
230 | for i in range(config.max_max_epoch):
231 | train_loss = \
232 | run_epoch(session, train_model,
233 | data_info, config,
234 | teacher_forcing_ratio,
235 | path=train_step_loss_path,
236 | train_type=0,
237 | epoch=i,
238 | verbose=True)
239 |
240 | print("---Epoch {0} train_loss: {1}\n".format(i, train_loss))
241 | train_loss_dict[str(i+1)] = train_loss
242 | json.dump(train_loss_dict, open(train_loss_path, 'w'), indent=4)
243 |
244 | if (i + 1) % test_config.save_data_epoch == 0:
245 | _ = \
246 | generate_motion(session, test_model,
247 | data_info, 'test', test_config, hop=test_config.num_steps,
248 | epoch=i, time_dir=save_dir,
249 | use_pre_mot=True, prefix='seq')
250 |
251 | if test_config.is_save_train:
252 | _ = \
253 | generate_motion(session, test_model,
254 | data_info, 'train', test_config, hop=test_config.num_steps,
255 | epoch=i, time_dir=save_dir,
256 | use_pre_mot=True, prefix='seq_train')
257 |
258 | if (i == 0 or (i + 1) % config.save_model_epoch == 0) and config.is_save_model:
259 | model_save_path = os.path.join(model_save_dir, 'cnn-erd_'+str(i)+'_model.ckpt')
260 | saver.save(session, model_save_path)
261 |
262 | time_info = "Epoch: {0} Elapsed Time : {1}\n".format(i + 1, time.time()-start_time)
263 | print(time_info)
264 | with open(time_path, 'a') as fh:
265 | fh.write(time_info)
266 |
267 | teacher_forcing_ratio *= config.tf_decay
268 |
269 | coord.request_stop()
270 | coord.join()
271 |
272 |
273 | def main(_):
274 | type_list = args.type_list
275 | fold_list = args.fold_list
276 | seg_list = args.seg_list
277 |
278 | for seg_len in seg_list:
279 | for fold_idx in fold_list:
280 | for i, m_type in enumerate(type_list):
281 | seg_str = str(seg_len)
282 | fold_str = 'fold_' + str(fold_idx)
283 | print(m_type, seg_str, fold_str)
284 | if fold_idx != 0 and m_type in ['hiphop', 'salsa']:
285 | continue
286 | if fold_idx == 3 and m_type == 'groovenet':
287 | continue
288 | config = cfg.get_config(m_type, fold_str, seg_str)
289 | cfg_list = []
290 | care_list = ['add_info', 'mse_rate', 'dis_rate', 'dis_learning_rate',
291 | 'reg_scale', 'rnn_keep_list', 'is_reg', 'cond_axis']
292 | for k, v in sorted(vars(args).items()):
293 | print(k, v)
294 | setattr(config, k, v)
295 | if k in care_list:
296 | v_str = str(v)
297 | if isinstance(v, bool):
298 | v_str = v_str[0]
299 | cfg_list.append(v_str)
300 | config.save_dir = os.path.join(args.add_info, m_type)
301 |
302 | args.care_list = care_list
303 | test_config = copy.deepcopy(config)
304 | test_config.batch_size = config.test_batch_size
305 | test_config.num_steps = config.test_num_steps
306 |
307 | print(config.save_dir)
308 | data_info = rd.run_all(config)
309 | config.mot_data_info = data_info['mot']
310 | test_config.mot_data_info = data_info['mot']
311 | run_main(config, test_config, data_info)
312 |
313 |
314 | if __name__ == "__main__":
315 |
316 | parser = argparse.ArgumentParser()
317 | parser.add_argument('--add_info', type=str, default='')
318 | parser.add_argument('--learning_rate', type=float, default=1e-4)
319 | parser.add_argument('--is_load_model', type=lambda x: (str(x).lower() == 'true'))
320 | parser.add_argument('--optimizer', type=str, default='Adam')
321 | parser.add_argument('--fold_list', nargs='+', type=int, help='0, 1, 2, 3')
322 | parser.add_argument('--seg_list', nargs='+', type=int, help='150, 90')
323 | parser.add_argument('--type_list', nargs='+', type=str, help='music type')
324 | parser.add_argument('--model_path', type=str, help='model_path')
325 | parser.add_argument('--max_max_epoch', type=int, help='training epoch number')
326 | parser.add_argument('--save_model_epoch', type=int, help='save_model_epoch_number')
327 | parser.add_argument('--save_data_epoch', type=int, help='save_data_epoch_number')
328 | parser.add_argument('--is_reg', type=lambda x: (str(x).lower() == 'true'), help='if add regularization')
329 | parser.add_argument('--reg_scale', type=float, help='5e-4')
330 | parser.add_argument('--rnn_keep_list', nargs='+', type=float, help='rnn_keep_probability list, [1.0, 1.0, 1.0]')
331 | parser.add_argument('--batch_size', type=int, help='32 or 64')
332 | parser.add_argument('--has_random_seed', type=lambda x: (str(x).lower() == 'true'), help='')
333 | parser.add_argument('--teacher_forcing_ratio', type=float, help='')
334 | parser.add_argument('--tf_decay', type=float, help='')
335 | parser.add_argument('--norm_way', type=str, help='zscore, maxmin, no')
336 | parser.add_argument('--seq_shift', type=int, help='seq_shift')
337 | parser.add_argument('--gen_hop', type=int, help='gen_hop')
338 | parser.add_argument('--mot_scale', type=float, help='motion scale')
339 | parser.add_argument('--is_save_train', type=lambda x: (str(x).lower() == 'true'))
340 | parser.add_argument('--cond_axis', type=int, help='1: height, 3: channel', default=3)
341 | parser.add_argument('--act_type', type=str, default='lrelu')
342 | parser.add_argument('--kernel_size', nargs='+', type=int)
343 | parser.add_argument('--stride', nargs='+', type=int)
344 | parser.add_argument('--dis_learning_rate', type=float, default=1e-4)
345 | parser.add_argument('--dis_type', type=str, help='DisFrameGraph or DisSegGraph')
346 | parser.add_argument('--dis_name', type=str, default='cond_cnn')
347 | parser.add_argument('--mse_rate', type=float, default=0.99)
348 | parser.add_argument('--dis_rate', type=float, default=0.01)
349 | parser.add_argument('--loss_mode', type=str, default='gan')
350 | parser.add_argument('--clip_value', type=float, default=0.01)
351 | parser.add_argument('--pen_lambda', type=float, default=10)
352 | parser.add_argument('--mus_ebd_dim', type=int)
353 | parser.add_argument('--is_all_norm', type=lambda x: (str(x).lower() == 'true'), default=False)
354 | parser.add_argument('--loss_rate_list', nargs='+', type=float, default=[1., 0., 0.])
355 |
356 | args = parser.parse_args()
357 |
358 | tf.app.run()
359 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computer-animation-perception-group/DeepDance_train/9f47de0b9f599b2590608465ffb96c9eb0344502/utils/__init__.py
--------------------------------------------------------------------------------
/utils/capg_exp_skel.pkl:
--------------------------------------------------------------------------------
1 | (lp0
2 | ccopy_reg
3 | _reconstructor
4 | p1
5 | (cbvh
6 | Node
7 | p2
8 | c__builtin__
9 | object
10 | p3
11 | Ntp4
12 | Rp5
13 | (dp6
14 | Vparent
15 | p7
16 | L-1L
17 | sV_is_root
18 | p8
19 | I01
20 | sVoffset
21 | p9
22 | (F38.2704
23 | F101.351
24 | F-83.7549
25 | tp10
26 | sVchannels
27 | p11
28 | (lp12
29 | VXposition
30 | p13
31 | aVYposition
32 | p14
33 | aVZposition
34 | p15
35 | aVZrotation
36 | p16
37 | aVXrotation
38 | p17
39 | aVYrotation
40 | p18
41 | asVorder
42 | p19
43 | VZXY
44 | p20
45 | sVquat_idx
46 | p21
47 | L0L
48 | sVname
49 | p22
50 | VHips
51 | p23
52 | sVchildren
53 | p24
54 | (lp25
55 | g1
56 | (g2
57 | g3
58 | Ntp26
59 | Rp27
60 | (dp28
61 | g7
62 | L0L
63 | sg8
64 | I00
65 | sg9
66 | (F0.0
67 | F11.6235
68 | F0.0
69 | tp29
70 | sg11
71 | (lp30
72 | VZrotation
73 | p31
74 | aVXrotation
75 | p32
76 | aVYrotation
77 | p33
78 | asg19
79 | VZXY
80 | p34
81 | sg21
82 | L1L
83 | sg22
84 | VChest
85 | p35
86 | sg24
87 | (lp36
88 | g1
89 | (g2
90 | g3
91 | Ntp37
92 | Rp38
93 | (dp39
94 | g7
95 | L1L
96 | sg8
97 | I00
98 | sg9
99 | (F0.0
100 | F16.7104
101 | F0.0
102 | tp40
103 | sg11
104 | (lp41
105 | VZrotation
106 | p42
107 | aVXrotation
108 | p43
109 | aVYrotation
110 | p44
111 | asg19
112 | VZXY
113 | p45
114 | sg21
115 | L2L
116 | sg22
117 | VChest2
118 | p46
119 | sg24
120 | (lp47
121 | g1
122 | (g2
123 | g3
124 | Ntp48
125 | Rp49
126 | (dp50
127 | g7
128 | L2L
129 | sg8
130 | I00
131 | sg9
132 | (F0.0
133 | F10.3009
134 | F0.0
135 | tp51
136 | sg11
137 | (lp52
138 | VZrotation
139 | p53
140 | aVXrotation
141 | p54
142 | aVYrotation
143 | p55
144 | asg19
145 | VZXY
146 | p56
147 | sg21
148 | L3L
149 | sg22
150 | VNeck
151 | p57
152 | sg24
153 | (lp58
154 | g1
155 | (g2
156 | g3
157 | Ntp59
158 | Rp60
159 | (dp61
160 | g7
161 | L3L
162 | sg8
163 | I00
164 | sg9
165 | (F0.0
166 | F13.2005
167 | F0.0
168 | tp62
169 | sg11
170 | (lp63
171 | VZrotation
172 | p64
173 | aVXrotation
174 | p65
175 | aVYrotation
176 | p66
177 | asg19
178 | VZXY
179 | p67
180 | sg21
181 | L4L
182 | sg22
183 | VHead
184 | p68
185 | sg24
186 | (lp69
187 | g1
188 | (g2
189 | g3
190 | Ntp70
191 | Rp71
192 | (dp72
193 | g7
194 | L4L
195 | sg8
196 | I00
197 | sg9
198 | (F0.0
199 | F10.5299
200 | F0.0
201 | tp73
202 | sg11
203 | (lp74
204 | sg19
205 | V
206 | p75
207 | sg21
208 | Nsg22
209 | VEnd Site
210 | p76
211 | sg24
212 | (lp77
213 | sVrot_idx
214 | p78
215 | (lp79
216 | sVpos_idx
217 | p80
218 | (lp81
219 | sVexp_idx
220 | p82
221 | (lp83
222 | L18L
223 | aL19L
224 | aL20L
225 | asbasg78
226 | cnumpy.core.multiarray
227 | _reconstruct
228 | p84
229 | (cnumpy
230 | ndarray
231 | p85
232 | (L0L
233 | tp86
234 | c_codecs
235 | encode
236 | p87
237 | (Vb
238 | p88
239 | Vlatin1
240 | p89
241 | tp90
242 | Rp91
243 | tp92
244 | Rp93
245 | (L1L
246 | (L3L
247 | tp94
248 | cnumpy
249 | dtype
250 | p95
251 | (Vi4
252 | p96
253 | L0L
254 | L1L
255 | tp97
256 | Rp98
257 | (L3L
258 | V<
259 | p99
260 | NNNL-1L
261 | L-1L
262 | L0L
263 | tp100
264 | bI00
265 | g87
266 | (V
267 | p101
268 | g89
269 | tp102
270 | Rp103
271 | tp104
272 | bsg80
273 | (lp105
274 | sg82
275 | (lp106
276 | L15L
277 | aL16L
278 | aL17L
279 | asbasg78
280 | g84
281 | (g85
282 | (L0L
283 | tp107
284 | g91
285 | tp108
286 | Rp109
287 | (L1L
288 | (L3L
289 | tp110
290 | g98
291 | I00
292 | g87
293 | (V
294 | p111
295 | g89
296 | tp112
297 | Rp113
298 | tp114
299 | bsg80
300 | (lp115
301 | sg82
302 | (lp116
303 | L12L
304 | aL13L
305 | aL14L
306 | asbag1
307 | (g2
308 | g3
309 | Ntp117
310 | Rp118
311 | (dp119
312 | g7
313 | L2L
314 | sg8
315 | I00
316 | sg9
317 | (F2.69605
318 | F10.657
319 | F4.47645
320 | tp120
321 | sg11
322 | (lp121
323 | VZrotation
324 | p122
325 | aVXrotation
326 | p123
327 | aVYrotation
328 | p124
329 | asg19
330 | VZXY
331 | p125
332 | sg21
333 | L5L
334 | sg22
335 | VLeftCollar
336 | p126
337 | sg24
338 | (lp127
339 | g1
340 | (g2
341 | g3
342 | Ntp128
343 | Rp129
344 | (dp130
345 | g7
346 | L6L
347 | sg8
348 | I00
349 | sg9
350 | (F14.7774
351 | F0.0
352 | F0.0
353 | tp131
354 | sg11
355 | (lp132
356 | VZrotation
357 | p133
358 | aVXrotation
359 | p134
360 | aVYrotation
361 | p135
362 | asg19
363 | VZXY
364 | p136
365 | sg21
366 | L6L
367 | sg22
368 | VLeftShoulder
369 | p137
370 | sg24
371 | (lp138
372 | g1
373 | (g2
374 | g3
375 | Ntp139
376 | Rp140
377 | (dp141
378 | g7
379 | L7L
380 | sg8
381 | I00
382 | sg9
383 | (F0.0
384 | F-30.7247
385 | F0.0
386 | tp142
387 | sg11
388 | (lp143
389 | VZrotation
390 | p144
391 | aVXrotation
392 | p145
393 | aVYrotation
394 | p146
395 | asg19
396 | VZXY
397 | p147
398 | sg21
399 | L7L
400 | sg22
401 | VLeftElbow
402 | p148
403 | sg24
404 | (lp149
405 | g1
406 | (g2
407 | g3
408 | Ntp150
409 | Rp151
410 | (dp152
411 | g7
412 | L8L
413 | sg8
414 | I00
415 | sg9
416 | (F0.0
417 | F-24.9766
418 | F0.0
419 | tp153
420 | sg11
421 | (lp154
422 | VZrotation
423 | p155
424 | aVXrotation
425 | p156
426 | aVYrotation
427 | p157
428 | asg19
429 | VZXY
430 | p158
431 | sg21
432 | L8L
433 | sg22
434 | VLeftWrist
435 | p159
436 | sg24
437 | (lp160
438 | g1
439 | (g2
440 | g3
441 | Ntp161
442 | Rp162
443 | (dp163
444 | g7
445 | L9L
446 | sg8
447 | I00
448 | sg9
449 | (F0.0
450 | F-18.7451
451 | F0.0
452 | tp164
453 | sg11
454 | (lp165
455 | sg19
456 | g75
457 | sg21
458 | Nsg22
459 | g76
460 | sg24
461 | (lp166
462 | sg78
463 | (lp167
464 | sg80
465 | (lp168
466 | sg82
467 | (lp169
468 | L33L
469 | aL34L
470 | aL35L
471 | asbasg78
472 | g84
473 | (g85
474 | (L0L
475 | tp170
476 | g91
477 | tp171
478 | Rp172
479 | (L1L
480 | (L3L
481 | tp173
482 | g98
483 | I00
484 | g87
485 | (V
486 | p174
487 | g89
488 | tp175
489 | Rp176
490 | tp177
491 | bsg80
492 | (lp178
493 | sg82
494 | (lp179
495 | L30L
496 | aL31L
497 | aL32L
498 | asbasg78
499 | g84
500 | (g85
501 | (L0L
502 | tp180
503 | g91
504 | tp181
505 | Rp182
506 | (L1L
507 | (L3L
508 | tp183
509 | g98
510 | I00
511 | g87
512 | (V
513 | p184
514 | g89
515 | tp185
516 | Rp186
517 | tp187
518 | bsg80
519 | (lp188
520 | sg82
521 | (lp189
522 | L27L
523 | aL28L
524 | aL29L
525 | asbasg78
526 | g84
527 | (g85
528 | (L0L
529 | tp190
530 | g91
531 | tp191
532 | Rp192
533 | (L1L
534 | (L3L
535 | tp193
536 | g98
537 | I00
538 | g87
539 | (V
540 | p194
541 | g89
542 | tp195
543 | Rp196
544 | tp197
545 | bsg80
546 | (lp198
547 | sg82
548 | (lp199
549 | L24L
550 | aL25L
551 | aL26L
552 | asbasg78
553 | g84
554 | (g85
555 | (L0L
556 | tp200
557 | g91
558 | tp201
559 | Rp202
560 | (L1L
561 | (L3L
562 | tp203
563 | g98
564 | I00
565 | g87
566 | (V
567 | p204
568 | g89
569 | tp205
570 | Rp206
571 | tp207
572 | bsg80
573 | (lp208
574 | sg82
575 | (lp209
576 | L21L
577 | aL22L
578 | aL23L
579 | asbag1
580 | (g2
581 | g3
582 | Ntp210
583 | Rp211
584 | (dp212
585 | g7
586 | L2L
587 | sg8
588 | I00
589 | sg9
590 | (F-2.69605
591 | F10.657
592 | F4.47645
593 | tp213
594 | sg11
595 | (lp214
596 | VZrotation
597 | p215
598 | aVXrotation
599 | p216
600 | aVYrotation
601 | p217
602 | asg19
603 | VZXY
604 | p218
605 | sg21
606 | L9L
607 | sg22
608 | VRightCollar
609 | p219
610 | sg24
611 | (lp220
612 | g1
613 | (g2
614 | g3
615 | Ntp221
616 | Rp222
617 | (dp223
618 | g7
619 | L11L
620 | sg8
621 | I00
622 | sg9
623 | (F-15.4132
624 | F0.0
625 | F0.0
626 | tp224
627 | sg11
628 | (lp225
629 | VZrotation
630 | p226
631 | aVXrotation
632 | p227
633 | aVYrotation
634 | p228
635 | asg19
636 | VZXY
637 | p229
638 | sg21
639 | L10L
640 | sg22
641 | VRightShoulder
642 | p230
643 | sg24
644 | (lp231
645 | g1
646 | (g2
647 | g3
648 | Ntp232
649 | Rp233
650 | (dp234
651 | g7
652 | L12L
653 | sg8
654 | I00
655 | sg9
656 | (F0.0
657 | F-28.1813
658 | F0.0
659 | tp235
660 | sg11
661 | (lp236
662 | VZrotation
663 | p237
664 | aVXrotation
665 | p238
666 | aVYrotation
667 | p239
668 | asg19
669 | VZXY
670 | p240
671 | sg21
672 | L11L
673 | sg22
674 | VRightElbow
675 | p241
676 | sg24
677 | (lp242
678 | g1
679 | (g2
680 | g3
681 | Ntp243
682 | Rp244
683 | (dp245
684 | g7
685 | L13L
686 | sg8
687 | I00
688 | sg9
689 | (F0.0
690 | F-24.9766
691 | F0.0
692 | tp246
693 | sg11
694 | (lp247
695 | VZrotation
696 | p248
697 | aVXrotation
698 | p249
699 | aVYrotation
700 | p250
701 | asg19
702 | VZXY
703 | p251
704 | sg21
705 | L12L
706 | sg22
707 | VRightWrist
708 | p252
709 | sg24
710 | (lp253
711 | g1
712 | (g2
713 | g3
714 | Ntp254
715 | Rp255
716 | (dp256
717 | g7
718 | L14L
719 | sg8
720 | I00
721 | sg9
722 | (F0.0
723 | F-18.1602
724 | F0.0
725 | tp257
726 | sg11
727 | (lp258
728 | sg19
729 | g75
730 | sg21
731 | Nsg22
732 | g76
733 | sg24
734 | (lp259
735 | sg78
736 | (lp260
737 | sg80
738 | (lp261
739 | sg82
740 | (lp262
741 | L48L
742 | aL49L
743 | aL50L
744 | asbasg78
745 | g84
746 | (g85
747 | (L0L
748 | tp263
749 | g91
750 | tp264
751 | Rp265
752 | (L1L
753 | (L3L
754 | tp266
755 | g98
756 | I00
757 | g87
758 | (V( ) '
759 | p267
760 | g89
761 | tp268
762 | Rp269
763 | tp270
764 | bsg80
765 | (lp271
766 | sg82
767 | (lp272
768 | L45L
769 | aL46L
770 | aL47L
771 | asbasg78
772 | g84
773 | (g85
774 | (L0L
775 | tp273
776 | g91
777 | tp274
778 | Rp275
779 | (L1L
780 | (L3L
781 | tp276
782 | g98
783 | I00
784 | g87
785 | (V% &