├── README.md ├── TANGO_gradio_jupyter.ipynb └── TANGO_jupyter.ipynb /README.md: -------------------------------------------------------------------------------- 1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 | 5 | 🚦 Prototype Model 🚦 6 | 7 | ### 🍊 Jupyter Notebook 8 | 9 | | Colab | Info 10 | | --- | --- | 11 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_jupyter.ipynb) | TANGO_jupyter 12 | 13 | ### 🧬 Code 14 | https://github.com/CyberAgentAILab/TANGO 15 | 16 | ### 📄 Paper 17 | https://arxiv.org/abs/2410.04221 18 | 19 | ### 🌐 Page 20 | https://pantomatrix.github.io/TANGO/ 21 | 22 | ### 🖼 Output 23 | 24 | Input: 25 | 26 | https://github.com/user-attachments/assets/5660da40-af7e-46b5-ba2c-6b1bba9bed67 27 | 28 | Output: 29 | 30 | https://github.com/user-attachments/assets/40486269-7a90-4c03-9904-43113f0f1281 31 | 32 | Input: 33 | 34 | https://github.com/user-attachments/assets/aa5c59c2-46e7-4023-839c-d62d8d771c81 35 | 36 | Output: 37 | 38 | https://github.com/user-attachments/assets/d2665857-2dfc-46ba-98e5-557cb2331fbd 39 | 40 | Output: 41 | 42 | https://github.com/user-attachments/assets/35a50632-6d1e-4bdf-8cb0-df2c43083c8a 43 | 44 | ### 🏢 Sponsor 45 | https://runpod.io 46 | -------------------------------------------------------------------------------- /TANGO_gradio_jupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_gradio_jupyter.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!GIT_LFS_SKIP_SMUDGE=1 git clone -b dev https://github.com/camenduru/TANGO-hf /content/TANGO\n", 22 | "\n", 23 | "!pip install gradio omegaconf wget decord smplx igraph av\n", 24 | "!pip install git+https://github.com/elliottzheng/batch-face\n", 25 | "\n", 26 | "!apt install -y -qq aria2\n", 27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz -d /content/TANGO/emage/smplx_models/smplx -o SMPLX_NEUTRAL_2020.npz\n", 28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/AESKConv_240_100.bin -d /content/TANGO/emage -o AESKConv_240_100.bin\n", 29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/mean_vel_smplxflame_30.npy -d /content/TANGO/emage -o mean_vel_smplxflame_30.npy\n", 30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/mobilenet.pth -d /content/TANGO/Wav2Lip/checkpoints -o mobilenet.pth\n", 31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/resnet50.pth -d /content/TANGO/Wav2Lip/checkpoints -o resnet50.pth\n", 32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/wav2lip_gan.pth -d /content/TANGO/Wav2Lip/checkpoints -o wav2lip_gan.pth\n", 33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp16.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp16.pt\n", 34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp32.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp32.pt\n", 35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_ckpts/ckpt.pth -d /content/TANGO/datasets/cached_ckpts -o ckpt.pth\n", 36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker1.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker1.pkl\n", 37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker7.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker7.pkl\n", 38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker8.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker8.pkl\n", 39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker9.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker9.pkl\n", 40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker1.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker1.json\n", 41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker7.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker7.json\n", 42 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker8.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker8.json\n", 43 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker9.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker9.json\n", 44 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_female_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_female_voice_9_seconds.wav\n", 45 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_male_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_male_voice_9_seconds.wav\n", 46 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4 -d /content/TANGO/datasets/cached_audio -o 1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\n", 47 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4 -d /content/TANGO/datasets/cached_audio -o speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\n", 48 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4 -d /content/TANGO/datasets/cached_audio -o speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\n", 49 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4 -d /content/TANGO/datasets/cached_audio -o speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\n", 50 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/101099-00_18_09-00_18_19.mp4 -d /content/TANGO/datasets/cached_audio -o 101099-00_18_09-00_18_19.mp4\n", 51 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo0.mp4 -d /content/TANGO/datasets/cached_audio -o demo0.mp4\n", 52 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo1.mp4 -d /content/TANGO/datasets/cached_audio -o demo1.mp4\n", 53 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo2.mp4 -d /content/TANGO/datasets/cached_audio -o demo2.mp4\n", 54 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo3.mp4 -d /content/TANGO/datasets/cached_audio -o demo3.mp4\n", 55 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo4.mp4 -d /content/TANGO/datasets/cached_audio -o demo4.mp4\n", 56 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo5.mp4 -d /content/TANGO/datasets/cached_audio -o demo5.mp4\n", 57 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo6.mp4 -d /content/TANGO/datasets/cached_audio -o demo6.mp4\n", 58 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo7.mp4 -d /content/TANGO/datasets/cached_audio -o demo7.mp4\n", 59 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo8.mp4 -d /content/TANGO/datasets/cached_audio -o demo8.mp4\n", 60 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo9.mp4 -d /content/TANGO/datasets/cached_audio -o demo9.mp4\n", 61 | "\n", 62 | "%cd /content/TANGO\n", 63 | "!python app.py" 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "accelerator": "GPU", 69 | "colab": { 70 | "gpuType": "T4", 71 | "provenance": [] 72 | }, 73 | "kernelspec": { 74 | "display_name": "Python 3", 75 | "name": "python3" 76 | }, 77 | "language_info": { 78 | "name": "python" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 0 83 | } 84 | -------------------------------------------------------------------------------- /TANGO_jupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_jupyter.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!GIT_LFS_SKIP_SMUDGE=1 git clone -b dev https://github.com/camenduru/TANGO-hf /content/TANGO\n", 22 | "\n", 23 | "!pip install omegaconf wget decord smplx igraph av\n", 24 | "!pip install git+https://github.com/elliottzheng/batch-face\n", 25 | "\n", 26 | "!apt install -y -qq aria2\n", 27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz -d /content/TANGO/emage/smplx_models/smplx -o SMPLX_NEUTRAL_2020.npz\n", 28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/AESKConv_240_100.bin -d /content/TANGO/emage -o AESKConv_240_100.bin\n", 29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/mean_vel_smplxflame_30.npy -d /content/TANGO/emage -o mean_vel_smplxflame_30.npy\n", 30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/mobilenet.pth -d /content/TANGO/Wav2Lip/checkpoints -o mobilenet.pth\n", 31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/resnet50.pth -d /content/TANGO/Wav2Lip/checkpoints -o resnet50.pth\n", 32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/wav2lip_gan.pth -d /content/TANGO/Wav2Lip/checkpoints -o wav2lip_gan.pth\n", 33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp16.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp16.pt\n", 34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp32.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp32.pt\n", 35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_ckpts/ckpt.pth -d /content/TANGO/datasets/cached_ckpts -o ckpt.pth\n", 36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker1.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker1.pkl\n", 37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker7.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker7.pkl\n", 38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker8.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker8.pkl\n", 39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker9.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker9.pkl\n", 40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.pkl -d /content/TANGO/datasets/cached_graph/show_oliver_test -o Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.pkl\n", 41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker1.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker1.json\n", 42 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker7.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker7.json\n", 43 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker8.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker8.json\n", 44 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker9.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker9.json\n", 45 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json -d /content/TANGO/datasets/data_json/show_oliver_test -o Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json\n", 46 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_female_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_female_voice_9_seconds.wav\n", 47 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_male_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_male_voice_9_seconds.wav\n", 48 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4 -d /content/TANGO/datasets/cached_audio -o 1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\n", 49 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4 -d /content/TANGO/datasets/cached_audio -o speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\n", 50 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4 -d /content/TANGO/datasets/cached_audio -o speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\n", 51 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4 -d /content/TANGO/datasets/cached_audio -o speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\n", 52 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/101099-00_18_09-00_18_19.mp4 -d /content/TANGO/datasets/cached_audio -o 101099-00_18_09-00_18_19.mp4" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "%cd /content/TANGO\n", 62 | "\n", 63 | "import os\n", 64 | "import gc\n", 65 | "import soundfile as sf\n", 66 | "import shutil\n", 67 | "import argparse\n", 68 | "from moviepy.tools import verbose_print\n", 69 | "from omegaconf import OmegaConf\n", 70 | "import random\n", 71 | "import numpy as np\n", 72 | "import json \n", 73 | "import librosa\n", 74 | "import emage.mertic\n", 75 | "from datetime import datetime\n", 76 | "from decord import VideoReader\n", 77 | "from PIL import Image\n", 78 | "import copy\n", 79 | "\n", 80 | "import importlib\n", 81 | "import torch\n", 82 | "import torch.nn as nn\n", 83 | "import torch.nn.functional as F\n", 84 | "from torch.optim import AdamW\n", 85 | "from torch.utils.data import DataLoader\n", 86 | "from torch.nn.parallel import DistributedDataParallel as DDP\n", 87 | "from tqdm import tqdm\n", 88 | "import smplx\n", 89 | "from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip\n", 90 | "import igraph\n", 91 | "\n", 92 | "# import emage\n", 93 | "import utils.rotation_conversions as rc\n", 94 | "from utils.video_io import save_videos_from_pil\n", 95 | "from utils.genextend_inference_utils import adjust_statistics_to_match_reference\n", 96 | "from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor, path_visualization_v2\n", 97 | "\n", 98 | "def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.1, top_k=1, search_mode=\"both\", continue_penalty=0.1):\n", 99 | " T = audio_low_np.shape[0] # Total time steps\n", 100 | " N = len(graph.vs) # Total number of nodes in the graph\n", 101 | "\n", 102 | " # Initialize DP tables\n", 103 | " min_cost = [{} for _ in range(T)] # min_cost[t][node_index] = list of tuples: (cost, prev_node_index, prev_tuple_index, non_continue_count, visited_nodes)\n", 104 | "\n", 105 | " # Initialize the first time step\n", 106 | " start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1]\n", 107 | " for node in start_nodes:\n", 108 | " node_index = node.index\n", 109 | " motion_low = node['motion_low'] # Shape: [C]\n", 110 | " motion_high = node['motion_high'] # Shape: [C]\n", 111 | "\n", 112 | " # Cost using cosine similarity\n", 113 | " if search_mode == \"both\":\n", 114 | " cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T))\n", 115 | " elif search_mode == \"high_level\":\n", 116 | " cost = 1 - np.dot(audio_high_np[0], motion_high.T)\n", 117 | " elif search_mode == \"low_level\":\n", 118 | " cost = 1 - np.dot(audio_low_np[0], motion_low.T)\n", 119 | "\n", 120 | " visited_nodes = {node_index: 1} # Initialize visit count as a dictionary\n", 121 | "\n", 122 | " min_cost[0][node_index] = [ (cost, None, None, 0, visited_nodes) ] # Initialize with no predecessor and 0 non-continue count\n", 123 | "\n", 124 | " # DP over time steps\n", 125 | " for t in range(1, T):\n", 126 | " for node in graph.vs:\n", 127 | " node_index = node.index\n", 128 | " candidates = []\n", 129 | "\n", 130 | " # Incoming edges to the current node\n", 131 | " incoming_edges = graph.es.select(_to=node_index)\n", 132 | " for edge in incoming_edges:\n", 133 | " prev_node_index = edge.source\n", 134 | " edge_id = edge.index\n", 135 | " is_continue_edge = graph.es[edge_id]['is_continue']\n", 136 | " prev_node = graph.vs[prev_node_index]\n", 137 | " if prev_node_index in min_cost[t-1]:\n", 138 | " for tuple_index, (prev_cost, _, _, prev_non_continue_count, prev_visited) in enumerate(min_cost[t-1][prev_node_index]):\n", 139 | " # Loop punishment\n", 140 | " if node_index in prev_visited:\n", 141 | " loop_time = prev_visited[node_index] # Get the count of previous visits\n", 142 | " loop_cost = prev_cost + loop_penalty * np.exp(loop_time) # Apply exponential penalty\n", 143 | " new_visited = prev_visited.copy()\n", 144 | " new_visited[node_index] = loop_time + 1 # Increment visit count\n", 145 | " else:\n", 146 | " loop_cost = prev_cost\n", 147 | " new_visited = prev_visited.copy()\n", 148 | " new_visited[node_index] = 1 # Initialize visit count for the new node\n", 149 | "\n", 150 | " motion_low = node['motion_low'] # Shape: [C]\n", 151 | " motion_high = node['motion_high'] # Shape: [C]\n", 152 | "\n", 153 | " if search_mode == \"both\":\n", 154 | " cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T))\n", 155 | " elif search_mode == \"high_level\":\n", 156 | " cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T)\n", 157 | " elif search_mode == \"low_level\":\n", 158 | " cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T)\n", 159 | "\n", 160 | " # Check if the edge is \"is_continue\"\n", 161 | " if not is_continue_edge:\n", 162 | " non_continue_count = prev_non_continue_count + 1 # Increment the count of non-continue edges\n", 163 | " else:\n", 164 | " non_continue_count = prev_non_continue_count\n", 165 | "\n", 166 | " # Apply the penalty based on the square of the number of non-continuous edges\n", 167 | " continue_penalty_cost = continue_penalty * non_continue_count\n", 168 | "\n", 169 | " total_cost = loop_cost + cost_increment + continue_penalty_cost\n", 170 | "\n", 171 | " candidates.append( (total_cost, prev_node_index, tuple_index, non_continue_count, new_visited) )\n", 172 | "\n", 173 | " # Keep the top k candidates\n", 174 | " if candidates:\n", 175 | " # Sort candidates by total_cost\n", 176 | " candidates.sort(key=lambda x: x[0])\n", 177 | " # Keep top k\n", 178 | " min_cost[t][node_index] = candidates[:top_k]\n", 179 | " else:\n", 180 | " # No candidates, do nothing\n", 181 | " pass\n", 182 | "\n", 183 | " # Collect all possible end paths at time T-1\n", 184 | " end_candidates = []\n", 185 | " for node_index, tuples in min_cost[T-1].items():\n", 186 | " for tuple_index, (cost, _, _, _, _) in enumerate(tuples):\n", 187 | " end_candidates.append( (cost, node_index, tuple_index) )\n", 188 | "\n", 189 | " if not end_candidates:\n", 190 | " print(\"No valid path found.\")\n", 191 | " return [], []\n", 192 | "\n", 193 | " # Sort end candidates by cost\n", 194 | " end_candidates.sort(key=lambda x: x[0])\n", 195 | "\n", 196 | " # Keep top k paths\n", 197 | " top_k_paths_info = end_candidates[:top_k]\n", 198 | "\n", 199 | " # Reconstruct the paths\n", 200 | " optimal_paths = []\n", 201 | " is_continue_lists = []\n", 202 | " for final_cost, node_index, tuple_index in top_k_paths_info:\n", 203 | " optimal_path_indices = []\n", 204 | " current_node_index = node_index\n", 205 | " current_tuple_index = tuple_index\n", 206 | " for t in range(T-1, -1, -1):\n", 207 | " optimal_path_indices.append(current_node_index)\n", 208 | " tuple_data = min_cost[t][current_node_index][current_tuple_index]\n", 209 | " _, prev_node_index, prev_tuple_index, _, _ = tuple_data\n", 210 | " current_node_index = prev_node_index\n", 211 | " current_tuple_index = prev_tuple_index\n", 212 | " if current_node_index is None:\n", 213 | " break # Reached the start node\n", 214 | " optimal_path_indices = optimal_path_indices[::-1] # Reverse to get correct order\n", 215 | " optimal_path = [graph.vs[idx] for idx in optimal_path_indices]\n", 216 | " optimal_paths.append(optimal_path)\n", 217 | "\n", 218 | " # Extract continuity information\n", 219 | " is_continue = []\n", 220 | " for i in range(len(optimal_path) - 1):\n", 221 | " edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index)\n", 222 | " is_cont = graph.es[edge_id]['is_continue']\n", 223 | " is_continue.append(is_cont)\n", 224 | " is_continue_lists.append(is_continue)\n", 225 | "\n", 226 | " print(\"Top {} Paths:\".format(len(optimal_paths)))\n", 227 | " for i, path in enumerate(optimal_paths):\n", 228 | " path_indices = [node.index for node in path]\n", 229 | " print(\"Path {}: Cost: {}, Nodes: {}\".format(i+1, top_k_paths_info[i][0], path_indices))\n", 230 | "\n", 231 | " return optimal_paths, is_continue_lists\n", 232 | "\n", 233 | "\n", 234 | "def test_fn(model, device, iteration, candidate_json_path, test_path, cfg, audio_path, **kwargs):\n", 235 | " torch.set_grad_enabled(False)\n", 236 | " pool_path = candidate_json_path.replace(\"data_json\", \"cached_graph\").replace(\".json\", \".pkl\")\n", 237 | " print(pool_path)\n", 238 | " graph = igraph.Graph.Read_Pickle(fname=pool_path)\n", 239 | " # print(len(graph.vs))\n", 240 | "\n", 241 | " save_dir = os.path.join(test_path, f\"retrieved_motions_{iteration}\")\n", 242 | " os.makedirs(save_dir, exist_ok=True)\n", 243 | "\n", 244 | " actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model\n", 245 | " actual_model.eval()\n", 246 | "\n", 247 | " # with open(candidate_json_path, 'r') as f:\n", 248 | " # candidate_data = json.load(f)\n", 249 | " all_motions = {}\n", 250 | " for i, node in enumerate(graph.vs):\n", 251 | " if all_motions.get(node[\"name\"]) is None:\n", 252 | " all_motions[node[\"name\"]] = [node[\"axis_angle\"].reshape(-1)]\n", 253 | " else:\n", 254 | " all_motions[node[\"name\"]].append(node[\"axis_angle\"].reshape(-1))\n", 255 | " for k, v in all_motions.items():\n", 256 | " all_motions[k] = np.stack(v) # T, J*3\n", 257 | " # print(k, all_motions[k].shape)\n", 258 | " \n", 259 | " window_size = cfg.data.pose_length\n", 260 | " motion_high_all = []\n", 261 | " motion_low_all = []\n", 262 | " for k, v in all_motions.items():\n", 263 | " motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0)\n", 264 | " _, t, _ = motion_tensor.shape\n", 265 | " \n", 266 | " if t >= window_size:\n", 267 | " num_chunks = t // window_size\n", 268 | " motion_high_list = []\n", 269 | " motion_low_list = []\n", 270 | "\n", 271 | " for i in range(num_chunks):\n", 272 | " start_idx = i * window_size\n", 273 | " end_idx = start_idx + window_size\n", 274 | " motion_slice = motion_tensor[:, start_idx:end_idx, :]\n", 275 | " \n", 276 | " motion_features = actual_model.get_motion_features(motion_slice)\n", 277 | " \n", 278 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n", 279 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n", 280 | "\n", 281 | " motion_high_list.append(motion_high[0])\n", 282 | " motion_low_list.append(motion_low[0])\n", 283 | "\n", 284 | " remain_length = t % window_size\n", 285 | " if remain_length > 0:\n", 286 | " start_idx = t - window_size\n", 287 | " motion_slice = motion_tensor[:, start_idx:, :]\n", 288 | "\n", 289 | " motion_features = actual_model.get_motion_features(motion_slice)\n", 290 | " # motion_high = motion_features[\"motion_high_weight\"].cpu().numpy()\n", 291 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n", 292 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n", 293 | "\n", 294 | " motion_high_list.append(motion_high[0][-remain_length:])\n", 295 | " motion_low_list.append(motion_low[0][-remain_length:])\n", 296 | "\n", 297 | " motion_high_all.append(np.concatenate(motion_high_list, axis=0))\n", 298 | " motion_low_all.append(np.concatenate(motion_low_list, axis=0))\n", 299 | "\n", 300 | " else: # t < window_size:\n", 301 | " gap = window_size - t\n", 302 | " motion_slice = torch.cat([motion_tensor, torch.zeros((motion_tensor.shape[0], gap, motion_tensor.shape[2])).to(motion_tensor.device)], 1)\n", 303 | " motion_features = actual_model.get_motion_features(motion_slice)\n", 304 | " # motion_high = motion_features[\"motion_high_weight\"].cpu().numpy()\n", 305 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n", 306 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n", 307 | "\n", 308 | " motion_high_all.append(motion_high[0][:t])\n", 309 | " motion_low_all.append(motion_low[0][:t])\n", 310 | " \n", 311 | " motion_high_all = np.concatenate(motion_high_all, axis=0)\n", 312 | " motion_low_all = np.concatenate(motion_low_all, axis=0)\n", 313 | " # print(motion_high_all.shape, motion_low_all.shape, len(graph.vs))\n", 314 | " motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True)\n", 315 | " motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True)\n", 316 | " assert motion_high_all.shape[0] == len(graph.vs)\n", 317 | " assert motion_low_all.shape[0] == len(graph.vs)\n", 318 | " \n", 319 | " for i, node in enumerate(graph.vs):\n", 320 | " node[\"motion_high\"] = motion_high_all[i]\n", 321 | " node[\"motion_low\"] = motion_low_all[i]\n", 322 | "\n", 323 | " graph = graph_pruning(graph)\n", 324 | " # for gradio, use a subgraph\n", 325 | " if len(graph.vs) > 1800:\n", 326 | " gap = len(graph.vs) - 1800\n", 327 | " start_d = random.randint(0, 1800)\n", 328 | " graph.delete_vertices(range(start_d, start_d + gap))\n", 329 | " ascc_2 = graph.clusters(mode=\"STRONG\")\n", 330 | " graph = ascc_2.giant()\n", 331 | "\n", 332 | " # drop the id of gt\n", 333 | " idx = 0\n", 334 | " audio_waveform, sr = librosa.load(audio_path)\n", 335 | " audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr)\n", 336 | " audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0)\n", 337 | " \n", 338 | " target_length = audio_tensor.shape[1] // cfg.data.audio_sr * 30\n", 339 | " window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30))\n", 340 | " _, t = audio_tensor.shape\n", 341 | " audio_low_list = []\n", 342 | " audio_high_list = []\n", 343 | "\n", 344 | " if t >= window_size:\n", 345 | " num_chunks = t // window_size\n", 346 | " # print(num_chunks, t % window_size)\n", 347 | " for i in range(num_chunks):\n", 348 | " start_idx = i * window_size\n", 349 | " end_idx = start_idx + window_size\n", 350 | " # print(start_idx, end_idx, window_size)\n", 351 | " audio_slice = audio_tensor[:, start_idx:end_idx]\n", 352 | "\n", 353 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n", 354 | " audio_low = model_out_candidates[\"audio_low\"]\n", 355 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n", 356 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n", 357 | " # print(audio_low.shape, audio_high.shape)\n", 358 | "\n", 359 | " audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy()\n", 360 | " audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy()\n", 361 | "\n", 362 | " audio_low_list.append(audio_low)\n", 363 | " audio_high_list.append(audio_high)\n", 364 | " # print(audio_low.shape, audio_high.shape)\n", 365 | " \n", 366 | "\n", 367 | " remain_length = t % window_size\n", 368 | " if remain_length > 1:\n", 369 | " start_idx = t - window_size\n", 370 | " audio_slice = audio_tensor[:, start_idx:]\n", 371 | "\n", 372 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n", 373 | " audio_low = model_out_candidates[\"audio_low\"]\n", 374 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n", 375 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n", 376 | " \n", 377 | " gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1]\n", 378 | " audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy()\n", 379 | " audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy()\n", 380 | " \n", 381 | " # print(audio_low.shape, audio_high.shape)\n", 382 | " audio_low_list.append(audio_low)\n", 383 | " audio_high_list.append(audio_high)\n", 384 | " else:\n", 385 | " gap = window_size - t\n", 386 | " audio_slice = audio_tensor \n", 387 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n", 388 | " audio_low = model_out_candidates[\"audio_low\"]\n", 389 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n", 390 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n", 391 | " \n", 392 | " gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1]\n", 393 | " audio_low = F.normalize(audio_low, dim=2)[0][:gap].cpu().numpy()\n", 394 | " audio_high = F.normalize(audio_high, dim=2)[0][:gap].cpu().numpy()\n", 395 | " audio_low_list.append(audio_low)\n", 396 | " audio_high_list.append(audio_high)\n", 397 | " \n", 398 | " audio_low_all = np.concatenate(audio_low_list, axis=0)\n", 399 | " audio_high_all = np.concatenate(audio_high_list, axis=0)\n", 400 | " path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode=\"both\")\n", 401 | " \n", 402 | " res_motion = []\n", 403 | " counter = 0\n", 404 | " for path, is_continue in zip(path_list, is_continue_list):\n", 405 | " # print(path)\n", 406 | " # res_motion_current = path_visualization(\n", 407 | " # graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n", 408 | " # )\n", 409 | " res_motion_current = path_visualization_v2(\n", 410 | " graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n", 411 | " )\n", 412 | "\n", 413 | " video_temp_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n", 414 | " \n", 415 | " video_reader = VideoReader(video_temp_path)\n", 416 | " video_np = []\n", 417 | " for i in range(len(video_reader)):\n", 418 | " if i == 0: continue\n", 419 | " video_frame = video_reader[i].asnumpy()\n", 420 | " video_np.append(Image.fromarray(video_frame))\n", 421 | " adjusted_video_pil = adjust_statistics_to_match_reference([video_np])\n", 422 | " save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), fps=30, bitrate=2000000)\n", 423 | "\n", 424 | "\n", 425 | " audio_temp_path = audio_path\n", 426 | " lipsync_output_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n", 427 | " checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth' # Update this path to your Wav2Lip checkpoint\n", 428 | " os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth')\n", 429 | "\n", 430 | " res_motion.append(res_motion_current)\n", 431 | " np.savez(os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.npz\"), motion=res_motion_current)\n", 432 | " \n", 433 | " start_node = path[1].index\n", 434 | " end_node = start_node + 100\n", 435 | " print(f\"delete gt-nodes {start_node}, {end_node}\")\n", 436 | " nodes_to_delete = list(range(start_node, end_node))\n", 437 | " graph.delete_vertices(nodes_to_delete)\n", 438 | " graph = graph_pruning(graph)\n", 439 | " path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode=\"both\")\n", 440 | " res_motion = []\n", 441 | " counter = 1\n", 442 | " for path, is_continue in zip(path_list, is_continue_list):\n", 443 | " res_motion_current = path_visualization(\n", 444 | " graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n", 445 | " )\n", 446 | " video_temp_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n", 447 | " \n", 448 | " video_reader = VideoReader(video_temp_path)\n", 449 | " video_np = []\n", 450 | " for i in range(len(video_reader)):\n", 451 | " if i == 0: continue\n", 452 | " video_frame = video_reader[i].asnumpy()\n", 453 | " video_np.append(Image.fromarray(video_frame))\n", 454 | " adjusted_video_pil = adjust_statistics_to_match_reference([video_np])\n", 455 | " save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), fps=30, bitrate=2000000)\n", 456 | "\n", 457 | "\n", 458 | " audio_temp_path = audio_path\n", 459 | " lipsync_output_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n", 460 | " checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth' # Update this path to your Wav2Lip checkpoint\n", 461 | " os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth')\n", 462 | " res_motion.append(res_motion_current)\n", 463 | " np.savez(os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.npz\"), motion=res_motion_current)\n", 464 | " \n", 465 | " result = [\n", 466 | " os.path.join(save_dir, f\"audio_{idx}_retri_0.mp4\"),\n", 467 | " os.path.join(save_dir, f\"audio_{idx}_retri_1.mp4\"),\n", 468 | " os.path.join(save_dir, f\"audio_{idx}_retri_0.npz\"),\n", 469 | " os.path.join(save_dir, f\"audio_{idx}_retri_1.npz\")\n", 470 | " ]\n", 471 | " return result\n", 472 | "\n", 473 | "\n", 474 | "def init_class(module_name, class_name, config, **kwargs):\n", 475 | " module = importlib.import_module(module_name)\n", 476 | " model_class = getattr(module, class_name)\n", 477 | " instance = model_class(config, **kwargs)\n", 478 | " return instance\n", 479 | "\n", 480 | "def seed_everything(seed):\n", 481 | " random.seed(seed)\n", 482 | " np.random.seed(seed)\n", 483 | " torch.manual_seed(seed)\n", 484 | " torch.cuda.manual_seed_all(seed)\n", 485 | "\n", 486 | "def prepare_all(yaml_name):\n", 487 | " if yaml_name.endswith(\".yaml\"):\n", 488 | " config = OmegaConf.load(yaml_name)\n", 489 | " config.exp_name = os.path.basename(yaml_name)[:-5]\n", 490 | " else:\n", 491 | " raise ValueError(\"Unsupported config file format. Only .yaml files are allowed.\")\n", 492 | " save_dir = os.path.join(config.output_dir, config.exp_name)\n", 493 | " os.makedirs(save_dir, exist_ok=True)\n", 494 | " return config\n", 495 | "\n", 496 | "def save_first_10_seconds(video_path, output_path=\"./save_video.mp4\"):\n", 497 | " import cv2\n", 498 | " cap = cv2.VideoCapture(video_path)\n", 499 | " \n", 500 | " if not cap.isOpened():\n", 501 | " return\n", 502 | "\n", 503 | " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", 504 | " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", 505 | " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", 506 | "\n", 507 | " fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n", 508 | " out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n", 509 | "\n", 510 | " frames_to_save = fps * 10\n", 511 | " frame_count = 0\n", 512 | " \n", 513 | " while cap.isOpened() and frame_count < frames_to_save:\n", 514 | " ret, frame = cap.read()\n", 515 | " if not ret:\n", 516 | " break\n", 517 | " out.write(frame)\n", 518 | " frame_count += 1\n", 519 | "\n", 520 | " cap.release()\n", 521 | " out.release()\n", 522 | "\n", 523 | "\n", 524 | "character_name_to_yaml = {\n", 525 | " \"speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\": \"./datasets/data_json/youtube_test/speaker8.json\",\n", 526 | " \"speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\": \"./datasets/data_json/youtube_test/speaker7.json\",\n", 527 | " \"speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\": \"./datasets/data_json/youtube_test/speaker9.json\",\n", 528 | " \"1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\": \"./datasets/data_json/youtube_test/speaker1.json\",\n", 529 | " \"101099-00_18_09-00_18_19.mp4\": \"./datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json\",\n", 530 | "}\n", 531 | "\n", 532 | "def tango(audio_path, character_name, seed, create_graph=False, video_folder_path=None):\n", 533 | " cfg = prepare_all(\"./configs/gradio.yaml\")\n", 534 | " cfg.seed = seed\n", 535 | " seed_everything(cfg.seed)\n", 536 | " experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name)\n", 537 | " saved_audio_path = \"./saved_audio.wav\"\n", 538 | "\n", 539 | " # Load audio from the specified path\n", 540 | " audio_waveform, sample_rate = librosa.load(audio_path, sr=None) # Load with original sample rate\n", 541 | " sf.write(saved_audio_path, audio_waveform, sample_rate)\n", 542 | "\n", 543 | " # Resample the audio to 16000 Hz\n", 544 | " resampled_audio = librosa.resample(audio_waveform, orig_sr=sample_rate, target_sr=16000)\n", 545 | " required_length = int(16000 * (128 / 30)) * 2\n", 546 | " resampled_audio = resampled_audio[:required_length]\n", 547 | " sf.write(saved_audio_path, resampled_audio, 16000)\n", 548 | " audio_path = saved_audio_path # Update the audio_path to point to the saved resampled audio\n", 549 | "\n", 550 | " yaml_name = character_name_to_yaml.get(character_name.split(\"/\")[-1], \"./datasets/data_json/youtube_test/speaker1.json\")\n", 551 | " cfg.data.test_meta_paths = yaml_name\n", 552 | " print(yaml_name, character_name.split(\"/\")[-1])\n", 553 | "\n", 554 | " if character_name.split(\"/\")[-1] not in character_name_to_yaml.keys():\n", 555 | " create_graph = True\n", 556 | " os.makedirs(\"./outputs/tmpvideo/\", exist_ok=True)\n", 557 | " save_first_10_seconds(character_name, \"./outputs/tmpvideo/save_video.mp4\")\n", 558 | "\n", 559 | " if create_graph:\n", 560 | " video_folder_path = \"./outputs/tmpvideo/\"\n", 561 | " data_save_path = \"./outputs/tmpdata/\"\n", 562 | " json_save_path = \"./outputs/save_video.json\"\n", 563 | " graph_save_path = \"./outputs/save_video.pkl\"\n", 564 | " os.system(f\"cd ./SMPLer-X/ && python app.py --video_folder_path {video_folder_path} --data_save_path {data_save_path} --json_save_path {json_save_path} && cd ..\")\n", 565 | " os.system(f\"python ./create_graph.py --json_save_path {json_save_path} --graph_save_path {graph_save_path}\") \n", 566 | " cfg.data.test_meta_paths = json_save_path\n", 567 | "\n", 568 | " smplx_model = smplx.create(\n", 569 | " \"./emage/smplx_models/\", \n", 570 | " model_type='smplx',\n", 571 | " gender='NEUTRAL_2020', \n", 572 | " use_face_contour=False,\n", 573 | " num_betas=300,\n", 574 | " num_expression_coeffs=100, \n", 575 | " ext='npz',\n", 576 | " use_pca=False,\n", 577 | " )\n", 578 | " model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg)\n", 579 | " for param in model.parameters():\n", 580 | " param.requires_grad = False\n", 581 | " model.smplx_model = smplx_model\n", 582 | " model.get_motion_reps = get_motion_reps_tensor\n", 583 | " \n", 584 | " local_rank = 0 \n", 585 | " torch.cuda.set_device(local_rank)\n", 586 | " device = torch.device(\"cuda\", local_rank)\n", 587 | "\n", 588 | " smplx_model = smplx_model.to(device).eval()\n", 589 | " model = model.to(device)\n", 590 | " model.smplx_model = model.smplx_model.to(device)\n", 591 | "\n", 592 | " checkpoint_path = \"./datasets/cached_ckpts/ckpt.pth\"\n", 593 | " checkpoint = torch.load(checkpoint_path)\n", 594 | " state_dict = checkpoint['model_state_dict']\n", 595 | " new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}\n", 596 | " model.load_state_dict(new_state_dict, strict=False)\n", 597 | " \n", 598 | " test_path = os.path.join(experiment_ckpt_dir, f\"test_{0}\")\n", 599 | " os.makedirs(test_path, exist_ok=True)\n", 600 | " result = test_fn(model, device, 0, cfg.data.test_meta_paths, test_path, cfg, audio_path)\n", 601 | " gc.collect()\n", 602 | " torch.cuda.empty_cache()\n", 603 | " return result" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "%cd /content/TANGO\n", 613 | "seed = 42\n", 614 | "wav = \"./datasets/cached_audio/example_male_voice_9_seconds.wav\"\n", 615 | "mp4 = \"./datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\"\n", 616 | "video_output_1, video_output_2, file_output_1, file_output_2 = tango(wav, mp4, seed)\n", 617 | "\n", 618 | "from IPython.display import Video\n", 619 | "# Video(video_output_1, embed=True)\n", 620 | "Video(video_output_2, embed=True)" 621 | ] 622 | } 623 | ], 624 | "metadata": { 625 | "accelerator": "GPU", 626 | "colab": { 627 | "gpuType": "T4", 628 | "provenance": [] 629 | }, 630 | "kernelspec": { 631 | "display_name": "Python 3", 632 | "name": "python3" 633 | }, 634 | "language_info": { 635 | "name": "python" 636 | } 637 | }, 638 | "nbformat": 4, 639 | "nbformat_minor": 0 640 | } 641 | --------------------------------------------------------------------------------