├── .gitignore ├── README.md ├── __init__.py ├── config ├── LS3DCG.json ├── body_pixel.json ├── body_vq.json └── face.json ├── data_utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── consts.cpython-37.pyc │ ├── dataloader_torch.cpython-37.pyc │ ├── lower_body.cpython-37.pyc │ ├── mesh_dataset.cpython-37.pyc │ ├── rotation_conversion.cpython-37.pyc │ └── utils.cpython-37.pyc ├── apply_split.py ├── axis2matrix.py ├── consts.py ├── dataloader_torch.py ├── dataset_preprocess.py ├── get_j.py ├── hand_component.json ├── lower_body.py ├── mesh_dataset.py ├── rotation_conversion.py ├── split_more_than_2s.pkl ├── split_train_val_test.py ├── train_val_test.json └── utils.py ├── demo ├── 1st-page │ ├── 1st-page-upper.mp4 │ └── 1st-page-upper.npy ├── french │ ├── french.mp4 │ └── french.npy ├── rich │ ├── rich.mp4 │ └── rich.npy ├── song │ ├── cut.mp4 │ ├── song.mp4 │ └── song.npy └── style │ ├── chemistry.mp4 │ ├── chemistry.npy │ ├── conan.mp4 │ ├── conan.npy │ ├── diversity.mp4 │ ├── diversity.npy │ ├── face.mp4 │ ├── face.npy │ ├── oliver.mp4 │ ├── oliver.npy │ ├── seth.mp4 │ └── seth.npy ├── demo_audio ├── 1st-page.wav ├── french.wav ├── rich.wav ├── rich_short.wav ├── song.wav └── style.wav ├── evaluation ├── FGD.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── metrics.cpython-37.pyc ├── diversity_LVD.py ├── get_quality_samples.py ├── metrics.py ├── mode_transition.py ├── peak_velocity.py └── util.py ├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── losses.cpython-37.pyc └── losses.py ├── nets ├── LS3DCG.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── base.cpython-37.pyc │ ├── init_model.cpython-37.pyc │ ├── layers.cpython-37.pyc │ ├── smplx_body_pixel.cpython-37.pyc │ ├── smplx_body_vq.cpython-37.pyc │ ├── smplx_face.cpython-37.pyc │ └── utils.cpython-37.pyc ├── base.py ├── body_ae.py ├── init_model.py ├── layers.py ├── smplx_body_pixel.py ├── smplx_body_vq.py ├── smplx_face.py ├── spg │ ├── __pycache__ │ │ ├── gated_pixelcnn_v2.cpython-37.pyc │ │ ├── s2g_face.cpython-37.pyc │ │ ├── s2glayers.cpython-37.pyc │ │ ├── vqvae_1d.cpython-37.pyc │ │ ├── vqvae_modules.cpython-37.pyc │ │ └── wav2vec.cpython-37.pyc │ ├── gated_pixelcnn_v2.py │ ├── s2g_face.py │ ├── s2glayers.py │ ├── vqvae_1d.py │ ├── vqvae_modules.py │ └── wav2vec.py └── utils.py ├── requirements.txt ├── scripts ├── .idea │ ├── __init__.py │ ├── aws.xml │ ├── deployment.xml │ ├── get_prevar.py │ ├── inspectionProfiles │ │ ├── Project_Default.xml │ │ └── profiles_settings.xml │ ├── lower body │ ├── modules.xml │ ├── scripts.iml │ ├── test.png │ ├── testtext.py │ ├── vcs.xml │ └── workspace.xml ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── diversity.cpython-37.pyc ├── continuity.py ├── demo.py ├── diversity.py ├── test_body.py ├── test_face.py ├── test_vq.py └── train.py ├── test_body.sh ├── test_face.sh ├── train_body_pixel.sh ├── train_body_vq.sh ├── train_face.sh ├── trainer ├── Trainer.py ├── __init__.py ├── __pycache__ │ ├── Trainer.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── config.cpython-37.pyc │ └── options.cpython-37.pyc ├── config.py ├── options.py └── training_config.cfg ├── visualise.sh ├── visualise ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── rendering.cpython-37.pyc ├── rendering.py ├── smplx │ ├── SMPLX_to_J14.pkl │ └── smplx_extra_joints.yaml └── teaser_01.png └── voca ├── __pycache__ └── rendering.cpython-37.pyc └── rendering.py /.gitignore: -------------------------------------------------------------------------------- 1 | /experiments 2 | visualise/smplx/SMPLX_NEUTRAL.npz 3 | test_6d_mfcc.pkl 4 | test_3d_mfcc.pkl 5 | visualise/video 6 | .idea 7 | .git -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TalkSHOW: Generating Holistic 3D Human Motion from Speech [CVPR2023] 2 | 3 | The official PyTorch implementation of the **CVPR2023** paper [**"Generating Holistic 3D Human Motion from Speech"**](https://arxiv.org/abs/2212.04420). 4 | 5 | Please visit our [**webpage**](https://talkshow.is.tue.mpg.de/) for more details. 6 | 7 | ![teaser](visualise/teaser_01.png) 8 | 9 | ## HighLight 10 | 11 | We directly provide the input and our output for the demo data, you can find them in `/demo/` and `/demo_audio/`. TalkSHOW can generalize well on English, French, Songs so far. Looking forward to more demos. 12 | 13 | You can directly use the generated motion to animate your 3D character or your own digital avatar. We will provide more demos, please stay tuned. And we are quite looking forward to your pull request. 14 | 15 | ## Notes 16 | 17 | We are using 100 dimension parameters for SMPL-X facial expression, if you need other dimensions parameters, you can use this code to convert. 18 | 19 | ``` 20 | https://github.com/yhw-yhw/SHOW/blob/main/cvt_exp_dim_tool.py 21 | ``` 22 | 23 | ## TODO 24 | 25 | - [x] [🤗Hugging Face Demo](https://huggingface.co/spaces/feifeifeiliu/TalkSHOW) 26 | - [ ] Animated 2D videos by the generated motion from TalkSHOW. 27 | 28 | 29 | ## Getting started 30 | 31 | The training code was tested on `Ubuntu 18.04.5 LTS` and the visualization code was test on `Windows 10`, and it requires: 32 | 33 | * Python 3.7 34 | * conda3 or miniconda3 35 | * CUDA capable GPU (one is enough) 36 | 37 | 38 | 39 | ### 1. Setup environment 40 | 41 | Clone the repo: 42 | ```bash 43 | git clone https://github.com/yhw-yhw/TalkSHOW 44 | cd TalkSHOW 45 | ``` 46 | Create conda environment: 47 | ```bash 48 | conda create --name talkshow python=3.7 49 | conda activate talkshow 50 | ``` 51 | Please install pytorch (v1.10.1). 52 | 53 | pip install -r requirements.txt 54 | 55 | Please install [**MPI-Mesh**](https://github.com/MPI-IS/mesh). 56 | 57 | ### 2. Get data 58 | 59 | Please note that if you only want to generate demo videos, you can skip this step and directly download the pretrained models. 60 | 61 | Download [**SHOW_dataset_v1.0.zip**](https://download.is.tue.mpg.de/download.php?domain=talkshow&resume=1&sfile=SHOW_dataset_v1.0.zip) from [**TalkSHOW download webpage**](https://talkshow.is.tue.mpg.de/download.php), 62 | unzip using ``for i in $(ls *.tar.gz);do tar xvf $i;done``. 63 | 64 | ~~Run ``python data_utils/dataset_preprocess.py`` to check and split dataset. 65 | Modify ``data_root`` in ``config/*.json`` to the dataset-path.~~ 66 | 67 | Modify ``data_root`` in ``data_utils/apply_split.py`` to the dataset path and run it to apply ``data_utils/split_more_than_2s.pkl`` to the dataset. 68 | 69 | We will update the benchmark soon. 70 | 71 | ### 3. Download the pretrained models (Optional) 72 | 73 | Download [**pretrained models**](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view?usp=sharing), 74 | unzip and place it in the TalkSHOW folder, i.e. ``path-to-TalkSHOW/experiments``. 75 | 76 | ### 4. Training 77 | Please note that the process of loading data for the first time can be quite slow. If you have already completed the loading process, setting ``dataset_load_mode`` to ``pickle`` in ``config/[config_name].json`` will make the loading process much faster. 78 | 79 | # 1. Train VQ-VAEs. 80 | bash train_body_vq.sh 81 | # 2. Train PixelCNN. Please modify "Model:vq_path" in config/body_pixel.json to the path of VQ-VAEs. 82 | bash train_body_pixel.sh 83 | # 3. Train face generator. 84 | bash train_face.sh 85 | 86 | ### 5. Testing 87 | 88 | Modify the arguments in ``test_face.sh`` and ``test_body.sh``. Then 89 | 90 | bash test_face.sh 91 | bash test_body.sh 92 | 93 | ### 5. Visualization 94 | 95 | If you ssh into the linux machine, NotImplementedError might occur. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error. 96 | 97 | Download [**smplx model**](https://drive.google.com/file/d/1Ly_hQNLQcZ89KG0Nj4jYZwccQiimSUVn/view?usp=share_link) (Please register in the official [**SMPLX webpage**](https://smpl-x.is.tue.mpg.de) before you use it.) 98 | and place it in ``path-to-TalkSHOW/visualise/smplx_model``. 99 | To visualise the test set and generated result (in each video, left: generated result | right: ground truth). 100 | The videos and generated motion data are saved in ``./visualise/video/body-pixel``: 101 | 102 | bash visualise.sh 103 | 104 | If you ssh into the linux machine, there might be an error about OffscreenRenderer. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error. 105 | 106 | To reproduce the demo videos, run 107 | ```bash 108 | # the whole body demo 109 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/1st-page.wav --id 0 --whole_body 110 | # the face demo 111 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --only_face 112 | # the identity-specific demo 113 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 114 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 1 115 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 2 116 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 3 --stand 117 | # the diversity demo 118 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --num_samples 12 119 | # the french demo 120 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/french.wav --id 0 121 | # the synthetic speech demo 122 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/rich.wav --id 0 123 | # the song demo 124 | python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/song.wav --id 0 125 | ```` 126 | ### 6. Baseline 127 | 128 | For training the reproducted "Learning Speech-driven 3D Conversational Gestures from Video" (Habibie et al.), you could run 129 | ```bash 130 | python -W ignore scripts/train.py --speakers oliver seth conan chemistry --config_file ./config/LS3DCG.json 131 | ``` 132 | 133 | For visualization with the pretrained model, download the above [pretrained models](#3-download-the-pretrained-models--optional-) and run 134 | ```bash 135 | python scripts/demo.py --config_file ./config/LS3DCG.json --infer --audio_file ./demo_audio/style.wav --body_model_name s2g_LS3DCG --body_model_path experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth --id 0 136 | ``` 137 | 138 | ## Citation 139 | If you find our work useful to your research, please consider citing: 140 | ``` 141 | @inproceedings{yi2022generating, 142 | title={Generating Holistic 3D Human Motion from Speech}, 143 | author={Yi, Hongwei and Liang, Hualin and Liu, Yifei and Cao, Qiong and Wen, Yandong and Bolkart, Timo and Tao, Dacheng and Black, Michael J}, 144 | booktitle={CVPR}, 145 | year={2023} 146 | } 147 | ``` 148 | 149 | ## Acknowledgements 150 | For functions or scripts that are based on external sources, we acknowledge the origin individually in each file. 151 | Here are some great resources we benefit: 152 | - [Freeform](https://github.com/TheTempAccount/Co-Speech-Motion-Generation) for training pipeline 153 | - [MPI-Mesh](https://github.com/MPI-IS/mesh), [Pyrender](https://github.com/mmatl/pyrender), [Smplx](https://github.com/vchoutas/smplx), [VOCA](https://github.com/TimoBolkart/voca) for rendering 154 | - [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) and [Faceformer](https://github.com/EvelynFan/FaceFormer) for audio encoder 155 | 156 | ## Contact 157 | For questions, please contact talkshow@tue.mpg.de or hongwei.yi@tuebingen.mpg.de or fthualinliang@mail.scut.edu.cn or ft_lyf@mail.scut.edu.cn 158 | 159 | For commercial licensing, please contact ps-licensing@tue.mpg.de 160 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/__init__.py -------------------------------------------------------------------------------- /config/LS3DCG.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "pickle", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/", 15 | "pklname": "_3d_mfcc.pkl", 16 | "whole_video": false, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "body", 36 | "model_name": "s2g_LS3DCG", 37 | "code_num": 2048, 38 | "AudioOpt": "Adam", 39 | "encoder_choice": "mfcc", 40 | "gan": false 41 | }, 42 | "DataLoader": { 43 | "batch_size": 128, 44 | "num_workers": 0 45 | }, 46 | "Train": { 47 | "epochs": 100, 48 | "max_gradient_norm": 5, 49 | "learning_rate": { 50 | "generator_learning_rate": 1e-4, 51 | "discriminator_learning_rate": 1e-4 52 | }, 53 | "weights": { 54 | "keypoint_loss_weight": 1.0, 55 | "gan_loss_weight": 1.0 56 | } 57 | }, 58 | "Log": { 59 | "save_every": 50, 60 | "print_every": 200, 61 | "name": "LS3DCG" 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /config/body_pixel.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "json", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/", 15 | "pklname": "_3d_mfcc.pkl", 16 | "whole_video": false, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "body", 36 | "model_name": "s2g_body_pixel", 37 | "composition": true, 38 | "code_num": 2048, 39 | "bh_model": true, 40 | "AudioOpt": "Adam", 41 | "encoder_choice": "mfcc", 42 | "gan": false, 43 | "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth" 44 | }, 45 | "DataLoader": { 46 | "batch_size": 128, 47 | "num_workers": 0 48 | }, 49 | "Train": { 50 | "epochs": 100, 51 | "max_gradient_norm": 5, 52 | "learning_rate": { 53 | "generator_learning_rate": 1e-4, 54 | "discriminator_learning_rate": 1e-4 55 | } 56 | }, 57 | "Log": { 58 | "save_every": 50, 59 | "print_every": 200, 60 | "name": "body-pixel2" 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /config/body_vq.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "json", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/", 15 | "pklname": "_3d_mfcc.pkl", 16 | "whole_video": false, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "body", 36 | "model_name": "s2g_body_vq", 37 | "composition": true, 38 | "code_num": 2048, 39 | "bh_model": true, 40 | "AudioOpt": "Adam", 41 | "encoder_choice": "mfcc", 42 | "gan": false 43 | }, 44 | "DataLoader": { 45 | "batch_size": 128, 46 | "num_workers": 0 47 | }, 48 | "Train": { 49 | "epochs": 100, 50 | "max_gradient_norm": 5, 51 | "learning_rate": { 52 | "generator_learning_rate": 1e-4, 53 | "discriminator_learning_rate": 1e-4 54 | } 55 | }, 56 | "Log": { 57 | "save_every": 50, 58 | "print_every": 200, 59 | "name": "body-vq" 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /config/face.json: -------------------------------------------------------------------------------- 1 | { 2 | "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts", 3 | "dataset_load_mode": "json", 4 | "store_file_path": "store.pkl", 5 | "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz", 6 | "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml", 7 | "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl", 8 | "param": { 9 | "w_j": 1, 10 | "w_b": 1, 11 | "w_h": 1 12 | }, 13 | "Data": { 14 | "data_root": "../ExpressiveWholeBodyDatasetv1.0/", 15 | "pklname": "_3d_wv2.pkl", 16 | "whole_video": true, 17 | "pose": { 18 | "normalization": false, 19 | "convert_to_6d": false, 20 | "norm_method": "all", 21 | "augmentation": false, 22 | "generate_length": 88, 23 | "pre_pose_length": 0, 24 | "pose_dim": 99, 25 | "expression": true 26 | }, 27 | "aud": { 28 | "feat_method": "mfcc", 29 | "aud_feat_dim": 64, 30 | "aud_feat_win_size": null, 31 | "context_info": false 32 | } 33 | }, 34 | "Model": { 35 | "model_type": "face", 36 | "model_name": "s2g_face", 37 | "AudioOpt": "SGD", 38 | "encoder_choice": "faceformer", 39 | "gan": false 40 | }, 41 | "DataLoader": { 42 | "batch_size": 1, 43 | "num_workers": 0 44 | }, 45 | "Train": { 46 | "epochs": 100, 47 | "max_gradient_norm": 5, 48 | "learning_rate": { 49 | "generator_learning_rate": 1e-4, 50 | "discriminator_learning_rate": 1e-4 51 | } 52 | }, 53 | "Log": { 54 | "save_every": 50, 55 | "print_every": 1000, 56 | "name": "face" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dataloader_csv import MultiVidData as csv_data 2 | from .dataloader_torch import MultiVidData as torch_data 3 | from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta -------------------------------------------------------------------------------- /data_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/consts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/consts.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/dataloader_torch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/dataloader_torch.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/lower_body.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/lower_body.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/mesh_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/mesh_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/rotation_conversion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/rotation_conversion.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/apply_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import pickle 4 | import shutil 5 | 6 | speakers = ['seth', 'oliver', 'conan', 'chemistry'] 7 | source_data_root = "../expressive_body-V0.7" 8 | data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0" 9 | 10 | f_read = open('split_more_than_2s.pkl', 'rb') 11 | f_save = open('none.pkl', 'wb') 12 | data_split = pickle.load(f_read) 13 | none_split = [] 14 | 15 | train = val = test = 0 16 | 17 | for speaker_name in speakers: 18 | speaker_root = os.path.join(data_root, speaker_name) 19 | 20 | videos = [v for v in data_split[speaker_name]] 21 | 22 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): 23 | for split in data_split[speaker_name][vid]: 24 | for seq in data_split[speaker_name][vid][split]: 25 | 26 | seq = seq.replace('\\', '/') 27 | old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1]) 28 | old_file_path = old_file_path.replace('\\', '/') 29 | new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1]) 30 | try: 31 | shutil.move(old_file_path, new_file_path) 32 | if split == 'train': 33 | train = train + 1 34 | elif split == 'test': 35 | test = test + 1 36 | elif split == 'val': 37 | val = val + 1 38 | except FileNotFoundError: 39 | none_split.append(old_file_path) 40 | print(f"The file {old_file_path} does not exists.") 41 | except shutil.Error: 42 | none_split.append(old_file_path) 43 | print(f"The file {old_file_path} does not exists.") 44 | 45 | print(none_split.__len__()) 46 | pickle.dump(none_split, f_save) 47 | f_save.close() 48 | 49 | print(train, val, test) 50 | 51 | 52 | -------------------------------------------------------------------------------- /data_utils/axis2matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import scipy.linalg as linalg 4 | 5 | 6 | def rotate_mat(axis, radian): 7 | 8 | a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian) 9 | 10 | rot_matrix = linalg.expm(a) 11 | 12 | return rot_matrix 13 | 14 | def aaa2mat(axis, sin, cos): 15 | i = np.eye(3) 16 | nnt = np.dot(axis.T, axis) 17 | s = np.asarray([[0, -axis[0,2], axis[0,1]], 18 | [axis[0,2], 0, -axis[0,0]], 19 | [-axis[0,1], axis[0,0], 0]]) 20 | r = cos * i + (1-cos)*nnt +sin * s 21 | return r 22 | 23 | rand_axis = np.asarray([[1,0,0]]) 24 | #旋转角度 25 | r = math.pi/2 26 | #返回旋转矩阵 27 | rot_matrix = rotate_mat(rand_axis, r) 28 | r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r)) 29 | print(rot_matrix) -------------------------------------------------------------------------------- /data_utils/dataset_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import shutil 5 | import torch 6 | import numpy as np 7 | import librosa 8 | import random 9 | 10 | speakers = ['seth', 'conan', 'oliver', 'chemistry'] 11 | data_root = "../ExpressiveWholeBodyDatasetv1.0/" 12 | split = 'train' 13 | 14 | 15 | 16 | def split_list(full_list,shuffle=False,ratio=0.2): 17 | n_total = len(full_list) 18 | offset_0 = int(n_total * ratio) 19 | offset_1 = int(n_total * ratio * 2) 20 | if n_total==0 or offset_1<1: 21 | return [],full_list 22 | if shuffle: 23 | random.shuffle(full_list) 24 | sublist_0 = full_list[:offset_0] 25 | sublist_1 = full_list[offset_0:offset_1] 26 | sublist_2 = full_list[offset_1:] 27 | return sublist_0, sublist_1, sublist_2 28 | 29 | 30 | def moveto(list, file): 31 | for f in list: 32 | before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1] 33 | new_path = os.path.join(before, file) 34 | new_path = os.path.join(new_path, after) 35 | # os.makedirs(new_path) 36 | # os.path.isdir(new_path) 37 | # shutil.move(f, new_path) 38 | 39 | #转移到新目录 40 | shutil.copytree(f, new_path) 41 | #删除原train里的文件 42 | shutil.rmtree(f) 43 | return None 44 | 45 | 46 | def read_pkl(data): 47 | betas = np.array(data['betas']) 48 | 49 | jaw_pose = np.array(data['jaw_pose']) 50 | leye_pose = np.array(data['leye_pose']) 51 | reye_pose = np.array(data['reye_pose']) 52 | global_orient = np.array(data['global_orient']).squeeze() 53 | body_pose = np.array(data['body_pose_axis']) 54 | left_hand_pose = np.array(data['left_hand_pose']) 55 | right_hand_pose = np.array(data['right_hand_pose']) 56 | 57 | full_body = np.concatenate( 58 | (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1) 59 | 60 | expression = np.array(data['expression']) 61 | full_body = np.concatenate((full_body, expression), axis=1) 62 | 63 | if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0): 64 | return 1 65 | else: 66 | return 0 67 | 68 | 69 | for speaker_name in speakers: 70 | speaker_root = os.path.join(data_root, speaker_name) 71 | 72 | videos = [v for v in os.listdir(speaker_root)] 73 | print(videos) 74 | 75 | haode = huaide = 0 76 | total_seqs = [] 77 | 78 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): 79 | # for vid in videos: 80 | source_vid = vid 81 | vid_pth = os.path.join(speaker_root, source_vid) 82 | # vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split) 83 | t = os.path.join(speaker_root, source_vid, 'test') 84 | v = os.path.join(speaker_root, source_vid, 'val') 85 | 86 | # if os.path.exists(t): 87 | # shutil.rmtree(t) 88 | # if os.path.exists(v): 89 | # shutil.rmtree(v) 90 | try: 91 | seqs = [s for s in os.listdir(vid_pth)] 92 | except: 93 | continue 94 | # if len(seqs) == 0: 95 | # shutil.rmtree(os.path.join(speaker_root, source_vid)) 96 | # None 97 | for s in seqs: 98 | quality = 0 99 | total_seqs.append(os.path.join(vid_pth,s)) 100 | seq_root = os.path.join(vid_pth, s) 101 | key = seq_root # correspond to clip****** 102 | audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s)) 103 | 104 | # delete the data without audio or the audio file could not be read 105 | if os.path.isfile(audio_fname): 106 | try: 107 | audio = librosa.load(audio_fname) 108 | except: 109 | # print(key) 110 | shutil.rmtree(key) 111 | huaide = huaide + 1 112 | continue 113 | else: 114 | huaide = huaide + 1 115 | # print(key) 116 | shutil.rmtree(key) 117 | continue 118 | 119 | # check motion file 120 | motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s)) 121 | try: 122 | f = open(motion_fname, 'rb+') 123 | except: 124 | shutil.rmtree(key) 125 | huaide = huaide + 1 126 | continue 127 | 128 | data = pickle.load(f) 129 | w = read_pkl(data) 130 | f.close() 131 | quality = quality + w 132 | 133 | if w == 1: 134 | shutil.rmtree(key) 135 | # print(key) 136 | huaide = huaide + 1 137 | continue 138 | 139 | haode = haode + 1 140 | 141 | print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__())) 142 | 143 | for speaker_name in speakers: 144 | speaker_root = os.path.join(data_root, speaker_name) 145 | 146 | videos = [v for v in os.listdir(speaker_root)] 147 | print(videos) 148 | 149 | haode = huaide = 0 150 | total_seqs = [] 151 | 152 | for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): 153 | # for vid in videos: 154 | source_vid = vid 155 | vid_pth = os.path.join(speaker_root, source_vid) 156 | try: 157 | seqs = [s for s in os.listdir(vid_pth)] 158 | except: 159 | continue 160 | for s in seqs: 161 | quality = 0 162 | total_seqs.append(os.path.join(vid_pth, s)) 163 | print("total_seqs:{}".format(total_seqs.__len__())) 164 | # split the dataset 165 | test_list, val_list, train_list = split_list(total_seqs, True, 0.1) 166 | print(len(test_list), len(val_list), len(train_list)) 167 | moveto(train_list, 'train') 168 | moveto(test_list, 'test') 169 | moveto(val_list, 'val') 170 | 171 | -------------------------------------------------------------------------------- /data_utils/get_j.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to3d(poses, config): 5 | if config.Data.pose.convert_to_6d: 6 | if config.Data.pose.expression: 7 | poses_exp = poses[:, -100:] 8 | poses = poses[:, :-100] 9 | 10 | poses = poses.reshape(poses.shape[0], -1, 5) 11 | sin, cos = poses[:, :, 3], poses[:, :, 4] 12 | pose_angle = torch.atan2(sin, cos) 13 | poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1) 14 | 15 | if config.Data.pose.expression: 16 | poses = torch.cat([poses, poses_exp], dim=-1) 17 | return poses 18 | 19 | 20 | def get_joint(smplx_model, betas, pred): 21 | joint = smplx_model(betas=betas.repeat(pred.shape[0], 1), 22 | expression=pred[:, 165:265], 23 | jaw_pose=pred[:, 0:3], 24 | leye_pose=pred[:, 3:6], 25 | reye_pose=pred[:, 6:9], 26 | global_orient=pred[:, 9:12], 27 | body_pose=pred[:, 12:75], 28 | left_hand_pose=pred[:, 75:120], 29 | right_hand_pose=pred[:, 120:165], 30 | return_verts=True)['joints'] 31 | return joint 32 | 33 | 34 | def get_joints(smplx_model, betas, pred): 35 | if len(pred.shape) == 3: 36 | B = pred.shape[0] 37 | x = 4 if B>= 4 else B 38 | T = pred.shape[1] 39 | pred = pred.reshape(-1, 265) 40 | smplx_model.batch_size = L = T * x 41 | 42 | times = pred.shape[0] // smplx_model.batch_size 43 | joints = [] 44 | for i in range(times): 45 | joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L])) 46 | joints = torch.cat(joints, dim=0) 47 | joints = joints.reshape(B, T, -1, 3) 48 | else: 49 | smplx_model.batch_size = pred.shape[0] 50 | joints = get_joint(smplx_model, betas, pred) 51 | return joints -------------------------------------------------------------------------------- /data_utils/lower_body.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | lower_pose = torch.tensor( 5 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048, 6 | 0.15146760642528534, -1.2604516744613647, -0.3160211145877838, 7 | -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718, 8 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 9 | lower_pose_stand = torch.tensor([ 10 | 8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06, 11 | 3.0747, -0.0158, -0.0152, 12 | -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01, 13 | -3.9716e-01, -4.0229e-02, -1.2637e-01, 14 | 7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01, 15 | 7.8632e-01, -4.3810e-02, 1.4375e-02, 16 | -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ]) 17 | # lower_pose_stand = torch.tensor( 18 | # [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06, 19 | # 3.0747, -0.0158, -0.0152, 20 | # -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01, 21 | # 1.1654603481292725, 0.0, 0.0, 22 | # 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01, 23 | # 0.0, 0.0, 0.0, 24 | # 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,]) 25 | lower_body = [0, 1, 3, 4, 6, 7, 9, 10] 26 | count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 27 | 29, 30, 31, 32, 33, 34, 35, 36, 37, 28 | 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54] 29 | fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30 | 29, 31 | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 32 | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 33 | 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] 34 | all_index = np.ones(275) 35 | all_index[fix_index] = 0 36 | c_index = [] 37 | i = 0 38 | for num in all_index: 39 | if num == 1: 40 | c_index.append(i) 41 | i = i + 1 42 | c_index = np.asarray(c_index) 43 | 44 | fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 45 | 21, 22, 23, 24, 25, 26, 46 | 30, 31, 32, 33, 34, 35, 47 | 45, 46, 47, 48, 49, 50] 48 | all_index_3d = np.ones(165) 49 | all_index_3d[fix_index_3d] = 0 50 | c_index_3d = [] 51 | i = 0 52 | for num in all_index_3d: 53 | if num == 1: 54 | c_index_3d.append(i) 55 | i = i + 1 56 | c_index_3d = np.asarray(c_index_3d) 57 | 58 | c_index_6d = [] 59 | i = 0 60 | for num in all_index_3d: 61 | if num == 1: 62 | c_index_6d.append(2*i) 63 | c_index_6d.append(2 * i + 1) 64 | i = i + 1 65 | c_index_6d = np.asarray(c_index_6d) 66 | 67 | 68 | def part2full(input, stand=False): 69 | if stand: 70 | # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 71 | lp = torch.zeros_like(lower_pose) 72 | lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152]) 73 | lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 74 | else: 75 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 76 | 77 | input = torch.cat([input[:, :3], 78 | lp[:, :15], 79 | input[:, 3:6], 80 | lp[:, 15:21], 81 | input[:, 6:9], 82 | lp[:, 21:27], 83 | input[:, 9:12], 84 | lp[:, 27:], 85 | input[:, 12:]] 86 | , dim=1) 87 | return input 88 | 89 | 90 | def pred2poses(input, gt): 91 | input = torch.cat([input[:, :3], 92 | gt[0:1, 3:18].repeat(input.shape[0], 1), 93 | input[:, 3:6], 94 | gt[0:1, 21:27].repeat(input.shape[0], 1), 95 | input[:, 6:9], 96 | gt[0:1, 30:36].repeat(input.shape[0], 1), 97 | input[:, 9:12], 98 | gt[0:1, 39:45].repeat(input.shape[0], 1), 99 | input[:, 12:]] 100 | , dim=1) 101 | return input 102 | 103 | 104 | def poses2poses(input, gt): 105 | input = torch.cat([input[:, :3], 106 | gt[0:1, 3:18].repeat(input.shape[0], 1), 107 | input[:, 18:21], 108 | gt[0:1, 21:27].repeat(input.shape[0], 1), 109 | input[:, 27:30], 110 | gt[0:1, 30:36].repeat(input.shape[0], 1), 111 | input[:, 36:39], 112 | gt[0:1, 39:45].repeat(input.shape[0], 1), 113 | input[:, 45:]] 114 | , dim=1) 115 | return input 116 | 117 | def poses2pred(input, stand=False): 118 | if stand: 119 | lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 120 | # lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 121 | else: 122 | lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device) 123 | input = torch.cat([input[:, :3], 124 | lp[:, :15], 125 | input[:, 18:21], 126 | lp[:, 15:21], 127 | input[:, 27:30], 128 | lp[:, 21:27], 129 | input[:, 36:39], 130 | lp[:, 27:], 131 | input[:, 45:]] 132 | , dim=1) 133 | return input 134 | 135 | 136 | rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\ 137 | # ,22, 23, 24, 25, 40, 26, 41, 138 | # 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55, 139 | # 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75] 140 | 141 | symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 142 | # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 143 | # 1, 1, 1, 1, 1, 1] 144 | -------------------------------------------------------------------------------- /data_utils/split_more_than_2s.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/data_utils/split_more_than_2s.pkl -------------------------------------------------------------------------------- /data_utils/split_train_val_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | 5 | if __name__ =='__main__': 6 | id_list = "chemistry conan oliver seth" 7 | id_list = id_list.split(' ') 8 | 9 | old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0' 10 | new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited' 11 | 12 | with open('train_val_test.json') as f: 13 | split_info = json.load(f) 14 | phase_list = ['train', 'val', 'test'] 15 | for phase in phase_list: 16 | phase_path_list = split_info[phase] 17 | for p in phase_path_list: 18 | old_path = os.path.join(old_root, p) 19 | if not os.path.exists(old_path): 20 | print(f'{old_path} not found, continue' ) 21 | continue 22 | new_path = os.path.join(new_root, phase, p) 23 | dir_name = os.path.dirname(new_path) 24 | if not os.path.isdir(dir_name): 25 | os.makedirs(dir_name, exist_ok=True) 26 | shutil.move(old_path, new_path) 27 | 28 | -------------------------------------------------------------------------------- /data_utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import librosa #has to do this cause librosa is not supported on my server 3 | import python_speech_features 4 | from scipy.io import wavfile 5 | from scipy import signal 6 | import librosa 7 | import torch 8 | import torchaudio as ta 9 | import torchaudio.functional as ta_F 10 | import torchaudio.transforms as ta_T 11 | # import pyloudnorm as pyln 12 | 13 | 14 | def load_wav_old(audio_fn, sr = 16000): 15 | sample_rate, sig = wavfile.read(audio_fn) 16 | if sample_rate != sr: 17 | result = int((sig.shape[0]) / sample_rate * sr) 18 | x_resampled = signal.resample(sig, result) 19 | x_resampled = x_resampled.astype(np.float64) 20 | return x_resampled, sr 21 | 22 | sig = sig / (2**15) 23 | return sig, sample_rate 24 | 25 | 26 | def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None): 27 | 28 | y, sr = librosa.load(audio_fn, sr=sr, mono=True) 29 | 30 | if win_size is None: 31 | hop_len=int(sr / fps) 32 | else: 33 | hop_len=int(sr / win_size) 34 | 35 | n_fft=2048 36 | 37 | C = librosa.feature.mfcc( 38 | y = y, 39 | sr = sr, 40 | n_mfcc = n_mfcc, 41 | hop_length = hop_len, 42 | n_fft = n_fft 43 | ) 44 | 45 | if C.shape[0] == n_mfcc: 46 | C = C.transpose(1, 0) 47 | 48 | return C 49 | 50 | 51 | def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64): 52 | raise NotImplementedError 53 | ''' 54 | # y, sr = load_wav(audio_fn=audio_fn, sr=sr) 55 | 56 | # hop_len = int(sr / fps) 57 | # n_fft = 2048 58 | 59 | # C = librosa.feature.melspectrogram( 60 | # y = y, 61 | # sr = sr, 62 | # n_fft=n_fft, 63 | # hop_length=hop_len, 64 | # n_mels = n_mels, 65 | # fmin=0, 66 | # fmax=8000) 67 | 68 | 69 | # mask = (C == 0).astype(np.float) 70 | # C = mask * eps + (1-mask) * C 71 | 72 | # C = np.log(C) 73 | # #wierd error may occur here 74 | # assert not (np.isnan(C).any()), audio_fn 75 | # if C.shape[0] == n_mels: 76 | # C = C.transpose(1, 0) 77 | 78 | # return C 79 | ''' 80 | 81 | def extract_mfcc(audio,sample_rate=16000): 82 | mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04)) 83 | mfcc = np.stack([np.array(i) for i in mfcc]) 84 | return mfcc 85 | 86 | def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None): 87 | y, sr = load_wav_old(audio_fn, sr=sr) 88 | 89 | if y.shape.__len__() > 1: 90 | y = (y[:,0]+y[:,1])/2 91 | 92 | if win_size is None: 93 | hop_len=int(sr / fps) 94 | else: 95 | hop_len=int(sr/ win_size) 96 | 97 | n_fft=2048 98 | 99 | #hard coded for 25 fps 100 | if not smlpx: 101 | C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04) 102 | else: 103 | C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15) 104 | # if C.shape[0] == n_mfcc: 105 | # C = C.transpose(1, 0) 106 | 107 | return C 108 | 109 | 110 | def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None): 111 | y, sr = load_wav_old(audio_fn, sr=sr) 112 | 113 | if y.shape.__len__() > 1: 114 | y = (y[:, 0] + y[:, 1]) / 2 115 | n_fft = 2048 116 | 117 | slice_len = 22000 * 5 118 | slice = y.size // slice_len 119 | 120 | C = [] 121 | 122 | for i in range(slice): 123 | if i != (slice - 1): 124 | feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15) 125 | else: 126 | feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15) 127 | 128 | C.append(feat) 129 | 130 | return C 131 | 132 | 133 | def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): 134 | """ 135 | :param audio: 1 x T tensor containing a 16kHz audio signal 136 | :param frame_rate: frame rate for video (we need one audio chunk per video frame) 137 | :param chunk_size: number of audio samples per chunk 138 | :return: num_chunks x chunk_size tensor containing sliced audio 139 | """ 140 | samples_per_frame = chunk_size // frame_rate 141 | padding = (chunk_size - samples_per_frame) // 2 142 | audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) 143 | anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) 144 | audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) 145 | return audio 146 | 147 | 148 | def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'): 149 | if am is None: 150 | audio, sr_0 = ta.load(audio_fn) 151 | if sr != sr_0: 152 | audio = ta.transforms.Resample(sr_0, sr)(audio) 153 | if audio.shape[0] > 1: 154 | audio = torch.mean(audio, dim=0, keepdim=True) 155 | 156 | n_fft = 2048 157 | if fps == 15: 158 | hop_length = 1467 159 | elif fps == 30: 160 | hop_length = 734 161 | win_length = hop_length * 2 162 | n_mels = 256 163 | n_mfcc = 64 164 | 165 | if type == 'mfcc': 166 | mfcc_transform = ta_T.MFCC( 167 | sample_rate=sr, 168 | n_mfcc=n_mfcc, 169 | melkwargs={ 170 | "n_fft": n_fft, 171 | "n_mels": n_mels, 172 | # "win_length": win_length, 173 | "hop_length": hop_length, 174 | "mel_scale": "htk", 175 | }, 176 | ) 177 | audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy() 178 | elif type == 'mel': 179 | # audio = 0.01 * audio / torch.mean(torch.abs(audio)) 180 | mel_transform = ta_T.MelSpectrogram( 181 | sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels 182 | ) 183 | audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy() 184 | # audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy() 185 | elif type == 'mel_mul': 186 | audio = 0.01 * audio / torch.mean(torch.abs(audio)) 187 | audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr) 188 | mel_transform = ta_T.MelSpectrogram( 189 | sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels 190 | ) 191 | audio_ft = mel_transform(audio).squeeze(1) 192 | audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy() 193 | else: 194 | speech_array, sampling_rate = librosa.load(audio_fn, sr=16000) 195 | 196 | if encoder_choice == 'faceformer': 197 | # audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1) 198 | audio_ft = speech_array.reshape(-1, 1) 199 | elif encoder_choice == 'meshtalk': 200 | audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array)) 201 | elif encoder_choice == 'onset': 202 | audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1) 203 | else: 204 | audio, sr_0 = ta.load(audio_fn) 205 | if sr != sr_0: 206 | audio = ta.transforms.Resample(sr_0, sr)(audio) 207 | if audio.shape[0] > 1: 208 | audio = torch.mean(audio, dim=0, keepdim=True) 209 | 210 | n_fft = 2048 211 | if fps == 15: 212 | hop_length = 1467 213 | elif fps == 30: 214 | hop_length = 734 215 | win_length = hop_length * 2 216 | n_mels = 256 217 | n_mfcc = 64 218 | 219 | mfcc_transform = ta_T.MFCC( 220 | sample_rate=sr, 221 | n_mfcc=n_mfcc, 222 | melkwargs={ 223 | "n_fft": n_fft, 224 | "n_mels": n_mels, 225 | # "win_length": win_length, 226 | "hop_length": hop_length, 227 | "mel_scale": "htk", 228 | }, 229 | ) 230 | audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy() 231 | return audio_ft 232 | 233 | 234 | def get_mfcc_sepa(audio_fn, fps=15, sr=16000): 235 | audio, sr_0 = ta.load(audio_fn) 236 | if sr != sr_0: 237 | audio = ta.transforms.Resample(sr_0, sr)(audio) 238 | if audio.shape[0] > 1: 239 | audio = torch.mean(audio, dim=0, keepdim=True) 240 | 241 | n_fft = 2048 242 | if fps == 15: 243 | hop_length = 1467 244 | elif fps == 30: 245 | hop_length = 734 246 | n_mels = 256 247 | n_mfcc = 64 248 | 249 | mfcc_transform = ta_T.MFCC( 250 | sample_rate=sr, 251 | n_mfcc=n_mfcc, 252 | melkwargs={ 253 | "n_fft": n_fft, 254 | "n_mels": n_mels, 255 | # "win_length": win_length, 256 | "hop_length": hop_length, 257 | "mel_scale": "htk", 258 | }, 259 | ) 260 | audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy() 261 | audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy() 262 | audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0) 263 | return audio_ft, audio_ft_0.shape[0] 264 | 265 | 266 | def get_mfcc_old(wav_file): 267 | sig, sample_rate = load_wav_old(wav_file) 268 | mfcc = extract_mfcc(sig) 269 | return mfcc 270 | 271 | 272 | def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0): 273 | """ 274 | :param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame 275 | :param mask: V-dimensional Tensor containing a mask with vertices to be smoothed 276 | :param filter_size: size of the Gaussian filter 277 | :param sigma: standard deviation of the Gaussian filter 278 | :return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask) 279 | """ 280 | assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}" 281 | # Gaussian smoothing (low-pass filtering) 282 | fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1) 283 | fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2) 284 | fltr = torch.Tensor(fltr) / np.sum(fltr) 285 | # apply fltr 286 | fltr = fltr.view(1, 1, -1).to(device=geom.device) 287 | T, V = geom.shape[1], geom.shape[2] 288 | g = torch.nn.functional.pad( 289 | geom.permute(2, 0, 1).view(V, 1, T), 290 | pad=[filter_size // 2, filter_size // 2], mode='replicate' 291 | ) 292 | g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T) 293 | smoothed = g.permute(1, 2, 0).contiguous() 294 | # blend smoothed signal with original signal 295 | if mask is None: 296 | return smoothed 297 | else: 298 | return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1) 299 | 300 | if __name__ == '__main__': 301 | audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav' 302 | 303 | C = get_mfcc_psf(audio_fn) 304 | print(C.shape) 305 | 306 | C_2 = get_mfcc_librosa(audio_fn) 307 | print(C.shape) 308 | 309 | print(C) 310 | print(C_2) 311 | print((C == C_2).all()) 312 | # print(y.shape, sr) 313 | # mel_spec = get_melspec(audio_fn) 314 | # print(mel_spec.shape) 315 | # mfcc = get_mfcc(audio_fn, sr = 16000) 316 | # print(mfcc.shape) 317 | # print(mel_spec.max(), mel_spec.min()) 318 | # print(mfcc.max(), mfcc.min()) -------------------------------------------------------------------------------- /demo/1st-page/1st-page-upper.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/1st-page/1st-page-upper.mp4 -------------------------------------------------------------------------------- /demo/1st-page/1st-page-upper.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/1st-page/1st-page-upper.npy -------------------------------------------------------------------------------- /demo/french/french.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/french/french.mp4 -------------------------------------------------------------------------------- /demo/french/french.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/french/french.npy -------------------------------------------------------------------------------- /demo/rich/rich.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/rich/rich.mp4 -------------------------------------------------------------------------------- /demo/rich/rich.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/rich/rich.npy -------------------------------------------------------------------------------- /demo/song/cut.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/cut.mp4 -------------------------------------------------------------------------------- /demo/song/song.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/song.mp4 -------------------------------------------------------------------------------- /demo/song/song.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/song/song.npy -------------------------------------------------------------------------------- /demo/style/chemistry.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/chemistry.mp4 -------------------------------------------------------------------------------- /demo/style/chemistry.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/chemistry.npy -------------------------------------------------------------------------------- /demo/style/conan.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/conan.mp4 -------------------------------------------------------------------------------- /demo/style/conan.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/conan.npy -------------------------------------------------------------------------------- /demo/style/diversity.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/diversity.mp4 -------------------------------------------------------------------------------- /demo/style/diversity.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/diversity.npy -------------------------------------------------------------------------------- /demo/style/face.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/face.mp4 -------------------------------------------------------------------------------- /demo/style/face.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/face.npy -------------------------------------------------------------------------------- /demo/style/oliver.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/oliver.mp4 -------------------------------------------------------------------------------- /demo/style/oliver.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/oliver.npy -------------------------------------------------------------------------------- /demo/style/seth.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/seth.mp4 -------------------------------------------------------------------------------- /demo/style/seth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo/style/seth.npy -------------------------------------------------------------------------------- /demo_audio/1st-page.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/1st-page.wav -------------------------------------------------------------------------------- /demo_audio/french.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/french.wav -------------------------------------------------------------------------------- /demo_audio/rich.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/rich.wav -------------------------------------------------------------------------------- /demo_audio/rich_short.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/rich_short.wav -------------------------------------------------------------------------------- /demo_audio/song.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/song.wav -------------------------------------------------------------------------------- /demo_audio/style.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/demo_audio/style.wav -------------------------------------------------------------------------------- /evaluation/FGD.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy import linalg 7 | import math 8 | from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings 12 | 13 | 14 | change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04]) 15 | class EmbeddingSpaceEvaluator: 16 | def __init__(self, ae, vae, device): 17 | 18 | # init embed net 19 | self.ae = ae 20 | # self.vae = vae 21 | 22 | # storage 23 | self.real_feat_list = [] 24 | self.generated_feat_list = [] 25 | self.real_joints_list = [] 26 | self.generated_joints_list = [] 27 | self.real_6d_list = [] 28 | self.generated_6d_list = [] 29 | self.audio_beat_list = [] 30 | 31 | def reset(self): 32 | self.real_feat_list = [] 33 | self.generated_feat_list = [] 34 | 35 | def get_no_of_samples(self): 36 | return len(self.real_feat_list) 37 | 38 | def push_samples(self, generated_poses, real_poses): 39 | # self.net.eval() 40 | # convert poses to latent features 41 | real_feat, real_poses = self.ae.extract(real_poses) 42 | generated_feat, generated_poses = self.ae.extract(generated_poses) 43 | 44 | num_joints = real_poses.shape[2] // 3 45 | 46 | real_feat = real_feat.squeeze() 47 | generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1) 48 | 49 | self.real_feat_list.append(real_feat.data.cpu().numpy()) 50 | self.generated_feat_list.append(generated_feat.data.cpu().numpy()) 51 | 52 | # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6) 53 | # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6) 54 | # 55 | # self.real_feat_list.append(real_poses.data.cpu().numpy()) 56 | # self.generated_feat_list.append(generated_poses.data.cpu().numpy()) 57 | 58 | def push_joints(self, generated_poses, real_poses): 59 | self.real_joints_list.append(real_poses.data.cpu()) 60 | self.generated_joints_list.append(generated_poses.squeeze().data.cpu()) 61 | 62 | def push_aud(self, aud): 63 | self.audio_beat_list.append(aud.squeeze().data.cpu()) 64 | 65 | def get_MAAC(self): 66 | ang_vel_list = [] 67 | for real_joints in self.real_joints_list: 68 | real_joints[:, 15:21] = real_joints[:, 16:22] 69 | vec = real_joints[:, 15:21] - real_joints[:, 13:19] 70 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]]) 71 | inner_product = torch.clamp(inner_product, -1, 1, out=None) 72 | angle = torch.acos(inner_product) / math.pi 73 | ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0) 74 | ang_vel_list.append(ang_vel.unsqueeze(dim=0)) 75 | all_vel = torch.cat(ang_vel_list, dim=0) 76 | MAAC = all_vel.mean(dim=0) 77 | return MAAC 78 | 79 | def get_BCscore(self): 80 | thres = 0.01 81 | sigma = 0.1 82 | sum_1 = 0 83 | total_beat = 0 84 | for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list): 85 | motion_beat_time = [] 86 | if joints.dim() == 4: 87 | joints = joints[0] 88 | joints[:, 15:21] = joints[:, 16:22] 89 | vec = joints[:, 15:21] - joints[:, 13:19] 90 | inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]]) 91 | inner_product = torch.clamp(inner_product, -1, 1, out=None) 92 | angle = torch.acos(inner_product) / math.pi 93 | ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle) 94 | 95 | angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0) 96 | 97 | sum_2 = 0 98 | for i in range(angle_diff.shape[1]): 99 | motion_beat_time = [] 100 | for t in range(1, joints.shape[0]-1): 101 | if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]): 102 | if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[ 103 | t][i] >= thres): 104 | motion_beat_time.append(float(t) / 30.0) 105 | if (len(motion_beat_time) == 0): 106 | continue 107 | motion_beat_time = torch.tensor(motion_beat_time) 108 | sum = 0 109 | for audio in audio_beat_time: 110 | sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma)) 111 | sum_2 = sum_2 + sum 112 | total_beat = total_beat + len(audio_beat_time) 113 | sum_1 = sum_1 + sum_2 114 | return sum_1/total_beat 115 | 116 | 117 | def get_scores(self): 118 | generated_feats = np.vstack(self.generated_feat_list) 119 | real_feats = np.vstack(self.real_feat_list) 120 | 121 | def frechet_distance(samples_A, samples_B): 122 | A_mu = np.mean(samples_A, axis=0) 123 | A_sigma = np.cov(samples_A, rowvar=False) 124 | B_mu = np.mean(samples_B, axis=0) 125 | B_sigma = np.cov(samples_B, rowvar=False) 126 | try: 127 | frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) 128 | except ValueError: 129 | frechet_dist = 1e+10 130 | return frechet_dist 131 | 132 | #################################################################### 133 | # frechet distance 134 | frechet_dist = frechet_distance(generated_feats, real_feats) 135 | 136 | #################################################################### 137 | # distance between real and generated samples on the latent feature space 138 | dists = [] 139 | for i in range(real_feats.shape[0]): 140 | d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE 141 | dists.append(d) 142 | feat_dist = np.mean(dists) 143 | 144 | return frechet_dist, feat_dist 145 | 146 | @staticmethod 147 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 148 | """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ 149 | """Numpy implementation of the Frechet Distance. 150 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 151 | and X_2 ~ N(mu_2, C_2) is 152 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 153 | Stable version by Dougal J. Sutherland. 154 | Params: 155 | -- mu1 : Numpy array containing the activations of a layer of the 156 | inception net (like returned by the function 'get_predictions') 157 | for generated samples. 158 | -- mu2 : The sample mean over activations, precalculated on an 159 | representative data set. 160 | -- sigma1: The covariance matrix over activations for generated samples. 161 | -- sigma2: The covariance matrix over activations, precalculated on an 162 | representative data set. 163 | Returns: 164 | -- : The Frechet Distance. 165 | """ 166 | 167 | mu1 = np.atleast_1d(mu1) 168 | mu2 = np.atleast_1d(mu2) 169 | 170 | sigma1 = np.atleast_2d(sigma1) 171 | sigma2 = np.atleast_2d(sigma2) 172 | 173 | assert mu1.shape == mu2.shape, \ 174 | 'Training and test mean vectors have different lengths' 175 | assert sigma1.shape == sigma2.shape, \ 176 | 'Training and test covariances have different dimensions' 177 | 178 | diff = mu1 - mu2 179 | 180 | # Product might be almost singular 181 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 182 | if not np.isfinite(covmean).all(): 183 | msg = ('fid calculation produces singular product; ' 184 | 'adding %s to diagonal of cov estimates') % eps 185 | print(msg) 186 | offset = np.eye(sigma1.shape[0]) * eps 187 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 188 | 189 | # Numerical error might give slight imaginary component 190 | if np.iscomplexobj(covmean): 191 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 192 | m = np.max(np.abs(covmean.imag)) 193 | raise ValueError('Imaginary component {}'.format(m)) 194 | covmean = covmean.real 195 | 196 | tr_covmean = np.trace(covmean) 197 | 198 | return (diff.dot(diff) + np.trace(sigma1) + 199 | np.trace(sigma2) - 2 * tr_covmean) -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/evaluation/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/diversity_LVD.py: -------------------------------------------------------------------------------- 1 | ''' 2 | LVD: different initial pose 3 | diversity: same initial pose 4 | ''' 5 | import os 6 | import sys 7 | sys.path.append(os.getcwd()) 8 | 9 | from glob import glob 10 | 11 | from argparse import ArgumentParser 12 | import json 13 | 14 | from evaluation.util import * 15 | from evaluation.metrics import * 16 | from tqdm import tqdm 17 | 18 | parser = ArgumentParser() 19 | parser.add_argument('--speaker', required=True, type=str) 20 | parser.add_argument('--post_fix', nargs='+', default=['base'], type=str) 21 | args = parser.parse_args() 22 | 23 | speaker = args.speaker 24 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 25 | 26 | LVD_list = [] 27 | diversity_list = [] 28 | 29 | for aud in tqdm(test_audios): 30 | base_name = os.path.splitext(aud)[0] 31 | gt_path = get_full_path(aud, speaker, 'val') 32 | _, gt_poses, _ = get_gts(gt_path) 33 | gt_poses = gt_poses[np.newaxis,...] 34 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 35 | for post_fix in args.post_fix: 36 | pred_path = base_name + '_'+post_fix+'.json' 37 | pred_poses = np.array(json.load(open(pred_path))) 38 | # print(pred_poses.shape)#(B, seq_len, 108) 39 | pred_poses = cvt25(pred_poses, gt_poses) 40 | # print(pred_poses.shape)#(B, seq, pose_dim) 41 | 42 | gt_valid_points = hand_points(gt_poses) 43 | pred_valid_points = hand_points(pred_poses) 44 | 45 | lvd = LVD(gt_valid_points, pred_valid_points) 46 | # div = diversity(pred_valid_points) 47 | 48 | LVD_list.append(lvd) 49 | # diversity_list.append(div) 50 | 51 | # gt_velocity = peak_velocity(gt_valid_points, order=2) 52 | # pred_velocity = peak_velocity(pred_valid_points, order=2) 53 | 54 | # gt_consistency = velocity_consistency(gt_velocity, pred_velocity) 55 | # pred_consistency = velocity_consistency(pred_velocity, gt_velocity) 56 | 57 | # gt_consistency_list.append(gt_consistency) 58 | # pred_consistency_list.append(pred_consistency) 59 | 60 | lvd = np.mean(LVD_list) 61 | # diversity_list = np.mean(diversity_list) 62 | 63 | print('LVD:', lvd) 64 | # print("diversity:", diversity_list) -------------------------------------------------------------------------------- /evaluation/get_quality_samples.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | import os 4 | import sys 5 | sys.path.append(os.getcwd()) 6 | 7 | from glob import glob 8 | 9 | from argparse import ArgumentParser 10 | import json 11 | 12 | from evaluation.util import * 13 | from evaluation.metrics import * 14 | from tqdm import tqdm 15 | 16 | parser = ArgumentParser() 17 | parser.add_argument('--speaker', required=True, type=str) 18 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 19 | args = parser.parse_args() 20 | 21 | speaker = args.speaker 22 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 23 | 24 | quality_samples={'gt':[]} 25 | for post_fix in args.post_fix: 26 | quality_samples[post_fix] = [] 27 | 28 | for aud in tqdm(test_audios): 29 | base_name = os.path.splitext(aud)[0] 30 | gt_path = get_full_path(aud, speaker, 'val') 31 | _, gt_poses, _ = get_gts(gt_path) 32 | gt_poses = gt_poses[np.newaxis,...] 33 | gt_valid_points = valid_points(gt_poses) 34 | # print(gt_valid_points.shape) 35 | quality_samples['gt'].append(gt_valid_points) 36 | 37 | for post_fix in args.post_fix: 38 | pred_path = base_name + '_'+post_fix+'.json' 39 | pred_poses = np.array(json.load(open(pred_path))) 40 | # print(pred_poses.shape)#(B, seq_len, 108) 41 | pred_poses = cvt25(pred_poses, gt_poses) 42 | # print(pred_poses.shape)#(B, seq, pose_dim) 43 | 44 | pred_valid_points = valid_points(pred_poses)[0:1] 45 | quality_samples[post_fix].append(pred_valid_points) 46 | 47 | quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1) 48 | for post_fix in args.post_fix: 49 | quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1) 50 | 51 | print('gt:', quality_samples['gt'].shape) 52 | quality_samples['gt'] = quality_samples['gt'].tolist() 53 | for post_fix in args.post_fix: 54 | print(post_fix, ':', quality_samples[post_fix].shape) 55 | quality_samples[post_fix] = quality_samples[post_fix].tolist() 56 | 57 | save_dir = '../../experiments/' 58 | os.makedirs(save_dir, exist_ok=True) 59 | save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker)) 60 | with open(save_name, 'w') as f: 61 | json.dump(quality_samples, f) 62 | 63 | -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Warning: metrics are for reference only, may have limited significance 3 | ''' 4 | import os 5 | import sys 6 | sys.path.append(os.getcwd()) 7 | import numpy as np 8 | import torch 9 | 10 | from data_utils.lower_body import rearrange, symmetry 11 | import torch.nn.functional as F 12 | 13 | def data_driven_baselines(gt_kps): 14 | ''' 15 | gt_kps: T, D 16 | ''' 17 | gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1]) 18 | 19 | mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D) 20 | mean = np.mean(np.abs(gt_velocity-mean)) 21 | last_step = gt_kps[1] - gt_kps[0] 22 | last_step = last_step[np.newaxis] #(1, D) 23 | last_step = np.mean(np.abs(gt_velocity-last_step)) 24 | return last_step, mean 25 | 26 | def Batch_LVD(gt_kps, pr_kps, symmetrical, weight): 27 | if gt_kps.shape[0] > pr_kps.shape[1]: 28 | length = pr_kps.shape[1] 29 | else: 30 | length = gt_kps.shape[0] 31 | gt_kps = gt_kps[:length] 32 | pr_kps = pr_kps[:, :length] 33 | global symmetry 34 | symmetry = torch.tensor(symmetry).bool() 35 | 36 | if symmetrical: 37 | # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints. 38 | gt_kps = gt_kps[:, rearrange] 39 | ns_gt_kps = gt_kps[:, ~symmetry] 40 | ys_gt_kps = gt_kps[:, symmetry] 41 | ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3) 42 | ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1) 43 | ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1) 44 | left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1) 45 | right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1) 46 | move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda()) 47 | ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool()) 48 | ys_gt_velocity = ys_gt_velocity.transpose(0,1) 49 | gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1) 50 | 51 | pr_kps = pr_kps[:, :, rearrange] 52 | ns_pr_kps = pr_kps[:, :, ~symmetry] 53 | ys_pr_kps = pr_kps[:, :, symmetry] 54 | ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3) 55 | ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1) 56 | ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1) 57 | left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1) 58 | right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1) 59 | move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(), 60 | torch.zeros(left_pr_vel.shape).cuda()) 61 | ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul( 62 | ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long()) 63 | ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0) 64 | pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2) 65 | else: 66 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1) 67 | pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1) 68 | 69 | if weight: 70 | w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0) 71 | else: 72 | w = 1 / gt_velocity.shape[0] 73 | 74 | v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean() 75 | 76 | return v_diff 77 | 78 | 79 | def LVD(gt_kps, pr_kps, symmetrical=False, weight=False): 80 | gt_kps = gt_kps.squeeze() 81 | pr_kps = pr_kps.squeeze() 82 | if len(pr_kps.shape) == 4: 83 | return Batch_LVD(gt_kps, pr_kps, symmetrical, weight) 84 | # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0]) 85 | length = gt_kps.shape[0]-10 86 | # gt_kps = gt_kps[25:length] 87 | # pr_kps = pr_kps[25:length] #(T, D) 88 | # if pr_kps.shape[0] < gt_kps.shape[0]: 89 | # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant') 90 | 91 | gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1) 92 | pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1) 93 | 94 | return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean() 95 | 96 | def diversity(kps): 97 | ''' 98 | kps: bs, seq, dim 99 | ''' 100 | dis_list = [] 101 | #the distance between each pair 102 | for i in range(kps.shape[0]): 103 | for j in range(i+1, kps.shape[0]): 104 | seq_i = kps[i] 105 | seq_j = kps[j] 106 | 107 | dis = np.mean(np.abs(seq_i - seq_j)) 108 | dis_list.append(dis) 109 | return np.mean(dis_list) 110 | -------------------------------------------------------------------------------- /evaluation/mode_transition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | from glob import glob 6 | 7 | from argparse import ArgumentParser 8 | import json 9 | 10 | from evaluation.util import * 11 | from evaluation.metrics import * 12 | from tqdm import tqdm 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--speaker', required=True, type=str) 16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 17 | args = parser.parse_args() 18 | 19 | speaker = args.speaker 20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 21 | 22 | precision_list=[] 23 | recall_list=[] 24 | accuracy_list=[] 25 | 26 | for aud in tqdm(test_audios): 27 | base_name = os.path.splitext(aud)[0] 28 | gt_path = get_full_path(aud, speaker, 'val') 29 | _, gt_poses, _ = get_gts(gt_path) 30 | if gt_poses.shape[0] < 50: 31 | continue 32 | gt_poses = gt_poses[np.newaxis,...] 33 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 34 | for post_fix in args.post_fix: 35 | pred_path = base_name + '_'+post_fix+'.json' 36 | pred_poses = np.array(json.load(open(pred_path))) 37 | # print(pred_poses.shape)#(B, seq_len, 108) 38 | pred_poses = cvt25(pred_poses, gt_poses) 39 | # print(pred_poses.shape)#(B, seq, pose_dim) 40 | 41 | gt_valid_points = valid_points(gt_poses) 42 | pred_valid_points = valid_points(pred_poses) 43 | 44 | # print(gt_valid_points.shape, pred_valid_points.shape) 45 | 46 | gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N) 47 | pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N) 48 | 49 | # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape) 50 | # pred_mode_transition_seq = baseline 51 | precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq) 52 | precision_list.append(precision) 53 | recall_list.append(recall) 54 | accuracy_list.append(accuracy) 55 | print(len(precision_list), len(recall_list), len(accuracy_list)) 56 | precision_list = np.mean(precision_list) 57 | recall_list = np.mean(recall_list) 58 | accuracy_list = np.mean(accuracy_list) 59 | 60 | print('precision, recall, accu:', precision_list, recall_list, accuracy_list) 61 | -------------------------------------------------------------------------------- /evaluation/peak_velocity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | from glob import glob 6 | 7 | from argparse import ArgumentParser 8 | import json 9 | 10 | from evaluation.util import * 11 | from evaluation.metrics import * 12 | from tqdm import tqdm 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--speaker', required=True, type=str) 16 | parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str) 17 | args = parser.parse_args() 18 | 19 | speaker = args.speaker 20 | test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker))) 21 | 22 | gt_consistency_list=[] 23 | pred_consistency_list=[] 24 | 25 | for aud in tqdm(test_audios): 26 | base_name = os.path.splitext(aud)[0] 27 | gt_path = get_full_path(aud, speaker, 'val') 28 | _, gt_poses, _ = get_gts(gt_path) 29 | gt_poses = gt_poses[np.newaxis,...] 30 | # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face 31 | for post_fix in args.post_fix: 32 | pred_path = base_name + '_'+post_fix+'.json' 33 | pred_poses = np.array(json.load(open(pred_path))) 34 | # print(pred_poses.shape)#(B, seq_len, 108) 35 | pred_poses = cvt25(pred_poses, gt_poses) 36 | # print(pred_poses.shape)#(B, seq, pose_dim) 37 | 38 | gt_valid_points = hand_points(gt_poses) 39 | pred_valid_points = hand_points(pred_poses) 40 | 41 | gt_velocity = peak_velocity(gt_valid_points, order=2) 42 | pred_velocity = peak_velocity(pred_valid_points, order=2) 43 | 44 | gt_consistency = velocity_consistency(gt_velocity, pred_velocity) 45 | pred_consistency = velocity_consistency(pred_velocity, gt_velocity) 46 | 47 | gt_consistency_list.append(gt_consistency) 48 | pred_consistency_list.append(pred_consistency) 49 | 50 | gt_consistency_list = np.concatenate(gt_consistency_list) 51 | pred_consistency_list = np.concatenate(pred_consistency_list) 52 | 53 | print(gt_consistency_list.max(), gt_consistency_list.min()) 54 | print(pred_consistency_list.max(), pred_consistency_list.min()) 55 | print(np.mean(gt_consistency_list), np.mean(pred_consistency_list)) 56 | print(np.std(gt_consistency_list), np.std(pred_consistency_list)) 57 | 58 | draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue') 59 | draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue') 60 | 61 | to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker)) 62 | to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker)) 63 | 64 | np.save('%s_gt.npy'%(speaker), gt_consistency_list) 65 | np.save('%s_pred.npy'%(speaker), pred_consistency_list) -------------------------------------------------------------------------------- /evaluation/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import json 5 | from matplotlib import pyplot as plt 6 | import pandas as pd 7 | def get_gts(clip): 8 | ''' 9 | clip: abs path to the clip dir 10 | ''' 11 | keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json')) 12 | 13 | upper_body_points = list(np.arange(0, 25)) 14 | poses = [] 15 | confs = [] 16 | neck_to_nose_len = [] 17 | mean_position = [] 18 | for kp_file in keypoints_files: 19 | kp_load = json.load(open(kp_file, 'r'))['people'][0] 20 | posepts = kp_load['pose_keypoints_2d'] 21 | lhandpts = kp_load['hand_left_keypoints_2d'] 22 | rhandpts = kp_load['hand_right_keypoints_2d'] 23 | facepts = kp_load['face_keypoints_2d'] 24 | 25 | neck = np.array(posepts).reshape(-1,3)[1] 26 | nose = np.array(posepts).reshape(-1,3)[0] 27 | x_offset = abs(neck[0]-nose[0]) 28 | y_offset = abs(neck[1]-nose[1]) 29 | neck_to_nose_len.append(y_offset) 30 | mean_position.append([neck[0],neck[1]]) 31 | 32 | keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2] 33 | 34 | upper_body = keypoints[upper_body_points, :] 35 | hand_points = keypoints[25:, :] 36 | keypoints = np.vstack([upper_body, hand_points]) 37 | 38 | poses.append(keypoints) 39 | 40 | if len(neck_to_nose_len) > 0: 41 | scale_factor = np.mean(neck_to_nose_len) 42 | else: 43 | raise ValueError(clip) 44 | mean_position = np.mean(np.array(mean_position), axis=0) 45 | 46 | unlocalized_poses = np.array(poses).copy() 47 | localized_poses = [] 48 | for i in range(len(poses)): 49 | keypoints = poses[i] 50 | neck = keypoints[1].copy() 51 | 52 | keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor 53 | keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor 54 | localized_poses.append(keypoints.reshape(-1)) 55 | 56 | localized_poses=np.array(localized_poses) 57 | return unlocalized_poses, localized_poses, (scale_factor, mean_position) 58 | 59 | def get_full_path(wav_name, speaker, split): 60 | ''' 61 | get clip path from aud file 62 | ''' 63 | wav_name = os.path.basename(wav_name) 64 | wav_name = os.path.splitext(wav_name)[0] 65 | clip_name, vid_name = wav_name[:10], wav_name[11:] 66 | 67 | full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name) 68 | 69 | assert os.path.isdir(full_path), full_path 70 | 71 | return full_path 72 | 73 | def smooth(res): 74 | ''' 75 | res: (B, seq_len, pose_dim) 76 | ''' 77 | window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]] 78 | w_size=7 79 | for i in range(10, res.shape[1]-3): 80 | window.append(res[:, i+3, :]) 81 | if len(window) > w_size: 82 | window = window[1:] 83 | 84 | if (i%25) in [22, 23, 24, 0, 1, 2, 3]: 85 | res[:, i, :] = np.mean(window, axis=1) 86 | 87 | return res 88 | 89 | def cvt25(pred_poses, gt_poses=None): 90 | ''' 91 | gt_poses: (1, seq_len, 270), 135 *2 92 | pred_poses: (B, seq_len, 108), 54 * 2 93 | ''' 94 | if gt_poses is None: 95 | gt_poses = np.zeros_like(pred_poses) 96 | else: 97 | gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0) 98 | 99 | length = min(pred_poses.shape[1], gt_poses.shape[1]) 100 | pred_poses = pred_poses[:, :length, :] 101 | gt_poses = gt_poses[:, :length, :] 102 | gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2) 103 | pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2) 104 | 105 | gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :] 106 | gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :] 107 | 108 | return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1) 109 | 110 | def hand_points(seq): 111 | ''' 112 | seq: (B, seq_len, 135*2) 113 | hands only 114 | ''' 115 | hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21)) 116 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2) 117 | return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1) 118 | 119 | def valid_points(seq): 120 | ''' 121 | hands with some head points 122 | ''' 123 | valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21)) 124 | seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2) 125 | 126 | seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1) 127 | assert seq.shape[-1] == 108, seq.shape 128 | return seq 129 | 130 | def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'): 131 | plt.figure() 132 | plt.hist(seq, bins=100, range=(0, 100), color=color) 133 | plt.savefig(save_name) 134 | 135 | def to_excel(seq, save_name='res.xlsx'): 136 | ''' 137 | seq: (T) 138 | ''' 139 | df = pd.DataFrame(seq) 140 | writer = pd.ExcelWriter(save_name) 141 | df.to_excel(writer, 'sheet1') 142 | writer.save() 143 | writer.close() 144 | 145 | 146 | if __name__ == '__main__': 147 | random_data = np.random.randint(0, 10, 100) 148 | draw_cdf(random_data) -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/losses/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | class KeypointLoss(nn.Module): 12 | def __init__(self): 13 | super(KeypointLoss, self).__init__() 14 | 15 | def forward(self, pred_seq, gt_seq, gt_conf=None): 16 | #pred_seq: (B, C, T) 17 | if gt_conf is not None: 18 | gt_conf = gt_conf >= 0.01 19 | return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean') 20 | else: 21 | return F.mse_loss(pred_seq, gt_seq) 22 | 23 | 24 | class KLLoss(nn.Module): 25 | def __init__(self, kl_tolerance): 26 | super(KLLoss, self).__init__() 27 | self.kl_tolerance = kl_tolerance 28 | 29 | def forward(self, mu, var, mul=1): 30 | kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64 31 | kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1) 32 | # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1) 33 | if self.kl_tolerance is not None: 34 | # above_line = kld_loss[kld_loss > self.kl_tolerance] 35 | # if len(above_line) > 0: 36 | # kld_loss = torch.mean(kld_loss) 37 | # else: 38 | # kld_loss = 0 39 | kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda')) 40 | # else: 41 | kld_loss = torch.mean(kld_loss) 42 | return kld_loss 43 | 44 | 45 | class L2KLLoss(nn.Module): 46 | def __init__(self, kl_tolerance): 47 | super(L2KLLoss, self).__init__() 48 | self.kl_tolerance = kl_tolerance 49 | 50 | def forward(self, x): 51 | # TODO: check 52 | kld_loss = torch.sum(x ** 2, dim=1) 53 | if self.kl_tolerance is not None: 54 | above_line = kld_loss[kld_loss > self.kl_tolerance] 55 | if len(above_line) > 0: 56 | kld_loss = torch.mean(kld_loss) 57 | else: 58 | kld_loss = 0 59 | else: 60 | kld_loss = torch.mean(kld_loss) 61 | return kld_loss 62 | 63 | class L2RegLoss(nn.Module): 64 | def __init__(self): 65 | super(L2RegLoss, self).__init__() 66 | 67 | def forward(self, x): 68 | #TODO: check 69 | return torch.sum(x**2) 70 | 71 | 72 | class L2Loss(nn.Module): 73 | def __init__(self): 74 | super(L2Loss, self).__init__() 75 | 76 | def forward(self, x): 77 | # TODO: check 78 | return torch.sum(x ** 2) 79 | 80 | 81 | class AudioLoss(nn.Module): 82 | def __init__(self): 83 | super(AudioLoss, self).__init__() 84 | 85 | def forward(self, dynamics, gt_poses): 86 | #pay attention, normalized 87 | mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1) 88 | gt = gt_poses - mean 89 | return F.mse_loss(dynamics, gt) 90 | 91 | L1Loss = nn.L1Loss -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .smplx_face import TrainWrapper as s2g_face 2 | from .smplx_body_vq import TrainWrapper as s2g_body_vq 3 | from .smplx_body_pixel import TrainWrapper as s2g_body_pixel 4 | from .body_ae import TrainWrapper as s2g_body_ae 5 | from .LS3DCG import TrainWrapper as LS3DCG 6 | from .base import TrainWrapperBaseClass 7 | 8 | from .utils import normalize, denormalize -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/init_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/init_model.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/smplx_body_pixel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_body_pixel.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/smplx_body_vq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_body_vq.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/smplx_face.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/smplx_face.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /nets/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | class TrainWrapperBaseClass(): 6 | def __init__(self, args, config) -> None: 7 | self.init_optimizer() 8 | 9 | def init_optimizer(self) -> None: 10 | print('using Adam') 11 | self.generator_optimizer = optim.Adam( 12 | self.generator.parameters(), 13 | lr = self.config.Train.learning_rate.generator_learning_rate, 14 | betas=[0.9, 0.999] 15 | ) 16 | if self.discriminator is not None: 17 | self.discriminator_optimizer = optim.Adam( 18 | self.discriminator.parameters(), 19 | lr = self.config.Train.learning_rate.discriminator_learning_rate, 20 | betas=[0.9, 0.999] 21 | ) 22 | 23 | def __call__(self, bat): 24 | raise NotImplementedError 25 | 26 | def get_loss(self, **kwargs): 27 | raise NotImplementedError 28 | 29 | def state_dict(self): 30 | model_state = { 31 | 'generator': self.generator.state_dict(), 32 | 'generator_optim': self.generator_optimizer.state_dict(), 33 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, 34 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None 35 | } 36 | return model_state 37 | 38 | def parameters(self): 39 | return self.generator.parameters() 40 | 41 | def load_state_dict(self, state_dict): 42 | if 'generator' in state_dict: 43 | self.generator.load_state_dict(state_dict['generator']) 44 | else: 45 | self.generator.load_state_dict(state_dict) 46 | 47 | if 'generator_optim' in state_dict and self.generator_optimizer is not None: 48 | self.generator_optimizer.load_state_dict(state_dict['generator_optim']) 49 | 50 | if self.discriminator is not None: 51 | self.discriminator.load_state_dict(state_dict['discriminator']) 52 | 53 | if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None: 54 | self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim']) 55 | 56 | def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs): 57 | raise NotImplementedError 58 | 59 | def init_params(self): 60 | if self.config.Data.pose.convert_to_6d: 61 | scale = 2 62 | else: 63 | scale = 1 64 | 65 | global_orient = round(0 * scale) 66 | leye_pose = reye_pose = round(0 * scale) 67 | jaw_pose = round(0 * scale) 68 | body_pose = round((63 - 24) * scale) 69 | left_hand_pose = right_hand_pose = round(45 * scale) 70 | if self.expression: 71 | expression = 100 72 | else: 73 | expression = 0 74 | 75 | b_j = 0 76 | jaw_dim = jaw_pose 77 | b_e = b_j + jaw_dim 78 | eye_dim = leye_pose + reye_pose 79 | b_b = b_e + eye_dim 80 | body_dim = global_orient + body_pose 81 | b_h = b_b + body_dim 82 | hand_dim = left_hand_pose + right_hand_pose 83 | b_f = b_h + hand_dim 84 | face_dim = expression 85 | 86 | self.dim_list = [b_j, b_e, b_b, b_h, b_f] 87 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim 88 | self.pose = int(self.full_dim / round(3 * scale)) 89 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] -------------------------------------------------------------------------------- /nets/body_ae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | from nets.base import TrainWrapperBaseClass 7 | from nets.spg.s2glayers import Discriminator as D_S2G 8 | from nets.spg.vqvae_1d import AE as s2g_body 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | 13 | from data_utils.lower_body import c_index, c_index_3d, c_index_6d 14 | 15 | 16 | def separate_aa(aa): 17 | aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5) 18 | axis = F.normalize(aa[:, :, :, :3], dim=-1) 19 | angle = F.normalize(aa[:, :, :, 3:5], dim=-1) 20 | return axis, angle 21 | 22 | 23 | class TrainWrapper(TrainWrapperBaseClass): 24 | ''' 25 | a wrapper receving a batch from data_utils and calculate loss 26 | ''' 27 | 28 | def __init__(self, args, config): 29 | self.args = args 30 | self.config = config 31 | self.device = torch.device(self.args.gpu) 32 | self.global_step = 0 33 | 34 | self.gan = False 35 | self.convert_to_6d = self.config.Data.pose.convert_to_6d 36 | self.preleng = self.config.Data.pose.pre_pose_length 37 | self.expression = self.config.Data.pose.expression 38 | self.epoch = 0 39 | self.init_params() 40 | self.num_classes = 4 41 | self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0, 42 | num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device) 43 | if self.gan: 44 | self.discriminator = D_S2G( 45 | pose_dim=110 + 64, pose=self.pose 46 | ).to(self.device) 47 | else: 48 | self.discriminator = None 49 | 50 | if self.convert_to_6d: 51 | self.c_index = c_index_6d 52 | else: 53 | self.c_index = c_index_3d 54 | 55 | super().__init__(args, config) 56 | 57 | def init_optimizer(self): 58 | 59 | self.g_optimizer = optim.Adam( 60 | self.g.parameters(), 61 | lr=self.config.Train.learning_rate.generator_learning_rate, 62 | betas=[0.9, 0.999] 63 | ) 64 | 65 | def state_dict(self): 66 | model_state = { 67 | 'g': self.g.state_dict(), 68 | 'g_optim': self.g_optimizer.state_dict(), 69 | 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None, 70 | 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None 71 | } 72 | return model_state 73 | 74 | 75 | def __call__(self, bat): 76 | # assert (not self.args.infer), "infer mode" 77 | self.global_step += 1 78 | 79 | total_loss = None 80 | loss_dict = {} 81 | 82 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) 83 | 84 | # id = bat['speaker'].to(self.device) - 20 85 | # id = F.one_hot(id, self.num_classes) 86 | 87 | poses = poses[:, self.c_index, :] 88 | gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1) 89 | 90 | loss = 0 91 | loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss) 92 | 93 | return total_loss, loss_dict 94 | 95 | def vq_train(self, gt, name, model, dict, total_loss, pre=None): 96 | x_recon = model(gt_poses=gt, pre_state=pre) 97 | loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre) 98 | # total_loss = total_loss + loss 99 | 100 | if name == 'g': 101 | optimizer_name = 'g_optimizer' 102 | 103 | optimizer = getattr(self, optimizer_name) 104 | optimizer.zero_grad() 105 | loss.backward() 106 | optimizer.step() 107 | 108 | for key in list(loss_dict.keys()): 109 | dict[name + key] = loss_dict.get(key, 0).item() 110 | return dict, total_loss 111 | 112 | def get_loss(self, 113 | pred_poses, 114 | gt_poses, 115 | pre=None 116 | ): 117 | loss_dict = {} 118 | 119 | 120 | rec_loss = torch.mean(torch.abs(pred_poses - gt_poses)) 121 | v_pr = pred_poses[:, 1:] - pred_poses[:, :-1] 122 | v_gt = gt_poses[:, 1:] - gt_poses[:, :-1] 123 | velocity_loss = torch.mean(torch.abs(v_pr - v_gt)) 124 | 125 | if pre is None: 126 | f0_vel = 0 127 | else: 128 | v0_pr = pred_poses[:, 0] - pre[:, -1] 129 | v0_gt = gt_poses[:, 0] - pre[:, -1] 130 | f0_vel = torch.mean(torch.abs(v0_pr - v0_gt)) 131 | 132 | gen_loss = rec_loss + velocity_loss + f0_vel 133 | 134 | loss_dict['rec_loss'] = rec_loss 135 | loss_dict['velocity_loss'] = velocity_loss 136 | # loss_dict['e_q_loss'] = e_q_loss 137 | if pre is not None: 138 | loss_dict['f0_vel'] = f0_vel 139 | 140 | return gen_loss, loss_dict 141 | 142 | def load_state_dict(self, state_dict): 143 | self.g.load_state_dict(state_dict['g']) 144 | 145 | def extract(self, x): 146 | self.g.eval() 147 | if x.shape[2] > self.full_dim: 148 | if x.shape[2] == 239: 149 | x = x[:, :, 102:] 150 | x = x[:, :, self.c_index] 151 | feat = self.g.encode(x) 152 | return feat.transpose(1, 2), x 153 | -------------------------------------------------------------------------------- /nets/init_model.py: -------------------------------------------------------------------------------- 1 | from nets import * 2 | 3 | 4 | def init_model(model_name, args, config): 5 | 6 | if model_name == 's2g_face': 7 | generator = s2g_face( 8 | args, 9 | config, 10 | ) 11 | elif model_name == 's2g_body_vq': 12 | generator = s2g_body_vq( 13 | args, 14 | config, 15 | ) 16 | elif model_name == 's2g_body_pixel': 17 | generator = s2g_body_pixel( 18 | args, 19 | config, 20 | ) 21 | elif model_name == 's2g_body_ae': 22 | generator = s2g_body_ae( 23 | args, 24 | config, 25 | ) 26 | elif model_name == 's2g_LS3DCG': 27 | generator = LS3DCG( 28 | args, 29 | config, 30 | ) 31 | else: 32 | raise ValueError 33 | return generator 34 | 35 | 36 | -------------------------------------------------------------------------------- /nets/smplx_face.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | from nets.layers import * 7 | from nets.base import TrainWrapperBaseClass 8 | # from nets.spg.faceformer import Faceformer 9 | from nets.spg.s2g_face import Generator as s2g_face 10 | from losses import KeypointLoss 11 | from nets.utils import denormalize 12 | from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta 13 | import numpy as np 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from sklearn.preprocessing import normalize 17 | import smplx 18 | 19 | 20 | class TrainWrapper(TrainWrapperBaseClass): 21 | ''' 22 | a wrapper receving a batch from data_utils and calculate loss 23 | ''' 24 | 25 | def __init__(self, args, config): 26 | self.args = args 27 | self.config = config 28 | self.device = torch.device(self.args.gpu) 29 | self.global_step = 0 30 | 31 | self.convert_to_6d = self.config.Data.pose.convert_to_6d 32 | self.expression = self.config.Data.pose.expression 33 | self.epoch = 0 34 | self.init_params() 35 | self.num_classes = 4 36 | 37 | self.generator = s2g_face( 38 | n_poses=self.config.Data.pose.generate_length, 39 | each_dim=self.each_dim, 40 | dim_list=self.dim_list, 41 | training=not self.args.infer, 42 | device=self.device, 43 | identity=False if self.convert_to_6d else True, 44 | num_classes=self.num_classes, 45 | ).to(self.device) 46 | 47 | # self.generator = Faceformer().to(self.device) 48 | 49 | self.discriminator = None 50 | self.am = None 51 | 52 | self.MSELoss = KeypointLoss().to(self.device) 53 | super().__init__(args, config) 54 | 55 | def init_optimizer(self): 56 | self.generator_optimizer = optim.SGD( 57 | filter(lambda p: p.requires_grad,self.generator.parameters()), 58 | lr=0.001, 59 | momentum=0.9, 60 | nesterov=False, 61 | ) 62 | 63 | def init_params(self): 64 | if self.convert_to_6d: 65 | scale = 2 66 | else: 67 | scale = 1 68 | 69 | global_orient = round(3 * scale) 70 | leye_pose = reye_pose = round(3 * scale) 71 | jaw_pose = round(3 * scale) 72 | body_pose = round(63 * scale) 73 | left_hand_pose = right_hand_pose = round(45 * scale) 74 | if self.expression: 75 | expression = 100 76 | else: 77 | expression = 0 78 | 79 | b_j = 0 80 | jaw_dim = jaw_pose 81 | b_e = b_j + jaw_dim 82 | eye_dim = leye_pose + reye_pose 83 | b_b = b_e + eye_dim 84 | body_dim = global_orient + body_pose 85 | b_h = b_b + body_dim 86 | hand_dim = left_hand_pose + right_hand_pose 87 | b_f = b_h + hand_dim 88 | face_dim = expression 89 | 90 | self.dim_list = [b_j, b_e, b_b, b_h, b_f] 91 | self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim 92 | self.pose = int(self.full_dim / round(3 * scale)) 93 | self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] 94 | 95 | def __call__(self, bat): 96 | # assert (not self.args.infer), "infer mode" 97 | self.global_step += 1 98 | 99 | total_loss = None 100 | loss_dict = {} 101 | 102 | aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) 103 | id = bat['speaker'].to(self.device) - 20 104 | id = F.one_hot(id, self.num_classes) 105 | 106 | aud = aud.permute(0, 2, 1) 107 | gt_poses = poses.permute(0, 2, 1) 108 | 109 | if self.expression: 110 | expression = bat['expression'].to(self.device).to(torch.float32) 111 | gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2) 112 | 113 | pred_poses, _ = self.generator( 114 | aud, 115 | gt_poses, 116 | id, 117 | ) 118 | 119 | G_loss, G_loss_dict = self.get_loss( 120 | pred_poses=pred_poses, 121 | gt_poses=gt_poses, 122 | pre_poses=None, 123 | mode='training_G', 124 | gt_conf=None, 125 | aud=aud, 126 | ) 127 | 128 | self.generator_optimizer.zero_grad() 129 | G_loss.backward() 130 | grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm) 131 | loss_dict['grad'] = grad.item() 132 | self.generator_optimizer.step() 133 | 134 | for key in list(G_loss_dict.keys()): 135 | loss_dict[key] = G_loss_dict.get(key, 0).item() 136 | 137 | return total_loss, loss_dict 138 | 139 | def get_loss(self, 140 | pred_poses, 141 | gt_poses, 142 | pre_poses, 143 | aud, 144 | mode='training_G', 145 | gt_conf=None, 146 | exp=1, 147 | gt_nzero=None, 148 | pre_nzero=None, 149 | ): 150 | loss_dict = {} 151 | 152 | 153 | [b_j, b_e, b_b, b_h, b_f] = self.dim_list 154 | 155 | MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6])) 156 | if self.expression: 157 | expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2) 158 | else: 159 | expl = 0 160 | 161 | gen_loss = expl + MSELoss 162 | 163 | loss_dict['MSELoss'] = MSELoss 164 | if self.expression: 165 | loss_dict['exp_loss'] = expl 166 | 167 | return gen_loss, loss_dict 168 | 169 | def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs): 170 | ''' 171 | initial_pose: (B, C, T), normalized 172 | (aud_fn, txgfile) -> generated motion (B, T, C) 173 | ''' 174 | output = [] 175 | 176 | # assert self.args.infer, "train mode" 177 | self.generator.eval() 178 | 179 | if self.config.Data.pose.normalization: 180 | assert norm_stats is not None 181 | data_mean = norm_stats[0] 182 | data_std = norm_stats[1] 183 | 184 | # assert initial_pose.shape[-1] == pre_length 185 | if initial_pose is not None: 186 | gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32) 187 | pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32) 188 | poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32) 189 | B = pre_poses.shape[0] 190 | else: 191 | gt = None 192 | pre_poses=None 193 | B = 1 194 | 195 | if type(aud_fn) == torch.Tensor: 196 | aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device) 197 | num_poses_to_generate = aud_feat.shape[-1] 198 | else: 199 | aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer') 200 | aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) 201 | aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2) 202 | if frame is None: 203 | frame = aud_feat.shape[2]*30//16000 204 | # 205 | if id is None: 206 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device) 207 | else: 208 | id = F.one_hot(id, self.num_classes).to(self.generator.device) 209 | 210 | with torch.no_grad(): 211 | pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0] 212 | pred_poses = pred_poses.cpu().numpy() 213 | output = pred_poses 214 | 215 | if self.config.Data.pose.normalization: 216 | output = denormalize(output, data_mean, data_std) 217 | 218 | return output 219 | 220 | 221 | def generate(self, wv2_feat, frame): 222 | ''' 223 | initial_pose: (B, C, T), normalized 224 | (aud_fn, txgfile) -> generated motion (B, T, C) 225 | ''' 226 | output = [] 227 | 228 | # assert self.args.infer, "train mode" 229 | self.generator.eval() 230 | 231 | B = 1 232 | 233 | id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device) 234 | id = id.repeat(wv2_feat.shape[0], 1) 235 | 236 | with torch.no_grad(): 237 | pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0] 238 | return pred_poses 239 | -------------------------------------------------------------------------------- /nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/gated_pixelcnn_v2.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/__pycache__/s2g_face.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/s2g_face.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/__pycache__/s2glayers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/s2glayers.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/__pycache__/vqvae_1d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/vqvae_1d.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/__pycache__/vqvae_modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/vqvae_modules.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/__pycache__/wav2vec.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/nets/spg/__pycache__/wav2vec.cpython-37.pyc -------------------------------------------------------------------------------- /nets/spg/gated_pixelcnn_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | try: 10 | nn.init.xavier_uniform_(m.weight.data) 11 | m.bias.data.fill_(0) 12 | except AttributeError: 13 | print("Skipping initialization of ", classname) 14 | 15 | 16 | class GatedActivation(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x): 21 | x, y = x.chunk(2, dim=1) 22 | return F.tanh(x) * F.sigmoid(y) 23 | 24 | 25 | class GatedMaskedConv2d(nn.Module): 26 | def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False): 27 | super().__init__() 28 | assert kernel % 2 == 1, print("Kernel size must be odd") 29 | self.mask_type = mask_type 30 | self.residual = residual 31 | self.bh_model = bh_model 32 | 33 | self.class_cond_embedding = nn.Embedding( 34 | n_classes, 2 * dim 35 | ) 36 | 37 | kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n) 38 | padding_shp = (kernel // 2, 1 if self.bh_model else 0) 39 | self.vert_stack = nn.Conv2d( 40 | dim, dim * 2, 41 | kernel_shp, 1, padding_shp 42 | ) 43 | 44 | self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1) 45 | 46 | kernel_shp = (1, 2) 47 | padding_shp = (0, 1) 48 | self.horiz_stack = nn.Conv2d( 49 | dim, dim * 2, 50 | kernel_shp, 1, padding_shp 51 | ) 52 | 53 | self.horiz_resid = nn.Conv2d(dim, dim, 1) 54 | 55 | self.gate = GatedActivation() 56 | 57 | def make_causal(self): 58 | self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row 59 | self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column 60 | 61 | def forward(self, x_v, x_h, h): 62 | if self.mask_type == 'A': 63 | self.make_causal() 64 | 65 | h = self.class_cond_embedding(h) 66 | h_vert = self.vert_stack(x_v) 67 | h_vert = h_vert[:, :, :x_v.size(-2), :] 68 | out_v = self.gate(h_vert + h[:, :, None, None]) 69 | 70 | if self.bh_model: 71 | h_horiz = self.horiz_stack(x_h) 72 | h_horiz = h_horiz[:, :, :, :x_h.size(-1)] 73 | v2h = self.vert_to_horiz(h_vert) 74 | 75 | out = self.gate(v2h + h_horiz + h[:, :, None, None]) 76 | if self.residual: 77 | out_h = self.horiz_resid(out) + x_h 78 | else: 79 | out_h = self.horiz_resid(out) 80 | else: 81 | if self.residual: 82 | out_v = self.horiz_resid(out_v) + x_v 83 | else: 84 | out_v = self.horiz_resid(out_v) 85 | out_h = out_v 86 | 87 | return out_v, out_h 88 | 89 | 90 | class GatedPixelCNN(nn.Module): 91 | def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False): 92 | super().__init__() 93 | self.dim = dim 94 | self.audio = audio 95 | self.bh_model = bh_model 96 | 97 | if self.audio: 98 | self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0) 99 | self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0) 100 | self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0) 101 | 102 | # Create embedding layer to embed input 103 | self.embedding = nn.Embedding(input_dim, dim) 104 | 105 | # Building the PixelCNN layer by layer 106 | self.layers = nn.ModuleList() 107 | 108 | # Initial block with Mask-A convolution 109 | # Rest with Mask-B convolutions 110 | for i in range(n_layers): 111 | mask_type = 'A' if i == 0 else 'B' 112 | kernel = 7 if i == 0 else 3 113 | residual = False if i == 0 else True 114 | 115 | self.layers.append( 116 | GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model) 117 | ) 118 | 119 | # Add the output layer 120 | self.output_conv = nn.Sequential( 121 | nn.Conv2d(dim, 512, 1), 122 | nn.ReLU(True), 123 | nn.Conv2d(512, input_dim, 1) 124 | ) 125 | 126 | self.apply(weights_init) 127 | 128 | self.dp = nn.Dropout(0.1) 129 | 130 | def forward(self, x, label, aud=None): 131 | shp = x.size() + (-1,) 132 | x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C) 133 | x = x.permute(0, 3, 1, 2) # (B, C, W, W) 134 | 135 | x_v, x_h = (x, x) 136 | for i, layer in enumerate(self.layers): 137 | if i == 1 and self.audio is True: 138 | aud = self.embedding_aud(aud) 139 | a = torch.ones(aud.shape[-2]).to(aud.device) 140 | a = self.dp(a) 141 | aud = (aud.transpose(-1, -2) * a).transpose(-1, -2) 142 | x_v = self.fusion_v(torch.cat([x_v, aud], dim=1)) 143 | if self.bh_model: 144 | x_h = self.fusion_h(torch.cat([x_h, aud], dim=1)) 145 | x_v, x_h = layer(x_v, x_h, label) 146 | 147 | if self.bh_model: 148 | return self.output_conv(x_h) 149 | else: 150 | return self.output_conv(x_v) 151 | 152 | def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None): 153 | param = next(self.parameters()) 154 | x = torch.zeros( 155 | (batch_size, *shape), 156 | dtype=torch.int64, device=param.device 157 | ) 158 | if pre_latents is not None: 159 | x = torch.cat([pre_latents, x], dim=1) 160 | aud_feat = torch.cat([pre_audio, aud_feat], dim=2) 161 | h0 = pre_latents.shape[1] 162 | h = h0 + shape[0] 163 | else: 164 | h0 = 0 165 | h = shape[0] 166 | 167 | for i in range(h0, h): 168 | for j in range(shape[1]): 169 | if self.audio: 170 | logits = self.forward(x, label, aud_feat) 171 | else: 172 | logits = self.forward(x, label) 173 | probs = F.softmax(logits[:, :, i, j], -1) 174 | x.data[:, i, j].copy_( 175 | probs.multinomial(1).squeeze().data 176 | ) 177 | return x[:, h0:h] 178 | -------------------------------------------------------------------------------- /nets/spg/s2g_face.py: -------------------------------------------------------------------------------- 1 | ''' 2 | not exactly the same as the official repo but the results are good 3 | ''' 4 | import sys 5 | import os 6 | 7 | from transformers import Wav2Vec2Processor 8 | 9 | from .wav2vec import Wav2Vec2Model 10 | from torchaudio.sox_effects import apply_effects_tensor 11 | 12 | sys.path.append(os.getcwd()) 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchaudio as ta 19 | import math 20 | from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu 21 | 22 | 23 | """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ 24 | 25 | 26 | def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): 27 | """ 28 | :param audio: 1 x T tensor containing a 16kHz audio signal 29 | :param frame_rate: frame rate for video (we need one audio chunk per video frame) 30 | :param chunk_size: number of audio samples per chunk 31 | :return: num_chunks x chunk_size tensor containing sliced audio 32 | """ 33 | samples_per_frame = 16000 // frame_rate 34 | padding = (chunk_size - samples_per_frame) // 2 35 | audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) 36 | anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) 37 | audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) 38 | return audio 39 | 40 | 41 | class MeshtalkEncoder(nn.Module): 42 | def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): 43 | """ 44 | :param latent_dim: size of the latent audio embedding 45 | :param model_name: name of the model, used to load and save the model 46 | """ 47 | super().__init__() 48 | 49 | self.melspec = ta.transforms.MelSpectrogram( 50 | sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 51 | ) 52 | 53 | conv_len = 5 54 | self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) 55 | self.weights_init(self.convert_dimensions) 56 | self.receptive_field = conv_len 57 | 58 | convs = [] 59 | for i in range(6): 60 | dilation = 2 * (i % 3 + 1) 61 | self.receptive_field += (conv_len - 1) * dilation 62 | convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] 63 | self.weights_init(convs[-1]) 64 | self.convs = torch.nn.ModuleList(convs) 65 | self.code = torch.nn.Linear(128, latent_dim) 66 | 67 | self.apply(lambda x: self.weights_init(x)) 68 | 69 | def weights_init(self, m): 70 | if isinstance(m, torch.nn.Conv1d): 71 | torch.nn.init.xavier_uniform_(m.weight) 72 | try: 73 | torch.nn.init.constant_(m.bias, .01) 74 | except: 75 | pass 76 | 77 | def forward(self, audio: torch.Tensor): 78 | """ 79 | :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame 80 | :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding 81 | """ 82 | B, T = audio.shape[0], audio.shape[1] 83 | x = self.melspec(audio).squeeze(1) 84 | x = torch.log(x.clamp(min=1e-10, max=None)) 85 | if T == 1: 86 | x = x.unsqueeze(1) 87 | 88 | # Convert to the right dimensionality 89 | x = x.view(-1, x.shape[2], x.shape[3]) 90 | x = F.leaky_relu(self.convert_dimensions(x), .2) 91 | 92 | # Process stacks 93 | for conv in self.convs: 94 | x_ = F.leaky_relu(conv(x), .2) 95 | if self.training: 96 | x_ = F.dropout(x_, .2) 97 | l = (x.shape[2] - x_.shape[2]) // 2 98 | x = (x[:, :, l:-l] + x_) / 2 99 | 100 | x = torch.mean(x, dim=-1) 101 | x = x.view(B, T, x.shape[-1]) 102 | x = self.code(x) 103 | 104 | return {"code": x} 105 | 106 | 107 | class AudioEncoder(nn.Module): 108 | def __init__(self, in_dim, out_dim, identity=False, num_classes=0): 109 | super().__init__() 110 | self.identity = identity 111 | if self.identity: 112 | in_dim = in_dim + 64 113 | self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1) 114 | self.first_net = SeqTranslator1D(in_dim, out_dim, 115 | min_layers_num=3, 116 | residual=True, 117 | norm='ln' 118 | ) 119 | self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True) 120 | self.dropout = nn.Dropout(0.1) 121 | # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True) 122 | 123 | def forward(self, spectrogram, pre_state=None, id=None, time_steps=None): 124 | 125 | spectrogram = spectrogram 126 | spectrogram = self.dropout(spectrogram) 127 | if self.identity: 128 | id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32) 129 | id = self.id_mlp(id) 130 | spectrogram = torch.cat([spectrogram, id], dim=1) 131 | x1 = self.first_net(spectrogram)# .permute(0, 2, 1) 132 | if time_steps is not None: 133 | x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear') 134 | # x1, _ = self.att(x1, x1, x1) 135 | # x1, hidden_state = self.grus(x1) 136 | # x1 = x1.permute(0, 2, 1) 137 | hidden_state=None 138 | 139 | return x1, hidden_state 140 | 141 | 142 | class Generator(nn.Module): 143 | def __init__(self, 144 | n_poses, 145 | each_dim: list, 146 | dim_list: list, 147 | training=False, 148 | device=None, 149 | identity=True, 150 | num_classes=0, 151 | ): 152 | super().__init__() 153 | 154 | self.training = training 155 | self.device = device 156 | self.gen_length = n_poses 157 | self.identity = identity 158 | 159 | norm = 'ln' 160 | in_dim = 256 161 | out_dim = 256 162 | 163 | self.encoder_choice = 'faceformer' 164 | 165 | if self.encoder_choice == 'meshtalk': 166 | self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim) 167 | elif self.encoder_choice == 'faceformer': 168 | # wav2vec 2.0 weights initialization 169 | self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" 170 | self.audio_encoder.feature_extractor._freeze_parameters() 171 | self.audio_feature_map = nn.Linear(768, in_dim) 172 | else: 173 | self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim) 174 | 175 | self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes) 176 | 177 | self.dim_list = dim_list 178 | 179 | self.decoder = nn.ModuleList() 180 | self.final_out = nn.ModuleList() 181 | 182 | self.decoder.append(nn.Sequential( 183 | ConvNormRelu(out_dim, 64, norm=norm), 184 | ConvNormRelu(64, 64, norm=norm), 185 | ConvNormRelu(64, 64, norm=norm), 186 | )) 187 | self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) 188 | 189 | self.decoder.append(nn.Sequential( 190 | ConvNormRelu(out_dim, out_dim, norm=norm), 191 | ConvNormRelu(out_dim, out_dim, norm=norm), 192 | ConvNormRelu(out_dim, out_dim, norm=norm), 193 | )) 194 | self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1)) 195 | 196 | def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): 197 | if self.training: 198 | time_steps = gt_poses.shape[1] 199 | 200 | # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) 201 | if self.encoder_choice == 'meshtalk': 202 | in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) 203 | feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) 204 | elif self.encoder_choice == 'faceformer': 205 | hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state 206 | feature = self.audio_feature_map(hidden_states).transpose(1, 2) 207 | else: 208 | feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) 209 | 210 | # hidden_states = in_spec 211 | 212 | feature, _ = self.audio_middle(feature, id=id) 213 | 214 | out = [] 215 | 216 | for i in range(self.decoder.__len__()): 217 | mid = self.decoder[i](feature) 218 | mid = self.final_out[i](mid) 219 | out.append(mid) 220 | 221 | out = torch.cat(out, dim=1) 222 | out = out.transpose(1, 2) 223 | 224 | return out, None 225 | 226 | 227 | -------------------------------------------------------------------------------- /nets/spg/vqvae_1d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .wav2vec import Wav2Vec2Model 7 | from .vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack 8 | 9 | 10 | 11 | class AudioEncoder(nn.Module): 12 | def __init__(self, in_dim, num_hiddens, num_residual_layers, num_residual_hiddens): 13 | super(AudioEncoder, self).__init__() 14 | self._num_hiddens = num_hiddens 15 | self._num_residual_layers = num_residual_layers 16 | self._num_residual_hiddens = num_residual_hiddens 17 | 18 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True) 19 | 20 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) 21 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True, 22 | sample='down') 23 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) 24 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down') 25 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 26 | 27 | def forward(self, x, frame_num=0): 28 | h = self.project(x) 29 | h = self._enc_1(h) 30 | h = self._down_1(h) 31 | h = self._enc_2(h) 32 | h = self._down_2(h) 33 | h = self._enc_3(h) 34 | return h 35 | 36 | 37 | class Wav2VecEncoder(nn.Module): 38 | def __init__(self, num_hiddens, num_residual_layers): 39 | super(Wav2VecEncoder, self).__init__() 40 | self._num_hiddens = num_hiddens 41 | self._num_residual_layers = num_residual_layers 42 | 43 | self.audio_encoder = Wav2Vec2Model.from_pretrained( 44 | "facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" 45 | self.audio_encoder.feature_extractor._freeze_parameters() 46 | 47 | self.project = ConvNormRelu(768, self._num_hiddens, leaky=True) 48 | 49 | self._enc_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 50 | self._down_1 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down') 51 | self._enc_2 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 52 | self._down_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens, leaky=True, residual=True, sample='down') 53 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 54 | 55 | def forward(self, x, frame_num): 56 | h = self.audio_encoder(x.squeeze(), frame_num=frame_num).last_hidden_state.transpose(1, 2) 57 | h = self.project(h) 58 | h = self._enc_1(h) 59 | h = self._down_1(h) 60 | h = self._enc_2(h) 61 | h = self._down_2(h) 62 | h = self._enc_3(h) 63 | return h 64 | 65 | 66 | class Encoder(nn.Module): 67 | def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): 68 | super(Encoder, self).__init__() 69 | self._num_hiddens = num_hiddens 70 | self._num_residual_layers = num_residual_layers 71 | self._num_residual_hiddens = num_residual_hiddens 72 | 73 | self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True) 74 | 75 | self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) 76 | self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True, 77 | sample='down') 78 | self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) 79 | self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down') 80 | self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 81 | 82 | self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1) 83 | 84 | def forward(self, x): 85 | h = self.project(x) 86 | h = self._enc_1(h) 87 | h = self._down_1(h) 88 | h = self._enc_2(h) 89 | h = self._down_2(h) 90 | h = self._enc_3(h) 91 | h = self.pre_vq_conv(h) 92 | return h 93 | 94 | 95 | class Frame_Enc(nn.Module): 96 | def __init__(self, in_dim, num_hiddens): 97 | super(Frame_Enc, self).__init__() 98 | self.in_dim = in_dim 99 | self.num_hiddens = num_hiddens 100 | 101 | # self.enc = transformer_Enc(in_dim, num_hiddens, 2, 8, 256, 256, 256, 256, 0, dropout=0.1, n_position=4) 102 | self.proj = nn.Conv1d(in_dim, num_hiddens, 1, 1) 103 | self.enc = Res_CNR_Stack(num_hiddens, 2, leaky=True) 104 | self.proj_1 = nn.Conv1d(256*4, num_hiddens, 1, 1) 105 | self.proj_2 = nn.Conv1d(256*4, num_hiddens*2, 1, 1) 106 | 107 | def forward(self, x): 108 | # x = self.enc(x, None)[0].reshape(x.shape[0], -1, 1) 109 | x = self.enc(self.proj(x)).reshape(x.shape[0], -1, 1) 110 | second_last = self.proj_2(x) 111 | last = self.proj_1(x) 112 | return second_last, last 113 | 114 | 115 | 116 | class Decoder(nn.Module): 117 | def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False): 118 | super(Decoder, self).__init__() 119 | self._num_hiddens = num_hiddens 120 | self._num_residual_layers = num_residual_layers 121 | self._num_residual_hiddens = num_residual_hiddens 122 | 123 | self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1) 124 | 125 | self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) 126 | self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up') 127 | self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) 128 | self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True, 129 | sample='up') 130 | self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) 131 | 132 | if ae: 133 | self.frame_enc = Frame_Enc(out_dim, self._num_hiddens // 4) 134 | self.gru_sl = nn.GRU(self._num_hiddens // 2, self._num_hiddens // 2, 1, batch_first=True) 135 | self.gru_l = nn.GRU(self._num_hiddens // 4, self._num_hiddens // 4, 1, batch_first=True) 136 | 137 | self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1) 138 | 139 | def forward(self, h, last_frame=None): 140 | 141 | h = self.aft_vq_conv(h) 142 | h = self._dec_1(h) 143 | h = self._up_2(h) 144 | h = self._dec_2(h) 145 | h = self._up_3(h) 146 | h = self._dec_3(h) 147 | 148 | recon = self.project(h) 149 | return recon, None 150 | 151 | 152 | class Pre_VQ(nn.Module): 153 | def __init__(self, num_hiddens, embedding_dim, num_chunks): 154 | super(Pre_VQ, self).__init__() 155 | self.conv = nn.Conv1d(num_hiddens, num_hiddens, 1, 1, 0, groups=num_chunks) 156 | self.bn = nn.GroupNorm(num_chunks, num_hiddens) 157 | self.relu = nn.ReLU() 158 | self.proj = nn.Conv1d(num_hiddens, embedding_dim, 1, 1, 0, groups=num_chunks) 159 | 160 | def forward(self, x): 161 | x = self.conv(x) 162 | x = self.bn(x) 163 | x = self.relu(x) 164 | x = self.proj(x) 165 | return x 166 | 167 | 168 | class VQVAE(nn.Module): 169 | """VQ-VAE""" 170 | 171 | def __init__(self, in_dim, embedding_dim, num_embeddings, 172 | num_hiddens, num_residual_layers, num_residual_hiddens, 173 | commitment_cost=0.25, decay=0.99, share=False): 174 | super().__init__() 175 | self.in_dim = in_dim 176 | self.embedding_dim = embedding_dim 177 | self.num_embeddings = num_embeddings 178 | self.share_code_vq = share 179 | 180 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) 181 | self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay) 182 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) 183 | 184 | def forward(self, gt_poses, id=None, pre_state=None): 185 | z = self.encoder(gt_poses.transpose(1, 2)) 186 | if not self.training: 187 | e, _ = self.vq_layer(z) 188 | x_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) 189 | return e, x_recon 190 | 191 | e, e_q_loss = self.vq_layer(z) 192 | gt_recon, cur_state = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) 193 | 194 | return e_q_loss, gt_recon.transpose(1, 2) 195 | 196 | def encode(self, gt_poses, id=None): 197 | z = self.encoder(gt_poses.transpose(1, 2)) 198 | e, latents = self.vq_layer(z) 199 | return e, latents 200 | 201 | def decode(self, b, w, e=None, latents=None, pre_state=None): 202 | if e is not None: 203 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) 204 | else: 205 | e = self.vq_layer.quantize(latents) 206 | e = e.view(b, w, -1).permute(0, 2, 1).contiguous() 207 | x = self.decoder(e, pre_state.transpose(1, 2) if pre_state is not None else None) 208 | return x 209 | 210 | 211 | class AE(nn.Module): 212 | """VQ-VAE""" 213 | 214 | def __init__(self, in_dim, embedding_dim, num_embeddings, 215 | num_hiddens, num_residual_layers, num_residual_hiddens): 216 | super().__init__() 217 | self.in_dim = in_dim 218 | self.embedding_dim = embedding_dim 219 | self.num_embeddings = num_embeddings 220 | 221 | self.encoder = Encoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) 222 | self.decoder = Decoder(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, True) 223 | 224 | def forward(self, gt_poses, id=None, pre_state=None): 225 | z = self.encoder(gt_poses.transpose(1, 2)) 226 | if not self.training: 227 | x_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None) 228 | return z, x_recon 229 | gt_recon, cur_state = self.decoder(z, pre_state.transpose(1, 2) if pre_state is not None else None) 230 | 231 | return gt_recon.transpose(1, 2) 232 | 233 | def encode(self, gt_poses, id=None): 234 | z = self.encoder(gt_poses.transpose(1, 2)) 235 | return z 236 | -------------------------------------------------------------------------------- /nets/spg/wav2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import math 7 | from transformers import Wav2Vec2Model,Wav2Vec2Config 8 | from transformers.modeling_outputs import BaseModelOutput 9 | from typing import Optional, Tuple 10 | _CONFIG_FOR_DOC = "Wav2Vec2Config" 11 | 12 | # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model 13 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 14 | def _compute_mask_indices( 15 | shape: Tuple[int, int], 16 | mask_prob: float, 17 | mask_length: int, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | min_masks: int = 0, 20 | ) -> np.ndarray: 21 | bsz, all_sz = shape 22 | mask = np.full((bsz, all_sz), False) 23 | 24 | all_num_mask = int( 25 | mask_prob * all_sz / float(mask_length) 26 | + np.random.rand() 27 | ) 28 | all_num_mask = max(min_masks, all_num_mask) 29 | mask_idcs = [] 30 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None 31 | for i in range(bsz): 32 | if padding_mask is not None: 33 | sz = all_sz - padding_mask[i].long().sum().item() 34 | num_mask = int( 35 | mask_prob * sz / float(mask_length) 36 | + np.random.rand() 37 | ) 38 | num_mask = max(min_masks, num_mask) 39 | else: 40 | sz = all_sz 41 | num_mask = all_num_mask 42 | 43 | lengths = np.full(num_mask, mask_length) 44 | 45 | if sum(lengths) == 0: 46 | lengths[0] = min(mask_length, sz - 1) 47 | 48 | min_len = min(lengths) 49 | if sz - min_len <= num_mask: 50 | min_len = sz - num_mask - 1 51 | 52 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 53 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) 54 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 55 | 56 | min_len = min([len(m) for m in mask_idcs]) 57 | for i, mask_idc in enumerate(mask_idcs): 58 | if len(mask_idc) > min_len: 59 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 60 | mask[i, mask_idc] = True 61 | return mask 62 | 63 | # linear interpolation layer 64 | def linear_interpolation(features, input_fps, output_fps, output_len=None): 65 | features = features.transpose(1, 2) 66 | seq_len = features.shape[2] / float(input_fps) 67 | if output_len is None: 68 | output_len = int(seq_len * output_fps) 69 | output_features = F.interpolate(features,size=output_len,align_corners=False,mode='linear') 70 | return output_features.transpose(1, 2) 71 | 72 | 73 | class Wav2Vec2Model(Wav2Vec2Model): 74 | def __init__(self, config): 75 | super().__init__(config) 76 | def forward( 77 | self, 78 | input_values, 79 | attention_mask=None, 80 | output_attentions=None, 81 | output_hidden_states=None, 82 | return_dict=None, 83 | frame_num=None 84 | ): 85 | self.config.output_attentions = True 86 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 87 | output_hidden_states = ( 88 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 89 | ) 90 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 91 | 92 | hidden_states = self.feature_extractor(input_values) 93 | hidden_states = hidden_states.transpose(1, 2) 94 | 95 | hidden_states = linear_interpolation(hidden_states, 50, 30,output_len=frame_num) 96 | 97 | if attention_mask is not None: 98 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) 99 | attention_mask = torch.zeros( 100 | hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device 101 | ) 102 | attention_mask[ 103 | (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) 104 | ] = 1 105 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 106 | 107 | hidden_states = self.feature_projection(hidden_states) 108 | 109 | if self.config.apply_spec_augment and self.training: 110 | batch_size, sequence_length, hidden_size = hidden_states.size() 111 | if self.config.mask_time_prob > 0: 112 | mask_time_indices = _compute_mask_indices( 113 | (batch_size, sequence_length), 114 | self.config.mask_time_prob, 115 | self.config.mask_time_length, 116 | attention_mask=attention_mask, 117 | min_masks=2, 118 | ) 119 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) 120 | if self.config.mask_feature_prob > 0: 121 | mask_feature_indices = _compute_mask_indices( 122 | (batch_size, hidden_size), 123 | self.config.mask_feature_prob, 124 | self.config.mask_feature_length, 125 | ) 126 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) 127 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 128 | encoder_outputs = self.encoder( 129 | hidden_states[0], 130 | attention_mask=attention_mask, 131 | output_attentions=output_attentions, 132 | output_hidden_states=output_hidden_states, 133 | return_dict=return_dict, 134 | ) 135 | hidden_states = encoder_outputs[0] 136 | if not return_dict: 137 | return (hidden_states,) + encoder_outputs[1:] 138 | 139 | return BaseModelOutput( 140 | last_hidden_state=hidden_states, 141 | hidden_states=encoder_outputs.hidden_states, 142 | attentions=encoder_outputs.attentions, 143 | ) 144 | -------------------------------------------------------------------------------- /nets/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import textgrid as tg 3 | import numpy as np 4 | 5 | def get_parameter_size(model): 6 | total_num = sum(p.numel() for p in model.parameters()) 7 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | return total_num, trainable_num 9 | 10 | def denormalize(kps, data_mean, data_std): 11 | ''' 12 | kps: (B, T, C) 13 | ''' 14 | data_std = data_std.reshape(1, 1, -1) 15 | data_mean = data_mean.reshape(1, 1, -1) 16 | return (kps * data_std) + data_mean 17 | 18 | def normalize(kps, data_mean, data_std): 19 | ''' 20 | kps: (B, T, C) 21 | ''' 22 | data_std = data_std.squeeze().reshape(1, 1, -1) 23 | data_mean = data_mean.squeeze().reshape(1, 1, -1) 24 | 25 | return (kps-data_mean) / data_std 26 | 27 | def parse_audio(textgrid_file): 28 | '''a demo implementation''' 29 | words=['but', 'as', 'to', 'that', 'with', 'of', 'the', 'and', 'or', 'not', 'which', 'what', 'this', 'for', 'because', 'if', 'so', 'just', 'about', 'like', 'by', 'how', 'from', 'whats', 'now', 'very', 'that', 'also', 'actually', 'who', 'then', 'well', 'where', 'even', 'today', 'between', 'than', 'when'] 30 | txt=tg.TextGrid.fromFile(textgrid_file) 31 | 32 | total_time=int(np.ceil(txt.maxTime)) 33 | code_seq=np.zeros(total_time) 34 | 35 | word_level=txt[0] 36 | 37 | for i in range(len(word_level)): 38 | start_time=word_level[i].minTime 39 | end_time=word_level[i].maxTime 40 | mark=word_level[i].mark 41 | 42 | if mark in words: 43 | start=int(np.round(start_time)) 44 | end=int(np.round(end_time)) 45 | 46 | if start >= len(code_seq) or end >= len(code_seq): 47 | code_seq[-1] = 1 48 | else: 49 | code_seq[start]=1 50 | 51 | return code_seq 52 | 53 | 54 | def get_path(model_name, model_type): 55 | if model_name == 's2g_body_pixel': 56 | if model_type == 'mfcc': 57 | return './experiments/2022-10-09-smplx_S2G-body-pixel-aud-3p/ckpt-99.pth' 58 | elif model_type == 'wv2': 59 | return './experiments/2022-10-28-smplx_S2G-body-pixel-wv2-sg2/ckpt-99.pth' 60 | elif model_type == 'random': 61 | return './experiments/2022-10-09-smplx_S2G-body-pixel-random-3p/ckpt-99.pth' 62 | elif model_type == 'wbhmodel': 63 | return './experiments/2022-11-02-smplx_S2G-body-pixel-w-bhmodel/ckpt-99.pth' 64 | elif model_type == 'wobhmodel': 65 | return './experiments/2022-11-02-smplx_S2G-body-pixel-wo-bhmodel/ckpt-99.pth' 66 | elif model_name == 's2g_body': 67 | if model_type == 'a+m-vae': 68 | return './experiments/2022-10-19-smplx_S2G-body-audio-motion-vae/ckpt-99.pth' 69 | elif model_type == 'a-vae': 70 | return './experiments/2022-10-18-smplx_S2G-body-audiovae/ckpt-99.pth' 71 | elif model_type == 'a-ed': 72 | return './experiments/2022-10-18-smplx_S2G-body-audioae/ckpt-99.pth' 73 | elif model_name == 's2g_LS3DCG': 74 | return './experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth' 75 | elif model_name == 's2g_body_vq': 76 | if model_type == 'n_com_1024': 77 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn1024/ckpt-99.pth' 78 | elif model_type == 'n_com_2048': 79 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn2048/ckpt-99.pth' 80 | elif model_type == 'n_com_4096': 81 | return './experiments/2022-10-29-smplx_S2G-body-vq-cn4096/ckpt-99.pth' 82 | elif model_type == 'n_com_8192': 83 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn8192/ckpt-99.pth' 84 | elif model_type == 'n_com_16384': 85 | return './experiments/2022-11-02-smplx_S2G-body-vq-cn16384/ckpt-99.pth' 86 | elif model_type == 'n_com_170000': 87 | return './experiments/2022-10-30-smplx_S2G-body-vq-cn170000/ckpt-99.pth' 88 | elif model_type == 'com_1024': 89 | return './experiments/2022-10-29-smplx_S2G-body-vq-composition/ckpt-99.pth' 90 | elif model_type == 'com_2048': 91 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition2048/ckpt-99.pth' 92 | elif model_type == 'com_4096': 93 | return './experiments/2022-10-31-smplx_S2G-body-vq-composition4096/ckpt-99.pth' 94 | elif model_type == 'com_8192': 95 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition8192/ckpt-99.pth' 96 | elif model_type == 'com_16384': 97 | return './experiments/2022-11-02-smplx_S2G-body-vq-composition16384/ckpt-99.pth' 98 | 99 | 100 | def get_dpath(model_name, model_type): 101 | if model_name == 's2g_body_pixel': 102 | if model_type == 'audio': 103 | return './experiments/2022-10-26-smplx_S2G-d-pixel-aud/ckpt-9.pth' 104 | elif model_type == 'wv2': 105 | return './experiments/2022-11-04-smplx_S2G-d-pixel-wv2/ckpt-9.pth' 106 | elif model_type == 'random': 107 | return './experiments/2022-10-26-smplx_S2G-d-pixel-random/ckpt-9.pth' 108 | elif model_type == 'wbhmodel': 109 | return './experiments/2022-11-10-smplx_S2G-hD-wbhmodel/ckpt-9.pth' 110 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wbhmodel/ckpt-9.pth' 111 | elif model_type == 'wobhmodel': 112 | return './experiments/2022-11-10-smplx_S2G-hD-wobhmodel/ckpt-9.pth' 113 | # return './experiments/2022-11-05-smplx_S2G-d-pixel-wobhmodel/ckpt-9.pth' 114 | elif model_name == 's2g_body': 115 | if model_type == 'a+m-vae': 116 | return './experiments/2022-10-26-smplx_S2G-d-audio+motion-vae/ckpt-9.pth' 117 | elif model_type == 'a-vae': 118 | return './experiments/2022-10-26-smplx_S2G-d-audio-vae/ckpt-9.pth' 119 | elif model_type == 'a-ed': 120 | return './experiments/2022-10-26-smplx_S2G-d-audio-ae/ckpt-9.pth' 121 | elif model_name == 's2g_LS3DCG': 122 | return './experiments/2022-10-26-smplx_S2G-d-ls3dcg/ckpt-9.pth' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.21.5 2 | transformers~=4.22.1 3 | matplotlib~=3.2.2 4 | textgrid~=1.5 5 | smplx~=0.1.28 6 | scikit-learn~=1.0.2 7 | pyrender~=0.1.45 8 | trimesh~=3.14.1 9 | tqdm~=4.64.1 10 | librosa~=0.9.2 11 | scipy~=1.7.3 12 | python_speech_features~=0.6 13 | opencv-python~=4.7.0.68 14 | pyglet~=1.5 -------------------------------------------------------------------------------- /scripts/.idea/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/.idea/__init__.py -------------------------------------------------------------------------------- /scripts/.idea/aws.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | -------------------------------------------------------------------------------- /scripts/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /scripts/.idea/get_prevar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | 5 | sys.path.append(os.getcwd()) 6 | from glob import glob 7 | 8 | import numpy as np 9 | import json 10 | import smplx as smpl 11 | 12 | from nets import * 13 | from repro_nets import * 14 | from trainer.options import parse_args 15 | from data_utils import torch_data 16 | from trainer.config import load_JsonConfig 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils import data 22 | 23 | def init_model(model_name, model_path, args, config): 24 | if model_name == 'freeMo': 25 | # generator = freeMo_Generator(args) 26 | # generator = freeMo_Generator(args) 27 | generator = freeMo_dev(args, config) 28 | # generator.load_state_dict(torch.load(model_path)['generator']) 29 | elif model_name == 'smplx_S2G': 30 | generator = smplx_S2G(args, config) 31 | elif model_name == 'StyleGestures': 32 | generator = StyleGesture_Generator( 33 | args, 34 | config 35 | ) 36 | elif model_name == 'Audio2Gestures': 37 | config.Train.using_mspec_stat = False 38 | generator = Audio2Gesture_Generator( 39 | args, 40 | config, 41 | torch.zeros([1, 1, 108]), 42 | torch.ones([1, 1, 108]) 43 | ) 44 | elif model_name == 'S2G': 45 | generator = S2G_Generator( 46 | args, 47 | config, 48 | ) 49 | elif model_name == 'Tmpt': 50 | generator = S2G_Generator( 51 | args, 52 | config, 53 | ) 54 | else: 55 | raise NotImplementedError 56 | 57 | model_ckpt = torch.load(model_path, map_location=torch.device('cpu')) 58 | if model_name == 'smplx_S2G': 59 | generator.generator.load_state_dict(model_ckpt['generator']['generator']) 60 | elif 'generator' in list(model_ckpt.keys()): 61 | generator.load_state_dict(model_ckpt['generator']) 62 | else: 63 | model_ckpt = {'generator': model_ckpt} 64 | generator.load_state_dict(model_ckpt) 65 | 66 | return generator 67 | 68 | 69 | 70 | def prevar_loader(data_root, speakers, args, config, model_path, device, generator): 71 | path = model_path.split('ckpt')[0] 72 | file = os.path.join(os.path.dirname(path), "pre_variable.npy") 73 | data_base = torch_data( 74 | data_root=data_root, 75 | speakers=speakers, 76 | split='pre', 77 | limbscaling=False, 78 | normalization=config.Data.pose.normalization, 79 | norm_method=config.Data.pose.norm_method, 80 | split_trans_zero=False, 81 | num_pre_frames=config.Data.pose.pre_pose_length, 82 | num_generate_length=config.Data.pose.generate_length, 83 | num_frames=15, 84 | aud_feat_win_size=config.Data.aud.aud_feat_win_size, 85 | aud_feat_dim=config.Data.aud.aud_feat_dim, 86 | feat_method=config.Data.aud.feat_method, 87 | smplx=True, 88 | audio_sr=22000, 89 | convert_to_6d=config.Data.pose.convert_to_6d, 90 | expression=config.Data.pose.expression 91 | ) 92 | 93 | data_base.get_dataset() 94 | pre_set = data_base.all_dataset 95 | pre_loader = data.DataLoader(pre_set, batch_size=config.DataLoader.batch_size, shuffle=False, drop_last=True) 96 | 97 | total_pose = [] 98 | 99 | with torch.no_grad(): 100 | for bat in pre_loader: 101 | pose = bat['poses'].to(device).to(torch.float32) 102 | expression = bat['expression'].to(device).to(torch.float32) 103 | pose = pose.permute(0, 2, 1) 104 | pose = torch.cat([pose[:, :15], pose[:, 15:30], pose[:, 30:45], pose[:, 45:60], pose[:, 60:]], dim=0) 105 | expression = expression.permute(0, 2, 1) 106 | expression = torch.cat([expression[:, :15], expression[:, 15:30], expression[:, 30:45], expression[:, 45:60], expression[:, 60:]], dim=0) 107 | pose = torch.cat([pose, expression], dim=-1) 108 | pose = pose.reshape(pose.shape[0], -1, 1) 109 | pose_code = generator.generator.pre_pose_encoder(pose).squeeze().detach().cpu() 110 | total_pose.append(np.asarray(pose_code)) 111 | total_pose = np.concatenate(total_pose, axis=0) 112 | mean = np.mean(total_pose, axis=0) 113 | std = np.std(total_pose, axis=0) 114 | prevar = (mean, std) 115 | np.save(file, prevar, allow_pickle=True) 116 | 117 | return mean, std 118 | 119 | def main(): 120 | parser = parse_args() 121 | args = parser.parse_args() 122 | device = torch.device(args.gpu) 123 | torch.cuda.set_device(device) 124 | 125 | config = load_JsonConfig(args.config_file) 126 | 127 | print('init model...') 128 | generator = init_model(config.Model.model_name, args.model_path, args, config) 129 | print('init pre-pose vectors...') 130 | mean, std = prevar_loader(config.Data.data_root, args.speakers, args, config, args.model_path, device, generator) 131 | 132 | main() -------------------------------------------------------------------------------- /scripts/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /scripts/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /scripts/.idea/lower body: -------------------------------------------------------------------------------- 1 | 0, 1, 3, 4, 6, 7, 9, 10, -------------------------------------------------------------------------------- /scripts/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/.idea/scripts.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/.idea/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhw-yhw/TalkSHOW/9aef82df5ff1082f0cfa0cfc116c0b7208e85d5b/scripts/.idea/test.png -------------------------------------------------------------------------------- /scripts/.idea/testtext.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | # path being defined from where the system will read the image 4 | path = r'test.png' 5 | # command used for reading an image from the disk disk, cv2.imread function is used 6 | image1 = cv2.imread(path) 7 | # Window name being specified where the image will be displayed 8 | window_name1 = 'image' 9 | # font for the text being specified 10 | font1 = cv2.FONT_HERSHEY_SIMPLEX 11 | # org for the text being specified 12 | org1 = (50, 50) 13 | # font scale for the text being specified 14 | fontScale1 = 1 15 | # Blue color for the text being specified from BGR 16 | color1 = (255, 255, 255) 17 | # Line thickness for the text being specified at 2 px 18 | thickness1 = 2 19 | # Using the cv2.putText() method for inserting text in the image of the specified path 20 | image_1 = cv2.putText(image1, 'CAT IN BOX', org1, font1, fontScale1, color1, thickness1, cv2.LINE_AA) 21 | # Displaying the output image 22 | cv2.imshow(window_name1, image_1) 23 | cv2.waitKey(0) 24 | cv2.destroyAllWindows() 25 | -------------------------------------------------------------------------------- /scripts/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /scripts/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |