├── evaluation
├── __init__.py
├── diversity_LVD.py
├── get_quality_samples.py
├── mode_transition.py
├── peak_velocity.py
├── metrics.py
├── util.py
└── FGD.py
├── scripts
├── __init__.py
├── .idea
│ ├── __init__.py
│ ├── lower body
│ ├── test.png
│ ├── vcs.xml
│ ├── inspectionProfiles
│ │ ├── profiles_settings.xml
│ │ └── Project_Default.xml
│ ├── modules.xml
│ ├── scripts.iml
│ ├── aws.xml
│ ├── testtext.py
│ ├── deployment.xml
│ ├── workspace.xml
│ └── get_prevar.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── diversity.cpython-37.pyc
├── train.py
├── test_vq.py
├── test_face.py
├── continuity.py
└── test_body.py
├── visualise
├── __init__.py
├── .DS_Store
└── image.png
├── losses
├── __init__.py
└── losses.py
├── trainer
├── __init__.py
├── config.py
├── training_config.cfg
└── options.py
├── data_utils
├── split_more_than_2s.pkl
├── __init__.py
├── axis2matrix.py
├── get_j.py
├── apply_split.py
├── test.py
├── dataset_preprocess.py
├── lower_body.py
└── dataloader_torch.py
├── voca
├── __pycache__
│ ├── rendering.cpython-37.pyc
│ ├── rendering.cpython-38.pyc
│ └── rendering.cpython-39.pyc
└── rendering.py
├── train_face.sh
├── train_body_vq.sh
├── requirements.txt
├── train_body_pixel.sh
├── test_body.sh
├── test_face.sh
├── visualise.sh
├── nets
├── __init__.py
├── init_model.py
├── base.py
├── body_ae.py
├── utils.py
├── spg
│ ├── wav2vec.py
│ ├── qformer.py
│ ├── gated_pixelcnn_v2.py
│ ├── t2m_trans.py
│ ├── s2g_face.py
│ ├── vqvae_1d.py
│ └── blip.py
└── smplx_face.py
├── config
├── bert_config.json
├── face.json
├── LS3DCG.json
├── body_vq.json
└── body_pixel.json
└── README.md
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visualise/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/.idea/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
--------------------------------------------------------------------------------
/scripts/.idea/lower body:
--------------------------------------------------------------------------------
1 | 0, 1, 3, 4, 6, 7, 9, 10,
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .Trainer import Trainer
2 | from .Trainer_vq import Trainer_vq
--------------------------------------------------------------------------------
/visualise/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/visualise/.DS_Store
--------------------------------------------------------------------------------
/visualise/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/visualise/image.png
--------------------------------------------------------------------------------
/scripts/.idea/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/.idea/test.png
--------------------------------------------------------------------------------
/data_utils/split_more_than_2s.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/data_utils/split_more_than_2s.pkl
--------------------------------------------------------------------------------
/voca/__pycache__/rendering.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-37.pyc
--------------------------------------------------------------------------------
/voca/__pycache__/rendering.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-38.pyc
--------------------------------------------------------------------------------
/voca/__pycache__/rendering.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/voca/__pycache__/rendering.cpython-39.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/diversity.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gloria2tt/T3M/HEAD/scripts/__pycache__/diversity.cpython-37.pyc
--------------------------------------------------------------------------------
/train_face.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/face.json
--------------------------------------------------------------------------------
/train_body_vq.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_vq.json
--------------------------------------------------------------------------------
/scripts/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | transformers
3 | matplotlib
4 | textgrid
5 | smplx
6 | scikit-learn
7 | pyrender
8 | trimesh
9 | tqdm
10 | librosa
11 | scipy
12 | python_speech_features
13 | opencv-python
14 | pyglet
15 | encodec
16 |
--------------------------------------------------------------------------------
/scripts/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/train_body_pixel.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir your/save/path \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json \
6 | --bert_config ./config/bert_config.json
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # from .dataloader_csv import MultiVidData as csv_data
2 | from .dataloader_torch import MultiVidData as torch_data
3 | from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta,get_encodec,get_encodec_token
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 |
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
7 | sys.path.append(os.getcwd())
8 | from trainer import Trainer,Trainer_vq
9 |
10 | if __name__ == '__main__':
11 |
12 | trainer = Trainer()
13 | trainer.train()
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/test_body.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/test_body.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json \
6 | --body_model_name s2g_body_pixel \
7 | --body_model_path your/train/model/path \
8 | --infer
9 |
10 |
--------------------------------------------------------------------------------
/test_face.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/test_face.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/face.json \
6 | --face_model_name s2g_face \
7 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \
8 | --infer
--------------------------------------------------------------------------------
/scripts/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/visualise.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/diversity.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json \
6 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \
7 | --body_model_path your/body/model/path \
8 | --infer
--------------------------------------------------------------------------------
/scripts/.idea/scripts.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/scripts/.idea/aws.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .smplx_face import TrainWrapper as s2g_face
2 | from .smplx_body_vq import TrainWrapper as s2g_body_vq
3 | from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
4 | from .body_ae import TrainWrapper as s2g_body_ae
5 | from .LS3DCG import TrainWrapper as LS3DCG
6 | from .base import TrainWrapperBaseClass
7 |
8 | from .utils import normalize, denormalize
--------------------------------------------------------------------------------
/scripts/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/config/bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 2048,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 8,
15 | "num_hidden_layers": 6,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30522,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/trainer/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | load config from json file
3 | '''
4 | import json
5 | import os
6 |
7 | import configparser
8 |
9 |
10 | class Object():
11 | def __init__(self, config:dict) -> None:
12 | for key in list(config.keys()):
13 | if isinstance(config[key], dict):
14 | setattr(self, key, Object(config[key]))
15 | else:
16 | setattr(self, key, config[key])
17 |
18 | def load_JsonConfig(json_file):
19 | with open(json_file, 'r') as f:
20 | config = json.load(f)
21 |
22 | return Object(config)
23 |
24 |
25 | if __name__ == '__main__':
26 | config = load_JsonConfig('config/style_gestures.json')
27 | print(dir(config))
--------------------------------------------------------------------------------
/data_utils/axis2matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import scipy.linalg as linalg
4 |
5 |
6 | def rotate_mat(axis, radian):
7 |
8 | a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
9 |
10 | rot_matrix = linalg.expm(a)
11 |
12 | return rot_matrix
13 |
14 | def aaa2mat(axis, sin, cos):
15 | i = np.eye(3)
16 | nnt = np.dot(axis.T, axis)
17 | s = np.asarray([[0, -axis[0,2], axis[0,1]],
18 | [axis[0,2], 0, -axis[0,0]],
19 | [-axis[0,1], axis[0,0], 0]])
20 | r = cos * i + (1-cos)*nnt +sin * s
21 | return r
22 |
23 | rand_axis = np.asarray([[1,0,0]])
24 | #旋转角度
25 | r = math.pi/2
26 | #返回旋转矩阵
27 | rot_matrix = rotate_mat(rand_axis, r)
28 | r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
29 | print(rot_matrix)
--------------------------------------------------------------------------------
/nets/init_model.py:
--------------------------------------------------------------------------------
1 | from nets import *
2 |
3 |
4 | def init_model(model_name, args, config,bert_config=None,iteration=None):
5 |
6 | if model_name == 's2g_face':
7 | generator = s2g_face(
8 | args,
9 | config,
10 | )
11 | elif model_name == 's2g_body_vq':
12 | generator = s2g_body_vq(
13 | args,
14 | config,
15 | )
16 | elif model_name == 's2g_body_pixel':
17 | generator = s2g_body_pixel(
18 | args,
19 | config,
20 | bert_config,
21 | iteration
22 | )
23 | elif model_name == 's2g_body_ae':
24 | generator = s2g_body_ae(
25 | args,
26 | config,
27 | )
28 | elif model_name == 's2g_LS3DCG':
29 | generator = LS3DCG(
30 | args,
31 | config,
32 | )
33 | else:
34 | raise ValueError
35 | return generator
36 |
37 |
38 |
--------------------------------------------------------------------------------
/scripts/.idea/testtext.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | # path being defined from where the system will read the image
4 | path = r'test.png'
5 | # command used for reading an image from the disk disk, cv2.imread function is used
6 | image1 = cv2.imread(path)
7 | # Window name being specified where the image will be displayed
8 | window_name1 = 'image'
9 | # font for the text being specified
10 | font1 = cv2.FONT_HERSHEY_SIMPLEX
11 | # org for the text being specified
12 | org1 = (50, 50)
13 | # font scale for the text being specified
14 | fontScale1 = 1
15 | # Blue color for the text being specified from BGR
16 | color1 = (255, 255, 255)
17 | # Line thickness for the text being specified at 2 px
18 | thickness1 = 2
19 | # Using the cv2.putText() method for inserting text in the image of the specified path
20 | image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA)
21 | # Displaying the output image
22 | cv2.imshow(window_name1, image_1)
23 | cv2.waitKey(0)
24 | cv2.destroyAllWindows()
25 |
--------------------------------------------------------------------------------
/trainer/training_config.cfg:
--------------------------------------------------------------------------------
1 | [Input Output]
2 | checkpoint_dir = ./training
3 | expression_basis_fname = ./training_data/init_expression_basis.npy
4 | template_fname = ./template/FLAME_sample.ply
5 | deepspeech_graph_fname = ./ds_graph/output_graph.pb
6 | face_or_body = body
7 | verts_mmaps_path = ./training_data/data_verts.npy
8 | raw_audio_path = ./training_data/raw_audio_fixed.pkl
9 | processed_audio_path = ./training_data/processed_audio_deepspeech.pkl
10 | templates_path = ./training_data/templates.pkl
11 | data2array_verts_path = ./training_data/subj_seq_to_idx.pkl
12 |
13 | [Audio Parameters]
14 | audio_feature_type = deepspeech
15 | num_audio_features = 29
16 | audio_window_size = 16
17 | audio_window_stride = 1
18 | condition_speech_features = True
19 | speech_encoder_size_factor = 1.0
20 |
21 | [Model Parameters]
22 | num_vertices = 10475
23 | expression_dim = 50
24 | init_expression = False
25 | num_consecutive_frames = 30
26 | absolute_reconstruction_loss = False
27 | velocity_weight = 10.0
28 | acceleration_weight = 0.0
29 | verts_regularizer_weight = 0.0
30 |
31 | [Data Setup]
32 | subject_for_training = speeker_oliver
33 | sequence_for_training = 0-00'00'05-00'00'10 1-00'00'32-00'00'37 2-00'01'05-00'01'10
34 | subject_for_validation = speeker_oliver
35 | sequence_for_validation = 2-00'01'05-00'01'10
36 | subject_for_testing = speeker_oliver
37 | sequence_for_testing = 2-00'01'05-00'01'10
38 |
39 | [Learning Parameters]
40 | batch_size = 64
41 | learning_rate = 1e-4
42 | decay_rate = 1.0
43 | epoch_num = 1000
44 | adam_beta1_value = 0.9
45 |
46 | [Visualization Parameters]
47 | num_render_sequences = 3
48 |
49 |
--------------------------------------------------------------------------------
/config/face.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "json",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/ExpressiveWholeBodyDatasetReleaseV1.0",
15 | "pklname": "_3d_wv2.pkl",
16 | "whole_video": true,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "face",
36 | "model_name": "s2g_face",
37 | "AudioOpt": "SGD",
38 | "encoder_choice": "faceformer",
39 | "gan": false
40 | },
41 | "DataLoader": {
42 | "batch_size": 1,
43 | "num_workers": 16
44 | },
45 | "Train": {
46 | "epochs": 100,
47 | "max_gradient_norm": 5,
48 | "learning_rate": {
49 | "generator_learning_rate": 1e-3,
50 | "discriminator_learning_rate": 1e-4
51 | }
52 | },
53 | "Log": {
54 | "save_every": 50,
55 | "print_every": 1000,
56 | "name": "face"
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/config/LS3DCG.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "pickle",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15 | "pklname": "_3d_mfcc.pkl",
16 | "whole_video": false,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "body",
36 | "model_name": "s2g_LS3DCG",
37 | "code_num": 2048,
38 | "AudioOpt": "Adam",
39 | "encoder_choice": "mfcc",
40 | "gan": false
41 | },
42 | "DataLoader": {
43 | "batch_size": 128,
44 | "num_workers": 0
45 | },
46 | "Train": {
47 | "epochs": 100,
48 | "max_gradient_norm": 5,
49 | "learning_rate": {
50 | "generator_learning_rate": 1e-4,
51 | "discriminator_learning_rate": 1e-4
52 | },
53 | "weights": {
54 | "keypoint_loss_weight": 1.0,
55 | "gan_loss_weight": 1.0
56 | }
57 | },
58 | "Log": {
59 | "save_every": 50,
60 | "print_every": 200,
61 | "name": "LS3DCG"
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/config/body_vq.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "pickle",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/ExpressiveWholeBodyDatasetReleaseV1.0",
15 | "pklname": "_3d_encodec_viclip_200m.pkl",
16 | "whole_video": false,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "body",
36 | "model_name": "s2g_body_vq",
37 | "composition": true,
38 | "code_num": 2048,
39 | "bh_model": true,
40 | "AudioOpt": "Adam",
41 | "encoder_choice": "mfcc",
42 | "gan": false
43 | },
44 | "DataLoader": {
45 | "batch_size": 128,
46 | "num_workers": 0
47 | },
48 | "Train": {
49 | "epochs": 100,
50 | "max_gradient_norm": 5,
51 | "learning_rate": {
52 | "generator_learning_rate": 5e-5,
53 | "discriminator_learning_rate": 1e-4
54 | }
55 | },
56 | "Log": {
57 | "save_every": 50,
58 | "print_every": 200,
59 | "name": "body-vq"
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/config/body_pixel.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "pickle",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 |
14 | "Data": {
15 | "data_root": "./ExpressiveWholeBodyDatasetReleaseV1.0",
16 | "pklname": "_3d_encodec_viclip_200m_token.pkl",
17 | "whole_video": false,
18 | "pose": {
19 | "normalization": false,
20 | "convert_to_6d": false,
21 | "norm_method": "all",
22 | "augmentation": false,
23 | "generate_length": 88,
24 | "pre_pose_length": 0,
25 | "pose_dim": 99,
26 | "expression": true
27 | },
28 | "aud": {
29 | "feat_method": "mfcc",
30 | "aud_feat_dim":128,
31 | "aud_feat_win_size": null,
32 | "context_info": false
33 | }
34 | },
35 |
36 | "Model": {
37 | "model_type": "body",
38 | "model_name": "s2g_body_pixel",
39 | "composition": true,
40 | "code_num": 2048,
41 | "bh_model": true,
42 | "AudioOpt": "Adam",
43 | "encoder_choice": "mfcc",
44 | "gan": false,
45 | "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
46 | },
47 |
48 | "DataLoader": {
49 | "batch_size": 128,
50 | "num_workers": 0
51 | },
52 |
53 | "Train": {
54 | "epochs": 300,
55 | "max_gradient_norm": 5,
56 | "learning_rate": {
57 | "generator_learning_rate": 1e-4,
58 | "discriminator_learning_rate": 1e-4
59 | }
60 | },
61 |
62 | "Log": {
63 | "save_every": 50,
64 | "print_every": 400,
65 | "name": "body-pixel2"
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/data_utils/get_j.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def to3d(poses, config):
5 | if config.Data.pose.convert_to_6d:
6 | if config.Data.pose.expression:
7 | poses_exp = poses[:, -100:]
8 | poses = poses[:, :-100]
9 |
10 | poses = poses.reshape(poses.shape[0], -1, 5)
11 | sin, cos = poses[:, :, 3], poses[:, :, 4]
12 | pose_angle = torch.atan2(sin, cos)
13 | poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
14 |
15 | if config.Data.pose.expression:
16 | poses = torch.cat([poses, poses_exp], dim=-1)
17 | return poses
18 |
19 |
20 | def get_joint(smplx_model, betas, pred):
21 | joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
22 | expression=pred[:, 165:265],
23 | jaw_pose=pred[:, 0:3],
24 | leye_pose=pred[:, 3:6],
25 | reye_pose=pred[:, 6:9],
26 | global_orient=pred[:, 9:12],
27 | body_pose=pred[:, 12:75],
28 | left_hand_pose=pred[:, 75:120],
29 | right_hand_pose=pred[:, 120:165],
30 | return_verts=True)['joints']
31 | return joint
32 |
33 |
34 | def get_joints(smplx_model, betas, pred):
35 | if len(pred.shape) == 3:
36 | B = pred.shape[0]
37 | x = 4 if B>= 4 else B
38 | T = pred.shape[1]
39 | pred = pred.reshape(-1, 265)
40 | smplx_model.batch_size = L = T * x
41 |
42 | times = pred.shape[0] // smplx_model.batch_size
43 | joints = []
44 | for i in range(times):
45 | joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
46 | joints = torch.cat(joints, dim=0)
47 | joints = joints.reshape(B, T, -1, 3)
48 | else:
49 | smplx_model.batch_size = pred.shape[0]
50 | joints = get_joint(smplx_model, betas, pred)
51 | return joints
--------------------------------------------------------------------------------
/data_utils/apply_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import pickle
4 | import shutil
5 |
6 | speakers = ['seth', 'oliver', 'conan', 'chemistry']
7 | source_data_root = "../expressive_body-V0.7"
8 | data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
9 |
10 | f_read = open('split_more_than_2s.pkl', 'rb')
11 | f_save = open('none.pkl', 'wb')
12 | data_split = pickle.load(f_read)
13 | none_split = []
14 |
15 | train = val = test = 0
16 |
17 | for speaker_name in speakers:
18 | speaker_root = os.path.join(data_root, speaker_name)
19 |
20 | videos = [v for v in data_split[speaker_name]]
21 |
22 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
23 | for split in data_split[speaker_name][vid]:
24 | for seq in data_split[speaker_name][vid][split]:
25 |
26 | seq = seq.replace('\\', '/')
27 | old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
28 | old_file_path = old_file_path.replace('\\', '/')
29 | new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
30 | try:
31 | shutil.move(old_file_path, new_file_path)
32 | if split == 'train':
33 | train = train + 1
34 | elif split == 'test':
35 | test = test + 1
36 | elif split == 'val':
37 | val = val + 1
38 | except FileNotFoundError:
39 | none_split.append(old_file_path)
40 | print(f"The file {old_file_path} does not exists.")
41 | except shutil.Error:
42 | none_split.append(old_file_path)
43 | print(f"The file {old_file_path} does not exists.")
44 |
45 | print(none_split.__len__())
46 | pickle.dump(none_split, f_save)
47 | f_save.close()
48 |
49 | print(train, val, test)
50 |
51 |
52 |
--------------------------------------------------------------------------------
/trainer/options.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | def parse_args():
4 | parser = ArgumentParser()
5 | parser.add_argument('--gpu', default=0, type=int)
6 | parser.add_argument('--save_dir', default='experiments', type=str)
7 | parser.add_argument('--exp_name', default='smplx_S2G', type=str)
8 | parser.add_argument('--speakers', nargs='+')
9 | parser.add_argument('--seed', default=1, type=int)
10 | parser.add_argument('--model_name', type=str)
11 | #parser.add_argument('--bert_config', default='experiments', type=str)
12 | #for Tmpt and S2G
13 | parser.add_argument('--use_template', action='store_true')
14 | parser.add_argument('--template_length', default=0, type=int)
15 | #for training from a ckpt
16 | parser.add_argument('--resume', action='store_true')
17 | parser.add_argument('--pretrained_pth', default=None, type=str)
18 | parser.add_argument('--style_layer_norm', action='store_true')
19 | #required
20 | parser.add_argument('--config_file', default='./config/style_gestures.json', type=str)
21 | parser.add_argument('--bert_config', default='./config/bert_config.json', type=str)
22 | # for visualization and test
23 | parser.add_argument('--audio_file', default=None, type=str)
24 | parser.add_argument('--id', default=0, type=int, help='0=oliver, 1=chemistry, 2=seth, 3=conan')
25 | parser.add_argument('--only_face', action='store_true')
26 | parser.add_argument('--stand', action='store_true')
27 | parser.add_argument('--num_sample', default=1, type=int)
28 | parser.add_argument('--whole_body', action='store_true')
29 | parser.add_argument('--face_model_name', default='s2g_face', type=str)
30 | parser.add_argument('--face_model_path', default='./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth', type=str)
31 | parser.add_argument('--body_model_name', default='s2g_body_pixel', type=str)
32 | parser.add_argument('--body_model_path', default='./viclip_lr1e-4_layer6_head8_repeat_half/2024-02-25-smplx_S2G-body-pixel2/ckpt-299.pth', type=str)
33 | parser.add_argument('--infer', action='store_true')
34 | return parser
35 |
36 |
--------------------------------------------------------------------------------
/scripts/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/evaluation/diversity_LVD.py:
--------------------------------------------------------------------------------
1 | '''
2 | LVD: different initial pose
3 | diversity: same initial pose
4 | '''
5 | import os
6 | import sys
7 | sys.path.append(os.getcwd())
8 |
9 | from glob import glob
10 |
11 | from argparse import ArgumentParser
12 | import json
13 |
14 | from evaluation.util import *
15 | from evaluation.metrics import *
16 | from tqdm import tqdm
17 |
18 | parser = ArgumentParser()
19 | parser.add_argument('--speaker', required=True, type=str)
20 | parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
21 | args = parser.parse_args()
22 |
23 | speaker = args.speaker
24 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
25 |
26 | LVD_list = []
27 | diversity_list = []
28 |
29 | for aud in tqdm(test_audios):
30 | base_name = os.path.splitext(aud)[0]
31 | gt_path = get_full_path(aud, speaker, 'val')
32 | _, gt_poses, _ = get_gts(gt_path)
33 | gt_poses = gt_poses[np.newaxis,...]
34 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
35 | for post_fix in args.post_fix:
36 | pred_path = base_name + '_'+post_fix+'.json'
37 | pred_poses = np.array(json.load(open(pred_path)))
38 | # print(pred_poses.shape)#(B, seq_len, 108)
39 | pred_poses = cvt25(pred_poses, gt_poses)
40 | # print(pred_poses.shape)#(B, seq, pose_dim)
41 |
42 | gt_valid_points = hand_points(gt_poses)
43 | pred_valid_points = hand_points(pred_poses)
44 |
45 | lvd = LVD(gt_valid_points, pred_valid_points)
46 | # div = diversity(pred_valid_points)
47 |
48 | LVD_list.append(lvd)
49 | # diversity_list.append(div)
50 |
51 | # gt_velocity = peak_velocity(gt_valid_points, order=2)
52 | # pred_velocity = peak_velocity(pred_valid_points, order=2)
53 |
54 | # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
55 | # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
56 |
57 | # gt_consistency_list.append(gt_consistency)
58 | # pred_consistency_list.append(pred_consistency)
59 |
60 | lvd = np.mean(LVD_list)
61 | # diversity_list = np.mean(diversity_list)
62 |
63 | print('LVD:', lvd)
64 | # print("diversity:", diversity_list)
--------------------------------------------------------------------------------
/evaluation/get_quality_samples.py:
--------------------------------------------------------------------------------
1 | '''
2 | '''
3 | import os
4 | import sys
5 | sys.path.append(os.getcwd())
6 |
7 | from glob import glob
8 |
9 | from argparse import ArgumentParser
10 | import json
11 |
12 | from evaluation.util import *
13 | from evaluation.metrics import *
14 | from tqdm import tqdm
15 |
16 | parser = ArgumentParser()
17 | parser.add_argument('--speaker', required=True, type=str)
18 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
19 | args = parser.parse_args()
20 |
21 | speaker = args.speaker
22 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
23 |
24 | quality_samples={'gt':[]}
25 | for post_fix in args.post_fix:
26 | quality_samples[post_fix] = []
27 |
28 | for aud in tqdm(test_audios):
29 | base_name = os.path.splitext(aud)[0]
30 | gt_path = get_full_path(aud, speaker, 'val')
31 | _, gt_poses, _ = get_gts(gt_path)
32 | gt_poses = gt_poses[np.newaxis,...]
33 | gt_valid_points = valid_points(gt_poses)
34 | # print(gt_valid_points.shape)
35 | quality_samples['gt'].append(gt_valid_points)
36 |
37 | for post_fix in args.post_fix:
38 | pred_path = base_name + '_'+post_fix+'.json'
39 | pred_poses = np.array(json.load(open(pred_path)))
40 | # print(pred_poses.shape)#(B, seq_len, 108)
41 | pred_poses = cvt25(pred_poses, gt_poses)
42 | # print(pred_poses.shape)#(B, seq, pose_dim)
43 |
44 | pred_valid_points = valid_points(pred_poses)[0:1]
45 | quality_samples[post_fix].append(pred_valid_points)
46 |
47 | quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
48 | for post_fix in args.post_fix:
49 | quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
50 |
51 | print('gt:', quality_samples['gt'].shape)
52 | quality_samples['gt'] = quality_samples['gt'].tolist()
53 | for post_fix in args.post_fix:
54 | print(post_fix, ':', quality_samples[post_fix].shape)
55 | quality_samples[post_fix] = quality_samples[post_fix].tolist()
56 |
57 | save_dir = '../../experiments/'
58 | os.makedirs(save_dir, exist_ok=True)
59 | save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
60 | with open(save_name, 'w') as f:
61 | json.dump(quality_samples, f)
62 |
63 |
--------------------------------------------------------------------------------
/evaluation/mode_transition.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | from glob import glob
6 |
7 | from argparse import ArgumentParser
8 | import json
9 |
10 | from evaluation.util import *
11 | from evaluation.metrics import *
12 | from tqdm import tqdm
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--speaker', required=True, type=str)
16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17 | args = parser.parse_args()
18 |
19 | speaker = args.speaker
20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21 |
22 | precision_list=[]
23 | recall_list=[]
24 | accuracy_list=[]
25 |
26 | for aud in tqdm(test_audios):
27 | base_name = os.path.splitext(aud)[0]
28 | gt_path = get_full_path(aud, speaker, 'val')
29 | _, gt_poses, _ = get_gts(gt_path)
30 | if gt_poses.shape[0] < 50:
31 | continue
32 | gt_poses = gt_poses[np.newaxis,...]
33 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
34 | for post_fix in args.post_fix:
35 | pred_path = base_name + '_'+post_fix+'.json'
36 | pred_poses = np.array(json.load(open(pred_path)))
37 | # print(pred_poses.shape)#(B, seq_len, 108)
38 | pred_poses = cvt25(pred_poses, gt_poses)
39 | # print(pred_poses.shape)#(B, seq, pose_dim)
40 |
41 | gt_valid_points = valid_points(gt_poses)
42 | pred_valid_points = valid_points(pred_poses)
43 |
44 | # print(gt_valid_points.shape, pred_valid_points.shape)
45 |
46 | gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
47 | pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
48 |
49 | # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
50 | # pred_mode_transition_seq = baseline
51 | precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
52 | precision_list.append(precision)
53 | recall_list.append(recall)
54 | accuracy_list.append(accuracy)
55 | print(len(precision_list), len(recall_list), len(accuracy_list))
56 | precision_list = np.mean(precision_list)
57 | recall_list = np.mean(recall_list)
58 | accuracy_list = np.mean(accuracy_list)
59 |
60 | print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
61 |
--------------------------------------------------------------------------------
/evaluation/peak_velocity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | from glob import glob
6 |
7 | from argparse import ArgumentParser
8 | import json
9 |
10 | from evaluation.util import *
11 | from evaluation.metrics import *
12 | from tqdm import tqdm
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--speaker', required=True, type=str)
16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17 | args = parser.parse_args()
18 |
19 | speaker = args.speaker
20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21 |
22 | gt_consistency_list=[]
23 | pred_consistency_list=[]
24 |
25 | for aud in tqdm(test_audios):
26 | base_name = os.path.splitext(aud)[0]
27 | gt_path = get_full_path(aud, speaker, 'val')
28 | _, gt_poses, _ = get_gts(gt_path)
29 | gt_poses = gt_poses[np.newaxis,...]
30 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
31 | for post_fix in args.post_fix:
32 | pred_path = base_name + '_'+post_fix+'.json'
33 | pred_poses = np.array(json.load(open(pred_path)))
34 | # print(pred_poses.shape)#(B, seq_len, 108)
35 | pred_poses = cvt25(pred_poses, gt_poses)
36 | # print(pred_poses.shape)#(B, seq, pose_dim)
37 |
38 | gt_valid_points = hand_points(gt_poses)
39 | pred_valid_points = hand_points(pred_poses)
40 |
41 | gt_velocity = peak_velocity(gt_valid_points, order=2)
42 | pred_velocity = peak_velocity(pred_valid_points, order=2)
43 |
44 | gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
45 | pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
46 |
47 | gt_consistency_list.append(gt_consistency)
48 | pred_consistency_list.append(pred_consistency)
49 |
50 | gt_consistency_list = np.concatenate(gt_consistency_list)
51 | pred_consistency_list = np.concatenate(pred_consistency_list)
52 |
53 | print(gt_consistency_list.max(), gt_consistency_list.min())
54 | print(pred_consistency_list.max(), pred_consistency_list.min())
55 | print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
56 | print(np.std(gt_consistency_list), np.std(pred_consistency_list))
57 |
58 | draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
59 | draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
60 |
61 | to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
62 | to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
63 |
64 | np.save('%s_gt.npy'%(speaker), gt_consistency_list)
65 | np.save('%s_pred.npy'%(speaker), pred_consistency_list)
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 | class KeypointLoss(nn.Module):
12 | def __init__(self):
13 | super(KeypointLoss, self).__init__()
14 |
15 | def forward(self, pred_seq, gt_seq, gt_conf=None):
16 | #pred_seq: (B, C, T)
17 | if gt_conf is not None:
18 | gt_conf = gt_conf >= 0.01
19 | return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
20 | else:
21 | return F.mse_loss(pred_seq, gt_seq)
22 |
23 |
24 | class KLLoss(nn.Module):
25 | def __init__(self, kl_tolerance):
26 | super(KLLoss, self).__init__()
27 | self.kl_tolerance = kl_tolerance
28 |
29 | def forward(self, mu, var, mul=1):
30 | kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
31 | kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
32 | # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
33 | if self.kl_tolerance is not None:
34 | # above_line = kld_loss[kld_loss > self.kl_tolerance]
35 | # if len(above_line) > 0:
36 | # kld_loss = torch.mean(kld_loss)
37 | # else:
38 | # kld_loss = 0
39 | kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
40 | # else:
41 | kld_loss = torch.mean(kld_loss)
42 | return kld_loss
43 |
44 |
45 | class L2KLLoss(nn.Module):
46 | def __init__(self, kl_tolerance):
47 | super(L2KLLoss, self).__init__()
48 | self.kl_tolerance = kl_tolerance
49 |
50 | def forward(self, x):
51 | # TODO: check
52 | kld_loss = torch.sum(x ** 2, dim=1)
53 | if self.kl_tolerance is not None:
54 | above_line = kld_loss[kld_loss > self.kl_tolerance]
55 | if len(above_line) > 0:
56 | kld_loss = torch.mean(kld_loss)
57 | else:
58 | kld_loss = 0
59 | else:
60 | kld_loss = torch.mean(kld_loss)
61 | return kld_loss
62 |
63 | class L2RegLoss(nn.Module):
64 | def __init__(self):
65 | super(L2RegLoss, self).__init__()
66 |
67 | def forward(self, x):
68 | #TODO: check
69 | return torch.sum(x**2)
70 |
71 |
72 | class L2Loss(nn.Module):
73 | def __init__(self):
74 | super(L2Loss, self).__init__()
75 |
76 | def forward(self, x):
77 | # TODO: check
78 | return torch.sum(x ** 2)
79 |
80 |
81 | class AudioLoss(nn.Module):
82 | def __init__(self):
83 | super(AudioLoss, self).__init__()
84 |
85 | def forward(self, dynamics, gt_poses):
86 | #pay attention, normalized
87 | mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
88 | gt = gt_poses - mean
89 | return F.mse_loss(dynamics, gt)
90 |
91 | L1Loss = nn.L1Loss
--------------------------------------------------------------------------------
/data_utils/test.py:
--------------------------------------------------------------------------------
1 | """def get_encodec(audio_fn,model):
2 | samples, sr = ta.load(audio_fn)
3 | seq_lenth = samples.shape[1] / sr
4 | print("sr:",sr)
5 | samples = samples.to('cuda')
6 |
7 | if samples.shape[0] > 1:
8 | samples = torch.mean(samples, dim=0, keepdim=True)
9 | print("samples audio:",sample_audio.shape)
10 | with torch.no_grad():
11 | #model = EncodecModel.encodec_model_24khz().to('cuda')
12 | model.set_target_bandwidth(6)
13 | samples = samples.unsqueeze(0)
14 | codes_raw = model.encode(samples)
15 | for frame in codes_raw:
16 | codes,_ = frame
17 | codes = codes.transpose(0,1)
18 | emb = model.quantizer.decode(codes)
19 | emb = emb.transpose(1,2)
20 | emb = linear_interpolation(emb,seq_len=seq_lenth,output_fps=30,output_len=None)
21 | emb = emb.squeeze(0).cpu().numpy()
22 | return emb"""
23 |
24 | def get_encodec(audio_fn,model):
25 | wav, sr = torchaudio.load(audio_fn)
26 | model.set_target_bandwidth(6.0)
27 | #print(sr,wav.shape)
28 | wav = convert_audio(wav, sr, model.sample_rate, model.channels)
29 | #print(wav.shape)
30 | seq_lenth = wav.shape[1]/model.sample_rate
31 | #print(seq_lenth)
32 | wav = wav.unsqueeze(0).to("cuda")
33 | # Extract discrete codes from EnCodec
34 | with torch.no_grad():
35 | encoded_frames = model.encode(wav)
36 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze(0)+1
37 |
38 | token_sample = get_target_token(fps=30,codes=codes,duration_time=seq_lenth,num_codes_layer=-3)
39 | # [B, n_q, T]
40 | return token_sample
41 |
42 | def interpolate_vector(input_vector, outdim):
43 | indim = input_vector.shape[1]
44 | interval = indim / outdim
45 | # 生成采样索引
46 | idx = (np.arange(outdim) * interval).astype(int)
47 |
48 | # 等间隔采样
49 | output_vector = input_vector[:, idx]
50 |
51 | return output_vector
52 |
53 | def get_target_token(fps,codes,duration_time,num_codes_layer):
54 | seq_len = fps*duration_time
55 | #print(codes.shape) ### 8x750
56 | token_codes = codes[num_codes_layer,:].unsqueeze(0) ### 1x750
57 | for t in token_codes:
58 | p = torch.unique_consecutive(t)
59 | print(p.shape)
60 | token_sample = interpolate_vector(token_codes,seq_len) ### 1x300
61 | #print(token_sample.shape)
62 | return token_sample
63 |
64 |
65 | import numpy as np
66 |
67 |
68 | if __name__ == "__main__":
69 | audio_fn = '/mnt/nj-aigc/usr/pengwenshuo/TalkSHOW/demo_audio/214428-00_00_58-00_01_08.wav'
70 | from encodec import EncodecModel
71 | from encodec.utils import convert_audio
72 | from pathlib import Path
73 | import torchaudio
74 | import torch
75 | model = EncodecModel.encodec_model_24khz(repository=Path("/mnt/nj-aigc/dataset/pengwenshuo/encodec")).to('cuda')
76 | codes = get_encodec(audio_fn,model)
77 | codes = codes+1
78 | for code in codes:
79 | c = torch.unique_consecutive(code)
80 | print(c.shape)
81 | #print(torch.max(codes),torch.min(codes))
82 | #token = get_target_token(fps=30,codes=codes,duration_time=duration_time,num_codes_layer=-1)
83 |
84 |
85 |
--------------------------------------------------------------------------------
/scripts/test_vq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 | sys.path.append(os.getcwd())
6 |
7 | from tqdm import tqdm
8 | from transformers import Wav2Vec2Processor
9 |
10 | from evaluation.metrics import LVD
11 |
12 | import numpy as np
13 | import smplx as smpl
14 |
15 | from data_utils.lower_body import part2full, poses2pred, c_index_3d
16 | from nets import *
17 | from nets.utils import get_path, get_dpath
18 | from trainer.options import parse_args
19 | from data_utils import torch_data
20 | from trainer.config import load_JsonConfig
21 |
22 | import torch
23 | from torch.utils import data
24 | from data_utils.get_j import to3d, get_joints
25 | from scripts.test_body import init_model, init_dataloader
26 |
27 |
28 | def test(test_loader, generator, config):
29 | print('start testing')
30 |
31 | loss_dict = {}
32 | B = 1
33 | with torch.no_grad():
34 | count = 0
35 | for bat in tqdm(test_loader, desc="Testing......"):
36 | count = count + 1
37 | aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
38 | bat['expression'].to('cuda').to(torch.float32)
39 | id = bat['speaker'].to('cuda') - 20
40 | betas = bat['betas'][0].to('cuda').to(torch.float64)
41 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
42 | poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2)
43 | # poses = poses[:, c_index_3d, :]
44 |
45 | cur_wav_file = bat['aud_file'][0]
46 |
47 | pred = generator.infer_on_audio(cur_wav_file,
48 | initial_pose=poses,
49 | id=id,
50 | fps=30,
51 | B=B
52 | )
53 | pred = torch.tensor(pred, device='cuda')
54 | bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()}
55 |
56 | if loss_dict: # 非空
57 | for key in list(bat_loss_dict.keys()):
58 | loss_dict[key] += bat_loss_dict[key]
59 | else:
60 | for key in list(bat_loss_dict.keys()):
61 | loss_dict[key] = bat_loss_dict[key]
62 | for key in loss_dict.keys():
63 | loss_dict[key] = loss_dict[key] / count
64 | print(key + '=' + str(loss_dict[key].item()))
65 |
66 |
67 | def main():
68 | parser = parse_args()
69 | args = parser.parse_args()
70 | device = torch.device(args.gpu)
71 | torch.cuda.set_device(device)
72 |
73 | config = load_JsonConfig(args.config_file)
74 |
75 | os.environ['smplx_npz_path'] = config.smplx_npz_path
76 | os.environ['extra_joint_path'] = config.extra_joint_path
77 | os.environ['j14_regressor_path'] = config.j14_regressor_path
78 |
79 | print('init dataloader...')
80 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
81 | print('init model...')
82 | model_name = 's2g_body_vq'
83 | model_type = 'n_com_8192'
84 | model_path = get_path(model_name, model_type)
85 | generator = init_model(model_name, model_path, args, config)
86 |
87 | test(test_loader, generator, config)
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
92 |
--------------------------------------------------------------------------------
/scripts/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | 1655101254730
48 |
49 |
50 | 1655101254730
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/nets/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 | class TrainWrapperBaseClass():
6 | def __init__(self, args, config) -> None:
7 | #self.init_optimizer()
8 | pass
9 | def init_optimizer(self) -> None:
10 | print('using Adam')
11 | self.generator_optimizer = optim.Adam(
12 | self.generator.parameters(),
13 | lr = self.config.Train.learning_rate.generator_learning_rate,
14 | betas=[0.9, 0.999]
15 | )
16 | if self.discriminator is not None:
17 | self.discriminator_optimizer = optim.Adam(
18 | self.discriminator.parameters(),
19 | lr = self.config.Train.learning_rate.discriminator_learning_rate,
20 | betas=[0.9, 0.999]
21 | )
22 |
23 | def __call__(self, bat):
24 | raise NotImplementedError
25 |
26 | def get_loss(self, **kwargs):
27 | raise NotImplementedError
28 |
29 | def state_dict(self):
30 | model_state = {
31 | 'generator': self.generator.state_dict(),
32 | 'generator_optim': self.generator_optimizer.state_dict(),
33 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
34 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
35 | }
36 | return model_state
37 |
38 | def parameters(self):
39 | return self.generator.parameters()
40 |
41 | def load_state_dict(self, state_dict):
42 | if 'generator' in state_dict:
43 | self.generator.load_state_dict(state_dict['generator'])
44 | else:
45 | self.generator.load_state_dict(state_dict)
46 |
47 |
48 | """if 'generator_optim' in state_dict and self.generator_optimizer is not None:
49 | self.generator_optimizer.load_state_dict(state_dict['generator_optim'])"""
50 |
51 | if self.discriminator is not None:
52 | self.discriminator.load_state_dict(state_dict['discriminator'])
53 |
54 | if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
55 | self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
56 |
57 | def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
58 | raise NotImplementedError
59 |
60 | def init_params(self):
61 | if self.config.Data.pose.convert_to_6d:
62 | scale = 2
63 | else:
64 | scale = 1
65 |
66 | global_orient = round(0 * scale)
67 | leye_pose = reye_pose = round(0 * scale)
68 | jaw_pose = round(0 * scale)
69 | body_pose = round((63 - 24) * scale)
70 | left_hand_pose = right_hand_pose = round(45 * scale)
71 | if self.expression:
72 | expression = 100
73 | else:
74 | expression = 0
75 |
76 | b_j = 0
77 | jaw_dim = jaw_pose
78 | b_e = b_j + jaw_dim
79 | eye_dim = leye_pose + reye_pose
80 | b_b = b_e + eye_dim
81 | body_dim = global_orient + body_pose
82 | b_h = b_b + body_dim
83 | hand_dim = left_hand_pose + right_hand_pose
84 | b_f = b_h + hand_dim
85 | face_dim = expression
86 |
87 | self.dim_list = [b_j, b_e, b_b, b_h, b_f]
88 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
89 | self.pose = int(self.full_dim / round(3 * scale))
90 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # T3M: Text Guided 3D Human Motion Synthesis from Speech[NAACL 2024]
2 | This repository contains the PyTorch implementation of the "T3M: Text Guided 3D Human Motion Synthesis from Speech" project. The goal of this project is to synthesize realistic 3D human motion based on both speech and text inputs.
3 | 
4 |
5 |
6 | ## Environment Setup
7 |
8 | To get started with this project, you will need to set up a Python environment using `miniconda3`. Follow the steps below to create the required environment:
9 | ### Prerequisites
10 |
11 | - Python 3.10
12 | - [Miniconda3](https://docs.anaconda.com/miniconda/)
13 |
14 | ### Creating the Environment
15 |
16 | 1. Install Miniconda3 if you haven't done so already.
17 | 2. Create a new conda environment named `t3m` with Python 3.10 and install the dependencies:
18 | ```bash
19 | conda create -n t3m python=3.10
20 | conda activate t3m
21 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
22 | pip install -r requirements.txt
23 | ```
24 |
25 | ## Usage
26 | To use this project, follow these steps:
27 |
28 | 1. Clone the repository:
29 | ```bash
30 | git clone https://github.com/Gloria2tt/T3M.git
31 | cd T3M
32 | 2. Download the dataset and pre-trained weight:
33 |
34 | We provide an enhanced approach compared to the original papers by utilizing a more advanced video-text alignment model, [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid), to extract video embeddings from the SHOW dataset.
35 |
36 | - Download the SHOW dataset:
37 | Download the Talkshow dataset from [this link](https://download.is.tue.mpg.de/download.php?domain=talkshow&resume=1&sfile=SHOW_dataset_v1.0.zip) and unzip the folder.
38 |
39 | - In addition to audio and pose data, the original video is also required for training. Download the original video following instructions from the [SHOW repository](https://github.com/yhw-yhw/SHOW?tab=readme-ov-file).
40 |
41 | - Extract the audio-aligned segments from the video based on the file names, and use the video encoder from InternVid to extract the video embeddings. We recommend performing this step on A100 or H100 GPUs.
42 |
43 | - Following the instructions in the TalkSHOW repository, download the pre-trained face model and VQ-VAE model from [this](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view), as our paper modifies only the body and hand generation parts.
44 | - I've noticed that the SHOW repository no longer contains the original videos. Therefore, we've established a new repository to facilitate the download of the preprocessed dataset, which you can find [here](https://huggingface.co/Wenshuo1/t3m_dataset/tree/main)
45 |
46 | ### Train
47 | To train the model, you need to modify the body_pixel.json configuration file to match your environment:
48 |
49 | - If this is your first time running the code, set the dataset_load_mode option from pickle to json.
50 | Adjust the vq_path option to match the location of your folder.
51 |
52 | - Adjust the vq_path option to match the location of your folder.
53 |
54 | Finally, use the following command to start training:
55 |
56 | sh train_body_pixel.sh
57 |
58 |
59 | ### Visualize
60 |
61 | 1. To visualize the results after training, ensure you have ffmpeg installed:
62 |
63 | sudo apt-get install ffmpeg
64 | 2. Run the visualization script:
65 |
66 | bash visualise.sh
67 | 3. Alternatively, you can visualize a specific audio file:
68 |
69 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file your/voice/file
70 |
71 | makesure you have changed the model path correctly.
72 |
73 | ## Citation
74 | If you find our work interesting, please consider citing:
75 |
76 | @inproceedings{peng2024t3m,
77 | title={T3M: Text Guided 3D Human Motion Synthesis from Speech},
78 | author={Peng, Wenshuo and Zhang, Kaipeng and Zhang, Sai Qian},
79 | booktitle={Findings of the Association for Computational Linguistics: NAACL 2024},
80 | pages={1168--1177},
81 | year={2024}
82 | }
83 |
84 | ## Acknowledgement
85 | Our code is built upon [TalkSHOW](https://github.com/yhw-yhw/TalkSHOW) and [SHOW](https://github.com/yhw-yhw/SHOW). We specifically thanks [Hongwei Yi](https://xyyhw.top/) for sharing their codebase.
86 |
87 | ## Contact
88 | Any questions just send me an email(gin2pws@gmail.com) directly.
89 |
90 |
--------------------------------------------------------------------------------
/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Warning: metrics are for reference only, may have limited significance
3 | '''
4 | import os
5 | import sys
6 | sys.path.append(os.getcwd())
7 | import numpy as np
8 | import torch
9 |
10 | from data_utils.lower_body import rearrange, symmetry
11 | import torch.nn.functional as F
12 |
13 | def data_driven_baselines(gt_kps):
14 | '''
15 | gt_kps: T, D
16 | '''
17 | gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
18 |
19 | mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
20 | mean = np.mean(np.abs(gt_velocity-mean))
21 | last_step = gt_kps[1] - gt_kps[0]
22 | last_step = last_step[np.newaxis] #(1, D)
23 | last_step = np.mean(np.abs(gt_velocity-last_step))
24 | return last_step, mean
25 |
26 | def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
27 | if gt_kps.shape[0] > pr_kps.shape[1]:
28 | length = pr_kps.shape[1]
29 | else:
30 | length = gt_kps.shape[0]
31 | gt_kps = gt_kps[:length]
32 | pr_kps = pr_kps[:, :length]
33 | global symmetry
34 | symmetry = torch.tensor(symmetry).bool()
35 |
36 | if symmetrical:
37 | # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
38 | gt_kps = gt_kps[:, rearrange]
39 | ns_gt_kps = gt_kps[:, ~symmetry]
40 | ys_gt_kps = gt_kps[:, symmetry]
41 | ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
42 | ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
43 | ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
44 | left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
45 | right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
46 | move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
47 | ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
48 | ys_gt_velocity = ys_gt_velocity.transpose(0,1)
49 | gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
50 |
51 | pr_kps = pr_kps[:, :, rearrange]
52 | ns_pr_kps = pr_kps[:, :, ~symmetry]
53 | ys_pr_kps = pr_kps[:, :, symmetry]
54 | ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
55 | ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
56 | ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
57 | left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
58 | right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
59 | move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
60 | torch.zeros(left_pr_vel.shape).cuda())
61 | ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
62 | ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
63 | ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
64 | pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
65 | else:
66 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
67 | pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
68 |
69 | if weight:
70 | w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
71 | else:
72 | w = 1 / gt_velocity.shape[0]
73 |
74 | v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
75 |
76 | return v_diff
77 |
78 |
79 | def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
80 | gt_kps = gt_kps.squeeze()
81 | pr_kps = pr_kps.squeeze()
82 | if len(pr_kps.shape) == 4:
83 | return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
84 | # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
85 | length = gt_kps.shape[0]-10
86 | # gt_kps = gt_kps[25:length]
87 | # pr_kps = pr_kps[25:length] #(T, D)
88 | # if pr_kps.shape[0] < gt_kps.shape[0]:
89 | # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
90 |
91 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
92 | pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
93 |
94 | return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
95 |
96 | def diversity(kps):
97 | '''
98 | kps: bs, seq, dim
99 | '''
100 | dis_list = []
101 | #the distance between each pair
102 | for i in range(kps.shape[0]):
103 | for j in range(i+1, kps.shape[0]):
104 | seq_i = kps[i]
105 | seq_j = kps[j]
106 |
107 | dis = np.mean(np.abs(seq_i - seq_j))
108 | dis_list.append(dis)
109 | return np.mean(dis_list)
110 |
--------------------------------------------------------------------------------
/scripts/.idea/get_prevar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4 |
5 | sys.path.append(os.getcwd())
6 | from glob import glob
7 |
8 | import numpy as np
9 | import json
10 | import smplx as smpl
11 |
12 | from nets import *
13 | from repro_nets import *
14 | from trainer.options import parse_args
15 | from data_utils import torch_data
16 | from trainer.config import load_JsonConfig
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.utils import data
22 |
23 | def init_model(model_name, model_path, args, config):
24 | if model_name == 'freeMo':
25 | # generator = freeMo_Generator(args)
26 | # generator = freeMo_Generator(args)
27 | generator = freeMo_dev(args, config)
28 | # generator.load_state_dict(torch.load(model_path)['generator'])
29 | elif model_name == 'smplx_S2G':
30 | generator = smplx_S2G(args, config)
31 | elif model_name == 'StyleGestures':
32 | generator = StyleGesture_Generator(
33 | args,
34 | config
35 | )
36 | elif model_name == 'Audio2Gestures':
37 | config.Train.using_mspec_stat = False
38 | generator = Audio2Gesture_Generator(
39 | args,
40 | config,
41 | torch.zeros([1, 1, 108]),
42 | torch.ones([1, 1, 108])
43 | )
44 | elif model_name == 'S2G':
45 | generator = S2G_Generator(
46 | args,
47 | config,
48 | )
49 | elif model_name == 'Tmpt':
50 | generator = S2G_Generator(
51 | args,
52 | config,
53 | )
54 | else:
55 | raise NotImplementedError
56 |
57 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
58 | if model_name == 'smplx_S2G':
59 | generator.generator.load_state_dict(model_ckpt['generator']['generator'])
60 | elif 'generator' in list(model_ckpt.keys()):
61 | generator.load_state_dict(model_ckpt['generator'])
62 | else:
63 | model_ckpt = {'generator': model_ckpt}
64 | generator.load_state_dict(model_ckpt)
65 |
66 | return generator
67 |
68 |
69 |
70 | def prevar_loader(data_root, speakers, args, config, model_path, device, generator):
71 | path = model_path.split('ckpt')[0]
72 | file = os.path.join(os.path.dirname(path), "pre_variable.npy")
73 | data_base = torch_data(
74 | data_root=data_root,
75 | speakers=speakers,
76 | split='pre',
77 | limbscaling=False,
78 | normalization=config.Data.pose.normalization,
79 | norm_method=config.Data.pose.norm_method,
80 | split_trans_zero=False,
81 | num_pre_frames=config.Data.pose.pre_pose_length,
82 | num_generate_length=config.Data.pose.generate_length,
83 | num_frames=15,
84 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
85 | aud_feat_dim=config.Data.aud.aud_feat_dim,
86 | feat_method=config.Data.aud.feat_method,
87 | smplx=True,
88 | audio_sr=22000,
89 | convert_to_6d=config.Data.pose.convert_to_6d,
90 | expression=config.Data.pose.expression
91 | )
92 |
93 | data_base.get_dataset()
94 | pre_set = data_base.all_dataset
95 | pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True)
96 |
97 | total_pose = []
98 |
99 | with torch.no_grad():
100 | for bat in pre_loader:
101 | pose = bat['poses'].to(device).to(torch.float32)
102 | expression = bat['expression'].to(device).to(torch.float32)
103 | pose = pose.permute(0, 2, 1)
104 | pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0)
105 | expression = expression.permute(0, 2, 1)
106 | expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0)
107 | pose = torch.cat([pose, expression], dim=-1)
108 | pose = pose.reshape(pose.shape[0], -1, 1)
109 | pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu()
110 | total_pose.append(np.asarray(pose_code))
111 | total_pose = np.concatenate(total_pose, axis=0)
112 | mean = np.mean(total_pose, axis=0)
113 | std = np.std(total_pose, axis=0)
114 | prevar = (mean, std)
115 | np.save(file, prevar, allow_pickle=True)
116 |
117 | return mean, std
118 |
119 | def main():
120 | parser = parse_args()
121 | args = parser.parse_args()
122 | device = torch.device(args.gpu)
123 | torch.cuda.set_device(device)
124 |
125 | config = load_JsonConfig(args.config_file)
126 |
127 | print('init model...')
128 | generator = init_model(config.Model.model_name, args.model_path, args, config)
129 | print('init pre-pose vectors...')
130 | mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator)
131 |
132 | main()
--------------------------------------------------------------------------------
/evaluation/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import json
5 | from matplotlib import pyplot as plt
6 | import pandas as pd
7 | def get_gts(clip):
8 | '''
9 | clip: abs path to the clip dir
10 | '''
11 | keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
12 |
13 | upper_body_points = list(np.arange(0, 25))
14 | poses = []
15 | confs = []
16 | neck_to_nose_len = []
17 | mean_position = []
18 | for kp_file in keypoints_files:
19 | kp_load = json.load(open(kp_file, 'r'))['people'][0]
20 | posepts = kp_load['pose_keypoints_2d']
21 | lhandpts = kp_load['hand_left_keypoints_2d']
22 | rhandpts = kp_load['hand_right_keypoints_2d']
23 | facepts = kp_load['face_keypoints_2d']
24 |
25 | neck = np.array(posepts).reshape(-1,3)[1]
26 | nose = np.array(posepts).reshape(-1,3)[0]
27 | x_offset = abs(neck[0]-nose[0])
28 | y_offset = abs(neck[1]-nose[1])
29 | neck_to_nose_len.append(y_offset)
30 | mean_position.append([neck[0],neck[1]])
31 |
32 | keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
33 |
34 | upper_body = keypoints[upper_body_points, :]
35 | hand_points = keypoints[25:, :]
36 | keypoints = np.vstack([upper_body, hand_points])
37 |
38 | poses.append(keypoints)
39 |
40 | if len(neck_to_nose_len) > 0:
41 | scale_factor = np.mean(neck_to_nose_len)
42 | else:
43 | raise ValueError(clip)
44 | mean_position = np.mean(np.array(mean_position), axis=0)
45 |
46 | unlocalized_poses = np.array(poses).copy()
47 | localized_poses = []
48 | for i in range(len(poses)):
49 | keypoints = poses[i]
50 | neck = keypoints[1].copy()
51 |
52 | keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
53 | keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
54 | localized_poses.append(keypoints.reshape(-1))
55 |
56 | localized_poses=np.array(localized_poses)
57 | return unlocalized_poses, localized_poses, (scale_factor, mean_position)
58 |
59 | def get_full_path(wav_name, speaker, split):
60 | '''
61 | get clip path from aud file
62 | '''
63 | wav_name = os.path.basename(wav_name)
64 | wav_name = os.path.splitext(wav_name)[0]
65 | clip_name, vid_name = wav_name[:10], wav_name[11:]
66 |
67 | full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
68 |
69 | assert os.path.isdir(full_path), full_path
70 |
71 | return full_path
72 |
73 | def smooth(res):
74 | '''
75 | res: (B, seq_len, pose_dim)
76 | '''
77 | window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
78 | w_size=7
79 | for i in range(10, res.shape[1]-3):
80 | window.append(res[:, i+3, :])
81 | if len(window) > w_size:
82 | window = window[1:]
83 |
84 | if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
85 | res[:, i, :] = np.mean(window, axis=1)
86 |
87 | return res
88 |
89 | def cvt25(pred_poses, gt_poses=None):
90 | '''
91 | gt_poses: (1, seq_len, 270), 135 *2
92 | pred_poses: (B, seq_len, 108), 54 * 2
93 | '''
94 | if gt_poses is None:
95 | gt_poses = np.zeros_like(pred_poses)
96 | else:
97 | gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
98 |
99 | length = min(pred_poses.shape[1], gt_poses.shape[1])
100 | pred_poses = pred_poses[:, :length, :]
101 | gt_poses = gt_poses[:, :length, :]
102 | gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
103 | pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
104 |
105 | gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
106 | gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
107 |
108 | return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
109 |
110 | def hand_points(seq):
111 | '''
112 | seq: (B, seq_len, 135*2)
113 | hands only
114 | '''
115 | hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
116 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
117 | return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
118 |
119 | def valid_points(seq):
120 | '''
121 | hands with some head points
122 | '''
123 | valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
124 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
125 |
126 | seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
127 | assert seq.shape[-1] == 108, seq.shape
128 | return seq
129 |
130 | def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
131 | plt.figure()
132 | plt.hist(seq, bins=100, range=(0, 100), color=color)
133 | plt.savefig(save_name)
134 |
135 | def to_excel(seq, save_name='res.xlsx'):
136 | '''
137 | seq: (T)
138 | '''
139 | df = pd.DataFrame(seq)
140 | writer = pd.ExcelWriter(save_name)
141 | df.to_excel(writer, 'sheet1')
142 | writer.save()
143 | writer.close()
144 |
145 |
146 | if __name__ == '__main__':
147 | random_data = np.random.randint(0, 10, 100)
148 | draw_cdf(random_data)
--------------------------------------------------------------------------------
/nets/body_ae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | from nets.base import TrainWrapperBaseClass
7 | from nets.spg.s2glayers import Discriminator as D_S2G
8 | from nets.spg.vqvae_1d import AE as s2g_body
9 | import torch
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 |
13 | from data_utils.lower_body import c_index, c_index_3d, c_index_6d
14 |
15 |
16 | def separate_aa(aa):
17 | aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
18 | axis = F.normalize(aa[:, :, :, :3], dim=-1)
19 | angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
20 | return axis, angle
21 |
22 |
23 | class TrainWrapper(TrainWrapperBaseClass):
24 | '''
25 | a wrapper receving a batch from data_utils and calculate loss
26 | '''
27 |
28 | def __init__(self, args, config):
29 | self.args = args
30 | self.config = config
31 | self.device = torch.device(self.args.gpu)
32 | self.global_step = 0
33 |
34 | self.gan = False
35 | self.convert_to_6d = self.config.Data.pose.convert_to_6d
36 | self.preleng = self.config.Data.pose.pre_pose_length
37 | self.expression = self.config.Data.pose.expression
38 | self.epoch = 0
39 | self.init_params()
40 | self.num_classes = 4
41 | self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
42 | num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
43 | if self.gan:
44 | self.discriminator = D_S2G(
45 | pose_dim=110 + 64, pose=self.pose
46 | ).to(self.device)
47 | else:
48 | self.discriminator = None
49 |
50 | if self.convert_to_6d:
51 | self.c_index = c_index_6d
52 | else:
53 | self.c_index = c_index_3d
54 |
55 | super().__init__(args, config)
56 |
57 | def init_optimizer(self):
58 |
59 | self.g_optimizer = optim.Adam(
60 | self.g.parameters(),
61 | lr=self.config.Train.learning_rate.generator_learning_rate,
62 | betas=[0.9, 0.999]
63 | )
64 |
65 | def state_dict(self):
66 | model_state = {
67 | 'g': self.g.state_dict(),
68 | 'g_optim': self.g_optimizer.state_dict(),
69 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
70 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
71 | }
72 | return model_state
73 |
74 |
75 | def __call__(self, bat):
76 | # assert (not self.args.infer), "infer mode"
77 | self.global_step += 1
78 |
79 | total_loss = None
80 | loss_dict = {}
81 |
82 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
83 |
84 | # id = bat['speaker'].to(self.device) - 20
85 | # id = F.one_hot(id, self.num_classes)
86 |
87 | poses = poses[:, self.c_index, :]
88 | gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
89 |
90 | loss = 0
91 | loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
92 |
93 | return total_loss, loss_dict
94 |
95 | def vq_train(self, gt, name, model, dict, total_loss, pre=None):
96 | x_recon = model(gt_poses=gt, pre_state=pre)
97 | loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
98 | # total_loss = total_loss + loss
99 |
100 | if name == 'g':
101 | optimizer_name = 'g_optimizer'
102 |
103 | optimizer = getattr(self, optimizer_name)
104 | optimizer.zero_grad()
105 | loss.backward()
106 | optimizer.step()
107 |
108 | for key in list(loss_dict.keys()):
109 | dict[name + key] = loss_dict.get(key, 0).item()
110 | return dict, total_loss
111 |
112 | def get_loss(self,
113 | pred_poses,
114 | gt_poses,
115 | pre=None
116 | ):
117 | loss_dict = {}
118 |
119 |
120 | rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
121 | v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
122 | v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
123 | velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
124 |
125 | if pre is None:
126 | f0_vel = 0
127 | else:
128 | v0_pr = pred_poses[:, 0] - pre[:, -1]
129 | v0_gt = gt_poses[:, 0] - pre[:, -1]
130 | f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
131 |
132 | gen_loss = rec_loss + velocity_loss + f0_vel
133 |
134 | loss_dict['rec_loss'] = rec_loss
135 | loss_dict['velocity_loss'] = velocity_loss
136 | # loss_dict['e_q_loss'] = e_q_loss
137 | if pre is not None:
138 | loss_dict['f0_vel'] = f0_vel
139 |
140 | return gen_loss, loss_dict
141 |
142 | def load_state_dict(self, state_dict):
143 | self.g.load_state_dict(state_dict['g'])
144 |
145 | def extract(self, x):
146 | self.g.eval()
147 | if x.shape[2] > self.full_dim:
148 | if x.shape[2] == 239:
149 | x = x[:, :, 102:]
150 | x = x[:, :, self.c_index]
151 | feat = self.g.encode(x)
152 | return feat.transpose(1, 2), x
153 |
--------------------------------------------------------------------------------
/data_utils/dataset_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from tqdm import tqdm
4 | import shutil
5 | import torch
6 | import numpy as np
7 | import librosa
8 | import random
9 |
10 | speakers = ['seth', 'conan', 'oliver', 'chemistry']
11 | data_root = "../ExpressiveWholeBodyDatasetv1.0/"
12 | split = 'train'
13 |
14 |
15 |
16 | def split_list(full_list,shuffle=False,ratio=0.2):
17 | n_total = len(full_list)
18 | offset_0 = int(n_total * ratio)
19 | offset_1 = int(n_total * ratio * 2)
20 | if n_total==0 or offset_1<1:
21 | return [],full_list
22 | if shuffle:
23 | random.shuffle(full_list)
24 | sublist_0 = full_list[:offset_0]
25 | sublist_1 = full_list[offset_0:offset_1]
26 | sublist_2 = full_list[offset_1:]
27 | return sublist_0, sublist_1, sublist_2
28 |
29 |
30 | def moveto(list, file):
31 | for f in list:
32 | before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
33 | new_path = os.path.join(before, file)
34 | new_path = os.path.join(new_path, after)
35 | # os.makedirs(new_path)
36 | # os.path.isdir(new_path)
37 | # shutil.move(f, new_path)
38 |
39 | #转移到新目录
40 | shutil.copytree(f, new_path)
41 | #删除原train里的文件
42 | shutil.rmtree(f)
43 | return None
44 |
45 |
46 | def read_pkl(data):
47 | betas = np.array(data['betas'])
48 |
49 | jaw_pose = np.array(data['jaw_pose'])
50 | leye_pose = np.array(data['leye_pose'])
51 | reye_pose = np.array(data['reye_pose'])
52 | global_orient = np.array(data['global_orient']).squeeze()
53 | body_pose = np.array(data['body_pose_axis'])
54 | left_hand_pose = np.array(data['left_hand_pose'])
55 | right_hand_pose = np.array(data['right_hand_pose'])
56 |
57 | full_body = np.concatenate(
58 | (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
59 |
60 | expression = np.array(data['expression'])
61 | full_body = np.concatenate((full_body, expression), axis=1)
62 |
63 | if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
64 | return 1
65 | else:
66 | return 0
67 |
68 |
69 | for speaker_name in speakers:
70 | speaker_root = os.path.join(data_root, speaker_name)
71 |
72 | videos = [v for v in os.listdir(speaker_root)]
73 | print(videos)
74 |
75 | haode = huaide = 0
76 | total_seqs = []
77 |
78 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
79 | # for vid in videos:
80 | source_vid = vid
81 | vid_pth = os.path.join(speaker_root, source_vid)
82 | # vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
83 | t = os.path.join(speaker_root, source_vid, 'test')
84 | v = os.path.join(speaker_root, source_vid, 'val')
85 |
86 | # if os.path.exists(t):
87 | # shutil.rmtree(t)
88 | # if os.path.exists(v):
89 | # shutil.rmtree(v)
90 | try:
91 | seqs = [s for s in os.listdir(vid_pth)]
92 | except:
93 | continue
94 | # if len(seqs) == 0:
95 | # shutil.rmtree(os.path.join(speaker_root, source_vid))
96 | # None
97 | for s in seqs:
98 | quality = 0
99 | total_seqs.append(os.path.join(vid_pth,s))
100 | seq_root = os.path.join(vid_pth, s)
101 | key = seq_root # correspond to clip******
102 | audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
103 |
104 | # delete the data without audio or the audio file could not be read
105 | if os.path.isfile(audio_fname):
106 | try:
107 | audio = librosa.load(audio_fname)
108 | except:
109 | # print(key)
110 | shutil.rmtree(key)
111 | huaide = huaide + 1
112 | continue
113 | else:
114 | huaide = huaide + 1
115 | # print(key)
116 | shutil.rmtree(key)
117 | continue
118 |
119 | # check motion file
120 | motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
121 | try:
122 | f = open(motion_fname, 'rb+')
123 | except:
124 | shutil.rmtree(key)
125 | huaide = huaide + 1
126 | continue
127 |
128 | data = pickle.load(f)
129 | w = read_pkl(data)
130 | f.close()
131 | quality = quality + w
132 |
133 | if w == 1:
134 | shutil.rmtree(key)
135 | # print(key)
136 | huaide = huaide + 1
137 | continue
138 |
139 | haode = haode + 1
140 |
141 | print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
142 |
143 | for speaker_name in speakers:
144 | speaker_root = os.path.join(data_root, speaker_name)
145 |
146 | videos = [v for v in os.listdir(speaker_root)]
147 | print(videos)
148 |
149 | haode = huaide = 0
150 | total_seqs = []
151 |
152 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
153 | # for vid in videos:
154 | source_vid = vid
155 | vid_pth = os.path.join(speaker_root, source_vid)
156 | try:
157 | seqs = [s for s in os.listdir(vid_pth)]
158 | except:
159 | continue
160 | for s in seqs:
161 | quality = 0
162 | total_seqs.append(os.path.join(vid_pth, s))
163 | print("total_seqs:{}".format(total_seqs.__len__()))
164 | # split the dataset
165 | test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
166 | print(len(test_list), len(val_list), len(train_list))
167 | moveto(train_list, 'train')
168 | moveto(test_list, 'test')
169 | moveto(val_list, 'val')
170 |
171 |
--------------------------------------------------------------------------------
/nets/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import textgrid as tg
3 | import numpy as np
4 |
5 | def get_parameter_size(model):
6 | total_num = sum(p.numel() for p in model.parameters())
7 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
8 | return total_num, trainable_num
9 |
10 | def denormalize(kps, data_mean, data_std):
11 | '''
12 | kps: (B, T, C)
13 | '''
14 | data_std = data_std.reshape(1, 1, -1)
15 | data_mean = data_mean.reshape(1, 1, -1)
16 | return (kps * data_std) + data_mean
17 |
18 | def normalize(kps, data_mean, data_std):
19 | '''
20 | kps: (B, T, C)
21 | '''
22 | data_std = data_std.squeeze().reshape(1, 1, -1)
23 | data_mean = data_mean.squeeze().reshape(1, 1, -1)
24 |
25 | return (kps-data_mean) / data_std
26 |
27 | def parse_audio(textgrid_file):
28 | '''a demo implementation'''
29 | words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when']
30 | txt=tg.TextGrid.fromFile(textgrid_file)
31 |
32 | total_time=int(np.ceil(txt.maxTime))
33 | code_seq=np.zeros(total_time)
34 |
35 | word_level=txt[0]
36 |
37 | for i in range(len(word_level)):
38 | start_time=word_level[i].minTime
39 | end_time=word_level[i].maxTime
40 | mark=word_level[i].mark
41 |
42 | if mark in words:
43 | start=int(np.round(start_time))
44 | end=int(np.round(end_time))
45 |
46 | if start >= len(code_seq) or end >= len(code_seq):
47 | code_seq[-1] = 1
48 | else:
49 | code_seq[start]=1
50 |
51 | return code_seq
52 |
53 |
54 | def get_path(model_name, model_type):
55 | if model_name == 's2g_body_pixel':
56 | if model_type == 'mfcc':
57 | return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
58 | elif model_type == 'wv2':
59 | return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth'
60 | elif model_type == 'random':
61 | return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth'
62 | elif model_type == 'wbhmodel':
63 | return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth'
64 | elif model_type == 'wobhmodel':
65 | return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth'
66 | elif model_name == 's2g_body':
67 | if model_type == 'a+m-vae':
68 | return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth'
69 | elif model_type == 'a-vae':
70 | return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth'
71 | elif model_type == 'a-ed':
72 | return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth'
73 | elif model_name == 's2g_LS3DCG':
74 | return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
75 | elif model_name == 's2g_body_vq':
76 | if model_type == 'n_com_1024':
77 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth'
78 | elif model_type == 'n_com_2048':
79 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth'
80 | elif model_type == 'n_com_4096':
81 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth'
82 | elif model_type == 'n_com_8192':
83 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth'
84 | elif model_type == 'n_com_16384':
85 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth'
86 | elif model_type == 'n_com_170000':
87 | return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth'
88 | elif model_type == 'com_1024':
89 | return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth'
90 | elif model_type == 'com_2048':
91 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth'
92 | elif model_type == 'com_4096':
93 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth'
94 | elif model_type == 'com_8192':
95 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth'
96 | elif model_type == 'com_16384':
97 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth'
98 |
99 |
100 | def get_dpath(model_name, model_type):
101 | if model_name == 's2g_body_pixel':
102 | if model_type == 'audio':
103 | return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth'
104 | elif model_type == 'wv2':
105 | return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth'
106 | elif model_type == 'random':
107 | return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth'
108 | elif model_type == 'wbhmodel':
109 | return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth'
110 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth'
111 | elif model_type == 'wobhmodel':
112 | return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth'
113 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth'
114 | elif model_name == 's2g_body':
115 | if model_type == 'a+m-vae':
116 | return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth'
117 | elif model_type == 'a-vae':
118 | return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth'
119 | elif model_type == 'a-ed':
120 | return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth'
121 | elif model_name == 's2g_LS3DCG':
122 | return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth'
--------------------------------------------------------------------------------
/nets/spg/wav2vec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | import math
7 | from transformers import Wav2Vec2Model,Wav2Vec2Config
8 | from transformers.modeling_outputs import BaseModelOutput
9 | from typing import Optional, Tuple
10 | _CONFIG_FOR_DOC = "Wav2Vec2Config"
11 |
12 | # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
13 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
14 | def _compute_mask_indices(
15 | shape: Tuple[int, int],
16 | mask_prob: float,
17 | mask_length: int,
18 | attention_mask: Optional[torch.Tensor] = None,
19 | min_masks: int = 0,
20 | ) -> np.ndarray:
21 | bsz, all_sz = shape
22 | mask = np.full((bsz, all_sz), False)
23 |
24 | all_num_mask = int(
25 | mask_prob * all_sz / float(mask_length)
26 | + np.random.rand()
27 | )
28 | all_num_mask = max(min_masks, all_num_mask)
29 | mask_idcs = []
30 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None
31 | for i in range(bsz):
32 | if padding_mask is not None:
33 | sz = all_sz - padding_mask[i].long().sum().item()
34 | num_mask = int(
35 | mask_prob * sz / float(mask_length)
36 | + np.random.rand()
37 | )
38 | num_mask = max(min_masks, num_mask)
39 | else:
40 | sz = all_sz
41 | num_mask = all_num_mask
42 |
43 | lengths = np.full(num_mask, mask_length)
44 |
45 | if sum(lengths) == 0:
46 | lengths[0] = min(mask_length, sz - 1)
47 |
48 | min_len = min(lengths)
49 | if sz - min_len <= num_mask:
50 | min_len = sz - num_mask - 1
51 |
52 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
53 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
54 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
55 |
56 | min_len = min([len(m) for m in mask_idcs])
57 | for i, mask_idc in enumerate(mask_idcs):
58 | if len(mask_idc) > min_len:
59 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
60 | mask[i, mask_idc] = True
61 | return mask
62 |
63 | # linear interpolation layer
64 | def linear_interpolation(features, input_fps, output_fps, output_len=None):
65 | features = features.transpose(1, 2)
66 | seq_len = features.shape[2] / float(input_fps)
67 | if output_len is None:
68 | output_len = int(seq_len * output_fps)
69 | output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear')
70 | return output_features.transpose(1, 2)
71 |
72 |
73 | class Wav2Vec2Model(Wav2Vec2Model):
74 | def __init__(self, config):
75 | super().__init__(config)
76 | def forward(
77 | self,
78 | input_values,
79 | attention_mask=None,
80 | output_attentions=None,
81 | output_hidden_states=None,
82 | return_dict=None,
83 | frame_num=None
84 | ):
85 | self.config.output_attentions = True
86 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
87 | output_hidden_states = (
88 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
89 | )
90 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91 |
92 | hidden_states = self.feature_extractor(input_values)
93 | hidden_states = hidden_states.transpose(1, 2)
94 |
95 | hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num)
96 |
97 | if attention_mask is not None:
98 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
99 | attention_mask = torch.zeros(
100 | hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
101 | )
102 | attention_mask[
103 | (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
104 | ] = 1
105 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
106 |
107 | hidden_states = self.feature_projection(hidden_states)
108 |
109 | if self.config.apply_spec_augment and self.training:
110 | batch_size, sequence_length, hidden_size = hidden_states.size()
111 | if self.config.mask_time_prob > 0:
112 | mask_time_indices = _compute_mask_indices(
113 | (batch_size, sequence_length),
114 | self.config.mask_time_prob,
115 | self.config.mask_time_length,
116 | attention_mask=attention_mask,
117 | min_masks=2,
118 | )
119 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
120 | if self.config.mask_feature_prob > 0:
121 | mask_feature_indices = _compute_mask_indices(
122 | (batch_size, hidden_size),
123 | self.config.mask_feature_prob,
124 | self.config.mask_feature_length,
125 | )
126 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
127 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
128 | encoder_outputs = self.encoder(
129 | hidden_states[0],
130 | attention_mask=attention_mask,
131 | output_attentions=output_attentions,
132 | output_hidden_states=output_hidden_states,
133 | return_dict=return_dict,
134 | )
135 | hidden_states = encoder_outputs[0]
136 | if not return_dict:
137 | return (hidden_states,) + encoder_outputs[1:]
138 |
139 | return BaseModelOutput(
140 | last_hidden_state=hidden_states,
141 | hidden_states=encoder_outputs.hidden_states,
142 | attentions=encoder_outputs.attentions,
143 | )
144 |
--------------------------------------------------------------------------------
/nets/spg/qformer.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel, BertLMHeadModel
2 | from transformers import BertTokenizer
3 | import transformers
4 | transformers.logging.set_verbosity_error()
5 | from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 |
11 | from .blip import create_vit, init_tokenizer, load_checkpoint
12 | import math
13 | import torch.nn.init as init
14 |
15 | class PositionalEncoding(nn.Module):
16 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
17 | super().__init__()
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len).unsqueeze(1)
20 | div_term = torch.exp(
21 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
22 | )
23 | pe[:, 0::2] = torch.sin(position * div_term)
24 | pe[:, 1::2] = torch.cos(position * div_term)
25 |
26 | self.register_buffer("pe", pe)
27 | self.dropout = nn.Dropout(p=dropout)
28 |
29 | def forward(self, x: torch.Tensor):
30 | """
31 | :param x: B x T x d_model tensor
32 | :return: B x T x d_model tensor
33 | """
34 | x = x + self.pe[None, : x.shape[1], :]
35 | x = self.dropout(x)
36 | return x
37 |
38 | class AudioEncoder(nn.Module):
39 | def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
40 | super(AudioEncoder, self).__init__()
41 | self._num_hiddens = num_hiddens
42 | self._num_residual_layers = num_residual_layers
43 | self._num_residual_hiddens = num_residual_hiddens
44 |
45 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
46 |
47 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
48 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
49 | sample='down')
50 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
51 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True)
52 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
53 |
54 | def forward(self, x, frame_num=0):
55 | h = self.project(x)
56 | h = self._enc_1(h)
57 | h = self._down_1(h)
58 | h = self._enc_2(h)
59 | h = self._down_2(h)
60 | h = self._enc_3(h)
61 | return h
62 |
63 |
64 |
65 |
66 |
67 |
68 | class qformer(nn.Module):
69 | def __init__(self,
70 | batchsize = 128 ,
71 | q_lenth = 44,
72 | width = 768,
73 | embed_dim = 512,
74 | codebook_size = 2048,
75 | num_layers = 6,
76 | random = True,
77 | num_q = 1
78 | ):
79 | super().__init__()
80 |
81 |
82 | self.decoder_layer = nn.TransformerDecoderLayer(d_model=512,nhead=8,batch_first=True,activation="relu")
83 | self.position = PositionalEncoding(embed_dim, dropout=0)
84 | self.text_encoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer,num_layers=num_layers)
85 | self.audioencoder = AudioEncoder(in_dim=128,num_hiddens=512,num_residual_layers=2,num_residual_hiddens=0)
86 | self.width = width
87 | self.proj = nn.Linear(512,codebook_size)
88 | self.apply(self.weights_init)
89 | self.num_q = num_q
90 |
91 | def weights_init(self, m):
92 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
93 | init.xavier_uniform_(m.weight.data)
94 | if m.bias is not None:
95 | init.constant_(m.bias.data, 0.01)
96 |
97 | def get_tgt_mask(self, size: int, device: str) -> torch.tensor:
98 | mask = torch.tril(
99 | torch.ones((size, size), device=device) == 1
100 | ) # Lower triangular matrix
101 | mask = mask.float()
102 | mask = mask.masked_fill(mask == 0, float("-inf")) # Convert zeros to -inf
103 | mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
104 | return mask
105 |
106 | def forward(self,audio_feat,video_input,ablation=False):
107 | if not ablation:
108 | b,seq_len = audio_feat.shape[0],audio_feat.shape[1]
109 | audio_feat = self.audioencoder(audio_feat[:,:].transpose(1,2),frame_num = 0).transpose(1,2)### bx seqx 512
110 | inputs_embeds = audio_feat
111 | if len(video_input.shape) == 2:
112 | video_input = video_input.unsqueeze(0)
113 | audio_embeds = self.position(inputs_embeds)
114 | #print(audio_embeds.shape)
115 | tgt_mask = self.get_tgt_mask(audio_embeds.shape[1], audio_embeds.device)
116 | video_input = video_input.repeat(1,seq_len//2,1)
117 | video_embeds = self.position(video_input)
118 | out_put = self.text_encoder(tgt=audio_embeds,memory = video_embeds,tgt_mask=tgt_mask)
119 | projhead = self.proj(out_put)
120 | try:
121 | prjhead = projhead.view(b,-1,2,2048)
122 | except:
123 | prjhead = projhead[:,:-1,:].view(b,-1,2,2048)
124 | return prjhead
125 |
126 | def generate(self,shape=(8,8),batch_size=64,aud_feat=None,text_feat=None): ### 22x2
127 |
128 | param = next(self.parameters())
129 |
130 | shape[0] = shape[0]//4
131 |
132 | x = torch.zeros(
133 | (batch_size,*shape),
134 | dtype=torch.int64,device=param.device
135 | )
136 |
137 | h0 = 0
138 | print(x.shape)
139 | h = shape[0]
140 | print(shape)
141 | for i in range(h0,h):
142 | for j in range(shape[1]):
143 | logits = self.forward(aud_feat,text_feat,ablation=False).permute(0, 3, 1, 2).contiguous()
144 | probs = F.softmax(logits[:, :, i, j], -1)
145 | x.data[:, i, j].copy_(
146 | probs.multinomial(1).squeeze().data
147 | )
148 | return x[:, h0:h]
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/nets/spg/gated_pixelcnn_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def weights_init(m):
7 | classname = m.__class__.__name__
8 | if classname.find('Conv') != -1:
9 | try:
10 | nn.init.xavier_uniform_(m.weight.data)
11 | m.bias.data.fill_(0)
12 | except AttributeError:
13 | print("Skipping initialization of ", classname)
14 |
15 |
16 | class GatedActivation(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 |
20 | def forward(self, x):
21 | x, y = x.chunk(2, dim=1)
22 | return F.tanh(x) * F.sigmoid(y)
23 |
24 |
25 | class GatedMaskedConv2d(nn.Module):
26 | def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
27 | super().__init__()
28 | assert kernel % 2 == 1, print("Kernel size must be odd")
29 | self.mask_type = mask_type
30 | self.residual = residual
31 | self.bh_model = bh_model
32 |
33 | self.class_cond_embedding = nn.Embedding(
34 | n_classes, 2 * dim
35 | )
36 |
37 | kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
38 | padding_shp = (kernel // 2, 1 if self.bh_model else 0)
39 | self.vert_stack = nn.Conv2d(
40 | dim, dim * 2,
41 | kernel_shp, 1, padding_shp
42 | )
43 |
44 | self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
45 |
46 | kernel_shp = (1, 2)
47 | padding_shp = (0, 1)
48 | self.horiz_stack = nn.Conv2d(
49 | dim, dim * 2,
50 | kernel_shp, 1, padding_shp
51 | )
52 |
53 | self.horiz_resid = nn.Conv2d(dim, dim, 1)
54 |
55 | self.gate = GatedActivation()
56 |
57 | def make_causal(self):
58 | self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
59 | self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
60 |
61 | def forward(self, x_v, x_h, h):
62 | if self.mask_type == 'A':
63 | self.make_causal()
64 |
65 | h = self.class_cond_embedding(h)
66 | h_vert = self.vert_stack(x_v)
67 | h_vert = h_vert[:, :, :x_v.size(-2), :]
68 | out_v = self.gate(h_vert + h[:, :, None, None])
69 |
70 | if self.bh_model:
71 | h_horiz = self.horiz_stack(x_h)
72 | h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
73 | v2h = self.vert_to_horiz(h_vert)
74 |
75 | out = self.gate(v2h + h_horiz + h[:, :, None, None])
76 | if self.residual:
77 | out_h = self.horiz_resid(out) + x_h
78 | else:
79 | out_h = self.horiz_resid(out)
80 | else:
81 | if self.residual:
82 | out_v = self.horiz_resid(out_v) + x_v
83 | else:
84 | out_v = self.horiz_resid(out_v)
85 | out_h = out_v
86 |
87 | return out_v, out_h
88 |
89 |
90 | class GatedPixelCNN(nn.Module):
91 | def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
92 | super().__init__()
93 | self.dim = dim
94 | self.audio = audio
95 | self.bh_model = bh_model
96 |
97 | if self.audio:
98 | self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
99 | self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
100 | self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
101 |
102 | # Create embedding layer to embed input
103 | self.embedding = nn.Embedding(input_dim, dim)
104 |
105 | # Building the PixelCNN layer by layer
106 | self.layers = nn.ModuleList()
107 |
108 | # Initial block with Mask-A convolution
109 | # Rest with Mask-B convolutions
110 | for i in range(n_layers):
111 | mask_type = 'A' if i == 0 else 'B'
112 | kernel = 7 if i == 0 else 3
113 | residual = False if i == 0 else True
114 |
115 | self.layers.append(
116 | GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
117 | )
118 |
119 | # Add the output layer
120 | self.output_conv = nn.Sequential(
121 | nn.Conv2d(dim, 512, 1),
122 | nn.ReLU(True),
123 | nn.Conv2d(512, input_dim, 1)
124 | )
125 |
126 | self.apply(weights_init)
127 |
128 | self.dp = nn.Dropout(0.1)
129 |
130 | def forward(self, x, label, aud=None):
131 | shp = x.size() + (-1,)
132 | x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
133 | x = x.permute(0, 3, 1, 2) # (B, C, W, W)
134 |
135 | x_v, x_h = (x, x)
136 | for i, layer in enumerate(self.layers):
137 | if i == 1 and self.audio is True:
138 | aud = self.embedding_aud(aud)
139 | a = torch.ones(aud.shape[-2]).to(aud.device)
140 | a = self.dp(a)
141 | aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
142 | x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
143 | if self.bh_model:
144 | x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
145 | x_v, x_h = layer(x_v, x_h, label)
146 |
147 | if self.bh_model:
148 | return self.output_conv(x_h)
149 | else:
150 | return self.output_conv(x_v)
151 |
152 | def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
153 | param = next(self.parameters())
154 | x = torch.zeros(
155 | (batch_size, *shape),
156 | dtype=torch.int64, device=param.device
157 | )
158 | if pre_latents is not None:
159 | x = torch.cat([pre_latents, x], dim=1)
160 | aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
161 | h0 = pre_latents.shape[1]
162 | h = h0 + shape[0]
163 | else:
164 | h0 = 0
165 | h = shape[0]
166 |
167 | for i in range(h0, h):
168 | for j in range(shape[1]):
169 | if self.audio:
170 | logits = self.forward(x, label, aud_feat)
171 | else:
172 | logits = self.forward(x, label)
173 | probs = F.softmax(logits[:, :, i, j], -1)
174 | x.data[:, i, j].copy_(
175 | probs.multinomial(1).squeeze().data
176 | )
177 | return x[:, h0:h]
178 |
--------------------------------------------------------------------------------
/data_utils/lower_body.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | lower_pose = torch.tensor(
5 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
6 | 0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
7 | -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
8 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
9 | lower_pose_stand = torch.tensor([
10 | 8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
11 | 3.0747, -0.0158, -0.0152,
12 | -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
13 | -3.9716e-01, -4.0229e-02, -1.2637e-01,
14 | 7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
15 | 7.8632e-01, -4.3810e-02, 1.4375e-02,
16 | -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
17 | # lower_pose_stand = torch.tensor(
18 | # [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
19 | # 3.0747, -0.0158, -0.0152,
20 | # -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
21 | # 1.1654603481292725, 0.0, 0.0,
22 | # 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
23 | # 0.0, 0.0, 0.0,
24 | # 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
25 | lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
26 | count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
27 | 29, 30, 31, 32, 33, 34, 35, 36, 37,
28 | 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
29 | fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
30 | 29,
31 | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
32 | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
33 | 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
34 | all_index = np.ones(275)
35 | all_index[fix_index] = 0
36 | c_index = []
37 | i = 0
38 | for num in all_index:
39 | if num == 1:
40 | c_index.append(i)
41 | i = i + 1
42 | c_index = np.asarray(c_index)
43 | ### 18 19 20 27 28 29 36 37 38
44 | fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
45 | 21, 22, 23, 24, 25, 26,
46 | 30, 31, 32, 33, 34, 35, 45,46,47,48,49,50]
47 | #39,40,41,42,43,44]
48 | all_index_3d = np.ones(165)
49 | all_index_3d[fix_index_3d] = 0
50 | c_index_3d = []
51 | i = 0
52 | for num in all_index_3d:
53 | if num == 1:
54 | c_index_3d.append(i)
55 | i = i + 1
56 | c_index_3d = np.asarray(c_index_3d)
57 | #print(c_index_3d,c_index_3d.shape)
58 | c_index_6d = []
59 | i = 0
60 | for num in all_index_3d:
61 | if num == 1:
62 | c_index_6d.append(2*i)
63 | c_index_6d.append(2 * i + 1)
64 | i = i + 1
65 | c_index_6d = np.asarray(c_index_6d)
66 |
67 |
68 | def part2full(input, stand=False): ### 300x232
69 | if stand:
70 | # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
71 | lp = torch.zeros_like(lower_pose)
72 | lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
73 | lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
74 | else:
75 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
76 |
77 | input = torch.cat([input[:, :3], ### 03 ### jaw
78 | lp[:, :15],#### 3 18 ### 3-12 uk 12-18 fix pose 12-75。63
79 | input[:, 3:6], #### 18 21 ### 18-21 pred pose
80 | lp[:, 15:21], #### 21 27 ### 21 - 27 fix pose
81 | input[:, 6:9], #### 27 30 ### 27-30 pred pose
82 | lp[:, 21:27], ### 30 36 #### 30- 36 fix pose
83 | input[:, 9:18], #### 36 45 ### 36-45 pred pose 36-39
84 | lp[:, 27:], ### 45 51 ### 45 - 51 fix pose 39 45
85 | input[:, 18:]] ### 51 ### 51 - 51+(232-18) 51 - 265 ## 45 - 265
86 | , dim=1)
87 | return input
88 |
89 |
90 | def pred2poses(input, gt):
91 | input = torch.cat([input[:, :3],
92 | gt[0:1, 3:18].repeat(input.shape[0], 1),
93 | input[:, 3:6],
94 | gt[0:1, 21:27].repeat(input.shape[0], 1),
95 | input[:, 6:9],
96 | gt[0:1, 30:36].repeat(input.shape[0], 1),
97 | input[:, 9:12],
98 | gt[0:1, 39:45].repeat(input.shape[0], 1),
99 | input[:, 12:]]
100 | , dim=1)
101 | return input
102 |
103 |
104 | def poses2poses(input, gt):
105 | input = torch.cat([input[:, :3],
106 | gt[0:1, 3:18].repeat(input.shape[0], 1),
107 | input[:, 18:21],
108 | gt[0:1, 21:27].repeat(input.shape[0], 1),
109 | input[:, 27:30],
110 | gt[0:1, 30:36].repeat(input.shape[0], 1),
111 | input[:, 36:39],
112 | gt[0:1, 39:45].repeat(input.shape[0], 1),
113 | input[:, 45:]]
114 | , dim=1)
115 | return input
116 |
117 | def poses2pred(input, stand=False):
118 | if stand:
119 | lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
120 | # lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
121 | else:
122 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
123 | input = torch.cat([input[:, :3],
124 | lp[:, :15],
125 | input[:, 18:21],
126 | lp[:, 15:21],
127 | input[:, 27:30],
128 | lp[:, 21:27],
129 | input[:, 36:39],
130 | lp[:, 27:],
131 | input[:, 45:]]
132 | , dim=1)
133 | return input
134 |
135 |
136 | rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
137 | # ,22, 23, 24, 25, 40, 26, 41,
138 | # 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
139 | # 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
140 |
141 | symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
142 | # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
143 | # 1, 1, 1, 1, 1, 1]
144 |
--------------------------------------------------------------------------------
/voca/rendering.py:
--------------------------------------------------------------------------------
1 | '''
2 | Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
3 | computer program.
4 |
5 | You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
6 | the computer program from someone who is authorized to grant you that right.
7 |
8 | Any use of the computer program without a valid license is prohibited and liable to prosecution.
9 |
10 | Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
11 | Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
12 | All rights reserved.
13 |
14 | More information about VOCA is available at http://voca.is.tue.mpg.de.
15 | For comments or questions, please email us at voca@tue.mpg.de
16 | '''
17 |
18 | from __future__ import division
19 | import os
20 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # Uncommnet this line while running remotely
21 | import cv2
22 | import pyrender
23 | import trimesh
24 | import tempfile
25 | import numpy as np
26 | import matplotlib as mpl
27 | import matplotlib.cm as cm
28 |
29 |
30 | def get_unit_factor(unit):
31 | if unit == 'mm':
32 | return 1000.0
33 | elif unit == 'cm':
34 | return 100.0
35 | elif unit == 'm':
36 | return 1.0
37 | else:
38 | raise ValueError('Unit not supported')
39 |
40 |
41 | def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, v_colors=None,
42 | errors=None, error_unit='m', min_dist_in_mm=0.0, max_dist_in_mm=3.0, z_offset=1.0, xmag=0.5,
43 | y=0.7, z=1, camera='o', r=None):
44 | camera_params = {'c': np.array([0, 0]),
45 | 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
46 | 'f': np.array([5000, 5000])}
47 |
48 | frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}
49 |
50 | v, f = mesh
51 | v = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
52 |
53 | texture_rendering = tex_img is not None and hasattr(mesh, 'vt') and hasattr(mesh, 'ft')
54 | if texture_rendering:
55 | intensity = 0.5
56 | tex = pyrender.Texture(source=tex_img, source_channels='RGB')
57 | material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
58 |
59 | # Workaround as pyrender requires number of vertices and uv coordinates to be the same
60 | temp_filename = '%s.obj' % next(tempfile._get_candidate_names())
61 | mesh.write_obj(temp_filename)
62 | tri_mesh = trimesh.load(temp_filename, process=False)
63 | try:
64 | os.remove(temp_filename)
65 | except:
66 | print('Failed deleting temporary file - %s' % temp_filename)
67 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=material)
68 | elif errors is not None:
69 | intensity = 0.5
70 | unit_factor = get_unit_factor('mm') / get_unit_factor(error_unit)
71 | errors = unit_factor * errors
72 |
73 | norm = mpl.colors.Normalize(vmin=min_dist_in_mm, vmax=max_dist_in_mm)
74 | cmap = cm.get_cmap(name='jet')
75 | colormapper = cm.ScalarMappable(norm=norm, cmap=cmap)
76 | rgba_per_v = colormapper.to_rgba(errors)
77 | rgb_per_v = rgba_per_v[:, 0:3]
78 | elif v_colors is not None:
79 | intensity = 0.5
80 | rgb_per_v = v_colors
81 | else:
82 | intensity = 6.
83 | rgb_per_v = None
84 |
85 | color = np.array([0.3, 0.5, 0.55])
86 |
87 | if not texture_rendering:
88 | tri_mesh = trimesh.Trimesh(vertices=v, faces=f, vertex_colors=rgb_per_v)
89 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh,
90 | smooth=True,
91 | material=pyrender.MetallicRoughnessMaterial(
92 | metallicFactor=0.05,
93 | roughnessFactor=0.7,
94 | alphaMode='OPAQUE',
95 | baseColorFactor=(color[0], color[1], color[2], 1.0)
96 | ))
97 |
98 | scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
99 |
100 | if camera == 'o':
101 | ymag = xmag * z_offset
102 | camera = pyrender.OrthographicCamera(xmag=xmag, ymag=ymag)
103 | elif camera == 'i':
104 | camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
105 | fy=camera_params['f'][1],
106 | cx=camera_params['c'][0],
107 | cy=camera_params['c'][1],
108 | znear=frustum['near'],
109 | zfar=frustum['far'])
110 | elif camera == 'y':
111 | camera = pyrender.PerspectiveCamera(yfov=(np.pi / 2.0))
112 |
113 | scene.add(render_mesh, pose=np.eye(4))
114 |
115 | camera_pose = np.eye(4)
116 | camera_pose[:3, 3] = np.array([0, 0.7, 1.0 - z_offset])
117 | scene.add(camera, pose=[[1, 0, 0, 0],
118 | [0, 1, 0, y], # 0.25
119 | [0, 0, 1, z], # 0.2
120 | [0, 0, 0, 1]])
121 |
122 |
123 | angle = np.pi / 6.0
124 | # pos = camera_pose[:3,3]
125 | pos = np.array([0, 0.7, 2.0])
126 | if False:
127 | light_color = np.array([1., 1., 1.])
128 | light = pyrender.DirectionalLight(color=light_color, intensity=intensity)
129 |
130 | light_pose = np.eye(4)
131 | light_pose[:3, 3] = np.array([0, 0.7, 2.0])
132 | scene.add(light, pose=light_pose.copy())
133 | else:
134 | light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=2)
135 | light_pose = np.eye(4)
136 | light_pose[:3, 3] = [0, -1, 1]
137 | scene.add(light, pose=light_pose)
138 |
139 | light_pose[:3, 3] = [0, 1, 1]
140 | scene.add(light, pose=light_pose)
141 |
142 | light_pose[:3, 3] = [-1, 1, 2]
143 | scene.add(light, pose=light_pose)
144 |
145 | spot_l = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
146 | innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2)
147 |
148 | light_pose[:3, 3] = [-1, 2, 2]
149 | scene.add(spot_l, pose=light_pose)
150 |
151 | light_pose[:3, 3] = [1, 2, 2]
152 | scene.add(spot_l, pose=light_pose)
153 |
154 | # light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
155 | # scene.add(light, pose=light_pose.copy())
156 | #
157 | # light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
158 | # scene.add(light, pose=light_pose.copy())
159 | #
160 | # light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
161 | # scene.add(light, pose=light_pose.copy())
162 | #
163 | # light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
164 | # scene.add(light, pose=light_pose.copy())
165 |
166 | # pyrender.Viewer(scene)
167 |
168 | flags = pyrender.RenderFlags.SKIP_CULL_FACES
169 | # try:
170 | # r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
171 | color, _ = r.render(scene, flags=flags)
172 | # r.delete()
173 | # except:
174 | # print('pyrender: Failed rendering frame')
175 | # color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')
176 |
177 | return color[..., ::-1]
178 |
--------------------------------------------------------------------------------
/scripts/test_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 | sys.path.append(os.getcwd())
6 |
7 | from tqdm import tqdm
8 | from transformers import Wav2Vec2Processor
9 |
10 | from evaluation.metrics import LVD
11 |
12 | import numpy as np
13 | import smplx as smpl
14 |
15 | from nets import *
16 | from trainer.options import parse_args
17 | from data_utils import torch_data
18 | from trainer.config import load_JsonConfig
19 | from data_utils.get_j import get_joints
20 |
21 | import torch
22 | from torch.utils import data
23 |
24 |
25 | def init_model(model_name, model_path, args, config):
26 | if model_name == 's2g_face':
27 | generator = s2g_face(
28 | args,
29 | config,
30 | )
31 | elif model_name == 's2g_body_vq':
32 | generator = s2g_body_vq(
33 | args,
34 | config,
35 | )
36 | elif model_name == 's2g_body_pixel':
37 | generator = s2g_body_pixel(
38 | args,
39 | config,
40 | )
41 | else:
42 | raise NotImplementedError
43 |
44 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
45 | if model_name == 'smplx_S2G':
46 | generator.generator.load_state_dict(model_ckpt['generator']['generator'])
47 | elif 'generator' in list(model_ckpt.keys()):
48 | generator.load_state_dict(model_ckpt['generator'])
49 | else:
50 | model_ckpt = {'generator': model_ckpt}
51 | generator.load_state_dict(model_ckpt)
52 |
53 | return generator
54 |
55 |
56 | def init_dataloader(data_root, speakers, args, config):
57 | data_base = torch_data(
58 | data_root=data_root,
59 | speakers=speakers,
60 | split='test',
61 | limbscaling=False,
62 | normalization=config.Data.pose.normalization,
63 | norm_method=config.Data.pose.norm_method,
64 | split_trans_zero=False,
65 | num_pre_frames=config.Data.pose.pre_pose_length,
66 | num_generate_length=config.Data.pose.generate_length,
67 | num_frames=30,
68 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
69 | aud_feat_dim=config.Data.aud.aud_feat_dim,
70 | feat_method=config.Data.aud.feat_method,
71 | smplx=True,
72 | audio_sr=22000,
73 | convert_to_6d=config.Data.pose.convert_to_6d,
74 | expression=config.Data.pose.expression,
75 | config=config
76 | )
77 |
78 | if config.Data.pose.normalization:
79 | norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
80 | norm_stats = np.load(norm_stats_fn, allow_pickle=True)
81 | data_base.data_mean = norm_stats[0]
82 | data_base.data_std = norm_stats[1]
83 | else:
84 | norm_stats = None
85 |
86 | data_base.get_dataset()
87 | test_set = data_base.all_dataset
88 | test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
89 |
90 | return test_set, test_loader, norm_stats
91 |
92 |
93 | def face_loss(gt, gt_param, pr, pr_param):
94 | loss_dict = {}
95 |
96 | jaw_xyz = gt[:, 22:25, :] - pr[:, 22:25, :]
97 | jaw_dist = jaw_xyz.norm(p=2, dim=-1)
98 | jaw_dist = jaw_dist.sum(dim=-1).mean()
99 | loss_dict['jaw_l1'] = jaw_dist
100 |
101 | landmark_xyz = gt[:, 74:] - pr[:, 74:]
102 | landmark_dist = landmark_xyz.norm(p=2, dim=-1)
103 | landmark_dist = landmark_dist.sum(dim=-1).mean()
104 | loss_dict['landmark_l1'] = landmark_dist
105 |
106 | face_gt = torch.cat([gt[:, 22:25], gt[:, 74:]], dim=1)
107 | face_pr = torch.cat([pr[:, 22:25], pr[:, 74:]], dim=1)
108 |
109 | loss_dict['LVD'] = LVD(face_gt, face_pr, symmetrical=False, weight=False)
110 |
111 | return loss_dict
112 |
113 |
114 | def test(test_loader, generator, smplx_model, args, config):
115 | print('start testing')
116 |
117 | am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
118 | am_sr = 16000
119 |
120 | loss_dict = {}
121 | with torch.no_grad():
122 | i = 0
123 | for bat in tqdm(test_loader, desc="Testing......"):
124 | i = i + 1
125 | aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
126 | bat['expression'].to('cuda').to(torch.float32)
127 | id = bat['speaker'].to('cuda') - 20
128 | betas = bat['betas'][0].to('cuda').to(torch.float64)
129 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
130 | # poses = to3d(poses, config)
131 |
132 | cur_wav_file = bat['aud_file'][0]
133 | pred_face = generator.infer_on_audio(cur_wav_file,
134 | id=id,
135 | frame=poses.shape[0],
136 | am=am,
137 | am_sr=am_sr
138 | )
139 |
140 | pred_face = torch.tensor(pred_face).to('cuda').squeeze()
141 | if pred_face.shape[1] > 103:
142 | pred_face = pred_face[:, :103]
143 | zero_poses = torch.zeros([pred_face.shape[0], 162], device='cuda')
144 |
145 | full_param = torch.cat([pred_face[:, :3], zero_poses, pred_face[:, 3:]], dim=-1)
146 |
147 | poses[:, 3:165] = full_param[:, 3:165]
148 | gt_joints = get_joints(smplx_model, betas, poses)
149 | pred_joints = get_joints(smplx_model, betas, full_param)
150 | bat_loss_dict = face_loss(gt_joints, poses, pred_joints, full_param)
151 |
152 | if loss_dict: # 非空
153 | for key in list(bat_loss_dict.keys()):
154 | loss_dict[key] += bat_loss_dict[key]
155 | else:
156 | for key in list(bat_loss_dict.keys()):
157 | loss_dict[key] = bat_loss_dict[key]
158 | for key in loss_dict.keys():
159 | loss_dict[key] = loss_dict[key] / i
160 | print(key + '=' + str(loss_dict[key].item()))
161 |
162 |
163 | def main():
164 | parser = parse_args()
165 | args = parser.parse_args()
166 | device = torch.device(args.gpu)
167 | torch.cuda.set_device(device)
168 |
169 | config = load_JsonConfig(args.config_file)
170 |
171 | os.environ['smplx_npz_path'] = config.smplx_npz_path
172 | os.environ['extra_joint_path'] = config.extra_joint_path
173 | os.environ['j14_regressor_path'] = config.j14_regressor_path
174 |
175 | print('init dataloader...')
176 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
177 | print('init model...')
178 | face_model_name = args.face_model_name
179 | face_model_path = args.face_model_path
180 | generator_face = init_model(face_model_name, face_model_path, args, config)
181 |
182 | print('init smlpx model...')
183 | dtype = torch.float64
184 | smplx_path = './visualise/'
185 | model_params = dict(model_path=smplx_path,
186 | model_type='smplx',
187 | create_global_orient=True,
188 | create_body_pose=True,
189 | create_betas=True,
190 | num_betas=300,
191 | create_left_hand_pose=True,
192 | create_right_hand_pose=True,
193 | use_pca=False,
194 | flat_hand_mean=False,
195 | create_expression=True,
196 | num_expression_coeffs=100,
197 | num_pca_comps=12,
198 | create_jaw_pose=True,
199 | create_leye_pose=True,
200 | create_reye_pose=True,
201 | create_transl=False,
202 | dtype=dtype, )
203 | smplx_model = smpl.create(**model_params).to('cuda')
204 |
205 | test(test_loader, generator_face, smplx_model, args, config)
206 |
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
--------------------------------------------------------------------------------
/nets/spg/t2m_trans.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from torch.distributions import Categorical
6 | import models1.pos_encoding as pos_encoding
7 |
8 |
9 |
10 | class Text2Motion_Transformer(nn.Module):
11 |
12 | def __init__(self,
13 | num_vq=1024,
14 | embed_dim=512,
15 | clip_dim=512,
16 | block_size=16,
17 | num_layers=2,
18 | n_head=8,
19 | drop_out_rate=0.1,
20 | fc_rate=4):
21 | super().__init__()
22 | self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
23 | self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
24 | self.block_size = block_size
25 | self.num_vq = num_vq
26 |
27 | def get_block_size(self):
28 | return self.block_size
29 |
30 | def forward(self, idxs, clip_feature):
31 | feat = self.trans_base(idxs, clip_feature)
32 | logits = self.trans_head(feat)
33 | return logits
34 |
35 | def sample(self, clip_feature, if_categorial=False):
36 | for k in range(self.block_size):
37 | if k == 0:
38 | x = []
39 | else:
40 | x = xs
41 | logits = self.forward(x, clip_feature)
42 | logits = logits[:, -1, :]
43 | probs = F.softmax(logits, dim=-1)
44 | if if_categorial:
45 | dist = Categorical(probs)
46 | idx = dist.sample()
47 | if idx == self.num_vq:
48 | break
49 | idx = idx.unsqueeze(-1)
50 | else:
51 | _, idx = torch.topk(probs, k=1, dim=-1)
52 | if idx[0] == self.num_vq:
53 | break
54 | # append to the sequence and continue
55 | if k == 0:
56 | xs = idx
57 | else:
58 | xs = torch.cat((xs, idx), dim=1)
59 |
60 | if k == self.block_size - 1:
61 | return xs[:, :-1]
62 | return xs
63 |
64 | class CausalCrossConditionalSelfAttention(nn.Module):
65 |
66 | def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
67 | super().__init__()
68 | assert embed_dim % 8 == 0
69 | # key, query, value projections for all heads
70 | self.key = nn.Linear(embed_dim, embed_dim)
71 | self.query = nn.Linear(embed_dim, embed_dim)
72 | self.value = nn.Linear(embed_dim, embed_dim)
73 |
74 | self.attn_drop = nn.Dropout(drop_out_rate)
75 | self.resid_drop = nn.Dropout(drop_out_rate)
76 |
77 | self.proj = nn.Linear(embed_dim, embed_dim)
78 | # causal mask to ensure that attention is only applied to the left in the input sequence
79 | self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
80 | self.n_head = n_head
81 |
82 | def forward(self, x):
83 | B, T, C = x.size()
84 |
85 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
86 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
87 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
88 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
89 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
90 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
91 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
92 | att = F.softmax(att, dim=-1)
93 | att = self.attn_drop(att)
94 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
95 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
96 |
97 | # output projection
98 | y = self.resid_drop(self.proj(y))
99 | return y
100 |
101 | class Block(nn.Module):
102 |
103 | def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
104 | super().__init__()
105 | self.ln1 = nn.LayerNorm(embed_dim)
106 | self.ln2 = nn.LayerNorm(embed_dim)
107 | self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate)
108 | self.mlp = nn.Sequential(
109 | nn.Linear(embed_dim, fc_rate * embed_dim),
110 | nn.GELU(),
111 | nn.Linear(fc_rate * embed_dim, embed_dim),
112 | nn.Dropout(drop_out_rate),
113 | )
114 |
115 | def forward(self, x):
116 | x = x + self.attn(self.ln1(x))
117 | x = x + self.mlp(self.ln2(x))
118 | return x
119 |
120 | class CrossCondTransBase(nn.Module):
121 |
122 | def __init__(self,
123 | num_vq=1024,
124 | embed_dim=512,
125 | clip_dim=512,
126 | block_size=16,
127 | num_layers=2,
128 | n_head=8,
129 | drop_out_rate=0.1,
130 | fc_rate=4):
131 | super().__init__()
132 | self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)
133 | self.cond_emb = nn.Linear(clip_dim, embed_dim)
134 | self.pos_embedding = nn.Embedding(block_size, embed_dim)#### 16x512
135 | self.drop = nn.Dropout(drop_out_rate)
136 | # transformer block
137 | self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
138 | self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
139 |
140 | self.block_size = block_size
141 |
142 | self.apply(self._init_weights)
143 |
144 | def get_block_size(self):
145 | return self.block_size
146 |
147 | def _init_weights(self, module):
148 | if isinstance(module, (nn.Linear, nn.Embedding)):
149 | module.weight.data.normal_(mean=0.0, std=0.02)
150 | if isinstance(module, nn.Linear) and module.bias is not None:
151 | module.bias.data.zero_()
152 | elif isinstance(module, nn.LayerNorm):
153 | module.bias.data.zero_()
154 | module.weight.data.fill_(1.0)
155 |
156 | def forward(self, audio_feature, clip_feature):
157 | if len(idx) == 0:
158 | token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
159 | else:
160 | #b, t = idx.size()
161 | #assert t <= self.block_size, "Cannot forward, model block size is exhausted."
162 | # forward the Trans model
163 | #token_embeddings = self.tok_emb(idx)
164 | token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), audio_feature], dim=1)
165 |
166 | x = self.pos_embed(token_embeddings)
167 | x = self.blocks(x)
168 |
169 | return x
170 |
171 |
172 | class CrossCondTransHead(nn.Module):
173 |
174 | def __init__(self,
175 | num_vq=1024,
176 | embed_dim=512,
177 | block_size=16,
178 | num_layers=2,
179 | n_head=8,
180 | drop_out_rate=0.1,
181 | fc_rate=4):
182 | super().__init__()
183 |
184 | self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
185 | self.ln_f = nn.LayerNorm(embed_dim)
186 | self.head = nn.Linear(embed_dim, num_vq + 1, bias=False)
187 | self.block_size = block_size
188 |
189 | self.apply(self._init_weights)
190 |
191 | def get_block_size(self):
192 | return self.block_size
193 |
194 | def _init_weights(self, module):
195 | if isinstance(module, (nn.Linear, nn.Embedding)):
196 | module.weight.data.normal_(mean=0.0, std=0.02)
197 | if isinstance(module, nn.Linear) and module.bias is not None:
198 | module.bias.data.zero_()
199 | elif isinstance(module, nn.LayerNorm):
200 | module.bias.data.zero_()
201 | module.weight.data.fill_(1.0)
202 |
203 | def forward(self, x):
204 | x = self.blocks(x)
205 | x = self.ln_f(x)
206 | logits = self.head(x)
207 | return logits
208 |
209 |
210 |
211 |
212 |
213 |
214 |
--------------------------------------------------------------------------------
/nets/smplx_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | from nets.layers import *
7 | from nets.base import TrainWrapperBaseClass
8 | # from nets.spg.faceformer import Faceformer
9 | from nets.spg.s2g_face import Generator as s2g_face
10 | from losses import KeypointLoss
11 | from nets.utils import denormalize
12 | from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
13 | import numpy as np
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 | from sklearn.preprocessing import normalize
17 | import smplx
18 |
19 |
20 | class TrainWrapper(TrainWrapperBaseClass):
21 | '''
22 | a wrapper receving a batch from data_utils and calculate loss
23 | '''
24 |
25 | def __init__(self, args, config):
26 | self.args = args
27 | self.config = config
28 | self.device = torch.device(self.args.gpu)
29 | self.global_step = 0
30 |
31 | self.convert_to_6d = self.config.Data.pose.convert_to_6d
32 | self.expression = self.config.Data.pose.expression
33 | self.epoch = 0
34 | self.init_params()
35 | self.num_classes = 4
36 |
37 | self.generator = s2g_face(
38 | n_poses=self.config.Data.pose.generate_length,
39 | each_dim=self.each_dim,
40 | dim_list=self.dim_list,
41 | training=not self.args.infer,
42 | device=self.device,
43 | identity=False if self.convert_to_6d else True,
44 | num_classes=self.num_classes,
45 | ).to(self.device)
46 |
47 | # self.generator = Faceformer().to(self.device)
48 |
49 | self.discriminator = None
50 | self.am = None
51 |
52 | self.MSELoss = KeypointLoss().to(self.device)
53 | super().__init__(args, config)
54 |
55 | def init_optimizer(self):
56 | self.generator_optimizer = optim.SGD(
57 | filter(lambda p: p.requires_grad,self.generator.parameters()),
58 | lr=0.001,
59 | momentum=0.9,
60 | nesterov=False,
61 | )
62 |
63 | def init_params(self):
64 | if self.convert_to_6d:
65 | scale = 2
66 | else:
67 | scale = 1
68 |
69 | global_orient = round(3 * scale)
70 | leye_pose = reye_pose = round(3 * scale)
71 | jaw_pose = round(3 * scale)
72 | body_pose = round(63 * scale)
73 | left_hand_pose = right_hand_pose = round(45 * scale)
74 | if self.expression:
75 | expression = 100
76 | else:
77 | expression = 0
78 |
79 | b_j = 0
80 | jaw_dim = jaw_pose
81 | b_e = b_j + jaw_dim
82 | eye_dim = leye_pose + reye_pose
83 | b_b = b_e + eye_dim
84 | body_dim = global_orient + body_pose
85 | b_h = b_b + body_dim
86 | hand_dim = left_hand_pose + right_hand_pose
87 | b_f = b_h + hand_dim
88 | face_dim = expression
89 |
90 | self.dim_list = [b_j, b_e, b_b, b_h, b_f]
91 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
92 | self.pose = int(self.full_dim / round(3 * scale))
93 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
94 |
95 | def __call__(self, bat):
96 | # assert (not self.args.infer), "infer mode"
97 | self.global_step += 1
98 |
99 | total_loss = None
100 | loss_dict = {}
101 |
102 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
103 | id = bat['speaker'].to(self.device) - 20
104 | id = F.one_hot(id, self.num_classes)
105 |
106 | aud = aud.permute(0, 2, 1)
107 | gt_poses = poses.permute(0, 2, 1)
108 |
109 | if self.expression:
110 | expression = bat['expression'].to(self.device).to(torch.float32)
111 | gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
112 |
113 | pred_poses, _ = self.generator(
114 | aud,
115 | gt_poses,
116 | id,
117 | )
118 |
119 | G_loss, G_loss_dict = self.get_loss(
120 | pred_poses=pred_poses,
121 | gt_poses=gt_poses,
122 | pre_poses=None,
123 | mode='training_G',
124 | gt_conf=None,
125 | aud=aud,
126 | )
127 |
128 | self.generator_optimizer.zero_grad()
129 | G_loss.backward()
130 | grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
131 | loss_dict['grad'] = grad.item()
132 | self.generator_optimizer.step()
133 |
134 | for key in list(G_loss_dict.keys()):
135 | loss_dict[key] = G_loss_dict.get(key, 0).item()
136 |
137 | return total_loss, loss_dict
138 |
139 | def get_loss(self,
140 | pred_poses,
141 | gt_poses,
142 | pre_poses,
143 | aud,
144 | mode='training_G',
145 | gt_conf=None,
146 | exp=1,
147 | gt_nzero=None,
148 | pre_nzero=None,
149 | ):
150 | loss_dict = {}
151 |
152 |
153 | [b_j, b_e, b_b, b_h, b_f] = self.dim_list
154 |
155 | MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
156 | if self.expression:
157 | expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
158 | else:
159 | expl = 0
160 |
161 | gen_loss = expl + MSELoss
162 |
163 | loss_dict['MSELoss'] = MSELoss
164 | if self.expression:
165 | loss_dict['exp_loss'] = expl
166 |
167 | return gen_loss, loss_dict
168 |
169 | def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
170 | '''
171 | initial_pose: (B, C, T), normalized
172 | (aud_fn, txgfile) -> generated motion (B, T, C)
173 | '''
174 | output = []
175 |
176 | # assert self.args.infer, "train mode"
177 | self.generator.eval()
178 |
179 | if self.config.Data.pose.normalization:
180 | assert norm_stats is not None
181 | data_mean = norm_stats[0]
182 | data_std = norm_stats[1]
183 |
184 | # assert initial_pose.shape[-1] == pre_length
185 | if initial_pose is not None:
186 | gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
187 | pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
188 | poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
189 | B = pre_poses.shape[0]
190 | else:
191 | gt = None
192 | pre_poses=None
193 | B = 1
194 |
195 | if type(aud_fn) == torch.Tensor:
196 | aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
197 | num_poses_to_generate = aud_feat.shape[-1]
198 | else:
199 | aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
200 | aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
201 | aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
202 | if frame is None:
203 | frame = aud_feat.shape[2]*30//16000
204 | #
205 | if id is None:
206 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
207 | else:
208 | id = F.one_hot(id, self.num_classes).to(self.generator.device)
209 |
210 | with torch.no_grad():
211 | pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
212 | pred_poses = pred_poses.cpu().numpy()
213 | output = pred_poses
214 |
215 | if self.config.Data.pose.normalization:
216 | output = denormalize(output, data_mean, data_std)
217 |
218 | return output
219 |
220 |
221 | def generate(self, wv2_feat, frame):
222 | '''
223 | initial_pose: (B, C, T), normalized
224 | (aud_fn, txgfile) -> generated motion (B, T, C)
225 | '''
226 | output = []
227 |
228 | # assert self.args.infer, "train mode"
229 | self.generator.eval()
230 |
231 | B = 1
232 |
233 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
234 | id = id.repeat(wv2_feat.shape[0], 1)
235 |
236 | with torch.no_grad():
237 | pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
238 | return pred_poses
239 |
--------------------------------------------------------------------------------
/evaluation/FGD.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from scipy import linalg
7 | import math
8 | from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
9 |
10 | import warnings
11 | warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
12 |
13 |
14 | change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
15 | class EmbeddingSpaceEvaluator:
16 | def __init__(self, ae, vae, device):
17 |
18 | # init embed net
19 | self.ae = ae
20 | # self.vae = vae
21 |
22 | # storage
23 | self.real_feat_list = []
24 | self.generated_feat_list = []
25 | self.real_joints_list = []
26 | self.generated_joints_list = []
27 | self.real_6d_list = []
28 | self.generated_6d_list = []
29 | self.audio_beat_list = []
30 |
31 | def reset(self):
32 | self.real_feat_list = []
33 | self.generated_feat_list = []
34 |
35 | def get_no_of_samples(self):
36 | return len(self.real_feat_list)
37 |
38 | def push_samples(self, generated_poses, real_poses):
39 | # self.net.eval()
40 | # convert poses to latent features
41 | real_feat, real_poses = self.ae.extract(real_poses)
42 | generated_feat, generated_poses = self.ae.extract(generated_poses)
43 |
44 | num_joints = real_poses.shape[2] // 3
45 |
46 | real_feat = real_feat.squeeze()
47 | generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
48 |
49 | self.real_feat_list.append(real_feat.data.cpu().numpy())
50 | self.generated_feat_list.append(generated_feat.data.cpu().numpy())
51 |
52 | # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
53 | # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
54 | #
55 | # self.real_feat_list.append(real_poses.data.cpu().numpy())
56 | # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
57 |
58 | def push_joints(self, generated_poses, real_poses):
59 | self.real_joints_list.append(real_poses.data.cpu())
60 | self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
61 |
62 | def push_aud(self, aud):
63 | self.audio_beat_list.append(aud.squeeze().data.cpu())
64 |
65 | def get_MAAC(self):
66 | ang_vel_list = []
67 | for real_joints in self.real_joints_list:
68 | real_joints[:, 15:21] = real_joints[:, 16:22]
69 | vec = real_joints[:, 15:21] - real_joints[:, 13:19]
70 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
71 | inner_product = torch.clamp(inner_product, -1, 1, out=None)
72 | angle = torch.acos(inner_product) / math.pi
73 | ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
74 | ang_vel_list.append(ang_vel.unsqueeze(dim=0))
75 | all_vel = torch.cat(ang_vel_list, dim=0)
76 | MAAC = all_vel.mean(dim=0)
77 | return MAAC
78 |
79 | def get_BCscore(self):
80 | thres = 0.01
81 | sigma = 0.1
82 | sum_1 = 0
83 | total_beat = 0
84 | for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
85 | motion_beat_time = []
86 | if joints.dim() == 4:
87 | joints = joints[0]
88 | joints[:, 15:21] = joints[:, 16:22]
89 | vec = joints[:, 15:21] - joints[:, 13:19]
90 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
91 | inner_product = torch.clamp(inner_product, -1, 1, out=None)
92 | angle = torch.acos(inner_product) / math.pi
93 | ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
94 |
95 | angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
96 |
97 | sum_2 = 0
98 | for i in range(angle_diff.shape[1]):
99 | motion_beat_time = []
100 | for t in range(1, joints.shape[0]-1):
101 | if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
102 | if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
103 | t][i] >= thres):
104 | motion_beat_time.append(float(t) / 30.0)
105 | if (len(motion_beat_time) == 0):
106 | continue
107 | motion_beat_time = torch.tensor(motion_beat_time)
108 | sum = 0
109 | for audio in audio_beat_time:
110 | sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
111 | sum_2 = sum_2 + sum
112 | total_beat = total_beat + len(audio_beat_time)
113 | sum_1 = sum_1 + sum_2
114 | return sum_1/total_beat
115 |
116 |
117 | def get_scores(self):
118 | generated_feats = np.vstack(self.generated_feat_list)
119 | real_feats = np.vstack(self.real_feat_list)
120 |
121 | def frechet_distance(samples_A, samples_B):
122 | A_mu = np.mean(samples_A, axis=0)
123 | A_sigma = np.cov(samples_A, rowvar=False)
124 | B_mu = np.mean(samples_B, axis=0)
125 | B_sigma = np.cov(samples_B, rowvar=False)
126 | try:
127 | frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
128 | except ValueError:
129 | frechet_dist = 1e+10
130 | return frechet_dist
131 |
132 | ####################################################################
133 | # frechet distance
134 | frechet_dist = frechet_distance(generated_feats, real_feats)
135 |
136 | ####################################################################
137 | # distance between real and generated samples on the latent feature space
138 | dists = []
139 | for i in range(real_feats.shape[0]):
140 | d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
141 | dists.append(d)
142 | feat_dist = np.mean(dists)
143 |
144 | return frechet_dist, feat_dist
145 |
146 | @staticmethod
147 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
148 | """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
149 | """Numpy implementation of the Frechet Distance.
150 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
151 | and X_2 ~ N(mu_2, C_2) is
152 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
153 | Stable version by Dougal J. Sutherland.
154 | Params:
155 | -- mu1 : Numpy array containing the activations of a layer of the
156 | inception net (like returned by the function 'get_predictions')
157 | for generated samples.
158 | -- mu2 : The sample mean over activations, precalculated on an
159 | representative data set.
160 | -- sigma1: The covariance matrix over activations for generated samples.
161 | -- sigma2: The covariance matrix over activations, precalculated on an
162 | representative data set.
163 | Returns:
164 | -- : The Frechet Distance.
165 | """
166 |
167 | mu1 = np.atleast_1d(mu1)
168 | mu2 = np.atleast_1d(mu2)
169 |
170 | sigma1 = np.atleast_2d(sigma1)
171 | sigma2 = np.atleast_2d(sigma2)
172 |
173 | assert mu1.shape == mu2.shape, \
174 | 'Training and test mean vectors have different lengths'
175 | assert sigma1.shape == sigma2.shape, \
176 | 'Training and test covariances have different dimensions'
177 |
178 | diff = mu1 - mu2
179 |
180 | # Product might be almost singular
181 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
182 | if not np.isfinite(covmean).all():
183 | msg = ('fid calculation produces singular product; '
184 | 'adding %s to diagonal of cov estimates') % eps
185 | print(msg)
186 | offset = np.eye(sigma1.shape[0]) * eps
187 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
188 |
189 | # Numerical error might give slight imaginary component
190 | if np.iscomplexobj(covmean):
191 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
192 | m = np.max(np.abs(covmean.imag))
193 | raise ValueError('Imaginary component {}'.format(m))
194 | covmean = covmean.real
195 |
196 | tr_covmean = np.trace(covmean)
197 |
198 | return (diff.dot(diff) + np.trace(sigma1) +
199 | np.trace(sigma2) - 2 * tr_covmean)
--------------------------------------------------------------------------------
/nets/spg/s2g_face.py:
--------------------------------------------------------------------------------
1 | '''
2 | not exactly the same as the official repo but the results are good
3 | '''
4 | import sys
5 | import os
6 |
7 | from transformers import Wav2Vec2Processor
8 |
9 | from .wav2vec import Wav2Vec2Model
10 | from torchaudio.sox_effects import apply_effects_tensor
11 |
12 | sys.path.append(os.getcwd())
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torchaudio as ta
19 | import math
20 | from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
21 |
22 |
23 | """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
24 |
25 |
26 | def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
27 | """
28 | :param audio: 1 x T tensor containing a 16kHz audio signal
29 | :param frame_rate: frame rate for video (we need one audio chunk per video frame)
30 | :param chunk_size: number of audio samples per chunk
31 | :return: num_chunks x chunk_size tensor containing sliced audio
32 | """
33 | samples_per_frame = 16000 // frame_rate
34 | padding = (chunk_size - samples_per_frame) // 2
35 | audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
36 | anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
37 | audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
38 | return audio
39 |
40 |
41 | class MeshtalkEncoder(nn.Module):
42 | def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
43 | """
44 | :param latent_dim: size of the latent audio embedding
45 | :param model_name: name of the model, used to load and save the model
46 | """
47 | super().__init__()
48 |
49 | self.melspec = ta.transforms.MelSpectrogram(
50 | sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
51 | )
52 |
53 | conv_len = 5
54 | self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
55 | self.weights_init(self.convert_dimensions)
56 | self.receptive_field = conv_len
57 |
58 | convs = []
59 | for i in range(6):
60 | dilation = 2 * (i % 3 + 1)
61 | self.receptive_field += (conv_len - 1) * dilation
62 | convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
63 | self.weights_init(convs[-1])
64 | self.convs = torch.nn.ModuleList(convs)
65 | self.code = torch.nn.Linear(128, latent_dim)
66 |
67 | self.apply(lambda x: self.weights_init(x))
68 |
69 | def weights_init(self, m):
70 | if isinstance(m, torch.nn.Conv1d):
71 | torch.nn.init.xavier_uniform_(m.weight)
72 | try:
73 | torch.nn.init.constant_(m.bias, .01)
74 | except:
75 | pass
76 |
77 | def forward(self, audio: torch.Tensor):
78 | """
79 | :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
80 | :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
81 | """
82 | B, T = audio.shape[0], audio.shape[1]
83 | x = self.melspec(audio).squeeze(1)
84 | x = torch.log(x.clamp(min=1e-10, max=None))
85 | if T == 1:
86 | x = x.unsqueeze(1)
87 |
88 | # Convert to the right dimensionality
89 | x = x.view(-1, x.shape[2], x.shape[3])
90 | x = F.leaky_relu(self.convert_dimensions(x), .2)
91 |
92 | # Process stacks
93 | for conv in self.convs:
94 | x_ = F.leaky_relu(conv(x), .2)
95 | if self.training:
96 | x_ = F.dropout(x_, .2)
97 | l = (x.shape[2] - x_.shape[2]) // 2
98 | x = (x[:, :, l:-l] + x_) / 2
99 |
100 | x = torch.mean(x, dim=-1)
101 | x = x.view(B, T, x.shape[-1])
102 | x = self.code(x)
103 |
104 | return {"code": x}
105 |
106 |
107 | class AudioEncoder(nn.Module):
108 | def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
109 | super().__init__()
110 | self.identity = identity
111 | if self.identity:
112 | in_dim = in_dim + 64
113 | self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
114 | self.first_net = SeqTranslator1D(in_dim, out_dim,
115 | min_layers_num=3,
116 | residual=True,
117 | norm='ln'
118 | )
119 | self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
120 | self.dropout = nn.Dropout(0.1)
121 | # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
122 |
123 | def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
124 |
125 | spectrogram = spectrogram
126 | spectrogram = self.dropout(spectrogram)
127 | if self.identity:
128 | id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
129 | id = self.id_mlp(id)
130 | spectrogram = torch.cat([spectrogram, id], dim=1)
131 | x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
132 | if time_steps is not None:
133 | x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
134 | # x1, _ = self.att(x1, x1, x1)
135 | # x1, hidden_state = self.grus(x1)
136 | # x1 = x1.permute(0, 2, 1)
137 | hidden_state=None
138 |
139 | return x1, hidden_state
140 |
141 |
142 | class Generator(nn.Module):
143 | def __init__(self,
144 | n_poses,
145 | each_dim: list,
146 | dim_list: list,
147 | training=False,
148 | device=None,
149 | identity=True,
150 | num_classes=0,
151 | ):
152 | super().__init__()
153 |
154 | self.training = training
155 | self.device = device
156 | self.gen_length = n_poses
157 | self.identity = identity
158 |
159 | norm = 'ln'
160 | in_dim = 256
161 | out_dim = 256
162 |
163 | self.encoder_choice = 'faceformer'
164 |
165 | if self.encoder_choice == 'meshtalk':
166 | self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
167 | elif self.encoder_choice == 'faceformer':
168 | # wav2vec 2.0 weights initialization
169 | self.audio_encoder = Wav2Vec2Model.from_pretrained("../dataset/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
170 | self.audio_encoder.feature_extractor._freeze_parameters()
171 | self.audio_feature_map = nn.Linear(768, in_dim)
172 | else:
173 | self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
174 |
175 | self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
176 |
177 | self.dim_list = dim_list
178 |
179 | self.decoder = nn.ModuleList()
180 | self.final_out = nn.ModuleList()
181 |
182 | self.decoder.append(nn.Sequential(
183 | ConvNormRelu(out_dim, 64, norm=norm),
184 | ConvNormRelu(64, 64, norm=norm),
185 | ConvNormRelu(64, 64, norm=norm),
186 | ))
187 | self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
188 |
189 | self.decoder.append(nn.Sequential(
190 | ConvNormRelu(out_dim, out_dim, norm=norm),
191 | ConvNormRelu(out_dim, out_dim, norm=norm),
192 | ConvNormRelu(out_dim, out_dim, norm=norm),
193 | ))
194 | self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
195 |
196 | def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
197 | if self.training:
198 | time_steps = gt_poses.shape[1]
199 |
200 | # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
201 | if self.encoder_choice == 'meshtalk':
202 | in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
203 | feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
204 | elif self.encoder_choice == 'faceformer':
205 | hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
206 | feature = self.audio_feature_map(hidden_states).transpose(1, 2)
207 | else:
208 | feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
209 |
210 | # hidden_states = in_spec
211 |
212 | feature, _ = self.audio_middle(feature, id=id)
213 |
214 | out = []
215 |
216 | for i in range(self.decoder.__len__()):
217 | mid = self.decoder[i](feature)
218 | mid = self.final_out[i](mid)
219 | out.append(mid)
220 |
221 | out = torch.cat(out, dim=1)
222 | out = out.transpose(1, 2)
223 |
224 | return out, None
225 |
226 |
227 |
--------------------------------------------------------------------------------
/scripts/continuity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # os.environ["PYOPENGL_PLATFORM"] = "egl"
4 | from transformers import Wav2Vec2Processor
5 | from visualise.rendering import RenderTool
6 |
7 | sys.path.append(os.getcwd())
8 | from glob import glob
9 |
10 | import numpy as np
11 | import json
12 | import smplx as smpl
13 |
14 | from nets import *
15 | from trainer.options import parse_args
16 | from data_utils import torch_data
17 | from trainer.config import load_JsonConfig
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from torch.utils import data
23 | from scripts.diversity import init_model, init_dataloader, get_vertices
24 | from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
25 | from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
26 | import time
27 |
28 |
29 | global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
30 |
31 |
32 | def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
33 | smplx_model, rendertool, args=None, config=None, var=None):
34 | am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
35 | am_sr = 16000
36 | num_sample = 1
37 | face = False
38 | if face:
39 | body_static = torch.zeros([1, 162], device='cuda')
40 | body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
41 | stand = False
42 | j = 0
43 | gt_0 = None
44 |
45 | for bat in infer_loader:
46 | poses_ = bat['poses'].to(torch.float32).to(device)
47 | if poses_.shape[-1] == 300:
48 | j = j + 1
49 | if j > 1000:
50 | continue
51 | id = bat['speaker'].to('cuda') - 20
52 | if config.Data.pose.expression:
53 | expression = bat['expression'].to(device).to(torch.float32)
54 | poses = torch.cat([poses_, expression], dim=1)
55 | else:
56 | poses = poses_
57 | cur_wav_file = bat['aud_file'][0]
58 | betas = bat['betas'][0].to(torch.float64).to('cuda')
59 | # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
60 | gt = poses.to('cuda').squeeze().transpose(1, 0)
61 | if config.Data.pose.normalization:
62 | gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
63 | if config.Data.pose.convert_to_6d:
64 | if config.Data.pose.expression:
65 | gt_exp = gt[:, -100:]
66 | gt = gt[:, :-100]
67 |
68 | gt = gt.reshape(gt.shape[0], -1, 6)
69 | gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
70 | gt = torch.cat([gt, gt_exp], -1)
71 | if face:
72 | gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
73 |
74 | result_list = [gt]
75 |
76 | # cur_wav_file = '.\\training_data\\french-V4.wav'
77 |
78 | # pred_face = g_face.infer_on_audio(cur_wav_file,
79 | # initial_pose=poses_,
80 | # norm_stats=None,
81 | # w_pre=False,
82 | # # id=id,
83 | # frame=None,
84 | # am=am,
85 | # am_sr=am_sr
86 | # )
87 | #
88 | # pred_face = torch.tensor(pred_face).squeeze().to('cuda')
89 |
90 | pred_face = torch.zeros([gt.shape[0], 103], device='cuda')
91 | pred_jaw = pred_face[:, :3]
92 | pred_face = pred_face[:, 3:]
93 |
94 | # id = torch.tensor([0], device='cuda')
95 |
96 | for i in range(num_sample):
97 | pred_res = g_body.infer_on_audio(cur_wav_file,
98 | initial_pose=poses_,
99 | norm_stats=norm_stats,
100 | txgfile=None,
101 | id=id,
102 | var=var,
103 | fps=30,
104 | continuity=True,
105 | smooth=False
106 | )
107 | pred = torch.tensor(pred_res).squeeze().to('cuda')
108 |
109 | if pred.shape[0] < pred_face.shape[0]:
110 | repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
111 | pred = torch.cat([pred, repeat_frame], dim=0)
112 | else:
113 | pred = pred[:pred_face.shape[0], :]
114 |
115 | if config.Data.pose.convert_to_6d:
116 | pred = pred.reshape(pred.shape[0], -1, 6)
117 | pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
118 | pred = pred.reshape(pred.shape[0], -1)
119 |
120 | pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
121 | # pred[:, 9:12] = global_orient
122 | pred = part2full(pred, stand)
123 | if face:
124 | pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
125 | # result_list[0] = poses2pred(result_list[0], stand)
126 | # if gt_0 is None:
127 | # gt_0 = gt
128 | # pred = pred2poses(pred, gt_0)
129 | # result_list[0] = poses2poses(result_list[0], gt_0)
130 |
131 | result_list.append(pred)
132 |
133 | vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
134 |
135 | result_list = [res.to('cpu') for res in result_list]
136 | dict = np.concatenate(result_list[1:], axis=0)
137 | file_name = 'visualise/video/' + config.Log.name + '/' + \
138 | cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
139 | np.save(file_name, dict)
140 |
141 | rendertool._render_continuity(cur_wav_file, vertices_list[1], frame=60)
142 |
143 |
144 | def main():
145 | parser = parse_args()
146 | args = parser.parse_args()
147 | device = torch.device(args.gpu)
148 | torch.cuda.set_device(device)
149 |
150 | config = load_JsonConfig(args.config_file)
151 |
152 | smplx = True
153 |
154 | os.environ['smplx_npz_path'] = config.smplx_npz_path
155 | os.environ['extra_joint_path'] = config.extra_joint_path
156 | os.environ['j14_regressor_path'] = config.j14_regressor_path
157 |
158 | print('init model...')
159 | body_model_name = 's2g_body_pixel'
160 | body_model_path = './experiments/2022-12-31-smplx_S2G-body-pixel-conti-wide/ckpt-99.pth' # './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
161 | generator = init_model(body_model_name, body_model_path, args, config)
162 |
163 | # face_model_name = 's2g_face'
164 | # face_model_path = './experiments/2022-10-15-smplx_S2G-face-sgd-3p-wv2/ckpt-99.pth' # './experiments/2022-09-28-smplx_S2G-face-faceformer-3d/ckpt-99.pth'
165 | # generator_face = init_model(face_model_name, face_model_path, args, config)
166 | generator_face = None
167 | print('init dataloader...')
168 | infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
169 |
170 | print('init smlpx model...')
171 | dtype = torch.float64
172 | model_params = dict(model_path='E:/PycharmProjects/Motion-Projects/models',
173 | model_type='smplx',
174 | create_global_orient=True,
175 | create_body_pose=True,
176 | create_betas=True,
177 | num_betas=300,
178 | create_left_hand_pose=True,
179 | create_right_hand_pose=True,
180 | use_pca=False,
181 | flat_hand_mean=False,
182 | create_expression=True,
183 | num_expression_coeffs=100,
184 | num_pca_comps=12,
185 | create_jaw_pose=True,
186 | create_leye_pose=True,
187 | create_reye_pose=True,
188 | create_transl=False,
189 | # gender='ne',
190 | dtype=dtype, )
191 | smplx_model = smpl.create(**model_params).to('cuda')
192 | print('init rendertool...')
193 | rendertool = RenderTool('visualise/video/' + config.Log.name)
194 |
195 | infer(config.Data.data_root, generator, generator_face, None, args.exp_name, infer_loader, infer_set, device,
196 | norm_stats, smplx, smplx_model, rendertool, args, config, (None, None))
197 |
198 |
199 | if __name__ == '__main__':
200 | main()
201 |
--------------------------------------------------------------------------------
/scripts/test_body.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
6 | sys.path.append(os.getcwd())
7 |
8 | from tqdm import tqdm
9 | from transformers import Wav2Vec2Processor
10 |
11 | from evaluation.FGD import EmbeddingSpaceEvaluator
12 |
13 | from evaluation.metrics import LVD
14 |
15 | import numpy as np
16 | import smplx as smpl
17 |
18 | from data_utils.lower_body import part2full, poses2pred
19 | from data_utils.utils import get_mfcc_ta
20 | from nets import *
21 | from nets.utils import get_path, get_dpath
22 | from trainer.options import parse_args
23 | from data_utils import torch_data
24 | from trainer.config import load_JsonConfig
25 |
26 | import torch
27 | from torch.utils import data
28 | from data_utils.get_j import to3d, get_joints
29 |
30 |
31 | def init_model(model_name, model_path, args, config,bert_config,iteration):
32 | if model_name == 's2g_face':
33 | generator = s2g_face(
34 | args,
35 | config,
36 | )
37 | elif model_name == 's2g_body_vq':
38 | generator = s2g_body_vq(
39 | args,
40 | config,
41 | )
42 | elif model_name == 's2g_body_pixel':
43 | generator = s2g_body_pixel(
44 | args,
45 | config,
46 | bert_config,
47 | iteration
48 | )
49 | elif model_name == 's2g_body_ae':
50 | generator = s2g_body_ae(
51 | args,
52 | config,
53 | )
54 | else:
55 | raise NotImplementedError
56 |
57 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
58 | generator.load_state_dict(model_ckpt['generator'])
59 |
60 | return generator
61 |
62 |
63 | def init_dataloader(data_root, speakers, args, config):
64 | data_base = torch_data(
65 | data_root=data_root,
66 | speakers=speakers,
67 | split='test',
68 | limbscaling=False,
69 | normalization=config.Data.pose.normalization,
70 | norm_method=config.Data.pose.norm_method,
71 | split_trans_zero=False,
72 | num_pre_frames=config.Data.pose.pre_pose_length,
73 | num_generate_length=config.Data.pose.generate_length,
74 | num_frames=30,
75 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
76 | aud_feat_dim=config.Data.aud.aud_feat_dim,
77 | feat_method=config.Data.aud.feat_method,
78 | smplx=True,
79 | audio_sr=22000,
80 | convert_to_6d=config.Data.pose.convert_to_6d,
81 | expression=config.Data.pose.expression,
82 | config=config
83 | )
84 |
85 | if config.Data.pose.normalization:
86 | norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
87 | norm_stats = np.load(norm_stats_fn, allow_pickle=True)
88 | data_base.data_mean = norm_stats[0]
89 | data_base.data_std = norm_stats[1]
90 | else:
91 | norm_stats = None
92 |
93 | data_base.get_dataset()
94 | test_set = data_base.all_dataset
95 | test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
96 |
97 | return test_set, test_loader, norm_stats
98 |
99 |
100 | def body_loss(gt, prs):
101 | loss_dict = {}
102 | # LVD
103 | v_diff = LVD(gt[:, :22, :], prs[:, :, :22, :], symmetrical=False, weight=False)
104 | loss_dict['LVD'] = v_diff
105 | # Accuracy
106 | error = (gt - prs).norm(p=2, dim=-1).sum(dim=-1).mean()
107 | loss_dict['error'] = error
108 | # Diversity
109 | var = prs.var(dim=0).norm(p=2, dim=-1).sum(dim=-1).mean()
110 | loss_dict['diverse'] = var
111 |
112 | return loss_dict
113 |
114 |
115 | def test(test_loader, generator, FGD_handler, smplx_model, config):
116 | print('start testing')
117 |
118 | am = Wav2Vec2Processor.from_pretrained("/mnt/nj-1/usr/xuanqing/pws/dataset/wav2vec2-xls-r-300m-phoneme")
119 | am_sr = 16000
120 |
121 | loss_dict = {}
122 | B = 2
123 | with torch.no_grad():
124 | count = 0
125 | for bat in tqdm(test_loader, desc="Testing......"):
126 | count = count + 1
127 | # if count == 10:
128 | # break
129 | _, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
130 | bat['expression'].to('cuda').to(torch.float32)
131 | id = bat['speaker'].to('cuda') - 20
132 | betas = bat['betas'][0].to('cuda').to(torch.float64)
133 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2)
134 | text_feat = bat["video"].to('cuda')
135 | cur_wav_file = bat['aud_file'][0]
136 |
137 | zero_face = torch.zeros([B, poses.shape[1], 103], device='cuda')
138 |
139 | joints_list = []
140 |
141 | pred = generator.infer_on_audio(cur_wav_file,
142 | id=id,
143 | fps=30,
144 | B=B,
145 | am=am,
146 | am_sr=am_sr,
147 | frame=poses.shape[0],
148 | text_feat=text_feat
149 | )
150 | pred = torch.tensor(pred, device='cuda')
151 |
152 | FGD_handler.push_samples(pred, poses)
153 |
154 | poses = poses.squeeze()
155 | poses = to3d(poses, config)
156 |
157 | if pred.shape[2] > 129:
158 | pred = pred[:, :, 103:]
159 |
160 | pred = torch.cat([zero_face[:, :pred.shape[1], :3], pred, zero_face[:, :pred.shape[1], 3:]], dim=-1)
161 | full_pred = []
162 | for j in range(B):
163 | f_pred = part2full(pred[j])
164 | full_pred.append(f_pred)
165 |
166 | for i in range(full_pred.__len__()):
167 | full_pred[i] = full_pred[i].unsqueeze(dim=0)
168 | full_pred = torch.cat(full_pred, dim=0)
169 |
170 | pred_joints = get_joints(smplx_model, betas, full_pred)
171 |
172 | poses = poses2pred(poses)
173 | poses = torch.cat([zero_face[0, :, :3], poses[:, 3:165], zero_face[0, :, 3:]], dim=-1)
174 | gt_joints = get_joints(smplx_model, betas, poses[:pred_joints.shape[1]])
175 | FGD_handler.push_joints(pred_joints, gt_joints)
176 | aud = get_mfcc_ta(cur_wav_file, fps=30, sr=16000, am='not None', encoder_choice='onset')
177 | FGD_handler.push_aud(torch.from_numpy(aud))
178 |
179 | bat_loss_dict = body_loss(gt_joints, pred_joints)
180 |
181 | if loss_dict: # 非空
182 | for key in list(bat_loss_dict.keys()):
183 | loss_dict[key] += bat_loss_dict[key]
184 | else:
185 | for key in list(bat_loss_dict.keys()):
186 | loss_dict[key] = bat_loss_dict[key]
187 | for key in loss_dict.keys():
188 | loss_dict[key] = loss_dict[key] / count
189 | print(key + '=' + str(loss_dict[key].item()))
190 |
191 | # MAAC = FGD_handler.get_MAAC()
192 | # print(MAAC)
193 | fgd_dist, feat_dist = FGD_handler.get_scores()
194 | print('fgd_dist=', fgd_dist.item())
195 | print('feat_dist=', feat_dist.item())
196 | BCscore = FGD_handler.get_BCscore()
197 | print('Beat consistency score=', BCscore)
198 |
199 |
200 |
201 |
202 |
203 | def main():
204 | parser = parse_args()
205 | args = parser.parse_args()
206 | device = torch.device(args.gpu)
207 | torch.cuda.set_device(device)
208 |
209 | config = load_JsonConfig(args.config_file)
210 |
211 | os.environ['smplx_npz_path'] = config.smplx_npz_path
212 | os.environ['extra_joint_path'] = config.extra_joint_path
213 | os.environ['j14_regressor_path'] = config.j14_regressor_path
214 |
215 | print('init dataloader...')
216 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
217 | print('init model...')
218 | model_name = args.body_model_name
219 | # model_path = get_path(model_name, model_type)
220 | model_path = args.body_model_path
221 | bert_config = args.bert_config
222 | iteration = 100
223 | generator = init_model(model_name, model_path, args, config,bert_config,iteration)
224 |
225 | ae = init_model('s2g_body_ae', './experiments/feature_extractor.pth', args,
226 | config,bert_config,iteration)
227 | FGD_handler = EmbeddingSpaceEvaluator(ae, None, 'cuda')
228 |
229 | print('init smlpx model...')
230 | dtype = torch.float64
231 | smplx_path = './visualise/'
232 | model_params = dict(model_path=smplx_path,
233 | model_type='smplx',
234 | create_global_orient=True,
235 | create_body_pose=True,
236 | create_betas=True,
237 | num_betas=300,
238 | create_left_hand_pose=True,
239 | create_right_hand_pose=True,
240 | use_pca=False,
241 | flat_hand_mean=False,
242 | create_expression=True,
243 | num_expression_coeffs=100,
244 | num_pca_comps=12,
245 | create_jaw_pose=True,
246 | create_leye_pose=True,
247 | create_reye_pose=True,
248 | create_transl=False,
249 | dtype=dtype, )
250 |
251 | smplx_model = smpl.create(**model_params).to('cuda')
252 |
253 | test(test_loader, generator, FGD_handler, smplx_model, config)
254 |
255 |
256 | if __name__ == '__main__':
257 | main()
258 |
--------------------------------------------------------------------------------
/nets/spg/vqvae_1d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .wav2vec import Wav2Vec2Model
7 | from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
8 |
9 |
10 |
11 | class AudioEncoder(nn.Module):
12 | def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
13 | super(AudioEncoder, self).__init__()
14 | self._num_hiddens = num_hiddens
15 | self._num_residual_layers = num_residual_layers
16 | self._num_residual_hiddens = num_residual_hiddens
17 |
18 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
19 |
20 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
21 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
22 | sample='down')
23 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
24 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
25 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
26 |
27 | def forward(self, x, frame_num=0):
28 | h = self.project(x)
29 | h = self._enc_1(h)
30 | h = self._down_1(h)
31 | h = self._enc_2(h)
32 | h = self._down_2(h)
33 | h = self._enc_3(h)
34 | return h
35 |
36 |
37 | class Wav2VecEncoder(nn.Module):
38 | def __init__(self, num_hiddens, num_residual_layers):
39 | super(Wav2VecEncoder, self).__init__()
40 | self._num_hiddens = num_hiddens
41 | self._num_residual_layers = num_residual_layers
42 |
43 | self.audio_encoder = Wav2Vec2Model.from_pretrained(
44 | "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
45 | self.audio_encoder.feature_extractor._freeze_parameters()
46 |
47 | self.project = ConvNormRelu(768, self._num_hiddens, leaky=True)
48 |
49 | self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
50 | self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
51 | self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
52 | self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
53 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
54 |
55 | def forward(self, x, frame_num):
56 | h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2)
57 | h = self.project(h)
58 | h = self._enc_1(h)
59 | h = self._down_1(h)
60 | h = self._enc_2(h)
61 | h = self._down_2(h)
62 | h = self._enc_3(h)
63 | return h
64 |
65 |
66 | class Encoder(nn.Module):
67 | def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
68 | super(Encoder, self).__init__()
69 | self._num_hiddens = num_hiddens
70 | self._num_residual_layers = num_residual_layers
71 | self._num_residual_hiddens = num_residual_hiddens
72 |
73 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
74 |
75 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
76 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
77 | sample='down')
78 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
79 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
80 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
81 |
82 | self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
83 |
84 | def forward(self, x):
85 | """h = self.project(x)
86 | h = self._enc_1(h)
87 | h = self._down_1(h)
88 | h = self._enc_2(h)
89 | h = self._down_2(h)
90 | h = self._enc_3(h)
91 | h = self.pre_vq_conv(h)
92 | return h"""
93 | #print("1:",x.shape)
94 | h = self.project(x)
95 | #print("2:",h.shape)
96 | h = self._enc_1(h)
97 | #print("3",h.shape)
98 | h = self._down_1(h)
99 | #print("4:",h.shape)
100 | h = self._enc_2(h)
101 | #print("5:",h.shape)
102 | h = self._down_2(h)
103 | #print("6:",h.shape)
104 | h = self._enc_3(h)
105 | #print("7:",h.shape)
106 | h = self.pre_vq_conv(h)
107 | #print("8",h.shape)
108 | return h
109 |
110 |
111 | class Frame_Enc(nn.Module):
112 | def __init__(self, in_dim, num_hiddens):
113 | super(Frame_Enc, self).__init__()
114 | self.in_dim = in_dim
115 | self.num_hiddens = num_hiddens
116 |
117 | # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4)
118 | self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1)
119 | self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True)
120 | self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1)
121 | self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1)
122 |
123 | def forward(self, x):
124 | # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1)
125 | x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1)
126 | second_last = self.proj_2(x)
127 | last = self.proj_1(x)
128 | return second_last, last
129 |
130 |
131 |
132 | class Decoder(nn.Module):
133 | def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
134 | super(Decoder, self).__init__()
135 | self._num_hiddens = num_hiddens
136 | self._num_residual_layers = num_residual_layers
137 | self._num_residual_hiddens = num_residual_hiddens
138 |
139 | self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
140 |
141 | self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
142 | self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
143 | self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
144 | self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
145 | sample='up')
146 | self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
147 |
148 | if ae:
149 | self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4)
150 | self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True)
151 | self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True)
152 |
153 | self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
154 |
155 | def forward(self, h, last_frame=None):
156 |
157 | h = self.aft_vq_conv(h)
158 | h = self._dec_1(h)
159 | h = self._up_2(h)
160 | h = self._dec_2(h)
161 | h = self._up_3(h)
162 | h = self._dec_3(h)
163 |
164 | recon = self.project(h)
165 | return recon, None
166 |
167 |
168 | class Pre_VQ(nn.Module):
169 | def __init__(self, num_hiddens, embedding_dim, num_chunks):
170 | super(Pre_VQ, self).__init__()
171 | self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks)
172 | self.bn = nn.GroupNorm(num_chunks, num_hiddens)
173 | self.relu = nn.ReLU()
174 | self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks)
175 |
176 | def forward(self, x):
177 | x = self.conv(x)
178 | x = self.bn(x)
179 | x = self.relu(x)
180 | x = self.proj(x)
181 | return x
182 |
183 |
184 | class VQVAE(nn.Module):
185 | """VQ-VAE"""
186 |
187 | def __init__(self, in_dim, embedding_dim, num_embeddings,
188 | num_hiddens, num_residual_layers, num_residual_hiddens,
189 | commitment_cost=0.25, decay=0.99, share=False):
190 | super().__init__()
191 | self.in_dim = in_dim
192 | self.embedding_dim = embedding_dim
193 | self.num_embeddings = num_embeddings
194 | self.share_code_vq = share
195 |
196 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
197 | self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
198 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
199 |
200 | def forward(self, gt_poses, id=None, pre_state=None):
201 | #print("gt_pose: ",gt_poses.shape)
202 | z = self.encoder(gt_poses.transpose(1, 2))
203 | #print("z: ", z.shape)
204 | if not self.training:
205 | e, _ = self.vq_layer(z)
206 | x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
207 | return e, x_recon
208 |
209 | e, e_q_loss = self.vq_layer(z)
210 | gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
211 |
212 | return e_q_loss, gt_recon.transpose(1, 2)
213 |
214 | def encode(self, gt_poses, id=None):
215 | z = self.encoder(gt_poses.transpose(1, 2))
216 | e, latents = self.vq_layer(z)
217 | return e, latents
218 |
219 | def decode(self, b, w, e=None, latents=None, pre_state=None):
220 | if e is not None:
221 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
222 | else:
223 | e = self.vq_layer.quantize(latents)
224 | e = e.view(b, w, -1).permute(0, 2, 1).contiguous()
225 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
226 | return x
227 |
228 |
229 | class AE(nn.Module):
230 | """VQ-VAE"""
231 |
232 | def __init__(self, in_dim, embedding_dim, num_embeddings,
233 | num_hiddens, num_residual_layers, num_residual_hiddens):
234 | super().__init__()
235 | self.in_dim = in_dim
236 | self.embedding_dim = embedding_dim
237 | self.num_embeddings = num_embeddings
238 |
239 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
240 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True)
241 |
242 | def forward(self, gt_poses, id=None, pre_state=None):
243 | z = self.encoder(gt_poses.transpose(1, 2))
244 | if not self.training:
245 | x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
246 | return z, x_recon
247 | gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
248 |
249 | return gt_recon.transpose(1, 2)
250 |
251 | def encode(self, gt_poses, id=None):
252 | z = self.encoder(gt_poses.transpose(1, 2))
253 | return z
254 |
--------------------------------------------------------------------------------
/nets/spg/blip.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | import warnings
9 | warnings.filterwarnings("ignore")
10 |
11 | from .vit import VisionTransformer, interpolate_pos_embed
12 | from .med import BertConfig, BertModel, BertLMHeadModel
13 | from transformers import BertTokenizer
14 |
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 |
19 | import os
20 | from urllib.parse import urlparse
21 | from timm.models.hub import download_cached_file
22 |
23 | class BLIP_Base(nn.Module):
24 | def __init__(self,
25 | med_config = 'configs/med_config.json',
26 | image_size = 224,
27 | vit = 'base',
28 | vit_grad_ckpt = False,
29 | vit_ckpt_layer = 0,
30 | ):
31 | """
32 | Args:
33 | med_config (str): path for the mixture of encoder-decoder model's configuration file
34 | image_size (int): input image size
35 | vit (str): model size of vision transformer
36 | """
37 | super().__init__()
38 |
39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40 | self.tokenizer = init_tokenizer()
41 | med_config = BertConfig.from_json_file(med_config)
42 | med_config.encoder_width = vision_width
43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44 |
45 |
46 | def forward(self, image, caption, mode):
47 |
48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50 |
51 | if mode=='image':
52 | # return image features
53 | image_embeds = self.visual_encoder(image)
54 | return image_embeds
55 |
56 | elif mode=='text':
57 | # return text features
58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59 | return_dict = True, mode = 'text')
60 | return text_output.last_hidden_state
61 |
62 | elif mode=='multimodal':
63 | # return multimodel features
64 | image_embeds = self.visual_encoder(image)
65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66 |
67 | text.input_ids[:,0] = self.tokenizer.enc_token_id
68 | output = self.text_encoder(text.input_ids,
69 | attention_mask = text.attention_mask,
70 | encoder_hidden_states = image_embeds,
71 | encoder_attention_mask = image_atts,
72 | return_dict = True,
73 | )
74 | return output.last_hidden_state
75 |
76 |
77 |
78 | class BLIP_Decoder(nn.Module):
79 | def __init__(self,
80 | med_config = 'configs/med_config.json',
81 | image_size = 384,
82 | vit = 'base',
83 | vit_grad_ckpt = False,
84 | vit_ckpt_layer = 0,
85 | prompt = 'a picture of ',
86 | ):
87 | """
88 | Args:
89 | med_config (str): path for the mixture of encoder-decoder model's configuration file
90 | image_size (int): input image size
91 | vit (str): model size of vision transformer
92 | """
93 | super().__init__()
94 |
95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96 | self.tokenizer = init_tokenizer()
97 | med_config = BertConfig.from_json_file(med_config)
98 | med_config.encoder_width = vision_width
99 | self.text_decoder = BertLMHeadModel(config=med_config)
100 |
101 | self.prompt = prompt
102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103 |
104 |
105 | def forward(self, image, caption):
106 |
107 | image_embeds = self.visual_encoder(image)
108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109 |
110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111 |
112 | text.input_ids[:,0] = self.tokenizer.bos_token_id
113 |
114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115 | decoder_targets[:,:self.prompt_length] = -100
116 |
117 | decoder_output = self.text_decoder(text.input_ids,
118 | attention_mask = text.attention_mask,
119 | encoder_hidden_states = image_embeds,
120 | encoder_attention_mask = image_atts,
121 | labels = decoder_targets,
122 | return_dict = True,
123 | )
124 | loss_lm = decoder_output.loss
125 |
126 | return loss_lm
127 |
128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129 | image_embeds = self.visual_encoder(image)
130 |
131 | if not sample:
132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133 |
134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136 |
137 | prompt = [self.prompt] * image.size(0)
138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139 | input_ids[:,0] = self.tokenizer.bos_token_id
140 | input_ids = input_ids[:, :-1]
141 |
142 | if sample:
143 | #nucleus sampling
144 | outputs = self.text_decoder.generate(input_ids=input_ids,
145 | max_length=max_length,
146 | min_length=min_length,
147 | do_sample=True,
148 | top_p=top_p,
149 | num_return_sequences=1,
150 | eos_token_id=self.tokenizer.sep_token_id,
151 | pad_token_id=self.tokenizer.pad_token_id,
152 | repetition_penalty=1.1,
153 | **model_kwargs)
154 | else:
155 | #beam search
156 | outputs = self.text_decoder.generate(input_ids=input_ids,
157 | max_length=max_length,
158 | min_length=min_length,
159 | num_beams=num_beams,
160 | eos_token_id=self.tokenizer.sep_token_id,
161 | pad_token_id=self.tokenizer.pad_token_id,
162 | repetition_penalty=repetition_penalty,
163 | **model_kwargs)
164 |
165 | captions = []
166 | for output in outputs:
167 | caption = self.tokenizer.decode(output, skip_special_tokens=True)
168 | captions.append(caption[len(self.prompt):])
169 | return captions
170 |
171 |
172 | def blip_decoder(pretrained='',**kwargs):
173 | model = BLIP_Decoder(**kwargs)
174 | if pretrained:
175 | model,msg = load_checkpoint(model,pretrained)
176 | assert(len(msg.missing_keys)==0)
177 | return model
178 |
179 | def blip_feature_extractor(pretrained='',**kwargs):
180 | model = BLIP_Base(**kwargs)
181 | if pretrained:
182 | model,msg = load_checkpoint(model,pretrained)
183 | assert(len(msg.missing_keys)==0)
184 | return model
185 |
186 | def init_tokenizer():
187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'})
189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
191 | return tokenizer
192 |
193 |
194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
195 |
196 | assert vit in ['base', 'large'], "vit parameter must be base or large"
197 | if vit=='base':
198 | vision_width = 768
199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
201 | drop_path_rate=0 or drop_path_rate
202 | )
203 | elif vit=='large':
204 | vision_width = 1024
205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207 | drop_path_rate=0.1 or drop_path_rate
208 | )
209 | return visual_encoder, vision_width
210 |
211 | def is_url(url_or_filename):
212 | parsed = urlparse(url_or_filename)
213 | return parsed.scheme in ("http", "https")
214 |
215 | def load_checkpoint(model,url_or_filename):
216 | if is_url(url_or_filename):
217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
218 | checkpoint = torch.load(cached_file, map_location='cpu')
219 | elif os.path.isfile(url_or_filename):
220 | checkpoint = torch.load(url_or_filename, map_location='cpu')
221 | else:
222 | raise RuntimeError('checkpoint url or path is invalid')
223 |
224 | state_dict = checkpoint['model']
225 |
226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
229 | model.visual_encoder_m)
230 | for key in model.state_dict().keys():
231 | if key in state_dict.keys():
232 | if state_dict[key].shape!=model.state_dict()[key].shape:
233 | del state_dict[key]
234 |
235 | msg = model.load_state_dict(state_dict,strict=False)
236 | print('load checkpoint from %s'%url_or_filename)
237 | return model,msg
238 |
239 |
--------------------------------------------------------------------------------
/data_utils/dataloader_torch.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.getcwd())
4 | import os
5 | from tqdm import tqdm
6 | from data_utils.utils import *
7 | import torch.utils.data as data
8 | from data_utils.mesh_dataset import SmplxDataset
9 | from transformers import Wav2Vec2Processor
10 |
11 |
12 | class MultiVidData():
13 | def __init__(self,
14 | data_root,
15 | speakers,
16 | split='train',
17 | limbscaling=False,
18 | normalization=False,
19 | norm_method='new',
20 | split_trans_zero=False,
21 | num_frames=25,
22 | num_pre_frames=25,
23 | num_generate_length=None,
24 | aud_feat_win_size=None,
25 | aud_feat_dim=64,
26 | feat_method='mel_spec',
27 | context_info=False,
28 | smplx=False,
29 | audio_sr=16000,
30 | convert_to_6d=False,
31 | expression=False,
32 | config=None
33 | ):
34 | self.data_root = data_root
35 | self.speakers = speakers
36 | self.split = split
37 | if split == 'pre':
38 | self.split = 'train'
39 | self.norm_method=norm_method
40 | self.normalization = normalization
41 | self.limbscaling = limbscaling
42 | self.convert_to_6d = convert_to_6d
43 | self.num_frames=num_frames
44 | self.num_pre_frames=num_pre_frames
45 | if num_generate_length is None:
46 | self.num_generate_length = num_frames
47 | else:
48 | self.num_generate_length = num_generate_length
49 | self.split_trans_zero=split_trans_zero
50 |
51 | dataset = SmplxDataset
52 |
53 | if self.split_trans_zero:
54 | self.trans_dataset_list = []
55 | self.zero_dataset_list = []
56 | else:
57 | self.all_dataset_list = []
58 | self.dataset={}
59 | self.complete_data=[]
60 | self.config=config
61 | load_mode=self.config.dataset_load_mode
62 |
63 | ######################load with pickle file
64 | if load_mode=='pickle':
65 | import pickle
66 | import subprocess
67 |
68 | # store_file_path='/tmp/store.pkl'
69 | # cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
70 | # subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
71 |
72 | # f = open(self.config.store_file_path, 'rb+')
73 | f = open(self.split+config.Data.pklname, 'rb+')
74 | self.dataset=pickle.load(f)
75 | f.close()
76 | for key in self.dataset:
77 | self.complete_data.append(self.dataset[key].complete_data)
78 | ######################load with pickle file
79 |
80 | ######################load with a csv file
81 | elif load_mode=='csv':
82 |
83 | # 这里从我的一个code文件夹导入的,后续再完善进来
84 | try:
85 | sys.path.append(self.config.config_root_path)
86 | from config import config_path
87 | from csv_parser import csv_parse
88 |
89 | except ImportError as e:
90 | print(f'err: {e}')
91 | raise ImportError('config root path error...')
92 |
93 |
94 | for speaker_name in self.speakers:
95 | # df_intervals=pd.read_csv(self.config.voca_csv_file_path)
96 | df_intervals=None
97 | df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
98 | df_intervals = df_intervals[df_intervals['dataset'] == self.split]
99 |
100 | print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
101 | for iter_index, (_, interval) in tqdm(
102 | (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
103 | ):
104 |
105 | (
106 | interval_index,
107 | interval_speaker,
108 | interval_video_fn,
109 | interval_id,
110 |
111 | start_time,
112 | end_time,
113 | duration_time,
114 | start_time_10,
115 | over_flow_flag,
116 | short_dur_flag,
117 |
118 | big_video_dir,
119 | small_video_dir_name,
120 | speaker_video_path,
121 |
122 | voca_basename,
123 | json_basename,
124 | wav_basename,
125 | voca_top_clip_path,
126 | voca_json_clip_path,
127 | voca_wav_clip_path,
128 |
129 | audio_output_fn,
130 | image_output_path,
131 | pifpaf_output_path,
132 | mp_output_path,
133 | op_output_path,
134 | deca_output_path,
135 | pixie_output_path,
136 | cam_output_path,
137 | ours_output_path,
138 | merge_output_path,
139 | multi_output_path,
140 | gt_output_path,
141 | ours_images_path,
142 | pkl_fil_path,
143 | )=csv_parse(interval)
144 |
145 | if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
146 | continue
147 |
148 | key=f'{interval_video_fn}/{small_video_dir_name}'
149 | self.dataset[key] = dataset(
150 | data_root=pkl_fil_path,
151 | speaker=speaker_name,
152 | audio_fn=audio_output_fn,
153 | audio_sr=audio_sr,
154 | fps=num_frames,
155 | feat_method=feat_method,
156 | audio_feat_dim=aud_feat_dim,
157 | train=(self.split == 'train'),
158 | load_all=True,
159 | split_trans_zero=self.split_trans_zero,
160 | limbscaling=self.limbscaling,
161 | num_frames=self.num_frames,
162 | num_pre_frames=self.num_pre_frames,
163 | num_generate_length=self.num_generate_length,
164 | audio_feat_win_size=aud_feat_win_size,
165 | context_info=context_info,
166 | convert_to_6d=convert_to_6d,
167 | expression=expression,
168 | config=self.config
169 | )
170 | self.complete_data.append(self.dataset[key].complete_data)
171 | ######################load with a csv file
172 |
173 | ######################origin load method
174 | elif load_mode=='json':
175 |
176 | # if self.split == 'train':
177 | # import pickle
178 | # f = open('store.pkl', 'rb+')
179 | # self.dataset=pickle.load(f)
180 | # f.close()
181 | # for key in self.dataset:
182 | # self.complete_data.append(self.dataset[key].complete_data)
183 | # else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
184 | # if config.Model.model_type == 'face':
185 | am = Wav2Vec2Processor.from_pretrained("/mnt/nj-1/usr/xuanqing/pws/dataset/wav2vec2-xls-r-300m-phoneme")
186 | am_sr = 16000
187 | # else:
188 | # am, am_sr = None, None
189 | for speaker_name in self.speakers:
190 | speaker_root = os.path.join(self.data_root, speaker_name)
191 |
192 | videos=[v for v in os.listdir(speaker_root) ]
193 | print(videos)
194 |
195 | haode = huaide = 0
196 |
197 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
198 | source_vid=vid
199 | # vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
200 | vid_pth = os.path.join(speaker_root, source_vid, self.split)
201 | if smplx == 'pose':
202 | seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
203 | else:
204 | try:
205 | seqs = [s for s in os.listdir(vid_pth)]
206 | except:
207 | continue
208 |
209 | for s in seqs:
210 | seq_root=os.path.join(vid_pth, s)
211 | key = seq_root # correspond to clip******
212 | audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
213 | motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
214 | video_fname = os.path.join(speaker_root, source_vid, self.split, s,'embedding_viclip_base_200m.npy')
215 | if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
216 | huaide = huaide + 1
217 | continue
218 |
219 | self.dataset[key]=dataset(
220 | data_root=seq_root,
221 | speaker=speaker_name,
222 | motion_fn=motion_fname,
223 | video_fn = video_fname,
224 | audio_fn=audio_fname,
225 | audio_sr=audio_sr,
226 | fps=num_frames,
227 | feat_method=feat_method,
228 | audio_feat_dim=aud_feat_dim,
229 | train=(self.split=='train'),
230 | load_all=True,
231 | split_trans_zero=self.split_trans_zero,
232 | limbscaling=self.limbscaling,
233 | num_frames=self.num_frames,
234 | num_pre_frames=self.num_pre_frames,
235 | num_generate_length=self.num_generate_length,
236 | audio_feat_win_size=aud_feat_win_size,
237 | context_info=context_info,
238 | convert_to_6d=convert_to_6d,
239 | expression=expression,
240 | config=self.config,
241 | am=am,
242 | am_sr=am_sr,
243 | whole_video=config.Data.whole_video
244 | )
245 | self.complete_data.append(self.dataset[key].complete_data)
246 | haode = haode + 1
247 | print("huaide:{}, haode:{}".format(huaide, haode))
248 | import pickle
249 |
250 | f = open(self.split+config.Data.pklname, 'wb')
251 | pickle.dump(self.dataset, f)
252 | f.close()
253 | ######################origin load method
254 |
255 | self.complete_data=np.concatenate(self.complete_data, axis=0)
256 |
257 | # assert self.complete_data.shape[-1] == (12+21+21)*2
258 | self.normalize_stats = {}
259 |
260 | self.data_mean = None
261 | self.data_std = None
262 |
263 | def get_dataset(self):
264 | self.normalize_stats['mean'] = self.data_mean
265 | self.normalize_stats['std'] = self.data_std
266 |
267 | for key in list(self.dataset.keys()):
268 | if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
269 | continue
270 | self.dataset[key].num_generate_length = self.num_generate_length
271 | self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
272 | self.all_dataset_list.append(self.dataset[key].all_dataset)
273 |
274 | if self.split_trans_zero:
275 | self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
276 | self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
277 | else:
278 | self.all_dataset = data.ConcatDataset(self.all_dataset_list)
279 |
280 |
281 |
282 |
--------------------------------------------------------------------------------