├── utils ├── __init__.py ├── paramUtil.py ├── get_opt.py ├── other_tools.py ├── utils.py ├── metrics.py └── quaternion.py ├── datasets ├── pymo │ ├── __init__.py │ ├── mocapplayer │ │ ├── data-template.js │ │ ├── libs │ │ │ ├── threejs │ │ │ │ └── Detector.js │ │ │ ├── pace.min.js │ │ │ └── papaparse.min.js │ │ ├── styles │ │ │ └── pace.css │ │ ├── js │ │ │ └── skeletonFactory.js │ │ └── playURL.html │ ├── features.py │ ├── data.py │ ├── writers.py │ ├── rotation_tools.py │ ├── viz_tools.py │ └── parsers.py ├── __init__.py ├── extract_hubert.py ├── dataloader.py └── show.py ├── assets ├── data.tar.gz ├── beat_visualize.blend ├── teaser_for_demo_cvpr.png ├── requirements.txt └── environment.yml ├── audios ├── Forrest_tts.wav └── 2_scott_0_3_3.wav ├── trainers ├── __init__.py └── loss_factory.py ├── models ├── __init__.py ├── respace.py ├── motion_autoencoder.py ├── scheduler.py └── ddpm_utils.py ├── .gitignore ├── inference_custom_audio_beat.sh ├── inference_custom_audio_show.sh ├── LICENSE ├── options ├── evaluate_options.py ├── train_options.py └── base_options.py ├── train_test_scripts.sh ├── README.md └── runner.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/pymo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/data.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/data.tar.gz -------------------------------------------------------------------------------- /audios/Forrest_tts.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/audios/Forrest_tts.wav -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/data-template.js: -------------------------------------------------------------------------------- 1 | var dataBuffer = `$$DATA$$`; 2 | 3 | start(dataBuffer); -------------------------------------------------------------------------------- /audios/2_scott_0_3_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/audios/2_scott_0_3_3.wav -------------------------------------------------------------------------------- /assets/beat_visualize.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/beat_visualize.blend -------------------------------------------------------------------------------- /assets/teaser_for_demo_cvpr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/teaser_for_demo_cvpr.png -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddpm_beat_trainer import DDPMTrainer_beat 2 | from .ddpm_show_trainer import DDPMTrainer_show 3 | 4 | 5 | __all__ = ['DDPMTrainer_beat', 'DDPMTrainer_show'] -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .show import ShowDataset 2 | from .beat import BeatDataset 3 | from .dataloader import build_dataloader 4 | 5 | 6 | 7 | __all__ = [ 8 | 'ShowDataset', 'BeatDataset', 'build_dataloader'] -------------------------------------------------------------------------------- /assets/requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow==3.0.0 2 | mmcv-full==1.6.0 3 | matplotlib==3.7.1 4 | scipy==1.10.1 5 | lmdb==0.96 6 | termcolor==2.1.1 7 | librosa==0.9.2 8 | loguru==0.6.0 9 | umap==0.1.1 10 | wandb==0.13.10 11 | transformers==4.31.0 12 | pandas==1.4.2 13 | IPython -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import MotionTransformer, UniDiffuser 2 | from .gaussian_diffusion import GaussianDiffusion 3 | # from .respace import GaussianDiffusion 4 | from .scheduler import get_schedule_jump, get_schedule_jump_paper 5 | 6 | __all__ = ['MotionTransformer', 'UniDiffuser', 'GaussianDiffusion', 'get_schedule_jump', 'get_schedule_jump_paper'] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | results 3 | wandb 4 | SMPLX_NEUTRAL_2020.npz 5 | rendered 6 | data 7 | rendering 8 | temp 9 | 10 | __pycache__ 11 | *.pyc 12 | 13 | motion_ae_fid/logs 14 | test.sh 15 | train.sh 16 | 17 | videos 18 | rebuttal 19 | 20 | test* 21 | 22 | ## BEAT_ori 23 | 0_BEAT_ori/outputs/* 24 | 0_BEAT_ori/datasets 25 | 0_BEAT_ori/codes/audio2pose/docs/* 26 | 0_BEAT_ori/codes/audio2pose/*.sh 27 | 28 | 29 | ## TalkSHOW 30 | *.pkl 31 | *.zip 32 | *.mp4 33 | A_TalkSHOW_ori/.idea 34 | A_TalkSHOW_ori/data 35 | A_TalkSHOW_ori/experiments 36 | A_TalkSHOW_ori/visualise/smplx 37 | A_TalkSHOW_ori/visualise/video 38 | A_TalkSHOW_ori/visualise/mesh_vertices 39 | A_TalkSHOW_ori/visualise/mouth_keypoints 40 | A_TalkSHOW_ori/demo_audio 41 | A_TalkSHOW_ori/demo 42 | A_TalkSHOW_ori/results 43 | 44 | 45 | 46 | # SpeechSplit 47 | B_SpeechSplit/assets 48 | B_SpeechSplit/results* 49 | 50 | -------------------------------------------------------------------------------- /inference_custom_audio_beat.sh: -------------------------------------------------------------------------------- 1 | ## For 25+ FPS on A100 2 | # PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 3 | # OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \ 4 | # --dataset_name beat \ 5 | # --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \ 6 | # --n_poses 34 \ 7 | # --ddim \ 8 | # --ckpt fgd_best.tar \ 9 | # --ddim \ 10 | # --timestep_respacing ddim25 \ 11 | # --overlap_len 4 \ 12 | # --mode test_custom_audio \ 13 | # --test_audio_path audios/2_scott_0_3_3.wav 14 | 15 | ## For 30+ FPS on 3090; For 55+ FPS on A100 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \ 18 | --dataset_name beat \ 19 | --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \ 20 | --n_poses 34 \ 21 | --ddim \ 22 | --ckpt fgd_best.tar \ 23 | --ddim \ 24 | --timestep_respacing ddim25 \ 25 | --overlap_len 4 \ 26 | --mode test_custom_audio \ 27 | --jump_n_sample 2 \ 28 | --test_audio_path audios/2_scott_0_3_3.wav 29 | -------------------------------------------------------------------------------- /inference_custom_audio_show.sh: -------------------------------------------------------------------------------- 1 | # For 50+ FPS on A100 2 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 3 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \ 4 | --dataset_name talkshow \ 5 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \ 6 | --n_poses 88 \ 7 | --model_base transformer_encoder \ 8 | --classifier_free \ 9 | --cond_scale 1.15 \ 10 | --ckpt ckpt_e2599.tar \ 11 | --ddim \ 12 | --timestep_respacing ddim25 \ 13 | --overlap_len 10 \ 14 | --mode test_custom_audio \ 15 | --test_audio_path audios/Forrest_tts.wav 16 | 17 | 18 | ## For 120+ FPS on A100 19 | # PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 20 | # OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \ 21 | # --dataset_name talkshow \ 22 | # --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \ 23 | # --n_poses 88 \ 24 | # --model_base transformer_encoder \ 25 | # --classifier_free \ 26 | # --cond_scale 1.15 \ 27 | # --ckpt ckpt_e2599.tar \ 28 | # --ddim \ 29 | # --timestep_respacing ddim25 \ 30 | # --overlap_len 10 \ 31 | # --mode test_custom_audio \ 32 | # --jump_n_sample 2 \ 33 | # --test_audio_path audios/Forrest_tts.wav -------------------------------------------------------------------------------- /datasets/pymo/features.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A set of mocap feature extraction functions 3 | 4 | Created by Omid Alemi | Nov 17 2017 5 | 6 | ''' 7 | import numpy as np 8 | import pandas as pd 9 | import peakutils 10 | import matplotlib.pyplot as plt 11 | 12 | def get_foot_contact_idxs(signal, t=0.02, min_dist=120): 13 | up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist) 14 | down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist) 15 | 16 | return [up_idxs, down_idxs] 17 | 18 | 19 | def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120): 20 | signal = mocap_track.values[col_name].values 21 | idxs = get_foot_contact_idxs(signal, t, min_dist) 22 | 23 | step_signal = [] 24 | 25 | c = start 26 | for f in range(len(signal)): 27 | if f in idxs[1]: 28 | c = 0 29 | elif f in idxs[0]: 30 | c = 1 31 | 32 | step_signal.append(c) 33 | 34 | return step_signal 35 | 36 | def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120): 37 | 38 | signal = mocap_track.values[col_name].values 39 | idxs = get_foot_contact_idxs(signal, t, min_dist) 40 | 41 | plt.plot(mocap_track.values.index, signal) 42 | plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro') 43 | plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go') 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, JeremyCJM 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /datasets/pymo/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Joint(): 4 | def __init__(self, name, parent=None, children=None): 5 | self.name = name 6 | self.parent = parent 7 | self.children = children 8 | 9 | class MocapData(): 10 | def __init__(self): 11 | self.skeleton = {} 12 | self.values = None 13 | self.channel_names = [] 14 | self.framerate = 0.0 15 | self.root_name = '' 16 | 17 | def traverse(self, j=None): 18 | stack = [self.root_name] 19 | while stack: 20 | joint = stack.pop() 21 | yield joint 22 | for c in self.skeleton[joint]['children']: 23 | stack.append(c) 24 | 25 | def clone(self): 26 | import copy 27 | new_data = MocapData() 28 | new_data.skeleton = copy.copy(self.skeleton) 29 | new_data.values = copy.copy(self.values) 30 | new_data.channel_names = copy.copy(self.channel_names) 31 | new_data.root_name = copy.copy(self.root_name) 32 | new_data.framerate = copy.copy(self.framerate) 33 | return new_data 34 | 35 | def get_all_channels(self): 36 | '''Returns all of the channels parsed from the file as a 2D numpy array''' 37 | 38 | frames = [f[1] for f in self.values] 39 | return np.asarray([[channel[2] for channel in frame] for frame in frames]) 40 | 41 | def get_skeleton_tree(self): 42 | tree = [] 43 | root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0] 44 | 45 | root_joint = Joint(root_key) 46 | 47 | def get_empty_channels(self): 48 | #TODO 49 | pass 50 | 51 | def get_constant_channels(self): 52 | #TODO 53 | pass 54 | -------------------------------------------------------------------------------- /options/evaluate_options.py: -------------------------------------------------------------------------------- 1 | from options.base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size') 8 | self.parser.add_argument('--start_mov_len', type=int, default=10) 9 | self.parser.add_argument('--est_length', action="store_true", help="Whether to use sampled motion length") 10 | self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer') 11 | self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer') 12 | self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer') 13 | self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain') 14 | self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention') 15 | 16 | 17 | self.parser.add_argument('--repeat_times', type=int, default=3, help="Number of generation rounds for each text description") 18 | self.parser.add_argument('--split_file', type=str, default='test.txt') 19 | self.parser.add_argument('--text', type=str, default="", help='Text description for motion generation') 20 | self.parser.add_argument('--motion_length', type=int, default=0, help='Number of framese for motion generation') 21 | self.parser.add_argument('--text_file', type=str, default="", help='Path of text description for motion generation') 22 | self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint that will be used') 23 | self.parser.add_argument('--result_path', type=str, default="./eval_results/", help='Path to save generation results') 24 | self.parser.add_argument('--num_results', type=int, default=40, help='Number of descriptions that will be used') 25 | self.parser.add_argument('--ext', type=str, default='default', help='Save file path extension') 26 | 27 | self.is_train = False 28 | -------------------------------------------------------------------------------- /utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /datasets/pymo/writers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | class BVHWriter(): 5 | def __init__(self): 6 | pass 7 | 8 | def write(self, X, ofile): 9 | 10 | # Writing the skeleton info 11 | ofile.write('HIERARCHY\n') 12 | 13 | self.motions_ = [] 14 | self._printJoint(X, X.root_name, 0, ofile) 15 | 16 | # Writing the motion header 17 | ofile.write('MOTION\n') 18 | ofile.write('Frames: %d\n'%X.values.shape[0]) 19 | ofile.write('Frame Time: %f\n'%X.framerate) 20 | 21 | # Writing the data 22 | self.motions_ = np.asarray(self.motions_).T 23 | lines = [" ".join(item) for item in self.motions_.astype(str)] 24 | ofile.write("".join("%s\n"%l for l in lines)) 25 | 26 | def _printJoint(self, X, joint, tab, ofile): 27 | 28 | if X.skeleton[joint]['parent'] == None: 29 | ofile.write('ROOT %s\n'%joint) 30 | elif len(X.skeleton[joint]['children']) > 0: 31 | ofile.write('%sJOINT %s\n'%('\t'*(tab), joint)) 32 | else: 33 | ofile.write('%sEnd site\n'%('\t'*(tab))) 34 | 35 | ofile.write('%s{\n'%('\t'*(tab))) 36 | 37 | ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1), 38 | X.skeleton[joint]['offsets'][0], 39 | X.skeleton[joint]['offsets'][1], 40 | X.skeleton[joint]['offsets'][2])) 41 | channels = X.skeleton[joint]['channels'] 42 | n_channels = len(channels) 43 | 44 | if n_channels > 0: 45 | for ch in channels: 46 | self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values)) 47 | 48 | if len(X.skeleton[joint]['children']) > 0: 49 | ch_str = ''.join(' %s'*n_channels%tuple(channels)) 50 | ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str)) 51 | 52 | for c in X.skeleton[joint]['children']: 53 | self._printJoint(X, c, tab+1, ofile) 54 | 55 | ofile.write('%s}\n'%('\t'*(tab))) 56 | -------------------------------------------------------------------------------- /trainers/loss_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) HuaWei, Inc. and its affiliates. 2 | # liu.haiyang@huawei.com 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class BCE_Loss(nn.Module): 11 | def __init__(self, args=None): 12 | super(BCE_Loss, self).__init__() 13 | 14 | def forward(self, fake_outputs, real_target): 15 | final_loss = F.cross_entropy(fake_outputs, real_target, reduce="mean") 16 | return final_loss 17 | 18 | 19 | class HuberLoss(nn.Module): 20 | def __init__(self, beta=0.1, reduction="mean"): 21 | super(HuberLoss, self).__init__() 22 | self.beta = beta 23 | self.reduction = reduction 24 | 25 | def forward(self, outputs, targets): 26 | final_loss = F.smooth_l1_loss(outputs / self.beta, targets / self.beta, reduction=self.reduction) * self.beta 27 | return final_loss 28 | 29 | 30 | class KLDLoss(nn.Module): 31 | def __init__(self, beta=0.1): 32 | super(KLDLoss, self).__init__() 33 | self.beta = beta 34 | 35 | def forward(self, outputs, targets): 36 | final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta) 37 | return final_loss 38 | 39 | 40 | class REGLoss(nn.Module): 41 | def __init__(self, beta=0.1): 42 | super(REGLoss, self).__init__() 43 | self.beta = beta 44 | 45 | def forward(self, outputs, targets): 46 | final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta) 47 | return final_loss 48 | 49 | 50 | class L2Loss(nn.Module): 51 | def __init__(self): 52 | super(L2Loss, self).__init__() 53 | 54 | def forward(self, outputs, targets): 55 | final_loss = F.l2_loss(outputs, targets) 56 | return final_loss 57 | 58 | LOSS_FUNC_LUT = { 59 | "bce_loss": BCE_Loss, 60 | "l2_loss": L2Loss, 61 | "huber_loss": HuberLoss, 62 | "kl_loss": KLDLoss, 63 | "id_loss": REGLoss, 64 | } 65 | 66 | 67 | def get_loss_func(loss_name, **kwargs): 68 | loss_func_class = LOSS_FUNC_LUT.get(loss_name) 69 | loss_func = loss_func_class(**kwargs) 70 | return loss_func 71 | 72 | 73 | -------------------------------------------------------------------------------- /train_test_scripts.sh: -------------------------------------------------------------------------------- 1 | ################################# BEAT ################################# 2 | 3 | ### Train on BEAT. 4 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 5 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \ 6 | --dataset_name beat \ 7 | --name beat_diffsheg \ 8 | --batch_size 2500 \ 9 | --num_epochs 1000 \ 10 | --save_every_e 20 \ 11 | --eval_every_e 40 \ 12 | --n_poses 34 \ 13 | --ddim \ 14 | --multiprocessing-distributed \ 15 | --dist-url 'tcp://127.0.0.1:6666' 16 | 17 | 18 | 19 | ### Test on BEAT. 20 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 21 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \ 22 | --dataset_name beat \ 23 | --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \ 24 | --n_poses 34 \ 25 | --multiprocessing-distributed \ 26 | --dist-url 'tcp://127.0.0.1:8888' \ 27 | --ckpt fgd_best.tar \ 28 | --mode test_arbitrary_len \ 29 | --ddim \ 30 | --timestep_respacing ddim25 \ 31 | --overlap_len 4 32 | 33 | 34 | ################################# SHOW ################################# 35 | 36 | ### Train on SHOW. 37 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 38 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=1,2,3,4,5 python -u runner.py \ 39 | --dataset_name talkshow \ 40 | --name talkshow_diffsheg \ 41 | --batch_size 950 \ 42 | --num_epochs 4000 \ 43 | --save_every_e 20 \ 44 | --eval_every_e 40 \ 45 | --n_poses 88 \ 46 | --classifier_free \ 47 | --multiprocessing-distributed \ 48 | --dist-url 'tcp://127.0.0.1:6667' \ 49 | --ddim \ 50 | --max_eval_samples 200 51 | 52 | 53 | 54 | ### Test on SHOW. 55 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 56 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \ 57 | --dataset_name talkshow \ 58 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \ 59 | --PE pe_sinu \ 60 | --n_poses 88 \ 61 | --multiprocessing-distributed \ 62 | --dist-url 'tcp://127.0.0.1:8889' \ 63 | --classifier_free \ 64 | --cond_scale 1.25 \ 65 | --ckpt ckpt_e2599.tar \ 66 | --mode test_arbitrary_len \ 67 | --ddim \ 68 | --timestep_respacing ddim25 \ 69 | --overlap_len 10 70 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from options.base_options import BaseOptions 2 | import argparse 3 | 4 | class TrainCompOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer') 8 | self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer') 9 | self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer') 10 | self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain') 11 | self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention') 12 | 13 | self.parser.add_argument('--num_epochs', type=int, default=5000, help='Number of epochs') 14 | self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate') 15 | self.parser.add_argument('--reset_lr', action='store_true', help='Reset the optimizer lr to args.lr after resume from a ckpt') 16 | self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU') 17 | self.parser.add_argument('--times', type=int, default=1, help='times of dataset') 18 | 19 | self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact') 20 | 21 | self.parser.add_argument('--resume', action="store_true", help='Is this trail continued from previous trail?') 22 | 23 | self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress (by iteration)') 24 | self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of saving models (by epoch)') 25 | self.parser.add_argument('--eval_every_e', type=int, default=5, help='Frequency of animation results (by epoch)') 26 | self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving models (by iteration)') 27 | 28 | 29 | self.parser.add_argument('--ckpt', type=str, default='latest.tar', help='choose which checkpoint to use') 30 | # self.parser.add_argument('--max_motion_length', type=int, default=200, help='max_motion_length') 31 | 32 | self.is_train = True 33 | -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/libs/threejs/Detector.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @author alteredq / http://alteredqualia.com/ 3 | * @author mr.doob / http://mrdoob.com/ 4 | */ 5 | 6 | var Detector = { 7 | 8 | canvas: !! window.CanvasRenderingContext2D, 9 | webgl: ( function () { 10 | 11 | try { 12 | 13 | var canvas = document.createElement( 'canvas' ); return !! ( window.WebGLRenderingContext && ( canvas.getContext( 'webgl' ) || canvas.getContext( 'experimental-webgl' ) ) ); 14 | 15 | } catch ( e ) { 16 | 17 | return false; 18 | 19 | } 20 | 21 | } )(), 22 | workers: !! window.Worker, 23 | fileapi: window.File && window.FileReader && window.FileList && window.Blob, 24 | 25 | getWebGLErrorMessage: function () { 26 | 27 | var element = document.createElement( 'div' ); 28 | element.id = 'webgl-error-message'; 29 | element.style.fontFamily = 'monospace'; 30 | element.style.fontSize = '13px'; 31 | element.style.fontWeight = 'normal'; 32 | element.style.textAlign = 'center'; 33 | element.style.background = '#fff'; 34 | element.style.color = '#000'; 35 | element.style.padding = '1.5em'; 36 | element.style.width = '400px'; 37 | element.style.margin = '5em auto 0'; 38 | 39 | if ( ! this.webgl ) { 40 | 41 | element.innerHTML = window.WebGLRenderingContext ? [ 42 | 'Your graphics card does not seem to support WebGL.
', 43 | 'Find out how to get it here.' 44 | ].join( '\n' ) : [ 45 | 'Your browser does not seem to support WebGL.
', 46 | 'Find out how to get it here.' 47 | ].join( '\n' ); 48 | 49 | } 50 | 51 | return element; 52 | 53 | }, 54 | 55 | addGetWebGLMessage: function ( parameters ) { 56 | 57 | var parent, id, element; 58 | 59 | parameters = parameters || {}; 60 | 61 | parent = parameters.parent !== undefined ? parameters.parent : document.body; 62 | id = parameters.id !== undefined ? parameters.id : 'oldie'; 63 | 64 | element = Detector.getWebGLErrorMessage(); 65 | element.id = id; 66 | 67 | parent.appendChild( element ); 68 | 69 | } 70 | 71 | }; 72 | 73 | // browserify support 74 | if ( typeof module === 'object' ) { 75 | 76 | module.exports = Detector; 77 | 78 | } -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/styles/pace.css: -------------------------------------------------------------------------------- 1 | .pace { 2 | -webkit-pointer-events: none; 3 | pointer-events: none; 4 | -webkit-user-select: none; 5 | -moz-user-select: none; 6 | user-select: none; 7 | } 8 | 9 | .pace-inactive { 10 | display: none; 11 | } 12 | 13 | .pace .pace-progress { 14 | background: #29d; 15 | position: fixed; 16 | z-index: 2000; 17 | top: 0; 18 | right: 100%; 19 | width: 100%; 20 | height: 2px; 21 | } 22 | 23 | .pace .pace-progress-inner { 24 | display: block; 25 | position: absolute; 26 | right: 0px; 27 | width: 100px; 28 | height: 100%; 29 | box-shadow: 0 0 10px #29d, 0 0 5px #29d; 30 | opacity: 1.0; 31 | -webkit-transform: rotate(3deg) translate(0px, -4px); 32 | -moz-transform: rotate(3deg) translate(0px, -4px); 33 | -ms-transform: rotate(3deg) translate(0px, -4px); 34 | -o-transform: rotate(3deg) translate(0px, -4px); 35 | transform: rotate(3deg) translate(0px, -4px); 36 | } 37 | 38 | .pace .pace-activity { 39 | display: block; 40 | position: fixed; 41 | z-index: 2000; 42 | top: 15px; 43 | right: 20px; 44 | width: 34px; 45 | height: 34px; 46 | border: solid 2px transparent; 47 | border-top-color: #9ea7ac; 48 | border-left-color: #9ea7ac; 49 | border-radius: 30px; 50 | -webkit-animation: pace-spinner 700ms linear infinite; 51 | -moz-animation: pace-spinner 700ms linear infinite; 52 | -ms-animation: pace-spinner 700ms linear infinite; 53 | -o-animation: pace-spinner 700ms linear infinite; 54 | animation: pace-spinner 700ms linear infinite; 55 | } 56 | 57 | @-webkit-keyframes pace-spinner { 58 | 0% { -webkit-transform: rotate(0deg); transform: rotate(0deg); } 59 | 100% { -webkit-transform: rotate(360deg); transform: rotate(360deg); } 60 | } 61 | @-moz-keyframes pace-spinner { 62 | 0% { -moz-transform: rotate(0deg); transform: rotate(0deg); } 63 | 100% { -moz-transform: rotate(360deg); transform: rotate(360deg); } 64 | } 65 | @-o-keyframes pace-spinner { 66 | 0% { -o-transform: rotate(0deg); transform: rotate(0deg); } 67 | 100% { -o-transform: rotate(360deg); transform: rotate(360deg); } 68 | } 69 | @-ms-keyframes pace-spinner { 70 | 0% { -ms-transform: rotate(0deg); transform: rotate(0deg); } 71 | 100% { -ms-transform: rotate(360deg); transform: rotate(360deg); } 72 | } 73 | @keyframes pace-spinner { 74 | 0% { transform: rotate(0deg); transform: rotate(0deg); } 75 | 100% { transform: rotate(360deg); transform: rotate(360deg); } 76 | } -------------------------------------------------------------------------------- /utils/get_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import re 4 | from os.path import join as pjoin 5 | from utils.word_vectorizer import POS_enumerator 6 | 7 | 8 | def is_float(numStr): 9 | flag = False 10 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 11 | try: 12 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 13 | res = reg.match(str(numStr)) 14 | if res: 15 | flag = True 16 | except Exception as ex: 17 | print("is_float() - error: " + str(ex)) 18 | return flag 19 | 20 | 21 | def is_number(numStr): 22 | flag = False 23 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 24 | if str(numStr).isdigit(): 25 | flag = True 26 | return flag 27 | 28 | 29 | def get_opt(opt_path, device): 30 | opt = Namespace() 31 | opt_dict = vars(opt) 32 | 33 | skip = ('-------------- End ----------------', 34 | '------------ Options -------------', 35 | '\n') 36 | print('Reading', opt_path) 37 | with open(opt_path) as f: 38 | for line in f: 39 | if line.strip() not in skip: 40 | # print(line.strip()) 41 | key, value = line.strip().split(': ') 42 | if value in ('True', 'False'): 43 | opt_dict[key] = True if value == 'True' else False 44 | elif is_float(value): 45 | opt_dict[key] = float(value) 46 | elif is_number(value): 47 | opt_dict[key] = int(value) 48 | else: 49 | opt_dict[key] = str(value) 50 | 51 | opt_dict['which_epoch'] = 'latest' 52 | if 'num_layers' not in opt_dict: 53 | opt_dict['num_layers'] = 8 54 | if 'latent_dim' not in opt_dict: 55 | opt_dict['latent_dim'] = 512 56 | if 'diffusion_steps' not in opt_dict: 57 | opt_dict['diffusion_steps'] = 1000 58 | if 'no_clip' not in opt_dict: 59 | opt_dict['no_clip'] = False 60 | if 'no_eff' not in opt_dict: 61 | opt_dict['no_eff'] = False 62 | 63 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 64 | opt.model_dir = pjoin(opt.save_root, 'model') 65 | opt.meta_dir = pjoin(opt.save_root, 'meta') 66 | 67 | if opt.dataset_name == 't2m': 68 | opt.data_root = './data/HumanML3D' 69 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 70 | opt.text_dir = pjoin(opt.data_root, 'texts') 71 | opt.joints_num = 22 72 | opt.dim_pose = 263 73 | opt.max_motion_length = 196 74 | elif opt.dataset_name == 'kit': 75 | opt.data_root = './data/KIT-ML' 76 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 77 | opt.text_dir = pjoin(opt.data_root, 'texts') 78 | opt.joints_num = 21 79 | opt.dim_pose = 251 80 | opt.max_motion_length = 196 81 | else: 82 | raise KeyError('Dataset not recognized') 83 | 84 | opt.dim_word = 300 85 | opt.num_classes = 200 // opt.unit_length 86 | opt.dim_pos_ohot = len(POS_enumerator) 87 | opt.is_train = False 88 | opt.is_continue = False 89 | opt.device = device 90 | 91 | return opt -------------------------------------------------------------------------------- /assets/environment.yml: -------------------------------------------------------------------------------- 1 | name: diffsheg 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - ca-certificates=2024.3.11=h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.4=h6a678d5_0 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=3.0.13=h7f8727e_0 16 | - pip=23.3.1=py39h06a4308_0 17 | - python=3.9.19=h955ad1f_0 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=68.2.2=py39h06a4308_0 20 | - sqlite=3.41.2=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - tzdata=2024a=h04d1e81_0 23 | - wheel=0.41.2=py39h06a4308_0 24 | - xz=5.4.6=h5eee18b_0 25 | - zlib=1.2.13=h5eee18b_0 26 | - pip: 27 | - addict==2.4.0 28 | - appdirs==1.4.4 29 | - asttokens==2.4.1 30 | - audioread==3.0.1 31 | - certifi==2024.2.2 32 | - cffi==1.16.0 33 | - charset-normalizer==3.3.2 34 | - click==8.1.7 35 | - contourpy==1.2.1 36 | - cycler==0.12.1 37 | - decorator==5.1.1 38 | - docker-pycreds==0.4.0 39 | - exceptiongroup==1.2.1 40 | - executing==2.0.1 41 | - filelock==3.14.0 42 | - fonttools==4.51.0 43 | - fsspec==2024.3.1 44 | - gitdb==4.0.11 45 | - gitpython==3.1.43 46 | - huggingface-hub==0.22.2 47 | - idna==3.7 48 | - importlib-metadata==7.1.0 49 | - importlib-resources==6.4.0 50 | - ipython==8.18.1 51 | - jedi==0.19.1 52 | - joblib==1.4.0 53 | - kiwisolver==1.4.5 54 | - librosa==0.9.2 55 | - llvmlite==0.42.0 56 | - lmdb==0.96 57 | - loguru==0.6.0 58 | - matplotlib==3.7.1 59 | - matplotlib-inline==0.1.7 60 | - mmcv-full==1.6.0 61 | - numba==0.59.1 62 | - numpy==1.26.4 63 | - opencv-python==4.9.0.80 64 | - packaging==24.0 65 | - pandas==1.4.2 66 | - parso==0.8.4 67 | - pathtools==0.1.2 68 | - pexpect==4.9.0 69 | - pillow==10.3.0 70 | - platformdirs==4.2.1 71 | - pooch==1.8.1 72 | - prompt-toolkit==3.0.43 73 | - protobuf==4.25.3 74 | - psutil==5.9.8 75 | - ptyprocess==0.7.0 76 | - pure-eval==0.2.2 77 | - pyarrow==3.0.0 78 | - pycparser==2.22 79 | - pygments==2.17.2 80 | - pyparsing==3.1.2 81 | - python-dateutil==2.9.0.post0 82 | - pytz==2024.1 83 | - pyyaml==6.0.1 84 | - regex==2024.4.28 85 | - requests==2.31.0 86 | - resampy==0.4.3 87 | - safetensors==0.4.3 88 | - scikit-learn==1.4.2 89 | - scipy==1.10.1 90 | - sentry-sdk==2.0.1 91 | - setproctitle==1.3.3 92 | - six==1.16.0 93 | - smmap==5.0.1 94 | - soundfile==0.12.1 95 | - stack-data==0.6.3 96 | - termcolor==2.1.1 97 | - threadpoolctl==3.5.0 98 | - tokenizers==0.13.3 99 | - tomli==2.0.1 100 | - torch==1.13.1+cu117 101 | - torchaudio==0.13.1+cu117 102 | - torchvision==0.14.1+cu117 103 | - tqdm==4.66.2 104 | - traitlets==5.14.3 105 | - transformers==4.31.0 106 | - typing-extensions==4.11.0 107 | - umap==0.1.1 108 | - urllib3==2.2.1 109 | - wandb==0.13.10 110 | - wcwidth==0.2.13 111 | - yapf==0.40.2 112 | - zipp==3.18.1 113 | -------------------------------------------------------------------------------- /datasets/extract_hubert.py: -------------------------------------------------------------------------------- 1 | from transformers import Wav2Vec2Processor, HubertModel 2 | import soundfile as sf 3 | import numpy as np 4 | import torch 5 | 6 | print("Loading the Wav2Vec2 Processor...") 7 | wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") 8 | print("Loading the HuBERT Model...") 9 | hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") 10 | 11 | 12 | def get_hubert_from_16k_wav(wav_16k_name): 13 | speech_16k, _ = sf.read(wav_16k_name) 14 | hubert = get_hubert_from_16k_speech(speech_16k) 15 | return hubert 16 | 17 | @torch.no_grad() 18 | def get_hubert_from_16k_speech(speech, device="cuda:0"): 19 | global hubert_model 20 | hubert_model = hubert_model.to(device) 21 | if speech.ndim ==2: 22 | speech = speech[:, 0] # [T, 2] ==> [T,] 23 | input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] 24 | input_values_all = input_values_all.to(device) 25 | # For long audio sequence, due to the memory limitation, we cannot process them in one run 26 | # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 27 | # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. 28 | # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 29 | # We have the equation to calculate out time step: T = floor((t-k)/s) 30 | # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip 31 | # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N 32 | kernel = 400 33 | stride = 320 34 | clip_length = stride * 1000 35 | num_iter = input_values_all.shape[1] // clip_length 36 | expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride 37 | res_lst = [] 38 | for i in range(num_iter): 39 | if i == 0: 40 | start_idx = 0 41 | end_idx = clip_length - stride + kernel 42 | else: 43 | start_idx = clip_length * i 44 | end_idx = start_idx + (clip_length - stride + kernel) 45 | input_values = input_values_all[:, start_idx: end_idx] 46 | hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] 47 | res_lst.append(hidden_states[0]) 48 | if num_iter > 0: 49 | input_values = input_values_all[:, clip_length * num_iter:] 50 | else: 51 | input_values = input_values_all 52 | # if input_values.shape[1] != 0: 53 | if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it 54 | hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] 55 | res_lst.append(hidden_states[0]) 56 | ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] 57 | # assert ret.shape[0] == expected_T 58 | assert abs(ret.shape[0] - expected_T) <= 1 59 | if ret.shape[0] < expected_T: 60 | ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) 61 | else: 62 | ret = ret[:expected_T] 63 | return ret 64 | 65 | 66 | if __name__ == '__main__': 67 | ### Process Single Long Audio for NeRF dataset 68 | # person_id = 'May' 69 | # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" 70 | # hubert_npy_name = f"data/processed/videos/{person_id}/hubert.npy" 71 | # speech_16k, _ = sf.read(wav_16k_name) 72 | # hubert_hidden = get_hubert_from_16k_speech(speech_16k) 73 | # np.save(hubert_npy_name, hubert_hidden.detach().numpy()) 74 | 75 | ### Process short audio clips for LRS3 dataset 76 | import glob, os, tqdm 77 | lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/' 78 | wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav')) 79 | for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)): 80 | spk_id = wav_16k_name.split("/")[-2] 81 | clip_id = wav_16k_name.split("/")[-1][:-4] 82 | out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy') 83 | if os.path.exists(out_name): 84 | continue 85 | speech_16k, _ = sf.read(wav_16k_name) 86 | hubert_hidden = get_hubert_from_16k_speech(speech_16k) 87 | np.save(out_name, hubert_hidden.detach().numpy()) -------------------------------------------------------------------------------- /utils/other_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | import csv 6 | import pprint 7 | from loguru import logger 8 | from collections import OrderedDict 9 | 10 | def print_exp_info(args): 11 | logger.info(pprint.pformat(vars(args))) 12 | logger.info(f"# ------------ {args.name} ----------- #") 13 | logger.info("PyTorch version: {}".format(torch.__version__)) 14 | logger.info("CUDA version: {}".format(torch.version.cuda)) 15 | logger.info("{} GPUs".format(torch.cuda.device_count())) 16 | logger.info(f"Random Seed: {args.random_seed}") 17 | 18 | def args2csv(args, get_head=False, list4print=[]): 19 | for k, v in args.items(): 20 | if isinstance(args[k], dict): 21 | args2csv(args[k], get_head, list4print) 22 | else: list4print.append(k) if get_head else list4print.append(v) 23 | return list4print 24 | 25 | def record_trial(args, csv_path, best_metric, best_epoch): 26 | metric_name = [] 27 | metric_value = [] 28 | metric_epoch = [] 29 | list4print = [] 30 | name4print = [] 31 | for k, v in vars(args).items(): 32 | list4print.append(v) 33 | name4print.append(k) 34 | 35 | for k, v in best_metric.items(): 36 | metric_name.append(k) 37 | metric_value.append(v) 38 | metric_epoch.append(best_epoch[k]) 39 | 40 | if not os.path.exists(csv_path): 41 | with open(csv_path, "a+") as f: 42 | csv_writer = csv.writer(f) 43 | csv_writer.writerow([*metric_name, *metric_name, *name4print]) 44 | 45 | with open(csv_path, "a+") as f: 46 | csv_writer = csv.writer(f) 47 | csv_writer.writerow([*metric_value,*metric_epoch, *list4print]) 48 | 49 | 50 | def set_random_seed(args): 51 | os.environ['PYTHONHASHSEED'] = str(args.random_seed) 52 | random.seed(args.random_seed) 53 | np.random.seed(args.random_seed) 54 | torch.manual_seed(args.random_seed) 55 | torch.cuda.manual_seed_all(args.random_seed) 56 | torch.cuda.manual_seed(args.random_seed) 57 | torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC 58 | torch.backends.cudnn.benchmark = args.benchmark 59 | torch.backends.cudnn.enabled = args.cudnn_enabled 60 | 61 | 62 | def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): 63 | if lrs is not None: 64 | states = { 'model_state': model.state_dict(), 65 | 'epoch': epoch + 1, 66 | 'opt_state': opt.state_dict(), 67 | 'lrs':lrs.state_dict(),} 68 | elif opt is not None: 69 | states = { 'model_state': model.state_dict(), 70 | 'epoch': epoch + 1, 71 | 'opt_state': opt.state_dict(),} 72 | else: 73 | states = { 'model_state': model.state_dict(),} 74 | torch.save(states, save_path) 75 | 76 | 77 | def load_checkpoints(model, save_path, load_name='model'): 78 | states = torch.load(save_path) 79 | new_weights = OrderedDict() 80 | flag=False 81 | for k, v in states['model_state'].items(): 82 | if "module" not in k: 83 | break 84 | else: 85 | new_weights[k[7:]]=v 86 | flag=True 87 | if flag: 88 | model.load_state_dict(new_weights) 89 | else: 90 | model.load_state_dict(states['model_state']) 91 | logger.info(f"load self-pretrained checkpoints for {load_name}") 92 | 93 | 94 | def model_complexity(model, args): 95 | from ptflops import get_model_complexity_info 96 | flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), 97 | as_strings=False, print_per_layer_stat=False) 98 | logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) 99 | logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) 100 | 101 | 102 | class AverageMeter(object): 103 | """Computes and stores the average and current value""" 104 | def __init__(self, name, fmt=':f'): 105 | self.name = name 106 | self.fmt = fmt 107 | self.reset() 108 | 109 | def reset(self): 110 | self.val = 0 111 | self.avg = 0 112 | self.sum = 0 113 | self.count = 0 114 | 115 | def update(self, val, n=1): 116 | self.val = val 117 | self.sum += val * n 118 | self.count += n 119 | self.avg = self.sum / self.count 120 | 121 | def __str__(self): 122 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 123 | return fmtstr.format(**self.__dict__) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # import cv2 4 | from PIL import Image 5 | from utils import paramUtil 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | from scipy.ndimage import gaussian_filter 10 | 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 17 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 18 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 19 | 20 | MISSING_VALUE = -1 21 | 22 | def save_image(image_numpy, image_path): 23 | img_pil = Image.fromarray(image_numpy) 24 | img_pil.save(image_path) 25 | 26 | 27 | def save_logfile(log_loss, save_path): 28 | with open(save_path, 'wt') as f: 29 | for k, v in log_loss.items(): 30 | w_line = k 31 | for digit in v: 32 | w_line += ' %.3f' % digit 33 | f.write(w_line + '\n') 34 | 35 | 36 | def print_current_loss(start_time, niter_state, losses, epoch=None, inner_iter=None, mode='train'): 37 | 38 | def as_minutes(s): 39 | m = math.floor(s / 60) 40 | s -= m * 60 41 | return '%dm %ds' % (m, s) 42 | 43 | def time_since(since, percent): 44 | now = time.time() 45 | s = now - since 46 | es = s / percent 47 | rs = es - s 48 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 49 | 50 | if epoch is not None: 51 | if mode == 'train': 52 | print('epoch: %3d niter: %6d inner_iter: %4d' % (epoch, niter_state, inner_iter), end=" ") 53 | elif mode == 'val': 54 | print('[Validation] epoch: %3d niter: %6d inner_iter: %4d' % (epoch, niter_state, inner_iter), end=" ") 55 | 56 | now = time.time() 57 | message = '%s'%(as_minutes(now - start_time)) 58 | 59 | for k, v in losses.items(): 60 | message += ' %s: %.4f ' % (k, v) 61 | print(message) 62 | 63 | 64 | def compose_gif_img_list(img_list, fp_out, duration): 65 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] 66 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, 67 | save_all=True, loop=0, duration=duration) 68 | 69 | 70 | def save_images(visuals, image_path): 71 | if not os.path.exists(image_path): 72 | os.makedirs(image_path) 73 | 74 | for i, (label, img_numpy) in enumerate(visuals.items()): 75 | img_name = '%d_%s.jpg' % (i, label) 76 | save_path = os.path.join(image_path, img_name) 77 | save_image(img_numpy, save_path) 78 | 79 | 80 | def save_images_test(visuals, image_path, from_name, to_name): 81 | if not os.path.exists(image_path): 82 | os.makedirs(image_path) 83 | 84 | for i, (label, img_numpy) in enumerate(visuals.items()): 85 | img_name = "%s_%s_%s" % (from_name, to_name, label) 86 | save_path = os.path.join(image_path, img_name) 87 | save_image(img_numpy, save_path) 88 | 89 | 90 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): 91 | # print(col, row) 92 | compose_img = compose_image(img_list, col, row, img_size) 93 | if not os.path.exists(save_dir): 94 | os.makedirs(save_dir) 95 | img_path = os.path.join(save_dir, img_name) 96 | # print(img_path) 97 | compose_img.save(img_path) 98 | 99 | 100 | def compose_image(img_list, col, row, img_size): 101 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) 102 | for y in range(0, row): 103 | for x in range(0, col): 104 | from_img = Image.fromarray(img_list[y * col + x]) 105 | # print((x * img_size[0], y*img_size[1], 106 | # (x + 1) * img_size[0], (y + 1) * img_size[1])) 107 | paste_area = (x * img_size[0], y*img_size[1], 108 | (x + 1) * img_size[0], (y + 1) * img_size[1]) 109 | to_image.paste(from_img, paste_area) 110 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img 111 | return to_image 112 | 113 | 114 | def list_cut_average(ll, intervals): 115 | if intervals == 1: 116 | return ll 117 | 118 | bins = math.ceil(len(ll) * 1.0 / intervals) 119 | ll_new = [] 120 | for i in range(bins): 121 | l_low = intervals * i 122 | l_high = l_low + intervals 123 | l_high = l_high if l_high < len(ll) else len(ll) 124 | ll_new.append(np.mean(ll[l_low:l_high])) 125 | return ll_new 126 | 127 | 128 | def motion_temporal_filter(motion, sigma=1): 129 | motion = motion.reshape(motion.shape[0], -1) 130 | # print(motion.shape) 131 | for i in range(motion.shape[1]): 132 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 133 | return motion.reshape(motion.shape[0], -1, 3) 134 | 135 | -------------------------------------------------------------------------------- /datasets/pymo/rotation_tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tools for Manipulating and Converting 3D Rotations 3 | 4 | By Omid Alemi 5 | Created: June 12, 2017 6 | 7 | Adapted from that matlab file... 8 | ''' 9 | 10 | import math 11 | import numpy as np 12 | 13 | def deg2rad(x): 14 | return x/180*math.pi 15 | 16 | 17 | def rad2deg(x): 18 | return x/math.pi*180 19 | 20 | class Rotation(): 21 | def __init__(self,rot, param_type, rotation_order, **params): 22 | self.rotmat = [] 23 | self.rotation_order = rotation_order 24 | if param_type == 'euler': 25 | self._from_euler(rot[0],rot[1],rot[2], params) 26 | elif param_type == 'expmap': 27 | self._from_expmap(rot[0], rot[1], rot[2], params) 28 | 29 | def _from_euler(self, alpha, beta, gamma, params): 30 | '''Expecting degress''' 31 | 32 | if params['from_deg']==True: 33 | alpha = deg2rad(alpha) 34 | beta = deg2rad(beta) 35 | gamma = deg2rad(gamma) 36 | 37 | ca = math.cos(alpha) 38 | cb = math.cos(beta) 39 | cg = math.cos(gamma) 40 | sa = math.sin(alpha) 41 | sb = math.sin(beta) 42 | sg = math.sin(gamma) 43 | 44 | Rx = np.asarray([[1, 0, 0], 45 | [0, ca, sa], 46 | [0, -sa, ca] 47 | ]) 48 | 49 | Ry = np.asarray([[cb, 0, -sb], 50 | [0, 1, 0], 51 | [sb, 0, cb]]) 52 | 53 | Rz = np.asarray([[cg, sg, 0], 54 | [-sg, cg, 0], 55 | [0, 0, 1]]) 56 | 57 | self.rotmat = np.eye(3) 58 | 59 | ############################ inner product rotation matrix in order defined at BVH file ######################### 60 | for axis in self.rotation_order : 61 | if axis == 'X' : 62 | self.rotmat = np.matmul(Rx, self.rotmat) 63 | elif axis == 'Y': 64 | self.rotmat = np.matmul(Ry, self.rotmat) 65 | else : 66 | self.rotmat = np.matmul(Rz, self.rotmat) 67 | ################################################################################################################ 68 | 69 | def _from_expmap(self, alpha, beta, gamma, params): 70 | if (alpha == 0 and beta == 0 and gamma == 0): 71 | self.rotmat = np.eye(3) 72 | return 73 | 74 | #TODO: Check exp map params 75 | 76 | theta = np.linalg.norm([alpha, beta, gamma]) 77 | 78 | expmap = [alpha, beta, gamma] / theta 79 | 80 | x = expmap[0] 81 | y = expmap[1] 82 | z = expmap[2] 83 | 84 | s = math.sin(theta/2) 85 | c = math.cos(theta/2) 86 | 87 | self.rotmat = np.asarray([ 88 | [2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s], 89 | [2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s], 90 | [2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1] 91 | ]) 92 | 93 | 94 | 95 | def get_euler_axis(self): 96 | R = self.rotmat 97 | theta = math.acos((self.rotmat.trace() - 1) / 2) 98 | axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]]) 99 | axis = axis/(2*math.sin(theta)) 100 | return theta, axis 101 | 102 | def to_expmap(self): 103 | theta, axis = self.get_euler_axis() 104 | rot_arr = theta * axis 105 | if np.isnan(rot_arr).any(): 106 | rot_arr = [0, 0, 0] 107 | return rot_arr 108 | 109 | def to_euler(self, use_deg=False): 110 | eulers = np.zeros((2, 3)) 111 | 112 | if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12: 113 | #GIMBAL LOCK! 114 | print('Gimbal') 115 | if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12: 116 | eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2]) 117 | eulers[:,1] = -math.pi/2 118 | else: 119 | eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2]) 120 | eulers[:,1] = math.pi/2 121 | 122 | return eulers 123 | 124 | theta = - math.asin(self.rotmat[2,0]) 125 | theta2 = math.pi - theta 126 | 127 | # psi1, psi2 128 | eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta)) 129 | eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2)) 130 | 131 | # theta1, theta2 132 | eulers[0,1] = theta 133 | eulers[1,1] = theta2 134 | 135 | # phi1, phi2 136 | eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta)) 137 | eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2)) 138 | 139 | if use_deg: 140 | eulers = rad2deg(eulers) 141 | 142 | return eulers 143 | 144 | def to_quat(self): 145 | #TODO 146 | pass 147 | 148 | def __str__(self): 149 | return "Rotation Matrix: \n " + self.rotmat.__str__() 150 | 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | from functools import partial 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | from mmcv.runner import get_dist_info 8 | from mmcv.utils import Registry, build_from_cfg 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.dataset import Dataset 11 | 12 | import torch 13 | from torch.utils.data import DistributedSampler as _DistributedSampler 14 | 15 | 16 | class DistributedSampler(_DistributedSampler): 17 | 18 | def __init__(self, 19 | dataset, 20 | num_replicas=None, 21 | rank=None, 22 | shuffle=True, 23 | round_up=True): 24 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 25 | self.shuffle = shuffle 26 | self.round_up = round_up 27 | if self.round_up: 28 | self.total_size = self.num_samples * self.num_replicas 29 | else: 30 | self.total_size = len(self.dataset) 31 | 32 | def __iter__(self): 33 | # deterministically shuffle based on epoch 34 | if self.shuffle: 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 38 | else: 39 | indices = torch.arange(len(self.dataset)).tolist() 40 | 41 | # add extra samples to make it evenly divisible 42 | if self.round_up: 43 | indices = ( 44 | indices * 45 | int(self.total_size / len(indices) + 1))[:self.total_size] 46 | assert len(indices) == self.total_size 47 | 48 | # subsample 49 | indices = indices[self.rank:self.total_size:self.num_replicas] 50 | if self.round_up: 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices) 54 | 55 | 56 | def build_dataloader(dataset: Dataset, 57 | samples_per_gpu: int, 58 | workers_per_gpu: int, 59 | num_gpus: Optional[int] = 1, 60 | dist: Optional[bool] = True, 61 | shuffle: Optional[bool] = True, 62 | round_up: Optional[bool] = True, 63 | seed: Optional[Union[int, None]] = None, 64 | persistent_workers: Optional[bool] = True, 65 | **kwargs): 66 | """Build PyTorch DataLoader. 67 | In distributed training, each GPU/process has a dataloader. 68 | In non-distributed training, there is only one dataloader for all GPUs. 69 | Args: 70 | dataset (:obj:`Dataset`): A PyTorch dataset. 71 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 72 | batch size of each GPU. 73 | workers_per_gpu (int): How many subprocesses to use for data loading 74 | for each GPU. 75 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed 76 | training. 77 | dist (bool, optional): Distributed training/test or not. Default: True. 78 | shuffle (bool, optional): Whether to shuffle the data at every epoch. 79 | Default: True. 80 | round_up (bool, optional): Whether to round up the length of dataset by 81 | adding extra samples to make it evenly divisible. Default: True. 82 | persistent_workers (bool): If True, the data loader will not shutdown 83 | the worker processes after a dataset has been consumed once. 84 | This allows to maintain the workers Dataset instances alive. 85 | The argument also has effect in PyTorch>=1.7.0. 86 | Default: True 87 | kwargs: any keyword argument to be used to initialize DataLoader 88 | Returns: 89 | DataLoader: A PyTorch dataloader. 90 | """ 91 | rank, world_size = get_dist_info() 92 | if dist: 93 | sampler = DistributedSampler( 94 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up) 95 | shuffle = False 96 | batch_size = samples_per_gpu 97 | num_workers = workers_per_gpu 98 | else: 99 | sampler = None 100 | batch_size = num_gpus * samples_per_gpu 101 | num_workers = num_gpus * workers_per_gpu 102 | 103 | init_fn = partial( 104 | worker_init_fn, num_workers=num_workers, rank=rank, 105 | seed=seed) if seed is not None else None 106 | 107 | data_loader = DataLoader( 108 | dataset, 109 | batch_size=batch_size, 110 | sampler=sampler, 111 | num_workers=num_workers, 112 | pin_memory=False, 113 | shuffle=shuffle, 114 | worker_init_fn=init_fn, 115 | persistent_workers=persistent_workers, 116 | **kwargs) 117 | 118 | return data_loader 119 | 120 | 121 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): 122 | """Init random seed for each worker.""" 123 | # The seed of each worker equals to 124 | # num_worker * rank + worker_id + user_seed 125 | worker_seed = num_workers * rank + worker_id + seed 126 | np.random.seed(worker_seed) 127 | random.seed(worker_seed) -------------------------------------------------------------------------------- /models/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | For example, if there's 300 timesteps and the section counts are [10,15,20] 13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 14 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | :param num_timesteps: the number of diffusion steps in the original 18 | process to divide up. 19 | :param section_counts: either a list of numbers, or a string containing 20 | comma-separated numbers, indicating the step count 21 | per section. As a special case, use "ddimN" where N 22 | is a number of steps to use the striding from the 23 | DDIM paper. 24 | :return: a set of diffusion steps from the original process to use. 25 | """ 26 | if isinstance(section_counts, str): 27 | if section_counts.startswith("ddim"): 28 | desired_count = int(section_counts[len("ddim") :]) 29 | for i in range(1, num_timesteps): 30 | if len(range(0, num_timesteps, i)) == desired_count: 31 | return set(range(0, num_timesteps, i)) 32 | raise ValueError( 33 | f"cannot create exactly {num_timesteps} steps with an integer stride" 34 | ) 35 | section_counts = [int(x) for x in section_counts.split(",")] 36 | size_per = num_timesteps // len(section_counts) 37 | extra = num_timesteps % len(section_counts) 38 | start_idx = 0 39 | all_steps = [] 40 | for i, section_count in enumerate(section_counts): 41 | size = size_per + (1 if i < extra else 0) 42 | if size < section_count: 43 | raise ValueError( 44 | f"cannot divide section of {size} steps into {section_count}" 45 | ) 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | class SpacedDiffusion(GaussianDiffusion): 61 | """ 62 | A diffusion process which can skip steps in a base diffusion process. 63 | :param use_timesteps: a collection (sequence or set) of timesteps from the 64 | original diffusion process to retain. 65 | :param kwargs: the kwargs to create the base diffusion process. 66 | """ 67 | 68 | def __init__(self, use_timesteps, **kwargs): 69 | self.use_timesteps = set(use_timesteps) 70 | self.timestep_map = [] 71 | self.original_num_steps = len(kwargs["betas"]) 72 | 73 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 74 | last_alpha_cumprod = 1.0 75 | new_betas = [] 76 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 77 | if i in self.use_timesteps: 78 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 79 | last_alpha_cumprod = alpha_cumprod 80 | self.timestep_map.append(i) 81 | kwargs["betas"] = np.array(new_betas) 82 | super().__init__(**kwargs) 83 | 84 | def p_mean_variance( 85 | self, model, *args, **kwargs 86 | ): # pylint: disable=signature-differs 87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 88 | 89 | def training_losses( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 93 | 94 | def condition_mean(self, cond_fn, *args, **kwargs): 95 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 96 | 97 | def condition_score(self, cond_fn, *args, **kwargs): 98 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 99 | 100 | def _wrap_model(self, model): 101 | if isinstance(model, _WrappedModel): 102 | return model 103 | return _WrappedModel( 104 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 105 | ) 106 | 107 | def _scale_timesteps(self, t): 108 | # Scaling is done by the wrapped model. 109 | return t 110 | 111 | 112 | class _WrappedModel: 113 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 114 | self.model = model 115 | self.timestep_map = timestep_map 116 | self.rescale_timesteps = rescale_timesteps 117 | self.original_num_steps = original_num_steps 118 | 119 | def __call__(self, x, ts, **kwargs): 120 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 121 | new_ts = map_tensor[ts] 122 | if self.rescale_timesteps: 123 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 124 | return self.model(x, new_ts, **kwargs) 125 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 6 | def euclidean_distance_matrix(matrix1, matrix2): 7 | """ 8 | Params: 9 | -- matrix1: N1 x D 10 | -- matrix2: N2 x D 11 | Returns: 12 | -- dist: N1 x N2 13 | dist[i, j] == distance(matrix1[i], matrix2[j]) 14 | """ 15 | assert matrix1.shape[1] == matrix2.shape[1] 16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 20 | return dists 21 | 22 | def calculate_top_k(mat, top_k): 23 | size = mat.shape[0] 24 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 25 | bool_mat = (mat == gt_mat) 26 | correct_vec = False 27 | top_k_list = [] 28 | for i in range(top_k): 29 | # print(correct_vec, bool_mat[:, i]) 30 | correct_vec = (correct_vec | bool_mat[:, i]) 31 | # print(correct_vec) 32 | top_k_list.append(correct_vec[:, None]) 33 | top_k_mat = np.concatenate(top_k_list, axis=1) 34 | return top_k_mat 35 | 36 | 37 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 38 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 39 | argmax = np.argsort(dist_mat, axis=1) 40 | top_k_mat = calculate_top_k(argmax, top_k) 41 | if sum_all: 42 | return top_k_mat.sum(axis=0) 43 | else: 44 | return top_k_mat 45 | 46 | 47 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 48 | assert len(embedding1.shape) == 2 49 | assert embedding1.shape[0] == embedding2.shape[0] 50 | assert embedding1.shape[1] == embedding2.shape[1] 51 | 52 | dist = linalg.norm(embedding1 - embedding2, axis=1) 53 | if sum_all: 54 | return dist.sum(axis=0) 55 | else: 56 | return dist 57 | 58 | 59 | 60 | def calculate_activation_statistics(activations): 61 | """ 62 | Params: 63 | -- activation: num_samples x dim_feat 64 | Returns: 65 | -- mu: dim_feat 66 | -- sigma: dim_feat x dim_feat 67 | """ 68 | mu = np.mean(activations, axis=0) 69 | cov = np.cov(activations, rowvar=False) 70 | return mu, cov 71 | 72 | 73 | def calculate_diversity(activation, diversity_times): 74 | assert len(activation.shape) == 2 75 | assert activation.shape[0] > diversity_times 76 | num_samples = activation.shape[0] 77 | 78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 81 | return dist.mean() 82 | 83 | 84 | def calculate_multimodality(activation, multimodality_times): 85 | assert len(activation.shape) == 3 86 | assert activation.shape[1] > multimodality_times 87 | num_per_sent = activation.shape[1] 88 | 89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 92 | return dist.mean() 93 | 94 | 95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 96 | """Numpy implementation of the Frechet Distance. 97 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 98 | and X_2 ~ N(mu_2, C_2) is 99 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 100 | Stable version by Dougal J. Sutherland. 101 | Params: 102 | -- mu1 : Numpy array containing the activations of a layer of the 103 | inception net (like returned by the function 'get_predictions') 104 | for generated samples. 105 | -- mu2 : The sample mean over activations, precalculated on an 106 | representative data set. 107 | -- sigma1: The covariance matrix over activations for generated samples. 108 | -- sigma2: The covariance matrix over activations, precalculated on an 109 | representative data set. 110 | Returns: 111 | -- : The Frechet Distance. 112 | """ 113 | 114 | mu1 = np.atleast_1d(mu1) 115 | mu2 = np.atleast_1d(mu2) 116 | 117 | sigma1 = np.atleast_2d(sigma1) 118 | sigma2 = np.atleast_2d(sigma2) 119 | 120 | assert mu1.shape == mu2.shape, \ 121 | 'Training and test mean vectors have different lengths' 122 | assert sigma1.shape == sigma2.shape, \ 123 | 'Training and test covariances have different dimensions' 124 | 125 | diff = mu1 - mu2 126 | 127 | # Product might be almost singular 128 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 129 | if not np.isfinite(covmean).all(): 130 | msg = ('fid calculation produces singular product; ' 131 | 'adding %s to diagonal of cov estimates') % eps 132 | print(msg) 133 | offset = np.eye(sigma1.shape[0]) * eps 134 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 135 | 136 | # Numerical error might give slight imaginary component 137 | if np.iscomplexobj(covmean): 138 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 139 | m = np.max(np.abs(covmean.imag)) 140 | raise ValueError('Imaginary component {}'.format(m)) 141 | covmean = covmean.real 142 | 143 | tr_covmean = np.trace(covmean) 144 | 145 | return (diff.dot(diff) + np.trace(sigma1) + 146 | np.trace(sigma2) - 2 * tr_covmean) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## DiffSHEG: A Diffusion-Based Approach for Real-Time Speech-driven Holistic 3D Expression and Gesture Generation 4 | (CVPR 2024 Official Repo) 5 | 6 | [Junming Chen](https://jeremycjm.github.io)†1,2, [Yunfei Liu](http://liuyunfei.net/)2, [Jianan Wang](https://scholar.google.com/citations?user=mt5mvZ8AAAAJ&hl=en&inst=1381320739207392350)2, [Ailing Zeng](https://ailingzeng.site/)2, [Yu Li](https://yu-li.github.io/)*2, [Qifeng Chen](https://cqf.io)*1 7 | 8 |

