├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── donate.jpg
├── inference.py
├── inference_realtime.py
├── musetalk
├── models
│ ├── __pycache__
│ │ ├── unet.cpython-310.pyc
│ │ └── vae.cpython-310.pyc
│ ├── unet.py
│ └── vae.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── blending.cpython-310.pyc
│ │ ├── preprocessing.cpython-310.pyc
│ │ └── utils.cpython-310.pyc
│ ├── blending.py
│ ├── dwpose
│ │ ├── default_runtime.py
│ │ └── rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py
│ ├── face_detection
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── api.cpython-310.pyc
│ │ │ ├── models.cpython-310.pyc
│ │ │ └── utils.cpython-310.pyc
│ │ ├── api.py
│ │ ├── detection
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── core.cpython-310.pyc
│ │ │ ├── core.py
│ │ │ └── sfd
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── bbox.cpython-310.pyc
│ │ │ │ ├── detect.cpython-310.pyc
│ │ │ │ ├── net_s3fd.cpython-310.pyc
│ │ │ │ └── sfd_detector.cpython-310.pyc
│ │ │ │ ├── bbox.py
│ │ │ │ ├── detect.py
│ │ │ │ ├── net_s3fd.py
│ │ │ │ └── sfd_detector.py
│ │ ├── models.py
│ │ └── utils.py
│ ├── face_parsing
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── model.cpython-310.pyc
│ │ │ └── resnet.cpython-310.pyc
│ │ ├── model.py
│ │ └── resnet.py
│ ├── preprocessing.py
│ └── utils.py
└── whisper
│ ├── __pycache__
│ └── audio2feature.cpython-310.pyc
│ ├── audio2feature.py
│ └── whisper
│ ├── __init__.py
│ ├── __main__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── audio.cpython-310.pyc
│ ├── decoding.cpython-310.pyc
│ ├── model.cpython-310.pyc
│ ├── tokenizer.cpython-310.pyc
│ ├── transcribe.cpython-310.pyc
│ └── utils.cpython-310.pyc
│ ├── assets
│ ├── gpt2
│ │ ├── merges.txt
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ └── vocab.json
│ ├── mel_filters.npz
│ └── multilingual
│ │ ├── added_tokens.json
│ │ ├── merges.txt
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ └── vocab.json
│ ├── audio.py
│ ├── decoding.py
│ ├── model.py
│ ├── normalizers
│ ├── __init__.py
│ ├── basic.py
│ ├── english.json
│ └── english.py
│ ├── tokenizer.py
│ ├── transcribe.py
│ └── utils.py
├── nodes.py
├── requirements.txt
├── web.png
├── web
└── js
│ ├── previewVideo.js
│ └── uploadVideo.js
└── wechat.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /models
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | 本软件及其相关代码以MIT协议开源,作者不对软件具备任何控制力,使用软件者、传播软件导出的声音者自负全责。
2 | 如不认可该条款,则不能使用或引用软件包内任何代码和文件。
3 |
4 | 特此授予任何获得本软件和相关文档文件(以下简称“软件”)副本的人免费使用、复制、修改、合并、出版、分发、再授权和/或销售本软件的权利,以及授予本软件所提供的人使用本软件的权利,但须符合以下条件:
5 | 上述版权声明和本许可声明应包含在软件的所有副本或实质部分中。
6 | 软件是“按原样”提供的,没有任何明示或暗示的保证,包括但不限于适销性、适用于特定目的和不侵权的保证。在任何情况下,作者或版权持有人均不承担因软件或软件的使用或其他交易而产生、产生或与之相关的任何索赔、损害赔偿或其他责任,无论是在合同诉讼、侵权诉讼还是其他诉讼中。
7 |
8 |
9 |
10 | MIT License
11 |
12 | Copyright (c) 2024 AIFSH
13 |
14 | Permission is hereby granted, free of charge, to any person obtaining a copy
15 | of this software and associated documentation files (the "Software"), to deal
16 | in the Software without restriction, including without limitation the rights
17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18 | copies of the Software, and to permit persons to whom the Software is
19 | furnished to do so, subject to the following conditions:
20 |
21 | The above copyright notice and this permission notice shall be included in all
22 | copies or substantial portions of the Software.
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 | SOFTWARE.
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-MuseTalk_FSH
2 | the comfyui custom node of [MuseTalk](https://github.com/TMElyralab/MuseTalk.git) to make audio driven videos!
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## How to use
10 | make sure `ffmpeg` is worked in your commandline
11 | for Linux
12 | ```
13 | apt update
14 | apt install ffmpeg
15 | ```
16 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically
17 |
18 | then!
19 | ```
20 | git clone https://github.com/AIFSH/ComfyUI-MuseTalk_FSH.git
21 | cd ComfyUI-MuseTalk_FSH
22 | pip install -r requirements.txt
23 | ```
24 | ### mmlab packages
25 | ```bash
26 | pip install --no-cache-dir -U openmim
27 | mim install mmengine
28 | mim install "mmcv>=2.0.1"
29 | mim install "mmdet>=3.1.0"
30 | mim install "mmpose>=1.1.0"
31 | ```
32 |
33 | ### Download weights
34 | You can download weights manually as follows:
35 |
36 | 1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
37 |
38 | 2. Download the weights of other components:
39 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
40 | - [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
41 | - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
42 | - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
43 | - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
44 |
45 | 或者下载[MuseTalk.zip](https://pan.quark.cn/s/d6e76084ae92),
46 | 解压后把子文件夹放入`ComfyUI-MuseTalk_FSH/models/`目录
47 |
48 | Finally, these weights should be organized in `models` as follows:
49 | ```
50 | ComfyUI-MuseTalk_FSH/models/
51 | ├── musetalk
52 | │ └── musetalk.json
53 | │ └── pytorch_model.bin
54 | ├── dwpose
55 | │ └── dw-ll_ucoco_384.pth
56 | ├── face-parse-bisent
57 | │ ├── 79999_iter.pth
58 | │ └── resnet18-5c106cde.pth
59 | ├── sd-vae-ft-mse
60 | │ ├── config.json
61 | │ └── diffusion_pytorch_model.bin
62 | └── whisper
63 | └── tiny.pt
64 | ```
65 |
66 | ## Tutorial
67 | - [Demo on 3060 12GB](https://www.bilibili.com/video/BV1St421w7Qn)
68 | - [Demo on 4090 24GB](https://www.bilibili.com/video/BV13T42117uM/)
69 |
70 |
71 | ## WeChat Group && Donate
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | ## Thanks
80 | - [MuseTalk](https://github.com/TMElyralab/MuseTalk.git)
81 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import MuseTalk,LoadVideo,PreViewVideo,CombineAudioVideo,MuseTalkRealTime
2 | WEB_DIRECTORY = "./web"
3 | # A dictionary that contains all nodes you want to export with their names
4 | # NOTE: names should be globally unique
5 | NODE_CLASS_MAPPINGS = {
6 | "MuseTalk": MuseTalk,
7 | "LoadVideo": LoadVideo,
8 | "PreViewVideo": PreViewVideo,
9 | "CombineAudioVideo": CombineAudioVideo,
10 | "MuseTalkRealTime": MuseTalkRealTime
11 | }
12 |
13 | # A dictionary that contains the friendly/humanly readable titles for the nodes
14 | NODE_DISPLAY_NAME_MAPPINGS = {
15 | "MuseTalk": "MuseTalk Node",
16 | "LoadVideo": "Video Loader",
17 | "PreViewVideo": "PreView Video",
18 | "CombineAudioVideo": "Combine Audio Video",
19 | "MuseTalkRealTime": "MuseTalk RealTime Node"
20 | }
21 |
--------------------------------------------------------------------------------
/donate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/donate.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import cv2
4 | import glob,copy
5 | import shutil
6 | from tqdm import tqdm
7 | import torch
8 | import pickle
9 | import numpy as np
10 | from cuda_malloc import cuda_malloc_supported
11 | from typing import Any
12 | import folder_paths
13 | from .musetalk.utils.face_parsing import FaceParsing
14 | from mmpose.apis import init_model
15 | from .musetalk.utils.blending import get_image
16 | from .musetalk.utils.utils import load_all_model,get_file_type,get_video_fps,datagen
17 | from .musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
18 |
19 | parent_directory = os.path.dirname(os.path.abspath(__file__))
20 | input_path = folder_paths.get_input_directory()
21 | out_path = folder_paths.get_output_directory()
22 |
23 | class MuseTalk_INFER:
24 | def __init__(self,bbox_shift=0,fps=25,
25 | batch_size=8,batch_size_fa=2,
26 | use_saved_coord=False) -> None:
27 | self.fps = fps
28 | self.bbox_shift = bbox_shift
29 | self.batch_size = batch_size
30 | self.batch_size_fa = batch_size_fa
31 | self.use_saved_coord = use_saved_coord
32 | self.device = torch.device("cuda" if cuda_malloc_supported() else "cpu")
33 | config_file = os.path.join(parent_directory,"musetalk","utils","dwpose","rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py")
34 | checkpoint_file = os.path.join(parent_directory,'models','dwpose','dw-ll_ucoco_384.pth')
35 | resnet_path = os.path.join(parent_directory,'models','face-parse-bisent','resnet18-5c106cde.pth')
36 | face_model_pth = os.path.join(parent_directory,"models','face-parse-bisent','79999_iter.pth")
37 | self.fp_model = FaceParsing(resnet_path,face_model_pth)
38 | self.dwpose_model = init_model(config_file, checkpoint_file, device=self.device)
39 | self.audio_processor,self.vae,self.unet,self.pe = load_all_model(os.path.join(parent_directory,"models"))
40 | self.timesteps = torch.tensor([0], device=self.device)
41 |
42 | def __call__(self, video_path,audio_path,*args: Any, **kwds: Any) -> Any:
43 | input_basename = os.path.basename(video_path).split('.')[0]
44 | audio_basename = os.path.basename(audio_path).split('.')[0]
45 | output_basename = f"{input_basename}_{audio_basename}"
46 | result_img_save_path = os.path.join(out_path,"musetalk_result", output_basename) # related to video & audio inputs
47 | os.makedirs(result_img_save_path, exist_ok=True)
48 | crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
49 | output_vid_name = os.path.join(out_path, output_basename+".mp4")
50 |
51 | ############################################## extract frames from source video ##############################################
52 | if get_file_type(video_path)=="video":
53 | save_dir_full = os.path.join(out_path,"musetalk_result",input_basename)
54 | os.makedirs(save_dir_full,exist_ok = True)
55 | png_path = os.path.join(save_dir_full,"%08d.png")
56 | cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {png_path}"
57 | os.system(cmd)
58 | input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
59 | fps = get_video_fps(video_path)
60 | else: # input img folder
61 | input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
62 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
63 | fps = self.fps
64 |
65 | #print(input_img_list)
66 | ############################################## extract audio feature ##############################################
67 | whisper_feature = self.audio_processor.audio2feat(audio_path)
68 | whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
69 | ############################################## preprocess input image ##############################################
70 | if os.path.exists(crop_coord_save_path) and self.use_saved_coord:
71 | print("using extracted coordinates")
72 | with open(crop_coord_save_path,'rb') as f:
73 | coord_list = pickle.load(f)
74 | frame_list = read_imgs(input_img_list)
75 | else:
76 | print("extracting landmarks...time consuming")
77 | coord_list, frame_list = get_landmark_and_bbox(self.dwpose_model,input_img_list, self.batch_size_fa,self.bbox_shift)
78 | with open(crop_coord_save_path, 'wb') as f:
79 | pickle.dump(coord_list, f)
80 |
81 | i = 0
82 | input_latent_list = []
83 | for bbox, frame in zip(coord_list, frame_list):
84 | if bbox == coord_placeholder:
85 | continue
86 | x1, y1, x2, y2 = bbox
87 | crop_frame = frame[y1:y2, x1:x2]
88 | crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
89 | latents = self.vae.get_latents_for_unet(crop_frame)
90 | input_latent_list.append(latents)
91 |
92 | # to smooth the first and the last frame
93 | frame_list_cycle = frame_list + frame_list[::-1]
94 | coord_list_cycle = coord_list + coord_list[::-1]
95 | input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
96 | ############################################## inference batch by batch ##############################################
97 | print("start inference")
98 | video_num = len(whisper_chunks)
99 | batch_size = self.batch_size
100 | gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
101 | res_frame_list = []
102 | for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
103 | tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
104 | audio_feature_batch = torch.stack(tensor_list).to(self.unet.device) # torch, B, 5*N,384
105 | audio_feature_batch = self.pe(audio_feature_batch)
106 |
107 | pred_latents = self.unet.model(latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch).sample
108 | recon = self.vae.decode_latents(pred_latents)
109 | for res_frame in recon:
110 | res_frame_list.append(res_frame)
111 |
112 | ############################################## pad to full image ##############################################
113 | print("pad talking image to original video")
114 | for i, res_frame in enumerate(tqdm(res_frame_list)):
115 | bbox = coord_list_cycle[i%(len(coord_list_cycle))]
116 | ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
117 | x1, y1, x2, y2 = bbox
118 | try:
119 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
120 | except:
121 | # print(bbox)
122 | continue
123 |
124 | combine_frame = get_image(self.fp_model, ori_frame,res_frame,bbox)
125 | cv2.imwrite(os.path.join(result_img_save_path,f"{str(i).zfill(8)}.png"),combine_frame)
126 |
127 | res_tmp_path = os.path.join(result_img_save_path,"%08d.png")
128 | cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {res_tmp_path} -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {output_vid_name}"
129 | print(cmd_img2video)
130 | os.system(cmd_img2video)
131 | '''
132 | cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
133 | print(cmd_combine_audio)
134 | os.system(cmd_combine_audio)
135 | '''
136 | del self.fp_model,self.dwpose_model,self.audio_processor,self.vae,self.unet,self.pe
137 | torch.cuda.empty_cache()
138 | # os.remove("temp.mp4")
139 | shutil.rmtree(result_img_save_path)
140 | print(f"result is save to {output_vid_name}")
141 | return output_vid_name
142 |
--------------------------------------------------------------------------------
/inference_realtime.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import json
5 | import torch
6 | import shutil
7 | import pickle
8 | import glob,time
9 | import queue,copy
10 | import threading
11 | from tqdm import tqdm
12 | import numpy as np
13 | import folder_paths
14 | from cuda_malloc import cuda_malloc_supported
15 | from typing import Any
16 | from .musetalk.utils.face_parsing import FaceParsing
17 | from mmpose.apis import init_model
18 | from .musetalk.utils.utils import load_all_model,datagen
19 | from .musetalk.utils.preprocessing import read_imgs,get_landmark_and_bbox
20 | from .musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
21 |
22 | parent_directory = os.path.dirname(os.path.abspath(__file__))
23 | device = torch.device("cuda" if cuda_malloc_supported() else "cpu")
24 | timesteps = torch.tensor([0], device=device)
25 |
26 | output_path = folder_paths.get_output_directory()
27 | musetalk_out_path = os.path.join(output_path,"musetalk_realtime")
28 | os.makedirs(musetalk_out_path, exist_ok=True)
29 |
30 | def osmakedirs(path_list):
31 | for path in path_list:
32 | os.makedirs(path) if not os.path.exists(path) else None
33 |
34 | def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
35 | cap = cv2.VideoCapture(vid_path)
36 | count = 0
37 | while True:
38 | if count > cut_frame:
39 | break
40 | ret, frame = cap.read()
41 | if ret:
42 | cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
43 | count += 1
44 | else:
45 | break
46 |
47 | @torch.no_grad()
48 | class Avatar:
49 | def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
50 | self.avatar_id = avatar_id
51 | self.video_path = video_path
52 | self.bbox_shift = bbox_shift
53 | self.avatar_path = os.path.join(musetalk_out_path,avatar_id)
54 | self.full_imgs_path = os.path.join(self.avatar_path, "full_imgs")
55 | self.coords_path = os.path.join(self.avatar_path, "coords.pkl")
56 | self.latents_out_path= os.path.join(self.avatar_path, "latents.pt")
57 | self.video_out_path = output_path
58 | self.mask_out_path = os.path.join(self.avatar_path, "mask")
59 | self.mask_coords_path = os.path.join(self.avatar_path, "mask_coords.pkl")
60 | self.avatar_info_path = os.path.join(self.avatar_path, "avator_info.json")
61 | self.avatar_info = {
62 | "avatar_id":avatar_id,
63 | "video_path":video_path,
64 | "bbox_shift":bbox_shift
65 | }
66 | self.preparation = preparation
67 | self.batch_size = batch_size
68 | self.idx = 0
69 | # load model weights
70 | config_file = os.path.join(parent_directory,"musetalk","utils","dwpose","rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py")
71 | checkpoint_file = os.path.join(parent_directory,"models",'dwpose','dw-ll_ucoco_384.pth')
72 | resnet_path = os.path.join(parent_directory,'models','face-parse-bisent','resnet18-5c106cde.pth')
73 | face_model_pth = os.path.join(parent_directory,'models','face-parse-bisent',"79999_iter.pth")
74 |
75 | self.fp_model = FaceParsing(resnet_path,face_model_pth)
76 | self.dwpose_model = init_model(config_file, checkpoint_file, device=device)
77 | self.audio_processor,self.vae,self.unet,self.pe = load_all_model(os.path.join(parent_directory,"models"))
78 | self.vae.vae = self.vae.vae.half()
79 | self.unet.model = self.unet.model.half()
80 | self.pe = self.pe.half()
81 | self.init()
82 |
83 | def init(self):
84 | if self.preparation:
85 | if os.path.exists(self.avatar_path):
86 | response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)")
87 | if response.lower() == "y":
88 | shutil.rmtree(self.avatar_path)
89 | print("*********************************")
90 | print(f" creating avator: {self.avatar_id}")
91 | print("*********************************")
92 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
93 | self.prepare_material()
94 | else:
95 | self.input_latent_list_cycle = torch.load(self.latents_out_path)
96 | with open(self.coords_path, 'rb') as f:
97 | self.coord_list_cycle = pickle.load(f)
98 | input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
99 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
100 | self.frame_list_cycle = read_imgs(input_img_list)
101 | with open(self.mask_coords_path, 'rb') as f:
102 | self.mask_coords_list_cycle = pickle.load(f)
103 | input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
104 | input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
105 | self.mask_list_cycle = read_imgs(input_mask_list)
106 | else:
107 | print("*********************************")
108 | print(f" creating avator: {self.avatar_id}")
109 | print("*********************************")
110 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
111 | self.prepare_material()
112 | else:
113 | with open(self.avatar_info_path, "r") as f:
114 | avatar_info = json.load(f)
115 |
116 | if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
117 | response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
118 | if response.lower() == "c":
119 | shutil.rmtree(self.avatar_path)
120 | print("*********************************")
121 | print(f" creating avator: {self.avatar_id}")
122 | print("*********************************")
123 | osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
124 | self.prepare_material()
125 | else:
126 | sys.exit()
127 | else:
128 | self.input_latent_list_cycle = torch.load(self.latents_out_path)
129 | with open(self.coords_path, 'rb') as f:
130 | self.coord_list_cycle = pickle.load(f)
131 | input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
132 | input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
133 | self.frame_list_cycle = read_imgs(input_img_list)
134 | with open(self.mask_coords_path, 'rb') as f:
135 | self.mask_coords_list_cycle = pickle.load(f)
136 | input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
137 | input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
138 | self.mask_list_cycle = read_imgs(input_mask_list)
139 | try:
140 | del self.dwpose_model,self.fp_model
141 | import gc; gc.collect(); torch.cuda.empty_cache();
142 | except:
143 | pass
144 |
145 | def prepare_material(self):
146 | print("preparing data materials ... ...")
147 | with open(self.avatar_info_path, "w") as f:
148 | json.dump(self.avatar_info, f)
149 |
150 | if os.path.isfile(self.video_path):
151 | video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
152 | else:
153 | print(f"copy files in {self.video_path}")
154 | files = os.listdir(self.video_path)
155 | files.sort()
156 | files = [file for file in files if file.split(".")[-1]=="png"]
157 | for filename in files:
158 | shutil.copyfile(os.path.join(self.video_path,filename), os.path.join(self.full_imgs_path,filename))
159 | input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
160 |
161 | print("extracting landmarks...")
162 | coord_list, frame_list = get_landmark_and_bbox(self.dwpose_model,input_img_list,1,self.bbox_shift,)
163 | input_latent_list = []
164 | idx = -1
165 | # maker if the bbox is not sufficient
166 | coord_placeholder = (0.0,0.0,0.0,0.0)
167 | for bbox, frame in zip(coord_list, frame_list):
168 | idx = idx + 1
169 | if bbox == coord_placeholder:
170 | continue
171 | x1, y1, x2, y2 = bbox
172 | crop_frame = frame[y1:y2, x1:x2]
173 |
174 | resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
175 | latents = self.vae.get_latents_for_unet(resized_crop_frame)
176 | input_latent_list.append(latents)
177 |
178 |
179 | self.frame_list_cycle = frame_list + frame_list[::-1]
180 | self.coord_list_cycle = coord_list + coord_list[::-1]
181 | self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
182 | self.mask_coords_list_cycle = []
183 | self.mask_list_cycle = []
184 |
185 | for i,frame in enumerate(tqdm(self.frame_list_cycle)):
186 | cv2.imwrite(os.path.join(self.full_imgs_path,str(i).zfill(8)+".png"),frame)
187 |
188 | face_box = self.coord_list_cycle[i]
189 | mask,crop_box = get_image_prepare_material(self.fp_model,frame,face_box)
190 | cv2.imwrite(os.path.join(self.mask_out_path,str(i).zfill(8)+".png"),mask)
191 | self.mask_coords_list_cycle += [crop_box]
192 | self.mask_list_cycle.append(mask)
193 |
194 | with open(self.mask_coords_path, 'wb') as f:
195 | pickle.dump(self.mask_coords_list_cycle, f)
196 |
197 | with open(self.coords_path, 'wb') as f:
198 | pickle.dump(self.coord_list_cycle, f)
199 |
200 | torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
201 |
202 |
203 |
204 | def process_frames(self, res_frame_queue,video_len):
205 | print(video_len)
206 | while True:
207 | if self.idx>=video_len-1:
208 | break
209 | try:
210 | start = time.time()
211 | res_frame = res_frame_queue.get(block=True, timeout=1)
212 | except queue.Empty:
213 | continue
214 |
215 | bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
216 | ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
217 | x1, y1, x2, y2 = bbox
218 | try:
219 | res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
220 | except:
221 | continue
222 | mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
223 | mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
224 | #combine_frame = get_image(ori_frame,res_frame,bbox)
225 | combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
226 |
227 | fps = 1/(time.time()-start+1e-6)
228 | print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}")
229 | cv2.imwrite(os.path.join(self.avatar_path,"tmp",str(self.idx).zfill(8)+".png"),combine_frame)
230 | self.idx = self.idx + 1
231 |
232 | def inference(self, audio_path, out_vid_name, fps):
233 | os.makedirs(os.path.join(self.avatar_path,'tmp'),exist_ok =True)
234 | ############################################## extract audio feature ##############################################
235 | whisper_feature = self.audio_processor.audio2feat(audio_path)
236 | whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
237 | ############################################## inference batch by batch ##############################################
238 | video_num = len(whisper_chunks)
239 | print("start inference")
240 | res_frame_queue = queue.Queue()
241 | self.idx = 0
242 | # # Create a sub-thread and start it
243 | process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
244 | process_thread.start()
245 | start_time = time.time()
246 | gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size)
247 | print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
248 | start_time = time.time()
249 | res_frame_list = []
250 |
251 | for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
252 | start_time = time.time()
253 | tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
254 | audio_feature_batch = torch.stack(tensor_list).to(self.unet.device) # torch, B, 5*N,384
255 | audio_feature_batch = self.pe(audio_feature_batch)
256 |
257 | pred_latents = self.unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
258 | recon = self.vae.decode_latents(pred_latents)
259 | for res_frame in recon:
260 | res_frame_queue.put(res_frame)
261 | # Close the queue and sub-thread after all tasks are completed
262 | process_thread.join()
263 |
264 | if out_vid_name is not None:
265 | # optional
266 | img_path = os.path.join(self.avatar_path,'tmp','%08d.png')
267 | tmp_mp4 = os.path.join(self.avatar_path,"temp.mp4")
268 | cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {img_path} -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {tmp_mp4}"
269 | print(cmd_img2video)
270 | os.system(cmd_img2video)
271 |
272 | output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on
273 | cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {tmp_mp4} {output_vid}"
274 | print(cmd_combine_audio)
275 | os.system(cmd_combine_audio)
276 |
277 | os.remove(tmp_mp4)
278 | shutil.rmtree(os.path.join(self.avatar_path,"tmp"))
279 | print(f"result is save to {output_vid}")
280 | del self.audio_processor,self.vae,self.unet,self.pe
281 | import gc; gc.collect(); torch.cuda.empty_cache();
282 | return output_vid
283 |
284 | class Infer_Real_Time:
285 | def __init__(self) -> None:
286 | pass
287 |
288 | def __call__(self, audio_path,video_path,
289 | avatar_id,fps=25,batch_size=4,
290 | preparation=True,bbox_shift=0,
291 | *args: Any, **kwds: Any) -> Any:
292 |
293 | avatar = Avatar(
294 | avatar_id = avatar_id,
295 | video_path = video_path,
296 | bbox_shift = bbox_shift,
297 | batch_size = batch_size,
298 | preparation= preparation)
299 | output_name = os.path.basename(audio_path)[:-4] + "musetalk"
300 | return avatar.inference(audio_path,output_name,fps)
301 |
--------------------------------------------------------------------------------
/musetalk/models/__pycache__/unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/models/__pycache__/unet.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/models/__pycache__/vae.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/models/__pycache__/vae.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import json
5 |
6 | from diffusers import UNet2DConditionModel
7 | import sys
8 | import time
9 | import numpy as np
10 | import os
11 |
12 | class PositionalEncoding(nn.Module):
13 | def __init__(self, d_model=384, max_len=5000):
14 | super(PositionalEncoding, self).__init__()
15 | pe = torch.zeros(max_len, d_model)
16 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
18 | pe[:, 0::2] = torch.sin(position * div_term)
19 | pe[:, 1::2] = torch.cos(position * div_term)
20 | pe = pe.unsqueeze(0)
21 | self.register_buffer('pe', pe)
22 |
23 | def forward(self, x):
24 | b, seq_len, d_model = x.size()
25 | pe = self.pe[:, :seq_len, :]
26 | x = x + pe.to(x.device)
27 | return x
28 |
29 | class UNet():
30 | def __init__(self,
31 | unet_config,
32 | model_path,
33 | use_float16=False,
34 | ):
35 | with open(unet_config, 'r') as f:
36 | unet_config = json.load(f)
37 | self.model = UNet2DConditionModel(**unet_config)
38 | self.pe = PositionalEncoding(d_model=384)
39 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 | self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
41 | self.model.load_state_dict(self.weights)
42 | if use_float16:
43 | self.model = self.model.half()
44 | self.model.to(self.device)
45 |
46 | if __name__ == "__main__":
47 | unet = UNet()
--------------------------------------------------------------------------------
/musetalk/models/vae.py:
--------------------------------------------------------------------------------
1 | from diffusers import AutoencoderKL
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torch.nn.functional as F
5 | import cv2
6 | import numpy as np
7 | from PIL import Image
8 | import os
9 |
10 | class VAE():
11 | """
12 | VAE (Variational Autoencoder) class for image processing.
13 | """
14 |
15 | def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
16 | """
17 | Initialize the VAE instance.
18 |
19 | :param model_path: Path to the trained model.
20 | :param resized_img: The size to which images are resized.
21 | :param use_float16: Whether to use float16 precision.
22 | """
23 | self.model_path = model_path
24 | self.vae = AutoencoderKL.from_pretrained(self.model_path)
25 |
26 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | self.vae.to(self.device)
28 |
29 | if use_float16:
30 | self.vae = self.vae.half()
31 | self._use_float16 = True
32 | else:
33 | self._use_float16 = False
34 |
35 | self.scaling_factor = self.vae.config.scaling_factor
36 | self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
37 | self._resized_img = resized_img
38 | self._mask_tensor = self.get_mask_tensor()
39 |
40 | def get_mask_tensor(self):
41 | """
42 | Creates a mask tensor for image processing.
43 | :return: A mask tensor.
44 | """
45 | mask_tensor = torch.zeros((self._resized_img,self._resized_img))
46 | mask_tensor[:self._resized_img//2,:] = 1
47 | mask_tensor[mask_tensor< 0.5] = 0
48 | mask_tensor[mask_tensor>= 0.5] = 1
49 | return mask_tensor
50 |
51 | def preprocess_img(self,img_name,half_mask=False):
52 | """
53 | Preprocess an image for the VAE.
54 |
55 | :param img_name: The image file path or a list of image file paths.
56 | :param half_mask: Whether to apply a half mask to the image.
57 | :return: A preprocessed image tensor.
58 | """
59 | window = []
60 | if isinstance(img_name, str):
61 | window_fnames = [img_name]
62 | for fname in window_fnames:
63 | img = cv2.imread(fname)
64 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65 | img = cv2.resize(img, (self._resized_img, self._resized_img),
66 | interpolation=cv2.INTER_LANCZOS4)
67 | window.append(img)
68 | else:
69 | img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
70 | window.append(img)
71 |
72 | x = np.asarray(window) / 255.
73 | x = np.transpose(x, (3, 0, 1, 2))
74 | x = torch.squeeze(torch.FloatTensor(x))
75 | if half_mask:
76 | x = x * (self._mask_tensor>0.5)
77 | x = self.transform(x)
78 |
79 | x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
80 | x = x.to(self.vae.device)
81 |
82 | return x
83 |
84 | def encode_latents(self,image):
85 | """
86 | Encode an image into latent variables.
87 |
88 | :param image: The image tensor to encode.
89 | :return: The encoded latent variables.
90 | """
91 | with torch.no_grad():
92 | init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
93 | init_latents = self.scaling_factor * init_latent_dist.sample()
94 | return init_latents
95 |
96 | def decode_latents(self, latents):
97 | """
98 | Decode latent variables back into an image.
99 | :param latents: The latent variables to decode.
100 | :return: A NumPy array representing the decoded image.
101 | """
102 | latents = (1/ self.scaling_factor) * latents
103 | image = self.vae.decode(latents.to(self.vae.dtype)).sample
104 | image = (image / 2 + 0.5).clamp(0, 1)
105 | image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
106 | image = (image * 255).round().astype("uint8")
107 | image = image[...,::-1] # RGB to BGR
108 | return image
109 |
110 | def get_latents_for_unet(self,img):
111 | """
112 | Prepare latent variables for a U-Net model.
113 | :param img: The image to process.
114 | :return: A concatenated tensor of latents for U-Net input.
115 | """
116 |
117 | ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
118 | masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
119 | ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
120 | ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
121 | latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
122 | return latent_model_input
123 |
124 | if __name__ == "__main__":
125 | vae_mode_path = "./models/sd-vae-ft-mse/"
126 | vae = VAE(model_path = vae_mode_path,use_float16=False)
127 | img_path = "./results/sun001_crop/00000.png"
128 |
129 | crop_imgs_path = "./results/sun001_crop/"
130 | latents_out_path = "./results/latents/"
131 | if not os.path.exists(latents_out_path):
132 | os.mkdir(latents_out_path)
133 |
134 | files = os.listdir(crop_imgs_path)
135 | files.sort()
136 | files = [file for file in files if file.split(".")[-1] == "png"]
137 |
138 | for file in files:
139 | index = file.split(".")[0]
140 | img_path = crop_imgs_path + file
141 | latents = vae.get_latents_for_unet(img_path)
142 | print(img_path,"latents",latents.size())
143 | #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
144 | #reload_tensor = torch.load('tensor.pt')
145 | #print(reload_tensor.size())
146 |
147 |
148 |
--------------------------------------------------------------------------------
/musetalk/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from os.path import abspath, dirname
3 | current_dir = dirname(abspath(__file__))
4 | parent_dir = dirname(current_dir)
5 | sys.path.append(parent_dir+'/utils')
6 |
--------------------------------------------------------------------------------
/musetalk/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/__pycache__/blending.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/blending.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/__pycache__/preprocessing.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/preprocessing.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/blending.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | def get_crop_box(box, expand):
7 | x, y, x1, y1 = box
8 | x_c, y_c = (x+x1)//2, (y+y1)//2
9 | w, h = x1-x, y1-y
10 | s = int(max(w, h)//2*expand)
11 | crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
12 | return crop_box, s
13 |
14 | def face_seg(fp, image):
15 | seg_image = fp(image)
16 | if seg_image is None:
17 | print("error, no person_segment")
18 | return None
19 |
20 | seg_image = seg_image.resize(image.size)
21 | return seg_image
22 |
23 | def get_image(fp_model,image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
24 | #print(image.shape)
25 | #print(face.shape)
26 |
27 | body = Image.fromarray(image[:,:,::-1])
28 | face = Image.fromarray(face[:,:,::-1])
29 |
30 | x, y, x1, y1 = face_box
31 | #print(x1-x,y1-y)
32 | crop_box, s = get_crop_box(face_box, expand)
33 | x_s, y_s, x_e, y_e = crop_box
34 | face_position = (x, y)
35 |
36 | face_large = body.crop(crop_box)
37 | ori_shape = face_large.size
38 |
39 | mask_image = face_seg(fp_model,face_large)
40 | mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
41 | mask_image = Image.new('L', ori_shape, 0)
42 | mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
43 |
44 | # keep upper_boundary_ratio of talking area
45 | width, height = mask_image.size
46 | top_boundary = int(height * upper_boundary_ratio)
47 | modified_mask_image = Image.new('L', ori_shape, 0)
48 | modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
49 |
50 | blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
51 | mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
52 | mask_image = Image.fromarray(mask_array)
53 |
54 | face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
55 | body.paste(face_large, crop_box[:2], mask_image)
56 | body = np.array(body)
57 | return body[:,:,::-1]
58 |
59 | def get_image_prepare_material(fp_model,image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
60 | body = Image.fromarray(image[:,:,::-1])
61 |
62 | x, y, x1, y1 = face_box
63 | #print(x1-x,y1-y)
64 | crop_box, s = get_crop_box(face_box, expand)
65 | x_s, y_s, x_e, y_e = crop_box
66 |
67 | face_large = body.crop(crop_box)
68 | ori_shape = face_large.size
69 |
70 | mask_image = face_seg(fp_model,face_large)
71 | mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
72 | mask_image = Image.new('L', ori_shape, 0)
73 | mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
74 |
75 | # keep upper_boundary_ratio of talking area
76 | width, height = mask_image.size
77 | top_boundary = int(height * upper_boundary_ratio)
78 | modified_mask_image = Image.new('L', ori_shape, 0)
79 | modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
80 |
81 | blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
82 | mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
83 | return mask_array,crop_box
84 |
85 | def get_image_blending(image,face,face_box,mask_array,crop_box):
86 | body = Image.fromarray(image[:,:,::-1])
87 | face = Image.fromarray(face[:,:,::-1])
88 |
89 | x, y, x1, y1 = face_box
90 | x_s, y_s, x_e, y_e = crop_box
91 | face_large = body.crop(crop_box)
92 |
93 | mask_image = Image.fromarray(mask_array)
94 | mask_image = mask_image.convert("L")
95 | face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
96 | body.paste(face_large, crop_box[:2], mask_image)
97 | body = np.array(body)
98 | return body[:,:,::-1]
--------------------------------------------------------------------------------
/musetalk/utils/dwpose/default_runtime.py:
--------------------------------------------------------------------------------
1 | default_scope = 'mmpose'
2 |
3 | # hooks
4 | default_hooks = dict(
5 | timer=dict(type='IterTimerHook'),
6 | logger=dict(type='LoggerHook', interval=50),
7 | param_scheduler=dict(type='ParamSchedulerHook'),
8 | checkpoint=dict(type='CheckpointHook', interval=10),
9 | sampler_seed=dict(type='DistSamplerSeedHook'),
10 | visualization=dict(type='PoseVisualizationHook', enable=False),
11 | badcase=dict(
12 | type='BadCaseAnalysisHook',
13 | enable=False,
14 | out_dir='badcase',
15 | metric_type='loss',
16 | badcase_thr=5))
17 |
18 | # custom hooks
19 | custom_hooks = [
20 | # Synchronize model buffers such as running_mean and running_var in BN
21 | # at the end of each epoch
22 | dict(type='SyncBuffersHook')
23 | ]
24 |
25 | # multi-processing backend
26 | env_cfg = dict(
27 | cudnn_benchmark=False,
28 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
29 | dist_cfg=dict(backend='nccl'),
30 | )
31 |
32 | # visualizer
33 | vis_backends = [
34 | dict(type='LocalVisBackend'),
35 | # dict(type='TensorboardVisBackend'),
36 | # dict(type='WandbVisBackend'),
37 | ]
38 | visualizer = dict(
39 | type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
40 |
41 | # logger
42 | log_processor = dict(
43 | type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
44 | log_level = 'INFO'
45 | load_from = None
46 | resume = False
47 |
48 | # file I/O backend
49 | backend_args = dict(backend='local')
50 |
51 | # training/validation/testing progress
52 | train_cfg = dict(by_epoch=True)
53 | val_cfg = dict()
54 | test_cfg = dict()
55 |
--------------------------------------------------------------------------------
/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py:
--------------------------------------------------------------------------------
1 | #_base_ = ['../../../_base_/default_runtime.py']
2 | _base_ = ['default_runtime.py']
3 |
4 | # runtime
5 | max_epochs = 270
6 | stage2_num_epochs = 30
7 | base_lr = 4e-3
8 | train_batch_size = 32
9 | val_batch_size = 32
10 |
11 | train_cfg = dict(max_epochs=max_epochs, val_interval=10)
12 | randomness = dict(seed=21)
13 |
14 | # optimizer
15 | optim_wrapper = dict(
16 | type='OptimWrapper',
17 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
18 | paramwise_cfg=dict(
19 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
20 |
21 | # learning rate
22 | param_scheduler = [
23 | dict(
24 | type='LinearLR',
25 | start_factor=1.0e-5,
26 | by_epoch=False,
27 | begin=0,
28 | end=1000),
29 | dict(
30 | # use cosine lr from 150 to 300 epoch
31 | type='CosineAnnealingLR',
32 | eta_min=base_lr * 0.05,
33 | begin=max_epochs // 2,
34 | end=max_epochs,
35 | T_max=max_epochs // 2,
36 | by_epoch=True,
37 | convert_to_iter_based=True),
38 | ]
39 |
40 | # automatically scaling LR based on the actual training batch size
41 | auto_scale_lr = dict(base_batch_size=512)
42 |
43 | # codec settings
44 | codec = dict(
45 | type='SimCCLabel',
46 | input_size=(288, 384),
47 | sigma=(6., 6.93),
48 | simcc_split_ratio=2.0,
49 | normalize=False,
50 | use_dark=False)
51 |
52 | # model settings
53 | model = dict(
54 | type='TopdownPoseEstimator',
55 | data_preprocessor=dict(
56 | type='PoseDataPreprocessor',
57 | mean=[123.675, 116.28, 103.53],
58 | std=[58.395, 57.12, 57.375],
59 | bgr_to_rgb=True),
60 | backbone=dict(
61 | _scope_='mmdet',
62 | type='CSPNeXt',
63 | arch='P5',
64 | expand_ratio=0.5,
65 | deepen_factor=1.,
66 | widen_factor=1.,
67 | out_indices=(4, ),
68 | channel_attention=True,
69 | norm_cfg=dict(type='SyncBN'),
70 | act_cfg=dict(type='SiLU'),
71 | init_cfg=dict(
72 | type='Pretrained',
73 | prefix='backbone.',
74 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
75 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
76 | )),
77 | head=dict(
78 | type='RTMCCHead',
79 | in_channels=1024,
80 | out_channels=133,
81 | input_size=codec['input_size'],
82 | in_featuremap_size=(9, 12),
83 | simcc_split_ratio=codec['simcc_split_ratio'],
84 | final_layer_kernel_size=7,
85 | gau_cfg=dict(
86 | hidden_dims=256,
87 | s=128,
88 | expansion_factor=2,
89 | dropout_rate=0.,
90 | drop_path=0.,
91 | act_fn='SiLU',
92 | use_rel_bias=False,
93 | pos_enc=False),
94 | loss=dict(
95 | type='KLDiscretLoss',
96 | use_target_weight=True,
97 | beta=10.,
98 | label_softmax=True),
99 | decoder=codec),
100 | test_cfg=dict(flip_test=True, ))
101 |
102 | # base dataset settings
103 | dataset_type = 'UBody2dDataset'
104 | data_mode = 'topdown'
105 | data_root = 'data/UBody/'
106 |
107 | backend_args = dict(backend='local')
108 |
109 | scenes = [
110 | 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
111 | 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
112 | 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
113 | ]
114 |
115 | train_datasets = [
116 | dict(
117 | type='CocoWholeBodyDataset',
118 | data_root='data/coco/',
119 | data_mode=data_mode,
120 | ann_file='annotations/coco_wholebody_train_v1.0.json',
121 | data_prefix=dict(img='train2017/'),
122 | pipeline=[])
123 | ]
124 |
125 | for scene in scenes:
126 | train_dataset = dict(
127 | type=dataset_type,
128 | data_root=data_root,
129 | data_mode=data_mode,
130 | ann_file=f'annotations/{scene}/train_annotations.json',
131 | data_prefix=dict(img='images/'),
132 | pipeline=[],
133 | sample_interval=10)
134 | train_datasets.append(train_dataset)
135 |
136 | # pipelines
137 | train_pipeline = [
138 | dict(type='LoadImage', backend_args=backend_args),
139 | dict(type='GetBBoxCenterScale'),
140 | dict(type='RandomFlip', direction='horizontal'),
141 | dict(type='RandomHalfBody'),
142 | dict(
143 | type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
144 | dict(type='TopdownAffine', input_size=codec['input_size']),
145 | dict(type='mmdet.YOLOXHSVRandomAug'),
146 | dict(
147 | type='Albumentation',
148 | transforms=[
149 | dict(type='Blur', p=0.1),
150 | dict(type='MedianBlur', p=0.1),
151 | dict(
152 | type='CoarseDropout',
153 | max_holes=1,
154 | max_height=0.4,
155 | max_width=0.4,
156 | min_holes=1,
157 | min_height=0.2,
158 | min_width=0.2,
159 | p=1.0),
160 | ]),
161 | dict(type='GenerateTarget', encoder=codec),
162 | dict(type='PackPoseInputs')
163 | ]
164 | val_pipeline = [
165 | dict(type='LoadImage', backend_args=backend_args),
166 | dict(type='GetBBoxCenterScale'),
167 | dict(type='TopdownAffine', input_size=codec['input_size']),
168 | dict(type='PackPoseInputs')
169 | ]
170 |
171 | train_pipeline_stage2 = [
172 | dict(type='LoadImage', backend_args=backend_args),
173 | dict(type='GetBBoxCenterScale'),
174 | dict(type='RandomFlip', direction='horizontal'),
175 | dict(type='RandomHalfBody'),
176 | dict(
177 | type='RandomBBoxTransform',
178 | shift_factor=0.,
179 | scale_factor=[0.5, 1.5],
180 | rotate_factor=90),
181 | dict(type='TopdownAffine', input_size=codec['input_size']),
182 | dict(type='mmdet.YOLOXHSVRandomAug'),
183 | dict(
184 | type='Albumentation',
185 | transforms=[
186 | dict(type='Blur', p=0.1),
187 | dict(type='MedianBlur', p=0.1),
188 | dict(
189 | type='CoarseDropout',
190 | max_holes=1,
191 | max_height=0.4,
192 | max_width=0.4,
193 | min_holes=1,
194 | min_height=0.2,
195 | min_width=0.2,
196 | p=0.5),
197 | ]),
198 | dict(type='GenerateTarget', encoder=codec),
199 | dict(type='PackPoseInputs')
200 | ]
201 |
202 | # data loaders
203 | train_dataloader = dict(
204 | batch_size=train_batch_size,
205 | num_workers=10,
206 | persistent_workers=True,
207 | sampler=dict(type='DefaultSampler', shuffle=True),
208 | dataset=dict(
209 | type='CombinedDataset',
210 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
211 | datasets=train_datasets,
212 | pipeline=train_pipeline,
213 | test_mode=False,
214 | ))
215 |
216 | val_dataloader = dict(
217 | batch_size=val_batch_size,
218 | num_workers=10,
219 | persistent_workers=True,
220 | drop_last=False,
221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222 | dataset=dict(
223 | type='CocoWholeBodyDataset',
224 | data_root=data_root,
225 | data_mode=data_mode,
226 | ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
227 | bbox_file='data/coco/person_detection_results/'
228 | 'COCO_val2017_detections_AP_H_56_person.json',
229 | data_prefix=dict(img='coco/val2017/'),
230 | test_mode=True,
231 | pipeline=val_pipeline,
232 | ))
233 | test_dataloader = val_dataloader
234 |
235 | # hooks
236 | default_hooks = dict(
237 | checkpoint=dict(
238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239 |
240 | custom_hooks = [
241 | dict(
242 | type='EMAHook',
243 | ema_type='ExpMomentumEMA',
244 | momentum=0.0002,
245 | update_buffers=True,
246 | priority=49),
247 | dict(
248 | type='mmdet.PipelineSwitchHook',
249 | switch_epoch=max_epochs - stage2_num_epochs,
250 | switch_pipeline=train_pipeline_stage2)
251 | ]
252 |
253 | # evaluators
254 | val_evaluator = dict(
255 | type='CocoWholeBodyMetric',
256 | ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
257 | test_evaluator = val_evaluator
258 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/README.md:
--------------------------------------------------------------------------------
1 | The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | __author__ = """Adrian Bulat"""
4 | __email__ = 'adrian.bulat@nottingham.ac.uk'
5 | __version__ = '1.0.1'
6 |
7 | from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
8 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/__pycache__/api.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/api.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/__pycache__/models.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/models.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/api.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.model_zoo import load_url
5 | from enum import Enum
6 | import numpy as np
7 | import cv2
8 | try:
9 | import urllib.request as request_file
10 | except BaseException:
11 | import urllib as request_file
12 |
13 | from .models import FAN, ResNetDepth
14 | from .utils import *
15 |
16 |
17 | class LandmarksType(Enum):
18 | """Enum class defining the type of landmarks to detect.
19 |
20 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21 | ``_2halfD`` - this points represent the projection of the 3D points into 3D
22 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23 |
24 | """
25 | _2D = 1
26 | _2halfD = 2
27 | _3D = 3
28 |
29 |
30 | class NetworkSize(Enum):
31 | # TINY = 1
32 | # SMALL = 2
33 | # MEDIUM = 3
34 | LARGE = 4
35 |
36 | def __new__(cls, value):
37 | member = object.__new__(cls)
38 | member._value_ = value
39 | return member
40 |
41 | def __int__(self):
42 | return self.value
43 |
44 |
45 |
46 | class FaceAlignment:
47 | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48 | device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49 | self.device = device
50 | self.flip_input = flip_input
51 | self.landmarks_type = landmarks_type
52 | self.verbose = verbose
53 |
54 | network_size = int(network_size)
55 |
56 | if 'cuda' in device:
57 | torch.backends.cudnn.benchmark = True
58 | # torch.backends.cuda.matmul.allow_tf32 = False
59 | # torch.backends.cudnn.benchmark = True
60 | # torch.backends.cudnn.deterministic = False
61 | # torch.backends.cudnn.allow_tf32 = True
62 | print('cuda start')
63 |
64 |
65 | # Get the face detector
66 | face_detector_module = __import__('face_detection.detection.' + face_detector,
67 | globals(), locals(), [face_detector], 0)
68 |
69 | self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
70 |
71 | def get_detections_for_batch(self, images):
72 | images = images[..., ::-1]
73 | detected_faces = self.face_detector.detect_from_batch(images.copy())
74 | results = []
75 |
76 | for i, d in enumerate(detected_faces):
77 | if len(d) == 0:
78 | results.append(None)
79 | continue
80 | d = d[0]
81 | d = np.clip(d, 0, None)
82 |
83 | x1, y1, x2, y2 = map(int, d[:-1])
84 | results.append((x1, y1, x2, y2))
85 |
86 | return results
87 |
88 |
89 | class YOLOv8_face:
90 | def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
91 | self.conf_threshold = conf_thres
92 | self.iou_threshold = iou_thres
93 | self.class_names = ['face']
94 | self.num_classes = len(self.class_names)
95 | # Initialize model
96 | self.net = cv2.dnn.readNet(path)
97 | self.input_height = 640
98 | self.input_width = 640
99 | self.reg_max = 16
100 |
101 | self.project = np.arange(self.reg_max)
102 | self.strides = (8, 16, 32)
103 | self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
104 | self.anchors = self.make_anchors(self.feats_hw)
105 |
106 | def make_anchors(self, feats_hw, grid_cell_offset=0.5):
107 | """Generate anchors from features."""
108 | anchor_points = {}
109 | for i, stride in enumerate(self.strides):
110 | h,w = feats_hw[i]
111 | x = np.arange(0, w) + grid_cell_offset # shift x
112 | y = np.arange(0, h) + grid_cell_offset # shift y
113 | sx, sy = np.meshgrid(x, y)
114 | # sy, sx = np.meshgrid(y, x)
115 | anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
116 | return anchor_points
117 |
118 | def softmax(self, x, axis=1):
119 | x_exp = np.exp(x)
120 | # 如果是列向量,则axis=0
121 | x_sum = np.sum(x_exp, axis=axis, keepdims=True)
122 | s = x_exp / x_sum
123 | return s
124 |
125 | def resize_image(self, srcimg, keep_ratio=True):
126 | top, left, newh, neww = 0, 0, self.input_width, self.input_height
127 | if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
128 | hw_scale = srcimg.shape[0] / srcimg.shape[1]
129 | if hw_scale > 1:
130 | newh, neww = self.input_height, int(self.input_width / hw_scale)
131 | img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
132 | left = int((self.input_width - neww) * 0.5)
133 | img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
134 | value=(0, 0, 0)) # add border
135 | else:
136 | newh, neww = int(self.input_height * hw_scale), self.input_width
137 | img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
138 | top = int((self.input_height - newh) * 0.5)
139 | img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
140 | value=(0, 0, 0))
141 | else:
142 | img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
143 | return img, newh, neww, top, left
144 |
145 | def detect(self, srcimg):
146 | input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
147 | scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
148 | input_img = input_img.astype(np.float32) / 255.0
149 |
150 | blob = cv2.dnn.blobFromImage(input_img)
151 | self.net.setInput(blob)
152 | outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
153 | # if isinstance(outputs, tuple):
154 | # outputs = list(outputs)
155 | # if float(cv2.__version__[:3])>=4.7:
156 | # outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
157 | # Perform inference on the image
158 | det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
159 | return det_bboxes, det_conf, det_classid, landmarks
160 |
161 | def post_process(self, preds, scale_h, scale_w, padh, padw):
162 | bboxes, scores, landmarks = [], [], []
163 | for i, pred in enumerate(preds):
164 | stride = int(self.input_height/pred.shape[2])
165 | pred = pred.transpose((0, 2, 3, 1))
166 |
167 | box = pred[..., :self.reg_max * 4]
168 | cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
169 | kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
170 |
171 | # tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
172 | tmp = box.reshape(-1, 4, self.reg_max)
173 | bbox_pred = self.softmax(tmp, axis=-1)
174 | bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
175 |
176 | bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
177 | kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
178 | kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
179 | kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
180 |
181 | bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
182 | bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
183 | kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
184 | kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
185 |
186 | bboxes.append(bbox)
187 | scores.append(cls)
188 | landmarks.append(kpts)
189 |
190 | bboxes = np.concatenate(bboxes, axis=0)
191 | scores = np.concatenate(scores, axis=0)
192 | landmarks = np.concatenate(landmarks, axis=0)
193 |
194 | bboxes_wh = bboxes.copy()
195 | bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
196 | classIds = np.argmax(scores, axis=1)
197 | confidences = np.max(scores, axis=1) ####max_class_confidence
198 |
199 | mask = confidences>self.conf_threshold
200 | bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
201 | confidences = confidences[mask]
202 | classIds = classIds[mask]
203 | landmarks = landmarks[mask]
204 |
205 | indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
206 | self.iou_threshold).flatten()
207 | if len(indices) > 0:
208 | mlvl_bboxes = bboxes_wh[indices]
209 | confidences = confidences[indices]
210 | classIds = classIds[indices]
211 | landmarks = landmarks[indices]
212 | return mlvl_bboxes, confidences, classIds, landmarks
213 | else:
214 | print('nothing detect')
215 | return np.array([]), np.array([]), np.array([]), np.array([])
216 |
217 | def distance2bbox(self, points, distance, max_shape=None):
218 | x1 = points[:, 0] - distance[:, 0]
219 | y1 = points[:, 1] - distance[:, 1]
220 | x2 = points[:, 0] + distance[:, 2]
221 | y2 = points[:, 1] + distance[:, 3]
222 | if max_shape is not None:
223 | x1 = np.clip(x1, 0, max_shape[1])
224 | y1 = np.clip(y1, 0, max_shape[0])
225 | x2 = np.clip(x2, 0, max_shape[1])
226 | y2 = np.clip(y2, 0, max_shape[0])
227 | return np.stack([x1, y1, x2, y2], axis=-1)
228 |
229 | def draw_detections(self, image, boxes, scores, kpts):
230 | for box, score, kp in zip(boxes, scores, kpts):
231 | x, y, w, h = box.astype(int)
232 | # Draw rectangle
233 | cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
234 | cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
235 | for i in range(5):
236 | cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
237 | # cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
238 | return image
239 |
240 | ROOT = os.path.dirname(os.path.abspath(__file__))
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import FaceDetector
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/__pycache__/core.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/__pycache__/core.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/core.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import glob
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | import cv2
7 |
8 |
9 | class FaceDetector(object):
10 | """An abstract class representing a face detector.
11 |
12 | Any other face detection implementation must subclass it. All subclasses
13 | must implement ``detect_from_image``, that return a list of detected
14 | bounding boxes. Optionally, for speed considerations detect from path is
15 | recommended.
16 | """
17 |
18 | def __init__(self, device, verbose):
19 | self.device = device
20 | self.verbose = verbose
21 |
22 | if verbose:
23 | if 'cpu' in device:
24 | logger = logging.getLogger(__name__)
25 | logger.warning("Detection running on CPU, this may be potentially slow.")
26 |
27 | if 'cpu' not in device and 'cuda' not in device:
28 | if verbose:
29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30 | raise ValueError
31 |
32 | def detect_from_image(self, tensor_or_path):
33 | """Detects faces in a given image.
34 |
35 | This function detects the faces present in a provided BGR(usually)
36 | image. The input can be either the image itself or the path to it.
37 |
38 | Arguments:
39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40 | to an image or the image itself.
41 |
42 | Example::
43 |
44 | >>> path_to_image = 'data/image_01.jpg'
45 | ... detected_faces = detect_from_image(path_to_image)
46 | [A list of bounding boxes (x1, y1, x2, y2)]
47 | >>> image = cv2.imread(path_to_image)
48 | ... detected_faces = detect_from_image(image)
49 | [A list of bounding boxes (x1, y1, x2, y2)]
50 |
51 | """
52 | raise NotImplementedError
53 |
54 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55 | """Detects faces from all the images present in a given directory.
56 |
57 | Arguments:
58 | path {string} -- a string containing a path that points to the folder containing the images
59 |
60 | Keyword Arguments:
61 | extensions {list} -- list of string containing the extensions to be
62 | consider in the following format: ``.extension_name`` (default:
63 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64 | folder recursively (default: {False}) show_progress_bar {bool} --
65 | display a progressbar (default: {True})
66 |
67 | Example:
68 | >>> directory = 'data'
69 | ... detected_faces = detect_from_directory(directory)
70 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71 |
72 | """
73 | if self.verbose:
74 | logger = logging.getLogger(__name__)
75 |
76 | if len(extensions) == 0:
77 | if self.verbose:
78 | logger.error("Expected at list one extension, but none was received.")
79 | raise ValueError
80 |
81 | if self.verbose:
82 | logger.info("Constructing the list of images.")
83 | additional_pattern = '/**/*' if recursive else '/*'
84 | files = []
85 | for extension in extensions:
86 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87 |
88 | if self.verbose:
89 | logger.info("Finished searching for images. %s images found", len(files))
90 | logger.info("Preparing to run the detection.")
91 |
92 | predictions = {}
93 | for image_path in tqdm(files, disable=not show_progress_bar):
94 | if self.verbose:
95 | logger.info("Running the face detector on image: %s", image_path)
96 | predictions[image_path] = self.detect_from_image(image_path)
97 |
98 | if self.verbose:
99 | logger.info("The detector was successfully run on all %s images", len(files))
100 |
101 | return predictions
102 |
103 | @property
104 | def reference_scale(self):
105 | raise NotImplementedError
106 |
107 | @property
108 | def reference_x_shift(self):
109 | raise NotImplementedError
110 |
111 | @property
112 | def reference_y_shift(self):
113 | raise NotImplementedError
114 |
115 | @staticmethod
116 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118 |
119 | Arguments:
120 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121 | """
122 | if isinstance(tensor_or_path, str):
123 | return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124 | elif torch.is_tensor(tensor_or_path):
125 | # Call cpu in case its coming from cuda
126 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127 | elif isinstance(tensor_or_path, np.ndarray):
128 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129 | else:
130 | raise TypeError
131 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__init__.py:
--------------------------------------------------------------------------------
1 | from .sfd_detector import SFDDetector as FaceDetector
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/bbox.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import cv2
5 | import random
6 | import datetime
7 | import time
8 | import math
9 | import argparse
10 | import numpy as np
11 | import torch
12 |
13 | try:
14 | from iou import IOU
15 | except BaseException:
16 | # IOU cython speedup 10x
17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18 | sa = abs((ax2 - ax1) * (ay2 - ay1))
19 | sb = abs((bx2 - bx1) * (by2 - by1))
20 | x1, y1 = max(ax1, bx1), max(ay1, by1)
21 | x2, y2 = min(ax2, bx2), min(ay2, by2)
22 | w = x2 - x1
23 | h = y2 - y1
24 | if w < 0 or h < 0:
25 | return 0.0
26 | else:
27 | return 1.0 * w * h / (sa + sb - w * h)
28 |
29 |
30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33 | dw, dh = math.log(ww / aww), math.log(hh / ahh)
34 | return dx, dy, dw, dh
35 |
36 |
37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38 | xc, yc = dx * aww + axc, dy * ahh + ayc
39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41 | return x1, y1, x2, y2
42 |
43 |
44 | def nms(dets, thresh):
45 | if 0 == len(dets):
46 | return []
47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49 | order = scores.argsort()[::-1]
50 |
51 | keep = []
52 | while order.size > 0:
53 | i = order[0]
54 | keep.append(i)
55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57 |
58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60 |
61 | inds = np.where(ovr <= thresh)[0]
62 | order = order[inds + 1]
63 |
64 | return keep
65 |
66 |
67 | def encode(matched, priors, variances):
68 | """Encode the variances from the priorbox layers into the ground truth boxes
69 | we have matched (based on jaccard overlap) with the prior boxes.
70 | Args:
71 | matched: (tensor) Coords of ground truth for each prior in point-form
72 | Shape: [num_priors, 4].
73 | priors: (tensor) Prior boxes in center-offset form
74 | Shape: [num_priors,4].
75 | variances: (list[float]) Variances of priorboxes
76 | Return:
77 | encoded boxes (tensor), Shape: [num_priors, 4]
78 | """
79 |
80 | # dist b/t match center and prior's center
81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82 | # encode variance
83 | g_cxcy /= (variances[0] * priors[:, 2:])
84 | # match wh / prior wh
85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86 | g_wh = torch.log(g_wh) / variances[1]
87 | # return target for smooth_l1_loss
88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89 |
90 |
91 | def decode(loc, priors, variances):
92 | """Decode locations from predictions using priors to undo
93 | the encoding we did for offset regression at train time.
94 | Args:
95 | loc (tensor): location predictions for loc layers,
96 | Shape: [num_priors,4]
97 | priors (tensor): Prior boxes in center-offset form.
98 | Shape: [num_priors,4].
99 | variances: (list[float]) Variances of priorboxes
100 | Return:
101 | decoded bounding box predictions
102 | """
103 |
104 | boxes = torch.cat((
105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107 | boxes[:, :2] -= boxes[:, 2:] / 2
108 | boxes[:, 2:] += boxes[:, :2]
109 | return boxes
110 |
111 | def batch_decode(loc, priors, variances):
112 | """Decode locations from predictions using priors to undo
113 | the encoding we did for offset regression at train time.
114 | Args:
115 | loc (tensor): location predictions for loc layers,
116 | Shape: [num_priors,4]
117 | priors (tensor): Prior boxes in center-offset form.
118 | Shape: [num_priors,4].
119 | variances: (list[float]) Variances of priorboxes
120 | Return:
121 | decoded bounding box predictions
122 | """
123 |
124 | boxes = torch.cat((
125 | priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126 | priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127 | boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128 | boxes[:, :, 2:] += boxes[:, :, :2]
129 | return boxes
130 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import os
5 | import sys
6 | import cv2
7 | import random
8 | import datetime
9 | import math
10 | import argparse
11 | import numpy as np
12 |
13 | import scipy.io as sio
14 | import zipfile
15 | from .net_s3fd import s3fd
16 | from .bbox import *
17 |
18 |
19 | def detect(net, img, device):
20 | img = img - np.array([104, 117, 123])
21 | img = img.transpose(2, 0, 1)
22 | img = img.reshape((1,) + img.shape)
23 |
24 | if 'cuda' in device:
25 | torch.backends.cudnn.benchmark = True
26 |
27 | img = torch.from_numpy(img).float().to(device)
28 | BB, CC, HH, WW = img.size()
29 | with torch.no_grad():
30 | olist = net(img)
31 |
32 | bboxlist = []
33 | for i in range(len(olist) // 2):
34 | olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35 | olist = [oelem.data.cpu() for oelem in olist]
36 | for i in range(len(olist) // 2):
37 | ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38 | FB, FC, FH, FW = ocls.size() # feature map size
39 | stride = 2**(i + 2) # 4,8,16,32,64,128
40 | anchor = stride * 4
41 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42 | for Iindex, hindex, windex in poss:
43 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44 | score = ocls[0, 1, hindex, windex]
45 | loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47 | variances = [0.1, 0.2]
48 | box = decode(loc, priors, variances)
49 | x1, y1, x2, y2 = box[0] * 1.0
50 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51 | bboxlist.append([x1, y1, x2, y2, score])
52 | bboxlist = np.array(bboxlist)
53 | if 0 == len(bboxlist):
54 | bboxlist = np.zeros((1, 5))
55 |
56 | return bboxlist
57 |
58 | def batch_detect(net, imgs, device):
59 | imgs = imgs - np.array([104, 117, 123])
60 | imgs = imgs.transpose(0, 3, 1, 2)
61 |
62 | if 'cuda' in device:
63 | torch.backends.cudnn.benchmark = True
64 |
65 | imgs = torch.from_numpy(imgs).float().to(device)
66 | BB, CC, HH, WW = imgs.size()
67 | with torch.no_grad():
68 | olist = net(imgs)
69 | # print(olist)
70 |
71 | bboxlist = []
72 | for i in range(len(olist) // 2):
73 | olist[i * 2] = F.softmax(olist[i * 2], dim=1)
74 |
75 | olist = [oelem.cpu() for oelem in olist]
76 | for i in range(len(olist) // 2):
77 | ocls, oreg = olist[i * 2], olist[i * 2 + 1]
78 | FB, FC, FH, FW = ocls.size() # feature map size
79 | stride = 2**(i + 2) # 4,8,16,32,64,128
80 | anchor = stride * 4
81 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
82 | for Iindex, hindex, windex in poss:
83 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
84 | score = ocls[:, 1, hindex, windex]
85 | loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
86 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
87 | variances = [0.1, 0.2]
88 | box = batch_decode(loc, priors, variances)
89 | box = box[:, 0] * 1.0
90 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
91 | bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
92 | bboxlist = np.array(bboxlist)
93 | if 0 == len(bboxlist):
94 | bboxlist = np.zeros((1, BB, 5))
95 |
96 | return bboxlist
97 |
98 | def flip_detect(net, img, device):
99 | img = cv2.flip(img, 1)
100 | b = detect(net, img, device)
101 |
102 | bboxlist = np.zeros(b.shape)
103 | bboxlist[:, 0] = img.shape[1] - b[:, 2]
104 | bboxlist[:, 1] = b[:, 1]
105 | bboxlist[:, 2] = img.shape[1] - b[:, 0]
106 | bboxlist[:, 3] = b[:, 3]
107 | bboxlist[:, 4] = b[:, 4]
108 | return bboxlist
109 |
110 |
111 | def pts_to_bb(pts):
112 | min_x, min_y = np.min(pts, axis=0)
113 | max_x, max_y = np.max(pts, axis=0)
114 | return np.array([min_x, min_y, max_x, max_y])
115 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/net_s3fd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class L2Norm(nn.Module):
7 | def __init__(self, n_channels, scale=1.0):
8 | super(L2Norm, self).__init__()
9 | self.n_channels = n_channels
10 | self.scale = scale
11 | self.eps = 1e-10
12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13 | self.weight.data *= 0.0
14 | self.weight.data += self.scale
15 |
16 | def forward(self, x):
17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18 | x = x / norm * self.weight.view(1, -1, 1, 1)
19 | return x
20 |
21 |
22 | class s3fd(nn.Module):
23 | def __init__(self):
24 | super(s3fd, self).__init__()
25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27 |
28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30 |
31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34 |
35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38 |
39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42 |
43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45 |
46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48 |
49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51 |
52 | self.conv3_3_norm = L2Norm(256, scale=10)
53 | self.conv4_3_norm = L2Norm(512, scale=8)
54 | self.conv5_3_norm = L2Norm(512, scale=5)
55 |
56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62 |
63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69 |
70 | def forward(self, x):
71 | h = F.relu(self.conv1_1(x))
72 | h = F.relu(self.conv1_2(h))
73 | h = F.max_pool2d(h, 2, 2)
74 |
75 | h = F.relu(self.conv2_1(h))
76 | h = F.relu(self.conv2_2(h))
77 | h = F.max_pool2d(h, 2, 2)
78 |
79 | h = F.relu(self.conv3_1(h))
80 | h = F.relu(self.conv3_2(h))
81 | h = F.relu(self.conv3_3(h))
82 | f3_3 = h
83 | h = F.max_pool2d(h, 2, 2)
84 |
85 | h = F.relu(self.conv4_1(h))
86 | h = F.relu(self.conv4_2(h))
87 | h = F.relu(self.conv4_3(h))
88 | f4_3 = h
89 | h = F.max_pool2d(h, 2, 2)
90 |
91 | h = F.relu(self.conv5_1(h))
92 | h = F.relu(self.conv5_2(h))
93 | h = F.relu(self.conv5_3(h))
94 | f5_3 = h
95 | h = F.max_pool2d(h, 2, 2)
96 |
97 | h = F.relu(self.fc6(h))
98 | h = F.relu(self.fc7(h))
99 | ffc7 = h
100 | h = F.relu(self.conv6_1(h))
101 | h = F.relu(self.conv6_2(h))
102 | f6_2 = h
103 | h = F.relu(self.conv7_1(h))
104 | h = F.relu(self.conv7_2(h))
105 | f7_2 = h
106 |
107 | f3_3 = self.conv3_3_norm(f3_3)
108 | f4_3 = self.conv4_3_norm(f4_3)
109 | f5_3 = self.conv5_3_norm(f5_3)
110 |
111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117 | cls4 = self.fc7_mbox_conf(ffc7)
118 | reg4 = self.fc7_mbox_loc(ffc7)
119 | cls5 = self.conv6_2_mbox_conf(f6_2)
120 | reg5 = self.conv6_2_mbox_loc(f6_2)
121 | cls6 = self.conv7_2_mbox_conf(f7_2)
122 | reg6 = self.conv7_2_mbox_loc(f7_2)
123 |
124 | # max-out background label
125 | chunk = torch.chunk(cls1, 4, 1)
126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127 | cls1 = torch.cat([bmax, chunk[3]], dim=1)
128 |
129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
130 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/detection/sfd/sfd_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from torch.utils.model_zoo import load_url
4 |
5 | from ..core import FaceDetector
6 |
7 | from .net_s3fd import s3fd
8 | from .bbox import *
9 | from .detect import *
10 |
11 | models_urls = {
12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13 | }
14 |
15 |
16 | class SFDDetector(FaceDetector):
17 | def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18 | super(SFDDetector, self).__init__(device, verbose)
19 |
20 | # Initialise the face detector
21 | if not os.path.isfile(path_to_detector):
22 | model_weights = load_url(models_urls['s3fd'])
23 | else:
24 | model_weights = torch.load(path_to_detector)
25 |
26 | self.face_detector = s3fd()
27 | self.face_detector.load_state_dict(model_weights)
28 | self.face_detector.to(device)
29 | self.face_detector.eval()
30 |
31 | def detect_from_image(self, tensor_or_path):
32 | image = self.tensor_or_path_to_ndarray(tensor_or_path)
33 |
34 | bboxlist = detect(self.face_detector, image, device=self.device)
35 | keep = nms(bboxlist, 0.3)
36 | bboxlist = bboxlist[keep, :]
37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38 |
39 | return bboxlist
40 |
41 | def detect_from_batch(self, images):
42 | bboxlists = batch_detect(self.face_detector, images, device=self.device)
43 | keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44 | bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45 | bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46 |
47 | return bboxlists
48 |
49 | @property
50 | def reference_scale(self):
51 | return 195
52 |
53 | @property
54 | def reference_x_shift(self):
55 | return 0
56 |
57 | @property
58 | def reference_y_shift(self):
59 | return 0
60 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10 | stride=strd, padding=padding, bias=bias)
11 |
12 |
13 | class ConvBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes):
15 | super(ConvBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22 |
23 | if in_planes != out_planes:
24 | self.downsample = nn.Sequential(
25 | nn.BatchNorm2d(in_planes),
26 | nn.ReLU(True),
27 | nn.Conv2d(in_planes, out_planes,
28 | kernel_size=1, stride=1, bias=False),
29 | )
30 | else:
31 | self.downsample = None
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out1 = self.bn1(x)
37 | out1 = F.relu(out1, True)
38 | out1 = self.conv1(out1)
39 |
40 | out2 = self.bn2(out1)
41 | out2 = F.relu(out2, True)
42 | out2 = self.conv2(out2)
43 |
44 | out3 = self.bn3(out2)
45 | out3 = F.relu(out3, True)
46 | out3 = self.conv3(out3)
47 |
48 | out3 = torch.cat((out1, out2, out3), 1)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(residual)
52 |
53 | out3 += residual
54 |
55 | return out3
56 |
57 |
58 | class Bottleneck(nn.Module):
59 |
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class HourGlass(nn.Module):
99 | def __init__(self, num_modules, depth, num_features):
100 | super(HourGlass, self).__init__()
101 | self.num_modules = num_modules
102 | self.depth = depth
103 | self.features = num_features
104 |
105 | self._generate_network(self.depth)
106 |
107 | def _generate_network(self, level):
108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109 |
110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111 |
112 | if level > 1:
113 | self._generate_network(level - 1)
114 | else:
115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116 |
117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118 |
119 | def _forward(self, level, inp):
120 | # Upper branch
121 | up1 = inp
122 | up1 = self._modules['b1_' + str(level)](up1)
123 |
124 | # Lower branch
125 | low1 = F.avg_pool2d(inp, 2, stride=2)
126 | low1 = self._modules['b2_' + str(level)](low1)
127 |
128 | if level > 1:
129 | low2 = self._forward(level - 1, low1)
130 | else:
131 | low2 = low1
132 | low2 = self._modules['b2_plus_' + str(level)](low2)
133 |
134 | low3 = low2
135 | low3 = self._modules['b3_' + str(level)](low3)
136 |
137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138 |
139 | return up1 + up2
140 |
141 | def forward(self, x):
142 | return self._forward(self.depth, x)
143 |
144 |
145 | class FAN(nn.Module):
146 |
147 | def __init__(self, num_modules=1):
148 | super(FAN, self).__init__()
149 | self.num_modules = num_modules
150 |
151 | # Base part
152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153 | self.bn1 = nn.BatchNorm2d(64)
154 | self.conv2 = ConvBlock(64, 128)
155 | self.conv3 = ConvBlock(128, 128)
156 | self.conv4 = ConvBlock(128, 256)
157 |
158 | # Stacking part
159 | for hg_module in range(self.num_modules):
160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162 | self.add_module('conv_last' + str(hg_module),
163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165 | self.add_module('l' + str(hg_module), nn.Conv2d(256,
166 | 68, kernel_size=1, stride=1, padding=0))
167 |
168 | if hg_module < self.num_modules - 1:
169 | self.add_module(
170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171 | self.add_module('al' + str(hg_module), nn.Conv2d(68,
172 | 256, kernel_size=1, stride=1, padding=0))
173 |
174 | def forward(self, x):
175 | x = F.relu(self.bn1(self.conv1(x)), True)
176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177 | x = self.conv3(x)
178 | x = self.conv4(x)
179 |
180 | previous = x
181 |
182 | outputs = []
183 | for i in range(self.num_modules):
184 | hg = self._modules['m' + str(i)](previous)
185 |
186 | ll = hg
187 | ll = self._modules['top_m_' + str(i)](ll)
188 |
189 | ll = F.relu(self._modules['bn_end' + str(i)]
190 | (self._modules['conv_last' + str(i)](ll)), True)
191 |
192 | # Predict heatmaps
193 | tmp_out = self._modules['l' + str(i)](ll)
194 | outputs.append(tmp_out)
195 |
196 | if i < self.num_modules - 1:
197 | ll = self._modules['bl' + str(i)](ll)
198 | tmp_out_ = self._modules['al' + str(i)](tmp_out)
199 | previous = previous + ll + tmp_out_
200 |
201 | return outputs
202 |
203 |
204 | class ResNetDepth(nn.Module):
205 |
206 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207 | self.inplanes = 64
208 | super(ResNetDepth, self).__init__()
209 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210 | bias=False)
211 | self.bn1 = nn.BatchNorm2d(64)
212 | self.relu = nn.ReLU(inplace=True)
213 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214 | self.layer1 = self._make_layer(block, 64, layers[0])
215 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218 | self.avgpool = nn.AvgPool2d(7)
219 | self.fc = nn.Linear(512 * block.expansion, num_classes)
220 |
221 | for m in self.modules():
222 | if isinstance(m, nn.Conv2d):
223 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224 | m.weight.data.normal_(0, math.sqrt(2. / n))
225 | elif isinstance(m, nn.BatchNorm2d):
226 | m.weight.data.fill_(1)
227 | m.bias.data.zero_()
228 |
229 | def _make_layer(self, block, planes, blocks, stride=1):
230 | downsample = None
231 | if stride != 1 or self.inplanes != planes * block.expansion:
232 | downsample = nn.Sequential(
233 | nn.Conv2d(self.inplanes, planes * block.expansion,
234 | kernel_size=1, stride=stride, bias=False),
235 | nn.BatchNorm2d(planes * block.expansion),
236 | )
237 |
238 | layers = []
239 | layers.append(block(self.inplanes, planes, stride, downsample))
240 | self.inplanes = planes * block.expansion
241 | for i in range(1, blocks):
242 | layers.append(block(self.inplanes, planes))
243 |
244 | return nn.Sequential(*layers)
245 |
246 | def forward(self, x):
247 | x = self.conv1(x)
248 | x = self.bn1(x)
249 | x = self.relu(x)
250 | x = self.maxpool(x)
251 |
252 | x = self.layer1(x)
253 | x = self.layer2(x)
254 | x = self.layer3(x)
255 | x = self.layer4(x)
256 |
257 | x = self.avgpool(x)
258 | x = x.view(x.size(0), -1)
259 | x = self.fc(x)
260 |
261 | return x
262 |
--------------------------------------------------------------------------------
/musetalk/utils/face_detection/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import time
5 | import torch
6 | import math
7 | import numpy as np
8 | import cv2
9 |
10 |
11 | def _gaussian(
12 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14 | mean_vert=0.5):
15 | # handle some defaults
16 | if width is None:
17 | width = size
18 | if height is None:
19 | height = size
20 | if sigma_horz is None:
21 | sigma_horz = sigma
22 | if sigma_vert is None:
23 | sigma_vert = sigma
24 | center_x = mean_horz * width + 0.5
25 | center_y = mean_vert * height + 0.5
26 | gauss = np.empty((height, width), dtype=np.float32)
27 | # generate kernel
28 | for i in range(height):
29 | for j in range(width):
30 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32 | if normalize:
33 | gauss = gauss / np.sum(gauss)
34 | return gauss
35 |
36 |
37 | def draw_gaussian(image, point, sigma):
38 | # Check if the gaussian is inside
39 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42 | return image
43 | size = 6 * sigma + 1
44 | g = _gaussian(size)
45 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49 | assert (g_x[0] > 0 and g_y[1] > 0)
50 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52 | image[image > 1] = 1
53 | return image
54 |
55 |
56 | def transform(point, center, scale, resolution, invert=False):
57 | """Generate and affine transformation matrix.
58 |
59 | Given a set of points, a center, a scale and a targer resolution, the
60 | function generates and affine transformation matrix. If invert is ``True``
61 | it will produce the inverse transformation.
62 |
63 | Arguments:
64 | point {torch.tensor} -- the input 2D point
65 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66 | scale {float} -- the scale of the face/object
67 | resolution {float} -- the output resolution
68 |
69 | Keyword Arguments:
70 | invert {bool} -- define wherever the function should produce the direct or the
71 | inverse transformation matrix (default: {False})
72 | """
73 | _pt = torch.ones(3)
74 | _pt[0] = point[0]
75 | _pt[1] = point[1]
76 |
77 | h = 200.0 * scale
78 | t = torch.eye(3)
79 | t[0, 0] = resolution / h
80 | t[1, 1] = resolution / h
81 | t[0, 2] = resolution * (-center[0] / h + 0.5)
82 | t[1, 2] = resolution * (-center[1] / h + 0.5)
83 |
84 | if invert:
85 | t = torch.inverse(t)
86 |
87 | new_point = (torch.matmul(t, _pt))[0:2]
88 |
89 | return new_point.int()
90 |
91 |
92 | def crop(image, center, scale, resolution=256.0):
93 | """Center crops an image or set of heatmaps
94 |
95 | Arguments:
96 | image {numpy.array} -- an rgb image
97 | center {numpy.array} -- the center of the object, usually the same as of the bounding box
98 | scale {float} -- scale of the face
99 |
100 | Keyword Arguments:
101 | resolution {float} -- the size of the output cropped image (default: {256.0})
102 |
103 | Returns:
104 | [type] -- [description]
105 | """ # Crop around the center point
106 | """ Crops the image around the center. Input is expected to be an np.ndarray """
107 | ul = transform([1, 1], center, scale, resolution, True)
108 | br = transform([resolution, resolution], center, scale, resolution, True)
109 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110 | if image.ndim > 2:
111 | newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112 | image.shape[2]], dtype=np.int32)
113 | newImg = np.zeros(newDim, dtype=np.uint8)
114 | else:
115 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116 | newImg = np.zeros(newDim, dtype=np.uint8)
117 | ht = image.shape[0]
118 | wd = image.shape[1]
119 | newX = np.array(
120 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121 | newY = np.array(
122 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128 | interpolation=cv2.INTER_LINEAR)
129 | return newImg
130 |
131 |
132 | def get_preds_fromhm(hm, center=None, scale=None):
133 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134 | and the scale is provided the function will return the points also in
135 | the original coordinate frame.
136 |
137 | Arguments:
138 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139 |
140 | Keyword Arguments:
141 | center {torch.tensor} -- the center of the bounding box (default: {None})
142 | scale {float} -- face scale (default: {None})
143 | """
144 | max, idx = torch.max(
145 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146 | idx += 1
147 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150 |
151 | for i in range(preds.size(0)):
152 | for j in range(preds.size(1)):
153 | hm_ = hm[i, j, :]
154 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156 | diff = torch.FloatTensor(
157 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159 | preds[i, j].add_(diff.sign_().mul_(.25))
160 |
161 | preds.add_(-.5)
162 |
163 | preds_orig = torch.zeros(preds.size())
164 | if center is not None and scale is not None:
165 | for i in range(hm.size(0)):
166 | for j in range(hm.size(1)):
167 | preds_orig[i, j] = transform(
168 | preds[i, j], center, scale, hm.size(2), True)
169 |
170 | return preds, preds_orig
171 |
172 | def get_preds_fromhm_batch(hm, centers=None, scales=None):
173 | """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174 | and the scales is provided the function will return the points also in
175 | the original coordinate frame.
176 |
177 | Arguments:
178 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179 |
180 | Keyword Arguments:
181 | centers {torch.tensor} -- the centers of the bounding box (default: {None})
182 | scales {float} -- face scales (default: {None})
183 | """
184 | max, idx = torch.max(
185 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186 | idx += 1
187 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190 |
191 | for i in range(preds.size(0)):
192 | for j in range(preds.size(1)):
193 | hm_ = hm[i, j, :]
194 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196 | diff = torch.FloatTensor(
197 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199 | preds[i, j].add_(diff.sign_().mul_(.25))
200 |
201 | preds.add_(-.5)
202 |
203 | preds_orig = torch.zeros(preds.size())
204 | if centers is not None and scales is not None:
205 | for i in range(hm.size(0)):
206 | for j in range(hm.size(1)):
207 | preds_orig[i, j] = transform(
208 | preds[i, j], centers[i], scales[i], hm.size(2), True)
209 |
210 | return preds, preds_orig
211 |
212 | def shuffle_lr(parts, pairs=None):
213 | """Shuffle the points left-right according to the axis of symmetry
214 | of the object.
215 |
216 | Arguments:
217 | parts {torch.tensor} -- a 3D or 4D object containing the
218 | heatmaps.
219 |
220 | Keyword Arguments:
221 | pairs {list of integers} -- [order of the flipped points] (default: {None})
222 | """
223 | if pairs is None:
224 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228 | 62, 61, 60, 67, 66, 65]
229 | if parts.ndimension() == 3:
230 | parts = parts[pairs, ...]
231 | else:
232 | parts = parts[:, pairs, ...]
233 |
234 | return parts
235 |
236 |
237 | def flip(tensor, is_label=False):
238 | """Flip an image or a set of heatmaps left-right
239 |
240 | Arguments:
241 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242 |
243 | Keyword Arguments:
244 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245 | """
246 | if not torch.is_tensor(tensor):
247 | tensor = torch.from_numpy(tensor)
248 |
249 | if is_label:
250 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251 | else:
252 | tensor = tensor.flip(tensor.ndimension() - 1)
253 |
254 | return tensor
255 |
256 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257 |
258 |
259 | def appdata_dir(appname=None, roaming=False):
260 | """ appdata_dir(appname=None, roaming=False)
261 |
262 | Get the path to the application directory, where applications are allowed
263 | to write user specific files (e.g. configurations). For non-user specific
264 | data, consider using common_appdata_dir().
265 | If appname is given, a subdir is appended (and created if necessary).
266 | If roaming is True, will prefer a roaming directory (Windows Vista/7).
267 | """
268 |
269 | # Define default user directory
270 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271 | if userDir is None:
272 | userDir = os.path.expanduser('~')
273 | if not os.path.isdir(userDir): # pragma: no cover
274 | userDir = '/var/tmp' # issue #54
275 |
276 | # Get system app data dir
277 | path = None
278 | if sys.platform.startswith('win'):
279 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280 | path = (path2 or path1) if roaming else (path1 or path2)
281 | elif sys.platform.startswith('darwin'):
282 | path = os.path.join(userDir, 'Library', 'Application Support')
283 | # On Linux and as fallback
284 | if not (path and os.path.isdir(path)):
285 | path = userDir
286 |
287 | # Maybe we should store things local to the executable (in case of a
288 | # portable distro or a frozen application that wants to be portable)
289 | prefix = sys.prefix
290 | if getattr(sys, 'frozen', None):
291 | prefix = os.path.abspath(os.path.dirname(sys.executable))
292 | for reldir in ('settings', '../settings'):
293 | localpath = os.path.abspath(os.path.join(prefix, reldir))
294 | if os.path.isdir(localpath): # pragma: no cover
295 | try:
296 | open(os.path.join(localpath, 'test.write'), 'wb').close()
297 | os.remove(os.path.join(localpath, 'test.write'))
298 | except IOError:
299 | pass # We cannot write in this directory
300 | else:
301 | path = localpath
302 | break
303 |
304 | # Get path specific for this app
305 | if appname:
306 | if path == userDir:
307 | appname = '.' + appname.lstrip('.') # Make it a hidden directory
308 | path = os.path.join(path, appname)
309 | if not os.path.isdir(path): # pragma: no cover
310 | os.mkdir(path)
311 |
312 | # Done
313 | return path
314 |
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import os
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | from .model import BiSeNet
8 | import torchvision.transforms as transforms
9 |
10 | class FaceParsing():
11 | def __init__(self,resnet_path,model_pth):
12 | self.net = self.model_init(resnet_path,model_pth)
13 | self.preprocess = self.image_preprocess()
14 |
15 | def model_init(self,
16 | resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
17 | model_pth='./models/face-parse-bisent/79999_iter.pth'):
18 | net = BiSeNet(resnet_path)
19 | if torch.cuda.is_available():
20 | net.cuda()
21 | net.load_state_dict(torch.load(model_pth))
22 | else:
23 | net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
24 | net.eval()
25 | return net
26 |
27 | def image_preprocess(self):
28 | return transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
31 | ])
32 |
33 | def __call__(self, image, size=(512, 512)):
34 | if isinstance(image, str):
35 | image = Image.open(image)
36 |
37 | width, height = image.size
38 | with torch.no_grad():
39 | image = image.resize(size, Image.BILINEAR)
40 | img = self.preprocess(image)
41 | if torch.cuda.is_available():
42 | img = torch.unsqueeze(img, 0).cuda()
43 | else:
44 | img = torch.unsqueeze(img, 0)
45 | out = self.net(img)[0]
46 | parsing = out.squeeze(0).cpu().numpy().argmax(0)
47 | parsing[np.where(parsing>13)] = 0
48 | parsing[np.where(parsing>=1)] = 255
49 | parsing = Image.fromarray(parsing.astype(np.uint8))
50 | return parsing
51 |
52 | if __name__ == "__main__":
53 | fp = FaceParsing()
54 | segmap = fp('154_small.png')
55 | segmap.save('res.png')
56 |
57 |
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/utils/face_parsing/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from .resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size = ks,
20 | stride = stride,
21 | padding = padding,
22 | bias = False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36 |
37 | class BiSeNetOutput(nn.Module):
38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39 | super(BiSeNetOutput, self).__init__()
40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42 | self.init_weight()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.conv_out(x)
47 | return x
48 |
49 | def init_weight(self):
50 | for ly in self.children():
51 | if isinstance(ly, nn.Conv2d):
52 | nn.init.kaiming_normal_(ly.weight, a=1)
53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54 |
55 | def get_params(self):
56 | wd_params, nowd_params = [], []
57 | for name, module in self.named_modules():
58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59 | wd_params.append(module.weight)
60 | if not module.bias is None:
61 | nowd_params.append(module.bias)
62 | elif isinstance(module, nn.BatchNorm2d):
63 | nowd_params += list(module.parameters())
64 | return wd_params, nowd_params
65 |
66 |
67 | class AttentionRefinementModule(nn.Module):
68 | def __init__(self, in_chan, out_chan, *args, **kwargs):
69 | super(AttentionRefinementModule, self).__init__()
70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72 | self.bn_atten = nn.BatchNorm2d(out_chan)
73 | self.sigmoid_atten = nn.Sigmoid()
74 | self.init_weight()
75 |
76 | def forward(self, x):
77 | feat = self.conv(x)
78 | atten = F.avg_pool2d(feat, feat.size()[2:])
79 | atten = self.conv_atten(atten)
80 | atten = self.bn_atten(atten)
81 | atten = self.sigmoid_atten(atten)
82 | out = torch.mul(feat, atten)
83 | return out
84 |
85 | def init_weight(self):
86 | for ly in self.children():
87 | if isinstance(ly, nn.Conv2d):
88 | nn.init.kaiming_normal_(ly.weight, a=1)
89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90 |
91 |
92 | class ContextPath(nn.Module):
93 | def __init__(self, resnet_path, *args, **kwargs):
94 | super(ContextPath, self).__init__()
95 | self.resnet = Resnet18(resnet_path)
96 | self.arm16 = AttentionRefinementModule(256, 128)
97 | self.arm32 = AttentionRefinementModule(512, 128)
98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101 |
102 | self.init_weight()
103 |
104 | def forward(self, x):
105 | H0, W0 = x.size()[2:]
106 | feat8, feat16, feat32 = self.resnet(x)
107 | H8, W8 = feat8.size()[2:]
108 | H16, W16 = feat16.size()[2:]
109 | H32, W32 = feat32.size()[2:]
110 |
111 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
112 | avg = self.conv_avg(avg)
113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114 |
115 | feat32_arm = self.arm32(feat32)
116 | feat32_sum = feat32_arm + avg_up
117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118 | feat32_up = self.conv_head32(feat32_up)
119 |
120 | feat16_arm = self.arm16(feat16)
121 | feat16_sum = feat16_arm + feat32_up
122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123 | feat16_up = self.conv_head16(feat16_up)
124 |
125 | return feat8, feat16_up, feat32_up # x8, x8, x16
126 |
127 | def init_weight(self):
128 | for ly in self.children():
129 | if isinstance(ly, nn.Conv2d):
130 | nn.init.kaiming_normal_(ly.weight, a=1)
131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132 |
133 | def get_params(self):
134 | wd_params, nowd_params = [], []
135 | for name, module in self.named_modules():
136 | if isinstance(module, (nn.Linear, nn.Conv2d)):
137 | wd_params.append(module.weight)
138 | if not module.bias is None:
139 | nowd_params.append(module.bias)
140 | elif isinstance(module, nn.BatchNorm2d):
141 | nowd_params += list(module.parameters())
142 | return wd_params, nowd_params
143 |
144 |
145 | ### This is not used, since I replace this with the resnet feature with the same size
146 | class SpatialPath(nn.Module):
147 | def __init__(self, *args, **kwargs):
148 | super(SpatialPath, self).__init__()
149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153 | self.init_weight()
154 |
155 | def forward(self, x):
156 | feat = self.conv1(x)
157 | feat = self.conv2(feat)
158 | feat = self.conv3(feat)
159 | feat = self.conv_out(feat)
160 | return feat
161 |
162 | def init_weight(self):
163 | for ly in self.children():
164 | if isinstance(ly, nn.Conv2d):
165 | nn.init.kaiming_normal_(ly.weight, a=1)
166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167 |
168 | def get_params(self):
169 | wd_params, nowd_params = [], []
170 | for name, module in self.named_modules():
171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172 | wd_params.append(module.weight)
173 | if not module.bias is None:
174 | nowd_params.append(module.bias)
175 | elif isinstance(module, nn.BatchNorm2d):
176 | nowd_params += list(module.parameters())
177 | return wd_params, nowd_params
178 |
179 |
180 | class FeatureFusionModule(nn.Module):
181 | def __init__(self, in_chan, out_chan, *args, **kwargs):
182 | super(FeatureFusionModule, self).__init__()
183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184 | self.conv1 = nn.Conv2d(out_chan,
185 | out_chan//4,
186 | kernel_size = 1,
187 | stride = 1,
188 | padding = 0,
189 | bias = False)
190 | self.conv2 = nn.Conv2d(out_chan//4,
191 | out_chan,
192 | kernel_size = 1,
193 | stride = 1,
194 | padding = 0,
195 | bias = False)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.sigmoid = nn.Sigmoid()
198 | self.init_weight()
199 |
200 | def forward(self, fsp, fcp):
201 | fcat = torch.cat([fsp, fcp], dim=1)
202 | feat = self.convblk(fcat)
203 | atten = F.avg_pool2d(feat, feat.size()[2:])
204 | atten = self.conv1(atten)
205 | atten = self.relu(atten)
206 | atten = self.conv2(atten)
207 | atten = self.sigmoid(atten)
208 | feat_atten = torch.mul(feat, atten)
209 | feat_out = feat_atten + feat
210 | return feat_out
211 |
212 | def init_weight(self):
213 | for ly in self.children():
214 | if isinstance(ly, nn.Conv2d):
215 | nn.init.kaiming_normal_(ly.weight, a=1)
216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217 |
218 | def get_params(self):
219 | wd_params, nowd_params = [], []
220 | for name, module in self.named_modules():
221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222 | wd_params.append(module.weight)
223 | if not module.bias is None:
224 | nowd_params.append(module.bias)
225 | elif isinstance(module, nn.BatchNorm2d):
226 | nowd_params += list(module.parameters())
227 | return wd_params, nowd_params
228 |
229 |
230 | class BiSeNet(nn.Module):
231 | def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
232 | super(BiSeNet, self).__init__()
233 | self.cp = ContextPath(resnet_path)
234 | ## here self.sp is deleted
235 | self.ffm = FeatureFusionModule(256, 256)
236 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239 | self.init_weight()
240 |
241 | def forward(self, x):
242 | H, W = x.size()[2:]
243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245 | feat_fuse = self.ffm(feat_sp, feat_cp8)
246 |
247 | feat_out = self.conv_out(feat_fuse)
248 | feat_out16 = self.conv_out16(feat_cp8)
249 | feat_out32 = self.conv_out32(feat_cp16)
250 |
251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254 | return feat_out, feat_out16, feat_out32
255 |
256 | def init_weight(self):
257 | for ly in self.children():
258 | if isinstance(ly, nn.Conv2d):
259 | nn.init.kaiming_normal_(ly.weight, a=1)
260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261 |
262 | def get_params(self):
263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264 | for name, child in self.named_children():
265 | child_wd_params, child_nowd_params = child.get_params()
266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267 | lr_mul_wd_params += child_wd_params
268 | lr_mul_nowd_params += child_nowd_params
269 | else:
270 | wd_params += child_wd_params
271 | nowd_params += child_nowd_params
272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273 |
274 |
275 | if __name__ == "__main__":
276 | net = BiSeNet(19)
277 | net.cuda()
278 | net.eval()
279 | in_ten = torch.randn(16, 3, 640, 480).cuda()
280 | out, out16, out32 = net(in_ten)
281 | print(out.shape)
282 |
283 | net.get_params()
284 |
--------------------------------------------------------------------------------
/musetalk/utils/face_parsing/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self, model_path):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight(model_path)
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self, model_path):
83 | state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/musetalk/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from face_detection import FaceAlignment,LandmarksType
3 | from os import listdir, path
4 | import subprocess
5 | import numpy as np
6 | import cv2
7 | import pickle
8 | import os
9 | import json
10 | from mmpose.apis import inference_topdown, init_model
11 | from mmpose.structures import merge_data_samples
12 | import torch
13 | from tqdm import tqdm
14 | parent_directory = os.path.dirname(os.path.abspath(__file__))
15 | # initialize the mmpose model
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 | '''
18 | config_file = os.path.join(parent_directory,"dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py")
19 | checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
20 | model = init_model(config_file, checkpoint_file, device=device)
21 | '''
22 |
23 |
24 | # initialize the face detection model
25 | device = "cuda" if torch.cuda.is_available() else "cpu"
26 | fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
27 |
28 | # maker if the bbox is not sufficient
29 | coord_placeholder = (0.0,0.0,0.0,0.0)
30 |
31 | def resize_landmark(landmark, w, h, new_w, new_h):
32 | w_ratio = new_w / w
33 | h_ratio = new_h / h
34 | landmark_norm = landmark / [w, h]
35 | landmark_resized = landmark_norm * [new_w, new_h]
36 | return landmark_resized
37 |
38 | def read_imgs(img_list):
39 | frames = []
40 | print('reading images...')
41 | for img_path in tqdm(img_list):
42 | frame = cv2.imread(img_path)
43 | frames.append(frame)
44 | return frames
45 |
46 | def get_bbox_range(model,img_list,batch_size_fa,upperbondrange =0):
47 | frames = read_imgs(img_list)
48 | # batch_size_fa = 1
49 | batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
50 | coords_list = []
51 | landmarks = []
52 | if upperbondrange != 0:
53 | print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
54 | else:
55 | print('get key_landmark and face bounding boxes with the default value')
56 | average_range_minus = []
57 | average_range_plus = []
58 | for fb in tqdm(batches):
59 | results = inference_topdown(model, np.asarray(fb)[0])
60 | results = merge_data_samples(results)
61 | keypoints = results.pred_instances.keypoints
62 | face_land_mark= keypoints[0][23:91]
63 | face_land_mark = face_land_mark.astype(np.int32)
64 |
65 | # get bounding boxes by face detetion
66 | bbox = fa.get_detections_for_batch(np.asarray(fb))
67 |
68 | # adjust the bounding box refer to landmark
69 | # Add the bounding box to a tuple and append it to the coordinates list
70 | for j, f in enumerate(bbox):
71 | if f is None: # no face in the image
72 | coords_list += [coord_placeholder]
73 | continue
74 |
75 | half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
76 | range_minus = (face_land_mark[30]- face_land_mark[29])[1]
77 | range_plus = (face_land_mark[29]- face_land_mark[28])[1]
78 | average_range_minus.append(range_minus)
79 | average_range_plus.append(range_plus)
80 | if upperbondrange != 0:
81 | half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
82 |
83 | text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
84 | return text_range
85 |
86 |
87 | def get_landmark_and_bbox(model,img_list,batch_size_fa,upperbondrange =0):
88 | frames = read_imgs(img_list)
89 | # batch_size_fa = 1
90 | batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
91 | coords_list = []
92 | landmarks = []
93 | if upperbondrange != 0:
94 | print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
95 | else:
96 | print('get key_landmark and face bounding boxes with the default value')
97 | average_range_minus = []
98 | average_range_plus = []
99 | for fb in tqdm(batches):
100 | results = inference_topdown(model, np.asarray(fb)[0])
101 | results = merge_data_samples(results)
102 | keypoints = results.pred_instances.keypoints
103 | face_land_mark= keypoints[0][23:91]
104 | face_land_mark = face_land_mark.astype(np.int32)
105 |
106 | # get bounding boxes by face detetion
107 | bbox = fa.get_detections_for_batch(np.asarray(fb))
108 |
109 | # adjust the bounding box refer to landmark
110 | # Add the bounding box to a tuple and append it to the coordinates list
111 | for j, f in enumerate(bbox):
112 | if f is None: # no face in the image
113 | coords_list += [coord_placeholder]
114 | continue
115 |
116 | half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
117 | range_minus = (face_land_mark[30]- face_land_mark[29])[1]
118 | range_plus = (face_land_mark[29]- face_land_mark[28])[1]
119 | average_range_minus.append(range_minus)
120 | average_range_plus.append(range_plus)
121 | if upperbondrange != 0:
122 | half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
123 | half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
124 | upper_bond = half_face_coord[1]-half_face_dist
125 | upper_bond = max(0, upper_bond)
126 | # https://github.com/TMElyralab/MuseTalk/issues/38
127 | f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
128 | x1, y1, x2, y2 = f_landmark
129 |
130 | if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
131 | coords_list += [f]
132 | w,h = f[2]-f[0], f[3]-f[1]
133 | print("error bbox:",f)
134 | else:
135 | coords_list += [f_landmark]
136 |
137 | print("********************************************bbox_shift parameter adjustment**********************************************************")
138 | print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
139 | print("*************************************************************************************************************************************")
140 | return coords_list,frames
141 |
142 |
143 | if __name__ == "__main__":
144 | img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
145 | crop_coord_path = "./coord_face.pkl"
146 | coords_list,full_frames = get_landmark_and_bbox(img_list)
147 | with open(crop_coord_path, 'wb') as f:
148 | pickle.dump(coords_list, f)
149 |
150 | for bbox, frame in zip(coords_list,full_frames):
151 | if bbox == coord_placeholder:
152 | continue
153 | x1, y1, x2, y2 = bbox
154 | crop_frame = frame[y1:y2, x1:x2]
155 | print('Cropped shape', crop_frame.shape)
156 |
157 | #cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
158 | print(coords_list)
159 |
--------------------------------------------------------------------------------
/musetalk/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 | ffmpeg_path = os.getenv('FFMPEG_PATH')
7 | if ffmpeg_path is None:
8 | print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
9 | elif ffmpeg_path not in os.getenv('PATH'):
10 | print("add ffmpeg to path")
11 | os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
12 |
13 |
14 | from ..whisper.audio2feature import Audio2Feature
15 | from ..models.vae import VAE
16 | from ..models.unet import UNet,PositionalEncoding
17 |
18 | def load_all_model(base_dir):
19 | audio_processor = Audio2Feature(model_path=os.path.join(base_dir,"whisper/tiny.pt"))
20 | vae = VAE(model_path =os.path.join(base_dir,"sd-vae-ft-mse/"))
21 | unet = UNet(unet_config=os.path.join(base_dir,"musetalk/musetalk.json"),
22 | model_path =os.path.join(base_dir,"musetalk/pytorch_model.bin"))
23 | pe = PositionalEncoding(d_model=384)
24 | return audio_processor,vae,unet,pe
25 |
26 | def get_file_type(video_path):
27 | _, ext = os.path.splitext(video_path)
28 |
29 | if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
30 | return 'image'
31 | elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
32 | return 'video'
33 | else:
34 | return 'unsupported'
35 |
36 | def get_video_fps(video_path):
37 | video = cv2.VideoCapture(video_path)
38 | fps = video.get(cv2.CAP_PROP_FPS)
39 | video.release()
40 | return fps
41 |
42 | def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
43 | whisper_batch, latent_batch = [], []
44 | for i, w in enumerate(whisper_chunks):
45 | idx = (i+delay_frame)%len(vae_encode_latents)
46 | latent = vae_encode_latents[idx]
47 | whisper_batch.append(w)
48 | latent_batch.append(latent)
49 |
50 | if len(latent_batch) >= batch_size:
51 | whisper_batch = np.asarray(whisper_batch)
52 | latent_batch = torch.cat(latent_batch, dim=0)
53 | yield whisper_batch, latent_batch
54 | whisper_batch, latent_batch = [], []
55 |
56 | # the last batch may smaller than batch size
57 | if len(latent_batch) > 0:
58 | whisper_batch = np.asarray(whisper_batch)
59 | latent_batch = torch.cat(latent_batch, dim=0)
60 |
61 | yield whisper_batch, latent_batch
--------------------------------------------------------------------------------
/musetalk/whisper/__pycache__/audio2feature.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/__pycache__/audio2feature.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/audio2feature.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .whisper import load_model
3 | import soundfile as sf
4 | import numpy as np
5 | import time
6 | import sys
7 | sys.path.append("..")
8 |
9 | class Audio2Feature():
10 | def __init__(self,
11 | whisper_model_type="tiny",
12 | model_path="./models/whisper/tiny.pt"):
13 | self.whisper_model_type = whisper_model_type
14 | self.model = load_model(model_path) #
15 |
16 | def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
17 | """
18 | Get sliced features based on a given index
19 | :param feature_array:
20 | :param start_idx: the start index of the feature
21 | :param audio_feat_length:
22 | :return:
23 | """
24 | length = len(feature_array)
25 | selected_feature = []
26 | selected_idx = []
27 |
28 | center_idx = int(vid_idx*50/fps)
29 | left_idx = center_idx-audio_feat_length[0]*2
30 | right_idx = center_idx + (audio_feat_length[1]+1)*2
31 |
32 | for idx in range(left_idx,right_idx):
33 | idx = max(0, idx)
34 | idx = min(length-1, idx)
35 | x = feature_array[idx]
36 | selected_feature.append(x)
37 | selected_idx.append(idx)
38 |
39 | selected_feature = np.concatenate(selected_feature, axis=0)
40 | selected_feature = selected_feature.reshape(-1, 384)# 50*384
41 | return selected_feature,selected_idx
42 |
43 | def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
44 | """
45 | Get sliced features based on a given index
46 | :param feature_array:
47 | :param start_idx: the start index of the feature
48 | :param audio_feat_length:
49 | :return:
50 | """
51 | length = len(feature_array)
52 | selected_feature = []
53 | selected_idx = []
54 |
55 | for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
56 | left_idx = int((vid_idx+dt)*50/fps)
57 | if left_idx<1 or left_idx>length-1:
58 | left_idx = max(0, left_idx)
59 | left_idx = min(length-1, left_idx)
60 |
61 | x = feature_array[left_idx]
62 | x = x[np.newaxis,:,:]
63 | x = np.repeat(x, 2, axis=0)
64 | selected_feature.append(x)
65 | selected_idx.append(left_idx)
66 | selected_idx.append(left_idx)
67 | else:
68 | x = feature_array[left_idx-1:left_idx+1]
69 | selected_feature.append(x)
70 | selected_idx.append(left_idx-1)
71 | selected_idx.append(left_idx)
72 | selected_feature = np.concatenate(selected_feature, axis=0)
73 | selected_feature = selected_feature.reshape(-1, 384)# 50*384
74 | return selected_feature,selected_idx
75 |
76 |
77 | def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
78 | whisper_chunks = []
79 | whisper_idx_multiplier = 50./fps
80 | i = 0
81 | print(f"video in {fps} FPS, audio idx in 50FPS")
82 | while 1:
83 | start_idx = int(i * whisper_idx_multiplier)
84 | selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
85 | #print(f"i:{i},selected_idx {selected_idx}")
86 | whisper_chunks.append(selected_feature)
87 | i += 1
88 | if start_idx>len(feature_array):
89 | break
90 |
91 | return whisper_chunks
92 |
93 | def audio2feat(self,audio_path):
94 | # get the sample rate of the audio
95 | result = self.model.transcribe(audio_path)
96 | embed_list = []
97 | for emb in result['segments']:
98 | encoder_embeddings = emb['encoder_embeddings']
99 | encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
100 | encoder_embeddings = encoder_embeddings.squeeze(0)
101 | start_idx = int(emb['start'])
102 | end_idx = int(emb['end'])
103 | emb_end_idx = int((end_idx - start_idx)/2)
104 | embed_list.append(encoder_embeddings[:emb_end_idx])
105 | concatenated_array = np.concatenate(embed_list, axis=0)
106 | return concatenated_array
107 |
108 | if __name__ == "__main__":
109 | audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt")
110 | audio_path = "./test.mp3"
111 | array = audio_processor.audio2feat(audio_path)
112 | print(array.shape)
113 | fps = 25
114 | whisper_idx_multiplier = 50./fps
115 |
116 | i = 0
117 | print(f"video in {fps} FPS, audio idx in 50FPS")
118 | while 1:
119 | start_idx = int(i * whisper_idx_multiplier)
120 | selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps)
121 | print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
122 | i += 1
123 | if start_idx>len(array):
124 | break
125 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13 | from .model import Whisper, ModelDimensions
14 | from .transcribe import transcribe
15 |
16 |
17 | _MODELS = {
18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
29 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
30 | }
31 |
32 |
33 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
34 | os.makedirs(root, exist_ok=True)
35 |
36 | expected_sha256 = url.split("/")[-2]
37 | download_target = os.path.join(root, os.path.basename(url))
38 |
39 | if os.path.exists(download_target) and not os.path.isfile(download_target):
40 | raise RuntimeError(f"{download_target} exists and is not a regular file")
41 |
42 | if os.path.isfile(download_target):
43 | model_bytes = open(download_target, "rb").read()
44 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
45 | return model_bytes if in_memory else download_target
46 | else:
47 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
48 |
49 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
50 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
51 | while True:
52 | buffer = source.read(8192)
53 | if not buffer:
54 | break
55 |
56 | output.write(buffer)
57 | loop.update(len(buffer))
58 |
59 | model_bytes = open(download_target, "rb").read()
60 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
61 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
62 |
63 | return model_bytes if in_memory else download_target
64 |
65 |
66 | def available_models() -> List[str]:
67 | """Returns the names of available models"""
68 | return list(_MODELS.keys())
69 |
70 |
71 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
72 | """
73 | Load a Whisper ASR model
74 |
75 | Parameters
76 | ----------
77 | name : str
78 | one of the official model names listed by `whisper.available_models()`, or
79 | path to a model checkpoint containing the model dimensions and the model state_dict.
80 | device : Union[str, torch.device]
81 | the PyTorch device to put the model into
82 | download_root: str
83 | path to download the model files; by default, it uses "~/.cache/whisper"
84 | in_memory: bool
85 | whether to preload the model weights into host memory
86 |
87 | Returns
88 | -------
89 | model : Whisper
90 | The Whisper ASR model instance
91 | """
92 |
93 | if device is None:
94 | device = "cuda" if torch.cuda.is_available() else "cpu"
95 | if download_root is None:
96 | download_root = os.getenv(
97 | "XDG_CACHE_HOME",
98 | os.path.join(os.path.expanduser("~"), ".cache", "whisper")
99 | )
100 |
101 | if name in _MODELS:
102 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
103 | elif os.path.isfile(name):
104 | checkpoint_file = open(name, "rb").read() if in_memory else name
105 | else:
106 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
107 |
108 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
109 | checkpoint = torch.load(fp, map_location=device)
110 | del checkpoint_file
111 |
112 | dims = ModelDimensions(**checkpoint["dims"])
113 | model = Whisper(dims)
114 | model.load_state_dict(checkpoint["model_state_dict"])
115 |
116 | return model.to(device)
117 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__main__.py:
--------------------------------------------------------------------------------
1 | from .transcribe import cli
2 |
3 |
4 | cli()
5 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/audio.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/audio.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/decoding.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/decoding.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/tokenizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/tokenizer.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/transcribe.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/transcribe.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/gpt2/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/gpt2/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/musetalk/whisper/whisper/assets/mel_filters.npz
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/multilingual/added_tokens.json:
--------------------------------------------------------------------------------
1 | {"<|endoftext|>": 50257}
2 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/multilingual/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/assets/multilingual/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Union
4 |
5 | import ffmpeg
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | N_MELS = 80
16 | HOP_LENGTH = 160
17 | CHUNK_LENGTH = 30
18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
20 |
21 |
22 | def load_audio(file: str, sr: int = SAMPLE_RATE):
23 | """
24 | Open an audio file and read as mono waveform, resampling as necessary
25 |
26 | Parameters
27 | ----------
28 | file: str
29 | The audio file to open
30 |
31 | sr: int
32 | The sample rate to resample the audio if necessary
33 |
34 | Returns
35 | -------
36 | A NumPy array containing the audio waveform, in float32 dtype.
37 | """
38 | try:
39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
41 | out, _ = (
42 | ffmpeg.input(file, threads=0)
43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
45 | )
46 | except ffmpeg.Error as e:
47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
48 |
49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
50 |
51 |
52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
53 | """
54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
55 | """
56 | if torch.is_tensor(array):
57 | if array.shape[axis] > length:
58 | array = array.index_select(dim=axis, index=torch.arange(length))
59 |
60 | if array.shape[axis] < length:
61 | pad_widths = [(0, 0)] * array.ndim
62 | pad_widths[axis] = (0, length - array.shape[axis])
63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
64 | else:
65 | if array.shape[axis] > length:
66 | array = array.take(indices=range(length), axis=axis)
67 |
68 | if array.shape[axis] < length:
69 | pad_widths = [(0, 0)] * array.ndim
70 | pad_widths[axis] = (0, length - array.shape[axis])
71 | array = np.pad(array, pad_widths)
72 |
73 | return array
74 |
75 |
76 | @lru_cache(maxsize=None)
77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
78 | """
79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
80 | Allows decoupling librosa dependency; saved using:
81 |
82 | np.savez_compressed(
83 | "mel_filters.npz",
84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
85 | )
86 | """
87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
90 |
91 |
92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
93 | """
94 | Compute the log-Mel spectrogram of
95 |
96 | Parameters
97 | ----------
98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
100 |
101 | n_mels: int
102 | The number of Mel-frequency filters, only 80 is supported
103 |
104 | Returns
105 | -------
106 | torch.Tensor, shape = (80, n_frames)
107 | A Tensor that contains the Mel spectrogram
108 | """
109 | if not torch.is_tensor(audio):
110 | if isinstance(audio, str):
111 | audio = load_audio(audio)
112 | audio = torch.from_numpy(audio)
113 |
114 | window = torch.hann_window(N_FFT).to(audio.device)
115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
116 |
117 | magnitudes = stft[:, :-1].abs() ** 2
118 |
119 | filters = mel_filters(audio.device, n_mels)
120 | mel_spec = filters @ magnitudes
121 |
122 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
123 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
124 | log_spec = (log_spec + 4.0) / 4.0
125 | return log_spec
126 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict
3 | from typing import Iterable, Optional
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from torch import nn
10 |
11 | from .transcribe import transcribe as transcribe_function
12 | from .decoding import detect_language as detect_language_function, decode as decode_function
13 |
14 |
15 | @dataclass
16 | class ModelDimensions:
17 | n_mels: int
18 | n_audio_ctx: int
19 | n_audio_state: int
20 | n_audio_head: int
21 | n_audio_layer: int
22 | n_vocab: int
23 | n_text_ctx: int
24 | n_text_state: int
25 | n_text_head: int
26 | n_text_layer: int
27 |
28 |
29 | class LayerNorm(nn.LayerNorm):
30 | def forward(self, x: Tensor) -> Tensor:
31 | return super().forward(x.float()).type(x.dtype)
32 |
33 |
34 | class Linear(nn.Linear):
35 | def forward(self, x: Tensor) -> Tensor:
36 | return F.linear(
37 | x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
38 | )
39 |
40 |
41 | class Conv1d(nn.Conv1d):
42 | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
43 | return super()._conv_forward(
44 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
45 | )
46 |
47 |
48 | def sinusoids(length, channels, max_timescale=10000):
49 | """Returns sinusoids for positional embedding"""
50 | assert channels % 2 == 0
51 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
52 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
53 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
54 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
55 |
56 |
57 | class MultiHeadAttention(nn.Module):
58 | def __init__(self, n_state: int, n_head: int):
59 | super().__init__()
60 | self.n_head = n_head
61 | self.query = Linear(n_state, n_state)
62 | self.key = Linear(n_state, n_state, bias=False)
63 | self.value = Linear(n_state, n_state)
64 | self.out = Linear(n_state, n_state)
65 |
66 | def forward(
67 | self,
68 | x: Tensor,
69 | xa: Optional[Tensor] = None,
70 | mask: Optional[Tensor] = None,
71 | kv_cache: Optional[dict] = None,
72 | ):
73 | q = self.query(x)
74 |
75 | if kv_cache is None or xa is None:
76 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
77 | # otherwise, perform key/value projections for self- or cross-attention as usual.
78 | k = self.key(x if xa is None else xa)
79 | v = self.value(x if xa is None else xa)
80 | else:
81 | # for cross-attention, calculate keys and values once and reuse in subsequent calls.
82 | k = kv_cache.get(self.key, self.key(xa))
83 | v = kv_cache.get(self.value, self.value(xa))
84 |
85 | wv = self.qkv_attention(q, k, v, mask)
86 | return self.out(wv)
87 |
88 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
89 | n_batch, n_ctx, n_state = q.shape
90 | scale = (n_state // self.n_head) ** -0.25
91 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
92 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
93 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
94 |
95 | qk = q @ k
96 | if mask is not None:
97 | qk = qk + mask[:n_ctx, :n_ctx]
98 |
99 | w = F.softmax(qk.float(), dim=-1).to(q.dtype)
100 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
101 |
102 |
103 | class ResidualAttentionBlock(nn.Module):
104 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
105 | super().__init__()
106 |
107 | self.attn = MultiHeadAttention(n_state, n_head)
108 | self.attn_ln = LayerNorm(n_state)
109 |
110 | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
111 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
112 |
113 | n_mlp = n_state * 4
114 | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
115 | self.mlp_ln = LayerNorm(n_state)
116 |
117 | def forward(
118 | self,
119 | x: Tensor,
120 | xa: Optional[Tensor] = None,
121 | mask: Optional[Tensor] = None,
122 | kv_cache: Optional[dict] = None,
123 | ):
124 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
125 | if self.cross_attn:
126 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
127 | x = x + self.mlp(self.mlp_ln(x))
128 | return x
129 |
130 |
131 | class AudioEncoder(nn.Module):
132 | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
133 | super().__init__()
134 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
135 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
136 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
137 |
138 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
139 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
140 | )
141 | self.ln_post = LayerNorm(n_state)
142 |
143 | def forward(self, x: Tensor, include_embeddings: bool = False):
144 | """
145 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
146 | the mel spectrogram of the audio
147 | include_embeddings: bool
148 | whether to include intermediate steps in the output
149 | """
150 | x = F.gelu(self.conv1(x))
151 | x = F.gelu(self.conv2(x))
152 | x = x.permute(0, 2, 1)
153 |
154 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
155 | x = (x + self.positional_embedding).to(x.dtype)
156 |
157 | if include_embeddings:
158 | embeddings = [x.cpu().detach().numpy()]
159 |
160 | for block in self.blocks:
161 | x = block(x)
162 | if include_embeddings:
163 | embeddings.append(x.cpu().detach().numpy())
164 |
165 | x = self.ln_post(x)
166 |
167 | if include_embeddings:
168 | embeddings = np.stack(embeddings, axis=1)
169 | return x, embeddings
170 | else:
171 | return x
172 |
173 |
174 | class TextDecoder(nn.Module):
175 | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
176 | super().__init__()
177 |
178 | self.token_embedding = nn.Embedding(n_vocab, n_state)
179 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
180 |
181 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
182 | [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
183 | )
184 | self.ln = LayerNorm(n_state)
185 |
186 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
187 | self.register_buffer("mask", mask, persistent=False)
188 |
189 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
190 | """
191 | x : torch.LongTensor, shape = (batch_size, <= n_ctx)
192 | the text tokens
193 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
194 | the encoded audio features to be attended on
195 | include_embeddings : bool
196 | Whether to include intermediate values in the output to this function
197 | """
198 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
199 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
200 | x = x.to(xa.dtype)
201 |
202 | if include_embeddings:
203 | embeddings = [x.cpu().detach().numpy()]
204 |
205 | for block in self.blocks:
206 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
207 | if include_embeddings:
208 | embeddings.append(x.cpu().detach().numpy())
209 |
210 | x = self.ln(x)
211 | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
212 |
213 | if include_embeddings:
214 | embeddings = np.stack(embeddings, axis=1)
215 | return logits, embeddings
216 | else:
217 | return logits
218 |
219 |
220 | class Whisper(nn.Module):
221 | def __init__(self, dims: ModelDimensions):
222 | super().__init__()
223 | self.dims = dims
224 | self.encoder = AudioEncoder(
225 | self.dims.n_mels,
226 | self.dims.n_audio_ctx,
227 | self.dims.n_audio_state,
228 | self.dims.n_audio_head,
229 | self.dims.n_audio_layer,
230 | )
231 | self.decoder = TextDecoder(
232 | self.dims.n_vocab,
233 | self.dims.n_text_ctx,
234 | self.dims.n_text_state,
235 | self.dims.n_text_head,
236 | self.dims.n_text_layer,
237 | )
238 |
239 | def embed_audio(self, mel: torch.Tensor):
240 | return self.encoder.forward(mel)
241 |
242 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
243 | return self.decoder.forward(tokens, audio_features)
244 |
245 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
246 | return self.decoder(tokens, self.encoder(mel))
247 |
248 | @property
249 | def device(self):
250 | return next(self.parameters()).device
251 |
252 | @property
253 | def is_multilingual(self):
254 | return self.dims.n_vocab == 51865
255 |
256 | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
257 | """
258 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
259 | tensors calculated for the previous positions. This method returns a dictionary that stores
260 | all caches, and the necessary hooks for the key and value projection modules that save the
261 | intermediate tensors to be reused during later calculations.
262 |
263 | Returns
264 | -------
265 | cache : Dict[nn.Module, torch.Tensor]
266 | A dictionary object mapping the key/value projection modules to its cache
267 | hooks : List[RemovableHandle]
268 | List of PyTorch RemovableHandle objects to stop the hooks to be called
269 | """
270 | cache = {**cache} if cache is not None else {}
271 | hooks = []
272 |
273 | def save_to_cache(module, _, output):
274 | if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
275 | cache[module] = output # save as-is, for the first token or cross attention
276 | else:
277 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
278 | return cache[module]
279 |
280 | def install_hooks(layer: nn.Module):
281 | if isinstance(layer, MultiHeadAttention):
282 | hooks.append(layer.key.register_forward_hook(save_to_cache))
283 | hooks.append(layer.value.register_forward_hook(save_to_cache))
284 |
285 | self.decoder.apply(install_hooks)
286 | return cache, hooks
287 |
288 | detect_language = detect_language_function
289 | transcribe = transcribe_function
290 | decode = decode_function
291 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer
2 | from .english import EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
52 | )
53 |
54 |
55 | class BasicTextNormalizer:
56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
58 | self.split_letters = split_letters
59 |
60 | def __call__(self, s: str):
61 | s = s.lower()
62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
64 | s = self.clean(s).lower()
65 |
66 | if self.split_letters:
67 | s = " ".join(regex.findall(r"\X", s, regex.U))
68 |
69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
70 |
71 | return s
72 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from functools import lru_cache
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | from transformers import GPT2TokenizerFast
9 |
10 | LANGUAGES = {
11 | "en": "english",
12 | "zh": "chinese",
13 | "de": "german",
14 | "es": "spanish",
15 | "ru": "russian",
16 | "ko": "korean",
17 | "fr": "french",
18 | "ja": "japanese",
19 | "pt": "portuguese",
20 | "tr": "turkish",
21 | "pl": "polish",
22 | "ca": "catalan",
23 | "nl": "dutch",
24 | "ar": "arabic",
25 | "sv": "swedish",
26 | "it": "italian",
27 | "id": "indonesian",
28 | "hi": "hindi",
29 | "fi": "finnish",
30 | "vi": "vietnamese",
31 | "iw": "hebrew",
32 | "uk": "ukrainian",
33 | "el": "greek",
34 | "ms": "malay",
35 | "cs": "czech",
36 | "ro": "romanian",
37 | "da": "danish",
38 | "hu": "hungarian",
39 | "ta": "tamil",
40 | "no": "norwegian",
41 | "th": "thai",
42 | "ur": "urdu",
43 | "hr": "croatian",
44 | "bg": "bulgarian",
45 | "lt": "lithuanian",
46 | "la": "latin",
47 | "mi": "maori",
48 | "ml": "malayalam",
49 | "cy": "welsh",
50 | "sk": "slovak",
51 | "te": "telugu",
52 | "fa": "persian",
53 | "lv": "latvian",
54 | "bn": "bengali",
55 | "sr": "serbian",
56 | "az": "azerbaijani",
57 | "sl": "slovenian",
58 | "kn": "kannada",
59 | "et": "estonian",
60 | "mk": "macedonian",
61 | "br": "breton",
62 | "eu": "basque",
63 | "is": "icelandic",
64 | "hy": "armenian",
65 | "ne": "nepali",
66 | "mn": "mongolian",
67 | "bs": "bosnian",
68 | "kk": "kazakh",
69 | "sq": "albanian",
70 | "sw": "swahili",
71 | "gl": "galician",
72 | "mr": "marathi",
73 | "pa": "punjabi",
74 | "si": "sinhala",
75 | "km": "khmer",
76 | "sn": "shona",
77 | "yo": "yoruba",
78 | "so": "somali",
79 | "af": "afrikaans",
80 | "oc": "occitan",
81 | "ka": "georgian",
82 | "be": "belarusian",
83 | "tg": "tajik",
84 | "sd": "sindhi",
85 | "gu": "gujarati",
86 | "am": "amharic",
87 | "yi": "yiddish",
88 | "lo": "lao",
89 | "uz": "uzbek",
90 | "fo": "faroese",
91 | "ht": "haitian creole",
92 | "ps": "pashto",
93 | "tk": "turkmen",
94 | "nn": "nynorsk",
95 | "mt": "maltese",
96 | "sa": "sanskrit",
97 | "lb": "luxembourgish",
98 | "my": "myanmar",
99 | "bo": "tibetan",
100 | "tl": "tagalog",
101 | "mg": "malagasy",
102 | "as": "assamese",
103 | "tt": "tatar",
104 | "haw": "hawaiian",
105 | "ln": "lingala",
106 | "ha": "hausa",
107 | "ba": "bashkir",
108 | "jw": "javanese",
109 | "su": "sundanese",
110 | }
111 |
112 | # language code lookup by name, with a few language aliases
113 | TO_LANGUAGE_CODE = {
114 | **{language: code for code, language in LANGUAGES.items()},
115 | "burmese": "my",
116 | "valencian": "ca",
117 | "flemish": "nl",
118 | "haitian": "ht",
119 | "letzeburgesch": "lb",
120 | "pushto": "ps",
121 | "panjabi": "pa",
122 | "moldavian": "ro",
123 | "moldovan": "ro",
124 | "sinhalese": "si",
125 | "castilian": "es",
126 | }
127 |
128 |
129 | @dataclass(frozen=True)
130 | class Tokenizer:
131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
132 |
133 | tokenizer: "GPT2TokenizerFast"
134 | language: Optional[str]
135 | sot_sequence: Tuple[int]
136 |
137 | def encode(self, text, **kwargs):
138 | return self.tokenizer.encode(text, **kwargs)
139 |
140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
141 | return self.tokenizer.decode(token_ids, **kwargs)
142 |
143 | def decode_with_timestamps(self, tokens) -> str:
144 | """
145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
147 | """
148 | outputs = [[]]
149 | for token in tokens:
150 | if token >= self.timestamp_begin:
151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
152 | outputs.append(timestamp)
153 | outputs.append([])
154 | else:
155 | outputs[-1].append(token)
156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
157 | return "".join(outputs)
158 |
159 | @property
160 | @lru_cache()
161 | def eot(self) -> int:
162 | return self.tokenizer.eos_token_id
163 |
164 | @property
165 | @lru_cache()
166 | def sot(self) -> int:
167 | return self._get_single_token_id("<|startoftranscript|>")
168 |
169 | @property
170 | @lru_cache()
171 | def sot_lm(self) -> int:
172 | return self._get_single_token_id("<|startoflm|>")
173 |
174 | @property
175 | @lru_cache()
176 | def sot_prev(self) -> int:
177 | return self._get_single_token_id("<|startofprev|>")
178 |
179 | @property
180 | @lru_cache()
181 | def no_speech(self) -> int:
182 | return self._get_single_token_id("<|nospeech|>")
183 |
184 | @property
185 | @lru_cache()
186 | def no_timestamps(self) -> int:
187 | return self._get_single_token_id("<|notimestamps|>")
188 |
189 | @property
190 | @lru_cache()
191 | def timestamp_begin(self) -> int:
192 | return self.tokenizer.all_special_ids[-1] + 1
193 |
194 | @property
195 | @lru_cache()
196 | def language_token(self) -> int:
197 | """Returns the token id corresponding to the value of the `language` field"""
198 | if self.language is None:
199 | raise ValueError(f"This tokenizer does not have language token configured")
200 |
201 | additional_tokens = dict(
202 | zip(
203 | self.tokenizer.additional_special_tokens,
204 | self.tokenizer.additional_special_tokens_ids,
205 | )
206 | )
207 | candidate = f"<|{self.language}|>"
208 | if candidate in additional_tokens:
209 | return additional_tokens[candidate]
210 |
211 | raise KeyError(f"Language {self.language} not found in tokenizer.")
212 |
213 | @property
214 | @lru_cache()
215 | def all_language_tokens(self) -> Tuple[int]:
216 | result = []
217 | for token, token_id in zip(
218 | self.tokenizer.additional_special_tokens,
219 | self.tokenizer.additional_special_tokens_ids,
220 | ):
221 | if token.strip("<|>") in LANGUAGES:
222 | result.append(token_id)
223 | return tuple(result)
224 |
225 | @property
226 | @lru_cache()
227 | def all_language_codes(self) -> Tuple[str]:
228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
229 |
230 | @property
231 | @lru_cache()
232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
233 | return tuple(list(self.sot_sequence) + [self.no_timestamps])
234 |
235 | @property
236 | @lru_cache()
237 | def non_speech_tokens(self) -> Tuple[int]:
238 | """
239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
241 |
242 | - ♪♪♪
243 | - ( SPEAKING FOREIGN LANGUAGE )
244 | - [DAVID] Hey there,
245 |
246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
247 | """
248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
250 |
251 | # symbols that may be a single token or multiple tokens depending on the tokenizer.
252 | # In case they're multiple tokens, suppress the first token, which is safe because:
253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
255 | miscellaneous = set("♩♪♫♬♭♮♯")
256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
257 |
258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
260 | for symbol in symbols + list(miscellaneous):
261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
262 | if len(tokens) == 1 or symbol in miscellaneous:
263 | result.add(tokens[0])
264 |
265 | return tuple(sorted(result))
266 |
267 | def _get_single_token_id(self, text) -> int:
268 | tokens = self.tokenizer.encode(text)
269 | assert len(tokens) == 1, f"{text} is not encoded as a single token"
270 | return tokens[0]
271 |
272 |
273 | @lru_cache(maxsize=None)
274 | def build_tokenizer(name: str = "gpt2"):
275 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
276 | path = os.path.join(os.path.dirname(__file__), "assets", name)
277 | tokenizer = GPT2TokenizerFast.from_pretrained(path)
278 |
279 | specials = [
280 | "<|startoftranscript|>",
281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
282 | "<|translate|>",
283 | "<|transcribe|>",
284 | "<|startoflm|>",
285 | "<|startofprev|>",
286 | "<|nospeech|>",
287 | "<|notimestamps|>",
288 | ]
289 |
290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
291 | return tokenizer
292 |
293 |
294 | @lru_cache(maxsize=None)
295 | def get_tokenizer(
296 | multilingual: bool,
297 | *,
298 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
299 | language: Optional[str] = None,
300 | ) -> Tokenizer:
301 | if language is not None:
302 | language = language.lower()
303 | if language not in LANGUAGES:
304 | if language in TO_LANGUAGE_CODE:
305 | language = TO_LANGUAGE_CODE[language]
306 | else:
307 | raise ValueError(f"Unsupported language: {language}")
308 |
309 | if multilingual:
310 | tokenizer_name = "multilingual"
311 | task = task or "transcribe"
312 | language = language or "en"
313 | else:
314 | tokenizer_name = "gpt2"
315 | task = None
316 | language = None
317 |
318 | tokenizer = build_tokenizer(name=tokenizer_name)
319 | all_special_ids: List[int] = tokenizer.all_special_ids
320 | sot: int = all_special_ids[1]
321 | translate: int = all_special_ids[-6]
322 | transcribe: int = all_special_ids[-5]
323 |
324 | langs = tuple(LANGUAGES.keys())
325 | sot_sequence = [sot]
326 | if language is not None:
327 | sot_sequence.append(sot + 1 + langs.index(language))
328 | if task is not None:
329 | sot_sequence.append(transcribe if task == "transcribe" else translate)
330 |
331 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
332 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/transcribe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 |
10 | from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
11 | from .decoding import DecodingOptions, DecodingResult
12 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
13 | from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
14 |
15 | if TYPE_CHECKING:
16 | from .model import Whisper
17 |
18 |
19 | def transcribe(
20 | model: "Whisper",
21 | audio: Union[str, np.ndarray, torch.Tensor],
22 | *,
23 | verbose: Optional[bool] = None,
24 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
25 | compression_ratio_threshold: Optional[float] = 2.4,
26 | logprob_threshold: Optional[float] = -1.0,
27 | no_speech_threshold: Optional[float] = 0.6,
28 | condition_on_previous_text: bool = True,
29 | force_extraction: bool = False,
30 | **decode_options,
31 | ):
32 | """
33 | Transcribe an audio file using Whisper
34 |
35 | Parameters
36 | ----------
37 | model: Whisper
38 | The Whisper model instance
39 |
40 | audio: Union[str, np.ndarray, torch.Tensor]
41 | The path to the audio file to open, or the audio waveform
42 |
43 | verbose: bool
44 | Whether to display the text being decoded to the console. If True, displays all the details,
45 | If False, displays minimal details. If None, does not display anything
46 |
47 | temperature: Union[float, Tuple[float, ...]]
48 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
49 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
50 |
51 | compression_ratio_threshold: float
52 | If the gzip compression ratio is above this value, treat as failed
53 |
54 | logprob_threshold: float
55 | If the average log probability over sampled tokens is below this value, treat as failed
56 |
57 | no_speech_threshold: float
58 | If the no_speech probability is higher than this value AND the average log probability
59 | over sampled tokens is below `logprob_threshold`, consider the segment as silent
60 |
61 | condition_on_previous_text: bool
62 | if True, the previous output of the model is provided as a prompt for the next window;
63 | disabling may make the text inconsistent across windows, but the model becomes less prone to
64 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
65 |
66 | decode_options: dict
67 | Keyword arguments to construct `DecodingOptions` instances
68 |
69 | Returns
70 | -------
71 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
72 | the spoken language ("language"), which is detected when `decode_options["language"]` is None.
73 | """
74 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
75 | if model.device == torch.device("cpu"):
76 | if torch.cuda.is_available():
77 | warnings.warn("Performing inference on CPU when CUDA is available")
78 | if dtype == torch.float16:
79 | warnings.warn("FP16 is not supported on CPU; using FP32 instead")
80 | dtype = torch.float32
81 |
82 | if dtype == torch.float32:
83 | decode_options["fp16"] = False
84 |
85 | mel = log_mel_spectrogram(audio)
86 |
87 | all_segments = []
88 | def add_segment(
89 | *, start: float, end: float, encoder_embeddings
90 | ):
91 |
92 | all_segments.append(
93 | {
94 | "start": start,
95 | "end": end,
96 | "encoder_embeddings":encoder_embeddings,
97 | }
98 | )
99 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
100 | num_frames = mel.shape[-1]
101 | seek = 0
102 | previous_seek_value = seek
103 | sample_skip = 3000 #
104 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
105 | while seek < num_frames:
106 | # seek是开始的帧数
107 | end_seek = min(seek + sample_skip, num_frames)
108 | segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
109 |
110 | single = segment.ndim == 2
111 | if single:
112 | segment = segment.unsqueeze(0)
113 | if dtype == torch.float16:
114 | segment = segment.half()
115 | audio_features, embeddings = model.encoder(segment, include_embeddings = True)
116 |
117 | encoder_embeddings = embeddings
118 | #print(f"encoder_embeddings shape {encoder_embeddings.shape}")
119 | add_segment(
120 | start=seek,
121 | end=end_seek,
122 | #text_tokens=tokens,
123 | #result=result,
124 | encoder_embeddings=encoder_embeddings,
125 | )
126 | seek+=sample_skip
127 |
128 | return dict(segments=all_segments)
129 |
130 |
131 | def cli():
132 | from . import available_models
133 |
134 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
135 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
136 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
137 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
138 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
139 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
140 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
141 |
142 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
143 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
144 |
145 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
146 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
147 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
148 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
149 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
150 |
151 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
152 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
153 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
154 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
155 |
156 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
157 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
158 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
159 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
160 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
161 |
162 | args = parser.parse_args().__dict__
163 | model_name: str = args.pop("model")
164 | model_dir: str = args.pop("model_dir")
165 | output_dir: str = args.pop("output_dir")
166 | device: str = args.pop("device")
167 | os.makedirs(output_dir, exist_ok=True)
168 |
169 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
170 | if args["language"] is not None:
171 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
172 | args["language"] = "en"
173 |
174 | temperature = args.pop("temperature")
175 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
176 | if temperature_increment_on_fallback is not None:
177 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
178 | else:
179 | temperature = [temperature]
180 |
181 | threads = args.pop("threads")
182 | if threads > 0:
183 | torch.set_num_threads(threads)
184 |
185 | from . import load_model
186 | model = load_model(model_name, device=device, download_root=model_dir)
187 |
188 | for audio_path in args.pop("audio"):
189 | result = transcribe(model, audio_path, temperature=temperature, **args)
190 |
191 | audio_basename = os.path.basename(audio_path)
192 |
193 | # save TXT
194 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
195 | write_txt(result["segments"], file=txt)
196 |
197 | # save VTT
198 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
199 | write_vtt(result["segments"], file=vtt)
200 |
201 | # save SRT
202 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
203 | write_srt(result["segments"], file=srt)
204 |
205 |
206 | if __name__ == '__main__':
207 | cli()
208 |
--------------------------------------------------------------------------------
/musetalk/whisper/whisper/utils.py:
--------------------------------------------------------------------------------
1 | import zlib
2 | from typing import Iterator, TextIO
3 |
4 |
5 | def exact_div(x, y):
6 | assert x % y == 0
7 | return x // y
8 |
9 |
10 | def str2bool(string):
11 | str2val = {"True": True, "False": False}
12 | if string in str2val:
13 | return str2val[string]
14 | else:
15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
16 |
17 |
18 | def optional_int(string):
19 | return None if string == "None" else int(string)
20 |
21 |
22 | def optional_float(string):
23 | return None if string == "None" else float(string)
24 |
25 |
26 | def compression_ratio(text) -> float:
27 | return len(text) / len(zlib.compress(text.encode("utf-8")))
28 |
29 |
30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
31 | assert seconds >= 0, "non-negative timestamp expected"
32 | milliseconds = round(seconds * 1000.0)
33 |
34 | hours = milliseconds // 3_600_000
35 | milliseconds -= hours * 3_600_000
36 |
37 | minutes = milliseconds // 60_000
38 | milliseconds -= minutes * 60_000
39 |
40 | seconds = milliseconds // 1_000
41 | milliseconds -= seconds * 1_000
42 |
43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
45 |
46 |
47 | def write_txt(transcript: Iterator[dict], file: TextIO):
48 | for segment in transcript:
49 | print(segment['text'].strip(), file=file, flush=True)
50 |
51 |
52 | def write_vtt(transcript: Iterator[dict], file: TextIO):
53 | print("WEBVTT\n", file=file)
54 | for segment in transcript:
55 | print(
56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
57 | f"{segment['text'].strip().replace('-->', '->')}\n",
58 | file=file,
59 | flush=True,
60 | )
61 |
62 |
63 | def write_srt(transcript: Iterator[dict], file: TextIO):
64 | """
65 | Write a transcript to a file in SRT format.
66 |
67 | Example usage:
68 | from pathlib import Path
69 | from whisper.utils import write_srt
70 |
71 | result = transcribe(model, audio_path, temperature=temperature, **args)
72 |
73 | # save SRT
74 | audio_basename = Path(audio_path).stem
75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
76 | write_srt(result["segments"], file=srt)
77 | """
78 | for i, segment in enumerate(transcript, start=1):
79 | # write srt lines
80 | print(
81 | f"{i}\n"
82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
84 | f"{segment['text'].strip().replace('-->', '->')}\n",
85 | file=file,
86 | flush=True,
87 | )
88 |
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths
3 | from .inference import MuseTalk_INFER
4 | from .inference_realtime import Infer_Real_Time
5 | from pydub import AudioSegment
6 | from moviepy.editor import VideoFileClip,AudioFileClip
7 |
8 | parent_directory = os.path.dirname(os.path.abspath(__file__))
9 | input_path = folder_paths.get_input_directory()
10 | out_path = folder_paths.get_output_directory()
11 |
12 | class MuseTalkRealTime:
13 | @classmethod
14 | def INPUT_TYPES(s):
15 | return {
16 | "required":{
17 | "audio":("AUDIO",),
18 | "video":("VIDEO",),
19 | "avatar_id":("STRING",{
20 | "default": "talker1"
21 | }),
22 | "bbox_shift":("INT",{
23 | "default":0
24 | }),
25 | "fps":("INT",{
26 | "default":25
27 | }),
28 | "batch_size":("INT",{
29 | "default":4
30 | }),
31 | "preparation":("BOOLEAN",{
32 | "default":True
33 | })
34 | }
35 | }
36 | CATEGORY = "AIFSH_MuseTalk"
37 | DESCRIPTION = "hello world!"
38 |
39 | RETURN_TYPES = ("VIDEO",)
40 |
41 | OUTPUT_NODE = False
42 |
43 | FUNCTION = "process"
44 |
45 | def process(self,audio,video,avatar_id,bbox_shift,fps,batch_size,preparation):
46 | muse_talk_real_time = Infer_Real_Time()
47 | output_vid_name = muse_talk_real_time(audio, video,avatar_id,fps=fps,batch_size=batch_size,
48 | preparation=preparation,bbox_shift=bbox_shift)
49 | return (output_vid_name,)
50 |
51 |
52 | class MuseTalk:
53 | @classmethod
54 | def INPUT_TYPES(s):
55 | return {
56 | "required":{
57 | "audio":("AUDIO",),
58 | "video":("VIDEO",),
59 | "bbox_shift":("INT",{
60 | "default":0
61 | }),
62 | "fps":("INT",{
63 | "default":25
64 | }),
65 | "batch_size":("INT",{
66 | "default":8
67 | }),
68 | "batch_size_fa":("INT",{
69 | "default":2
70 | }),
71 | "use_saved_coord":("BOOLEAN",{
72 | "default":False
73 | })
74 | }
75 | }
76 | CATEGORY = "AIFSH_MuseTalk"
77 | DESCRIPTION = "hello world!"
78 |
79 | RETURN_TYPES = ("VIDEO",)
80 |
81 | OUTPUT_NODE = False
82 |
83 | FUNCTION = "process"
84 |
85 | def process(self,audio,video,bbox_shift,fps,batch_size,batch_size_fa,use_saved_coord):
86 | muse_talk = MuseTalk_INFER(bbox_shift,fps,batch_size,batch_size_fa,use_saved_coord)
87 | output_vid_name = muse_talk(video, audio)
88 | return (output_vid_name,)
89 |
90 |
91 | class CombineAudioVideo:
92 | @classmethod
93 | def INPUT_TYPES(s):
94 | return {"required":
95 | {"vocal_AUDIO": ("AUDIO",),
96 | "bgm_AUDIO": ("AUDIO",),
97 | "video": ("VIDEO",)
98 | }
99 | }
100 |
101 | CATEGORY = "AIFSH_MuseTalk"
102 | DESCRIPTION = "hello world!"
103 |
104 | RETURN_TYPES = ("VIDEO",)
105 |
106 | OUTPUT_NODE = False
107 |
108 | FUNCTION = "combine"
109 |
110 | def combine(self, vocal_AUDIO,bgm_AUDIO,video):
111 | vocal = AudioSegment.from_file(vocal_AUDIO)
112 | bgm = AudioSegment.from_file(bgm_AUDIO)
113 | audio = vocal.overlay(bgm)
114 | audio_file = os.path.join(out_path,"ip_lap_voice.wav")
115 | audio.export(audio_file, format="wav")
116 | cm_video_file = os.path.join(out_path,"voice_"+os.path.basename(video))
117 | video_clip = VideoFileClip(video)
118 | audio_clip = AudioFileClip(audio_file)
119 | new_video_clip = video_clip.set_audio(audio_clip)
120 | new_video_clip.write_videofile(cm_video_file)
121 | return (cm_video_file,)
122 |
123 |
124 | class PreViewVideo:
125 | @classmethod
126 | def INPUT_TYPES(s):
127 | return {"required":{
128 | "video":("VIDEO",),
129 | }}
130 |
131 | CATEGORY = "AIFSH_MuseTalk"
132 | DESCRIPTION = "hello world!"
133 |
134 | RETURN_TYPES = ()
135 |
136 | OUTPUT_NODE = True
137 |
138 | FUNCTION = "load_video"
139 |
140 | def load_video(self, video):
141 | video_name = os.path.basename(video)
142 | video_path_name = os.path.basename(os.path.dirname(video))
143 | return {"ui":{"video":[video_name,video_path_name]}}
144 |
145 | class LoadVideo:
146 | @classmethod
147 | def INPUT_TYPES(s):
148 | files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]]
149 | return {"required":{
150 | "video":(files,),
151 | }}
152 |
153 | CATEGORY = "AIFSH_MuseTalk"
154 | DESCRIPTION = "hello world!"
155 |
156 | RETURN_TYPES = ("VIDEO","AUDIO")
157 |
158 | OUTPUT_NODE = False
159 |
160 | FUNCTION = "load_video"
161 |
162 | def load_video(self, video):
163 | video_path = os.path.join(input_path,video)
164 | video_clip = VideoFileClip(video_path)
165 | audio_path = os.path.join(input_path,video+".wav")
166 | video_clip.audio.write_audiofile(audio_path)
167 | return (video_path,audio_path,)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | ffmpeg-python
3 | soundfile
4 | diffusers
5 | pydub
6 | moviepy
7 | accelerate
--------------------------------------------------------------------------------
/web.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/web.png
--------------------------------------------------------------------------------
/web/js/previewVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 |
4 | function fitHeight(node) {
5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
6 | node?.graph?.setDirtyCanvas(true);
7 | }
8 | function chainCallback(object, property, callback) {
9 | if (object == undefined) {
10 | //This should not happen.
11 | console.error("Tried to add callback to non-existant object")
12 | return;
13 | }
14 | if (property in object) {
15 | const callback_orig = object[property]
16 | object[property] = function () {
17 | const r = callback_orig.apply(this, arguments);
18 | callback.apply(this, arguments);
19 | return r
20 | };
21 | } else {
22 | object[property] = callback;
23 | }
24 | }
25 |
26 | function addPreviewOptions(nodeType) {
27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) {
28 | // The intended way of appending options is returning a list of extra options,
29 | // but this isn't used in widgetInputs.js and would require
30 | // less generalization of chainCallback
31 | let optNew = []
32 | try {
33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview");
34 |
35 | let url = null
36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) {
37 | //Use full quality video
38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params));
39 | url = previewWidget.videoEl.src
40 | }
41 | if (url) {
42 | optNew.push(
43 | {
44 | content: "Open preview",
45 | callback: () => {
46 | window.open(url, "_blank")
47 | },
48 | },
49 | {
50 | content: "Save preview",
51 | callback: () => {
52 | const a = document.createElement("a");
53 | a.href = url;
54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename"));
55 | document.body.append(a);
56 | a.click();
57 | requestAnimationFrame(() => a.remove());
58 | },
59 | }
60 | );
61 | }
62 | if(options.length > 0 && options[0] != null && optNew.length > 0) {
63 | optNew.push(null);
64 | }
65 | options.unshift(...optNew);
66 |
67 | } catch (error) {
68 | console.log(error);
69 | }
70 |
71 | });
72 | }
73 | function previewVideo(node,file,type){
74 | var element = document.createElement("div");
75 | const previewNode = node;
76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
77 | serialize: false,
78 | hideOnZoom: false,
79 | getValue() {
80 | return element.value;
81 | },
82 | setValue(v) {
83 | element.value = v;
84 | },
85 | });
86 | previewWidget.computeSize = function(width) {
87 | if (this.aspectRatio && !this.parentEl.hidden) {
88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
89 | if (!(height > 0)) {
90 | height = 0;
91 | }
92 | this.computedHeight = height + 10;
93 | return [width, height];
94 | }
95 | return [width, -4];//no loaded src, widget should not display
96 | }
97 | // element.style['pointer-events'] = "none"
98 | previewWidget.value = {hidden: false, paused: false, params: {}}
99 | previewWidget.parentEl = document.createElement("div");
100 | previewWidget.parentEl.className = "video_preview";
101 | previewWidget.parentEl.style['width'] = "100%"
102 | element.appendChild(previewWidget.parentEl);
103 | previewWidget.videoEl = document.createElement("video");
104 | previewWidget.videoEl.controls = true;
105 | previewWidget.videoEl.loop = false;
106 | previewWidget.videoEl.muted = false;
107 | previewWidget.videoEl.style['width'] = "100%"
108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
109 |
110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
111 | fitHeight(this);
112 | });
113 | previewWidget.videoEl.addEventListener("error", () => {
114 | //TODO: consider a way to properly notify the user why a preview isn't shown.
115 | previewWidget.parentEl.hidden = true;
116 | fitHeight(this);
117 | });
118 |
119 | let params = {
120 | "filename": file,
121 | "type": type,
122 | }
123 |
124 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
126 | let target_width = 256
127 | if (element.style?.width) {
128 | //overscale to allow scrolling. Endpoint won't return higher than native
129 | target_width = element.style.width.slice(0,-2)*2;
130 | }
131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
132 | params.force_size = target_width+"x?"
133 | } else {
134 | let size = params.force_size.split("x")
135 | let ar = parseInt(size[0])/parseInt(size[1])
136 | params.force_size = target_width+"x"+(target_width/ar)
137 | }
138 |
139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
140 |
141 | previewWidget.videoEl.hidden = false;
142 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
143 | }
144 |
145 | app.registerExtension({
146 | name: "MuseTalk.VideoPreviewer",
147 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
148 | if (nodeData?.name == "PreViewVideo") {
149 | nodeType.prototype.onExecuted = function (data) {
150 | previewVideo(this, data.video[0], data.video[1]);
151 | }
152 | addPreviewOptions(nodeType)
153 | }
154 | }
155 | });
156 |
--------------------------------------------------------------------------------
/web/js/uploadVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 | import { ComfyWidgets } from "../../../scripts/widgets.js"
4 |
5 | function fitHeight(node) {
6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
7 | node?.graph?.setDirtyCanvas(true);
8 | }
9 |
10 | function previewVideo(node,file){
11 | while (node.widgets.length > 2){
12 | node.widgets.pop()
13 | }
14 | try {
15 | var el = document.getElementById("uploadVideo");
16 | el.remove();
17 | } catch (error) {
18 | console.log(error);
19 | }
20 | var element = document.createElement("div");
21 | element.id = "uploadVideo";
22 | const previewNode = node;
23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
24 | serialize: false,
25 | hideOnZoom: false,
26 | getValue() {
27 | return element.value;
28 | },
29 | setValue(v) {
30 | element.value = v;
31 | },
32 | });
33 | previewWidget.computeSize = function(width) {
34 | if (this.aspectRatio && !this.parentEl.hidden) {
35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
36 | if (!(height > 0)) {
37 | height = 0;
38 | }
39 | this.computedHeight = height + 10;
40 | return [width, height];
41 | }
42 | return [width, -4];//no loaded src, widget should not display
43 | }
44 | // element.style['pointer-events'] = "none"
45 | previewWidget.value = {hidden: false, paused: false, params: {}}
46 | previewWidget.parentEl = document.createElement("div");
47 | previewWidget.parentEl.className = "video_preview";
48 | previewWidget.parentEl.style['width'] = "100%"
49 | element.appendChild(previewWidget.parentEl);
50 | previewWidget.videoEl = document.createElement("video");
51 | previewWidget.videoEl.controls = true;
52 | previewWidget.videoEl.loop = false;
53 | previewWidget.videoEl.muted = false;
54 | previewWidget.videoEl.style['width'] = "100%"
55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
56 |
57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
58 | fitHeight(this);
59 | });
60 | previewWidget.videoEl.addEventListener("error", () => {
61 | //TODO: consider a way to properly notify the user why a preview isn't shown.
62 | previewWidget.parentEl.hidden = true;
63 | fitHeight(this);
64 | });
65 |
66 | let params = {
67 | "filename": file,
68 | "type": "input",
69 | }
70 |
71 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
73 | let target_width = 256
74 | if (element.style?.width) {
75 | //overscale to allow scrolling. Endpoint won't return higher than native
76 | target_width = element.style.width.slice(0,-2)*2;
77 | }
78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
79 | params.force_size = target_width+"x?"
80 | } else {
81 | let size = params.force_size.split("x")
82 | let ar = parseInt(size[0])/parseInt(size[1])
83 | params.force_size = target_width+"x"+(target_width/ar)
84 | }
85 |
86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
87 |
88 | previewWidget.videoEl.hidden = false;
89 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
90 | }
91 |
92 | function videoUpload(node, inputName, inputData, app) {
93 | const videoWidget = node.widgets.find((w) => w.name === "video");
94 | let uploadWidget;
95 | /*
96 | A method that returns the required style for the html
97 | */
98 | var default_value = videoWidget.value;
99 | Object.defineProperty(videoWidget, "value", {
100 | set : function(value) {
101 | this._real_value = value;
102 | },
103 |
104 | get : function() {
105 | let value = "";
106 | if (this._real_value) {
107 | value = this._real_value;
108 | } else {
109 | return default_value;
110 | }
111 |
112 | if (value.filename) {
113 | let real_value = value;
114 | value = "";
115 | if (real_value.subfolder) {
116 | value = real_value.subfolder + "/";
117 | }
118 |
119 | value += real_value.filename;
120 |
121 | if(real_value.type && real_value.type !== "input")
122 | value += ` [${real_value.type}]`;
123 | }
124 | return value;
125 | }
126 | });
127 | async function uploadFile(file, updateNode, pasted = false) {
128 | try {
129 | // Wrap file in formdata so it includes filename
130 | const body = new FormData();
131 | body.append("image", file);
132 | if (pasted) body.append("subfolder", "pasted");
133 | const resp = await api.fetchApi("/upload/image", {
134 | method: "POST",
135 | body,
136 | });
137 |
138 | if (resp.status === 200) {
139 | const data = await resp.json();
140 | // Add the file to the dropdown list and update the widget value
141 | let path = data.name;
142 | if (data.subfolder) path = data.subfolder + "/" + path;
143 |
144 | if (!videoWidget.options.values.includes(path)) {
145 | videoWidget.options.values.push(path);
146 | }
147 |
148 | if (updateNode) {
149 | videoWidget.value = path;
150 | previewVideo(node,path)
151 |
152 | }
153 | } else {
154 | alert(resp.status + " - " + resp.statusText);
155 | }
156 | } catch (error) {
157 | alert(error);
158 | }
159 | }
160 |
161 | const fileInput = document.createElement("input");
162 | Object.assign(fileInput, {
163 | type: "file",
164 | accept: "video/webm,video/mp4,video/mkv,video/avi",
165 | style: "display: none",
166 | onchange: async () => {
167 | if (fileInput.files.length) {
168 | await uploadFile(fileInput.files[0], true);
169 | }
170 | },
171 | });
172 | document.body.append(fileInput);
173 |
174 | // Create the button widget for selecting the files
175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => {
176 | fileInput.click();
177 | });
178 |
179 | uploadWidget.serialize = false;
180 |
181 | previewVideo(node, videoWidget.value);
182 | const cb = node.callback;
183 | videoWidget.callback = function () {
184 | previewVideo(node,videoWidget.value);
185 | if (cb) {
186 | return cb.apply(this, arguments);
187 | }
188 | };
189 |
190 | return { widget: uploadWidget };
191 | }
192 |
193 | ComfyWidgets.VIDEOPLOAD = videoUpload;
194 |
195 | app.registerExtension({
196 | name: "MuseTalk.UploadVideo",
197 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
198 | if (nodeData?.name == "LoadVideo") {
199 | nodeData.input.required.upload = ["VIDEOPLOAD"];
200 | }
201 | },
202 | });
203 |
204 |
--------------------------------------------------------------------------------
/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MuseTalk_FSH/e93586c997982e951d65cecc32d30ab60ac1cd9b/wechat.jpg
--------------------------------------------------------------------------------