├── License.txt
├── README.md
├── data.py
├── demo.py
├── imgs
├── .DS_Store
├── example.gif
├── long.gif
├── multimodal.gif
└── v2v.gif
├── model_comp.py
├── model_decomp.py
├── modulate.py
├── networks.py
├── options.py
├── test.py
├── train_comp.py
├── train_decomp.py
└── utils.py
/License.txt:
--------------------------------------------------------------------------------
1 | Nvidia Source Code License-NC
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 |
7 | “Software” means the original work of authorship made available under this License.
8 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
9 |
10 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates.
11 |
12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
13 |
14 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
15 |
16 | 2. License Grants
17 |
18 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
19 |
20 | 3. Limitations
21 |
22 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
23 |
24 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
25 |
26 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
27 |
28 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately.
29 |
30 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
31 |
32 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately.
33 |
34 | 4. Disclaimer of Warranty.
35 |
36 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
37 |
38 | 5. Limitation of Liability.
39 |
40 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
41 |
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | ## Dancing to Music
4 | PyTorch implementation of the cross-modality generative model that synthesizes dance from music.
5 |
6 |
7 | ### Paper
8 | [Hsin-Ying Lee](http://vllab.ucmerced.edu/hylee/), [Xiaodong Yang](https://xiaodongyang.org/), [Ming-Yu Liu](http://mingyuliu.net/), [Ting-Chun Wang](https://tcwang0509.github.io/), [Yu-Ding Lu](https://jonlu0602.github.io/), [Ming-Hsuan Yang](https://faculty.ucmerced.edu/mhyang/), [Jan Kautz](http://jankautz.com/)
9 | Dancing to Music
10 | Neural Information Processing Systems (**NeurIPS**) 2019
11 | [[Paper]](https://arxiv.org/abs/1911.02001) [[YouTube]](https://youtu.be/-e9USqfwZ4A) [[Project]](http://vllab.ucmerced.edu/hylee/Dancing2Music/script.txt) [[Blog]](https://news.developer.nvidia.com/nvidia-dance-to-music-neurips/) [[Supp]](http://xiaodongyang.org/publications/papers/dance2music-supp-neurips19.pdf)
12 |
13 | ### Example Videos
14 | - Beat-Matching
15 | 1st row: generated dance sequences, 2nd row: music beats, 3rd row: kinematics beats
16 |
17 |
18 |
19 |
20 | - Multimodality
21 | Generate various dance sequences with the same music and the same initial pose.
22 |
23 |
24 |
25 |
26 | - Long-Term Generation
27 | Seamlessly generate a dance sequence with arbitrary length.
28 |
29 |
30 |
31 |
32 |
33 |
34 | - Photo-Realisitc Videos
35 | Map generated dance sequences to photo-realistic videos.
36 |
37 |
38 |
39 |
40 |
41 | ## Train Decomposition
42 | ```
43 | python train_decomp.py --name Decomp
44 | ```
45 |
46 | ## Train Composition
47 | ```
48 | python train_comp.py --name Decomp --decomp_snapshot DECOMP_SNAPSHOT
49 | ```
50 |
51 | ## Demo
52 | ```
53 | python demo.py --decomp_snapshot DECOMP_SNAPSHOT --comp_snapshot COMP_SNAPSHOT --aud_path AUD_PATH --out_file OUT_FILE --out_dir OUT_DIR --thr THR
54 | ```
55 | - Flags
56 | - `aud_path`: input .wav file
57 | - `out_file`: location of output .mp4 file
58 | - `out_dir`: directory of output frames
59 | - `thr`: threshold based on motion magnitude
60 | - `modulate`: whether to do beat warping
61 |
62 | - Example
63 | ```
64 | python demo.py -decomp_snapshot snapshot/Stage1.ckpt --comp_snapshot snapshot/Stage2.ckpt --aud_path demo/demo.wav --out_file demo/out.mp4 --out_dir demo/out_frame
65 | ```
66 |
67 |
68 | ### Citation
69 | If you find this code useful for your research, please cite our paper:
70 | ```bibtex
71 | @inproceedings{lee2019dancing2music,
72 | title={Dancing to Music},
73 | author={Lee, Hsin-Ying and Yang, Xiaodong and Liu, Ming-Yu and Wang, Ting-Chun and Lu, Yu-Ding and Yang, Ming-Hsuan and Kautz, Jan},
74 | booktitle={NeurIPS},
75 | year={2019}
76 | }
77 | ```
78 |
79 | ### License
80 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. This work is made available under NVIDIA Source Code License (1-Way Commercial). To view a copy of this license, visit https://nvlabs.github.io/Dancing2Music/LICENSE.txt.
81 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 | import os
8 | import pickle
9 | import numpy as np
10 | import random
11 | import torch.utils.data
12 | from torchvision.datasets import ImageFolder
13 | import utils
14 |
15 |
16 | class PoseDataset(torch.utils.data.Dataset):
17 | def __init__(self, data_dir, tolerance=False):
18 | self.data_dir = data_dir
19 | z_fname = '{}/unitList/zumba_unit.txt'.format(data_dir)
20 | b_fname = '{}/unitList/ballet_unit.txt'.format(data_dir)
21 | h_fname = '{}/unitList/hiphop_unit.txt'.format(data_dir)
22 | self.z_data = []
23 | self.b_data = []
24 | self.h_data = []
25 | with open(z_fname, 'r') as f:
26 | for line in f:
27 | self.z_data.append([s for s in line.strip().split(' ')])
28 | with open(b_fname, 'r') as f:
29 | for line in f:
30 | self.b_data.append([s for s in line.strip().split(' ')])
31 | with open(h_fname, 'r') as f:
32 | for line in f:
33 | self.h_data.append([s for s in line.strip().split(' ')])
34 | self.data = [self.z_data, self.b_data, self.h_data]
35 |
36 | self.tolerance = tolerance
37 | if self.tolerance:
38 | z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
39 | b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
40 | h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
41 | z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
42 | b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
43 | h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
44 | z3_data = []; b3_data = []; h3_data = []; z4_data = []; b4_data = []; h4_data = []
45 | with open(z3_fname, 'r') as f:
46 | for line in f:
47 | z3_data.append([s for s in line.strip().split(' ')])
48 | with open(b3_fname, 'r') as f:
49 | for line in f:
50 | b3_data.append([s for s in line.strip().split(' ')])
51 | with open(h3_fname, 'r') as f:
52 | for line in f:
53 | h3_data.append([s for s in line.strip().split(' ')])
54 | with open(z4_fname, 'r') as f:
55 | for line in f:
56 | z4_data.append([s for s in line.strip().split(' ')])
57 | with open(b4_fname, 'r') as f:
58 | for line in f:
59 | b4_data.append([s for s in line.strip().split(' ')])
60 | with open(h4_fname, 'r') as f:
61 | for line in f:
62 | h4_data.append([s for s in line.strip().split(' ')])
63 | self.zt_data = z3_data + z4_data
64 | self.bt_data = b3_data + b4_data
65 | self.ht_data = h3_data + h4_data
66 | self.t_data = [self.zt_data, self.bt_data, self.ht_data]
67 |
68 | self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
69 | self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
70 |
71 | def __getitem__(self, index):
72 | cls = random.randint(0,2)
73 | cls = random.randint(0,1)
74 | if self.tolerance and random.randint(0,9)==0:
75 | index = random.randint(0, len(self.t_data[cls])-1)
76 | path = self.t_data[cls][index][0]
77 | path = os.path.join(self.data_dir, path[5:])
78 | orig_poses = np.load(path)
79 | sel = random.randint(0, orig_poses.shape[0]-1)
80 | orig_poses = orig_poses[sel]
81 | else:
82 | index = random.randint(0, len(self.data[cls])-1)
83 | path = self.data[cls][index][0]
84 | path = os.path.join(self.data_dir, path[5:])
85 | orig_poses = np.load(path)
86 |
87 | xjit = np.random.uniform(low=-50, high=50)
88 | yjit = np.random.uniform(low=-20, high=20)
89 | poses = orig_poses.copy()
90 | poses[:,:,0] += xjit
91 | poses[:,:,1] += yjit
92 | xjit = np.random.uniform(low=-50, high=50)
93 | yjit = np.random.uniform(low=-20, high=20)
94 | poses2 = orig_poses.copy()
95 | poses2[:,:,0] += xjit
96 | poses2[:,:,1] += yjit
97 |
98 | poses = poses.reshape(poses.shape[0], poses.shape[1]*poses.shape[2])
99 | poses2 = poses2.reshape(poses2.shape[0], poses2.shape[1]*poses2.shape[2])
100 | for i in range(poses.shape[0]):
101 | poses[i] = (poses[i]-self.mean_pose)/self.std_pose
102 | poses2[i] = (poses2[i]-self.mean_pose)/self.std_pose
103 |
104 | return torch.Tensor(poses), torch.Tensor(poses2)
105 |
106 | def __len__(self):
107 | return len(self.z_data)+len(self.b_data)
108 |
109 |
110 | class MovementAudDataset(torch.utils.data.Dataset):
111 | def __init__(self, data_dir):
112 | self.data_dir = data_dir
113 | z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
114 | b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
115 | h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
116 | z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
117 | b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
118 | h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
119 | self.z3_data = []
120 | self.b3_data = []
121 | self.h3_data = []
122 | self.z4_data = []
123 | self.b4_data = []
124 | self.h4_data = []
125 | with open(z3_fname, 'r') as f:
126 | for line in f:
127 | self.z3_data.append([s for s in line.strip().split(' ')])
128 | with open(b3_fname, 'r') as f:
129 | for line in f:
130 | self.b3_data.append([s for s in line.strip().split(' ')])
131 | with open(h3_fname, 'r') as f:
132 | for line in f:
133 | self.h3_data.append([s for s in line.strip().split(' ')])
134 | with open(z4_fname, 'r') as f:
135 | for line in f:
136 | self.z4_data.append([s for s in line.strip().split(' ')])
137 | with open(b4_fname, 'r') as f:
138 | for line in f:
139 | self.b4_data.append([s for s in line.strip().split(' ')])
140 | with open(h4_fname, 'r') as f:
141 | for line in f:
142 | self.h4_data.append([s for s in line.strip().split(' ')])
143 | self.data_3 = [self.z3_data, self.b3_data, self.h3_data]
144 | self.data_4 = [self.z4_data, self.b4_data, self.h4_data]
145 |
146 | z_data_root = 'zumba/'
147 | b_data_root = 'ballet/'
148 | h_data_root = 'hiphop/'
149 | self.data_root = [z_data_root, b_data_root, h_data_root ]
150 | self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
151 | self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
152 | self.mean_aud=np.load(data_dir+'/stats/all_aud_mean.npy')
153 | self.std_aud=np.load(data_dir+'/stats/all_aud_std.npy')
154 |
155 | def __getitem__(self, index):
156 | cls = random.randint(0,2)
157 | cls = random.randint(0,1)
158 | isthree = random.randint(0,1)
159 |
160 | if isthree == 0:
161 | index = random.randint(0, len(self.data_4[cls])-1)
162 | path = self.data_4[cls][index][0]
163 | else:
164 | index = random.randint(0, len(self.data_3[cls])-1)
165 | path = self.data_3[cls][index][0]
166 | path = os.path.join(self.data_dir, path[5:])
167 | stdpSeq = np.load(path)
168 | vid, cid = path.split('/')[-4], path.split('/')[-3]
169 | #vid, cid = vid_cid[:11], vid_cid[12:]
170 | aud = np.load('{}/{}/{}/{}/aud/c{}_fps15.npy'.format(self.data_dir, self.data_root[cls], vid, cid, cid))
171 |
172 | stdpSeq = stdpSeq.reshape(stdpSeq.shape[0], stdpSeq.shape[1], stdpSeq.shape[2]*stdpSeq.shape[3])
173 | for i in range(stdpSeq.shape[0]):
174 | for j in range(stdpSeq.shape[1]):
175 | stdpSeq[i,j] = (stdpSeq[i,j]-self.mean_pose)/self.std_pose
176 | if isthree == 0:
177 | start = random.randint(0,1)
178 | stdpSeq = stdpSeq[start:start+3]
179 |
180 | for i in range(aud.shape[0]):
181 | aud[i] = (aud[i]-self.mean_aud)/self.std_aud
182 | aud = aud[:30]
183 | return torch.Tensor(stdpSeq), torch.Tensor(aud)
184 |
185 | def __len__(self):
186 | return len(self.z3_data)+len(self.b3_data)+len(self.z4_data)+len(self.b4_data)+len(self.h3_data)+len(self.h4_data)
187 |
188 | def get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tolerance=False):
189 | if dataset == 0:
190 | a2d = PoseDataset(data_dir, tolerance)
191 | elif dataset == 2:
192 | a2d = MovementAudDataset(data_dir)
193 | data_loader = torch.utils.data.DataLoader(dataset=a2d,
194 | batch_size=batch_size,
195 | shuffle=shuffle,
196 | num_workers=num_workers,
197 | )
198 | return data_loader
199 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 | import os
8 | import argparse
9 | import functools
10 | import librosa
11 | import shutil
12 | import sys
13 | sys.path.insert(0, 'preprocess')
14 | import preprocess as p
15 | import subprocess as sp
16 | from shutil import copyfile
17 |
18 | import torch
19 | from torch.utils.data import DataLoader
20 | from torchvision import transforms
21 |
22 | from model_comp import *
23 | from networks import *
24 | from options import TestOptions
25 | import modulate
26 | import utils
27 |
28 | def loadDecompModel(args):
29 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
30 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
31 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
32 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
33 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
34 | checkpoint = torch.load(args.decomp_snapshot)
35 | initp_enc.load_state_dict(checkpoint['initp_enc'])
36 | stdp_dec.load_state_dict(checkpoint['stdp_dec'])
37 | movement_enc.load_state_dict(checkpoint['movement_enc'])
38 | return initp_enc, stdp_dec, movement_enc
39 |
40 | def loadCompModel(args):
41 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
42 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
43 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
44 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
45 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
46 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
47 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
48 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
49 | checkpoint = torch.load(args.comp_snapshot)
50 | dance_enc.load_state_dict(checkpoint['dance_enc'])
51 | dance_dec.load_state_dict(checkpoint['dance_dec'])
52 | audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
53 |
54 | checkpoint2 = torch.load(args.neta_snapshot)
55 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
56 | neta_cls.load_state_dict(checkpoint2)
57 |
58 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
59 |
60 | if __name__ == "__main__":
61 | parser = TestOptions()
62 | args = parser.parse()
63 | args.train = False
64 |
65 | thr = args.thr
66 |
67 | # Process music and get feature
68 | infile = args.aud_path
69 | outfile = 'style.npy'
70 | p.preprocess(infile, outfile)
71 |
72 | y, sr = librosa.load(infile)
73 | onset_env = librosa.onset.onset_strength(y, sr=sr,aggregate=np.median)
74 | times = librosa.frames_to_time(np.arange(len(onset_env)),sr=sr, hop_length=512)
75 | tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env,sr=sr)
76 | np.save('beats.npy', times[beats])
77 | beats = np.round(librosa.frames_to_time(beats, sr=sr)*15)
78 |
79 | beats = np.load('beats.npy')
80 | aud = np.load('style.npy')
81 | os.remove('beats.npy')
82 | os.remove('style.npy')
83 | shutil.rmtree('normalized')
84 |
85 | #### Pretrain network from Decomp
86 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
87 |
88 | #### Comp network
89 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
90 |
91 | trainer = Trainer_Comp(data_loader=None,
92 | movement_enc = movement_enc,
93 | initp_enc = initp_enc,
94 | stdp_dec = stdp_dec,
95 | dance_enc = dance_enc,
96 | dance_dec = dance_dec,
97 | danceAud_dis = danceAud_dis,
98 | zdance_dis = zdance_dis,
99 | aud_enc=neta_cls,
100 | audstyle_enc=audstyle_enc,
101 | dance_reg=dance_reg,
102 | args = args
103 | )
104 |
105 | print('Loading Done')
106 |
107 | mean_pose=np.load('{}/stats/all_onbeat_mean.npy'.format(args.data_dir))
108 | std_pose=np.load('{}/stats/all_onbeat_std.npy'.format(args.data_dir))
109 | mean_aud=np.load('{}/stats/all_aud_mean.npy'.format(args.data_dir))
110 | std_aud=np.load('{}/stats/all_aud_std.npy'.format(args.data_dir))
111 |
112 |
113 | length = aud.shape[0]
114 |
115 | initpose = np.zeros((14, 2))
116 | initpose = initpose.reshape(-1)
117 | #initpose = (initpose-mean_pose)/std_pose
118 |
119 | for j in range(aud.shape[0]):
120 | aud[j] = (aud[j]-mean_aud)/std_aud
121 |
122 | total_t = int(length/32+1)
123 | final_stdpSeq = np.zeros((total_t*3*32, 14, 2))
124 | initpose, aud = torch.Tensor(initpose).cuda(), torch.Tensor(aud).cuda()
125 | initpose, aud = initpose.view(1, initpose.shape[0]), aud.view(1, aud.shape[0], aud.shape[1])
126 | for t in range(total_t):
127 | print('process {}/{}'.format(t, total_t))
128 | fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
129 | while True:
130 | fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
131 | if not fake_stdpSeq is None:
132 | break
133 | initpose = fake_stdpSeq[2,-1]
134 | initpose = torch.Tensor(initpose).cuda()
135 | initpose = initpose.view(1,-1)
136 | fake_stdpSeq = fake_stdpSeq.squeeze()
137 | for j in range(fake_stdpSeq.shape[0]):
138 | for k in range(fake_stdpSeq.shape[1]):
139 | fake_stdpSeq[j,k] = fake_stdpSeq[j,k]*std_pose + mean_pose
140 | fake_stdpSeq = np.resize(fake_stdpSeq, (fake_stdpSeq.shape[0],32, 14, 2))
141 | for j in range(3):
142 | final_stdpSeq[96*t+32*j:96*t+32*(j+1)] = fake_stdpSeq[j]
143 |
144 | if args.modulate:
145 | final_stdpSeq = modulate.modulate(final_stdpSeq, beats, length)
146 |
147 | out_dir = args.out_dir
148 | if not os.path.exists(out_dir):
149 | os.mkdir(out_dir)
150 | utils.vis(final_stdpSeq, out_dir)
151 | sp.call('ffmpeg -r 15 -i {}/frame%03d.png -i {} -c:v libx264 -pix_fmt yuv420p -crf 23 -r 30 -y -strict -2 {}'.format(out_dir, args.aud_path, args.out_file), shell=True)
152 |
153 |
--------------------------------------------------------------------------------
/imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/.DS_Store
--------------------------------------------------------------------------------
/imgs/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/example.gif
--------------------------------------------------------------------------------
/imgs/long.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/long.gif
--------------------------------------------------------------------------------
/imgs/multimodal.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/multimodal.gif
--------------------------------------------------------------------------------
/imgs/v2v.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/v2v.gif
--------------------------------------------------------------------------------
/model_comp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 | import os
8 | import time
9 | import numpy as np
10 | import random
11 | import math
12 |
13 | import torch
14 | from torch import nn
15 | from torch.autograd import Variable
16 | import torch.optim as optim
17 | from torch.nn.utils import clip_grad_norm_
18 |
19 | from utils import Logger
20 |
21 | if torch.cuda.is_available():
22 | T = torch.cuda
23 | else:
24 | T = torch
25 |
26 | class Trainer_Comp(object):
27 | def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, movement_enc, initp_enc, stdp_dec, aud_enc, audstyle_enc, dance_reg=None, args=None, zdance_dis=None):
28 | self.data_loader = data_loader
29 | self.movement_enc = movement_enc
30 | self.initp_enc = initp_enc
31 | self.stdp_dec = stdp_dec
32 | self.dance_enc = dance_enc
33 | self.dance_dec = dance_dec
34 | self.danceAud_dis = danceAud_dis
35 | self.aud_enc = aud_enc
36 | self.audstyle_enc = audstyle_enc
37 | self.train = args.train
38 | self.args = args
39 |
40 | if args.train:
41 | self.zdance_dis = zdance_dis
42 | self.dance_reg = dance_reg
43 |
44 | self.logger = Logger(args.log_dir)
45 | self.logs = self.init_logs()
46 | self.log_interval = args.log_interval
47 | self.snapshot_ep = args.snapshot_ep
48 | self.snapshot_dir = args.snapshot_dir
49 |
50 | self.opt_dance_enc = torch.optim.Adam(self.dance_enc.parameters(), lr=args.lr)
51 | self.opt_dance_dec = torch.optim.Adam(self.dance_dec.parameters(), lr=args.lr)
52 | self.opt_danceAud_dis = torch.optim.Adam(self.danceAud_dis.parameters(), lr=args.lr)
53 | self.opt_audstyle_enc = torch.optim.Adam(self.audstyle_enc.parameters(), lr=args.lr)
54 | self.opt_zdance_dis = torch.optim.Adam(self.zdance_dis.parameters(), lr=args.lr)
55 | self.opt_dance_reg = torch.optim.Adam(self.dance_reg.parameters(), lr=args.lr)
56 |
57 | self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr*0.1)
58 | self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr*0.1)
59 |
60 | self.latent_dropout = nn.Dropout(p=args.latent_dropout)
61 | self.l1_criterion = torch.nn.L1Loss()
62 | self.gan_criterion = nn.BCEWithLogitsLoss()
63 | self.mse_criterion = nn.MSELoss().cuda()
64 |
65 | def init_logs(self):
66 | return {'l_kl_zdance':0, 'l_kl_zmovement':0, 'l_kl_fake_zdance':0, 'l_kl_fake_zmovement':0,
67 | 'l_l1_zmovement_mu':0, 'l_l1_zmovement_logvar':0, 'l_l1_stdpSeq':0, 'l_l1_zdance':0,
68 | 'l_dis':0, 'l_dis_true':0, 'l_dis_fake':0,
69 | 'l_info':0, 'l_info_real':0, 'l_info_fake':0,
70 | 'l_gen':0
71 | }
72 |
73 | def get_z_random(self, batchSize, nz, random_type='gauss'):
74 | z = torch.randn(batchSize, nz).cuda()
75 | return z
76 |
77 | @staticmethod
78 | def ones_like(tensor, val=1.):
79 | return T.FloatTensor(tensor.size()).fill_(val)
80 |
81 | @staticmethod
82 | def zeros_like(tensor, val=0.):
83 | return T.FloatTensor(tensor.size()).fill_(val)
84 | def kld_coef(self, i):
85 | return float(1/(1+np.exp(-0.0005*(i-15000))))
86 |
87 |
88 | def forward(self, stdpSeq, batchsize, aud_style, aud):
89 | self.aud = torch.mean(aud, dim=1)
90 |
91 | self.batchsize = batchsize
92 | self.stdpSeq = stdpSeq
93 | self.aud_style = aud_style
94 | ### stdpSeq -> z_inits, z_movements
95 | self.pose_0 = stdpSeq[:,0,:]
96 | self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
97 | z_init_std = self.z_init_logvar.mul(0.5).exp_()
98 | z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
99 | self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
100 |
101 | self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdpSeq)
102 | z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
103 | z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
104 | self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
105 | self.z_movementSeq_mu = self.z_movement_mus.view(batchsize, -1, self.z_movements.shape[1])
106 | self.z_movementSeq_logvar = self.z_movement_logvars.view(batchsize, -1, self.z_movements.shape[1])
107 |
108 | self.z_init, self.z_movements = self.z_init.detach(), self.z_movements.detach()
109 | self.z_movement_mus, self.z_movement_logvars = self.z_movement_mus.detach(), self.z_movement_logvars.detach()
110 |
111 | ### z_movements -> z_dance
112 | self.z_dance_mu, self.z_dance_logvar = self.dance_enc(self.z_movementSeq_mu, self.z_movementSeq_logvar)
113 | z_dance_std = self.z_dance_logvar.mul(0.5).exp_()
114 | z_dance_eps = self.get_z_random(z_dance_std.size(0), z_dance_std.size(1), 'gauss')
115 | self.z_dance = z_dance_eps.mul(z_dance_std).add_(self.z_dance_mu)
116 | ### z_dance -> z_movements
117 | self.recon_z_movements_mu, self.recon_z_movements_logvar = self.dance_dec(self.z_dance)
118 | recon_z_movement_std = self.recon_z_movements_logvar.mul(0.5).exp_()
119 | recon_z_movement_eps = self.get_z_random(recon_z_movement_std.size(0), recon_z_movement_std.size(1), 'gauss')
120 | self.recon_z_movements = recon_z_movement_eps.mul(recon_z_movement_std).add_(self.recon_z_movements_mu)
121 | ### z_movements -> stdpSeq
122 | self.recon_stdpSeq = self.stdp_dec(self.z_init, self.recon_z_movements)
123 |
124 | ### Music to z_dance to z_movements
125 | self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
126 | fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
127 | fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
128 | self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
129 | self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance)
130 | fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
131 | fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
132 | self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
133 |
134 | fake_z_movementSeq_mu = self.fake_z_movements_mu.view(batchsize, -1, self.fake_z_movements_mu.shape[1])
135 | fake_z_movementSeq_logvar = self.fake_z_movements_logvar.view(batchsize, -1, self.fake_z_movements_logvar.shape[1])
136 | self.fake_z_movementSeq = torch.cat((fake_z_movementSeq_mu, fake_z_movementSeq_logvar),2)
137 |
138 | def backward_D(self):
139 | #real_movements = torch.cat((self.z_movementSeq_mu, self.z_movementSeq_logvar),2)
140 | tmp_recon_mu = self.recon_z_movements_mu.view(self.batchsize, -1, self.z_movements.shape[1])
141 | tmp_recon_logvar = self.recon_z_movements_logvar.view(self.batchsize, -1, self.z_movements.shape[1])
142 | real_movements = torch.cat((tmp_recon_mu, tmp_recon_logvar),2)
143 | fake_movements = self.fake_z_movementSeq
144 |
145 | real_labels,_ = self.danceAud_dis(real_movements.detach(), self.aud)
146 | fake_labels,_ = self.danceAud_dis(fake_movements.detach(), self.aud)
147 |
148 | ones = self.ones_like(real_labels)
149 | zeros = self.zeros_like(fake_labels)
150 |
151 | self.loss_dis_true = self.gan_criterion(real_labels, ones)
152 | self.loss_dis_fake = self.gan_criterion(fake_labels, zeros)
153 | self.loss_dis = (self.loss_dis_true + self.loss_dis_fake)*self.args.lambda_gan
154 |
155 | real_dance = torch.cat((self.z_dance_mu, self.z_dance_logvar), 1)
156 | fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
157 | real_labels, _ = self.zdance_dis(real_dance.detach(), self.aud)
158 | fake_labels, _ = self.zdance_dis(fake_dance.detach(), self.aud)
159 | ones = self.ones_like(real_labels)
160 | zeros = self.zeros_like(fake_labels)
161 |
162 | self.loss_zdis_true = self.gan_criterion(real_labels, ones)
163 | self.loss_zdis_fake = self.gan_criterion(fake_labels, zeros)
164 | self.loss_dis += (self.loss_zdis_true + self.loss_zdis_fake)*self.args.lambda_gan
165 |
166 |
167 | def backward_danceED(self):
168 | # z_dance KL
169 | kl_element = self.z_dance_mu.pow(2).add_(self.z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.z_dance_logvar)
170 | self.loss_kl_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
171 | kl_element = self.fake_z_dance_mu.pow(2).add_(self.fake_z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_dance_logvar)
172 | self.loss_kl_fake_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
173 | # z_movement KL
174 | kl_element = self.recon_z_movements_mu.pow(2).add_(self.recon_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.recon_z_movements_logvar)
175 | self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
176 | kl_element = self.fake_z_movements_mu.pow(2).add_(self.fake_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_movements_logvar)
177 | self.loss_kl_fake_z_movements = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
178 | # z_movement reconstruction
179 | self.loss_l1_z_movement_mu = self.l1_criterion(self.recon_z_movements_mu, self.z_movement_mus) * self.args.lambda_zmovements_recon
180 | self.loss_l1_z_movement_logvar = self.l1_criterion(self.recon_z_movements_logvar, self.z_movement_logvars) * self.args.lambda_zmovements_recon
181 |
182 | # stdp reconstruction
183 | self.loss_l1_stdpSeq = self.l1_criterion(self.recon_stdpSeq, self.stdpSeq) * self.args.lambda_stdpSeq_recon
184 |
185 | # Music2Dance GAN
186 | fake_movements = self.fake_z_movementSeq
187 | fake_labels, _ = self.danceAud_dis(fake_movements, self.aud)
188 |
189 | ones = self.ones_like(fake_labels)
190 | self.loss_gen = self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
191 |
192 | fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
193 | fake_labels, _ = self.zdance_dis(fake_dance, self.aud)
194 | ones = self.ones_like(fake_labels)
195 | self.loss_gen += self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
196 |
197 | self.loss = self.loss_kl_z_movement + self.loss_kl_z_dance + self.loss_l1_z_movement_mu + self.loss_l1_z_movement_logvar + self.loss_l1_stdpSeq + self.loss_gen
198 |
199 | def backward_info_ondance(self):
200 | real_pred = self.dance_reg(self.z_dance)
201 | fake_pred = self.dance_reg(self.fake_z_dance)
202 | self.loss_info_real = self.mse_criterion(real_pred, self.aud_style)
203 | self.loss_info_fake = self.mse_criterion(fake_pred, self.aud_style)
204 | self.loss_info = self.loss_info_real + self.loss_info_fake
205 |
206 | def zero_grad(self, opt_list):
207 | for opt in opt_list:
208 | opt.zero_grad()
209 |
210 | def clip_norm(self, network_list):
211 | for network in network_list:
212 | clip_grad_norm_(network.parameters(), 0.5)
213 |
214 | def step(self, opt_list):
215 | for opt in opt_list:
216 | opt.step()
217 |
218 | def update(self):
219 | self.zero_grad([self.opt_danceAud_dis, self.opt_zdance_dis])
220 | self.backward_D()
221 | self.loss_dis.backward(retain_graph=True)
222 | self.clip_norm([self.danceAud_dis, self.zdance_dis])
223 | self.step([self.opt_danceAud_dis, self.opt_zdance_dis])
224 |
225 | self.zero_grad([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
226 | self.backward_danceED()
227 | self.loss.backward(retain_graph=True)
228 | self.clip_norm([self.dance_enc, self.dance_dec, self.audstyle_enc, self.stdp_dec])
229 | self.step([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
230 |
231 | self.zero_grad([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
232 | self.backward_info_ondance()
233 | self.loss_info.backward()
234 | self.clip_norm([self.dance_enc, self.audstyle_enc, self.dance_reg, self.stdp_dec])
235 | self.step([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
236 |
237 | def test_final(self, initpose, aud, n, thr=0):
238 | self.cuda()
239 | self.movement_enc.eval()
240 | self.stdp_dec.eval()
241 | self.initp_enc.eval()
242 | self.dance_enc.eval()
243 | self.dance_dec.eval()
244 | self.aud_enc.eval()
245 | self.audstyle_enc.eval()
246 | aud_style = self.aud_enc.get_style(aud).detach()
247 |
248 | self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
249 | fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
250 | fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
251 | self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
252 |
253 | self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance, length=3)
254 | fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
255 | fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
256 | self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
257 |
258 | fake_stdpSeq=[]
259 | for i in range(n):
260 | z_init_mus, z_init_logvars = self.initp_enc(initpose)
261 | z_init_stds = z_init_logvars.mul(0.5).exp_()
262 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
263 | z_init = z_init_epss.mul(z_init_stds).add_(z_init_mus)
264 | fake_stdp = self.stdp_dec(z_init, self.fake_z_movements[i:i+1])
265 | fake_stdpSeq.append(fake_stdp)
266 | initpose = fake_stdp[:,-1,:]
267 | fake_stdpSeq = torch.cat(fake_stdpSeq, dim=0)
268 | flag = False
269 | for i in range(n):
270 | s = fake_stdpSeq[i]
271 | diff = torch.abs(s[1:]-s[:-1])
272 | diffsum = torch.sum(diff)
273 | if diffsum.cpu().detach().numpy() < thr:
274 | flag = True
275 |
276 | if flag:
277 | return None
278 | else:
279 | return fake_stdpSeq.cpu().detach().numpy()
280 |
281 |
282 | def resume(self, model_dir, train=True):
283 | checkpoint = torch.load(model_dir)
284 | self.dance_enc.load_state_dict(checkpoint['dance_enc'])
285 | self.dance_dec.load_state_dict(checkpoint['dance_dec'])
286 | self.audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
287 | self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
288 | self.movement_enc.load_state_dict(checkpoint['movement_enc'])
289 | if train:
290 | self.danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
291 | self.dance_reg.load_state_dict(checkpoint['dance_reg'])
292 | self.opt_dance_enc.load_state_dict(checkpoint['opt_dance_enc'])
293 | self.opt_dance_dec.load_state_dict(checkpoint['opt_dance_dec'])
294 | self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
295 | self.opt_audstyle_enc.load_state_dict(checkpoint['opt_audstyle_enc'])
296 | self.opt_danceAud_dis.load_state_dict(checkpoint['opt_danceAud_dis'])
297 | self.opt_dance_reg.load_state_dict(checkpoint['opt_dance_reg'])
298 | return checkpoint['ep'], checkpoint['total_it']
299 |
300 | def save(self, filename, ep, total_it):
301 | state = {
302 | 'stdp_dec': self.stdp_dec.state_dict(),
303 | 'movement_enc': self.movement_enc.state_dict(),
304 | 'dance_enc': self.dance_enc.state_dict(),
305 | 'dance_dec': self.dance_dec.state_dict(),
306 | 'audstyle_enc': self.audstyle_enc.state_dict(),
307 | 'danceAud_dis': self.danceAud_dis.state_dict(),
308 | 'zdance_dis': self.zdance_dis.state_dict(),
309 | 'dance_reg': self.dance_reg.state_dict(),
310 | 'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
311 | 'opt_movement_enc': self.opt_movement_enc.state_dict(),
312 | 'opt_dance_enc': self.opt_dance_enc.state_dict(),
313 | 'opt_dance_dec': self.opt_dance_dec.state_dict(),
314 | 'opt_audstyle_enc': self.opt_audstyle_enc.state_dict(),
315 | 'opt_danceAud_dis': self.opt_danceAud_dis.state_dict(),
316 | 'opt_zdance_dis': self.opt_zdance_dis.state_dict(),
317 | 'opt_dance_reg': self.opt_dance_reg.state_dict(),
318 | 'ep': ep,
319 | 'total_it': total_it
320 | }
321 | torch.save(state, filename)
322 | return
323 |
324 | def cuda(self):
325 | if self.train:
326 | self.dance_reg.cuda()
327 | self.danceAud_dis.cuda()
328 | self.zdance_dis.cuda()
329 | self.stdp_dec.cuda()
330 | self.initp_enc.cuda()
331 | self.movement_enc.cuda()
332 | self.dance_enc.cuda()
333 | self.dance_dec.cuda()
334 | self.aud_enc.cuda()
335 | self.audstyle_enc.cuda()
336 | self.gan_criterion.cuda()
337 |
338 | def train(self, ep=0, it=0):
339 | self.cuda()
340 | for epoch in range(ep, self.args.num_epochs):
341 | self.movement_enc.train()
342 | self.stdp_dec.train()
343 | self.initp_enc.train()
344 | self.dance_enc.train()
345 | self.dance_dec.train()
346 | self.danceAud_dis.train()
347 | self.zdance_dis.train()
348 | self.audstyle_enc.train()
349 | self.dance_reg.train()
350 | self.aud_enc.eval()
351 | stdp_recon = 0
352 |
353 | for i, (stdpSeq, aud) in enumerate(self.data_loader):
354 | stdpSeq, aud = stdpSeq.cuda().detach(), aud.cuda().detach()
355 | stdpSeq = stdpSeq.view(stdpSeq.shape[0]*stdpSeq.shape[1], stdpSeq.shape[2], stdpSeq.shape[3])
356 | aud_style = self.aud_enc.get_style(aud).detach()
357 |
358 | self.forward(stdpSeq, aud.shape[0], aud_style, aud)
359 | self.update()
360 | self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
361 | self.logs['l_kl_zdance'] += self.loss_kl_z_dance.data
362 | self.logs['l_l1_zmovement_mu'] += self.loss_l1_z_movement_mu.data
363 | self.logs['l_l1_zmovement_logvar'] += self.loss_l1_z_movement_logvar.data
364 | self.logs['l_l1_stdpSeq'] += self.loss_l1_stdpSeq.data
365 | self.logs['l_kl_fake_zdance'] += self.loss_kl_fake_z_dance.data
366 | self.logs['l_kl_fake_zmovement'] += self.loss_kl_fake_z_movements
367 | self.logs['l_dis'] += self.loss_dis.data
368 | self.logs['l_dis_true'] += self.loss_dis_true.data
369 | self.logs['l_dis_fake'] += self.loss_dis_fake.data
370 | self.logs['l_gen'] += self.loss_gen.data
371 | self.logs['l_info'] += self.loss_info
372 | self.logs['l_info_real'] += self.loss_info_real
373 | self.logs['l_info_fake'] += self.loss_info_fake
374 |
375 | print('Epoch:{:3} Iter{}/{}\tl_l1_zmovement mu{:.3f} logvar{:.3f}\tl_l1_stdpSeq {:.3f}\tl_kl_dance {:.3f}\tl_kl_movement {:.3f}\n'.format(epoch, i, len(self.data_loader),
376 | self.loss_l1_z_movement_mu, self.loss_l1_z_movement_logvar, self.loss_l1_stdpSeq, self.loss_kl_z_dance, self.loss_kl_z_movement) +
377 | '\t\t\tl_kl_f_dance {:.3f}\tl_dis {:.3f} {:.3f}\tl_gen {:.3f}'.format(self.loss_kl_fake_z_dance, self.loss_dis_true, self.loss_dis_fake, self.loss_gen))
378 |
379 | it += 1
380 | if it % self.log_interval == 0:
381 | for tag, value in self.logs.items():
382 | self.logger.scalar_summary(tag, value/self.log_interval, it)
383 | self.logs = self.init_logs()
384 | if epoch % self.snapshot_ep == 0:
385 | self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
386 |
--------------------------------------------------------------------------------
/model_decomp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 | import os
8 | import time
9 | import numpy as np
10 | import random
11 | import math
12 |
13 | import torch
14 | from torch import nn
15 | from torch.autograd import Variable
16 | import torch.optim as optim
17 | from torch.nn.utils import clip_grad_norm_
18 |
19 | from utils import Logger
20 |
21 | if torch.cuda.is_available():
22 | T = torch.cuda
23 | else:
24 | T = torch
25 |
26 | class Trainer_Decomp(object):
27 | def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, stdp_dec, args=None):
28 | self.data_loader = data_loader
29 | self.initp_enc = initp_enc
30 | self.initp_dec = initp_dec
31 | self.movement_enc = movement_enc
32 | self.stdp_dec = stdp_dec
33 |
34 |
35 | self.args = args
36 | if args.train:
37 | self.logger = Logger(args.log_dir)
38 | self.logs = self.init_logs()
39 | self.log_interval = args.log_interval
40 | self.snapshot_ep = args.snapshot_ep
41 | self.snapshot_dir = args.snapshot_dir
42 |
43 | self.opt_initp_enc = torch.optim.Adam(self.initp_enc.parameters(), lr=args.lr)
44 | self.opt_initp_dec = torch.optim.Adam(self.initp_dec.parameters(), lr=args.lr)
45 | self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr)
46 | self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr)
47 |
48 | self.latent_dropout = nn.Dropout(p=args.latent_dropout)
49 | self.l1_criterion = torch.nn.L1Loss()
50 | self.gan_criterion = nn.BCEWithLogitsLoss()
51 |
52 |
53 | def init_logs(self):
54 | return {'l_kl_zinit':0, 'l_kl_zmovement':0, 'l_l1_stdp':0, 'l_l1_cross_stdp':0, 'l_dist_zmovement':0,
55 | 'l_l1_initp':0, 'l_l1_initp_con':0,
56 | 'kld_coef':0
57 | }
58 |
59 | def get_z_random(self, batchSize, nz, random_type='gauss'):
60 | z = torch.randn(batchSize, nz).cuda()
61 | return z
62 |
63 | @staticmethod
64 | def ones_like(tensor, val=1.):
65 | return T.FloatTensor(tensor.size()).fill_(val)
66 |
67 | @staticmethod
68 | def zeros_like(tensor, val=0.):
69 | return T.FloatTensor(tensor.size()).fill_(val)
70 |
71 |
72 | def random_generate_stdp(self, init_p):
73 | self.pose_0 = init_p
74 | self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
75 | z_init_std = self.z_init_logvar.mul(0.5).exp_()
76 | z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
77 | self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
78 | self.z_random_movement = self.get_z_random(self.z_init.size(0), 512, 'gauss')
79 | self.fake_stdpose = self.stdp_dec(self.z_init, self.z_random_movement)
80 | return self.fake_stdpose
81 |
82 | def forward(self, stdpose1, stdpose2):
83 | self.stdpose1 = stdpose1
84 | self.stdpose2 = stdpose2
85 |
86 | # stdpose -> stdpose[0] -> z_init
87 | self.pose1_0 = stdpose1[:,0,:]
88 | self.pose2_0 = stdpose2[:,0,:]
89 | self.poses_0 = torch.cat((self.pose1_0, self.pose2_0), 0)
90 | self.z_init_mus, self.z_init_logvars = self.initp_enc(self.poses_0)
91 | z_init_stds = self.z_init_logvars.mul(0.5).exp_()
92 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
93 | self.z_inits = z_init_epss.mul(z_init_stds).add_(self.z_init_mus)
94 | self.z_init1, self.z_init2 = torch.split(self.z_inits, self.stdpose1.size(0), dim=0)
95 |
96 | # stdpose -> z_movement
97 | stdposes = torch.cat((stdpose1, stdpose2), 0)
98 | self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdposes)
99 | z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
100 | z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
101 | self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
102 | self.z_movement1, self.z_movement2 = torch.split(self.z_movements, self.stdpose1.size(0), dim=0)
103 |
104 | # zinit1+zmovement1->stdpose1 zinit2+zmovement2->stdpose2
105 | self.recon_stdpose1 = self.stdp_dec(self.z_init1, self.z_movement1)
106 | self.recon_stdpose2 = self.stdp_dec(self.z_init2, self.z_movement2)
107 |
108 | # zinit1+zmovement2->stdpose1 zinit2+zmovement1->stdpose2
109 | self.recon_stdpose1_cross = self.stdp_dec(self.z_init1, self.z_movement2)
110 | self.recon_stdpose2_cross = self.stdp_dec(self.z_init2, self.z_movement1)
111 |
112 | # z_init -> \hat{stdpose[0]}
113 | self.recon_pose1_0 = self.initp_dec(self.z_init1)
114 | self.recon_pose2_0 = self.initp_dec(self.z_init2)
115 |
116 | # single pose reconstruction
117 | randomlist = np.random.permutation(31)[:4]
118 | singlepose = []
119 | for r in randomlist:
120 | singlepose.append(self.stdpose1[:,r,:])
121 | self.singleposes = torch.cat(singlepose, dim=0).detach()
122 | self.z_single_mus, self.z_single_logvars = self.initp_enc(self.singleposes)
123 | z_single_stds = self.z_single_logvars.mul(0.5).exp_()
124 | z_single_epss = self.get_z_random(z_single_stds.size(0), z_single_stds.size(1), 'gauss')
125 | z_single = z_single_epss.mul(z_single_stds).add_(self.z_single_mus)
126 | self.recon_singleposes = self.initp_dec(z_single)
127 |
128 | def backward_initp_ED(self):
129 | # z_init KL
130 | kl_element = self.z_init_mus.pow(2).add_(self.z_init_logvars.exp()).mul_(-1).add_(1).add_(self.z_init_logvars)
131 | self.loss_kl_z_init = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
132 |
133 | # initpose reconstruction
134 | self.loss_l1_initp = self.l1_criterion(self.recon_singleposes, self.singleposes) * self.args.lambda_initp_recon
135 |
136 | self.loss_initp = self.loss_kl_z_init + self.loss_l1_initp
137 |
138 | def backward_movement_ED(self):
139 | # z_movement KL
140 | kl_element = self.z_movement_mus.pow(2).add_(self.z_movement_logvars.exp()).mul_(-1).add_(1).add_(self.z_movement_logvars)
141 | #self.loss_kl_z_movement = torch.mean(kl_element).mul_(-0.5) * self.args.lambda_kl
142 | self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
143 |
144 | # stdpose self reconstruction
145 | loss_l1_stdp1 = self.l1_criterion(self.recon_stdpose1, self.stdpose1) * self.args.lambda_stdp_recon
146 | loss_l1_stdp2 = self.l1_criterion(self.recon_stdpose2, self.stdpose2) * self.args.lambda_stdp_recon
147 | self.loss_l1_stdp = loss_l1_stdp1 + loss_l1_stdp2
148 |
149 | # stdpose cross reconstruction
150 | loss_l1_cross_stdp1 = self.l1_criterion(self.recon_stdpose1_cross, self.stdpose1) * self.args.lambda_stdp_recon
151 | loss_l1_cross_stdp2 = self.l1_criterion(self.recon_stdpose2_cross, self.stdpose2) * self.args.lambda_stdp_recon
152 | self.loss_l1_cross_stdp = loss_l1_cross_stdp1 + loss_l1_cross_stdp2
153 |
154 | # Movement dist
155 | self.loss_dist_z_movement = torch.mean(torch.abs(self.z_movement1-self.z_movement2)) * self.args.lambda_dist_z_movement
156 |
157 | self.loss_movement = self.loss_kl_z_movement + self.loss_l1_stdp + self.loss_l1_cross_stdp + self.loss_dist_z_movement
158 |
159 |
160 | def update(self):
161 | self.opt_initp_enc.zero_grad()
162 | self.opt_initp_dec.zero_grad()
163 | self.opt_movement_enc.zero_grad()
164 | self.opt_stdp_dec.zero_grad()
165 | self.backward_initp_ED()
166 | self.backward_movement_ED()
167 | self.g_loss = self.loss_initp + self.loss_movement
168 | self.g_loss.backward(retain_graph=True)
169 | clip_grad_norm_(self.movement_enc.parameters(), 0.5)
170 | clip_grad_norm_(self.stdp_dec.parameters(), 0.5)
171 | self.opt_initp_enc.step()
172 | self.opt_initp_dec.step()
173 | self.opt_movement_enc.step()
174 | self.opt_stdp_dec.step()
175 |
176 |
177 | def save(self, filename, ep, total_it):
178 | state = {
179 | 'stdp_dec': self.stdp_dec.state_dict(),
180 | 'movement_enc': self.movement_enc.state_dict(),
181 | 'initp_enc': self.initp_enc.state_dict(),
182 | 'initp_dec': self.initp_dec.state_dict(),
183 | 'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
184 | 'opt_movement_enc': self.opt_movement_enc.state_dict(),
185 | 'opt_initp_enc': self.opt_initp_enc.state_dict(),
186 | 'opt_initp_dec': self.opt_initp_dec.state_dict(),
187 | 'ep': ep,
188 | 'total_it': total_it
189 | }
190 | torch.save(state, filename)
191 | return
192 |
193 | def resume(self, model_dir, train=True):
194 | checkpoint = torch.load(model_dir)
195 | # weight
196 | self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
197 | self.movement_enc.load_state_dict(checkpoint['movement_enc'])
198 | self.initp_enc.load_state_dict(checkpoint['initp_enc'])
199 | self.initp_dec.load_state_dict(checkpoint['initp_dec'])
200 | # optimizer
201 | if train:
202 | self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
203 | self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc'])
204 | self.opt_initp_enc.load_state_dict(checkpoint['opt_initp_enc'])
205 | self.opt_initp_dec.load_state_dict(checkpoint['opt_initp_dec'])
206 | return checkpoint['ep'], checkpoint['total_it']
207 |
208 | def kld_coef(self, i):
209 | return float(1/(1+np.exp(-0.0005*(i-15000)))) #v3
210 |
211 |
212 | def generate_stdp_sequence(self, initpose, aud, num_stdp):
213 | self.initp_enc.cuda()
214 | self.initp_dec.cuda()
215 | self.movement_enc.cuda()
216 | self.stdp_dec.cuda()
217 | self.initp_enc.eval()
218 | self.initp_dec.eval()
219 | self.movement_enc.eval()
220 | self.stdp_dec.eval()
221 | initpose = initpose.cuda()
222 |
223 | aud_style = self.aud_enc.get_style(aud)
224 |
225 | stdp_seq = []
226 | cnt = 0
227 | #for i in range(num_stdp):
228 | while not cnt == num_stdp:
229 | if cnt==0:
230 | z_inits = self.get_z_random(1, 10, 'gauss')
231 | else:
232 | z_init_mus, z_init_logvars = self.initp_enc(initpose)
233 | z_init_stds = z_init_logvars.mul(0.5).exp_()
234 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
235 | z_inits = z_init_epss.mul(z_init_stds).add_(z_init_mus)
236 |
237 | z_audstyle_mu, z_audstyle_logvar = self.audstyle_enc(aud_style)
238 | z_as_std = z_audstyle_logvar.mul(0.5).exp_()
239 | z_as_eps = self.get_z_random(z_as_std.size(0), z_as_std.size(1), 'gauss')
240 | z_audstyle = z_as_eps.mul(z_as_std).add_(z_audstyle_mu)
241 | if random.randint(0,5)==100:
242 | z_audstyle = self.get_z_random(z_inits.size(0), 512, 'gauss')
243 |
244 | fake_stdpose = self.stdp_dec(z_inits, z_audstyle)
245 |
246 | s = fake_stdpose[0]
247 | diff = torch.abs(s[1:]-s[:-1])
248 | diffsum = torch.sum(diff)
249 | if diffsum.cpu().detach().numpy() < 70:
250 | continue
251 |
252 | cnt += 1
253 | stdp_seq.append(fake_stdpose.cpu().detach().numpy())
254 | initpose = fake_stdpose[:,-1,:]
255 | return stdp_seq
256 |
257 |
258 | def cuda(self):
259 | self.initp_enc.cuda()
260 | self.initp_dec.cuda()
261 | self.movement_enc.cuda()
262 | self.stdp_dec.cuda()
263 | self.l1_criterion.cuda()
264 |
265 | def train(self, ep=0, it=0):
266 | self.cuda()
267 |
268 | full_kl = self.args.lambda_kl
269 | kl_w = 0
270 | kl_step = 0.05
271 | best_stdp_recon = 100
272 | for epoch in range(ep, self.args.num_epochs):
273 | self.initp_enc.train()
274 | self.initp_dec.train()
275 | self.movement_enc.train()
276 | self.stdp_dec.train()
277 | stdp_recon = 0
278 | for i, (stdpose, stdpose2) in enumerate(self.data_loader):
279 | self.args.lambda_kl = full_kl*self.kld_coef(it)
280 | stdpose, stdpose2 = stdpose.cuda().detach(), stdpose2.cuda().detach()
281 |
282 | self.forward(stdpose, stdpose2)
283 | self.update()
284 | self.logs['l_kl_zinit'] += self.loss_kl_z_init.data
285 | self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
286 | self.logs['l_l1_initp'] += self.loss_l1_initp.data
287 | self.logs['l_l1_stdp'] += self.loss_l1_stdp.data
288 | self.logs['l_l1_cross_stdp'] += self.loss_l1_cross_stdp.data
289 | self.logs['l_dist_zmovement'] += self.loss_dist_z_movement.data
290 | self.logs['kld_coef'] += self.args.lambda_kl
291 |
292 | print('Epoch:{:3} Iter{}/{}\tl_l1_initp {:.3f}\tl_l1_stdp {:.3f}\tl_l1_cross_stdp {:.3f}\tl_dist_zmove {:.3f}\tl_kl_zinit {:.3f}\t l_kl_zmove {:.3f}'.format(
293 | epoch, i, len(self.data_loader), self.loss_l1_initp, self.loss_l1_stdp, self.loss_l1_cross_stdp, self.loss_dist_z_movement, self.loss_kl_z_init, self.loss_kl_z_movement))
294 |
295 | it += 1
296 | if it % self.log_interval == 0:
297 | for tag, value in self.logs.items():
298 | self.logger.scalar_summary(tag, value/self.log_interval, it)
299 | self.logs = self.init_logs()
300 | if epoch % self.snapshot_ep == 0:
301 | self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
302 |
--------------------------------------------------------------------------------
/modulate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import os
9 | import numpy as np
10 | import librosa
11 | import utils
12 |
13 |
14 | def modulate(dance, beats, length):
15 | sec_interframe = 1/15
16 |
17 | beats_frame = np.around(beats)
18 | t_beat = beats_frame.astype(int)
19 | s_beat = np.arange(3,dance.shape[0],8)
20 | final_pose = np.zeros((length, 14, 2))
21 |
22 | if t_beat[0] >3:
23 | final_pose[t_beat[0]-3:t_beat[0]] = dance[:3]
24 | else:
25 | final_pose[:t_beat[0]] = dance[:t_beat[0]]
26 | if t_beat[0]-3 > 0:
27 | final_pose[:t_beat[0]-3] = dance[0]
28 | for t in range(t_beat.shape[0]-1):
29 | begin = int(t_beat[t])
30 | end = int(t_beat[t+1])
31 | interval = end-begin
32 | if t==s_beat.shape[0]-1:
33 | rest = min(final_pose.shape[0]-begin-1, dance.shape[0]-s_beat[t]-1)
34 | break
35 | if t+1 < s_beat.shape[0] and s_beat[t+1]=3:
38 | final_pose[begin-s_beat[t]:begin] = dance[:s_beat[t]]
39 | final_pose[begin:end+1]=pose
40 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
41 | else:
42 | end = begin
43 | if t+1 < s_beat.shape[0]:
44 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
45 | else:
46 | print(t_beat.shape, s_beat.shape, t)
47 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t]-1)
48 | if rest > 0:
49 | if t+1 < s_beat.shape[0]:
50 | final_pose[end+1:end+1+rest] = dance[s_beat[t+1]+1:s_beat[t+1]+1+rest]
51 | else:
52 | final_pose[end+1:end+1+rest] = dance[s_beat[t]+1:s_beat[t]+1+rest]
53 |
54 | return final_pose
55 |
56 | def get_pose(pose, n):
57 | t_pose = np.zeros((n, 14, 2))
58 | if n==11:
59 | t_pose[0] = pose[0]
60 | t_pose[1] = (pose[0]*1+pose[1]*4)/5
61 | t_pose[2] = (pose[1]*2+pose[2]*3)/5
62 | t_pose[3] = (pose[2]*3+pose[3]*2)/5
63 | t_pose[4] = (pose[3]*4+pose[4]*1)/5
64 | t_pose[5] = pose[4]
65 | t_pose[6] = (pose[4]*1+pose[5]*4)/5
66 | t_pose[7] = (pose[5]*2+pose[6]*3)/5
67 | t_pose[8] = (pose[6]*3+pose[7]*2)/5
68 | t_pose[9] = (pose[7]*4+pose[8]*1)/5
69 | t_pose[10] = pose[8]
70 | elif n==10:
71 | t_pose[0] = pose[0]
72 | t_pose[1] = (pose[0]*1+pose[1]*8)/9
73 | t_pose[2] = (pose[1]*2+pose[2]*7)/9
74 | t_pose[3] = (pose[2]*3+pose[3]*6)/9
75 | t_pose[4] = (pose[3]*4+pose[4]*5)/9
76 | t_pose[5] = (pose[4]*5+pose[5]*4)/9
77 | t_pose[6] = (pose[5]*6+pose[6]*3)/9
78 | t_pose[7] = (pose[6]*7+pose[7]*2)/9
79 | t_pose[8] = (pose[7]*8+pose[8]*1)/9
80 | t_pose[9] = pose[8]
81 | elif n==12:
82 | t_pose[0] = pose[0]
83 | t_pose[1] = (pose[0]*3+pose[1]*8)/11
84 | t_pose[2] = (pose[1]*6+pose[2]*5)/11
85 | t_pose[3] = (pose[2]*9+pose[3]*2)/11
86 | t_pose[4] = (pose[2]*1+pose[3]*10)/11
87 | t_pose[5] = (pose[3]*4+pose[4]*7)/11
88 | t_pose[6] = (pose[4]*7+pose[5]*4)/11
89 | t_pose[7] = (pose[5]*10+pose[6]*1)/11
90 | t_pose[8] = (pose[5]*2+pose[6]*9)/11
91 | t_pose[9] = (pose[6]*5+pose[7]*6)/11
92 | t_pose[10] = (pose[7]*8+pose[8]*3)/11
93 | t_pose[11] = pose[8]
94 | elif n==13:
95 | t_pose[0] = pose[0]
96 | t_pose[1] = (pose[0]*1+pose[1]*2)/3
97 | t_pose[2] = (pose[1]*2+pose[2]*1)/3
98 | t_pose[3] = pose[2]
99 | t_pose[4] = (pose[2]*1+pose[3]*2)/3
100 | t_pose[5] = (pose[3]*2+pose[4]*1)/3
101 | t_pose[6] = pose[4]
102 | t_pose[7] = (pose[4]*1+pose[5]*2)/3
103 | t_pose[8] = (pose[5]*2+pose[6]*1)/3
104 | t_pose[9] = pose[6]
105 | t_pose[10] = (pose[6]*1+pose[7]*2)/3
106 | t_pose[11] = (pose[7]*2+pose[8]*1)/3
107 | t_pose[12] = pose[8]
108 | elif n==14:
109 | t_pose[0] = pose[0]
110 | t_pose[1] = (pose[0]*5+pose[1]*8)/13
111 | t_pose[2] = (pose[1]*10+pose[2]*3)/13
112 | t_pose[3] = (pose[1]*2+pose[2]*11)/13
113 | t_pose[4] = (pose[2]*7+pose[3]*6)/13
114 | t_pose[5] = (pose[3]*12+pose[4]*1)/13
115 | t_pose[6] = (pose[3]*4+pose[4]*9)/13
116 | t_pose[7] = (pose[4]*9+pose[5]*4)/13
117 | t_pose[8] = (pose[4]*12+pose[5]*1)/13
118 | t_pose[9] = (pose[5]*6+pose[6]*7)/13
119 | t_pose[10] = (pose[6]*11+pose[7]*2)/13
120 | t_pose[11] = (pose[6]*3+pose[7]*10)/13
121 | t_pose[12] = (pose[7]*8+pose[8]*5)/13
122 | t_pose[13] = pose[8]
123 | elif n==9:
124 | t_pose = pose
125 | elif n==8:
126 | t_pose[0] = pose[0]
127 | t_pose[1] = (pose[1]*6+pose[2]*1)/7
128 | t_pose[2] = (pose[2]*5+pose[3]*2)/7
129 | t_pose[3] = (pose[3]*4+pose[4]*3)/7
130 | t_pose[4] = (pose[4]*3+pose[5]*4)/7
131 | t_pose[5] = (pose[5]*2+pose[6]*5)/7
132 | t_pose[6] = (pose[6]*1+pose[7]*6)/7
133 | t_pose[7] = pose[8]
134 | elif n==7:
135 | t_pose[0] = pose[0]
136 | t_pose[1] = (pose[1]*2+pose[2]*1)/3
137 | t_pose[2] = (pose[2]*1+pose[3]*2)/3
138 | t_pose[3] = pose[4]
139 | t_pose[4] = (pose[5]*2+pose[6]*1)/3
140 | t_pose[5] = (pose[6]*1+pose[7]*2)/3
141 | t_pose[6] = pose[8]
142 | elif n==6:
143 | t_pose[0] = pose[0]
144 | t_pose[1] = (pose[1]*2+pose[2]*3)/5
145 | t_pose[2] = (pose[3]*4+pose[4]*1)/5
146 | t_pose[3] = (pose[4]*1+pose[5]*4)/5
147 | t_pose[4] = (pose[6]*3+pose[7]*2)/5
148 | t_pose[5] = pose[8]
149 | elif n<6:
150 | t_pose[0] = pose[0]
151 | t_pose[n-1] = pose[8]
152 | for i in range(1,n-1):
153 | t_pose[i] = pose[4]
154 | elif n>14:
155 | t_pose[0] = pose[0]
156 | t_pose[n-1] = pose[8]
157 | for i in range(1, n-1):
158 | k = int(8/(n-1)*i)
159 | t_pose[i] = t_pose[k]
160 | else:
161 | print('NOT IMPLEMENT {}'.format(n))
162 |
163 | return t_pose
164 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.utils.data
12 | from torch.autograd import Variable
13 |
14 | import numpy as np
15 |
16 | if torch.cuda.is_available():
17 | T = torch.cuda
18 | else:
19 | T = torch
20 |
21 | ###########################################################
22 | ##########
23 | ########## Stage 1: Movement
24 | ##########
25 | ###########################################################
26 | class InitPose_Enc(nn.Module):
27 | def __init__(self, pose_size, dim_z_init):
28 | super(InitPose_Enc, self).__init__()
29 | nf = 64
30 | #nf = 32
31 | self.enc = nn.Sequential(
32 | nn.Linear(pose_size, nf),
33 | nn.LayerNorm(nf),
34 | nn.LeakyReLU(0.2, inplace=True),
35 | nn.Linear(nf, nf),
36 | nn.LayerNorm(nf),
37 | nn.LeakyReLU(0.2, inplace=True),
38 | )
39 | self.mean = nn.Sequential(
40 | nn.Linear(nf,dim_z_init),
41 | )
42 | self.std = nn.Sequential(
43 | nn.Linear(nf,dim_z_init),
44 | )
45 | def forward(self, pose):
46 | enc = self.enc(pose)
47 | return self.mean(enc), self.std(enc)
48 |
49 | class InitPose_Dec(nn.Module):
50 | def __init__(self, pose_size, dim_z_init):
51 | super(InitPose_Dec, self).__init__()
52 | nf = 64
53 | #nf = dim_z_init
54 | self.dec = nn.Sequential(
55 | nn.Linear(dim_z_init, nf),
56 | nn.LayerNorm(nf),
57 | nn.LeakyReLU(0.2, inplace=True),
58 | nn.Linear(nf, nf),
59 | nn.LayerNorm(nf),
60 | nn.LeakyReLU(0.2, inplace=True),
61 | nn.Linear(nf,pose_size),
62 | )
63 | def forward(self, z_init):
64 | return self.dec(z_init)
65 |
66 | class Movement_Enc(nn.Module):
67 | def __init__(self, pose_size, dim_z_movement, length, hidden_size, num_layers, bidirection=False):
68 | super(Movement_Enc, self).__init__()
69 | self.hidden_size = hidden_size
70 | self.bidirection = bidirection
71 | if bidirection:
72 | self.num_dir = 2
73 | else:
74 | self.num_dir = 1
75 | self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
76 | self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
77 | if bidirection:
78 | self.mean = nn.Sequential(
79 | nn.Linear(hidden_size*2,dim_z_movement),
80 | )
81 | self.std = nn.Sequential(
82 | nn.Linear(hidden_size*2,dim_z_movement),
83 | )
84 | else:
85 | '''
86 | self.enc = nn.Sequential(
87 | nn.Linear(hidden_size, hidden_size//2),
88 | nn.LayerNorm(hidden_size//2),
89 | nn.ReLU(inplace=True),
90 | )
91 | '''
92 | self.mean = nn.Sequential(
93 | nn.Linear(hidden_size,dim_z_movement),
94 | )
95 | self.std = nn.Sequential(
96 | nn.Linear(hidden_size,dim_z_movement),
97 | )
98 | def forward(self, poses):
99 | num_samples = poses.shape[0]
100 | h_t = [self.init_h.repeat(1, num_samples, 1)]
101 | output, hidden = self.recurrent(poses, h_t[0])
102 | if self.bidirection:
103 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
104 | else:
105 | output = output[:,-1,:]
106 | #enc = self.enc(output)
107 | #return self.mean(enc), self.std(enc)
108 | return self.mean(output), self.std(output)
109 |
110 | def getFeature(self, poses):
111 | num_samples = poses.shape[0]
112 | h_t = [self.init_h.repeat(1, num_samples, 1)]
113 | output, hidden = self.recurrent(poses, h_t[0])
114 | if self.bidirection:
115 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
116 | else:
117 | output = output[:,-1,:]
118 | return output
119 |
120 | class StandardPose_Dec(nn.Module):
121 | def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidden_size, num_layers):
122 | super(StandardPose_Dec, self).__init__()
123 | self.length = length
124 | self.pose_size = pose_size
125 | self.hidden_size = hidden_size
126 | self.num_layers = num_layers
127 | #dim_z_init=0
128 | '''
129 | self.z2init = nn.Sequential(
130 | nn.Linear(dim_z_init+dim_z_movement, hidden_size),
131 | nn.LayerNorm(hidden_size),
132 | nn.ReLU(True),
133 | nn.Linear(hidden_size, num_layers*hidden_size)
134 | )
135 | '''
136 | self.z2init = nn.Sequential(
137 | nn.Linear(dim_z_init+dim_z_movement, num_layers*hidden_size)
138 | )
139 | self.recurrent = nn.GRU(dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True)
140 | self.pose_g = nn.Sequential(
141 | nn.Linear(hidden_size, hidden_size),
142 | nn.LayerNorm(hidden_size),
143 | nn.ReLU(True),
144 | nn.Linear(hidden_size, pose_size)
145 | )
146 |
147 | def forward(self, z_init, z_movement):
148 | h_init = self.z2init(torch.cat((z_init, z_movement), 1))
149 | #h_init = self.z2init(z_movement)
150 | h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
151 | z_movements = z_movement.view(z_movement.size(0),1,z_movement.size(1)).repeat(1, self.length, 1)
152 | z_m_t, _ = self.recurrent(z_movements, h_init)
153 | z_m = z_m_t.contiguous().view(-1, self.hidden_size)
154 | poses = self.pose_g(z_m)
155 | poses = poses.view(z_movement.shape[0], self.length, self.pose_size)
156 | return poses
157 |
158 | class StandardPose_Dis(nn.Module):
159 | def __init__(self, pose_size, length):
160 | super(StandardPose_Dis, self).__init__()
161 | self.pose_size = pose_size
162 | self.length = length
163 | nd = 1024
164 | self.main = nn.Sequential(
165 | nn.Linear(length*pose_size, nd),
166 | nn.LayerNorm(nd),
167 | nn.LeakyReLU(0.2, inplace=True),
168 | nn.Linear(nd,nd//2),
169 | nn.LayerNorm(nd//2),
170 | nn.LeakyReLU(0.2, inplace=True),
171 | nn.Linear(nd//2,nd//4),
172 | nn.LayerNorm(nd//4),
173 | nn.LeakyReLU(0.2, inplace=True),
174 | nn.Linear(nd//4, 1)
175 | )
176 | def forward(self, pose_seq):
177 | pose_seq = pose_seq.view(-1, self.pose_size*self.length)
178 | return self.main(pose_seq).squeeze()
179 |
180 | ###########################################################
181 | ##########
182 | ########## Stage 2: Dance
183 | ##########
184 | ###########################################################
185 | class Dance_Enc(nn.Module):
186 | def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layers, bidirection=False):
187 | super(Dance_Enc, self).__init__()
188 | self.hidden_size = hidden_size
189 | self.bidirection = bidirection
190 | if bidirection:
191 | self.num_dir = 2
192 | else:
193 | self.num_dir = 1
194 | self.recurrent = nn.GRU(2*dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
195 | self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
196 | if bidirection:
197 | self.mean = nn.Sequential(
198 | nn.Linear(hidden_size*2,dim_z_dance),
199 | )
200 | self.std = nn.Sequential(
201 | nn.Linear(hidden_size*2,dim_z_dance),
202 | )
203 | else:
204 | self.mean = nn.Sequential(
205 | nn.Linear(hidden_size,dim_z_dance),
206 | )
207 | self.std = nn.Sequential(
208 | nn.Linear(hidden_size,dim_z_dance),
209 | )
210 | def forward(self, movements_mean, movements_std):
211 | movements = torch.cat((movements_mean, movements_std),2)
212 | num_samples = movements.shape[0]
213 | h_t = [self.init_h.repeat(1, num_samples, 1)]
214 | output, hidden = self.recurrent(movements, h_t[0])
215 | if self.bidirection:
216 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
217 | else:
218 | output = output[:,-1,:]
219 | return self.mean(output), self.std(output)
220 |
221 | class Dance_Dec(nn.Module):
222 | def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers):
223 | super(Dance_Dec, self).__init__()
224 | #self.length = length
225 | self.num_layers = num_layers
226 | self.hidden_size = hidden_size
227 | self.dim_z_movement = dim_z_movement
228 | #dim_z_init=0
229 | '''
230 | self.z2init = nn.Sequential(
231 | nn.Linear(dim_z_init+dim_z_movement, hidden_size),
232 | nn.LayerNorm(hidden_size),
233 | nn.ReLU(True),
234 | nn.Linear(hidden_size, num_layers*hidden_size)
235 | )
236 | '''
237 | self.z2init = nn.Sequential(
238 | nn.Linear(dim_z_dance, num_layers*hidden_size)
239 | )
240 | self.recurrent = nn.GRU(dim_z_dance, hidden_size, num_layers=num_layers, batch_first=True)
241 | self.movement_g = nn.Sequential(
242 | nn.Linear(hidden_size, hidden_size),
243 | nn.LayerNorm(hidden_size),
244 | nn.ReLU(True),
245 | #nn.Linear(hidden_size, dim_z_movement)
246 | )
247 | self.mean = nn.Sequential(
248 | nn.Linear(hidden_size,dim_z_movement),
249 | )
250 | self.std = nn.Sequential(
251 | nn.Linear(hidden_size,dim_z_movement),
252 | )
253 |
254 | def forward(self, z_dance, length=3):
255 | h_init = self.z2init(z_dance)
256 | h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
257 | z_dance = z_dance.view(z_dance.size(0),1,z_dance.size(1)).repeat(1, length, 1)
258 | z_d_t, _ = self.recurrent(z_dance, h_init)
259 | z_d = z_d_t.contiguous().view(-1, self.hidden_size)
260 | z_movement = self.movement_g(z_d)
261 | z_movement_mean, z_movement_std = self.mean(z_movement), self.std(z_movement)
262 | #z_movement = z_movement.view(z_dance.shape[0], length, self.dim_z_movement)
263 | return z_movement_mean, z_movement_std
264 |
265 |
266 | class DanceAud_Dis2(nn.Module):
267 | def __init__(self, aud_size, dim_z_movement, length=3):
268 | super(DanceAud_Dis2, self).__init__()
269 | self.aud_size = aud_size
270 | self.dim_z_movement = dim_z_movement
271 | self.length = length
272 | nd = 1024
273 | self.movementd = nn.Sequential(
274 | nn.Linear(dim_z_movement*2*length, nd),
275 | nn.LayerNorm(nd),
276 | nn.LeakyReLU(0.2, inplace=True),
277 | nn.Linear(nd,nd//2),
278 | nn.LayerNorm(nd//2),
279 | nn.LeakyReLU(0.2, inplace=True),
280 | nn.Linear(nd//2,nd//4),
281 | nn.LayerNorm(nd//4),
282 | nn.LeakyReLU(0.2, inplace=True),
283 | #nn.Linear(nd//4, 30),
284 | nn.Linear(nd//4, 30),
285 | )
286 |
287 | self.audd = nn.Sequential(
288 | nn.Linear(aud_size, 30),
289 | nn.LayerNorm(30),
290 | nn.LeakyReLU(0.2, inplace=True),
291 | nn.Linear(30, 30),
292 | nn.LayerNorm(30),
293 | nn.LeakyReLU(0.2, inplace=True),
294 | )
295 | self.jointd = nn.Sequential(
296 | nn.Linear(60, 1)
297 | )
298 |
299 | def forward(self, movements, aud):
300 | if len(movements.shape) == 3:
301 | movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
302 | m = self.movementd(movements)
303 | a = self.audd(aud)
304 | ma = torch.cat((m,a),1)
305 |
306 | return self.jointd(ma).squeeze(), None
307 |
308 | class DanceAud_Dis(nn.Module):
309 | def __init__(self, aud_size, dim_z_movement, length=3):
310 | super(DanceAud_Dis, self).__init__()
311 | self.aud_size = aud_size
312 | self.dim_z_movement = dim_z_movement
313 | self.length = length
314 | nd = 1024
315 | self.movementd = nn.Sequential(
316 | #nn.Linear(dim_z_movement*3, nd),
317 | nn.Linear(dim_z_movement*2, nd),
318 | nn.LayerNorm(nd),
319 | nn.LeakyReLU(0.2, inplace=True),
320 | nn.Linear(nd,nd//2),
321 | nn.LayerNorm(nd//2),
322 | nn.LeakyReLU(0.2, inplace=True),
323 | nn.Linear(nd//2,nd//4),
324 | nn.LayerNorm(nd//4),
325 | nn.LeakyReLU(0.2, inplace=True),
326 | #nn.Linear(nd//4, 30),
327 | nn.Linear(nd//4, 30),
328 | )
329 |
330 |
331 | def forward(self, movements, aud):
332 | #movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
333 | m = self.movementd(movements)
334 | return m.squeeze()
335 | #a = self.audd(aud)
336 | #ma = torch.cat((m,a),1)
337 |
338 | #return self.jointd(ma).squeeze()
339 |
340 | class DanceAud_InfoDis(nn.Module):
341 | def __init__(self, aud_size, dim_z_movement, length):
342 | super(DanceAud_InfoDis, self).__init__()
343 | self.aud_size = aud_size
344 | self.dim_z_movement = dim_z_movement
345 | self.length = length
346 | nd = 1024
347 |
348 | self.movementd = nn.Sequential(
349 | nn.Linear(dim_z_movement*6, nd*2),
350 | nn.LayerNorm(nd*2),
351 | nn.LeakyReLU(0.2, inplace=True),
352 | nn.Linear(nd*2, nd),
353 | nn.LayerNorm(nd),
354 | nn.LeakyReLU(0.2, inplace=True),
355 | nn.Linear(nd,nd//2),
356 | nn.LayerNorm(nd//2),
357 | nn.LeakyReLU(0.2, inplace=True),
358 | nn.Linear(nd//2,nd//4),
359 | nn.LayerNorm(nd//4),
360 | nn.LeakyReLU(0.2, inplace=True),
361 | )
362 |
363 | self.dis = nn.Sequential(
364 | nn.Linear(nd//4, 1)
365 | )
366 | self.reg = nn.Sequential(
367 | nn.Linear(nd//4, aud_size)
368 | )
369 |
370 | def forward(self, movements, aud):
371 | movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
372 | m = self.movementd(movements)
373 | return self.dis(m).squeeze(), self.reg(m)
374 |
375 | class Dance2Style(nn.Module):
376 | def __init__(self, dim_z_dance, aud_size):
377 | super(Dance2Style, self).__init__()
378 | self.aud_size = aud_size
379 | self.dim_z_dance = dim_z_dance
380 | nd = 512
381 | self.main = nn.Sequential(
382 | nn.Linear(dim_z_dance, nd),
383 | nn.LayerNorm(nd),
384 | nn.LeakyReLU(0.2, inplace=True),
385 | nn.Linear(nd, nd//2),
386 | nn.LayerNorm(nd//2),
387 | nn.LeakyReLU(0.2, inplace=True),
388 | nn.Linear(nd//2, nd//4),
389 | nn.LayerNorm(nd//4),
390 | nn.LeakyReLU(0.2, inplace=True),
391 | nn.Linear(nd//4, aud_size),
392 | )
393 | def forward(self, zdance):
394 | return self.main(zdance)
395 |
396 | ###########################################################
397 | ##########
398 | ########## Audio
399 | ##########
400 | ###########################################################
401 | class AudioClassifier_rnn(nn.Module):
402 | def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_layers=1, h_init=2):
403 | super(AudioClassifier_rnn, self).__init__()
404 | self.dim_z_motion = dim_z_motion
405 | self.hidden_size = hidden_size
406 | self.pose_size = pose_size
407 | self.h_init = h_init
408 | self.num_layers = num_layers
409 |
410 | self.init_h = nn.Parameter(torch.randn(1, 1, self.hidden_size).type(T.FloatTensor), requires_grad=True)
411 | self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True)
412 | self.classifier = nn.Sequential(
413 | #nn.Dropout(p=0.2),
414 | nn.Linear(hidden_size, hidden_size),
415 | nn.ReLU(True),
416 | #nn.Dropout(p=0.2),
417 | nn.Linear(hidden_size, cls)
418 | )
419 | def forward(self, poses):
420 | hidden, _ = self.recurrent(poses, self.init_h.repeat(1, poses.shape[0], 1))
421 | last_hidden = hidden[:,-1,:]
422 | cls = self.classifier(last_hidden)
423 | return cls
424 | def get_style(self, auds):
425 | hidden, _ = self.recurrent(auds, self.init_h.repeat(1, auds.shape[0], 1))
426 | last_hidden = hidden[:,-1,:]
427 | return last_hidden
428 |
429 |
430 | class Audstyle_Enc(nn.Module):
431 | def __init__(self, aud_size, dim_z, dim_noise=30):
432 | super(Audstyle_Enc, self).__init__()
433 | self.dim_noise = dim_noise
434 | nf = 64
435 | #nf = 32
436 | self.enc = nn.Sequential(
437 | nn.Linear(aud_size+dim_noise, nf),
438 | nn.LayerNorm(nf),
439 | nn.LeakyReLU(0.2, inplace=True),
440 | nn.Linear(nf, nf*2),
441 | nn.LayerNorm(nf*2),
442 | nn.LeakyReLU(0.2, inplace=True),
443 | )
444 | self.mean = nn.Sequential(
445 | nn.Linear(nf*2,dim_z),
446 | )
447 | self.std = nn.Sequential(
448 | nn.Linear(nf*2,dim_z),
449 | )
450 | def forward(self, aud):
451 | noise = torch.randn(aud.shape[0], self.dim_noise).cuda()
452 | y = torch.cat((aud, noise), 1)
453 | enc = self.enc(y)
454 | return self.mean(enc), self.std(enc)
455 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import argparse
9 |
10 |
11 | class DecompOptions():
12 | def __init__(self):
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--name', default=None)
16 |
17 | parser.add_argument('--log_interval', type=int, default=50)
18 | parser.add_argument('--log_dir', default='./logs')
19 | parser.add_argument('--snapshot_ep', type=int, default=1)
20 | parser.add_argument('--snapshot_dir', default='./snapshot')
21 | parser.add_argument('--data_dir', default='./data')
22 |
23 | # Model architecture
24 | parser.add_argument('--pose_size', type=int, default=28)
25 | parser.add_argument('--dim_z_init', type=int, default=10)
26 | parser.add_argument('--dim_z_movement', type=int, default=512)
27 | parser.add_argument('--stdp_length', type=int, default=32)
28 | parser.add_argument('--movement_enc_bidirection', type=int, default=1)
29 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
30 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
31 | parser.add_argument('--movement_enc_num_layers', type=int, default=1)
32 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
33 | # Training
34 | parser.add_argument('--lr', type=float, default=2e-4)
35 | parser.add_argument('--batch_size', type=int, default=256)
36 | parser.add_argument('--num_epochs', type=int, default=1000)
37 | parser.add_argument('--latent_dropout', type=float, default=0.3)
38 | parser.add_argument('--lambda_kl', type=float, default=0.01)
39 | parser.add_argument('--lambda_initp_recon', type=float, default=1)
40 | parser.add_argument('--lambda_initp_consistency', type=float, default=1)
41 | parser.add_argument('--lambda_stdp_recon', type=float, default=1)
42 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
43 | # Others
44 | parser.add_argument('--num_workers', type=int, default=4)
45 | parser.add_argument('--resume', default=None)
46 | parser.add_argument('--dataset', type=int, default=0)
47 | parser.add_argument('--tolerance', action='store_true')
48 |
49 | self.parser = parser
50 |
51 | def parse(self):
52 | self.opt = self.parser.parse_args()
53 | args = vars(self.opt)
54 | return self.opt
55 |
56 | class CompOptions():
57 | def __init__(self):
58 | parser = argparse.ArgumentParser()
59 |
60 | parser.add_argument('--name', default=None)
61 |
62 | parser.add_argument('--log_interval', type=int, default=50)
63 | parser.add_argument('--log_dir', default='./logs')
64 | parser.add_argument('--snapshot_ep', type=int, default=1)
65 | parser.add_argument('--snapshot_dir', default='./snapshot')
66 | parser.add_argument('--data_dir', default='./data')
67 | # Network architecture
68 | parser.add_argument('--pose_size', type=int, default=28)
69 | parser.add_argument('--aud_style_size', type=int, default=30)
70 | parser.add_argument('--dim_z_init', type=int, default=10)
71 | parser.add_argument('--dim_z_movement', type=int, default=512)
72 | parser.add_argument('--dim_z_dance', type=int, default=512)
73 | parser.add_argument('--stdp_length', type=int, default=32)
74 | parser.add_argument('--movement_enc_bidirection', type=int, default=1)
75 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
76 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
77 | parser.add_argument('--movement_enc_num_layers', type=int, default=1)
78 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
79 | parser.add_argument('--dance_enc_bidirection', type=int, default=0)
80 | parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
81 | parser.add_argument('--dance_enc_num_layers', type=int, default=1)
82 | parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
83 | parser.add_argument('--dance_dec_num_layers', type=int, default=1)
84 | # Training
85 | parser.add_argument('--lr', type=float, default=2e-4)
86 | parser.add_argument('--batch_size', type=int, default=256)
87 | parser.add_argument('--num_epochs', type=int, default=1500)
88 | parser.add_argument('--latent_dropout', type=float, default=0.3)
89 | parser.add_argument('--lambda_kl', type=float, default=0.01)
90 | parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
91 | parser.add_argument('--lambda_gan', type=float, default=1)
92 | parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
93 | parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
94 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
95 | # Other
96 | parser.add_argument('--num_workers', type=int, default=4)
97 | parser.add_argument('--decomp_snapshot', required=True)
98 | parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
99 | parser.add_argument('--resume', default=None)
100 | parser.add_argument('--dataset', type=int, default=2)
101 | self.parser = parser
102 |
103 | def parse(self):
104 | self.opt = self.parser.parse_args()
105 | args = vars(self.opt)
106 | return self.opt
107 |
108 | class TestOptions():
109 | def __init__(self):
110 | parser = argparse.ArgumentParser()
111 |
112 | parser.add_argument('--name', default=None)
113 |
114 | parser.add_argument('--log_interval', type=int, default=50)
115 | parser.add_argument('--log_dir', default='./logs')
116 | parser.add_argument('--snapshot_ep', type=int, default=1)
117 | parser.add_argument('--snapshot_dir', default='./snapshot')
118 | parser.add_argument('--data_dir', default='./data')
119 | # Network architecture
120 | parser.add_argument('--pose_size', type=int, default=28)
121 | parser.add_argument('--aud_style_size', type=int, default=30)
122 | parser.add_argument('--dim_z_init', type=int, default=10)
123 | parser.add_argument('--dim_z_movement', type=int, default=512)
124 | parser.add_argument('--dim_z_dance', type=int, default=512)
125 | parser.add_argument('--stdp_length', type=int, default=32)
126 | parser.add_argument('--movement_enc_bidirection', type=int, default=1)
127 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
128 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
129 | parser.add_argument('--movement_enc_num_layers', type=int, default=1)
130 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
131 | parser.add_argument('--dance_enc_bidirection', type=int, default=0)
132 | parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
133 | parser.add_argument('--dance_enc_num_layers', type=int, default=1)
134 | parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
135 | parser.add_argument('--dance_dec_num_layers', type=int, default=1)
136 | # Training
137 | parser.add_argument('--lr', type=float, default=2e-4)
138 | parser.add_argument('--batch_size', type=int, default=256)
139 | parser.add_argument('--num_epochs', type=int, default=1500)
140 | parser.add_argument('--latent_dropout', type=float, default=0.3)
141 | parser.add_argument('--lambda_kl', type=float, default=0.01)
142 | parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
143 | parser.add_argument('--lambda_gan', type=float, default=1)
144 | parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
145 | parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
146 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
147 | # Other
148 | parser.add_argument('--num_workers', type=int, default=4)
149 | parser.add_argument('--decomp_snapshot', required=True)
150 | parser.add_argument('--comp_snapshot', required=True)
151 | parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
152 | parser.add_argument('--dataset', type=int, default=2)
153 | parser.add_argument('--thr', type=int, default=50)
154 | parser.add_argument('--aud_path', type=str, required=True)
155 | parser.add_argument('--modulate', action='store_true')
156 | parser.add_argument('--out_file', type=str, default='demo/out.mp4')
157 | parser.add_argument('--out_dir', type=str, default='demo/out_frame')
158 | self.parser = parser
159 |
160 | def parse(self):
161 | self.opt = self.parser.parse_args()
162 | args = vars(self.opt)
163 | return self.opt
164 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import os
9 | import argparse
10 | import functools
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from torchvision import transforms
15 |
16 | from model_comp import *
17 | from networks import *
18 | from options import CompOptions
19 | from data import get_loader
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = CompOptions()
24 | args = parser.parse()
25 | #### Pretrain network from Decomp
26 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
27 |
28 | #### Comp network
29 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
30 |
31 | mean_pose=np.load('../onbeat/all_onbeat_mean.npy')
32 | std_pose=np.load('../onbeat/all_onbeat_std.npy')
33 | mean_aud=np.load('../onbeat/all_aud_mean.npy')
34 | std_aud=np.load('../onbeat/all_aud_std.npy')
35 |
36 |
37 | def loadDecompModel(args):
38 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
39 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
40 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
41 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
42 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
43 | checkpoint = torch.load(args.decomp_snapshot)
44 | initp_enc.load_state_dict(checkpoint['initp_enc'])
45 | stdp_dec.load_state_dict(checkpoint['stdp_dec'])
46 | movement_enc.load_state_dict(checkpoint['movement_enc'])
47 | return initp_enc, stdp_dec, movement_enc
48 |
49 | def loadCompModel(args):
50 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
51 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
52 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
53 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
54 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
55 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
56 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
57 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
58 | checkpoint = torch.load(args.resume)
59 | dance_enc.load_state_dict(checkpoint['dance_enc'])
60 | dance_dec.load_state_dict(checkpoint['dance_dec'])
61 | audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
62 | dance_reg.load_state_dict(checkpoint['dance_reg'])
63 | danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
64 | zdance_dis.load_state_dict(checkpoint['zdance_dis'])
65 |
66 | checkpoint2 = torch.load(args.neta_snapshot)
67 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
68 | neta_cls.load_state_dict(checkpoint2)
69 |
70 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
71 |
--------------------------------------------------------------------------------
/train_comp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import os
9 | import argparse
10 | import functools
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from torchvision import transforms
15 |
16 | from model_comp import *
17 | from networks import *
18 | from options import CompOptions
19 | from data import get_loader
20 |
21 | def loadDecompModel(args):
22 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
23 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
24 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
25 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
26 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
27 | checkpoint = torch.load(args.decomp_snapshot)
28 | initp_enc.load_state_dict(checkpoint['initp_enc'])
29 | stdp_dec.load_state_dict(checkpoint['stdp_dec'])
30 | movement_enc.load_state_dict(checkpoint['movement_enc'])
31 | return initp_enc, stdp_dec, movement_enc
32 |
33 | def getCompNetworks(args):
34 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
35 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
36 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
37 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
38 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
39 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
40 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
41 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
42 |
43 | checkpoint2 = torch.load(args.neta_snapshot)
44 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
45 | neta_cls.load_state_dict(checkpoint2)
46 |
47 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
48 |
49 | if __name__ == "__main__":
50 | parser = CompOptions()
51 | args = parser.parse()
52 |
53 | args.train = True
54 |
55 | if args.name is None:
56 | args.name = 'Comp'
57 |
58 | args.log_dir = os.path.join(args.log_dir, args.name)
59 | if not os.path.exists(args.log_dir):
60 | os.mkdir(args.log_dir)
61 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
62 | if not os.path.exists(args.snapshot_dir):
63 | os.mkdir(args.snapshot_dir)
64 |
65 | data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir)
66 |
67 | #### Pretrain network from Decomp
68 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
69 |
70 | #### Comp network
71 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = getCompNetworks(args)
72 |
73 |
74 | trainer = Trainer_Comp(data_loader,
75 | movement_enc = movement_enc,
76 | initp_enc = initp_enc,
77 | stdp_dec = stdp_dec,
78 | dance_enc = dance_enc,
79 | dance_dec = dance_dec,
80 | danceAud_dis = danceAud_dis,
81 | zdance_dis = zdance_dis,
82 | aud_enc=neta_cls,
83 | audstyle_enc=audstyle_enc,
84 | dance_reg=dance_reg,
85 | args = args
86 | )
87 |
88 | if not args.resume is None:
89 | ep, it = trainer.resume(args.resume, True)
90 | else:
91 | ep, it = 0, 0
92 | trainer.train(ep, it)
93 |
94 |
--------------------------------------------------------------------------------
/train_decomp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import os
9 | import argparse
10 | import functools
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from torchvision import transforms
15 |
16 | from model_decomp import *
17 | from networks import *
18 | from options import DecompOptions
19 | from data import get_loader
20 |
21 | def getDecompNetworks(args):
22 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
23 | initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
24 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
25 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
26 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
27 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
28 | return initp_enc, initp_dec, movement_enc, stdp_dec
29 |
30 | if __name__ == "__main__":
31 | parser = DecompOptions()
32 | args = parser.parse()
33 |
34 | args.train = True
35 |
36 | if args.name is None:
37 | args.name = 'Decomp'
38 |
39 | args.log_dir = os.path.join(args.log_dir, args.name)
40 | if not os.path.exists(args.log_dir):
41 | os.mkdir(args.log_dir)
42 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
43 | if not os.path.exists(args.snapshot_dir):
44 | os.mkdir(args.snapshot_dir)
45 |
46 | data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance)
47 |
48 | initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args)
49 |
50 | trainer = Trainer_Decomp(data_loader,
51 | initp_enc = initp_enc,
52 | initp_dec = initp_dec,
53 | movement_enc = movement_enc,
54 | stdp_dec = stdp_dec,
55 | args = args
56 | )
57 |
58 | if not args.resume is None:
59 | ep, it = trainer.resume(args.resume, False)
60 | else:
61 | ep, it = 0, 0
62 |
63 | trainer.train(ep=ep, it=it)
64 |
65 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/Dancing2Music/License.txt
7 |
8 | import numpy as np
9 | import pickle
10 | import cv2
11 | import math
12 | import os
13 | import random
14 | import tensorflow as tf
15 |
16 | class Logger(object):
17 | def __init__(self, log_dir):
18 | self.writer = tf.summary.FileWriter(log_dir)
19 |
20 | def scalar_summary(self, tag, value, step):
21 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
22 | self.writer.add_summary(summary, step)
23 |
24 | def vis(poses, outdir, aud=None):
25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28 |
29 | # find connection in the specified sequence, center 29 is in the position 15
30 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
31 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
32 | [1,16], [16,18], [3,17], [6,18]]
33 |
34 | neglect = [14,15,16,17]
35 |
36 | for t in range(poses.shape[0]):
37 | #break
38 | canvas = np.ones((256,500,3), np.uint8)*255
39 |
40 | thisPeak = poses[t]
41 | for i in range(18):
42 | if i in neglect:
43 | continue
44 | if thisPeak[i,0] == -1:
45 | continue
46 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
47 |
48 | for i in range(17):
49 | limbid = np.array(limbSeq[i])-1
50 | if limbid[0] in neglect or limbid[1] in neglect:
51 | continue
52 | X = thisPeak[[limbid[0],limbid[1]], 1]
53 | Y = thisPeak[[limbid[0],limbid[1]], 0]
54 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
55 | continue
56 | stickwidth = 4
57 | cur_canvas = canvas.copy()
58 | mX = np.mean(X)
59 | mY = np.mean(Y)
60 | #print(X, Y, limbid)
61 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
62 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
63 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
64 | #print(i, n, int(mY), int(mX), limbid, X, Y)
65 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
66 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
67 | if aud is not None:
68 | if aud[:,t] == 1:
69 | cv2.circle(canvas, (30, 30), 20, (0,0,255), -1)
70 | #canvas = cv2.copyMakeBorder(canvas,10,10,10,10,cv2.BORDER_CONSTANT,value=[255,0,0])
71 | cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
72 |
73 | def vis2(poses, outdir, fibeat):
74 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
75 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
76 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
77 |
78 | # find connection in the specified sequence, center 29 is in the position 15
79 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
80 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
81 | [1,16], [16,18], [3,17], [6,18]]
82 |
83 |
84 | neglect = [14,15,16,17]
85 |
86 | ibeat = cv2.imread(fibeat);
87 | ibeat = cv2.resize(ibeat, (500,200))
88 |
89 | for t in range(poses.shape[0]):
90 | subibeat = ibeat.copy()
91 | canvas = np.ones((256+200,500,3), np.uint8)*255
92 | canvas[256:,:,:] = subibeat
93 |
94 | overlay = canvas.copy()
95 | cv2.rectangle(overlay, (int(500/poses.shape[0]*(t+1)),256),(500,256+200), (100,100,100), -1)
96 | cv2.addWeighted(overlay, 0.4, canvas, 1-0.4, 0, canvas)
97 | thisPeak = poses[t]
98 | for i in range(18):
99 | if i in neglect:
100 | continue
101 | if thisPeak[i,0] == -1:
102 | continue
103 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
104 |
105 | for i in range(17):
106 | limbid = np.array(limbSeq[i])-1
107 | if limbid[0] in neglect or limbid[1] in neglect:
108 | continue
109 | X = thisPeak[[limbid[0],limbid[1]], 1]
110 | Y = thisPeak[[limbid[0],limbid[1]], 0]
111 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
112 | continue
113 | stickwidth = 4
114 | cur_canvas = canvas.copy()
115 | mX = np.mean(X)
116 | mY = np.mean(Y)
117 | #print(X, Y, limbid)
118 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
119 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
120 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
121 | #print(i, n, int(mY), int(mX), limbid, X, Y)
122 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
123 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
124 | cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
125 |
126 | def vis_single(pose, outfile):
127 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
128 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
129 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
130 |
131 | # find connection in the specified sequence, center 29 is in the position 15
132 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
133 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
134 | [1,16], [16,18], [3,17], [6,18]]
135 |
136 | neglect = [14,15,16,17]
137 |
138 | for t in range(1):
139 | #break
140 | canvas = np.ones((256,500,3), np.uint8)*255
141 |
142 | thisPeak = pose
143 | for i in range(18):
144 | if i in neglect:
145 | continue
146 | if thisPeak[i,0] == -1:
147 | continue
148 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
149 |
150 | for i in range(17):
151 | limbid = np.array(limbSeq[i])-1
152 | if limbid[0] in neglect or limbid[1] in neglect:
153 | continue
154 | X = thisPeak[[limbid[0],limbid[1]], 1]
155 | Y = thisPeak[[limbid[0],limbid[1]], 0]
156 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
157 | continue
158 | stickwidth = 4
159 | cur_canvas = canvas.copy()
160 | mX = np.mean(X)
161 | mY = np.mean(Y)
162 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
163 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
164 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
165 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
166 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
167 | cv2.imwrite(outfile,canvas)
168 |
--------------------------------------------------------------------------------