1HKUST   2International Digital Economy Academy (IDEA)    9 |
*Corresponding authors   Work done during an internship at IDEA

10 | 11 | #### [Project Page](https://jeremycjm.github.io/proj/DiffSHEG/) · [Paper](https://arxiv.org/abs/2401.04747) · [Video](https://www.youtube.com/watch?v=HFaSd5do-zI) 12 | 13 |

14 | 15 | ![DiffSEHG Teaser](./assets/teaser_for_demo_cvpr.png) 16 | 17 | ## Environment 18 | We have tested on Ubuntu 18.04 and 20.04. 19 | ``` 20 | cd assets 21 | ``` 22 | - Option 1: conda install 23 | ``` 24 | conda env create -f environment.yml 25 | conda activate diffsheg 26 | ``` 27 | - Option 2: pip install 28 | ``` 29 | conda create -n "diffsheg" python=3.9 30 | conda activate diffsheg 31 | pip install -r requirements.txt 32 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 33 | ``` 34 | - Untar data.tar.gz for data statistics 35 | ``` 36 | tar zxvf data.tar.gz 37 | mv data ../ 38 | ``` 39 | 40 | ## Checkpoints 41 | [Google Drive](https://drive.google.com/file/d/1JPoMOcGDrvkFt7QbN6sEyYAPOOWkVN0h/view) 42 | 43 | ## Inference on a Custom Audio 44 | First specify the '--test_audio_path' argument to your test audio path in the following mentioned bash files. Note that the audio should be a .wav file. 45 | 46 | - Use model trained on BEAT dataset: 47 | ``` 48 | bash inference_custom_audio_beat.sh 49 | ``` 50 | 51 | - Use model trained on SHOW dataset: 52 | ``` 53 | bash inference_custom_audio_talkshow.sh 54 | ``` 55 | ## Training 56 |
Train on BEAT dataset 57 | 58 | ``` 59 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 60 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \ 61 | --dataset_name beat \ 62 | --name beat_diffsheg \ 63 | --batch_size 2500 \ 64 | --num_epochs 1000 \ 65 | --save_every_e 20 \ 66 | --eval_every_e 40 \ 67 | --n_poses 34 \ 68 | --ddim \ 69 | --multiprocessing-distributed \ 70 | --dist-url 'tcp://127.0.0.1:6666' 71 | ``` 72 |
73 | 74 | 75 |
Train on SHOW dataset 76 | 77 | ``` 78 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 79 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \ 80 | --dataset_name talkshow \ 81 | --name talkshow_diffsheg \ 82 | --batch_size 950 \ 83 | --num_epochs 4000 \ 84 | --save_every_e 20 \ 85 | --eval_every_e 40 \ 86 | --n_poses 88 \ 87 | --classifier_free \ 88 | --multiprocessing-distributed \ 89 | --dist-url 'tcp://127.0.0.1:6667' \ 90 | --ddim \ 91 | --max_eval_samples 200 92 | ``` 93 |
94 | 95 | ## Testing 96 | 97 |
Test on BEAT dataset 98 | 99 | ``` 100 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 101 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \ 102 | --dataset_name talkshow \ 103 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \ 104 | --PE pe_sinu \ 105 | --n_poses 88 \ 106 | --multiprocessing-distributed \ 107 | --dist-url 'tcp://127.0.0.1:8889' \ 108 | --classifier_free \ 109 | --cond_scale 1.25 \ 110 | --ckpt ckpt_e2599.tar \ 111 | --mode test_arbitrary_len \ 112 | --ddim \ 113 | --timestep_respacing ddim25 \ 114 | --overlap_len 10 115 | ``` 116 |
117 | 118 | 119 |
Test on SHOW dataset 120 | 121 | ``` 122 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 123 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \ 124 | --dataset_name talkshow \ 125 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \ 126 | --PE pe_sinu \ 127 | --n_poses 88 \ 128 | --multiprocessing-distributed \ 129 | --dist-url 'tcp://127.0.0.1:8889' \ 130 | --classifier_free \ 131 | --cond_scale 1.25 \ 132 | --ckpt ckpt_e2599.tar \ 133 | --mode test_arbitrary_len \ 134 | --ddim \ 135 | --timestep_respacing ddim25 \ 136 | --overlap_len 10 137 | ``` 138 |
139 | 140 | ## Visualization 141 | After running under the test or test-custom-audio mode, the Gesture and Expression results will be saved in the ./results directory. 142 | ### BEAT 143 | 1. Open ```assets/beat_visualize.blend``` with latest Blender on your local computer. 144 | 2. Specify the audio, BVH (for gesture), JSON (for expression), and video saving path in the transcript in Blender. 145 | 3. (Optional) Click Window --> Toggle System Console to check the visulization progress. 146 | 4. Run the script in Blender. 147 | ### SHOW 148 | Please refer the the [TalkSHOW](https://github.com/yhw-yhw/TalkSHOW) code for the visualization of our generated motion. 149 | 150 | ## Acknowledgement 151 | Our implementation is partially based on [BEAT](https://github.com/PantoMatrix/BEAT), [TalkSHOW](https://github.com/yhw-yhw/TalkSHOW), and [MotionDiffuse](https://github.com/mingyuan-zhang/MotionDiffuse/tree/main). 152 | 153 | ## Citation 154 | If you use our code or find this repo useful, please consider cite our paper: 155 | ``` 156 | @inproceedings{chen2024diffsheg, 157 | title = {DiffSHEG: A Diffusion-Based Approach for Real-Time Speech-driven Holistic 3D Expression and Gesture Generation}, 158 | author = {Chen, Junming and Liu, Yunfei and Wang, Jianan and Zeng, Ailing and Li, Yu and Chen, Qifeng}, 159 | booktitle = {CVPR}, 160 | year = {2024} 161 | } 162 | ``` 163 | 164 | 165 | -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/js/skeletonFactory.js: -------------------------------------------------------------------------------- 1 | bm_v = new THREE.MeshPhongMaterial({ 2 | color: 0x08519c, 3 | emissive: 0x08306b, 4 | specular: 0x08519c, 5 | shininess: 10, 6 | side: THREE.DoubleSide 7 | }); 8 | 9 | jm_v = new THREE.MeshPhongMaterial({ 10 | color: 0x08306b, 11 | emissive: 0x000000, 12 | specular: 0x111111, 13 | shininess: 90, 14 | side: THREE.DoubleSide 15 | }); 16 | 17 | bm_a = new THREE.MeshPhongMaterial({ 18 | color: 0x980043, 19 | emissive: 0x67001f, 20 | specular: 0x6a51a3, 21 | shininess: 10, 22 | side: THREE.DoubleSide 23 | }); 24 | 25 | jm_a = new THREE.MeshPhongMaterial({ 26 | color: 0x67001f, 27 | emissive: 0x000000, 28 | specular: 0x111111, 29 | shininess: 90, 30 | side: THREE.DoubleSide 31 | }); 32 | 33 | bm_b = new THREE.MeshPhongMaterial({ 34 | color: 0x3f007d, 35 | emissive: 0x3f007d, 36 | specular: 0x807dba, 37 | shininess: 2, 38 | side: THREE.DoubleSide 39 | }); 40 | 41 | jm_b = new THREE.MeshPhongMaterial({ 42 | color: 0x3f007d, 43 | emissive: 0x000000, 44 | specular: 0x807dba, 45 | shininess: 90, 46 | side: THREE.DoubleSide 47 | }); 48 | 49 | //------------------ 50 | 51 | 52 | jointmaterial = new THREE.MeshLambertMaterial({ 53 | color: 0xc57206, 54 | emissive: 0x271c18, 55 | side: THREE.DoubleSide, 56 | // shading: THREE.FlatShading, 57 | wireframe: false, 58 | shininess: 90, 59 | }); 60 | 61 | bonematerial = new THREE.MeshPhongMaterial({ 62 | color: 0xbd9a6d, 63 | emissive: 0x271c18, 64 | side: THREE.DoubleSide, 65 | // shading: THREE.FlatShading, 66 | wireframe: false 67 | }); 68 | 69 | jointmaterial2 = new THREE.MeshPhongMaterial({ 70 | color: 0x1562a2, 71 | emissive: 0x000000, 72 | specular: 0x111111, 73 | shininess: 30, 74 | side: THREE.DoubleSide 75 | }); 76 | 77 | bonematerial2 = new THREE.MeshPhongMaterial({ 78 | color: 0x552211, 79 | emissive: 0x882211, 80 | // emissive: 0x000000, 81 | specular: 0x111111, 82 | shininess: 30, 83 | side: THREE.DoubleSide 84 | }); 85 | 86 | bonematerial3 = new THREE.MeshPhongMaterial({ 87 | color: 0x176793, 88 | emissive: 0x000000, 89 | specular: 0x111111, 90 | shininess: 90, 91 | side: THREE.DoubleSide 92 | }); 93 | 94 | 95 | 96 | jointmaterial4 = new THREE.MeshPhongMaterial({ 97 | color: 0xFF8A00, 98 | emissive: 0x000000, 99 | specular: 0x111111, 100 | shininess: 90, 101 | side: THREE.DoubleSide 102 | }); 103 | 104 | 105 | bonematerial4 = new THREE.MeshPhongMaterial({ 106 | color: 0x53633D, 107 | emissive: 0x000000, 108 | specular: 0xFFC450, 109 | shininess: 90, 110 | side: THREE.DoubleSide 111 | }); 112 | 113 | 114 | 115 | bonematerial44 = new THREE.MeshPhongMaterial({ 116 | color: 0x582A72, 117 | emissive: 0x000000, 118 | specular: 0xFFC450, 119 | shininess: 90, 120 | side: THREE.DoubleSide 121 | }); 122 | 123 | jointmaterial5 = new THREE.MeshPhongMaterial({ 124 | color: 0xAA5533, 125 | emissive: 0x000000, 126 | specular: 0x111111, 127 | shininess: 30, 128 | side: THREE.DoubleSide 129 | }); 130 | 131 | bonematerial5 = new THREE.MeshPhongMaterial({ 132 | color: 0x552211, 133 | emissive: 0x772211, 134 | specular: 0x111111, 135 | shininess: 30, 136 | side: THREE.DoubleSide 137 | }); 138 | 139 | 140 | markermaterial = new THREE.MeshPhongMaterial({ 141 | color: 0xc57206, 142 | emissive: 0x271c18, 143 | side: THREE.DoubleSide, 144 | // shading: THREE.FlatShading, 145 | wireframe: false, 146 | shininess: 20, 147 | }); 148 | 149 | markermaterial2 = new THREE.MeshPhongMaterial({ 150 | color: 0x1562a2, 151 | emissive: 0x271c18, 152 | side: THREE.DoubleSide, 153 | // shading: THREE.FlatShading, 154 | wireframe: false, 155 | shininess: 20, 156 | }); 157 | 158 | markermaterial3 = new THREE.MeshPhongMaterial({ 159 | color: 0x555555, 160 | emissive: 0x999999, 161 | side: THREE.DoubleSide, 162 | // shading: THREE.FlatShading, 163 | wireframe: false, 164 | shininess: 20, 165 | }); 166 | 167 | 168 | var makeMarkerGeometry_Sphere10 = function(markerName, scale) { 169 | return new THREE.SphereGeometry(10, 60, 60); 170 | }; 171 | 172 | var makeMarkerGeometry_Sphere3 = function(markerName, scale) { 173 | return new THREE.SphereGeometry(3, 60, 60); 174 | }; 175 | 176 | var makeMarkerGeometry_SphereX = function(markerName, scale) { 177 | return new THREE.SphereGeometry(5, 60, 60); 178 | }; 179 | 180 | var makeJointGeometry_SphereX = function(X) { 181 | return function(jointName, scale) { 182 | return new THREE.SphereGeometry(X, 60, 60); 183 | }; 184 | }; 185 | 186 | 187 | var makeJointGeometry_Sphere1 = function(jointName, scale) { 188 | return new THREE.SphereGeometry(2 / scale, 60, 60); 189 | }; 190 | 191 | var makeJointGeometry_Sphere2 = function(jointName, scale) { 192 | return new THREE.SphereGeometry(1 / scale, 60, 60); 193 | }; 194 | 195 | var makeJointGeometry_Dode = function(jointName, scale) { 196 | return new THREE.DodecahedronGeometry(1 / scale, 0); 197 | }; 198 | 199 | var makeBoneGeometry_Cylinder1 = function(joint1Name, joint2Name, length, scale) { 200 | return new THREE.CylinderGeometry(1.5 / scale, 0.7 / scale, length, 40); 201 | }; 202 | 203 | var makeBoneGeometry_Cylinder2 = function(joint1Name, joint2Name, length, scale) { 204 | // if (joint1Name.includes("LeftHip")) 205 | // length = 400; 206 | return new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length, 40); 207 | }; 208 | 209 | var makeBoneGeometry_Cylinder3 = function(joint1Name, joint2Name, length, scale) { 210 | var c1 = new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length / 1, 20); 211 | var c2 = new THREE.CylinderGeometry(0.2 / scale, 1.5 / scale, length / 1, 40); 212 | 213 | var material = new THREE.MeshPhongMaterial({ 214 | color: 0xF7FE2E 215 | }); 216 | var mmesh = new THREE.Mesh(c1, material); 217 | mmesh.updateMatrix(); 218 | c2.merge(mmesh.geometry, mmesh.matrix); 219 | return c2; 220 | }; 221 | 222 | var makeBoneGeometry_Box1 = function(joint1Name, joint2Name, length, scale) { 223 | return new THREE.BoxGeometry(1 / scale, length, 1 / scale, 40); 224 | }; 225 | 226 | 227 | var makeJointGeometry_Empty = function(jointName, scale) { 228 | return new THREE.SphereGeometry(0.001, 60, 60); 229 | }; 230 | 231 | var makeBoneGeometry_Empty = function(joint1Name, joint2Name, length, scale) { 232 | return new THREE.CylinderGeometry(0.001, 0.001, 0.001, 40); 233 | }; 234 | -------------------------------------------------------------------------------- /models/motion_autoencoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def reparameterize(mu, logvar): 8 | std = torch.exp(0.5 * logvar) 9 | eps = torch.randn_like(std) 10 | return mu + eps * std 11 | 12 | 13 | def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): 14 | if not downsample: 15 | k = 3 16 | s = 1 17 | else: 18 | k = 4 19 | s = 2 20 | 21 | conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) 22 | norm_block = nn.BatchNorm1d(out_channels) 23 | 24 | if batchnorm: 25 | net = nn.Sequential( 26 | conv_block, 27 | norm_block, 28 | nn.LeakyReLU(0.2, True) 29 | ) 30 | else: 31 | net = nn.Sequential( 32 | conv_block, 33 | nn.LeakyReLU(0.2, True) 34 | ) 35 | return net 36 | 37 | 38 | class PoseEncoderConv(nn.Module): 39 | def __init__(self, length, dim, feature_length=32): 40 | super().__init__() 41 | self.base = feature_length 42 | self.net = nn.Sequential( 43 | ConvNormRelu(dim, self.base, batchnorm=True), #32 44 | ConvNormRelu(self.base, self.base*2, batchnorm=True), #30 45 | ConvNormRelu(self.base*2, self.base*2, True, batchnorm=True), #14 46 | nn.Conv1d(self.base*2, self.base, 3) 47 | ) 48 | if length == 88: 49 | self.out_net = nn.Sequential( 50 | nn.Linear(39*self.base, self.base*12), # for 64 frames 51 | nn.BatchNorm1d(self.base*12), 52 | nn.Linear(12*self.base, self.base*4), 53 | nn.BatchNorm1d(self.base*4), 54 | nn.LeakyReLU(True), 55 | nn.Linear(self.base*4, self.base*2), 56 | nn.BatchNorm1d(self.base*2), 57 | nn.LeakyReLU(True), 58 | nn.Linear(self.base*2, self.base), 59 | ) 60 | 61 | if length == 64: 62 | self.out_net = nn.Sequential( 63 | nn.Linear(27*self.base, self.base*12), # for 64 frames 64 | nn.BatchNorm1d(self.base*12), 65 | nn.Linear(12*self.base, self.base*4), 66 | nn.BatchNorm1d(self.base*4), 67 | nn.LeakyReLU(True), 68 | nn.Linear(self.base*4, self.base*2), 69 | nn.BatchNorm1d(self.base*2), 70 | nn.LeakyReLU(True), 71 | nn.Linear(self.base*2, self.base), 72 | ) 73 | 74 | if length == 34: 75 | self.out_net = nn.Sequential( 76 | nn.Linear(12*self.base, self.base*4), # for 34 frames 77 | nn.BatchNorm1d(self.base*4), 78 | nn.LeakyReLU(True), 79 | nn.Linear(self.base*4, self.base*2), 80 | nn.BatchNorm1d(self.base*2), 81 | nn.LeakyReLU(True), 82 | nn.Linear(self.base*2, self.base), 83 | ) 84 | 85 | self.fc_mu = nn.Linear(self.base, self.base) 86 | self.fc_logvar = nn.Linear(self.base, self.base) 87 | 88 | def forward(self, poses, variational_encoding=None): 89 | # encode 90 | poses = poses.transpose(1, 2) 91 | out = self.net(poses) 92 | out = out.flatten(1) 93 | out = self.out_net(out) 94 | mu = self.fc_mu(out) 95 | logvar = self.fc_logvar(out) 96 | if variational_encoding: 97 | z = reparameterize(mu, logvar) 98 | else: 99 | z = mu 100 | return z, mu, logvar 101 | 102 | 103 | class PoseDecoderConv(nn.Module): 104 | def __init__(self, length, dim, use_pre_poses=False, feature_length=32): 105 | super().__init__() 106 | self.use_pre_poses = use_pre_poses 107 | self.feat_size = feature_length 108 | 109 | if use_pre_poses: 110 | self.pre_pose_net = nn.Sequential( 111 | nn.Linear(dim * 4, 32), 112 | nn.BatchNorm1d(32), 113 | nn.ReLU(), 114 | nn.Linear(32, 32), 115 | ) 116 | self.feat_size += 32 117 | 118 | # if length == 64: 119 | # self.pre_net = nn.Sequential( 120 | # nn.Linear(self.feat_size, 128), 121 | # nn.BatchNorm1d(128), 122 | # nn.LeakyReLU(True), 123 | # nn.Linear(128, 256), 124 | # ) 125 | # elif length == 34: 126 | # self.pre_net = nn.Sequential( 127 | # nn.Linear(self.feat_size, self.feat_size*2), 128 | # nn.BatchNorm1d(self.feat_size*2), 129 | # nn.LeakyReLU(True), 130 | # nn.Linear(self.feat_size*2, self.feat_size//8*34), 131 | # ) 132 | # else: 133 | # assert False 134 | self.pre_net = nn.Sequential( 135 | nn.Linear(self.feat_size, self.feat_size*2), 136 | nn.BatchNorm1d(self.feat_size*2), 137 | nn.LeakyReLU(True), 138 | nn.Linear(self.feat_size*2, self.feat_size//8*length), 139 | ) 140 | 141 | 142 | self.decoder_size = self.feat_size//8 143 | self.net = nn.Sequential( 144 | nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3), 145 | nn.BatchNorm1d(self.feat_size), 146 | nn.LeakyReLU(0.2, True), 147 | 148 | nn.ConvTranspose1d(self.feat_size, self.feat_size, 3), 149 | nn.BatchNorm1d(self.feat_size), 150 | nn.LeakyReLU(0.2, True), 151 | nn.Conv1d(self.feat_size, self.feat_size*2, 3), 152 | nn.Conv1d(self.feat_size*2, dim, 3), 153 | ) 154 | 155 | def forward(self, feat, pre_poses=None): 156 | if self.use_pre_poses: 157 | pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) 158 | feat = torch.cat((pre_pose_feat, feat), dim=1) 159 | #print(feat.shape) 160 | out = self.pre_net(feat) 161 | #print(out.shape) 162 | out = out.view(feat.shape[0], self.decoder_size, -1) 163 | #print(out.shape) 164 | out = self.net(out) 165 | out = out.transpose(1, 2) 166 | return out 167 | 168 | 169 | class EmbeddingNet(nn.Module): 170 | def __init__(self, args): 171 | super().__init__() 172 | n_frames = args.n_poses 173 | # n_frames = 34 174 | pose_dim = args.net_dim_pose 175 | feature_length = args.vae_length 176 | self.pose_encoder = PoseEncoderConv(n_frames, pose_dim, feature_length=feature_length) 177 | self.decoder = PoseDecoderConv(n_frames, pose_dim, feature_length=feature_length) 178 | 179 | def forward(self, pre_poses, poses, variational_encoding=False): 180 | poses_feat, pose_mu, pose_logvar = self.pose_encoder(poses, variational_encoding) 181 | latent_feat = poses_feat 182 | out_poses = self.decoder(latent_feat, pre_poses) 183 | return poses_feat, pose_mu, pose_logvar, out_poses 184 | 185 | def freeze_pose_nets(self): 186 | for param in self.pose_encoder.parameters(): 187 | param.requires_grad = False 188 | for param in self.decoder.parameters(): 189 | param.requires_grad = False 190 | 191 | 192 | class HalfEmbeddingNet(nn.Module): 193 | def __init__(self, args): 194 | super().__init__() 195 | n_frames = args.n_poses 196 | # n_frames = 34 197 | pose_dim = args.net_dim_pose 198 | feature_length = args.vae_length 199 | self.pose_encoder = PoseEncoderConv(n_frames, pose_dim, feature_length=feature_length) 200 | self.decoder = PoseDecoderConv(n_frames, pose_dim, feature_length=feature_length) 201 | 202 | def forward(self, poses): 203 | poses_feat, _, _ = self.pose_encoder(poses) 204 | return poses_feat -------------------------------------------------------------------------------- /models/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license 16 | 17 | def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0): 18 | if n_steplength > 1: 19 | if not n_sample > 1: 20 | raise RuntimeError('n_steplength has no effect if n_sample=1') 21 | 22 | t = t_T 23 | times = [t] 24 | while t >= 0: 25 | t = t - 1 26 | times.append(t) 27 | n_steplength_cur = min(n_steplength, t_T - t) 28 | 29 | for _ in range(n_sample - 1): 30 | 31 | for _ in range(n_steplength_cur): 32 | t = t + 1 33 | times.append(t) 34 | for _ in range(n_steplength_cur): 35 | t = t - 1 36 | times.append(t) 37 | 38 | _check_times(times, t_0, t_T) 39 | 40 | if debug == 2: 41 | for x in [list(range(0, 50)), list(range(-1, -50, -1))]: 42 | _plot_times(x=x, times=[times[i] for i in x]) 43 | 44 | return times 45 | 46 | 47 | def _check_times(times, t_0, t_T): 48 | # Check end 49 | assert times[0] > times[1], (times[0], times[1]) 50 | 51 | # Check beginning 52 | assert times[-1] == -1, times[-1] 53 | 54 | # Steplength = 1 55 | for t_last, t_cur in zip(times[:-1], times[1:]): 56 | assert abs(t_last - t_cur) == 1, (t_last, t_cur) 57 | 58 | # Value range 59 | for t in times: 60 | assert t >= t_0, (t, t_0) 61 | assert t <= t_T, (t, t_T) 62 | 63 | 64 | def _plot_times(x, times): 65 | import matplotlib.pyplot as plt 66 | plt.plot(x, times) 67 | plt.show() 68 | 69 | 70 | def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample, 71 | jump2_length=1, jump2_n_sample=1, 72 | jump3_length=1, jump3_n_sample=1, 73 | start_resampling=100000000): 74 | 75 | jumps = {} 76 | for j in range(0, t_T - jump_length, jump_length): 77 | jumps[j] = jump_n_sample - 1 78 | 79 | jumps2 = {} 80 | for j in range(0, t_T - jump2_length, jump2_length): 81 | jumps2[j] = jump2_n_sample - 1 82 | 83 | jumps3 = {} 84 | for j in range(0, t_T - jump3_length, jump3_length): 85 | jumps3[j] = jump3_n_sample - 1 86 | 87 | t = t_T 88 | ts = [] 89 | 90 | while t >= 1: 91 | t = t-1 92 | ts.append(t) 93 | 94 | if ( 95 | t + 1 < t_T - 1 and 96 | t <= start_resampling 97 | ): 98 | for _ in range(n_sample - 1): 99 | t = t + 1 100 | ts.append(t) 101 | 102 | if t >= 0: 103 | t = t - 1 104 | ts.append(t) 105 | 106 | if ( 107 | jumps3.get(t, 0) > 0 and 108 | t <= start_resampling - jump3_length 109 | ): 110 | jumps3[t] = jumps3[t] - 1 111 | for _ in range(jump3_length): 112 | t = t + 1 113 | ts.append(t) 114 | 115 | if ( 116 | jumps2.get(t, 0) > 0 and 117 | t <= start_resampling - jump2_length 118 | ): 119 | jumps2[t] = jumps2[t] - 1 120 | for _ in range(jump2_length): 121 | t = t + 1 122 | ts.append(t) 123 | jumps3 = {} 124 | for j in range(0, t_T - jump3_length, jump3_length): 125 | jumps3[j] = jump3_n_sample - 1 126 | 127 | if ( 128 | jumps.get(t, 0) > 0 and 129 | t <= start_resampling - jump_length 130 | ): 131 | jumps[t] = jumps[t] - 1 132 | for _ in range(jump_length): 133 | t = t + 1 134 | ts.append(t) 135 | jumps2 = {} 136 | for j in range(0, t_T - jump2_length, jump2_length): 137 | jumps2[j] = jump2_n_sample - 1 138 | 139 | jumps3 = {} 140 | for j in range(0, t_T - jump3_length, jump3_length): 141 | jumps3[j] = jump3_n_sample - 1 142 | 143 | ts.append(-1) 144 | 145 | _check_times(ts, -1, t_T) 146 | 147 | return ts 148 | 149 | 150 | def get_schedule_jump_paper(): 151 | t_T = 250 152 | jump_length = 10 153 | jump_n_sample = 10 154 | 155 | jumps = {} 156 | for j in range(0, t_T - jump_length, jump_length): 157 | jumps[j] = jump_n_sample - 1 158 | 159 | t = t_T 160 | ts = [] 161 | 162 | while t >= 1: 163 | t = t-1 164 | ts.append(t) 165 | 166 | if jumps.get(t, 0) > 0: 167 | jumps[t] = jumps[t] - 1 168 | for _ in range(jump_length): 169 | t = t + 1 170 | ts.append(t) 171 | 172 | ts.append(-1) 173 | 174 | _check_times(ts, -1, t_T) 175 | 176 | return ts 177 | 178 | def get_schedule_jump_cjm_ddim(time_respacing=25, jump_length=1, jump_n_sample=1): 179 | if time_respacing == 25: 180 | t_T = 15 181 | else: 182 | t_T = int(time_respacing * 0.6) 183 | 184 | # t_T = time_respacing 185 | # t_T = 15 186 | 187 | jumps = {} 188 | for j in range(0, t_T - jump_length, jump_length): 189 | jumps[j] = jump_n_sample - 1 190 | 191 | t = t_T 192 | ts = [] 193 | 194 | while t >= 1: 195 | t = t-1 196 | ts.append(t) 197 | 198 | if jumps.get(t, 0) > 0: 199 | jumps[t] = jumps[t] - 1 200 | for _ in range(jump_length): 201 | t = t + 1 202 | ts.append(t) 203 | 204 | ts.append(-1) 205 | 206 | _check_times(ts, -1, t_T) 207 | 208 | return ts 209 | 210 | 211 | def get_schedule_jump_test(to_supplement=False): 212 | ts = get_schedule_jump(t_T=250, n_sample=1, 213 | jump_length=10, jump_n_sample=10, 214 | jump2_length=1, jump2_n_sample=1, 215 | jump3_length=1, jump3_n_sample=1, 216 | start_resampling=250) 217 | 218 | import matplotlib.pyplot as plt 219 | SMALL_SIZE = 8*3 220 | MEDIUM_SIZE = 10*3 221 | BIGGER_SIZE = 12*3 222 | 223 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 224 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 225 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 226 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 227 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 228 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 229 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 230 | 231 | plt.plot(ts) 232 | 233 | fig = plt.gcf() 234 | fig.set_size_inches(20, 10) 235 | 236 | ax = plt.gca() 237 | ax.set_xlabel('Number of Transitions') 238 | ax.set_ylabel('Diffusion time $t$') 239 | 240 | fig.tight_layout() 241 | 242 | if to_supplement: 243 | out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf" 244 | plt.savefig(out_path) 245 | 246 | out_path = "./schedule.png" 247 | plt.savefig(out_path) 248 | print(out_path) 249 | 250 | 251 | def main(): 252 | get_schedule_jump_test() 253 | 254 | 255 | if __name__ == "__main__": 256 | main() -------------------------------------------------------------------------------- /datasets/show.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | import os 5 | from os.path import join as pjoin 6 | import random 7 | import codecs as cs 8 | from tqdm import tqdm 9 | import pickle 10 | import lmdb 11 | import pyarrow 12 | import torch.nn.functional as F 13 | 14 | class ShowDataset(data.Dataset): 15 | """ 16 | TED dataset. 17 | Prepares conditioning information (previous poses + control signal) and the corresponding next poses""" 18 | 19 | def __init__(self, opt, data_path): 20 | """ 21 | Args: 22 | control_data: The control input 23 | joint_data: body pose input 24 | Both with shape (samples, time-slices, features)s 25 | """ 26 | self.opt = opt 27 | 28 | if self.opt.mode != "test_custom_audio": 29 | print("Loading data ...") 30 | self.lmdb_env = lmdb.open(data_path, readonly=True, lock=False) 31 | if opt.use_aud_feat or self.opt.expAddHubert or self.opt.addHubert: 32 | self.aud_feat_path = os.path.join(os.path.dirname(os.path.dirname(data_path)), f"cached_aud_hubert/{os.path.basename(data_path).split('_')[-2]}/hubert_large_ls960_ft") 33 | self.aud_lmdb_env = lmdb.open(self.aud_feat_path, readonly=True, lock=False) 34 | 35 | if opt.use_aud_feat or self.opt.addWav2Vec2: 36 | self.aud_feat_path = os.path.join(os.path.dirname(os.path.dirname(data_path)), f"cached_aud_wav2vec2/{os.path.basename(data_path).split('_')[-2]}/wav2vec2_base_960h") 37 | self.aud_lmdb_env = lmdb.open(self.aud_feat_path, readonly=True, lock=False) 38 | 39 | with self.lmdb_env.begin() as txn: 40 | self.n_samples = txn.stat()["entries"] 41 | # dict_keys(['poses', 'expression', 'aud_feat', 'speaker', 'aud_file', 'betas']) 42 | 43 | mean_std_dict = np.load("data/SHOW/talkshow_mean_std.npy", allow_pickle=True)[()] 44 | 45 | self.pose_mean = self.extract_pose(torch.from_numpy(mean_std_dict["pose_mean"]).float()) 46 | self.pose_std = self.extract_pose(torch.from_numpy(mean_std_dict["pose_std"]).float()) 47 | self.expression_mean = torch.cat([torch.from_numpy(mean_std_dict["pose_mean"][:3]).float(), torch.from_numpy(mean_std_dict["expression_mean"]).float()], dim=-1) 48 | self.expression_std = torch.cat([torch.from_numpy(mean_std_dict["pose_mean"][:3]).float(), torch.from_numpy(mean_std_dict["expression_std"]).float()], dim=-1) 49 | 50 | self.motion_mean = torch.cat([self.pose_mean, self.expression_mean], dim=-1) 51 | self.motion_std = torch.cat([self.pose_std, self.expression_std], dim=-1) 52 | 53 | if self.opt.usePredExpr: 54 | face_list = os.listdir(self.opt.usePredExpr) 55 | face_list = [ff for ff in face_list if ff.endswith(".npy")] 56 | face_list.sort() 57 | self.face_list = [os.path.join(self.opt.usePredExpr, pp) for pp in face_list] 58 | 59 | 60 | 61 | 62 | def __len__(self): 63 | return self.n_samples 64 | 65 | def __getitem__(self, idx): 66 | """ 67 | Returns poses and conditioning. 68 | """ 69 | with self.lmdb_env.begin(write=False) as txn: 70 | key = "{:005}".format(idx).encode("ascii") 71 | sample = txn.get(key) 72 | sample = pyarrow.deserialize(sample) 73 | pose, expression, aud_raw, mfcc, mel, speaker, aud_file, betas = sample 74 | # pose, expression, mfcc = pose.T, expression.T, mfcc.T 75 | pose = torch.from_numpy(pose.copy()).float() 76 | expression = torch.from_numpy(expression.copy()).float() 77 | aud_raw = torch.from_numpy(aud_raw.copy()).float() 78 | mfcc = torch.from_numpy(mfcc.copy()).float() 79 | mel = torch.from_numpy(mel.copy()).float() 80 | speaker = torch.from_numpy(speaker.copy()).float() 81 | betas = torch.from_numpy(betas.copy()).float() 82 | 83 | jaw_pose, leye_pose, reye_pose, global_orient, body_pose, hand_pose = torch.split(pose, [3,3,3,3, 63, 90], dim=-1) 84 | low1, up1, low2, up2, low3, up3, low4, up4 = torch.split(body_pose, [6, 3, 6, 3, 6, 3, 6, 30], dim=-1) 85 | pose = torch.cat([up1, up2, up3, up4, hand_pose], dim=-1) 86 | expression = torch.cat([jaw_pose, expression], dim=-1) 87 | 88 | pose = self.standardize(pose, self.pose_mean, self.pose_std) 89 | expression = self.standardize(expression, self.expression_mean, self.expression_std) 90 | 91 | hubert = None 92 | if self.opt.audio_feat == "hubert" or self.opt.expAddHubert or self.opt.addHubert: 93 | with self.aud_lmdb_env.begin(write=False) as txn_aud: 94 | key = "{:005}".format(idx).encode("ascii") 95 | hubert = txn_aud.get(key) 96 | hubert = pyarrow.deserialize(hubert) 97 | hubert = torch.from_numpy(hubert.copy()).float() 98 | hubert = F.interpolate(hubert.swapaxes(-1,-2).unsqueeze(0), size=pose.shape[0], mode='linear', align_corners=True).swapaxes(-1,-2).squeeze() 99 | 100 | wav2vec2 = None 101 | if self.opt.audio_feat == "wav2vec2" or self.opt.addWav2Vec2: 102 | with self.aud_lmdb_env.begin(write=False) as txn_aud: 103 | key = "{:005}".format(idx).encode("ascii") 104 | wav2vec2 = txn_aud.get(key) 105 | wav2vec2 = pyarrow.deserialize(wav2vec2) 106 | wav2vec2 = torch.from_numpy(wav2vec2.copy()).float() 107 | 108 | if self.opt.audio_feat == "mfcc": 109 | aud_feat = mfcc 110 | elif self.opt.audio_feat == "mel": 111 | aud_feat = mel 112 | elif self.opt.audio_feat == "raw": 113 | aud_feat = aud_raw 114 | elif self.opt.audio_feat == "hubert": 115 | aud_feat = hubert 116 | elif self.opt.audio_feat == "wav2vec2": 117 | aud_feat = wav2vec2 118 | 119 | if self.opt.expAddHubert or self.opt.addHubert: 120 | return {'poses': pose, 121 | 'expression': expression, 122 | 'aud_feat': aud_feat, 123 | 'pretrain_aud_feat': hubert, 124 | 'speaker': speaker, 125 | 'aud_file': aud_file, 126 | 'betas': betas 127 | } 128 | elif self.opt.addWav2Vec2: 129 | return {'poses': pose, 130 | 'expression': expression, 131 | 'aud_feat': aud_feat, 132 | 'pretrain_aud_feat': wav2vec2, 133 | 'speaker': speaker, 134 | 'aud_file': aud_file, 135 | 'betas': betas 136 | } 137 | else: 138 | return {'poses': pose, 139 | 'expression': expression, 140 | 'aud_feat': aud_feat, 141 | 'speaker': speaker, 142 | 'aud_file': aud_file, 143 | 'betas': betas 144 | } 145 | 146 | def extract_pose(self, pose): 147 | jaw_pose, leye_pose, reye_pose, global_orient, body_pose, hand_pose = torch.split(pose, [3,3,3,3, 63, 90], dim=-1) 148 | low1, up1, low2, up2, low3, up3, low4, up4 = torch.split(body_pose, [6, 3, 6, 3, 6, 3, 6, 30], dim=-1) 149 | # pose = torch.cat([jaw_pose, up1, up2, up3, up4, hand_pose], dim=-1) 150 | pose = torch.cat([up1, up2, up3, up4, hand_pose], dim=-1) 151 | return pose 152 | 153 | def standardize(self, data, mean, std): 154 | scaled = (data - mean) / std 155 | return scaled 156 | 157 | def inv_standardize(self, data, mean, std): 158 | try: 159 | inv_scaled = data * std + mean 160 | except: 161 | inv_scaled = data * std.numpy() + mean.numpy() 162 | return inv_scaled -------------------------------------------------------------------------------- /datasets/pymo/viz_tools.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import IPython 5 | import os 6 | 7 | def save_fig(fig_id, tight_layout=True): 8 | if tight_layout: 9 | plt.tight_layout() 10 | plt.savefig(fig_id + '.png', format='png', dpi=300) 11 | 12 | 13 | def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): 14 | if ax is None: 15 | fig = plt.figure(figsize=figsize) 16 | ax = fig.add_subplot(111) 17 | 18 | if joints is None: 19 | joints_to_draw = mocap_track.skeleton.keys() 20 | else: 21 | joints_to_draw = joints 22 | 23 | if data is None: 24 | df = mocap_track.values 25 | else: 26 | df = data 27 | 28 | for joint in joints_to_draw: 29 | ax.scatter(x=df['%s_Xposition'%joint][frame], 30 | y=df['%s_Yposition'%joint][frame], 31 | alpha=0.6, c='b', marker='o') 32 | 33 | parent_x = df['%s_Xposition'%joint][frame] 34 | parent_y = df['%s_Yposition'%joint][frame] 35 | 36 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] 37 | 38 | for c in children_to_draw: 39 | child_x = df['%s_Xposition'%c][frame] 40 | child_y = df['%s_Yposition'%c][frame] 41 | ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2) 42 | 43 | if draw_names: 44 | ax.annotate(joint, 45 | (df['%s_Xposition'%joint][frame] + 0.1, 46 | df['%s_Yposition'%joint][frame] + 0.1)) 47 | 48 | return ax 49 | 50 | def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): 51 | from mpl_toolkits.mplot3d import Axes3D 52 | 53 | if ax is None: 54 | fig = plt.figure(figsize=figsize) 55 | ax = fig.add_subplot(111, projection='3d') 56 | 57 | if joints is None: 58 | joints_to_draw = mocap_track.skeleton.keys() 59 | else: 60 | joints_to_draw = joints 61 | 62 | if data is None: 63 | df = mocap_track.values 64 | else: 65 | df = data 66 | 67 | for joint in joints_to_draw: 68 | parent_x = df['%s_Xposition'%joint][frame] 69 | parent_y = df['%s_Zposition'%joint][frame] 70 | parent_z = df['%s_Yposition'%joint][frame] 71 | # ^ In mocaps, Y is the up-right axis 72 | 73 | ax.scatter(xs=parent_x, 74 | ys=parent_y, 75 | zs=parent_z, 76 | alpha=0.6, c='b', marker='o') 77 | 78 | 79 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] 80 | 81 | for c in children_to_draw: 82 | child_x = df['%s_Xposition'%c][frame] 83 | child_y = df['%s_Zposition'%c][frame] 84 | child_z = df['%s_Yposition'%c][frame] 85 | # ^ In mocaps, Y is the up-right axis 86 | 87 | ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black') 88 | 89 | if draw_names: 90 | ax.text(x=parent_x + 0.1, 91 | y=parent_y + 0.1, 92 | z=parent_z + 0.1, 93 | s=joint, 94 | color='rgba(0,0,0,0.9)') 95 | 96 | return ax 97 | 98 | 99 | def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)): 100 | if ax is None: 101 | fig = plt.figure(figsize=figsize) 102 | ax = fig.add_subplot(111) 103 | 104 | if data is None: 105 | data = mocap_track.values 106 | 107 | for frame in range(0, data.shape[0], 4): 108 | # draw_stickfigure(mocap_track, f, data=data, ax=ax) 109 | 110 | for joint in mocap_track.skeleton.keys(): 111 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] 112 | 113 | parent_x = data['%s_Xposition'%joint][frame] 114 | parent_y = data['%s_Yposition'%joint][frame] 115 | 116 | frame_alpha = frame/data.shape[0] 117 | 118 | for c in children_to_draw: 119 | child_x = data['%s_Xposition'%c][frame] 120 | child_y = data['%s_Yposition'%c][frame] 121 | 122 | ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) 123 | 124 | 125 | 126 | def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25): 127 | fig = plt.figure(figsize=(16,4)) 128 | ax = plt.subplot2grid((1,8),(0,0)) 129 | ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest') 130 | 131 | ax = plt.subplot2grid((1,8),(0,1), colspan=7) 132 | for frame in range(feature_to_viz.shape[0]): 133 | frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2 134 | 135 | for joint_i, joint in enumerate(mocap_track.skeleton.keys()): 136 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] 137 | 138 | parent_x = data['%s_Xposition'%joint][frame] + frame * gap 139 | parent_y = data['%s_Yposition'%joint][frame] 140 | 141 | ax.scatter(x=parent_x, 142 | y=parent_y, 143 | alpha=0.6, 144 | cmap='RdBu', 145 | c=feature_to_viz[frame][joint_i] * 10000, 146 | marker='o', 147 | s = abs(feature_to_viz[frame][joint_i] * 10000)) 148 | plt.axis('off') 149 | for c in children_to_draw: 150 | child_x = data['%s_Xposition'%c][frame] + frame * gap 151 | child_y = data['%s_Yposition'%c][frame] 152 | 153 | ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) 154 | 155 | 156 | def print_skel(X): 157 | stack = [X.root_name] 158 | tab=0 159 | while stack: 160 | joint = stack.pop() 161 | tab = len(stack) 162 | print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent'])) 163 | for c in X.skeleton[joint]['children']: 164 | stack.append(c) 165 | 166 | 167 | def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'): 168 | if mf == 'bvh': 169 | bw = BVHWriter() 170 | with open('test.bvh', 'w') as ofile: 171 | bw.write(mocap, ofile) 172 | 173 | filepath = '../notebooks/test.bvh' 174 | elif mf == 'pos': 175 | c = list(mocap.values.columns) 176 | 177 | for cc in c: 178 | if 'rotation' in cc: 179 | c.remove(cc) 180 | mocap.values.to_csv('test.csv', index=False, columns=c) 181 | 182 | filepath = '../notebooks/test.csv' 183 | else: 184 | return 185 | 186 | url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time) 187 | iframe = '' 188 | link = 'New Window'%url 189 | return IPython.display.HTML(iframe+link) 190 | 191 | def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None): 192 | data_template = 'var dataBuffer = `$$DATA$$`;' 193 | data_template += 'var metadata = $$META$$;' 194 | data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);' 195 | dir_path = os.path.dirname(os.path.realpath(__file__)) 196 | 197 | 198 | if base_url is None: 199 | base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html') 200 | 201 | # print(dir_path) 202 | 203 | if mf == 'bvh': 204 | pass 205 | elif mf == 'pos': 206 | cols = list(mocap.values.columns) 207 | for c in cols: 208 | if 'rotation' in c: 209 | cols.remove(c) 210 | 211 | data_csv = mocap.values.to_csv(index=False, columns=cols) 212 | 213 | if meta is not None: 214 | lines = [','.join(item) for item in meta.astype('str')] 215 | meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']' 216 | else: 217 | meta_csv = '[]' 218 | 219 | data_assigned = data_template.replace('$$DATA$$', data_csv) 220 | data_assigned = data_assigned.replace('$$META$$', meta_csv) 221 | data_assigned = data_assigned.replace('$$CZ$$', str(camera_z)) 222 | data_assigned = data_assigned.replace('$$SCALE$$', str(scale)) 223 | data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time)) 224 | 225 | else: 226 | return 227 | 228 | 229 | 230 | with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile: 231 | oFile.write(data_assigned) 232 | 233 | url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale) 234 | iframe = '' 235 | link = 'New Window'%url 236 | return IPython.display.HTML(iframe+link) 237 | -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/playURL.html: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | BVH Player 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 69 | 70 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /datasets/pymo/parsers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | BVH Parser Class 3 | 4 | By Omid Alemi 5 | Created: June 12, 2017 6 | 7 | Based on: https://gist.github.com/johnfredcee/2007503 8 | 9 | ''' 10 | import re 11 | import numpy as np 12 | from .data import Joint, MocapData 13 | 14 | class BVHScanner(): 15 | ''' 16 | A wrapper class for re.Scanner 17 | ''' 18 | def __init__(self): 19 | 20 | def identifier(scanner, token): 21 | return 'IDENT', token 22 | 23 | def operator(scanner, token): 24 | return 'OPERATOR', token 25 | 26 | def digit(scanner, token): 27 | return 'DIGIT', token 28 | 29 | def open_brace(scanner, token): 30 | return 'OPEN_BRACE', token 31 | 32 | def close_brace(scanner, token): 33 | return 'CLOSE_BRACE', token 34 | 35 | self.scanner = re.Scanner([ 36 | (r'[a-zA-Z_]\w*', identifier), 37 | #(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34 38 | #(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2 39 | #(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), 40 | (r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), 41 | (r'}', close_brace), 42 | (r'}', close_brace), 43 | (r'{', open_brace), 44 | (r':', None), 45 | (r'\s+', None) 46 | ]) 47 | 48 | def scan(self, stuff): 49 | return self.scanner.scan(stuff) 50 | 51 | 52 | 53 | class BVHParser(): 54 | ''' 55 | A class to parse a BVH file. 56 | 57 | Extracts the skeleton and channel values 58 | ''' 59 | def __init__(self, filename=None): 60 | self.reset() 61 | 62 | def reset(self): 63 | self._skeleton = {} 64 | self.bone_context = [] 65 | self._motion_channels = [] 66 | self._motions = [] 67 | self.current_token = 0 68 | self.framerate = 0.0 69 | self.root_name = '' 70 | 71 | self.scanner = BVHScanner() 72 | 73 | self.data = MocapData() 74 | 75 | 76 | def parse(self, filename, start=0, stop=-1): 77 | self.reset() 78 | 79 | with open(filename, 'r') as bvh_file: 80 | raw_contents = bvh_file.read() 81 | tokens, remainder = self.scanner.scan(raw_contents) 82 | self._parse_hierarchy(tokens) 83 | self.current_token = self.current_token + 1 84 | self._parse_motion(tokens, start, stop) 85 | 86 | self.data.skeleton = self._skeleton 87 | self.data.channel_names = self._motion_channels 88 | self.data.values = self._to_DataFrame() 89 | self.data.root_name = self.root_name 90 | self.data.framerate = self.framerate 91 | 92 | return self.data 93 | 94 | def _to_DataFrame(self): 95 | '''Returns all of the channels parsed from the file as a pandas DataFrame''' 96 | 97 | import pandas as pd 98 | time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s') 99 | frames = [f[1] for f in self._motions] 100 | channels = np.asarray([[channel[2] for channel in frame] for frame in frames]) 101 | column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels] 102 | 103 | return pd.DataFrame(data=channels, index=time_index, columns=column_names) 104 | 105 | 106 | def _new_bone(self, parent, name): 107 | bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []} 108 | return bone 109 | 110 | def _push_bone_context(self,name): 111 | self.bone_context.append(name) 112 | 113 | def _get_bone_context(self): 114 | return self.bone_context[len(self.bone_context)-1] 115 | 116 | def _pop_bone_context(self): 117 | self.bone_context = self.bone_context[:-1] 118 | return self.bone_context[len(self.bone_context)-1] 119 | 120 | def _read_offset(self, bvh, token_index): 121 | if bvh[token_index] != ('IDENT', 'OFFSET'): 122 | return None, None 123 | token_index = token_index + 1 124 | offsets = [0.0] * 3 125 | for i in range(3): 126 | offsets[i] = float(bvh[token_index][1]) 127 | token_index = token_index + 1 128 | return offsets, token_index 129 | 130 | def _read_channels(self, bvh, token_index): 131 | if bvh[token_index] != ('IDENT', 'CHANNELS'): 132 | return None, None 133 | token_index = token_index + 1 134 | channel_count = int(bvh[token_index][1]) 135 | token_index = token_index + 1 136 | channels = [""] * channel_count 137 | order = "" 138 | for i in range(channel_count): 139 | channels[i] = bvh[token_index][1] 140 | token_index = token_index + 1 141 | if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"): 142 | order += channels[i][0] 143 | else : 144 | order = "" 145 | return channels, token_index, order 146 | 147 | def _parse_joint(self, bvh, token_index): 148 | end_site = False 149 | joint_id = bvh[token_index][1] 150 | token_index = token_index + 1 151 | joint_name = bvh[token_index][1] 152 | token_index = token_index + 1 153 | 154 | parent_name = self._get_bone_context() 155 | 156 | if (joint_id == "End"): 157 | joint_name = parent_name+ '_Nub' 158 | end_site = True 159 | joint = self._new_bone(parent_name, joint_name) 160 | if bvh[token_index][0] != 'OPEN_BRACE': 161 | print('Was expecting brance, got ', bvh[token_index]) 162 | return None 163 | token_index = token_index + 1 164 | offsets, token_index = self._read_offset(bvh, token_index) 165 | joint['offsets'] = offsets 166 | if not end_site: 167 | channels, token_index, order = self._read_channels(bvh, token_index) 168 | joint['channels'] = channels 169 | joint['order'] = order 170 | for channel in channels: 171 | self._motion_channels.append((joint_name, channel)) 172 | 173 | self._skeleton[joint_name] = joint 174 | self._skeleton[parent_name]['children'].append(joint_name) 175 | 176 | while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'): 177 | self._push_bone_context(joint_name) 178 | token_index = self._parse_joint(bvh, token_index) 179 | self._pop_bone_context() 180 | 181 | if bvh[token_index][0] == 'CLOSE_BRACE': 182 | return token_index + 1 183 | 184 | print('Unexpected token ', bvh[token_index]) 185 | 186 | def _parse_hierarchy(self, bvh): 187 | self.current_token = 0 188 | if bvh[self.current_token] != ('IDENT', 'HIERARCHY'): 189 | return None 190 | self.current_token = self.current_token + 1 191 | if bvh[self.current_token] != ('IDENT', 'ROOT'): 192 | return None 193 | self.current_token = self.current_token + 1 194 | if bvh[self.current_token][0] != 'IDENT': 195 | return None 196 | 197 | root_name = bvh[self.current_token][1] 198 | root_bone = self._new_bone(None, root_name) 199 | self.current_token = self.current_token + 2 #skipping open brace 200 | offsets, self.current_token = self._read_offset(bvh, self.current_token) 201 | channels, self.current_token, order = self._read_channels(bvh, self.current_token) 202 | root_bone['offsets'] = offsets 203 | root_bone['channels'] = channels 204 | root_bone['order'] = order 205 | self._skeleton[root_name] = root_bone 206 | self._push_bone_context(root_name) 207 | 208 | for channel in channels: 209 | self._motion_channels.append((root_name, channel)) 210 | 211 | while bvh[self.current_token][1] == 'JOINT': 212 | self.current_token = self._parse_joint(bvh, self.current_token) 213 | 214 | self.root_name = root_name 215 | 216 | def _parse_motion(self, bvh, start, stop): 217 | if bvh[self.current_token][0] != 'IDENT': 218 | print('Unexpected text') 219 | return None 220 | if bvh[self.current_token][1] != 'MOTION': 221 | print('No motion section') 222 | return None 223 | self.current_token = self.current_token + 1 224 | if bvh[self.current_token][1] != 'Frames': 225 | return None 226 | self.current_token = self.current_token + 1 227 | frame_count = int(bvh[self.current_token][1]) 228 | 229 | if stop<0 or stop>frame_count: 230 | stop = frame_count 231 | 232 | assert(start>=0) 233 | assert(start=start: 258 | self._motions[idx] = (frame_time, channel_values) 259 | frame_time = frame_time + frame_rate 260 | idx+=1 261 | -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/libs/pace.min.js: -------------------------------------------------------------------------------- 1 | /*! pace 1.0.2 */ 2 | (function(){var a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,X=[].slice,Y={}.hasOwnProperty,Z=function(a,b){function c(){this.constructor=a}for(var d in b)Y.call(b,d)&&(a[d]=b[d]);return c.prototype=b.prototype,a.prototype=new c,a.__super__=b.prototype,a},$=[].indexOf||function(a){for(var b=0,c=this.length;c>b;b++)if(b in this&&this[b]===a)return b;return-1};for(u={catchupTime:100,initialRate:.03,minTime:250,ghostTime:100,maxProgressPerFrame:20,easeFactor:1.25,startOnPageLoad:!0,restartOnPushState:!0,restartOnRequestAfter:500,target:"body",elements:{checkInterval:100,selectors:["body"]},eventLag:{minSamples:10,sampleCount:3,lagThreshold:3},ajax:{trackMethods:["GET"],trackWebSockets:!0,ignoreURLs:[]}},C=function(){var a;return null!=(a="undefined"!=typeof performance&&null!==performance&&"function"==typeof performance.now?performance.now():void 0)?a:+new Date},E=window.requestAnimationFrame||window.mozRequestAnimationFrame||window.webkitRequestAnimationFrame||window.msRequestAnimationFrame,t=window.cancelAnimationFrame||window.mozCancelAnimationFrame,null==E&&(E=function(a){return setTimeout(a,50)},t=function(a){return clearTimeout(a)}),G=function(a){var b,c;return b=C(),(c=function(){var d;return d=C()-b,d>=33?(b=C(),a(d,function(){return E(c)})):setTimeout(c,33-d)})()},F=function(){var a,b,c;return c=arguments[0],b=arguments[1],a=3<=arguments.length?X.call(arguments,2):[],"function"==typeof c[b]?c[b].apply(c,a):c[b]},v=function(){var a,b,c,d,e,f,g;for(b=arguments[0],d=2<=arguments.length?X.call(arguments,1):[],f=0,g=d.length;g>f;f++)if(c=d[f])for(a in c)Y.call(c,a)&&(e=c[a],null!=b[a]&&"object"==typeof b[a]&&null!=e&&"object"==typeof e?v(b[a],e):b[a]=e);return b},q=function(a){var b,c,d,e,f;for(c=b=0,e=0,f=a.length;f>e;e++)d=a[e],c+=Math.abs(d),b++;return c/b},x=function(a,b){var c,d,e;if(null==a&&(a="options"),null==b&&(b=!0),e=document.querySelector("[data-pace-"+a+"]")){if(c=e.getAttribute("data-pace-"+a),!b)return c;try{return JSON.parse(c)}catch(f){return d=f,"undefined"!=typeof console&&null!==console?console.error("Error parsing inline pace options",d):void 0}}},g=function(){function a(){}return a.prototype.on=function(a,b,c,d){var e;return null==d&&(d=!1),null==this.bindings&&(this.bindings={}),null==(e=this.bindings)[a]&&(e[a]=[]),this.bindings[a].push({handler:b,ctx:c,once:d})},a.prototype.once=function(a,b,c){return this.on(a,b,c,!0)},a.prototype.off=function(a,b){var c,d,e;if(null!=(null!=(d=this.bindings)?d[a]:void 0)){if(null==b)return delete this.bindings[a];for(c=0,e=[];cQ;Q++)K=U[Q],D[K]===!0&&(D[K]=u[K]);i=function(a){function b(){return V=b.__super__.constructor.apply(this,arguments)}return Z(b,a),b}(Error),b=function(){function a(){this.progress=0}return a.prototype.getElement=function(){var a;if(null==this.el){if(a=document.querySelector(D.target),!a)throw new i;this.el=document.createElement("div"),this.el.className="pace pace-active",document.body.className=document.body.className.replace(/pace-done/g,""),document.body.className+=" pace-running",this.el.innerHTML='
\n
\n
\n
',null!=a.firstChild?a.insertBefore(this.el,a.firstChild):a.appendChild(this.el)}return this.el},a.prototype.finish=function(){var a;return a=this.getElement(),a.className=a.className.replace("pace-active",""),a.className+=" pace-inactive",document.body.className=document.body.className.replace("pace-running",""),document.body.className+=" pace-done"},a.prototype.update=function(a){return this.progress=a,this.render()},a.prototype.destroy=function(){try{this.getElement().parentNode.removeChild(this.getElement())}catch(a){i=a}return this.el=void 0},a.prototype.render=function(){var a,b,c,d,e,f,g;if(null==document.querySelector(D.target))return!1;for(a=this.getElement(),d="translate3d("+this.progress+"%, 0, 0)",g=["webkitTransform","msTransform","transform"],e=0,f=g.length;f>e;e++)b=g[e],a.children[0].style[b]=d;return(!this.lastRenderedProgress||this.lastRenderedProgress|0!==this.progress|0)&&(a.children[0].setAttribute("data-progress-text",""+(0|this.progress)+"%"),this.progress>=100?c="99":(c=this.progress<10?"0":"",c+=0|this.progress),a.children[0].setAttribute("data-progress",""+c)),this.lastRenderedProgress=this.progress},a.prototype.done=function(){return this.progress>=100},a}(),h=function(){function a(){this.bindings={}}return a.prototype.trigger=function(a,b){var c,d,e,f,g;if(null!=this.bindings[a]){for(f=this.bindings[a],g=[],d=0,e=f.length;e>d;d++)c=f[d],g.push(c.call(this,b));return g}},a.prototype.on=function(a,b){var c;return null==(c=this.bindings)[a]&&(c[a]=[]),this.bindings[a].push(b)},a}(),P=window.XMLHttpRequest,O=window.XDomainRequest,N=window.WebSocket,w=function(a,b){var c,d,e;e=[];for(d in b.prototype)try{e.push(null==a[d]&&"function"!=typeof b[d]?"function"==typeof Object.defineProperty?Object.defineProperty(a,d,{get:function(){return b.prototype[d]},configurable:!0,enumerable:!0}):a[d]=b.prototype[d]:void 0)}catch(f){c=f}return e},A=[],j.ignore=function(){var a,b,c;return b=arguments[0],a=2<=arguments.length?X.call(arguments,1):[],A.unshift("ignore"),c=b.apply(null,a),A.shift(),c},j.track=function(){var a,b,c;return b=arguments[0],a=2<=arguments.length?X.call(arguments,1):[],A.unshift("track"),c=b.apply(null,a),A.shift(),c},J=function(a){var b;if(null==a&&(a="GET"),"track"===A[0])return"force";if(!A.length&&D.ajax){if("socket"===a&&D.ajax.trackWebSockets)return!0;if(b=a.toUpperCase(),$.call(D.ajax.trackMethods,b)>=0)return!0}return!1},k=function(a){function b(){var a,c=this;b.__super__.constructor.apply(this,arguments),a=function(a){var b;return b=a.open,a.open=function(d,e){return J(d)&&c.trigger("request",{type:d,url:e,request:a}),b.apply(a,arguments)}},window.XMLHttpRequest=function(b){var c;return c=new P(b),a(c),c};try{w(window.XMLHttpRequest,P)}catch(d){}if(null!=O){window.XDomainRequest=function(){var b;return b=new O,a(b),b};try{w(window.XDomainRequest,O)}catch(d){}}if(null!=N&&D.ajax.trackWebSockets){window.WebSocket=function(a,b){var d;return d=null!=b?new N(a,b):new N(a),J("socket")&&c.trigger("request",{type:"socket",url:a,protocols:b,request:d}),d};try{w(window.WebSocket,N)}catch(d){}}}return Z(b,a),b}(h),R=null,y=function(){return null==R&&(R=new k),R},I=function(a){var b,c,d,e;for(e=D.ajax.ignoreURLs,c=0,d=e.length;d>c;c++)if(b=e[c],"string"==typeof b){if(-1!==a.indexOf(b))return!0}else if(b.test(a))return!0;return!1},y().on("request",function(b){var c,d,e,f,g;return f=b.type,e=b.request,g=b.url,I(g)?void 0:j.running||D.restartOnRequestAfter===!1&&"force"!==J(f)?void 0:(d=arguments,c=D.restartOnRequestAfter||0,"boolean"==typeof c&&(c=0),setTimeout(function(){var b,c,g,h,i,k;if(b="socket"===f?e.readyState<2:0<(h=e.readyState)&&4>h){for(j.restart(),i=j.sources,k=[],c=0,g=i.length;g>c;c++){if(K=i[c],K instanceof a){K.watch.apply(K,d);break}k.push(void 0)}return k}},c))}),a=function(){function a(){var a=this;this.elements=[],y().on("request",function(){return a.watch.apply(a,arguments)})}return a.prototype.watch=function(a){var b,c,d,e;return d=a.type,b=a.request,e=a.url,I(e)?void 0:(c="socket"===d?new n(b):new o(b),this.elements.push(c))},a}(),o=function(){function a(a){var b,c,d,e,f,g,h=this;if(this.progress=0,null!=window.ProgressEvent)for(c=null,a.addEventListener("progress",function(a){return h.progress=a.lengthComputable?100*a.loaded/a.total:h.progress+(100-h.progress)/2},!1),g=["load","abort","timeout","error"],d=0,e=g.length;e>d;d++)b=g[d],a.addEventListener(b,function(){return h.progress=100},!1);else f=a.onreadystatechange,a.onreadystatechange=function(){var b;return 0===(b=a.readyState)||4===b?h.progress=100:3===a.readyState&&(h.progress=50),"function"==typeof f?f.apply(null,arguments):void 0}}return a}(),n=function(){function a(a){var b,c,d,e,f=this;for(this.progress=0,e=["error","open"],c=0,d=e.length;d>c;c++)b=e[c],a.addEventListener(b,function(){return f.progress=100},!1)}return a}(),d=function(){function a(a){var b,c,d,f;for(null==a&&(a={}),this.elements=[],null==a.selectors&&(a.selectors=[]),f=a.selectors,c=0,d=f.length;d>c;c++)b=f[c],this.elements.push(new e(b))}return a}(),e=function(){function a(a){this.selector=a,this.progress=0,this.check()}return a.prototype.check=function(){var a=this;return document.querySelector(this.selector)?this.done():setTimeout(function(){return a.check()},D.elements.checkInterval)},a.prototype.done=function(){return this.progress=100},a}(),c=function(){function a(){var a,b,c=this;this.progress=null!=(b=this.states[document.readyState])?b:100,a=document.onreadystatechange,document.onreadystatechange=function(){return null!=c.states[document.readyState]&&(c.progress=c.states[document.readyState]),"function"==typeof a?a.apply(null,arguments):void 0}}return a.prototype.states={loading:0,interactive:50,complete:100},a}(),f=function(){function a(){var a,b,c,d,e,f=this;this.progress=0,a=0,e=[],d=0,c=C(),b=setInterval(function(){var g;return g=C()-c-50,c=C(),e.push(g),e.length>D.eventLag.sampleCount&&e.shift(),a=q(e),++d>=D.eventLag.minSamples&&a=100&&(this.done=!0),b===this.last?this.sinceLastUpdate+=a:(this.sinceLastUpdate&&(this.rate=(b-this.last)/this.sinceLastUpdate),this.catchup=(b-this.progress)/D.catchupTime,this.sinceLastUpdate=0,this.last=b),b>this.progress&&(this.progress+=this.catchup*a),c=1-Math.pow(this.progress/100,D.easeFactor),this.progress+=c*this.rate*a,this.progress=Math.min(this.lastProgress+D.maxProgressPerFrame,this.progress),this.progress=Math.max(0,this.progress),this.progress=Math.min(100,this.progress),this.lastProgress=this.progress,this.progress},a}(),L=null,H=null,r=null,M=null,p=null,s=null,j.running=!1,z=function(){return D.restartOnPushState?j.restart():void 0},null!=window.history.pushState&&(T=window.history.pushState,window.history.pushState=function(){return z(),T.apply(window.history,arguments)}),null!=window.history.replaceState&&(W=window.history.replaceState,window.history.replaceState=function(){return z(),W.apply(window.history,arguments)}),l={ajax:a,elements:d,document:c,eventLag:f},(B=function(){var a,c,d,e,f,g,h,i;for(j.sources=L=[],g=["ajax","elements","document","eventLag"],c=0,e=g.length;e>c;c++)a=g[c],D[a]!==!1&&L.push(new l[a](D[a]));for(i=null!=(h=D.extraSources)?h:[],d=0,f=i.length;f>d;d++)K=i[d],L.push(new K(D));return j.bar=r=new b,H=[],M=new m})(),j.stop=function(){return j.trigger("stop"),j.running=!1,r.destroy(),s=!0,null!=p&&("function"==typeof t&&t(p),p=null),B()},j.restart=function(){return j.trigger("restart"),j.stop(),j.start()},j.go=function(){var a;return j.running=!0,r.render(),a=C(),s=!1,p=G(function(b,c){var d,e,f,g,h,i,k,l,n,o,p,q,t,u,v,w;for(l=100-r.progress,e=p=0,f=!0,i=q=0,u=L.length;u>q;i=++q)for(K=L[i],o=null!=H[i]?H[i]:H[i]=[],h=null!=(w=K.elements)?w:[K],k=t=0,v=h.length;v>t;k=++t)g=h[k],n=null!=o[k]?o[k]:o[k]=new m(g),f&=n.done,n.done||(e++,p+=n.tick(b));return d=p/e,r.update(M.tick(b,d)),r.done()||f||s?(r.update(100),j.trigger("done"),setTimeout(function(){return r.finish(),j.running=!1,j.trigger("hide")},Math.max(D.ghostTime,Math.max(D.minTime-(C()-a),0)))):c()})},j.start=function(a){v(D,a),j.running=!0;try{r.render()}catch(b){i=b}return document.querySelector(".pace")?(j.trigger("start"),j.go()):setTimeout(j.start,50)},"function"==typeof define&&define.amd?define(["pace"],function(){return j}):"object"==typeof exports?module.exports=j:D.startOnPageLoad&&j.start()}).call(this); -------------------------------------------------------------------------------- /datasets/pymo/mocapplayer/libs/papaparse.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | Papa Parse 3 | v4.1.2 4 | https://github.com/mholt/PapaParse 5 | */ 6 | !function(e){"use strict";function t(t,r){if(r=r||{},r.worker&&S.WORKERS_SUPPORTED){var n=f();return n.userStep=r.step,n.userChunk=r.chunk,n.userComplete=r.complete,n.userError=r.error,r.step=m(r.step),r.chunk=m(r.chunk),r.complete=m(r.complete),r.error=m(r.error),delete r.worker,void n.postMessage({input:t,config:r,workerId:n.id})}var o=null;return"string"==typeof t?o=r.download?new i(r):new a(r):(e.File&&t instanceof File||t instanceof Object)&&(o=new s(r)),o.stream(t)}function r(e,t){function r(){"object"==typeof t&&("string"==typeof t.delimiter&&1==t.delimiter.length&&-1==S.BAD_DELIMITERS.indexOf(t.delimiter)&&(u=t.delimiter),("boolean"==typeof t.quotes||t.quotes instanceof Array)&&(o=t.quotes),"string"==typeof t.newline&&(h=t.newline))}function n(e){if("object"!=typeof e)return[];var t=[];for(var r in e)t.push(r);return t}function i(e,t){var r="";"string"==typeof e&&(e=JSON.parse(e)),"string"==typeof t&&(t=JSON.parse(t));var n=e instanceof Array&&e.length>0,i=!(t[0]instanceof Array);if(n){for(var a=0;a0&&(r+=u),r+=s(e[a],a);t.length>0&&(r+=h)}for(var o=0;oc;c++){c>0&&(r+=u);var d=n&&i?e[c]:c;r+=s(t[o][d],c)}o-1||" "==e.charAt(0)||" "==e.charAt(e.length-1);return r?'"'+e+'"':e}function a(e,t){for(var r=0;r-1)return!0;return!1}var o=!1,u=",",h="\r\n";if(r(),"string"==typeof e&&(e=JSON.parse(e)),e instanceof Array){if(!e.length||e[0]instanceof Array)return i(null,e);if("object"==typeof e[0])return i(n(e[0]),e)}else if("object"==typeof e)return"string"==typeof e.data&&(e.data=JSON.parse(e.data)),e.data instanceof Array&&(e.fields||(e.fields=e.data[0]instanceof Array?e.fields:n(e.data[0])),e.data[0]instanceof Array||"object"==typeof e.data[0]||(e.data=[e.data])),i(e.fields||[],e.data||[]);throw"exception: Unable to serialize unrecognized input"}function n(t){function r(e){var t=_(e);t.chunkSize=parseInt(t.chunkSize),e.step||e.chunk||(t.chunkSize=null),this._handle=new o(t),this._handle.streamer=this,this._config=t}this._handle=null,this._paused=!1,this._finished=!1,this._input=null,this._baseIndex=0,this._partialLine="",this._rowCount=0,this._start=0,this._nextChunk=null,this.isFirstChunk=!0,this._completeResults={data:[],errors:[],meta:{}},r.call(this,t),this.parseChunk=function(t){if(this.isFirstChunk&&m(this._config.beforeFirstChunk)){var r=this._config.beforeFirstChunk(t);void 0!==r&&(t=r)}this.isFirstChunk=!1;var n=this._partialLine+t;this._partialLine="";var i=this._handle.parse(n,this._baseIndex,!this._finished);if(!this._handle.paused()&&!this._handle.aborted()){var s=i.meta.cursor;this._finished||(this._partialLine=n.substring(s-this._baseIndex),this._baseIndex=s),i&&i.data&&(this._rowCount+=i.data.length);var a=this._finished||this._config.preview&&this._rowCount>=this._config.preview;if(y)e.postMessage({results:i,workerId:S.WORKER_ID,finished:a});else if(m(this._config.chunk)){if(this._config.chunk(i,this._handle),this._paused)return;i=void 0,this._completeResults=void 0}return this._config.step||this._config.chunk||(this._completeResults.data=this._completeResults.data.concat(i.data),this._completeResults.errors=this._completeResults.errors.concat(i.errors),this._completeResults.meta=i.meta),!a||!m(this._config.complete)||i&&i.meta.aborted||this._config.complete(this._completeResults),a||i&&i.meta.paused||this._nextChunk(),i}},this._sendError=function(t){m(this._config.error)?this._config.error(t):y&&this._config.error&&e.postMessage({workerId:S.WORKER_ID,error:t,finished:!1})}}function i(e){function t(e){var t=e.getResponseHeader("Content-Range");return parseInt(t.substr(t.lastIndexOf("/")+1))}e=e||{},e.chunkSize||(e.chunkSize=S.RemoteChunkSize),n.call(this,e);var r;this._nextChunk=k?function(){this._readChunk(),this._chunkLoaded()}:function(){this._readChunk()},this.stream=function(e){this._input=e,this._nextChunk()},this._readChunk=function(){if(this._finished)return void this._chunkLoaded();if(r=new XMLHttpRequest,k||(r.onload=g(this._chunkLoaded,this),r.onerror=g(this._chunkError,this)),r.open("GET",this._input,!k),this._config.chunkSize){var e=this._start+this._config.chunkSize-1;r.setRequestHeader("Range","bytes="+this._start+"-"+e),r.setRequestHeader("If-None-Match","webkit-no-cache")}try{r.send()}catch(t){this._chunkError(t.message)}k&&0==r.status?this._chunkError():this._start+=this._config.chunkSize},this._chunkLoaded=function(){if(4==r.readyState){if(r.status<200||r.status>=400)return void this._chunkError();this._finished=!this._config.chunkSize||this._start>t(r),this.parseChunk(r.responseText)}},this._chunkError=function(e){var t=r.statusText||e;this._sendError(t)}}function s(e){e=e||{},e.chunkSize||(e.chunkSize=S.LocalChunkSize),n.call(this,e);var t,r,i="undefined"!=typeof FileReader;this.stream=function(e){this._input=e,r=e.slice||e.webkitSlice||e.mozSlice,i?(t=new FileReader,t.onload=g(this._chunkLoaded,this),t.onerror=g(this._chunkError,this)):t=new FileReaderSync,this._nextChunk()},this._nextChunk=function(){this._finished||this._config.preview&&!(this._rowCount=this._input.size,this.parseChunk(e.target.result)},this._chunkError=function(){this._sendError(t.error)}}function a(e){e=e||{},n.call(this,e);var t,r;this.stream=function(e){return t=e,r=e,this._nextChunk()},this._nextChunk=function(){if(!this._finished){var e=this._config.chunkSize,t=e?r.substr(0,e):r;return r=e?r.substr(e):"",this._finished=!r,this.parseChunk(t)}}}function o(e){function t(){if(b&&d&&(h("Delimiter","UndetectableDelimiter","Unable to auto-detect delimiting character; defaulted to '"+S.DefaultDelimiter+"'"),d=!1),e.skipEmptyLines)for(var t=0;t=y.length?(r.__parsed_extra||(r.__parsed_extra=[]),r.__parsed_extra.push(b.data[t][n])):r[y[n]]=b.data[t][n])}e.header&&(b.data[t]=r,n>y.length?h("FieldMismatch","TooManyFields","Too many fields: expected "+y.length+" fields but parsed "+n,t):n1&&(h+=Math.abs(l-i),i=l):i=l}c.data.length>0&&(f/=c.data.length),("undefined"==typeof n||n>h)&&f>1.99&&(n=h,r=o)}return e.delimiter=r,{successful:!!r,bestDelimiter:r}}function a(e){e=e.substr(0,1048576);var t=e.split("\r");if(1==t.length)return"\n";for(var r=0,n=0;n=t.length/2?"\r\n":"\r"}function o(e){var t=l.test(e);return t?parseFloat(e):e}function h(e,t,r,n){b.errors.push({type:e,code:t,message:r,row:n})}var f,c,d,l=/^\s*-?(\d*\.?\d+|\d+\.?\d*)(e[-+]?\d+)?\s*$/i,p=this,g=0,v=!1,k=!1,y=[],b={data:[],errors:[],meta:{}};if(m(e.step)){var R=e.step;e.step=function(n){if(b=n,r())t();else{if(t(),0==b.data.length)return;g+=n.data.length,e.preview&&g>e.preview?c.abort():R(b,p)}}}this.parse=function(r,n,i){if(e.newline||(e.newline=a(r)),d=!1,!e.delimiter){var o=s(r);o.successful?e.delimiter=o.bestDelimiter:(d=!0,e.delimiter=S.DefaultDelimiter),b.meta.delimiter=e.delimiter}var h=_(e);return e.preview&&e.header&&h.preview++,f=r,c=new u(h),b=c.parse(f,n,i),t(),v?{meta:{paused:!0}}:b||{meta:{paused:!1}}},this.paused=function(){return v},this.pause=function(){v=!0,c.abort(),f=f.substr(c.getCharIndex())},this.resume=function(){v=!1,p.streamer.parseChunk(f)},this.aborted=function(){return k},this.abort=function(){k=!0,c.abort(),b.meta.aborted=!0,m(e.complete)&&e.complete(b),f=""}}function u(e){e=e||{};var t=e.delimiter,r=e.newline,n=e.comments,i=e.step,s=e.preview,a=e.fastMode;if(("string"!=typeof t||S.BAD_DELIMITERS.indexOf(t)>-1)&&(t=","),n===t)throw"Comment character same as delimiter";n===!0?n="#":("string"!=typeof n||S.BAD_DELIMITERS.indexOf(n)>-1)&&(n=!1),"\n"!=r&&"\r"!=r&&"\r\n"!=r&&(r="\n");var o=0,u=!1;this.parse=function(e,h,f){function c(e){b.push(e),S=o}function d(t){return f?p():("undefined"==typeof t&&(t=e.substr(o)),w.push(t),o=g,c(w),y&&_(),p())}function l(t){o=t,c(w),w=[],O=e.indexOf(r,o)}function p(e){return{data:b,errors:R,meta:{delimiter:t,linebreak:r,aborted:u,truncated:!!e,cursor:S+(h||0)}}}function _(){i(p()),b=[],R=[]}if("string"!=typeof e)throw"Input must be a string";var g=e.length,m=t.length,v=r.length,k=n.length,y="function"==typeof i;o=0;var b=[],R=[],w=[],S=0;if(!e)return p();if(a||a!==!1&&-1===e.indexOf('"')){for(var C=e.split(r),E=0;E=s)return b=b.slice(0,s),p(!0)}}return p()}for(var x=e.indexOf(t,o),O=e.indexOf(r,o);;)if('"'!=e[o])if(n&&0===w.length&&e.substr(o,k)===n){if(-1==O)return p();o=O+v,O=e.indexOf(r,o),x=e.indexOf(t,o)}else if(-1!==x&&(O>x||-1===O))w.push(e.substring(o,x)),o=x+m,x=e.indexOf(t,o);else{if(-1===O)break;if(w.push(e.substring(o,O)),l(O+v),y&&(_(),u))return p();if(s&&b.length>=s)return p(!0)}else{var I=o;for(o++;;){var I=e.indexOf('"',I+1);if(-1===I)return f||R.push({type:"Quotes",code:"MissingQuotes",message:"Quoted field unterminated",row:b.length,index:o}),d();if(I===g-1){var D=e.substring(o,I).replace(/""/g,'"');return d(D)}if('"'!=e[I+1]){if(e[I+1]==t){w.push(e.substring(o,I).replace(/""/g,'"')),o=I+1+m,x=e.indexOf(t,o),O=e.indexOf(r,o);break}if(e.substr(I+1,v)===r){if(w.push(e.substring(o,I).replace(/""/g,'"')),l(I+1+v),x=e.indexOf(t,o),y&&(_(),u))return p();if(s&&b.length>=s)return p(!0);break}}else I++}}return d()},this.abort=function(){u=!0},this.getCharIndex=function(){return o}}function h(){var e=document.getElementsByTagName("script");return e.length?e[e.length-1].src:""}function f(){if(!S.WORKERS_SUPPORTED)return!1;if(!b&&null===S.SCRIPT_PATH)throw new Error("Script path cannot be determined automatically when Papa Parse is loaded asynchronously. You need to set Papa.SCRIPT_PATH manually.");var t=S.SCRIPT_PATH||v;t+=(-1!==t.indexOf("?")?"&":"?")+"papaworker";var r=new e.Worker(t);return r.onmessage=c,r.id=w++,R[r.id]=r,r}function c(e){var t=e.data,r=R[t.workerId],n=!1;if(t.error)r.userError(t.error,t.file);else if(t.results&&t.results.data){var i=function(){n=!0,d(t.workerId,{data:[],errors:[],meta:{aborted:!0}})},s={abort:i,pause:l,resume:l};if(m(r.userStep)){for(var a=0;a 1 or opt.multiprocessing_distributed 75 | 76 | if torch.cuda.is_available(): 77 | ngpus_per_node = torch.cuda.device_count() 78 | else: 79 | ngpus_per_node = 1 80 | if opt.multiprocessing_distributed: 81 | # Since we have ngpus_per_node processes per node, the total world_size 82 | # needs to be adjusted accordingly 83 | opt.world_size = ngpus_per_node * opt.world_size 84 | # Use torch.multiprocessing.spawn to launch distributed processes: the 85 | # main_worker process function 86 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 87 | else: 88 | # Simply call main_worker function 89 | main_worker(opt.gpu_id, ngpus_per_node, opt) 90 | 91 | 92 | 93 | def main_worker(gpu_id, ngpus_per_node, opt): 94 | # rank, world_size = get_dist_info() 95 | opt.gpu_id = gpu_id 96 | 97 | if opt.gpu_id is not None: 98 | print("Use GPU: {}".format(opt.gpu_id)) 99 | 100 | if opt.distributed: 101 | if opt.dist_url == "env://" and opt.rank == -1: 102 | opt.rank = int(os.environ["RANK"]) 103 | if opt.multiprocessing_distributed: 104 | # For multiprocessing distributed training, rank needs to be the 105 | # global rank among all the processes 106 | opt.rank = opt.rank * ngpus_per_node + gpu_id 107 | dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, 108 | world_size=opt.world_size, rank=opt.rank) 109 | 110 | 111 | opt.device = torch.device("cuda") 112 | torch.autograd.set_detect_anomaly(True) 113 | 114 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 115 | opt.model_dir = pjoin(opt.save_root, 'model') 116 | opt.meta_dir = pjoin(opt.save_root, 'meta') 117 | 118 | # if opt.rank == 0: 119 | os.makedirs(opt.model_dir, exist_ok=True) 120 | os.makedirs(opt.meta_dir, exist_ok=True) 121 | if opt.world_size > 1: 122 | dist.barrier() 123 | 124 | if opt.dataset_name.lower() == 'beat': 125 | opt.data_root = 'data/BEAT' 126 | opt.fps = 15 127 | opt.net_dim_pose = 192 # body: [16, 34, 141], expression: [16, 34, 51], in_audio: [16, 36266] 128 | opt.split_pos = 141 129 | opt.dim_pose = 141 130 | if opt.remove_hand: 131 | opt.dim_pose = 33 132 | opt.expression_dim = 51 133 | 134 | if opt.expression_only or opt.gesCondition_expression_only: 135 | opt.net_dim_pose = opt.expression_dim # expression 136 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/face_300.bin' 137 | elif opt.gesture_only or opt.expCondition_gesture_only != None or \ 138 | opt.textExpEmoCondition_gesture_only: 139 | opt.net_dim_pose = opt.dim_pose # gesture 140 | if opt.axis_angle: 141 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/ges_axis_angle_300.bin' 142 | else: 143 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/ae_300.bin' 144 | else: 145 | opt.net_dim_pose = opt.dim_pose + opt.expression_dim # gesture + expression 146 | if opt.axis_angle: 147 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/GesAxisAngle_Face_300.bin' 148 | else: 149 | raise NotImplementedError 150 | 151 | opt.audio_dim = 128 152 | if opt.use_aud_feat: 153 | opt.audio_dim = 1024 154 | opt.style_dim = 30 # totally 30 subjects 155 | opt.speaker_dim = 30 156 | opt.word_index_num = 5793 157 | opt.word_dims = 300 158 | opt.word_f = 128 159 | opt.emotion_f = 8 160 | opt.emotion_dims = 8 161 | opt.freeze_wordembed = False 162 | opt.hidden_size = 256 163 | opt.n_layer = 4 164 | 165 | if opt.n_poses == 150: 166 | opt.stride = 50 167 | elif opt.n_poses == 34: 168 | opt.stride = 10 169 | opt.pose_fps = 15 170 | opt.vae_length = 300 171 | opt.new_cache = False 172 | opt.audio_norm = False 173 | opt.facial_norm = True 174 | opt.pose_norm = True 175 | opt.train_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/' 176 | opt.val_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/val/' 177 | opt.test_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/test/' 178 | opt.mean_pose_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/' 179 | opt.std_pose_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/' 180 | opt.multi_length_training = [1.0] 181 | opt.audio_rep = 'wave16k' 182 | opt.facial_rep = 'facial52' 183 | opt.speaker_id = 'id' 184 | opt.pose_rep = 'bvh_rot' 185 | opt.word_rep = 'text' 186 | opt.sem_rep = 'sem' 187 | opt.emo_rep = 'emo' 188 | 189 | elif opt.dataset_name.lower() == 'talkshow': 190 | opt.talkshow_config = 'options/talkshow_configs/body_pixel.json' 191 | opt.speaker_dim = 4 192 | opt.fps = 30 193 | opt.dim_pose = 129 194 | opt.split_pos = 129 195 | if opt.remove_hand: 196 | opt.dim_pose = 39 197 | opt.expression_dim = 103 198 | if opt.ablation == "reverse_ges2exp": 199 | opt.expression_dim, opt.dim_pose = opt.dim_pose, opt.expression_dim 200 | if opt.expression_only or opt.gesCondition_expression_only: 201 | opt.net_dim_pose = opt.expression_dim # expression 202 | opt.e_path = f'data/SHOW/ae_weights/expression.pth.tar' 203 | elif opt.gesture_only or opt.expCondition_gesture_only != None: 204 | opt.net_dim_pose = opt.dim_pose # gesture 205 | opt.e_path = f'data/SHOW/ae_weights/gesture.pth.tar' 206 | else: 207 | opt.net_dim_pose = opt.dim_pose + opt.expression_dim # gesture + expression 208 | opt.e_path = f'data/SHOW/ae_weights/gesture_expression.pth.tar' 209 | 210 | if opt.audio_feat == 'mfcc': 211 | opt.audio_dim = 64 212 | elif opt.audio_feat == 'mel': 213 | opt.audio_dim = 128 214 | elif opt.audio_feat == 'raw': 215 | opt.audio_dim = 1 216 | elif opt.audio_feat == 'hubert': 217 | opt.audio_dim = 1024 218 | opt.style_dim = 4 219 | opt.speaker_dim = 4 220 | opt.n_poses = 88 221 | opt.pose_fps = 30 222 | opt.vae_length = 300 223 | 224 | else: 225 | raise KeyError('Dataset Does Not Exist') 226 | 227 | 228 | 229 | print("=> creating model '{}'".format(opt.model_base)) 230 | model = build_models(opt, opt.net_dim_pose, opt.audio_dim, opt.audio_latent_dim, opt.style_dim) 231 | 232 | if opt.no_fgd == False: 233 | eval_model = build_fgd_val_model(opt) 234 | else: 235 | eval_model = None 236 | 237 | if not torch.cuda.is_available() and not torch.backends.mps.is_available(): 238 | print('using CPU, this will be slow') 239 | elif opt.distributed: 240 | # For multiprocessing distributed, DistributedDataParallel constructor 241 | # should always set the single device scope, otherwise, 242 | # DistributedDataParallel will use all available devices. 243 | if torch.cuda.is_available(): 244 | if opt.gpu_id is not None: 245 | torch.cuda.set_device(opt.gpu_id) 246 | model.cuda(opt.gpu_id) 247 | # When using a single GPU per process and per 248 | # DistributedDataParallel, we need to divide the batch size 249 | # ourselves based on the total number of GPUs of the current node. 250 | opt.batch_size = int(opt.batch_size / ngpus_per_node) 251 | opt.workers = int((opt.workers + ngpus_per_node - 1) / ngpus_per_node) 252 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu_id], find_unused_parameters=False) 253 | 254 | if not opt.no_fgd: 255 | eval_model.cuda(opt.gpu_id) 256 | eval_model = torch.nn.parallel.DistributedDataParallel(eval_model, device_ids=[opt.gpu_id], find_unused_parameters=False) 257 | else: 258 | # DistributedDataParallel will divide and allocate batch_size to all 259 | # available GPUs if device_ids are not set 260 | # model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 261 | model = torch.nn.parallel.DistributedDataParallel(model) 262 | if not opt.no_fgd: 263 | # eval_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(eval_model) 264 | eval_model = torch.nn.parallel.DistributedDataParallel(eval_model, device_ids=[opt.rank], broadcast_buffers=True, find_unused_parameters=False).to(opt.rank) 265 | elif opt.gpu_id is not None and torch.cuda.is_available(): 266 | torch.cuda.set_device(opt.gpu_id) 267 | model = model.cuda(opt.gpu_id) 268 | if not opt.no_fgd: 269 | eval_model = eval_model.cuda(opt.gpu_id) 270 | elif torch.backends.mps.is_available(): 271 | device = torch.device("mps") 272 | model = model.to(device) 273 | if not opt.no_fgd: 274 | eval_model = eval_model.to(device) 275 | else: 276 | # Use single gpu 277 | model = model.cuda() 278 | if not opt.no_fgd: 279 | eval_model = eval_model.cuda() 280 | 281 | if torch.cuda.is_available(): 282 | if opt.gpu_id: 283 | device = torch.device('cuda:{}'.format(opt.gpu_id)) 284 | else: 285 | device = torch.device("cuda") 286 | elif torch.backends.mps.is_available(): 287 | device = torch.device("mps") 288 | else: 289 | device = torch.device("cpu") 290 | 291 | if opt.dataset_name == 'beat': 292 | runner = DDPMTrainer_beat(opt, model, eval_model=eval_model) 293 | elif opt.dataset_name == 'talkshow': 294 | runner = DDPMTrainer_show(opt, model, eval_model=eval_model) 295 | else: 296 | runner = DDPMTrainer(opt, model) 297 | 298 | if opt.mode == "train": 299 | if opt.dataset_name.lower() == 'beat': 300 | train_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "train") 301 | val_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "val") 302 | 303 | elif opt.dataset_name.lower() == 'talkshow': 304 | train_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_train_cache') 305 | val_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_val_cache') 306 | 307 | 308 | runner.train(train_dataset, val_dataset) 309 | 310 | elif "test" in opt.mode: 311 | if opt.dataset_name.lower() == 'beat': 312 | test_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "test") 313 | 314 | elif opt.dataset_name.lower() == 'talkshow': 315 | test_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_test_cache') 316 | 317 | if opt.mode == "test": 318 | results_dir = runner.test(test_dataset) 319 | elif opt.mode == "test_arbitrary_len": 320 | opt.batch_size = 1 321 | results_dir = runner.test_arbitrary_len(test_dataset) 322 | elif opt.mode == "test_custom_audio": 323 | results_dir = runner.test_custom_aud(opt.test_audio_path, test_dataset) 324 | print(results_dir) 325 | 326 | 327 | 328 | 329 | if __name__ == '__main__': 330 | main() 331 | 332 | -------------------------------------------------------------------------------- /models/ddpm_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | # from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | ): 148 | if channel_mult == "": 149 | if image_size == 512: 150 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 151 | elif image_size == 256: 152 | channel_mult = (1, 1, 2, 2, 4, 4) 153 | elif image_size == 128: 154 | channel_mult = (1, 1, 2, 3, 4) 155 | elif image_size == 64: 156 | channel_mult = (1, 2, 3, 4) 157 | else: 158 | raise ValueError(f"unsupported image size: {image_size}") 159 | else: 160 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 161 | 162 | attention_ds = [] 163 | for res in attention_resolutions.split(","): 164 | attention_ds.append(image_size // int(res)) 165 | 166 | return UNetModel( 167 | image_size=image_size, 168 | in_channels=3, 169 | model_channels=num_channels, 170 | out_channels=(3 if not learn_sigma else 6), 171 | num_res_blocks=num_res_blocks, 172 | attention_resolutions=tuple(attention_ds), 173 | dropout=dropout, 174 | channel_mult=channel_mult, 175 | num_classes=(NUM_CLASSES if class_cond else None), 176 | use_checkpoint=use_checkpoint, 177 | use_fp16=use_fp16, 178 | num_heads=num_heads, 179 | num_head_channels=num_head_channels, 180 | num_heads_upsample=num_heads_upsample, 181 | use_scale_shift_norm=use_scale_shift_norm, 182 | resblock_updown=resblock_updown, 183 | use_new_attention_order=use_new_attention_order, 184 | ) 185 | 186 | 187 | def create_classifier_and_diffusion( 188 | image_size, 189 | classifier_use_fp16, 190 | classifier_width, 191 | classifier_depth, 192 | classifier_attention_resolutions, 193 | classifier_use_scale_shift_norm, 194 | classifier_resblock_updown, 195 | classifier_pool, 196 | learn_sigma, 197 | diffusion_steps, 198 | noise_schedule, 199 | timestep_respacing, 200 | use_kl, 201 | predict_xstart, 202 | rescale_timesteps, 203 | rescale_learned_sigmas, 204 | ): 205 | classifier = create_classifier( 206 | image_size, 207 | classifier_use_fp16, 208 | classifier_width, 209 | classifier_depth, 210 | classifier_attention_resolutions, 211 | classifier_use_scale_shift_norm, 212 | classifier_resblock_updown, 213 | classifier_pool, 214 | ) 215 | diffusion = create_gaussian_diffusion( 216 | steps=diffusion_steps, 217 | learn_sigma=learn_sigma, 218 | noise_schedule=noise_schedule, 219 | use_kl=use_kl, 220 | predict_xstart=predict_xstart, 221 | rescale_timesteps=rescale_timesteps, 222 | rescale_learned_sigmas=rescale_learned_sigmas, 223 | timestep_respacing=timestep_respacing, 224 | ) 225 | return classifier, diffusion 226 | 227 | 228 | def create_classifier( 229 | image_size, 230 | classifier_use_fp16, 231 | classifier_width, 232 | classifier_depth, 233 | classifier_attention_resolutions, 234 | classifier_use_scale_shift_norm, 235 | classifier_resblock_updown, 236 | classifier_pool, 237 | ): 238 | if image_size == 512: 239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 240 | elif image_size == 256: 241 | channel_mult = (1, 1, 2, 2, 4, 4) 242 | elif image_size == 128: 243 | channel_mult = (1, 1, 2, 3, 4) 244 | elif image_size == 64: 245 | channel_mult = (1, 2, 3, 4) 246 | else: 247 | raise ValueError(f"unsupported image size: {image_size}") 248 | 249 | attention_ds = [] 250 | for res in classifier_attention_resolutions.split(","): 251 | attention_ds.append(image_size // int(res)) 252 | 253 | return EncoderUNetModel( 254 | image_size=image_size, 255 | in_channels=3, 256 | model_channels=classifier_width, 257 | out_channels=1000, 258 | num_res_blocks=classifier_depth, 259 | attention_resolutions=tuple(attention_ds), 260 | channel_mult=channel_mult, 261 | use_fp16=classifier_use_fp16, 262 | num_head_channels=64, 263 | use_scale_shift_norm=classifier_use_scale_shift_norm, 264 | resblock_updown=classifier_resblock_updown, 265 | pool=classifier_pool, 266 | ) 267 | 268 | 269 | def sr_model_and_diffusion_defaults(): 270 | res = model_and_diffusion_defaults() 271 | res["large_size"] = 256 272 | res["small_size"] = 64 273 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 274 | for k in res.copy().keys(): 275 | if k not in arg_names: 276 | del res[k] 277 | return res 278 | 279 | 280 | def sr_create_model_and_diffusion( 281 | large_size, 282 | small_size, 283 | class_cond, 284 | learn_sigma, 285 | num_channels, 286 | num_res_blocks, 287 | num_heads, 288 | num_head_channels, 289 | num_heads_upsample, 290 | attention_resolutions, 291 | dropout, 292 | diffusion_steps, 293 | noise_schedule, 294 | timestep_respacing, 295 | use_kl, 296 | predict_xstart, 297 | rescale_timesteps, 298 | rescale_learned_sigmas, 299 | use_checkpoint, 300 | use_scale_shift_norm, 301 | resblock_updown, 302 | use_fp16, 303 | ): 304 | model = sr_create_model( 305 | large_size, 306 | small_size, 307 | num_channels, 308 | num_res_blocks, 309 | learn_sigma=learn_sigma, 310 | class_cond=class_cond, 311 | use_checkpoint=use_checkpoint, 312 | attention_resolutions=attention_resolutions, 313 | num_heads=num_heads, 314 | num_head_channels=num_head_channels, 315 | num_heads_upsample=num_heads_upsample, 316 | use_scale_shift_norm=use_scale_shift_norm, 317 | dropout=dropout, 318 | resblock_updown=resblock_updown, 319 | use_fp16=use_fp16, 320 | ) 321 | diffusion = create_gaussian_diffusion( 322 | steps=diffusion_steps, 323 | learn_sigma=learn_sigma, 324 | noise_schedule=noise_schedule, 325 | use_kl=use_kl, 326 | predict_xstart=predict_xstart, 327 | rescale_timesteps=rescale_timesteps, 328 | rescale_learned_sigmas=rescale_learned_sigmas, 329 | timestep_respacing=timestep_respacing, 330 | ) 331 | return model, diffusion 332 | 333 | 334 | def sr_create_model( 335 | large_size, 336 | small_size, 337 | num_channels, 338 | num_res_blocks, 339 | learn_sigma, 340 | class_cond, 341 | use_checkpoint, 342 | attention_resolutions, 343 | num_heads, 344 | num_head_channels, 345 | num_heads_upsample, 346 | use_scale_shift_norm, 347 | dropout, 348 | resblock_updown, 349 | use_fp16, 350 | ): 351 | _ = small_size # hack to prevent unused variable 352 | 353 | if large_size == 512: 354 | channel_mult = (1, 1, 2, 2, 4, 4) 355 | elif large_size == 256: 356 | channel_mult = (1, 1, 2, 2, 4, 4) 357 | elif large_size == 64: 358 | channel_mult = (1, 2, 3, 4) 359 | else: 360 | raise ValueError(f"unsupported large size: {large_size}") 361 | 362 | attention_ds = [] 363 | for res in attention_resolutions.split(","): 364 | attention_ds.append(large_size // int(res)) 365 | 366 | return SuperResModel( 367 | image_size=large_size, 368 | in_channels=3, 369 | model_channels=num_channels, 370 | out_channels=(3 if not learn_sigma else 6), 371 | num_res_blocks=num_res_blocks, 372 | attention_resolutions=tuple(attention_ds), 373 | dropout=dropout, 374 | channel_mult=channel_mult, 375 | num_classes=(NUM_CLASSES if class_cond else None), 376 | use_checkpoint=use_checkpoint, 377 | num_heads=num_heads, 378 | num_head_channels=num_head_channels, 379 | num_heads_upsample=num_heads_upsample, 380 | use_scale_shift_norm=use_scale_shift_norm, 381 | resblock_updown=resblock_updown, 382 | use_fp16=use_fp16, 383 | ) 384 | 385 | 386 | def create_gaussian_diffusion( 387 | *, 388 | steps=1000, 389 | learn_sigma=False, 390 | sigma_small=False, 391 | noise_schedule="linear", 392 | use_kl=False, 393 | predict_xstart=False, 394 | rescale_timesteps=False, 395 | rescale_learned_sigmas=False, 396 | timestep_respacing="", 397 | ): 398 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 399 | if use_kl: 400 | loss_type = gd.LossType.RESCALED_KL 401 | elif rescale_learned_sigmas: 402 | loss_type = gd.LossType.RESCALED_MSE 403 | else: 404 | loss_type = gd.LossType.MSE 405 | if not timestep_respacing: 406 | timestep_respacing = [steps] 407 | return SpacedDiffusion( 408 | use_timesteps=space_timesteps(steps, timestep_respacing), 409 | betas=betas, 410 | model_mean_type=( 411 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 412 | ), 413 | model_var_type=( 414 | ( 415 | gd.ModelVarType.FIXED_LARGE 416 | if not sigma_small 417 | else gd.ModelVarType.FIXED_SMALL 418 | ) 419 | if not learn_sigma 420 | else gd.ModelVarType.LEARNED_RANGE 421 | ), 422 | loss_type=loss_type, 423 | rescale_timesteps=rescale_timesteps, 424 | ) 425 | 426 | 427 | def add_dict_to_argparser(parser, default_dict): 428 | for k, v in default_dict.items(): 429 | v_type = type(v) 430 | if v is None: 431 | v_type = str 432 | elif isinstance(v, bool): 433 | v_type = str2bool 434 | parser.add_argument(f"--{k}", default=v, type=v_type) 435 | 436 | 437 | def args_to_dict(args, keys): 438 | return {k: getattr(args, k) for k in keys} 439 | 440 | 441 | def str2bool(v): 442 | """ 443 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 444 | """ 445 | if isinstance(v, bool): 446 | return v 447 | if v.lower() in ("yes", "true", "t", "y", "1"): 448 | return True 449 | elif v.lower() in ("no", "false", "f", "n", "0"): 450 | return False 451 | else: 452 | raise argparse.ArgumentTypeError("boolean value expected") -------------------------------------------------------------------------------- /utils/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(float).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(np.float).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | 128 | def qmul_np(q, r): 129 | q = torch.from_numpy(q).contiguous().float() 130 | r = torch.from_numpy(r).contiguous().float() 131 | return qmul(q, r).numpy() 132 | 133 | 134 | def qrot_np(q, v): 135 | q = torch.from_numpy(q).contiguous().float() 136 | v = torch.from_numpy(v).contiguous().float() 137 | return qrot(q, v).numpy() 138 | 139 | 140 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 141 | if use_gpu: 142 | q = torch.from_numpy(q).cuda().float() 143 | return qeuler(q, order, epsilon).cpu().numpy() 144 | else: 145 | q = torch.from_numpy(q).contiguous().float() 146 | return qeuler(q, order, epsilon).numpy() 147 | 148 | 149 | def qfix(q): 150 | """ 151 | Enforce quaternion continuity across the time dimension by selecting 152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 153 | between two consecutive frames. 154 | 155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 156 | Returns a tensor of the same shape. 157 | """ 158 | assert len(q.shape) == 3 159 | assert q.shape[-1] == 4 160 | 161 | result = q.copy() 162 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 163 | mask = dot_products < 0 164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 165 | result[1:][mask] *= -1 166 | return result 167 | 168 | 169 | def euler2quat(e, order, deg=True): 170 | """ 171 | Convert Euler angles to quaternions. 172 | """ 173 | assert e.shape[-1] == 3 174 | 175 | original_shape = list(e.shape) 176 | original_shape[-1] = 4 177 | 178 | e = e.view(-1, 3) 179 | 180 | ## if euler angles in degrees 181 | if deg: 182 | e = e * np.pi / 180. 183 | 184 | x = e[:, 0] 185 | y = e[:, 1] 186 | z = e[:, 2] 187 | 188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 191 | 192 | result = None 193 | for coord in order: 194 | if coord == 'x': 195 | r = rx 196 | elif coord == 'y': 197 | r = ry 198 | elif coord == 'z': 199 | r = rz 200 | else: 201 | raise 202 | if result is None: 203 | result = r 204 | else: 205 | result = qmul(result, r) 206 | 207 | # Reverse antipodal representation to have a non-negative "w" 208 | if order in ['xyz', 'yzx', 'zxy']: 209 | result *= -1 210 | 211 | return result.view(original_shape) 212 | 213 | 214 | def expmap_to_quaternion(e): 215 | """ 216 | Convert axis-angle rotations (aka exponential maps) to quaternions. 217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 219 | Returns a tensor of shape (*, 4). 220 | """ 221 | assert e.shape[-1] == 3 222 | 223 | original_shape = list(e.shape) 224 | original_shape[-1] = 4 225 | e = e.reshape(-1, 3) 226 | 227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 228 | w = np.cos(0.5 * theta).reshape(-1, 1) 229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 231 | 232 | 233 | def euler_to_quaternion(e, order): 234 | """ 235 | Convert Euler angles to quaternions. 236 | """ 237 | assert e.shape[-1] == 3 238 | 239 | original_shape = list(e.shape) 240 | original_shape[-1] = 4 241 | 242 | e = e.reshape(-1, 3) 243 | 244 | x = e[:, 0] 245 | y = e[:, 1] 246 | z = e[:, 2] 247 | 248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 251 | 252 | result = None 253 | for coord in order: 254 | if coord == 'x': 255 | r = rx 256 | elif coord == 'y': 257 | r = ry 258 | elif coord == 'z': 259 | r = rz 260 | else: 261 | raise 262 | if result is None: 263 | result = r 264 | else: 265 | result = qmul_np(result, r) 266 | 267 | # Reverse antipodal representation to have a non-negative "w" 268 | if order in ['xyz', 'yzx', 'zxy']: 269 | result *= -1 270 | 271 | return result.reshape(original_shape) 272 | 273 | 274 | def quaternion_to_matrix(quaternions): 275 | """ 276 | Convert rotations given as quaternions to rotation matrices. 277 | Args: 278 | quaternions: quaternions with real part first, 279 | as tensor of shape (..., 4). 280 | Returns: 281 | Rotation matrices as tensor of shape (..., 3, 3). 282 | """ 283 | r, i, j, k = torch.unbind(quaternions, -1) 284 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 285 | 286 | o = torch.stack( 287 | ( 288 | 1 - two_s * (j * j + k * k), 289 | two_s * (i * j - k * r), 290 | two_s * (i * k + j * r), 291 | two_s * (i * j + k * r), 292 | 1 - two_s * (i * i + k * k), 293 | two_s * (j * k - i * r), 294 | two_s * (i * k - j * r), 295 | two_s * (j * k + i * r), 296 | 1 - two_s * (i * i + j * j), 297 | ), 298 | -1, 299 | ) 300 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 301 | 302 | 303 | def quaternion_to_matrix_np(quaternions): 304 | q = torch.from_numpy(quaternions).contiguous().float() 305 | return quaternion_to_matrix(q).numpy() 306 | 307 | 308 | def quaternion_to_cont6d_np(quaternions): 309 | rotation_mat = quaternion_to_matrix_np(quaternions) 310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 311 | return cont_6d 312 | 313 | 314 | def quaternion_to_cont6d(quaternions): 315 | rotation_mat = quaternion_to_matrix(quaternions) 316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 317 | return cont_6d 318 | 319 | 320 | def cont6d_to_matrix(cont6d): 321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 322 | x_raw = cont6d[..., 0:3] 323 | y_raw = cont6d[..., 3:6] 324 | 325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 326 | z = torch.cross(x, y_raw, dim=-1) 327 | z = z / torch.norm(z, dim=-1, keepdim=True) 328 | 329 | y = torch.cross(z, x, dim=-1) 330 | 331 | x = x[..., None] 332 | y = y[..., None] 333 | z = z[..., None] 334 | 335 | mat = torch.cat([x, y, z], dim=-1) 336 | return mat 337 | 338 | 339 | def cont6d_to_matrix_np(cont6d): 340 | q = torch.from_numpy(cont6d).contiguous().float() 341 | return cont6d_to_matrix(q).numpy() 342 | 343 | 344 | def qpow(q0, t, dtype=torch.float): 345 | ''' q0 : tensor of quaternions 346 | t: tensor of powers 347 | ''' 348 | q0 = qnormalize(q0) 349 | theta0 = torch.acos(q0[..., 0]) 350 | 351 | ## if theta0 is close to zero, add epsilon to avoid NaNs 352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 353 | theta0 = (1 - mask) * theta0 + mask * 10e-10 354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 355 | 356 | if isinstance(t, torch.Tensor): 357 | q = torch.zeros(t.shape + q0.shape) 358 | theta = t.view(-1, 1) * theta0.view(1, -1) 359 | else: ## if t is a number 360 | q = torch.zeros(q0.shape) 361 | theta = t * theta0 362 | 363 | q[..., 0] = torch.cos(theta) 364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 365 | 366 | return q.to(dtype) 367 | 368 | 369 | def qslerp(q0, q1, t): 370 | ''' 371 | q0: starting quaternion 372 | q1: ending quaternion 373 | t: array of points along the way 374 | 375 | Returns: 376 | Tensor of Slerps: t.shape + q0.shape 377 | ''' 378 | 379 | q0 = qnormalize(q0) 380 | q1 = qnormalize(q1) 381 | q_ = qpow(qmul(q1, qinv(q0)), t) 382 | 383 | return qmul(q_, 384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 385 | 386 | 387 | def qbetween(v0, v1): 388 | ''' 389 | find the quaternion used to rotate v0 to v1 390 | ''' 391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 393 | 394 | v = torch.cross(v0, v1) 395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 396 | keepdim=True) 397 | return qnormalize(torch.cat([w, v], dim=-1)) 398 | 399 | 400 | def qbetween_np(v0, v1): 401 | ''' 402 | find the quaternion used to rotate v0 to v1 403 | ''' 404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 406 | 407 | v0 = torch.from_numpy(v0).float() 408 | v1 = torch.from_numpy(v1).float() 409 | return qbetween(v0, v1).numpy() 410 | 411 | 412 | def lerp(p0, p1, t): 413 | if not isinstance(t, torch.Tensor): 414 | t = torch.Tensor([t]) 415 | 416 | new_shape = t.shape + p0.shape 417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 419 | p0 = p0.view(new_view_p).expand(new_shape) 420 | p1 = p1.view(new_view_p).expand(new_shape) 421 | t = t.view(new_view_t).expand(new_shape) 422 | 423 | return p0 + t * (p1 - p0) 424 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from mmcv.runner.dist_utils import get_dist_info 5 | import torch.distributed as dist 6 | 7 | 8 | class BaseOptions(): 9 | def __init__(self): 10 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | self.initialized = False 12 | 13 | def initialize(self): 14 | self.parser.add_argument('--name', type=str, default="test", help='Name of this trial') 15 | self.parser.add_argument('--decomp_name', type=str, default="Decomp_SP001_SM001_H512", help='Name of autoencoder model') 16 | self.parser.add_argument('--model_base', type=str, default="transformer_encoder", choices=["transformer_decoder", "transformer_encoder", "st_unet"], help='Model architecture') 17 | self.parser.add_argument('--model_mean_type', type=str, default="epsilon", choices=["epsilon", "start_x", "previous_x"], help='Choose which type of data the model ouputs') 18 | self.parser.add_argument('--PE', type=str, default='pe_sinu', choices=['learnable', 'ppe_sinu', 'pe_sinu', 'pe_sinu_repeat', 'ppe_sinu_dropout'], help='Choose the type of positional emb') 19 | self.parser.add_argument("--ddim", action="store_true", help='Use ddim sampling') 20 | self.parser.add_argument("--timestep_respacing", type=str, default='ddim1000', help="Set ddim steps 'ddim{STEP}'") 21 | self.parser.add_argument("--cond_projection", type=str, default='mlp_includeX', choices=["linear_includeX", "mlp_includeX", "none", "linear_excludeX", "mlp_excludeX"], help="condition projection choices") 22 | self.parser.add_argument("--cond_residual", type=bool, default=True, help='Weather to use residual during condition projection') 23 | 24 | 25 | self.parser.add_argument("--gpu_id", type=int, default=None, help='GPU id') 26 | self.parser.add_argument("--distributed", action="store_true", help='Weather to use DDP training') 27 | self.parser.add_argument("--data_parallel", action="store_true", help='Weather to use DP training') 28 | self.parser.add_argument("--max_eval_samples", type=int, default=-1, help='max_eval_samples') 29 | self.parser.add_argument("--n_poses", type=int, help='number of poses for a training sequence') 30 | self.parser.add_argument("--axis_angle", type=bool, default=True, help='whether use the axis_angle rot representaiton') 31 | self.parser.add_argument("--rename", default=None, help='rename the experiment name during test') 32 | 33 | self.parser.add_argument("--debug", action="store_true", help='debug mode, only run one iteration') 34 | self.parser.add_argument('--mode', type=str, default='train', choices=["train", "val", "test", "test_arbitrary_len", "test_custom_audio"], help='train, val or test') 35 | 36 | self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name') 37 | self.parser.add_argument('--data_mode', type=str, default='original', choices=['original', 'add_init_state'], help='Data modes') 38 | self.parser.add_argument('--data_type', type=str, default='pos', choices=['pos', 'vel', 'pos_vel'], help='Data types') 39 | self.parser.add_argument('--data_sel', type=str, default='upperbody', choices=['upperbody', 'all', 'upperbody_head', 'upperbody_hands'], help='Data selection') 40 | self.parser.add_argument('--data_root', type=str, default='./Freeform/processed_data_200', help='Dataset path') 41 | self.parser.add_argument('--beat_cache_name', default='beat_4english_15_141', help='Beat cache name') 42 | self.parser.add_argument('--use_aud_feat', type=str, default=None, choices=["interpolate", "conv"], help='Audio feature path') 43 | self.parser.add_argument('--audio_feat', type=str, default='mel', choices=["mel", "mfcc", "raw", "hubert", 'wav2vec2'], help='Audio feature type') 44 | self.parser.add_argument('--test_audio_path', type=str, default=None, help='test audio file or directory path') 45 | self.parser.add_argument('--same_overlap_noisy', action="store_true", help='During the outpainting process, use the same overlapping noisyGT') 46 | self.parser.add_argument('--no_repaint', action="store_true", help='Do not perform repaint during long-form generation') 47 | 48 | 49 | self.parser.add_argument('--vel_interval', type=int, default=10, help='Interval to compute the velocity') 50 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 51 | self.parser.add_argument('--overlap_len', type=int, default=0, help='Fix the initial N frames for this clip') 52 | self.parser.add_argument('--addBlend', type=bool, default=True, help='Blend in the overlapping region at the last two denoise steps') 53 | self.parser.add_argument('--fix_very_first', action='store_true', help='Fix the very first {overlap_len} frames for this video to be the same as GT') 54 | self.parser.add_argument('--remove_audio', action='store_true', help='set audio to 0') 55 | self.parser.add_argument('--remove_style', action='store_true', help='set style to 0') 56 | self.parser.add_argument('--remove_hand', action='store_true', help='remove hand rotations from motion data') 57 | self.parser.add_argument('--no_fgd', action='store_true', help='do not compute fgd') 58 | self.parser.add_argument('--ablation', type=str, default=None, choices=["no_x0", "no_detach", "reverse_ges2exp"], help='ablation options') 59 | self.parser.add_argument('--rebuttal', type=str, default=None, choices=["noMelSpec", "noHuBert", "noMidAud"], help='rebuttal ablation options') 60 | self.parser.add_argument('--visualize_unify_x0_step', type=int, default=None, help='visualize expression x0 in unified mode every N step') 61 | 62 | 63 | self.parser.add_argument('--audio_dim', type=int, default=128, help='Input Audio feature dimension.') 64 | self.parser.add_argument('--audio_latent_dim', type=int, default=256, help='Audio latent dimension.') 65 | self.parser.add_argument('--style_dim', type=int, default=4, help='Input Style vector dimension. Can be one hot.') 66 | self.parser.add_argument('--dim_text_hidden', type=int, default=512, help='Dimension of hidden unit in text encoder') 67 | self.parser.add_argument('--dim_att_vec', type=int, default=512, help='Dimension of attention vector') 68 | self.parser.add_argument('--dim_z', type=int, default=128, help='Dimension of latent Gaussian vector') 69 | 70 | self.parser.add_argument('--n_layers_pri', type=int, default=1, help='Number of layers in prior network') 71 | self.parser.add_argument('--n_layers_pos', type=int, default=1, help='Number of layers in posterior network') 72 | self.parser.add_argument('--n_layers_dec', type=int, default=1, help='Number of layers in generator') 73 | 74 | self.parser.add_argument('--dim_pri_hidden', type=int, default=1024, help='Dimension of hidden unit in prior network') 75 | self.parser.add_argument('--dim_pos_hidden', type=int, default=1024, help='Dimension of hidden unit in posterior network') 76 | self.parser.add_argument('--dim_dec_hidden', type=int, default=1024, help='Dimension of hidden unit in generator') 77 | 78 | self.parser.add_argument('--dim_movement_enc_hidden', type=int, default=512, 79 | help='Dimension of hidden in AutoEncoder(encoder)') 80 | self.parser.add_argument('--dim_movement_dec_hidden', type=int, default=512, 81 | help='Dimension of hidden in AutoEncoder(decoder)') 82 | self.parser.add_argument('--dim_movement_latent', type=int, default=512, help='Dimension of motion snippet') 83 | 84 | self.parser.add_argument('--embed_net_path', type=str, default="feature_extractor/gesture_autoencoder_checkpoint_best.bin", help='embed_net_path') 85 | 86 | self.parser.add_argument('--fix_head_var', action="store_true", help='Make expression prediction derterministic') 87 | self.parser.add_argument('--expression_only', action="store_true", help='train epxression only') 88 | self.parser.add_argument('--gesture_only', action="store_true", help='train gesture only') 89 | self.parser.add_argument('--expCondition_gesture_only', type=str, choices=['gt', 'pred'], default=None, help='train gesture only, with expressions as condition') 90 | self.parser.add_argument('--gesCondition_expression_only', action="store_true", help='train expression only, with gesture as condition') 91 | self.parser.add_argument('--textExpEmoCondition_gesture_only', action="store_true", help='use all conditions: audio, text, emo, pid, facial') 92 | self.parser.add_argument('--addTextCond', action="store_true", help='add Text feature to audio feature') 93 | self.parser.add_argument('--addEmoCond', action="store_true", help='add Emo feature to audio feature') 94 | self.parser.add_argument('--expAddHubert', action="store_true", help='concat Hubert feature to encoded audio feature only for expression generation') 95 | self.parser.add_argument('--addHubert', type=bool, default=True, help='concat Hubert feature to encoded audio feature for both expression and gesture generation') 96 | self.parser.add_argument('--addWav2Vec2', action="store_true", help='concat Wav2Vec2 feature to encoded audio feature for both expression and gesture generation') 97 | self.parser.add_argument('--encode_wav2vec2', action="store_true", help='encode the wav2vec2 feature') 98 | self.parser.add_argument('--encode_hubert', type=bool, default=True, help='encode the hubert feature') 99 | self.parser.add_argument('--separate', type=str, choices=['v1', 'v2'], default=None, help='limit information exchange between expression and gestures, v1 share encoder, v2 two independent encoders') 100 | self.parser.add_argument('--usePredExpr', type=str, default=None, help='Path to the predicted expressions.') 101 | self.parser.add_argument('--unidiffuser', type=bool, default=True, help='Use the unified framework for joint expression and gesture generation') 102 | 103 | self.parser.add_argument('--separate_pure', action="store_true", help='pure two encoders') 104 | 105 | # classifier-free guidance 106 | self.parser.add_argument('--classifier_free', action="store_true", help='Use classifier-free guidance') 107 | self.parser.add_argument('--null_cond_prob', type=float, default=0.2, help='Probability of null condition during classifier-free training') 108 | self.parser.add_argument('--cond_scale', type=float, default=1.0, help='Scale of the condition in classifier-free guidance sampling') 109 | 110 | # Try Expression ID off 111 | self.parser.add_argument('--ExprID_off', action="store_true", help='Turn off the expression ID condition') 112 | self.parser.add_argument('--ExprID_off_uncond', action="store_true", help='Turn off the expression ID condition under the classifier-free uncondition part of training') 113 | 114 | 115 | self.parser.add_argument('--use_joints', action="store_true", help='Whether convert to joints if using TED 3D dataset') 116 | self.parser.add_argument('--use_single_style', action="store_true", help='Whether to use single style') 117 | self.parser.add_argument('--test_on_trainset', action="store_true", help='Whether to test on training set') 118 | self.parser.add_argument('--test_on_val', action="store_true", help='Whether to test on validation set') 119 | self.parser.add_argument('--output_gt', action="store_true", help='Directly output GT during test') 120 | self.parser.add_argument('--no_style', action="store_true", help='Do not use style vectors') 121 | self.parser.add_argument('--no_resample', action="store_true", help='Do not use resample during inpainting based sampling') 122 | self.parser.add_argument('--add_vel_loss', type=bool, default=True, help='Add velocity loss') 123 | self.parser.add_argument('--vel_loss_start', type=int, default=-1, help='velocity loss and huber loss start epoch') 124 | self.parser.add_argument('--expr_weight', type=int, default=1, help='expression weight') 125 | 126 | # inference 127 | self.parser.add_argument('--jump_n_sample', type=int, default=5, help='hyperparameter for resampling') 128 | self.parser.add_argument('--jump_length', type=int, default=3, help='hyperparameter for resampling') 129 | 130 | 131 | ## Distributed 132 | self.parser.add_argument('--world-size', default=1, type=int, 133 | help='number of nodes for distributed training') 134 | self.parser.add_argument('--rank', default=0, type=int, 135 | help='node rank for distributed training') 136 | self.parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 137 | help='url used to set up distributed training') 138 | self.parser.add_argument('--dist-backend', default='nccl', type=str, 139 | help='distributed backend') 140 | self.parser.add_argument('--multiprocessing-distributed', action='store_true', 141 | help='Use multi-processing distributed training to launch ' 142 | 'N processes per node, which has N GPUs. This is the ' 143 | 'fastest way to use PyTorch for either single node or ' 144 | 'multi node data parallel training') 145 | self.parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 146 | help='number of data loading workers (default: 4)') 147 | 148 | self.initialized = True 149 | 150 | 151 | def parse(self): 152 | if not self.initialized: 153 | self.initialize() 154 | 155 | self.opt = self.parser.parse_args() 156 | 157 | self.opt.is_train = self.is_train 158 | 159 | args = vars(self.opt) 160 | 161 | if self.opt.rank == 0: 162 | print('------------ Options -------------') 163 | for k, v in sorted(args.items()): 164 | print('%s: %s' % (str(k), str(v))) 165 | print('-------------- End ----------------') 166 | if self.is_train: 167 | # save to the disk 168 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name) 169 | if not os.path.exists(expr_dir): 170 | os.makedirs(expr_dir) 171 | file_name = os.path.join(expr_dir, 'opt.txt') 172 | with open(file_name, 'wt') as opt_file: 173 | opt_file.write('------------ Options -------------\n') 174 | for k, v in sorted(args.items()): 175 | opt_file.write('%s: %s\n' % (str(k), str(v))) 176 | opt_file.write('-------------- End ----------------\n') 177 | if self.opt.world_size > 1: 178 | dist.barrier() 179 | return self.opt 180 | --------------------------------------------------------------------------------