├── License.txt ├── README.md ├── data.py ├── demo.py ├── imgs ├── .DS_Store ├── example.gif ├── long.gif ├── multimodal.gif └── v2v.gif ├── model_comp.py ├── model_decomp.py ├── modulate.py ├── networks.py ├── options.py ├── test.py ├── train_comp.py ├── train_decomp.py └── utils.py /License.txt: -------------------------------------------------------------------------------- 1 | Nvidia Source Code License-NC 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 9 | 10 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 13 | 14 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 15 | 16 | 2. License Grants 17 | 18 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 19 | 20 | 3. Limitations 21 | 22 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 23 | 24 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 25 | 26 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 27 | 28 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately. 29 | 30 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 31 | 32 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately. 33 | 34 | 4. Disclaimer of Warranty. 35 | 36 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 37 | 38 | 5. Limitation of Liability. 39 | 40 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 2 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 3 | ## Dancing to Music 4 | PyTorch implementation of the cross-modality generative model that synthesizes dance from music. 5 | 6 | 7 | ### Paper 8 | [Hsin-Ying Lee](http://vllab.ucmerced.edu/hylee/), [Xiaodong Yang](https://xiaodongyang.org/), [Ming-Yu Liu](http://mingyuliu.net/), [Ting-Chun Wang](https://tcwang0509.github.io/), [Yu-Ding Lu](https://jonlu0602.github.io/), [Ming-Hsuan Yang](https://faculty.ucmerced.edu/mhyang/), [Jan Kautz](http://jankautz.com/) 9 | Dancing to Music 10 | Neural Information Processing Systems (**NeurIPS**) 2019 11 | [[Paper]](https://arxiv.org/abs/1911.02001) [[YouTube]](https://youtu.be/-e9USqfwZ4A) [[Project]](http://vllab.ucmerced.edu/hylee/Dancing2Music/script.txt) [[Blog]](https://news.developer.nvidia.com/nvidia-dance-to-music-neurips/) [[Supp]](http://xiaodongyang.org/publications/papers/dance2music-supp-neurips19.pdf) 12 | 13 | ### Example Videos 14 | - Beat-Matching 15 | 1st row: generated dance sequences, 2nd row: music beats, 3rd row: kinematics beats 16 |

17 | 18 |

19 | 20 | - Multimodality 21 | Generate various dance sequences with the same music and the same initial pose. 22 |

23 | 24 |

25 | 26 | - Long-Term Generation 27 | Seamlessly generate a dance sequence with arbitrary length. 28 |

29 | 30 | 31 | 32 |

33 | 34 | - Photo-Realisitc Videos 35 | Map generated dance sequences to photo-realistic videos. 36 |

37 | 38 |

