├── .gitignore
├── README.md
├── __init__.py
├── config
├── LS3DCG.json
├── body_pixel.json
├── body_vq.json
└── face.json
├── data_utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── consts.cpython-37.pyc
│ ├── dataloader_torch.cpython-37.pyc
│ ├── lower_body.cpython-37.pyc
│ ├── mesh_dataset.cpython-37.pyc
│ ├── rotation_conversion.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── apply_split.py
├── axis2matrix.py
├── consts.py
├── dataloader_torch.py
├── dataset_preprocess.py
├── get_j.py
├── hand_component.json
├── lower_body.py
├── mesh_dataset.py
├── rotation_conversion.py
├── split_more_than_2s.pkl
├── split_train_val_test.py
├── train_val_test.json
└── utils.py
├── demo
├── 1st-page
│ ├── 1st-page-upper.mp4
│ └── 1st-page-upper.npy
├── french
│ ├── french.mp4
│ └── french.npy
├── rich
│ ├── rich.mp4
│ └── rich.npy
├── song
│ ├── cut.mp4
│ ├── song.mp4
│ └── song.npy
└── style
│ ├── chemistry.mp4
│ ├── chemistry.npy
│ ├── conan.mp4
│ ├── conan.npy
│ ├── diversity.mp4
│ ├── diversity.npy
│ ├── face.mp4
│ ├── face.npy
│ ├── oliver.mp4
│ ├── oliver.npy
│ ├── seth.mp4
│ └── seth.npy
├── demo_audio
├── 1st-page.wav
├── french.wav
├── rich.wav
├── rich_short.wav
├── song.wav
└── style.wav
├── evaluation
├── FGD.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── metrics.cpython-37.pyc
├── diversity_LVD.py
├── get_quality_samples.py
├── metrics.py
├── mode_transition.py
├── peak_velocity.py
└── util.py
├── losses
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── losses.cpython-37.pyc
└── losses.py
├── nets
├── LS3DCG.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── base.cpython-37.pyc
│ ├── init_model.cpython-37.pyc
│ ├── layers.cpython-37.pyc
│ ├── smplx_body_pixel.cpython-37.pyc
│ ├── smplx_body_vq.cpython-37.pyc
│ ├── smplx_face.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── base.py
├── body_ae.py
├── init_model.py
├── layers.py
├── smplx_body_pixel.py
├── smplx_body_vq.py
├── smplx_face.py
├── spg
│ ├── __pycache__
│ │ ├── gated_pixelcnn_v2.cpython-37.pyc
│ │ ├── s2g_face.cpython-37.pyc
│ │ ├── s2glayers.cpython-37.pyc
│ │ ├── vqvae_1d.cpython-37.pyc
│ │ ├── vqvae_modules.cpython-37.pyc
│ │ └── wav2vec.cpython-37.pyc
│ ├── gated_pixelcnn_v2.py
│ ├── s2g_face.py
│ ├── s2glayers.py
│ ├── vqvae_1d.py
│ ├── vqvae_modules.py
│ └── wav2vec.py
└── utils.py
├── requirements.txt
├── scripts
├── .idea
│ ├── __init__.py
│ ├── aws.xml
│ ├── deployment.xml
│ ├── get_prevar.py
│ ├── inspectionProfiles
│ │ ├── Project_Default.xml
│ │ └── profiles_settings.xml
│ ├── lower body
│ ├── modules.xml
│ ├── scripts.iml
│ ├── test.png
│ ├── testtext.py
│ ├── vcs.xml
│ └── workspace.xml
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── diversity.cpython-37.pyc
├── continuity.py
├── demo.py
├── diversity.py
├── test_body.py
├── test_face.py
├── test_vq.py
└── train.py
├── test_body.sh
├── test_face.sh
├── train_body_pixel.sh
├── train_body_vq.sh
├── train_face.sh
├── trainer
├── Trainer.py
├── __init__.py
├── __pycache__
│ ├── Trainer.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── config.cpython-37.pyc
│ └── options.cpython-37.pyc
├── config.py
├── options.py
└── training_config.cfg
├── visualise.sh
├── visualise
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── rendering.cpython-37.pyc
├── rendering.py
├── smplx
│ ├── SMPLX_to_J14.pkl
│ └── smplx_extra_joints.yaml
└── teaser_01.png
└── voca
├── __pycache__
└── rendering.cpython-37.pyc
└── rendering.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /experiments
2 | visualise/smplx/SMPLX_NEUTRAL.npz
3 | test_6d_mfcc.pkl
4 | test_3d_mfcc.pkl
5 | visualise/video
6 | .idea
7 | .git
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TalkSHOW: Generating Holistic 3D Human Motion from Speech [CVPR2023]
2 |
3 | The official PyTorch implementation of the **CVPR2023** paper [**"Generating Holistic 3D Human Motion from Speech"**](https://arxiv.org/abs/2212.04420).
4 |
5 | Please visit our [**webpage**](https://talkshow.is.tue.mpg.de/) for more details.
6 |
7 | 
8 |
9 | ## HighLight
10 |
11 | We directly provide the input and our output for the demo data, you can find them in `/demo/` and `/demo_audio/`. TalkSHOW can generalize well on English, French, Songs so far. Looking forward to more demos.
12 |
13 | You can directly use the generated motion to animate your 3D character or your own digital avatar. We will provide more demos, please stay tuned. And we are quite looking forward to your pull request.
14 |
15 | ## Notes
16 |
17 | We are using 100 dimension parameters for SMPL-X facial expression, if you need other dimensions parameters, you can use this code to convert.
18 |
19 | ```
20 | https://github.com/yhw-yhw/SHOW/blob/main/cvt_exp_dim_tool.py
21 | ```
22 |
23 | ## TODO
24 |
25 | - [x] [🤗Hugging Face Demo](https://huggingface.co/spaces/feifeifeiliu/TalkSHOW)
26 | - [ ] Animated 2D videos by the generated motion from TalkSHOW.
27 |
28 |
29 | ## Getting started
30 |
31 | The training code was tested on `Ubuntu 18.04.5 LTS` and the visualization code was test on `Windows 10`, and it requires:
32 |
33 | * Python 3.7
34 | * conda3 or miniconda3
35 | * CUDA capable GPU (one is enough)
36 |
37 |
38 |
39 | ### 1. Setup environment
40 |
41 | Clone the repo:
42 | ```bash
43 | git clone https://github.com/yhw-yhw/TalkSHOW
44 | cd TalkSHOW
45 | ```
46 | Create conda environment:
47 | ```bash
48 | conda create --name talkshow python=3.7
49 | conda activate talkshow
50 | ```
51 | Please install pytorch (v1.10.1).
52 |
53 | pip install -r requirements.txt
54 |
55 | Please install [**MPI-Mesh**](https://github.com/MPI-IS/mesh).
56 |
57 | ### 2. Get data
58 |
59 | Please note that if you only want to generate demo videos, you can skip this step and directly download the pretrained models.
60 |
61 | Download [**SHOW_dataset_v1.0.zip**](https://download.is.tue.mpg.de/download.php?domain=talkshow&resume=1&sfile=SHOW_dataset_v1.0.zip) from [**TalkSHOW download webpage**](https://talkshow.is.tue.mpg.de/download.php),
62 | unzip using ``for i in $(ls *.tar.gz);do tar xvf $i;done``.
63 |
64 | ~~Run ``python data_utils/dataset_preprocess.py`` to check and split dataset.
65 | Modify ``data_root`` in ``config/*.json`` to the dataset-path.~~
66 |
67 | Modify ``data_root`` in ``data_utils/apply_split.py`` to the dataset path and run it to apply ``data_utils/split_more_than_2s.pkl`` to the dataset.
68 |
69 | We will update the benchmark soon.
70 |
71 | ### 3. Download the pretrained models (Optional)
72 |
73 | Download [**pretrained models**](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view?usp=sharing),
74 | unzip and place it in the TalkSHOW folder, i.e. ``path-to-TalkSHOW/experiments``.
75 |
76 | ### 4. Training
77 | Please note that the process of loading data for the first time can be quite slow. If you have already completed the loading process, setting ``dataset_load_mode`` to ``pickle`` in ``config/[config_name].json`` will make the loading process much faster.
78 |
79 | # 1. Train VQ-VAEs.
80 | bash train_body_vq.sh
81 | # 2. Train PixelCNN. Please modify "Model:vq_path" in config/body_pixel.json to the path of VQ-VAEs.
82 | bash train_body_pixel.sh
83 | # 3. Train face generator.
84 | bash train_face.sh
85 |
86 | ### 5. Testing
87 |
88 | Modify the arguments in ``test_face.sh`` and ``test_body.sh``. Then
89 |
90 | bash test_face.sh
91 | bash test_body.sh
92 |
93 | ### 5. Visualization
94 |
95 | If you ssh into the linux machine, NotImplementedError might occur. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
96 |
97 | Download [**smplx model**](https://drive.google.com/file/d/1Ly_hQNLQcZ89KG0Nj4jYZwccQiimSUVn/view?usp=share_link) (Please register in the official [**SMPLX webpage**](https://smpl-x.is.tue.mpg.de) before you use it.)
98 | and place it in ``path-to-TalkSHOW/visualise/smplx_model``.
99 | To visualise the test set and generated result (in each video, left: generated result | right: ground truth).
100 | The videos and generated motion data are saved in ``./visualise/video/body-pixel``:
101 |
102 | bash visualise.sh
103 |
104 | If you ssh into the linux machine, there might be an error about OffscreenRenderer. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
105 |
106 | To reproduce the demo videos, run
107 | ```bash
108 | # the whole body demo
109 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/1st-page.wav --id 0 --whole_body
110 | # the face demo
111 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --only_face
112 | # the identity-specific demo
113 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0
114 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 1
115 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 2
116 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 3 --stand
117 | # the diversity demo
118 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --num_samples 12
119 | # the french demo
120 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/french.wav --id 0
121 | # the synthetic speech demo
122 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/rich.wav --id 0
123 | # the song demo
124 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/song.wav --id 0
125 | ````
126 | ### 6. Baseline
127 |
128 | For training the reproducted "Learning Speech-driven 3D Conversational Gestures from Video" (Habibie et al.), you could run
129 | ```bash
130 | python -W ignore scripts/train.py --speakers oliver seth conan chemistry --config_file ./config/LS3DCG.json
131 | ```
132 |
133 | For visualization with the pretrained model, download the above [pretrained models](#3-download-the-pretrained-models--optional-) and run
134 | ```bash
135 | python scripts/demo.py --config_file ./config/LS3DCG.json --infer --audio_file ./demo_audio/style.wav --body_model_name s2g_LS3DCG --body_model_path experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth --id 0
136 | ```
137 |
138 | ## Citation
139 | If you find our work useful to your research, please consider citing:
140 | ```
141 | @inproceedings{yi2022generating,
142 | title={Generating Holistic 3D Human Motion from Speech},
143 | author={Yi, Hongwei and Liang, Hualin and Liu, Yifei and Cao, Qiong and Wen, Yandong and Bolkart, Timo and Tao, Dacheng and Black, Michael J},
144 | booktitle={CVPR},
145 | year={2023}
146 | }
147 | ```
148 |
149 | ## Acknowledgements
150 | For functions or scripts that are based on external sources, we acknowledge the origin individually in each file.
151 | Here are some great resources we benefit:
152 | - [Freeform](https://github.com/TheTempAccount/Co-Speech-Motion-Generation) for training pipeline
153 | - [MPI-Mesh](https://github.com/MPI-IS/mesh), [Pyrender](https://github.com/mmatl/pyrender), [Smplx](https://github.com/vchoutas/smplx), [VOCA](https://github.com/TimoBolkart/voca) for rendering
154 | - [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) and [Faceformer](https://github.com/EvelynFan/FaceFormer) for audio encoder
155 |
156 | ## Contact
157 | For questions, please contact talkshow@tue.mpg.de or hongwei.yi@tuebingen.mpg.de or fthualinliang@mail.scut.edu.cn or ft_lyf@mail.scut.edu.cn
158 |
159 | For commercial licensing, please contact ps-licensing@tue.mpg.de
160 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/__init__.py
--------------------------------------------------------------------------------
/config/LS3DCG.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "pickle",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15 | "pklname": "_3d_mfcc.pkl",
16 | "whole_video": false,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "body",
36 | "model_name": "s2g_LS3DCG",
37 | "code_num": 2048,
38 | "AudioOpt": "Adam",
39 | "encoder_choice": "mfcc",
40 | "gan": false
41 | },
42 | "DataLoader": {
43 | "batch_size": 128,
44 | "num_workers": 0
45 | },
46 | "Train": {
47 | "epochs": 100,
48 | "max_gradient_norm": 5,
49 | "learning_rate": {
50 | "generator_learning_rate": 1e-4,
51 | "discriminator_learning_rate": 1e-4
52 | },
53 | "weights": {
54 | "keypoint_loss_weight": 1.0,
55 | "gan_loss_weight": 1.0
56 | }
57 | },
58 | "Log": {
59 | "save_every": 50,
60 | "print_every": 200,
61 | "name": "LS3DCG"
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/config/body_pixel.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "json",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15 | "pklname": "_3d_mfcc.pkl",
16 | "whole_video": false,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "body",
36 | "model_name": "s2g_body_pixel",
37 | "composition": true,
38 | "code_num": 2048,
39 | "bh_model": true,
40 | "AudioOpt": "Adam",
41 | "encoder_choice": "mfcc",
42 | "gan": false,
43 | "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
44 | },
45 | "DataLoader": {
46 | "batch_size": 128,
47 | "num_workers": 0
48 | },
49 | "Train": {
50 | "epochs": 100,
51 | "max_gradient_norm": 5,
52 | "learning_rate": {
53 | "generator_learning_rate": 1e-4,
54 | "discriminator_learning_rate": 1e-4
55 | }
56 | },
57 | "Log": {
58 | "save_every": 50,
59 | "print_every": 200,
60 | "name": "body-pixel2"
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/config/body_vq.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "json",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15 | "pklname": "_3d_mfcc.pkl",
16 | "whole_video": false,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "body",
36 | "model_name": "s2g_body_vq",
37 | "composition": true,
38 | "code_num": 2048,
39 | "bh_model": true,
40 | "AudioOpt": "Adam",
41 | "encoder_choice": "mfcc",
42 | "gan": false
43 | },
44 | "DataLoader": {
45 | "batch_size": 128,
46 | "num_workers": 0
47 | },
48 | "Train": {
49 | "epochs": 100,
50 | "max_gradient_norm": 5,
51 | "learning_rate": {
52 | "generator_learning_rate": 1e-4,
53 | "discriminator_learning_rate": 1e-4
54 | }
55 | },
56 | "Log": {
57 | "save_every": 50,
58 | "print_every": 200,
59 | "name": "body-vq"
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/config/face.json:
--------------------------------------------------------------------------------
1 | {
2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3 | "dataset_load_mode": "json",
4 | "store_file_path": "store.pkl",
5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8 | "param": {
9 | "w_j": 1,
10 | "w_b": 1,
11 | "w_h": 1
12 | },
13 | "Data": {
14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15 | "pklname": "_3d_wv2.pkl",
16 | "whole_video": true,
17 | "pose": {
18 | "normalization": false,
19 | "convert_to_6d": false,
20 | "norm_method": "all",
21 | "augmentation": false,
22 | "generate_length": 88,
23 | "pre_pose_length": 0,
24 | "pose_dim": 99,
25 | "expression": true
26 | },
27 | "aud": {
28 | "feat_method": "mfcc",
29 | "aud_feat_dim": 64,
30 | "aud_feat_win_size": null,
31 | "context_info": false
32 | }
33 | },
34 | "Model": {
35 | "model_type": "face",
36 | "model_name": "s2g_face",
37 | "AudioOpt": "SGD",
38 | "encoder_choice": "faceformer",
39 | "gan": false
40 | },
41 | "DataLoader": {
42 | "batch_size": 1,
43 | "num_workers": 0
44 | },
45 | "Train": {
46 | "epochs": 100,
47 | "max_gradient_norm": 5,
48 | "learning_rate": {
49 | "generator_learning_rate": 1e-4,
50 | "discriminator_learning_rate": 1e-4
51 | }
52 | },
53 | "Log": {
54 | "save_every": 50,
55 | "print_every": 1000,
56 | "name": "face"
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # from .dataloader_csv import MultiVidData as csv_data
2 | from .dataloader_torch import MultiVidData as torch_data
3 | from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
--------------------------------------------------------------------------------
/data_utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/consts.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/consts.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/dataloader_torch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/dataloader_torch.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/lower_body.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/lower_body.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/mesh_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/mesh_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/rotation_conversion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/rotation_conversion.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/apply_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import pickle
4 | import shutil
5 |
6 | speakers = ['seth', 'oliver', 'conan', 'chemistry']
7 | source_data_root = "../expressive_body-V0.7"
8 | data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
9 |
10 | f_read = open('split_more_than_2s.pkl', 'rb')
11 | f_save = open('none.pkl', 'wb')
12 | data_split = pickle.load(f_read)
13 | none_split = []
14 |
15 | train = val = test = 0
16 |
17 | for speaker_name in speakers:
18 | speaker_root = os.path.join(data_root, speaker_name)
19 |
20 | videos = [v for v in data_split[speaker_name]]
21 |
22 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
23 | for split in data_split[speaker_name][vid]:
24 | for seq in data_split[speaker_name][vid][split]:
25 |
26 | seq = seq.replace('\\', '/')
27 | old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
28 | old_file_path = old_file_path.replace('\\', '/')
29 | new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
30 | try:
31 | shutil.move(old_file_path, new_file_path)
32 | if split == 'train':
33 | train = train + 1
34 | elif split == 'test':
35 | test = test + 1
36 | elif split == 'val':
37 | val = val + 1
38 | except FileNotFoundError:
39 | none_split.append(old_file_path)
40 | print(f"The file {old_file_path} does not exists.")
41 | except shutil.Error:
42 | none_split.append(old_file_path)
43 | print(f"The file {old_file_path} does not exists.")
44 |
45 | print(none_split.__len__())
46 | pickle.dump(none_split, f_save)
47 | f_save.close()
48 |
49 | print(train, val, test)
50 |
51 |
52 |
--------------------------------------------------------------------------------
/data_utils/axis2matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import scipy.linalg as linalg
4 |
5 |
6 | def rotate_mat(axis, radian):
7 |
8 | a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
9 |
10 | rot_matrix = linalg.expm(a)
11 |
12 | return rot_matrix
13 |
14 | def aaa2mat(axis, sin, cos):
15 | i = np.eye(3)
16 | nnt = np.dot(axis.T, axis)
17 | s = np.asarray([[0, -axis[0,2], axis[0,1]],
18 | [axis[0,2], 0, -axis[0,0]],
19 | [-axis[0,1], axis[0,0], 0]])
20 | r = cos * i + (1-cos)*nnt +sin * s
21 | return r
22 |
23 | rand_axis = np.asarray([[1,0,0]])
24 | #旋转角度
25 | r = math.pi/2
26 | #返回旋转矩阵
27 | rot_matrix = rotate_mat(rand_axis, r)
28 | r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
29 | print(rot_matrix)
--------------------------------------------------------------------------------
/data_utils/dataset_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from tqdm import tqdm
4 | import shutil
5 | import torch
6 | import numpy as np
7 | import librosa
8 | import random
9 |
10 | speakers = ['seth', 'conan', 'oliver', 'chemistry']
11 | data_root = "../ExpressiveWholeBodyDatasetv1.0/"
12 | split = 'train'
13 |
14 |
15 |
16 | def split_list(full_list,shuffle=False,ratio=0.2):
17 | n_total = len(full_list)
18 | offset_0 = int(n_total * ratio)
19 | offset_1 = int(n_total * ratio * 2)
20 | if n_total==0 or offset_1<1:
21 | return [],full_list
22 | if shuffle:
23 | random.shuffle(full_list)
24 | sublist_0 = full_list[:offset_0]
25 | sublist_1 = full_list[offset_0:offset_1]
26 | sublist_2 = full_list[offset_1:]
27 | return sublist_0, sublist_1, sublist_2
28 |
29 |
30 | def moveto(list, file):
31 | for f in list:
32 | before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
33 | new_path = os.path.join(before, file)
34 | new_path = os.path.join(new_path, after)
35 | # os.makedirs(new_path)
36 | # os.path.isdir(new_path)
37 | # shutil.move(f, new_path)
38 |
39 | #转移到新目录
40 | shutil.copytree(f, new_path)
41 | #删除原train里的文件
42 | shutil.rmtree(f)
43 | return None
44 |
45 |
46 | def read_pkl(data):
47 | betas = np.array(data['betas'])
48 |
49 | jaw_pose = np.array(data['jaw_pose'])
50 | leye_pose = np.array(data['leye_pose'])
51 | reye_pose = np.array(data['reye_pose'])
52 | global_orient = np.array(data['global_orient']).squeeze()
53 | body_pose = np.array(data['body_pose_axis'])
54 | left_hand_pose = np.array(data['left_hand_pose'])
55 | right_hand_pose = np.array(data['right_hand_pose'])
56 |
57 | full_body = np.concatenate(
58 | (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
59 |
60 | expression = np.array(data['expression'])
61 | full_body = np.concatenate((full_body, expression), axis=1)
62 |
63 | if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
64 | return 1
65 | else:
66 | return 0
67 |
68 |
69 | for speaker_name in speakers:
70 | speaker_root = os.path.join(data_root, speaker_name)
71 |
72 | videos = [v for v in os.listdir(speaker_root)]
73 | print(videos)
74 |
75 | haode = huaide = 0
76 | total_seqs = []
77 |
78 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
79 | # for vid in videos:
80 | source_vid = vid
81 | vid_pth = os.path.join(speaker_root, source_vid)
82 | # vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
83 | t = os.path.join(speaker_root, source_vid, 'test')
84 | v = os.path.join(speaker_root, source_vid, 'val')
85 |
86 | # if os.path.exists(t):
87 | # shutil.rmtree(t)
88 | # if os.path.exists(v):
89 | # shutil.rmtree(v)
90 | try:
91 | seqs = [s for s in os.listdir(vid_pth)]
92 | except:
93 | continue
94 | # if len(seqs) == 0:
95 | # shutil.rmtree(os.path.join(speaker_root, source_vid))
96 | # None
97 | for s in seqs:
98 | quality = 0
99 | total_seqs.append(os.path.join(vid_pth,s))
100 | seq_root = os.path.join(vid_pth, s)
101 | key = seq_root # correspond to clip******
102 | audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
103 |
104 | # delete the data without audio or the audio file could not be read
105 | if os.path.isfile(audio_fname):
106 | try:
107 | audio = librosa.load(audio_fname)
108 | except:
109 | # print(key)
110 | shutil.rmtree(key)
111 | huaide = huaide + 1
112 | continue
113 | else:
114 | huaide = huaide + 1
115 | # print(key)
116 | shutil.rmtree(key)
117 | continue
118 |
119 | # check motion file
120 | motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
121 | try:
122 | f = open(motion_fname, 'rb+')
123 | except:
124 | shutil.rmtree(key)
125 | huaide = huaide + 1
126 | continue
127 |
128 | data = pickle.load(f)
129 | w = read_pkl(data)
130 | f.close()
131 | quality = quality + w
132 |
133 | if w == 1:
134 | shutil.rmtree(key)
135 | # print(key)
136 | huaide = huaide + 1
137 | continue
138 |
139 | haode = haode + 1
140 |
141 | print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
142 |
143 | for speaker_name in speakers:
144 | speaker_root = os.path.join(data_root, speaker_name)
145 |
146 | videos = [v for v in os.listdir(speaker_root)]
147 | print(videos)
148 |
149 | haode = huaide = 0
150 | total_seqs = []
151 |
152 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
153 | # for vid in videos:
154 | source_vid = vid
155 | vid_pth = os.path.join(speaker_root, source_vid)
156 | try:
157 | seqs = [s for s in os.listdir(vid_pth)]
158 | except:
159 | continue
160 | for s in seqs:
161 | quality = 0
162 | total_seqs.append(os.path.join(vid_pth, s))
163 | print("total_seqs:{}".format(total_seqs.__len__()))
164 | # split the dataset
165 | test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
166 | print(len(test_list), len(val_list), len(train_list))
167 | moveto(train_list, 'train')
168 | moveto(test_list, 'test')
169 | moveto(val_list, 'val')
170 |
171 |
--------------------------------------------------------------------------------
/data_utils/get_j.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def to3d(poses, config):
5 | if config.Data.pose.convert_to_6d:
6 | if config.Data.pose.expression:
7 | poses_exp = poses[:, -100:]
8 | poses = poses[:, :-100]
9 |
10 | poses = poses.reshape(poses.shape[0], -1, 5)
11 | sin, cos = poses[:, :, 3], poses[:, :, 4]
12 | pose_angle = torch.atan2(sin, cos)
13 | poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
14 |
15 | if config.Data.pose.expression:
16 | poses = torch.cat([poses, poses_exp], dim=-1)
17 | return poses
18 |
19 |
20 | def get_joint(smplx_model, betas, pred):
21 | joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
22 | expression=pred[:, 165:265],
23 | jaw_pose=pred[:, 0:3],
24 | leye_pose=pred[:, 3:6],
25 | reye_pose=pred[:, 6:9],
26 | global_orient=pred[:, 9:12],
27 | body_pose=pred[:, 12:75],
28 | left_hand_pose=pred[:, 75:120],
29 | right_hand_pose=pred[:, 120:165],
30 | return_verts=True)['joints']
31 | return joint
32 |
33 |
34 | def get_joints(smplx_model, betas, pred):
35 | if len(pred.shape) == 3:
36 | B = pred.shape[0]
37 | x = 4 if B>= 4 else B
38 | T = pred.shape[1]
39 | pred = pred.reshape(-1, 265)
40 | smplx_model.batch_size = L = T * x
41 |
42 | times = pred.shape[0] // smplx_model.batch_size
43 | joints = []
44 | for i in range(times):
45 | joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
46 | joints = torch.cat(joints, dim=0)
47 | joints = joints.reshape(B, T, -1, 3)
48 | else:
49 | smplx_model.batch_size = pred.shape[0]
50 | joints = get_joint(smplx_model, betas, pred)
51 | return joints
--------------------------------------------------------------------------------
/data_utils/lower_body.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | lower_pose = torch.tensor(
5 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
6 | 0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
7 | -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
8 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
9 | lower_pose_stand = torch.tensor([
10 | 8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
11 | 3.0747, -0.0158, -0.0152,
12 | -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
13 | -3.9716e-01, -4.0229e-02, -1.2637e-01,
14 | 7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
15 | 7.8632e-01, -4.3810e-02, 1.4375e-02,
16 | -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
17 | # lower_pose_stand = torch.tensor(
18 | # [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
19 | # 3.0747, -0.0158, -0.0152,
20 | # -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
21 | # 1.1654603481292725, 0.0, 0.0,
22 | # 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
23 | # 0.0, 0.0, 0.0,
24 | # 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
25 | lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
26 | count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
27 | 29, 30, 31, 32, 33, 34, 35, 36, 37,
28 | 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
29 | fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
30 | 29,
31 | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
32 | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
33 | 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
34 | all_index = np.ones(275)
35 | all_index[fix_index] = 0
36 | c_index = []
37 | i = 0
38 | for num in all_index:
39 | if num == 1:
40 | c_index.append(i)
41 | i = i + 1
42 | c_index = np.asarray(c_index)
43 |
44 | fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
45 | 21, 22, 23, 24, 25, 26,
46 | 30, 31, 32, 33, 34, 35,
47 | 45, 46, 47, 48, 49, 50]
48 | all_index_3d = np.ones(165)
49 | all_index_3d[fix_index_3d] = 0
50 | c_index_3d = []
51 | i = 0
52 | for num in all_index_3d:
53 | if num == 1:
54 | c_index_3d.append(i)
55 | i = i + 1
56 | c_index_3d = np.asarray(c_index_3d)
57 |
58 | c_index_6d = []
59 | i = 0
60 | for num in all_index_3d:
61 | if num == 1:
62 | c_index_6d.append(2*i)
63 | c_index_6d.append(2 * i + 1)
64 | i = i + 1
65 | c_index_6d = np.asarray(c_index_6d)
66 |
67 |
68 | def part2full(input, stand=False):
69 | if stand:
70 | # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
71 | lp = torch.zeros_like(lower_pose)
72 | lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
73 | lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
74 | else:
75 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
76 |
77 | input = torch.cat([input[:, :3],
78 | lp[:, :15],
79 | input[:, 3:6],
80 | lp[:, 15:21],
81 | input[:, 6:9],
82 | lp[:, 21:27],
83 | input[:, 9:12],
84 | lp[:, 27:],
85 | input[:, 12:]]
86 | , dim=1)
87 | return input
88 |
89 |
90 | def pred2poses(input, gt):
91 | input = torch.cat([input[:, :3],
92 | gt[0:1, 3:18].repeat(input.shape[0], 1),
93 | input[:, 3:6],
94 | gt[0:1, 21:27].repeat(input.shape[0], 1),
95 | input[:, 6:9],
96 | gt[0:1, 30:36].repeat(input.shape[0], 1),
97 | input[:, 9:12],
98 | gt[0:1, 39:45].repeat(input.shape[0], 1),
99 | input[:, 12:]]
100 | , dim=1)
101 | return input
102 |
103 |
104 | def poses2poses(input, gt):
105 | input = torch.cat([input[:, :3],
106 | gt[0:1, 3:18].repeat(input.shape[0], 1),
107 | input[:, 18:21],
108 | gt[0:1, 21:27].repeat(input.shape[0], 1),
109 | input[:, 27:30],
110 | gt[0:1, 30:36].repeat(input.shape[0], 1),
111 | input[:, 36:39],
112 | gt[0:1, 39:45].repeat(input.shape[0], 1),
113 | input[:, 45:]]
114 | , dim=1)
115 | return input
116 |
117 | def poses2pred(input, stand=False):
118 | if stand:
119 | lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
120 | # lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
121 | else:
122 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
123 | input = torch.cat([input[:, :3],
124 | lp[:, :15],
125 | input[:, 18:21],
126 | lp[:, 15:21],
127 | input[:, 27:30],
128 | lp[:, 21:27],
129 | input[:, 36:39],
130 | lp[:, 27:],
131 | input[:, 45:]]
132 | , dim=1)
133 | return input
134 |
135 |
136 | rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
137 | # ,22, 23, 24, 25, 40, 26, 41,
138 | # 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
139 | # 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
140 |
141 | symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
142 | # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
143 | # 1, 1, 1, 1, 1, 1]
144 |
--------------------------------------------------------------------------------
/data_utils/split_more_than_2s.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/split_more_than_2s.pkl
--------------------------------------------------------------------------------
/data_utils/split_train_val_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 |
5 | if __name__ =='__main__':
6 | id_list = "chemistry conan oliver seth"
7 | id_list = id_list.split(' ')
8 |
9 | old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0'
10 | new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited'
11 |
12 | with open('train_val_test.json') as f:
13 | split_info = json.load(f)
14 | phase_list = ['train', 'val', 'test']
15 | for phase in phase_list:
16 | phase_path_list = split_info[phase]
17 | for p in phase_path_list:
18 | old_path = os.path.join(old_root, p)
19 | if not os.path.exists(old_path):
20 | print(f'{old_path} not found, continue' )
21 | continue
22 | new_path = os.path.join(new_root, phase, p)
23 | dir_name = os.path.dirname(new_path)
24 | if not os.path.isdir(dir_name):
25 | os.makedirs(dir_name, exist_ok=True)
26 | shutil.move(old_path, new_path)
27 |
28 |
--------------------------------------------------------------------------------
/data_utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import librosa #has to do this cause librosa is not supported on my server
3 | import python_speech_features
4 | from scipy.io import wavfile
5 | from scipy import signal
6 | import librosa
7 | import torch
8 | import torchaudio as ta
9 | import torchaudio.functional as ta_F
10 | import torchaudio.transforms as ta_T
11 | # import pyloudnorm as pyln
12 |
13 |
14 | def load_wav_old(audio_fn, sr = 16000):
15 | sample_rate, sig = wavfile.read(audio_fn)
16 | if sample_rate != sr:
17 | result = int((sig.shape[0]) / sample_rate * sr)
18 | x_resampled = signal.resample(sig, result)
19 | x_resampled = x_resampled.astype(np.float64)
20 | return x_resampled, sr
21 |
22 | sig = sig / (2**15)
23 | return sig, sample_rate
24 |
25 |
26 | def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
27 |
28 | y, sr = librosa.load(audio_fn, sr=sr, mono=True)
29 |
30 | if win_size is None:
31 | hop_len=int(sr / fps)
32 | else:
33 | hop_len=int(sr / win_size)
34 |
35 | n_fft=2048
36 |
37 | C = librosa.feature.mfcc(
38 | y = y,
39 | sr = sr,
40 | n_mfcc = n_mfcc,
41 | hop_length = hop_len,
42 | n_fft = n_fft
43 | )
44 |
45 | if C.shape[0] == n_mfcc:
46 | C = C.transpose(1, 0)
47 |
48 | return C
49 |
50 |
51 | def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64):
52 | raise NotImplementedError
53 | '''
54 | # y, sr = load_wav(audio_fn=audio_fn, sr=sr)
55 |
56 | # hop_len = int(sr / fps)
57 | # n_fft = 2048
58 |
59 | # C = librosa.feature.melspectrogram(
60 | # y = y,
61 | # sr = sr,
62 | # n_fft=n_fft,
63 | # hop_length=hop_len,
64 | # n_mels = n_mels,
65 | # fmin=0,
66 | # fmax=8000)
67 |
68 |
69 | # mask = (C == 0).astype(np.float)
70 | # C = mask * eps + (1-mask) * C
71 |
72 | # C = np.log(C)
73 | # #wierd error may occur here
74 | # assert not (np.isnan(C).any()), audio_fn
75 | # if C.shape[0] == n_mels:
76 | # C = C.transpose(1, 0)
77 |
78 | # return C
79 | '''
80 |
81 | def extract_mfcc(audio,sample_rate=16000):
82 | mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04))
83 | mfcc = np.stack([np.array(i) for i in mfcc])
84 | return mfcc
85 |
86 | def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
87 | y, sr = load_wav_old(audio_fn, sr=sr)
88 |
89 | if y.shape.__len__() > 1:
90 | y = (y[:,0]+y[:,1])/2
91 |
92 | if win_size is None:
93 | hop_len=int(sr / fps)
94 | else:
95 | hop_len=int(sr/ win_size)
96 |
97 | n_fft=2048
98 |
99 | #hard coded for 25 fps
100 | if not smlpx:
101 | C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04)
102 | else:
103 | C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15)
104 | # if C.shape[0] == n_mfcc:
105 | # C = C.transpose(1, 0)
106 |
107 | return C
108 |
109 |
110 | def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
111 | y, sr = load_wav_old(audio_fn, sr=sr)
112 |
113 | if y.shape.__len__() > 1:
114 | y = (y[:, 0] + y[:, 1]) / 2
115 | n_fft = 2048
116 |
117 | slice_len = 22000 * 5
118 | slice = y.size // slice_len
119 |
120 | C = []
121 |
122 | for i in range(slice):
123 | if i != (slice - 1):
124 | feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
125 | else:
126 | feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
127 |
128 | C.append(feat)
129 |
130 | return C
131 |
132 |
133 | def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
134 | """
135 | :param audio: 1 x T tensor containing a 16kHz audio signal
136 | :param frame_rate: frame rate for video (we need one audio chunk per video frame)
137 | :param chunk_size: number of audio samples per chunk
138 | :return: num_chunks x chunk_size tensor containing sliced audio
139 | """
140 | samples_per_frame = chunk_size // frame_rate
141 | padding = (chunk_size - samples_per_frame) // 2
142 | audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
143 | anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
144 | audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
145 | return audio
146 |
147 |
148 | def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'):
149 | if am is None:
150 | audio, sr_0 = ta.load(audio_fn)
151 | if sr != sr_0:
152 | audio = ta.transforms.Resample(sr_0, sr)(audio)
153 | if audio.shape[0] > 1:
154 | audio = torch.mean(audio, dim=0, keepdim=True)
155 |
156 | n_fft = 2048
157 | if fps == 15:
158 | hop_length = 1467
159 | elif fps == 30:
160 | hop_length = 734
161 | win_length = hop_length * 2
162 | n_mels = 256
163 | n_mfcc = 64
164 |
165 | if type == 'mfcc':
166 | mfcc_transform = ta_T.MFCC(
167 | sample_rate=sr,
168 | n_mfcc=n_mfcc,
169 | melkwargs={
170 | "n_fft": n_fft,
171 | "n_mels": n_mels,
172 | # "win_length": win_length,
173 | "hop_length": hop_length,
174 | "mel_scale": "htk",
175 | },
176 | )
177 | audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy()
178 | elif type == 'mel':
179 | # audio = 0.01 * audio / torch.mean(torch.abs(audio))
180 | mel_transform = ta_T.MelSpectrogram(
181 | sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels
182 | )
183 | audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy()
184 | # audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy()
185 | elif type == 'mel_mul':
186 | audio = 0.01 * audio / torch.mean(torch.abs(audio))
187 | audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr)
188 | mel_transform = ta_T.MelSpectrogram(
189 | sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels
190 | )
191 | audio_ft = mel_transform(audio).squeeze(1)
192 | audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy()
193 | else:
194 | speech_array, sampling_rate = librosa.load(audio_fn, sr=16000)
195 |
196 | if encoder_choice == 'faceformer':
197 | # audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1)
198 | audio_ft = speech_array.reshape(-1, 1)
199 | elif encoder_choice == 'meshtalk':
200 | audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array))
201 | elif encoder_choice == 'onset':
202 | audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1)
203 | else:
204 | audio, sr_0 = ta.load(audio_fn)
205 | if sr != sr_0:
206 | audio = ta.transforms.Resample(sr_0, sr)(audio)
207 | if audio.shape[0] > 1:
208 | audio = torch.mean(audio, dim=0, keepdim=True)
209 |
210 | n_fft = 2048
211 | if fps == 15:
212 | hop_length = 1467
213 | elif fps == 30:
214 | hop_length = 734
215 | win_length = hop_length * 2
216 | n_mels = 256
217 | n_mfcc = 64
218 |
219 | mfcc_transform = ta_T.MFCC(
220 | sample_rate=sr,
221 | n_mfcc=n_mfcc,
222 | melkwargs={
223 | "n_fft": n_fft,
224 | "n_mels": n_mels,
225 | # "win_length": win_length,
226 | "hop_length": hop_length,
227 | "mel_scale": "htk",
228 | },
229 | )
230 | audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy()
231 | return audio_ft
232 |
233 |
234 | def get_mfcc_sepa(audio_fn, fps=15, sr=16000):
235 | audio, sr_0 = ta.load(audio_fn)
236 | if sr != sr_0:
237 | audio = ta.transforms.Resample(sr_0, sr)(audio)
238 | if audio.shape[0] > 1:
239 | audio = torch.mean(audio, dim=0, keepdim=True)
240 |
241 | n_fft = 2048
242 | if fps == 15:
243 | hop_length = 1467
244 | elif fps == 30:
245 | hop_length = 734
246 | n_mels = 256
247 | n_mfcc = 64
248 |
249 | mfcc_transform = ta_T.MFCC(
250 | sample_rate=sr,
251 | n_mfcc=n_mfcc,
252 | melkwargs={
253 | "n_fft": n_fft,
254 | "n_mels": n_mels,
255 | # "win_length": win_length,
256 | "hop_length": hop_length,
257 | "mel_scale": "htk",
258 | },
259 | )
260 | audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy()
261 | audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy()
262 | audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0)
263 | return audio_ft, audio_ft_0.shape[0]
264 |
265 |
266 | def get_mfcc_old(wav_file):
267 | sig, sample_rate = load_wav_old(wav_file)
268 | mfcc = extract_mfcc(sig)
269 | return mfcc
270 |
271 |
272 | def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0):
273 | """
274 | :param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame
275 | :param mask: V-dimensional Tensor containing a mask with vertices to be smoothed
276 | :param filter_size: size of the Gaussian filter
277 | :param sigma: standard deviation of the Gaussian filter
278 | :return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask)
279 | """
280 | assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}"
281 | # Gaussian smoothing (low-pass filtering)
282 | fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1)
283 | fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2)
284 | fltr = torch.Tensor(fltr) / np.sum(fltr)
285 | # apply fltr
286 | fltr = fltr.view(1, 1, -1).to(device=geom.device)
287 | T, V = geom.shape[1], geom.shape[2]
288 | g = torch.nn.functional.pad(
289 | geom.permute(2, 0, 1).view(V, 1, T),
290 | pad=[filter_size // 2, filter_size // 2], mode='replicate'
291 | )
292 | g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T)
293 | smoothed = g.permute(1, 2, 0).contiguous()
294 | # blend smoothed signal with original signal
295 | if mask is None:
296 | return smoothed
297 | else:
298 | return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1)
299 |
300 | if __name__ == '__main__':
301 | audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav'
302 |
303 | C = get_mfcc_psf(audio_fn)
304 | print(C.shape)
305 |
306 | C_2 = get_mfcc_librosa(audio_fn)
307 | print(C.shape)
308 |
309 | print(C)
310 | print(C_2)
311 | print((C == C_2).all())
312 | # print(y.shape, sr)
313 | # mel_spec = get_melspec(audio_fn)
314 | # print(mel_spec.shape)
315 | # mfcc = get_mfcc(audio_fn, sr = 16000)
316 | # print(mfcc.shape)
317 | # print(mel_spec.max(), mel_spec.min())
318 | # print(mfcc.max(), mfcc.min())
--------------------------------------------------------------------------------
/demo/1st-page/1st-page-upper.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/1st-page/1st-page-upper.mp4
--------------------------------------------------------------------------------
/demo/1st-page/1st-page-upper.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/1st-page/1st-page-upper.npy
--------------------------------------------------------------------------------
/demo/french/french.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/french/french.mp4
--------------------------------------------------------------------------------
/demo/french/french.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/french/french.npy
--------------------------------------------------------------------------------
/demo/rich/rich.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/rich/rich.mp4
--------------------------------------------------------------------------------
/demo/rich/rich.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/rich/rich.npy
--------------------------------------------------------------------------------
/demo/song/cut.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/cut.mp4
--------------------------------------------------------------------------------
/demo/song/song.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/song.mp4
--------------------------------------------------------------------------------
/demo/song/song.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/song.npy
--------------------------------------------------------------------------------
/demo/style/chemistry.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/chemistry.mp4
--------------------------------------------------------------------------------
/demo/style/chemistry.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/chemistry.npy
--------------------------------------------------------------------------------
/demo/style/conan.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/conan.mp4
--------------------------------------------------------------------------------
/demo/style/conan.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/conan.npy
--------------------------------------------------------------------------------
/demo/style/diversity.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/diversity.mp4
--------------------------------------------------------------------------------
/demo/style/diversity.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/diversity.npy
--------------------------------------------------------------------------------
/demo/style/face.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/face.mp4
--------------------------------------------------------------------------------
/demo/style/face.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/face.npy
--------------------------------------------------------------------------------
/demo/style/oliver.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/oliver.mp4
--------------------------------------------------------------------------------
/demo/style/oliver.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/oliver.npy
--------------------------------------------------------------------------------
/demo/style/seth.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/seth.mp4
--------------------------------------------------------------------------------
/demo/style/seth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/seth.npy
--------------------------------------------------------------------------------
/demo_audio/1st-page.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/1st-page.wav
--------------------------------------------------------------------------------
/demo_audio/french.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/french.wav
--------------------------------------------------------------------------------
/demo_audio/rich.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/rich.wav
--------------------------------------------------------------------------------
/demo_audio/rich_short.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/rich_short.wav
--------------------------------------------------------------------------------
/demo_audio/song.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/song.wav
--------------------------------------------------------------------------------
/demo_audio/style.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/style.wav
--------------------------------------------------------------------------------
/evaluation/FGD.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from scipy import linalg
7 | import math
8 | from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
9 |
10 | import warnings
11 | warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
12 |
13 |
14 | change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
15 | class EmbeddingSpaceEvaluator:
16 | def __init__(self, ae, vae, device):
17 |
18 | # init embed net
19 | self.ae = ae
20 | # self.vae = vae
21 |
22 | # storage
23 | self.real_feat_list = []
24 | self.generated_feat_list = []
25 | self.real_joints_list = []
26 | self.generated_joints_list = []
27 | self.real_6d_list = []
28 | self.generated_6d_list = []
29 | self.audio_beat_list = []
30 |
31 | def reset(self):
32 | self.real_feat_list = []
33 | self.generated_feat_list = []
34 |
35 | def get_no_of_samples(self):
36 | return len(self.real_feat_list)
37 |
38 | def push_samples(self, generated_poses, real_poses):
39 | # self.net.eval()
40 | # convert poses to latent features
41 | real_feat, real_poses = self.ae.extract(real_poses)
42 | generated_feat, generated_poses = self.ae.extract(generated_poses)
43 |
44 | num_joints = real_poses.shape[2] // 3
45 |
46 | real_feat = real_feat.squeeze()
47 | generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
48 |
49 | self.real_feat_list.append(real_feat.data.cpu().numpy())
50 | self.generated_feat_list.append(generated_feat.data.cpu().numpy())
51 |
52 | # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
53 | # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
54 | #
55 | # self.real_feat_list.append(real_poses.data.cpu().numpy())
56 | # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
57 |
58 | def push_joints(self, generated_poses, real_poses):
59 | self.real_joints_list.append(real_poses.data.cpu())
60 | self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
61 |
62 | def push_aud(self, aud):
63 | self.audio_beat_list.append(aud.squeeze().data.cpu())
64 |
65 | def get_MAAC(self):
66 | ang_vel_list = []
67 | for real_joints in self.real_joints_list:
68 | real_joints[:, 15:21] = real_joints[:, 16:22]
69 | vec = real_joints[:, 15:21] - real_joints[:, 13:19]
70 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
71 | inner_product = torch.clamp(inner_product, -1, 1, out=None)
72 | angle = torch.acos(inner_product) / math.pi
73 | ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
74 | ang_vel_list.append(ang_vel.unsqueeze(dim=0))
75 | all_vel = torch.cat(ang_vel_list, dim=0)
76 | MAAC = all_vel.mean(dim=0)
77 | return MAAC
78 |
79 | def get_BCscore(self):
80 | thres = 0.01
81 | sigma = 0.1
82 | sum_1 = 0
83 | total_beat = 0
84 | for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
85 | motion_beat_time = []
86 | if joints.dim() == 4:
87 | joints = joints[0]
88 | joints[:, 15:21] = joints[:, 16:22]
89 | vec = joints[:, 15:21] - joints[:, 13:19]
90 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
91 | inner_product = torch.clamp(inner_product, -1, 1, out=None)
92 | angle = torch.acos(inner_product) / math.pi
93 | ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
94 |
95 | angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
96 |
97 | sum_2 = 0
98 | for i in range(angle_diff.shape[1]):
99 | motion_beat_time = []
100 | for t in range(1, joints.shape[0]-1):
101 | if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
102 | if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
103 | t][i] >= thres):
104 | motion_beat_time.append(float(t) / 30.0)
105 | if (len(motion_beat_time) == 0):
106 | continue
107 | motion_beat_time = torch.tensor(motion_beat_time)
108 | sum = 0
109 | for audio in audio_beat_time:
110 | sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
111 | sum_2 = sum_2 + sum
112 | total_beat = total_beat + len(audio_beat_time)
113 | sum_1 = sum_1 + sum_2
114 | return sum_1/total_beat
115 |
116 |
117 | def get_scores(self):
118 | generated_feats = np.vstack(self.generated_feat_list)
119 | real_feats = np.vstack(self.real_feat_list)
120 |
121 | def frechet_distance(samples_A, samples_B):
122 | A_mu = np.mean(samples_A, axis=0)
123 | A_sigma = np.cov(samples_A, rowvar=False)
124 | B_mu = np.mean(samples_B, axis=0)
125 | B_sigma = np.cov(samples_B, rowvar=False)
126 | try:
127 | frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
128 | except ValueError:
129 | frechet_dist = 1e+10
130 | return frechet_dist
131 |
132 | ####################################################################
133 | # frechet distance
134 | frechet_dist = frechet_distance(generated_feats, real_feats)
135 |
136 | ####################################################################
137 | # distance between real and generated samples on the latent feature space
138 | dists = []
139 | for i in range(real_feats.shape[0]):
140 | d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
141 | dists.append(d)
142 | feat_dist = np.mean(dists)
143 |
144 | return frechet_dist, feat_dist
145 |
146 | @staticmethod
147 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
148 | """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
149 | """Numpy implementation of the Frechet Distance.
150 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
151 | and X_2 ~ N(mu_2, C_2) is
152 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
153 | Stable version by Dougal J. Sutherland.
154 | Params:
155 | -- mu1 : Numpy array containing the activations of a layer of the
156 | inception net (like returned by the function 'get_predictions')
157 | for generated samples.
158 | -- mu2 : The sample mean over activations, precalculated on an
159 | representative data set.
160 | -- sigma1: The covariance matrix over activations for generated samples.
161 | -- sigma2: The covariance matrix over activations, precalculated on an
162 | representative data set.
163 | Returns:
164 | -- : The Frechet Distance.
165 | """
166 |
167 | mu1 = np.atleast_1d(mu1)
168 | mu2 = np.atleast_1d(mu2)
169 |
170 | sigma1 = np.atleast_2d(sigma1)
171 | sigma2 = np.atleast_2d(sigma2)
172 |
173 | assert mu1.shape == mu2.shape, \
174 | 'Training and test mean vectors have different lengths'
175 | assert sigma1.shape == sigma2.shape, \
176 | 'Training and test covariances have different dimensions'
177 |
178 | diff = mu1 - mu2
179 |
180 | # Product might be almost singular
181 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
182 | if not np.isfinite(covmean).all():
183 | msg = ('fid calculation produces singular product; '
184 | 'adding %s to diagonal of cov estimates') % eps
185 | print(msg)
186 | offset = np.eye(sigma1.shape[0]) * eps
187 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
188 |
189 | # Numerical error might give slight imaginary component
190 | if np.iscomplexobj(covmean):
191 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
192 | m = np.max(np.abs(covmean.imag))
193 | raise ValueError('Imaginary component {}'.format(m))
194 | covmean = covmean.real
195 |
196 | tr_covmean = np.trace(covmean)
197 |
198 | return (diff.dot(diff) + np.trace(sigma1) +
199 | np.trace(sigma2) - 2 * tr_covmean)
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/diversity_LVD.py:
--------------------------------------------------------------------------------
1 | '''
2 | LVD: different initial pose
3 | diversity: same initial pose
4 | '''
5 | import os
6 | import sys
7 | sys.path.append(os.getcwd())
8 |
9 | from glob import glob
10 |
11 | from argparse import ArgumentParser
12 | import json
13 |
14 | from evaluation.util import *
15 | from evaluation.metrics import *
16 | from tqdm import tqdm
17 |
18 | parser = ArgumentParser()
19 | parser.add_argument('--speaker', required=True, type=str)
20 | parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
21 | args = parser.parse_args()
22 |
23 | speaker = args.speaker
24 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
25 |
26 | LVD_list = []
27 | diversity_list = []
28 |
29 | for aud in tqdm(test_audios):
30 | base_name = os.path.splitext(aud)[0]
31 | gt_path = get_full_path(aud, speaker, 'val')
32 | _, gt_poses, _ = get_gts(gt_path)
33 | gt_poses = gt_poses[np.newaxis,...]
34 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
35 | for post_fix in args.post_fix:
36 | pred_path = base_name + '_'+post_fix+'.json'
37 | pred_poses = np.array(json.load(open(pred_path)))
38 | # print(pred_poses.shape)#(B, seq_len, 108)
39 | pred_poses = cvt25(pred_poses, gt_poses)
40 | # print(pred_poses.shape)#(B, seq, pose_dim)
41 |
42 | gt_valid_points = hand_points(gt_poses)
43 | pred_valid_points = hand_points(pred_poses)
44 |
45 | lvd = LVD(gt_valid_points, pred_valid_points)
46 | # div = diversity(pred_valid_points)
47 |
48 | LVD_list.append(lvd)
49 | # diversity_list.append(div)
50 |
51 | # gt_velocity = peak_velocity(gt_valid_points, order=2)
52 | # pred_velocity = peak_velocity(pred_valid_points, order=2)
53 |
54 | # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
55 | # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
56 |
57 | # gt_consistency_list.append(gt_consistency)
58 | # pred_consistency_list.append(pred_consistency)
59 |
60 | lvd = np.mean(LVD_list)
61 | # diversity_list = np.mean(diversity_list)
62 |
63 | print('LVD:', lvd)
64 | # print("diversity:", diversity_list)
--------------------------------------------------------------------------------
/evaluation/get_quality_samples.py:
--------------------------------------------------------------------------------
1 | '''
2 | '''
3 | import os
4 | import sys
5 | sys.path.append(os.getcwd())
6 |
7 | from glob import glob
8 |
9 | from argparse import ArgumentParser
10 | import json
11 |
12 | from evaluation.util import *
13 | from evaluation.metrics import *
14 | from tqdm import tqdm
15 |
16 | parser = ArgumentParser()
17 | parser.add_argument('--speaker', required=True, type=str)
18 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
19 | args = parser.parse_args()
20 |
21 | speaker = args.speaker
22 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
23 |
24 | quality_samples={'gt':[]}
25 | for post_fix in args.post_fix:
26 | quality_samples[post_fix] = []
27 |
28 | for aud in tqdm(test_audios):
29 | base_name = os.path.splitext(aud)[0]
30 | gt_path = get_full_path(aud, speaker, 'val')
31 | _, gt_poses, _ = get_gts(gt_path)
32 | gt_poses = gt_poses[np.newaxis,...]
33 | gt_valid_points = valid_points(gt_poses)
34 | # print(gt_valid_points.shape)
35 | quality_samples['gt'].append(gt_valid_points)
36 |
37 | for post_fix in args.post_fix:
38 | pred_path = base_name + '_'+post_fix+'.json'
39 | pred_poses = np.array(json.load(open(pred_path)))
40 | # print(pred_poses.shape)#(B, seq_len, 108)
41 | pred_poses = cvt25(pred_poses, gt_poses)
42 | # print(pred_poses.shape)#(B, seq, pose_dim)
43 |
44 | pred_valid_points = valid_points(pred_poses)[0:1]
45 | quality_samples[post_fix].append(pred_valid_points)
46 |
47 | quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
48 | for post_fix in args.post_fix:
49 | quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
50 |
51 | print('gt:', quality_samples['gt'].shape)
52 | quality_samples['gt'] = quality_samples['gt'].tolist()
53 | for post_fix in args.post_fix:
54 | print(post_fix, ':', quality_samples[post_fix].shape)
55 | quality_samples[post_fix] = quality_samples[post_fix].tolist()
56 |
57 | save_dir = '../../experiments/'
58 | os.makedirs(save_dir, exist_ok=True)
59 | save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
60 | with open(save_name, 'w') as f:
61 | json.dump(quality_samples, f)
62 |
63 |
--------------------------------------------------------------------------------
/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Warning: metrics are for reference only, may have limited significance
3 | '''
4 | import os
5 | import sys
6 | sys.path.append(os.getcwd())
7 | import numpy as np
8 | import torch
9 |
10 | from data_utils.lower_body import rearrange, symmetry
11 | import torch.nn.functional as F
12 |
13 | def data_driven_baselines(gt_kps):
14 | '''
15 | gt_kps: T, D
16 | '''
17 | gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
18 |
19 | mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
20 | mean = np.mean(np.abs(gt_velocity-mean))
21 | last_step = gt_kps[1] - gt_kps[0]
22 | last_step = last_step[np.newaxis] #(1, D)
23 | last_step = np.mean(np.abs(gt_velocity-last_step))
24 | return last_step, mean
25 |
26 | def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
27 | if gt_kps.shape[0] > pr_kps.shape[1]:
28 | length = pr_kps.shape[1]
29 | else:
30 | length = gt_kps.shape[0]
31 | gt_kps = gt_kps[:length]
32 | pr_kps = pr_kps[:, :length]
33 | global symmetry
34 | symmetry = torch.tensor(symmetry).bool()
35 |
36 | if symmetrical:
37 | # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
38 | gt_kps = gt_kps[:, rearrange]
39 | ns_gt_kps = gt_kps[:, ~symmetry]
40 | ys_gt_kps = gt_kps[:, symmetry]
41 | ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
42 | ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
43 | ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
44 | left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
45 | right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
46 | move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
47 | ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
48 | ys_gt_velocity = ys_gt_velocity.transpose(0,1)
49 | gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
50 |
51 | pr_kps = pr_kps[:, :, rearrange]
52 | ns_pr_kps = pr_kps[:, :, ~symmetry]
53 | ys_pr_kps = pr_kps[:, :, symmetry]
54 | ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
55 | ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
56 | ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
57 | left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
58 | right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
59 | move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
60 | torch.zeros(left_pr_vel.shape).cuda())
61 | ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
62 | ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
63 | ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
64 | pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
65 | else:
66 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
67 | pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
68 |
69 | if weight:
70 | w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
71 | else:
72 | w = 1 / gt_velocity.shape[0]
73 |
74 | v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
75 |
76 | return v_diff
77 |
78 |
79 | def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
80 | gt_kps = gt_kps.squeeze()
81 | pr_kps = pr_kps.squeeze()
82 | if len(pr_kps.shape) == 4:
83 | return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
84 | # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
85 | length = gt_kps.shape[0]-10
86 | # gt_kps = gt_kps[25:length]
87 | # pr_kps = pr_kps[25:length] #(T, D)
88 | # if pr_kps.shape[0] < gt_kps.shape[0]:
89 | # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
90 |
91 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
92 | pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
93 |
94 | return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
95 |
96 | def diversity(kps):
97 | '''
98 | kps: bs, seq, dim
99 | '''
100 | dis_list = []
101 | #the distance between each pair
102 | for i in range(kps.shape[0]):
103 | for j in range(i+1, kps.shape[0]):
104 | seq_i = kps[i]
105 | seq_j = kps[j]
106 |
107 | dis = np.mean(np.abs(seq_i - seq_j))
108 | dis_list.append(dis)
109 | return np.mean(dis_list)
110 |
--------------------------------------------------------------------------------
/evaluation/mode_transition.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | from glob import glob
6 |
7 | from argparse import ArgumentParser
8 | import json
9 |
10 | from evaluation.util import *
11 | from evaluation.metrics import *
12 | from tqdm import tqdm
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--speaker', required=True, type=str)
16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17 | args = parser.parse_args()
18 |
19 | speaker = args.speaker
20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21 |
22 | precision_list=[]
23 | recall_list=[]
24 | accuracy_list=[]
25 |
26 | for aud in tqdm(test_audios):
27 | base_name = os.path.splitext(aud)[0]
28 | gt_path = get_full_path(aud, speaker, 'val')
29 | _, gt_poses, _ = get_gts(gt_path)
30 | if gt_poses.shape[0] < 50:
31 | continue
32 | gt_poses = gt_poses[np.newaxis,...]
33 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
34 | for post_fix in args.post_fix:
35 | pred_path = base_name + '_'+post_fix+'.json'
36 | pred_poses = np.array(json.load(open(pred_path)))
37 | # print(pred_poses.shape)#(B, seq_len, 108)
38 | pred_poses = cvt25(pred_poses, gt_poses)
39 | # print(pred_poses.shape)#(B, seq, pose_dim)
40 |
41 | gt_valid_points = valid_points(gt_poses)
42 | pred_valid_points = valid_points(pred_poses)
43 |
44 | # print(gt_valid_points.shape, pred_valid_points.shape)
45 |
46 | gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
47 | pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
48 |
49 | # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
50 | # pred_mode_transition_seq = baseline
51 | precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
52 | precision_list.append(precision)
53 | recall_list.append(recall)
54 | accuracy_list.append(accuracy)
55 | print(len(precision_list), len(recall_list), len(accuracy_list))
56 | precision_list = np.mean(precision_list)
57 | recall_list = np.mean(recall_list)
58 | accuracy_list = np.mean(accuracy_list)
59 |
60 | print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
61 |
--------------------------------------------------------------------------------
/evaluation/peak_velocity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | from glob import glob
6 |
7 | from argparse import ArgumentParser
8 | import json
9 |
10 | from evaluation.util import *
11 | from evaluation.metrics import *
12 | from tqdm import tqdm
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--speaker', required=True, type=str)
16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17 | args = parser.parse_args()
18 |
19 | speaker = args.speaker
20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21 |
22 | gt_consistency_list=[]
23 | pred_consistency_list=[]
24 |
25 | for aud in tqdm(test_audios):
26 | base_name = os.path.splitext(aud)[0]
27 | gt_path = get_full_path(aud, speaker, 'val')
28 | _, gt_poses, _ = get_gts(gt_path)
29 | gt_poses = gt_poses[np.newaxis,...]
30 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
31 | for post_fix in args.post_fix:
32 | pred_path = base_name + '_'+post_fix+'.json'
33 | pred_poses = np.array(json.load(open(pred_path)))
34 | # print(pred_poses.shape)#(B, seq_len, 108)
35 | pred_poses = cvt25(pred_poses, gt_poses)
36 | # print(pred_poses.shape)#(B, seq, pose_dim)
37 |
38 | gt_valid_points = hand_points(gt_poses)
39 | pred_valid_points = hand_points(pred_poses)
40 |
41 | gt_velocity = peak_velocity(gt_valid_points, order=2)
42 | pred_velocity = peak_velocity(pred_valid_points, order=2)
43 |
44 | gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
45 | pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
46 |
47 | gt_consistency_list.append(gt_consistency)
48 | pred_consistency_list.append(pred_consistency)
49 |
50 | gt_consistency_list = np.concatenate(gt_consistency_list)
51 | pred_consistency_list = np.concatenate(pred_consistency_list)
52 |
53 | print(gt_consistency_list.max(), gt_consistency_list.min())
54 | print(pred_consistency_list.max(), pred_consistency_list.min())
55 | print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
56 | print(np.std(gt_consistency_list), np.std(pred_consistency_list))
57 |
58 | draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
59 | draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
60 |
61 | to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
62 | to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
63 |
64 | np.save('%s_gt.npy'%(speaker), gt_consistency_list)
65 | np.save('%s_pred.npy'%(speaker), pred_consistency_list)
--------------------------------------------------------------------------------
/evaluation/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import json
5 | from matplotlib import pyplot as plt
6 | import pandas as pd
7 | def get_gts(clip):
8 | '''
9 | clip: abs path to the clip dir
10 | '''
11 | keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
12 |
13 | upper_body_points = list(np.arange(0, 25))
14 | poses = []
15 | confs = []
16 | neck_to_nose_len = []
17 | mean_position = []
18 | for kp_file in keypoints_files:
19 | kp_load = json.load(open(kp_file, 'r'))['people'][0]
20 | posepts = kp_load['pose_keypoints_2d']
21 | lhandpts = kp_load['hand_left_keypoints_2d']
22 | rhandpts = kp_load['hand_right_keypoints_2d']
23 | facepts = kp_load['face_keypoints_2d']
24 |
25 | neck = np.array(posepts).reshape(-1,3)[1]
26 | nose = np.array(posepts).reshape(-1,3)[0]
27 | x_offset = abs(neck[0]-nose[0])
28 | y_offset = abs(neck[1]-nose[1])
29 | neck_to_nose_len.append(y_offset)
30 | mean_position.append([neck[0],neck[1]])
31 |
32 | keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
33 |
34 | upper_body = keypoints[upper_body_points, :]
35 | hand_points = keypoints[25:, :]
36 | keypoints = np.vstack([upper_body, hand_points])
37 |
38 | poses.append(keypoints)
39 |
40 | if len(neck_to_nose_len) > 0:
41 | scale_factor = np.mean(neck_to_nose_len)
42 | else:
43 | raise ValueError(clip)
44 | mean_position = np.mean(np.array(mean_position), axis=0)
45 |
46 | unlocalized_poses = np.array(poses).copy()
47 | localized_poses = []
48 | for i in range(len(poses)):
49 | keypoints = poses[i]
50 | neck = keypoints[1].copy()
51 |
52 | keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
53 | keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
54 | localized_poses.append(keypoints.reshape(-1))
55 |
56 | localized_poses=np.array(localized_poses)
57 | return unlocalized_poses, localized_poses, (scale_factor, mean_position)
58 |
59 | def get_full_path(wav_name, speaker, split):
60 | '''
61 | get clip path from aud file
62 | '''
63 | wav_name = os.path.basename(wav_name)
64 | wav_name = os.path.splitext(wav_name)[0]
65 | clip_name, vid_name = wav_name[:10], wav_name[11:]
66 |
67 | full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
68 |
69 | assert os.path.isdir(full_path), full_path
70 |
71 | return full_path
72 |
73 | def smooth(res):
74 | '''
75 | res: (B, seq_len, pose_dim)
76 | '''
77 | window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
78 | w_size=7
79 | for i in range(10, res.shape[1]-3):
80 | window.append(res[:, i+3, :])
81 | if len(window) > w_size:
82 | window = window[1:]
83 |
84 | if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
85 | res[:, i, :] = np.mean(window, axis=1)
86 |
87 | return res
88 |
89 | def cvt25(pred_poses, gt_poses=None):
90 | '''
91 | gt_poses: (1, seq_len, 270), 135 *2
92 | pred_poses: (B, seq_len, 108), 54 * 2
93 | '''
94 | if gt_poses is None:
95 | gt_poses = np.zeros_like(pred_poses)
96 | else:
97 | gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
98 |
99 | length = min(pred_poses.shape[1], gt_poses.shape[1])
100 | pred_poses = pred_poses[:, :length, :]
101 | gt_poses = gt_poses[:, :length, :]
102 | gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
103 | pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
104 |
105 | gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
106 | gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
107 |
108 | return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
109 |
110 | def hand_points(seq):
111 | '''
112 | seq: (B, seq_len, 135*2)
113 | hands only
114 | '''
115 | hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
116 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
117 | return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
118 |
119 | def valid_points(seq):
120 | '''
121 | hands with some head points
122 | '''
123 | valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
124 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
125 |
126 | seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
127 | assert seq.shape[-1] == 108, seq.shape
128 | return seq
129 |
130 | def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
131 | plt.figure()
132 | plt.hist(seq, bins=100, range=(0, 100), color=color)
133 | plt.savefig(save_name)
134 |
135 | def to_excel(seq, save_name='res.xlsx'):
136 | '''
137 | seq: (T)
138 | '''
139 | df = pd.DataFrame(seq)
140 | writer = pd.ExcelWriter(save_name)
141 | df.to_excel(writer, 'sheet1')
142 | writer.save()
143 | writer.close()
144 |
145 |
146 | if __name__ == '__main__':
147 | random_data = np.random.randint(0, 10, 100)
148 | draw_cdf(random_data)
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
--------------------------------------------------------------------------------
/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/losses/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 | class KeypointLoss(nn.Module):
12 | def __init__(self):
13 | super(KeypointLoss, self).__init__()
14 |
15 | def forward(self, pred_seq, gt_seq, gt_conf=None):
16 | #pred_seq: (B, C, T)
17 | if gt_conf is not None:
18 | gt_conf = gt_conf >= 0.01
19 | return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
20 | else:
21 | return F.mse_loss(pred_seq, gt_seq)
22 |
23 |
24 | class KLLoss(nn.Module):
25 | def __init__(self, kl_tolerance):
26 | super(KLLoss, self).__init__()
27 | self.kl_tolerance = kl_tolerance
28 |
29 | def forward(self, mu, var, mul=1):
30 | kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
31 | kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
32 | # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
33 | if self.kl_tolerance is not None:
34 | # above_line = kld_loss[kld_loss > self.kl_tolerance]
35 | # if len(above_line) > 0:
36 | # kld_loss = torch.mean(kld_loss)
37 | # else:
38 | # kld_loss = 0
39 | kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
40 | # else:
41 | kld_loss = torch.mean(kld_loss)
42 | return kld_loss
43 |
44 |
45 | class L2KLLoss(nn.Module):
46 | def __init__(self, kl_tolerance):
47 | super(L2KLLoss, self).__init__()
48 | self.kl_tolerance = kl_tolerance
49 |
50 | def forward(self, x):
51 | # TODO: check
52 | kld_loss = torch.sum(x ** 2, dim=1)
53 | if self.kl_tolerance is not None:
54 | above_line = kld_loss[kld_loss > self.kl_tolerance]
55 | if len(above_line) > 0:
56 | kld_loss = torch.mean(kld_loss)
57 | else:
58 | kld_loss = 0
59 | else:
60 | kld_loss = torch.mean(kld_loss)
61 | return kld_loss
62 |
63 | class L2RegLoss(nn.Module):
64 | def __init__(self):
65 | super(L2RegLoss, self).__init__()
66 |
67 | def forward(self, x):
68 | #TODO: check
69 | return torch.sum(x**2)
70 |
71 |
72 | class L2Loss(nn.Module):
73 | def __init__(self):
74 | super(L2Loss, self).__init__()
75 |
76 | def forward(self, x):
77 | # TODO: check
78 | return torch.sum(x ** 2)
79 |
80 |
81 | class AudioLoss(nn.Module):
82 | def __init__(self):
83 | super(AudioLoss, self).__init__()
84 |
85 | def forward(self, dynamics, gt_poses):
86 | #pay attention, normalized
87 | mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
88 | gt = gt_poses - mean
89 | return F.mse_loss(dynamics, gt)
90 |
91 | L1Loss = nn.L1Loss
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .smplx_face import TrainWrapper as s2g_face
2 | from .smplx_body_vq import TrainWrapper as s2g_body_vq
3 | from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
4 | from .body_ae import TrainWrapper as s2g_body_ae
5 | from .LS3DCG import TrainWrapper as LS3DCG
6 | from .base import TrainWrapperBaseClass
7 |
8 | from .utils import normalize, denormalize
--------------------------------------------------------------------------------
/nets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/init_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/init_model.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/smplx_body_pixel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/smplx_body_vq.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_body_vq.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/smplx_face.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_face.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 | class TrainWrapperBaseClass():
6 | def __init__(self, args, config) -> None:
7 | self.init_optimizer()
8 |
9 | def init_optimizer(self) -> None:
10 | print('using Adam')
11 | self.generator_optimizer = optim.Adam(
12 | self.generator.parameters(),
13 | lr = self.config.Train.learning_rate.generator_learning_rate,
14 | betas=[0.9, 0.999]
15 | )
16 | if self.discriminator is not None:
17 | self.discriminator_optimizer = optim.Adam(
18 | self.discriminator.parameters(),
19 | lr = self.config.Train.learning_rate.discriminator_learning_rate,
20 | betas=[0.9, 0.999]
21 | )
22 |
23 | def __call__(self, bat):
24 | raise NotImplementedError
25 |
26 | def get_loss(self, **kwargs):
27 | raise NotImplementedError
28 |
29 | def state_dict(self):
30 | model_state = {
31 | 'generator': self.generator.state_dict(),
32 | 'generator_optim': self.generator_optimizer.state_dict(),
33 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
34 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
35 | }
36 | return model_state
37 |
38 | def parameters(self):
39 | return self.generator.parameters()
40 |
41 | def load_state_dict(self, state_dict):
42 | if 'generator' in state_dict:
43 | self.generator.load_state_dict(state_dict['generator'])
44 | else:
45 | self.generator.load_state_dict(state_dict)
46 |
47 | if 'generator_optim' in state_dict and self.generator_optimizer is not None:
48 | self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
49 |
50 | if self.discriminator is not None:
51 | self.discriminator.load_state_dict(state_dict['discriminator'])
52 |
53 | if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
54 | self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
55 |
56 | def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
57 | raise NotImplementedError
58 |
59 | def init_params(self):
60 | if self.config.Data.pose.convert_to_6d:
61 | scale = 2
62 | else:
63 | scale = 1
64 |
65 | global_orient = round(0 * scale)
66 | leye_pose = reye_pose = round(0 * scale)
67 | jaw_pose = round(0 * scale)
68 | body_pose = round((63 - 24) * scale)
69 | left_hand_pose = right_hand_pose = round(45 * scale)
70 | if self.expression:
71 | expression = 100
72 | else:
73 | expression = 0
74 |
75 | b_j = 0
76 | jaw_dim = jaw_pose
77 | b_e = b_j + jaw_dim
78 | eye_dim = leye_pose + reye_pose
79 | b_b = b_e + eye_dim
80 | body_dim = global_orient + body_pose
81 | b_h = b_b + body_dim
82 | hand_dim = left_hand_pose + right_hand_pose
83 | b_f = b_h + hand_dim
84 | face_dim = expression
85 |
86 | self.dim_list = [b_j, b_e, b_b, b_h, b_f]
87 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
88 | self.pose = int(self.full_dim / round(3 * scale))
89 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
--------------------------------------------------------------------------------
/nets/body_ae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | from nets.base import TrainWrapperBaseClass
7 | from nets.spg.s2glayers import Discriminator as D_S2G
8 | from nets.spg.vqvae_1d import AE as s2g_body
9 | import torch
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 |
13 | from data_utils.lower_body import c_index, c_index_3d, c_index_6d
14 |
15 |
16 | def separate_aa(aa):
17 | aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
18 | axis = F.normalize(aa[:, :, :, :3], dim=-1)
19 | angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
20 | return axis, angle
21 |
22 |
23 | class TrainWrapper(TrainWrapperBaseClass):
24 | '''
25 | a wrapper receving a batch from data_utils and calculate loss
26 | '''
27 |
28 | def __init__(self, args, config):
29 | self.args = args
30 | self.config = config
31 | self.device = torch.device(self.args.gpu)
32 | self.global_step = 0
33 |
34 | self.gan = False
35 | self.convert_to_6d = self.config.Data.pose.convert_to_6d
36 | self.preleng = self.config.Data.pose.pre_pose_length
37 | self.expression = self.config.Data.pose.expression
38 | self.epoch = 0
39 | self.init_params()
40 | self.num_classes = 4
41 | self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
42 | num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
43 | if self.gan:
44 | self.discriminator = D_S2G(
45 | pose_dim=110 + 64, pose=self.pose
46 | ).to(self.device)
47 | else:
48 | self.discriminator = None
49 |
50 | if self.convert_to_6d:
51 | self.c_index = c_index_6d
52 | else:
53 | self.c_index = c_index_3d
54 |
55 | super().__init__(args, config)
56 |
57 | def init_optimizer(self):
58 |
59 | self.g_optimizer = optim.Adam(
60 | self.g.parameters(),
61 | lr=self.config.Train.learning_rate.generator_learning_rate,
62 | betas=[0.9, 0.999]
63 | )
64 |
65 | def state_dict(self):
66 | model_state = {
67 | 'g': self.g.state_dict(),
68 | 'g_optim': self.g_optimizer.state_dict(),
69 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
70 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
71 | }
72 | return model_state
73 |
74 |
75 | def __call__(self, bat):
76 | # assert (not self.args.infer), "infer mode"
77 | self.global_step += 1
78 |
79 | total_loss = None
80 | loss_dict = {}
81 |
82 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
83 |
84 | # id = bat['speaker'].to(self.device) - 20
85 | # id = F.one_hot(id, self.num_classes)
86 |
87 | poses = poses[:, self.c_index, :]
88 | gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
89 |
90 | loss = 0
91 | loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
92 |
93 | return total_loss, loss_dict
94 |
95 | def vq_train(self, gt, name, model, dict, total_loss, pre=None):
96 | x_recon = model(gt_poses=gt, pre_state=pre)
97 | loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
98 | # total_loss = total_loss + loss
99 |
100 | if name == 'g':
101 | optimizer_name = 'g_optimizer'
102 |
103 | optimizer = getattr(self, optimizer_name)
104 | optimizer.zero_grad()
105 | loss.backward()
106 | optimizer.step()
107 |
108 | for key in list(loss_dict.keys()):
109 | dict[name + key] = loss_dict.get(key, 0).item()
110 | return dict, total_loss
111 |
112 | def get_loss(self,
113 | pred_poses,
114 | gt_poses,
115 | pre=None
116 | ):
117 | loss_dict = {}
118 |
119 |
120 | rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
121 | v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
122 | v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
123 | velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
124 |
125 | if pre is None:
126 | f0_vel = 0
127 | else:
128 | v0_pr = pred_poses[:, 0] - pre[:, -1]
129 | v0_gt = gt_poses[:, 0] - pre[:, -1]
130 | f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
131 |
132 | gen_loss = rec_loss + velocity_loss + f0_vel
133 |
134 | loss_dict['rec_loss'] = rec_loss
135 | loss_dict['velocity_loss'] = velocity_loss
136 | # loss_dict['e_q_loss'] = e_q_loss
137 | if pre is not None:
138 | loss_dict['f0_vel'] = f0_vel
139 |
140 | return gen_loss, loss_dict
141 |
142 | def load_state_dict(self, state_dict):
143 | self.g.load_state_dict(state_dict['g'])
144 |
145 | def extract(self, x):
146 | self.g.eval()
147 | if x.shape[2] > self.full_dim:
148 | if x.shape[2] == 239:
149 | x = x[:, :, 102:]
150 | x = x[:, :, self.c_index]
151 | feat = self.g.encode(x)
152 | return feat.transpose(1, 2), x
153 |
--------------------------------------------------------------------------------
/nets/init_model.py:
--------------------------------------------------------------------------------
1 | from nets import *
2 |
3 |
4 | def init_model(model_name, args, config):
5 |
6 | if model_name == 's2g_face':
7 | generator = s2g_face(
8 | args,
9 | config,
10 | )
11 | elif model_name == 's2g_body_vq':
12 | generator = s2g_body_vq(
13 | args,
14 | config,
15 | )
16 | elif model_name == 's2g_body_pixel':
17 | generator = s2g_body_pixel(
18 | args,
19 | config,
20 | )
21 | elif model_name == 's2g_body_ae':
22 | generator = s2g_body_ae(
23 | args,
24 | config,
25 | )
26 | elif model_name == 's2g_LS3DCG':
27 | generator = LS3DCG(
28 | args,
29 | config,
30 | )
31 | else:
32 | raise ValueError
33 | return generator
34 |
35 |
36 |
--------------------------------------------------------------------------------
/nets/smplx_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | from nets.layers import *
7 | from nets.base import TrainWrapperBaseClass
8 | # from nets.spg.faceformer import Faceformer
9 | from nets.spg.s2g_face import Generator as s2g_face
10 | from losses import KeypointLoss
11 | from nets.utils import denormalize
12 | from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
13 | import numpy as np
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 | from sklearn.preprocessing import normalize
17 | import smplx
18 |
19 |
20 | class TrainWrapper(TrainWrapperBaseClass):
21 | '''
22 | a wrapper receving a batch from data_utils and calculate loss
23 | '''
24 |
25 | def __init__(self, args, config):
26 | self.args = args
27 | self.config = config
28 | self.device = torch.device(self.args.gpu)
29 | self.global_step = 0
30 |
31 | self.convert_to_6d = self.config.Data.pose.convert_to_6d
32 | self.expression = self.config.Data.pose.expression
33 | self.epoch = 0
34 | self.init_params()
35 | self.num_classes = 4
36 |
37 | self.generator = s2g_face(
38 | n_poses=self.config.Data.pose.generate_length,
39 | each_dim=self.each_dim,
40 | dim_list=self.dim_list,
41 | training=not self.args.infer,
42 | device=self.device,
43 | identity=False if self.convert_to_6d else True,
44 | num_classes=self.num_classes,
45 | ).to(self.device)
46 |
47 | # self.generator = Faceformer().to(self.device)
48 |
49 | self.discriminator = None
50 | self.am = None
51 |
52 | self.MSELoss = KeypointLoss().to(self.device)
53 | super().__init__(args, config)
54 |
55 | def init_optimizer(self):
56 | self.generator_optimizer = optim.SGD(
57 | filter(lambda p: p.requires_grad,self.generator.parameters()),
58 | lr=0.001,
59 | momentum=0.9,
60 | nesterov=False,
61 | )
62 |
63 | def init_params(self):
64 | if self.convert_to_6d:
65 | scale = 2
66 | else:
67 | scale = 1
68 |
69 | global_orient = round(3 * scale)
70 | leye_pose = reye_pose = round(3 * scale)
71 | jaw_pose = round(3 * scale)
72 | body_pose = round(63 * scale)
73 | left_hand_pose = right_hand_pose = round(45 * scale)
74 | if self.expression:
75 | expression = 100
76 | else:
77 | expression = 0
78 |
79 | b_j = 0
80 | jaw_dim = jaw_pose
81 | b_e = b_j + jaw_dim
82 | eye_dim = leye_pose + reye_pose
83 | b_b = b_e + eye_dim
84 | body_dim = global_orient + body_pose
85 | b_h = b_b + body_dim
86 | hand_dim = left_hand_pose + right_hand_pose
87 | b_f = b_h + hand_dim
88 | face_dim = expression
89 |
90 | self.dim_list = [b_j, b_e, b_b, b_h, b_f]
91 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
92 | self.pose = int(self.full_dim / round(3 * scale))
93 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
94 |
95 | def __call__(self, bat):
96 | # assert (not self.args.infer), "infer mode"
97 | self.global_step += 1
98 |
99 | total_loss = None
100 | loss_dict = {}
101 |
102 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
103 | id = bat['speaker'].to(self.device) - 20
104 | id = F.one_hot(id, self.num_classes)
105 |
106 | aud = aud.permute(0, 2, 1)
107 | gt_poses = poses.permute(0, 2, 1)
108 |
109 | if self.expression:
110 | expression = bat['expression'].to(self.device).to(torch.float32)
111 | gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
112 |
113 | pred_poses, _ = self.generator(
114 | aud,
115 | gt_poses,
116 | id,
117 | )
118 |
119 | G_loss, G_loss_dict = self.get_loss(
120 | pred_poses=pred_poses,
121 | gt_poses=gt_poses,
122 | pre_poses=None,
123 | mode='training_G',
124 | gt_conf=None,
125 | aud=aud,
126 | )
127 |
128 | self.generator_optimizer.zero_grad()
129 | G_loss.backward()
130 | grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
131 | loss_dict['grad'] = grad.item()
132 | self.generator_optimizer.step()
133 |
134 | for key in list(G_loss_dict.keys()):
135 | loss_dict[key] = G_loss_dict.get(key, 0).item()
136 |
137 | return total_loss, loss_dict
138 |
139 | def get_loss(self,
140 | pred_poses,
141 | gt_poses,
142 | pre_poses,
143 | aud,
144 | mode='training_G',
145 | gt_conf=None,
146 | exp=1,
147 | gt_nzero=None,
148 | pre_nzero=None,
149 | ):
150 | loss_dict = {}
151 |
152 |
153 | [b_j, b_e, b_b, b_h, b_f] = self.dim_list
154 |
155 | MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
156 | if self.expression:
157 | expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
158 | else:
159 | expl = 0
160 |
161 | gen_loss = expl + MSELoss
162 |
163 | loss_dict['MSELoss'] = MSELoss
164 | if self.expression:
165 | loss_dict['exp_loss'] = expl
166 |
167 | return gen_loss, loss_dict
168 |
169 | def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
170 | '''
171 | initial_pose: (B, C, T), normalized
172 | (aud_fn, txgfile) -> generated motion (B, T, C)
173 | '''
174 | output = []
175 |
176 | # assert self.args.infer, "train mode"
177 | self.generator.eval()
178 |
179 | if self.config.Data.pose.normalization:
180 | assert norm_stats is not None
181 | data_mean = norm_stats[0]
182 | data_std = norm_stats[1]
183 |
184 | # assert initial_pose.shape[-1] == pre_length
185 | if initial_pose is not None:
186 | gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
187 | pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
188 | poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
189 | B = pre_poses.shape[0]
190 | else:
191 | gt = None
192 | pre_poses=None
193 | B = 1
194 |
195 | if type(aud_fn) == torch.Tensor:
196 | aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
197 | num_poses_to_generate = aud_feat.shape[-1]
198 | else:
199 | aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
200 | aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
201 | aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
202 | if frame is None:
203 | frame = aud_feat.shape[2]*30//16000
204 | #
205 | if id is None:
206 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
207 | else:
208 | id = F.one_hot(id, self.num_classes).to(self.generator.device)
209 |
210 | with torch.no_grad():
211 | pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
212 | pred_poses = pred_poses.cpu().numpy()
213 | output = pred_poses
214 |
215 | if self.config.Data.pose.normalization:
216 | output = denormalize(output, data_mean, data_std)
217 |
218 | return output
219 |
220 |
221 | def generate(self, wv2_feat, frame):
222 | '''
223 | initial_pose: (B, C, T), normalized
224 | (aud_fn, txgfile) -> generated motion (B, T, C)
225 | '''
226 | output = []
227 |
228 | # assert self.args.infer, "train mode"
229 | self.generator.eval()
230 |
231 | B = 1
232 |
233 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
234 | id = id.repeat(wv2_feat.shape[0], 1)
235 |
236 | with torch.no_grad():
237 | pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
238 | return pred_poses
239 |
--------------------------------------------------------------------------------
/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/__pycache__/s2g_face.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/s2g_face.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/__pycache__/s2glayers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/s2glayers.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/__pycache__/wav2vec.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/wav2vec.cpython-37.pyc
--------------------------------------------------------------------------------
/nets/spg/gated_pixelcnn_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def weights_init(m):
7 | classname = m.__class__.__name__
8 | if classname.find('Conv') != -1:
9 | try:
10 | nn.init.xavier_uniform_(m.weight.data)
11 | m.bias.data.fill_(0)
12 | except AttributeError:
13 | print("Skipping initialization of ", classname)
14 |
15 |
16 | class GatedActivation(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 |
20 | def forward(self, x):
21 | x, y = x.chunk(2, dim=1)
22 | return F.tanh(x) * F.sigmoid(y)
23 |
24 |
25 | class GatedMaskedConv2d(nn.Module):
26 | def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
27 | super().__init__()
28 | assert kernel % 2 == 1, print("Kernel size must be odd")
29 | self.mask_type = mask_type
30 | self.residual = residual
31 | self.bh_model = bh_model
32 |
33 | self.class_cond_embedding = nn.Embedding(
34 | n_classes, 2 * dim
35 | )
36 |
37 | kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
38 | padding_shp = (kernel // 2, 1 if self.bh_model else 0)
39 | self.vert_stack = nn.Conv2d(
40 | dim, dim * 2,
41 | kernel_shp, 1, padding_shp
42 | )
43 |
44 | self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
45 |
46 | kernel_shp = (1, 2)
47 | padding_shp = (0, 1)
48 | self.horiz_stack = nn.Conv2d(
49 | dim, dim * 2,
50 | kernel_shp, 1, padding_shp
51 | )
52 |
53 | self.horiz_resid = nn.Conv2d(dim, dim, 1)
54 |
55 | self.gate = GatedActivation()
56 |
57 | def make_causal(self):
58 | self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
59 | self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
60 |
61 | def forward(self, x_v, x_h, h):
62 | if self.mask_type == 'A':
63 | self.make_causal()
64 |
65 | h = self.class_cond_embedding(h)
66 | h_vert = self.vert_stack(x_v)
67 | h_vert = h_vert[:, :, :x_v.size(-2), :]
68 | out_v = self.gate(h_vert + h[:, :, None, None])
69 |
70 | if self.bh_model:
71 | h_horiz = self.horiz_stack(x_h)
72 | h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
73 | v2h = self.vert_to_horiz(h_vert)
74 |
75 | out = self.gate(v2h + h_horiz + h[:, :, None, None])
76 | if self.residual:
77 | out_h = self.horiz_resid(out) + x_h
78 | else:
79 | out_h = self.horiz_resid(out)
80 | else:
81 | if self.residual:
82 | out_v = self.horiz_resid(out_v) + x_v
83 | else:
84 | out_v = self.horiz_resid(out_v)
85 | out_h = out_v
86 |
87 | return out_v, out_h
88 |
89 |
90 | class GatedPixelCNN(nn.Module):
91 | def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
92 | super().__init__()
93 | self.dim = dim
94 | self.audio = audio
95 | self.bh_model = bh_model
96 |
97 | if self.audio:
98 | self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
99 | self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
100 | self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
101 |
102 | # Create embedding layer to embed input
103 | self.embedding = nn.Embedding(input_dim, dim)
104 |
105 | # Building the PixelCNN layer by layer
106 | self.layers = nn.ModuleList()
107 |
108 | # Initial block with Mask-A convolution
109 | # Rest with Mask-B convolutions
110 | for i in range(n_layers):
111 | mask_type = 'A' if i == 0 else 'B'
112 | kernel = 7 if i == 0 else 3
113 | residual = False if i == 0 else True
114 |
115 | self.layers.append(
116 | GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
117 | )
118 |
119 | # Add the output layer
120 | self.output_conv = nn.Sequential(
121 | nn.Conv2d(dim, 512, 1),
122 | nn.ReLU(True),
123 | nn.Conv2d(512, input_dim, 1)
124 | )
125 |
126 | self.apply(weights_init)
127 |
128 | self.dp = nn.Dropout(0.1)
129 |
130 | def forward(self, x, label, aud=None):
131 | shp = x.size() + (-1,)
132 | x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
133 | x = x.permute(0, 3, 1, 2) # (B, C, W, W)
134 |
135 | x_v, x_h = (x, x)
136 | for i, layer in enumerate(self.layers):
137 | if i == 1 and self.audio is True:
138 | aud = self.embedding_aud(aud)
139 | a = torch.ones(aud.shape[-2]).to(aud.device)
140 | a = self.dp(a)
141 | aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
142 | x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
143 | if self.bh_model:
144 | x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
145 | x_v, x_h = layer(x_v, x_h, label)
146 |
147 | if self.bh_model:
148 | return self.output_conv(x_h)
149 | else:
150 | return self.output_conv(x_v)
151 |
152 | def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
153 | param = next(self.parameters())
154 | x = torch.zeros(
155 | (batch_size, *shape),
156 | dtype=torch.int64, device=param.device
157 | )
158 | if pre_latents is not None:
159 | x = torch.cat([pre_latents, x], dim=1)
160 | aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
161 | h0 = pre_latents.shape[1]
162 | h = h0 + shape[0]
163 | else:
164 | h0 = 0
165 | h = shape[0]
166 |
167 | for i in range(h0, h):
168 | for j in range(shape[1]):
169 | if self.audio:
170 | logits = self.forward(x, label, aud_feat)
171 | else:
172 | logits = self.forward(x, label)
173 | probs = F.softmax(logits[:, :, i, j], -1)
174 | x.data[:, i, j].copy_(
175 | probs.multinomial(1).squeeze().data
176 | )
177 | return x[:, h0:h]
178 |
--------------------------------------------------------------------------------
/nets/spg/s2g_face.py:
--------------------------------------------------------------------------------
1 | '''
2 | not exactly the same as the official repo but the results are good
3 | '''
4 | import sys
5 | import os
6 |
7 | from transformers import Wav2Vec2Processor
8 |
9 | from .wav2vec import Wav2Vec2Model
10 | from torchaudio.sox_effects import apply_effects_tensor
11 |
12 | sys.path.append(os.getcwd())
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torchaudio as ta
19 | import math
20 | from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
21 |
22 |
23 | """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
24 |
25 |
26 | def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
27 | """
28 | :param audio: 1 x T tensor containing a 16kHz audio signal
29 | :param frame_rate: frame rate for video (we need one audio chunk per video frame)
30 | :param chunk_size: number of audio samples per chunk
31 | :return: num_chunks x chunk_size tensor containing sliced audio
32 | """
33 | samples_per_frame = 16000 // frame_rate
34 | padding = (chunk_size - samples_per_frame) // 2
35 | audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
36 | anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
37 | audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
38 | return audio
39 |
40 |
41 | class MeshtalkEncoder(nn.Module):
42 | def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
43 | """
44 | :param latent_dim: size of the latent audio embedding
45 | :param model_name: name of the model, used to load and save the model
46 | """
47 | super().__init__()
48 |
49 | self.melspec = ta.transforms.MelSpectrogram(
50 | sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
51 | )
52 |
53 | conv_len = 5
54 | self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
55 | self.weights_init(self.convert_dimensions)
56 | self.receptive_field = conv_len
57 |
58 | convs = []
59 | for i in range(6):
60 | dilation = 2 * (i % 3 + 1)
61 | self.receptive_field += (conv_len - 1) * dilation
62 | convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
63 | self.weights_init(convs[-1])
64 | self.convs = torch.nn.ModuleList(convs)
65 | self.code = torch.nn.Linear(128, latent_dim)
66 |
67 | self.apply(lambda x: self.weights_init(x))
68 |
69 | def weights_init(self, m):
70 | if isinstance(m, torch.nn.Conv1d):
71 | torch.nn.init.xavier_uniform_(m.weight)
72 | try:
73 | torch.nn.init.constant_(m.bias, .01)
74 | except:
75 | pass
76 |
77 | def forward(self, audio: torch.Tensor):
78 | """
79 | :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
80 | :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
81 | """
82 | B, T = audio.shape[0], audio.shape[1]
83 | x = self.melspec(audio).squeeze(1)
84 | x = torch.log(x.clamp(min=1e-10, max=None))
85 | if T == 1:
86 | x = x.unsqueeze(1)
87 |
88 | # Convert to the right dimensionality
89 | x = x.view(-1, x.shape[2], x.shape[3])
90 | x = F.leaky_relu(self.convert_dimensions(x), .2)
91 |
92 | # Process stacks
93 | for conv in self.convs:
94 | x_ = F.leaky_relu(conv(x), .2)
95 | if self.training:
96 | x_ = F.dropout(x_, .2)
97 | l = (x.shape[2] - x_.shape[2]) // 2
98 | x = (x[:, :, l:-l] + x_) / 2
99 |
100 | x = torch.mean(x, dim=-1)
101 | x = x.view(B, T, x.shape[-1])
102 | x = self.code(x)
103 |
104 | return {"code": x}
105 |
106 |
107 | class AudioEncoder(nn.Module):
108 | def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
109 | super().__init__()
110 | self.identity = identity
111 | if self.identity:
112 | in_dim = in_dim + 64
113 | self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
114 | self.first_net = SeqTranslator1D(in_dim, out_dim,
115 | min_layers_num=3,
116 | residual=True,
117 | norm='ln'
118 | )
119 | self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
120 | self.dropout = nn.Dropout(0.1)
121 | # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
122 |
123 | def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
124 |
125 | spectrogram = spectrogram
126 | spectrogram = self.dropout(spectrogram)
127 | if self.identity:
128 | id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
129 | id = self.id_mlp(id)
130 | spectrogram = torch.cat([spectrogram, id], dim=1)
131 | x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
132 | if time_steps is not None:
133 | x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
134 | # x1, _ = self.att(x1, x1, x1)
135 | # x1, hidden_state = self.grus(x1)
136 | # x1 = x1.permute(0, 2, 1)
137 | hidden_state=None
138 |
139 | return x1, hidden_state
140 |
141 |
142 | class Generator(nn.Module):
143 | def __init__(self,
144 | n_poses,
145 | each_dim: list,
146 | dim_list: list,
147 | training=False,
148 | device=None,
149 | identity=True,
150 | num_classes=0,
151 | ):
152 | super().__init__()
153 |
154 | self.training = training
155 | self.device = device
156 | self.gen_length = n_poses
157 | self.identity = identity
158 |
159 | norm = 'ln'
160 | in_dim = 256
161 | out_dim = 256
162 |
163 | self.encoder_choice = 'faceformer'
164 |
165 | if self.encoder_choice == 'meshtalk':
166 | self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
167 | elif self.encoder_choice == 'faceformer':
168 | # wav2vec 2.0 weights initialization
169 | self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
170 | self.audio_encoder.feature_extractor._freeze_parameters()
171 | self.audio_feature_map = nn.Linear(768, in_dim)
172 | else:
173 | self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
174 |
175 | self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
176 |
177 | self.dim_list = dim_list
178 |
179 | self.decoder = nn.ModuleList()
180 | self.final_out = nn.ModuleList()
181 |
182 | self.decoder.append(nn.Sequential(
183 | ConvNormRelu(out_dim, 64, norm=norm),
184 | ConvNormRelu(64, 64, norm=norm),
185 | ConvNormRelu(64, 64, norm=norm),
186 | ))
187 | self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
188 |
189 | self.decoder.append(nn.Sequential(
190 | ConvNormRelu(out_dim, out_dim, norm=norm),
191 | ConvNormRelu(out_dim, out_dim, norm=norm),
192 | ConvNormRelu(out_dim, out_dim, norm=norm),
193 | ))
194 | self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
195 |
196 | def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
197 | if self.training:
198 | time_steps = gt_poses.shape[1]
199 |
200 | # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
201 | if self.encoder_choice == 'meshtalk':
202 | in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
203 | feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
204 | elif self.encoder_choice == 'faceformer':
205 | hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
206 | feature = self.audio_feature_map(hidden_states).transpose(1, 2)
207 | else:
208 | feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
209 |
210 | # hidden_states = in_spec
211 |
212 | feature, _ = self.audio_middle(feature, id=id)
213 |
214 | out = []
215 |
216 | for i in range(self.decoder.__len__()):
217 | mid = self.decoder[i](feature)
218 | mid = self.final_out[i](mid)
219 | out.append(mid)
220 |
221 | out = torch.cat(out, dim=1)
222 | out = out.transpose(1, 2)
223 |
224 | return out, None
225 |
226 |
227 |
--------------------------------------------------------------------------------
/nets/spg/vqvae_1d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .wav2vec import Wav2Vec2Model
7 | from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
8 |
9 |
10 |
11 | class AudioEncoder(nn.Module):
12 | def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
13 | super(AudioEncoder, self).__init__()
14 | self._num_hiddens = num_hiddens
15 | self._num_residual_layers = num_residual_layers
16 | self._num_residual_hiddens = num_residual_hiddens
17 |
18 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
19 |
20 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
21 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
22 | sample='down')
23 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
24 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
25 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
26 |
27 | def forward(self, x, frame_num=0):
28 | h = self.project(x)
29 | h = self._enc_1(h)
30 | h = self._down_1(h)
31 | h = self._enc_2(h)
32 | h = self._down_2(h)
33 | h = self._enc_3(h)
34 | return h
35 |
36 |
37 | class Wav2VecEncoder(nn.Module):
38 | def __init__(self, num_hiddens, num_residual_layers):
39 | super(Wav2VecEncoder, self).__init__()
40 | self._num_hiddens = num_hiddens
41 | self._num_residual_layers = num_residual_layers
42 |
43 | self.audio_encoder = Wav2Vec2Model.from_pretrained(
44 | "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
45 | self.audio_encoder.feature_extractor._freeze_parameters()
46 |
47 | self.project = ConvNormRelu(768, self._num_hiddens, leaky=True)
48 |
49 | self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
50 | self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
51 | self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
52 | self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down')
53 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
54 |
55 | def forward(self, x, frame_num):
56 | h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2)
57 | h = self.project(h)
58 | h = self._enc_1(h)
59 | h = self._down_1(h)
60 | h = self._enc_2(h)
61 | h = self._down_2(h)
62 | h = self._enc_3(h)
63 | return h
64 |
65 |
66 | class Encoder(nn.Module):
67 | def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
68 | super(Encoder, self).__init__()
69 | self._num_hiddens = num_hiddens
70 | self._num_residual_layers = num_residual_layers
71 | self._num_residual_hiddens = num_residual_hiddens
72 |
73 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
74 |
75 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
76 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
77 | sample='down')
78 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
79 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
80 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
81 |
82 | self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
83 |
84 | def forward(self, x):
85 | h = self.project(x)
86 | h = self._enc_1(h)
87 | h = self._down_1(h)
88 | h = self._enc_2(h)
89 | h = self._down_2(h)
90 | h = self._enc_3(h)
91 | h = self.pre_vq_conv(h)
92 | return h
93 |
94 |
95 | class Frame_Enc(nn.Module):
96 | def __init__(self, in_dim, num_hiddens):
97 | super(Frame_Enc, self).__init__()
98 | self.in_dim = in_dim
99 | self.num_hiddens = num_hiddens
100 |
101 | # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4)
102 | self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1)
103 | self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True)
104 | self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1)
105 | self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1)
106 |
107 | def forward(self, x):
108 | # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1)
109 | x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1)
110 | second_last = self.proj_2(x)
111 | last = self.proj_1(x)
112 | return second_last, last
113 |
114 |
115 |
116 | class Decoder(nn.Module):
117 | def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
118 | super(Decoder, self).__init__()
119 | self._num_hiddens = num_hiddens
120 | self._num_residual_layers = num_residual_layers
121 | self._num_residual_hiddens = num_residual_hiddens
122 |
123 | self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
124 |
125 | self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
126 | self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
127 | self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
128 | self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
129 | sample='up')
130 | self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
131 |
132 | if ae:
133 | self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4)
134 | self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True)
135 | self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True)
136 |
137 | self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
138 |
139 | def forward(self, h, last_frame=None):
140 |
141 | h = self.aft_vq_conv(h)
142 | h = self._dec_1(h)
143 | h = self._up_2(h)
144 | h = self._dec_2(h)
145 | h = self._up_3(h)
146 | h = self._dec_3(h)
147 |
148 | recon = self.project(h)
149 | return recon, None
150 |
151 |
152 | class Pre_VQ(nn.Module):
153 | def __init__(self, num_hiddens, embedding_dim, num_chunks):
154 | super(Pre_VQ, self).__init__()
155 | self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks)
156 | self.bn = nn.GroupNorm(num_chunks, num_hiddens)
157 | self.relu = nn.ReLU()
158 | self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks)
159 |
160 | def forward(self, x):
161 | x = self.conv(x)
162 | x = self.bn(x)
163 | x = self.relu(x)
164 | x = self.proj(x)
165 | return x
166 |
167 |
168 | class VQVAE(nn.Module):
169 | """VQ-VAE"""
170 |
171 | def __init__(self, in_dim, embedding_dim, num_embeddings,
172 | num_hiddens, num_residual_layers, num_residual_hiddens,
173 | commitment_cost=0.25, decay=0.99, share=False):
174 | super().__init__()
175 | self.in_dim = in_dim
176 | self.embedding_dim = embedding_dim
177 | self.num_embeddings = num_embeddings
178 | self.share_code_vq = share
179 |
180 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
181 | self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
182 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
183 |
184 | def forward(self, gt_poses, id=None, pre_state=None):
185 | z = self.encoder(gt_poses.transpose(1, 2))
186 | if not self.training:
187 | e, _ = self.vq_layer(z)
188 | x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
189 | return e, x_recon
190 |
191 | e, e_q_loss = self.vq_layer(z)
192 | gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
193 |
194 | return e_q_loss, gt_recon.transpose(1, 2)
195 |
196 | def encode(self, gt_poses, id=None):
197 | z = self.encoder(gt_poses.transpose(1, 2))
198 | e, latents = self.vq_layer(z)
199 | return e, latents
200 |
201 | def decode(self, b, w, e=None, latents=None, pre_state=None):
202 | if e is not None:
203 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
204 | else:
205 | e = self.vq_layer.quantize(latents)
206 | e = e.view(b, w, -1).permute(0, 2, 1).contiguous()
207 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None)
208 | return x
209 |
210 |
211 | class AE(nn.Module):
212 | """VQ-VAE"""
213 |
214 | def __init__(self, in_dim, embedding_dim, num_embeddings,
215 | num_hiddens, num_residual_layers, num_residual_hiddens):
216 | super().__init__()
217 | self.in_dim = in_dim
218 | self.embedding_dim = embedding_dim
219 | self.num_embeddings = num_embeddings
220 |
221 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
222 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True)
223 |
224 | def forward(self, gt_poses, id=None, pre_state=None):
225 | z = self.encoder(gt_poses.transpose(1, 2))
226 | if not self.training:
227 | x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
228 | return z, x_recon
229 | gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None)
230 |
231 | return gt_recon.transpose(1, 2)
232 |
233 | def encode(self, gt_poses, id=None):
234 | z = self.encoder(gt_poses.transpose(1, 2))
235 | return z
236 |
--------------------------------------------------------------------------------
/nets/spg/wav2vec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | import math
7 | from transformers import Wav2Vec2Model,Wav2Vec2Config
8 | from transformers.modeling_outputs import BaseModelOutput
9 | from typing import Optional, Tuple
10 | _CONFIG_FOR_DOC = "Wav2Vec2Config"
11 |
12 | # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
13 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
14 | def _compute_mask_indices(
15 | shape: Tuple[int, int],
16 | mask_prob: float,
17 | mask_length: int,
18 | attention_mask: Optional[torch.Tensor] = None,
19 | min_masks: int = 0,
20 | ) -> np.ndarray:
21 | bsz, all_sz = shape
22 | mask = np.full((bsz, all_sz), False)
23 |
24 | all_num_mask = int(
25 | mask_prob * all_sz / float(mask_length)
26 | + np.random.rand()
27 | )
28 | all_num_mask = max(min_masks, all_num_mask)
29 | mask_idcs = []
30 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None
31 | for i in range(bsz):
32 | if padding_mask is not None:
33 | sz = all_sz - padding_mask[i].long().sum().item()
34 | num_mask = int(
35 | mask_prob * sz / float(mask_length)
36 | + np.random.rand()
37 | )
38 | num_mask = max(min_masks, num_mask)
39 | else:
40 | sz = all_sz
41 | num_mask = all_num_mask
42 |
43 | lengths = np.full(num_mask, mask_length)
44 |
45 | if sum(lengths) == 0:
46 | lengths[0] = min(mask_length, sz - 1)
47 |
48 | min_len = min(lengths)
49 | if sz - min_len <= num_mask:
50 | min_len = sz - num_mask - 1
51 |
52 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
53 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
54 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
55 |
56 | min_len = min([len(m) for m in mask_idcs])
57 | for i, mask_idc in enumerate(mask_idcs):
58 | if len(mask_idc) > min_len:
59 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
60 | mask[i, mask_idc] = True
61 | return mask
62 |
63 | # linear interpolation layer
64 | def linear_interpolation(features, input_fps, output_fps, output_len=None):
65 | features = features.transpose(1, 2)
66 | seq_len = features.shape[2] / float(input_fps)
67 | if output_len is None:
68 | output_len = int(seq_len * output_fps)
69 | output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear')
70 | return output_features.transpose(1, 2)
71 |
72 |
73 | class Wav2Vec2Model(Wav2Vec2Model):
74 | def __init__(self, config):
75 | super().__init__(config)
76 | def forward(
77 | self,
78 | input_values,
79 | attention_mask=None,
80 | output_attentions=None,
81 | output_hidden_states=None,
82 | return_dict=None,
83 | frame_num=None
84 | ):
85 | self.config.output_attentions = True
86 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
87 | output_hidden_states = (
88 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
89 | )
90 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91 |
92 | hidden_states = self.feature_extractor(input_values)
93 | hidden_states = hidden_states.transpose(1, 2)
94 |
95 | hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num)
96 |
97 | if attention_mask is not None:
98 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
99 | attention_mask = torch.zeros(
100 | hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
101 | )
102 | attention_mask[
103 | (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
104 | ] = 1
105 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
106 |
107 | hidden_states = self.feature_projection(hidden_states)
108 |
109 | if self.config.apply_spec_augment and self.training:
110 | batch_size, sequence_length, hidden_size = hidden_states.size()
111 | if self.config.mask_time_prob > 0:
112 | mask_time_indices = _compute_mask_indices(
113 | (batch_size, sequence_length),
114 | self.config.mask_time_prob,
115 | self.config.mask_time_length,
116 | attention_mask=attention_mask,
117 | min_masks=2,
118 | )
119 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
120 | if self.config.mask_feature_prob > 0:
121 | mask_feature_indices = _compute_mask_indices(
122 | (batch_size, hidden_size),
123 | self.config.mask_feature_prob,
124 | self.config.mask_feature_length,
125 | )
126 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
127 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
128 | encoder_outputs = self.encoder(
129 | hidden_states[0],
130 | attention_mask=attention_mask,
131 | output_attentions=output_attentions,
132 | output_hidden_states=output_hidden_states,
133 | return_dict=return_dict,
134 | )
135 | hidden_states = encoder_outputs[0]
136 | if not return_dict:
137 | return (hidden_states,) + encoder_outputs[1:]
138 |
139 | return BaseModelOutput(
140 | last_hidden_state=hidden_states,
141 | hidden_states=encoder_outputs.hidden_states,
142 | attentions=encoder_outputs.attentions,
143 | )
144 |
--------------------------------------------------------------------------------
/nets/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import textgrid as tg
3 | import numpy as np
4 |
5 | def get_parameter_size(model):
6 | total_num = sum(p.numel() for p in model.parameters())
7 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
8 | return total_num, trainable_num
9 |
10 | def denormalize(kps, data_mean, data_std):
11 | '''
12 | kps: (B, T, C)
13 | '''
14 | data_std = data_std.reshape(1, 1, -1)
15 | data_mean = data_mean.reshape(1, 1, -1)
16 | return (kps * data_std) + data_mean
17 |
18 | def normalize(kps, data_mean, data_std):
19 | '''
20 | kps: (B, T, C)
21 | '''
22 | data_std = data_std.squeeze().reshape(1, 1, -1)
23 | data_mean = data_mean.squeeze().reshape(1, 1, -1)
24 |
25 | return (kps-data_mean) / data_std
26 |
27 | def parse_audio(textgrid_file):
28 | '''a demo implementation'''
29 | words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when']
30 | txt=tg.TextGrid.fromFile(textgrid_file)
31 |
32 | total_time=int(np.ceil(txt.maxTime))
33 | code_seq=np.zeros(total_time)
34 |
35 | word_level=txt[0]
36 |
37 | for i in range(len(word_level)):
38 | start_time=word_level[i].minTime
39 | end_time=word_level[i].maxTime
40 | mark=word_level[i].mark
41 |
42 | if mark in words:
43 | start=int(np.round(start_time))
44 | end=int(np.round(end_time))
45 |
46 | if start >= len(code_seq) or end >= len(code_seq):
47 | code_seq[-1] = 1
48 | else:
49 | code_seq[start]=1
50 |
51 | return code_seq
52 |
53 |
54 | def get_path(model_name, model_type):
55 | if model_name == 's2g_body_pixel':
56 | if model_type == 'mfcc':
57 | return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
58 | elif model_type == 'wv2':
59 | return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth'
60 | elif model_type == 'random':
61 | return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth'
62 | elif model_type == 'wbhmodel':
63 | return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth'
64 | elif model_type == 'wobhmodel':
65 | return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth'
66 | elif model_name == 's2g_body':
67 | if model_type == 'a+m-vae':
68 | return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth'
69 | elif model_type == 'a-vae':
70 | return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth'
71 | elif model_type == 'a-ed':
72 | return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth'
73 | elif model_name == 's2g_LS3DCG':
74 | return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
75 | elif model_name == 's2g_body_vq':
76 | if model_type == 'n_com_1024':
77 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth'
78 | elif model_type == 'n_com_2048':
79 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth'
80 | elif model_type == 'n_com_4096':
81 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth'
82 | elif model_type == 'n_com_8192':
83 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth'
84 | elif model_type == 'n_com_16384':
85 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth'
86 | elif model_type == 'n_com_170000':
87 | return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth'
88 | elif model_type == 'com_1024':
89 | return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth'
90 | elif model_type == 'com_2048':
91 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth'
92 | elif model_type == 'com_4096':
93 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth'
94 | elif model_type == 'com_8192':
95 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth'
96 | elif model_type == 'com_16384':
97 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth'
98 |
99 |
100 | def get_dpath(model_name, model_type):
101 | if model_name == 's2g_body_pixel':
102 | if model_type == 'audio':
103 | return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth'
104 | elif model_type == 'wv2':
105 | return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth'
106 | elif model_type == 'random':
107 | return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth'
108 | elif model_type == 'wbhmodel':
109 | return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth'
110 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth'
111 | elif model_type == 'wobhmodel':
112 | return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth'
113 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth'
114 | elif model_name == 's2g_body':
115 | if model_type == 'a+m-vae':
116 | return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth'
117 | elif model_type == 'a-vae':
118 | return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth'
119 | elif model_type == 'a-ed':
120 | return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth'
121 | elif model_name == 's2g_LS3DCG':
122 | return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth'
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.21.5
2 | transformers~=4.22.1
3 | matplotlib~=3.2.2
4 | textgrid~=1.5
5 | smplx~=0.1.28
6 | scikit-learn~=1.0.2
7 | pyrender~=0.1.45
8 | trimesh~=3.14.1
9 | tqdm~=4.64.1
10 | librosa~=0.9.2
11 | scipy~=1.7.3
12 | python_speech_features~=0.6
13 | opencv-python~=4.7.0.68
14 | pyglet~=1.5
--------------------------------------------------------------------------------
/scripts/.idea/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/.idea/__init__.py
--------------------------------------------------------------------------------
/scripts/.idea/aws.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
--------------------------------------------------------------------------------
/scripts/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/scripts/.idea/get_prevar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4 |
5 | sys.path.append(os.getcwd())
6 | from glob import glob
7 |
8 | import numpy as np
9 | import json
10 | import smplx as smpl
11 |
12 | from nets import *
13 | from repro_nets import *
14 | from trainer.options import parse_args
15 | from data_utils import torch_data
16 | from trainer.config import load_JsonConfig
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.utils import data
22 |
23 | def init_model(model_name, model_path, args, config):
24 | if model_name == 'freeMo':
25 | # generator = freeMo_Generator(args)
26 | # generator = freeMo_Generator(args)
27 | generator = freeMo_dev(args, config)
28 | # generator.load_state_dict(torch.load(model_path)['generator'])
29 | elif model_name == 'smplx_S2G':
30 | generator = smplx_S2G(args, config)
31 | elif model_name == 'StyleGestures':
32 | generator = StyleGesture_Generator(
33 | args,
34 | config
35 | )
36 | elif model_name == 'Audio2Gestures':
37 | config.Train.using_mspec_stat = False
38 | generator = Audio2Gesture_Generator(
39 | args,
40 | config,
41 | torch.zeros([1, 1, 108]),
42 | torch.ones([1, 1, 108])
43 | )
44 | elif model_name == 'S2G':
45 | generator = S2G_Generator(
46 | args,
47 | config,
48 | )
49 | elif model_name == 'Tmpt':
50 | generator = S2G_Generator(
51 | args,
52 | config,
53 | )
54 | else:
55 | raise NotImplementedError
56 |
57 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
58 | if model_name == 'smplx_S2G':
59 | generator.generator.load_state_dict(model_ckpt['generator']['generator'])
60 | elif 'generator' in list(model_ckpt.keys()):
61 | generator.load_state_dict(model_ckpt['generator'])
62 | else:
63 | model_ckpt = {'generator': model_ckpt}
64 | generator.load_state_dict(model_ckpt)
65 |
66 | return generator
67 |
68 |
69 |
70 | def prevar_loader(data_root, speakers, args, config, model_path, device, generator):
71 | path = model_path.split('ckpt')[0]
72 | file = os.path.join(os.path.dirname(path), "pre_variable.npy")
73 | data_base = torch_data(
74 | data_root=data_root,
75 | speakers=speakers,
76 | split='pre',
77 | limbscaling=False,
78 | normalization=config.Data.pose.normalization,
79 | norm_method=config.Data.pose.norm_method,
80 | split_trans_zero=False,
81 | num_pre_frames=config.Data.pose.pre_pose_length,
82 | num_generate_length=config.Data.pose.generate_length,
83 | num_frames=15,
84 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
85 | aud_feat_dim=config.Data.aud.aud_feat_dim,
86 | feat_method=config.Data.aud.feat_method,
87 | smplx=True,
88 | audio_sr=22000,
89 | convert_to_6d=config.Data.pose.convert_to_6d,
90 | expression=config.Data.pose.expression
91 | )
92 |
93 | data_base.get_dataset()
94 | pre_set = data_base.all_dataset
95 | pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True)
96 |
97 | total_pose = []
98 |
99 | with torch.no_grad():
100 | for bat in pre_loader:
101 | pose = bat['poses'].to(device).to(torch.float32)
102 | expression = bat['expression'].to(device).to(torch.float32)
103 | pose = pose.permute(0, 2, 1)
104 | pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0)
105 | expression = expression.permute(0, 2, 1)
106 | expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0)
107 | pose = torch.cat([pose, expression], dim=-1)
108 | pose = pose.reshape(pose.shape[0], -1, 1)
109 | pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu()
110 | total_pose.append(np.asarray(pose_code))
111 | total_pose = np.concatenate(total_pose, axis=0)
112 | mean = np.mean(total_pose, axis=0)
113 | std = np.std(total_pose, axis=0)
114 | prevar = (mean, std)
115 | np.save(file, prevar, allow_pickle=True)
116 |
117 | return mean, std
118 |
119 | def main():
120 | parser = parse_args()
121 | args = parser.parse_args()
122 | device = torch.device(args.gpu)
123 | torch.cuda.set_device(device)
124 |
125 | config = load_JsonConfig(args.config_file)
126 |
127 | print('init model...')
128 | generator = init_model(config.Model.model_name, args.model_path, args, config)
129 | print('init pre-pose vectors...')
130 | mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator)
131 |
132 | main()
--------------------------------------------------------------------------------
/scripts/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/scripts/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/scripts/.idea/lower body:
--------------------------------------------------------------------------------
1 | 0, 1, 3, 4, 6, 7, 9, 10,
--------------------------------------------------------------------------------
/scripts/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/scripts/.idea/scripts.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/scripts/.idea/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/.idea/test.png
--------------------------------------------------------------------------------
/scripts/.idea/testtext.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | # path being defined from where the system will read the image
4 | path = r'test.png'
5 | # command used for reading an image from the disk disk, cv2.imread function is used
6 | image1 = cv2.imread(path)
7 | # Window name being specified where the image will be displayed
8 | window_name1 = 'image'
9 | # font for the text being specified
10 | font1 = cv2.FONT_HERSHEY_SIMPLEX
11 | # org for the text being specified
12 | org1 = (50, 50)
13 | # font scale for the text being specified
14 | fontScale1 = 1
15 | # Blue color for the text being specified from BGR
16 | color1 = (255, 255, 255)
17 | # Line thickness for the text being specified at 2 px
18 | thickness1 = 2
19 | # Using the cv2.putText() method for inserting text in the image of the specified path
20 | image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA)
21 | # Displaying the output image
22 | cv2.imshow(window_name1, image_1)
23 | cv2.waitKey(0)
24 | cv2.destroyAllWindows()
25 |
--------------------------------------------------------------------------------
/scripts/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/scripts/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | 1655101254730
48 |
49 |
50 | 1655101254730
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/diversity.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/__pycache__/diversity.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/continuity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # os.environ["PYOPENGL_PLATFORM"] = "egl"
4 | from transformers import Wav2Vec2Processor
5 | from visualise.rendering import RenderTool
6 |
7 | sys.path.append(os.getcwd())
8 | from glob import glob
9 |
10 | import numpy as np
11 | import json
12 | import smplx as smpl
13 |
14 | from nets import *
15 | from trainer.options import parse_args
16 | from data_utils import torch_data
17 | from trainer.config import load_JsonConfig
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from torch.utils import data
23 | from scripts.diversity import init_model, init_dataloader, get_vertices
24 | from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
25 | from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
26 | import time
27 |
28 |
29 | global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
30 |
31 |
32 | def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
33 | smplx_model, rendertool, args=None, config=None, var=None):
34 | am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
35 | am_sr = 16000
36 | num_sample = 1
37 | face = False
38 | if face:
39 | body_static = torch.zeros([1, 162], device='cuda')
40 | body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
41 | stand = False
42 | j = 0
43 | gt_0 = None
44 |
45 | for bat in infer_loader:
46 | poses_ = bat['poses'].to(torch.float32).to(device)
47 | if poses_.shape[-1] == 300:
48 | j = j + 1
49 | if j > 1000:
50 | continue
51 | id = bat['speaker'].to('cuda') - 20
52 | if config.Data.pose.expression:
53 | expression = bat['expression'].to(device).to(torch.float32)
54 | poses = torch.cat([poses_, expression], dim=1)
55 | else:
56 | poses = poses_
57 | cur_wav_file = bat['aud_file'][0]
58 | betas = bat['betas'][0].to(torch.float64).to('cuda')
59 | # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
60 | gt = poses.to('cuda').squeeze().transpose(1, 0)
61 | if config.Data.pose.normalization:
62 | gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
63 | if config.Data.pose.convert_to_6d:
64 | if config.Data.pose.expression:
65 | gt_exp = gt[:, -100:]
66 | gt = gt[:, :-100]
67 |
68 | gt = gt.reshape(gt.shape[0], -1, 6)
69 | gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
70 | gt = torch.cat([gt, gt_exp], -1)
71 | if face:
72 | gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)
73 |
74 | result_list = [gt]
75 |
76 | # cur_wav_file = '.\\training_data\\french-V4.wav'
77 |
78 | # pred_face = g_face.infer_on_audio(cur_wav_file,
79 | # initial_pose=poses_,
80 | # norm_stats=None,
81 | # w_pre=False,
82 | # # id=id,
83 | # frame=None,
84 | # am=am,
85 | # am_sr=am_sr
86 | # )
87 | #
88 | # pred_face = torch.tensor(pred_face).squeeze().to('cuda')
89 |
90 | pred_face = torch.zeros([gt.shape[0], 103], device='cuda')
91 | pred_jaw = pred_face[:, :3]
92 | pred_face = pred_face[:, 3:]
93 |
94 | # id = torch.tensor([0], device='cuda')
95 |
96 | for i in range(num_sample):
97 | pred_res = g_body.infer_on_audio(cur_wav_file,
98 | initial_pose=poses_,
99 | norm_stats=norm_stats,
100 | txgfile=None,
101 | id=id,
102 | var=var,
103 | fps=30,
104 | continuity=True,
105 | smooth=False
106 | )
107 | pred = torch.tensor(pred_res).squeeze().to('cuda')
108 |
109 | if pred.shape[0] < pred_face.shape[0]:
110 | repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
111 | pred = torch.cat([pred, repeat_frame], dim=0)
112 | else:
113 | pred = pred[:pred_face.shape[0], :]
114 |
115 | if config.Data.pose.convert_to_6d:
116 | pred = pred.reshape(pred.shape[0], -1, 6)
117 | pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
118 | pred = pred.reshape(pred.shape[0], -1)
119 |
120 | pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
121 | # pred[:, 9:12] = global_orient
122 | pred = part2full(pred, stand)
123 | if face:
124 | pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
125 | # result_list[0] = poses2pred(result_list[0], stand)
126 | # if gt_0 is None:
127 | # gt_0 = gt
128 | # pred = pred2poses(pred, gt_0)
129 | # result_list[0] = poses2poses(result_list[0], gt_0)
130 |
131 | result_list.append(pred)
132 |
133 | vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
134 |
135 | result_list = [res.to('cpu') for res in result_list]
136 | dict = np.concatenate(result_list[1:], axis=0)
137 | file_name = 'visualise/video/' + config.Log.name + '/' + \
138 | cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
139 | np.save(file_name, dict)
140 |
141 | rendertool._render_continuity(cur_wav_file, vertices_list[1], frame=60)
142 |
143 |
144 | def main():
145 | parser = parse_args()
146 | args = parser.parse_args()
147 | device = torch.device(args.gpu)
148 | torch.cuda.set_device(device)
149 |
150 | config = load_JsonConfig(args.config_file)
151 |
152 | smplx = True
153 |
154 | os.environ['smplx_npz_path'] = config.smplx_npz_path
155 | os.environ['extra_joint_path'] = config.extra_joint_path
156 | os.environ['j14_regressor_path'] = config.j14_regressor_path
157 |
158 | print('init model...')
159 | body_model_name = 's2g_body_pixel'
160 | body_model_path = './experiments/2022-12-31-smplx_S2G-body-pixel-conti-wide/ckpt-99.pth' # './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth'
161 | generator = init_model(body_model_name, body_model_path, args, config)
162 |
163 | # face_model_name = 's2g_face'
164 | # face_model_path = './experiments/2022-10-15-smplx_S2G-face-sgd-3p-wv2/ckpt-99.pth' # './experiments/2022-09-28-smplx_S2G-face-faceformer-3d/ckpt-99.pth'
165 | # generator_face = init_model(face_model_name, face_model_path, args, config)
166 | generator_face = None
167 | print('init dataloader...')
168 | infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
169 |
170 | print('init smlpx model...')
171 | dtype = torch.float64
172 | model_params = dict(model_path='E:/PycharmProjects/Motion-Projects/models',
173 | model_type='smplx',
174 | create_global_orient=True,
175 | create_body_pose=True,
176 | create_betas=True,
177 | num_betas=300,
178 | create_left_hand_pose=True,
179 | create_right_hand_pose=True,
180 | use_pca=False,
181 | flat_hand_mean=False,
182 | create_expression=True,
183 | num_expression_coeffs=100,
184 | num_pca_comps=12,
185 | create_jaw_pose=True,
186 | create_leye_pose=True,
187 | create_reye_pose=True,
188 | create_transl=False,
189 | # gender='ne',
190 | dtype=dtype, )
191 | smplx_model = smpl.create(**model_params).to('cuda')
192 | print('init rendertool...')
193 | rendertool = RenderTool('visualise/video/' + config.Log.name)
194 |
195 | infer(config.Data.data_root, generator, generator_face, None, args.exp_name, infer_loader, infer_set, device,
196 | norm_stats, smplx, smplx_model, rendertool, args, config, (None, None))
197 |
198 |
199 | if __name__ == '__main__':
200 | main()
201 |
--------------------------------------------------------------------------------
/scripts/test_body.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
6 | sys.path.append(os.getcwd())
7 |
8 | from tqdm import tqdm
9 | from transformers import Wav2Vec2Processor
10 |
11 | from evaluation.FGD import EmbeddingSpaceEvaluator
12 |
13 | from evaluation.metrics import LVD
14 |
15 | import numpy as np
16 | import smplx as smpl
17 |
18 | from data_utils.lower_body import part2full, poses2pred
19 | from data_utils.utils import get_mfcc_ta
20 | from nets import *
21 | from nets.utils import get_path, get_dpath
22 | from trainer.options import parse_args
23 | from data_utils import torch_data
24 | from trainer.config import load_JsonConfig
25 |
26 | import torch
27 | from torch.utils import data
28 | from data_utils.get_j import to3d, get_joints
29 |
30 |
31 | def init_model(model_name, model_path, args, config):
32 | if model_name == 's2g_face':
33 | generator = s2g_face(
34 | args,
35 | config,
36 | )
37 | elif model_name == 's2g_body_vq':
38 | generator = s2g_body_vq(
39 | args,
40 | config,
41 | )
42 | elif model_name == 's2g_body_pixel':
43 | generator = s2g_body_pixel(
44 | args,
45 | config,
46 | )
47 | elif model_name == 's2g_body_ae':
48 | generator = s2g_body_ae(
49 | args,
50 | config,
51 | )
52 | else:
53 | raise NotImplementedError
54 |
55 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
56 | generator.load_state_dict(model_ckpt['generator'])
57 |
58 | return generator
59 |
60 |
61 | def init_dataloader(data_root, speakers, args, config):
62 | data_base = torch_data(
63 | data_root=data_root,
64 | speakers=speakers,
65 | split='test',
66 | limbscaling=False,
67 | normalization=config.Data.pose.normalization,
68 | norm_method=config.Data.pose.norm_method,
69 | split_trans_zero=False,
70 | num_pre_frames=config.Data.pose.pre_pose_length,
71 | num_generate_length=config.Data.pose.generate_length,
72 | num_frames=30,
73 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
74 | aud_feat_dim=config.Data.aud.aud_feat_dim,
75 | feat_method=config.Data.aud.feat_method,
76 | smplx=True,
77 | audio_sr=22000,
78 | convert_to_6d=config.Data.pose.convert_to_6d,
79 | expression=config.Data.pose.expression,
80 | config=config
81 | )
82 |
83 | if config.Data.pose.normalization:
84 | norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
85 | norm_stats = np.load(norm_stats_fn, allow_pickle=True)
86 | data_base.data_mean = norm_stats[0]
87 | data_base.data_std = norm_stats[1]
88 | else:
89 | norm_stats = None
90 |
91 | data_base.get_dataset()
92 | test_set = data_base.all_dataset
93 | test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
94 |
95 | return test_set, test_loader, norm_stats
96 |
97 |
98 | def body_loss(gt, prs):
99 | loss_dict = {}
100 | # LVD
101 | v_diff = LVD(gt[:, :22, :], prs[:, :, :22, :], symmetrical=False, weight=False)
102 | loss_dict['LVD'] = v_diff
103 | # Accuracy
104 | error = (gt - prs).norm(p=2, dim=-1).sum(dim=-1).mean()
105 | loss_dict['error'] = error
106 | # Diversity
107 | var = prs.var(dim=0).norm(p=2, dim=-1).sum(dim=-1).mean()
108 | loss_dict['diverse'] = var
109 |
110 | return loss_dict
111 |
112 |
113 | def test(test_loader, generator, FGD_handler, smplx_model, config):
114 | print('start testing')
115 |
116 | am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
117 | am_sr = 16000
118 |
119 | loss_dict = {}
120 | B = 2
121 | with torch.no_grad():
122 | count = 0
123 | for bat in tqdm(test_loader, desc="Testing......"):
124 | count = count + 1
125 | # if count == 10:
126 | # break
127 | _, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
128 | bat['expression'].to('cuda').to(torch.float32)
129 | id = bat['speaker'].to('cuda') - 20
130 | betas = bat['betas'][0].to('cuda').to(torch.float64)
131 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2)
132 |
133 | cur_wav_file = bat['aud_file'][0]
134 |
135 | zero_face = torch.zeros([B, poses.shape[1], 103], device='cuda')
136 |
137 | joints_list = []
138 |
139 | pred = generator.infer_on_audio(cur_wav_file,
140 | id=id,
141 | fps=30,
142 | B=B,
143 | am=am,
144 | am_sr=am_sr,
145 | frame=poses.shape[0]
146 | )
147 | pred = torch.tensor(pred, device='cuda')
148 |
149 | FGD_handler.push_samples(pred, poses)
150 |
151 | poses = poses.squeeze()
152 | poses = to3d(poses, config)
153 |
154 | if pred.shape[2] > 129:
155 | pred = pred[:, :, 103:]
156 |
157 | pred = torch.cat([zero_face[:, :pred.shape[1], :3], pred, zero_face[:, :pred.shape[1], 3:]], dim=-1)
158 | full_pred = []
159 | for j in range(B):
160 | f_pred = part2full(pred[j])
161 | full_pred.append(f_pred)
162 |
163 | for i in range(full_pred.__len__()):
164 | full_pred[i] = full_pred[i].unsqueeze(dim=0)
165 | full_pred = torch.cat(full_pred, dim=0)
166 |
167 | pred_joints = get_joints(smplx_model, betas, full_pred)
168 |
169 | poses = poses2pred(poses)
170 | poses = torch.cat([zero_face[0, :, :3], poses[:, 3:165], zero_face[0, :, 3:]], dim=-1)
171 | gt_joints = get_joints(smplx_model, betas, poses[:pred_joints.shape[1]])
172 | FGD_handler.push_joints(pred_joints, gt_joints)
173 | aud = get_mfcc_ta(cur_wav_file, fps=30, sr=16000, am='not None', encoder_choice='onset')
174 | FGD_handler.push_aud(torch.from_numpy(aud))
175 |
176 | bat_loss_dict = body_loss(gt_joints, pred_joints)
177 |
178 | if loss_dict: # 非空
179 | for key in list(bat_loss_dict.keys()):
180 | loss_dict[key] += bat_loss_dict[key]
181 | else:
182 | for key in list(bat_loss_dict.keys()):
183 | loss_dict[key] = bat_loss_dict[key]
184 | for key in loss_dict.keys():
185 | loss_dict[key] = loss_dict[key] / count
186 | print(key + '=' + str(loss_dict[key].item()))
187 |
188 | # MAAC = FGD_handler.get_MAAC()
189 | # print(MAAC)
190 | fgd_dist, feat_dist = FGD_handler.get_scores()
191 | print('fgd_dist=', fgd_dist.item())
192 | print('feat_dist=', feat_dist.item())
193 | BCscore = FGD_handler.get_BCscore()
194 | print('Beat consistency score=', BCscore)
195 |
196 |
197 |
198 |
199 |
200 | def main():
201 | parser = parse_args()
202 | args = parser.parse_args()
203 | device = torch.device(args.gpu)
204 | torch.cuda.set_device(device)
205 |
206 | config = load_JsonConfig(args.config_file)
207 |
208 | os.environ['smplx_npz_path'] = config.smplx_npz_path
209 | os.environ['extra_joint_path'] = config.extra_joint_path
210 | os.environ['j14_regressor_path'] = config.j14_regressor_path
211 |
212 | print('init dataloader...')
213 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
214 | print('init model...')
215 | model_name = args.body_model_name
216 | # model_path = get_path(model_name, model_type)
217 | model_path = args.body_model_path
218 | generator = init_model(model_name, model_path, args, config)
219 |
220 | ae = init_model('s2g_body_ae', './experiments/feature_extractor.pth', args,
221 | config)
222 | FGD_handler = EmbeddingSpaceEvaluator(ae, None, 'cuda')
223 |
224 | print('init smlpx model...')
225 | dtype = torch.float64
226 | smplx_path = './visualise/'
227 | model_params = dict(model_path=smplx_path,
228 | model_type='smplx',
229 | create_global_orient=True,
230 | create_body_pose=True,
231 | create_betas=True,
232 | num_betas=300,
233 | create_left_hand_pose=True,
234 | create_right_hand_pose=True,
235 | use_pca=False,
236 | flat_hand_mean=False,
237 | create_expression=True,
238 | num_expression_coeffs=100,
239 | num_pca_comps=12,
240 | create_jaw_pose=True,
241 | create_leye_pose=True,
242 | create_reye_pose=True,
243 | create_transl=False,
244 | dtype=dtype, )
245 |
246 | smplx_model = smpl.create(**model_params).to('cuda')
247 |
248 | test(test_loader, generator, FGD_handler, smplx_model, config)
249 |
250 |
251 | if __name__ == '__main__':
252 | main()
253 |
--------------------------------------------------------------------------------
/scripts/test_face.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 | sys.path.append(os.getcwd())
6 |
7 | from tqdm import tqdm
8 | from transformers import Wav2Vec2Processor
9 |
10 | from evaluation.metrics import LVD
11 |
12 | import numpy as np
13 | import smplx as smpl
14 |
15 | from nets import *
16 | from trainer.options import parse_args
17 | from data_utils import torch_data
18 | from trainer.config import load_JsonConfig
19 | from data_utils.get_j import get_joints
20 |
21 | import torch
22 | from torch.utils import data
23 |
24 |
25 | def init_model(model_name, model_path, args, config):
26 | if model_name == 's2g_face':
27 | generator = s2g_face(
28 | args,
29 | config,
30 | )
31 | elif model_name == 's2g_body_vq':
32 | generator = s2g_body_vq(
33 | args,
34 | config,
35 | )
36 | elif model_name == 's2g_body_pixel':
37 | generator = s2g_body_pixel(
38 | args,
39 | config,
40 | )
41 | else:
42 | raise NotImplementedError
43 |
44 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
45 | if model_name == 'smplx_S2G':
46 | generator.generator.load_state_dict(model_ckpt['generator']['generator'])
47 | elif 'generator' in list(model_ckpt.keys()):
48 | generator.load_state_dict(model_ckpt['generator'])
49 | else:
50 | model_ckpt = {'generator': model_ckpt}
51 | generator.load_state_dict(model_ckpt)
52 |
53 | return generator
54 |
55 |
56 | def init_dataloader(data_root, speakers, args, config):
57 | data_base = torch_data(
58 | data_root=data_root,
59 | speakers=speakers,
60 | split='test',
61 | limbscaling=False,
62 | normalization=config.Data.pose.normalization,
63 | norm_method=config.Data.pose.norm_method,
64 | split_trans_zero=False,
65 | num_pre_frames=config.Data.pose.pre_pose_length,
66 | num_generate_length=config.Data.pose.generate_length,
67 | num_frames=30,
68 | aud_feat_win_size=config.Data.aud.aud_feat_win_size,
69 | aud_feat_dim=config.Data.aud.aud_feat_dim,
70 | feat_method=config.Data.aud.feat_method,
71 | smplx=True,
72 | audio_sr=22000,
73 | convert_to_6d=config.Data.pose.convert_to_6d,
74 | expression=config.Data.pose.expression,
75 | config=config
76 | )
77 |
78 | if config.Data.pose.normalization:
79 | norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
80 | norm_stats = np.load(norm_stats_fn, allow_pickle=True)
81 | data_base.data_mean = norm_stats[0]
82 | data_base.data_std = norm_stats[1]
83 | else:
84 | norm_stats = None
85 |
86 | data_base.get_dataset()
87 | test_set = data_base.all_dataset
88 | test_loader = data.DataLoader(test_set, batch_size=1, shuffle=False)
89 |
90 | return test_set, test_loader, norm_stats
91 |
92 |
93 | def face_loss(gt, gt_param, pr, pr_param):
94 | loss_dict = {}
95 |
96 | jaw_xyz = gt[:, 22:25, :] - pr[:, 22:25, :]
97 | jaw_dist = jaw_xyz.norm(p=2, dim=-1)
98 | jaw_dist = jaw_dist.sum(dim=-1).mean()
99 | loss_dict['jaw_l1'] = jaw_dist
100 |
101 | landmark_xyz = gt[:, 74:] - pr[:, 74:]
102 | landmark_dist = landmark_xyz.norm(p=2, dim=-1)
103 | landmark_dist = landmark_dist.sum(dim=-1).mean()
104 | loss_dict['landmark_l1'] = landmark_dist
105 |
106 | face_gt = torch.cat([gt[:, 22:25], gt[:, 74:]], dim=1)
107 | face_pr = torch.cat([pr[:, 22:25], pr[:, 74:]], dim=1)
108 |
109 | loss_dict['LVD'] = LVD(face_gt, face_pr, symmetrical=False, weight=False)
110 |
111 | return loss_dict
112 |
113 |
114 | def test(test_loader, generator, smplx_model, args, config):
115 | print('start testing')
116 |
117 | am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
118 | am_sr = 16000
119 |
120 | loss_dict = {}
121 | with torch.no_grad():
122 | i = 0
123 | for bat in tqdm(test_loader, desc="Testing......"):
124 | i = i + 1
125 | aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
126 | bat['expression'].to('cuda').to(torch.float32)
127 | id = bat['speaker'].to('cuda') - 20
128 | betas = bat['betas'][0].to('cuda').to(torch.float64)
129 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
130 | # poses = to3d(poses, config)
131 |
132 | cur_wav_file = bat['aud_file'][0]
133 | pred_face = generator.infer_on_audio(cur_wav_file,
134 | id=id,
135 | frame=poses.shape[0],
136 | am=am,
137 | am_sr=am_sr
138 | )
139 |
140 | pred_face = torch.tensor(pred_face).to('cuda').squeeze()
141 | if pred_face.shape[1] > 103:
142 | pred_face = pred_face[:, :103]
143 | zero_poses = torch.zeros([pred_face.shape[0], 162], device='cuda')
144 |
145 | full_param = torch.cat([pred_face[:, :3], zero_poses, pred_face[:, 3:]], dim=-1)
146 |
147 | poses[:, 3:165] = full_param[:, 3:165]
148 | gt_joints = get_joints(smplx_model, betas, poses)
149 | pred_joints = get_joints(smplx_model, betas, full_param)
150 | bat_loss_dict = face_loss(gt_joints, poses, pred_joints, full_param)
151 |
152 | if loss_dict: # 非空
153 | for key in list(bat_loss_dict.keys()):
154 | loss_dict[key] += bat_loss_dict[key]
155 | else:
156 | for key in list(bat_loss_dict.keys()):
157 | loss_dict[key] = bat_loss_dict[key]
158 | for key in loss_dict.keys():
159 | loss_dict[key] = loss_dict[key] / i
160 | print(key + '=' + str(loss_dict[key].item()))
161 |
162 |
163 | def main():
164 | parser = parse_args()
165 | args = parser.parse_args()
166 | device = torch.device(args.gpu)
167 | torch.cuda.set_device(device)
168 |
169 | config = load_JsonConfig(args.config_file)
170 |
171 | os.environ['smplx_npz_path'] = config.smplx_npz_path
172 | os.environ['extra_joint_path'] = config.extra_joint_path
173 | os.environ['j14_regressor_path'] = config.j14_regressor_path
174 |
175 | print('init dataloader...')
176 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
177 | print('init model...')
178 | face_model_name = args.face_model_name
179 | face_model_path = args.face_model_path
180 | generator_face = init_model(face_model_name, face_model_path, args, config)
181 |
182 | print('init smlpx model...')
183 | dtype = torch.float64
184 | smplx_path = './visualise/'
185 | model_params = dict(model_path=smplx_path,
186 | model_type='smplx',
187 | create_global_orient=True,
188 | create_body_pose=True,
189 | create_betas=True,
190 | num_betas=300,
191 | create_left_hand_pose=True,
192 | create_right_hand_pose=True,
193 | use_pca=False,
194 | flat_hand_mean=False,
195 | create_expression=True,
196 | num_expression_coeffs=100,
197 | num_pca_comps=12,
198 | create_jaw_pose=True,
199 | create_leye_pose=True,
200 | create_reye_pose=True,
201 | create_transl=False,
202 | dtype=dtype, )
203 | smplx_model = smpl.create(**model_params).to('cuda')
204 |
205 | test(test_loader, generator_face, smplx_model, args, config)
206 |
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
--------------------------------------------------------------------------------
/scripts/test_vq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 | sys.path.append(os.getcwd())
6 |
7 | from tqdm import tqdm
8 | from transformers import Wav2Vec2Processor
9 |
10 | from evaluation.metrics import LVD
11 |
12 | import numpy as np
13 | import smplx as smpl
14 |
15 | from data_utils.lower_body import part2full, poses2pred, c_index_3d
16 | from nets import *
17 | from nets.utils import get_path, get_dpath
18 | from trainer.options import parse_args
19 | from data_utils import torch_data
20 | from trainer.config import load_JsonConfig
21 |
22 | import torch
23 | from torch.utils import data
24 | from data_utils.get_j import to3d, get_joints
25 | from scripts.test_body import init_model, init_dataloader
26 |
27 |
28 | def test(test_loader, generator, config):
29 | print('start testing')
30 |
31 | loss_dict = {}
32 | B = 1
33 | with torch.no_grad():
34 | count = 0
35 | for bat in tqdm(test_loader, desc="Testing......"):
36 | count = count + 1
37 | aud, poses, exp = bat['aud_feat'].to('cuda').to(torch.float32), bat['poses'].to('cuda').to(torch.float32), \
38 | bat['expression'].to('cuda').to(torch.float32)
39 | id = bat['speaker'].to('cuda') - 20
40 | betas = bat['betas'][0].to('cuda').to(torch.float64)
41 | poses = torch.cat([poses, exp], dim=-2).transpose(-1, -2).squeeze()
42 | poses = to3d(poses, config).unsqueeze(dim=0).transpose(1, 2)
43 | # poses = poses[:, c_index_3d, :]
44 |
45 | cur_wav_file = bat['aud_file'][0]
46 |
47 | pred = generator.infer_on_audio(cur_wav_file,
48 | initial_pose=poses,
49 | id=id,
50 | fps=30,
51 | B=B
52 | )
53 | pred = torch.tensor(pred, device='cuda')
54 | bat_loss_dict = {'capacity': (poses[:, c_index_3d, :pred.shape[0]].transpose(1,2) - pred).abs().sum(-1).mean()}
55 |
56 | if loss_dict: # 非空
57 | for key in list(bat_loss_dict.keys()):
58 | loss_dict[key] += bat_loss_dict[key]
59 | else:
60 | for key in list(bat_loss_dict.keys()):
61 | loss_dict[key] = bat_loss_dict[key]
62 | for key in loss_dict.keys():
63 | loss_dict[key] = loss_dict[key] / count
64 | print(key + '=' + str(loss_dict[key].item()))
65 |
66 |
67 | def main():
68 | parser = parse_args()
69 | args = parser.parse_args()
70 | device = torch.device(args.gpu)
71 | torch.cuda.set_device(device)
72 |
73 | config = load_JsonConfig(args.config_file)
74 |
75 | os.environ['smplx_npz_path'] = config.smplx_npz_path
76 | os.environ['extra_joint_path'] = config.extra_joint_path
77 | os.environ['j14_regressor_path'] = config.j14_regressor_path
78 |
79 | print('init dataloader...')
80 | test_set, test_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)
81 | print('init model...')
82 | model_name = 's2g_body_vq'
83 | model_type = 'n_com_8192'
84 | model_path = get_path(model_name, model_type)
85 | generator = init_model(model_name, model_path, args, config)
86 |
87 | test(test_loader, generator, config)
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
92 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # os.chdir('/home/jovyan/Co-Speech-Motion-Generation/src')
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5 | sys.path.append(os.getcwd())
6 |
7 | from trainer import Trainer
8 |
9 | if __name__ == '__main__':
10 | trainer = Trainer()
11 | trainer.train()
--------------------------------------------------------------------------------
/test_body.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/test_body.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json \
6 | --body_model_name s2g_body_pixel \
7 | --body_model_path ./experiments/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth \
8 | --infer
--------------------------------------------------------------------------------
/test_face.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/test_face.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/face.json \
6 | --face_model_name s2g_face \
7 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \
8 | --infer
--------------------------------------------------------------------------------
/train_body_pixel.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json
--------------------------------------------------------------------------------
/train_body_vq.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_vq.json
--------------------------------------------------------------------------------
/train_face.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/train.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/face.json
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .Trainer import Trainer
--------------------------------------------------------------------------------
/trainer/__pycache__/Trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/trainer/__pycache__/Trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/trainer/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/trainer/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/trainer/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/trainer/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/trainer/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/trainer/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/trainer/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | load config from json file
3 | '''
4 | import json
5 | import os
6 |
7 | import configparser
8 |
9 |
10 | class Object():
11 | def __init__(self, config:dict) -> None:
12 | for key in list(config.keys()):
13 | if isinstance(config[key], dict):
14 | setattr(self, key, Object(config[key]))
15 | else:
16 | setattr(self, key, config[key])
17 |
18 | def load_JsonConfig(json_file):
19 | with open(json_file, 'r') as f:
20 | config = json.load(f)
21 |
22 | return Object(config)
23 |
24 |
25 | if __name__ == '__main__':
26 | config = load_JsonConfig('config/style_gestures.json')
27 | print(dir(config))
--------------------------------------------------------------------------------
/trainer/options.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | def parse_args():
4 | parser = ArgumentParser()
5 | parser.add_argument('--gpu', default=0, type=int)
6 | parser.add_argument('--save_dir', default='experiments', type=str)
7 | parser.add_argument('--exp_name', default='smplx_S2G', type=str)
8 | parser.add_argument('--speakers', nargs='+')
9 | parser.add_argument('--seed', default=1, type=int)
10 | parser.add_argument('--model_name', type=str)
11 |
12 | #for Tmpt and S2G
13 | parser.add_argument('--use_template', action='store_true')
14 | parser.add_argument('--template_length', default=0, type=int)
15 |
16 | #for training from a ckpt
17 | parser.add_argument('--resume', action='store_true')
18 | parser.add_argument('--pretrained_pth', default=None, type=str)
19 | parser.add_argument('--style_layer_norm', action='store_true')
20 |
21 | #required
22 | parser.add_argument('--config_file', default='./config/style_gestures.json', type=str)
23 |
24 | # for visualization and test
25 | parser.add_argument('--audio_file', default=None, type=str)
26 | parser.add_argument('--id', default=0, type=int, help='0=oliver, 1=chemistry, 2=seth, 3=conan')
27 | parser.add_argument('--only_face', action='store_true')
28 | parser.add_argument('--stand', action='store_true')
29 | parser.add_argument('--whole_body', action='store_true')
30 | parser.add_argument('--num_sample', default=1, type=int)
31 | parser.add_argument('--face_model_name', default='s2g_face', type=str)
32 | parser.add_argument('--face_model_path', default='./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth', type=str)
33 | parser.add_argument('--body_model_name', default='s2g_body_pixel', type=str)
34 | parser.add_argument('--body_model_path', default='./experiments/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth', type=str)
35 | parser.add_argument('--infer', action='store_true')
36 |
37 | return parser
--------------------------------------------------------------------------------
/trainer/training_config.cfg:
--------------------------------------------------------------------------------
1 | [Input Output]
2 | checkpoint_dir = ./training
3 | expression_basis_fname = ./training_data/init_expression_basis.npy
4 | template_fname = ./template/FLAME_sample.ply
5 | deepspeech_graph_fname = ./ds_graph/output_graph.pb
6 | face_or_body = body
7 | verts_mmaps_path = ./training_data/data_verts.npy
8 | raw_audio_path = ./training_data/raw_audio_fixed.pkl
9 | processed_audio_path = ./training_data/processed_audio_deepspeech.pkl
10 | templates_path = ./training_data/templates.pkl
11 | data2array_verts_path = ./training_data/subj_seq_to_idx.pkl
12 |
13 | [Audio Parameters]
14 | audio_feature_type = deepspeech
15 | num_audio_features = 29
16 | audio_window_size = 16
17 | audio_window_stride = 1
18 | condition_speech_features = True
19 | speech_encoder_size_factor = 1.0
20 |
21 | [Model Parameters]
22 | num_vertices = 10475
23 | expression_dim = 50
24 | init_expression = False
25 | num_consecutive_frames = 30
26 | absolute_reconstruction_loss = False
27 | velocity_weight = 10.0
28 | acceleration_weight = 0.0
29 | verts_regularizer_weight = 0.0
30 |
31 | [Data Setup]
32 | subject_for_training = speeker_oliver
33 | sequence_for_training = 0-00'00'05-00'00'10 1-00'00'32-00'00'37 2-00'01'05-00'01'10
34 | subject_for_validation = speeker_oliver
35 | sequence_for_validation = 2-00'01'05-00'01'10
36 | subject_for_testing = speeker_oliver
37 | sequence_for_testing = 2-00'01'05-00'01'10
38 |
39 | [Learning Parameters]
40 | batch_size = 64
41 | learning_rate = 1e-4
42 | decay_rate = 1.0
43 | epoch_num = 1000
44 | adam_beta1_value = 0.9
45 |
46 | [Visualization Parameters]
47 | num_render_sequences = 3
48 |
49 |
--------------------------------------------------------------------------------
/visualise.sh:
--------------------------------------------------------------------------------
1 | python -W ignore scripts/diversity.py \
2 | --save_dir experiments \
3 | --exp_name smplx_S2G \
4 | --speakers oliver seth conan chemistry \
5 | --config_file ./config/body_pixel.json \
6 | --face_model_path ./experiments/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth \
7 | --body_model_path ./experiments/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth \
8 | --infer
--------------------------------------------------------------------------------
/visualise/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/visualise/__init__.py
--------------------------------------------------------------------------------
/visualise/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/visualise/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/visualise/__pycache__/rendering.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/visualise/__pycache__/rendering.cpython-37.pyc
--------------------------------------------------------------------------------
/visualise/smplx/smplx_extra_joints.yaml:
--------------------------------------------------------------------------------
1 | head_top:
2 | bc:
3 | - 0.8277337276382795
4 | - 0.1422200962169292
5 | - 0.030046176144791284
6 | face: 2581
7 | left_big_toe:
8 | bc:
9 | - 0.0
10 | - 0.0
11 | - 1.0
12 | face: 4407
13 | left_ear:
14 | bc:
15 | - 0.0
16 | - 0.0
17 | - 1.0
18 | face: 1946
19 | left_eye:
20 | bc:
21 | - 0.0
22 | - 1.0
23 | - 0.0
24 | face: 9470
25 | left_heel:
26 | bc:
27 | - 1.0
28 | - 0.0
29 | - 0.0
30 | face: 4621
31 | left_index:
32 | bc:
33 | - 0.0
34 | - 0.0
35 | - 1.0
36 | face: 3720
37 | left_middle:
38 | bc:
39 | - 0.0
40 | - 0.0
41 | - 1.0
42 | face: 3469
43 | left_pinky:
44 | bc:
45 | - 1.0
46 | - 0.0
47 | - 0.0
48 | face: 3575
49 | left_ring:
50 | bc:
51 | - 0.0
52 | - 0.0
53 | - 1.0
54 | face: 3542
55 | left_small_toe:
56 | bc:
57 | - 1.0
58 | - 0.0
59 | - 0.0
60 | face: 4329
61 | left_thumb:
62 | bc:
63 | - 0.0
64 | - 1.0
65 | - 0.0
66 | face: 3630
67 | nose:
68 | bc:
69 | - 0.0
70 | - 1.0
71 | - 0.0
72 | face: 9041
73 | right_big_toe:
74 | bc:
75 | - 1.0
76 | - 0.0
77 | - 0.0
78 | face: 8094
79 | right_ear:
80 | bc:
81 | - 1.0
82 | - 0.0
83 | - 0.0
84 | face: 351
85 | right_eye:
86 | bc:
87 | - 1.0
88 | - 0.0
89 | - 0.0
90 | face: 10093
91 | right_heel:
92 | bc:
93 | - 1.0
94 | - 0.0
95 | - 0.0
96 | face: 8247
97 | right_index:
98 | bc:
99 | - 1.0
100 | - 0.0
101 | - 0.0
102 | face: 6919
103 | right_middle:
104 | bc:
105 | - 0.0
106 | - 0.0
107 | - 1.0
108 | face: 7050
109 | right_pinky:
110 | bc:
111 | - 1.0
112 | - 0.0
113 | - 0.0
114 | face: 7284
115 | right_ring:
116 | bc:
117 | - 0.0
118 | - 0.0
119 | - 1.0
120 | face: 7168
121 | right_small_toe:
122 | bc:
123 | - 0.0
124 | - 0.0
125 | - 1.0
126 | face: 8096
127 | right_thumb:
128 | bc:
129 | - 0.0
130 | - 0.0
131 | - 1.0
132 | face: 7370
133 |
--------------------------------------------------------------------------------
/visualise/teaser_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/visualise/teaser_01.png
--------------------------------------------------------------------------------
/voca/__pycache__/rendering.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/voca/__pycache__/rendering.cpython-37.pyc
--------------------------------------------------------------------------------
/voca/rendering.py:
--------------------------------------------------------------------------------
1 | '''
2 | Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
3 | computer program.
4 |
5 | You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
6 | the computer program from someone who is authorized to grant you that right.
7 |
8 | Any use of the computer program without a valid license is prohibited and liable to prosecution.
9 |
10 | Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
11 | Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
12 | All rights reserved.
13 |
14 | More information about VOCA is available at http://voca.is.tue.mpg.de.
15 | For comments or questions, please email us at voca@tue.mpg.de
16 | '''
17 |
18 | from __future__ import division
19 | import os
20 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # Uncommnet this line while running remotely
21 | import cv2
22 | import pyrender
23 | import trimesh
24 | import tempfile
25 | import numpy as np
26 | import matplotlib as mpl
27 | import matplotlib.cm as cm
28 |
29 |
30 | def get_unit_factor(unit):
31 | if unit == 'mm':
32 | return 1000.0
33 | elif unit == 'cm':
34 | return 100.0
35 | elif unit == 'm':
36 | return 1.0
37 | else:
38 | raise ValueError('Unit not supported')
39 |
40 |
41 | def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, v_colors=None,
42 | errors=None, error_unit='m', min_dist_in_mm=0.0, max_dist_in_mm=3.0, z_offset=1.0, xmag=0.5,
43 | y=0.7, z=1, camera='o', r=None):
44 | camera_params = {'c': np.array([0, 0]),
45 | 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
46 | 'f': np.array([5000, 5000])}
47 |
48 | frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}
49 |
50 | v, f = mesh
51 | v = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
52 |
53 | texture_rendering = tex_img is not None and hasattr(mesh, 'vt') and hasattr(mesh, 'ft')
54 | if texture_rendering:
55 | intensity = 0.5
56 | tex = pyrender.Texture(source=tex_img, source_channels='RGB')
57 | material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
58 |
59 | # Workaround as pyrender requires number of vertices and uv coordinates to be the same
60 | temp_filename = '%s.obj' % next(tempfile._get_candidate_names())
61 | mesh.write_obj(temp_filename)
62 | tri_mesh = trimesh.load(temp_filename, process=False)
63 | try:
64 | os.remove(temp_filename)
65 | except:
66 | print('Failed deleting temporary file - %s' % temp_filename)
67 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=material)
68 | elif errors is not None:
69 | intensity = 0.5
70 | unit_factor = get_unit_factor('mm') / get_unit_factor(error_unit)
71 | errors = unit_factor * errors
72 |
73 | norm = mpl.colors.Normalize(vmin=min_dist_in_mm, vmax=max_dist_in_mm)
74 | cmap = cm.get_cmap(name='jet')
75 | colormapper = cm.ScalarMappable(norm=norm, cmap=cmap)
76 | rgba_per_v = colormapper.to_rgba(errors)
77 | rgb_per_v = rgba_per_v[:, 0:3]
78 | elif v_colors is not None:
79 | intensity = 0.5
80 | rgb_per_v = v_colors
81 | else:
82 | intensity = 6.
83 | rgb_per_v = None
84 |
85 | color = np.array([0.3, 0.5, 0.55])
86 |
87 | if not texture_rendering:
88 | tri_mesh = trimesh.Trimesh(vertices=v, faces=f, vertex_colors=rgb_per_v)
89 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh,
90 | smooth=True,
91 | material=pyrender.MetallicRoughnessMaterial(
92 | metallicFactor=0.05,
93 | roughnessFactor=0.7,
94 | alphaMode='OPAQUE',
95 | baseColorFactor=(color[0], color[1], color[2], 1.0)
96 | ))
97 |
98 | scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
99 |
100 | if camera == 'o':
101 | ymag = xmag * z_offset
102 | camera = pyrender.OrthographicCamera(xmag=xmag, ymag=ymag)
103 | elif camera == 'i':
104 | camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
105 | fy=camera_params['f'][1],
106 | cx=camera_params['c'][0],
107 | cy=camera_params['c'][1],
108 | znear=frustum['near'],
109 | zfar=frustum['far'])
110 | elif camera == 'y':
111 | camera = pyrender.PerspectiveCamera(yfov=(np.pi / 2.0))
112 |
113 | scene.add(render_mesh, pose=np.eye(4))
114 |
115 | camera_pose = np.eye(4)
116 | camera_pose[:3, 3] = np.array([0, 0.7, 1.0 - z_offset])
117 | scene.add(camera, pose=[[1, 0, 0, 0],
118 | [0, 1, 0, y], # 0.25
119 | [0, 0, 1, z], # 0.2
120 | [0, 0, 0, 1]])
121 |
122 |
123 | angle = np.pi / 6.0
124 | # pos = camera_pose[:3,3]
125 | pos = np.array([0, 0.7, 2.0])
126 | if False:
127 | light_color = np.array([1., 1., 1.])
128 | light = pyrender.DirectionalLight(color=light_color, intensity=intensity)
129 |
130 | light_pose = np.eye(4)
131 | light_pose[:3, 3] = np.array([0, 0.7, 2.0])
132 | scene.add(light, pose=light_pose.copy())
133 | else:
134 | light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=2)
135 | light_pose = np.eye(4)
136 | light_pose[:3, 3] = [0, -1, 1]
137 | scene.add(light, pose=light_pose)
138 |
139 | light_pose[:3, 3] = [0, 1, 1]
140 | scene.add(light, pose=light_pose)
141 |
142 | light_pose[:3, 3] = [-1, 1, 2]
143 | scene.add(light, pose=light_pose)
144 |
145 | spot_l = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
146 | innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2)
147 |
148 | light_pose[:3, 3] = [-1, 2, 2]
149 | scene.add(spot_l, pose=light_pose)
150 |
151 | light_pose[:3, 3] = [1, 2, 2]
152 | scene.add(spot_l, pose=light_pose)
153 |
154 | # light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
155 | # scene.add(light, pose=light_pose.copy())
156 | #
157 | # light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
158 | # scene.add(light, pose=light_pose.copy())
159 | #
160 | # light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
161 | # scene.add(light, pose=light_pose.copy())
162 | #
163 | # light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
164 | # scene.add(light, pose=light_pose.copy())
165 |
166 | # pyrender.Viewer(scene)
167 |
168 | flags = pyrender.RenderFlags.SKIP_CULL_FACES
169 | # try:
170 | # r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
171 | color, _ = r.render(scene, flags=flags)
172 | # r.delete()
173 | # except:
174 | # print('pyrender: Failed rendering frame')
175 | # color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')
176 |
177 | return color[..., ::-1]
178 |
--------------------------------------------------------------------------------