├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── donate.jpg ├── inference.py ├── inference_realtime.py ├── musetalk ├── models │ ├── __pycache__ │ │ ├── unet.cpython-310.pyc │ │ └── vae.cpython-310.pyc │ ├── unet.py │ └── vae.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── blending.cpython-310.pyc │ │ ├── preprocessing.cpython-310.pyc │ │ └── utils.cpython-310.pyc │ ├── blending.py │ ├── dwpose │ │ ├── default_runtime.py │ │ └── rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py │ ├── face_detection │ │ ├── README.md │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── api.cpython-310.pyc │ │ │ ├── models.cpython-310.pyc │ │ │ └── utils.cpython-310.pyc │ │ ├── api.py │ │ ├── detection │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── core.cpython-310.pyc │ │ │ ├── core.py │ │ │ └── sfd │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── bbox.cpython-310.pyc │ │ │ │ ├── detect.cpython-310.pyc │ │ │ │ ├── net_s3fd.cpython-310.pyc │ │ │ │ └── sfd_detector.cpython-310.pyc │ │ │ │ ├── bbox.py │ │ │ │ ├── detect.py │ │ │ │ ├── net_s3fd.py │ │ │ │ └── sfd_detector.py │ │ ├── models.py │ │ └── utils.py │ ├── face_parsing │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── model.cpython-310.pyc │ │ │ └── resnet.cpython-310.pyc │ │ ├── model.py │ │ └── resnet.py │ ├── preprocessing.py │ └── utils.py └── whisper │ ├── __pycache__ │ └── audio2feature.cpython-310.pyc │ ├── audio2feature.py │ └── whisper │ ├── __init__.py │ ├── __main__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── audio.cpython-310.pyc │ ├── decoding.cpython-310.pyc │ ├── model.cpython-310.pyc │ ├── tokenizer.cpython-310.pyc │ ├── transcribe.cpython-310.pyc │ └── utils.cpython-310.pyc │ ├── assets │ ├── gpt2 │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ ├── mel_filters.npz │ └── multilingual │ │ ├── added_tokens.json │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ ├── audio.py │ ├── decoding.py │ ├── model.py │ ├── normalizers │ ├── __init__.py │ ├── basic.py │ ├── english.json │ └── english.py │ ├── tokenizer.py │ ├── transcribe.py │ └── utils.py ├── nodes.py ├── requirements.txt ├── web.png ├── web └── js │ ├── previewVideo.js │ └── uploadVideo.js └── wechat.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /models -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 本软件及其相关代码以MIT协议开源,作者不对软件具备任何控制力,使用软件者、传播软件导出的声音者自负全责。 2 | 如不认可该条款,则不能使用或引用软件包内任何代码和文件。 3 | 4 | 特此授予任何获得本软件和相关文档文件(以下简称“软件”)副本的人免费使用、复制、修改、合并、出版、分发、再授权和/或销售本软件的权利,以及授予本软件所提供的人使用本软件的权利,但须符合以下条件: 5 | 上述版权声明和本许可声明应包含在软件的所有副本或实质部分中。 6 | 软件是“按原样”提供的,没有任何明示或暗示的保证,包括但不限于适销性、适用于特定目的和不侵权的保证。在任何情况下,作者或版权持有人均不承担因软件或软件的使用或其他交易而产生、产生或与之相关的任何索赔、损害赔偿或其他责任,无论是在合同诉讼、侵权诉讼还是其他诉讼中。 7 | 8 | 9 | 10 | MIT License 11 | 12 | Copyright (c) 2024 AIFSH 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-MuseTalk_FSH 2 | the comfyui custom node of [MuseTalk](https://github.com/TMElyralab/MuseTalk.git) to make audio driven videos! 3 |
4 |
5 | webpage 6 |
7 |
8 | 9 | ## How to use 10 | make sure `ffmpeg` is worked in your commandline 11 | for Linux 12 | ``` 13 | apt update 14 | apt install ffmpeg 15 | ``` 16 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically 17 | 18 | then! 19 | ``` 20 | git clone https://github.com/AIFSH/ComfyUI-MuseTalk_FSH.git 21 | cd ComfyUI-MuseTalk_FSH 22 | pip install -r requirements.txt 23 | ``` 24 | ### mmlab packages 25 | ```bash 26 | pip install --no-cache-dir -U openmim 27 | mim install mmengine 28 | mim install "mmcv>=2.0.1" 29 | mim install "mmdet>=3.1.0" 30 | mim install "mmpose>=1.1.0" 31 | ``` 32 | 33 | ### Download weights 34 | You can download weights manually as follows: 35 | 36 | 1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk). 37 | 38 | 2. Download the weights of other components: 39 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 40 | - [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt) 41 | - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) 42 | - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch) 43 | - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) 44 | 45 | 或者下载[MuseTalk.zip](https://pan.quark.cn/s/d6e76084ae92), 46 | 解压后把子文件夹放入`ComfyUI-MuseTalk_FSH/models/`目录 47 | 48 | Finally, these weights should be organized in `models` as follows: 49 | ``` 50 | ComfyUI-MuseTalk_FSH/models/ 51 | ├── musetalk 52 | │ └── musetalk.json 53 | │ └── pytorch_model.bin 54 | ├── dwpose 55 | │ └── dw-ll_ucoco_384.pth 56 | ├── face-parse-bisent 57 | │ ├── 79999_iter.pth 58 | │ └── resnet18-5c106cde.pth 59 | ├── sd-vae-ft-mse 60 | │ ├── config.json 61 | │ └── diffusion_pytorch_model.bin 62 | └── whisper 63 | └── tiny.pt 64 | ``` 65 | 66 | ## Tutorial 67 | - [Demo on 3060 12GB](https://www.bilibili.com/video/BV1St421w7Qn) 68 | - [Demo on 4090 24GB](https://www.bilibili.com/video/BV13T42117uM/) 69 | 70 | 71 | ## WeChat Group && Donate 72 |
73 |
74 | Wechat 75 | donate 76 |
77 |
78 | 79 | ## Thanks 80 | - [MuseTalk](https://github.com/TMElyralab/MuseTalk.git) 81 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import MuseTalk,LoadVideo,PreViewVideo,CombineAudioVideo,MuseTalkRealTime 2 | WEB_DIRECTORY = "./web" 3 | # A dictionary that contains all nodes you want to export with their names 4 | # NOTE: names should be globally unique 5 | NODE_CLASS_MAPPINGS = { 6 | "MuseTalk": MuseTalk, 7 | "LoadVideo": LoadVideo, 8 | "PreViewVideo": PreViewVideo, 9 | "CombineAudioVideo": CombineAudioVideo, 10 | "MuseTalkRealTime": MuseTalkRealTime 11 | } 12 | 13 | # A dictionary that contains the friendly/humanly readable titles for the nodes 14 | NODE_DISPLAY_NAME_MAPPINGS = { 15 | "MuseTalk": "MuseTalk Node", 16 | "LoadVideo": "Video Loader", 17 | "PreViewVideo": "PreView Video", 18 | "CombineAudioVideo": "Combine Audio Video", 19 | "MuseTalkRealTime": "MuseTalk RealTime Node" 20 | } 21 | -------------------------------------------------------------------------------- /donate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/donate.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import glob,copy 5 | import shutil 6 | from tqdm import tqdm 7 | import torch 8 | import pickle 9 | import numpy as np 10 | from cuda_malloc import cuda_malloc_supported 11 | from typing import Any 12 | import folder_paths 13 | from .musetalk.utils.face_parsing import FaceParsing 14 | from mmpose.apis import init_model 15 | from .musetalk.utils.blending import get_image 16 | from .musetalk.utils.utils import load_all_model,get_file_type,get_video_fps,datagen 17 | from .musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder 18 | 19 | parent_directory = os.path.dirname(os.path.abspath(__file__)) 20 | input_path = folder_paths.get_input_directory() 21 | out_path = folder_paths.get_output_directory() 22 | 23 | class MuseTalk_INFER: 24 | def __init__(self,bbox_shift=0,fps=25, 25 | batch_size=8,batch_size_fa=2, 26 | use_saved_coord=False) -> None: 27 | self.fps = fps 28 | self.bbox_shift = bbox_shift 29 | self.batch_size = batch_size 30 | self.batch_size_fa = batch_size_fa 31 | self.use_saved_coord = use_saved_coord 32 | self.device = torch.device("cuda" if cuda_malloc_supported() else "cpu") 33 | config_file = os.path.join(parent_directory,"musetalk","utils","dwpose","rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py") 34 | checkpoint_file = os.path.join(parent_directory,'models','dwpose','dw-ll_ucoco_384.pth') 35 | resnet_path = os.path.join(parent_directory,'models','face-parse-bisent','resnet18-5c106cde.pth') 36 | face_model_pth = os.path.join(parent_directory,"models','face-parse-bisent','79999_iter.pth") 37 | self.fp_model = FaceParsing(resnet_path,face_model_pth) 38 | self.dwpose_model = init_model(config_file, checkpoint_file, device=self.device) 39 | self.audio_processor,self.vae,self.unet,self.pe = load_all_model(os.path.join(parent_directory,"models")) 40 | self.timesteps = torch.tensor([0], device=self.device) 41 | 42 | def __call__(self, video_path,audio_path,*args: Any, **kwds: Any) -> Any: 43 | input_basename = os.path.basename(video_path).split('.')[0] 44 | audio_basename = os.path.basename(audio_path).split('.')[0] 45 | output_basename = f"{input_basename}_{audio_basename}" 46 | result_img_save_path = os.path.join(out_path,"musetalk_result", output_basename) # related to video & audio inputs 47 | os.makedirs(result_img_save_path, exist_ok=True) 48 | crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input 49 | output_vid_name = os.path.join(out_path, output_basename+".mp4") 50 | 51 | ############################################## extract frames from source video ############################################## 52 | if get_file_type(video_path)=="video": 53 | save_dir_full = os.path.join(out_path,"musetalk_result",input_basename) 54 | os.makedirs(save_dir_full,exist_ok = True) 55 | png_path = os.path.join(save_dir_full,"%08d.png") 56 | cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {png_path}" 57 | os.system(cmd) 58 | input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) 59 | fps = get_video_fps(video_path) 60 | else: # input img folder 61 | input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) 62 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) 63 | fps = self.fps 64 | 65 | #print(input_img_list) 66 | ############################################## extract audio feature ############################################## 67 | whisper_feature = self.audio_processor.audio2feat(audio_path) 68 | whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) 69 | ############################################## preprocess input image ############################################## 70 | if os.path.exists(crop_coord_save_path) and self.use_saved_coord: 71 | print("using extracted coordinates") 72 | with open(crop_coord_save_path,'rb') as f: 73 | coord_list = pickle.load(f) 74 | frame_list = read_imgs(input_img_list) 75 | else: 76 | print("extracting landmarks...time consuming") 77 | coord_list, frame_list = get_landmark_and_bbox(self.dwpose_model,input_img_list, self.batch_size_fa,self.bbox_shift) 78 | with open(crop_coord_save_path, 'wb') as f: 79 | pickle.dump(coord_list, f) 80 | 81 | i = 0 82 | input_latent_list = [] 83 | for bbox, frame in zip(coord_list, frame_list): 84 | if bbox == coord_placeholder: 85 | continue 86 | x1, y1, x2, y2 = bbox 87 | crop_frame = frame[y1:y2, x1:x2] 88 | crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) 89 | latents = self.vae.get_latents_for_unet(crop_frame) 90 | input_latent_list.append(latents) 91 | 92 | # to smooth the first and the last frame 93 | frame_list_cycle = frame_list + frame_list[::-1] 94 | coord_list_cycle = coord_list + coord_list[::-1] 95 | input_latent_list_cycle = input_latent_list + input_latent_list[::-1] 96 | ############################################## inference batch by batch ############################################## 97 | print("start inference") 98 | video_num = len(whisper_chunks) 99 | batch_size = self.batch_size 100 | gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) 101 | res_frame_list = [] 102 | for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))): 103 | tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] 104 | audio_feature_batch = torch.stack(tensor_list).to(self.unet.device) # torch, B, 5*N,384 105 | audio_feature_batch = self.pe(audio_feature_batch) 106 | 107 | pred_latents = self.unet.model(latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch).sample 108 | recon = self.vae.decode_latents(pred_latents) 109 | for res_frame in recon: 110 | res_frame_list.append(res_frame) 111 | 112 | ############################################## pad to full image ############################################## 113 | print("pad talking image to original video") 114 | for i, res_frame in enumerate(tqdm(res_frame_list)): 115 | bbox = coord_list_cycle[i%(len(coord_list_cycle))] 116 | ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) 117 | x1, y1, x2, y2 = bbox 118 | try: 119 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) 120 | except: 121 | # print(bbox) 122 | continue 123 | 124 | combine_frame = get_image(self.fp_model, ori_frame,res_frame,bbox) 125 | cv2.imwrite(os.path.join(result_img_save_path,f"{str(i).zfill(8)}.png"),combine_frame) 126 | 127 | res_tmp_path = os.path.join(result_img_save_path,"%08d.png") 128 | cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {res_tmp_path} -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {output_vid_name}" 129 | print(cmd_img2video) 130 | os.system(cmd_img2video) 131 | ''' 132 | cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}" 133 | print(cmd_combine_audio) 134 | os.system(cmd_combine_audio) 135 | ''' 136 | del self.fp_model,self.dwpose_model,self.audio_processor,self.vae,self.unet,self.pe 137 | torch.cuda.empty_cache() 138 | # os.remove("temp.mp4") 139 | shutil.rmtree(result_img_save_path) 140 | print(f"result is save to {output_vid_name}") 141 | return output_vid_name 142 | -------------------------------------------------------------------------------- /inference_realtime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import json 5 | import torch 6 | import shutil 7 | import pickle 8 | import glob,time 9 | import queue,copy 10 | import threading 11 | from tqdm import tqdm 12 | import numpy as np 13 | import folder_paths 14 | from cuda_malloc import cuda_malloc_supported 15 | from typing import Any 16 | from .musetalk.utils.face_parsing import FaceParsing 17 | from mmpose.apis import init_model 18 | from .musetalk.utils.utils import load_all_model,datagen 19 | from .musetalk.utils.preprocessing import read_imgs,get_landmark_and_bbox 20 | from .musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending 21 | 22 | parent_directory = os.path.dirname(os.path.abspath(__file__)) 23 | device = torch.device("cuda" if cuda_malloc_supported() else "cpu") 24 | timesteps = torch.tensor([0], device=device) 25 | 26 | output_path = folder_paths.get_output_directory() 27 | musetalk_out_path = os.path.join(output_path,"musetalk_realtime") 28 | os.makedirs(musetalk_out_path, exist_ok=True) 29 | 30 | def osmakedirs(path_list): 31 | for path in path_list: 32 | os.makedirs(path) if not os.path.exists(path) else None 33 | 34 | def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): 35 | cap = cv2.VideoCapture(vid_path) 36 | count = 0 37 | while True: 38 | if count > cut_frame: 39 | break 40 | ret, frame = cap.read() 41 | if ret: 42 | cv2.imwrite(f"{save_path}/{count:08d}.png", frame) 43 | count += 1 44 | else: 45 | break 46 | 47 | @torch.no_grad() 48 | class Avatar: 49 | def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation): 50 | self.avatar_id = avatar_id 51 | self.video_path = video_path 52 | self.bbox_shift = bbox_shift 53 | self.avatar_path = os.path.join(musetalk_out_path,avatar_id) 54 | self.full_imgs_path = os.path.join(self.avatar_path, "full_imgs") 55 | self.coords_path = os.path.join(self.avatar_path, "coords.pkl") 56 | self.latents_out_path= os.path.join(self.avatar_path, "latents.pt") 57 | self.video_out_path = output_path 58 | self.mask_out_path = os.path.join(self.avatar_path, "mask") 59 | self.mask_coords_path = os.path.join(self.avatar_path, "mask_coords.pkl") 60 | self.avatar_info_path = os.path.join(self.avatar_path, "avator_info.json") 61 | self.avatar_info = { 62 | "avatar_id":avatar_id, 63 | "video_path":video_path, 64 | "bbox_shift":bbox_shift 65 | } 66 | self.preparation = preparation 67 | self.batch_size = batch_size 68 | self.idx = 0 69 | # load model weights 70 | config_file = os.path.join(parent_directory,"musetalk","utils","dwpose","rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py") 71 | checkpoint_file = os.path.join(parent_directory,"models",'dwpose','dw-ll_ucoco_384.pth') 72 | resnet_path = os.path.join(parent_directory,'models','face-parse-bisent','resnet18-5c106cde.pth') 73 | face_model_pth = os.path.join(parent_directory,'models','face-parse-bisent',"79999_iter.pth") 74 | 75 | self.fp_model = FaceParsing(resnet_path,face_model_pth) 76 | self.dwpose_model = init_model(config_file, checkpoint_file, device=device) 77 | self.audio_processor,self.vae,self.unet,self.pe = load_all_model(os.path.join(parent_directory,"models")) 78 | self.vae.vae = self.vae.vae.half() 79 | self.unet.model = self.unet.model.half() 80 | self.pe = self.pe.half() 81 | self.init() 82 | 83 | def init(self): 84 | if self.preparation: 85 | if os.path.exists(self.avatar_path): 86 | response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)") 87 | if response.lower() == "y": 88 | shutil.rmtree(self.avatar_path) 89 | print("*********************************") 90 | print(f" creating avator: {self.avatar_id}") 91 | print("*********************************") 92 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) 93 | self.prepare_material() 94 | else: 95 | self.input_latent_list_cycle = torch.load(self.latents_out_path) 96 | with open(self.coords_path, 'rb') as f: 97 | self.coord_list_cycle = pickle.load(f) 98 | input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) 99 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) 100 | self.frame_list_cycle = read_imgs(input_img_list) 101 | with open(self.mask_coords_path, 'rb') as f: 102 | self.mask_coords_list_cycle = pickle.load(f) 103 | input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) 104 | input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) 105 | self.mask_list_cycle = read_imgs(input_mask_list) 106 | else: 107 | print("*********************************") 108 | print(f" creating avator: {self.avatar_id}") 109 | print("*********************************") 110 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) 111 | self.prepare_material() 112 | else: 113 | with open(self.avatar_info_path, "r") as f: 114 | avatar_info = json.load(f) 115 | 116 | if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']: 117 | response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)") 118 | if response.lower() == "c": 119 | shutil.rmtree(self.avatar_path) 120 | print("*********************************") 121 | print(f" creating avator: {self.avatar_id}") 122 | print("*********************************") 123 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path]) 124 | self.prepare_material() 125 | else: 126 | sys.exit() 127 | else: 128 | self.input_latent_list_cycle = torch.load(self.latents_out_path) 129 | with open(self.coords_path, 'rb') as f: 130 | self.coord_list_cycle = pickle.load(f) 131 | input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) 132 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) 133 | self.frame_list_cycle = read_imgs(input_img_list) 134 | with open(self.mask_coords_path, 'rb') as f: 135 | self.mask_coords_list_cycle = pickle.load(f) 136 | input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]')) 137 | input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) 138 | self.mask_list_cycle = read_imgs(input_mask_list) 139 | try: 140 | del self.dwpose_model,self.fp_model 141 | import gc; gc.collect(); torch.cuda.empty_cache(); 142 | except: 143 | pass 144 | 145 | def prepare_material(self): 146 | print("preparing data materials ... ...") 147 | with open(self.avatar_info_path, "w") as f: 148 | json.dump(self.avatar_info, f) 149 | 150 | if os.path.isfile(self.video_path): 151 | video2imgs(self.video_path, self.full_imgs_path, ext = 'png') 152 | else: 153 | print(f"copy files in {self.video_path}") 154 | files = os.listdir(self.video_path) 155 | files.sort() 156 | files = [file for file in files if file.split(".")[-1]=="png"] 157 | for filename in files: 158 | shutil.copyfile(os.path.join(self.video_path,filename), os.path.join(self.full_imgs_path,filename)) 159 | input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) 160 | 161 | print("extracting landmarks...") 162 | coord_list, frame_list = get_landmark_and_bbox(self.dwpose_model,input_img_list,1,self.bbox_shift,) 163 | input_latent_list = [] 164 | idx = -1 165 | # maker if the bbox is not sufficient 166 | coord_placeholder = (0.0,0.0,0.0,0.0) 167 | for bbox, frame in zip(coord_list, frame_list): 168 | idx = idx + 1 169 | if bbox == coord_placeholder: 170 | continue 171 | x1, y1, x2, y2 = bbox 172 | crop_frame = frame[y1:y2, x1:x2] 173 | 174 | resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) 175 | latents = self.vae.get_latents_for_unet(resized_crop_frame) 176 | input_latent_list.append(latents) 177 | 178 | 179 | self.frame_list_cycle = frame_list + frame_list[::-1] 180 | self.coord_list_cycle = coord_list + coord_list[::-1] 181 | self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1] 182 | self.mask_coords_list_cycle = [] 183 | self.mask_list_cycle = [] 184 | 185 | for i,frame in enumerate(tqdm(self.frame_list_cycle)): 186 | cv2.imwrite(os.path.join(self.full_imgs_path,str(i).zfill(8)+".png"),frame) 187 | 188 | face_box = self.coord_list_cycle[i] 189 | mask,crop_box = get_image_prepare_material(self.fp_model,frame,face_box) 190 | cv2.imwrite(os.path.join(self.mask_out_path,str(i).zfill(8)+".png"),mask) 191 | self.mask_coords_list_cycle += [crop_box] 192 | self.mask_list_cycle.append(mask) 193 | 194 | with open(self.mask_coords_path, 'wb') as f: 195 | pickle.dump(self.mask_coords_list_cycle, f) 196 | 197 | with open(self.coords_path, 'wb') as f: 198 | pickle.dump(self.coord_list_cycle, f) 199 | 200 | torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) 201 | 202 | 203 | 204 | def process_frames(self, res_frame_queue,video_len): 205 | print(video_len) 206 | while True: 207 | if self.idx>=video_len-1: 208 | break 209 | try: 210 | start = time.time() 211 | res_frame = res_frame_queue.get(block=True, timeout=1) 212 | except queue.Empty: 213 | continue 214 | 215 | bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))] 216 | ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))]) 217 | x1, y1, x2, y2 = bbox 218 | try: 219 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) 220 | except: 221 | continue 222 | mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))] 223 | mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))] 224 | #combine_frame = get_image(ori_frame,res_frame,bbox) 225 | combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box) 226 | 227 | fps = 1/(time.time()-start+1e-6) 228 | print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}") 229 | cv2.imwrite(os.path.join(self.avatar_path,"tmp",str(self.idx).zfill(8)+".png"),combine_frame) 230 | self.idx = self.idx + 1 231 | 232 | def inference(self, audio_path, out_vid_name, fps): 233 | os.makedirs(os.path.join(self.avatar_path,'tmp'),exist_ok =True) 234 | ############################################## extract audio feature ############################################## 235 | whisper_feature = self.audio_processor.audio2feat(audio_path) 236 | whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) 237 | ############################################## inference batch by batch ############################################## 238 | video_num = len(whisper_chunks) 239 | print("start inference") 240 | res_frame_queue = queue.Queue() 241 | self.idx = 0 242 | # # Create a sub-thread and start it 243 | process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num)) 244 | process_thread.start() 245 | start_time = time.time() 246 | gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size) 247 | print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") 248 | start_time = time.time() 249 | res_frame_list = [] 250 | 251 | for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))): 252 | start_time = time.time() 253 | tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch] 254 | audio_feature_batch = torch.stack(tensor_list).to(self.unet.device) # torch, B, 5*N,384 255 | audio_feature_batch = self.pe(audio_feature_batch) 256 | 257 | pred_latents = self.unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample 258 | recon = self.vae.decode_latents(pred_latents) 259 | for res_frame in recon: 260 | res_frame_queue.put(res_frame) 261 | # Close the queue and sub-thread after all tasks are completed 262 | process_thread.join() 263 | 264 | if out_vid_name is not None: 265 | # optional 266 | img_path = os.path.join(self.avatar_path,'tmp','%08d.png') 267 | tmp_mp4 = os.path.join(self.avatar_path,"temp.mp4") 268 | cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {img_path} -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {tmp_mp4}" 269 | print(cmd_img2video) 270 | os.system(cmd_img2video) 271 | 272 | output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on 273 | cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {tmp_mp4} {output_vid}" 274 | print(cmd_combine_audio) 275 | os.system(cmd_combine_audio) 276 | 277 | os.remove(tmp_mp4) 278 | shutil.rmtree(os.path.join(self.avatar_path,"tmp")) 279 | print(f"result is save to {output_vid}") 280 | del self.audio_processor,self.vae,self.unet,self.pe 281 | import gc; gc.collect(); torch.cuda.empty_cache(); 282 | return output_vid 283 | 284 | class Infer_Real_Time: 285 | def __init__(self) -> None: 286 | pass 287 | 288 | def __call__(self, audio_path,video_path, 289 | avatar_id,fps=25,batch_size=4, 290 | preparation=True,bbox_shift=0, 291 | *args: Any, **kwds: Any) -> Any: 292 | 293 | avatar = Avatar( 294 | avatar_id = avatar_id, 295 | video_path = video_path, 296 | bbox_shift = bbox_shift, 297 | batch_size = batch_size, 298 | preparation= preparation) 299 | output_name = os.path.basename(audio_path)[:-4] + "musetalk" 300 | return avatar.inference(audio_path,output_name,fps) 301 | -------------------------------------------------------------------------------- /musetalk/models/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/models/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/models/__pycache__/vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/models/__pycache__/vae.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import json 5 | 6 | from diffusers import UNet2DConditionModel 7 | import sys 8 | import time 9 | import numpy as np 10 | import os 11 | 12 | class PositionalEncoding(nn.Module): 13 | def __init__(self, d_model=384, max_len=5000): 14 | super(PositionalEncoding, self).__init__() 15 | pe = torch.zeros(max_len, d_model) 16 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | pe = pe.unsqueeze(0) 21 | self.register_buffer('pe', pe) 22 | 23 | def forward(self, x): 24 | b, seq_len, d_model = x.size() 25 | pe = self.pe[:, :seq_len, :] 26 | x = x + pe.to(x.device) 27 | return x 28 | 29 | class UNet(): 30 | def __init__(self, 31 | unet_config, 32 | model_path, 33 | use_float16=False, 34 | ): 35 | with open(unet_config, 'r') as f: 36 | unet_config = json.load(f) 37 | self.model = UNet2DConditionModel(**unet_config) 38 | self.pe = PositionalEncoding(d_model=384) 39 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) 41 | self.model.load_state_dict(self.weights) 42 | if use_float16: 43 | self.model = self.model.half() 44 | self.model.to(self.device) 45 | 46 | if __name__ == "__main__": 47 | unet = UNet() -------------------------------------------------------------------------------- /musetalk/models/vae.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoencoderKL 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torch.nn.functional as F 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | import os 9 | 10 | class VAE(): 11 | """ 12 | VAE (Variational Autoencoder) class for image processing. 13 | """ 14 | 15 | def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False): 16 | """ 17 | Initialize the VAE instance. 18 | 19 | :param model_path: Path to the trained model. 20 | :param resized_img: The size to which images are resized. 21 | :param use_float16: Whether to use float16 precision. 22 | """ 23 | self.model_path = model_path 24 | self.vae = AutoencoderKL.from_pretrained(self.model_path) 25 | 26 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | self.vae.to(self.device) 28 | 29 | if use_float16: 30 | self.vae = self.vae.half() 31 | self._use_float16 = True 32 | else: 33 | self._use_float16 = False 34 | 35 | self.scaling_factor = self.vae.config.scaling_factor 36 | self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 37 | self._resized_img = resized_img 38 | self._mask_tensor = self.get_mask_tensor() 39 | 40 | def get_mask_tensor(self): 41 | """ 42 | Creates a mask tensor for image processing. 43 | :return: A mask tensor. 44 | """ 45 | mask_tensor = torch.zeros((self._resized_img,self._resized_img)) 46 | mask_tensor[:self._resized_img//2,:] = 1 47 | mask_tensor[mask_tensor< 0.5] = 0 48 | mask_tensor[mask_tensor>= 0.5] = 1 49 | return mask_tensor 50 | 51 | def preprocess_img(self,img_name,half_mask=False): 52 | """ 53 | Preprocess an image for the VAE. 54 | 55 | :param img_name: The image file path or a list of image file paths. 56 | :param half_mask: Whether to apply a half mask to the image. 57 | :return: A preprocessed image tensor. 58 | """ 59 | window = [] 60 | if isinstance(img_name, str): 61 | window_fnames = [img_name] 62 | for fname in window_fnames: 63 | img = cv2.imread(fname) 64 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 65 | img = cv2.resize(img, (self._resized_img, self._resized_img), 66 | interpolation=cv2.INTER_LANCZOS4) 67 | window.append(img) 68 | else: 69 | img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB) 70 | window.append(img) 71 | 72 | x = np.asarray(window) / 255. 73 | x = np.transpose(x, (3, 0, 1, 2)) 74 | x = torch.squeeze(torch.FloatTensor(x)) 75 | if half_mask: 76 | x = x * (self._mask_tensor>0.5) 77 | x = self.transform(x) 78 | 79 | x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor 80 | x = x.to(self.vae.device) 81 | 82 | return x 83 | 84 | def encode_latents(self,image): 85 | """ 86 | Encode an image into latent variables. 87 | 88 | :param image: The image tensor to encode. 89 | :return: The encoded latent variables. 90 | """ 91 | with torch.no_grad(): 92 | init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist 93 | init_latents = self.scaling_factor * init_latent_dist.sample() 94 | return init_latents 95 | 96 | def decode_latents(self, latents): 97 | """ 98 | Decode latent variables back into an image. 99 | :param latents: The latent variables to decode. 100 | :return: A NumPy array representing the decoded image. 101 | """ 102 | latents = (1/ self.scaling_factor) * latents 103 | image = self.vae.decode(latents.to(self.vae.dtype)).sample 104 | image = (image / 2 + 0.5).clamp(0, 1) 105 | image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() 106 | image = (image * 255).round().astype("uint8") 107 | image = image[...,::-1] # RGB to BGR 108 | return image 109 | 110 | def get_latents_for_unet(self,img): 111 | """ 112 | Prepare latent variables for a U-Net model. 113 | :param img: The image to process. 114 | :return: A concatenated tensor of latents for U-Net input. 115 | """ 116 | 117 | ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor 118 | masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor 119 | ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor 120 | ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor 121 | latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) 122 | return latent_model_input 123 | 124 | if __name__ == "__main__": 125 | vae_mode_path = "./models/sd-vae-ft-mse/" 126 | vae = VAE(model_path = vae_mode_path,use_float16=False) 127 | img_path = "./results/sun001_crop/00000.png" 128 | 129 | crop_imgs_path = "./results/sun001_crop/" 130 | latents_out_path = "./results/latents/" 131 | if not os.path.exists(latents_out_path): 132 | os.mkdir(latents_out_path) 133 | 134 | files = os.listdir(crop_imgs_path) 135 | files.sort() 136 | files = [file for file in files if file.split(".")[-1] == "png"] 137 | 138 | for file in files: 139 | index = file.split(".")[0] 140 | img_path = crop_imgs_path + file 141 | latents = vae.get_latents_for_unet(img_path) 142 | print(img_path,"latents",latents.size()) 143 | #torch.save(latents,os.path.join(latents_out_path,index+".pt")) 144 | #reload_tensor = torch.load('tensor.pt') 145 | #print(reload_tensor.size()) 146 | 147 | 148 | -------------------------------------------------------------------------------- /musetalk/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import abspath, dirname 3 | current_dir = dirname(abspath(__file__)) 4 | parent_dir = dirname(current_dir) 5 | sys.path.append(parent_dir+'/utils') 6 | -------------------------------------------------------------------------------- /musetalk/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/__pycache__/blending.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/blending.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/__pycache__/preprocessing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/preprocessing.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/blending.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def get_crop_box(box, expand): 7 | x, y, x1, y1 = box 8 | x_c, y_c = (x+x1)//2, (y+y1)//2 9 | w, h = x1-x, y1-y 10 | s = int(max(w, h)//2*expand) 11 | crop_box = [x_c-s, y_c-s, x_c+s, y_c+s] 12 | return crop_box, s 13 | 14 | def face_seg(fp, image): 15 | seg_image = fp(image) 16 | if seg_image is None: 17 | print("error, no person_segment") 18 | return None 19 | 20 | seg_image = seg_image.resize(image.size) 21 | return seg_image 22 | 23 | def get_image(fp_model,image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2): 24 | #print(image.shape) 25 | #print(face.shape) 26 | 27 | body = Image.fromarray(image[:,:,::-1]) 28 | face = Image.fromarray(face[:,:,::-1]) 29 | 30 | x, y, x1, y1 = face_box 31 | #print(x1-x,y1-y) 32 | crop_box, s = get_crop_box(face_box, expand) 33 | x_s, y_s, x_e, y_e = crop_box 34 | face_position = (x, y) 35 | 36 | face_large = body.crop(crop_box) 37 | ori_shape = face_large.size 38 | 39 | mask_image = face_seg(fp_model,face_large) 40 | mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) 41 | mask_image = Image.new('L', ori_shape, 0) 42 | mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) 43 | 44 | # keep upper_boundary_ratio of talking area 45 | width, height = mask_image.size 46 | top_boundary = int(height * upper_boundary_ratio) 47 | modified_mask_image = Image.new('L', ori_shape, 0) 48 | modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) 49 | 50 | blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 51 | mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) 52 | mask_image = Image.fromarray(mask_array) 53 | 54 | face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) 55 | body.paste(face_large, crop_box[:2], mask_image) 56 | body = np.array(body) 57 | return body[:,:,::-1] 58 | 59 | def get_image_prepare_material(fp_model,image,face_box,upper_boundary_ratio = 0.5,expand=1.2): 60 | body = Image.fromarray(image[:,:,::-1]) 61 | 62 | x, y, x1, y1 = face_box 63 | #print(x1-x,y1-y) 64 | crop_box, s = get_crop_box(face_box, expand) 65 | x_s, y_s, x_e, y_e = crop_box 66 | 67 | face_large = body.crop(crop_box) 68 | ori_shape = face_large.size 69 | 70 | mask_image = face_seg(fp_model,face_large) 71 | mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s)) 72 | mask_image = Image.new('L', ori_shape, 0) 73 | mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s)) 74 | 75 | # keep upper_boundary_ratio of talking area 76 | width, height = mask_image.size 77 | top_boundary = int(height * upper_boundary_ratio) 78 | modified_mask_image = Image.new('L', ori_shape, 0) 79 | modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) 80 | 81 | blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1 82 | mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) 83 | return mask_array,crop_box 84 | 85 | def get_image_blending(image,face,face_box,mask_array,crop_box): 86 | body = Image.fromarray(image[:,:,::-1]) 87 | face = Image.fromarray(face[:,:,::-1]) 88 | 89 | x, y, x1, y1 = face_box 90 | x_s, y_s, x_e, y_e = crop_box 91 | face_large = body.crop(crop_box) 92 | 93 | mask_image = Image.fromarray(mask_array) 94 | mask_image = mask_image.convert("L") 95 | face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s)) 96 | body.paste(face_large, crop_box[:2], mask_image) 97 | body = np.array(body) 98 | return body[:,:,::-1] -------------------------------------------------------------------------------- /musetalk/utils/dwpose/default_runtime.py: -------------------------------------------------------------------------------- 1 | default_scope = 'mmpose' 2 | 3 | # hooks 4 | default_hooks = dict( 5 | timer=dict(type='IterTimerHook'), 6 | logger=dict(type='LoggerHook', interval=50), 7 | param_scheduler=dict(type='ParamSchedulerHook'), 8 | checkpoint=dict(type='CheckpointHook', interval=10), 9 | sampler_seed=dict(type='DistSamplerSeedHook'), 10 | visualization=dict(type='PoseVisualizationHook', enable=False), 11 | badcase=dict( 12 | type='BadCaseAnalysisHook', 13 | enable=False, 14 | out_dir='badcase', 15 | metric_type='loss', 16 | badcase_thr=5)) 17 | 18 | # custom hooks 19 | custom_hooks = [ 20 | # Synchronize model buffers such as running_mean and running_var in BN 21 | # at the end of each epoch 22 | dict(type='SyncBuffersHook') 23 | ] 24 | 25 | # multi-processing backend 26 | env_cfg = dict( 27 | cudnn_benchmark=False, 28 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 29 | dist_cfg=dict(backend='nccl'), 30 | ) 31 | 32 | # visualizer 33 | vis_backends = [ 34 | dict(type='LocalVisBackend'), 35 | # dict(type='TensorboardVisBackend'), 36 | # dict(type='WandbVisBackend'), 37 | ] 38 | visualizer = dict( 39 | type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer') 40 | 41 | # logger 42 | log_processor = dict( 43 | type='LogProcessor', window_size=50, by_epoch=True, num_digits=6) 44 | log_level = 'INFO' 45 | load_from = None 46 | resume = False 47 | 48 | # file I/O backend 49 | backend_args = dict(backend='local') 50 | 51 | # training/validation/testing progress 52 | train_cfg = dict(by_epoch=True) 53 | val_cfg = dict() 54 | test_cfg = dict() 55 | -------------------------------------------------------------------------------- /musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py: -------------------------------------------------------------------------------- 1 | #_base_ = ['../../../_base_/default_runtime.py'] 2 | _base_ = ['default_runtime.py'] 3 | 4 | # runtime 5 | max_epochs = 270 6 | stage2_num_epochs = 30 7 | base_lr = 4e-3 8 | train_batch_size = 32 9 | val_batch_size = 32 10 | 11 | train_cfg = dict(max_epochs=max_epochs, val_interval=10) 12 | randomness = dict(seed=21) 13 | 14 | # optimizer 15 | optim_wrapper = dict( 16 | type='OptimWrapper', 17 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), 18 | paramwise_cfg=dict( 19 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) 20 | 21 | # learning rate 22 | param_scheduler = [ 23 | dict( 24 | type='LinearLR', 25 | start_factor=1.0e-5, 26 | by_epoch=False, 27 | begin=0, 28 | end=1000), 29 | dict( 30 | # use cosine lr from 150 to 300 epoch 31 | type='CosineAnnealingLR', 32 | eta_min=base_lr * 0.05, 33 | begin=max_epochs // 2, 34 | end=max_epochs, 35 | T_max=max_epochs // 2, 36 | by_epoch=True, 37 | convert_to_iter_based=True), 38 | ] 39 | 40 | # automatically scaling LR based on the actual training batch size 41 | auto_scale_lr = dict(base_batch_size=512) 42 | 43 | # codec settings 44 | codec = dict( 45 | type='SimCCLabel', 46 | input_size=(288, 384), 47 | sigma=(6., 6.93), 48 | simcc_split_ratio=2.0, 49 | normalize=False, 50 | use_dark=False) 51 | 52 | # model settings 53 | model = dict( 54 | type='TopdownPoseEstimator', 55 | data_preprocessor=dict( 56 | type='PoseDataPreprocessor', 57 | mean=[123.675, 116.28, 103.53], 58 | std=[58.395, 57.12, 57.375], 59 | bgr_to_rgb=True), 60 | backbone=dict( 61 | _scope_='mmdet', 62 | type='CSPNeXt', 63 | arch='P5', 64 | expand_ratio=0.5, 65 | deepen_factor=1., 66 | widen_factor=1., 67 | out_indices=(4, ), 68 | channel_attention=True, 69 | norm_cfg=dict(type='SyncBN'), 70 | act_cfg=dict(type='SiLU'), 71 | init_cfg=dict( 72 | type='Pretrained', 73 | prefix='backbone.', 74 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 75 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501 76 | )), 77 | head=dict( 78 | type='RTMCCHead', 79 | in_channels=1024, 80 | out_channels=133, 81 | input_size=codec['input_size'], 82 | in_featuremap_size=(9, 12), 83 | simcc_split_ratio=codec['simcc_split_ratio'], 84 | final_layer_kernel_size=7, 85 | gau_cfg=dict( 86 | hidden_dims=256, 87 | s=128, 88 | expansion_factor=2, 89 | dropout_rate=0., 90 | drop_path=0., 91 | act_fn='SiLU', 92 | use_rel_bias=False, 93 | pos_enc=False), 94 | loss=dict( 95 | type='KLDiscretLoss', 96 | use_target_weight=True, 97 | beta=10., 98 | label_softmax=True), 99 | decoder=codec), 100 | test_cfg=dict(flip_test=True, )) 101 | 102 | # base dataset settings 103 | dataset_type = 'UBody2dDataset' 104 | data_mode = 'topdown' 105 | data_root = 'data/UBody/' 106 | 107 | backend_args = dict(backend='local') 108 | 109 | scenes = [ 110 | 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', 111 | 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', 112 | 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference' 113 | ] 114 | 115 | train_datasets = [ 116 | dict( 117 | type='CocoWholeBodyDataset', 118 | data_root='data/coco/', 119 | data_mode=data_mode, 120 | ann_file='annotations/coco_wholebody_train_v1.0.json', 121 | data_prefix=dict(img='train2017/'), 122 | pipeline=[]) 123 | ] 124 | 125 | for scene in scenes: 126 | train_dataset = dict( 127 | type=dataset_type, 128 | data_root=data_root, 129 | data_mode=data_mode, 130 | ann_file=f'annotations/{scene}/train_annotations.json', 131 | data_prefix=dict(img='images/'), 132 | pipeline=[], 133 | sample_interval=10) 134 | train_datasets.append(train_dataset) 135 | 136 | # pipelines 137 | train_pipeline = [ 138 | dict(type='LoadImage', backend_args=backend_args), 139 | dict(type='GetBBoxCenterScale'), 140 | dict(type='RandomFlip', direction='horizontal'), 141 | dict(type='RandomHalfBody'), 142 | dict( 143 | type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90), 144 | dict(type='TopdownAffine', input_size=codec['input_size']), 145 | dict(type='mmdet.YOLOXHSVRandomAug'), 146 | dict( 147 | type='Albumentation', 148 | transforms=[ 149 | dict(type='Blur', p=0.1), 150 | dict(type='MedianBlur', p=0.1), 151 | dict( 152 | type='CoarseDropout', 153 | max_holes=1, 154 | max_height=0.4, 155 | max_width=0.4, 156 | min_holes=1, 157 | min_height=0.2, 158 | min_width=0.2, 159 | p=1.0), 160 | ]), 161 | dict(type='GenerateTarget', encoder=codec), 162 | dict(type='PackPoseInputs') 163 | ] 164 | val_pipeline = [ 165 | dict(type='LoadImage', backend_args=backend_args), 166 | dict(type='GetBBoxCenterScale'), 167 | dict(type='TopdownAffine', input_size=codec['input_size']), 168 | dict(type='PackPoseInputs') 169 | ] 170 | 171 | train_pipeline_stage2 = [ 172 | dict(type='LoadImage', backend_args=backend_args), 173 | dict(type='GetBBoxCenterScale'), 174 | dict(type='RandomFlip', direction='horizontal'), 175 | dict(type='RandomHalfBody'), 176 | dict( 177 | type='RandomBBoxTransform', 178 | shift_factor=0., 179 | scale_factor=[0.5, 1.5], 180 | rotate_factor=90), 181 | dict(type='TopdownAffine', input_size=codec['input_size']), 182 | dict(type='mmdet.YOLOXHSVRandomAug'), 183 | dict( 184 | type='Albumentation', 185 | transforms=[ 186 | dict(type='Blur', p=0.1), 187 | dict(type='MedianBlur', p=0.1), 188 | dict( 189 | type='CoarseDropout', 190 | max_holes=1, 191 | max_height=0.4, 192 | max_width=0.4, 193 | min_holes=1, 194 | min_height=0.2, 195 | min_width=0.2, 196 | p=0.5), 197 | ]), 198 | dict(type='GenerateTarget', encoder=codec), 199 | dict(type='PackPoseInputs') 200 | ] 201 | 202 | # data loaders 203 | train_dataloader = dict( 204 | batch_size=train_batch_size, 205 | num_workers=10, 206 | persistent_workers=True, 207 | sampler=dict(type='DefaultSampler', shuffle=True), 208 | dataset=dict( 209 | type='CombinedDataset', 210 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), 211 | datasets=train_datasets, 212 | pipeline=train_pipeline, 213 | test_mode=False, 214 | )) 215 | 216 | val_dataloader = dict( 217 | batch_size=val_batch_size, 218 | num_workers=10, 219 | persistent_workers=True, 220 | drop_last=False, 221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 222 | dataset=dict( 223 | type='CocoWholeBodyDataset', 224 | data_root=data_root, 225 | data_mode=data_mode, 226 | ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json', 227 | bbox_file='data/coco/person_detection_results/' 228 | 'COCO_val2017_detections_AP_H_56_person.json', 229 | data_prefix=dict(img='coco/val2017/'), 230 | test_mode=True, 231 | pipeline=val_pipeline, 232 | )) 233 | test_dataloader = val_dataloader 234 | 235 | # hooks 236 | default_hooks = dict( 237 | checkpoint=dict( 238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) 239 | 240 | custom_hooks = [ 241 | dict( 242 | type='EMAHook', 243 | ema_type='ExpMomentumEMA', 244 | momentum=0.0002, 245 | update_buffers=True, 246 | priority=49), 247 | dict( 248 | type='mmdet.PipelineSwitchHook', 249 | switch_epoch=max_epochs - stage2_num_epochs, 250 | switch_pipeline=train_pipeline_stage2) 251 | ] 252 | 253 | # evaluators 254 | val_evaluator = dict( 255 | type='CocoWholeBodyMetric', 256 | ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json') 257 | test_evaluator = val_evaluator 258 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/README.md: -------------------------------------------------------------------------------- 1 | The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. -------------------------------------------------------------------------------- /musetalk/utils/face_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Adrian Bulat""" 4 | __email__ = 'adrian.bulat@nottingham.ac.uk' 5 | __version__ = '1.0.1' 6 | 7 | from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face 8 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/__pycache__/api.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/api.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/api.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.model_zoo import load_url 5 | from enum import Enum 6 | import numpy as np 7 | import cv2 8 | try: 9 | import urllib.request as request_file 10 | except BaseException: 11 | import urllib as request_file 12 | 13 | from .models import FAN, ResNetDepth 14 | from .utils import * 15 | 16 | 17 | class LandmarksType(Enum): 18 | """Enum class defining the type of landmarks to detect. 19 | 20 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face 21 | ``_2halfD`` - this points represent the projection of the 3D points into 3D 22 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space 23 | 24 | """ 25 | _2D = 1 26 | _2halfD = 2 27 | _3D = 3 28 | 29 | 30 | class NetworkSize(Enum): 31 | # TINY = 1 32 | # SMALL = 2 33 | # MEDIUM = 3 34 | LARGE = 4 35 | 36 | def __new__(cls, value): 37 | member = object.__new__(cls) 38 | member._value_ = value 39 | return member 40 | 41 | def __int__(self): 42 | return self.value 43 | 44 | 45 | 46 | class FaceAlignment: 47 | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, 48 | device='cuda', flip_input=False, face_detector='sfd', verbose=False): 49 | self.device = device 50 | self.flip_input = flip_input 51 | self.landmarks_type = landmarks_type 52 | self.verbose = verbose 53 | 54 | network_size = int(network_size) 55 | 56 | if 'cuda' in device: 57 | torch.backends.cudnn.benchmark = True 58 | # torch.backends.cuda.matmul.allow_tf32 = False 59 | # torch.backends.cudnn.benchmark = True 60 | # torch.backends.cudnn.deterministic = False 61 | # torch.backends.cudnn.allow_tf32 = True 62 | print('cuda start') 63 | 64 | 65 | # Get the face detector 66 | face_detector_module = __import__('face_detection.detection.' + face_detector, 67 | globals(), locals(), [face_detector], 0) 68 | 69 | self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) 70 | 71 | def get_detections_for_batch(self, images): 72 | images = images[..., ::-1] 73 | detected_faces = self.face_detector.detect_from_batch(images.copy()) 74 | results = [] 75 | 76 | for i, d in enumerate(detected_faces): 77 | if len(d) == 0: 78 | results.append(None) 79 | continue 80 | d = d[0] 81 | d = np.clip(d, 0, None) 82 | 83 | x1, y1, x2, y2 = map(int, d[:-1]) 84 | results.append((x1, y1, x2, y2)) 85 | 86 | return results 87 | 88 | 89 | class YOLOv8_face: 90 | def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5): 91 | self.conf_threshold = conf_thres 92 | self.iou_threshold = iou_thres 93 | self.class_names = ['face'] 94 | self.num_classes = len(self.class_names) 95 | # Initialize model 96 | self.net = cv2.dnn.readNet(path) 97 | self.input_height = 640 98 | self.input_width = 640 99 | self.reg_max = 16 100 | 101 | self.project = np.arange(self.reg_max) 102 | self.strides = (8, 16, 32) 103 | self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))] 104 | self.anchors = self.make_anchors(self.feats_hw) 105 | 106 | def make_anchors(self, feats_hw, grid_cell_offset=0.5): 107 | """Generate anchors from features.""" 108 | anchor_points = {} 109 | for i, stride in enumerate(self.strides): 110 | h,w = feats_hw[i] 111 | x = np.arange(0, w) + grid_cell_offset # shift x 112 | y = np.arange(0, h) + grid_cell_offset # shift y 113 | sx, sy = np.meshgrid(x, y) 114 | # sy, sx = np.meshgrid(y, x) 115 | anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2) 116 | return anchor_points 117 | 118 | def softmax(self, x, axis=1): 119 | x_exp = np.exp(x) 120 | # 如果是列向量,则axis=0 121 | x_sum = np.sum(x_exp, axis=axis, keepdims=True) 122 | s = x_exp / x_sum 123 | return s 124 | 125 | def resize_image(self, srcimg, keep_ratio=True): 126 | top, left, newh, neww = 0, 0, self.input_width, self.input_height 127 | if keep_ratio and srcimg.shape[0] != srcimg.shape[1]: 128 | hw_scale = srcimg.shape[0] / srcimg.shape[1] 129 | if hw_scale > 1: 130 | newh, neww = self.input_height, int(self.input_width / hw_scale) 131 | img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA) 132 | left = int((self.input_width - neww) * 0.5) 133 | img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT, 134 | value=(0, 0, 0)) # add border 135 | else: 136 | newh, neww = int(self.input_height * hw_scale), self.input_width 137 | img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA) 138 | top = int((self.input_height - newh) * 0.5) 139 | img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT, 140 | value=(0, 0, 0)) 141 | else: 142 | img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA) 143 | return img, newh, neww, top, left 144 | 145 | def detect(self, srcimg): 146 | input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)) 147 | scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww 148 | input_img = input_img.astype(np.float32) / 255.0 149 | 150 | blob = cv2.dnn.blobFromImage(input_img) 151 | self.net.setInput(blob) 152 | outputs = self.net.forward(self.net.getUnconnectedOutLayersNames()) 153 | # if isinstance(outputs, tuple): 154 | # outputs = list(outputs) 155 | # if float(cv2.__version__[:3])>=4.7: 156 | # outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要 157 | # Perform inference on the image 158 | det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw) 159 | return det_bboxes, det_conf, det_classid, landmarks 160 | 161 | def post_process(self, preds, scale_h, scale_w, padh, padw): 162 | bboxes, scores, landmarks = [], [], [] 163 | for i, pred in enumerate(preds): 164 | stride = int(self.input_height/pred.shape[2]) 165 | pred = pred.transpose((0, 2, 3, 1)) 166 | 167 | box = pred[..., :self.reg_max * 4] 168 | cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1)) 169 | kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5 170 | 171 | # tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max) 172 | tmp = box.reshape(-1, 4, self.reg_max) 173 | bbox_pred = self.softmax(tmp, axis=-1) 174 | bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4)) 175 | 176 | bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride 177 | kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride 178 | kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride 179 | kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3])) 180 | 181 | bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则 182 | bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]]) 183 | kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15)) 184 | kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15)) 185 | 186 | bboxes.append(bbox) 187 | scores.append(cls) 188 | landmarks.append(kpts) 189 | 190 | bboxes = np.concatenate(bboxes, axis=0) 191 | scores = np.concatenate(scores, axis=0) 192 | landmarks = np.concatenate(landmarks, axis=0) 193 | 194 | bboxes_wh = bboxes.copy() 195 | bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh 196 | classIds = np.argmax(scores, axis=1) 197 | confidences = np.max(scores, axis=1) ####max_class_confidence 198 | 199 | mask = confidences>self.conf_threshold 200 | bboxes_wh = bboxes_wh[mask] ###合理使用广播法则 201 | confidences = confidences[mask] 202 | classIds = classIds[mask] 203 | landmarks = landmarks[mask] 204 | 205 | indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold, 206 | self.iou_threshold).flatten() 207 | if len(indices) > 0: 208 | mlvl_bboxes = bboxes_wh[indices] 209 | confidences = confidences[indices] 210 | classIds = classIds[indices] 211 | landmarks = landmarks[indices] 212 | return mlvl_bboxes, confidences, classIds, landmarks 213 | else: 214 | print('nothing detect') 215 | return np.array([]), np.array([]), np.array([]), np.array([]) 216 | 217 | def distance2bbox(self, points, distance, max_shape=None): 218 | x1 = points[:, 0] - distance[:, 0] 219 | y1 = points[:, 1] - distance[:, 1] 220 | x2 = points[:, 0] + distance[:, 2] 221 | y2 = points[:, 1] + distance[:, 3] 222 | if max_shape is not None: 223 | x1 = np.clip(x1, 0, max_shape[1]) 224 | y1 = np.clip(y1, 0, max_shape[0]) 225 | x2 = np.clip(x2, 0, max_shape[1]) 226 | y2 = np.clip(y2, 0, max_shape[0]) 227 | return np.stack([x1, y1, x2, y2], axis=-1) 228 | 229 | def draw_detections(self, image, boxes, scores, kpts): 230 | for box, score, kp in zip(boxes, scores, kpts): 231 | x, y, w, h = box.astype(int) 232 | # Draw rectangle 233 | cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3) 234 | cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2) 235 | for i in range(5): 236 | cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1) 237 | # cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1) 238 | return image 239 | 240 | ROOT = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FaceDetector -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/__pycache__/core.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/__pycache__/core.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | 9 | class FaceDetector(object): 10 | """An abstract class representing a face detector. 11 | 12 | Any other face detection implementation must subclass it. All subclasses 13 | must implement ``detect_from_image``, that return a list of detected 14 | bounding boxes. Optionally, for speed considerations detect from path is 15 | recommended. 16 | """ 17 | 18 | def __init__(self, device, verbose): 19 | self.device = device 20 | self.verbose = verbose 21 | 22 | if verbose: 23 | if 'cpu' in device: 24 | logger = logging.getLogger(__name__) 25 | logger.warning("Detection running on CPU, this may be potentially slow.") 26 | 27 | if 'cpu' not in device and 'cuda' not in device: 28 | if verbose: 29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 30 | raise ValueError 31 | 32 | def detect_from_image(self, tensor_or_path): 33 | """Detects faces in a given image. 34 | 35 | This function detects the faces present in a provided BGR(usually) 36 | image. The input can be either the image itself or the path to it. 37 | 38 | Arguments: 39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 40 | to an image or the image itself. 41 | 42 | Example:: 43 | 44 | >>> path_to_image = 'data/image_01.jpg' 45 | ... detected_faces = detect_from_image(path_to_image) 46 | [A list of bounding boxes (x1, y1, x2, y2)] 47 | >>> image = cv2.imread(path_to_image) 48 | ... detected_faces = detect_from_image(image) 49 | [A list of bounding boxes (x1, y1, x2, y2)] 50 | 51 | """ 52 | raise NotImplementedError 53 | 54 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 55 | """Detects faces from all the images present in a given directory. 56 | 57 | Arguments: 58 | path {string} -- a string containing a path that points to the folder containing the images 59 | 60 | Keyword Arguments: 61 | extensions {list} -- list of string containing the extensions to be 62 | consider in the following format: ``.extension_name`` (default: 63 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 64 | folder recursively (default: {False}) show_progress_bar {bool} -- 65 | display a progressbar (default: {True}) 66 | 67 | Example: 68 | >>> directory = 'data' 69 | ... detected_faces = detect_from_directory(directory) 70 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 71 | 72 | """ 73 | if self.verbose: 74 | logger = logging.getLogger(__name__) 75 | 76 | if len(extensions) == 0: 77 | if self.verbose: 78 | logger.error("Expected at list one extension, but none was received.") 79 | raise ValueError 80 | 81 | if self.verbose: 82 | logger.info("Constructing the list of images.") 83 | additional_pattern = '/**/*' if recursive else '/*' 84 | files = [] 85 | for extension in extensions: 86 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 87 | 88 | if self.verbose: 89 | logger.info("Finished searching for images. %s images found", len(files)) 90 | logger.info("Preparing to run the detection.") 91 | 92 | predictions = {} 93 | for image_path in tqdm(files, disable=not show_progress_bar): 94 | if self.verbose: 95 | logger.info("Running the face detector on image: %s", image_path) 96 | predictions[image_path] = self.detect_from_image(image_path) 97 | 98 | if self.verbose: 99 | logger.info("The detector was successfully run on all %s images", len(files)) 100 | 101 | return predictions 102 | 103 | @property 104 | def reference_scale(self): 105 | raise NotImplementedError 106 | 107 | @property 108 | def reference_x_shift(self): 109 | raise NotImplementedError 110 | 111 | @property 112 | def reference_y_shift(self): 113 | raise NotImplementedError 114 | 115 | @staticmethod 116 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 117 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 118 | 119 | Arguments: 120 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 121 | """ 122 | if isinstance(tensor_or_path, str): 123 | return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] 124 | elif torch.is_tensor(tensor_or_path): 125 | # Call cpu in case its coming from cuda 126 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 127 | elif isinstance(tensor_or_path, np.ndarray): 128 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 129 | else: 130 | raise TypeError 131 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | if 0 == len(dets): 46 | return [] 47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | order = scores.argsort()[::-1] 50 | 51 | keep = [] 52 | while order.size > 0: 53 | i = order[0] 54 | keep.append(i) 55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 57 | 58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 60 | 61 | inds = np.where(ovr <= thresh)[0] 62 | order = order[inds + 1] 63 | 64 | return keep 65 | 66 | 67 | def encode(matched, priors, variances): 68 | """Encode the variances from the priorbox layers into the ground truth boxes 69 | we have matched (based on jaccard overlap) with the prior boxes. 70 | Args: 71 | matched: (tensor) Coords of ground truth for each prior in point-form 72 | Shape: [num_priors, 4]. 73 | priors: (tensor) Prior boxes in center-offset form 74 | Shape: [num_priors,4]. 75 | variances: (list[float]) Variances of priorboxes 76 | Return: 77 | encoded boxes (tensor), Shape: [num_priors, 4] 78 | """ 79 | 80 | # dist b/t match center and prior's center 81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 82 | # encode variance 83 | g_cxcy /= (variances[0] * priors[:, 2:]) 84 | # match wh / prior wh 85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 86 | g_wh = torch.log(g_wh) / variances[1] 87 | # return target for smooth_l1_loss 88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 89 | 90 | 91 | def decode(loc, priors, variances): 92 | """Decode locations from predictions using priors to undo 93 | the encoding we did for offset regression at train time. 94 | Args: 95 | loc (tensor): location predictions for loc layers, 96 | Shape: [num_priors,4] 97 | priors (tensor): Prior boxes in center-offset form. 98 | Shape: [num_priors,4]. 99 | variances: (list[float]) Variances of priorboxes 100 | Return: 101 | decoded bounding box predictions 102 | """ 103 | 104 | boxes = torch.cat(( 105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 107 | boxes[:, :2] -= boxes[:, 2:] / 2 108 | boxes[:, 2:] += boxes[:, :2] 109 | return boxes 110 | 111 | def batch_decode(loc, priors, variances): 112 | """Decode locations from predictions using priors to undo 113 | the encoding we did for offset regression at train time. 114 | Args: 115 | loc (tensor): location predictions for loc layers, 116 | Shape: [num_priors,4] 117 | priors (tensor): Prior boxes in center-offset form. 118 | Shape: [num_priors,4]. 119 | variances: (list[float]) Variances of priorboxes 120 | Return: 121 | decoded bounding box predictions 122 | """ 123 | 124 | boxes = torch.cat(( 125 | priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], 126 | priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) 127 | boxes[:, :, :2] -= boxes[:, :, 2:] / 2 128 | boxes[:, :, 2:] += boxes[:, :, :2] 129 | return boxes 130 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | 18 | 19 | def detect(net, img, device): 20 | img = img - np.array([104, 117, 123]) 21 | img = img.transpose(2, 0, 1) 22 | img = img.reshape((1,) + img.shape) 23 | 24 | if 'cuda' in device: 25 | torch.backends.cudnn.benchmark = True 26 | 27 | img = torch.from_numpy(img).float().to(device) 28 | BB, CC, HH, WW = img.size() 29 | with torch.no_grad(): 30 | olist = net(img) 31 | 32 | bboxlist = [] 33 | for i in range(len(olist) // 2): 34 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 35 | olist = [oelem.data.cpu() for oelem in olist] 36 | for i in range(len(olist) // 2): 37 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 38 | FB, FC, FH, FW = ocls.size() # feature map size 39 | stride = 2**(i + 2) # 4,8,16,32,64,128 40 | anchor = stride * 4 41 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 42 | for Iindex, hindex, windex in poss: 43 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 44 | score = ocls[0, 1, hindex, windex] 45 | loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) 46 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 47 | variances = [0.1, 0.2] 48 | box = decode(loc, priors, variances) 49 | x1, y1, x2, y2 = box[0] * 1.0 50 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 51 | bboxlist.append([x1, y1, x2, y2, score]) 52 | bboxlist = np.array(bboxlist) 53 | if 0 == len(bboxlist): 54 | bboxlist = np.zeros((1, 5)) 55 | 56 | return bboxlist 57 | 58 | def batch_detect(net, imgs, device): 59 | imgs = imgs - np.array([104, 117, 123]) 60 | imgs = imgs.transpose(0, 3, 1, 2) 61 | 62 | if 'cuda' in device: 63 | torch.backends.cudnn.benchmark = True 64 | 65 | imgs = torch.from_numpy(imgs).float().to(device) 66 | BB, CC, HH, WW = imgs.size() 67 | with torch.no_grad(): 68 | olist = net(imgs) 69 | # print(olist) 70 | 71 | bboxlist = [] 72 | for i in range(len(olist) // 2): 73 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 74 | 75 | olist = [oelem.cpu() for oelem in olist] 76 | for i in range(len(olist) // 2): 77 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 78 | FB, FC, FH, FW = ocls.size() # feature map size 79 | stride = 2**(i + 2) # 4,8,16,32,64,128 80 | anchor = stride * 4 81 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 82 | for Iindex, hindex, windex in poss: 83 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 84 | score = ocls[:, 1, hindex, windex] 85 | loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) 86 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) 87 | variances = [0.1, 0.2] 88 | box = batch_decode(loc, priors, variances) 89 | box = box[:, 0] * 1.0 90 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 91 | bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) 92 | bboxlist = np.array(bboxlist) 93 | if 0 == len(bboxlist): 94 | bboxlist = np.zeros((1, BB, 5)) 95 | 96 | return bboxlist 97 | 98 | def flip_detect(net, img, device): 99 | img = cv2.flip(img, 1) 100 | b = detect(net, img, device) 101 | 102 | bboxlist = np.zeros(b.shape) 103 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 104 | bboxlist[:, 1] = b[:, 1] 105 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 106 | bboxlist[:, 3] = b[:, 3] 107 | bboxlist[:, 4] = b[:, 4] 108 | return bboxlist 109 | 110 | 111 | def pts_to_bb(pts): 112 | min_x, min_y = np.min(pts, axis=0) 113 | max_x, max_y = np.max(pts, axis=0) 114 | return np.array([min_x, min_y, max_x, max_y]) 115 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/detection/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | 5 | from ..core import FaceDetector 6 | 7 | from .net_s3fd import s3fd 8 | from .bbox import * 9 | from .detect import * 10 | 11 | models_urls = { 12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 13 | } 14 | 15 | 16 | class SFDDetector(FaceDetector): 17 | def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): 18 | super(SFDDetector, self).__init__(device, verbose) 19 | 20 | # Initialise the face detector 21 | if not os.path.isfile(path_to_detector): 22 | model_weights = load_url(models_urls['s3fd']) 23 | else: 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_image(self, tensor_or_path): 32 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 33 | 34 | bboxlist = detect(self.face_detector, image, device=self.device) 35 | keep = nms(bboxlist, 0.3) 36 | bboxlist = bboxlist[keep, :] 37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 38 | 39 | return bboxlist 40 | 41 | def detect_from_batch(self, images): 42 | bboxlists = batch_detect(self.face_detector, images, device=self.device) 43 | keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] 44 | bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] 45 | bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] 46 | 47 | return bboxlists 48 | 49 | @property 50 | def reference_scale(self): 51 | return 195 52 | 53 | @property 54 | def reference_x_shift(self): 55 | return 0 56 | 57 | @property 58 | def reference_y_shift(self): 59 | return 0 60 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class HourGlass(nn.Module): 99 | def __init__(self, num_modules, depth, num_features): 100 | super(HourGlass, self).__init__() 101 | self.num_modules = num_modules 102 | self.depth = depth 103 | self.features = num_features 104 | 105 | self._generate_network(self.depth) 106 | 107 | def _generate_network(self, level): 108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) 109 | 110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) 111 | 112 | if level > 1: 113 | self._generate_network(level - 1) 114 | else: 115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) 116 | 117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) 118 | 119 | def _forward(self, level, inp): 120 | # Upper branch 121 | up1 = inp 122 | up1 = self._modules['b1_' + str(level)](up1) 123 | 124 | # Lower branch 125 | low1 = F.avg_pool2d(inp, 2, stride=2) 126 | low1 = self._modules['b2_' + str(level)](low1) 127 | 128 | if level > 1: 129 | low2 = self._forward(level - 1, low1) 130 | else: 131 | low2 = low1 132 | low2 = self._modules['b2_plus_' + str(level)](low2) 133 | 134 | low3 = low2 135 | low3 = self._modules['b3_' + str(level)](low3) 136 | 137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 138 | 139 | return up1 + up2 140 | 141 | def forward(self, x): 142 | return self._forward(self.depth, x) 143 | 144 | 145 | class FAN(nn.Module): 146 | 147 | def __init__(self, num_modules=1): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | 151 | # Base part 152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.conv2 = ConvBlock(64, 128) 155 | self.conv3 = ConvBlock(128, 128) 156 | self.conv4 = ConvBlock(128, 256) 157 | 158 | # Stacking part 159 | for hg_module in range(self.num_modules): 160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 162 | self.add_module('conv_last' + str(hg_module), 163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 165 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 166 | 68, kernel_size=1, stride=1, padding=0)) 167 | 168 | if hg_module < self.num_modules - 1: 169 | self.add_module( 170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 172 | 256, kernel_size=1, stride=1, padding=0)) 173 | 174 | def forward(self, x): 175 | x = F.relu(self.bn1(self.conv1(x)), True) 176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 177 | x = self.conv3(x) 178 | x = self.conv4(x) 179 | 180 | previous = x 181 | 182 | outputs = [] 183 | for i in range(self.num_modules): 184 | hg = self._modules['m' + str(i)](previous) 185 | 186 | ll = hg 187 | ll = self._modules['top_m_' + str(i)](ll) 188 | 189 | ll = F.relu(self._modules['bn_end' + str(i)] 190 | (self._modules['conv_last' + str(i)](ll)), True) 191 | 192 | # Predict heatmaps 193 | tmp_out = self._modules['l' + str(i)](ll) 194 | outputs.append(tmp_out) 195 | 196 | if i < self.num_modules - 1: 197 | ll = self._modules['bl' + str(i)](ll) 198 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 199 | previous = previous + ll + tmp_out_ 200 | 201 | return outputs 202 | 203 | 204 | class ResNetDepth(nn.Module): 205 | 206 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): 207 | self.inplanes = 64 208 | super(ResNetDepth, self).__init__() 209 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, 210 | bias=False) 211 | self.bn1 = nn.BatchNorm2d(64) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 214 | self.layer1 = self._make_layer(block, 64, layers[0]) 215 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 216 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 218 | self.avgpool = nn.AvgPool2d(7) 219 | self.fc = nn.Linear(512 * block.expansion, num_classes) 220 | 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 224 | m.weight.data.normal_(0, math.sqrt(2. / n)) 225 | elif isinstance(m, nn.BatchNorm2d): 226 | m.weight.data.fill_(1) 227 | m.bias.data.zero_() 228 | 229 | def _make_layer(self, block, planes, blocks, stride=1): 230 | downsample = None 231 | if stride != 1 or self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | nn.Conv2d(self.inplanes, planes * block.expansion, 234 | kernel_size=1, stride=stride, bias=False), 235 | nn.BatchNorm2d(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, stride, downsample)) 240 | self.inplanes = planes * block.expansion 241 | for i in range(1, blocks): 242 | layers.append(block(self.inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.maxpool(x) 251 | 252 | x = self.layer1(x) 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | x = self.layer4(x) 256 | 257 | x = self.avgpool(x) 258 | x = x.view(x.size(0), -1) 259 | x = self.fc(x) 260 | 261 | return x 262 | -------------------------------------------------------------------------------- /musetalk/utils/face_detection/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import math 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def _gaussian( 12 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None, 13 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, 14 | mean_vert=0.5): 15 | # handle some defaults 16 | if width is None: 17 | width = size 18 | if height is None: 19 | height = size 20 | if sigma_horz is None: 21 | sigma_horz = sigma 22 | if sigma_vert is None: 23 | sigma_vert = sigma 24 | center_x = mean_horz * width + 0.5 25 | center_y = mean_vert * height + 0.5 26 | gauss = np.empty((height, width), dtype=np.float32) 27 | # generate kernel 28 | for i in range(height): 29 | for j in range(width): 30 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( 31 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) 32 | if normalize: 33 | gauss = gauss / np.sum(gauss) 34 | return gauss 35 | 36 | 37 | def draw_gaussian(image, point, sigma): 38 | # Check if the gaussian is inside 39 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] 40 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] 41 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): 42 | return image 43 | size = 6 * sigma + 1 44 | g = _gaussian(size) 45 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] 46 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] 47 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 48 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 49 | assert (g_x[0] > 0 and g_y[1] > 0) 50 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] 51 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] 52 | image[image > 1] = 1 53 | return image 54 | 55 | 56 | def transform(point, center, scale, resolution, invert=False): 57 | """Generate and affine transformation matrix. 58 | 59 | Given a set of points, a center, a scale and a targer resolution, the 60 | function generates and affine transformation matrix. If invert is ``True`` 61 | it will produce the inverse transformation. 62 | 63 | Arguments: 64 | point {torch.tensor} -- the input 2D point 65 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations 66 | scale {float} -- the scale of the face/object 67 | resolution {float} -- the output resolution 68 | 69 | Keyword Arguments: 70 | invert {bool} -- define wherever the function should produce the direct or the 71 | inverse transformation matrix (default: {False}) 72 | """ 73 | _pt = torch.ones(3) 74 | _pt[0] = point[0] 75 | _pt[1] = point[1] 76 | 77 | h = 200.0 * scale 78 | t = torch.eye(3) 79 | t[0, 0] = resolution / h 80 | t[1, 1] = resolution / h 81 | t[0, 2] = resolution * (-center[0] / h + 0.5) 82 | t[1, 2] = resolution * (-center[1] / h + 0.5) 83 | 84 | if invert: 85 | t = torch.inverse(t) 86 | 87 | new_point = (torch.matmul(t, _pt))[0:2] 88 | 89 | return new_point.int() 90 | 91 | 92 | def crop(image, center, scale, resolution=256.0): 93 | """Center crops an image or set of heatmaps 94 | 95 | Arguments: 96 | image {numpy.array} -- an rgb image 97 | center {numpy.array} -- the center of the object, usually the same as of the bounding box 98 | scale {float} -- scale of the face 99 | 100 | Keyword Arguments: 101 | resolution {float} -- the size of the output cropped image (default: {256.0}) 102 | 103 | Returns: 104 | [type] -- [description] 105 | """ # Crop around the center point 106 | """ Crops the image around the center. Input is expected to be an np.ndarray """ 107 | ul = transform([1, 1], center, scale, resolution, True) 108 | br = transform([resolution, resolution], center, scale, resolution, True) 109 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) 110 | if image.ndim > 2: 111 | newDim = np.array([br[1] - ul[1], br[0] - ul[0], 112 | image.shape[2]], dtype=np.int32) 113 | newImg = np.zeros(newDim, dtype=np.uint8) 114 | else: 115 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) 116 | newImg = np.zeros(newDim, dtype=np.uint8) 117 | ht = image.shape[0] 118 | wd = image.shape[1] 119 | newX = np.array( 120 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) 121 | newY = np.array( 122 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) 123 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) 124 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) 125 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] 126 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] 127 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), 128 | interpolation=cv2.INTER_LINEAR) 129 | return newImg 130 | 131 | 132 | def get_preds_fromhm(hm, center=None, scale=None): 133 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 134 | and the scale is provided the function will return the points also in 135 | the original coordinate frame. 136 | 137 | Arguments: 138 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 139 | 140 | Keyword Arguments: 141 | center {torch.tensor} -- the center of the bounding box (default: {None}) 142 | scale {float} -- face scale (default: {None}) 143 | """ 144 | max, idx = torch.max( 145 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 146 | idx += 1 147 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 148 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 149 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 150 | 151 | for i in range(preds.size(0)): 152 | for j in range(preds.size(1)): 153 | hm_ = hm[i, j, :] 154 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 155 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 156 | diff = torch.FloatTensor( 157 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 158 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 159 | preds[i, j].add_(diff.sign_().mul_(.25)) 160 | 161 | preds.add_(-.5) 162 | 163 | preds_orig = torch.zeros(preds.size()) 164 | if center is not None and scale is not None: 165 | for i in range(hm.size(0)): 166 | for j in range(hm.size(1)): 167 | preds_orig[i, j] = transform( 168 | preds[i, j], center, scale, hm.size(2), True) 169 | 170 | return preds, preds_orig 171 | 172 | def get_preds_fromhm_batch(hm, centers=None, scales=None): 173 | """Obtain (x,y) coordinates given a set of N heatmaps. If the centers 174 | and the scales is provided the function will return the points also in 175 | the original coordinate frame. 176 | 177 | Arguments: 178 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 179 | 180 | Keyword Arguments: 181 | centers {torch.tensor} -- the centers of the bounding box (default: {None}) 182 | scales {float} -- face scales (default: {None}) 183 | """ 184 | max, idx = torch.max( 185 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 186 | idx += 1 187 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 188 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 189 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 190 | 191 | for i in range(preds.size(0)): 192 | for j in range(preds.size(1)): 193 | hm_ = hm[i, j, :] 194 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 195 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 196 | diff = torch.FloatTensor( 197 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 198 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 199 | preds[i, j].add_(diff.sign_().mul_(.25)) 200 | 201 | preds.add_(-.5) 202 | 203 | preds_orig = torch.zeros(preds.size()) 204 | if centers is not None and scales is not None: 205 | for i in range(hm.size(0)): 206 | for j in range(hm.size(1)): 207 | preds_orig[i, j] = transform( 208 | preds[i, j], centers[i], scales[i], hm.size(2), True) 209 | 210 | return preds, preds_orig 211 | 212 | def shuffle_lr(parts, pairs=None): 213 | """Shuffle the points left-right according to the axis of symmetry 214 | of the object. 215 | 216 | Arguments: 217 | parts {torch.tensor} -- a 3D or 4D object containing the 218 | heatmaps. 219 | 220 | Keyword Arguments: 221 | pairs {list of integers} -- [order of the flipped points] (default: {None}) 222 | """ 223 | if pairs is None: 224 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 225 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 226 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 227 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 228 | 62, 61, 60, 67, 66, 65] 229 | if parts.ndimension() == 3: 230 | parts = parts[pairs, ...] 231 | else: 232 | parts = parts[:, pairs, ...] 233 | 234 | return parts 235 | 236 | 237 | def flip(tensor, is_label=False): 238 | """Flip an image or a set of heatmaps left-right 239 | 240 | Arguments: 241 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] 242 | 243 | Keyword Arguments: 244 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) 245 | """ 246 | if not torch.is_tensor(tensor): 247 | tensor = torch.from_numpy(tensor) 248 | 249 | if is_label: 250 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) 251 | else: 252 | tensor = tensor.flip(tensor.ndimension() - 1) 253 | 254 | return tensor 255 | 256 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) 257 | 258 | 259 | def appdata_dir(appname=None, roaming=False): 260 | """ appdata_dir(appname=None, roaming=False) 261 | 262 | Get the path to the application directory, where applications are allowed 263 | to write user specific files (e.g. configurations). For non-user specific 264 | data, consider using common_appdata_dir(). 265 | If appname is given, a subdir is appended (and created if necessary). 266 | If roaming is True, will prefer a roaming directory (Windows Vista/7). 267 | """ 268 | 269 | # Define default user directory 270 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None) 271 | if userDir is None: 272 | userDir = os.path.expanduser('~') 273 | if not os.path.isdir(userDir): # pragma: no cover 274 | userDir = '/var/tmp' # issue #54 275 | 276 | # Get system app data dir 277 | path = None 278 | if sys.platform.startswith('win'): 279 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') 280 | path = (path2 or path1) if roaming else (path1 or path2) 281 | elif sys.platform.startswith('darwin'): 282 | path = os.path.join(userDir, 'Library', 'Application Support') 283 | # On Linux and as fallback 284 | if not (path and os.path.isdir(path)): 285 | path = userDir 286 | 287 | # Maybe we should store things local to the executable (in case of a 288 | # portable distro or a frozen application that wants to be portable) 289 | prefix = sys.prefix 290 | if getattr(sys, 'frozen', None): 291 | prefix = os.path.abspath(os.path.dirname(sys.executable)) 292 | for reldir in ('settings', '../settings'): 293 | localpath = os.path.abspath(os.path.join(prefix, reldir)) 294 | if os.path.isdir(localpath): # pragma: no cover 295 | try: 296 | open(os.path.join(localpath, 'test.write'), 'wb').close() 297 | os.remove(os.path.join(localpath, 'test.write')) 298 | except IOError: 299 | pass # We cannot write in this directory 300 | else: 301 | path = localpath 302 | break 303 | 304 | # Get path specific for this app 305 | if appname: 306 | if path == userDir: 307 | appname = '.' + appname.lstrip('.') # Make it a hidden directory 308 | path = os.path.join(path, appname) 309 | if not os.path.isdir(path): # pragma: no cover 310 | os.mkdir(path) 311 | 312 | # Done 313 | return path 314 | -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from .model import BiSeNet 8 | import torchvision.transforms as transforms 9 | 10 | class FaceParsing(): 11 | def __init__(self,resnet_path,model_pth): 12 | self.net = self.model_init(resnet_path,model_pth) 13 | self.preprocess = self.image_preprocess() 14 | 15 | def model_init(self, 16 | resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', 17 | model_pth='./models/face-parse-bisent/79999_iter.pth'): 18 | net = BiSeNet(resnet_path) 19 | if torch.cuda.is_available(): 20 | net.cuda() 21 | net.load_state_dict(torch.load(model_pth)) 22 | else: 23 | net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) 24 | net.eval() 25 | return net 26 | 27 | def image_preprocess(self): 28 | return transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 31 | ]) 32 | 33 | def __call__(self, image, size=(512, 512)): 34 | if isinstance(image, str): 35 | image = Image.open(image) 36 | 37 | width, height = image.size 38 | with torch.no_grad(): 39 | image = image.resize(size, Image.BILINEAR) 40 | img = self.preprocess(image) 41 | if torch.cuda.is_available(): 42 | img = torch.unsqueeze(img, 0).cuda() 43 | else: 44 | img = torch.unsqueeze(img, 0) 45 | out = self.net(img)[0] 46 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 47 | parsing[np.where(parsing>13)] = 0 48 | parsing[np.where(parsing>=1)] = 255 49 | parsing = Image.fromarray(parsing.astype(np.uint8)) 50 | return parsing 51 | 52 | if __name__ == "__main__": 53 | fp = FaceParsing() 54 | segmap = fp('154_small.png') 55 | segmap.save('res.png') 56 | 57 | -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from .resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, resnet_path, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18(resnet_path) 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath(resnet_path) 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() 284 | -------------------------------------------------------------------------------- /musetalk/utils/face_parsing/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self, model_path): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight(model_path) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self, model_path): 83 | state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /musetalk/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from face_detection import FaceAlignment,LandmarksType 3 | from os import listdir, path 4 | import subprocess 5 | import numpy as np 6 | import cv2 7 | import pickle 8 | import os 9 | import json 10 | from mmpose.apis import inference_topdown, init_model 11 | from mmpose.structures import merge_data_samples 12 | import torch 13 | from tqdm import tqdm 14 | parent_directory = os.path.dirname(os.path.abspath(__file__)) 15 | # initialize the mmpose model 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | ''' 18 | config_file = os.path.join(parent_directory,"dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py") 19 | checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' 20 | model = init_model(config_file, checkpoint_file, device=device) 21 | ''' 22 | 23 | 24 | # initialize the face detection model 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device) 27 | 28 | # maker if the bbox is not sufficient 29 | coord_placeholder = (0.0,0.0,0.0,0.0) 30 | 31 | def resize_landmark(landmark, w, h, new_w, new_h): 32 | w_ratio = new_w / w 33 | h_ratio = new_h / h 34 | landmark_norm = landmark / [w, h] 35 | landmark_resized = landmark_norm * [new_w, new_h] 36 | return landmark_resized 37 | 38 | def read_imgs(img_list): 39 | frames = [] 40 | print('reading images...') 41 | for img_path in tqdm(img_list): 42 | frame = cv2.imread(img_path) 43 | frames.append(frame) 44 | return frames 45 | 46 | def get_bbox_range(model,img_list,batch_size_fa,upperbondrange =0): 47 | frames = read_imgs(img_list) 48 | # batch_size_fa = 1 49 | batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)] 50 | coords_list = [] 51 | landmarks = [] 52 | if upperbondrange != 0: 53 | print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange) 54 | else: 55 | print('get key_landmark and face bounding boxes with the default value') 56 | average_range_minus = [] 57 | average_range_plus = [] 58 | for fb in tqdm(batches): 59 | results = inference_topdown(model, np.asarray(fb)[0]) 60 | results = merge_data_samples(results) 61 | keypoints = results.pred_instances.keypoints 62 | face_land_mark= keypoints[0][23:91] 63 | face_land_mark = face_land_mark.astype(np.int32) 64 | 65 | # get bounding boxes by face detetion 66 | bbox = fa.get_detections_for_batch(np.asarray(fb)) 67 | 68 | # adjust the bounding box refer to landmark 69 | # Add the bounding box to a tuple and append it to the coordinates list 70 | for j, f in enumerate(bbox): 71 | if f is None: # no face in the image 72 | coords_list += [coord_placeholder] 73 | continue 74 | 75 | half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0) 76 | range_minus = (face_land_mark[30]- face_land_mark[29])[1] 77 | range_plus = (face_land_mark[29]- face_land_mark[28])[1] 78 | average_range_minus.append(range_minus) 79 | average_range_plus.append(range_plus) 80 | if upperbondrange != 0: 81 | half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28) 82 | 83 | text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}" 84 | return text_range 85 | 86 | 87 | def get_landmark_and_bbox(model,img_list,batch_size_fa,upperbondrange =0): 88 | frames = read_imgs(img_list) 89 | # batch_size_fa = 1 90 | batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)] 91 | coords_list = [] 92 | landmarks = [] 93 | if upperbondrange != 0: 94 | print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange) 95 | else: 96 | print('get key_landmark and face bounding boxes with the default value') 97 | average_range_minus = [] 98 | average_range_plus = [] 99 | for fb in tqdm(batches): 100 | results = inference_topdown(model, np.asarray(fb)[0]) 101 | results = merge_data_samples(results) 102 | keypoints = results.pred_instances.keypoints 103 | face_land_mark= keypoints[0][23:91] 104 | face_land_mark = face_land_mark.astype(np.int32) 105 | 106 | # get bounding boxes by face detetion 107 | bbox = fa.get_detections_for_batch(np.asarray(fb)) 108 | 109 | # adjust the bounding box refer to landmark 110 | # Add the bounding box to a tuple and append it to the coordinates list 111 | for j, f in enumerate(bbox): 112 | if f is None: # no face in the image 113 | coords_list += [coord_placeholder] 114 | continue 115 | 116 | half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0) 117 | range_minus = (face_land_mark[30]- face_land_mark[29])[1] 118 | range_plus = (face_land_mark[29]- face_land_mark[28])[1] 119 | average_range_minus.append(range_minus) 120 | average_range_plus.append(range_plus) 121 | if upperbondrange != 0: 122 | half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28) 123 | half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1] 124 | upper_bond = half_face_coord[1]-half_face_dist 125 | upper_bond = max(0, upper_bond) 126 | # https://github.com/TMElyralab/MuseTalk/issues/38 127 | f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1])) 128 | x1, y1, x2, y2 = f_landmark 129 | 130 | if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox 131 | coords_list += [f] 132 | w,h = f[2]-f[0], f[3]-f[1] 133 | print("error bbox:",f) 134 | else: 135 | coords_list += [f_landmark] 136 | 137 | print("********************************************bbox_shift parameter adjustment**********************************************************") 138 | print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}") 139 | print("*************************************************************************************************************************************") 140 | return coords_list,frames 141 | 142 | 143 | if __name__ == "__main__": 144 | img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"] 145 | crop_coord_path = "./coord_face.pkl" 146 | coords_list,full_frames = get_landmark_and_bbox(img_list) 147 | with open(crop_coord_path, 'wb') as f: 148 | pickle.dump(coords_list, f) 149 | 150 | for bbox, frame in zip(coords_list,full_frames): 151 | if bbox == coord_placeholder: 152 | continue 153 | x1, y1, x2, y2 = bbox 154 | crop_frame = frame[y1:y2, x1:x2] 155 | print('Cropped shape', crop_frame.shape) 156 | 157 | #cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2]) 158 | print(coords_list) 159 | -------------------------------------------------------------------------------- /musetalk/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | ffmpeg_path = os.getenv('FFMPEG_PATH') 7 | if ffmpeg_path is None: 8 | print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static") 9 | elif ffmpeg_path not in os.getenv('PATH'): 10 | print("add ffmpeg to path") 11 | os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" 12 | 13 | 14 | from ..whisper.audio2feature import Audio2Feature 15 | from ..models.vae import VAE 16 | from ..models.unet import UNet,PositionalEncoding 17 | 18 | def load_all_model(base_dir): 19 | audio_processor = Audio2Feature(model_path=os.path.join(base_dir,"whisper/tiny.pt")) 20 | vae = VAE(model_path =os.path.join(base_dir,"sd-vae-ft-mse/")) 21 | unet = UNet(unet_config=os.path.join(base_dir,"musetalk/musetalk.json"), 22 | model_path =os.path.join(base_dir,"musetalk/pytorch_model.bin")) 23 | pe = PositionalEncoding(d_model=384) 24 | return audio_processor,vae,unet,pe 25 | 26 | def get_file_type(video_path): 27 | _, ext = os.path.splitext(video_path) 28 | 29 | if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']: 30 | return 'image' 31 | elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']: 32 | return 'video' 33 | else: 34 | return 'unsupported' 35 | 36 | def get_video_fps(video_path): 37 | video = cv2.VideoCapture(video_path) 38 | fps = video.get(cv2.CAP_PROP_FPS) 39 | video.release() 40 | return fps 41 | 42 | def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0): 43 | whisper_batch, latent_batch = [], [] 44 | for i, w in enumerate(whisper_chunks): 45 | idx = (i+delay_frame)%len(vae_encode_latents) 46 | latent = vae_encode_latents[idx] 47 | whisper_batch.append(w) 48 | latent_batch.append(latent) 49 | 50 | if len(latent_batch) >= batch_size: 51 | whisper_batch = np.asarray(whisper_batch) 52 | latent_batch = torch.cat(latent_batch, dim=0) 53 | yield whisper_batch, latent_batch 54 | whisper_batch, latent_batch = [], [] 55 | 56 | # the last batch may smaller than batch size 57 | if len(latent_batch) > 0: 58 | whisper_batch = np.asarray(whisper_batch) 59 | latent_batch = torch.cat(latent_batch, dim=0) 60 | 61 | yield whisper_batch, latent_batch -------------------------------------------------------------------------------- /musetalk/whisper/__pycache__/audio2feature.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/__pycache__/audio2feature.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/audio2feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .whisper import load_model 3 | import soundfile as sf 4 | import numpy as np 5 | import time 6 | import sys 7 | sys.path.append("..") 8 | 9 | class Audio2Feature(): 10 | def __init__(self, 11 | whisper_model_type="tiny", 12 | model_path="./models/whisper/tiny.pt"): 13 | self.whisper_model_type = whisper_model_type 14 | self.model = load_model(model_path) # 15 | 16 | def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25): 17 | """ 18 | Get sliced features based on a given index 19 | :param feature_array: 20 | :param start_idx: the start index of the feature 21 | :param audio_feat_length: 22 | :return: 23 | """ 24 | length = len(feature_array) 25 | selected_feature = [] 26 | selected_idx = [] 27 | 28 | center_idx = int(vid_idx*50/fps) 29 | left_idx = center_idx-audio_feat_length[0]*2 30 | right_idx = center_idx + (audio_feat_length[1]+1)*2 31 | 32 | for idx in range(left_idx,right_idx): 33 | idx = max(0, idx) 34 | idx = min(length-1, idx) 35 | x = feature_array[idx] 36 | selected_feature.append(x) 37 | selected_idx.append(idx) 38 | 39 | selected_feature = np.concatenate(selected_feature, axis=0) 40 | selected_feature = selected_feature.reshape(-1, 384)# 50*384 41 | return selected_feature,selected_idx 42 | 43 | def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25): 44 | """ 45 | Get sliced features based on a given index 46 | :param feature_array: 47 | :param start_idx: the start index of the feature 48 | :param audio_feat_length: 49 | :return: 50 | """ 51 | length = len(feature_array) 52 | selected_feature = [] 53 | selected_idx = [] 54 | 55 | for dt in range(-audio_feat_length[0],audio_feat_length[1]+1): 56 | left_idx = int((vid_idx+dt)*50/fps) 57 | if left_idx<1 or left_idx>length-1: 58 | left_idx = max(0, left_idx) 59 | left_idx = min(length-1, left_idx) 60 | 61 | x = feature_array[left_idx] 62 | x = x[np.newaxis,:,:] 63 | x = np.repeat(x, 2, axis=0) 64 | selected_feature.append(x) 65 | selected_idx.append(left_idx) 66 | selected_idx.append(left_idx) 67 | else: 68 | x = feature_array[left_idx-1:left_idx+1] 69 | selected_feature.append(x) 70 | selected_idx.append(left_idx-1) 71 | selected_idx.append(left_idx) 72 | selected_feature = np.concatenate(selected_feature, axis=0) 73 | selected_feature = selected_feature.reshape(-1, 384)# 50*384 74 | return selected_feature,selected_idx 75 | 76 | 77 | def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]): 78 | whisper_chunks = [] 79 | whisper_idx_multiplier = 50./fps 80 | i = 0 81 | print(f"video in {fps} FPS, audio idx in 50FPS") 82 | while 1: 83 | start_idx = int(i * whisper_idx_multiplier) 84 | selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps) 85 | #print(f"i:{i},selected_idx {selected_idx}") 86 | whisper_chunks.append(selected_feature) 87 | i += 1 88 | if start_idx>len(feature_array): 89 | break 90 | 91 | return whisper_chunks 92 | 93 | def audio2feat(self,audio_path): 94 | # get the sample rate of the audio 95 | result = self.model.transcribe(audio_path) 96 | embed_list = [] 97 | for emb in result['segments']: 98 | encoder_embeddings = emb['encoder_embeddings'] 99 | encoder_embeddings = encoder_embeddings.transpose(0,2,1,3) 100 | encoder_embeddings = encoder_embeddings.squeeze(0) 101 | start_idx = int(emb['start']) 102 | end_idx = int(emb['end']) 103 | emb_end_idx = int((end_idx - start_idx)/2) 104 | embed_list.append(encoder_embeddings[:emb_end_idx]) 105 | concatenated_array = np.concatenate(embed_list, axis=0) 106 | return concatenated_array 107 | 108 | if __name__ == "__main__": 109 | audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt") 110 | audio_path = "./test.mp3" 111 | array = audio_processor.audio2feat(audio_path) 112 | print(array.shape) 113 | fps = 25 114 | whisper_idx_multiplier = 50./fps 115 | 116 | i = 0 117 | print(f"video in {fps} FPS, audio idx in 50FPS") 118 | while 1: 119 | start_idx = int(i * whisper_idx_multiplier) 120 | selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps) 121 | print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}") 122 | i += 1 123 | if start_idx>len(array): 124 | break 125 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt", 27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 29 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 30 | } 31 | 32 | 33 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 34 | os.makedirs(root, exist_ok=True) 35 | 36 | expected_sha256 = url.split("/")[-2] 37 | download_target = os.path.join(root, os.path.basename(url)) 38 | 39 | if os.path.exists(download_target) and not os.path.isfile(download_target): 40 | raise RuntimeError(f"{download_target} exists and is not a regular file") 41 | 42 | if os.path.isfile(download_target): 43 | model_bytes = open(download_target, "rb").read() 44 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 45 | return model_bytes if in_memory else download_target 46 | else: 47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 48 | 49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 50 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 51 | while True: 52 | buffer = source.read(8192) 53 | if not buffer: 54 | break 55 | 56 | output.write(buffer) 57 | loop.update(len(buffer)) 58 | 59 | model_bytes = open(download_target, "rb").read() 60 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 61 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") 62 | 63 | return model_bytes if in_memory else download_target 64 | 65 | 66 | def available_models() -> List[str]: 67 | """Returns the names of available models""" 68 | return list(_MODELS.keys()) 69 | 70 | 71 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: 72 | """ 73 | Load a Whisper ASR model 74 | 75 | Parameters 76 | ---------- 77 | name : str 78 | one of the official model names listed by `whisper.available_models()`, or 79 | path to a model checkpoint containing the model dimensions and the model state_dict. 80 | device : Union[str, torch.device] 81 | the PyTorch device to put the model into 82 | download_root: str 83 | path to download the model files; by default, it uses "~/.cache/whisper" 84 | in_memory: bool 85 | whether to preload the model weights into host memory 86 | 87 | Returns 88 | ------- 89 | model : Whisper 90 | The Whisper ASR model instance 91 | """ 92 | 93 | if device is None: 94 | device = "cuda" if torch.cuda.is_available() else "cpu" 95 | if download_root is None: 96 | download_root = os.getenv( 97 | "XDG_CACHE_HOME", 98 | os.path.join(os.path.expanduser("~"), ".cache", "whisper") 99 | ) 100 | 101 | if name in _MODELS: 102 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 103 | elif os.path.isfile(name): 104 | checkpoint_file = open(name, "rb").read() if in_memory else name 105 | else: 106 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 107 | 108 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: 109 | checkpoint = torch.load(fp, map_location=device) 110 | del checkpoint_file 111 | 112 | dims = ModelDimensions(**checkpoint["dims"]) 113 | model = Whisper(dims) 114 | model.load_state_dict(checkpoint["model_state_dict"]) 115 | 116 | return model.to(device) 117 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/audio.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/audio.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/decoding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/decoding.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/transcribe.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/transcribe.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /musetalk/whisper/whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /musetalk/whisper/whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 90 | 91 | 92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 93 | """ 94 | Compute the log-Mel spectrogram of 95 | 96 | Parameters 97 | ---------- 98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 100 | 101 | n_mels: int 102 | The number of Mel-frequency filters, only 80 is supported 103 | 104 | Returns 105 | ------- 106 | torch.Tensor, shape = (80, n_frames) 107 | A Tensor that contains the Mel spectrogram 108 | """ 109 | if not torch.is_tensor(audio): 110 | if isinstance(audio, str): 111 | audio = load_audio(audio) 112 | audio = torch.from_numpy(audio) 113 | 114 | window = torch.hann_window(N_FFT).to(audio.device) 115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 116 | 117 | magnitudes = stft[:, :-1].abs() ** 2 118 | 119 | filters = mel_filters(audio.device, n_mels) 120 | mel_spec = filters @ magnitudes 121 | 122 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 123 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 124 | log_spec = (log_spec + 4.0) / 4.0 125 | return log_spec 126 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | from typing import Iterable, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch import nn 10 | 11 | from .transcribe import transcribe as transcribe_function 12 | from .decoding import detect_language as detect_language_function, decode as decode_function 13 | 14 | 15 | @dataclass 16 | class ModelDimensions: 17 | n_mels: int 18 | n_audio_ctx: int 19 | n_audio_state: int 20 | n_audio_head: int 21 | n_audio_layer: int 22 | n_vocab: int 23 | n_text_ctx: int 24 | n_text_state: int 25 | n_text_head: int 26 | n_text_layer: int 27 | 28 | 29 | class LayerNorm(nn.LayerNorm): 30 | def forward(self, x: Tensor) -> Tensor: 31 | return super().forward(x.float()).type(x.dtype) 32 | 33 | 34 | class Linear(nn.Linear): 35 | def forward(self, x: Tensor) -> Tensor: 36 | return F.linear( 37 | x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) 38 | ) 39 | 40 | 41 | class Conv1d(nn.Conv1d): 42 | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: 43 | return super()._conv_forward( 44 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 45 | ) 46 | 47 | 48 | def sinusoids(length, channels, max_timescale=10000): 49 | """Returns sinusoids for positional embedding""" 50 | assert channels % 2 == 0 51 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 52 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 53 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 54 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 55 | 56 | 57 | class MultiHeadAttention(nn.Module): 58 | def __init__(self, n_state: int, n_head: int): 59 | super().__init__() 60 | self.n_head = n_head 61 | self.query = Linear(n_state, n_state) 62 | self.key = Linear(n_state, n_state, bias=False) 63 | self.value = Linear(n_state, n_state) 64 | self.out = Linear(n_state, n_state) 65 | 66 | def forward( 67 | self, 68 | x: Tensor, 69 | xa: Optional[Tensor] = None, 70 | mask: Optional[Tensor] = None, 71 | kv_cache: Optional[dict] = None, 72 | ): 73 | q = self.query(x) 74 | 75 | if kv_cache is None or xa is None: 76 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 77 | # otherwise, perform key/value projections for self- or cross-attention as usual. 78 | k = self.key(x if xa is None else xa) 79 | v = self.value(x if xa is None else xa) 80 | else: 81 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 82 | k = kv_cache.get(self.key, self.key(xa)) 83 | v = kv_cache.get(self.value, self.value(xa)) 84 | 85 | wv = self.qkv_attention(q, k, v, mask) 86 | return self.out(wv) 87 | 88 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 89 | n_batch, n_ctx, n_state = q.shape 90 | scale = (n_state // self.n_head) ** -0.25 91 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 92 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 93 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 94 | 95 | qk = q @ k 96 | if mask is not None: 97 | qk = qk + mask[:n_ctx, :n_ctx] 98 | 99 | w = F.softmax(qk.float(), dim=-1).to(q.dtype) 100 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 101 | 102 | 103 | class ResidualAttentionBlock(nn.Module): 104 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 105 | super().__init__() 106 | 107 | self.attn = MultiHeadAttention(n_state, n_head) 108 | self.attn_ln = LayerNorm(n_state) 109 | 110 | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None 111 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 112 | 113 | n_mlp = n_state * 4 114 | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) 115 | self.mlp_ln = LayerNorm(n_state) 116 | 117 | def forward( 118 | self, 119 | x: Tensor, 120 | xa: Optional[Tensor] = None, 121 | mask: Optional[Tensor] = None, 122 | kv_cache: Optional[dict] = None, 123 | ): 124 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) 125 | if self.cross_attn: 126 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 127 | x = x + self.mlp(self.mlp_ln(x)) 128 | return x 129 | 130 | 131 | class AudioEncoder(nn.Module): 132 | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 133 | super().__init__() 134 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 135 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 136 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 137 | 138 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 139 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 140 | ) 141 | self.ln_post = LayerNorm(n_state) 142 | 143 | def forward(self, x: Tensor, include_embeddings: bool = False): 144 | """ 145 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 146 | the mel spectrogram of the audio 147 | include_embeddings: bool 148 | whether to include intermediate steps in the output 149 | """ 150 | x = F.gelu(self.conv1(x)) 151 | x = F.gelu(self.conv2(x)) 152 | x = x.permute(0, 2, 1) 153 | 154 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 155 | x = (x + self.positional_embedding).to(x.dtype) 156 | 157 | if include_embeddings: 158 | embeddings = [x.cpu().detach().numpy()] 159 | 160 | for block in self.blocks: 161 | x = block(x) 162 | if include_embeddings: 163 | embeddings.append(x.cpu().detach().numpy()) 164 | 165 | x = self.ln_post(x) 166 | 167 | if include_embeddings: 168 | embeddings = np.stack(embeddings, axis=1) 169 | return x, embeddings 170 | else: 171 | return x 172 | 173 | 174 | class TextDecoder(nn.Module): 175 | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): 176 | super().__init__() 177 | 178 | self.token_embedding = nn.Embedding(n_vocab, n_state) 179 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 180 | 181 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 182 | [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] 183 | ) 184 | self.ln = LayerNorm(n_state) 185 | 186 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 187 | self.register_buffer("mask", mask, persistent=False) 188 | 189 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False): 190 | """ 191 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 192 | the text tokens 193 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 194 | the encoded audio features to be attended on 195 | include_embeddings : bool 196 | Whether to include intermediate values in the output to this function 197 | """ 198 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 199 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] 200 | x = x.to(xa.dtype) 201 | 202 | if include_embeddings: 203 | embeddings = [x.cpu().detach().numpy()] 204 | 205 | for block in self.blocks: 206 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 207 | if include_embeddings: 208 | embeddings.append(x.cpu().detach().numpy()) 209 | 210 | x = self.ln(x) 211 | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() 212 | 213 | if include_embeddings: 214 | embeddings = np.stack(embeddings, axis=1) 215 | return logits, embeddings 216 | else: 217 | return logits 218 | 219 | 220 | class Whisper(nn.Module): 221 | def __init__(self, dims: ModelDimensions): 222 | super().__init__() 223 | self.dims = dims 224 | self.encoder = AudioEncoder( 225 | self.dims.n_mels, 226 | self.dims.n_audio_ctx, 227 | self.dims.n_audio_state, 228 | self.dims.n_audio_head, 229 | self.dims.n_audio_layer, 230 | ) 231 | self.decoder = TextDecoder( 232 | self.dims.n_vocab, 233 | self.dims.n_text_ctx, 234 | self.dims.n_text_state, 235 | self.dims.n_text_head, 236 | self.dims.n_text_layer, 237 | ) 238 | 239 | def embed_audio(self, mel: torch.Tensor): 240 | return self.encoder.forward(mel) 241 | 242 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 243 | return self.decoder.forward(tokens, audio_features) 244 | 245 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: 246 | return self.decoder(tokens, self.encoder(mel)) 247 | 248 | @property 249 | def device(self): 250 | return next(self.parameters()).device 251 | 252 | @property 253 | def is_multilingual(self): 254 | return self.dims.n_vocab == 51865 255 | 256 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 257 | """ 258 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 259 | tensors calculated for the previous positions. This method returns a dictionary that stores 260 | all caches, and the necessary hooks for the key and value projection modules that save the 261 | intermediate tensors to be reused during later calculations. 262 | 263 | Returns 264 | ------- 265 | cache : Dict[nn.Module, torch.Tensor] 266 | A dictionary object mapping the key/value projection modules to its cache 267 | hooks : List[RemovableHandle] 268 | List of PyTorch RemovableHandle objects to stop the hooks to be called 269 | """ 270 | cache = {**cache} if cache is not None else {} 271 | hooks = [] 272 | 273 | def save_to_cache(module, _, output): 274 | if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: 275 | cache[module] = output # save as-is, for the first token or cross attention 276 | else: 277 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 278 | return cache[module] 279 | 280 | def install_hooks(layer: nn.Module): 281 | if isinstance(layer, MultiHeadAttention): 282 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 283 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 284 | 285 | self.decoder.apply(install_hooks) 286 | return cache, hooks 287 | 288 | detect_language = detect_language_function 289 | transcribe = transcribe_function 290 | decode = decode_function 291 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from transformers import GPT2TokenizerFast 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "iw": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | } 111 | 112 | # language code lookup by name, with a few language aliases 113 | TO_LANGUAGE_CODE = { 114 | **{language: code for code, language in LANGUAGES.items()}, 115 | "burmese": "my", 116 | "valencian": "ca", 117 | "flemish": "nl", 118 | "haitian": "ht", 119 | "letzeburgesch": "lb", 120 | "pushto": "ps", 121 | "panjabi": "pa", 122 | "moldavian": "ro", 123 | "moldovan": "ro", 124 | "sinhalese": "si", 125 | "castilian": "es", 126 | } 127 | 128 | 129 | @dataclass(frozen=True) 130 | class Tokenizer: 131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 132 | 133 | tokenizer: "GPT2TokenizerFast" 134 | language: Optional[str] 135 | sot_sequence: Tuple[int] 136 | 137 | def encode(self, text, **kwargs): 138 | return self.tokenizer.encode(text, **kwargs) 139 | 140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): 141 | return self.tokenizer.decode(token_ids, **kwargs) 142 | 143 | def decode_with_timestamps(self, tokens) -> str: 144 | """ 145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 147 | """ 148 | outputs = [[]] 149 | for token in tokens: 150 | if token >= self.timestamp_begin: 151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 152 | outputs.append(timestamp) 153 | outputs.append([]) 154 | else: 155 | outputs[-1].append(token) 156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 157 | return "".join(outputs) 158 | 159 | @property 160 | @lru_cache() 161 | def eot(self) -> int: 162 | return self.tokenizer.eos_token_id 163 | 164 | @property 165 | @lru_cache() 166 | def sot(self) -> int: 167 | return self._get_single_token_id("<|startoftranscript|>") 168 | 169 | @property 170 | @lru_cache() 171 | def sot_lm(self) -> int: 172 | return self._get_single_token_id("<|startoflm|>") 173 | 174 | @property 175 | @lru_cache() 176 | def sot_prev(self) -> int: 177 | return self._get_single_token_id("<|startofprev|>") 178 | 179 | @property 180 | @lru_cache() 181 | def no_speech(self) -> int: 182 | return self._get_single_token_id("<|nospeech|>") 183 | 184 | @property 185 | @lru_cache() 186 | def no_timestamps(self) -> int: 187 | return self._get_single_token_id("<|notimestamps|>") 188 | 189 | @property 190 | @lru_cache() 191 | def timestamp_begin(self) -> int: 192 | return self.tokenizer.all_special_ids[-1] + 1 193 | 194 | @property 195 | @lru_cache() 196 | def language_token(self) -> int: 197 | """Returns the token id corresponding to the value of the `language` field""" 198 | if self.language is None: 199 | raise ValueError(f"This tokenizer does not have language token configured") 200 | 201 | additional_tokens = dict( 202 | zip( 203 | self.tokenizer.additional_special_tokens, 204 | self.tokenizer.additional_special_tokens_ids, 205 | ) 206 | ) 207 | candidate = f"<|{self.language}|>" 208 | if candidate in additional_tokens: 209 | return additional_tokens[candidate] 210 | 211 | raise KeyError(f"Language {self.language} not found in tokenizer.") 212 | 213 | @property 214 | @lru_cache() 215 | def all_language_tokens(self) -> Tuple[int]: 216 | result = [] 217 | for token, token_id in zip( 218 | self.tokenizer.additional_special_tokens, 219 | self.tokenizer.additional_special_tokens_ids, 220 | ): 221 | if token.strip("<|>") in LANGUAGES: 222 | result.append(token_id) 223 | return tuple(result) 224 | 225 | @property 226 | @lru_cache() 227 | def all_language_codes(self) -> Tuple[str]: 228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 229 | 230 | @property 231 | @lru_cache() 232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 233 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 234 | 235 | @property 236 | @lru_cache() 237 | def non_speech_tokens(self) -> Tuple[int]: 238 | """ 239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 241 | 242 | - ♪♪♪ 243 | - ( SPEAKING FOREIGN LANGUAGE ) 244 | - [DAVID] Hey there, 245 | 246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 247 | """ 248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 250 | 251 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 252 | # In case they're multiple tokens, suppress the first token, which is safe because: 253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 255 | miscellaneous = set("♩♪♫♬♭♮♯") 256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 257 | 258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 260 | for symbol in symbols + list(miscellaneous): 261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 262 | if len(tokens) == 1 or symbol in miscellaneous: 263 | result.add(tokens[0]) 264 | 265 | return tuple(sorted(result)) 266 | 267 | def _get_single_token_id(self, text) -> int: 268 | tokens = self.tokenizer.encode(text) 269 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 270 | return tokens[0] 271 | 272 | 273 | @lru_cache(maxsize=None) 274 | def build_tokenizer(name: str = "gpt2"): 275 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 276 | path = os.path.join(os.path.dirname(__file__), "assets", name) 277 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 278 | 279 | specials = [ 280 | "<|startoftranscript|>", 281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 282 | "<|translate|>", 283 | "<|transcribe|>", 284 | "<|startoflm|>", 285 | "<|startofprev|>", 286 | "<|nospeech|>", 287 | "<|notimestamps|>", 288 | ] 289 | 290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 291 | return tokenizer 292 | 293 | 294 | @lru_cache(maxsize=None) 295 | def get_tokenizer( 296 | multilingual: bool, 297 | *, 298 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 299 | language: Optional[str] = None, 300 | ) -> Tokenizer: 301 | if language is not None: 302 | language = language.lower() 303 | if language not in LANGUAGES: 304 | if language in TO_LANGUAGE_CODE: 305 | language = TO_LANGUAGE_CODE[language] 306 | else: 307 | raise ValueError(f"Unsupported language: {language}") 308 | 309 | if multilingual: 310 | tokenizer_name = "multilingual" 311 | task = task or "transcribe" 312 | language = language or "en" 313 | else: 314 | tokenizer_name = "gpt2" 315 | task = None 316 | language = None 317 | 318 | tokenizer = build_tokenizer(name=tokenizer_name) 319 | all_special_ids: List[int] = tokenizer.all_special_ids 320 | sot: int = all_special_ids[1] 321 | translate: int = all_special_ids[-6] 322 | transcribe: int = all_special_ids[-5] 323 | 324 | langs = tuple(LANGUAGES.keys()) 325 | sot_sequence = [sot] 326 | if language is not None: 327 | sot_sequence.append(sot + 1 + langs.index(language)) 328 | if task is not None: 329 | sot_sequence.append(transcribe if task == "transcribe" else translate) 330 | 331 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) 332 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 11 | from .decoding import DecodingOptions, DecodingResult 12 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 13 | from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def transcribe( 20 | model: "Whisper", 21 | audio: Union[str, np.ndarray, torch.Tensor], 22 | *, 23 | verbose: Optional[bool] = None, 24 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 25 | compression_ratio_threshold: Optional[float] = 2.4, 26 | logprob_threshold: Optional[float] = -1.0, 27 | no_speech_threshold: Optional[float] = 0.6, 28 | condition_on_previous_text: bool = True, 29 | force_extraction: bool = False, 30 | **decode_options, 31 | ): 32 | """ 33 | Transcribe an audio file using Whisper 34 | 35 | Parameters 36 | ---------- 37 | model: Whisper 38 | The Whisper model instance 39 | 40 | audio: Union[str, np.ndarray, torch.Tensor] 41 | The path to the audio file to open, or the audio waveform 42 | 43 | verbose: bool 44 | Whether to display the text being decoded to the console. If True, displays all the details, 45 | If False, displays minimal details. If None, does not display anything 46 | 47 | temperature: Union[float, Tuple[float, ...]] 48 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 49 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 50 | 51 | compression_ratio_threshold: float 52 | If the gzip compression ratio is above this value, treat as failed 53 | 54 | logprob_threshold: float 55 | If the average log probability over sampled tokens is below this value, treat as failed 56 | 57 | no_speech_threshold: float 58 | If the no_speech probability is higher than this value AND the average log probability 59 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 60 | 61 | condition_on_previous_text: bool 62 | if True, the previous output of the model is provided as a prompt for the next window; 63 | disabling may make the text inconsistent across windows, but the model becomes less prone to 64 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 65 | 66 | decode_options: dict 67 | Keyword arguments to construct `DecodingOptions` instances 68 | 69 | Returns 70 | ------- 71 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 72 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 73 | """ 74 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 75 | if model.device == torch.device("cpu"): 76 | if torch.cuda.is_available(): 77 | warnings.warn("Performing inference on CPU when CUDA is available") 78 | if dtype == torch.float16: 79 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 80 | dtype = torch.float32 81 | 82 | if dtype == torch.float32: 83 | decode_options["fp16"] = False 84 | 85 | mel = log_mel_spectrogram(audio) 86 | 87 | all_segments = [] 88 | def add_segment( 89 | *, start: float, end: float, encoder_embeddings 90 | ): 91 | 92 | all_segments.append( 93 | { 94 | "start": start, 95 | "end": end, 96 | "encoder_embeddings":encoder_embeddings, 97 | } 98 | ) 99 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 100 | num_frames = mel.shape[-1] 101 | seek = 0 102 | previous_seek_value = seek 103 | sample_skip = 3000 # 104 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 105 | while seek < num_frames: 106 | # seek是开始的帧数 107 | end_seek = min(seek + sample_skip, num_frames) 108 | segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype) 109 | 110 | single = segment.ndim == 2 111 | if single: 112 | segment = segment.unsqueeze(0) 113 | if dtype == torch.float16: 114 | segment = segment.half() 115 | audio_features, embeddings = model.encoder(segment, include_embeddings = True) 116 | 117 | encoder_embeddings = embeddings 118 | #print(f"encoder_embeddings shape {encoder_embeddings.shape}") 119 | add_segment( 120 | start=seek, 121 | end=end_seek, 122 | #text_tokens=tokens, 123 | #result=result, 124 | encoder_embeddings=encoder_embeddings, 125 | ) 126 | seek+=sample_skip 127 | 128 | return dict(segments=all_segments) 129 | 130 | 131 | def cli(): 132 | from . import available_models 133 | 134 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 135 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 136 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 137 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 138 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 139 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 140 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 141 | 142 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 143 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 144 | 145 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 146 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 147 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 148 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 149 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 150 | 151 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 152 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 153 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 154 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 155 | 156 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 157 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 158 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 159 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 160 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 161 | 162 | args = parser.parse_args().__dict__ 163 | model_name: str = args.pop("model") 164 | model_dir: str = args.pop("model_dir") 165 | output_dir: str = args.pop("output_dir") 166 | device: str = args.pop("device") 167 | os.makedirs(output_dir, exist_ok=True) 168 | 169 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 170 | if args["language"] is not None: 171 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") 172 | args["language"] = "en" 173 | 174 | temperature = args.pop("temperature") 175 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 176 | if temperature_increment_on_fallback is not None: 177 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 178 | else: 179 | temperature = [temperature] 180 | 181 | threads = args.pop("threads") 182 | if threads > 0: 183 | torch.set_num_threads(threads) 184 | 185 | from . import load_model 186 | model = load_model(model_name, device=device, download_root=model_dir) 187 | 188 | for audio_path in args.pop("audio"): 189 | result = transcribe(model, audio_path, temperature=temperature, **args) 190 | 191 | audio_basename = os.path.basename(audio_path) 192 | 193 | # save TXT 194 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: 195 | write_txt(result["segments"], file=txt) 196 | 197 | # save VTT 198 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: 199 | write_vtt(result["segments"], file=vtt) 200 | 201 | # save SRT 202 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: 203 | write_srt(result["segments"], file=srt) 204 | 205 | 206 | if __name__ == '__main__': 207 | cli() 208 | -------------------------------------------------------------------------------- /musetalk/whisper/whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 28 | 29 | 30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 31 | assert seconds >= 0, "non-negative timestamp expected" 32 | milliseconds = round(seconds * 1000.0) 33 | 34 | hours = milliseconds // 3_600_000 35 | milliseconds -= hours * 3_600_000 36 | 37 | minutes = milliseconds // 60_000 38 | milliseconds -= minutes * 60_000 39 | 40 | seconds = milliseconds // 1_000 41 | milliseconds -= seconds * 1_000 42 | 43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 45 | 46 | 47 | def write_txt(transcript: Iterator[dict], file: TextIO): 48 | for segment in transcript: 49 | print(segment['text'].strip(), file=file, flush=True) 50 | 51 | 52 | def write_vtt(transcript: Iterator[dict], file: TextIO): 53 | print("WEBVTT\n", file=file) 54 | for segment in transcript: 55 | print( 56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 57 | f"{segment['text'].strip().replace('-->', '->')}\n", 58 | file=file, 59 | flush=True, 60 | ) 61 | 62 | 63 | def write_srt(transcript: Iterator[dict], file: TextIO): 64 | """ 65 | Write a transcript to a file in SRT format. 66 | 67 | Example usage: 68 | from pathlib import Path 69 | from whisper.utils import write_srt 70 | 71 | result = transcribe(model, audio_path, temperature=temperature, **args) 72 | 73 | # save SRT 74 | audio_basename = Path(audio_path).stem 75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 76 | write_srt(result["segments"], file=srt) 77 | """ 78 | for i, segment in enumerate(transcript, start=1): 79 | # write srt lines 80 | print( 81 | f"{i}\n" 82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 84 | f"{segment['text'].strip().replace('-->', '->')}\n", 85 | file=file, 86 | flush=True, 87 | ) 88 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | from .inference import MuseTalk_INFER 4 | from .inference_realtime import Infer_Real_Time 5 | from pydub import AudioSegment 6 | from moviepy.editor import VideoFileClip,AudioFileClip 7 | 8 | parent_directory = os.path.dirname(os.path.abspath(__file__)) 9 | input_path = folder_paths.get_input_directory() 10 | out_path = folder_paths.get_output_directory() 11 | 12 | class MuseTalkRealTime: 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return { 16 | "required":{ 17 | "audio":("AUDIO",), 18 | "video":("VIDEO",), 19 | "avatar_id":("STRING",{ 20 | "default": "talker1" 21 | }), 22 | "bbox_shift":("INT",{ 23 | "default":0 24 | }), 25 | "fps":("INT",{ 26 | "default":25 27 | }), 28 | "batch_size":("INT",{ 29 | "default":4 30 | }), 31 | "preparation":("BOOLEAN",{ 32 | "default":True 33 | }) 34 | } 35 | } 36 | CATEGORY = "AIFSH_MuseTalk" 37 | DESCRIPTION = "hello world!" 38 | 39 | RETURN_TYPES = ("VIDEO",) 40 | 41 | OUTPUT_NODE = False 42 | 43 | FUNCTION = "process" 44 | 45 | def process(self,audio,video,avatar_id,bbox_shift,fps,batch_size,preparation): 46 | muse_talk_real_time = Infer_Real_Time() 47 | output_vid_name = muse_talk_real_time(audio, video,avatar_id,fps=fps,batch_size=batch_size, 48 | preparation=preparation,bbox_shift=bbox_shift) 49 | return (output_vid_name,) 50 | 51 | 52 | class MuseTalk: 53 | @classmethod 54 | def INPUT_TYPES(s): 55 | return { 56 | "required":{ 57 | "audio":("AUDIO",), 58 | "video":("VIDEO",), 59 | "bbox_shift":("INT",{ 60 | "default":0 61 | }), 62 | "fps":("INT",{ 63 | "default":25 64 | }), 65 | "batch_size":("INT",{ 66 | "default":8 67 | }), 68 | "batch_size_fa":("INT",{ 69 | "default":2 70 | }), 71 | "use_saved_coord":("BOOLEAN",{ 72 | "default":False 73 | }) 74 | } 75 | } 76 | CATEGORY = "AIFSH_MuseTalk" 77 | DESCRIPTION = "hello world!" 78 | 79 | RETURN_TYPES = ("VIDEO",) 80 | 81 | OUTPUT_NODE = False 82 | 83 | FUNCTION = "process" 84 | 85 | def process(self,audio,video,bbox_shift,fps,batch_size,batch_size_fa,use_saved_coord): 86 | muse_talk = MuseTalk_INFER(bbox_shift,fps,batch_size,batch_size_fa,use_saved_coord) 87 | output_vid_name = muse_talk(video, audio) 88 | return (output_vid_name,) 89 | 90 | 91 | class CombineAudioVideo: 92 | @classmethod 93 | def INPUT_TYPES(s): 94 | return {"required": 95 | {"vocal_AUDIO": ("AUDIO",), 96 | "bgm_AUDIO": ("AUDIO",), 97 | "video": ("VIDEO",) 98 | } 99 | } 100 | 101 | CATEGORY = "AIFSH_MuseTalk" 102 | DESCRIPTION = "hello world!" 103 | 104 | RETURN_TYPES = ("VIDEO",) 105 | 106 | OUTPUT_NODE = False 107 | 108 | FUNCTION = "combine" 109 | 110 | def combine(self, vocal_AUDIO,bgm_AUDIO,video): 111 | vocal = AudioSegment.from_file(vocal_AUDIO) 112 | bgm = AudioSegment.from_file(bgm_AUDIO) 113 | audio = vocal.overlay(bgm) 114 | audio_file = os.path.join(out_path,"ip_lap_voice.wav") 115 | audio.export(audio_file, format="wav") 116 | cm_video_file = os.path.join(out_path,"voice_"+os.path.basename(video)) 117 | video_clip = VideoFileClip(video) 118 | audio_clip = AudioFileClip(audio_file) 119 | new_video_clip = video_clip.set_audio(audio_clip) 120 | new_video_clip.write_videofile(cm_video_file) 121 | return (cm_video_file,) 122 | 123 | 124 | class PreViewVideo: 125 | @classmethod 126 | def INPUT_TYPES(s): 127 | return {"required":{ 128 | "video":("VIDEO",), 129 | }} 130 | 131 | CATEGORY = "AIFSH_MuseTalk" 132 | DESCRIPTION = "hello world!" 133 | 134 | RETURN_TYPES = () 135 | 136 | OUTPUT_NODE = True 137 | 138 | FUNCTION = "load_video" 139 | 140 | def load_video(self, video): 141 | video_name = os.path.basename(video) 142 | video_path_name = os.path.basename(os.path.dirname(video)) 143 | return {"ui":{"video":[video_name,video_path_name]}} 144 | 145 | class LoadVideo: 146 | @classmethod 147 | def INPUT_TYPES(s): 148 | files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]] 149 | return {"required":{ 150 | "video":(files,), 151 | }} 152 | 153 | CATEGORY = "AIFSH_MuseTalk" 154 | DESCRIPTION = "hello world!" 155 | 156 | RETURN_TYPES = ("VIDEO","AUDIO") 157 | 158 | OUTPUT_NODE = False 159 | 160 | FUNCTION = "load_video" 161 | 162 | def load_video(self, video): 163 | video_path = os.path.join(input_path,video) 164 | video_clip = VideoFileClip(video_path) 165 | audio_path = os.path.join(input_path,video+".wav") 166 | video_clip.audio.write_audiofile(audio_path) 167 | return (video_path,audio_path,) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | ffmpeg-python 3 | soundfile 4 | diffusers 5 | pydub 6 | moviepy 7 | accelerate -------------------------------------------------------------------------------- /web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/web.png -------------------------------------------------------------------------------- /web/js/previewVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | 4 | function fitHeight(node) { 5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 6 | node?.graph?.setDirtyCanvas(true); 7 | } 8 | function chainCallback(object, property, callback) { 9 | if (object == undefined) { 10 | //This should not happen. 11 | console.error("Tried to add callback to non-existant object") 12 | return; 13 | } 14 | if (property in object) { 15 | const callback_orig = object[property] 16 | object[property] = function () { 17 | const r = callback_orig.apply(this, arguments); 18 | callback.apply(this, arguments); 19 | return r 20 | }; 21 | } else { 22 | object[property] = callback; 23 | } 24 | } 25 | 26 | function addPreviewOptions(nodeType) { 27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) { 28 | // The intended way of appending options is returning a list of extra options, 29 | // but this isn't used in widgetInputs.js and would require 30 | // less generalization of chainCallback 31 | let optNew = [] 32 | try { 33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview"); 34 | 35 | let url = null 36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) { 37 | //Use full quality video 38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params)); 39 | url = previewWidget.videoEl.src 40 | } 41 | if (url) { 42 | optNew.push( 43 | { 44 | content: "Open preview", 45 | callback: () => { 46 | window.open(url, "_blank") 47 | }, 48 | }, 49 | { 50 | content: "Save preview", 51 | callback: () => { 52 | const a = document.createElement("a"); 53 | a.href = url; 54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename")); 55 | document.body.append(a); 56 | a.click(); 57 | requestAnimationFrame(() => a.remove()); 58 | }, 59 | } 60 | ); 61 | } 62 | if(options.length > 0 && options[0] != null && optNew.length > 0) { 63 | optNew.push(null); 64 | } 65 | options.unshift(...optNew); 66 | 67 | } catch (error) { 68 | console.log(error); 69 | } 70 | 71 | }); 72 | } 73 | function previewVideo(node,file,type){ 74 | var element = document.createElement("div"); 75 | const previewNode = node; 76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 77 | serialize: false, 78 | hideOnZoom: false, 79 | getValue() { 80 | return element.value; 81 | }, 82 | setValue(v) { 83 | element.value = v; 84 | }, 85 | }); 86 | previewWidget.computeSize = function(width) { 87 | if (this.aspectRatio && !this.parentEl.hidden) { 88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 89 | if (!(height > 0)) { 90 | height = 0; 91 | } 92 | this.computedHeight = height + 10; 93 | return [width, height]; 94 | } 95 | return [width, -4];//no loaded src, widget should not display 96 | } 97 | // element.style['pointer-events'] = "none" 98 | previewWidget.value = {hidden: false, paused: false, params: {}} 99 | previewWidget.parentEl = document.createElement("div"); 100 | previewWidget.parentEl.className = "video_preview"; 101 | previewWidget.parentEl.style['width'] = "100%" 102 | element.appendChild(previewWidget.parentEl); 103 | previewWidget.videoEl = document.createElement("video"); 104 | previewWidget.videoEl.controls = true; 105 | previewWidget.videoEl.loop = false; 106 | previewWidget.videoEl.muted = false; 107 | previewWidget.videoEl.style['width'] = "100%" 108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 109 | 110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 111 | fitHeight(this); 112 | }); 113 | previewWidget.videoEl.addEventListener("error", () => { 114 | //TODO: consider a way to properly notify the user why a preview isn't shown. 115 | previewWidget.parentEl.hidden = true; 116 | fitHeight(this); 117 | }); 118 | 119 | let params = { 120 | "filename": file, 121 | "type": type, 122 | } 123 | 124 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 126 | let target_width = 256 127 | if (element.style?.width) { 128 | //overscale to allow scrolling. Endpoint won't return higher than native 129 | target_width = element.style.width.slice(0,-2)*2; 130 | } 131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 132 | params.force_size = target_width+"x?" 133 | } else { 134 | let size = params.force_size.split("x") 135 | let ar = parseInt(size[0])/parseInt(size[1]) 136 | params.force_size = target_width+"x"+(target_width/ar) 137 | } 138 | 139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 140 | 141 | previewWidget.videoEl.hidden = false; 142 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 143 | } 144 | 145 | app.registerExtension({ 146 | name: "MuseTalk.VideoPreviewer", 147 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 148 | if (nodeData?.name == "PreViewVideo") { 149 | nodeType.prototype.onExecuted = function (data) { 150 | previewVideo(this, data.video[0], data.video[1]); 151 | } 152 | addPreviewOptions(nodeType) 153 | } 154 | } 155 | }); 156 | -------------------------------------------------------------------------------- /web/js/uploadVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function fitHeight(node) { 6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 7 | node?.graph?.setDirtyCanvas(true); 8 | } 9 | 10 | function previewVideo(node,file){ 11 | while (node.widgets.length > 2){ 12 | node.widgets.pop() 13 | } 14 | try { 15 | var el = document.getElementById("uploadVideo"); 16 | el.remove(); 17 | } catch (error) { 18 | console.log(error); 19 | } 20 | var element = document.createElement("div"); 21 | element.id = "uploadVideo"; 22 | const previewNode = node; 23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 24 | serialize: false, 25 | hideOnZoom: false, 26 | getValue() { 27 | return element.value; 28 | }, 29 | setValue(v) { 30 | element.value = v; 31 | }, 32 | }); 33 | previewWidget.computeSize = function(width) { 34 | if (this.aspectRatio && !this.parentEl.hidden) { 35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 36 | if (!(height > 0)) { 37 | height = 0; 38 | } 39 | this.computedHeight = height + 10; 40 | return [width, height]; 41 | } 42 | return [width, -4];//no loaded src, widget should not display 43 | } 44 | // element.style['pointer-events'] = "none" 45 | previewWidget.value = {hidden: false, paused: false, params: {}} 46 | previewWidget.parentEl = document.createElement("div"); 47 | previewWidget.parentEl.className = "video_preview"; 48 | previewWidget.parentEl.style['width'] = "100%" 49 | element.appendChild(previewWidget.parentEl); 50 | previewWidget.videoEl = document.createElement("video"); 51 | previewWidget.videoEl.controls = true; 52 | previewWidget.videoEl.loop = false; 53 | previewWidget.videoEl.muted = false; 54 | previewWidget.videoEl.style['width'] = "100%" 55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 56 | 57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 58 | fitHeight(this); 59 | }); 60 | previewWidget.videoEl.addEventListener("error", () => { 61 | //TODO: consider a way to properly notify the user why a preview isn't shown. 62 | previewWidget.parentEl.hidden = true; 63 | fitHeight(this); 64 | }); 65 | 66 | let params = { 67 | "filename": file, 68 | "type": "input", 69 | } 70 | 71 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 73 | let target_width = 256 74 | if (element.style?.width) { 75 | //overscale to allow scrolling. Endpoint won't return higher than native 76 | target_width = element.style.width.slice(0,-2)*2; 77 | } 78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 79 | params.force_size = target_width+"x?" 80 | } else { 81 | let size = params.force_size.split("x") 82 | let ar = parseInt(size[0])/parseInt(size[1]) 83 | params.force_size = target_width+"x"+(target_width/ar) 84 | } 85 | 86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 87 | 88 | previewWidget.videoEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 90 | } 91 | 92 | function videoUpload(node, inputName, inputData, app) { 93 | const videoWidget = node.widgets.find((w) => w.name === "video"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = videoWidget.value; 99 | Object.defineProperty(videoWidget, "value", { 100 | set : function(value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get : function() { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if(real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | } 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!videoWidget.options.values.includes(path)) { 145 | videoWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | videoWidget.value = path; 150 | previewVideo(node,path) 151 | 152 | } 153 | } else { 154 | alert(resp.status + " - " + resp.statusText); 155 | } 156 | } catch (error) { 157 | alert(error); 158 | } 159 | } 160 | 161 | const fileInput = document.createElement("input"); 162 | Object.assign(fileInput, { 163 | type: "file", 164 | accept: "video/webm,video/mp4,video/mkv,video/avi", 165 | style: "display: none", 166 | onchange: async () => { 167 | if (fileInput.files.length) { 168 | await uploadFile(fileInput.files[0], true); 169 | } 170 | }, 171 | }); 172 | document.body.append(fileInput); 173 | 174 | // Create the button widget for selecting the files 175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => { 176 | fileInput.click(); 177 | }); 178 | 179 | uploadWidget.serialize = false; 180 | 181 | previewVideo(node, videoWidget.value); 182 | const cb = node.callback; 183 | videoWidget.callback = function () { 184 | previewVideo(node,videoWidget.value); 185 | if (cb) { 186 | return cb.apply(this, arguments); 187 | } 188 | }; 189 | 190 | return { widget: uploadWidget }; 191 | } 192 | 193 | ComfyWidgets.VIDEOPLOAD = videoUpload; 194 | 195 | app.registerExtension({ 196 | name: "MuseTalk.UploadVideo", 197 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 198 | if (nodeData?.name == "LoadVideo") { 199 | nodeData.input.required.upload = ["VIDEOPLOAD"]; 200 | } 201 | }, 202 | }); 203 | 204 | -------------------------------------------------------------------------------- /wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/wechat.jpg --------------------------------------------------------------------------------