├── .gitignore ├── .idea ├── .gitignore ├── DH_live.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── app.py ├── data ├── face_pts_mean.txt ├── face_pts_mean_mainKps.txt └── pca.pkl ├── data_preparation_mini.py ├── data_preparation_web.py ├── demo_mini.py ├── mini_live ├── bs_texture.png ├── bs_texture_halfFace.png ├── face_fusion_mask.png ├── generate_fusion_mask.py ├── icon40.png ├── mouth_fusion_mask.png ├── obj │ ├── image_utils.py │ ├── obj_mediapipe │ │ ├── face3D.obj │ │ ├── face_wrap_entity.obj │ │ ├── generate_wrap_obj.py │ │ ├── modified_obj.py │ │ ├── modified_teeth_lower.obj │ │ ├── modified_teeth_upper.obj │ │ ├── teeth_lower.obj │ │ ├── teeth_upper.obj │ │ ├── wrap.obj │ │ └── wrap_index.py │ ├── obj_utils.py │ ├── utils.py │ ├── weights478.txt │ └── wrap_utils.py ├── opengl_render_interface.py ├── render.py ├── shader │ ├── prompt3.fsh │ └── prompt3.vsh ├── train.py └── train_input_validation.py ├── requirements.txt ├── talkingface ├── __init__.py ├── audio_model.py ├── config │ └── config.py ├── data │ ├── DHLive_mini_dataset.py │ ├── __init__.py │ ├── dataset_wav.py │ ├── face_mask.py │ └── few_shot_dataset.py ├── face_pts_mean.txt ├── mediapipe_utils.py ├── model_utils.py ├── models │ ├── DINet.py │ ├── DINet_mini.py │ ├── __init__.py │ ├── audio2bs_lstm.py │ ├── common │ │ ├── Discriminator.py │ │ └── VGG19.py │ └── speed_test.py ├── preprocess.py ├── render_model.py ├── render_model_mini.py ├── run_utils.py ├── util │ ├── __init__.py │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── log_board.py │ ├── smooth.py │ ├── util.py │ ├── utils.py │ └── visualizer.py └── utils.py ├── video_data ├── 000001 │ └── video.mp4 ├── 000002 │ └── video.mp4 ├── audio0.wav ├── audio1.wav └── teeth_ref │ ├── 221.png │ ├── 252.png │ ├── 328.png │ ├── 377.png │ ├── 398.png │ ├── 519.png │ ├── 558.png │ ├── 682.png │ ├── 743.png │ ├── 760.png │ └── 794.png └── web_demo ├── Flowchart.jpg ├── README.md ├── server.py ├── server_realtime.py ├── static ├── DHLiveMini.wasm ├── MiniLive.html ├── MiniLive_RealTime.html ├── MiniLive_new.html ├── assets │ ├── 01.mp4 │ └── combined_data.json.gz ├── assets2 │ ├── 01.mp4 │ └── combined_data.json.gz ├── common │ ├── bs_texture_halfFace.png │ ├── favicon.ico │ └── test.wav ├── css │ └── material-icons.css ├── dialog.html ├── dialog_RealTime.html ├── fonts │ └── flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2 └── js │ ├── DHLiveMini.js │ ├── MiniLive2.js │ ├── MiniMateLoader.js │ ├── audio_recorder.js │ ├── dialog.js │ ├── dialog_realtime.js │ ├── mp4box.all.min.js │ └── pako.min.js └── voiceapi ├── asr.py ├── llm.py ├── offline_tts.py └── tts.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | .idea/ 3 | /checkpoint/ 4 | /mini_live/pt/ 5 | /web_demo_authorized/ 6 | /utils/ 7 | # 忽略 video_data 文件夹下的所有内容 8 | video_data/* 9 | !video_data/000001/ 10 | !video_data/000002/ 11 | !video_data/teeth_ref/ 12 | video_data/teeth_ref/*_2.png 13 | !video_data/test/ 14 | !video_data/audio1.wav 15 | 16 | /web_demo/static/ios_test.html 17 | 18 | /websocket_demo/ 19 | web_demo/models/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/DH_live.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mobile and Web Real-time Live Streaming Digital Human! 2 | # 实时数字人 全网最快 3 | Notes:目前项目主要维护DH_live_mini, 目前最快的2D视频数字人方案,没有之一,项目含有网页推理的案例,不依赖任何GPU,可在任何手机设备实时运行。 4 | 5 | 原版DH_live已不再获支持,希望慎重考虑使用。原版使用方法参见另一分支 [here](https://github.com/kleinlee/DH_live/blob/main_250508/README_DH_live.md)。 6 | 7 | DHLive_mini手机浏览器直接推理演示 [bilibili video](https://www.bilibili.com/video/BV1UgFFeKEpp) 8 | 9 | 商业化网页应用:[matesx.com](matesx.com), 你可以直接打开网页查看完整体应用。 10 | 11 | ![微信图片_20250209153828](https://github.com/user-attachments/assets/32650fac-3885-4c98-886f-66258ef891a7) 12 | 13 | 14 | # News 15 | - 2025-01-26 最小化简化网页资源包,gzip资源小于2MB。简化视频数据,数据大小减半 16 | - 2025-02-09 增加ASR入口、增加一键切换形象。 17 | - 2025-02-27 优化渲染、去除参照视频,目前只需要一段视频即可生成。 18 | - 2025-03-11 增加DH_live_mini的CPU支持。 19 | - 2025-04-09 增加对IOS17以上的长视频支持。 20 | - 2025-04-25 增加完整的实时对话服务,包含vad-asr-llm-tts-数字人全流程,请见web_demo/server_realtime.py。 21 | 22 | # 数字人方案对比 23 | 24 | | 方案名称 | 单帧算力(Mflops) | 使用方式 | 脸部分辨率 | 适用设备 | 25 | |------------------------------|-------------------|------------|------------|------------------------------------| 26 | | Ultralight-Digital-Human(mobile) | 1100 | 单人训练 | 160 | 中高端手机APP | 27 | | DH_live_mini | 39 | 无须训练 | 128 | 所有设备,网页&APP&小程序 | 28 | | DH_live | 55046 | 无须训练 | 256 | 30系以上显卡 | 29 | | duix.ai | 1200 | 单人训练 | 160 | 中高端手机APP | 30 | 31 | ### checkpoint 32 | All checkpoint files are moved to [BaiduDrive](https://pan.baidu.com/s/1jH3WrIAfwI3U5awtnt9KPQ?pwd=ynd7) 33 | [GoogleDrive](https://drive.google.com/drive/folders/1az5WEWOFmh0_yrF3I9DEyctMyjPolo8V?usp=sharing) 34 | 35 | ### Key Features 36 | - **最低算力**: 推理一帧的算力39 Mflops,有多小?小于手机端大部分的人脸检测算法。 37 | - **最小存储**:整个网页资源可以压缩到3MB! 38 | - **无须训练**: 开箱即用,无需复杂的训练过程。 39 | 40 | ### 平台支持 41 | - **windows**: 支持视频数据处理、离线视频合成、网页服务器。 42 | - **linux&macOS**:支持视频数据处理、搭建网页服务器,不支持离线视频合成。 43 | - **网页&小程序**:支持客户端直接打开(可搜索小程序“MatesX数字生命”,功能和网页版完全一致)。 44 | - **App**:webview方式调用网页或重构原生应用。 45 | 46 | 47 | | 平台 | Windows | Linux/macOS | 48 | |---------------|---------------|-------------| 49 | | 原始视频处理&网页资源准备 | ✅ | ✅ | 50 | | 离线视频合成 | ✅ | ❌ | 51 | | 构建网页服务器 | ✅ | ✅ | 52 | | 实时对话 | ✅ | ✅ | 53 | 54 | ## Easy Usage (Gradio) 55 | 第一次使用或想获取完整流程请运行此Gradio。 56 | ```bash 57 | python app.py 58 | ``` 59 | 60 | ## Usage 61 | 62 | ### Create Environment 63 | First, navigate to the `checkpoint` directory and unzip the model file: 64 | ```bash 65 | conda create -n dh_live python=3.11 66 | conda activate dh_live 67 | pip install torch --index-url https://download.pytorch.org/whl/cu124 68 | pip install -r requirements.txt 69 | cd checkpoint 70 | ``` 71 | 注意如果没有GPU可以安装CPU版本的pytorch: pip install torch 72 | 73 | Download and unzip checkpoint files. 74 | ### Prepare Your Video 75 | ```bash 76 | python data_preparation_mini.py video_data/000002/video.mp4 video_data/000002 77 | python data_preparation_web.py video_data/000002 78 | ``` 79 | 处理后的视频信息将存储在 ./video_data 目录中。 80 | ### Run with Audio File ( linux and MacOS not supported!!! ) 81 | 语音文件必须是单通道16K Hz的wav文件格式。 82 | ```bash 83 | python demo_mini.py video_data/000002/assets video_data/audio0.wav 1.mp4 84 | ``` 85 | ### Web demo 86 | 请将新形象包中的assets文件(譬如video_data/000002/assets)替换 assets 文件夹中的对应文件 87 | ```bash 88 | python web_demo/server.py 89 | ``` 90 | 可以打开 localhost:8888/static/MiniLive.html。 91 | 92 | 如果想体验最佳的流式对话效果,请认真阅读 [web_demo/README.md](https://github.com/kleinlee/DH_live/blob/main/web_demo/README.md),内含完整的可商用工程。 93 | ### Authorize 94 | 网页部分的商业应用涉及形象授权(去除logo):访问[授权说明] (www.matesx.com/authorized.html) 95 | 96 | 上传你生成的combined_data.json.gz, 授权后下载得到新的combined_data.json.gz,覆盖原文件即可去除logo。 97 | ### Chat Now 98 | 访问 matesx.com, 即刻在任意设备开启定制形象、克隆语音、打造人设的数字人对话之旅。 99 | 100 | 小程序请搜索“MatesX数字生命” 101 | ## Algorithm Architecture 102 | ![deepseek_mermaid_20250506_c244f8](https://github.com/user-attachments/assets/548f65aa-3ede-4657-bf4e-56b3c93272bb) 103 | 104 | ## License 105 | DH_live is licensed under the MIT License. 106 | 107 | ## 联系 108 | | 加我好友,请备注“进群”,拉你进去微信交流群。| 进入QQ群聊,分享看法和最新资讯。 | 109 | |-------------------|------------------------------------------------------------------------------------------| 110 | | ![微信交流群](https://github.com/user-attachments/assets/b1f24ebb-153b-44b1-b522-14f765154110) | ![QQ群聊](https://github.com/user-attachments/assets/29bfef3f-438a-4b9f-ba09-e1926d1669cb) | 111 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import shutil 3 | import gradio as gr 4 | import subprocess 5 | import uuid 6 | from data_preparation_mini import data_preparation_mini 7 | from data_preparation_web import data_preparation_web 8 | 9 | 10 | # 自定义 CSS 样式 11 | css = """ 12 | #video-output video { 13 | max-width: 300px; 14 | max-height: 300px; 15 | display: block; 16 | margin: 0 auto; 17 | } 18 | """ 19 | 20 | video_dir_path = "" 21 | # 假设你已经有了这两个函数 22 | def data_preparation(video1, resize_option): 23 | global video_dir_path 24 | # 处理视频的逻辑 25 | video_dir_path = "video_data/{}".format(uuid.uuid4()) 26 | data_preparation_mini(video1, video_dir_path, resize_option) 27 | data_preparation_web(video_dir_path) 28 | 29 | return "视频处理完成,保存至目录{}".format(video_dir_path) 30 | 31 | def demo_mini(audio): 32 | global video_dir_path 33 | # 生成视频的逻辑 34 | audio_path = audio # 解包元组 35 | wav_path = "video_data/tmp.wav" 36 | ffmpeg_cmd = "ffmpeg -i {} -ac 1 -ar 16000 -y {}".format(audio_path, wav_path) 37 | print(ffmpeg_cmd) 38 | os.system(ffmpeg_cmd) 39 | output_video_name = "video_data/tmp.mp4" 40 | asset_path = os.path.join(video_dir_path, "assets") 41 | from demo_mini import interface_mini 42 | interface_mini(asset_path, wav_path, output_video_name) 43 | return output_video_name # 返回生成的视频文件路径 44 | 45 | # 启动网页的函数 46 | def launch_server(): 47 | global video_dir_path 48 | asset_path = os.path.join(video_dir_path, "assets") 49 | target_path = os.path.join("web_demo", "static", "assets") 50 | 51 | # 如果目标目录存在,先删除 52 | if os.path.exists(target_path): 53 | shutil.rmtree(target_path) 54 | 55 | # 将 asset_path 目录下的所有文件拷贝到 web_demo/static/assets 目录下 56 | shutil.copytree(asset_path, target_path) 57 | 58 | # 启动 server.py 59 | subprocess.Popen(["python", "web_demo/server.py"]) 60 | 61 | return "访问 http://localhost:8888/static/MiniLive_new.html" 62 | 63 | # 定义 Gradio 界面 64 | def create_interface(): 65 | with gr.Blocks(css=css) as demo: 66 | # 标题 67 | gr.Markdown("# 视频处理与生成工具") 68 | 69 | # 第一部分:上传静默视频和说话视频 70 | gr.Markdown("## 第一部分:视频处理") 71 | gr.Markdown(""" 72 | - **静默视频**:时长建议在 5-30 秒之间,嘴巴不要动(保持闭嘴或微张)。嘴巴如果有动作会影响效果,请认真对待。 73 | """) 74 | with gr.Row(): 75 | with gr.Column(): 76 | video1 = gr.Video(label="上传静默视频", elem_id="video-output", sources="upload") 77 | # 增加可选项 78 | resize_option = gr.Checkbox(label="是否转为最高720P(适配手机)", value=True) 79 | process_button = gr.Button("处理视频") 80 | process_output = gr.Textbox(label="处理结果") 81 | 82 | # 分隔线 83 | gr.Markdown("---") 84 | 85 | # 第二部分:上传音频文件并生成视频 86 | gr.Markdown("## 第二部分:测试语音生成视频(不支持linux和MacOS,请跳过此步)") 87 | gr.Markdown(""" 88 | - 上传音频文件后,点击“生成视频”按钮,程序会调用 `demo_mini` 函数完成推理并生成视频。 89 | - 此步骤用于初步验证结果。网页demo请执行第三步。 90 | """) 91 | # audio = gr.Audio(label="上传音频文件") 92 | 93 | with gr.Row(): 94 | with gr.Column(): 95 | audio = gr.Audio(label="上传音频文件", type="filepath") 96 | generate_button = gr.Button("生成视频") 97 | with gr.Column(): 98 | video_output = gr.Video(label="生成的视频", elem_id="video-output") 99 | 100 | # 分隔线 101 | gr.Markdown("---") 102 | 103 | # 第三部分:启动网页 104 | gr.Markdown("## 第三部分:启动网页") 105 | launch_button = gr.Button("启动网页") 106 | gr.Markdown(""" 107 | - **注意**:本项目使用了 WebCodecs API,该 API 仅在安全上下文(HTTPS 或 localhost)中可用。因此,在部署或测试时,请确保您的网页在 HTTPS 环境下运行,或者使用 localhost 进行本地测试。 108 | """) 109 | launch_output = gr.Textbox(label="启动结果") 110 | # 扩展功能提示 111 | gr.Markdown(""" 112 | **🔔 扩展功能提示:** 113 | > 更多高级功能(实时大模型对话、动态更换任务、音色切换等)请前往 114 | > `web_demo` 目录按照说明配置后,启动 115 | > `web_demo/server_realtime.py` 体验完整功能 116 | """) 117 | gr.Markdown(""" 118 | - 点击“启动网页”按钮后,会启动 `server.py`,提供一个模拟对话服务。 119 | - 在 `static/js/dialog.js` 文件中,找到第 1 行,将 server_url=`http://localhost:8888/eb_stream` 替换为您自己的对话服务网址。例如: 120 | ```bash 121 | https://your-dialogue-service.com/eb_stream 122 | ``` 123 | - `server.py` 提供了一个模拟对话服务的示例。它接收 JSON 格式的输入,并流式返回 JSON 格式的响应。 124 | # API 接口说明 125 | 126 | ## 输入 JSON 格式 127 | 128 | | 字段名 | 必填 | 类型 | 说明 | 默认值 | 129 | |--------------|------|--------|----------------------------------------------------------------------|--------| 130 | | `input_mode` | 是 | 字符串 | 输入模式,可选值为 `"text"` 或 `"audio"`,分别对应文字对话和语音对话输入 | "audio" | 131 | | `prompt` | 条件 | 字符串 | 当 `input_mode` 为 `"text"` 时必填,表示用户输入的对话内容 | 无 | 132 | | `audio` | 条件 | 字符串 | 当 `input_mode` 为 `"audio"` 时必填,表示 Base64 编码的音频数据 | 无 | 133 | | `voice_speed`| 否 | 字符串 | TTS 语速,可选 | "" | 134 | | `voice_id` | 否 | 字符串 | TTS 音色,可选 | "" | 135 | 136 | ## 输出 JSON 格式(流式返回) 137 | 138 | | 字段名 | 必填 | 类型 | 说明 | 默认值 | 139 | |------------|------|--------|----------------------------------------------------------------------|----------| 140 | | `text` | 是 | 字符串 | 返回的部分对话文本 | 无 | 141 | | `audio` | 否 | 字符串 | Base64 编码的音频数据,可选 | 无 | 142 | | `endpoint` | 是 | 布尔 | 是否为对话的最后一个片段,`true` 表示结束 | `false` | 143 | 144 | --- 145 | 146 | #### 输入输出示例 147 | ```json 148 | { 149 | "input_mode": "text", 150 | "prompt": "你好,今天天气怎么样?", 151 | "voice_speed": "", 152 | "voice_id": "" 153 | } 154 | 输出 155 | { 156 | "text": "今天天气晴朗,温度适宜。", 157 | "audio": "SGVsbG8sIFdvcm...", 158 | "endpoint": false 159 | } 160 | ``` 161 | """) 162 | # 第四部分:商业授权和更新 163 | gr.Markdown("## 第四部分:完整服务与更新") 164 | gr.Markdown(""" 165 | - 可访问www.matesx.com 体验完整服务。 166 | - 商业授权(去除logo):访问www.matesx.com/authorized.html, 上传你生成的combined_data.json.gz, 授权后下载得到新的combined_data.json.gz,覆盖原文件即可去除logo。 167 | - 人物切换:已开放功能,可自己整改,官方后续会完善。 168 | - 未来12个月会持续更新效果,可以关注公众号”Mates数字生命“获取即时动态。 169 | """) 170 | 171 | 172 | # 绑定按钮点击事件 173 | process_button.click(data_preparation, inputs=[video1, resize_option], outputs=process_output) 174 | generate_button.click(demo_mini, inputs=audio, outputs=video_output) 175 | launch_button.click(launch_server, outputs=launch_output) 176 | 177 | return demo 178 | 179 | # 创建 Gradio 界面并启动 180 | if __name__ == "__main__": 181 | demo = create_interface() 182 | demo.launch() -------------------------------------------------------------------------------- /data/pca.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/data/pca.pkl -------------------------------------------------------------------------------- /data_preparation_web.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import tqdm 3 | import numpy as np 4 | import cv2 5 | import sys 6 | import os 7 | import gzip 8 | from talkingface.data.few_shot_dataset import get_image 9 | import shutil 10 | from talkingface.utils import crop_mouth, main_keypoints_index, smooth_array,normalizeLips 11 | import json 12 | from mini_live.obj.wrap_utils import index_wrap, index_edge_wrap 13 | import pickle 14 | from talkingface.models.DINet_mini import model_size 15 | 16 | def step0_keypoints(video_path, out_path): 17 | Path_output_pkl = video_path + "/processed.pkl" 18 | with open(Path_output_pkl, "rb") as f: 19 | pts_3d = pickle.load(f) 20 | 21 | pts_3d = pts_3d.reshape(len(pts_3d), -1) 22 | smooth_array_ = smooth_array(pts_3d, weight=[0.02, 0.09, 0.78, 0.09, 0.02]) 23 | pts_3d = smooth_array_.reshape(len(pts_3d), 478, 3) 24 | 25 | video_path = os.path.join(video_path, "processed.mp4") 26 | cap = cv2.VideoCapture(video_path) 27 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 28 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 29 | cap.release() 30 | out_path = os.path.join(out_path, "01.mp4") 31 | try: 32 | # 复制文件 33 | shutil.copy(video_path, out_path) 34 | print(f"视频已成功复制到 {out_path}") 35 | except Exception as e: 36 | print(f"复制文件时出错: {e}") 37 | return pts_3d,vid_width,vid_height 38 | 39 | def step1_crop_mouth(pts_3d, vid_width, vid_height): 40 | list_source_crop_rect = [crop_mouth(source_pts[main_keypoints_index], vid_width, vid_height) for source_pts in 41 | pts_3d] 42 | list_source_crop_rect = np.array(list_source_crop_rect).reshape(len(pts_3d), -1) 43 | face_size = (list_source_crop_rect[:,2] - list_source_crop_rect[:,0]).mean()/2.0 + (list_source_crop_rect[:,3] - list_source_crop_rect[:,1]).mean()/2.0 44 | face_size = int(face_size)//2 * 2 45 | face_mid = (list_source_crop_rect[:,2:] + list_source_crop_rect[:,0:2])/2. 46 | # step 1: Smooth Cropping Rectangle Transition 47 | # Since HTML video playback can have inconsistent frame rates and may not align precisely from frame to frame, adjust the cropping rectangle to transition smoothly, compensating for potential misalignment. 48 | face_mid = smooth_array(face_mid, weight=[0.10, 0.20, 0.40, 0.20, 0.10]) 49 | face_mid = face_mid.astype(int) 50 | if face_mid[:, 0].max() + face_size / 2 > vid_width or face_mid[:, 1].max() + face_size / 2 > vid_height: 51 | raise ValueError("人脸范围超出了视频,请保证视频合格后再重试") 52 | 53 | list_source_crop_rect = np.concatenate([face_mid - face_size // 2, face_mid + face_size // 2], axis = 1) 54 | 55 | # import pandas as pd 56 | # pd.DataFrame(list_source_crop_rect).to_csv("sss.csv") 57 | 58 | standard_size = model_size 59 | list_standard_v = [] 60 | for frame_index in range(len(list_source_crop_rect)): 61 | source_pts = pts_3d[frame_index] 62 | source_crop_rect = list_source_crop_rect[frame_index] 63 | print(source_crop_rect) 64 | standard_v = get_image(source_pts, source_crop_rect, input_type="mediapipe", resize=standard_size) 65 | 66 | list_standard_v.append(standard_v) 67 | 68 | return list_source_crop_rect, list_standard_v 69 | 70 | def generate_combined_data(list_source_crop_rect, list_standard_v, video_path, out_path): 71 | from mini_live.obj.obj_utils import generateRenderInfo, generateWrapModel 72 | from talkingface.run_utils import calc_face_mat 73 | from mini_live.obj.wrap_utils import newWrapModel 74 | from talkingface.render_model_mini import RenderModel_Mini 75 | 76 | # Step 2: Generate face3D.obj data 77 | render_verts, render_face = generateRenderInfo() 78 | face_pts_mean = render_verts[:478, :3].copy() 79 | 80 | wrapModel_verts, wrapModel_face = generateWrapModel() 81 | mat_list, _, face_pts_mean_personal_primer = calc_face_mat(np.array(list_standard_v), face_pts_mean) 82 | 83 | # face_pts_mean_personal_primer[INDEX_MP_LIPS] = face_pts_mean[INDEX_MP_LIPS] * 0.33 + face_pts_mean_personal_primer[INDEX_MP_LIPS] * 0.66 84 | face_pts_mean_personal_primer = normalizeLips(face_pts_mean_personal_primer, face_pts_mean) 85 | face_wrap_entity = newWrapModel(wrapModel_verts, face_pts_mean_personal_primer) 86 | 87 | face3D_data = [] 88 | for i in face_wrap_entity: 89 | face3D_data.append("v {:.3f} {:.3f} {:.3f} {:.02f} {:.0f}\n".format(i[0], i[1], i[2], i[3], i[4])) 90 | for i in range(len(wrapModel_face) // 3): 91 | face3D_data.append("f {0} {1} {2}\n".format(wrapModel_face[3 * i] + 1, wrapModel_face[3 * i + 1] + 1, 92 | wrapModel_face[3 * i + 2] + 1)) 93 | 94 | # Step 3: Generate ref_data.txt data 95 | renderModel_mini = RenderModel_Mini() 96 | renderModel_mini.loadModel("checkpoint/DINet_mini/epoch_40.pth") 97 | 98 | Path_output_pkl = "{}/processed.pkl".format(video_path) 99 | with open(Path_output_pkl, "rb") as f: 100 | ref_images_info = pickle.load(f) 101 | 102 | video_path = "{}/processed.mp4".format(video_path) 103 | cap = cv2.VideoCapture(video_path) 104 | vid_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 105 | assert vid_frame_count > 0, "处理后的视频无有效帧" 106 | vid_width_ref = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 107 | vid_height_ref = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 108 | 109 | standard_size = model_size 110 | frame_index = 0 111 | # cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 112 | ret, frame = cap.read() 113 | cap.release() 114 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA) 115 | source_pts = ref_images_info[frame_index] 116 | source_crop_rect = crop_mouth(source_pts[main_keypoints_index], vid_width_ref, vid_height_ref) 117 | 118 | standard_img = get_image(frame, source_crop_rect, input_type="image", resize=standard_size) 119 | standard_v = get_image(source_pts, source_crop_rect, input_type="mediapipe", resize=standard_size) 120 | 121 | 122 | renderModel_mini.reset_charactor(standard_img, standard_v[main_keypoints_index], standard_size=standard_size) 123 | 124 | ref_in_feature = renderModel_mini.net.infer_model.ref_in_feature 125 | ref_in_feature = ref_in_feature.detach().squeeze(0).cpu().float().numpy().flatten() 126 | # cv2.imwrite(os.path.join(out_path, 'ref.png'), renderModel_mini.ref_img_save) 127 | rounded_array = np.round(ref_in_feature, 6) 128 | 129 | # Combine all data into a single JSON object 130 | combined_data = { 131 | "uid": "matesx_" + str(uuid.uuid4()), 132 | "frame_num": len(list_standard_v), 133 | "face3D_obj": face3D_data, 134 | "ref_data": rounded_array.tolist(), 135 | "json_data": [], 136 | "authorized": False, 137 | } 138 | 139 | for frame_index in range(len(list_source_crop_rect)): 140 | source_crop_rect = list_source_crop_rect[frame_index] 141 | standard_v = list_standard_v[frame_index] 142 | 143 | standard_v = standard_v[index_wrap, :2].flatten().tolist() 144 | mat = mat_list[frame_index].T.flatten().tolist() 145 | standard_v_rounded = [round(i, 5) for i in mat] + [round(i, 1) for i in standard_v] 146 | combined_data["json_data"].append({"rect": source_crop_rect.tolist(), "points": standard_v_rounded}) 147 | 148 | # with open(os.path.join(out_path, "combined_data.json"), "w") as f: 149 | # json.dump(combined_data, f) 150 | 151 | # Save as Gzip compressed JSON 152 | output_file = os.path.join(out_path, "combined_data.json.gz") 153 | with gzip.open(output_file, 'wt', encoding='UTF-8') as f: 154 | json.dump(combined_data, f) 155 | 156 | def data_preparation_web(path): 157 | video_path = os.path.join(path, "data") 158 | out_path = os.path.join(path, "assets") 159 | os.makedirs(out_path, exist_ok=True) 160 | pts_3d, vid_width,vid_height = step0_keypoints(video_path, out_path) 161 | list_source_crop_rect, list_standard_v = step1_crop_mouth(pts_3d, vid_width, vid_height) 162 | generate_combined_data(list_source_crop_rect, list_standard_v, video_path, out_path) 163 | 164 | def main(): 165 | # 检查命令行参数的数量 166 | if len(sys.argv) != 2: 167 | print("Usage: python data_preparation_web.py ") 168 | sys.exit(1) # 参数数量不正确时退出程序 169 | 170 | # 获取video_name参数 171 | video_dir_path = sys.argv[1] 172 | 173 | data_preparation_web(video_dir_path) 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /demo_mini.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import gzip 4 | import json 5 | import cv2 6 | import numpy as np 7 | import sys 8 | import torch 9 | from talkingface.model_utils import LoadAudioModel, Audio2bs 10 | from talkingface.data.few_shot_dataset import get_image 11 | from mini_live.render import create_render_model 12 | from talkingface.models.DINet_mini import input_height,input_width 13 | from talkingface.model_utils import device 14 | 15 | from talkingface.models.DINet_mini import model_size 16 | 17 | def interface_mini(path, wav_path, output_video_path): 18 | # 加载音频模型 19 | Audio2FeatureModel = LoadAudioModel(r'checkpoint/lstm/lstm_model_epoch_325.pkl') 20 | 21 | # 加载渲染模型 22 | from talkingface.render_model_mini import RenderModel_Mini 23 | renderModel_mini = RenderModel_Mini() 24 | renderModel_mini.loadModel("checkpoint/DINet_mini/epoch_40.pth") 25 | 26 | # 设置标准尺寸和裁剪比例 27 | standard_size = model_size * 2 28 | crop_rotio = [0.5, 0.5, 0.5, 0.5] 29 | out_w = int(standard_size * (crop_rotio[0] + crop_rotio[1])) 30 | out_h = int(standard_size * (crop_rotio[2] + crop_rotio[3])) 31 | out_size = (out_w, out_h) 32 | renderModel_gl = create_render_model((out_w, out_h), floor=20) 33 | 34 | # 读取 Gzip 压缩的 JSON 文件 35 | combined_data_path = os.path.join(path, "combined_data.json.gz") 36 | with gzip.open(combined_data_path, 'rt', encoding='UTF-8') as f: 37 | combined_data = json.load(f) 38 | 39 | # 从 combined_data 中提取数据 40 | face3D_obj = combined_data["face3D_obj"] 41 | json_data = combined_data["json_data"] 42 | ref_data = np.array(combined_data["ref_data"], dtype=np.float32).reshape([1, 20, input_height//4, input_width//4]) 43 | 44 | # 设置 ref_data 到渲染模型 45 | renderModel_mini.net.infer_model.ref_in_feature = torch.from_numpy(ref_data).float().to(device) 46 | 47 | # 读取视频信息 48 | video_path = os.path.join(path, "01.mp4") 49 | cap = cv2.VideoCapture(video_path) 50 | vid_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 51 | vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 52 | vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 53 | 54 | # 初始化列表 55 | list_source_crop_rect = [] 56 | list_video_img = [] 57 | list_standard_img = [] 58 | list_standard_v = [] 59 | 60 | # 处理每一帧 61 | for frame_index in range(min(vid_frame_count, len(json_data))): 62 | ret, frame = cap.read() 63 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA) 64 | standard_v = json_data[frame_index]["points"][16:] 65 | source_crop_rect = json_data[frame_index]["rect"] 66 | 67 | standard_img = get_image(frame, source_crop_rect, input_type="image", resize=standard_size) 68 | 69 | list_video_img.append(frame) 70 | list_source_crop_rect.append(source_crop_rect) 71 | list_standard_img.append(standard_img) 72 | list_standard_v.append(np.array(standard_v).reshape(-1, 2) * 2) 73 | cap.release() 74 | 75 | # 生成矩阵列表 76 | mat_list = [np.array(i["points"][:16]).reshape(4, 4) * 2 for i in json_data] 77 | 78 | # 反转列表中的数据 79 | list_video_img_reversed = list_video_img[::-1] 80 | list_source_crop_rect_reversed = list_source_crop_rect[::-1] 81 | list_standard_img_reversed = list_standard_img[::-1] 82 | list_standard_v_reversed = list_standard_v[::-1] 83 | mat_list_reversed = mat_list[::-1] 84 | 85 | # 将反转后的数据与原有数据合并 86 | list_video_img = list_video_img + list_video_img_reversed 87 | list_source_crop_rect = list_source_crop_rect + list_source_crop_rect_reversed 88 | list_standard_img = list_standard_img + list_standard_img_reversed 89 | list_standard_v = list_standard_v + list_standard_v_reversed 90 | mat_list = mat_list + mat_list_reversed 91 | 92 | # 解析 face3D.obj 数据 93 | v_ = [] 94 | for line in face3D_obj: 95 | if line.startswith("v "): 96 | v0, v1, v2, v3, v4 = line[2:].split() 97 | v_.append(float(v0)) 98 | v_.append(float(v1)) 99 | v_.append(float(v2)) 100 | v_.append(float(v3)) 101 | v_.append(float(v4)) 102 | face_wrap_entity = np.array(v_).reshape(-1, 5) 103 | 104 | # 生成 VBO 105 | renderModel_gl.GenVBO(face_wrap_entity) 106 | 107 | # 生成音频特征 108 | bs_array = Audio2bs(wav_path, Audio2FeatureModel)[5:] * 0.5 109 | 110 | # 创建视频写入器 111 | task_id = str(uuid.uuid1()) 112 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 113 | save_path = "{}.mp4".format(task_id) 114 | videoWriter = cv2.VideoWriter(save_path, fourcc, 25, (int(vid_width), int(vid_height))) 115 | 116 | # 渲染每一帧 117 | for index2_ in range(len(bs_array)): 118 | frame_index = index2_ % len(mat_list) 119 | bs = np.zeros([12], dtype=np.float32) 120 | bs[:6] = bs_array[frame_index, :6] 121 | bs[1] = bs[1] / 2 * 1.6 122 | 123 | verts_frame_buffer = np.array(list_standard_v)[frame_index, :, :2].copy() / model_size - 1 124 | 125 | rgba = renderModel_gl.render2cv(verts_frame_buffer, out_size=out_size, mat_world=mat_list[frame_index], 126 | bs_array=bs) 127 | rgba = rgba[::2, ::2, :] 128 | gl_tensor = torch.from_numpy(rgba / 255.).float().permute(2, 0, 1).unsqueeze(0) 129 | source_tensor = cv2.resize(list_standard_img[frame_index], (model_size, model_size)) 130 | source_tensor = torch.from_numpy(source_tensor / 255.).float().permute(2, 0, 1).unsqueeze(0) 131 | 132 | warped_img = renderModel_mini.interface(source_tensor.to(device), gl_tensor.to(device)) 133 | 134 | image_numpy = warped_img.detach().squeeze(0).cpu().float().numpy() 135 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 136 | image_numpy = image_numpy.clip(0, 255) 137 | image_numpy = image_numpy.astype(np.uint8) 138 | 139 | x_min, y_min, x_max, y_max = list_source_crop_rect[frame_index] 140 | 141 | img_face = cv2.resize(image_numpy, (x_max - x_min, y_max - y_min)) 142 | img_bg = list_video_img[frame_index][:, :, :3] 143 | img_bg[y_min:y_max, x_min:x_max, :3] = img_face[:, :, :3] 144 | 145 | videoWriter.write(img_bg[:, :, ::-1]) 146 | videoWriter.release() 147 | 148 | # 使用 ffmpeg 合并音频和视频 149 | os.system( 150 | "ffmpeg -i {} -i {} -c:v libx264 -pix_fmt yuv420p -y {}".format(save_path, wav_path, output_video_path)) 151 | os.remove(save_path) 152 | 153 | cv2.destroyAllWindows() 154 | 155 | def main(): 156 | # 检查命令行参数的数量 157 | if len(sys.argv) < 4: 158 | print("Usage: python demo_mini.py ") 159 | sys.exit(1) # 参数数量不正确时退出程序 160 | 161 | # 获取命令行参数 162 | asset_path = sys.argv[1] 163 | print(f"Video asset path is set to: {asset_path}") 164 | wav_path = sys.argv[2] 165 | print(f"Audio path is set to: {wav_path}") 166 | output_video_name = sys.argv[3] 167 | print(f"Output video name is set to: {output_video_name}") 168 | 169 | # 调用主函数 170 | interface_mini(asset_path, wav_path, output_video_name) 171 | 172 | # 示例使用 173 | if __name__ == "__main__": 174 | main() -------------------------------------------------------------------------------- /mini_live/bs_texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/mini_live/bs_texture.png -------------------------------------------------------------------------------- /mini_live/bs_texture_halfFace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/mini_live/bs_texture_halfFace.png -------------------------------------------------------------------------------- /mini_live/face_fusion_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/mini_live/face_fusion_mask.png -------------------------------------------------------------------------------- /mini_live/generate_fusion_mask.py: -------------------------------------------------------------------------------- 1 | # import numpy as np 2 | # import cv2 3 | # import os 4 | # 5 | # face_fusion_mask = np.zeros([128, 128], dtype = np.uint8) 6 | # for i in range(17): 7 | # face_fusion_mask[i:-i,i:-i] = min(255, 16*i) 8 | # 9 | # cv2.imwrite("face_fusion_mask.png", face_fusion_mask) 10 | 11 | 12 | # # from mini_live.obj.wrap_utils import index_wrap 13 | # # image2 = cv2.imread("bs_texture.png") 14 | # # image3 = np.zeros([12, 256, 3], dtype=np.uint8) 15 | # # image3[:, :len(index_wrap)] = image2[:, index_wrap] 16 | # # cv2.imwrite("bs_texture_halfFace.png", image3) 17 | # 18 | # 19 | from PIL import Image, ImageDraw 20 | 21 | # 1. 构造一个全黑的100*100的图片 22 | image = Image.new('RGB', (100, 100), color=(0, 0, 0)) 23 | draw = ImageDraw.Draw(image) 24 | 25 | # 2. 在图片中构造19个矩形 26 | for i in range(19): 27 | size = 98 - 2 * i 28 | color = (14 * i, 14 * i, 14 * i) 29 | x0 = (100 - size) // 2 30 | y0 = (100 - size) // 2 31 | x1 = x0 + size 32 | y1 = y0 + size 33 | draw.rounded_rectangle([x0, y0, x1, y1], radius=25, fill=color) 34 | 35 | image.save('final_image.png') 36 | image.show() 37 | 38 | # 3. 图片按照高度分为20、60、20三个区域 39 | region1 = image.crop((0, 0, 100, 20)) 40 | region2 = image.crop((0, 20, 100, 80)) 41 | region3 = image.crop((0, 80, 100, 100)) 42 | 43 | # 对第一个区域resize为100*6的区域 44 | region1_resized = region1.resize((100, 8)) 45 | # 对第三个区域resize为100*8的区域 46 | region3_resized = region3.resize((100, 10)) 47 | 48 | # 将三个区域concatenate起来形成新图片 49 | new_image1 = Image.new('RGB', (100, 78)) 50 | new_image1.paste(region1_resized, (0, 0)) 51 | new_image1.paste(region2, (0, 8)) 52 | new_image1.paste(region3_resized, (0, 68)) 53 | 54 | # 4. 新图片按照宽度分为20、60、20三个区域 55 | region1_width = new_image1.crop((0, 0, 20, 78)) 56 | region2_width = new_image1.crop((20, 0, 80, 78)) 57 | region3_width = new_image1.crop((80, 0, 100, 78)) 58 | 59 | # 对第一个、第三个区域resize为4*74的区域 60 | region1_width_resized = region1_width.resize((10, 78)) 61 | region3_width_resized = region3_width.resize((10, 78)) 62 | 63 | # 将三个区域concatenate起来,再次形成新图片 64 | new_image2 = Image.new('RGB', (80, 78)) 65 | new_image2.paste(region1_width_resized, (0, 0)) 66 | new_image2.paste(region2_width, (10, 0)) 67 | new_image2.paste(region3_width_resized, (70, 0)) 68 | 69 | # 5. 新图片resize为(72, 56) 70 | final_image = new_image2.resize((72, 72)) 71 | 72 | # 保存最终图片 73 | final_image.save('mouth_fusion_mask.png') 74 | final_image.show() -------------------------------------------------------------------------------- /mini_live/icon40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/mini_live/icon40.png -------------------------------------------------------------------------------- /mini_live/mouth_fusion_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/mini_live/mouth_fusion_mask.png -------------------------------------------------------------------------------- /mini_live/obj/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | current_dir = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | def get_standard_image_(img, kps, crop_coords, resize = (256, 256)): 7 | h = img.shape[0] 8 | w = img.shape[1] 9 | c = img.shape[2] 10 | (x_min, y_min, x_max, y_max) = [int(ii) for ii in crop_coords] 11 | new_w = x_max - x_min 12 | new_h = y_max - y_min 13 | img_new = np.zeros([new_h, new_w, c], dtype=np.uint8) 14 | 15 | # 确定裁剪区域上边top和左边left坐标 16 | top = int(y_min) 17 | left = int(x_min) 18 | # 裁剪区域与原图的重合区域 19 | top_coincidence = int(max(top, 0)) 20 | bottom_coincidence = int(min(y_max, h)) 21 | left_coincidence = int(max(left, 0)) 22 | right_coincidence = int(min(x_max, w)) 23 | img_new[top_coincidence - top:bottom_coincidence - top, left_coincidence - left:right_coincidence - left, :] = img[ 24 | top_coincidence:bottom_coincidence, 25 | left_coincidence:right_coincidence, 26 | :] 27 | 28 | img_new = cv2.resize(img_new, resize) 29 | kps = kps - np.array([left, top, 0]) 30 | 31 | factor = resize[0]/new_w 32 | kps = kps * factor 33 | return img_new, kps 34 | 35 | def get_standard_image(img_rgba, source_pts, source_crop_rect, out_size): 36 | ''' 37 | 将输入的RGBA图像和关键点点集转换为标准图像和标准顶点集。 38 | 39 | 参数: 40 | img_rgba (numpy.ndarray): 输入的RGBA图像,形状为 (H, W, 4)。 41 | source_pts (numpy.ndarray): 源点集,形状为 (N, 3),其中N是点的数量,每个点有三个坐标 (x, y, z)。 42 | source_crop_rect (tuple): 源图像的裁剪矩形,格式为 (x, y, width, height)。 43 | out_size (int): 输出图像的大小,格式为 (width, height)。 44 | 45 | 返回: 46 | standard_img (numpy.ndarray): 标准化的图像,形状为 (out_size, out_size, 4)。 47 | standard_v (numpy.ndarray): 标准化的顶点集,形状为 (N, 3)。 48 | standard_vt (numpy.ndarray): 标准化的顶点集的纹理坐标,形状为 (N, 2)。 49 | ''' 50 | source_pts[:, 2] = source_pts[:, 2] - np.max(source_pts[:, 2]) 51 | standard_img, standard_v = get_standard_image_(img_rgba, source_pts, source_crop_rect, resize=out_size) 52 | 53 | standard_vt = standard_v.copy() 54 | standard_vt = standard_vt[:,:2]/ out_size 55 | return standard_img, standard_v, standard_vt 56 | 57 | def crop_face_from_several_images(pts_array_origin, img_w, img_h): 58 | x_min, y_min, x_max, y_max = np.min(pts_array_origin[:, :, 0]), np.min( 59 | pts_array_origin[:, :, 1]), np.max( 60 | pts_array_origin[:, :, 0]), np.max(pts_array_origin[:, :, 1]) 61 | new_w = (x_max - x_min) * 2 62 | new_h = (y_max - y_min) * 2 63 | center_x = (x_max + x_min) / 2. 64 | center_y = y_min + (y_max - y_min) * 0.25 65 | x_min, y_min, x_max, y_max = int(center_x - new_w / 2), int(center_y - new_h / 2), int( 66 | center_x + new_w / 2), int(center_y + new_h / 2) 67 | x_min = max(0, x_min) 68 | y_min = max(0, y_min) 69 | x_max = min(x_max, img_w) 70 | y_max = min(y_max, img_h) 71 | new_size = min((x_max + x_min) / 2., (y_max + y_min) / 2.) 72 | center_x = (x_max + x_min) / 2. 73 | center_y = (y_max + y_min) / 2. 74 | x_min, y_min, x_max, y_max = int(center_x - new_size), int(center_y - new_size), int( 75 | center_x + new_size), int(center_y + new_size) 76 | return np.array([x_min, y_min, x_max, y_max]) 77 | 78 | def crop_face_from_image(kps, crop_rotio = [0.6,0.6,0.65,1.35]): 79 | ''' 80 | 只为了裁剪图片 81 | :param kps: 82 | :param crop_rotio: 83 | :param standard_size: 84 | :return: 85 | ''' 86 | x2d = kps[:, 0] 87 | y2d = kps[:, 1] 88 | w_span = x2d.max() - x2d.min() 89 | h_span = y2d.max() - y2d.min() 90 | crop_size = int(2*max(h_span, w_span)) 91 | center_x = (x2d.max() + x2d.min()) / 2. 92 | center_y = (y2d.max() + y2d.min()) / 2. 93 | # 确定裁剪区域上边top和左边left坐标,中心点是(x2d.max() + x2d.min()/2, y2d.max() + y2d.min()/2) 94 | y_min = int(center_y - crop_size*crop_rotio[2]) 95 | y_max = int(center_y + crop_size*crop_rotio[3]) 96 | x_min = int(center_x - crop_size*crop_rotio[0]) 97 | x_max = int(center_x + crop_size*crop_rotio[1]) 98 | return np.array([x_min, y_min, x_max, y_max]) 99 | 100 | def check_keypoint(img, pts_): 101 | point_size = 1 102 | point_color = (0, 0, 255) # BGR 103 | thickness = 4 # 0 、4、8 104 | for coor in pts_: 105 | # coor = (coor +1 )/2. 106 | cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, point_color, thickness) 107 | cv2.imshow("a", img) 108 | cv2.waitKey(-1) 109 | 110 | if __name__ == "__main__": 111 | from talkingface.mediapipe_utils import detect_face_mesh 112 | import glob 113 | 114 | image_list = glob.glob(r"F:\C\AI\CV\TalkingFace\OpenGLRender_0830\face_rgba/*.png") 115 | image_list.sort() 116 | for index, img_path in enumerate(image_list): 117 | img_primer_rgba = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 118 | source_pts = detect_face_mesh([img_primer_rgba[:, :, :3]])[0] 119 | img_primer_rgba = cv2.cvtColor(img_primer_rgba, cv2.COLOR_BGRA2RGBA) 120 | 121 | source_crop_rect = crop_face_from_image(source_pts, crop_rotio=[0.75, 0.75, 0.65, 1.35]) 122 | standard_img, standard_v, standard_vt = get_standard_image(img_primer_rgba, source_pts, source_crop_rect, 123 | out_size=(750, 1000)) 124 | print(np.max(standard_vt[:, 0])) 125 | print(np.max(standard_vt[:, 1])) 126 | point_size = 1 127 | point_color = (0, 0, 255) # BGR 128 | thickness = 4 # 0 、4、8 129 | pts_ = standard_v 130 | img = standard_img 131 | for coor in pts_: 132 | # coor = (coor +1 )/2. 133 | cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, point_color, thickness) 134 | cv2.imshow("a", img) 135 | cv2.waitKey(-1) -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/generate_wrap_obj.py: -------------------------------------------------------------------------------- 1 | index_wrap = [0, 2, 11, 12, 13, 14, 15, 16, 17, 18, 32, 36, 37, 38, 39, 40, 41, 42, 43, 50, 57, 58, 61, 2 | 62, 72, 73, 74, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 3 | 96, 97, 98, 100, 101, 106, 116, 117, 118, 119, 123, 129, 132, 135, 136, 137, 138, 140, 4 | 142, 146, 147, 148, 149, 150, 152, 164, 165, 167, 169, 170, 171, 172, 175, 176, 177, 5 | 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 191, 192, 194, 199, 200, 201, 202, 6 | 203, 204, 205, 206, 207, 208, 210, 211, 212, 213, 214, 215, 216, 227, 234, 262, 266, 7 | 267, 268, 269, 270, 271, 272, 273, 280, 287, 288, 291, 292, 302, 303, 304, 306, 307, 8 | 308, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 9 | 326, 327, 329, 330, 335, 345, 346, 347, 348, 352, 358, 361, 364, 365, 366, 367, 369, 10 | 371, 375, 376, 377, 378, 379, 391, 393, 394, 395, 396, 397, 400, 401, 402, 403, 404, 11 | 405, 406, 407, 408, 409, 410, 411, 415, 416, 418, 421, 422, 423, 424, 425, 426, 427, 12 | 428, 430, 431, 432, 433, 434, 435, 436, 447, 454] 13 | 14 | # INDEX_MP_LIPS = [ 15 | # 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61, 16 | # 146, 91, 181, 84, 17, 314, 405, 321, 375, 17 | # 306, 408, 304, 303, 302, 11, 72, 73, 74, 184, 76, 18 | # 77, 90, 180, 85, 16, 315, 404, 320, 307, 19 | # 292, 407, 272, 271, 268, 12, 38, 41, 42, 183, 62, 20 | # 96, 89, 179, 86, 15, 316, 403, 319, 325, 21 | # 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78, 22 | # 95, 88, 178, 87, 14, 317, 402, 318, 324, 23 | # ] 24 | INDEX_MP_LIPS_LOWER = [ 25 | 146, 91, 181, 84, 17, 314, 405, 321, 375, 26 | 77, 90, 180, 85, 16, 315, 404, 320, 307, 27 | 96, 89, 179, 86, 15, 316, 403, 319, 325, 28 | 95, 88, 178, 87, 14, 317, 402, 318, 324, 29 | ] 30 | INDEX_MP_LIPS_UPPER = [ 31 | 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61, 32 | 306, 408, 304, 303, 302, 11, 72, 73, 74, 184, 76, 33 | 292, 407, 272, 271, 268, 12, 38, 41, 42, 183, 62, 34 | 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78, 35 | ] 36 | index_lips_upper_wrap = [] 37 | for i in INDEX_MP_LIPS_UPPER: 38 | for j in range(len(index_wrap)): 39 | if index_wrap[j] == i: 40 | index_lips_upper_wrap.append(j) 41 | index_lips_lower_wrap = [] 42 | for i in INDEX_MP_LIPS_LOWER: 43 | for j in range(len(index_wrap)): 44 | if index_wrap[j] == i: 45 | index_lips_lower_wrap.append(j) 46 | print(index_lips_upper_wrap[:11] + index_lips_upper_wrap[33:44][::-1]) 47 | print(index_lips_lower_wrap[:9] + index_lips_upper_wrap[27:36][::-1]) 48 | exit(-1) 49 | if __name__ == "__main__": 50 | # index_edge_wrap = [111,43,57,21,76,59,68,67,78,66,69,168,177,169,170,161,176,123,159,145,208] 51 | # index_edge_wrap = [110,60,79,108,61,58,73,74,62,75,77,175,164,174,173,160,163,205,178,162,207] 52 | index_edge_wrap = [110, 60, 79, 108, 61, 58, 73, 67, 78, 66, 69, 168, 177, 169, 173, 160, 163, 205, 178, 162, 207] 53 | index_edge_wrap_upper = [111, 110, 51, 52, 53, 54, 48, 63, 56, 47, 46, 1, 148, 149, 158, 165, 150, 156, 155, 154, 54 | 153, 207, 208] 55 | import numpy as np 56 | 57 | 58 | def readObjFile(filepath): 59 | v_ = [] 60 | face = [] 61 | with open(filepath) as f: 62 | # with open(r"face3D.obj") as f: 63 | content = f.readlines() 64 | for i in content: 65 | if i[:2] == "v ": 66 | v0, v1, v2 = i[2:-1].split(" ") 67 | v_.append(float(v0)) 68 | v_.append(float(v1)) 69 | v_.append(float(v2)) 70 | if i[:2] == "f ": 71 | tmp = i[2:-1].split(" ") 72 | for ii in tmp: 73 | a = ii.split("/")[0] 74 | a = int(a) - 1 75 | face.append(a) 76 | return v_, face 77 | 78 | 79 | verts_wrap, faces_wrap = readObjFile(r"wrap.obj") 80 | verts_wrap = np.array(verts_wrap).reshape(-1, 3) 81 | vert_mid = verts_wrap[index_edge_wrap[:4] + index_edge_wrap[-4:]].mean(axis=0) 82 | 83 | face_verts_num = len(verts_wrap) 84 | index_new_edge = [] 85 | new_vert_list = [] 86 | for i in range(len(index_edge_wrap)): 87 | index = index_edge_wrap[i] 88 | new_vert = verts_wrap[index] + (verts_wrap[index] - vert_mid) * 0.3 89 | new_vert[2] = verts_wrap[index, 2] 90 | new_vert_list.append(new_vert) 91 | index_new_edge.append(len(index_wrap) + i) 92 | for i in range(len(index_edge_wrap) - 1): 93 | faces_wrap.extend([index_edge_wrap[i], face_verts_num + i, index_edge_wrap[(i + 1) % len(index_edge_wrap)]]) 94 | faces_wrap.extend([index_edge_wrap[(i + 1) % len(index_edge_wrap)], face_verts_num + i, 95 | face_verts_num + (i + 1) % len(index_edge_wrap)]) 96 | 97 | verts_wrap = np.concatenate([verts_wrap, np.array(new_vert_list).reshape(-1, 3)], axis=0) 98 | 99 | v_teeth, face_teeth = readObjFile("modified_teeth_upper.obj") 100 | v_teeth2, face_teeth2 = readObjFile("modified_teeth_lower.obj") 101 | 102 | faces_wrap = faces_wrap + [i + len(verts_wrap) for i in face_teeth] + [i + len(verts_wrap) + len(v_teeth) // 3 for i 103 | in 104 | face_teeth2] 105 | 106 | verts_wrap = np.concatenate([verts_wrap, np.array(v_teeth).reshape(-1, 3)], axis=0) 107 | verts_wrap = np.concatenate([verts_wrap, np.array(v_teeth2).reshape(-1, 3)], axis=0) 108 | 109 | # 边缘-1 正常0 上嘴唇2 下嘴唇2.01 上牙3 下牙4 110 | verts_wrap2 = np.zeros([len(verts_wrap), 5]) 111 | verts_wrap2[:, :3] = verts_wrap 112 | verts_wrap2[index_lips_upper_wrap, 3] = 2.0 113 | verts_wrap2[index_lips_lower_wrap, 3] = 2.01 114 | verts_wrap2[-36:-18, 3] = 3 115 | verts_wrap2[-18:, 3] = 4 116 | verts_wrap2[index_edge_wrap_upper, 3] = -1 117 | verts_wrap2[index_new_edge, 3] = -1 118 | verts_wrap2[:, 4] = range(len(verts_wrap2)) 119 | 120 | with open("face_wrap_entity.obj", "w") as f: 121 | for i in verts_wrap2: 122 | f.write("v {:.3f} {:.3f} {:.3f} {:.02f} {:.0f}\n".format(i[0], i[1], i[2], i[3], i[4])) 123 | for i in range(len(faces_wrap) // 3): 124 | f.write( 125 | "f {0} {1} {2}\n".format(faces_wrap[3 * i] + 1, faces_wrap[3 * i + 1] + 1, faces_wrap[3 * i + 2] + 1)) 126 | 127 | # f 240 247 254 128 | # f 240 254 255 129 | 130 | # f 233 231 250 131 | # f 231 264 250 -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/modified_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def ModifyObjFile(path): 4 | with_vn = False 5 | verts = [] 6 | vt = [] 7 | vn = [] 8 | face = [] 9 | map_face = {} 10 | with open(path) as f: 11 | content = f.readlines() 12 | for i in content: 13 | if i[:2] == "v ": 14 | verts.append(i) 15 | if i[:3] == "vt ": 16 | vt.append(i) 17 | if i[:3] == "vn ": 18 | with_vn = True 19 | vn.append(i) 20 | if i[:2] == "f ": 21 | tmp = i[2:-1].split(" ") 22 | face.extend(tmp) 23 | print(len(verts),len(vt),len(vn),len(face),len(set(face))) 24 | set_face = set(face) 25 | print(len(set_face)) 26 | for index,i in enumerate(set_face): 27 | map_face[i] = index 28 | 29 | 30 | 31 | with open("modified_" + path, "w") as f: 32 | for i in set_face: 33 | index = int(i.split("/")[0]) - 1 34 | f.write(verts[index]) 35 | for i in set_face: 36 | index = int(i.split("/")[1]) - 1 37 | f.write(vt[index]) 38 | if with_vn: 39 | for i in set_face: 40 | index = int(i.split("/")[2]) - 1 41 | f.write(vn[index]) 42 | for i in range(len(face) // 3): 43 | f.write("f {0}/{0}/{0} {1}/{1}/{1} {2}/{2}/{2}\n".format(map_face[face[3 * i]] + 1, 44 | map_face[face[3 * i + 1]] + 1, 45 | map_face[face[3 * i + 2]] + 1)) 46 | else: 47 | for i in range(len(face) // 3): 48 | f.write("f {0}/{0} {1}/{1} {2}/{2}\n".format(map_face[face[3 * i]] + 1, 49 | map_face[face[3 * i + 1]] + 1, 50 | map_face[face[3 * i + 2]] + 1)) 51 | 52 | 53 | ModifyObjFile("teeth_lower.obj") 54 | ModifyObjFile("teeth_upper.obj") 55 | 56 | # ModifyObjFile("teeth_new.obj") 57 | # 58 | # import cv2 59 | # import numpy as np 60 | # img0 = cv2.imread(r"texture/Std_Lower_Teeth_diffuse.jpg") 61 | # img1 = cv2.imread(r"texture/Std_Tongue_diffuse.jpg") 62 | # img2 = cv2.imread(r"texture/Std_Upper_Teeth_diffuse.jpg") 63 | # img = np.concatenate([img0, img1, img2], axis = 1) 64 | # cv2.imwrite("teeth.png", cv2.resize(img, (256*3,256))) 65 | # 66 | # img0 = cv2.imread(r"texture/Std_Lower_Teeth_normal.png") 67 | # img1 = cv2.imread(r"texture/Std_Tongue_normal.jpg") 68 | # img2 = cv2.imread(r"texture/Std_Upper_Teeth_normal.png") 69 | # img = np.concatenate([img0, img1, img2], axis = 1) 70 | # cv2.imwrite("teeth_normal.png", cv2.resize(img, (256*3,256))) 71 | # print(img[:100,:100]) 72 | # img0 = cv2.imread(r"teeth.png").astype(float) 73 | # img1 = cv2.imread(r"teeth_ao.png", cv2.IMREAD_GRAYSCALE).astype(float)/255. 74 | # 75 | # img0[:,:,0] = img0[:,:,0]*img1 76 | # img0[:,:,1] = img0[:,:,1]*img1 77 | # img0[:,:,2] = img0[:,:,2]*img1 78 | # 79 | # cv2.imwrite("teeth_.png", img0.astype(np.uint8)) -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/modified_teeth_lower.obj: -------------------------------------------------------------------------------- 1 | v 518.879272 791.716567 -194.869528 2 | v 437.176788 786.993483 -148.410498 3 | v 564.105408 811.103163 -148.188971 4 | v 564.111938 786.930068 -148.587774 5 | v 500.202179 813.422133 -199.281134 6 | v 500.247925 792.438185 -199.317450 7 | v 448.235626 812.271743 -169.162191 8 | v 448.281067 788.599501 -169.281454 9 | v 536.171387 812.935256 -186.142706 10 | v 464.157166 813.003676 -186.019003 11 | v 519.147461 813.607375 -194.448736 12 | v 481.376465 813.749221 -194.769247 13 | v 553.153503 788.594680 -169.177298 14 | v 481.245972 791.741103 -194.800222 15 | v 536.373901 790.437209 -186.078527 16 | v 464.153625 790.481154 -185.781210 17 | v 437.316589 810.741469 -148.175405 18 | v 553.047424 812.201857 -169.284903 19 | vt 1.000000 0.500000 20 | vt 0.000000 0.500000 21 | vt 0.000000 1.000000 22 | vt 0.000000 0.500000 23 | vt 0.000000 1.000000 24 | vt 0.000000 0.500000 25 | vt 1.000000 1.000000 26 | vt 1.000000 0.500000 27 | vt 0.000000 1.000000 28 | vt 0.000000 1.000000 29 | vt 1.000000 1.000000 30 | vt 1.000000 1.000000 31 | vt 1.000000 0.500000 32 | vt 1.000000 0.500000 33 | vt 0.000000 0.500000 34 | vt 0.000000 0.500000 35 | vt 0.000000 1.000000 36 | vt 1.000000 1.000000 37 | f 4/4 13/13 3/3 38 | f 3/3 13/13 18/18 39 | f 13/13 15/15 18/18 40 | f 18/18 15/15 9/9 41 | f 15/15 1/1 9/9 42 | f 9/9 1/1 11/11 43 | f 1/1 6/6 11/11 44 | f 11/11 6/6 5/5 45 | f 6/6 14/14 5/5 46 | f 5/5 14/14 12/12 47 | f 14/14 16/16 12/12 48 | f 12/12 16/16 10/10 49 | f 16/16 8/8 10/10 50 | f 10/10 8/8 7/7 51 | f 8/8 2/2 7/7 52 | f 7/7 2/2 17/17 53 | -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/modified_teeth_upper.obj: -------------------------------------------------------------------------------- 1 | v 522.885315 745.574219 -210.641006 2 | v 434.873383 736.817932 -155.635681 3 | v 565.570435 756.208923 -155.356766 4 | v 565.995178 736.791138 -155.788223 5 | v 500.232025 766.599976 -215.250641 6 | v 500.255157 746.713623 -215.260330 7 | v 443.795746 759.215149 -176.353271 8 | v 443.661713 739.395630 -176.518082 9 | v 541.328491 762.087402 -196.421417 10 | v 458.383057 762.203857 -195.788666 11 | v 522.946716 765.776550 -210.566010 12 | v 478.100006 765.962036 -210.878372 13 | v 557.205872 739.355713 -176.471313 14 | v 477.799988 745.691895 -210.943954 15 | v 540.990906 742.639099 -196.660324 16 | v 458.399994 742.563721 -195.760330 17 | v 434.232269 756.359375 -155.438797 18 | v 557.125732 759.150391 -176.335144 19 | vt 1.000000 0.000000 20 | vt 0.000000 0.000000 21 | vt 0.000000 0.500000 22 | vt 0.000000 0.000000 23 | vt 0.000000 0.500000 24 | vt 0.000000 0.000000 25 | vt 1.000000 0.500000 26 | vt 1.000000 0.000000 27 | vt 0.000000 0.500000 28 | vt 0.000000 0.500000 29 | vt 1.000000 0.500000 30 | vt 1.000000 0.500000 31 | vt 1.000000 0.000000 32 | vt 1.000000 0.000000 33 | vt 0.000000 0.000000 34 | vt 0.000000 0.000000 35 | vt 0.000000 0.500000 36 | vt 1.000000 0.500000 37 | f 4/4 13/13 3/3 38 | f 3/3 13/13 18/18 39 | f 13/13 15/15 18/18 40 | f 18/18 15/15 9/9 41 | f 15/15 1/1 9/9 42 | f 9/9 1/1 11/11 43 | f 1/1 6/6 11/11 44 | f 11/11 6/6 5/5 45 | f 6/6 14/14 5/5 46 | f 5/5 14/14 12/12 47 | f 14/14 16/16 12/12 48 | f 12/12 16/16 10/10 49 | f 16/16 8/8 10/10 50 | f 10/10 8/8 7/7 51 | f 8/8 2/2 7/7 52 | f 7/7 2/2 17/17 53 | -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/teeth_lower.obj: -------------------------------------------------------------------------------- 1 | # This file uses centimeters as units for non-parametric coordinates. 2 | 3 | mtllib teeth_lower.mtl 4 | v 437.176788 786.993483 -148.410498 5 | v 448.235626 812.271743 -169.162191 6 | v 518.879272 791.716567 -194.869528 7 | v 536.171387 812.935256 -186.142706 8 | v 464.153625 790.481154 -185.781210 9 | v 500.247925 792.438185 -199.317450 10 | v 564.111938 786.930068 -148.587774 11 | v 564.105408 811.103163 -148.188971 12 | v 500.202179 813.422133 -199.281134 13 | v 481.376465 813.749221 -194.769247 14 | v 464.157166 813.003676 -186.019003 15 | v 448.281067 788.599501 -169.281454 16 | v 519.147461 813.607375 -194.448736 17 | v 553.153503 788.594680 -169.177298 18 | v 437.316589 810.741469 -148.175405 19 | v 553.047424 812.201857 -169.284903 20 | v 481.245972 791.741103 -194.800222 21 | v 536.373901 790.437209 -186.078527 22 | vt 0.000000 0.500000 23 | vt 1.000000 0.500000 24 | vt 0.000000 1.000000 25 | vt 1.000000 1.000000 26 | vt 0.000000 0.500000 27 | vt 0.000000 1.000000 28 | vt 1.000000 0.500000 29 | vt 1.000000 1.000000 30 | vt 0.000000 0.500000 31 | vt 0.000000 1.000000 32 | vt 1.000000 0.500000 33 | vt 1.000000 1.000000 34 | vt 0.000000 0.500000 35 | vt 0.000000 1.000000 36 | vt 1.000000 0.500000 37 | vt 1.000000 1.000000 38 | vt 0.000000 0.500000 39 | vt 0.000000 1.000000 40 | usemtl initialShadingGroup 41 | f 7/1 14/2 8/3 42 | f 8/3 14/2 16/4 43 | f 14/2 18/5 16/4 44 | f 16/4 18/5 4/6 45 | f 18/5 3/7 4/6 46 | f 4/6 3/7 13/8 47 | f 3/7 6/9 13/8 48 | f 13/8 6/9 9/10 49 | f 6/9 17/11 9/10 50 | f 9/10 17/11 10/12 51 | f 17/11 5/13 10/12 52 | f 10/12 5/13 11/14 53 | f 5/13 12/15 11/14 54 | f 11/14 12/15 2/16 55 | f 12/15 1/17 2/16 56 | f 2/16 1/17 15/18 57 | -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/teeth_upper.obj: -------------------------------------------------------------------------------- 1 | # This file uses centimeters as units for non-parametric coordinates. 2 | 3 | mtllib teeth_upper.mtl 4 | v 434.873383 736.817932 -155.635681 5 | v 443.795746 759.215149 -176.353271 6 | v 522.885315 745.574219 -210.641006 7 | v 541.328491 762.087402 -196.421417 8 | v 458.399994 742.563721 -195.760330 9 | v 500.255157 746.713623 -215.260330 10 | v 565.995178 736.791138 -155.788223 11 | v 565.570435 756.208923 -155.356766 12 | v 500.232025 766.599976 -215.250641 13 | v 478.100006 765.962036 -210.878372 14 | v 458.383057 762.203857 -195.788666 15 | v 443.661713 739.395630 -176.518082 16 | v 522.946716 765.776550 -210.566010 17 | v 557.205872 739.355713 -176.471313 18 | v 434.232269 756.359375 -155.438797 19 | v 557.125732 759.150391 -176.335144 20 | v 477.799988 745.691895 -210.943954 21 | v 540.990906 742.639099 -196.660324 22 | vt 0.000000 0.000000 23 | vt 1.000000 0.000000 24 | vt 0.000000 0.500000 25 | vt 1.000000 0.500000 26 | vt 0.000000 0.000000 27 | vt 0.000000 0.500000 28 | vt 1.000000 0.000000 29 | vt 1.000000 0.500000 30 | vt 0.000000 0.000000 31 | vt 0.000000 0.500000 32 | vt 1.000000 0.000000 33 | vt 1.000000 0.500000 34 | vt 0.000000 0.000000 35 | vt 0.000000 0.500000 36 | vt 1.000000 0.000000 37 | vt 1.000000 0.500000 38 | vt 0.000000 0.000000 39 | vt 0.000000 0.500000 40 | usemtl initialShadingGroup 41 | f 7/1 14/2 8/3 42 | f 8/3 14/2 16/4 43 | f 14/2 18/5 16/4 44 | f 16/4 18/5 4/6 45 | f 18/5 3/7 4/6 46 | f 4/6 3/7 13/8 47 | f 3/7 6/9 13/8 48 | f 13/8 6/9 9/10 49 | f 6/9 17/11 9/10 50 | f 9/10 17/11 10/12 51 | f 17/11 5/13 10/12 52 | f 10/12 5/13 11/14 53 | f 5/13 12/15 11/14 54 | f 11/14 12/15 2/16 55 | f 12/15 1/17 2/16 56 | f 2/16 1/17 15/18 57 | -------------------------------------------------------------------------------- /mini_live/obj/obj_mediapipe/wrap_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | def readObjFile(filepath): 4 | v_ = [] 5 | face = [] 6 | with open(filepath) as f: 7 | # with open(r"face3D.obj") as f: 8 | content = f.readlines() 9 | for i in content: 10 | if i[:2] == "v ": 11 | v0,v1,v2 = i[2:-1].split(" ") 12 | v_.append(float(v0)) 13 | v_.append(float(v1)) 14 | v_.append(float(v2)) 15 | if i[:2] == "f ": 16 | tmp = i[2:-1].split(" ") 17 | for ii in tmp: 18 | a = ii.split("/")[0] 19 | a = int(a) - 1 20 | face.append(a) 21 | return v_, face 22 | 23 | verts_face,_ = readObjFile(r"face3D.obj") 24 | verts_wrap,_ = readObjFile(r"wrap.obj") 25 | 26 | verts_flame = np.array(verts_face).reshape(-1, 3) 27 | verts_mouth = np.array(verts_wrap).reshape(-1, 3) 28 | index_mouthInFlame = [] 29 | for index in range(len(verts_mouth)): 30 | vert = verts_mouth[index] 31 | dist_list = [] 32 | for i in verts_flame: 33 | dist_list.append(np.linalg.norm(i - vert)) 34 | align_index = np.argmin(dist_list) 35 | index_mouthInFlame.append(align_index) 36 | print(index_mouthInFlame) 37 | # exit() 38 | 39 | # from obj.utils import INDEX_FLAME_LIPS 40 | # index_mouthInFlame = np.array(index_mouthInFlame, dtype = int)[INDEX_FLAME_LIPS] 41 | # np.savetxt("index_mouthInFlame.txt", index_mouthInFlame) -------------------------------------------------------------------------------- /mini_live/obj/obj_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | current_dir = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | INDEX_FACE_EDGE = [ 7 | 234, 127, 162, 21, 8 | 54, 103, 67, 109, 10, 338, 297, 332, 284, 251, 9 | 389, 356, 10 | 454, 323, 361, 288, 397, 365, 11 | 379, 378, 400, 377, 152, 148, 176, 149, 150, 12 | 136, 172, 58, 132, 13 | 93, 14 | ] 15 | def readObjFile(filepath): 16 | with_vn = False 17 | with_vt = False 18 | v_ = [] 19 | vt = [] 20 | vn = [] 21 | face = [] 22 | with open(filepath) as f: 23 | # with open(r"face3D.obj") as f: 24 | content = f.readlines() 25 | for i in content: 26 | if i[:2] == "v ": 27 | v0,v1,v2 = i[2:-1].split(" ") 28 | v_.append(float(v0)) 29 | v_.append(float(v1)) 30 | v_.append(float(v2)) 31 | if i[:3] == "vt ": 32 | with_vt = True 33 | vt0,vt1 = i[3:-1].split(" ") 34 | vt.append(float(vt0)) 35 | vt.append(float(vt1)) 36 | if i[:3] == "vn ": 37 | with_vn = True 38 | vn0,vn1,vn2 = i[3:-1].split(" ") 39 | vn.append(float(vn0)) 40 | vn.append(float(vn1)) 41 | vn.append(float(vn2)) 42 | if i[:2] == "f ": 43 | tmp = i[2:-1].split(" ") 44 | for ii in tmp: 45 | a = ii.split("/")[0] 46 | a = int(a) - 1 47 | face.append(a) 48 | if not with_vn: 49 | vn = [0 for i in v_] 50 | if not with_vt: 51 | vt = [0 for i in range(len(v_)//3*2)] 52 | return v_, vt, vn, face 53 | 54 | def generateRenderInfo_mediapipe(): 55 | v_face, vt_face, vn_face, face_face = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/face3D.obj")) 56 | v_teeth, vt_teeth, vn_teeth, face_teeth = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/modified_teeth_upper.obj")) 57 | v_teeth2, vt_teeth2, vn_teeth2, face_teeth2 = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/modified_teeth_lower.obj")) 58 | 59 | v_, vt, vn, face = ( 60 | v_face + v_teeth + v_teeth2, vt_face + vt_teeth + vt_teeth2, vn_face + vn_teeth + vn_teeth2, 61 | face_face + [i + len(v_face)//3 for i in face_teeth] + [i + len(v_face)//3 + len(v_teeth)//3 for i in face_teeth2]) 62 | v_ = np.array(v_).reshape(-1, 3) 63 | 64 | # v_[:, 1] = -v_[:, 1] 65 | 66 | # 0-2: verts 3-4: vt 5:category 6: index 7-10 bone_weight 11-12 another vt 67 | vertices = np.zeros([len(v_), 13]) 68 | # vertices = np.zeros([len(pts_array_), 6]) 69 | 70 | vertices[:, :3] = v_ 71 | vertices[:, 3:5] = np.array(vt).reshape(-1, 2) 72 | vertices[:, 11:13] = np.array(vt).reshape(-1, 2) 73 | vertices[:, 12] = 1 - vertices[:, 12] 74 | # vertices[:, 5] = 0 75 | # 脸部为0,眼睛1,上牙2,下牙3 76 | vertices[468:478, 5] = 1. 77 | vertices[478:478 + 18, 5] = 2. 78 | vertices[478 + 18:478 + 36, 5] = 3. 79 | vertices[:, 6] = list(range(len(v_))) 80 | return vertices, face 81 | 82 | def generateRenderInfo(floor = 5): 83 | v_face, vt_face, vn_face, face_face = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/face3D.obj")) 84 | v_teeth, vt_teeth, vn_teeth, face_teeth = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/modified_teeth_upper.obj")) 85 | v_teeth2, vt_teeth2, vn_teeth2, face_teeth2 = readObjFile(os.path.join(current_dir,"../obj/obj_mediapipe/modified_teeth_lower.obj")) 86 | print(len(v_face), len(vt_face), len(vn_face), len(face_face)) 87 | print(len(v_teeth)//3, len(vt_teeth), len(vn_teeth), len(face_teeth)) 88 | print(len(v_face)//3 + len(v_teeth)//3 + len(v_teeth2)//3) 89 | 90 | v_, vt, vn, face = ( 91 | v_face + v_teeth + v_teeth2, vt_face + vt_teeth + vt_teeth2, vn_face + vn_teeth + vn_teeth2, 92 | face_face + [i + len(v_face)//3 for i in face_teeth] + [i + len(v_face)//3 + len(v_teeth)//3 for i in face_teeth2]) 93 | v_ = np.array(v_).reshape(-1, 3) 94 | 95 | # v_[:, 1] = -v_[:, 1] 96 | 97 | vertices = np.zeros([len(v_), 13]) 98 | # vertices = np.zeros([len(pts_array_), 6]) 99 | 100 | vertices[:, :3] = v_ 101 | vertices[:, 3:5] = np.array(vt).reshape(-1, 2) 102 | 103 | # 脸部为0,眼睛1,上牙2,下牙3, 补充的为9 104 | vertices[468:478, 5] = 1. 105 | 106 | vertices[len(v_face)//3:len(v_face)//3 + len(v_teeth)//3, 5] = 2. 107 | vertices[len(v_face)//3 + len(v_teeth)//3:len(v_face)//3 + len(v_teeth)//3 + len(v_teeth2)//3, 5] = 3. 108 | vertices[:, 6] = list(range(len(v_))) 109 | return vertices, face 110 | 111 | 112 | def generateWrapModel(): 113 | v_ = [] 114 | face = [] 115 | filepath = os.path.join(current_dir,"../obj/obj_mediapipe/face_wrap_entity.obj") 116 | with open(filepath) as f: 117 | # with open(r"face3D.obj") as f: 118 | content = f.readlines() 119 | for i in content: 120 | if i[:2] == "v ": 121 | v0, v1, v2, v3, v4 = i[2:-1].split(" ") 122 | v_.append(float(v0)) 123 | v_.append(float(v1)) 124 | v_.append(float(v2)) 125 | v_.append(float(v3)) 126 | v_.append(float(v4)) 127 | if i[:2] == "f ": 128 | tmp = i[2:-1].split(" ") 129 | for ii in tmp: 130 | a = ii.split("/")[0] 131 | a = int(a) - 1 132 | face.append(a) 133 | return np.array(v_).reshape(-1, 5), face 134 | 135 | def NewFaceVerts(render_verts, source_crop_pts, face_pts_mean): 136 | from talkingface.run_utils import calc_face_mat 137 | mat_list, _, face_pts_mean_personal_primer = calc_face_mat(source_crop_pts[np.newaxis, :478, :], 138 | face_pts_mean) 139 | 140 | 141 | 142 | # print(face_pts_mean_personal_primer.shape) 143 | mat_list__ = mat_list[0].T 144 | # mat_list__ = np.linalg.inv(mat_list[0]) 145 | render_verts[:478,:3] = face_pts_mean_personal_primer 146 | 147 | # 牙齿部分校正 148 | from talkingface.utils import INDEX_LIPS,main_keypoints_index,INDEX_LIPS_UPPER,INDEX_LIPS_LOWER 149 | # # 上嘴唇中点 150 | # mid_upper_mouth = np.mean(face_pts_mean_personal_primer[main_keypoints_index][INDEX_LIPS],axis = 0) 151 | # mid_upper_teeth = np.mean(render_verts[478:478 + 36,:3], axis=0) 152 | # tmp = mid_upper_teeth - mid_upper_mouth 153 | # print(tmp) 154 | # render_verts[478:478 + 36, :2] = render_verts[478:478 + 36, :2] - tmp[:2] 155 | 156 | # 上嘴唇中点 157 | mid_upper_mouth = np.mean(face_pts_mean_personal_primer[main_keypoints_index][INDEX_LIPS_UPPER],axis = 0) 158 | mid_upper_teeth = np.mean(render_verts[478:478 + 18,:3], axis=0) 159 | tmp = mid_upper_teeth - mid_upper_mouth 160 | print(tmp) 161 | render_verts[478:478 + 18, :2] = render_verts[478:478 + 18, :2] - tmp[:2] 162 | 163 | # 下嘴唇中点 164 | mid_lower_mouth = np.mean(face_pts_mean_personal_primer[main_keypoints_index][INDEX_LIPS_LOWER],axis = 0) 165 | mid_lower_teeth = np.mean(render_verts[478:478 + 18,:3], axis=0) 166 | tmp = mid_lower_teeth - mid_lower_mouth 167 | print(tmp) 168 | render_verts[478 + 18:478 + 36, :2] = render_verts[478 + 18:478 + 36, :2] - tmp[:2] 169 | 170 | 171 | return render_verts, mat_list__ 172 | -------------------------------------------------------------------------------- /mini_live/obj/utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | current_dir = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | def translation_matrix(point): 9 | """生成平移矩阵""" 10 | return np.array([ 11 | [1, 0, 0, point[0]], 12 | [0, 1, 0, point[1]], 13 | [0, 0, 1, point[2]], 14 | [0, 0, 0, 1] 15 | ]) 16 | def rotate_around_point(point, theta, phi, psi): 17 | """围绕点P旋转""" 18 | # 将点P平移到原点 19 | T1 = translation_matrix(-point) 20 | 21 | # 定义欧拉角 22 | theta = np.radians(theta) # 俯仰角 23 | phi = np.radians(phi) # 偏航角 24 | psi = np.radians(psi) # 翻滚角 25 | 26 | # 创建旋转矩阵 27 | tmp = [theta, phi, psi] 28 | matX = np.array([[1.0, 0, 0, 0], 29 | [0.0, np.cos(tmp[0]), -np.sin(tmp[0]), 0], 30 | [0.0, np.sin(tmp[0]), np.cos(tmp[0]), 0], 31 | [0, 0, 0, 1] 32 | ]) 33 | matY = np.array([[np.cos(tmp[1]), 0, np.sin(tmp[1]), 0], 34 | [0, 1, 0, 0], 35 | [-np.sin(tmp[1]),0, np.cos(tmp[1]), 0], 36 | [0, 0, 0, 1] 37 | ]) 38 | matZ = np.array([[np.cos(tmp[2]), -np.sin(tmp[2]), 0, 0], 39 | [np.sin(tmp[2]), np.cos(tmp[2]), 0, 0], 40 | [0, 0, 1, 0], 41 | [0, 0, 0, 1] 42 | ]) 43 | 44 | R = matZ @ matY @ matX 45 | 46 | # 将点P移回其原始位置 47 | T2 = translation_matrix(point) 48 | 49 | # 总的变换矩阵 50 | total_transform = T2 @ R @ T1 51 | 52 | return total_transform 53 | 54 | def rodrigues_rotation_formula(axis, theta): 55 | """Calculate the rotation matrix using Rodrigues' rotation formula.""" 56 | axis = np.asarray(axis) / np.linalg.norm(axis) # Normalize the axis 57 | cos_theta = np.cos(theta) 58 | sin_theta = np.sin(theta) 59 | K = np.array([[0, -axis[2], axis[1]], 60 | [axis[2], 0, -axis[0]], 61 | [-axis[1], axis[0], 0]]) 62 | R = np.eye(3) + sin_theta * K + (1 - cos_theta) * np.dot(K, K) 63 | return R 64 | def RotateAngle2Matrix(center, axis, theta): 65 | """Rotate around a center point.""" 66 | # Step 1: Translate the center to the origin 67 | translation_to_origin = np.eye(4) 68 | translation_to_origin[:3, 3] = -center 69 | 70 | # Step 2: Apply the rotation 71 | R = rodrigues_rotation_formula(axis, theta) 72 | R_ = np.eye(4) 73 | R_[:3,:3] = R 74 | R = R_ 75 | 76 | # Step 3: Translate back to the original position 77 | translation_back = np.eye(4) 78 | translation_back[:3, 3] = center 79 | 80 | # Combine the transformations 81 | rotation_matrix = translation_back @ R @ translation_to_origin 82 | 83 | return rotation_matrix 84 | 85 | INDEX_FLAME_LIPS = [ 86 | 1,26,23,21,8,155,83,96,98,101, 87 | 73,112,123,124,143,146,71,52,51,40, 88 | 2,25,24,22,7,156,82,97,99,100, 89 | 74,113,122,125,138,148,66,53,50,41, 90 | 30,31,32,38,39,157,111,110,106,105, 91 | 104,120,121,126,137,147,65,54,49,48, 92 | 4,28,33,20,19,153,94,95,107,103, 93 | 76,118,119,127,136,149,64,55,47,46, 94 | 95 | 3,27,35,17,18,154,93,92,109,102, 96 | 75,114,115,128,133,151,61,56,43,42, 97 | 6,29, 13, 12, 11, 158, 86, 87, 88, 79, 98 | 80,117, 116, 135, 134, 150, 62, 63, 44, 45, 99 | 5,36,14,9,10,159,85,84,89,78, 100 | 77,141,130,131,139,145,67,59,58,69, 101 | 0,37,34,15,16,152,91,90,108,81,72, 102 | 142,129,132,140,144,68,60,57,70, 103 | ] 104 | INDEX_MP_LIPS = [ 105 | 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61, 106 | 146, 91, 181, 84, 17, 314, 405, 321, 375, 107 | 306, 408, 304, 303, 302, 11, 72, 73, 74, 184, 76, 108 | 77, 90, 180, 85, 16, 315, 404, 320, 307, 109 | 292, 407, 272, 271, 268, 12, 38, 41, 42, 183, 62, 110 | 96, 89, 179, 86, 15, 316, 403, 319, 325, 111 | 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78, 112 | 95, 88, 178, 87, 14, 317, 402, 318, 324, 113 | ] 114 | 115 | def crop_mouth(mouth_pts, mat_list__): 116 | """ 117 | x_ratio: 裁剪出一个正方形,边长根据keypoints的宽度 * x_ratio决定 118 | """ 119 | num_ = len(mouth_pts) 120 | keypoints = np.ones([4, num_]) 121 | keypoints[:3, :] = mouth_pts.T 122 | keypoints = mat_list__.dot(keypoints).T 123 | keypoints = keypoints[:, :3] 124 | 125 | x_min, y_min, x_max, y_max = np.min(keypoints[:, 0]), np.min(keypoints[:, 1]), np.max(keypoints[:, 0]), np.max(keypoints[:, 1]) 126 | border_width_half = max(x_max - x_min, y_max - y_min) * 0.66 127 | y_min = y_min + border_width_half * 0.3 128 | center_x = (x_min + x_max) /2. 129 | center_y = (y_min + y_max) /2. 130 | x_min, y_min, x_max, y_max = int(center_x - border_width_half), int(center_y - border_width_half*0.75), int( 131 | center_x + border_width_half), int(center_y + border_width_half*0.75) 132 | print([x_min, y_min, x_max, y_max]) 133 | 134 | # pts = np.array([ 135 | # [x_min, y_min], 136 | # [x_max, y_min], 137 | # [x_max, y_max], 138 | # [x_min, y_max] 139 | # ]) 140 | return [x_min, y_min, x_max, y_max] 141 | 142 | def drawMouth(keypoints, source_texture, out_size = (700, 1400)): 143 | INDEX_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191] 144 | INDEX_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, ] 145 | INDEX_LIPS_LOWWER = INDEX_LIPS_INNER[:11] + INDEX_LIPS_OUTER[:11][::-1] 146 | INDEX_LIPS_UPPER = INDEX_LIPS_INNER[10:] + [INDEX_LIPS_INNER[0], INDEX_LIPS_OUTER[0]] + INDEX_LIPS_OUTER[10:][::-1] 147 | INDEX_LIPS = INDEX_LIPS_INNER + INDEX_LIPS_OUTER 148 | # keypoints = keypoints[INDEX_LIPS] 149 | keypoints[:, 0] = keypoints[:, 0] * out_size[0] 150 | keypoints[:, 1] = keypoints[:, 1] * out_size[1] 151 | # pts = keypoints[20:40] 152 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 153 | # cv2.fillPoly(source_texture, [pts], color=(255, 0, 0,)) 154 | # pts = keypoints[:20] 155 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 156 | # cv2.fillPoly(source_texture, [pts], color=(0, 0, 0,)) 157 | 158 | pts = keypoints[INDEX_LIPS_OUTER] 159 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 160 | cv2.fillPoly(source_texture, [pts], color=(0, 0, 0)) 161 | pts = keypoints[INDEX_LIPS_UPPER] 162 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 163 | cv2.fillPoly(source_texture, [pts], color=(255, 0, 0)) 164 | pts = keypoints[INDEX_LIPS_LOWWER] 165 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 166 | cv2.fillPoly(source_texture, [pts], color=(127, 0, 0)) 167 | 168 | prompt_texture = np.zeros_like(source_texture) 169 | pts = keypoints[INDEX_LIPS_UPPER] 170 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 171 | cv2.fillPoly(prompt_texture, [pts], color=(255, 0, 0)) 172 | pts = keypoints[INDEX_LIPS_LOWWER] 173 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 174 | cv2.fillPoly(prompt_texture, [pts], color=(127, 0, 0)) 175 | return source_texture, prompt_texture 176 | 177 | 178 | 179 | # 180 | # def draw_face_feature_maps(keypoints, mode = ["mouth", "nose", "eye", "oval"], size=(256, 256), im_edges = None): 181 | # w, h = size 182 | # # edge map for face region from keypoints 183 | # if im_edges is None: 184 | # im_edges = np.zeros((h, w, 3), np.uint8) # edge map for all edges 185 | # if "mouth" in mode: 186 | # pts = keypoints[INDEX_LIPS_OUTER] 187 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 188 | # cv2.fillPoly(im_edges, [pts], color=(0, 0, 0)) 189 | # pts = keypoints[INDEX_LIPS_UPPER] 190 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 191 | # cv2.fillPoly(im_edges, [pts], color=(255, 0, 0)) 192 | # pts = keypoints[INDEX_LIPS_LOWWER] 193 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 194 | # cv2.fillPoly(im_edges, [pts], color=(127, 0, 0)) 195 | # return im_edges -------------------------------------------------------------------------------- /mini_live/obj/wrap_utils.py: -------------------------------------------------------------------------------- 1 | index_wrap = [0, 2, 11, 12, 13, 14, 15, 16, 17, 18, 32, 36, 37, 38, 39, 40, 41, 42, 43, 50, 57, 58, 61, 2 | 62, 72, 73, 74, 76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 3 | 96, 97, 98, 100, 101, 106, 116, 117, 118, 119, 123, 129, 132, 135, 136, 137, 138, 140, 4 | 142, 146, 147, 148, 149, 150, 152, 164, 165, 167, 169, 170, 171, 172, 175, 176, 177, 5 | 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 191, 192, 194, 199, 200, 201, 202, 6 | 203, 204, 205, 206, 207, 208, 210, 211, 212, 213, 214, 215, 216, 227, 234, 262, 266, 7 | 267, 268, 269, 270, 271, 272, 273, 280, 287, 288, 291, 292, 302, 303, 304, 306, 307, 8 | 308, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 9 | 326, 327, 329, 330, 335, 345, 346, 347, 348, 352, 358, 361, 364, 365, 366, 367, 369, 10 | 371, 375, 376, 377, 378, 379, 391, 393, 394, 395, 396, 397, 400, 401, 402, 403, 404, 11 | 405, 406, 407, 408, 409, 410, 411, 415, 416, 418, 421, 422, 423, 424, 425, 426, 427, 12 | 428, 430, 431, 432, 433, 434, 435, 436, 447, 454] 13 | 14 | # index_edge_wrap = [111,43,57,21,76,59,68,67,78,66,69,168,177,169,170,161,176,123,159,145,208] 15 | index_edge_wrap = [110,60,79,108,61,58,73,67,78,66,69,168,177,169,173,160,163,205,178,162,207] 16 | index_edge_wrap_upper = [111, 110, 51, 52, 53, 54, 48, 63, 56, 47, 46, 1, 148, 149, 158, 165, 150, 156, 155, 154, 153, 207, 208] 17 | 18 | print(len(index_wrap), len(set(index_wrap))) 19 | 20 | # index_wrap = index_wrap + [291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61, 21 | # 146, 91, 181, 84, 17, 314, 405, 321, 375,] 22 | import numpy as np 23 | # 求平均人脸 24 | def newWrapModel(wrapModel, face_pts_mean_personal_primer): 25 | 26 | face_wrap_entity = wrapModel.copy() 27 | # 正规点 28 | face_wrap_entity[:len(index_wrap),:3] = face_pts_mean_personal_primer[index_wrap, :3] 29 | # 边缘点 30 | vert_mid = face_wrap_entity[:,:3][index_edge_wrap[:4] + index_edge_wrap[-4:]].mean(axis=0) 31 | for index_, jj in enumerate(index_edge_wrap): 32 | face_wrap_entity[len(index_wrap) + index_,:3] = face_wrap_entity[jj, :3] + (face_wrap_entity[jj, :3] - vert_mid) * 0.32 33 | face_wrap_entity[len(index_wrap) + index_, 2] = face_wrap_entity[jj, 2] 34 | # 牙齿点 35 | from talkingface.utils import INDEX_LIPS, main_keypoints_index, INDEX_LIPS_UPPER, INDEX_LIPS_LOWER 36 | # 上嘴唇中点 37 | mid_upper_mouth = np.mean(face_pts_mean_personal_primer[main_keypoints_index][INDEX_LIPS_UPPER], axis=0) 38 | mid_upper_teeth = np.mean(face_wrap_entity[-36:-18, :3], axis=0) 39 | tmp = mid_upper_teeth - mid_upper_mouth 40 | face_wrap_entity[-36:-18, :2] = face_wrap_entity[-36:-18, :2] - tmp[:2] 41 | # # 下嘴唇中点 42 | # mid_lower_mouth = np.mean(face_pts_mean_personal_primer[main_keypoints_index][INDEX_LIPS_LOWER], axis=0) 43 | # mid_lower_teeth = np.mean(face_wrap_entity[-18:, :3], axis=0) 44 | # tmp = mid_lower_teeth - mid_lower_mouth 45 | # # print(tmp) 46 | # face_wrap_entity[-18:, :2] = face_wrap_entity[-18:, :2] - tmp[:2] 47 | 48 | return face_wrap_entity -------------------------------------------------------------------------------- /mini_live/shader/prompt3.fsh: -------------------------------------------------------------------------------- 1 | # version 330 2 | precision mediump float; 3 | in mediump vec2 v_texture; 4 | in mediump vec2 v_bias; 5 | out highp vec4 out_color; 6 | 7 | void main() 8 | { 9 | if (v_texture.x == 2.0f) 10 | { 11 | out_color = vec4(1.0, 0.0, 0.0, 1.0); 12 | } 13 | else if (v_texture.x > 2.0f && v_texture.x < 2.1f) 14 | { 15 | out_color = vec4(0.5f, 0.0, 0.0, 1.0); 16 | } 17 | else if (v_texture.x == 3.0f) 18 | { 19 | out_color = vec4(0.0, 1.0, 0.0, 1.0); 20 | } 21 | else if (v_texture.x == 4.0f) 22 | { 23 | out_color = vec4(0.0, 0.0, 1.0, 1.0); 24 | } 25 | else if (v_texture.x > 3.0f && v_texture.x < 4.0f) 26 | { 27 | out_color = vec4(0.0, 0.0, 0.0, 1.0); 28 | } 29 | else 30 | { 31 | vec2 wrap = (v_bias.xy + 1.0)/2.0; 32 | out_color = vec4(wrap.xy, 0.5, 1.0); 33 | } 34 | } -------------------------------------------------------------------------------- /mini_live/shader/prompt3.vsh: -------------------------------------------------------------------------------- 1 | # version 330 2 | 3 | layout(location = 0) in vec3 a_position; 4 | layout(location = 1) in vec2 a_texture; 5 | uniform float bsVec[12]; 6 | uniform mat4 gProjection; 7 | uniform mat4 gWorld0; 8 | uniform sampler2D texture_bs; 9 | 10 | uniform vec2 vertBuffer[209]; 11 | 12 | out vec2 v_texture; 13 | out vec2 v_bias; 14 | vec4 calculateMorphPosition(vec3 position, vec2 textureCoord) { 15 | vec4 tmp_Position2 = vec4(position, 1.0); 16 | if (textureCoord.x < 3.0) { 17 | vec3 morphSum = vec3(0.0); 18 | for (int i = 0; i < 6; i++) { 19 | ivec2 coord = ivec2(int(textureCoord.y), i); 20 | vec3 morph = texelFetch(texture_bs, coord, 0).xyz * 2.0 - 1.0; 21 | morphSum += bsVec[i] * morph; 22 | } 23 | ivec2 coord6 = ivec2(int(textureCoord.y), 6); 24 | morphSum += bsVec[6] * texelFetch(texture_bs, coord6, 0).xyz; 25 | tmp_Position2.xyz += morphSum; 26 | } 27 | else if (textureCoord.x == 4.0) 28 | { 29 | // lower teeth 30 | vec3 morphSum = vec3(0.0, (bsVec[0] + bsVec[1])/ 2.7 + 6, 0.0); 31 | tmp_Position2.xyz += morphSum; 32 | } 33 | return tmp_Position2; 34 | } 35 | 36 | void main() { 37 | mat4 gWorld = gWorld0; 38 | 39 | vec4 tmp_Position2 = calculateMorphPosition(a_position, a_texture); 40 | vec4 tmp_Position = gWorld * tmp_Position2; 41 | // vec4 tmp_Position = gWorld * vec4(a_position, 1.0); 42 | // vec3 tmp_Position = a_position; 43 | // vec4 pos_ = gProjection * vec4(tmp_Position.x, tmp_Position.y, tmp_Position.z, 1.0); 44 | // # upper lips 1 lower lips2 teeth3 edge4 45 | 46 | v_bias = vec2(0.0, 0.0); 47 | if (a_texture.x == -1.0f) 48 | { 49 | v_bias = vec2(0.0, 0.0); 50 | } 51 | else if (a_texture.y < 209.0f) 52 | { 53 | vec4 vert_new = gProjection * vec4(tmp_Position.x, tmp_Position.y, tmp_Position.z, 1.0); 54 | v_bias = vert_new.xy - vertBuffer[int(a_texture.y)].xy; 55 | } 56 | 57 | if (a_texture.x >= 3.0f) 58 | { 59 | gl_Position = gProjection * vec4(tmp_Position.x, tmp_Position.y, 500.0, 1.0); 60 | } 61 | else 62 | { 63 | gl_Position = gProjection * vec4(tmp_Position.x, tmp_Position.y, tmp_Position.z, 1.0); 64 | } 65 | 66 | v_texture = a_texture; 67 | } -------------------------------------------------------------------------------- /mini_live/train_input_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["kmp_duplicate_lib_ok"] = "true" 3 | import pickle 4 | import cv2 5 | import numpy as np 6 | import random 7 | import pandas as pd 8 | import glob 9 | import copy 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from talkingface.data.DHLive_mini_dataset import Few_Shot_Dataset,data_preparation 13 | from talkingface.utils import * 14 | from talkingface.model_utils import device 15 | # video_list = glob.glob(r"E:\data\video\video\*.mp4") 16 | # video_list = [os.path.basename(i).split(".")[0] for i in video_list] 17 | 18 | df = pd.read_csv(r"F:\C\AI\CV\DH008_few_shot\DH0119_mouth64_48/imageVar2.csv") 19 | video_list = df[df["imageVar"] > 265000]["name"].tolist() 20 | video_list = [os.path.dirname(os.path.dirname(i)) for i in video_list] 21 | print(len(video_list)) 22 | point_size = 1 23 | point_color = (0, 0, 255) # BGR 24 | thickness = 4 # 0 、4、8 25 | video_list = video_list[105:125] 26 | 27 | dict_info = data_preparation(video_list) 28 | test_set = Few_Shot_Dataset(dict_info, is_train=True, n_ref = 3) 29 | testing_data_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False) 30 | 31 | def Tensor2img(tensor_, channel_index): 32 | frame = tensor_[channel_index:channel_index + 3, :, :].detach().squeeze(0).cpu().float().numpy() 33 | frame = np.transpose(frame, (1, 2, 0)) * 255.0 34 | frame = frame.clip(0, 255) 35 | return frame.astype(np.uint8) 36 | size_ = 256 37 | for iteration, batch in enumerate(testing_data_loader): 38 | # source_tensor, source_prompt_tensor, ref_tensor, ref_prompt_tensor, target_tensor = [iii.to(device) for iii in batch] 39 | source_tensor, ref_tensor, target_tensor = [iii.to(device) for iii in batch[:3]] 40 | print(source_tensor.size(), ref_tensor.size(), target_tensor.size(), batch[3][0]) 41 | 42 | frame0 = Tensor2img(source_tensor[0], 0) 43 | frame1 = Tensor2img(ref_tensor[0], 0) 44 | frame2 = Tensor2img(ref_tensor[0], 1) 45 | frame3 = Tensor2img(ref_tensor[0], 4) 46 | frame4 = Tensor2img(ref_tensor[0], 5) 47 | frame5 = Tensor2img(target_tensor[0], 0) 48 | 49 | # cv2.imwrite("in0.png", frame0) 50 | # cv2.imwrite("in1.png", frame1) 51 | # cv2.imwrite("in2.png", frame2) 52 | # cv2.imwrite("in3.png", frame3) 53 | # cv2.imwrite("in4.png", frame4) 54 | # exit() 55 | 56 | 57 | frame = np.concatenate([frame0, frame1, frame2, frame3, frame4, frame5], axis=1) 58 | 59 | cv2.imshow("ss", frame[:, :, ::-1]) 60 | # if iteration > 840: 61 | # cv2.waitKey(-1) 62 | cv2.waitKey(-1) 63 | # break 64 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kaldi_native_fbank 2 | mediapipe 3 | tqdm 4 | scikit-learn 5 | pyglm 6 | glfw 7 | PyOpenGL 8 | gradio 9 | sherpa-onnx -------------------------------------------------------------------------------- /talkingface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/talkingface/__init__.py -------------------------------------------------------------------------------- /talkingface/audio_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import kaldi_native_fbank as knf 4 | from scipy.io import wavfile 5 | import torch 6 | import pickle 7 | from model_utils import device 8 | import pickle 9 | import os 10 | def pca_process(x): 11 | a = x.reshape(15, 30, 3) 12 | # a = pca.mean_.reshape(15,30,3) 13 | tmp = a[:, :15] + a[:, 15:][:, ::-1] 14 | a[:, :15] = tmp / 2 15 | a[:, 15:] = a[:, :15][:, ::-1] 16 | return a.flatten() 17 | class AudioModel: 18 | def __init__(self): 19 | self.__net = None 20 | self.__fbank = None 21 | self.__fbank_processed_index = 0 22 | self.frame_index = 0 23 | 24 | current_dir = os.path.dirname(os.path.abspath(__file__)) 25 | Path_output_pkl = os.path.join(current_dir, "../data/pca.pkl") 26 | with open(Path_output_pkl, "rb") as f: 27 | pca = pickle.load(f) 28 | self.pca_mean_ = pca_process(pca.mean_) 29 | self.pca_components_ = np.zeros_like(pca.components_) 30 | self.pca_components_[0] = pca_process(pca.components_[0]) 31 | self.pca_components_[1] = pca_process(pca.components_[1]) 32 | self.pca_components_[2] = pca_process(pca.components_[2]) 33 | self.pca_components_[3] = pca_process(pca.components_[3]) 34 | self.pca_components_[4] = pca_process(pca.components_[4]) 35 | self.pca_components_[5] = pca_process(pca.components_[5]) 36 | 37 | self.reset() 38 | 39 | def loadModel(self, ckpt_path): 40 | # if method == "lstm": 41 | # ckpt_path = 'checkpoint/lstm/lstm_model_epoch_560.pth' 42 | # Audio2FeatureModel = torch.load(model_path).to(device) 43 | # Audio2FeatureModel.eval() 44 | from talkingface.models.audio2bs_lstm import Audio2Feature 45 | self.__net = Audio2Feature() # 调用模型Model 46 | self.__net.load_state_dict(torch.load(ckpt_path)) 47 | self.__net = self.__net.to(device) 48 | self.__net.eval() 49 | 50 | def reset(self): 51 | opts = knf.FbankOptions() 52 | opts.frame_opts.dither = 0 53 | opts.frame_opts.frame_length_ms = 50 54 | opts.frame_opts.frame_shift_ms = 20 55 | opts.mel_opts.num_bins = 80 56 | opts.frame_opts.snip_edges = False 57 | opts.mel_opts.debug_mel = False 58 | self.__fbank = knf.OnlineFbank(opts) 59 | 60 | self.h0 = torch.zeros(2, 1, 192).to(device) 61 | self.c0 = torch.zeros(2, 1, 192).to(device) 62 | 63 | self.__fbank_processed_index = 0 64 | 65 | audio_samples = np.zeros([320]) 66 | self.__fbank.accept_waveform(16000, audio_samples.tolist()) 67 | 68 | def interface_frame(self, audio_samples): 69 | # pcm为uint16位数据。 只处理一帧的数据, 16000/25 = 640 70 | self.__fbank.accept_waveform(16000, audio_samples.tolist()) 71 | orig_mel = np.zeros([2, 80]) 72 | 73 | orig_mel[0] = self.__fbank.get_frame(self.__fbank_processed_index) 74 | orig_mel[1] = self.__fbank.get_frame(self.__fbank_processed_index + 1) 75 | 76 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 77 | bs_array, self.h0, self.c0 = self.__net(input, self.h0, self.c0) 78 | bs_array = bs_array[0].detach().cpu().float().numpy() 79 | bs_real = bs_array[0] 80 | # print(self.__fbank_processed_index, self.__fbank.num_frames_ready, bs_real) 81 | 82 | frame = np.dot(bs_real[:6], self.pca_components_[:6]) + self.pca_mean_ 83 | # print(frame_index, frame.shape) 84 | frame = frame.reshape(15, 30, 3).clip(0, 255).astype(np.uint8) 85 | self.__fbank_processed_index += 2 86 | return frame 87 | 88 | def interface_wav(self, wavpath): 89 | rate, wav = wavfile.read(wavpath, mmap=False) 90 | augmented_samples = wav 91 | augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 92 | # print(augmented_samples2.shape, augmented_samples2.shape[0] / 16000) 93 | 94 | opts = knf.FbankOptions() 95 | opts.frame_opts.dither = 0 96 | opts.frame_opts.frame_length_ms = 50 97 | opts.frame_opts.frame_shift_ms = 20 98 | opts.mel_opts.num_bins = 80 99 | opts.frame_opts.snip_edges = False 100 | opts.mel_opts.debug_mel = False 101 | fbank = knf.OnlineFbank(opts) 102 | fbank.accept_waveform(16000, augmented_samples2.tolist()) 103 | seq_len = fbank.num_frames_ready // 2 104 | A2Lsamples = np.zeros([2 * seq_len, 80]) 105 | for i in range(2 * seq_len): 106 | f2 = fbank.get_frame(i) 107 | A2Lsamples[i] = f2 108 | 109 | orig_mel = A2Lsamples 110 | # print(orig_mel.shape) 111 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 112 | # print(input.shape) 113 | h0 = torch.zeros(2, 1, 192).to(device) 114 | c0 = torch.zeros(2, 1, 192).to(device) 115 | bs_array, hn, cn = self.__net(input, h0, c0) 116 | bs_array = bs_array[0].detach().cpu().float().numpy() 117 | bs_array = bs_array[4:] 118 | 119 | frame_num = len(bs_array) 120 | output = np.zeros([frame_num, 15, 30, 3], dtype = np.uint8) 121 | for frame_index in range(frame_num): 122 | bs_real = bs_array[frame_index] 123 | # bs_real[1:4] = - bs_real[1:4] 124 | frame = np.dot(bs_real[:6], self.pca_components_[:6]) + self.pca_mean_ 125 | # print(frame_index, frame.shape) 126 | frame = frame.reshape(15, 30, 3).clip(0, 255).astype(np.uint8) 127 | output[frame_index] = frame 128 | 129 | return output -------------------------------------------------------------------------------- /talkingface/config/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class DataProcessingOptions(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser() 6 | 7 | def parse_args(self): 8 | self.parser.add_argument('--extract_video_frame', action='store_true', help='extract video frame') 9 | self.parser.add_argument('--extract_audio', action='store_true', help='extract audio files from videos') 10 | self.parser.add_argument('--extract_deep_speech', action='store_true', help='extract deep speech features') 11 | self.parser.add_argument('--crop_face', action='store_true', help='crop face') 12 | self.parser.add_argument('--generate_training_json', action='store_true', help='generate training json file') 13 | 14 | self.parser.add_argument('--source_video_dir', type=str, default="./asserts/training_data/split_video_25fps", 15 | help='path of source video in 25 fps') 16 | self.parser.add_argument('--openface_landmark_dir', type=str, default="./asserts/training_data/split_video_25fps_landmark_openface", 17 | help='path of openface landmark dir') 18 | self.parser.add_argument('--video_frame_dir', type=str, default="./asserts/training_data/split_video_25fps_frame", 19 | help='path of video frames') 20 | self.parser.add_argument('--audio_dir', type=str, default="./asserts/training_data/split_video_25fps_audio", 21 | help='path of audios') 22 | self.parser.add_argument('--deep_speech_dir', type=str, default="./asserts/training_data/split_video_25fps_deepspeech", 23 | help='path of deep speech') 24 | self.parser.add_argument('--crop_face_dir', type=str, default="./asserts/training_data/split_video_25fps_crop_face", 25 | help='path of crop face dir') 26 | self.parser.add_argument('--json_path', type=str, default="./asserts/training_data/training_json.json", 27 | help='path of training json') 28 | self.parser.add_argument('--clip_length', type=int, default=9, help='clip length') 29 | self.parser.add_argument('--deep_speech_model', type=str, default="./asserts/output_graph.pb", 30 | help='path of pretrained deepspeech model') 31 | return self.parser.parse_args() 32 | 33 | class DINetTrainingOptions(): 34 | def __init__(self): 35 | self.parser = argparse.ArgumentParser() 36 | 37 | def parse_args(self): 38 | self.parser.add_argument('--seed', type=int, default=456, help='random seed to use.') 39 | self.parser.add_argument('--source_channel', type=int, default=3, help='input source image channels') 40 | self.parser.add_argument('--ref_channel', type=int, default=15, help='input reference image channels') 41 | self.parser.add_argument('--audio_channel', type=int, default=29, help='input audio channels') 42 | self.parser.add_argument('--augment_num', type=int, default=32, help='augment training data') 43 | self.parser.add_argument('--mouth_region_size', type=int, default=64, help='augment training data') 44 | self.parser.add_argument('--train_data', type=str, default=r"./asserts/training_data/training_json.json", 45 | help='path of training json') 46 | self.parser.add_argument('--batch_size', type=int, default=24, help='training batch size') 47 | self.parser.add_argument('--lamb_perception', type=int, default=10, help='weight of perception loss') 48 | self.parser.add_argument('--lamb_syncnet_perception', type=int, default=0.1, help='weight of perception loss') 49 | self.parser.add_argument('--lamb_pixel', type=int, default=10, help='weight of perception loss') 50 | self.parser.add_argument('--lr_g', type=float, default=0.00008, help='initial learning rate for adam') 51 | self.parser.add_argument('--lr_d', type=float, default=0.00008, help='initial learning rate for adam') 52 | self.parser.add_argument('--start_epoch', default=1, type=int, help='start epoch in training stage') 53 | self.parser.add_argument('--non_decay', default=4, type=int, help='num of epoches with fixed learning rate') 54 | self.parser.add_argument('--decay', default=36, type=int, help='num of linearly decay epochs') 55 | self.parser.add_argument('--checkpoint', type=int, default=2, help='num of checkpoints in training stage') 56 | self.parser.add_argument('--result_path', type=str, default=r"./asserts/training_model_weight/frame_training_64", 57 | help='result path to save model') 58 | self.parser.add_argument('--coarse2fine', action='store_true', help='If true, load pretrained model path.') 59 | self.parser.add_argument('--coarse_model_path', 60 | default='', 61 | type=str, 62 | help='Save data (.pth) of previous training') 63 | self.parser.add_argument('--pretrained_syncnet_path', 64 | default='', 65 | type=str, 66 | help='Save data (.pth) of pretrained syncnet') 67 | self.parser.add_argument('--pretrained_frame_DINet_path', 68 | default='', 69 | type=str, 70 | help='Save data (.pth) of frame trained DINet') 71 | # ========================= Discriminator ========================== 72 | self.parser.add_argument('--D_num_blocks', type=int, default=4, help='num of down blocks in discriminator') 73 | self.parser.add_argument('--D_block_expansion', type=int, default=64, help='block expansion in discriminator') 74 | self.parser.add_argument('--D_max_features', type=int, default=256, help='max channels in discriminator') 75 | return self.parser.parse_args() 76 | 77 | 78 | class DINetInferenceOptions(): 79 | def __init__(self): 80 | self.parser = argparse.ArgumentParser() 81 | 82 | def parse_args(self): 83 | self.parser.add_argument('--source_channel', type=int, default=3, help='channels of source image') 84 | self.parser.add_argument('--ref_channel', type=int, default=15, help='channels of reference image') 85 | self.parser.add_argument('--audio_channel', type=int, default=29, help='channels of audio feature') 86 | self.parser.add_argument('--mouth_region_size', type=int, default=256, help='help to resize window') 87 | self.parser.add_argument('--source_video_path', 88 | default='./asserts/examples/test4.mp4', 89 | type=str, 90 | help='path of source video') 91 | self.parser.add_argument('--source_openface_landmark_path', 92 | default='./asserts/examples/test4.csv', 93 | type=str, 94 | help='path of detected openface landmark') 95 | self.parser.add_argument('--driving_audio_path', 96 | default='./asserts/examples/driving_audio_1.wav', 97 | type=str, 98 | help='path of driving audio') 99 | self.parser.add_argument('--pretrained_clip_DINet_path', 100 | default='./asserts/clip_training_DINet_256mouth.pth', 101 | type=str, 102 | help='pretrained model of DINet(clip trained)') 103 | self.parser.add_argument('--deepspeech_model_path', 104 | default='./asserts/output_graph.pb', 105 | type=str, 106 | help='path of deepspeech model') 107 | self.parser.add_argument('--res_video_dir', 108 | default='./asserts/inference_result', 109 | type=str, 110 | help='path of generated videos') 111 | return self.parser.parse_args() -------------------------------------------------------------------------------- /talkingface/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/talkingface/data/__init__.py -------------------------------------------------------------------------------- /talkingface/data/dataset_wav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, PolarityInversion 5 | # from audio import melspectrogram,mel_bar 6 | import kaldi_native_fbank as knf 7 | import random 8 | 9 | 10 | 11 | class AudioVisualDataset(data.Dataset): 12 | """ audio-visual dataset. currently, return 2D info and 3D tracking info. 13 | 14 | ''' 15 | 多个片段的APC语音特征和嘴部顶点的PCA信息 16 | :param audio_features: list 17 | :param mouth_features: list 18 | ''' 19 | 20 | """ 21 | 22 | def __init__(self, audio_features, mouth_features, is_train = True, seq_len = 9): 23 | super(AudioVisualDataset, self).__init__() 24 | 25 | self.fps = 25 26 | # 每0.2s一个序列 27 | # self.seq_len = int(self.fps /5) 28 | self.seq_len = seq_len 29 | self.frame_jump_stride = 2 30 | self.audio_features = audio_features 31 | self.bs_features = mouth_features 32 | self.is_train = is_train 33 | 34 | self.augment = Compose([ 35 | AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5), 36 | # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5), 37 | PolarityInversion(p=0.5), 38 | PitchShift(min_semitones=-4, max_semitones=4, p=0.5), 39 | # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5), 40 | ]) 41 | 42 | 43 | # 每个序列的裁剪片段个数 44 | self.clip_num = [] 45 | for i in range(len(audio_features)): 46 | audio_frame_num = int(len(self.audio_features[i])/(16000/25)) - 2 47 | self.clip_num.append(min(len(self.bs_features[i]), audio_frame_num) - self.seq_len + 1) 48 | 49 | def __getitem__(self, index): 50 | if self.is_train: 51 | video_index = random.randint(0, len(self.bs_features) - 1) 52 | # print(video_index, self.clip_num[video_index]) 53 | clips_index = random.sample(range(self.clip_num[video_index]), 1) # 从当前视频选1个片段 54 | current_frame = clips_index[0] 55 | else: 56 | video_index = 0 57 | # video_index = 0 58 | # for i in range(len(self.clip_num)): 59 | # if index < np.sum(self.clip_num[:i+1]): 60 | # video_index = i 61 | # break 62 | # current_frame = index - np.sum(self.clip_num[:video_index], dtype=int) 63 | # print(video_index, current_frame, current_frame + self.seq_len, self.clip_num, self.bs_features[video_index].shape) 64 | 65 | # start point is current frame 66 | A2Lsamples = self.audio_features[video_index][current_frame*640: (current_frame + self.seq_len + 2)*640] 67 | # A2Lsamples = copy.deepcopy(A2Lsamples_) 68 | # print("A2Lsamples: ", A2Lsamples.shape, A2Lsamples.dtype, A2Lsamples.__class__) 69 | augmented_samples = self.augment(np.array(A2Lsamples, dtype=np.float32), sample_rate=16000) 70 | # print(augmented_samples.shape, augmented_samples.dtype) 71 | # int16转换为float格式 72 | augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 73 | # orig_mel = mel_bar(augmented_samples2) 74 | # orig_mel = melspectrogram(augmented_samples2).T 75 | opts = knf.FbankOptions() 76 | opts.frame_opts.dither = 0 77 | opts.frame_opts.frame_length_ms = 50 78 | opts.frame_opts.frame_shift_ms = 20 79 | opts.mel_opts.num_bins = 80 80 | opts.frame_opts.snip_edges = False 81 | opts.mel_opts.debug_mel = False 82 | fbank = knf.OnlineFbank(opts) 83 | fbank.accept_waveform(16000, augmented_samples2.tolist()) 84 | A2Lsamples = np.zeros([2*self.seq_len, 80]) 85 | for i in range(2*self.seq_len): 86 | f2 = fbank.get_frame(i) 87 | A2Lsamples[i] = f2 88 | fbank.input_finished() 89 | 90 | target_bs = self.bs_features[video_index][current_frame: current_frame + self.seq_len, :].reshape( 91 | self.seq_len, -1) 92 | 93 | # target_bs = self.bs_features[video_index][current_frame + self.seq_len//2, :] 94 | 95 | A2Lsamples = torch.from_numpy(A2Lsamples).float() 96 | target_bs = torch.from_numpy(target_bs).float() 97 | # print("*****", A2Lsamples.size(), target_bs.size(), len(self.clip_num)) 98 | 99 | return [A2Lsamples, target_bs] 100 | 101 | def __len__(self): 102 | return len(self.clip_num) 103 | # return np.sum(self.clip_num, dtype = int) 104 | # return 10000 -------------------------------------------------------------------------------- /talkingface/data/face_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_DUPLICATE_LIB_OK"] = "true" 3 | import pickle 4 | import cv2 5 | import numpy as np 6 | import os 7 | import glob 8 | from talkingface.util.smooth import smooth_array 9 | from talkingface.run_utils import calc_face_mat 10 | import tqdm 11 | from talkingface.utils import * 12 | 13 | path_ = r"../../../preparation_mix" 14 | video_list = [os.path.join(path_, i) for i in os.listdir(path_)] 15 | path_ = r"../../../preparation_hdtf" 16 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 17 | path_ = r"../../../preparation_vfhq" 18 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 19 | path_ = r"../../../preparation_bilibili" 20 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 21 | print(video_list) 22 | video_list = video_list[:] 23 | img_all = [] 24 | keypoints_all = [] 25 | point_size = 1 26 | point_color = (0, 0, 255) # BGR 27 | thickness = 4 # 0 、4、8 28 | for path_ in tqdm.tqdm(video_list): 29 | img_filelist = glob.glob("{}/image/*.png".format(path_)) 30 | img_filelist.sort() 31 | if len(img_filelist) == 0: 32 | continue 33 | img_all.append(img_filelist) 34 | 35 | Path_output_pkl = "{}/keypoint_rotate.pkl".format(path_) 36 | 37 | with open(Path_output_pkl, "rb") as f: 38 | images_info = pickle.load(f)[:, main_keypoints_index, :] 39 | pts_driven = images_info.reshape(len(images_info), -1) 40 | pts_driven = smooth_array(pts_driven).reshape(len(pts_driven), -1, 3) 41 | 42 | face_pts_mean = np.loadtxt(r"data\face_pts_mean_mainKps.txt") 43 | mat_list,pts_normalized_list,face_pts_mean_personal = calc_face_mat(pts_driven, face_pts_mean) 44 | pts_normalized_list = np.array(pts_normalized_list) 45 | # print(face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1]) 46 | # print(np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 1)) 47 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1] = np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 0) + np.arange(5,25,2) 48 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 0] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 0], axis=0) - (9 - np.arange(0,10)) 49 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 1], axis=0) - np.arange(5,25,2) + 28 50 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 0] = np.min(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 0], axis=0) + np.arange(0,10) 51 | 52 | face_pts_mean_personal[INDEX_FACE_OVAL[10], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[10], 1], axis=0) + 25 53 | 54 | # for keypoints_normalized in pts_normalized_list: 55 | # img = np.zeros([1000,1000,3], dtype=np.uint8) 56 | # for coor in face_pts_mean_personal: 57 | # # coor = (coor +1 )/2. 58 | # cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, (255, 0, 0), thickness) 59 | # for coor in keypoints_normalized: 60 | # # coor = (coor +1 )/2. 61 | # cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, point_color, thickness) 62 | # cv2.imshow("a", img) 63 | # cv2.waitKey(30) 64 | 65 | with open("{}/face_mat_mask20240722.pkl".format(path_), "wb") as f: 66 | pickle.dump([mat_list, face_pts_mean_personal], f) 67 | -------------------------------------------------------------------------------- /talkingface/mediapipe_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import mediapipe as mp 4 | mp_face_mesh = mp.solutions.face_mesh 5 | mp_face_detection = mp.solutions.face_detection 6 | 7 | def detect_face_mesh(frames): 8 | pts_3d = np.zeros([len(frames), 478, 3]) 9 | with mp_face_mesh.FaceMesh( 10 | static_image_mode=True, 11 | max_num_faces=1, 12 | refine_landmarks=True, 13 | min_detection_confidence=0.5) as face_mesh: 14 | for frame_index, frame in enumerate(frames): 15 | results = face_mesh.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 16 | if results.multi_face_landmarks: 17 | image_height, image_width = frame.shape[:2] 18 | for face_landmarks in results.multi_face_landmarks: 19 | for index_, i in enumerate(face_landmarks.landmark): 20 | x_px = i.x * image_width 21 | y_px = i.y * image_height 22 | z_px = i.z * image_width 23 | pts_3d[frame_index, index_] = np.array([x_px, y_px, z_px]) 24 | else: 25 | break 26 | return pts_3d 27 | def detect_face(frames): 28 | rect_2d = np.zeros([len(frames), 4]) 29 | # 剔除掉多个人脸、大角度侧脸(鼻子不在两个眼之间)、部分人脸框在画面外、人脸像素低于80*80的 30 | with mp_face_detection.FaceDetection( 31 | model_selection=1, min_detection_confidence=0.5) as face_detection: 32 | for frame_index, frame in enumerate(frames): 33 | results = face_detection.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 34 | if not results.detections or len(results.detections) > 1: 35 | break 36 | rect = results.detections[0].location_data.relative_bounding_box 37 | rect_2d[frame_index] = np.array([rect.xmin, rect.xmin + rect.width, rect.ymin, rect.ymin + rect.height]) 38 | return rect_2d -------------------------------------------------------------------------------- /talkingface/model_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import kaldi_native_fbank as knf 4 | from scipy.io import wavfile 5 | import torch 6 | device = "cuda" if torch.cuda.is_available() else "cpu" 7 | # device = "cpu" 8 | pca = None 9 | def LoadAudioModel(ckpt_path): 10 | # if method == "lstm": 11 | # ckpt_path = 'checkpoint/lstm/lstm_model_epoch_560.pth' 12 | # Audio2FeatureModel = torch.load(model_path).to(device) 13 | # Audio2FeatureModel.eval() 14 | from talkingface.models.audio2bs_lstm import Audio2Feature 15 | Audio2FeatureModel = Audio2Feature() # 调用模型Model 16 | checkpoint = torch.load(ckpt_path, map_location=device) 17 | Audio2FeatureModel.load_state_dict(checkpoint) 18 | Audio2FeatureModel = Audio2FeatureModel.to(device) 19 | Audio2FeatureModel.eval() 20 | return Audio2FeatureModel 21 | 22 | def LoadRenderModel(ckpt_path, model_name = "one_ref"): 23 | if model_name == "one_ref": 24 | from talkingface.models.DINet import LeeNet as DINet 25 | n_ref = 1 26 | source_channel = 3 27 | ref_channel = n_ref * 6 28 | else: 29 | from talkingface.models.DINet import DINet_five_Ref as DINet 30 | n_ref = 5 31 | source_channel = 6 32 | ref_channel = n_ref * 6 33 | net_g = DINet(source_channel, ref_channel).to(device) 34 | checkpoint = torch.load(ckpt_path) 35 | net_g_static = checkpoint['state_dict']['net_g'] 36 | net_g.load_state_dict(net_g_static) 37 | net_g.eval() 38 | return net_g 39 | 40 | 41 | def Audio2mouth(wavpath, Audio2FeatureModel, method = "lstm"): 42 | rate, wav = wavfile.read(wavpath, mmap=False) 43 | augmented_samples = wav 44 | augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 45 | print(augmented_samples2.shape, augmented_samples2.shape[0] / 16000) 46 | 47 | opts = knf.FbankOptions() 48 | opts.frame_opts.dither = 0 49 | opts.frame_opts.frame_length_ms = 50 50 | opts.frame_opts.frame_shift_ms = 20 51 | opts.mel_opts.num_bins = 80 52 | opts.frame_opts.snip_edges = False 53 | opts.mel_opts.debug_mel = False 54 | fbank = knf.OnlineFbank(opts) 55 | # sss = augmented_samples2.tolist() 56 | # for ii in range(0, len(sss), 10000): 57 | # fbank.accept_waveform(16000, sss[ii:ii+10000]) 58 | fbank.accept_waveform(16000, augmented_samples2.tolist()) 59 | seq_len = fbank.num_frames_ready // 2 60 | A2Lsamples = np.zeros([2 * seq_len, 80]) 61 | for i in range(2 * seq_len): 62 | f2 = fbank.get_frame(i) 63 | A2Lsamples[i] = f2 64 | 65 | orig_mel = A2Lsamples 66 | # print(orig_mel.shape) 67 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 68 | # print(input.shape) 69 | h0 = torch.zeros(2, 1, 192).to(device) 70 | c0 = torch.zeros(2, 1, 192).to(device) 71 | bs_array, hn, cn = Audio2FeatureModel(input, h0, c0) 72 | # print(bs_array.shape) 73 | bs_array = bs_array[0].detach().cpu().float().numpy() 74 | # print(bs_array.shape) 75 | bs_array = bs_array[4:] 76 | bs_array[:, :2] = bs_array[:, :2] / 8 77 | bs_array[:, 2] = - bs_array[:, 2] / 8 78 | 79 | return bs_array 80 | from scipy.signal import resample 81 | def Audio2bs(wavpath, Audio2FeatureModel): 82 | rate, wav = wavfile.read(wavpath, mmap=False) 83 | wav = resample(wav, len(wav) //2) 84 | augmented_samples = wav 85 | augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 86 | # print(augmented_samples2.shape, augmented_samples2.shape[0] / 16000) 87 | 88 | opts = knf.FbankOptions() 89 | opts.frame_opts.dither = 0 90 | opts.frame_opts.samp_freq = 8000 91 | opts.frame_opts.frame_length_ms = 50 92 | opts.frame_opts.frame_shift_ms = 20 93 | opts.mel_opts.num_bins = 80 94 | opts.frame_opts.snip_edges = False 95 | opts.mel_opts.debug_mel = False 96 | fbank = knf.OnlineFbank(opts) 97 | # sss = augmented_samples2.tolist() 98 | # for ii in range(0, len(sss), 10000): 99 | # fbank.accept_waveform(16000, sss[ii:ii+10000]) 100 | fbank.accept_waveform(8000, augmented_samples2.tolist()) 101 | seq_len = fbank.num_frames_ready // 2 102 | A2Lsamples = np.zeros([2 * seq_len, 80]) 103 | for i in range(2 * seq_len): 104 | f2 = fbank.get_frame(i) 105 | A2Lsamples[i] = f2 106 | 107 | orig_mel = A2Lsamples 108 | # print(orig_mel.shape) 109 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 110 | # print(input.shape) 111 | h0 = torch.zeros(2, 1, 192).to(device) 112 | c0 = torch.zeros(2, 1, 192).to(device) 113 | bs_array, hn, cn = Audio2FeatureModel(input, h0, c0) 114 | # print(bs_array.shape) 115 | bs_array = bs_array[0].detach().cpu().float().numpy() 116 | # print(bs_array.shape) 117 | # bs_array = bs_array[4:] 118 | return bs_array 119 | -------------------------------------------------------------------------------- /talkingface/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/talkingface/models/__init__.py -------------------------------------------------------------------------------- /talkingface/models/audio2bs_lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | class Audio2Feature(nn.Module): 4 | def __init__(self): 5 | super(Audio2Feature, self).__init__() 6 | num_pred = 1 7 | self.output_size = 6 8 | self.ndim = 80 9 | APC_hidden_size = 80 10 | # define networks 11 | self.downsample = nn.Sequential( 12 | nn.Linear(in_features=APC_hidden_size * 2, out_features=APC_hidden_size), 13 | nn.BatchNorm1d(APC_hidden_size), 14 | nn.LeakyReLU(0.2), 15 | nn.Linear(APC_hidden_size, APC_hidden_size), 16 | ) 17 | self.LSTM = nn.LSTM(input_size=APC_hidden_size, 18 | hidden_size=192, 19 | num_layers=2, 20 | dropout=0, 21 | bidirectional=False, 22 | batch_first=True) 23 | self.fc = nn.Sequential( 24 | nn.Linear(in_features=192, out_features=256), 25 | nn.BatchNorm1d(256), 26 | nn.LeakyReLU(0.2), 27 | nn.Linear(256, 256), 28 | nn.BatchNorm1d(256), 29 | nn.LeakyReLU(0.2), 30 | nn.Linear(256, self.output_size)) 31 | 32 | def forward(self, audio_features, h0, c0): 33 | ''' 34 | Args: 35 | audio_features: [b, T, ndim] 36 | ''' 37 | self.item_len = audio_features.size()[1] 38 | # new in 0324 39 | audio_features = audio_features.reshape(-1, self.ndim * 2) 40 | down_audio_feats = self.downsample(audio_features) 41 | # print(down_audio_feats) 42 | down_audio_feats = down_audio_feats.reshape(-1, int(self.item_len / 2), self.ndim) 43 | output, (hn, cn) = self.LSTM(down_audio_feats, (h0, c0)) 44 | 45 | # output, (hn, cn) = self.LSTM(audio_features) 46 | pred = self.fc(output.reshape(-1, 192)).reshape(-1, int(self.item_len / 2), self.output_size) 47 | return pred, hn, cn 48 | -------------------------------------------------------------------------------- /talkingface/models/common/Discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class DownBlock2d(nn.Module): 5 | def __init__(self, in_features, out_features, kernel_size=4, pool=False): 6 | super(DownBlock2d, self).__init__() 7 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 8 | self.pool = pool 9 | def forward(self, x): 10 | out = x 11 | out = self.conv(out) 12 | out = F.leaky_relu(out, 0.2) 13 | if self.pool: 14 | out = F.avg_pool2d(out, (2, 2)) 15 | return out 16 | 17 | 18 | class Discriminator(nn.Module): 19 | """ 20 | Discriminator for GAN loss 21 | """ 22 | def __init__(self, num_channels, block_expansion=64, num_blocks=4, max_features=512): 23 | super(Discriminator, self).__init__() 24 | down_blocks = [] 25 | for i in range(num_blocks): 26 | down_blocks.append( 27 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 28 | min(max_features, block_expansion * (2 ** (i + 1))), 29 | kernel_size=4, pool=(i != num_blocks - 1))) 30 | self.down_blocks = nn.ModuleList(down_blocks) 31 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 32 | def forward(self, x): 33 | feature_maps = [] 34 | out = x 35 | for down_block in self.down_blocks: 36 | feature_maps.append(down_block(out)) 37 | out = feature_maps[-1] 38 | out = self.conv(out) 39 | return feature_maps, out 40 | -------------------------------------------------------------------------------- /talkingface/models/common/VGG19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import numpy as np 4 | 5 | 6 | class Vgg19(torch.nn.Module): 7 | """ 8 | Vgg19 network for perceptual loss 9 | """ 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_model = models.vgg19(pretrained=True) 13 | vgg_pretrained_features = vgg_model.features 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | for x in range(2): 20 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(2, 7): 22 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(7, 12): 24 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(12, 21): 26 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(21, 30): 28 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 29 | 30 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 31 | requires_grad=False) 32 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 33 | requires_grad=False) 34 | 35 | if not requires_grad: 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, X): 40 | X = (X - self.mean) / self.std 41 | h_relu1 = self.slice1(X) 42 | h_relu2 = self.slice2(h_relu1) 43 | h_relu3 = self.slice3(h_relu2) 44 | h_relu4 = self.slice4(h_relu3) 45 | h_relu5 = self.slice5(h_relu4) 46 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 47 | return out 48 | -------------------------------------------------------------------------------- /talkingface/models/speed_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from talkingface.models.audio2bs_lstm import Audio2Feature 3 | import time 4 | import random 5 | import numpy as np 6 | import cv2 7 | device = "cpu" 8 | 9 | model = Audio2Feature() 10 | model.eval() 11 | x = torch.ones((1, 2, 80)) 12 | h0 = torch.zeros(2, 1, 192) 13 | c0 = torch.zeros(2, 1, 192) 14 | y, hn, cn = model(x, h0, c0) 15 | start_time = time.time() 16 | 17 | from thop import profile 18 | from thop import clever_format 19 | flops, params = profile(model.to(device), inputs=(x, h0, c0)) 20 | flops, params = clever_format([flops, params], "%.3f") 21 | print(flops, params) 22 | -------------------------------------------------------------------------------- /talkingface/preprocess.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import numpy as np 3 | import cv2 4 | import os 5 | import sys 6 | import time 7 | import argparse 8 | from talkingface.run_utils import video_pts_process, concat_output_2binfile 9 | from talkingface.mediapipe_utils import detect_face_mesh,detect_face 10 | from talkingface.utils import main_keypoints_index,INDEX_LIPS 11 | # 1、是否是mp4,宽高是否大于200,时长是否大于2s,可否成功转换为符合格式的mp4 12 | # 2、面部关键点检测及是否可以构成循环视频 13 | # 4、旋转矩阵、面部mask估计 14 | # 5、验证文件完整性 15 | 16 | dir_ = "data/asset/Actor" 17 | def print_log(task_id, progress, status, Error, mode = 0): 18 | ''' 19 | status: -1代表未开始, 0代表处理中, 1代表已完成, 2代表出错中断 20 | progress: 0-1000, 进度千分比 21 | ''' 22 | print("task_id: {}. progress: {:0>4d}. status: {}. mode: {}. Error: {}".format(task_id, progress, status, mode, Error)) 23 | sys.stdout.flush() 24 | 25 | def check_step0(task_id, video_path): 26 | try: 27 | cap = cv2.VideoCapture(video_path) 28 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 29 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 30 | frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) 31 | fps = cap.get(cv2.CAP_PROP_FPS) 32 | cap.release() 33 | if vid_width < 200 or vid_height < 200: 34 | print_log(task_id, 0, 2, "video width/height < 200") 35 | return 0 36 | if frames < 2*fps: 37 | print_log(task_id, 0, 2, "video duration < 2s") 38 | return 0 39 | os.makedirs(os.path.join(dir_, task_id), exist_ok=True) 40 | front_video_path = os.path.join("data", "front.mp4") 41 | scale = max(vid_width / 720., vid_height / 1280.) 42 | if scale > 1: 43 | new_width = int(vid_width / scale + 0.1)//2 * 2 44 | new_height = int(vid_height / scale + 0.1)//2 * 2 45 | ffmpeg_cmd = "ffmpeg -i {} -r 25 -ss 00:00:00 -t 00:02:00 -vf scale={}:{} -an -loglevel quiet -y {}".format( 46 | video_path,new_width,new_height,front_video_path) 47 | else: 48 | ffmpeg_cmd = "ffmpeg -i {} -r 25 -ss 00:00:00 -t 00:02:00 -an -loglevel quiet -y {}".format( 49 | video_path, front_video_path) 50 | os.system(ffmpeg_cmd) 51 | if not os.path.isfile(front_video_path): 52 | return 0 53 | return 1 54 | except: 55 | print_log(task_id, 0, 2, "video cant be opened") 56 | return 0 57 | 58 | def check_step1(task_id): 59 | front_video_path = os.path.join("data", "front.mp4") 60 | back_video_path = os.path.join("data", "back.mp4") 61 | video_out_path = os.path.join(dir_, task_id, "video.mp4") 62 | face_info_path = os.path.join(dir_, task_id, "video_info.bin") 63 | preview_path = os.path.join(dir_, task_id, "preview.jpg") 64 | if ExtractFromVideo(task_id, front_video_path) != 1: 65 | shutil.rmtree(os.path.join(dir_, task_id)) 66 | return 0 67 | ffmpeg_cmd = "ffmpeg -i {} -vf reverse -loglevel quiet -y {}".format(front_video_path, back_video_path) 68 | os.system(ffmpeg_cmd) 69 | ffmpeg_cmd = "ffmpeg -f concat -i {} -loglevel quiet -y {}".format("data/video_concat.txt", video_out_path) 70 | os.system(ffmpeg_cmd) 71 | ffmpeg_cmd = "ffmpeg -i {} -vf crop=w='min(iw\,ih)':h='min(iw\,ih)',scale=256:256,setsar=1 -vframes 1 {}".format(front_video_path, preview_path) 72 | # ffmpeg_cmd = "ffmpeg -i {} -vf scale=256:-1 -loglevel quiet -y {}".format(front_video_path, preview_path) 73 | os.system(ffmpeg_cmd) 74 | if os.path.isfile(front_video_path): 75 | os.remove(front_video_path) 76 | if os.path.isfile(back_video_path): 77 | os.remove(back_video_path) 78 | if os.path.isfile(video_out_path) and os.path.isfile(face_info_path): 79 | return 1 80 | else: 81 | return 0 82 | 83 | # def check_step2(task_id, ): 84 | # mat_list, pts_normalized_list, face_mask_pts = video_pts_process(pts_array_origin) 85 | 86 | 87 | def ExtractFromVideo(task_id, front_video_path): 88 | cap = cv2.VideoCapture(front_video_path) 89 | if not cap.isOpened(): 90 | print_log(task_id, 0, 2, "front_video cant be opened by opencv") 91 | return -1 92 | 93 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 94 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 95 | 96 | totalFrames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # 总帧数 97 | totalFrames = int(totalFrames) 98 | pts_3d = np.zeros([totalFrames, 478, 3]) 99 | frame_index = 0 100 | face_rect_list = [] 101 | start_time = time.time() 102 | while cap.isOpened(): 103 | ret, frame = cap.read() # 按帧读取视频 104 | # #到视频结尾时终止 105 | if ret is False: 106 | break 107 | rect_2d = detect_face([frame]) 108 | rect = rect_2d[0] 109 | tag_ = 1 if np.sum(rect) > 0 else 0 110 | if frame_index == 0 and tag_ != 1: 111 | print_log(task_id, 0, 2, "no face detected in first frame") 112 | cap.release() # 释放视频对象 113 | return 0 114 | elif tag_ == 0: # 有时候人脸检测会失败,就用上一帧的结果替代这一帧的结果 115 | rect = face_rect_list[-1] 116 | 117 | face_rect_list.append(rect) 118 | 119 | x_min = rect[0] * vid_width 120 | y_min = rect[2] * vid_height 121 | x_max = rect[1] * vid_width 122 | y_max = rect[3] * vid_height 123 | seq_w, seq_h = x_max - x_min, y_max - y_min 124 | x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2 125 | x_min = int(max(0, x_mid - seq_w * 0.65)) 126 | y_min = int(max(0, y_mid - seq_h * 0.4)) 127 | x_max = int(min(vid_width, x_mid + seq_w * 0.65)) 128 | y_max = int(min(vid_height, y_mid + seq_h * 0.8)) 129 | 130 | frame_face = frame[y_min:y_max, x_min:x_max] 131 | frame_kps = detect_face_mesh([frame_face])[0] 132 | if np.sum(frame_kps) == 0: 133 | print_log(task_id, 0, 2, "Frame num {} keypoint error".format(frame_index)) 134 | cap.release() # 释放视频对象 135 | return 0 136 | pts_3d[frame_index] = frame_kps + np.array([x_min, y_min, 0]) 137 | frame_index += 1 138 | 139 | if time.time() - start_time > 0.5: 140 | progress = int(1000 * frame_index / totalFrames * 0.99) 141 | print_log(task_id, progress, 0, "handling...") 142 | start_time = time.time() 143 | cap.release() # 释放视频对象 144 | if type(pts_3d) is np.ndarray and len(pts_3d) == totalFrames: 145 | pts_3d_main = pts_3d[:, main_keypoints_index] 146 | mat_list, pts_normalized_list, face_pts_mean_personal, face_mask_pts_normalized = video_pts_process(pts_3d_main) 147 | 148 | output = concat_output_2binfile(mat_list, pts_3d, face_pts_mean_personal, face_mask_pts_normalized) 149 | # print(output.shape) 150 | pts_normalized_list = np.array(pts_normalized_list)[:, INDEX_LIPS] 151 | # 找出此模特正面人脸的嘴巴区域范围 152 | x_max, x_min = np.max(pts_normalized_list[:, :, 0]), np.min(pts_normalized_list[:, :, 0]) 153 | y_max, y_min = np.max(pts_normalized_list[:, :, 1]), np.min(pts_normalized_list[:, :, 1]) 154 | y_min = y_min + (y_max - y_min) / 10. 155 | 156 | first_line = np.zeros([406]) 157 | first_line[:4] = np.array([x_min,x_max,y_min,y_max]) 158 | # print(first_line) 159 | # 160 | # pts_2d_main = pts_3d[:, main_keypoints_index, :2].reshape(len(pts_3d), -1) 161 | # smooth_array_ = np.array(mat_list).reshape(-1, 16)*100 162 | # 163 | # output = np.concatenate([smooth_array_, pts_2d_main], axis=1).astype(np.float32) 164 | output = np.concatenate([first_line.reshape(1,-1), output], axis=0).astype(np.float32) 165 | # print(smooth_array_.shape, pts_2d_main.shape, first_line.shape, output.shape) 166 | face_info_path = os.path.join(dir_, task_id, "video_info.bin") 167 | # np.savetxt(face_info_path, output, fmt='%.1f') 168 | # print(222) 169 | output.tofile(face_info_path) 170 | return 1 171 | else: 172 | print_log(task_id, 0, 2, "keypoint cant be saved") 173 | return 0 174 | 175 | def check_step0_audio(task_id, video_path): 176 | dir_ = "data/asset/Audio" 177 | wav_path = os.path.join(dir_, task_id + ".wav") 178 | ffmpeg_cmd = "ffmpeg -i {} -ac 1 -ar 16000 -loglevel quiet -y {}".format( 179 | video_path, wav_path) 180 | os.system(ffmpeg_cmd) 181 | if not os.path.isfile(wav_path): 182 | print_log(task_id, 0, 2, "audio convert failed", 2) 183 | return 0 184 | return 1 185 | 186 | def new_task(task_id, task_mode, video_path): 187 | # print(task_id, task_mode, video_path) 188 | if task_mode == "0": # "actor" 189 | print_log(task_id, 0, 0, "handling...") 190 | if check_step0(task_id, video_path): 191 | print_log(task_id, 0, 0, "handling...") 192 | if check_step1(task_id): 193 | print_log(task_id, 1000, 1, "process finished, click to confirm") 194 | if task_mode == "2": # "audio" 195 | print_log(task_id, 0, 0, "handling...", 2) 196 | if check_step0_audio(task_id, video_path): 197 | print_log(task_id, 1000, 1, "process finished, click to confirm", 2) 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser(description='Inference code to preprocess videos') 201 | parser.add_argument('--task_id', type=str, help='task_id') 202 | parser.add_argument('--task_mode', type=str, help='task_mode') 203 | parser.add_argument('--video_path', type=str, help='Filepath of video that contains faces to use') 204 | args = parser.parse_args() 205 | new_task(args.task_id, args.task_mode, args.video_path) -------------------------------------------------------------------------------- /talkingface/render_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import os 4 | import numpy as np 5 | import time 6 | from talkingface.run_utils import smooth_array, video_pts_process 7 | from talkingface.run_utils import mouth_replace, prepare_video_data 8 | from talkingface.utils import generate_face_mask, INDEX_LIPS_OUTER 9 | from talkingface.data.few_shot_dataset import select_ref_index,get_ref_images_fromVideo,generate_input, generate_input_pixels 10 | from talkingface.model_utils import device 11 | import pickle 12 | import cv2 13 | 14 | 15 | face_mask = generate_face_mask() 16 | 17 | 18 | class RenderModel: 19 | def __init__(self): 20 | self.__net = None 21 | 22 | self.__pts_driven = None 23 | self.__mat_list = None 24 | self.__pts_normalized_list = None 25 | self.__face_mask_pts = None 26 | self.__ref_img = None 27 | self.__cap_input = None 28 | self.frame_index = 0 29 | self.__mouth_coords_array = None 30 | 31 | def loadModel(self, ckpt_path): 32 | from talkingface.models.DINet import DINet_five_Ref as DINet 33 | n_ref = 5 34 | source_channel = 6 35 | ref_channel = n_ref * 6 36 | self.__net = DINet(source_channel, ref_channel).to(device) 37 | checkpoint = torch.load(ckpt_path) 38 | self.__net.load_state_dict(checkpoint) 39 | self.__net.eval() 40 | 41 | def reset_charactor(self, video_path, Path_pkl, ref_img_index_list = None): 42 | if self.__cap_input is not None: 43 | self.__cap_input.release() 44 | 45 | self.__pts_driven, self.__mat_list,self.__pts_normalized_list, self.__face_mask_pts, self.__ref_img, self.__cap_input = \ 46 | prepare_video_data(video_path, Path_pkl, ref_img_index_list) 47 | 48 | ref_tensor = torch.from_numpy(self.__ref_img / 255.).float().permute(2, 0, 1).unsqueeze(0).to(device) 49 | self.__net.ref_input(ref_tensor) 50 | 51 | x_min, x_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 0]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 0]) 52 | y_min, y_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 1]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 1]) 53 | z_min, z_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 2]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 2]) 54 | 55 | x_mid,y_mid,z_mid = (x_min + x_max)/2, (y_min + y_max)/2, (z_min + z_max)/2 56 | x_len, y_len, z_len = (x_max - x_min)/2, (y_max - y_min)/2, (z_max - z_min)/2 57 | x_min, x_max = x_mid - x_len*0.9, x_mid + x_len*0.9 58 | y_min, y_max = y_mid - y_len*0.9, y_mid + y_len*0.9 59 | z_min, z_max = z_mid - z_len*0.9, z_mid + z_len*0.9 60 | 61 | # print(face_personal.shape, x_min, x_max, y_min, y_max, z_min, z_max) 62 | coords_array = np.zeros([100, 150, 4]) 63 | for i in range(100): 64 | for j in range(150): 65 | coords_array[i, j, 0] = j/149 66 | coords_array[i, j, 1] = i/100 67 | # coords_array[i, j, 2] = int((-75 + abs(j - 75))*(2./3)) 68 | coords_array[i, j, 2] = ((j - 75)/ 75) ** 2 69 | coords_array[i, j, 3] = 1 70 | 71 | coords_array = coords_array*np.array([x_max - x_min, y_max - y_min, z_max - z_min, 1]) + np.array([x_min, y_min, z_min, 0]) 72 | self.__mouth_coords_array = coords_array.reshape(-1, 4).transpose(1, 0) 73 | 74 | 75 | 76 | def interface(self, mouth_frame): 77 | vid_frame_count = self.__cap_input.get(cv2.CAP_PROP_FRAME_COUNT) 78 | if self.frame_index % vid_frame_count == 0: 79 | self.__cap_input.set(cv2.CAP_PROP_POS_FRAMES, 0) # 设置要获取的帧号 80 | ret, frame = self.__cap_input.read() # 按帧读取视频 81 | 82 | epoch = self.frame_index // len(self.__mat_list) 83 | if epoch % 2 == 0: 84 | new_index = self.frame_index % len(self.__mat_list) 85 | else: 86 | new_index = -1 - self.frame_index % len(self.__mat_list) 87 | 88 | # print(self.__face_mask_pts.shape, "ssssssss") 89 | source_img, target_img, crop_coords = generate_input_pixels(frame, self.__pts_driven[new_index], self.__mat_list[new_index], 90 | mouth_frame, self.__face_mask_pts[new_index], 91 | self.__mouth_coords_array) 92 | 93 | # tensor 94 | source_tensor = torch.from_numpy(source_img / 255.).float().permute(2, 0, 1).unsqueeze(0).to(device) 95 | target_tensor = torch.from_numpy(target_img / 255.).float().permute(2, 0, 1).unsqueeze(0).to(device) 96 | 97 | source_tensor, source_prompt_tensor = source_tensor[:, :3], source_tensor[:, 3:] 98 | fake_out = self.__net.interface(source_tensor, source_prompt_tensor) 99 | 100 | image_numpy = fake_out.detach().squeeze(0).cpu().float().numpy() 101 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 102 | image_numpy = image_numpy.clip(0, 255) 103 | image_numpy = image_numpy.astype(np.uint8) 104 | 105 | image_numpy = target_img * face_mask + image_numpy * (1 - face_mask) 106 | 107 | img_bg = frame 108 | x_min, y_min, x_max, y_max = crop_coords 109 | 110 | img_face = cv2.resize(image_numpy, (x_max - x_min, y_max - y_min)) 111 | img_bg[y_min:y_max, x_min:x_max] = img_face 112 | self.frame_index += 1 113 | return img_bg 114 | 115 | def save(self, path): 116 | torch.save(self.__net.state_dict(), path) -------------------------------------------------------------------------------- /talkingface/render_model_mini.py: -------------------------------------------------------------------------------- 1 | import os 2 | current_dir = os.path.dirname(os.path.abspath(__file__)) 3 | import random 4 | import glob 5 | import torch 6 | import numpy as np 7 | import cv2 8 | 9 | from talkingface.utils import draw_mouth_maps 10 | from talkingface.models.DINet_mini import input_height,input_width,model_size 11 | from talkingface.model_utils import device 12 | class RenderModel_Mini: 13 | def __init__(self): 14 | self.__net = None 15 | 16 | def loadModel(self, ckpt_path): 17 | from talkingface.models.DINet_mini import DINet_mini_pipeline as DINet 18 | n_ref = 3 19 | source_channel = 3 20 | ref_channel = n_ref * 4 21 | self.net = DINet(source_channel, ref_channel, device == "cuda").to(device) 22 | checkpoint = torch.load(ckpt_path, map_location=device) 23 | net_g_static = checkpoint['state_dict']['net_g'] 24 | self.net.infer_model.load_state_dict(net_g_static) 25 | self.net.eval() 26 | 27 | 28 | def reset_charactor(self, ref_img, ref_keypoints, standard_size = 256): 29 | ref_img_list = [] 30 | ref_face_edge = draw_mouth_maps(ref_keypoints, size=(standard_size, standard_size)) 31 | # cv2.imshow("ss", ref_face_edge) 32 | # cv2.waitKey(-1) 33 | # cv2.imshow("ss", ref_img) 34 | # cv2.waitKey(-1) 35 | ref_face_edge = cv2.resize(ref_face_edge, (model_size, model_size)) 36 | ref_img = cv2.resize(ref_img, (model_size, model_size)) 37 | w_pad = int((model_size - input_width) / 2) 38 | h_pad = int((model_size - input_height) / 2) 39 | 40 | ref_img = np.concatenate( 41 | [ref_img[h_pad:-h_pad, w_pad:-w_pad, :3], ref_face_edge[h_pad:-h_pad, w_pad:-w_pad, :1]], axis=2) 42 | # cv2.imshow("ss", ref_face_edge[h_pad:-h_pad, w_pad:-w_pad]) 43 | # cv2.waitKey(-1) 44 | ref_img_list.append(ref_img) 45 | 46 | teeth_ref_img = os.path.join(current_dir, r"../video_data/teeth_ref/*.png") 47 | teeth_ref_img = random.sample(glob.glob(teeth_ref_img), 1)[0] 48 | # teeth_ref_img = teeth_ref_img.replace("_2", "") 49 | teeth_ref_img = cv2.imread(teeth_ref_img, cv2.IMREAD_UNCHANGED) 50 | teeth_ref_img = cv2.resize(teeth_ref_img, (input_width, input_height)) 51 | ref_img_list.append(teeth_ref_img) 52 | ref_img_list.append(teeth_ref_img) 53 | 54 | self.ref_img_save = np.concatenate([i[:,:,:3] for i in ref_img_list], axis=1) 55 | self.ref_img = np.concatenate(ref_img_list, axis=2) 56 | 57 | ref_tensor = torch.from_numpy(self.ref_img / 255.).float().permute(2, 0, 1).unsqueeze(0).to(device) 58 | 59 | self.net.ref_input(ref_tensor) 60 | 61 | 62 | def interface(self, source_tensor, gl_tensor): 63 | ''' 64 | 65 | Args: 66 | source_tensor: [batch, 3, 128, 128] 67 | gl_tensor: [batch, 3, 128, 128] 68 | 69 | Returns: 70 | warped_img: [batch, 3, 128, 128] 71 | ''' 72 | warped_img = self.net.interface(source_tensor, gl_tensor) 73 | return warped_img 74 | 75 | def save(self, path): 76 | torch.save(self.net.state_dict(), path) -------------------------------------------------------------------------------- /talkingface/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /talkingface/util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /talkingface/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /talkingface/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /talkingface/util/log_board.py: -------------------------------------------------------------------------------- 1 | def log( 2 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 3 | ): 4 | if losses is not None: 5 | logger.add_scalar("Loss/d_loss", losses[0], step) 6 | logger.add_scalar("Loss/g_gan_loss", losses[1], step) 7 | logger.add_scalar("Loss/g_l1_loss", losses[2], step) 8 | 9 | if fig is not None: 10 | logger.add_image(tag, fig, 2, dataformats='HWC') 11 | 12 | if audio is not None: 13 | logger.add_audio( 14 | tag, 15 | audio / max(abs(audio)), 16 | sample_rate=sampling_rate, 17 | ) -------------------------------------------------------------------------------- /talkingface/util/smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def smooth_array(array, weight = [0.1,0.8,0.1]): 7 | ''' 8 | 9 | Args: 10 | array: [n_frames, n_values], 需要转换为[n_values, 1, n_frames] 11 | weight: Conv1d.weight, 一维卷积核权重 12 | Returns: 13 | array: [n_frames, n_values], 光滑后的array 14 | ''' 15 | input = torch.Tensor(np.transpose(array[:,np.newaxis,:], (2, 1, 0))) 16 | smooth_length = len(weight) 17 | assert smooth_length%2 == 1, "卷积核权重个数必须使用奇数" 18 | pad = (smooth_length//2, smooth_length//2) # 当pad只有两个参数时,仅改变最后一个维度, 左边扩充1列,右边扩充1列 19 | input = F.pad(input, pad, "replicate") 20 | 21 | with torch.no_grad(): 22 | conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=smooth_length) 23 | # 卷积核的元素值初始化 24 | weight = torch.tensor(weight).view(1, 1, -1) 25 | conv1.weight = torch.nn.Parameter(weight) 26 | nn.init.constant_(conv1.bias, 0) # 偏置值为0 27 | # print(conv1.weight) 28 | out = conv1(input) 29 | return out.permute(2,1,0).squeeze().numpy() 30 | 31 | if __name__ == '__main__': 32 | model_id = "new_case" 33 | Path_output_pkl = "../preparation/{}/mouth_info.pkl".format(model_id + "/00001") 34 | import pickle 35 | with open(Path_output_pkl, "rb") as f: 36 | images_info = pickle.load(f) 37 | pts_array_normalized = np.array(images_info[2]) 38 | pts_array_normalized = pts_array_normalized.reshape(-1, 16) 39 | smooth_array_ = smooth_array(pts_array_normalized) 40 | print(smooth_array_, smooth_array_.shape) 41 | smooth_array_ = smooth_array_.reshape(-1, 4, 4) 42 | import pandas as pd 43 | 44 | pd.DataFrame(smooth_array_[:, :, 0]).to_csv("mat2.csv") -------------------------------------------------------------------------------- /talkingface/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | 57 | image_pil = Image.fromarray(image_numpy) 58 | h, w, _ = image_numpy.shape 59 | 60 | if aspect_ratio > 1.0: 61 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 62 | if aspect_ratio < 1.0: 63 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 64 | image_pil.save(image_path) 65 | 66 | 67 | def print_numpy(x, val=True, shp=False): 68 | """Print the mean, min, max, median, std, and size of a numpy array 69 | 70 | Parameters: 71 | val (bool) -- if print the values of the numpy array 72 | shp (bool) -- if print the shape of the numpy array 73 | """ 74 | x = x.astype(np.float64) 75 | if shp: 76 | print('shape,', x.shape) 77 | if val: 78 | x = x.flatten() 79 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 80 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 81 | 82 | 83 | def mkdirs(paths): 84 | """create empty directories if they don't exist 85 | 86 | Parameters: 87 | paths (str list) -- a list of directory paths 88 | """ 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | """create a single empty directory if it didn't exist 98 | 99 | Parameters: 100 | path (str) -- a single directory path 101 | """ 102 | if not os.path.exists(path): 103 | os.makedirs(path) 104 | -------------------------------------------------------------------------------- /talkingface/util/utils.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | 3 | import torch.nn as nn 4 | import torch 5 | 6 | ######################################################### training utils########################################################## 7 | 8 | def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50): 9 | ''' 10 | scheduler in training stage 11 | ''' 12 | if lr_policy == 'lambda': 13 | def lambda_rule(epoch): 14 | lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1) 15 | return lr_l 16 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 17 | elif lr_policy == 'step': 18 | scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1) 19 | elif lr_policy == 'plateau': 20 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 21 | elif lr_policy == 'cosine': 22 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) 23 | else: 24 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 25 | return scheduler 26 | 27 | def update_learning_rate(scheduler, optimizer): 28 | scheduler.step() 29 | lr = optimizer.param_groups[0]['lr'] 30 | print('learning rate = %.7f' % lr) 31 | 32 | class GANLoss(nn.Module): 33 | ''' 34 | GAN loss 35 | ''' 36 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 37 | super(GANLoss, self).__init__() 38 | self.register_buffer('real_label', torch.tensor(target_real_label)) 39 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 40 | if use_lsgan: 41 | self.loss = nn.MSELoss() 42 | else: 43 | self.loss = nn.BCELoss() 44 | 45 | def get_target_tensor(self, input, target_is_real): 46 | if target_is_real: 47 | target_tensor = self.real_label 48 | else: 49 | target_tensor = self.fake_label 50 | return target_tensor.expand_as(input) 51 | 52 | def forward(self, input, target_is_real): 53 | target_tensor = self.get_target_tensor(input, target_is_real) 54 | return self.loss(input, target_tensor) 55 | 56 | 57 | 58 | import tqdm 59 | import numpy as np 60 | import cv2 61 | import glob 62 | import os 63 | import math 64 | import pickle 65 | import mediapipe as mp 66 | mp_face_mesh = mp.solutions.face_mesh 67 | landmark_points_68 = [162,234,93,58,172,136,149,148,152,377,378,365,397,288,323,454,389, 68 | 71,63,105,66,107,336,296,334,293,301, 69 | 168,197,5,4,75,97,2,326,305, 70 | 33,160,158,133,153,144,362,385,387,263,373, 71 | 380,61,39,37,0,267,269,291,405,314,17,84,181,78,82,13,312,308,317,14,87] 72 | def ExtractFaceFromFrameList(frames_list, vid_height, vid_width, out_size = 256): 73 | pts_3d = np.zeros([len(frames_list), 478, 3]) 74 | with mp_face_mesh.FaceMesh( 75 | static_image_mode=True, 76 | max_num_faces=1, 77 | refine_landmarks=True, 78 | min_detection_confidence=0.5) as face_mesh: 79 | 80 | for index, frame in tqdm.tqdm(enumerate(frames_list)): 81 | results = face_mesh.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 82 | if not results.multi_face_landmarks: 83 | print("****** WARNING! No face detected! ******") 84 | pts_3d[index] = 0 85 | return 86 | # continue 87 | image_height, image_width = frame.shape[:2] 88 | for face_landmarks in results.multi_face_landmarks: 89 | for index_, i in enumerate(face_landmarks.landmark): 90 | x_px = min(math.floor(i.x * image_width), image_width - 1) 91 | y_px = min(math.floor(i.y * image_height), image_height - 1) 92 | z_px = min(math.floor(i.z * image_height), image_height - 1) 93 | pts_3d[index, index_] = np.array([x_px, y_px, z_px]) 94 | 95 | # 计算整个视频中人脸的范围 96 | 97 | x_min, y_min, x_max, y_max = np.min(pts_3d[:, :, 0]), np.min( 98 | pts_3d[:, :, 1]), np.max( 99 | pts_3d[:, :, 0]), np.max(pts_3d[:, :, 1]) 100 | new_w = int((x_max - x_min) * 0.55)*2 101 | new_h = int((y_max - y_min) * 0.6)*2 102 | center_x = int((x_max + x_min) / 2.) 103 | center_y = int(y_min + (y_max - y_min) * 0.6) 104 | size = max(new_h, new_w) 105 | x_min, y_min, x_max, y_max = int(center_x - size // 2), int(center_y - size // 2), int( 106 | center_x + size // 2), int(center_y + size // 2) 107 | 108 | # 确定裁剪区域上边top和左边left坐标 109 | top = y_min 110 | left = x_min 111 | # 裁剪区域与原图的重合区域 112 | top_coincidence = int(max(top, 0)) 113 | bottom_coincidence = int(min(y_max, vid_height)) 114 | left_coincidence = int(max(left, 0)) 115 | right_coincidence = int(min(x_max, vid_width)) 116 | 117 | scale = out_size / size 118 | pts_3d = (pts_3d - np.array([left, top, 0])) * scale 119 | pts_3d = pts_3d 120 | 121 | face_rect = np.array([center_x, center_y, size]) 122 | print(np.array([x_min, y_min, x_max, y_max])) 123 | 124 | img_array = np.zeros([len(frames_list), out_size, out_size, 3], dtype = np.uint8) 125 | for index, frame in tqdm.tqdm(enumerate(frames_list)): 126 | img_new = np.zeros([size, size, 3], dtype=np.uint8) 127 | img_new[top_coincidence - top:bottom_coincidence - top, left_coincidence - left:right_coincidence - left,:] = \ 128 | frame[top_coincidence:bottom_coincidence, left_coincidence:right_coincidence, :] 129 | img_new = cv2.resize(img_new, (out_size, out_size)) 130 | img_array[index] = img_new 131 | return pts_3d,img_array, face_rect 132 | 133 | -------------------------------------------------------------------------------- /video_data/000001/video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/000001/video.mp4 -------------------------------------------------------------------------------- /video_data/000002/video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/000002/video.mp4 -------------------------------------------------------------------------------- /video_data/audio0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/audio0.wav -------------------------------------------------------------------------------- /video_data/audio1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/audio1.wav -------------------------------------------------------------------------------- /video_data/teeth_ref/221.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/221.png -------------------------------------------------------------------------------- /video_data/teeth_ref/252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/252.png -------------------------------------------------------------------------------- /video_data/teeth_ref/328.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/328.png -------------------------------------------------------------------------------- /video_data/teeth_ref/377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/377.png -------------------------------------------------------------------------------- /video_data/teeth_ref/398.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/398.png -------------------------------------------------------------------------------- /video_data/teeth_ref/519.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/519.png -------------------------------------------------------------------------------- /video_data/teeth_ref/558.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/558.png -------------------------------------------------------------------------------- /video_data/teeth_ref/682.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/682.png -------------------------------------------------------------------------------- /video_data/teeth_ref/743.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/743.png -------------------------------------------------------------------------------- /video_data/teeth_ref/760.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/760.png -------------------------------------------------------------------------------- /video_data/teeth_ref/794.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/794.png -------------------------------------------------------------------------------- /web_demo/Flowchart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/Flowchart.jpg -------------------------------------------------------------------------------- /web_demo/README.md: -------------------------------------------------------------------------------- 1 | # DH_Live_mini 部署说明 2 | 3 | > [!NOTE] 4 | > 本项目专注于在最小硬件资源(无GPU、普通2核4G CPU)环境下实现低延迟的数字人服务部署。 5 | 6 | ## 服务组件分布 7 | 8 | | 组件 | 部署位置 | 9 | |--------|------------| 10 | | VAD | Web本地 | 11 | | ASR | 服务器本地 | 12 | | LLM | 云端服务 | 13 | | TTS | 服务器本地 | 14 | | 数字人 | Web本地 | 15 | 16 | ![deepseek_mermaid_20250428_94e921](https://github.com/user-attachments/assets/505a1602-86c8-4b80-b692-9f6c9dcb19ac) 17 | 18 | ## 目录结构 19 | 20 | 本项目目录结构如下: 21 | ```bash 22 | 项目根目录/ 23 | ├── models/ # 本地TTS及ASR模型 24 | │ ├── sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/ # ASR 25 | │ ├── sherpa-onnx-vits-zh-ll/ # TTS 26 | ├── static/ # 静态资源文件夹 27 | │ ├── assets/ # 人物形象资源文件夹 28 | │ ├── assets2/ # 人物2形象资源文件夹 29 | │ ├── common/ # 公共资源文件夹 30 | │ ├── css/ # CSS样式文件夹 31 | │ ├── js/ # JavaScript脚本文件夹 32 | │ ├── DHLiveMini.wasm # AI推理组件 33 | │ ├── dialog.html # MiniLive.html包含的纯对话iframe页面 34 | │ ├── dialog_RealTime.html # MiniLive_RealTime.html包含的纯对话iframe页面 35 | │ └── MiniLive.html # 数字人视频流主页面(简单demo) 36 | │ └── MiniLive_RealTime.html # 数字人视频流主页面(实时语音对话页面,推荐!) 37 | ├── voiceapi/ # asr、llm、tts具体设置 38 | └── server.py # 启动网页服务的Python程序 39 | └── server_realtime.py # 启动实时语音对话网页服务的Python程序 40 | ``` 41 | ### 运行项目 42 | (New!)启动实时语音对话服务: 43 | 44 | (注意需要下载本地ASR&TTS模型,并设置openai API进行大模型对话),请看下方配置说明。 45 | ```bash 46 | # 切换到DH_live根目录下 47 | python web_demo/server_realtime.py 48 | ``` 49 | 打开浏览器,访问 http://localhost:8888/static/MiniLive_RealTime.html 50 | 51 | 52 | 如果只是需要简单演示服务: 53 | ```bash 54 | # 切换到DH_live根目录下 55 | python web_demo/server.py 56 | ``` 57 | 打开浏览器,访问 http://localhost:8888/static/MiniLive.html 58 | 59 | ## 配置说明 60 | 61 | ### 1. 替换对话服务网址 62 | 63 | 对于全流程语音通话demo,在 static/js/dialog_realtime.js 文件中,找到第1行,将 http://localhost:8888/eb_stream 替换为您自己的对话服务网址。例如: 64 | https://your-dialogue-service.com/eb_stream, 将第二行的websocket url也改为"wss://your-dialogue-service.com/asr?samplerate=16000" 65 | 66 | 对于简单演示demo,在 static/js/dialog.js 文件中,找到第1行,将 http://localhost:8888/eb_stream 替换为您自己的对话服务网址。例如: 67 | https://your-dialogue-service.com/eb_stream 68 | 69 | ### 2. 模拟对话服务 70 | 71 | server.py 提供了一个模拟对话服务的示例。它接收JSON格式的输入,并流式返回JSON格式的响应。示例代码如下: 72 | 73 | 输入 JSON: 74 | ```bash 75 | { 76 | "prompt": "用户输入的对话内容" 77 | } 78 | ``` 79 | 输出 JSON(流式返回): 80 | ```bash 81 | { 82 | "text": "返回的部分对话文本", 83 | "audio": "base64编码的音频数据", 84 | "endpoint": false // 是否为对话的最后一个片段,true表示结束 85 | } 86 | ``` 87 | ### 3. 全流程的实时语音对话 88 | 下载相关模型(可以替换为其他类似模型): 89 | 90 | ASR model: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 91 | 92 | TTS model: https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2 93 | 94 | 在voiceapi/llm.py中,按照OpneAI API格式配置大模型接口: 95 | 96 | 豆包: 97 | ```bash 98 | from openai import OpenAI 99 | base_url = "https://ark.cn-beijing.volces.com/api/v3" 100 | api_key = "*****************************" 101 | model_name = "doubao-pro-32k-character-241215" 102 | 103 | llm_client = OpenAI( 104 | base_url=base_url, 105 | api_key=api_key, 106 | ) 107 | ``` 108 | 109 | DeepSeek: 110 | ```bash 111 | from openai import OpenAI 112 | base_url = "https://api.deepseek.com" 113 | api_key = "" 114 | model_name = "deepseek-chat" 115 | 116 | llm_client = OpenAI( 117 | base_url=base_url, 118 | api_key=api_key, 119 | ) 120 | ``` 121 | 122 | ### 4. 更换人物形象 123 | 124 | 要更换人物形象,请将新形象包中的文件替换 assets 文件夹中的对应文件。确保新文件的命名和路径与原有文件一致,以避免引用错误。 125 | 126 | ### 5. WebCodecs API 使用注意事项 127 | 128 | 本项目使用了 WebCodecs API,该 API 仅在安全上下文(HTTPS 或 localhost)中可用。因此,在部署或测试时,请确保您的网页在 HTTPS 环境下运行,或者使用 localhost 进行本地测试。 129 | 130 | ### 6. Thanks 131 | 此处重点感谢以下项目,本项目大量使用了以下项目的相关代码 132 | 133 | - [Project AIRI](https://github.com/moeru-ai/airi) 134 | - [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) 135 | -------------------------------------------------------------------------------- /web_demo/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import asyncio 4 | import re 5 | import base64 6 | from fastapi.responses import StreamingResponse 7 | from fastapi.staticfiles import StaticFiles 8 | from fastapi import FastAPI, Request, UploadFile, File,HTTPException 9 | app = FastAPI() 10 | 11 | # 挂载静态文件 12 | app.mount("/static", StaticFiles(directory="web_demo/static"), name="static") 13 | 14 | def get_audio(text_cache, voice_speed, voice_id): 15 | # 读取一个语音文件模拟语音合成的结果 16 | with open("web_demo/static/common/test.wav", "rb") as audio_file: 17 | audio_value = audio_file.read() 18 | base64_string = base64.b64encode(audio_value).decode('utf-8') 19 | return base64_string 20 | 21 | def llm_answer(prompt): 22 | # 模拟大模型的回答 23 | answer = "我会重复三遍来模仿大模型的回答,我会重复三遍来模仿大模型的回答,我会重复三遍来模仿大模型的回答。" 24 | return answer 25 | 26 | def split_sentence(sentence, min_length=10): 27 | # 定义包括小括号在内的主要标点符号 28 | punctuations = r'[。?!;…,、()()]' 29 | # 使用正则表达式切分句子,保留标点符号 30 | parts = re.split(f'({punctuations})', sentence) 31 | parts = [p for p in parts if p] # 移除空字符串 32 | sentences = [] 33 | current = '' 34 | for part in parts: 35 | if current: 36 | # 如果当前片段加上新片段长度超过最小长度,则将当前片段添加到结果中 37 | if len(current) + len(part) >= min_length: 38 | sentences.append(current + part) 39 | current = '' 40 | else: 41 | current += part 42 | else: 43 | current = part 44 | # 将剩余的片段添加到结果中 45 | if len(current) >= 2: 46 | sentences.append(current) 47 | return sentences 48 | 49 | 50 | import asyncio 51 | async def gen_stream(prompt, asr = False, voice_speed=None, voice_id=None): 52 | print("XXXXXXXXX", voice_speed, voice_id) 53 | if asr: 54 | chunk = { 55 | "prompt": prompt 56 | } 57 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块 58 | 59 | text_cache = llm_answer(prompt) 60 | sentences = split_sentence(text_cache) 61 | 62 | for index_, sub_text in enumerate(sentences): 63 | base64_string = get_audio(sub_text, voice_speed, voice_id) 64 | # 生成 JSON 格式的数据块 65 | chunk = { 66 | "text": sub_text, 67 | "audio": base64_string, 68 | "endpoint": index_ == len(sentences)-1 69 | } 70 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块 71 | await asyncio.sleep(0.2) # 模拟异步延迟 72 | 73 | # 处理 ASR 和 TTS 的端点 74 | @app.post("/process_audio") 75 | async def process_audio(file: UploadFile = File(...)): 76 | # 模仿调用 ASR API 获取文本 77 | text = "语音已收到,这里只是模仿,真正对话需要您自己设置ASR服务。" 78 | # 调用 TTS 生成流式响应 79 | return StreamingResponse(gen_stream(text, asr=True), media_type="application/json") 80 | 81 | 82 | async def call_asr_api(audio_data): 83 | # 调用ASR完成语音识别 84 | answer = "语音已收到,这里只是模仿,真正对话需要您自己设置ASR服务。" 85 | return answer 86 | 87 | @app.post("/eb_stream") # 前端调用的path 88 | async def eb_stream(request: Request): 89 | try: 90 | body = await request.json() 91 | input_mode = body.get("input_mode") 92 | voice_speed = body.get("voice_speed") 93 | voice_id = body.get("voice_id") 94 | 95 | if input_mode == "audio": 96 | base64_audio = body.get("audio") 97 | # 解码 Base64 音频数据 98 | audio_data = base64.b64decode(base64_audio) 99 | # 这里可以添加对音频数据的处理逻辑 100 | prompt = await call_asr_api(audio_data) # 假设 call_asr_api 可以处理音频数据 101 | return StreamingResponse(gen_stream(prompt, asr=True, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json") 102 | elif input_mode == "text": 103 | prompt = body.get("prompt") 104 | return StreamingResponse(gen_stream(prompt, asr=False, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json") 105 | else: 106 | raise HTTPException(status_code=400, detail="Invalid input mode") 107 | except Exception as e: 108 | raise HTTPException(status_code=500, detail=str(e)) 109 | 110 | # 启动Uvicorn服务器 111 | if __name__ == "__main__": 112 | import uvicorn 113 | uvicorn.run(app, host="0.0.0.0", port=8888) 114 | -------------------------------------------------------------------------------- /web_demo/server_realtime.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from contextlib import asynccontextmanager 4 | import re 5 | import asyncio 6 | import base64 7 | from fastapi.responses import StreamingResponse 8 | from fastapi.staticfiles import StaticFiles 9 | from fastapi import FastAPI, Request, UploadFile, File,HTTPException,WebSocketDisconnect,WebSocket 10 | from voiceapi.asr import start_asr_stream, ASRResult,ASREngineManager 11 | import uvicorn 12 | import argparse 13 | from voiceapi.llm import llm_stream 14 | from voiceapi.tts import get_audio,TTSEngineManager 15 | 16 | # 2. 生命周期管理 17 | @asynccontextmanager 18 | async def lifespan(app: FastAPI): 19 | # 服务启动时初始化模型(示例参数) 20 | print("ASR模型正在初始化,请稍等") 21 | ASREngineManager.initialize(samplerate=16000, args = args) 22 | print("TTS模型正在初始化,请稍等") 23 | TTSEngineManager.initialize(args = args) 24 | yield 25 | # 服务关闭时清理资源 26 | if ASREngineManager.get_engine(): 27 | ASREngineManager.get_engine().cleanup() 28 | 29 | 30 | app = FastAPI(lifespan=lifespan) 31 | 32 | # 挂载静态文件 33 | app.mount("/static", StaticFiles(directory="web_demo/static"), name="static") 34 | 35 | 36 | def split_sentence(sentence, min_length=10): 37 | # 定义包括小括号在内的主要标点符号 38 | punctuations = r'[。?!;…,、()()]' 39 | # 使用正则表达式切分句子,保留标点符号 40 | parts = re.split(f'({punctuations})', sentence) 41 | parts = [p for p in parts if p] # 移除空字符串 42 | sentences = [] 43 | current = '' 44 | for part in parts: 45 | if current: 46 | # 如果当前片段加上新片段长度超过最小长度,则将当前片段添加到结果中 47 | if len(current) + len(part) >= min_length: 48 | sentences.append(current + part) 49 | current = '' 50 | else: 51 | current += part 52 | else: 53 | current = part 54 | # 将剩余的片段添加到结果中 55 | if len(current) >= 2: 56 | sentences.append(current) 57 | return sentences 58 | 59 | PUNCTUATION_SET = { 60 | ',', " ", '。', '!', '?', ';', ':', '、', '(', ')', '【', '】', '“', '”', 61 | ',', '.', '!', '?', ';', ':', '(', ')', '[', ']', '"', "'" 62 | } 63 | 64 | async def gen_stream(prompt, asr = False, voice_speed=None, voice_id=None): 65 | print("gen_stream", voice_speed, voice_id) 66 | if asr: 67 | chunk = { 68 | "prompt": prompt 69 | } 70 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块 71 | 72 | # Streaming: 73 | print("----- streaming request -----") 74 | stream = llm_stream(prompt) 75 | llm_answer_cache = "" 76 | for chunk in stream: 77 | if not chunk.choices: 78 | continue 79 | llm_answer_cache += chunk.choices[0].delta.content 80 | 81 | # 查找第一个标点符号的位置 82 | punctuation_pos = -1 83 | for i, char in enumerate(llm_answer_cache[8:]): 84 | if char in PUNCTUATION_SET: 85 | punctuation_pos = i + 8 86 | break 87 | # 如果找到标点符号且第一小句字数大于8 88 | if punctuation_pos != -1: 89 | # 获取第一小句 90 | first_sentence = llm_answer_cache[:punctuation_pos + 1] 91 | # 剩余的文字 92 | remaining_text = llm_answer_cache[punctuation_pos + 1:] 93 | print("get_audio: ", first_sentence) 94 | base64_string = await get_audio(first_sentence, voice_id=voice_id, voice_speed=voice_speed) 95 | chunk = { 96 | "text": first_sentence, 97 | "audio": base64_string, 98 | "endpoint": False 99 | } 100 | 101 | # 更新缓存为剩余的文字 102 | llm_answer_cache = remaining_text 103 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块 104 | await asyncio.sleep(0.2) # 模拟异步延迟 105 | print("get_audio: ", llm_answer_cache) 106 | if len(llm_answer_cache) >= 2: 107 | base64_string = await get_audio(llm_answer_cache, voice_id=voice_id, voice_speed=voice_speed) 108 | else: 109 | base64_string = "" 110 | chunk = { 111 | "text": llm_answer_cache, 112 | "audio": base64_string, 113 | "endpoint": True 114 | } 115 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块 116 | 117 | @app.websocket("/asr") 118 | async def websocket_asr(websocket: WebSocket, samplerate: int = 16000): 119 | await websocket.accept() 120 | 121 | asr_stream = await start_asr_stream(samplerate, args) 122 | if not asr_stream: 123 | print("failed to start ASR stream") 124 | await websocket.close() 125 | return 126 | 127 | async def task_recv_pcm(): 128 | while True: 129 | try: 130 | data = await asyncio.wait_for(websocket.receive(), timeout=1.0) 131 | # print(f"message: {data}") 132 | except asyncio.TimeoutError: 133 | continue # 没有数据到达,继续循环 134 | 135 | if "text" in data.keys(): 136 | print(f"Received text message: {data}") 137 | data = data["text"] 138 | if data.strip() == "vad": 139 | print("VAD signal received") 140 | await asr_stream.vad_touched() 141 | elif "bytes" in data.keys(): 142 | pcm_bytes = data["bytes"] 143 | print("XXXX pcm_bytes", len(pcm_bytes)) 144 | if not pcm_bytes: 145 | return 146 | await asr_stream.write(pcm_bytes) 147 | 148 | 149 | async def task_send_result(): 150 | while True: 151 | result: ASRResult = await asr_stream.read() 152 | if not result: 153 | return 154 | await websocket.send_json(result.to_dict()) 155 | try: 156 | await asyncio.gather(task_recv_pcm(), task_send_result()) 157 | except WebSocketDisconnect: 158 | print("asr: disconnected") 159 | finally: 160 | await asr_stream.close() 161 | 162 | @app.post("/eb_stream") # 前端调用的path 163 | async def eb_stream(request: Request): 164 | try: 165 | body = await request.json() 166 | input_mode = body.get("input_mode") 167 | voice_speed = body.get("voice_speed", 1.0) 168 | voice_id = body.get("voice_id", 0) 169 | 170 | if voice_speed == "": 171 | voice_speed = 1.0 172 | if voice_id == "": 173 | voice_id = 0 174 | 175 | if input_mode == "text": 176 | prompt = body.get("prompt") 177 | return StreamingResponse(gen_stream(prompt, asr=False, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json") 178 | else: 179 | raise HTTPException(status_code=400, detail="Invalid input mode") 180 | except Exception as e: 181 | raise HTTPException(status_code=500, detail=str(e)) 182 | 183 | # 启动Uvicorn服务器 184 | if __name__ == "__main__": 185 | models_root = './models' 186 | 187 | for d in ['.', '..', 'web_demo']: 188 | if os.path.isdir(f'{d}/models'): 189 | models_root = f'{d}/models' 190 | break 191 | 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--port", type=int, default=8888, help="port number") 194 | parser.add_argument("--addr", type=str, 195 | default="0.0.0.0", help="serve address") 196 | 197 | parser.add_argument("--asr-provider", type=str, 198 | default="cpu", help="asr provider, cpu or cuda") 199 | parser.add_argument("--tts-provider", type=str, 200 | default="cpu", help="tts provider, cpu or cuda") 201 | 202 | parser.add_argument("--threads", type=int, default=2, 203 | help="number of threads") 204 | 205 | parser.add_argument("--models-root", type=str, default=models_root, 206 | help="model root directory") 207 | 208 | parser.add_argument("--asr-model", type=str, default='zipformer-bilingual', 209 | help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en, whisper-medium") 210 | 211 | parser.add_argument("--asr-lang", type=str, default='zh', 212 | help="ASR language, zh, en, ja, ko, yue") 213 | 214 | parser.add_argument("--tts-model", type=str, default='sherpa-onnx-vits-zh-ll', 215 | help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en") 216 | 217 | args = parser.parse_args() 218 | 219 | if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda': 220 | print( 221 | "vits-melo-tts-zh_en does not support CUDA fallback to CPU") 222 | args.tts_provider = 'cpu' 223 | 224 | uvicorn.run(app, host=args.addr, port=args.port) 225 | -------------------------------------------------------------------------------- /web_demo/static/DHLiveMini.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/DHLiveMini.wasm -------------------------------------------------------------------------------- /web_demo/static/MiniLive.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | MiniLive 9 | 81 | 82 | 83 | 84 |
85 | MiniMates: loading... 86 |
87 | 88 | 89 | 90 | 91 | 92 | 93 |
94 | 95 | 96 |
加载中
97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /web_demo/static/MiniLive_RealTime.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | MiniLive 9 | 129 | 130 | 131 | 137 |
138 | 145 |
146 | 147 |
148 | MiniMates: loading... 149 |
150 | 151 | 152 | 153 |
154 | 155 |
加载中
156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /web_demo/static/MiniLive_new.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | MiniLive 8 | 128 | 129 | 130 | 138 |
139 | 145 |
146 | 147 |
148 | MiniMates: loading... 149 |
150 | 151 | 152 | 153 |
154 | 155 |
加载中
156 | 157 | --> 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /web_demo/static/assets/01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets/01.mp4 -------------------------------------------------------------------------------- /web_demo/static/assets/combined_data.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets/combined_data.json.gz -------------------------------------------------------------------------------- /web_demo/static/assets2/01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets2/01.mp4 -------------------------------------------------------------------------------- /web_demo/static/assets2/combined_data.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets2/combined_data.json.gz -------------------------------------------------------------------------------- /web_demo/static/common/bs_texture_halfFace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/bs_texture_halfFace.png -------------------------------------------------------------------------------- /web_demo/static/common/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/favicon.ico -------------------------------------------------------------------------------- /web_demo/static/common/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/test.wav -------------------------------------------------------------------------------- /web_demo/static/css/material-icons.css: -------------------------------------------------------------------------------- 1 | /* 定义字体 */ 2 | @font-face { 3 | font-family: 'Material Icons'; 4 | font-style: normal; 5 | font-weight: 400; 6 | src: url('../fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2') format('woff2'); 7 | } 8 | 9 | /* 定义图标样式 */ 10 | .material-icons { 11 | font-family: 'Material Icons'; 12 | font-weight: normal; 13 | font-style: normal; 14 | font-size: 24px; 15 | line-height: 1; 16 | letter-spacing: normal; 17 | text-transform: none; 18 | display: inline-block; 19 | white-space: nowrap; 20 | word-wrap: normal; 21 | direction: ltr; 22 | font-feature-settings: 'liga'; 23 | -webkit-font-smoothing: antialiased; 24 | } -------------------------------------------------------------------------------- /web_demo/static/dialog.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI聊天 7 | 8 | 124 | 125 | 126 | 127 |
128 | 129 | 130 |
131 | 132 |
133 | 134 | 135 |
136 |
137 | 140 |
141 | 点击说话 142 |
143 | 144 | 147 |
148 |
149 | 150 | 151 | -------------------------------------------------------------------------------- /web_demo/static/dialog_RealTime.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI聊天 7 | 8 | 128 | 129 | 130 | 131 |
132 | 133 | 134 |
135 | 136 |
137 | 138 | 139 |
140 |
141 | 144 |
145 | 点击开始对话 146 |
147 | 148 | 151 |
152 |
153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /web_demo/static/fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2 -------------------------------------------------------------------------------- /web_demo/static/js/MiniMateLoader.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', function () { 2 | init(); 3 | }); 4 | 5 | async function init() 6 | { 7 | const spinner = document.querySelector('#loadingSpinner'); 8 | const screen = document.querySelector('#screen'); 9 | const showUi = () => { 10 | spinner.style.display = 'none'; 11 | screen.style.display = 'block'; 12 | } 13 | const instance = await qtLoad({ 14 | qt: { 15 | onLoaded: () => showUi(), 16 | entryFunction: window.createQtAppInstance, 17 | containerElements: [screen], 18 | } 19 | }); 20 | await newVideoTask(); 21 | document.getElementById('screen2').style.display = 'block'; 22 | } 23 | 24 | 25 | async function qtLoad(config) 26 | { 27 | const throwIfEnvUsedButNotExported = (instance, config) => 28 | { 29 | const environment = config.environment; 30 | if (!environment || Object.keys(environment).length === 0) 31 | return; 32 | const isEnvExported = typeof instance.ENV === 'object'; 33 | if (!isEnvExported) 34 | throw new Error('ENV must be exported if environment variables are passed'); 35 | }; 36 | 37 | const throwIfFsUsedButNotExported = (instance, config) => 38 | { 39 | const environment = config.environment; 40 | if (!environment || Object.keys(environment).length === 0) 41 | return; 42 | const isFsExported = typeof instance.FS === 'object'; 43 | if (!isFsExported) 44 | throw new Error('FS must be exported if preload is used'); 45 | }; 46 | 47 | if (typeof config !== 'object') 48 | throw new Error('config is required, expected an object'); 49 | if (typeof config.qt !== 'object') 50 | throw new Error('config.qt is required, expected an object'); 51 | if (typeof config.qt.entryFunction !== 'function') 52 | config.qt.entryFunction = window.createQtAppInstance; 53 | 54 | config.qt.qtdir ??= 'qt'; 55 | config.qt.preload ??= []; 56 | 57 | config.qtContainerElements = config.qt.containerElements; 58 | delete config.qt.containerElements; 59 | config.qtFontDpi = config.qt.fontDpi; 60 | delete config.qt.fontDpi; 61 | 62 | // Used for rejecting a failed load's promise where emscripten itself does not allow it, 63 | // like in instantiateWasm below. This allows us to throw in case of a load error instead of 64 | // hanging on a promise to entry function, which emscripten unfortunately does. 65 | let circuitBreakerReject; 66 | const circuitBreaker = new Promise((_, reject) => { circuitBreakerReject = reject; }); 67 | 68 | // If module async getter is present, use it so that module reuse is possible. 69 | if (config.qt.module) { 70 | config.instantiateWasm = async (imports, successCallback) => 71 | { 72 | try { 73 | const module = await config.qt.module; 74 | successCallback( 75 | await WebAssembly.instantiate(module, imports), module); 76 | } catch (e) { 77 | circuitBreakerReject(e); 78 | } 79 | } 80 | } 81 | 82 | const qtPreRun = (instance) => { 83 | // Copy qt.environment to instance.ENV 84 | throwIfEnvUsedButNotExported(instance, config); 85 | for (const [name, value] of Object.entries(config.qt.environment ?? {})) 86 | instance.ENV[name] = value; 87 | 88 | // Copy self.preloadData to MEMFS 89 | const makeDirs = (FS, filePath) => { 90 | const parts = filePath.split("/"); 91 | let path = "/"; 92 | for (let i = 0; i < parts.length - 1; ++i) { 93 | const part = parts[i]; 94 | if (part == "") 95 | continue; 96 | path += part + "/"; 97 | try { 98 | FS.mkdir(path); 99 | } catch (error) { 100 | const EEXIST = 20; 101 | if (error.errno != EEXIST) 102 | throw error; 103 | } 104 | } 105 | } 106 | throwIfFsUsedButNotExported(instance, config); 107 | for ({destination, data} of self.preloadData) { 108 | makeDirs(instance.FS, destination); 109 | instance.FS.writeFile(destination, new Uint8Array(data)); 110 | } 111 | } 112 | 113 | if (!config.preRun) 114 | config.preRun = []; 115 | config.preRun.push(qtPreRun); 116 | 117 | config.onRuntimeInitialized = () => config.qt.onLoaded?.(); 118 | 119 | const originalLocateFile = config.locateFile; 120 | config.locateFile = filename => 121 | { 122 | const originalLocatedFilename = originalLocateFile ? originalLocateFile(filename) : filename; 123 | if (originalLocatedFilename.startsWith('libQt6')) 124 | return `${config.qt.qtdir}/lib/${originalLocatedFilename}`; 125 | return originalLocatedFilename; 126 | } 127 | 128 | const originalOnExit = config.onExit; 129 | config.onExit = code => { 130 | originalOnExit?.(); 131 | config.qt.onExit?.({ 132 | code, 133 | crashed: false 134 | }); 135 | } 136 | 137 | const originalOnAbort = config.onAbort; 138 | config.onAbort = text => 139 | { 140 | originalOnAbort?.(); 141 | 142 | aborted = true; 143 | config.qt.onExit?.({ 144 | text, 145 | crashed: true 146 | }); 147 | }; 148 | 149 | const fetchPreloadFiles = async () => { 150 | const fetchJson = async path => (await fetch(path)).json(); 151 | const fetchArrayBuffer = async path => (await fetch(path)).arrayBuffer(); 152 | const loadFiles = async (paths) => { 153 | const source = paths['source'].replace('$QTDIR', config.qt.qtdir); 154 | return { 155 | destination: paths['destination'], 156 | data: await fetchArrayBuffer(source) 157 | }; 158 | } 159 | const fileList = (await Promise.all(config.qt.preload.map(fetchJson))).flat(); 160 | self.preloadData = (await Promise.all(fileList.map(loadFiles))).flat(); 161 | } 162 | 163 | await fetchPreloadFiles(); 164 | 165 | // Call app/emscripten module entry function. It may either come from the emscripten 166 | // runtime script or be customized as needed. 167 | let instance; 168 | try { 169 | instance = await Promise.race( 170 | [circuitBreaker, config.qt.entryFunction(config)]); 171 | } catch (e) { 172 | config.qt.onExit?.({ 173 | text: e.message, 174 | crashed: true 175 | }); 176 | throw e; 177 | } 178 | 179 | return instance; 180 | } 181 | 182 | // Compatibility API. This API is deprecated, 183 | // and will be removed in a future version of Qt. 184 | function QtLoader(qtConfig) { 185 | 186 | const warning = 'Warning: The QtLoader API is deprecated and will be removed in ' + 187 | 'a future version of Qt. Please port to the new qtLoad() API.'; 188 | console.warn(warning); 189 | 190 | let emscriptenConfig = qtConfig.moduleConfig || {} 191 | qtConfig.moduleConfig = undefined; 192 | const showLoader = qtConfig.showLoader; 193 | qtConfig.showLoader = undefined; 194 | const showError = qtConfig.showError; 195 | qtConfig.showError = undefined; 196 | const showExit = qtConfig.showExit; 197 | qtConfig.showExit = undefined; 198 | const showCanvas = qtConfig.showCanvas; 199 | qtConfig.showCanvas = undefined; 200 | if (qtConfig.canvasElements) { 201 | qtConfig.containerElements = qtConfig.canvasElements 202 | qtConfig.canvasElements = undefined; 203 | } else { 204 | qtConfig.containerElements = qtConfig.containerElements; 205 | qtConfig.containerElements = undefined; 206 | } 207 | emscriptenConfig.qt = qtConfig; 208 | 209 | let qtloader = { 210 | exitCode: undefined, 211 | exitText: "", 212 | loadEmscriptenModule: _name => { 213 | try { 214 | qtLoad(emscriptenConfig); 215 | } catch (e) { 216 | showError?.(e.message); 217 | } 218 | } 219 | } 220 | 221 | qtConfig.onLoaded = () => { 222 | showCanvas?.(); 223 | } 224 | 225 | qtConfig.onExit = exit => { 226 | qtloader.exitCode = exit.code 227 | qtloader.exitText = exit.text; 228 | showExit?.(); 229 | } 230 | 231 | showLoader?.("Loading"); 232 | 233 | return qtloader; 234 | }; 235 | -------------------------------------------------------------------------------- /web_demo/static/js/audio_recorder.js: -------------------------------------------------------------------------------- 1 | // 将AudioWorklet处理逻辑转为字符串嵌入主文件 2 | const workletCode = ` 3 | class PCMProcessor extends AudioWorkletProcessor { 4 | constructor() { 5 | super(); 6 | this.port.onmessage = (event) => { 7 | if (event.data === 'stop') { 8 | this.port.postMessage('prepare to stop'); 9 | this.isStopped = true; 10 | if (this.buffer.length > 0 && this.buffer.length > this.targetSampleCount) { 11 | this.port.postMessage(new Int16Array(this.buffer)); 12 | this.port.postMessage({'event':'stopped'}); 13 | this.buffer = []; 14 | } 15 | } 16 | }; 17 | this.buffer = []; 18 | this.targetSampleCount = 1024; 19 | } 20 | 21 | process(inputs) { 22 | const input = inputs[0]; 23 | if (input.length > 0) { 24 | const inputData = input[0]; 25 | // 优化数据转换 26 | const samples = inputData.map(sample => 27 | Math.max(-32768, Math.min(32767, Math.round(sample * 32767))) 28 | ); 29 | this.buffer.push(...samples); 30 | 31 | while (this.buffer.length >= this.targetSampleCount) { 32 | const pcmData = this.buffer.splice(0, this.targetSampleCount); 33 | this.port.postMessage(new Int16Array(pcmData)); 34 | this.port.postMessage({'event':'sending'}); 35 | } 36 | } 37 | return true; 38 | } 39 | } 40 | 41 | registerProcessor('pcm-processor', PCMProcessor); 42 | `; 43 | class PCMAudioRecorder { 44 | constructor() { 45 | this.audioContext = null; 46 | this.stream = null; 47 | this.currentSource = null; 48 | this.audioCallback = null; 49 | } 50 | 51 | async connect(audioCallback) { 52 | this.audioCallback = audioCallback; 53 | if (!this.audioContext) { 54 | this.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); 55 | } 56 | console.log('Current sample rate:', this.audioContext.sampleRate, 'Hz'); 57 | 58 | // 生成动态worklet 59 | const blob = new Blob([workletCode], { type: 'application/javascript' }); 60 | const url = URL.createObjectURL(blob); 61 | 62 | try { 63 | await this.audioContext.audioWorklet.addModule(url); 64 | URL.revokeObjectURL(url); // 清除内存 65 | } catch (e) { 66 | console.error('Error loading AudioWorklet:', e); 67 | return; 68 | } 69 | 70 | this.stream = await navigator.mediaDevices.getUserMedia({ audio: true }); 71 | this.currentSource = this.audioContext.createMediaStreamSource(this.stream); 72 | 73 | this.processorNode = new AudioWorkletNode(this.audioContext, 'pcm-processor'); 74 | 75 | this.processorNode.port.onmessage = (event) => { 76 | if (event.data instanceof Int16Array) { 77 | this.audioCallback?.(event.data); 78 | } else if (event.data?.event === 'stopped') { 79 | console.log('Recorder stopped.'); 80 | } 81 | }; 82 | 83 | this.currentSource.connect(this.processorNode); 84 | this.processorNode.connect(this.audioContext.destination); 85 | } 86 | 87 | stop() { 88 | if (this.processorNode) { 89 | this.processorNode.port.postMessage('stop'); 90 | this.processorNode.disconnect(); 91 | this.processorNode = null; 92 | } 93 | 94 | this.stream?.getTracks().forEach(track => track.stop()); 95 | this.currentSource?.disconnect(); 96 | 97 | if (this.audioContext) { 98 | this.audioContext.close().then(() => { 99 | this.audioContext = null; 100 | }); 101 | } 102 | } 103 | } 104 | 105 | // 暴露到全局环境 106 | window.PCMAudioRecorder = PCMAudioRecorder; -------------------------------------------------------------------------------- /web_demo/voiceapi/llm.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | # 豆包 3 | base_url = "https://ark.cn-beijing.volces.com/api/v3" 4 | api_key = "" 5 | model_name = "doubao-pro-32k-character-241215" 6 | 7 | # # DeepSeek 8 | # base_url = "https://api.deepseek.com" 9 | # api_key = "" 10 | # model_name = "deepseek-chat" 11 | 12 | assert api_key, "您必须配置自己的LLM API秘钥" 13 | 14 | llm_client = OpenAI( 15 | base_url=base_url, 16 | api_key=api_key, 17 | ) 18 | 19 | 20 | def llm_stream(prompt): 21 | stream = llm_client.chat.completions.create( 22 | # 指定您创建的方舟推理接入点 ID,此处已帮您修改为您的推理接入点 ID 23 | model=model_name, 24 | messages=[ 25 | {"role": "system", "content": "你是人工智能助手"}, 26 | {"role": "user", "content": prompt}, 27 | ], 28 | # 响应内容是否流式返回 29 | stream=True, 30 | ) 31 | return stream 32 | -------------------------------------------------------------------------------- /web_demo/voiceapi/tts.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import os 3 | import time 4 | import sherpa_onnx 5 | import logging 6 | import numpy as np 7 | import asyncio 8 | import time 9 | import soundfile 10 | from scipy.signal import resample 11 | import io 12 | import re 13 | import threading 14 | import base64 15 | logger = logging.getLogger(__file__) 16 | 17 | splitter = re.compile(r'[,,。.!?!?;;、\n]') 18 | _tts_engines = {} 19 | 20 | tts_configs = { 21 | 'sherpa-onnx-vits-zh-ll': { 22 | 'model': 'model.onnx', 23 | 'lexicon': 'lexicon.txt', 24 | 'dict_dir': 'dict', 25 | 'tokens': 'tokens.txt', 26 | 'sample_rate': 16000, 27 | # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'], 28 | }, 29 | 'vits-zh-hf-theresa': { 30 | 'model': 'theresa.onnx', 31 | 'lexicon': 'lexicon.txt', 32 | 'dict_dir': 'dict', 33 | 'tokens': 'tokens.txt', 34 | 'sample_rate': 22050, 35 | # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'], 36 | }, 37 | 'vits-melo-tts-zh_en': { 38 | 'model': 'model.onnx', 39 | 'lexicon': 'lexicon.txt', 40 | 'dict_dir': 'dict', 41 | 'tokens': 'tokens.txt', 42 | 'sample_rate': 44100, 43 | 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'], 44 | }, 45 | } 46 | 47 | 48 | def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig: 49 | cfg = tts_configs[name] 50 | fsts = [] 51 | model_dir = os.path.join(model_root, name) 52 | for f in cfg.get('rule_fsts', ''): 53 | fsts.append(os.path.join(model_dir, f)) 54 | tts_rule_fsts = ','.join(fsts) if fsts else '' 55 | 56 | model_config = sherpa_onnx.OfflineTtsModelConfig( 57 | vits=sherpa_onnx.OfflineTtsVitsModelConfig( 58 | model=os.path.join(model_dir, cfg['model']), 59 | lexicon=os.path.join(model_dir, cfg['lexicon']), 60 | dict_dir=os.path.join(model_dir, cfg['dict_dir']), 61 | tokens=os.path.join(model_dir, cfg['tokens']), 62 | ), 63 | provider=provider, 64 | debug=0, 65 | num_threads=num_threads, 66 | ) 67 | tts_config = sherpa_onnx.OfflineTtsConfig( 68 | model=model_config, 69 | rule_fsts=tts_rule_fsts, 70 | max_num_sentences=max_num_sentences) 71 | 72 | if not tts_config.validate(): 73 | raise ValueError("tts: invalid config") 74 | 75 | return tts_config 76 | 77 | 78 | def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]: 79 | sample_rate = tts_configs[args.tts_model]['sample_rate'] 80 | cache_engine = _tts_engines.get(args.tts_model) 81 | if cache_engine: 82 | return cache_engine, sample_rate 83 | st = time.time() 84 | tts_config = load_tts_model( 85 | args.tts_model, args.models_root, args.tts_provider) 86 | 87 | cache_engine = sherpa_onnx.OfflineTts(tts_config) 88 | elapsed = time.time() - st 89 | logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s") 90 | _tts_engines[args.tts_model] = cache_engine 91 | 92 | return cache_engine, sample_rate 93 | 94 | # 1. 全局模型管理类 95 | class TTSEngineManager: 96 | _instance = None 97 | _lock = threading.Lock() 98 | 99 | def __new__(cls): 100 | with cls._lock: 101 | if not cls._instance: 102 | cls._instance = super().__new__(cls) 103 | cls._instance.engine = None 104 | return cls._instance 105 | 106 | @classmethod 107 | def initialize(cls, args): 108 | instance = cls() 109 | if instance.engine is None: # 安全访问属性 110 | instance.engine, instance.original_sample_rate = get_tts_engine(args) 111 | 112 | @classmethod 113 | def get_engine(cls): 114 | instance = cls() # 确保实例存在 115 | return instance.engine,instance.original_sample_rate # 安全访问属性 116 | 117 | 118 | async def get_audio(text, voice_speed=1.0, voice_id=0, target_sample_rate = 16000): 119 | print("run_tts", text, voice_speed, voice_id) 120 | # 获取全局共享的ASR引擎 121 | tts_engine,original_sample_rate = TTSEngineManager.get_engine() 122 | 123 | # 将同步方法放入线程池执行 124 | loop = asyncio.get_event_loop() 125 | audio = await loop.run_in_executor( 126 | None, 127 | lambda: tts_engine.generate(text, voice_id, voice_speed) 128 | ) 129 | # audio = tts_engine.generate(text, voice_id, voice_speed) 130 | samples = audio.samples 131 | if target_sample_rate != original_sample_rate: 132 | num_samples = int( 133 | len(samples) * target_sample_rate / original_sample_rate) 134 | resampled_chunk = resample(samples, num_samples) 135 | audio.samples = resampled_chunk.astype(np.float32) 136 | audio.sample_rate = target_sample_rate 137 | 138 | output = io.BytesIO() 139 | # 使用 soundfile 写入 WAV 格式数据(自动生成头部) 140 | soundfile.write( 141 | output, 142 | audio.samples, # 音频数据(numpy 数组) 143 | samplerate=audio.sample_rate, # 采样率(如 16000) 144 | subtype="PCM_16", # 16-bit PCM 编码 145 | format="WAV" # WAV 容器格式 146 | ) 147 | 148 | # 获取字节数据并 Base64 编码 149 | wav_data = output.getvalue() 150 | return base64.b64encode(wav_data).decode("utf-8") 151 | 152 | # import wave 153 | # import uuid 154 | # with wave.open('{}.wav'.format(uuid.uuid4()), 'w') as f: 155 | # f.setnchannels(1) 156 | # f.setsampwidth(2) 157 | # f.setframerate(16000) 158 | # f.writeframes(samples) 159 | # return base64.b64encode(samples).decode('utf-8') 160 | 161 | --------------------------------------------------------------------------------