39 | 40 | 41 | ## Train Decomposition 42 | ``` 43 | python train_decomp.py --name Decomp 44 | ``` 45 | 46 | ## Train Composition 47 | ``` 48 | python train_comp.py --name Decomp --decomp_snapshot DECOMP_SNAPSHOT 49 | ``` 50 | 51 | ## Demo 52 | ``` 53 | python demo.py --decomp_snapshot DECOMP_SNAPSHOT --comp_snapshot COMP_SNAPSHOT --aud_path AUD_PATH --out_file OUT_FILE --out_dir OUT_DIR --thr THR 54 | ``` 55 | - Flags 56 | - `aud_path`: input .wav file 57 | - `out_file`: location of output .mp4 file 58 | - `out_dir`: directory of output frames 59 | - `thr`: threshold based on motion magnitude 60 | - `modulate`: whether to do beat warping 61 | 62 | - Example 63 | ``` 64 | python demo.py -decomp_snapshot snapshot/Stage1.ckpt --comp_snapshot snapshot/Stage2.ckpt --aud_path demo/demo.wav --out_file demo/out.mp4 --out_dir demo/out_frame 65 | ``` 66 | 67 | 68 | ### Citation 69 | If you find this code useful for your research, please cite our paper: 70 | ```bibtex 71 | @inproceedings{lee2019dancing2music, 72 | title={Dancing to Music}, 73 | author={Lee, Hsin-Ying and Yang, Xiaodong and Liu, Ming-Yu and Wang, Ting-Chun and Lu, Yu-Ding and Yang, Ming-Hsuan and Kautz, Jan}, 74 | booktitle={NeurIPS}, 75 | year={2019} 76 | } 77 | ``` 78 | 79 | ### License 80 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. This work is made available under NVIDIA Source Code License (1-Way Commercial). To view a copy of this license, visit https://nvlabs.github.io/Dancing2Music/LICENSE.txt. 81 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | import os 8 | import pickle 9 | import numpy as np 10 | import random 11 | import torch.utils.data 12 | from torchvision.datasets import ImageFolder 13 | import utils 14 | 15 | 16 | class PoseDataset(torch.utils.data.Dataset): 17 | def __init__(self, data_dir, tolerance=False): 18 | self.data_dir = data_dir 19 | z_fname = '{}/unitList/zumba_unit.txt'.format(data_dir) 20 | b_fname = '{}/unitList/ballet_unit.txt'.format(data_dir) 21 | h_fname = '{}/unitList/hiphop_unit.txt'.format(data_dir) 22 | self.z_data = [] 23 | self.b_data = [] 24 | self.h_data = [] 25 | with open(z_fname, 'r') as f: 26 | for line in f: 27 | self.z_data.append([s for s in line.strip().split(' ')]) 28 | with open(b_fname, 'r') as f: 29 | for line in f: 30 | self.b_data.append([s for s in line.strip().split(' ')]) 31 | with open(h_fname, 'r') as f: 32 | for line in f: 33 | self.h_data.append([s for s in line.strip().split(' ')]) 34 | self.data = [self.z_data, self.b_data, self.h_data] 35 | 36 | self.tolerance = tolerance 37 | if self.tolerance: 38 | z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir) 39 | b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir) 40 | h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir) 41 | z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir) 42 | b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir) 43 | h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir) 44 | z3_data = []; b3_data = []; h3_data = []; z4_data = []; b4_data = []; h4_data = [] 45 | with open(z3_fname, 'r') as f: 46 | for line in f: 47 | z3_data.append([s for s in line.strip().split(' ')]) 48 | with open(b3_fname, 'r') as f: 49 | for line in f: 50 | b3_data.append([s for s in line.strip().split(' ')]) 51 | with open(h3_fname, 'r') as f: 52 | for line in f: 53 | h3_data.append([s for s in line.strip().split(' ')]) 54 | with open(z4_fname, 'r') as f: 55 | for line in f: 56 | z4_data.append([s for s in line.strip().split(' ')]) 57 | with open(b4_fname, 'r') as f: 58 | for line in f: 59 | b4_data.append([s for s in line.strip().split(' ')]) 60 | with open(h4_fname, 'r') as f: 61 | for line in f: 62 | h4_data.append([s for s in line.strip().split(' ')]) 63 | self.zt_data = z3_data + z4_data 64 | self.bt_data = b3_data + b4_data 65 | self.ht_data = h3_data + h4_data 66 | self.t_data = [self.zt_data, self.bt_data, self.ht_data] 67 | 68 | self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy') 69 | self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy') 70 | 71 | def __getitem__(self, index): 72 | cls = random.randint(0,2) 73 | cls = random.randint(0,1) 74 | if self.tolerance and random.randint(0,9)==0: 75 | index = random.randint(0, len(self.t_data[cls])-1) 76 | path = self.t_data[cls][index][0] 77 | path = os.path.join(self.data_dir, path[5:]) 78 | orig_poses = np.load(path) 79 | sel = random.randint(0, orig_poses.shape[0]-1) 80 | orig_poses = orig_poses[sel] 81 | else: 82 | index = random.randint(0, len(self.data[cls])-1) 83 | path = self.data[cls][index][0] 84 | path = os.path.join(self.data_dir, path[5:]) 85 | orig_poses = np.load(path) 86 | 87 | xjit = np.random.uniform(low=-50, high=50) 88 | yjit = np.random.uniform(low=-20, high=20) 89 | poses = orig_poses.copy() 90 | poses[:,:,0] += xjit 91 | poses[:,:,1] += yjit 92 | xjit = np.random.uniform(low=-50, high=50) 93 | yjit = np.random.uniform(low=-20, high=20) 94 | poses2 = orig_poses.copy() 95 | poses2[:,:,0] += xjit 96 | poses2[:,:,1] += yjit 97 | 98 | poses = poses.reshape(poses.shape[0], poses.shape[1]*poses.shape[2]) 99 | poses2 = poses2.reshape(poses2.shape[0], poses2.shape[1]*poses2.shape[2]) 100 | for i in range(poses.shape[0]): 101 | poses[i] = (poses[i]-self.mean_pose)/self.std_pose 102 | poses2[i] = (poses2[i]-self.mean_pose)/self.std_pose 103 | 104 | return torch.Tensor(poses), torch.Tensor(poses2) 105 | 106 | def __len__(self): 107 | return len(self.z_data)+len(self.b_data) 108 | 109 | 110 | class MovementAudDataset(torch.utils.data.Dataset): 111 | def __init__(self, data_dir): 112 | self.data_dir = data_dir 113 | z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir) 114 | b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir) 115 | h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir) 116 | z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir) 117 | b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir) 118 | h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir) 119 | self.z3_data = [] 120 | self.b3_data = [] 121 | self.h3_data = [] 122 | self.z4_data = [] 123 | self.b4_data = [] 124 | self.h4_data = [] 125 | with open(z3_fname, 'r') as f: 126 | for line in f: 127 | self.z3_data.append([s for s in line.strip().split(' ')]) 128 | with open(b3_fname, 'r') as f: 129 | for line in f: 130 | self.b3_data.append([s for s in line.strip().split(' ')]) 131 | with open(h3_fname, 'r') as f: 132 | for line in f: 133 | self.h3_data.append([s for s in line.strip().split(' ')]) 134 | with open(z4_fname, 'r') as f: 135 | for line in f: 136 | self.z4_data.append([s for s in line.strip().split(' ')]) 137 | with open(b4_fname, 'r') as f: 138 | for line in f: 139 | self.b4_data.append([s for s in line.strip().split(' ')]) 140 | with open(h4_fname, 'r') as f: 141 | for line in f: 142 | self.h4_data.append([s for s in line.strip().split(' ')]) 143 | self.data_3 = [self.z3_data, self.b3_data, self.h3_data] 144 | self.data_4 = [self.z4_data, self.b4_data, self.h4_data] 145 | 146 | z_data_root = 'zumba/' 147 | b_data_root = 'ballet/' 148 | h_data_root = 'hiphop/' 149 | self.data_root = [z_data_root, b_data_root, h_data_root ] 150 | self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy') 151 | self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy') 152 | self.mean_aud=np.load(data_dir+'/stats/all_aud_mean.npy') 153 | self.std_aud=np.load(data_dir+'/stats/all_aud_std.npy') 154 | 155 | def __getitem__(self, index): 156 | cls = random.randint(0,2) 157 | cls = random.randint(0,1) 158 | isthree = random.randint(0,1) 159 | 160 | if isthree == 0: 161 | index = random.randint(0, len(self.data_4[cls])-1) 162 | path = self.data_4[cls][index][0] 163 | else: 164 | index = random.randint(0, len(self.data_3[cls])-1) 165 | path = self.data_3[cls][index][0] 166 | path = os.path.join(self.data_dir, path[5:]) 167 | stdpSeq = np.load(path) 168 | vid, cid = path.split('/')[-4], path.split('/')[-3] 169 | #vid, cid = vid_cid[:11], vid_cid[12:] 170 | aud = np.load('{}/{}/{}/{}/aud/c{}_fps15.npy'.format(self.data_dir, self.data_root[cls], vid, cid, cid)) 171 | 172 | stdpSeq = stdpSeq.reshape(stdpSeq.shape[0], stdpSeq.shape[1], stdpSeq.shape[2]*stdpSeq.shape[3]) 173 | for i in range(stdpSeq.shape[0]): 174 | for j in range(stdpSeq.shape[1]): 175 | stdpSeq[i,j] = (stdpSeq[i,j]-self.mean_pose)/self.std_pose 176 | if isthree == 0: 177 | start = random.randint(0,1) 178 | stdpSeq = stdpSeq[start:start+3] 179 | 180 | for i in range(aud.shape[0]): 181 | aud[i] = (aud[i]-self.mean_aud)/self.std_aud 182 | aud = aud[:30] 183 | return torch.Tensor(stdpSeq), torch.Tensor(aud) 184 | 185 | def __len__(self): 186 | return len(self.z3_data)+len(self.b3_data)+len(self.z4_data)+len(self.b4_data)+len(self.h3_data)+len(self.h4_data) 187 | 188 | def get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tolerance=False): 189 | if dataset == 0: 190 | a2d = PoseDataset(data_dir, tolerance) 191 | elif dataset == 2: 192 | a2d = MovementAudDataset(data_dir) 193 | data_loader = torch.utils.data.DataLoader(dataset=a2d, 194 | batch_size=batch_size, 195 | shuffle=shuffle, 196 | num_workers=num_workers, 197 | ) 198 | return data_loader 199 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | import os 8 | import argparse 9 | import functools 10 | import librosa 11 | import shutil 12 | import sys 13 | sys.path.insert(0, 'preprocess') 14 | import preprocess as p 15 | import subprocess as sp 16 | from shutil import copyfile 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | 22 | from model_comp import * 23 | from networks import * 24 | from options import TestOptions 25 | import modulate 26 | import utils 27 | 28 | def loadDecompModel(args): 29 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init) 30 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length, 31 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers) 32 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length, 33 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1)) 34 | checkpoint = torch.load(args.decomp_snapshot) 35 | initp_enc.load_state_dict(checkpoint['initp_enc']) 36 | stdp_dec.load_state_dict(checkpoint['stdp_dec']) 37 | movement_enc.load_state_dict(checkpoint['movement_enc']) 38 | return initp_enc, stdp_dec, movement_enc 39 | 40 | def loadCompModel(args): 41 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 42 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1)) 43 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 44 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers) 45 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance) 46 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance) 47 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3) 48 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1) 49 | checkpoint = torch.load(args.comp_snapshot) 50 | dance_enc.load_state_dict(checkpoint['dance_enc']) 51 | dance_dec.load_state_dict(checkpoint['dance_dec']) 52 | audstyle_enc.load_state_dict(checkpoint['audstyle_enc']) 53 | 54 | checkpoint2 = torch.load(args.neta_snapshot) 55 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3) 56 | neta_cls.load_state_dict(checkpoint2) 57 | 58 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls 59 | 60 | if __name__ == "__main__": 61 | parser = TestOptions() 62 | args = parser.parse() 63 | args.train = False 64 | 65 | thr = args.thr 66 | 67 | # Process music and get feature 68 | infile = args.aud_path 69 | outfile = 'style.npy' 70 | p.preprocess(infile, outfile) 71 | 72 | y, sr = librosa.load(infile) 73 | onset_env = librosa.onset.onset_strength(y, sr=sr,aggregate=np.median) 74 | times = librosa.frames_to_time(np.arange(len(onset_env)),sr=sr, hop_length=512) 75 | tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env,sr=sr) 76 | np.save('beats.npy', times[beats]) 77 | beats = np.round(librosa.frames_to_time(beats, sr=sr)*15) 78 | 79 | beats = np.load('beats.npy') 80 | aud = np.load('style.npy') 81 | os.remove('beats.npy') 82 | os.remove('style.npy') 83 | shutil.rmtree('normalized') 84 | 85 | #### Pretrain network from Decomp 86 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args) 87 | 88 | #### Comp network 89 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args) 90 | 91 | trainer = Trainer_Comp(data_loader=None, 92 | movement_enc = movement_enc, 93 | initp_enc = initp_enc, 94 | stdp_dec = stdp_dec, 95 | dance_enc = dance_enc, 96 | dance_dec = dance_dec, 97 | danceAud_dis = danceAud_dis, 98 | zdance_dis = zdance_dis, 99 | aud_enc=neta_cls, 100 | audstyle_enc=audstyle_enc, 101 | dance_reg=dance_reg, 102 | args = args 103 | ) 104 | 105 | print('Loading Done') 106 | 107 | mean_pose=np.load('{}/stats/all_onbeat_mean.npy'.format(args.data_dir)) 108 | std_pose=np.load('{}/stats/all_onbeat_std.npy'.format(args.data_dir)) 109 | mean_aud=np.load('{}/stats/all_aud_mean.npy'.format(args.data_dir)) 110 | std_aud=np.load('{}/stats/all_aud_std.npy'.format(args.data_dir)) 111 | 112 | 113 | length = aud.shape[0] 114 | 115 | initpose = np.zeros((14, 2)) 116 | initpose = initpose.reshape(-1) 117 | #initpose = (initpose-mean_pose)/std_pose 118 | 119 | for j in range(aud.shape[0]): 120 | aud[j] = (aud[j]-mean_aud)/std_aud 121 | 122 | total_t = int(length/32+1) 123 | final_stdpSeq = np.zeros((total_t*3*32, 14, 2)) 124 | initpose, aud = torch.Tensor(initpose).cuda(), torch.Tensor(aud).cuda() 125 | initpose, aud = initpose.view(1, initpose.shape[0]), aud.view(1, aud.shape[0], aud.shape[1]) 126 | for t in range(total_t): 127 | print('process {}/{}'.format(t, total_t)) 128 | fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr) 129 | while True: 130 | fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr) 131 | if not fake_stdpSeq is None: 132 | break 133 | initpose = fake_stdpSeq[2,-1] 134 | initpose = torch.Tensor(initpose).cuda() 135 | initpose = initpose.view(1,-1) 136 | fake_stdpSeq = fake_stdpSeq.squeeze() 137 | for j in range(fake_stdpSeq.shape[0]): 138 | for k in range(fake_stdpSeq.shape[1]): 139 | fake_stdpSeq[j,k] = fake_stdpSeq[j,k]*std_pose + mean_pose 140 | fake_stdpSeq = np.resize(fake_stdpSeq, (fake_stdpSeq.shape[0],32, 14, 2)) 141 | for j in range(3): 142 | final_stdpSeq[96*t+32*j:96*t+32*(j+1)] = fake_stdpSeq[j] 143 | 144 | if args.modulate: 145 | final_stdpSeq = modulate.modulate(final_stdpSeq, beats, length) 146 | 147 | out_dir = args.out_dir 148 | if not os.path.exists(out_dir): 149 | os.mkdir(out_dir) 150 | utils.vis(final_stdpSeq, out_dir) 151 | sp.call('ffmpeg -r 15 -i {}/frame%03d.png -i {} -c:v libx264 -pix_fmt yuv420p -crf 23 -r 30 -y -strict -2 {}'.format(out_dir, args.aud_path, args.out_file), shell=True) 152 | 153 | -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/example.gif -------------------------------------------------------------------------------- /imgs/long.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/long.gif -------------------------------------------------------------------------------- /imgs/multimodal.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/multimodal.gif -------------------------------------------------------------------------------- /imgs/v2v.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Dancing2Music/7ff1d95f9f3d3585e29ee7e4ca5a3a45e29db6de/imgs/v2v.gif -------------------------------------------------------------------------------- /model_comp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | import os 8 | import time 9 | import numpy as np 10 | import random 11 | import math 12 | 13 | import torch 14 | from torch import nn 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | from torch.nn.utils import clip_grad_norm_ 18 | 19 | from utils import Logger 20 | 21 | if torch.cuda.is_available(): 22 | T = torch.cuda 23 | else: 24 | T = torch 25 | 26 | class Trainer_Comp(object): 27 | def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, movement_enc, initp_enc, stdp_dec, aud_enc, audstyle_enc, dance_reg=None, args=None, zdance_dis=None): 28 | self.data_loader = data_loader 29 | self.movement_enc = movement_enc 30 | self.initp_enc = initp_enc 31 | self.stdp_dec = stdp_dec 32 | self.dance_enc = dance_enc 33 | self.dance_dec = dance_dec 34 | self.danceAud_dis = danceAud_dis 35 | self.aud_enc = aud_enc 36 | self.audstyle_enc = audstyle_enc 37 | self.train = args.train 38 | self.args = args 39 | 40 | if args.train: 41 | self.zdance_dis = zdance_dis 42 | self.dance_reg = dance_reg 43 | 44 | self.logger = Logger(args.log_dir) 45 | self.logs = self.init_logs() 46 | self.log_interval = args.log_interval 47 | self.snapshot_ep = args.snapshot_ep 48 | self.snapshot_dir = args.snapshot_dir 49 | 50 | self.opt_dance_enc = torch.optim.Adam(self.dance_enc.parameters(), lr=args.lr) 51 | self.opt_dance_dec = torch.optim.Adam(self.dance_dec.parameters(), lr=args.lr) 52 | self.opt_danceAud_dis = torch.optim.Adam(self.danceAud_dis.parameters(), lr=args.lr) 53 | self.opt_audstyle_enc = torch.optim.Adam(self.audstyle_enc.parameters(), lr=args.lr) 54 | self.opt_zdance_dis = torch.optim.Adam(self.zdance_dis.parameters(), lr=args.lr) 55 | self.opt_dance_reg = torch.optim.Adam(self.dance_reg.parameters(), lr=args.lr) 56 | 57 | self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr*0.1) 58 | self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr*0.1) 59 | 60 | self.latent_dropout = nn.Dropout(p=args.latent_dropout) 61 | self.l1_criterion = torch.nn.L1Loss() 62 | self.gan_criterion = nn.BCEWithLogitsLoss() 63 | self.mse_criterion = nn.MSELoss().cuda() 64 | 65 | def init_logs(self): 66 | return {'l_kl_zdance':0, 'l_kl_zmovement':0, 'l_kl_fake_zdance':0, 'l_kl_fake_zmovement':0, 67 | 'l_l1_zmovement_mu':0, 'l_l1_zmovement_logvar':0, 'l_l1_stdpSeq':0, 'l_l1_zdance':0, 68 | 'l_dis':0, 'l_dis_true':0, 'l_dis_fake':0, 69 | 'l_info':0, 'l_info_real':0, 'l_info_fake':0, 70 | 'l_gen':0 71 | } 72 | 73 | def get_z_random(self, batchSize, nz, random_type='gauss'): 74 | z = torch.randn(batchSize, nz).cuda() 75 | return z 76 | 77 | @staticmethod 78 | def ones_like(tensor, val=1.): 79 | return T.FloatTensor(tensor.size()).fill_(val) 80 | 81 | @staticmethod 82 | def zeros_like(tensor, val=0.): 83 | return T.FloatTensor(tensor.size()).fill_(val) 84 | def kld_coef(self, i): 85 | return float(1/(1+np.exp(-0.0005*(i-15000)))) 86 | 87 | 88 | def forward(self, stdpSeq, batchsize, aud_style, aud): 89 | self.aud = torch.mean(aud, dim=1) 90 | 91 | self.batchsize = batchsize 92 | self.stdpSeq = stdpSeq 93 | self.aud_style = aud_style 94 | ### stdpSeq -> z_inits, z_movements 95 | self.pose_0 = stdpSeq[:,0,:] 96 | self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0) 97 | z_init_std = self.z_init_logvar.mul(0.5).exp_() 98 | z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss') 99 | self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu) 100 | 101 | self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdpSeq) 102 | z_movement_stds = self.z_movement_logvars.mul(0.5).exp_() 103 | z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss') 104 | self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus) 105 | self.z_movementSeq_mu = self.z_movement_mus.view(batchsize, -1, self.z_movements.shape[1]) 106 | self.z_movementSeq_logvar = self.z_movement_logvars.view(batchsize, -1, self.z_movements.shape[1]) 107 | 108 | self.z_init, self.z_movements = self.z_init.detach(), self.z_movements.detach() 109 | self.z_movement_mus, self.z_movement_logvars = self.z_movement_mus.detach(), self.z_movement_logvars.detach() 110 | 111 | ### z_movements -> z_dance 112 | self.z_dance_mu, self.z_dance_logvar = self.dance_enc(self.z_movementSeq_mu, self.z_movementSeq_logvar) 113 | z_dance_std = self.z_dance_logvar.mul(0.5).exp_() 114 | z_dance_eps = self.get_z_random(z_dance_std.size(0), z_dance_std.size(1), 'gauss') 115 | self.z_dance = z_dance_eps.mul(z_dance_std).add_(self.z_dance_mu) 116 | ### z_dance -> z_movements 117 | self.recon_z_movements_mu, self.recon_z_movements_logvar = self.dance_dec(self.z_dance) 118 | recon_z_movement_std = self.recon_z_movements_logvar.mul(0.5).exp_() 119 | recon_z_movement_eps = self.get_z_random(recon_z_movement_std.size(0), recon_z_movement_std.size(1), 'gauss') 120 | self.recon_z_movements = recon_z_movement_eps.mul(recon_z_movement_std).add_(self.recon_z_movements_mu) 121 | ### z_movements -> stdpSeq 122 | self.recon_stdpSeq = self.stdp_dec(self.z_init, self.recon_z_movements) 123 | 124 | ### Music to z_dance to z_movements 125 | self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style) 126 | fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_() 127 | fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss') 128 | self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu) 129 | self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance) 130 | fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_() 131 | fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss') 132 | self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu) 133 | 134 | fake_z_movementSeq_mu = self.fake_z_movements_mu.view(batchsize, -1, self.fake_z_movements_mu.shape[1]) 135 | fake_z_movementSeq_logvar = self.fake_z_movements_logvar.view(batchsize, -1, self.fake_z_movements_logvar.shape[1]) 136 | self.fake_z_movementSeq = torch.cat((fake_z_movementSeq_mu, fake_z_movementSeq_logvar),2) 137 | 138 | def backward_D(self): 139 | #real_movements = torch.cat((self.z_movementSeq_mu, self.z_movementSeq_logvar),2) 140 | tmp_recon_mu = self.recon_z_movements_mu.view(self.batchsize, -1, self.z_movements.shape[1]) 141 | tmp_recon_logvar = self.recon_z_movements_logvar.view(self.batchsize, -1, self.z_movements.shape[1]) 142 | real_movements = torch.cat((tmp_recon_mu, tmp_recon_logvar),2) 143 | fake_movements = self.fake_z_movementSeq 144 | 145 | real_labels,_ = self.danceAud_dis(real_movements.detach(), self.aud) 146 | fake_labels,_ = self.danceAud_dis(fake_movements.detach(), self.aud) 147 | 148 | ones = self.ones_like(real_labels) 149 | zeros = self.zeros_like(fake_labels) 150 | 151 | self.loss_dis_true = self.gan_criterion(real_labels, ones) 152 | self.loss_dis_fake = self.gan_criterion(fake_labels, zeros) 153 | self.loss_dis = (self.loss_dis_true + self.loss_dis_fake)*self.args.lambda_gan 154 | 155 | real_dance = torch.cat((self.z_dance_mu, self.z_dance_logvar), 1) 156 | fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1) 157 | real_labels, _ = self.zdance_dis(real_dance.detach(), self.aud) 158 | fake_labels, _ = self.zdance_dis(fake_dance.detach(), self.aud) 159 | ones = self.ones_like(real_labels) 160 | zeros = self.zeros_like(fake_labels) 161 | 162 | self.loss_zdis_true = self.gan_criterion(real_labels, ones) 163 | self.loss_zdis_fake = self.gan_criterion(fake_labels, zeros) 164 | self.loss_dis += (self.loss_zdis_true + self.loss_zdis_fake)*self.args.lambda_gan 165 | 166 | 167 | def backward_danceED(self): 168 | # z_dance KL 169 | kl_element = self.z_dance_mu.pow(2).add_(self.z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.z_dance_logvar) 170 | self.loss_kl_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance)) 171 | kl_element = self.fake_z_dance_mu.pow(2).add_(self.fake_z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_dance_logvar) 172 | self.loss_kl_fake_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance)) 173 | # z_movement KL 174 | kl_element = self.recon_z_movements_mu.pow(2).add_(self.recon_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.recon_z_movements_logvar) 175 | self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl)) 176 | kl_element = self.fake_z_movements_mu.pow(2).add_(self.fake_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_movements_logvar) 177 | self.loss_kl_fake_z_movements = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl)) 178 | # z_movement reconstruction 179 | self.loss_l1_z_movement_mu = self.l1_criterion(self.recon_z_movements_mu, self.z_movement_mus) * self.args.lambda_zmovements_recon 180 | self.loss_l1_z_movement_logvar = self.l1_criterion(self.recon_z_movements_logvar, self.z_movement_logvars) * self.args.lambda_zmovements_recon 181 | 182 | # stdp reconstruction 183 | self.loss_l1_stdpSeq = self.l1_criterion(self.recon_stdpSeq, self.stdpSeq) * self.args.lambda_stdpSeq_recon 184 | 185 | # Music2Dance GAN 186 | fake_movements = self.fake_z_movementSeq 187 | fake_labels, _ = self.danceAud_dis(fake_movements, self.aud) 188 | 189 | ones = self.ones_like(fake_labels) 190 | self.loss_gen = self.gan_criterion(fake_labels, ones) * self.args.lambda_gan 191 | 192 | fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1) 193 | fake_labels, _ = self.zdance_dis(fake_dance, self.aud) 194 | ones = self.ones_like(fake_labels) 195 | self.loss_gen += self.gan_criterion(fake_labels, ones) * self.args.lambda_gan 196 | 197 | self.loss = self.loss_kl_z_movement + self.loss_kl_z_dance + self.loss_l1_z_movement_mu + self.loss_l1_z_movement_logvar + self.loss_l1_stdpSeq + self.loss_gen 198 | 199 | def backward_info_ondance(self): 200 | real_pred = self.dance_reg(self.z_dance) 201 | fake_pred = self.dance_reg(self.fake_z_dance) 202 | self.loss_info_real = self.mse_criterion(real_pred, self.aud_style) 203 | self.loss_info_fake = self.mse_criterion(fake_pred, self.aud_style) 204 | self.loss_info = self.loss_info_real + self.loss_info_fake 205 | 206 | def zero_grad(self, opt_list): 207 | for opt in opt_list: 208 | opt.zero_grad() 209 | 210 | def clip_norm(self, network_list): 211 | for network in network_list: 212 | clip_grad_norm_(network.parameters(), 0.5) 213 | 214 | def step(self, opt_list): 215 | for opt in opt_list: 216 | opt.step() 217 | 218 | def update(self): 219 | self.zero_grad([self.opt_danceAud_dis, self.opt_zdance_dis]) 220 | self.backward_D() 221 | self.loss_dis.backward(retain_graph=True) 222 | self.clip_norm([self.danceAud_dis, self.zdance_dis]) 223 | self.step([self.opt_danceAud_dis, self.opt_zdance_dis]) 224 | 225 | self.zero_grad([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec]) 226 | self.backward_danceED() 227 | self.loss.backward(retain_graph=True) 228 | self.clip_norm([self.dance_enc, self.dance_dec, self.audstyle_enc, self.stdp_dec]) 229 | self.step([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec]) 230 | 231 | self.zero_grad([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec]) 232 | self.backward_info_ondance() 233 | self.loss_info.backward() 234 | self.clip_norm([self.dance_enc, self.audstyle_enc, self.dance_reg, self.stdp_dec]) 235 | self.step([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec]) 236 | 237 | def test_final(self, initpose, aud, n, thr=0): 238 | self.cuda() 239 | self.movement_enc.eval() 240 | self.stdp_dec.eval() 241 | self.initp_enc.eval() 242 | self.dance_enc.eval() 243 | self.dance_dec.eval() 244 | self.aud_enc.eval() 245 | self.audstyle_enc.eval() 246 | aud_style = self.aud_enc.get_style(aud).detach() 247 | 248 | self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style) 249 | fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_() 250 | fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss') 251 | self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu) 252 | 253 | self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance, length=3) 254 | fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_() 255 | fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss') 256 | self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu) 257 | 258 | fake_stdpSeq=[] 259 | for i in range(n): 260 | z_init_mus, z_init_logvars = self.initp_enc(initpose) 261 | z_init_stds = z_init_logvars.mul(0.5).exp_() 262 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss') 263 | z_init = z_init_epss.mul(z_init_stds).add_(z_init_mus) 264 | fake_stdp = self.stdp_dec(z_init, self.fake_z_movements[i:i+1]) 265 | fake_stdpSeq.append(fake_stdp) 266 | initpose = fake_stdp[:,-1,:] 267 | fake_stdpSeq = torch.cat(fake_stdpSeq, dim=0) 268 | flag = False 269 | for i in range(n): 270 | s = fake_stdpSeq[i] 271 | diff = torch.abs(s[1:]-s[:-1]) 272 | diffsum = torch.sum(diff) 273 | if diffsum.cpu().detach().numpy() < thr: 274 | flag = True 275 | 276 | if flag: 277 | return None 278 | else: 279 | return fake_stdpSeq.cpu().detach().numpy() 280 | 281 | 282 | def resume(self, model_dir, train=True): 283 | checkpoint = torch.load(model_dir) 284 | self.dance_enc.load_state_dict(checkpoint['dance_enc']) 285 | self.dance_dec.load_state_dict(checkpoint['dance_dec']) 286 | self.audstyle_enc.load_state_dict(checkpoint['audstyle_enc']) 287 | self.stdp_dec.load_state_dict(checkpoint['stdp_dec']) 288 | self.movement_enc.load_state_dict(checkpoint['movement_enc']) 289 | if train: 290 | self.danceAud_dis.load_state_dict(checkpoint['danceAud_dis']) 291 | self.dance_reg.load_state_dict(checkpoint['dance_reg']) 292 | self.opt_dance_enc.load_state_dict(checkpoint['opt_dance_enc']) 293 | self.opt_dance_dec.load_state_dict(checkpoint['opt_dance_dec']) 294 | self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec']) 295 | self.opt_audstyle_enc.load_state_dict(checkpoint['opt_audstyle_enc']) 296 | self.opt_danceAud_dis.load_state_dict(checkpoint['opt_danceAud_dis']) 297 | self.opt_dance_reg.load_state_dict(checkpoint['opt_dance_reg']) 298 | return checkpoint['ep'], checkpoint['total_it'] 299 | 300 | def save(self, filename, ep, total_it): 301 | state = { 302 | 'stdp_dec': self.stdp_dec.state_dict(), 303 | 'movement_enc': self.movement_enc.state_dict(), 304 | 'dance_enc': self.dance_enc.state_dict(), 305 | 'dance_dec': self.dance_dec.state_dict(), 306 | 'audstyle_enc': self.audstyle_enc.state_dict(), 307 | 'danceAud_dis': self.danceAud_dis.state_dict(), 308 | 'zdance_dis': self.zdance_dis.state_dict(), 309 | 'dance_reg': self.dance_reg.state_dict(), 310 | 'opt_stdp_dec': self.opt_stdp_dec.state_dict(), 311 | 'opt_movement_enc': self.opt_movement_enc.state_dict(), 312 | 'opt_dance_enc': self.opt_dance_enc.state_dict(), 313 | 'opt_dance_dec': self.opt_dance_dec.state_dict(), 314 | 'opt_audstyle_enc': self.opt_audstyle_enc.state_dict(), 315 | 'opt_danceAud_dis': self.opt_danceAud_dis.state_dict(), 316 | 'opt_zdance_dis': self.opt_zdance_dis.state_dict(), 317 | 'opt_dance_reg': self.opt_dance_reg.state_dict(), 318 | 'ep': ep, 319 | 'total_it': total_it 320 | } 321 | torch.save(state, filename) 322 | return 323 | 324 | def cuda(self): 325 | if self.train: 326 | self.dance_reg.cuda() 327 | self.danceAud_dis.cuda() 328 | self.zdance_dis.cuda() 329 | self.stdp_dec.cuda() 330 | self.initp_enc.cuda() 331 | self.movement_enc.cuda() 332 | self.dance_enc.cuda() 333 | self.dance_dec.cuda() 334 | self.aud_enc.cuda() 335 | self.audstyle_enc.cuda() 336 | self.gan_criterion.cuda() 337 | 338 | def train(self, ep=0, it=0): 339 | self.cuda() 340 | for epoch in range(ep, self.args.num_epochs): 341 | self.movement_enc.train() 342 | self.stdp_dec.train() 343 | self.initp_enc.train() 344 | self.dance_enc.train() 345 | self.dance_dec.train() 346 | self.danceAud_dis.train() 347 | self.zdance_dis.train() 348 | self.audstyle_enc.train() 349 | self.dance_reg.train() 350 | self.aud_enc.eval() 351 | stdp_recon = 0 352 | 353 | for i, (stdpSeq, aud) in enumerate(self.data_loader): 354 | stdpSeq, aud = stdpSeq.cuda().detach(), aud.cuda().detach() 355 | stdpSeq = stdpSeq.view(stdpSeq.shape[0]*stdpSeq.shape[1], stdpSeq.shape[2], stdpSeq.shape[3]) 356 | aud_style = self.aud_enc.get_style(aud).detach() 357 | 358 | self.forward(stdpSeq, aud.shape[0], aud_style, aud) 359 | self.update() 360 | self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data 361 | self.logs['l_kl_zdance'] += self.loss_kl_z_dance.data 362 | self.logs['l_l1_zmovement_mu'] += self.loss_l1_z_movement_mu.data 363 | self.logs['l_l1_zmovement_logvar'] += self.loss_l1_z_movement_logvar.data 364 | self.logs['l_l1_stdpSeq'] += self.loss_l1_stdpSeq.data 365 | self.logs['l_kl_fake_zdance'] += self.loss_kl_fake_z_dance.data 366 | self.logs['l_kl_fake_zmovement'] += self.loss_kl_fake_z_movements 367 | self.logs['l_dis'] += self.loss_dis.data 368 | self.logs['l_dis_true'] += self.loss_dis_true.data 369 | self.logs['l_dis_fake'] += self.loss_dis_fake.data 370 | self.logs['l_gen'] += self.loss_gen.data 371 | self.logs['l_info'] += self.loss_info 372 | self.logs['l_info_real'] += self.loss_info_real 373 | self.logs['l_info_fake'] += self.loss_info_fake 374 | 375 | print('Epoch:{:3} Iter{}/{}\tl_l1_zmovement mu{:.3f} logvar{:.3f}\tl_l1_stdpSeq {:.3f}\tl_kl_dance {:.3f}\tl_kl_movement {:.3f}\n'.format(epoch, i, len(self.data_loader), 376 | self.loss_l1_z_movement_mu, self.loss_l1_z_movement_logvar, self.loss_l1_stdpSeq, self.loss_kl_z_dance, self.loss_kl_z_movement) + 377 | '\t\t\tl_kl_f_dance {:.3f}\tl_dis {:.3f} {:.3f}\tl_gen {:.3f}'.format(self.loss_kl_fake_z_dance, self.loss_dis_true, self.loss_dis_fake, self.loss_gen)) 378 | 379 | it += 1 380 | if it % self.log_interval == 0: 381 | for tag, value in self.logs.items(): 382 | self.logger.scalar_summary(tag, value/self.log_interval, it) 383 | self.logs = self.init_logs() 384 | if epoch % self.snapshot_ep == 0: 385 | self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it) 386 | -------------------------------------------------------------------------------- /model_decomp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | import os 8 | import time 9 | import numpy as np 10 | import random 11 | import math 12 | 13 | import torch 14 | from torch import nn 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | from torch.nn.utils import clip_grad_norm_ 18 | 19 | from utils import Logger 20 | 21 | if torch.cuda.is_available(): 22 | T = torch.cuda 23 | else: 24 | T = torch 25 | 26 | class Trainer_Decomp(object): 27 | def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, stdp_dec, args=None): 28 | self.data_loader = data_loader 29 | self.initp_enc = initp_enc 30 | self.initp_dec = initp_dec 31 | self.movement_enc = movement_enc 32 | self.stdp_dec = stdp_dec 33 | 34 | 35 | self.args = args 36 | if args.train: 37 | self.logger = Logger(args.log_dir) 38 | self.logs = self.init_logs() 39 | self.log_interval = args.log_interval 40 | self.snapshot_ep = args.snapshot_ep 41 | self.snapshot_dir = args.snapshot_dir 42 | 43 | self.opt_initp_enc = torch.optim.Adam(self.initp_enc.parameters(), lr=args.lr) 44 | self.opt_initp_dec = torch.optim.Adam(self.initp_dec.parameters(), lr=args.lr) 45 | self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr) 46 | self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr) 47 | 48 | self.latent_dropout = nn.Dropout(p=args.latent_dropout) 49 | self.l1_criterion = torch.nn.L1Loss() 50 | self.gan_criterion = nn.BCEWithLogitsLoss() 51 | 52 | 53 | def init_logs(self): 54 | return {'l_kl_zinit':0, 'l_kl_zmovement':0, 'l_l1_stdp':0, 'l_l1_cross_stdp':0, 'l_dist_zmovement':0, 55 | 'l_l1_initp':0, 'l_l1_initp_con':0, 56 | 'kld_coef':0 57 | } 58 | 59 | def get_z_random(self, batchSize, nz, random_type='gauss'): 60 | z = torch.randn(batchSize, nz).cuda() 61 | return z 62 | 63 | @staticmethod 64 | def ones_like(tensor, val=1.): 65 | return T.FloatTensor(tensor.size()).fill_(val) 66 | 67 | @staticmethod 68 | def zeros_like(tensor, val=0.): 69 | return T.FloatTensor(tensor.size()).fill_(val) 70 | 71 | 72 | def random_generate_stdp(self, init_p): 73 | self.pose_0 = init_p 74 | self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0) 75 | z_init_std = self.z_init_logvar.mul(0.5).exp_() 76 | z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss') 77 | self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu) 78 | self.z_random_movement = self.get_z_random(self.z_init.size(0), 512, 'gauss') 79 | self.fake_stdpose = self.stdp_dec(self.z_init, self.z_random_movement) 80 | return self.fake_stdpose 81 | 82 | def forward(self, stdpose1, stdpose2): 83 | self.stdpose1 = stdpose1 84 | self.stdpose2 = stdpose2 85 | 86 | # stdpose -> stdpose[0] -> z_init 87 | self.pose1_0 = stdpose1[:,0,:] 88 | self.pose2_0 = stdpose2[:,0,:] 89 | self.poses_0 = torch.cat((self.pose1_0, self.pose2_0), 0) 90 | self.z_init_mus, self.z_init_logvars = self.initp_enc(self.poses_0) 91 | z_init_stds = self.z_init_logvars.mul(0.5).exp_() 92 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss') 93 | self.z_inits = z_init_epss.mul(z_init_stds).add_(self.z_init_mus) 94 | self.z_init1, self.z_init2 = torch.split(self.z_inits, self.stdpose1.size(0), dim=0) 95 | 96 | # stdpose -> z_movement 97 | stdposes = torch.cat((stdpose1, stdpose2), 0) 98 | self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdposes) 99 | z_movement_stds = self.z_movement_logvars.mul(0.5).exp_() 100 | z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss') 101 | self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus) 102 | self.z_movement1, self.z_movement2 = torch.split(self.z_movements, self.stdpose1.size(0), dim=0) 103 | 104 | # zinit1+zmovement1->stdpose1 zinit2+zmovement2->stdpose2 105 | self.recon_stdpose1 = self.stdp_dec(self.z_init1, self.z_movement1) 106 | self.recon_stdpose2 = self.stdp_dec(self.z_init2, self.z_movement2) 107 | 108 | # zinit1+zmovement2->stdpose1 zinit2+zmovement1->stdpose2 109 | self.recon_stdpose1_cross = self.stdp_dec(self.z_init1, self.z_movement2) 110 | self.recon_stdpose2_cross = self.stdp_dec(self.z_init2, self.z_movement1) 111 | 112 | # z_init -> \hat{stdpose[0]} 113 | self.recon_pose1_0 = self.initp_dec(self.z_init1) 114 | self.recon_pose2_0 = self.initp_dec(self.z_init2) 115 | 116 | # single pose reconstruction 117 | randomlist = np.random.permutation(31)[:4] 118 | singlepose = [] 119 | for r in randomlist: 120 | singlepose.append(self.stdpose1[:,r,:]) 121 | self.singleposes = torch.cat(singlepose, dim=0).detach() 122 | self.z_single_mus, self.z_single_logvars = self.initp_enc(self.singleposes) 123 | z_single_stds = self.z_single_logvars.mul(0.5).exp_() 124 | z_single_epss = self.get_z_random(z_single_stds.size(0), z_single_stds.size(1), 'gauss') 125 | z_single = z_single_epss.mul(z_single_stds).add_(self.z_single_mus) 126 | self.recon_singleposes = self.initp_dec(z_single) 127 | 128 | def backward_initp_ED(self): 129 | # z_init KL 130 | kl_element = self.z_init_mus.pow(2).add_(self.z_init_logvars.exp()).mul_(-1).add_(1).add_(self.z_init_logvars) 131 | self.loss_kl_z_init = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl)) 132 | 133 | # initpose reconstruction 134 | self.loss_l1_initp = self.l1_criterion(self.recon_singleposes, self.singleposes) * self.args.lambda_initp_recon 135 | 136 | self.loss_initp = self.loss_kl_z_init + self.loss_l1_initp 137 | 138 | def backward_movement_ED(self): 139 | # z_movement KL 140 | kl_element = self.z_movement_mus.pow(2).add_(self.z_movement_logvars.exp()).mul_(-1).add_(1).add_(self.z_movement_logvars) 141 | #self.loss_kl_z_movement = torch.mean(kl_element).mul_(-0.5) * self.args.lambda_kl 142 | self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl)) 143 | 144 | # stdpose self reconstruction 145 | loss_l1_stdp1 = self.l1_criterion(self.recon_stdpose1, self.stdpose1) * self.args.lambda_stdp_recon 146 | loss_l1_stdp2 = self.l1_criterion(self.recon_stdpose2, self.stdpose2) * self.args.lambda_stdp_recon 147 | self.loss_l1_stdp = loss_l1_stdp1 + loss_l1_stdp2 148 | 149 | # stdpose cross reconstruction 150 | loss_l1_cross_stdp1 = self.l1_criterion(self.recon_stdpose1_cross, self.stdpose1) * self.args.lambda_stdp_recon 151 | loss_l1_cross_stdp2 = self.l1_criterion(self.recon_stdpose2_cross, self.stdpose2) * self.args.lambda_stdp_recon 152 | self.loss_l1_cross_stdp = loss_l1_cross_stdp1 + loss_l1_cross_stdp2 153 | 154 | # Movement dist 155 | self.loss_dist_z_movement = torch.mean(torch.abs(self.z_movement1-self.z_movement2)) * self.args.lambda_dist_z_movement 156 | 157 | self.loss_movement = self.loss_kl_z_movement + self.loss_l1_stdp + self.loss_l1_cross_stdp + self.loss_dist_z_movement 158 | 159 | 160 | def update(self): 161 | self.opt_initp_enc.zero_grad() 162 | self.opt_initp_dec.zero_grad() 163 | self.opt_movement_enc.zero_grad() 164 | self.opt_stdp_dec.zero_grad() 165 | self.backward_initp_ED() 166 | self.backward_movement_ED() 167 | self.g_loss = self.loss_initp + self.loss_movement 168 | self.g_loss.backward(retain_graph=True) 169 | clip_grad_norm_(self.movement_enc.parameters(), 0.5) 170 | clip_grad_norm_(self.stdp_dec.parameters(), 0.5) 171 | self.opt_initp_enc.step() 172 | self.opt_initp_dec.step() 173 | self.opt_movement_enc.step() 174 | self.opt_stdp_dec.step() 175 | 176 | 177 | def save(self, filename, ep, total_it): 178 | state = { 179 | 'stdp_dec': self.stdp_dec.state_dict(), 180 | 'movement_enc': self.movement_enc.state_dict(), 181 | 'initp_enc': self.initp_enc.state_dict(), 182 | 'initp_dec': self.initp_dec.state_dict(), 183 | 'opt_stdp_dec': self.opt_stdp_dec.state_dict(), 184 | 'opt_movement_enc': self.opt_movement_enc.state_dict(), 185 | 'opt_initp_enc': self.opt_initp_enc.state_dict(), 186 | 'opt_initp_dec': self.opt_initp_dec.state_dict(), 187 | 'ep': ep, 188 | 'total_it': total_it 189 | } 190 | torch.save(state, filename) 191 | return 192 | 193 | def resume(self, model_dir, train=True): 194 | checkpoint = torch.load(model_dir) 195 | # weight 196 | self.stdp_dec.load_state_dict(checkpoint['stdp_dec']) 197 | self.movement_enc.load_state_dict(checkpoint['movement_enc']) 198 | self.initp_enc.load_state_dict(checkpoint['initp_enc']) 199 | self.initp_dec.load_state_dict(checkpoint['initp_dec']) 200 | # optimizer 201 | if train: 202 | self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec']) 203 | self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc']) 204 | self.opt_initp_enc.load_state_dict(checkpoint['opt_initp_enc']) 205 | self.opt_initp_dec.load_state_dict(checkpoint['opt_initp_dec']) 206 | return checkpoint['ep'], checkpoint['total_it'] 207 | 208 | def kld_coef(self, i): 209 | return float(1/(1+np.exp(-0.0005*(i-15000)))) #v3 210 | 211 | 212 | def generate_stdp_sequence(self, initpose, aud, num_stdp): 213 | self.initp_enc.cuda() 214 | self.initp_dec.cuda() 215 | self.movement_enc.cuda() 216 | self.stdp_dec.cuda() 217 | self.initp_enc.eval() 218 | self.initp_dec.eval() 219 | self.movement_enc.eval() 220 | self.stdp_dec.eval() 221 | initpose = initpose.cuda() 222 | 223 | aud_style = self.aud_enc.get_style(aud) 224 | 225 | stdp_seq = [] 226 | cnt = 0 227 | #for i in range(num_stdp): 228 | while not cnt == num_stdp: 229 | if cnt==0: 230 | z_inits = self.get_z_random(1, 10, 'gauss') 231 | else: 232 | z_init_mus, z_init_logvars = self.initp_enc(initpose) 233 | z_init_stds = z_init_logvars.mul(0.5).exp_() 234 | z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss') 235 | z_inits = z_init_epss.mul(z_init_stds).add_(z_init_mus) 236 | 237 | z_audstyle_mu, z_audstyle_logvar = self.audstyle_enc(aud_style) 238 | z_as_std = z_audstyle_logvar.mul(0.5).exp_() 239 | z_as_eps = self.get_z_random(z_as_std.size(0), z_as_std.size(1), 'gauss') 240 | z_audstyle = z_as_eps.mul(z_as_std).add_(z_audstyle_mu) 241 | if random.randint(0,5)==100: 242 | z_audstyle = self.get_z_random(z_inits.size(0), 512, 'gauss') 243 | 244 | fake_stdpose = self.stdp_dec(z_inits, z_audstyle) 245 | 246 | s = fake_stdpose[0] 247 | diff = torch.abs(s[1:]-s[:-1]) 248 | diffsum = torch.sum(diff) 249 | if diffsum.cpu().detach().numpy() < 70: 250 | continue 251 | 252 | cnt += 1 253 | stdp_seq.append(fake_stdpose.cpu().detach().numpy()) 254 | initpose = fake_stdpose[:,-1,:] 255 | return stdp_seq 256 | 257 | 258 | def cuda(self): 259 | self.initp_enc.cuda() 260 | self.initp_dec.cuda() 261 | self.movement_enc.cuda() 262 | self.stdp_dec.cuda() 263 | self.l1_criterion.cuda() 264 | 265 | def train(self, ep=0, it=0): 266 | self.cuda() 267 | 268 | full_kl = self.args.lambda_kl 269 | kl_w = 0 270 | kl_step = 0.05 271 | best_stdp_recon = 100 272 | for epoch in range(ep, self.args.num_epochs): 273 | self.initp_enc.train() 274 | self.initp_dec.train() 275 | self.movement_enc.train() 276 | self.stdp_dec.train() 277 | stdp_recon = 0 278 | for i, (stdpose, stdpose2) in enumerate(self.data_loader): 279 | self.args.lambda_kl = full_kl*self.kld_coef(it) 280 | stdpose, stdpose2 = stdpose.cuda().detach(), stdpose2.cuda().detach() 281 | 282 | self.forward(stdpose, stdpose2) 283 | self.update() 284 | self.logs['l_kl_zinit'] += self.loss_kl_z_init.data 285 | self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data 286 | self.logs['l_l1_initp'] += self.loss_l1_initp.data 287 | self.logs['l_l1_stdp'] += self.loss_l1_stdp.data 288 | self.logs['l_l1_cross_stdp'] += self.loss_l1_cross_stdp.data 289 | self.logs['l_dist_zmovement'] += self.loss_dist_z_movement.data 290 | self.logs['kld_coef'] += self.args.lambda_kl 291 | 292 | print('Epoch:{:3} Iter{}/{}\tl_l1_initp {:.3f}\tl_l1_stdp {:.3f}\tl_l1_cross_stdp {:.3f}\tl_dist_zmove {:.3f}\tl_kl_zinit {:.3f}\t l_kl_zmove {:.3f}'.format( 293 | epoch, i, len(self.data_loader), self.loss_l1_initp, self.loss_l1_stdp, self.loss_l1_cross_stdp, self.loss_dist_z_movement, self.loss_kl_z_init, self.loss_kl_z_movement)) 294 | 295 | it += 1 296 | if it % self.log_interval == 0: 297 | for tag, value in self.logs.items(): 298 | self.logger.scalar_summary(tag, value/self.log_interval, it) 299 | self.logs = self.init_logs() 300 | if epoch % self.snapshot_ep == 0: 301 | self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it) 302 | -------------------------------------------------------------------------------- /modulate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import os 9 | import numpy as np 10 | import librosa 11 | import utils 12 | 13 | 14 | def modulate(dance, beats, length): 15 | sec_interframe = 1/15 16 | 17 | beats_frame = np.around(beats) 18 | t_beat = beats_frame.astype(int) 19 | s_beat = np.arange(3,dance.shape[0],8) 20 | final_pose = np.zeros((length, 14, 2)) 21 | 22 | if t_beat[0] >3: 23 | final_pose[t_beat[0]-3:t_beat[0]] = dance[:3] 24 | else: 25 | final_pose[:t_beat[0]] = dance[:t_beat[0]] 26 | if t_beat[0]-3 > 0: 27 | final_pose[:t_beat[0]-3] = dance[0] 28 | for t in range(t_beat.shape[0]-1): 29 | begin = int(t_beat[t]) 30 | end = int(t_beat[t+1]) 31 | interval = end-begin 32 | if t==s_beat.shape[0]-1: 33 | rest = min(final_pose.shape[0]-begin-1, dance.shape[0]-s_beat[t]-1) 34 | break 35 | if t+1 < s_beat.shape[0] and s_beat[t+1]=3: 38 | final_pose[begin-s_beat[t]:begin] = dance[:s_beat[t]] 39 | final_pose[begin:end+1]=pose 40 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1) 41 | else: 42 | end = begin 43 | if t+1 < s_beat.shape[0]: 44 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1) 45 | else: 46 | print(t_beat.shape, s_beat.shape, t) 47 | rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t]-1) 48 | if rest > 0: 49 | if t+1 < s_beat.shape[0]: 50 | final_pose[end+1:end+1+rest] = dance[s_beat[t+1]+1:s_beat[t+1]+1+rest] 51 | else: 52 | final_pose[end+1:end+1+rest] = dance[s_beat[t]+1:s_beat[t]+1+rest] 53 | 54 | return final_pose 55 | 56 | def get_pose(pose, n): 57 | t_pose = np.zeros((n, 14, 2)) 58 | if n==11: 59 | t_pose[0] = pose[0] 60 | t_pose[1] = (pose[0]*1+pose[1]*4)/5 61 | t_pose[2] = (pose[1]*2+pose[2]*3)/5 62 | t_pose[3] = (pose[2]*3+pose[3]*2)/5 63 | t_pose[4] = (pose[3]*4+pose[4]*1)/5 64 | t_pose[5] = pose[4] 65 | t_pose[6] = (pose[4]*1+pose[5]*4)/5 66 | t_pose[7] = (pose[5]*2+pose[6]*3)/5 67 | t_pose[8] = (pose[6]*3+pose[7]*2)/5 68 | t_pose[9] = (pose[7]*4+pose[8]*1)/5 69 | t_pose[10] = pose[8] 70 | elif n==10: 71 | t_pose[0] = pose[0] 72 | t_pose[1] = (pose[0]*1+pose[1]*8)/9 73 | t_pose[2] = (pose[1]*2+pose[2]*7)/9 74 | t_pose[3] = (pose[2]*3+pose[3]*6)/9 75 | t_pose[4] = (pose[3]*4+pose[4]*5)/9 76 | t_pose[5] = (pose[4]*5+pose[5]*4)/9 77 | t_pose[6] = (pose[5]*6+pose[6]*3)/9 78 | t_pose[7] = (pose[6]*7+pose[7]*2)/9 79 | t_pose[8] = (pose[7]*8+pose[8]*1)/9 80 | t_pose[9] = pose[8] 81 | elif n==12: 82 | t_pose[0] = pose[0] 83 | t_pose[1] = (pose[0]*3+pose[1]*8)/11 84 | t_pose[2] = (pose[1]*6+pose[2]*5)/11 85 | t_pose[3] = (pose[2]*9+pose[3]*2)/11 86 | t_pose[4] = (pose[2]*1+pose[3]*10)/11 87 | t_pose[5] = (pose[3]*4+pose[4]*7)/11 88 | t_pose[6] = (pose[4]*7+pose[5]*4)/11 89 | t_pose[7] = (pose[5]*10+pose[6]*1)/11 90 | t_pose[8] = (pose[5]*2+pose[6]*9)/11 91 | t_pose[9] = (pose[6]*5+pose[7]*6)/11 92 | t_pose[10] = (pose[7]*8+pose[8]*3)/11 93 | t_pose[11] = pose[8] 94 | elif n==13: 95 | t_pose[0] = pose[0] 96 | t_pose[1] = (pose[0]*1+pose[1]*2)/3 97 | t_pose[2] = (pose[1]*2+pose[2]*1)/3 98 | t_pose[3] = pose[2] 99 | t_pose[4] = (pose[2]*1+pose[3]*2)/3 100 | t_pose[5] = (pose[3]*2+pose[4]*1)/3 101 | t_pose[6] = pose[4] 102 | t_pose[7] = (pose[4]*1+pose[5]*2)/3 103 | t_pose[8] = (pose[5]*2+pose[6]*1)/3 104 | t_pose[9] = pose[6] 105 | t_pose[10] = (pose[6]*1+pose[7]*2)/3 106 | t_pose[11] = (pose[7]*2+pose[8]*1)/3 107 | t_pose[12] = pose[8] 108 | elif n==14: 109 | t_pose[0] = pose[0] 110 | t_pose[1] = (pose[0]*5+pose[1]*8)/13 111 | t_pose[2] = (pose[1]*10+pose[2]*3)/13 112 | t_pose[3] = (pose[1]*2+pose[2]*11)/13 113 | t_pose[4] = (pose[2]*7+pose[3]*6)/13 114 | t_pose[5] = (pose[3]*12+pose[4]*1)/13 115 | t_pose[6] = (pose[3]*4+pose[4]*9)/13 116 | t_pose[7] = (pose[4]*9+pose[5]*4)/13 117 | t_pose[8] = (pose[4]*12+pose[5]*1)/13 118 | t_pose[9] = (pose[5]*6+pose[6]*7)/13 119 | t_pose[10] = (pose[6]*11+pose[7]*2)/13 120 | t_pose[11] = (pose[6]*3+pose[7]*10)/13 121 | t_pose[12] = (pose[7]*8+pose[8]*5)/13 122 | t_pose[13] = pose[8] 123 | elif n==9: 124 | t_pose = pose 125 | elif n==8: 126 | t_pose[0] = pose[0] 127 | t_pose[1] = (pose[1]*6+pose[2]*1)/7 128 | t_pose[2] = (pose[2]*5+pose[3]*2)/7 129 | t_pose[3] = (pose[3]*4+pose[4]*3)/7 130 | t_pose[4] = (pose[4]*3+pose[5]*4)/7 131 | t_pose[5] = (pose[5]*2+pose[6]*5)/7 132 | t_pose[6] = (pose[6]*1+pose[7]*6)/7 133 | t_pose[7] = pose[8] 134 | elif n==7: 135 | t_pose[0] = pose[0] 136 | t_pose[1] = (pose[1]*2+pose[2]*1)/3 137 | t_pose[2] = (pose[2]*1+pose[3]*2)/3 138 | t_pose[3] = pose[4] 139 | t_pose[4] = (pose[5]*2+pose[6]*1)/3 140 | t_pose[5] = (pose[6]*1+pose[7]*2)/3 141 | t_pose[6] = pose[8] 142 | elif n==6: 143 | t_pose[0] = pose[0] 144 | t_pose[1] = (pose[1]*2+pose[2]*3)/5 145 | t_pose[2] = (pose[3]*4+pose[4]*1)/5 146 | t_pose[3] = (pose[4]*1+pose[5]*4)/5 147 | t_pose[4] = (pose[6]*3+pose[7]*2)/5 148 | t_pose[5] = pose[8] 149 | elif n<6: 150 | t_pose[0] = pose[0] 151 | t_pose[n-1] = pose[8] 152 | for i in range(1,n-1): 153 | t_pose[i] = pose[4] 154 | elif n>14: 155 | t_pose[0] = pose[0] 156 | t_pose[n-1] = pose[8] 157 | for i in range(1, n-1): 158 | k = int(8/(n-1)*i) 159 | t_pose[i] = t_pose[k] 160 | else: 161 | print('NOT IMPLEMENT {}'.format(n)) 162 | 163 | return t_pose 164 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.utils.data 12 | from torch.autograd import Variable 13 | 14 | import numpy as np 15 | 16 | if torch.cuda.is_available(): 17 | T = torch.cuda 18 | else: 19 | T = torch 20 | 21 | ########################################################### 22 | ########## 23 | ########## Stage 1: Movement 24 | ########## 25 | ########################################################### 26 | class InitPose_Enc(nn.Module): 27 | def __init__(self, pose_size, dim_z_init): 28 | super(InitPose_Enc, self).__init__() 29 | nf = 64 30 | #nf = 32 31 | self.enc = nn.Sequential( 32 | nn.Linear(pose_size, nf), 33 | nn.LayerNorm(nf), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | nn.Linear(nf, nf), 36 | nn.LayerNorm(nf), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | ) 39 | self.mean = nn.Sequential( 40 | nn.Linear(nf,dim_z_init), 41 | ) 42 | self.std = nn.Sequential( 43 | nn.Linear(nf,dim_z_init), 44 | ) 45 | def forward(self, pose): 46 | enc = self.enc(pose) 47 | return self.mean(enc), self.std(enc) 48 | 49 | class InitPose_Dec(nn.Module): 50 | def __init__(self, pose_size, dim_z_init): 51 | super(InitPose_Dec, self).__init__() 52 | nf = 64 53 | #nf = dim_z_init 54 | self.dec = nn.Sequential( 55 | nn.Linear(dim_z_init, nf), 56 | nn.LayerNorm(nf), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | nn.Linear(nf, nf), 59 | nn.LayerNorm(nf), 60 | nn.LeakyReLU(0.2, inplace=True), 61 | nn.Linear(nf,pose_size), 62 | ) 63 | def forward(self, z_init): 64 | return self.dec(z_init) 65 | 66 | class Movement_Enc(nn.Module): 67 | def __init__(self, pose_size, dim_z_movement, length, hidden_size, num_layers, bidirection=False): 68 | super(Movement_Enc, self).__init__() 69 | self.hidden_size = hidden_size 70 | self.bidirection = bidirection 71 | if bidirection: 72 | self.num_dir = 2 73 | else: 74 | self.num_dir = 1 75 | self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection) 76 | self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True) 77 | if bidirection: 78 | self.mean = nn.Sequential( 79 | nn.Linear(hidden_size*2,dim_z_movement), 80 | ) 81 | self.std = nn.Sequential( 82 | nn.Linear(hidden_size*2,dim_z_movement), 83 | ) 84 | else: 85 | ''' 86 | self.enc = nn.Sequential( 87 | nn.Linear(hidden_size, hidden_size//2), 88 | nn.LayerNorm(hidden_size//2), 89 | nn.ReLU(inplace=True), 90 | ) 91 | ''' 92 | self.mean = nn.Sequential( 93 | nn.Linear(hidden_size,dim_z_movement), 94 | ) 95 | self.std = nn.Sequential( 96 | nn.Linear(hidden_size,dim_z_movement), 97 | ) 98 | def forward(self, poses): 99 | num_samples = poses.shape[0] 100 | h_t = [self.init_h.repeat(1, num_samples, 1)] 101 | output, hidden = self.recurrent(poses, h_t[0]) 102 | if self.bidirection: 103 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1) 104 | else: 105 | output = output[:,-1,:] 106 | #enc = self.enc(output) 107 | #return self.mean(enc), self.std(enc) 108 | return self.mean(output), self.std(output) 109 | 110 | def getFeature(self, poses): 111 | num_samples = poses.shape[0] 112 | h_t = [self.init_h.repeat(1, num_samples, 1)] 113 | output, hidden = self.recurrent(poses, h_t[0]) 114 | if self.bidirection: 115 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1) 116 | else: 117 | output = output[:,-1,:] 118 | return output 119 | 120 | class StandardPose_Dec(nn.Module): 121 | def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidden_size, num_layers): 122 | super(StandardPose_Dec, self).__init__() 123 | self.length = length 124 | self.pose_size = pose_size 125 | self.hidden_size = hidden_size 126 | self.num_layers = num_layers 127 | #dim_z_init=0 128 | ''' 129 | self.z2init = nn.Sequential( 130 | nn.Linear(dim_z_init+dim_z_movement, hidden_size), 131 | nn.LayerNorm(hidden_size), 132 | nn.ReLU(True), 133 | nn.Linear(hidden_size, num_layers*hidden_size) 134 | ) 135 | ''' 136 | self.z2init = nn.Sequential( 137 | nn.Linear(dim_z_init+dim_z_movement, num_layers*hidden_size) 138 | ) 139 | self.recurrent = nn.GRU(dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True) 140 | self.pose_g = nn.Sequential( 141 | nn.Linear(hidden_size, hidden_size), 142 | nn.LayerNorm(hidden_size), 143 | nn.ReLU(True), 144 | nn.Linear(hidden_size, pose_size) 145 | ) 146 | 147 | def forward(self, z_init, z_movement): 148 | h_init = self.z2init(torch.cat((z_init, z_movement), 1)) 149 | #h_init = self.z2init(z_movement) 150 | h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size) 151 | z_movements = z_movement.view(z_movement.size(0),1,z_movement.size(1)).repeat(1, self.length, 1) 152 | z_m_t, _ = self.recurrent(z_movements, h_init) 153 | z_m = z_m_t.contiguous().view(-1, self.hidden_size) 154 | poses = self.pose_g(z_m) 155 | poses = poses.view(z_movement.shape[0], self.length, self.pose_size) 156 | return poses 157 | 158 | class StandardPose_Dis(nn.Module): 159 | def __init__(self, pose_size, length): 160 | super(StandardPose_Dis, self).__init__() 161 | self.pose_size = pose_size 162 | self.length = length 163 | nd = 1024 164 | self.main = nn.Sequential( 165 | nn.Linear(length*pose_size, nd), 166 | nn.LayerNorm(nd), 167 | nn.LeakyReLU(0.2, inplace=True), 168 | nn.Linear(nd,nd//2), 169 | nn.LayerNorm(nd//2), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | nn.Linear(nd//2,nd//4), 172 | nn.LayerNorm(nd//4), 173 | nn.LeakyReLU(0.2, inplace=True), 174 | nn.Linear(nd//4, 1) 175 | ) 176 | def forward(self, pose_seq): 177 | pose_seq = pose_seq.view(-1, self.pose_size*self.length) 178 | return self.main(pose_seq).squeeze() 179 | 180 | ########################################################### 181 | ########## 182 | ########## Stage 2: Dance 183 | ########## 184 | ########################################################### 185 | class Dance_Enc(nn.Module): 186 | def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layers, bidirection=False): 187 | super(Dance_Enc, self).__init__() 188 | self.hidden_size = hidden_size 189 | self.bidirection = bidirection 190 | if bidirection: 191 | self.num_dir = 2 192 | else: 193 | self.num_dir = 1 194 | self.recurrent = nn.GRU(2*dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection) 195 | self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True) 196 | if bidirection: 197 | self.mean = nn.Sequential( 198 | nn.Linear(hidden_size*2,dim_z_dance), 199 | ) 200 | self.std = nn.Sequential( 201 | nn.Linear(hidden_size*2,dim_z_dance), 202 | ) 203 | else: 204 | self.mean = nn.Sequential( 205 | nn.Linear(hidden_size,dim_z_dance), 206 | ) 207 | self.std = nn.Sequential( 208 | nn.Linear(hidden_size,dim_z_dance), 209 | ) 210 | def forward(self, movements_mean, movements_std): 211 | movements = torch.cat((movements_mean, movements_std),2) 212 | num_samples = movements.shape[0] 213 | h_t = [self.init_h.repeat(1, num_samples, 1)] 214 | output, hidden = self.recurrent(movements, h_t[0]) 215 | if self.bidirection: 216 | output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1) 217 | else: 218 | output = output[:,-1,:] 219 | return self.mean(output), self.std(output) 220 | 221 | class Dance_Dec(nn.Module): 222 | def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers): 223 | super(Dance_Dec, self).__init__() 224 | #self.length = length 225 | self.num_layers = num_layers 226 | self.hidden_size = hidden_size 227 | self.dim_z_movement = dim_z_movement 228 | #dim_z_init=0 229 | ''' 230 | self.z2init = nn.Sequential( 231 | nn.Linear(dim_z_init+dim_z_movement, hidden_size), 232 | nn.LayerNorm(hidden_size), 233 | nn.ReLU(True), 234 | nn.Linear(hidden_size, num_layers*hidden_size) 235 | ) 236 | ''' 237 | self.z2init = nn.Sequential( 238 | nn.Linear(dim_z_dance, num_layers*hidden_size) 239 | ) 240 | self.recurrent = nn.GRU(dim_z_dance, hidden_size, num_layers=num_layers, batch_first=True) 241 | self.movement_g = nn.Sequential( 242 | nn.Linear(hidden_size, hidden_size), 243 | nn.LayerNorm(hidden_size), 244 | nn.ReLU(True), 245 | #nn.Linear(hidden_size, dim_z_movement) 246 | ) 247 | self.mean = nn.Sequential( 248 | nn.Linear(hidden_size,dim_z_movement), 249 | ) 250 | self.std = nn.Sequential( 251 | nn.Linear(hidden_size,dim_z_movement), 252 | ) 253 | 254 | def forward(self, z_dance, length=3): 255 | h_init = self.z2init(z_dance) 256 | h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size) 257 | z_dance = z_dance.view(z_dance.size(0),1,z_dance.size(1)).repeat(1, length, 1) 258 | z_d_t, _ = self.recurrent(z_dance, h_init) 259 | z_d = z_d_t.contiguous().view(-1, self.hidden_size) 260 | z_movement = self.movement_g(z_d) 261 | z_movement_mean, z_movement_std = self.mean(z_movement), self.std(z_movement) 262 | #z_movement = z_movement.view(z_dance.shape[0], length, self.dim_z_movement) 263 | return z_movement_mean, z_movement_std 264 | 265 | 266 | class DanceAud_Dis2(nn.Module): 267 | def __init__(self, aud_size, dim_z_movement, length=3): 268 | super(DanceAud_Dis2, self).__init__() 269 | self.aud_size = aud_size 270 | self.dim_z_movement = dim_z_movement 271 | self.length = length 272 | nd = 1024 273 | self.movementd = nn.Sequential( 274 | nn.Linear(dim_z_movement*2*length, nd), 275 | nn.LayerNorm(nd), 276 | nn.LeakyReLU(0.2, inplace=True), 277 | nn.Linear(nd,nd//2), 278 | nn.LayerNorm(nd//2), 279 | nn.LeakyReLU(0.2, inplace=True), 280 | nn.Linear(nd//2,nd//4), 281 | nn.LayerNorm(nd//4), 282 | nn.LeakyReLU(0.2, inplace=True), 283 | #nn.Linear(nd//4, 30), 284 | nn.Linear(nd//4, 30), 285 | ) 286 | 287 | self.audd = nn.Sequential( 288 | nn.Linear(aud_size, 30), 289 | nn.LayerNorm(30), 290 | nn.LeakyReLU(0.2, inplace=True), 291 | nn.Linear(30, 30), 292 | nn.LayerNorm(30), 293 | nn.LeakyReLU(0.2, inplace=True), 294 | ) 295 | self.jointd = nn.Sequential( 296 | nn.Linear(60, 1) 297 | ) 298 | 299 | def forward(self, movements, aud): 300 | if len(movements.shape) == 3: 301 | movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2]) 302 | m = self.movementd(movements) 303 | a = self.audd(aud) 304 | ma = torch.cat((m,a),1) 305 | 306 | return self.jointd(ma).squeeze(), None 307 | 308 | class DanceAud_Dis(nn.Module): 309 | def __init__(self, aud_size, dim_z_movement, length=3): 310 | super(DanceAud_Dis, self).__init__() 311 | self.aud_size = aud_size 312 | self.dim_z_movement = dim_z_movement 313 | self.length = length 314 | nd = 1024 315 | self.movementd = nn.Sequential( 316 | #nn.Linear(dim_z_movement*3, nd), 317 | nn.Linear(dim_z_movement*2, nd), 318 | nn.LayerNorm(nd), 319 | nn.LeakyReLU(0.2, inplace=True), 320 | nn.Linear(nd,nd//2), 321 | nn.LayerNorm(nd//2), 322 | nn.LeakyReLU(0.2, inplace=True), 323 | nn.Linear(nd//2,nd//4), 324 | nn.LayerNorm(nd//4), 325 | nn.LeakyReLU(0.2, inplace=True), 326 | #nn.Linear(nd//4, 30), 327 | nn.Linear(nd//4, 30), 328 | ) 329 | 330 | 331 | def forward(self, movements, aud): 332 | #movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2]) 333 | m = self.movementd(movements) 334 | return m.squeeze() 335 | #a = self.audd(aud) 336 | #ma = torch.cat((m,a),1) 337 | 338 | #return self.jointd(ma).squeeze() 339 | 340 | class DanceAud_InfoDis(nn.Module): 341 | def __init__(self, aud_size, dim_z_movement, length): 342 | super(DanceAud_InfoDis, self).__init__() 343 | self.aud_size = aud_size 344 | self.dim_z_movement = dim_z_movement 345 | self.length = length 346 | nd = 1024 347 | 348 | self.movementd = nn.Sequential( 349 | nn.Linear(dim_z_movement*6, nd*2), 350 | nn.LayerNorm(nd*2), 351 | nn.LeakyReLU(0.2, inplace=True), 352 | nn.Linear(nd*2, nd), 353 | nn.LayerNorm(nd), 354 | nn.LeakyReLU(0.2, inplace=True), 355 | nn.Linear(nd,nd//2), 356 | nn.LayerNorm(nd//2), 357 | nn.LeakyReLU(0.2, inplace=True), 358 | nn.Linear(nd//2,nd//4), 359 | nn.LayerNorm(nd//4), 360 | nn.LeakyReLU(0.2, inplace=True), 361 | ) 362 | 363 | self.dis = nn.Sequential( 364 | nn.Linear(nd//4, 1) 365 | ) 366 | self.reg = nn.Sequential( 367 | nn.Linear(nd//4, aud_size) 368 | ) 369 | 370 | def forward(self, movements, aud): 371 | movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2]) 372 | m = self.movementd(movements) 373 | return self.dis(m).squeeze(), self.reg(m) 374 | 375 | class Dance2Style(nn.Module): 376 | def __init__(self, dim_z_dance, aud_size): 377 | super(Dance2Style, self).__init__() 378 | self.aud_size = aud_size 379 | self.dim_z_dance = dim_z_dance 380 | nd = 512 381 | self.main = nn.Sequential( 382 | nn.Linear(dim_z_dance, nd), 383 | nn.LayerNorm(nd), 384 | nn.LeakyReLU(0.2, inplace=True), 385 | nn.Linear(nd, nd//2), 386 | nn.LayerNorm(nd//2), 387 | nn.LeakyReLU(0.2, inplace=True), 388 | nn.Linear(nd//2, nd//4), 389 | nn.LayerNorm(nd//4), 390 | nn.LeakyReLU(0.2, inplace=True), 391 | nn.Linear(nd//4, aud_size), 392 | ) 393 | def forward(self, zdance): 394 | return self.main(zdance) 395 | 396 | ########################################################### 397 | ########## 398 | ########## Audio 399 | ########## 400 | ########################################################### 401 | class AudioClassifier_rnn(nn.Module): 402 | def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_layers=1, h_init=2): 403 | super(AudioClassifier_rnn, self).__init__() 404 | self.dim_z_motion = dim_z_motion 405 | self.hidden_size = hidden_size 406 | self.pose_size = pose_size 407 | self.h_init = h_init 408 | self.num_layers = num_layers 409 | 410 | self.init_h = nn.Parameter(torch.randn(1, 1, self.hidden_size).type(T.FloatTensor), requires_grad=True) 411 | self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True) 412 | self.classifier = nn.Sequential( 413 | #nn.Dropout(p=0.2), 414 | nn.Linear(hidden_size, hidden_size), 415 | nn.ReLU(True), 416 | #nn.Dropout(p=0.2), 417 | nn.Linear(hidden_size, cls) 418 | ) 419 | def forward(self, poses): 420 | hidden, _ = self.recurrent(poses, self.init_h.repeat(1, poses.shape[0], 1)) 421 | last_hidden = hidden[:,-1,:] 422 | cls = self.classifier(last_hidden) 423 | return cls 424 | def get_style(self, auds): 425 | hidden, _ = self.recurrent(auds, self.init_h.repeat(1, auds.shape[0], 1)) 426 | last_hidden = hidden[:,-1,:] 427 | return last_hidden 428 | 429 | 430 | class Audstyle_Enc(nn.Module): 431 | def __init__(self, aud_size, dim_z, dim_noise=30): 432 | super(Audstyle_Enc, self).__init__() 433 | self.dim_noise = dim_noise 434 | nf = 64 435 | #nf = 32 436 | self.enc = nn.Sequential( 437 | nn.Linear(aud_size+dim_noise, nf), 438 | nn.LayerNorm(nf), 439 | nn.LeakyReLU(0.2, inplace=True), 440 | nn.Linear(nf, nf*2), 441 | nn.LayerNorm(nf*2), 442 | nn.LeakyReLU(0.2, inplace=True), 443 | ) 444 | self.mean = nn.Sequential( 445 | nn.Linear(nf*2,dim_z), 446 | ) 447 | self.std = nn.Sequential( 448 | nn.Linear(nf*2,dim_z), 449 | ) 450 | def forward(self, aud): 451 | noise = torch.randn(aud.shape[0], self.dim_noise).cuda() 452 | y = torch.cat((aud, noise), 1) 453 | enc = self.enc(y) 454 | return self.mean(enc), self.std(enc) 455 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import argparse 9 | 10 | 11 | class DecompOptions(): 12 | def __init__(self): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--name', default=None) 16 | 17 | parser.add_argument('--log_interval', type=int, default=50) 18 | parser.add_argument('--log_dir', default='./logs') 19 | parser.add_argument('--snapshot_ep', type=int, default=1) 20 | parser.add_argument('--snapshot_dir', default='./snapshot') 21 | parser.add_argument('--data_dir', default='./data') 22 | 23 | # Model architecture 24 | parser.add_argument('--pose_size', type=int, default=28) 25 | parser.add_argument('--dim_z_init', type=int, default=10) 26 | parser.add_argument('--dim_z_movement', type=int, default=512) 27 | parser.add_argument('--stdp_length', type=int, default=32) 28 | parser.add_argument('--movement_enc_bidirection', type=int, default=1) 29 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024) 30 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024) 31 | parser.add_argument('--movement_enc_num_layers', type=int, default=1) 32 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1) 33 | # Training 34 | parser.add_argument('--lr', type=float, default=2e-4) 35 | parser.add_argument('--batch_size', type=int, default=256) 36 | parser.add_argument('--num_epochs', type=int, default=1000) 37 | parser.add_argument('--latent_dropout', type=float, default=0.3) 38 | parser.add_argument('--lambda_kl', type=float, default=0.01) 39 | parser.add_argument('--lambda_initp_recon', type=float, default=1) 40 | parser.add_argument('--lambda_initp_consistency', type=float, default=1) 41 | parser.add_argument('--lambda_stdp_recon', type=float, default=1) 42 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1) 43 | # Others 44 | parser.add_argument('--num_workers', type=int, default=4) 45 | parser.add_argument('--resume', default=None) 46 | parser.add_argument('--dataset', type=int, default=0) 47 | parser.add_argument('--tolerance', action='store_true') 48 | 49 | self.parser = parser 50 | 51 | def parse(self): 52 | self.opt = self.parser.parse_args() 53 | args = vars(self.opt) 54 | return self.opt 55 | 56 | class CompOptions(): 57 | def __init__(self): 58 | parser = argparse.ArgumentParser() 59 | 60 | parser.add_argument('--name', default=None) 61 | 62 | parser.add_argument('--log_interval', type=int, default=50) 63 | parser.add_argument('--log_dir', default='./logs') 64 | parser.add_argument('--snapshot_ep', type=int, default=1) 65 | parser.add_argument('--snapshot_dir', default='./snapshot') 66 | parser.add_argument('--data_dir', default='./data') 67 | # Network architecture 68 | parser.add_argument('--pose_size', type=int, default=28) 69 | parser.add_argument('--aud_style_size', type=int, default=30) 70 | parser.add_argument('--dim_z_init', type=int, default=10) 71 | parser.add_argument('--dim_z_movement', type=int, default=512) 72 | parser.add_argument('--dim_z_dance', type=int, default=512) 73 | parser.add_argument('--stdp_length', type=int, default=32) 74 | parser.add_argument('--movement_enc_bidirection', type=int, default=1) 75 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024) 76 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024) 77 | parser.add_argument('--movement_enc_num_layers', type=int, default=1) 78 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1) 79 | parser.add_argument('--dance_enc_bidirection', type=int, default=0) 80 | parser.add_argument('--dance_enc_hidden_size', type=int, default=1024) 81 | parser.add_argument('--dance_enc_num_layers', type=int, default=1) 82 | parser.add_argument('--dance_dec_hidden_size', type=int, default=1024) 83 | parser.add_argument('--dance_dec_num_layers', type=int, default=1) 84 | # Training 85 | parser.add_argument('--lr', type=float, default=2e-4) 86 | parser.add_argument('--batch_size', type=int, default=256) 87 | parser.add_argument('--num_epochs', type=int, default=1500) 88 | parser.add_argument('--latent_dropout', type=float, default=0.3) 89 | parser.add_argument('--lambda_kl', type=float, default=0.01) 90 | parser.add_argument('--lambda_kl_dance', type=float, default=0.01) 91 | parser.add_argument('--lambda_gan', type=float, default=1) 92 | parser.add_argument('--lambda_zmovements_recon', type=float, default=1) 93 | parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10) 94 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1) 95 | # Other 96 | parser.add_argument('--num_workers', type=int, default=4) 97 | parser.add_argument('--decomp_snapshot', required=True) 98 | parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt') 99 | parser.add_argument('--resume', default=None) 100 | parser.add_argument('--dataset', type=int, default=2) 101 | self.parser = parser 102 | 103 | def parse(self): 104 | self.opt = self.parser.parse_args() 105 | args = vars(self.opt) 106 | return self.opt 107 | 108 | class TestOptions(): 109 | def __init__(self): 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument('--name', default=None) 113 | 114 | parser.add_argument('--log_interval', type=int, default=50) 115 | parser.add_argument('--log_dir', default='./logs') 116 | parser.add_argument('--snapshot_ep', type=int, default=1) 117 | parser.add_argument('--snapshot_dir', default='./snapshot') 118 | parser.add_argument('--data_dir', default='./data') 119 | # Network architecture 120 | parser.add_argument('--pose_size', type=int, default=28) 121 | parser.add_argument('--aud_style_size', type=int, default=30) 122 | parser.add_argument('--dim_z_init', type=int, default=10) 123 | parser.add_argument('--dim_z_movement', type=int, default=512) 124 | parser.add_argument('--dim_z_dance', type=int, default=512) 125 | parser.add_argument('--stdp_length', type=int, default=32) 126 | parser.add_argument('--movement_enc_bidirection', type=int, default=1) 127 | parser.add_argument('--movement_enc_hidden_size', type=int, default=1024) 128 | parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024) 129 | parser.add_argument('--movement_enc_num_layers', type=int, default=1) 130 | parser.add_argument('--stdp_dec_num_layers', type=int, default=1) 131 | parser.add_argument('--dance_enc_bidirection', type=int, default=0) 132 | parser.add_argument('--dance_enc_hidden_size', type=int, default=1024) 133 | parser.add_argument('--dance_enc_num_layers', type=int, default=1) 134 | parser.add_argument('--dance_dec_hidden_size', type=int, default=1024) 135 | parser.add_argument('--dance_dec_num_layers', type=int, default=1) 136 | # Training 137 | parser.add_argument('--lr', type=float, default=2e-4) 138 | parser.add_argument('--batch_size', type=int, default=256) 139 | parser.add_argument('--num_epochs', type=int, default=1500) 140 | parser.add_argument('--latent_dropout', type=float, default=0.3) 141 | parser.add_argument('--lambda_kl', type=float, default=0.01) 142 | parser.add_argument('--lambda_kl_dance', type=float, default=0.01) 143 | parser.add_argument('--lambda_gan', type=float, default=1) 144 | parser.add_argument('--lambda_zmovements_recon', type=float, default=1) 145 | parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10) 146 | parser.add_argument('--lambda_dist_z_movement', type=float, default=1) 147 | # Other 148 | parser.add_argument('--num_workers', type=int, default=4) 149 | parser.add_argument('--decomp_snapshot', required=True) 150 | parser.add_argument('--comp_snapshot', required=True) 151 | parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt') 152 | parser.add_argument('--dataset', type=int, default=2) 153 | parser.add_argument('--thr', type=int, default=50) 154 | parser.add_argument('--aud_path', type=str, required=True) 155 | parser.add_argument('--modulate', action='store_true') 156 | parser.add_argument('--out_file', type=str, default='demo/out.mp4') 157 | parser.add_argument('--out_dir', type=str, default='demo/out_frame') 158 | self.parser = parser 159 | 160 | def parse(self): 161 | self.opt = self.parser.parse_args() 162 | args = vars(self.opt) 163 | return self.opt 164 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import os 9 | import argparse 10 | import functools 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | 16 | from model_comp import * 17 | from networks import * 18 | from options import CompOptions 19 | from data import get_loader 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = CompOptions() 24 | args = parser.parse() 25 | #### Pretrain network from Decomp 26 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args) 27 | 28 | #### Comp network 29 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args) 30 | 31 | mean_pose=np.load('../onbeat/all_onbeat_mean.npy') 32 | std_pose=np.load('../onbeat/all_onbeat_std.npy') 33 | mean_aud=np.load('../onbeat/all_aud_mean.npy') 34 | std_aud=np.load('../onbeat/all_aud_std.npy') 35 | 36 | 37 | def loadDecompModel(args): 38 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init) 39 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length, 40 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers) 41 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length, 42 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1)) 43 | checkpoint = torch.load(args.decomp_snapshot) 44 | initp_enc.load_state_dict(checkpoint['initp_enc']) 45 | stdp_dec.load_state_dict(checkpoint['stdp_dec']) 46 | movement_enc.load_state_dict(checkpoint['movement_enc']) 47 | return initp_enc, stdp_dec, movement_enc 48 | 49 | def loadCompModel(args): 50 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 51 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1)) 52 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 53 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers) 54 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance) 55 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance) 56 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3) 57 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1) 58 | checkpoint = torch.load(args.resume) 59 | dance_enc.load_state_dict(checkpoint['dance_enc']) 60 | dance_dec.load_state_dict(checkpoint['dance_dec']) 61 | audstyle_enc.load_state_dict(checkpoint['audstyle_enc']) 62 | dance_reg.load_state_dict(checkpoint['dance_reg']) 63 | danceAud_dis.load_state_dict(checkpoint['danceAud_dis']) 64 | zdance_dis.load_state_dict(checkpoint['zdance_dis']) 65 | 66 | checkpoint2 = torch.load(args.neta_snapshot) 67 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3) 68 | neta_cls.load_state_dict(checkpoint2) 69 | 70 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls 71 | -------------------------------------------------------------------------------- /train_comp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import os 9 | import argparse 10 | import functools 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | 16 | from model_comp import * 17 | from networks import * 18 | from options import CompOptions 19 | from data import get_loader 20 | 21 | def loadDecompModel(args): 22 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init) 23 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length, 24 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers) 25 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length, 26 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1)) 27 | checkpoint = torch.load(args.decomp_snapshot) 28 | initp_enc.load_state_dict(checkpoint['initp_enc']) 29 | stdp_dec.load_state_dict(checkpoint['stdp_dec']) 30 | movement_enc.load_state_dict(checkpoint['movement_enc']) 31 | return initp_enc, stdp_dec, movement_enc 32 | 33 | def getCompNetworks(args): 34 | dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 35 | hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1)) 36 | dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement, 37 | hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers) 38 | audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance) 39 | dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance) 40 | danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3) 41 | zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1) 42 | 43 | checkpoint2 = torch.load(args.neta_snapshot) 44 | neta_cls = AudioClassifier_rnn(10,30,28,cls=3) 45 | neta_cls.load_state_dict(checkpoint2) 46 | 47 | return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls 48 | 49 | if __name__ == "__main__": 50 | parser = CompOptions() 51 | args = parser.parse() 52 | 53 | args.train = True 54 | 55 | if args.name is None: 56 | args.name = 'Comp' 57 | 58 | args.log_dir = os.path.join(args.log_dir, args.name) 59 | if not os.path.exists(args.log_dir): 60 | os.mkdir(args.log_dir) 61 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.name) 62 | if not os.path.exists(args.snapshot_dir): 63 | os.mkdir(args.snapshot_dir) 64 | 65 | data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir) 66 | 67 | #### Pretrain network from Decomp 68 | initp_enc, stdp_dec, movement_enc = loadDecompModel(args) 69 | 70 | #### Comp network 71 | dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = getCompNetworks(args) 72 | 73 | 74 | trainer = Trainer_Comp(data_loader, 75 | movement_enc = movement_enc, 76 | initp_enc = initp_enc, 77 | stdp_dec = stdp_dec, 78 | dance_enc = dance_enc, 79 | dance_dec = dance_dec, 80 | danceAud_dis = danceAud_dis, 81 | zdance_dis = zdance_dis, 82 | aud_enc=neta_cls, 83 | audstyle_enc=audstyle_enc, 84 | dance_reg=dance_reg, 85 | args = args 86 | ) 87 | 88 | if not args.resume is None: 89 | ep, it = trainer.resume(args.resume, True) 90 | else: 91 | ep, it = 0, 0 92 | trainer.train(ep, it) 93 | 94 | -------------------------------------------------------------------------------- /train_decomp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import os 9 | import argparse 10 | import functools 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | 16 | from model_decomp import * 17 | from networks import * 18 | from options import DecompOptions 19 | from data import get_loader 20 | 21 | def getDecompNetworks(args): 22 | initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init) 23 | initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init) 24 | movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length, 25 | hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1)) 26 | stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length, 27 | hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers) 28 | return initp_enc, initp_dec, movement_enc, stdp_dec 29 | 30 | if __name__ == "__main__": 31 | parser = DecompOptions() 32 | args = parser.parse() 33 | 34 | args.train = True 35 | 36 | if args.name is None: 37 | args.name = 'Decomp' 38 | 39 | args.log_dir = os.path.join(args.log_dir, args.name) 40 | if not os.path.exists(args.log_dir): 41 | os.mkdir(args.log_dir) 42 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.name) 43 | if not os.path.exists(args.snapshot_dir): 44 | os.mkdir(args.snapshot_dir) 45 | 46 | data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance) 47 | 48 | initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args) 49 | 50 | trainer = Trainer_Decomp(data_loader, 51 | initp_enc = initp_enc, 52 | initp_dec = initp_dec, 53 | movement_enc = movement_enc, 54 | stdp_dec = stdp_dec, 55 | args = args 56 | ) 57 | 58 | if not args.resume is None: 59 | ep, it = trainer.resume(args.resume, False) 60 | else: 61 | ep, it = 0, 0 62 | 63 | trainer.train(ep=ep, it=it) 64 | 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/Dancing2Music/License.txt 7 | 8 | import numpy as np 9 | import pickle 10 | import cv2 11 | import math 12 | import os 13 | import random 14 | import tensorflow as tf 15 | 16 | class Logger(object): 17 | def __init__(self, log_dir): 18 | self.writer = tf.summary.FileWriter(log_dir) 19 | 20 | def scalar_summary(self, tag, value, step): 21 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 22 | self.writer.add_summary(summary, step) 23 | 24 | def vis(poses, outdir, aud=None): 25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 28 | 29 | # find connection in the specified sequence, center 29 is in the position 15 30 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \ 31 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \ 32 | [1,16], [16,18], [3,17], [6,18]] 33 | 34 | neglect = [14,15,16,17] 35 | 36 | for t in range(poses.shape[0]): 37 | #break 38 | canvas = np.ones((256,500,3), np.uint8)*255 39 | 40 | thisPeak = poses[t] 41 | for i in range(18): 42 | if i in neglect: 43 | continue 44 | if thisPeak[i,0] == -1: 45 | continue 46 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1) 47 | 48 | for i in range(17): 49 | limbid = np.array(limbSeq[i])-1 50 | if limbid[0] in neglect or limbid[1] in neglect: 51 | continue 52 | X = thisPeak[[limbid[0],limbid[1]], 1] 53 | Y = thisPeak[[limbid[0],limbid[1]], 0] 54 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1: 55 | continue 56 | stickwidth = 4 57 | cur_canvas = canvas.copy() 58 | mX = np.mean(X) 59 | mY = np.mean(Y) 60 | #print(X, Y, limbid) 61 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 62 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 63 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1) 64 | #print(i, n, int(mY), int(mX), limbid, X, Y) 65 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 66 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 67 | if aud is not None: 68 | if aud[:,t] == 1: 69 | cv2.circle(canvas, (30, 30), 20, (0,0,255), -1) 70 | #canvas = cv2.copyMakeBorder(canvas,10,10,10,10,cv2.BORDER_CONSTANT,value=[255,0,0]) 71 | cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas) 72 | 73 | def vis2(poses, outdir, fibeat): 74 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 75 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 76 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 77 | 78 | # find connection in the specified sequence, center 29 is in the position 15 79 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \ 80 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \ 81 | [1,16], [16,18], [3,17], [6,18]] 82 | 83 | 84 | neglect = [14,15,16,17] 85 | 86 | ibeat = cv2.imread(fibeat); 87 | ibeat = cv2.resize(ibeat, (500,200)) 88 | 89 | for t in range(poses.shape[0]): 90 | subibeat = ibeat.copy() 91 | canvas = np.ones((256+200,500,3), np.uint8)*255 92 | canvas[256:,:,:] = subibeat 93 | 94 | overlay = canvas.copy() 95 | cv2.rectangle(overlay, (int(500/poses.shape[0]*(t+1)),256),(500,256+200), (100,100,100), -1) 96 | cv2.addWeighted(overlay, 0.4, canvas, 1-0.4, 0, canvas) 97 | thisPeak = poses[t] 98 | for i in range(18): 99 | if i in neglect: 100 | continue 101 | if thisPeak[i,0] == -1: 102 | continue 103 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1) 104 | 105 | for i in range(17): 106 | limbid = np.array(limbSeq[i])-1 107 | if limbid[0] in neglect or limbid[1] in neglect: 108 | continue 109 | X = thisPeak[[limbid[0],limbid[1]], 1] 110 | Y = thisPeak[[limbid[0],limbid[1]], 0] 111 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1: 112 | continue 113 | stickwidth = 4 114 | cur_canvas = canvas.copy() 115 | mX = np.mean(X) 116 | mY = np.mean(Y) 117 | #print(X, Y, limbid) 118 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 119 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 120 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1) 121 | #print(i, n, int(mY), int(mX), limbid, X, Y) 122 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 123 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 124 | cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas) 125 | 126 | def vis_single(pose, outfile): 127 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 128 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 129 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 130 | 131 | # find connection in the specified sequence, center 29 is in the position 15 132 | limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \ 133 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \ 134 | [1,16], [16,18], [3,17], [6,18]] 135 | 136 | neglect = [14,15,16,17] 137 | 138 | for t in range(1): 139 | #break 140 | canvas = np.ones((256,500,3), np.uint8)*255 141 | 142 | thisPeak = pose 143 | for i in range(18): 144 | if i in neglect: 145 | continue 146 | if thisPeak[i,0] == -1: 147 | continue 148 | cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1) 149 | 150 | for i in range(17): 151 | limbid = np.array(limbSeq[i])-1 152 | if limbid[0] in neglect or limbid[1] in neglect: 153 | continue 154 | X = thisPeak[[limbid[0],limbid[1]], 1] 155 | Y = thisPeak[[limbid[0],limbid[1]], 0] 156 | if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1: 157 | continue 158 | stickwidth = 4 159 | cur_canvas = canvas.copy() 160 | mX = np.mean(X) 161 | mY = np.mean(Y) 162 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 163 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 164 | polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1) 165 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 166 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 167 | cv2.imwrite(outfile,canvas) 168 | --------------------------------------------------------------------------------