├── README.md
├── TANGO_gradio_jupyter.ipynb
└── TANGO_jupyter.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 |
5 | 🚦 Prototype Model 🚦
6 |
7 | ### 🍊 Jupyter Notebook
8 |
9 | | Colab | Info
10 | | --- | --- |
11 | [](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_jupyter.ipynb) | TANGO_jupyter
12 |
13 | ### 🧬 Code
14 | https://github.com/CyberAgentAILab/TANGO
15 |
16 | ### 📄 Paper
17 | https://arxiv.org/abs/2410.04221
18 |
19 | ### 🌐 Page
20 | https://pantomatrix.github.io/TANGO/
21 |
22 | ### 🖼 Output
23 |
24 | Input:
25 |
26 | https://github.com/user-attachments/assets/5660da40-af7e-46b5-ba2c-6b1bba9bed67
27 |
28 | Output:
29 |
30 | https://github.com/user-attachments/assets/40486269-7a90-4c03-9904-43113f0f1281
31 |
32 | Input:
33 |
34 | https://github.com/user-attachments/assets/aa5c59c2-46e7-4023-839c-d62d8d771c81
35 |
36 | Output:
37 |
38 | https://github.com/user-attachments/assets/d2665857-2dfc-46ba-98e5-557cb2331fbd
39 |
40 | Output:
41 |
42 | https://github.com/user-attachments/assets/35a50632-6d1e-4bdf-8cb0-df2c43083c8a
43 |
44 | ### 🏢 Sponsor
45 | https://runpod.io
46 |
--------------------------------------------------------------------------------
/TANGO_gradio_jupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_gradio_jupyter.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!GIT_LFS_SKIP_SMUDGE=1 git clone -b dev https://github.com/camenduru/TANGO-hf /content/TANGO\n",
22 | "\n",
23 | "!pip install gradio omegaconf wget decord smplx igraph av\n",
24 | "!pip install git+https://github.com/elliottzheng/batch-face\n",
25 | "\n",
26 | "!apt install -y -qq aria2\n",
27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz -d /content/TANGO/emage/smplx_models/smplx -o SMPLX_NEUTRAL_2020.npz\n",
28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/AESKConv_240_100.bin -d /content/TANGO/emage -o AESKConv_240_100.bin\n",
29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/mean_vel_smplxflame_30.npy -d /content/TANGO/emage -o mean_vel_smplxflame_30.npy\n",
30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/mobilenet.pth -d /content/TANGO/Wav2Lip/checkpoints -o mobilenet.pth\n",
31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/resnet50.pth -d /content/TANGO/Wav2Lip/checkpoints -o resnet50.pth\n",
32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/wav2lip_gan.pth -d /content/TANGO/Wav2Lip/checkpoints -o wav2lip_gan.pth\n",
33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp16.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp16.pt\n",
34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp32.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp32.pt\n",
35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_ckpts/ckpt.pth -d /content/TANGO/datasets/cached_ckpts -o ckpt.pth\n",
36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker1.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker1.pkl\n",
37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker7.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker7.pkl\n",
38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker8.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker8.pkl\n",
39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker9.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker9.pkl\n",
40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker1.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker1.json\n",
41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker7.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker7.json\n",
42 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker8.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker8.json\n",
43 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker9.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker9.json\n",
44 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_female_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_female_voice_9_seconds.wav\n",
45 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_male_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_male_voice_9_seconds.wav\n",
46 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4 -d /content/TANGO/datasets/cached_audio -o 1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\n",
47 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4 -d /content/TANGO/datasets/cached_audio -o speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\n",
48 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4 -d /content/TANGO/datasets/cached_audio -o speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\n",
49 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4 -d /content/TANGO/datasets/cached_audio -o speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\n",
50 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/101099-00_18_09-00_18_19.mp4 -d /content/TANGO/datasets/cached_audio -o 101099-00_18_09-00_18_19.mp4\n",
51 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo0.mp4 -d /content/TANGO/datasets/cached_audio -o demo0.mp4\n",
52 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo1.mp4 -d /content/TANGO/datasets/cached_audio -o demo1.mp4\n",
53 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo2.mp4 -d /content/TANGO/datasets/cached_audio -o demo2.mp4\n",
54 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo3.mp4 -d /content/TANGO/datasets/cached_audio -o demo3.mp4\n",
55 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo4.mp4 -d /content/TANGO/datasets/cached_audio -o demo4.mp4\n",
56 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo5.mp4 -d /content/TANGO/datasets/cached_audio -o demo5.mp4\n",
57 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo6.mp4 -d /content/TANGO/datasets/cached_audio -o demo6.mp4\n",
58 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo7.mp4 -d /content/TANGO/datasets/cached_audio -o demo7.mp4\n",
59 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo8.mp4 -d /content/TANGO/datasets/cached_audio -o demo8.mp4\n",
60 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/demo9.mp4 -d /content/TANGO/datasets/cached_audio -o demo9.mp4\n",
61 | "\n",
62 | "%cd /content/TANGO\n",
63 | "!python app.py"
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "accelerator": "GPU",
69 | "colab": {
70 | "gpuType": "T4",
71 | "provenance": []
72 | },
73 | "kernelspec": {
74 | "display_name": "Python 3",
75 | "name": "python3"
76 | },
77 | "language_info": {
78 | "name": "python"
79 | }
80 | },
81 | "nbformat": 4,
82 | "nbformat_minor": 0
83 | }
84 |
--------------------------------------------------------------------------------
/TANGO_jupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/TANGO-jupyter/blob/main/TANGO_jupyter.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!GIT_LFS_SKIP_SMUDGE=1 git clone -b dev https://github.com/camenduru/TANGO-hf /content/TANGO\n",
22 | "\n",
23 | "!pip install omegaconf wget decord smplx igraph av\n",
24 | "!pip install git+https://github.com/elliottzheng/batch-face\n",
25 | "\n",
26 | "!apt install -y -qq aria2\n",
27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz -d /content/TANGO/emage/smplx_models/smplx -o SMPLX_NEUTRAL_2020.npz\n",
28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/AESKConv_240_100.bin -d /content/TANGO/emage -o AESKConv_240_100.bin\n",
29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/emage/mean_vel_smplxflame_30.npy -d /content/TANGO/emage -o mean_vel_smplxflame_30.npy\n",
30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/mobilenet.pth -d /content/TANGO/Wav2Lip/checkpoints -o mobilenet.pth\n",
31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/resnet50.pth -d /content/TANGO/Wav2Lip/checkpoints -o resnet50.pth\n",
32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/Wav2Lip/checkpoints/wav2lip_gan.pth -d /content/TANGO/Wav2Lip/checkpoints -o wav2lip_gan.pth\n",
33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp16.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp16.pt\n",
34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/frame-interpolation-pytorch/film_net_fp32.pt -d /content/TANGO/frame-interpolation-pytorch -o film_net_fp32.pt\n",
35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_ckpts/ckpt.pth -d /content/TANGO/datasets/cached_ckpts -o ckpt.pth\n",
36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker1.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker1.pkl\n",
37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker7.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker7.pkl\n",
38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker8.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker8.pkl\n",
39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/youtube_test/speaker9.pkl -d /content/TANGO/datasets/cached_graph/youtube_test -o speaker9.pkl\n",
40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_graph/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.pkl -d /content/TANGO/datasets/cached_graph/show_oliver_test -o Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.pkl\n",
41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker1.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker1.json\n",
42 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker7.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker7.json\n",
43 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker8.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker8.json\n",
44 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/youtube_test/speaker9.json -d /content/TANGO/datasets/data_json/youtube_test -o speaker9.json\n",
45 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json -d /content/TANGO/datasets/data_json/show_oliver_test -o Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json\n",
46 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_female_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_female_voice_9_seconds.wav\n",
47 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/example_male_voice_9_seconds.wav -d /content/TANGO/datasets/cached_audio -o example_male_voice_9_seconds.wav\n",
48 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4 -d /content/TANGO/datasets/cached_audio -o 1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\n",
49 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4 -d /content/TANGO/datasets/cached_audio -o speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\n",
50 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4 -d /content/TANGO/datasets/cached_audio -o speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\n",
51 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4 -d /content/TANGO/datasets/cached_audio -o speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\n",
52 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/H-Liu1997/TANGO/resolve/main/datasets/cached_audio/101099-00_18_09-00_18_19.mp4 -d /content/TANGO/datasets/cached_audio -o 101099-00_18_09-00_18_19.mp4"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "%cd /content/TANGO\n",
62 | "\n",
63 | "import os\n",
64 | "import gc\n",
65 | "import soundfile as sf\n",
66 | "import shutil\n",
67 | "import argparse\n",
68 | "from moviepy.tools import verbose_print\n",
69 | "from omegaconf import OmegaConf\n",
70 | "import random\n",
71 | "import numpy as np\n",
72 | "import json \n",
73 | "import librosa\n",
74 | "import emage.mertic\n",
75 | "from datetime import datetime\n",
76 | "from decord import VideoReader\n",
77 | "from PIL import Image\n",
78 | "import copy\n",
79 | "\n",
80 | "import importlib\n",
81 | "import torch\n",
82 | "import torch.nn as nn\n",
83 | "import torch.nn.functional as F\n",
84 | "from torch.optim import AdamW\n",
85 | "from torch.utils.data import DataLoader\n",
86 | "from torch.nn.parallel import DistributedDataParallel as DDP\n",
87 | "from tqdm import tqdm\n",
88 | "import smplx\n",
89 | "from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip\n",
90 | "import igraph\n",
91 | "\n",
92 | "# import emage\n",
93 | "import utils.rotation_conversions as rc\n",
94 | "from utils.video_io import save_videos_from_pil\n",
95 | "from utils.genextend_inference_utils import adjust_statistics_to_match_reference\n",
96 | "from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor, path_visualization_v2\n",
97 | "\n",
98 | "def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.1, top_k=1, search_mode=\"both\", continue_penalty=0.1):\n",
99 | " T = audio_low_np.shape[0] # Total time steps\n",
100 | " N = len(graph.vs) # Total number of nodes in the graph\n",
101 | "\n",
102 | " # Initialize DP tables\n",
103 | " min_cost = [{} for _ in range(T)] # min_cost[t][node_index] = list of tuples: (cost, prev_node_index, prev_tuple_index, non_continue_count, visited_nodes)\n",
104 | "\n",
105 | " # Initialize the first time step\n",
106 | " start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1]\n",
107 | " for node in start_nodes:\n",
108 | " node_index = node.index\n",
109 | " motion_low = node['motion_low'] # Shape: [C]\n",
110 | " motion_high = node['motion_high'] # Shape: [C]\n",
111 | "\n",
112 | " # Cost using cosine similarity\n",
113 | " if search_mode == \"both\":\n",
114 | " cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T))\n",
115 | " elif search_mode == \"high_level\":\n",
116 | " cost = 1 - np.dot(audio_high_np[0], motion_high.T)\n",
117 | " elif search_mode == \"low_level\":\n",
118 | " cost = 1 - np.dot(audio_low_np[0], motion_low.T)\n",
119 | "\n",
120 | " visited_nodes = {node_index: 1} # Initialize visit count as a dictionary\n",
121 | "\n",
122 | " min_cost[0][node_index] = [ (cost, None, None, 0, visited_nodes) ] # Initialize with no predecessor and 0 non-continue count\n",
123 | "\n",
124 | " # DP over time steps\n",
125 | " for t in range(1, T):\n",
126 | " for node in graph.vs:\n",
127 | " node_index = node.index\n",
128 | " candidates = []\n",
129 | "\n",
130 | " # Incoming edges to the current node\n",
131 | " incoming_edges = graph.es.select(_to=node_index)\n",
132 | " for edge in incoming_edges:\n",
133 | " prev_node_index = edge.source\n",
134 | " edge_id = edge.index\n",
135 | " is_continue_edge = graph.es[edge_id]['is_continue']\n",
136 | " prev_node = graph.vs[prev_node_index]\n",
137 | " if prev_node_index in min_cost[t-1]:\n",
138 | " for tuple_index, (prev_cost, _, _, prev_non_continue_count, prev_visited) in enumerate(min_cost[t-1][prev_node_index]):\n",
139 | " # Loop punishment\n",
140 | " if node_index in prev_visited:\n",
141 | " loop_time = prev_visited[node_index] # Get the count of previous visits\n",
142 | " loop_cost = prev_cost + loop_penalty * np.exp(loop_time) # Apply exponential penalty\n",
143 | " new_visited = prev_visited.copy()\n",
144 | " new_visited[node_index] = loop_time + 1 # Increment visit count\n",
145 | " else:\n",
146 | " loop_cost = prev_cost\n",
147 | " new_visited = prev_visited.copy()\n",
148 | " new_visited[node_index] = 1 # Initialize visit count for the new node\n",
149 | "\n",
150 | " motion_low = node['motion_low'] # Shape: [C]\n",
151 | " motion_high = node['motion_high'] # Shape: [C]\n",
152 | "\n",
153 | " if search_mode == \"both\":\n",
154 | " cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T))\n",
155 | " elif search_mode == \"high_level\":\n",
156 | " cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T)\n",
157 | " elif search_mode == \"low_level\":\n",
158 | " cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T)\n",
159 | "\n",
160 | " # Check if the edge is \"is_continue\"\n",
161 | " if not is_continue_edge:\n",
162 | " non_continue_count = prev_non_continue_count + 1 # Increment the count of non-continue edges\n",
163 | " else:\n",
164 | " non_continue_count = prev_non_continue_count\n",
165 | "\n",
166 | " # Apply the penalty based on the square of the number of non-continuous edges\n",
167 | " continue_penalty_cost = continue_penalty * non_continue_count\n",
168 | "\n",
169 | " total_cost = loop_cost + cost_increment + continue_penalty_cost\n",
170 | "\n",
171 | " candidates.append( (total_cost, prev_node_index, tuple_index, non_continue_count, new_visited) )\n",
172 | "\n",
173 | " # Keep the top k candidates\n",
174 | " if candidates:\n",
175 | " # Sort candidates by total_cost\n",
176 | " candidates.sort(key=lambda x: x[0])\n",
177 | " # Keep top k\n",
178 | " min_cost[t][node_index] = candidates[:top_k]\n",
179 | " else:\n",
180 | " # No candidates, do nothing\n",
181 | " pass\n",
182 | "\n",
183 | " # Collect all possible end paths at time T-1\n",
184 | " end_candidates = []\n",
185 | " for node_index, tuples in min_cost[T-1].items():\n",
186 | " for tuple_index, (cost, _, _, _, _) in enumerate(tuples):\n",
187 | " end_candidates.append( (cost, node_index, tuple_index) )\n",
188 | "\n",
189 | " if not end_candidates:\n",
190 | " print(\"No valid path found.\")\n",
191 | " return [], []\n",
192 | "\n",
193 | " # Sort end candidates by cost\n",
194 | " end_candidates.sort(key=lambda x: x[0])\n",
195 | "\n",
196 | " # Keep top k paths\n",
197 | " top_k_paths_info = end_candidates[:top_k]\n",
198 | "\n",
199 | " # Reconstruct the paths\n",
200 | " optimal_paths = []\n",
201 | " is_continue_lists = []\n",
202 | " for final_cost, node_index, tuple_index in top_k_paths_info:\n",
203 | " optimal_path_indices = []\n",
204 | " current_node_index = node_index\n",
205 | " current_tuple_index = tuple_index\n",
206 | " for t in range(T-1, -1, -1):\n",
207 | " optimal_path_indices.append(current_node_index)\n",
208 | " tuple_data = min_cost[t][current_node_index][current_tuple_index]\n",
209 | " _, prev_node_index, prev_tuple_index, _, _ = tuple_data\n",
210 | " current_node_index = prev_node_index\n",
211 | " current_tuple_index = prev_tuple_index\n",
212 | " if current_node_index is None:\n",
213 | " break # Reached the start node\n",
214 | " optimal_path_indices = optimal_path_indices[::-1] # Reverse to get correct order\n",
215 | " optimal_path = [graph.vs[idx] for idx in optimal_path_indices]\n",
216 | " optimal_paths.append(optimal_path)\n",
217 | "\n",
218 | " # Extract continuity information\n",
219 | " is_continue = []\n",
220 | " for i in range(len(optimal_path) - 1):\n",
221 | " edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index)\n",
222 | " is_cont = graph.es[edge_id]['is_continue']\n",
223 | " is_continue.append(is_cont)\n",
224 | " is_continue_lists.append(is_continue)\n",
225 | "\n",
226 | " print(\"Top {} Paths:\".format(len(optimal_paths)))\n",
227 | " for i, path in enumerate(optimal_paths):\n",
228 | " path_indices = [node.index for node in path]\n",
229 | " print(\"Path {}: Cost: {}, Nodes: {}\".format(i+1, top_k_paths_info[i][0], path_indices))\n",
230 | "\n",
231 | " return optimal_paths, is_continue_lists\n",
232 | "\n",
233 | "\n",
234 | "def test_fn(model, device, iteration, candidate_json_path, test_path, cfg, audio_path, **kwargs):\n",
235 | " torch.set_grad_enabled(False)\n",
236 | " pool_path = candidate_json_path.replace(\"data_json\", \"cached_graph\").replace(\".json\", \".pkl\")\n",
237 | " print(pool_path)\n",
238 | " graph = igraph.Graph.Read_Pickle(fname=pool_path)\n",
239 | " # print(len(graph.vs))\n",
240 | "\n",
241 | " save_dir = os.path.join(test_path, f\"retrieved_motions_{iteration}\")\n",
242 | " os.makedirs(save_dir, exist_ok=True)\n",
243 | "\n",
244 | " actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model\n",
245 | " actual_model.eval()\n",
246 | "\n",
247 | " # with open(candidate_json_path, 'r') as f:\n",
248 | " # candidate_data = json.load(f)\n",
249 | " all_motions = {}\n",
250 | " for i, node in enumerate(graph.vs):\n",
251 | " if all_motions.get(node[\"name\"]) is None:\n",
252 | " all_motions[node[\"name\"]] = [node[\"axis_angle\"].reshape(-1)]\n",
253 | " else:\n",
254 | " all_motions[node[\"name\"]].append(node[\"axis_angle\"].reshape(-1))\n",
255 | " for k, v in all_motions.items():\n",
256 | " all_motions[k] = np.stack(v) # T, J*3\n",
257 | " # print(k, all_motions[k].shape)\n",
258 | " \n",
259 | " window_size = cfg.data.pose_length\n",
260 | " motion_high_all = []\n",
261 | " motion_low_all = []\n",
262 | " for k, v in all_motions.items():\n",
263 | " motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0)\n",
264 | " _, t, _ = motion_tensor.shape\n",
265 | " \n",
266 | " if t >= window_size:\n",
267 | " num_chunks = t // window_size\n",
268 | " motion_high_list = []\n",
269 | " motion_low_list = []\n",
270 | "\n",
271 | " for i in range(num_chunks):\n",
272 | " start_idx = i * window_size\n",
273 | " end_idx = start_idx + window_size\n",
274 | " motion_slice = motion_tensor[:, start_idx:end_idx, :]\n",
275 | " \n",
276 | " motion_features = actual_model.get_motion_features(motion_slice)\n",
277 | " \n",
278 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n",
279 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n",
280 | "\n",
281 | " motion_high_list.append(motion_high[0])\n",
282 | " motion_low_list.append(motion_low[0])\n",
283 | "\n",
284 | " remain_length = t % window_size\n",
285 | " if remain_length > 0:\n",
286 | " start_idx = t - window_size\n",
287 | " motion_slice = motion_tensor[:, start_idx:, :]\n",
288 | "\n",
289 | " motion_features = actual_model.get_motion_features(motion_slice)\n",
290 | " # motion_high = motion_features[\"motion_high_weight\"].cpu().numpy()\n",
291 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n",
292 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n",
293 | "\n",
294 | " motion_high_list.append(motion_high[0][-remain_length:])\n",
295 | " motion_low_list.append(motion_low[0][-remain_length:])\n",
296 | "\n",
297 | " motion_high_all.append(np.concatenate(motion_high_list, axis=0))\n",
298 | " motion_low_all.append(np.concatenate(motion_low_list, axis=0))\n",
299 | "\n",
300 | " else: # t < window_size:\n",
301 | " gap = window_size - t\n",
302 | " motion_slice = torch.cat([motion_tensor, torch.zeros((motion_tensor.shape[0], gap, motion_tensor.shape[2])).to(motion_tensor.device)], 1)\n",
303 | " motion_features = actual_model.get_motion_features(motion_slice)\n",
304 | " # motion_high = motion_features[\"motion_high_weight\"].cpu().numpy()\n",
305 | " motion_low = motion_features[\"motion_low\"].cpu().numpy()\n",
306 | " motion_high = motion_features[\"motion_cls\"].unsqueeze(0).repeat(1, motion_low.shape[1], 1).cpu().numpy()\n",
307 | "\n",
308 | " motion_high_all.append(motion_high[0][:t])\n",
309 | " motion_low_all.append(motion_low[0][:t])\n",
310 | " \n",
311 | " motion_high_all = np.concatenate(motion_high_all, axis=0)\n",
312 | " motion_low_all = np.concatenate(motion_low_all, axis=0)\n",
313 | " # print(motion_high_all.shape, motion_low_all.shape, len(graph.vs))\n",
314 | " motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True)\n",
315 | " motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True)\n",
316 | " assert motion_high_all.shape[0] == len(graph.vs)\n",
317 | " assert motion_low_all.shape[0] == len(graph.vs)\n",
318 | " \n",
319 | " for i, node in enumerate(graph.vs):\n",
320 | " node[\"motion_high\"] = motion_high_all[i]\n",
321 | " node[\"motion_low\"] = motion_low_all[i]\n",
322 | "\n",
323 | " graph = graph_pruning(graph)\n",
324 | " # for gradio, use a subgraph\n",
325 | " if len(graph.vs) > 1800:\n",
326 | " gap = len(graph.vs) - 1800\n",
327 | " start_d = random.randint(0, 1800)\n",
328 | " graph.delete_vertices(range(start_d, start_d + gap))\n",
329 | " ascc_2 = graph.clusters(mode=\"STRONG\")\n",
330 | " graph = ascc_2.giant()\n",
331 | "\n",
332 | " # drop the id of gt\n",
333 | " idx = 0\n",
334 | " audio_waveform, sr = librosa.load(audio_path)\n",
335 | " audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr)\n",
336 | " audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0)\n",
337 | " \n",
338 | " target_length = audio_tensor.shape[1] // cfg.data.audio_sr * 30\n",
339 | " window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30))\n",
340 | " _, t = audio_tensor.shape\n",
341 | " audio_low_list = []\n",
342 | " audio_high_list = []\n",
343 | "\n",
344 | " if t >= window_size:\n",
345 | " num_chunks = t // window_size\n",
346 | " # print(num_chunks, t % window_size)\n",
347 | " for i in range(num_chunks):\n",
348 | " start_idx = i * window_size\n",
349 | " end_idx = start_idx + window_size\n",
350 | " # print(start_idx, end_idx, window_size)\n",
351 | " audio_slice = audio_tensor[:, start_idx:end_idx]\n",
352 | "\n",
353 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n",
354 | " audio_low = model_out_candidates[\"audio_low\"]\n",
355 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n",
356 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n",
357 | " # print(audio_low.shape, audio_high.shape)\n",
358 | "\n",
359 | " audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy()\n",
360 | " audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy()\n",
361 | "\n",
362 | " audio_low_list.append(audio_low)\n",
363 | " audio_high_list.append(audio_high)\n",
364 | " # print(audio_low.shape, audio_high.shape)\n",
365 | " \n",
366 | "\n",
367 | " remain_length = t % window_size\n",
368 | " if remain_length > 1:\n",
369 | " start_idx = t - window_size\n",
370 | " audio_slice = audio_tensor[:, start_idx:]\n",
371 | "\n",
372 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n",
373 | " audio_low = model_out_candidates[\"audio_low\"]\n",
374 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n",
375 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n",
376 | " \n",
377 | " gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1]\n",
378 | " audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy()\n",
379 | " audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy()\n",
380 | " \n",
381 | " # print(audio_low.shape, audio_high.shape)\n",
382 | " audio_low_list.append(audio_low)\n",
383 | " audio_high_list.append(audio_high)\n",
384 | " else:\n",
385 | " gap = window_size - t\n",
386 | " audio_slice = audio_tensor \n",
387 | " model_out_candidates = actual_model.get_audio_features(audio_slice)\n",
388 | " audio_low = model_out_candidates[\"audio_low\"]\n",
389 | " # audio_high = model_out_candidates[\"audio_high_weight\"]\n",
390 | " audio_high = model_out_candidates[\"audio_cls\"].unsqueeze(0).repeat(1, audio_low.shape[1], 1)\n",
391 | " \n",
392 | " gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1]\n",
393 | " audio_low = F.normalize(audio_low, dim=2)[0][:gap].cpu().numpy()\n",
394 | " audio_high = F.normalize(audio_high, dim=2)[0][:gap].cpu().numpy()\n",
395 | " audio_low_list.append(audio_low)\n",
396 | " audio_high_list.append(audio_high)\n",
397 | " \n",
398 | " audio_low_all = np.concatenate(audio_low_list, axis=0)\n",
399 | " audio_high_all = np.concatenate(audio_high_list, axis=0)\n",
400 | " path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode=\"both\")\n",
401 | " \n",
402 | " res_motion = []\n",
403 | " counter = 0\n",
404 | " for path, is_continue in zip(path_list, is_continue_list):\n",
405 | " # print(path)\n",
406 | " # res_motion_current = path_visualization(\n",
407 | " # graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n",
408 | " # )\n",
409 | " res_motion_current = path_visualization_v2(\n",
410 | " graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n",
411 | " )\n",
412 | "\n",
413 | " video_temp_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n",
414 | " \n",
415 | " video_reader = VideoReader(video_temp_path)\n",
416 | " video_np = []\n",
417 | " for i in range(len(video_reader)):\n",
418 | " if i == 0: continue\n",
419 | " video_frame = video_reader[i].asnumpy()\n",
420 | " video_np.append(Image.fromarray(video_frame))\n",
421 | " adjusted_video_pil = adjust_statistics_to_match_reference([video_np])\n",
422 | " save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), fps=30, bitrate=2000000)\n",
423 | "\n",
424 | "\n",
425 | " audio_temp_path = audio_path\n",
426 | " lipsync_output_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n",
427 | " checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth' # Update this path to your Wav2Lip checkpoint\n",
428 | " os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth')\n",
429 | "\n",
430 | " res_motion.append(res_motion_current)\n",
431 | " np.savez(os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.npz\"), motion=res_motion_current)\n",
432 | " \n",
433 | " start_node = path[1].index\n",
434 | " end_node = start_node + 100\n",
435 | " print(f\"delete gt-nodes {start_node}, {end_node}\")\n",
436 | " nodes_to_delete = list(range(start_node, end_node))\n",
437 | " graph.delete_vertices(nodes_to_delete)\n",
438 | " graph = graph_pruning(graph)\n",
439 | " path_list, is_continue_list = search_path_dp(graph, audio_low_all, audio_high_all, top_k=1, search_mode=\"both\")\n",
440 | " res_motion = []\n",
441 | " counter = 1\n",
442 | " for path, is_continue in zip(path_list, is_continue_list):\n",
443 | " res_motion_current = path_visualization(\n",
444 | " graph, path, is_continue, os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), audio_path=audio_path, return_motion=True, verbose_continue=True\n",
445 | " )\n",
446 | " video_temp_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n",
447 | " \n",
448 | " video_reader = VideoReader(video_temp_path)\n",
449 | " video_np = []\n",
450 | " for i in range(len(video_reader)):\n",
451 | " if i == 0: continue\n",
452 | " video_frame = video_reader[i].asnumpy()\n",
453 | " video_np.append(Image.fromarray(video_frame))\n",
454 | " adjusted_video_pil = adjust_statistics_to_match_reference([video_np])\n",
455 | " save_videos_from_pil(adjusted_video_pil[0], os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\"), fps=30, bitrate=2000000)\n",
456 | "\n",
457 | "\n",
458 | " audio_temp_path = audio_path\n",
459 | " lipsync_output_path = os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.mp4\")\n",
460 | " checkpoint_path = './Wav2Lip/checkpoints/wav2lip_gan.pth' # Update this path to your Wav2Lip checkpoint\n",
461 | " os.system(f'python ./Wav2Lip/inference.py --checkpoint_path {checkpoint_path} --face {video_temp_path} --audio {audio_temp_path} --outfile {lipsync_output_path} --nosmooth')\n",
462 | " res_motion.append(res_motion_current)\n",
463 | " np.savez(os.path.join(save_dir, f\"audio_{idx}_retri_{counter}.npz\"), motion=res_motion_current)\n",
464 | " \n",
465 | " result = [\n",
466 | " os.path.join(save_dir, f\"audio_{idx}_retri_0.mp4\"),\n",
467 | " os.path.join(save_dir, f\"audio_{idx}_retri_1.mp4\"),\n",
468 | " os.path.join(save_dir, f\"audio_{idx}_retri_0.npz\"),\n",
469 | " os.path.join(save_dir, f\"audio_{idx}_retri_1.npz\")\n",
470 | " ]\n",
471 | " return result\n",
472 | "\n",
473 | "\n",
474 | "def init_class(module_name, class_name, config, **kwargs):\n",
475 | " module = importlib.import_module(module_name)\n",
476 | " model_class = getattr(module, class_name)\n",
477 | " instance = model_class(config, **kwargs)\n",
478 | " return instance\n",
479 | "\n",
480 | "def seed_everything(seed):\n",
481 | " random.seed(seed)\n",
482 | " np.random.seed(seed)\n",
483 | " torch.manual_seed(seed)\n",
484 | " torch.cuda.manual_seed_all(seed)\n",
485 | "\n",
486 | "def prepare_all(yaml_name):\n",
487 | " if yaml_name.endswith(\".yaml\"):\n",
488 | " config = OmegaConf.load(yaml_name)\n",
489 | " config.exp_name = os.path.basename(yaml_name)[:-5]\n",
490 | " else:\n",
491 | " raise ValueError(\"Unsupported config file format. Only .yaml files are allowed.\")\n",
492 | " save_dir = os.path.join(config.output_dir, config.exp_name)\n",
493 | " os.makedirs(save_dir, exist_ok=True)\n",
494 | " return config\n",
495 | "\n",
496 | "def save_first_10_seconds(video_path, output_path=\"./save_video.mp4\"):\n",
497 | " import cv2\n",
498 | " cap = cv2.VideoCapture(video_path)\n",
499 | " \n",
500 | " if not cap.isOpened():\n",
501 | " return\n",
502 | "\n",
503 | " fps = int(cap.get(cv2.CAP_PROP_FPS))\n",
504 | " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
505 | " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
506 | "\n",
507 | " fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
508 | " out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n",
509 | "\n",
510 | " frames_to_save = fps * 10\n",
511 | " frame_count = 0\n",
512 | " \n",
513 | " while cap.isOpened() and frame_count < frames_to_save:\n",
514 | " ret, frame = cap.read()\n",
515 | " if not ret:\n",
516 | " break\n",
517 | " out.write(frame)\n",
518 | " frame_count += 1\n",
519 | "\n",
520 | " cap.release()\n",
521 | " out.release()\n",
522 | "\n",
523 | "\n",
524 | "character_name_to_yaml = {\n",
525 | " \"speaker8_jjRWaMCWs44_00-00-30.16_00-00-33.32.mp4\": \"./datasets/data_json/youtube_test/speaker8.json\",\n",
526 | " \"speaker7_iuYlGRnC7J8_00-00-0.00_00-00-3.25.mp4\": \"./datasets/data_json/youtube_test/speaker7.json\",\n",
527 | " \"speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\": \"./datasets/data_json/youtube_test/speaker9.json\",\n",
528 | " \"1wrQ6Msp7wM_00-00-39.69_00-00-45.68.mp4\": \"./datasets/data_json/youtube_test/speaker1.json\",\n",
529 | " \"101099-00_18_09-00_18_19.mp4\": \"./datasets/data_json/show_oliver_test/Stupid_Watergate_-_Last_Week_Tonight_with_John_Oliver_HBO-FVFdsl29s_Q.mkv.json\",\n",
530 | "}\n",
531 | "\n",
532 | "def tango(audio_path, character_name, seed, create_graph=False, video_folder_path=None):\n",
533 | " cfg = prepare_all(\"./configs/gradio.yaml\")\n",
534 | " cfg.seed = seed\n",
535 | " seed_everything(cfg.seed)\n",
536 | " experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name)\n",
537 | " saved_audio_path = \"./saved_audio.wav\"\n",
538 | "\n",
539 | " # Load audio from the specified path\n",
540 | " audio_waveform, sample_rate = librosa.load(audio_path, sr=None) # Load with original sample rate\n",
541 | " sf.write(saved_audio_path, audio_waveform, sample_rate)\n",
542 | "\n",
543 | " # Resample the audio to 16000 Hz\n",
544 | " resampled_audio = librosa.resample(audio_waveform, orig_sr=sample_rate, target_sr=16000)\n",
545 | " required_length = int(16000 * (128 / 30)) * 2\n",
546 | " resampled_audio = resampled_audio[:required_length]\n",
547 | " sf.write(saved_audio_path, resampled_audio, 16000)\n",
548 | " audio_path = saved_audio_path # Update the audio_path to point to the saved resampled audio\n",
549 | "\n",
550 | " yaml_name = character_name_to_yaml.get(character_name.split(\"/\")[-1], \"./datasets/data_json/youtube_test/speaker1.json\")\n",
551 | " cfg.data.test_meta_paths = yaml_name\n",
552 | " print(yaml_name, character_name.split(\"/\")[-1])\n",
553 | "\n",
554 | " if character_name.split(\"/\")[-1] not in character_name_to_yaml.keys():\n",
555 | " create_graph = True\n",
556 | " os.makedirs(\"./outputs/tmpvideo/\", exist_ok=True)\n",
557 | " save_first_10_seconds(character_name, \"./outputs/tmpvideo/save_video.mp4\")\n",
558 | "\n",
559 | " if create_graph:\n",
560 | " video_folder_path = \"./outputs/tmpvideo/\"\n",
561 | " data_save_path = \"./outputs/tmpdata/\"\n",
562 | " json_save_path = \"./outputs/save_video.json\"\n",
563 | " graph_save_path = \"./outputs/save_video.pkl\"\n",
564 | " os.system(f\"cd ./SMPLer-X/ && python app.py --video_folder_path {video_folder_path} --data_save_path {data_save_path} --json_save_path {json_save_path} && cd ..\")\n",
565 | " os.system(f\"python ./create_graph.py --json_save_path {json_save_path} --graph_save_path {graph_save_path}\") \n",
566 | " cfg.data.test_meta_paths = json_save_path\n",
567 | "\n",
568 | " smplx_model = smplx.create(\n",
569 | " \"./emage/smplx_models/\", \n",
570 | " model_type='smplx',\n",
571 | " gender='NEUTRAL_2020', \n",
572 | " use_face_contour=False,\n",
573 | " num_betas=300,\n",
574 | " num_expression_coeffs=100, \n",
575 | " ext='npz',\n",
576 | " use_pca=False,\n",
577 | " )\n",
578 | " model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg)\n",
579 | " for param in model.parameters():\n",
580 | " param.requires_grad = False\n",
581 | " model.smplx_model = smplx_model\n",
582 | " model.get_motion_reps = get_motion_reps_tensor\n",
583 | " \n",
584 | " local_rank = 0 \n",
585 | " torch.cuda.set_device(local_rank)\n",
586 | " device = torch.device(\"cuda\", local_rank)\n",
587 | "\n",
588 | " smplx_model = smplx_model.to(device).eval()\n",
589 | " model = model.to(device)\n",
590 | " model.smplx_model = model.smplx_model.to(device)\n",
591 | "\n",
592 | " checkpoint_path = \"./datasets/cached_ckpts/ckpt.pth\"\n",
593 | " checkpoint = torch.load(checkpoint_path)\n",
594 | " state_dict = checkpoint['model_state_dict']\n",
595 | " new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}\n",
596 | " model.load_state_dict(new_state_dict, strict=False)\n",
597 | " \n",
598 | " test_path = os.path.join(experiment_ckpt_dir, f\"test_{0}\")\n",
599 | " os.makedirs(test_path, exist_ok=True)\n",
600 | " result = test_fn(model, device, 0, cfg.data.test_meta_paths, test_path, cfg, audio_path)\n",
601 | " gc.collect()\n",
602 | " torch.cuda.empty_cache()\n",
603 | " return result"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": null,
609 | "metadata": {},
610 | "outputs": [],
611 | "source": [
612 | "%cd /content/TANGO\n",
613 | "seed = 42\n",
614 | "wav = \"./datasets/cached_audio/example_male_voice_9_seconds.wav\"\n",
615 | "mp4 = \"./datasets/cached_audio/speaker9_o7Ik1OB4TaE_00-00-38.15_00-00-42.33.mp4\"\n",
616 | "video_output_1, video_output_2, file_output_1, file_output_2 = tango(wav, mp4, seed)\n",
617 | "\n",
618 | "from IPython.display import Video\n",
619 | "# Video(video_output_1, embed=True)\n",
620 | "Video(video_output_2, embed=True)"
621 | ]
622 | }
623 | ],
624 | "metadata": {
625 | "accelerator": "GPU",
626 | "colab": {
627 | "gpuType": "T4",
628 | "provenance": []
629 | },
630 | "kernelspec": {
631 | "display_name": "Python 3",
632 | "name": "python3"
633 | },
634 | "language_info": {
635 | "name": "python"
636 | }
637 | },
638 | "nbformat": 4,
639 | "nbformat_minor": 0
640 | }
641 |
--------------------------------------------------------------------------------