├── .gitignore ├── README.md ├── configuration ├── __init__.py └── development_config.py ├── main.py ├── objects ├── __init__.py ├── audio_video_stream.py ├── queue_over.py └── video_full_frames.py ├── preprocess.py ├── requirements.txt ├── result_videos └── README.md ├── routers ├── audio2head_stream │ ├── __init__.py │ └── routes.py ├── helper1.py ├── helper2.py ├── inference_video1 │ ├── __init__.py │ └── routes.py └── inference_video2 │ ├── __init__.py │ └── routes.py ├── scripts ├── 1697513088193.wav ├── client_sample.py └── yangshi.mp4 └── utilities ├── async_generator2q.py ├── async_q2async_generator.py ├── audio_bytes2np_array.py ├── base64_estimate.py ├── extract_frames_from_video.py ├── is_bytes_wav.py ├── ndarray2frame.py ├── p2jpg.py ├── singleton.py ├── text2voice_gener ├── __init__.py ├── aliyun_text2voice_gener.py └── azure_tts.py └── wav_bytes_2channel.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.jpg 3 | *.mp4 4 | *.pth 5 | *.pyc 6 | __pycache__ 7 | *.h5 8 | *.avi 9 | *.wav 10 | filelists/*.txt 11 | evaluation/test_filelists/lr*.txt 12 | *.mkv 13 | *.gif 14 | *.webm 15 | *.mp3 16 | .idea 17 | test.py 18 | .run 19 | log.log 20 | DINet 21 | wav2lip_288x288 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DigtalAvatarRealtime 2 | 3 | 数字人的实时推流和视频合成服务 4 | 5 | ## 安装 6 | 7 | 演示在Ubuntu22.04上,Python3.10.8。 8 | 9 | requirements.txt中的torch系列库的安装最好到pytorch官网查询具体cuda编译版本的安装命令。 10 | 11 | ```shell 12 | sudo apt-get install ffmpeg 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## 配置文件 17 | 18 | configuration/development_config.py 19 | 20 | ## 启动 21 | 22 | ```shell 23 | nohup python main.py & 24 | ``` 25 | 26 | ## 客户端 HTTP API文档 27 | 28 | 浏览器打开: 29 | 30 | {hostname}/docs 31 | 32 | 比如:http://localhost/docs 33 | 34 | ## 客户端脚本及更多演示 35 | 36 | [./scripts](./scripts) -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- 1 | from .development_config import Settings 2 | -------------------------------------------------------------------------------- /configuration/development_config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | from utilities.singleton import singleton 3 | 4 | 5 | @singleton 6 | class Settings(BaseSettings): 7 | host: str = '0.0.0.0' 8 | port: int = 6006 9 | 10 | fps: float = 25 # 帧率 11 | mouth_region_size: int = 256 # help to resize window 12 | 13 | max_workers: int = 10 # 最大子进程worker数量 14 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['USE_SIMPLE_THREADED_LEVEL3'] = '1' 4 | os.environ['OMP_NUM_THREADS'] = '1' 5 | from loguru import logger 6 | import git 7 | 8 | if not os.path.exists("./DINet"): 9 | logger.info("Download DINet...") 10 | git.Repo.clone_from("https://github.com/monk-after-90s/DINet.git", "./DINet") 11 | if not os.path.exists("./wav2lip_288x288"): 12 | logger.info("Download wav2lip_288x288...") 13 | git.Repo.clone_from("https://github.com/monk-after-90s/wav2lip_288x288.git", "./wav2lip_288x288") 14 | 15 | import sys 16 | 17 | sys.path.append(os.path.abspath("./DINet")) 18 | sys.path.append(os.path.abspath("./wav2lip_288x288")) 19 | 20 | import multiprocessing 21 | import uvicorn 22 | from fastapi import FastAPI 23 | from fastapi.middleware.cors import CORSMiddleware 24 | from configuration import Settings 25 | # from routers.inference_video1 import router as inference_video_router1 26 | from routers.inference_video2 import router as inference_video_router2 27 | from preprocess import ensure_pool_executor_closed 28 | 29 | app = FastAPI() 30 | 31 | app.add_middleware( 32 | CORSMiddleware, 33 | allow_origins=["*"], 34 | allow_credentials=True, 35 | allow_methods=["*"], 36 | allow_headers=["*"], 37 | ) 38 | # app.include_router(inference_video_router1, prefix="/inferenceVideoV1") 39 | app.include_router(inference_video_router2, prefix="/inferenceVideoV2") 40 | 41 | 42 | @app.on_event("startup") 43 | def startup_event(): 44 | from preprocess import preload_videos, load_model 45 | # 预加载视频 46 | preload_videos() 47 | # 预加载模型到GPU 48 | load_model() 49 | 50 | 51 | @app.on_event("shutdown") 52 | async def shutdown_event(): 53 | ensure_pool_executor_closed() 54 | 55 | 56 | if __name__ == "__main__": 57 | logger.info(f"multiprocessing.set_start_method:fork") 58 | multiprocessing.set_start_method("fork", True) 59 | # 这个项目只能启动Python脚本而不能启动uvicorn 60 | develop_mode = os.getenv("PYTHONUNBUFFERED") == "1" 61 | 62 | settings = Settings() 63 | uvicorn.run("main:app", 64 | host=settings.host, 65 | port=settings.port, 66 | log_level="debug" if develop_mode else None, 67 | reload=True if develop_mode else False 68 | ) 69 | -------------------------------------------------------------------------------- /objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .video_full_frames import VideoFrames 2 | from .queue_over import QUEUE_OVER 3 | -------------------------------------------------------------------------------- /objects/audio_video_stream.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 3 | 4 | 5 | class AVStream: 6 | """音视频流""" 7 | 8 | def __init__(self, digital_man: str, scheduler: AsyncIOScheduler): 9 | self.digital_man: str = digital_man 10 | self._scheduler: Optional[AsyncIOScheduler] = scheduler 11 | -------------------------------------------------------------------------------- /objects/queue_over.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | QUEUE_OVER = uuid.uuid4() 4 | -------------------------------------------------------------------------------- /objects/video_full_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import List 4 | import torch 5 | from numpy import ndarray 6 | import numpy as np 7 | from torch import Tensor 8 | from configuration import Settings 9 | import cv2 10 | import magic 11 | import subprocess 12 | import platform 13 | import random 14 | from DINet.data_processing import compute_crop_radius 15 | 16 | 17 | class VideoFrames: 18 | """保存视频帧数据的数据结构""" 19 | 20 | def __init__(self, video_path: str): 21 | self.fps = 0 22 | self.video_path = video_path 23 | self.video_name = os.path.splitext(os.path.basename(video_path))[0] 24 | self.full_frames: List[VideoFullFrame] = [] 25 | self.ref_img_tensor: Tensor | None = None 26 | 27 | def gen_frames(self): 28 | """将视频文件转为帧数据和人脸关键点""" 29 | # 文件名,数字人的名字 30 | file_name_without_ext = os.path.splitext(os.path.basename(self.video_path))[0] 31 | 32 | settings = Settings() 33 | video_stream = cv2.VideoCapture(self.video_path) 34 | self.fps = video_stream.get(cv2.CAP_PROP_FPS) 35 | if self.fps != float(settings.fps) or 'video/mp4' not in magic.Magic(mime=True).from_file( 36 | self.video_path): # 修改帧率 和 类型 37 | video_stream.release() 38 | tmp_file = os.path.join("/dev/shm", file_name_without_ext + f"_{settings.fps}fps_{uuid.uuid4()}.mp4") 39 | # 转帧率和类型 40 | command = f'ffmpeg -i {self.video_path} -r {settings.fps} {tmp_file} -y' 41 | subprocess.call(command, shell=platform.system() != 'Windows') 42 | # 读取到内存 43 | try: 44 | video_stream = cv2.VideoCapture(tmp_file) 45 | finally: 46 | os.remove(tmp_file) 47 | self.fps = video_stream.get(cv2.CAP_PROP_FPS) 48 | assert self.fps == settings.fps 49 | # 以上,视频的类型和帧率得到确保 50 | # 获取逐帧的ndarray表示 51 | full_frames: List[ndarray] = [] 52 | while 1: 53 | still_reading, frame = video_stream.read() 54 | if not still_reading: 55 | video_stream.release() 56 | break 57 | full_frames.append(frame) 58 | # 人脸68关键点 59 | full_frames: ndarray = np.stack(full_frames) 60 | # 批图片人脸遮罩 61 | from preprocess import get_fa 62 | fa = get_fa() 63 | batch_landmarks = fa.get_landmarks_from_batch(torch.Tensor(full_frames.transpose(0, 3, 1, 2))) 64 | batch_landmarks = [landmarks[:68, :] for landmarks in batch_landmarks] 65 | assert all(landmarks.shape == (68, 2) for landmarks in batch_landmarks) 66 | # 封装进VideoFullFrame对象 67 | self.full_frames = [VideoFullFrame(full_frame, landmarks) for full_frame, landmarks in 68 | zip(full_frames, batch_landmarks)] 69 | self.pick5ref_images() 70 | 71 | def pick5ref_images(self): 72 | '''selecting five reference images''' 73 | ref_img_list = [] 74 | settings = Settings() 75 | resize_w = int(settings.mouth_region_size + settings.mouth_region_size // 4) 76 | resize_h = int((settings.mouth_region_size // 2) * 3 + settings.mouth_region_size // 8) 77 | ref_index_list = random.sample(range(5, len(self.full_frames)), 5) 78 | for ref_index in ref_index_list: 79 | crop_flag, crop_radius = compute_crop_radius( 80 | self.full_frames[0].full_frame.shape[:2][::-1], 81 | np.stack( 82 | [video_full_frame.landmarks for video_full_frame in self.full_frames[ref_index - 5:ref_index]]), 83 | ) 84 | if not crop_flag: 85 | raise ValueError('Our method can not handle videos with large change of facial size!!') 86 | crop_radius_1_4 = crop_radius // 4 87 | ref_img = self.full_frames[ref_index - 3].full_frame[:, :, ::-1] 88 | ref_landmark = self.full_frames[ref_index - 3].landmarks 89 | ref_img_crop = ref_img[ 90 | ref_landmark[29, 1] - crop_radius: 91 | ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, 92 | ref_landmark[33, 0] - crop_radius - crop_radius_1_4: 93 | ref_landmark[33, 0] + crop_radius + crop_radius_1_4, 94 | :] # 裁剪鼻嘴部 95 | ref_img_crop = cv2.resize(ref_img_crop, (resize_w, resize_h)) 96 | ref_img_crop = ref_img_crop / 255.0 # 颜色比例 97 | ref_img_list.append(ref_img_crop) 98 | ref_video_frame = np.concatenate(ref_img_list, 2) 99 | self.ref_img_tensor = torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() # 预加载 100 | 101 | 102 | class VideoFullFrame: 103 | """保存视频单帧数据的数据结构""" 104 | 105 | def __init__(self, full_frame: ndarray, landmarks: ndarray): 106 | # 原帧 107 | self._full_frame: ndarray = full_frame 108 | # 人脸68特征点 109 | self._landmarks: ndarray = landmarks.astype(int) 110 | 111 | @property 112 | def full_frame(self): 113 | return self._full_frame.copy() 114 | 115 | @property 116 | def landmarks(self): 117 | return self._landmarks 118 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from typing import Dict 4 | import face_alignment 5 | import torch 6 | from DINet.models.DINet import DINet 7 | from objects import VideoFrames 8 | from collections import OrderedDict 9 | from DINet.utils.deep_speech import DeepSpeech 10 | from loguru import logger 11 | from wav2lip_288x288.inference import load_model as _load_wav2lip_model 12 | from concurrent.futures import ProcessPoolExecutor 13 | from configuration.development_config import Settings 14 | 15 | video_full_frames: Dict[str, VideoFrames] = {} 16 | # 人脸检测 17 | _fa = None 18 | # 推理模型 19 | _DINet_model = None 20 | # deepspeech 模型 21 | _DSModel = None 22 | # Wav2Lip推理模型 23 | _Wav2Lip_model = None 24 | # 进程池执行器 25 | _pool_executor: ProcessPoolExecutor = None 26 | 27 | 28 | def get_DINet_model(): 29 | """获取DINet推理模型""" 30 | global _DINet_model 31 | if _DINet_model is None: 32 | logger.info(f"load DINet model") 33 | _DINet_model = DINet(3, 15, 29).cuda() 34 | pretrained_clip_DINet_path = "./DINet/asserts/clip_training_DINet_256mouth.pth" 35 | if not os.path.exists(pretrained_clip_DINet_path): 36 | raise FileNotFoundError( 37 | 'wrong path of pretrained model weight: {}。Reference "https://github.com/monk-after-90s/DINet" to download.'.format( 38 | pretrained_clip_DINet_path)) 39 | state_dict = torch.load(pretrained_clip_DINet_path)['state_dict']['net_g'] 40 | new_state_dict = OrderedDict() 41 | for k, v in state_dict.items(): 42 | name = k[7:] # remove module. 43 | new_state_dict[name] = v 44 | _DINet_model.load_state_dict(new_state_dict) 45 | _DINet_model.eval() 46 | return _DINet_model 47 | 48 | 49 | def get_Wav2Lip_model(): 50 | """获取Wav2Lip288推理模型""" 51 | global _Wav2Lip_model 52 | if _Wav2Lip_model is None: 53 | logger.info("load Wav2Lip 288×288 model...") 54 | checkpoint_path = "./wav2lip_288x288/checkpoints/checkpoint_prod.pth" 55 | if not os.path.exists(checkpoint_path): 56 | raise FileNotFoundError( 57 | f"File {checkpoint_path} doesn't exist. Refer to https://github.com/monk-after-90s/wav2lip_288x288.git.") 58 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 59 | _Wav2Lip_model = _load_wav2lip_model(checkpoint_path, device) 60 | return _Wav2Lip_model 61 | 62 | 63 | def get_DSModel(): 64 | """获取deepspeech 模型""" 65 | global _DSModel 66 | if _DSModel is None: 67 | logger.info(f"load deepspeech model") 68 | deepspeech_model_path = "./DINet/asserts/output_graph.pb" 69 | if not os.path.exists(deepspeech_model_path): 70 | raise FileNotFoundError( 71 | 'pls download pretrained model of deepspeech.Refer to "https://github.com/monk-after-90s/DINet" to download.') 72 | _DSModel = DeepSpeech(deepspeech_model_path) 73 | return _DSModel 74 | 75 | 76 | def get_fa(): 77 | """获取FaceAlignment实例 todo 待废弃""" 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 79 | global _fa 80 | if _fa is None: 81 | logger.info(f"load face_alignment model") 82 | _fa = face_alignment.FaceAlignment( 83 | face_alignment.LandmarksType.TWO_HALF_D, device=device, face_detector='blazeface') 84 | return _fa 85 | 86 | 87 | def preload_videos(): 88 | """ 89 | 预加载视频文件,存为内存数据供推理用 90 | """ 91 | # 指定视频文件的扩展名 92 | video_extensions = ["*.mp4", "*.avi", "*.mkv", "*.flv", "*.mov", "*.wmv"] 93 | # 指定需要搜索的文件夹 94 | folder = os.path.join(os.getcwd(), 'faces') 95 | # 用于存储找到的所有视频文件的路径 96 | videos = [] 97 | for video_extension in video_extensions: 98 | # os.path.join用于合并路径 99 | # glob.glob返回所有匹配的文件路径列表 100 | videos.extend(glob.glob(os.path.join(folder, video_extension))) 101 | 102 | # 迭代找到的视频文件路径,转成video_frames 103 | for video in videos: 104 | vff = VideoFrames(video) 105 | vff.gen_frames() 106 | video_full_frames[os.path.splitext(os.path.basename(video))[0]] = vff 107 | 108 | 109 | def load_model(): 110 | """加载模型到GPU""" 111 | # DINet预训练模型 112 | get_DINet_model() 113 | # deepspeech模型 114 | get_DSModel() 115 | # face-alignment 116 | get_fa() 117 | # Wav2Lip288预训练模型 118 | get_Wav2Lip_model() 119 | 120 | 121 | def get_pool_executor(): 122 | global _pool_executor 123 | if _pool_executor is None: 124 | logger.info(f"instantiate ProcessPoolExecutor") 125 | _pool_executor = ProcessPoolExecutor(max_workers=Settings().max_workers, 126 | mp_context=torch.multiprocessing.get_context("spawn")) 127 | return _pool_executor 128 | 129 | 130 | def ensure_pool_executor_closed(): 131 | """关闭进程池执行器""" 132 | if _pool_executor is not None: 133 | _pool_executor.shutdown() 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python == 4.6.0.66 2 | python_speech_features == 0.6 3 | resampy == 0.2.2 4 | tensorflow == 1.15.2 5 | #torch系列库的安装需要到pytorch官网查询具体cuda编译版本的安装命令 6 | torch==1.13.1 7 | torchvision == 0.14.1 8 | protobuf==3.20.* 9 | loguru==0.7.0 10 | uvicorn==0.22.0 11 | fastapi==0.100.1 12 | pydantic-settings==2.0.3 13 | GitPython==3.1.37 14 | python-magic==0.4.27 15 | face-alignment==1.4.0 16 | apscheduler==3.10.1 17 | aiohttp==3.8.5 18 | python-multipart==0.0.6 19 | moviepy==1.0.3 20 | aiofiles==23.2.1 21 | soundfile==0.12.1 22 | librosa==0.10.1 23 | numba==0.56.4 -------------------------------------------------------------------------------- /result_videos/README.md: -------------------------------------------------------------------------------- 1 | directory of result videos -------------------------------------------------------------------------------- /routers/audio2head_stream/__init__.py: -------------------------------------------------------------------------------- 1 | from .routes import router 2 | -------------------------------------------------------------------------------- /routers/audio2head_stream/routes.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from typing import Dict 3 | from objects.audio_video_stream import AVStream 4 | from routers.helper1 import scheduler 5 | from loguru import logger 6 | 7 | router = APIRouter() 8 | # 流队列 字典 9 | streams: Dict[str, AVStream] = {} 10 | 11 | 12 | @router.get("/establish_stream") 13 | async def establish_stream(digital_man: str): 14 | """ 15 | 建立一个持续的流 16 | 17 | :return: 18 | """ 19 | av_stream = AVStream(digital_man, scheduler) 20 | streams[av_stream.stream_id] = av_stream 21 | streams[av_stream.stream_id].update_del_time_from(streams) 22 | logger.info(f"len(streams)={len(streams)}") 23 | logger.debug(f"len(scheduler.get_jobs())={len(scheduler.get_jobs())}") 24 | return av_stream.stream_id 25 | -------------------------------------------------------------------------------- /routers/helper1.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import traceback 3 | from typing import List 4 | from configuration.development_config import Settings 5 | from DINet.utils.data_processing import compute_crop_radius 6 | import shutil 7 | import numpy as np 8 | from numpy import ndarray 9 | import cv2 10 | import random 11 | from loguru import logger 12 | from tqdm import tqdm 13 | import tempfile 14 | from moviepy.editor import * 15 | import io 16 | import torch 17 | from preprocess import get_pool_executor 18 | from preprocess import get_DSModel, get_fa, get_DINet_model 19 | from utilities.extract_frames_from_video import extract_frames_from_video_bytes 20 | 21 | 22 | def _get_frames_landmarks_pad(frames_ndarray: ndarray, video_landmark_data: ndarray, res_frame_length: int): 23 | """ 24 | align frame with driving audio 25 | 首位加点pad 26 | """ 27 | video_frames_data_cycle = np.concatenate([frames_ndarray, np.flip(frames_ndarray, 0)], 0) 28 | video_landmark_data_cycle = np.concatenate([video_landmark_data, np.flip(video_landmark_data, 0)], 0) 29 | video_frames_data_cycle_length = len(video_frames_data_cycle) 30 | if video_frames_data_cycle_length >= res_frame_length: 31 | res_video_frames_data = video_frames_data_cycle[:res_frame_length, :, :, :] 32 | res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :] 33 | else: 34 | divisor = res_frame_length // video_frames_data_cycle_length 35 | remainder = res_frame_length % video_frames_data_cycle_length 36 | res_video_frames_data = np.concatenate( 37 | [video_frames_data_cycle] * divisor + [video_frames_data_cycle[:remainder]], 0) 38 | res_video_landmark_data = np.concatenate( 39 | [video_landmark_data_cycle] * divisor + [video_landmark_data_cycle[:remainder, :, :]], 0) 40 | res_video_frames_data_pad: ndarray = np.pad(res_video_frames_data, ((2, 2), (0, 0), (0, 0), (0, 0)), mode='edge') 41 | res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge') 42 | return res_video_frames_data_pad, res_video_landmark_data_pad 43 | 44 | 45 | def _pick5frames(res_video_frames_data_pad: ndarray, 46 | res_video_landmark_data_pad: ndarray, 47 | resize_w: int, 48 | resize_h: int): 49 | ref_index_list = random.sample(range(5, res_video_frames_data_pad.shape[0] - 2), 5) 50 | ref_img_list = [] 51 | video_size = res_video_frames_data_pad.shape[1:3][::-1] 52 | for ref_index in ref_index_list: 53 | crop_flag, crop_radius = compute_crop_radius(video_size, 54 | res_video_landmark_data_pad[ref_index - 5:ref_index, :, :]) 55 | if not crop_flag: 56 | raise ValueError('our method can not handle videos with large change of facial size!!') 57 | crop_radius_1_4 = crop_radius // 4 58 | ref_img = res_video_frames_data_pad[ref_index - 3, :, :, ::-1] 59 | ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :] 60 | ref_img_crop = ref_img[ 61 | ref_landmark[29, 1] - crop_radius: 62 | ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, 63 | ref_landmark[33, 0] - crop_radius - crop_radius_1_4: 64 | ref_landmark[33, 0] + crop_radius + crop_radius_1_4, 65 | :] # 裁剪鼻嘴部 66 | ref_img_crop = cv2.resize(ref_img_crop, (resize_w, resize_h)) 67 | ref_img_crop = ref_img_crop / 255.0 # 颜色比例 68 | ref_img_list.append(ref_img_crop) 69 | ref_video_frame = np.concatenate(ref_img_list, 2) 70 | return ref_video_frame, video_size 71 | 72 | 73 | def inf2frames(ref_video_frame, 74 | video_size, 75 | pad_length, 76 | res_video_landmark_data_pad, 77 | res_video_frames_data_pad, 78 | resize_w, 79 | resize_h, 80 | mouth_region_size, 81 | ds_feature_padding, 82 | model) -> List[ndarray]: 83 | ref_img_tensor = torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() 84 | frames = [] 85 | for clip_end_index in tqdm(range(5, pad_length, 1)): 86 | crop_flag, crop_radius = compute_crop_radius( 87 | video_size, 88 | res_video_landmark_data_pad[clip_end_index - 5:clip_end_index, :, :], 89 | random_scale=1.05) # 5个图片一包,窗口移动处理 90 | if not crop_flag: 91 | raise ValueError('our method can not handle videos with large change of facial size!!') 92 | crop_radius_1_4 = crop_radius // 4 93 | frame_data = res_video_frames_data_pad[clip_end_index - 3, :, :, ::-1] # 包里面5个图片的中间那个 94 | frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] 95 | crop_frame_data = frame_data[ 96 | frame_landmark[29, 1] - crop_radius:frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, 97 | frame_landmark[33, 0] - crop_radius - crop_radius_1_4:frame_landmark[ 98 | 33, 0] + crop_radius + crop_radius_1_4, 99 | :] # 裁剪面部 100 | crop_frame_h, crop_frame_w = crop_frame_data.shape[0], crop_frame_data.shape[1] 101 | crop_frame_data = cv2.resize(crop_frame_data, (resize_w, resize_h)) # [32:224, 32:224, :] 102 | crop_frame_data = crop_frame_data / 255.0 103 | # todo 平均亮度校正 104 | crop_frame_data[mouth_region_size // 2:mouth_region_size // 2 + mouth_region_size, 105 | mouth_region_size // 8:mouth_region_size // 8 + mouth_region_size, :] = 0 106 | crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) 107 | deepspeech_tensor = torch.from_numpy( 108 | ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda() 109 | with torch.no_grad(): 110 | pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor) 111 | pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 # 面部 112 | pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w, crop_frame_h)) # 恢复原面部高宽 113 | frame_data[ 114 | frame_landmark[29, 1] - crop_radius: 115 | frame_landmark[29, 1] + crop_radius * 2, 116 | frame_landmark[33, 0] - crop_radius - crop_radius_1_4: 117 | frame_landmark[33, 0] + crop_radius + crop_radius_1_4, 118 | :] = pre_frame_resize[:crop_radius * 3, :, :] # 将推理的面部写回原帧 119 | frames.append(frame_data) 120 | return frames 121 | 122 | 123 | async def inf_video(filename, audio_bytes, video_bytes, inf_video_tasks, vid): 124 | ds_feature_fut = asyncio.get_running_loop().run_in_executor( 125 | None, lambda audio_bytes: 126 | get_DSModel().compute_audio_feature(io.BytesIO(audio_bytes)), 127 | audio_bytes) # 音频处理为推理所用特征值 128 | frames_ndarray = await asyncio.get_running_loop().run_in_executor( 129 | get_pool_executor(), extract_frames_from_video_bytes, video_bytes) # 视频处理 得到视频的帧ndarray表示 130 | batch_landmarks: ndarray = await asyncio.get_running_loop().run_in_executor( 131 | None, 132 | lambda frames_ndarray: 133 | get_fa().get_landmarks_from_batch(torch.Tensor(frames_ndarray.transpose(0, 3, 1, 2))), 134 | frames_ndarray) 135 | 136 | res_video_frames_data_pad, pred_frames = await asyncio.get_running_loop().run_in_executor(get_pool_executor(), 137 | inf_video_from_ndarray2frames, 138 | frames_ndarray, 139 | get_DINet_model(), 140 | await ds_feature_fut, 141 | batch_landmarks) 142 | assert res_video_frames_data_pad.shape[0] - 5 == len(pred_frames) 143 | # 推理的人脸遮罩 抠图 144 | pred_batch_landmarks = await asyncio.get_running_loop().run_in_executor( 145 | None, 146 | lambda frames_ndarray: 147 | get_fa().get_landmarks_from_batch(torch.Tensor(frames_ndarray.transpose(0, 3, 1, 2))), 148 | np.stack(pred_frames)) 149 | 150 | res_video_dir = f"./result_videos/{vid}" 151 | if os.path.exists(res_video_dir): 152 | try: 153 | shutil.rmtree(res_video_dir) 154 | except: 155 | ... 156 | os.makedirs(res_video_dir, exist_ok=True) 157 | res_video_path = os.path.join(res_video_dir, filename + '_facial_dubbing_add_audio.mp4') 158 | if os.path.exists(res_video_path): 159 | os.remove(res_video_path) 160 | await asyncio.get_running_loop().run_in_executor(get_pool_executor(), 161 | face_join2video_file, 162 | pred_frames, 163 | pred_batch_landmarks, 164 | res_video_frames_data_pad[2:-3], 165 | audio_bytes, 166 | res_video_path) 167 | 168 | asyncio.create_task(delay_clear(300, vid, inf_video_tasks)) 169 | 170 | 171 | def face_join2video_file(pred_frames, pred_batch_landmarks, org_frames_ndarr, audio_bytes, res_video_path): 172 | joined_frames = [] 173 | for p, points_68, frame_ndarray in zip(pred_frames, pred_batch_landmarks, org_frames_ndarr): 174 | if points_68.shape[0] >= 17: 175 | face_points = points_68[:17] 176 | if points_68.shape[0] >= 25: 177 | face_points = np.append(face_points, [points_68[24], points_68[19]], axis=0) 178 | face_points = np.stack(face_points).astype(np.int32) 179 | # 1. 创建一个长方形遮罩 180 | mask = np.zeros(p.shape[:2], dtype=np.uint8) 181 | # 2. 使用fillPoly绘制人脸遮罩 182 | cv2.fillPoly(mask, [face_points], (255, 255, 255)) 183 | # 反向遮罩 184 | reverse_mask = cv2.bitwise_not(mask) 185 | # 3. 使用遮罩提取人脸 186 | face_image = cv2.bitwise_and(p, p, mask=mask) 187 | # 提取人脸周围 188 | face_surrounding = cv2.bitwise_and(frame_ndarray, frame_ndarray, mask=reverse_mask) 189 | # 推理出的人脸贴回原帧 190 | joined_frame = cv2.add(face_image, face_surrounding[:, :, ::-1]) 191 | joined_frames.append(joined_frame) 192 | 193 | # 添加声音 194 | # 创建一个 VideoClip 对象 195 | video_clip = ImageSequenceClip(joined_frames, fps=25) 196 | # 将音频数据保存到临时文件 197 | temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') 198 | temp_audio_file.write(audio_bytes) 199 | temp_audio_file.close() 200 | # 创建一个 AudioFileClip 对象 201 | audio_clip = AudioFileClip(temp_audio_file.name) 202 | # 将音频添加到视频 203 | final_clip = video_clip.set_audio(audio_clip) 204 | # 保存最终的视频文件 205 | final_clip.write_videofile(res_video_path, codec="libx264", audio_codec="aac") 206 | # 删除临时音频文件 207 | os.unlink(temp_audio_file.name) 208 | 209 | 210 | async def delay_clear(delay_sec: float, vid, inf_video_tasks): 211 | # 延迟清理视频 212 | sleep_task = asyncio.create_task(asyncio.sleep(delay_sec)) 213 | try: 214 | await sleep_task 215 | finally: 216 | if vid in inf_video_tasks.keys(): inf_video_tasks.pop(vid) 217 | res_video_dir = f"./result_videos/{vid}" 218 | if os.path.exists(res_video_dir): 219 | try: 220 | shutil.rmtree(res_video_dir) 221 | except: 222 | ... 223 | logger.info(f"Video id:{vid} cleared.") 224 | 225 | 226 | def inf_video_from_ndarray2frames(frames_ndarray, 227 | DINet_model, 228 | ds_feature, 229 | batch_landmarks): 230 | try: 231 | res_frame_length = ds_feature.shape[0] 232 | ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode='edge') 233 | # 人脸68关键点 234 | batch_landmarks = [landmarks[:68, :] for landmarks in batch_landmarks] 235 | video_landmark_data: np.ndarray = np.stack(batch_landmarks).astype(int) 236 | ############################################## align frame with driving audio ##############################################从视频无限回环中截取以对齐音频 237 | res_video_frames_data_pad, res_video_landmark_data_pad = _get_frames_landmarks_pad(frames_ndarray, 238 | video_landmark_data, 239 | res_frame_length) 240 | assert ds_feature_padding.shape[0] == res_video_frames_data_pad.shape[0] == res_video_landmark_data_pad.shape[0] 241 | pad_length = ds_feature_padding.shape[0] 242 | ############################################## randomly select 5 reference images ############################################## 243 | mouth_region_size = Settings().mouth_region_size 244 | resize_w = int(mouth_region_size + mouth_region_size // 4) 245 | resize_h = int((mouth_region_size // 2) * 3 + mouth_region_size // 8) 246 | ref_video_frame, video_size = _pick5frames(res_video_frames_data_pad, res_video_landmark_data_pad, resize_w, 247 | resize_h) 248 | ############################################## inference frame by frame ############################################## 249 | return res_video_frames_data_pad.copy(), inf2frames(ref_video_frame, video_size, pad_length, 250 | res_video_landmark_data_pad, 251 | res_video_frames_data_pad, resize_w, resize_h, 252 | mouth_region_size, 253 | ds_feature_padding, 254 | DINet_model) 255 | except: 256 | traceback.print_exc() 257 | -------------------------------------------------------------------------------- /routers/helper2.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os.path 3 | import shutil 4 | import traceback 5 | import torch 6 | from loguru import logger 7 | from preprocess import get_Wav2Lip_model, get_fa, get_pool_executor 8 | import aiofiles 9 | import uuid 10 | from functools import partial 11 | 12 | from wav2lip_288x288.inference import main 13 | 14 | 15 | async def save_video_bytes_2shm_file(video_bytes: bytes): 16 | """ 17 | 保存视频二进制数据为内存中的文件 18 | """ 19 | face = f"/dev/shm/{uuid.uuid4()}.mp4" 20 | async with aiofiles.open(face, mode='wb') as f: 21 | await f.write(video_bytes) 22 | return face 23 | 24 | 25 | async def save_audio_bytes_2shm_file(audio_bytes: bytes): 26 | """ 27 | 保存音频二进制数据为内存中的文件 28 | """ 29 | audio = f"/dev/shm/{uuid.uuid4()}.wav" 30 | async with aiofiles.open(audio, mode='wb') as f: 31 | await f.write(audio_bytes) 32 | return audio 33 | 34 | 35 | async def inf_video(filename, audio_bytes, video_bytes, inf_video_tasks, vid): 36 | """ 37 | filename: 文件名(不包括拓展名) 38 | """ 39 | face, audio = '', '' 40 | try: 41 | # 获得 42 | face, audio = await asyncio.gather(save_video_bytes_2shm_file(video_bytes), 43 | save_audio_bytes_2shm_file(audio_bytes)) 44 | # 合成的视频结果所在文件夹 45 | res_video_dir = f"./result_videos/{vid}" 46 | if os.path.exists(res_video_dir): 47 | try: 48 | shutil.rmtree(res_video_dir) 49 | except: 50 | ... 51 | os.makedirs(res_video_dir, exist_ok=True) 52 | 53 | await asyncio.get_running_loop().run_in_executor( 54 | get_pool_executor(), 55 | partial(main, face, audio, get_Wav2Lip_model(), 56 | device='cuda' if torch.cuda.is_available() else 'cpu', 57 | outfile=os.path.join(res_video_dir, f"{filename}_result.mp4"))) 58 | except: 59 | traceback.print_exc() 60 | finally: 61 | asyncio.create_task(delay_clear(300, vid, inf_video_tasks)) 62 | if os.path.exists(face): 63 | os.remove(face) 64 | if os.path.exists(audio): 65 | os.remove(audio) 66 | 67 | 68 | async def delay_clear(delay_sec: float, vid, inf_video_tasks): 69 | # 延迟清理视频 70 | sleep_task = asyncio.create_task(asyncio.sleep(delay_sec)) 71 | try: 72 | await sleep_task 73 | finally: 74 | if vid in inf_video_tasks.keys(): inf_video_tasks.pop(vid) 75 | res_video_dir = f"./result_videos/{vid}" 76 | if os.path.exists(res_video_dir): 77 | try: 78 | shutil.rmtree(res_video_dir) 79 | except: 80 | ... 81 | logger.info(f"Video id:{vid} cleared.") 82 | -------------------------------------------------------------------------------- /routers/inference_video1/__init__.py: -------------------------------------------------------------------------------- 1 | """推理合成数字人视频""" 2 | from .routes import router 3 | -------------------------------------------------------------------------------- /routers/inference_video1/routes.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from fastapi import UploadFile, File 3 | import asyncio 4 | from asyncio import Task 5 | import uuid 6 | from typing import Dict 7 | import os 8 | from starlette.responses import FileResponse 9 | import fnmatch 10 | from fastapi import HTTPException 11 | from routers.helper1 import inf_video 12 | 13 | router = APIRouter() 14 | 15 | inf_video_tasks: Dict[str, Task] = {} 16 | 17 | 18 | # todo 参照inference_video2进行优化 19 | 20 | @router.post("/uploadAv", summary="接收音频和视频", tags=["音视频"]) 21 | async def uploadAv(audio: UploadFile = File(...), video: UploadFile = File(...)): 22 | """ 23 | 推理视频 24 | """ 25 | audio_bytes, video_bytes = await asyncio.gather(audio.read(), video.read()) # 读取音视频数据 26 | filename, extension = os.path.splitext(video.filename) 27 | vid = str(uuid.uuid4()) # 视频id 28 | inf_video_tasks[vid] = asyncio.create_task(inf_video(filename, audio_bytes, video_bytes, inf_video_tasks, vid)) 29 | return {"videoId": vid} 30 | 31 | 32 | @router.get("/downloadVideo") 33 | def downloadVideo(vid: str): 34 | if vid not in inf_video_tasks.keys(): 35 | raise HTTPException(status_code=404, detail="不存在的Video ID") 36 | if not inf_video_tasks[vid].done(): 37 | raise HTTPException(status_code=404, detail="视频文件不存在,因为文件正在合成中") 38 | # 触发可能的报错 39 | inf_video_tasks[vid].result() 40 | 41 | target_dir = f"result_videos/{vid}" 42 | pattern = '*facial_dubbing_add_audio.mp4' 43 | matches = [] 44 | for root, dirnames, filenames in os.walk(target_dir): 45 | for filename in fnmatch.filter(filenames, pattern): 46 | matches.append(os.path.join(root, filename)) 47 | if not matches: 48 | raise HTTPException(status_code=404, detail="视频文件不存在,因为文件过期") 49 | 50 | return FileResponse(matches[0], filename=os.path.basename(matches[0])) 51 | -------------------------------------------------------------------------------- /routers/inference_video2/__init__.py: -------------------------------------------------------------------------------- 1 | """推理合成数字人视频""" 2 | from .routes import router 3 | -------------------------------------------------------------------------------- /routers/inference_video2/routes.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | from starlette.responses import FileResponse 3 | from fastapi import APIRouter 4 | from fastapi import UploadFile, File 5 | import asyncio 6 | from asyncio import Task 7 | import uuid 8 | from typing import Dict 9 | import os 10 | from routers.helper2 import inf_video 11 | from utilities.is_bytes_wav import is_wav 12 | from fastapi import HTTPException, Query 13 | import aiofiles 14 | 15 | router = APIRouter() 16 | 17 | inf_video_tasks: Dict[str, Task] = {} 18 | 19 | 20 | @router.post("/uploadAv", summary= 21 | """ 22 | 接收音频和视频,用于数字人视频合成。返回的vid用于获取合成结果。 23 | Python Demo: 24 | 25 | import requests 26 | 27 | 28 | def uploadAv(): 29 | '''上传音视频''' 30 | url = "http://127.0.0.1:6006/inferenceVideoV2/uploadAv" 31 | 32 | audio = "1697513088193.wav" 33 | video = "yangshi.mp4" 34 | 35 | files = {'video': open(video, 'rb'), 'audio': open(audio, 'rb')} 36 | result = requests.post(url=url, files=files) 37 | print(result.text) 38 | 39 | 40 | if __name__ == '__main__': 41 | uploadAv() 42 | 43 | 44 | """, tags=["音视频"]) 45 | async def uploadAv(audio: UploadFile = File(..., title="Audio", description="让数字人说话的音频"), 46 | video: UploadFile = File(..., title="Video", description="合成数字人的基础视频")): 47 | """ 48 | 推理视频 49 | """ 50 | audio_bytes, video_bytes = await asyncio.gather(audio.read(), video.read()) # 读取音视频数据 51 | if not is_wav(audio_bytes): 52 | # 转为wav格式 53 | tmp_file_name = str(uuid.uuid4()) 54 | tmp_audio_path = os.path.join("/dev/shm", tmp_file_name) 55 | async with aiofiles.open(tmp_audio_path, mode='wb') as f: 56 | await f.write(audio_bytes) 57 | 58 | cmd = f"ffmpeg -i {tmp_audio_path} {tmp_audio_path}.wav" 59 | proc = await asyncio.create_subprocess_shell(cmd, 60 | stdout=asyncio.subprocess.PIPE, 61 | stderr=asyncio.subprocess.PIPE) 62 | stdout, stderr = await proc.communicate() 63 | if stderr is not None: 64 | raise HTTPException(status_code=500, detail=f"执行cmd:'{cmd}'报错。\nstdout={stdout},\nstderr={stderr}") 65 | # 一切正常 66 | os.remove(tmp_audio_path) 67 | async with aiofiles.open(f"{tmp_audio_path}.wav", mode='r') as f: 68 | audio_bytes = await f.read() 69 | os.remove(f"{tmp_audio_path}.wav") 70 | filename, extension = os.path.splitext(video.filename) 71 | if extension != ".mp4": 72 | raise HTTPException(status_code=415, detail=f"需要视频格式为mp4") 73 | vid = str(uuid.uuid4()) # 视频id 74 | inf_video_tasks[vid] = asyncio.create_task(inf_video(filename, audio_bytes, video_bytes, inf_video_tasks, vid)) 75 | return {"videoId": vid} 76 | 77 | 78 | @router.get("/downloadVideo", 79 | summary="在浏览器中下载合成的视频,比如直接粘贴URL‘http://127.0.0.1:6006/inferenceVideoV2/downloadVideo?vid=b299784a-854c-4dee-92be-1e9e7755be52’到地址栏然后Enter,正常会触发下载。使用代码下载我就不演示了,原理一样。") 80 | def downloadVideo(vid: str = Query(title="video id", description="The ID for video downloading")): 81 | """ 82 | 下载vid对应的合成视频 83 | """ 84 | if vid not in inf_video_tasks.keys(): 85 | raise HTTPException(status_code=404, detail="不存在的Video ID") 86 | if not inf_video_tasks[vid].done(): 87 | raise HTTPException(status_code=404, detail="视频文件不存在,因为文件正在合成中") 88 | # 触发可能的报错 89 | inf_video_tasks[vid].result() 90 | 91 | target_dir = f"result_videos/{vid}" 92 | pattern = '*.mp4' 93 | matches = [] 94 | for root, dirnames, filenames in os.walk(target_dir): 95 | for filename in fnmatch.filter(filenames, pattern): 96 | matches.append(os.path.join(root, filename)) 97 | if not matches: 98 | raise HTTPException(status_code=404, detail="视频文件不存在,因为文件过期") 99 | 100 | return FileResponse(matches[0], filename=os.path.basename(matches[0])) 101 | -------------------------------------------------------------------------------- /scripts/1697513088193.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangnn520/digitalAvatarRealtime/f2c4a9333569181f4038525d266e2fc29afd6a13/scripts/1697513088193.wav -------------------------------------------------------------------------------- /scripts/client_sample.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def uploadAv(): 5 | """上传音视频""" 6 | url = "http://127.0.0.1:6006/inferenceVideoV2/uploadAv" 7 | 8 | audio = "1697513088193.wav" 9 | video = "yangshi.mp4" 10 | 11 | files = {'video': open(video, 'rb'), 'audio': open(audio, 'rb')} 12 | result = requests.post(url=url, files=files) 13 | print(result.text) 14 | 15 | 16 | if __name__ == '__main__': 17 | uploadAv() 18 | -------------------------------------------------------------------------------- /scripts/yangshi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangnn520/digitalAvatarRealtime/f2c4a9333569181f4038525d266e2fc29afd6a13/scripts/yangshi.mp4 -------------------------------------------------------------------------------- /utilities/async_generator2q.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import AsyncGenerator 3 | from objects import QUEUE_OVER 4 | 5 | 6 | def async_generator2q(gen: AsyncGenerator, q: asyncio.Queue = None): 7 | """将一个异步生成器转化为异步队列,队列数据以QUEUE_OVER表示生成器的结束""" 8 | if q is None: 9 | q = asyncio.Queue() 10 | asyncio.create_task(_generate(gen, q)) 11 | return q 12 | 13 | 14 | async def _generate(gen: AsyncGenerator, q: asyncio.Queue): 15 | async for data in gen: 16 | q.put_nowait(data) 17 | q.put_nowait(QUEUE_OVER) 18 | -------------------------------------------------------------------------------- /utilities/async_q2async_generator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import Queue 3 | 4 | from objects import QUEUE_OVER 5 | from typing import Optional, Callable 6 | from concurrent.futures import CancelledError 7 | import traceback 8 | 9 | 10 | async def async_q2async_generator(q: Queue, 11 | timeout: Optional[int] = None, 12 | cancelled_callback: Callable = None): 13 | """将一个异步队列转化为异步生成器,要求队列数据以QUEUE_OVER表示生成器的结束""" 14 | while True: 15 | try: 16 | if timeout is None: 17 | data = await q.get() 18 | else: 19 | data = await asyncio.wait_for(q.get(), timeout=timeout) 20 | q.task_done() 21 | except asyncio.TimeoutError: 22 | break 23 | except CancelledError: 24 | traceback.print_exc() 25 | if cancelled_callback is not None: 26 | # 普通函数 27 | if not asyncio.iscoroutinefunction(cancelled_callback): 28 | cancelled_callback() 29 | else: # 异步函数 30 | await cancelled_callback() 31 | 32 | raise 33 | else: 34 | # 结束标志 35 | if data == QUEUE_OVER: 36 | break 37 | yield data 38 | -------------------------------------------------------------------------------- /utilities/audio_bytes2np_array.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable, Generator 2 | import soundfile as sf 3 | from librosa.util import valid_audio 4 | import numpy as np 5 | import io 6 | from librosa import resample 7 | 8 | 9 | def voice_bytes2array(voice_byte: bytes): 10 | """将音频二进制数据转为numpy的ndarray""" 11 | data, samplerate = sf.read(io.BytesIO(voice_byte)) 12 | data = data.T 13 | 14 | valid_audio(data, mono=False) 15 | 16 | if data.ndim > 1: 17 | data = np.mean(data, axis=0) 18 | 19 | data = resample(y=data, orig_sr=samplerate, target_sr=16000, res_type="kaiser_best") 20 | return data 21 | 22 | 23 | def voice_iter2array(iter: Union[Iterable, Generator]): 24 | """将可以迭代音频二进制数据的迭代器或生成器转为numpy的ndarray""" 25 | data, samplerate = sf.read(io.BytesIO(b''.join(iter))) 26 | data = data.T 27 | 28 | valid_audio(data, mono=False) 29 | 30 | if data.ndim > 1: 31 | data = np.mean(data, axis=0) 32 | 33 | data = resample(data, samplerate, 16000, res_type="kaiser_best") 34 | return data 35 | -------------------------------------------------------------------------------- /utilities/base64_estimate.py: -------------------------------------------------------------------------------- 1 | def is_base64_code(s): 2 | '''Check s is Base64.b64encode''' 3 | if not isinstance(s, str) or not s: 4 | return "params s not string or None" 5 | 6 | _base64_code = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 7 | 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 8 | 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 9 | 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 10 | 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 11 | 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', 12 | '2', '3', '4', '5', '6', '7', '8', '9', '+', 13 | '/', '='] 14 | _base64_code_set = set(_base64_code) # 转为set增加in判断时候的效率 15 | # Check base64 OR codeCheck % 4 16 | code_fail = [i for i in s if i not in _base64_code_set] 17 | if code_fail or len(s) % 4 != 0: 18 | return False 19 | return True 20 | 21 | 22 | if __name__ == '__main__': 23 | print(is_base64_code("c2RhZGFzZGFkYWQ=")) 24 | -------------------------------------------------------------------------------- /utilities/extract_frames_from_video.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import cv2 3 | from typing import List 4 | import numpy as np 5 | from loguru import logger 6 | import uuid 7 | 8 | 9 | def extract_frames_from_video_bytes(video_bytes: bytes, return_list: bool = False): 10 | """ 11 | 获取视频二进制数据的ndarray表达形式 12 | 13 | return_list: 是否以列表形式返回结果 14 | """ 15 | temp_file_path = os.path.join("/dev/shm", str(uuid.uuid4()) + ".mp4") 16 | with open(temp_file_path, "wb") as f: 17 | f.write(video_bytes) 18 | 19 | videoCapture = cv2.VideoCapture(temp_file_path) 20 | fps = videoCapture.get(cv2.CAP_PROP_FPS) 21 | if int(fps) != 25: 22 | # todo 转25fps 23 | logger.warning('The input video is not 25 fps, it would be better to trans it to 25 fps!') 24 | frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) 25 | frame_ndarrays: List[np.ndarray] = [] 26 | for i in range(int(frames)): 27 | ret, frame = videoCapture.read() 28 | frame_ndarrays.append(frame) 29 | videoCapture.release() 30 | os.remove(temp_file_path) 31 | 32 | if not return_list: 33 | frame_ndarrays: np.ndarray = np.stack(frame_ndarrays) 34 | return frame_ndarrays 35 | -------------------------------------------------------------------------------- /utilities/is_bytes_wav.py: -------------------------------------------------------------------------------- 1 | def is_wav(data: bytes): 2 | """ 3 | 判断一个二进制数据是否是wav格式的音频二进制数据 4 | """ 5 | # Ensure there are enough bytes to check 6 | if len(data) < 16: 7 | return False 8 | 9 | # Check for the 'RIFF' identifier 10 | if data[0:4] != b'RIFF': 11 | return False 12 | 13 | # Check for the 'WAVE' identifier 14 | if data[8:12] != b'WAVE': 15 | return False 16 | 17 | # Check for the 'fmt ' subchunk identifier 18 | if data[12:16] != b'fmt ': 19 | return False 20 | 21 | return True 22 | 23 | # Example usage: 24 | # with open('audio_file.wav', 'rb') as f: 25 | # data = f.read() 26 | # print(is_wav(data)) # Output: True or False 27 | -------------------------------------------------------------------------------- /utilities/ndarray2frame.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from numpy import ndarray 3 | 4 | 5 | def ndarray2frame(full_frame: ndarray): 6 | """ndarray转frame""" 7 | # 将帧转换为JPEG格式。 8 | ret, buffer = cv2.imencode(".jpg", full_frame) 9 | if ret: 10 | return buffer 11 | -------------------------------------------------------------------------------- /utilities/p2jpg.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from numpy import ndarray 4 | 5 | 6 | def p2jpg(digman, p, x1, y1, x2, y2, points_68, full_frame, color_comple): 7 | """单个推理图片到整帧转换""" 8 | p = p.astype(np.uint8) 9 | # 转分辨率到原人脸检测长方形 10 | height, width = y2 - y1, x2 - x1 11 | if p.shape[:2] != (height, width): 12 | p = cv2.resize(p, (width, height)) 13 | try: 14 | if isinstance(points_68, ndarray) and points_68.shape[0] >= 17: 15 | # 面部轮廓 16 | face_points = points_68[:17] 17 | if points_68.shape[0] >= 25: 18 | face_points = np.append(face_points, [points_68[24], points_68[19]], axis=0) 19 | face_points = np.stack(face_points).astype(np.int32) 20 | # 1. 创建一个长方形遮罩 21 | mask = np.zeros(p.shape[:2], dtype=np.uint8) 22 | # 2. 使用fillPoly绘制人脸遮罩 23 | cv2.fillPoly(mask, [face_points], (255, 255, 255)) 24 | # 反向遮罩 25 | reverse_mask = cv2.bitwise_not(mask) 26 | # 3. 使用遮罩提取人脸 27 | face_image = cv2.bitwise_and(p, p, mask=mask) 28 | # 原帧对应部分 29 | org_face_rect = full_frame[y1:y2, x1:x2] 30 | # 提取人脸周围 31 | face_surrounding = cv2.bitwise_and(org_face_rect, org_face_rect, mask=reverse_mask) 32 | # 推理出的人脸贴回原帧 33 | inferenced_face_rect = cv2.add(face_image, face_surrounding) 34 | else: 35 | # 将推理的人脸覆盖原帧 36 | inferenced_face_rect = p 37 | except: 38 | # 将推理的人脸覆盖原帧 39 | full_frame[y1:y2, x1:x2] = p 40 | else: 41 | # 将推理的人脸覆盖原帧 42 | full_frame[y1:y2, x1:x2] = inferenced_face_rect 43 | 44 | if digman in color_comple: 45 | full_frame = np.clip( 46 | full_frame + np.array(color_comple[digman]), 0, 255).astype(np.uint8) 47 | 48 | # 将帧转换为JPEG格式。 49 | ret, buffer = cv2.imencode(".jpg", full_frame) 50 | if ret: 51 | return (b"--frame\r\n" 52 | b"Content-Type: image/jpeg\r\n\r\n" + buffer.tobytes() + b"\r\n") 53 | -------------------------------------------------------------------------------- /utilities/singleton.py: -------------------------------------------------------------------------------- 1 | """单例模式""" 2 | 3 | 4 | def singleton(cls): 5 | _instance = {} 6 | 7 | def inner(): 8 | if cls not in _instance: 9 | _instance[cls] = cls() 10 | return _instance[cls] 11 | 12 | return inner 13 | -------------------------------------------------------------------------------- /utilities/text2voice_gener/__init__.py: -------------------------------------------------------------------------------- 1 | from .aliyun_text2voice_gener import text2voice_gener, sync_text2voice_gener, text2voice 2 | from .azure_tts import azure_text2speech 3 | -------------------------------------------------------------------------------- /utilities/text2voice_gener/aliyun_text2voice_gener.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import requests 4 | from aiohttp import ClientSession 5 | from utilities.wav_bytes_2channel import wav_bytes_2channel 6 | 7 | host = 'nls-gateway.aliyuncs.com' 8 | url = 'https://' + host + '/stream/v1/tts' 9 | 10 | sem = None 11 | 12 | common_payload = { 13 | "format": "mp3", 14 | "sample_rate": 32000 15 | } 16 | 17 | 18 | async def text2voice(token: str, appkey: str, text: str, voice=None): 19 | """说出文本""" 20 | payload = { 21 | "appkey": appkey, 22 | "text": text, 23 | "token": token 24 | } 25 | # 说话人 26 | if voice: 27 | payload['voice'] = voice 28 | 29 | payload.update(common_payload) 30 | 31 | global sem 32 | if sem is None: 33 | sem = asyncio.Semaphore(2) # fixme 34 | 35 | async with sem: 36 | async with ClientSession(headers={"Content-Type": "application/json"}) as session: 37 | async with session.post(url, data=json.dumps(payload)) as resp: 38 | resp.raise_for_status() 39 | if 'audio/mpeg' == resp.content_type: 40 | return await resp.read() 41 | 42 | 43 | async def text2voice_gener(token: str, appkey: str, text: str): 44 | """说出文本的语音生成器,不需要实现语音流 todo 延迟高的话 封装类单例只用一个session 再不行阿里云请求也流起来,更精细的流""" 45 | payload = { 46 | "appkey": appkey, 47 | "text": text, 48 | "token": token 49 | } 50 | payload.update(common_payload) 51 | 52 | async with ClientSession(headers={"Content-Type": "application/json"}) as session: 53 | async with session.post(url, data=json.dumps(payload)) as resp: 54 | resp.raise_for_status() 55 | if 'audio/mpeg' == resp.content_type: 56 | v_data = await resp.read() 57 | 58 | v_data_2chan = await asyncio.get_running_loop().run_in_executor(None, wav_bytes_2channel, v_data) 59 | yield v_data_2chan 60 | # 用来测试音频数据是不是流起来了 61 | await asyncio.sleep(5) 62 | print("text2voice_gener over") 63 | 64 | 65 | def sync_text2voice_gener(token: str, appkey: str, text: str): 66 | payload = { 67 | "appkey": appkey, 68 | "text": text, 69 | "token": token, 70 | } 71 | payload.update(common_payload) 72 | 73 | response = requests.post(url, stream=True, headers={"Content-Type": "application/json"}, json=payload) 74 | for chunk in response.iter_content(chunk_size=1024): 75 | yield chunk 76 | 77 | 78 | # async def text2voice_gener(TOKEN: str, APPKEY: str, TEXT: str): 79 | # """ 80 | # 说出文本的语音流生成器 81 | # 82 | # :return: 83 | # """ 84 | # q = asyncio.Queue() 85 | # asyncio.get_running_loop().run_in_executor(None, read_text, q, TOKEN, APPKEY, TEXT) 86 | # while True: 87 | # try: 88 | # yield await asyncio.wait_for(q.get(), 5) 89 | # except asyncio.TimeoutError: 90 | # print("audio over") 91 | # return 92 | # 93 | # 94 | 95 | 96 | if __name__ == '__main__': 97 | ... 98 | -------------------------------------------------------------------------------- /utilities/text2voice_gener/azure_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from loguru import logger 3 | from aiohttp import ClientSession 4 | from collections import namedtuple 5 | import time 6 | 7 | # 10分寿命的token 8 | Token = namedtuple("Token", ["create_ts", "token_s"]) 9 | token: Token = None 10 | 11 | 12 | async def get_available_token(): 13 | """获取可用的token""" 14 | global token 15 | if token is None or time.time() - token.create_ts > 9 * 60: # 更新的时机 16 | async with ClientSession(headers={"Ocp-Apim-Subscription-Key": os.environ['RESOURCE_KEY']}) as session: 17 | async with session.post("https://eastasia.api.cognitive.microsoft.com/sts/v1.0/issuetoken") as resp: 18 | token = Token(token_s=await resp.text(), create_ts=time.time()) 19 | logger.info(f"New azure resource token:{token}") 20 | return token.token_s 21 | 22 | 23 | async def azure_text2speech(text: str, 24 | ssml: str = """ 25 | 26 | 27 | {} 28 | 29 | 30 | """ 31 | ): 32 | """ 33 | 文本转语音 34 | 35 | ssml: ssml标记语言模板 36 | """ 37 | global token 38 | async with ClientSession(headers={ 39 | "Authorization": f"Bearer {await get_available_token()}", 40 | "Content-Type": "application/ssml+xml", 41 | "X-Microsoft-OutputFormat": "riff-24khz-16bit-mono-pcm"}) as session:#todo 可以流的格式 42 | async with session.post("https://southeastasia.tts.speech.microsoft.com/cognitiveservices/v1", 43 | data=ssml.format(text)) as resp: 44 | return await resp.read() 45 | -------------------------------------------------------------------------------- /utilities/wav_bytes_2channel.py: -------------------------------------------------------------------------------- 1 | import io 2 | import wave 3 | 4 | 5 | def wav_bytes_2channel(single_channel_data: bytes): 6 | """ 7 | 将单声道wav二进制改为双声道二进制,采样率32000; 8 | 未测试如果传入多声道音频二进制的结果 9 | """ 10 | # 将单声道数据分割为16位样本 11 | samples = [single_channel_data[i:i + 2] for i in range(0, len(single_channel_data), 2)] 12 | 13 | # 将每个样本复制一次以创建双声道数据 14 | double_channel_data = b''.join(samples[:20] + [sample + sample for sample in samples[20:]]) 15 | 16 | # 创建一个用于写入WAV的字节流 17 | output_wav = io.BytesIO() 18 | 19 | # 设置WAV参数 20 | num_channels = 2 21 | sampwidth = 2 # 16位样本宽度 22 | framerate = 32000 # 设置帧率为32000 23 | num_frames = len(double_channel_data) // (num_channels * sampwidth) 24 | 25 | # 创建WAV文件 26 | with wave.open(output_wav, 'wb') as wav_file: 27 | wav_file.setnchannels(num_channels) 28 | wav_file.setsampwidth(sampwidth) 29 | wav_file.setframerate(framerate) 30 | wav_file.setnframes(num_frames) 31 | wav_file.writeframes(double_channel_data) 32 | 33 | return output_wav.getvalue() 34 | --------------------------------------------------------------------------------