├── evaluation ├── __init__.py ├── diversity_LVD.py ├── get_quality_samples.py ├── mode_transition.py ├── peak_velocity.py ├── metrics.py ├── util.py └── FGD.py ├── scripts ├── __init__.py ├── .idea │ ├── __init__.py │ ├── lower body │ ├── test.png │ ├── vcs.xml │ ├── inspectionProfiles │ │ ├── profiles_settings.xml │ │ └── Project_Default.xml │ ├── modules.xml │ ├── scripts.iml │ ├── aws.xml │ ├── testtext.py │ ├── deployment.xml │ ├── workspace.xml │ └── get_prevar.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── diversity.cpython-37.pyc ├── train.py ├── test_vq.py ├── test_face.py ├── continuity.py └── test_body.py ├── visualise ├── __init__.py ├── .DS_Store └── image.png ├── losses ├── __init__.py └── losses.py ├── trainer ├── __init__.py ├── config.py ├── training_config.cfg └── options.py ├── data_utils ├── split_more_than_2s.pkl ├── __init__.py ├── axis2matrix.py ├── get_j.py ├── apply_split.py ├── test.py ├── dataset_preprocess.py ├── lower_body.py └── dataloader_torch.py ├── voca ├── __pycache__ │ ├── rendering.cpython-37.pyc │ ├── rendering.cpython-38.pyc │ └── rendering.cpython-39.pyc └── rendering.py ├── train_face.sh ├── train_body_vq.sh ├── requirements.txt ├── train_body_pixel.sh ├── test_body.sh ├── test_face.sh ├── visualise.sh ├── nets ├── __init__.py ├── init_model.py ├── base.py ├── body_ae.py ├── utils.py ├── spg │ ├── wav2vec.py │ ├── qformer.py │ ├── gated_pixelcnn_v2.py │ ├── t2m_trans.py │ ├── s2g_face.py │ ├── vqvae_1d.py │ └── blip.py └── smplx_face.py ├── config ├── bert_config.json ├── face.json ├── LS3DCG.json ├── body_vq.json └── body_pixel.json └── README.md /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visualise/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/.idea/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /scripts/.idea/lower body: -------------------------------------------------------------------------------- 1 | 0, 1, 3, 4, 6, 7, 9, 10, -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Trainer import Trainer 2 | from .Trainer_vq import Trainer_vq -------------------------------------------------------------------------------- /visualise/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/visualise/.DS_Store -------------------------------------------------------------------------------- /visualise/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/visualise/image.png -------------------------------------------------------------------------------- /scripts/.idea/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/.idea/test.png -------------------------------------------------------------------------------- /data_utils/split_more_than_2s.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/data_utils/split_more_than_2s.pkl -------------------------------------------------------------------------------- /voca/__pycache__/rendering.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-37.pyc -------------------------------------------------------------------------------- /voca/__pycache__/rendering.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-38.pyc -------------------------------------------------------------------------------- /voca/__pycache__/rendering.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-39.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/diversity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/__pycache__/diversity.cpython-37.pyc -------------------------------------------------------------------------------- /train_face.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/train.py \ 2 | --save_dir experiments \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/face.json -------------------------------------------------------------------------------- /train_body_vq.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/train.py \ 2 | --save_dir experiments \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/body_vq.json -------------------------------------------------------------------------------- /scripts/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | transformers 3 | matplotlib 4 | textgrid 5 | smplx 6 | scikit-learn 7 | pyrender 8 | trimesh 9 | tqdm 10 | librosa 11 | scipy 12 | python_speech_features 13 | opencv-python 14 | pyglet 15 | encodec 16 | -------------------------------------------------------------------------------- /scripts/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /train_body_pixel.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/train.py \ 2 | --save_dir your/save/path \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/body_pixel.json \ 6 | --bert_config ./config/bert_config.json -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dataloader_csv import MultiVidData as csv_data 2 | from .dataloader_torch import MultiVidData as torch_data 3 | from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta,get_encodec,get_encodec_token -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 7 | sys.path.append(os.getcwd()) 8 | from trainer import Trainer,Trainer_vq 9 | 10 | if __name__ == '__main__': 11 | 12 | trainer = Trainer() 13 | trainer.train() 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /test_body.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/test_body.py \ 2 | --save_dir experiments \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/body_pixel.json \ 6 | --body_model_name s2g_body_pixel \ 7 | --body_model_path your/train/model/path \ 8 | --infer 9 | 10 | -------------------------------------------------------------------------------- /test_face.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/test_face.py \ 2 | --save_dir experiments \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/face.json \ 6 | --face_model_name s2g_face \ 7 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \ 8 | --infer -------------------------------------------------------------------------------- /scripts/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /visualise.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/diversity.py \ 2 | --save_dir experiments \ 3 | --exp_name smplx_S2G \ 4 | --speakers oliver seth conan chemistry \ 5 | --config_file ./config/body_pixel.json \ 6 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \ 7 | --body_model_path your/body/model/path \ 8 | --infer -------------------------------------------------------------------------------- /scripts/.idea/scripts.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/.idea/aws.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .smplx_face import TrainWrapper as s2g_face 2 | from .smplx_body_vq import TrainWrapper as s2g_body_vq 3 | from .smplx_body_pixel import TrainWrapper as s2g_body_pixel 4 | from .body_ae import TrainWrapper as s2g_body_ae 5 | from .LS3DCG import TrainWrapper as LS3DCG 6 | from .base import TrainWrapperBaseClass 7 | 8 | from .utils import normalize, denormalize -------------------------------------------------------------------------------- /scripts/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /config/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2048, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 8, 15 | "num_hidden_layers": 6, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /trainer/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | load config from json file 3 | ''' 4 | import json 5 | import os 6 | 7 | import configparser 8 | 9 | 10 | class Object(): 11 | def __init__(self, config:dict) -> None: 12 | for key in list(config.keys()): 13 | if isinstance(config[key], dict): 14 | setattr(self, key, Object(config[key])) 15 | else: 16 | setattr(self, key, config[key]) 17 | 18 | def load_JsonConfig(json_file): 19 | with open(json_file, 'r') as f: 20 | config = json.load(f) 21 | 22 | return Object(config) 23 | 24 | 25 | if __name__ == '__main__': 26 | config = load_JsonConfig('config/style_gestures.json') 27 | print(dir(config)) -------------------------------------------------------------------------------- /data_utils/axis2matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import scipy.linalg as linalg 4 | 5 | 6 | def rotate_mat(axis, radian): 7 | 8 | a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian) 9 | 10 | rot_matrix = linalg.expm(a) 11 | 12 | return rot_matrix 13 | 14 | def aaa2mat(axis, sin, cos): 15 | i = np.eye(3) 16 | nnt = np.dot(axis.T, axis) 17 | s = np.asarray([[0, -axis[0,2], axis[0,1]], 18 | [axis[0,2], 0, -axis[0,0]], 19 | [-axis[0,1], axis[0,0], 0]]) 20 | r = cos * i + (1-cos)*nnt +sin * s 21 | return r 22 | 23 | rand_axis = np.asarray([[1,0,0]]) 24 | #旋转角度 25 | r = math.pi/2 26 | #返回旋转矩阵 27 | rot_matrix = rotate_mat(rand_axis, r) 28 | r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r)) 29 | print(rot_matrix) -------------------------------------------------------------------------------- /nets/init_model.py: -------------------------------------------------------------------------------- 1 | from nets import * 2 | 3 | 4 | def init_model(model_name, args, config,bert_config=None,iteration=None): 5 | 6 | if model_name == 's2g_face': 7 | generator = s2g_face( 8 | args, 9 | config, 10 | ) 11 | elif model_name == 's2g_body_vq': 12 | generator = s2g_body_vq( 13 | args, 14 | config, 15 | ) 16 | elif model_name == 's2g_body_pixel': 17 | generator = s2g_body_pixel( 18 | args, 19 | config, 20 | bert_config, 21 | iteration 22 | ) 23 | elif model_name == 's2g_body_ae': 24 | generator = s2g_body_ae( 25 | args, 26 | config, 27 | ) 28 | elif model_name == 's2g_LS3DCG': 29 | generator = LS3DCG( 30 | args, 31 | config, 32 | ) 33 | else: 34 | raise ValueError 35 | return generator 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/.idea/testtext.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | # path being defined from where the system will read the image 4 | path = r'test.png' 5 | # command used for reading an image from the disk disk, cv2.imread function is used 6 | image1 = cv2.imread(path) 7 | # Window name being specified where the image will be displayed 8 | window_name1 = 'image' 9 | # font for the text being specified 10 | font1 = cv2.FONT_HERSHEY_SIMPLEX 11 | # org for the text being specified 12 | org1 = (50, 50) 13 | # font scale for the text being specified 14 | fontScale1 = 1 15 | # Blue color for the text being specified from BGR 16 | color1 = (255, 255, 255) 17 | # Line thickness for the text being specified at 2 px 18 | thickness1 = 2 19 | # Using the cv2.putText() method for inserting text in the image of the specified path 20 | image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA) 21 | # Displaying the output image 22 | cv2.imshow(window_name1, image_1) 23 | cv2.waitKey(0) 24 | cv2.destroyAllWindows() 25 | -------------------------------------------------------------------------------- /trainer/training_config.cfg: -------------------------------------------------------------------------------- 1 | [Input Output] 2 | checkpoint_dir = ./training 3 | expression_basis_fname = ./training_data/init_expression_basis.npy 4 | template_fname = ./template/FLAME_sample.ply 5 | deepspeech_graph_fname = ./ds_graph/output_graph.pb 6 | face_or_body = body 7 | verts_mmaps_path = ./training_data/data_verts.npy 8 | raw_audio_path = ./training_data/raw_audio_fixed.pkl 9 | processed_audio_path = ./training_data/processed_audio_deepspeech.pkl 10 | templates_path = ./training_data/templates.pkl 11 | data2array_verts_path = ./training_data/subj_seq_to_idx.pkl 12 | 13 | [Audio Parameters] 14 | audio_feature_type = deepspeech 15 | num_audio_features = 29 16 | audio_window_size = 16 17 | audio_window_stride = 1 18 | condition_speech_features = True 19 | speech_encoder_size_factor = 1.0 20 | 21 | [Model Parameters] 22 | num_vertices = 10475 23 | expression_dim = 50 24 | init_expression = False 25 | num_consecutive_frames = 30 26 | absolute_reconstruction_loss = False 27 | velocity_weight = 10.0 28 | acceleration_weight = 0.0 29 | verts_regularizer_weight = 0.0 30 | 31 | [Data Setup] 32 | subject_for_training = speeker_oliver 33 | sequence_for_training = 0-00'00'05-00'00'10 1-00'00'32-00'00'37 2-00'01'05-00'01'10 34 | subject_for_validation = speeker_oliver 35 | sequence_for_validation = 2-00'01'05-00'01'10 36 | subject_for_testing = speeker_oliver 37 | sequence_for_testing = 2-00'01'05-00'01'10 38 | 39 | [Learning Parameters] 40 | batch_size = 64 41 | learning_rate = 1e-4 42 | decay_rate = 1.0 43 | epoch_num = 1000 44 | adam_beta1_value = 0.9 45 | 46 | [Visualization Parameters] 47 | num_render_sequences = 3 48 | 49 | -------------------------------------------------------------------------------- /config/face.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "json", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/ExpressiveWholeBodyDatasetReleaseV1.0", 15 | "pklname": "_3d_wv2.pkl", 16 | "whole_video": true, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "face", 36 | "model_name": "s2g_face", 37 | "AudioOpt": "SGD", 38 | "encoder_choice": "faceformer", 39 | "gan": false 40 | }, 41 | "DataLoader": { 42 | "batch_size": 1, 43 | "num_workers": 16 44 | }, 45 | "Train": { 46 | "epochs": 100, 47 | "max_gradient_norm": 5, 48 | "learning_rate": { 49 | "generator_learning_rate": 1e-3, 50 | "discriminator_learning_rate": 1e-4 51 | } 52 | }, 53 | "Log": { 54 | "save_every": 50, 55 | "print_every": 1000, 56 | "name": "face" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /config/LS3DCG.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "pickle", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/", 15 | "pklname": "_3d_mfcc.pkl", 16 | "whole_video": false, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "body", 36 | "model_name": "s2g_LS3DCG", 37 | "code_num": 2048, 38 | "AudioOpt": "Adam", 39 | "encoder_choice": "mfcc", 40 | "gan": false 41 | }, 42 | "DataLoader": { 43 | "batch_size": 128, 44 | "num_workers": 0 45 | }, 46 | "Train": { 47 | "epochs": 100, 48 | "max_gradient_norm": 5, 49 | "learning_rate": { 50 | "generator_learning_rate": 1e-4, 51 | "discriminator_learning_rate": 1e-4 52 | }, 53 | "weights": { 54 | "keypoint_loss_weight": 1.0, 55 | "gan_loss_weight": 1.0 56 | } 57 | }, 58 | "Log": { 59 | "save_every": 50, 60 | "print_every": 200, 61 | "name": "LS3DCG" 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /config/body_vq.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "pickle", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/ExpressiveWholeBodyDatasetReleaseV1.0", 15 | "pklname": "_3d_encodec_viclip_200m.pkl", 16 | "whole_video": false, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "body", 36 | "model_name": "s2g_body_vq", 37 | "composition": true, 38 | "code_num": 2048, 39 | "bh_model": true, 40 | "AudioOpt": "Adam", 41 | "encoder_choice": "mfcc", 42 | "gan": false 43 | }, 44 | "DataLoader": { 45 | "batch_size": 128, 46 | "num_workers": 0 47 | }, 48 | "Train": { 49 | "epochs": 100, 50 | "max_gradient_norm": 5, 51 | "learning_rate": { 52 | "generator_learning_rate": 5e-5, 53 | "discriminator_learning_rate": 1e-4 54 | } 55 | }, 56 | "Log": { 57 | "save_every": 50, 58 | "print_every": 200, 59 | "name": "body-vq" 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /config/body_pixel.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "pickle", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | 14 | "Data": { 15 | "data_root": "./ExpressiveWholeBodyDatasetReleaseV1.0", 16 | "pklname": "_3d_encodec_viclip_200m_token.pkl", 17 | "whole_video": false, 18 | "pose": { 19 | "normalization": false, 20 | "convert_to_6d": false, 21 | "norm_method": "all", 22 | "augmentation": false, 23 | "generate_length": 88, 24 | "pre_pose_length": 0, 25 | "pose_dim": 99, 26 | "expression": true 27 | }, 28 | "aud": { 29 | "feat_method": "mfcc", 30 | "aud_feat_dim":128, 31 | "aud_feat_win_size": null, 32 | "context_info": false 33 | } 34 | }, 35 | 36 | "Model": { 37 | "model_type": "body", 38 | "model_name": "s2g_body_pixel", 39 | "composition": true, 40 | "code_num": 2048, 41 | "bh_model": true, 42 | "AudioOpt": "Adam", 43 | "encoder_choice": "mfcc", 44 | "gan": false, 45 | "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth" 46 | }, 47 | 48 | "DataLoader": { 49 | "batch_size": 128, 50 | "num_workers": 0 51 | }, 52 | 53 | "Train": { 54 | "epochs": 300, 55 | "max_gradient_norm": 5, 56 | "learning_rate": { 57 | "generator_learning_rate": 1e-4, 58 | "discriminator_learning_rate": 1e-4 59 | } 60 | }, 61 | 62 | "Log": { 63 | "save_every": 50, 64 | "print_every": 400, 65 | "name": "body-pixel2" 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /data_utils/get_j.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to3d(poses, config): 5 | if config.Data.pose.convert_to_6d: 6 | if config.Data.pose.expression: 7 | poses_exp = poses[:, -100:] 8 | poses = poses[:, :-100] 9 | 10 | poses = poses.reshape(poses.shape[0], -1, 5) 11 | sin, cos = poses[:, :, 3], poses[:, :, 4] 12 | pose_angle = torch.atan2(sin, cos) 13 | poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1) 14 | 15 | if config.Data.pose.expression: 16 | poses = torch.cat([poses, poses_exp], dim=-1) 17 | return poses 18 | 19 | 20 | def get_joint(smplx_model, betas, pred): 21 | joint = smplx_model(betas=betas.repeat(pred.shape[0], 1), 22 | expression=pred[:, 165:265], 23 | jaw_pose=pred[:, 0:3], 24 | leye_pose=pred[:, 3:6], 25 | reye_pose=pred[:, 6:9], 26 | global_orient=pred[:, 9:12], 27 | body_pose=pred[:, 12:75], 28 | left_hand_pose=pred[:, 75:120], 29 | right_hand_pose=pred[:, 120:165], 30 | return_verts=True)['joints'] 31 | return joint 32 | 33 | 34 | def get_joints(smplx_model, betas, pred): 35 | if len(pred.shape) == 3: 36 | B = pred.shape[0] 37 | x = 4 if B>= 4 else B 38 | T = pred.shape[1] 39 | pred = pred.reshape(-1, 265) 40 | smplx_model.batch_size = L = T * x 41 | 42 | times = pred.shape[0] // smplx_model.batch_size 43 | joints = [] 44 | for i in range(times): 45 | joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L])) 46 | joints = torch.cat(joints, dim=0) 47 | joints = joints.reshape(B, T, -1, 3) 48 | else: 49 | smplx_model.batch_size = pred.shape[0] 50 | joints = get_joint(smplx_model, betas, pred) 51 | return joints -------------------------------------------------------------------------------- /data_utils/apply_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import pickle 4 | import shutil 5 | 6 | speakers = ['seth', 'oliver', 'conan', 'chemistry'] 7 | source_data_root = "../expressive_body-V0.7" 8 | data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0" 9 | 10 | f_read = open('split_more_than_2s.pkl', 'rb') 11 | f_save = open('none.pkl', 'wb') 12 | data_split = pickle.load(f_read) 13 | none_split = [] 14 | 15 | train = val = test = 0 16 | 17 | for speaker_name in speakers: 18 | speaker_root = os.path.join(data_root, speaker_name) 19 | 20 | videos = [v for v in data_split[speaker_name]] 21 | 22 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): 23 | for split in data_split[speaker_name][vid]: 24 | for seq in data_split[speaker_name][vid][split]: 25 | 26 | seq = seq.replace('\\', '/') 27 | old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1]) 28 | old_file_path = old_file_path.replace('\\', '/') 29 | new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1]) 30 | try: 31 | shutil.move(old_file_path, new_file_path) 32 | if split == 'train': 33 | train = train + 1 34 | elif split == 'test': 35 | test = test + 1 36 | elif split == 'val': 37 | val = val + 1 38 | except FileNotFoundError: 39 | none_split.append(old_file_path) 40 | print(f"The file {old_file_path} does not exists.") 41 | except shutil.Error: 42 | none_split.append(old_file_path) 43 | print(f"The file {old_file_path} does not exists.") 44 | 45 | print(none_split.__len__()) 46 | pickle.dump(none_split, f_save) 47 | f_save.close() 48 | 49 | print(train, val, test) 50 | 51 | 52 | -------------------------------------------------------------------------------- /trainer/options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | def parse_args(): 4 | parser = ArgumentParser() 5 | parser.add_argument('--gpu', default=0, type=int) 6 | parser.add_argument('--save_dir', default='experiments', type=str) 7 | parser.add_argument('--exp_name', default='smplx_S2G', type=str) 8 | parser.add_argument('--speakers', nargs='+') 9 | parser.add_argument('--seed', default=1, type=int) 10 | parser.add_argument('--model_name', type=str) 11 | #parser.add_argument('--bert_config', default='experiments', type=str) 12 | #for Tmpt and S2G 13 | parser.add_argument('--use_template', action='store_true') 14 | parser.add_argument('--template_length', default=0, type=int) 15 | #for training from a ckpt 16 | parser.add_argument('--resume', action='store_true') 17 | parser.add_argument('--pretrained_pth', default=None, type=str) 18 | parser.add_argument('--style_layer_norm', action='store_true') 19 | #required 20 | parser.add_argument('--config_file', default='./config/style_gestures.json', type=str) 21 | parser.add_argument('--bert_config', default='./config/bert_config.json', type=str) 22 | # for visualization and test 23 | parser.add_argument('--audio_file', default=None, type=str) 24 | parser.add_argument('--id', default=0, type=int, help='0=oliver, 1=chemistry, 2=seth, 3=conan') 25 | parser.add_argument('--only_face', action='store_true') 26 | parser.add_argument('--stand', action='store_true') 27 | parser.add_argument('--num_sample', default=1, type=int) 28 | parser.add_argument('--whole_body', action='store_true') 29 | parser.add_argument('--face_model_name', default='s2g_face', type=str) 30 | parser.add_argument('--face_model_path', default='./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth', type=str) 31 | parser.add_argument('--body_model_name', default='s2g_body_pixel', type=str) 32 | parser.add_argument('--body_model_path', default='./viclip_lr1e-4_layer6_head8_repeat_half/2024-02-25-smplx_S2G-body-pixel2/ckpt-299.pth', type=str) 33 | parser.add_argument('--infer', action='store_true') 34 | return parser 35 | 36 | -------------------------------------------------------------------------------- /scripts/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /evaluation/diversity_LVD.py: -------------------------------------------------------------------------------- 1 | ''' 2 | LVD: different initial pose 3 | diversity: same initial pose 4 | ''' 5 | import os 6 | import sys 7 | sys.path.append(os.getcwd()) 8 | 9 | from glob import glob 10 | 11 | from argparse import ArgumentParser 12 | import json 13 | 14 | from evaluation.util import * 15 | from evaluation.metrics import * 16 | from tqdm import tqdm 17 | 18 | parser = ArgumentParser() 19 | parser.add_argument('--speaker', required=True, type=str) 20 | parser.add_argument('--post_fix', nargs='+', default=['base'], type=str) 21 | args = parser.parse_args() 22 | 23 | speaker = args.speaker 24 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 25 | 26 | LVD_list = [] 27 | diversity_list = [] 28 | 29 | for aud in tqdm(test_audios): 30 | base_name = os.path.splitext(aud)[0] 31 | gt_path = get_full_path(aud, speaker, 'val') 32 | _, gt_poses, _ = get_gts(gt_path) 33 | gt_poses = gt_poses[np.newaxis,...] 34 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 35 | for post_fix in args.post_fix: 36 | pred_path = base_name + '_'+post_fix+'.json' 37 | pred_poses = np.array(json.load(open(pred_path))) 38 | # print(pred_poses.shape)#(B, seq_len, 108) 39 | pred_poses = cvt25(pred_poses, gt_poses) 40 | # print(pred_poses.shape)#(B, seq, pose_dim) 41 | 42 | gt_valid_points = hand_points(gt_poses) 43 | pred_valid_points = hand_points(pred_poses) 44 | 45 | lvd = LVD(gt_valid_points, pred_valid_points) 46 | # div = diversity(pred_valid_points) 47 | 48 | LVD_list.append(lvd) 49 | # diversity_list.append(div) 50 | 51 | # gt_velocity = peak_velocity(gt_valid_points, order=2) 52 | # pred_velocity = peak_velocity(pred_valid_points, order=2) 53 | 54 | # gt_consistency = velocity_consistency(gt_velocity, pred_velocity) 55 | # pred_consistency = velocity_consistency(pred_velocity, gt_velocity) 56 | 57 | # gt_consistency_list.append(gt_consistency) 58 | # pred_consistency_list.append(pred_consistency) 59 | 60 | lvd = np.mean(LVD_list) 61 | # diversity_list = np.mean(diversity_list) 62 | 63 | print('LVD:', lvd) 64 | # print("diversity:", diversity_list) -------------------------------------------------------------------------------- /evaluation/get_quality_samples.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | import os 4 | import sys 5 | sys.path.append(os.getcwd()) 6 | 7 | from glob import glob 8 | 9 | from argparse import ArgumentParser 10 | import json 11 | 12 | from evaluation.util import * 13 | from evaluation.metrics import * 14 | from tqdm import tqdm 15 | 16 | parser = ArgumentParser() 17 | parser.add_argument('--speaker', required=True, type=str) 18 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 19 | args = parser.parse_args() 20 | 21 | speaker = args.speaker 22 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 23 | 24 | quality_samples={'gt':[]} 25 | for post_fix in args.post_fix: 26 | quality_samples[post_fix] = [] 27 | 28 | for aud in tqdm(test_audios): 29 | base_name = os.path.splitext(aud)[0] 30 | gt_path = get_full_path(aud, speaker, 'val') 31 | _, gt_poses, _ = get_gts(gt_path) 32 | gt_poses = gt_poses[np.newaxis,...] 33 | gt_valid_points = valid_points(gt_poses) 34 | # print(gt_valid_points.shape) 35 | quality_samples['gt'].append(gt_valid_points) 36 | 37 | for post_fix in args.post_fix: 38 | pred_path = base_name + '_'+post_fix+'.json' 39 | pred_poses = np.array(json.load(open(pred_path))) 40 | # print(pred_poses.shape)#(B, seq_len, 108) 41 | pred_poses = cvt25(pred_poses, gt_poses) 42 | # print(pred_poses.shape)#(B, seq, pose_dim) 43 | 44 | pred_valid_points = valid_points(pred_poses)[0:1] 45 | quality_samples[post_fix].append(pred_valid_points) 46 | 47 | quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1) 48 | for post_fix in args.post_fix: 49 | quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1) 50 | 51 | print('gt:', quality_samples['gt'].shape) 52 | quality_samples['gt'] = quality_samples['gt'].tolist() 53 | for post_fix in args.post_fix: 54 | print(post_fix, ':', quality_samples[post_fix].shape) 55 | quality_samples[post_fix] = quality_samples[post_fix].tolist() 56 | 57 | save_dir = '../../experiments/' 58 | os.makedirs(save_dir, exist_ok=True) 59 | save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker)) 60 | with open(save_name, 'w') as f: 61 | json.dump(quality_samples, f) 62 | 63 | -------------------------------------------------------------------------------- /evaluation/mode_transition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | from glob import glob 6 | 7 | from argparse import ArgumentParser 8 | import json 9 | 10 | from evaluation.util import * 11 | from evaluation.metrics import * 12 | from tqdm import tqdm 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--speaker', required=True, type=str) 16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 17 | args = parser.parse_args() 18 | 19 | speaker = args.speaker 20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 21 | 22 | precision_list=[] 23 | recall_list=[] 24 | accuracy_list=[] 25 | 26 | for aud in tqdm(test_audios): 27 | base_name = os.path.splitext(aud)[0] 28 | gt_path = get_full_path(aud, speaker, 'val') 29 | _, gt_poses, _ = get_gts(gt_path) 30 | if gt_poses.shape[0] < 50: 31 | continue 32 | gt_poses = gt_poses[np.newaxis,...] 33 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 34 | for post_fix in args.post_fix: 35 | pred_path = base_name + '_'+post_fix+'.json' 36 | pred_poses = np.array(json.load(open(pred_path))) 37 | # print(pred_poses.shape)#(B, seq_len, 108) 38 | pred_poses = cvt25(pred_poses, gt_poses) 39 | # print(pred_poses.shape)#(B, seq, pose_dim) 40 | 41 | gt_valid_points = valid_points(gt_poses) 42 | pred_valid_points = valid_points(pred_poses) 43 | 44 | # print(gt_valid_points.shape, pred_valid_points.shape) 45 | 46 | gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N) 47 | pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N) 48 | 49 | # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape) 50 | # pred_mode_transition_seq = baseline 51 | precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq) 52 | precision_list.append(precision) 53 | recall_list.append(recall) 54 | accuracy_list.append(accuracy) 55 | print(len(precision_list), len(recall_list), len(accuracy_list)) 56 | precision_list = np.mean(precision_list) 57 | recall_list = np.mean(recall_list) 58 | accuracy_list = np.mean(accuracy_list) 59 | 60 | print('precision, recall, accu:', precision_list, recall_list, accuracy_list) 61 | -------------------------------------------------------------------------------- /evaluation/peak_velocity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | from glob import glob 6 | 7 | from argparse import ArgumentParser 8 | import json 9 | 10 | from evaluation.util import * 11 | from evaluation.metrics import * 12 | from tqdm import tqdm 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--speaker', required=True, type=str) 16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 17 | args = parser.parse_args() 18 | 19 | speaker = args.speaker 20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 21 | 22 | gt_consistency_list=[] 23 | pred_consistency_list=[] 24 | 25 | for aud in tqdm(test_audios): 26 | base_name = os.path.splitext(aud)[0] 27 | gt_path = get_full_path(aud, speaker, 'val') 28 | _, gt_poses, _ = get_gts(gt_path) 29 | gt_poses = gt_poses[np.newaxis,...] 30 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 31 | for post_fix in args.post_fix: 32 | pred_path = base_name + '_'+post_fix+'.json' 33 | pred_poses = np.array(json.load(open(pred_path))) 34 | # print(pred_poses.shape)#(B, seq_len, 108) 35 | pred_poses = cvt25(pred_poses, gt_poses) 36 | # print(pred_poses.shape)#(B, seq, pose_dim) 37 | 38 | gt_valid_points = hand_points(gt_poses) 39 | pred_valid_points = hand_points(pred_poses) 40 | 41 | gt_velocity = peak_velocity(gt_valid_points, order=2) 42 | pred_velocity = peak_velocity(pred_valid_points, order=2) 43 | 44 | gt_consistency = velocity_consistency(gt_velocity, pred_velocity) 45 | pred_consistency = velocity_consistency(pred_velocity, gt_velocity) 46 | 47 | gt_consistency_list.append(gt_consistency) 48 | pred_consistency_list.append(pred_consistency) 49 | 50 | gt_consistency_list = np.concatenate(gt_consistency_list) 51 | pred_consistency_list = np.concatenate(pred_consistency_list) 52 | 53 | print(gt_consistency_list.max(), gt_consistency_list.min()) 54 | print(pred_consistency_list.max(), pred_consistency_list.min()) 55 | print(np.mean(gt_consistency_list), np.mean(pred_consistency_list)) 56 | print(np.std(gt_consistency_list), np.std(pred_consistency_list)) 57 | 58 | draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue') 59 | draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue') 60 | 61 | to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker)) 62 | to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker)) 63 | 64 | np.save('%s_gt.npy'%(speaker), gt_consistency_list) 65 | np.save('%s_pred.npy'%(speaker), pred_consistency_list) -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | class KeypointLoss(nn.Module): 12 | def __init__(self): 13 | super(KeypointLoss, self).__init__() 14 | 15 | def forward(self, pred_seq, gt_seq, gt_conf=None): 16 | #pred_seq: (B, C, T) 17 | if gt_conf is not None: 18 | gt_conf = gt_conf >= 0.01 19 | return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean') 20 | else: 21 | return F.mse_loss(pred_seq, gt_seq) 22 | 23 | 24 | class KLLoss(nn.Module): 25 | def __init__(self, kl_tolerance): 26 | super(KLLoss, self).__init__() 27 | self.kl_tolerance = kl_tolerance 28 | 29 | def forward(self, mu, var, mul=1): 30 | kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64 31 | kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1) 32 | # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1) 33 | if self.kl_tolerance is not None: 34 | # above_line = kld_loss[kld_loss > self.kl_tolerance] 35 | # if len(above_line) > 0: 36 | # kld_loss = torch.mean(kld_loss) 37 | # else: 38 | # kld_loss = 0 39 | kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda')) 40 | # else: 41 | kld_loss = torch.mean(kld_loss) 42 | return kld_loss 43 | 44 | 45 | class L2KLLoss(nn.Module): 46 | def __init__(self, kl_tolerance): 47 | super(L2KLLoss, self).__init__() 48 | self.kl_tolerance = kl_tolerance 49 | 50 | def forward(self, x): 51 | # TODO: check 52 | kld_loss = torch.sum(x ** 2, dim=1) 53 | if self.kl_tolerance is not None: 54 | above_line = kld_loss[kld_loss > self.kl_tolerance] 55 | if len(above_line) > 0: 56 | kld_loss = torch.mean(kld_loss) 57 | else: 58 | kld_loss = 0 59 | else: 60 | kld_loss = torch.mean(kld_loss) 61 | return kld_loss 62 | 63 | class L2RegLoss(nn.Module): 64 | def __init__(self): 65 | super(L2RegLoss, self).__init__() 66 | 67 | def forward(self, x): 68 | #TODO: check 69 | return torch.sum(x**2) 70 | 71 | 72 | class L2Loss(nn.Module): 73 | def __init__(self): 74 | super(L2Loss, self).__init__() 75 | 76 | def forward(self, x): 77 | # TODO: check 78 | return torch.sum(x ** 2) 79 | 80 | 81 | class AudioLoss(nn.Module): 82 | def __init__(self): 83 | super(AudioLoss, self).__init__() 84 | 85 | def forward(self, dynamics, gt_poses): 86 | #pay attention, normalized 87 | mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1) 88 | gt = gt_poses - mean 89 | return F.mse_loss(dynamics, gt) 90 | 91 | L1Loss = nn.L1Loss -------------------------------------------------------------------------------- /data_utils/test.py: -------------------------------------------------------------------------------- 1 | """def get_encodec(audio_fn,model): 2 | samples, sr = ta.load(audio_fn) 3 | seq_lenth = samples.shape[1] / sr 4 | print("sr:",sr) 5 | samples = samples.to('cuda') 6 | 7 | if samples.shape[0] > 1: 8 | samples = torch.mean(samples, dim=0, keepdim=True) 9 | print("samples audio:",sample_audio.shape) 10 | with torch.no_grad(): 11 | #model = EncodecModel.encodec_model_24khz().to('cuda') 12 | model.set_target_bandwidth(6) 13 | samples = samples.unsqueeze(0) 14 | codes_raw = model.encode(samples) 15 | for frame in codes_raw: 16 | codes,_ = frame 17 | codes = codes.transpose(0,1) 18 | emb = model.quantizer.decode(codes) 19 | emb = emb.transpose(1,2) 20 | emb = linear_interpolation(emb,seq_len=seq_lenth,output_fps=30,output_len=None) 21 | emb = emb.squeeze(0).cpu().numpy() 22 | return emb""" 23 | 24 | def get_encodec(audio_fn,model): 25 | wav, sr = torchaudio.load(audio_fn) 26 | model.set_target_bandwidth(6.0) 27 | #print(sr,wav.shape) 28 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 29 | #print(wav.shape) 30 | seq_lenth = wav.shape[1]/model.sample_rate 31 | #print(seq_lenth) 32 | wav = wav.unsqueeze(0).to("cuda") 33 | # Extract discrete codes from EnCodec 34 | with torch.no_grad(): 35 | encoded_frames = model.encode(wav) 36 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze(0)+1 37 | 38 | token_sample = get_target_token(fps=30,codes=codes,duration_time=seq_lenth,num_codes_layer=-3) 39 | # [B, n_q, T] 40 | return token_sample 41 | 42 | def interpolate_vector(input_vector, outdim): 43 | indim = input_vector.shape[1] 44 | interval = indim / outdim 45 | # 生成采样索引 46 | idx = (np.arange(outdim) * interval).astype(int) 47 | 48 | # 等间隔采样 49 | output_vector = input_vector[:, idx] 50 | 51 | return output_vector 52 | 53 | def get_target_token(fps,codes,duration_time,num_codes_layer): 54 | seq_len = fps*duration_time 55 | #print(codes.shape) ### 8x750 56 | token_codes = codes[num_codes_layer,:].unsqueeze(0) ### 1x750 57 | for t in token_codes: 58 | p = torch.unique_consecutive(t) 59 | print(p.shape) 60 | token_sample = interpolate_vector(token_codes,seq_len) ### 1x300 61 | #print(token_sample.shape) 62 | return token_sample 63 | 64 | 65 | import numpy as np 66 | 67 | 68 | if __name__ == "__main__": 69 | audio_fn = '/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/demo_audio/214428-00_00_58-00_01_08.wav' 70 | from encodec import EncodecModel 71 | from encodec.utils import convert_audio 72 | from pathlib import Path 73 | import torchaudio 74 | import torch 75 | model = EncodecModel.encodec_model_24khz(repository=Path("/mnt/nj-aigc/dataset/pengwenshuo/encodec")).to('cuda') 76 | codes = get_encodec(audio_fn,model) 77 | codes = codes+1 78 | for code in codes: 79 | c = torch.unique_consecutive(code) 80 | print(c.shape) 81 | #print(torch.max(codes),torch.min(codes)) 82 | #token = get_target_token(fps=30,codes=codes,duration_time=duration_time,num_codes_layer=-1) 83 | 84 | 85 | -------------------------------------------------------------------------------- /scripts/test_vq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 5 | sys.path.append(os.getcwd()) 6 | 7 | from tqdm import tqdm 8 | from transformers import Wav2Vec2Processor 9 | 10 | from evaluation.metrics import LVD 11 | 12 | import numpy as np 13 | import smplx as smpl 14 | 15 | from data_utils.lower_body import part2full, poses2pred, c_index_3d 16 | from nets import * 17 | from nets.utils import get_path, get_dpath 18 | from trainer.options import parse_args 19 | from data_utils import torch_data 20 | from trainer.config import load_JsonConfig 21 | 22 | import torch 23 | from torch.utils import data 24 | from data_utils.get_j import to3d, get_joints 25 | from scripts.test_body import init_model, init_dataloader 26 | 27 | 28 | def test(test_loader, generator, config): 29 | print('start testing') 30 | 31 | loss_dict = {} 32 | B = 1 33 | with torch.no_grad(): 34 | count = 0 35 | for bat in tqdm(test_loader, desc="Testing......"): 36 | count = count + 1 37 | aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \ 38 | bat['expression'].to('cuda').to(torch.float32) 39 | id = bat['speaker'].to('cuda') - 20 40 | betas = bat['betas'][0].to('cuda').to(torch.float64) 41 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze() 42 | poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2) 43 | # poses = poses[:, c_index_3d, :] 44 | 45 | cur_wav_file = bat['aud_file'][0] 46 | 47 | pred = generator.infer_on_audio(cur_wav_file, 48 | initial_pose=poses, 49 | id=id, 50 | fps=30, 51 | B=B 52 | ) 53 | pred = torch.tensor(pred, device='cuda') 54 | bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()} 55 | 56 | if loss_dict: # 非空 57 | for key in list(bat_loss_dict.keys()): 58 | loss_dict[key] += bat_loss_dict[key] 59 | else: 60 | for key in list(bat_loss_dict.keys()): 61 | loss_dict[key] = bat_loss_dict[key] 62 | for key in loss_dict.keys(): 63 | loss_dict[key] = loss_dict[key] / count 64 | print(key + '=' + str(loss_dict[key].item())) 65 | 66 | 67 | def main(): 68 | parser = parse_args() 69 | args = parser.parse_args() 70 | device = torch.device(args.gpu) 71 | torch.cuda.set_device(device) 72 | 73 | config = load_JsonConfig(args.config_file) 74 | 75 | os.environ['smplx_npz_path'] = config.smplx_npz_path 76 | os.environ['extra_joint_path'] = config.extra_joint_path 77 | os.environ['j14_regressor_path'] = config.j14_regressor_path 78 | 79 | print('init dataloader...') 80 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config) 81 | print('init model...') 82 | model_name = 's2g_body_vq' 83 | model_type = 'n_com_8192' 84 | model_path = get_path(model_name, model_type) 85 | generator = init_model(model_name, model_path, args, config) 86 | 87 | test(test_loader, generator, config) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /scripts/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |