0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/talkingface/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | class ImagePool():
6 | """This class implements an image buffer that stores previously generated images.
7 |
8 | This buffer enables us to update discriminators using a history of generated images
9 | rather than the ones produced by the latest generators.
10 | """
11 |
12 | def __init__(self, pool_size):
13 | """Initialize the ImagePool class
14 |
15 | Parameters:
16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17 | """
18 | self.pool_size = pool_size
19 | if self.pool_size > 0: # create an empty pool
20 | self.num_imgs = 0
21 | self.images = []
22 |
23 | def query(self, images):
24 | """Return an image from the pool.
25 |
26 | Parameters:
27 | images: the latest generated images from the generator
28 |
29 | Returns images from the buffer.
30 |
31 | By 50/100, the buffer will return input images.
32 | By 50/100, the buffer will return images previously stored in the buffer,
33 | and insert the current images to the buffer.
34 | """
35 | if self.pool_size == 0: # if the buffer size is 0, do nothing
36 | return images
37 | return_images = []
38 | for image in images:
39 | image = torch.unsqueeze(image.data, 0)
40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41 | self.num_imgs = self.num_imgs + 1
42 | self.images.append(image)
43 | return_images.append(image)
44 | else:
45 | p = random.uniform(0, 1)
46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48 | tmp = self.images[random_id].clone()
49 | self.images[random_id] = image
50 | return_images.append(tmp)
51 | else: # by another 50% chance, the buffer will return the current image
52 | return_images.append(image)
53 | return_images = torch.cat(return_images, 0) # collect all the images and return
54 | return return_images
55 |
--------------------------------------------------------------------------------
/talkingface/util/log_board.py:
--------------------------------------------------------------------------------
1 | def log(
2 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""
3 | ):
4 | if losses is not None:
5 | logger.add_scalar("Loss/d_loss", losses[0], step)
6 | logger.add_scalar("Loss/g_gan_loss", losses[1], step)
7 | logger.add_scalar("Loss/g_l1_loss", losses[2], step)
8 |
9 | if fig is not None:
10 | logger.add_image(tag, fig, 2, dataformats='HWC')
11 |
12 | if audio is not None:
13 | logger.add_audio(
14 | tag,
15 | audio / max(abs(audio)),
16 | sample_rate=sampling_rate,
17 | )
--------------------------------------------------------------------------------
/talkingface/util/smooth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def smooth_array(array, weight = [0.1,0.8,0.1]):
7 | '''
8 |
9 | Args:
10 | array: [n_frames, n_values], 需要转换为[n_values, 1, n_frames]
11 | weight: Conv1d.weight, 一维卷积核权重
12 | Returns:
13 | array: [n_frames, n_values], 光滑后的array
14 | '''
15 | input = torch.Tensor(np.transpose(array[:,np.newaxis,:], (2, 1, 0)))
16 | smooth_length = len(weight)
17 | assert smooth_length%2 == 1, "卷积核权重个数必须使用奇数"
18 | pad = (smooth_length//2, smooth_length//2) # 当pad只有两个参数时,仅改变最后一个维度, 左边扩充1列,右边扩充1列
19 | input = F.pad(input, pad, "replicate")
20 |
21 | with torch.no_grad():
22 | conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=smooth_length)
23 | # 卷积核的元素值初始化
24 | weight = torch.tensor(weight).view(1, 1, -1)
25 | conv1.weight = torch.nn.Parameter(weight)
26 | nn.init.constant_(conv1.bias, 0) # 偏置值为0
27 | # print(conv1.weight)
28 | out = conv1(input)
29 | return out.permute(2,1,0).squeeze().numpy()
30 |
31 | if __name__ == '__main__':
32 | model_id = "new_case"
33 | Path_output_pkl = "../preparation/{}/mouth_info.pkl".format(model_id + "/00001")
34 | import pickle
35 | with open(Path_output_pkl, "rb") as f:
36 | images_info = pickle.load(f)
37 | pts_array_normalized = np.array(images_info[2])
38 | pts_array_normalized = pts_array_normalized.reshape(-1, 16)
39 | smooth_array_ = smooth_array(pts_array_normalized)
40 | print(smooth_array_, smooth_array_.shape)
41 | smooth_array_ = smooth_array_.reshape(-1, 4, 4)
42 | import pandas as pd
43 |
44 | pd.DataFrame(smooth_array_[:, :, 0]).to_csv("mat2.csv")
--------------------------------------------------------------------------------
/talkingface/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 |
8 |
9 | def tensor2im(input_image, imtype=np.uint8):
10 | """"Converts a Tensor array into a numpy image array.
11 |
12 | Parameters:
13 | input_image (tensor) -- the input image tensor array
14 | imtype (type) -- the desired type of the converted numpy array
15 | """
16 | if not isinstance(input_image, np.ndarray):
17 | if isinstance(input_image, torch.Tensor): # get the data from a variable
18 | image_tensor = input_image.data
19 | else:
20 | return input_image
21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
22 | if image_numpy.shape[0] == 1: # grayscale to RGB
23 | image_numpy = np.tile(image_numpy, (3, 1, 1))
24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
25 | else: # if it is a numpy array, do nothing
26 | image_numpy = input_image
27 | return image_numpy.astype(imtype)
28 |
29 |
30 | def diagnose_network(net, name='network'):
31 | """Calculate and print the mean of average absolute(gradients)
32 |
33 | Parameters:
34 | net (torch network) -- Torch network
35 | name (str) -- the name of the network
36 | """
37 | mean = 0.0
38 | count = 0
39 | for param in net.parameters():
40 | if param.grad is not None:
41 | mean += torch.mean(torch.abs(param.grad.data))
42 | count += 1
43 | if count > 0:
44 | mean = mean / count
45 | print(name)
46 | print(mean)
47 |
48 |
49 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
50 | """Save a numpy image to the disk
51 |
52 | Parameters:
53 | image_numpy (numpy array) -- input numpy array
54 | image_path (str) -- the path of the image
55 | """
56 |
57 | image_pil = Image.fromarray(image_numpy)
58 | h, w, _ = image_numpy.shape
59 |
60 | if aspect_ratio > 1.0:
61 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
62 | if aspect_ratio < 1.0:
63 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
64 | image_pil.save(image_path)
65 |
66 |
67 | def print_numpy(x, val=True, shp=False):
68 | """Print the mean, min, max, median, std, and size of a numpy array
69 |
70 | Parameters:
71 | val (bool) -- if print the values of the numpy array
72 | shp (bool) -- if print the shape of the numpy array
73 | """
74 | x = x.astype(np.float64)
75 | if shp:
76 | print('shape,', x.shape)
77 | if val:
78 | x = x.flatten()
79 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
80 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
81 |
82 |
83 | def mkdirs(paths):
84 | """create empty directories if they don't exist
85 |
86 | Parameters:
87 | paths (str list) -- a list of directory paths
88 | """
89 | if isinstance(paths, list) and not isinstance(paths, str):
90 | for path in paths:
91 | mkdir(path)
92 | else:
93 | mkdir(paths)
94 |
95 |
96 | def mkdir(path):
97 | """create a single empty directory if it didn't exist
98 |
99 | Parameters:
100 | path (str) -- a single directory path
101 | """
102 | if not os.path.exists(path):
103 | os.makedirs(path)
104 |
--------------------------------------------------------------------------------
/talkingface/util/utils.py:
--------------------------------------------------------------------------------
1 | from torch.optim import lr_scheduler
2 |
3 | import torch.nn as nn
4 | import torch
5 |
6 | ######################################################### training utils##########################################################
7 |
8 | def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50):
9 | '''
10 | scheduler in training stage
11 | '''
12 | if lr_policy == 'lambda':
13 | def lambda_rule(epoch):
14 | lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1)
15 | return lr_l
16 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
17 | elif lr_policy == 'step':
18 | scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1)
19 | elif lr_policy == 'plateau':
20 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
21 | elif lr_policy == 'cosine':
22 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0)
23 | else:
24 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
25 | return scheduler
26 |
27 | def update_learning_rate(scheduler, optimizer):
28 | scheduler.step()
29 | lr = optimizer.param_groups[0]['lr']
30 | print('learning rate = %.7f' % lr)
31 |
32 | class GANLoss(nn.Module):
33 | '''
34 | GAN loss
35 | '''
36 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
37 | super(GANLoss, self).__init__()
38 | self.register_buffer('real_label', torch.tensor(target_real_label))
39 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
40 | if use_lsgan:
41 | self.loss = nn.MSELoss()
42 | else:
43 | self.loss = nn.BCELoss()
44 |
45 | def get_target_tensor(self, input, target_is_real):
46 | if target_is_real:
47 | target_tensor = self.real_label
48 | else:
49 | target_tensor = self.fake_label
50 | return target_tensor.expand_as(input)
51 |
52 | def forward(self, input, target_is_real):
53 | target_tensor = self.get_target_tensor(input, target_is_real)
54 | return self.loss(input, target_tensor)
55 |
56 |
57 |
58 | import tqdm
59 | import numpy as np
60 | import cv2
61 | import glob
62 | import os
63 | import math
64 | import pickle
65 | import mediapipe as mp
66 | mp_face_mesh = mp.solutions.face_mesh
67 | landmark_points_68 = [162,234,93,58,172,136,149,148,152,377,378,365,397,288,323,454,389,
68 | 71,63,105,66,107,336,296,334,293,301,
69 | 168,197,5,4,75,97,2,326,305,
70 | 33,160,158,133,153,144,362,385,387,263,373,
71 | 380,61,39,37,0,267,269,291,405,314,17,84,181,78,82,13,312,308,317,14,87]
72 | def ExtractFaceFromFrameList(frames_list, vid_height, vid_width, out_size = 256):
73 | pts_3d = np.zeros([len(frames_list), 478, 3])
74 | with mp_face_mesh.FaceMesh(
75 | static_image_mode=True,
76 | max_num_faces=1,
77 | refine_landmarks=True,
78 | min_detection_confidence=0.5) as face_mesh:
79 |
80 | for index, frame in tqdm.tqdm(enumerate(frames_list)):
81 | results = face_mesh.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
82 | if not results.multi_face_landmarks:
83 | print("****** WARNING! No face detected! ******")
84 | pts_3d[index] = 0
85 | return
86 | # continue
87 | image_height, image_width = frame.shape[:2]
88 | for face_landmarks in results.multi_face_landmarks:
89 | for index_, i in enumerate(face_landmarks.landmark):
90 | x_px = min(math.floor(i.x * image_width), image_width - 1)
91 | y_px = min(math.floor(i.y * image_height), image_height - 1)
92 | z_px = min(math.floor(i.z * image_height), image_height - 1)
93 | pts_3d[index, index_] = np.array([x_px, y_px, z_px])
94 |
95 | # 计算整个视频中人脸的范围
96 |
97 | x_min, y_min, x_max, y_max = np.min(pts_3d[:, :, 0]), np.min(
98 | pts_3d[:, :, 1]), np.max(
99 | pts_3d[:, :, 0]), np.max(pts_3d[:, :, 1])
100 | new_w = int((x_max - x_min) * 0.55)*2
101 | new_h = int((y_max - y_min) * 0.6)*2
102 | center_x = int((x_max + x_min) / 2.)
103 | center_y = int(y_min + (y_max - y_min) * 0.6)
104 | size = max(new_h, new_w)
105 | x_min, y_min, x_max, y_max = int(center_x - size // 2), int(center_y - size // 2), int(
106 | center_x + size // 2), int(center_y + size // 2)
107 |
108 | # 确定裁剪区域上边top和左边left坐标
109 | top = y_min
110 | left = x_min
111 | # 裁剪区域与原图的重合区域
112 | top_coincidence = int(max(top, 0))
113 | bottom_coincidence = int(min(y_max, vid_height))
114 | left_coincidence = int(max(left, 0))
115 | right_coincidence = int(min(x_max, vid_width))
116 |
117 | scale = out_size / size
118 | pts_3d = (pts_3d - np.array([left, top, 0])) * scale
119 | pts_3d = pts_3d
120 |
121 | face_rect = np.array([center_x, center_y, size])
122 | print(np.array([x_min, y_min, x_max, y_max]))
123 |
124 | img_array = np.zeros([len(frames_list), out_size, out_size, 3], dtype = np.uint8)
125 | for index, frame in tqdm.tqdm(enumerate(frames_list)):
126 | img_new = np.zeros([size, size, 3], dtype=np.uint8)
127 | img_new[top_coincidence - top:bottom_coincidence - top, left_coincidence - left:right_coincidence - left,:] = \
128 | frame[top_coincidence:bottom_coincidence, left_coincidence:right_coincidence, :]
129 | img_new = cv2.resize(img_new, (out_size, out_size))
130 | img_array[index] = img_new
131 | return pts_3d,img_array, face_rect
132 |
133 |
--------------------------------------------------------------------------------
/video_data/000001/video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/000001/video.mp4
--------------------------------------------------------------------------------
/video_data/000002/video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/000002/video.mp4
--------------------------------------------------------------------------------
/video_data/audio0.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/audio0.wav
--------------------------------------------------------------------------------
/video_data/audio1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/audio1.wav
--------------------------------------------------------------------------------
/video_data/teeth_ref/221.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/221.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/252.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/252.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/328.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/328.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/377.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/377.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/398.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/398.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/519.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/519.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/558.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/558.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/682.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/682.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/743.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/743.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/760.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/760.png
--------------------------------------------------------------------------------
/video_data/teeth_ref/794.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/video_data/teeth_ref/794.png
--------------------------------------------------------------------------------
/web_demo/Flowchart.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/Flowchart.jpg
--------------------------------------------------------------------------------
/web_demo/README.md:
--------------------------------------------------------------------------------
1 | # DH_Live_mini 部署说明
2 |
3 | > [!NOTE]
4 | > 本项目专注于在最小硬件资源(无GPU、普通2核4G CPU)环境下实现低延迟的数字人服务部署。
5 |
6 | ## 服务组件分布
7 |
8 | | 组件 | 部署位置 |
9 | |--------|------------|
10 | | VAD | Web本地 |
11 | | ASR | 服务器本地 |
12 | | LLM | 云端服务 |
13 | | TTS | 服务器本地 |
14 | | 数字人 | Web本地 |
15 |
16 | 
17 |
18 | ## 目录结构
19 |
20 | 本项目目录结构如下:
21 | ```bash
22 | 项目根目录/
23 | ├── models/ # 本地TTS及ASR模型
24 | │ ├── sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/ # ASR
25 | │ ├── sherpa-onnx-vits-zh-ll/ # TTS
26 | ├── static/ # 静态资源文件夹
27 | │ ├── assets/ # 人物形象资源文件夹
28 | │ ├── assets2/ # 人物2形象资源文件夹
29 | │ ├── common/ # 公共资源文件夹
30 | │ ├── css/ # CSS样式文件夹
31 | │ ├── js/ # JavaScript脚本文件夹
32 | │ ├── DHLiveMini.wasm # AI推理组件
33 | │ ├── dialog.html # MiniLive.html包含的纯对话iframe页面
34 | │ ├── dialog_RealTime.html # MiniLive_RealTime.html包含的纯对话iframe页面
35 | │ └── MiniLive.html # 数字人视频流主页面(简单demo)
36 | │ └── MiniLive_RealTime.html # 数字人视频流主页面(实时语音对话页面,推荐!)
37 | ├── voiceapi/ # asr、llm、tts具体设置
38 | └── server.py # 启动网页服务的Python程序
39 | └── server_realtime.py # 启动实时语音对话网页服务的Python程序
40 | ```
41 | ### 运行项目
42 | (New!)启动实时语音对话服务:
43 |
44 | (注意需要下载本地ASR&TTS模型,并设置openai API进行大模型对话),请看下方配置说明。
45 | ```bash
46 | # 切换到DH_live根目录下
47 | python web_demo/server_realtime.py
48 | ```
49 | 打开浏览器,访问 http://localhost:8888/static/MiniLive_RealTime.html
50 |
51 |
52 | 如果只是需要简单演示服务:
53 | ```bash
54 | # 切换到DH_live根目录下
55 | python web_demo/server.py
56 | ```
57 | 打开浏览器,访问 http://localhost:8888/static/MiniLive.html
58 |
59 | ## 配置说明
60 |
61 | ### 1. 替换对话服务网址
62 |
63 | 对于全流程语音通话demo,在 static/js/dialog_realtime.js 文件中,找到第1行,将 http://localhost:8888/eb_stream 替换为您自己的对话服务网址。例如:
64 | https://your-dialogue-service.com/eb_stream, 将第二行的websocket url也改为"wss://your-dialogue-service.com/asr?samplerate=16000"
65 |
66 | 对于简单演示demo,在 static/js/dialog.js 文件中,找到第1行,将 http://localhost:8888/eb_stream 替换为您自己的对话服务网址。例如:
67 | https://your-dialogue-service.com/eb_stream
68 |
69 | ### 2. 模拟对话服务
70 |
71 | server.py 提供了一个模拟对话服务的示例。它接收JSON格式的输入,并流式返回JSON格式的响应。示例代码如下:
72 |
73 | 输入 JSON:
74 | ```bash
75 | {
76 | "prompt": "用户输入的对话内容"
77 | }
78 | ```
79 | 输出 JSON(流式返回):
80 | ```bash
81 | {
82 | "text": "返回的部分对话文本",
83 | "audio": "base64编码的音频数据",
84 | "endpoint": false // 是否为对话的最后一个片段,true表示结束
85 | }
86 | ```
87 | ### 3. 全流程的实时语音对话
88 | 下载相关模型(可以替换为其他类似模型):
89 |
90 | ASR model: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
91 |
92 | TTS model: https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2
93 |
94 | 在voiceapi/llm.py中,按照OpneAI API格式配置大模型接口:
95 |
96 | 豆包:
97 | ```bash
98 | from openai import OpenAI
99 | base_url = "https://ark.cn-beijing.volces.com/api/v3"
100 | api_key = "*****************************"
101 | model_name = "doubao-pro-32k-character-241215"
102 |
103 | llm_client = OpenAI(
104 | base_url=base_url,
105 | api_key=api_key,
106 | )
107 | ```
108 |
109 | DeepSeek:
110 | ```bash
111 | from openai import OpenAI
112 | base_url = "https://api.deepseek.com"
113 | api_key = ""
114 | model_name = "deepseek-chat"
115 |
116 | llm_client = OpenAI(
117 | base_url=base_url,
118 | api_key=api_key,
119 | )
120 | ```
121 |
122 | ### 4. 更换人物形象
123 |
124 | 要更换人物形象,请将新形象包中的文件替换 assets 文件夹中的对应文件。确保新文件的命名和路径与原有文件一致,以避免引用错误。
125 |
126 | ### 5. WebCodecs API 使用注意事项
127 |
128 | 本项目使用了 WebCodecs API,该 API 仅在安全上下文(HTTPS 或 localhost)中可用。因此,在部署或测试时,请确保您的网页在 HTTPS 环境下运行,或者使用 localhost 进行本地测试。
129 |
130 | ### 6. Thanks
131 | 此处重点感谢以下项目,本项目大量使用了以下项目的相关代码
132 |
133 | - [Project AIRI](https://github.com/moeru-ai/airi)
134 | - [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx)
135 |
--------------------------------------------------------------------------------
/web_demo/server.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import asyncio
4 | import re
5 | import base64
6 | from fastapi.responses import StreamingResponse
7 | from fastapi.staticfiles import StaticFiles
8 | from fastapi import FastAPI, Request, UploadFile, File,HTTPException
9 | app = FastAPI()
10 |
11 | # 挂载静态文件
12 | app.mount("/static", StaticFiles(directory="web_demo/static"), name="static")
13 |
14 | def get_audio(text_cache, voice_speed, voice_id):
15 | # 读取一个语音文件模拟语音合成的结果
16 | with open("web_demo/static/common/test.wav", "rb") as audio_file:
17 | audio_value = audio_file.read()
18 | base64_string = base64.b64encode(audio_value).decode('utf-8')
19 | return base64_string
20 |
21 | def llm_answer(prompt):
22 | # 模拟大模型的回答
23 | answer = "我会重复三遍来模仿大模型的回答,我会重复三遍来模仿大模型的回答,我会重复三遍来模仿大模型的回答。"
24 | return answer
25 |
26 | def split_sentence(sentence, min_length=10):
27 | # 定义包括小括号在内的主要标点符号
28 | punctuations = r'[。?!;…,、()()]'
29 | # 使用正则表达式切分句子,保留标点符号
30 | parts = re.split(f'({punctuations})', sentence)
31 | parts = [p for p in parts if p] # 移除空字符串
32 | sentences = []
33 | current = ''
34 | for part in parts:
35 | if current:
36 | # 如果当前片段加上新片段长度超过最小长度,则将当前片段添加到结果中
37 | if len(current) + len(part) >= min_length:
38 | sentences.append(current + part)
39 | current = ''
40 | else:
41 | current += part
42 | else:
43 | current = part
44 | # 将剩余的片段添加到结果中
45 | if len(current) >= 2:
46 | sentences.append(current)
47 | return sentences
48 |
49 |
50 | import asyncio
51 | async def gen_stream(prompt, asr = False, voice_speed=None, voice_id=None):
52 | print("XXXXXXXXX", voice_speed, voice_id)
53 | if asr:
54 | chunk = {
55 | "prompt": prompt
56 | }
57 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块
58 |
59 | text_cache = llm_answer(prompt)
60 | sentences = split_sentence(text_cache)
61 |
62 | for index_, sub_text in enumerate(sentences):
63 | base64_string = get_audio(sub_text, voice_speed, voice_id)
64 | # 生成 JSON 格式的数据块
65 | chunk = {
66 | "text": sub_text,
67 | "audio": base64_string,
68 | "endpoint": index_ == len(sentences)-1
69 | }
70 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块
71 | await asyncio.sleep(0.2) # 模拟异步延迟
72 |
73 | # 处理 ASR 和 TTS 的端点
74 | @app.post("/process_audio")
75 | async def process_audio(file: UploadFile = File(...)):
76 | # 模仿调用 ASR API 获取文本
77 | text = "语音已收到,这里只是模仿,真正对话需要您自己设置ASR服务。"
78 | # 调用 TTS 生成流式响应
79 | return StreamingResponse(gen_stream(text, asr=True), media_type="application/json")
80 |
81 |
82 | async def call_asr_api(audio_data):
83 | # 调用ASR完成语音识别
84 | answer = "语音已收到,这里只是模仿,真正对话需要您自己设置ASR服务。"
85 | return answer
86 |
87 | @app.post("/eb_stream") # 前端调用的path
88 | async def eb_stream(request: Request):
89 | try:
90 | body = await request.json()
91 | input_mode = body.get("input_mode")
92 | voice_speed = body.get("voice_speed")
93 | voice_id = body.get("voice_id")
94 |
95 | if input_mode == "audio":
96 | base64_audio = body.get("audio")
97 | # 解码 Base64 音频数据
98 | audio_data = base64.b64decode(base64_audio)
99 | # 这里可以添加对音频数据的处理逻辑
100 | prompt = await call_asr_api(audio_data) # 假设 call_asr_api 可以处理音频数据
101 | return StreamingResponse(gen_stream(prompt, asr=True, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json")
102 | elif input_mode == "text":
103 | prompt = body.get("prompt")
104 | return StreamingResponse(gen_stream(prompt, asr=False, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json")
105 | else:
106 | raise HTTPException(status_code=400, detail="Invalid input mode")
107 | except Exception as e:
108 | raise HTTPException(status_code=500, detail=str(e))
109 |
110 | # 启动Uvicorn服务器
111 | if __name__ == "__main__":
112 | import uvicorn
113 | uvicorn.run(app, host="0.0.0.0", port=8888)
114 |
--------------------------------------------------------------------------------
/web_demo/server_realtime.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from contextlib import asynccontextmanager
4 | import re
5 | import asyncio
6 | import base64
7 | from fastapi.responses import StreamingResponse
8 | from fastapi.staticfiles import StaticFiles
9 | from fastapi import FastAPI, Request, UploadFile, File,HTTPException,WebSocketDisconnect,WebSocket
10 | from voiceapi.asr import start_asr_stream, ASRResult,ASREngineManager
11 | import uvicorn
12 | import argparse
13 | from voiceapi.llm import llm_stream
14 | from voiceapi.tts import get_audio,TTSEngineManager
15 |
16 | # 2. 生命周期管理
17 | @asynccontextmanager
18 | async def lifespan(app: FastAPI):
19 | # 服务启动时初始化模型(示例参数)
20 | print("ASR模型正在初始化,请稍等")
21 | ASREngineManager.initialize(samplerate=16000, args = args)
22 | print("TTS模型正在初始化,请稍等")
23 | TTSEngineManager.initialize(args = args)
24 | yield
25 | # 服务关闭时清理资源
26 | if ASREngineManager.get_engine():
27 | ASREngineManager.get_engine().cleanup()
28 |
29 |
30 | app = FastAPI(lifespan=lifespan)
31 |
32 | # 挂载静态文件
33 | app.mount("/static", StaticFiles(directory="web_demo/static"), name="static")
34 |
35 |
36 | def split_sentence(sentence, min_length=10):
37 | # 定义包括小括号在内的主要标点符号
38 | punctuations = r'[。?!;…,、()()]'
39 | # 使用正则表达式切分句子,保留标点符号
40 | parts = re.split(f'({punctuations})', sentence)
41 | parts = [p for p in parts if p] # 移除空字符串
42 | sentences = []
43 | current = ''
44 | for part in parts:
45 | if current:
46 | # 如果当前片段加上新片段长度超过最小长度,则将当前片段添加到结果中
47 | if len(current) + len(part) >= min_length:
48 | sentences.append(current + part)
49 | current = ''
50 | else:
51 | current += part
52 | else:
53 | current = part
54 | # 将剩余的片段添加到结果中
55 | if len(current) >= 2:
56 | sentences.append(current)
57 | return sentences
58 |
59 | PUNCTUATION_SET = {
60 | ',', " ", '。', '!', '?', ';', ':', '、', '(', ')', '【', '】', '“', '”',
61 | ',', '.', '!', '?', ';', ':', '(', ')', '[', ']', '"', "'"
62 | }
63 |
64 | async def gen_stream(prompt, asr = False, voice_speed=None, voice_id=None):
65 | print("gen_stream", voice_speed, voice_id)
66 | if asr:
67 | chunk = {
68 | "prompt": prompt
69 | }
70 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块
71 |
72 | # Streaming:
73 | print("----- streaming request -----")
74 | stream = llm_stream(prompt)
75 | llm_answer_cache = ""
76 | for chunk in stream:
77 | if not chunk.choices:
78 | continue
79 | llm_answer_cache += chunk.choices[0].delta.content
80 |
81 | # 查找第一个标点符号的位置
82 | punctuation_pos = -1
83 | for i, char in enumerate(llm_answer_cache[8:]):
84 | if char in PUNCTUATION_SET:
85 | punctuation_pos = i + 8
86 | break
87 | # 如果找到标点符号且第一小句字数大于8
88 | if punctuation_pos != -1:
89 | # 获取第一小句
90 | first_sentence = llm_answer_cache[:punctuation_pos + 1]
91 | # 剩余的文字
92 | remaining_text = llm_answer_cache[punctuation_pos + 1:]
93 | print("get_audio: ", first_sentence)
94 | base64_string = await get_audio(first_sentence, voice_id=voice_id, voice_speed=voice_speed)
95 | chunk = {
96 | "text": first_sentence,
97 | "audio": base64_string,
98 | "endpoint": False
99 | }
100 |
101 | # 更新缓存为剩余的文字
102 | llm_answer_cache = remaining_text
103 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块
104 | await asyncio.sleep(0.2) # 模拟异步延迟
105 | print("get_audio: ", llm_answer_cache)
106 | if len(llm_answer_cache) >= 2:
107 | base64_string = await get_audio(llm_answer_cache, voice_id=voice_id, voice_speed=voice_speed)
108 | else:
109 | base64_string = ""
110 | chunk = {
111 | "text": llm_answer_cache,
112 | "audio": base64_string,
113 | "endpoint": True
114 | }
115 | yield f"{json.dumps(chunk)}\n" # 使用换行符分隔 JSON 块
116 |
117 | @app.websocket("/asr")
118 | async def websocket_asr(websocket: WebSocket, samplerate: int = 16000):
119 | await websocket.accept()
120 |
121 | asr_stream = await start_asr_stream(samplerate, args)
122 | if not asr_stream:
123 | print("failed to start ASR stream")
124 | await websocket.close()
125 | return
126 |
127 | async def task_recv_pcm():
128 | while True:
129 | try:
130 | data = await asyncio.wait_for(websocket.receive(), timeout=1.0)
131 | # print(f"message: {data}")
132 | except asyncio.TimeoutError:
133 | continue # 没有数据到达,继续循环
134 |
135 | if "text" in data.keys():
136 | print(f"Received text message: {data}")
137 | data = data["text"]
138 | if data.strip() == "vad":
139 | print("VAD signal received")
140 | await asr_stream.vad_touched()
141 | elif "bytes" in data.keys():
142 | pcm_bytes = data["bytes"]
143 | print("XXXX pcm_bytes", len(pcm_bytes))
144 | if not pcm_bytes:
145 | return
146 | await asr_stream.write(pcm_bytes)
147 |
148 |
149 | async def task_send_result():
150 | while True:
151 | result: ASRResult = await asr_stream.read()
152 | if not result:
153 | return
154 | await websocket.send_json(result.to_dict())
155 | try:
156 | await asyncio.gather(task_recv_pcm(), task_send_result())
157 | except WebSocketDisconnect:
158 | print("asr: disconnected")
159 | finally:
160 | await asr_stream.close()
161 |
162 | @app.post("/eb_stream") # 前端调用的path
163 | async def eb_stream(request: Request):
164 | try:
165 | body = await request.json()
166 | input_mode = body.get("input_mode")
167 | voice_speed = body.get("voice_speed", 1.0)
168 | voice_id = body.get("voice_id", 0)
169 |
170 | if voice_speed == "":
171 | voice_speed = 1.0
172 | if voice_id == "":
173 | voice_id = 0
174 |
175 | if input_mode == "text":
176 | prompt = body.get("prompt")
177 | return StreamingResponse(gen_stream(prompt, asr=False, voice_speed=voice_speed, voice_id=voice_id), media_type="application/json")
178 | else:
179 | raise HTTPException(status_code=400, detail="Invalid input mode")
180 | except Exception as e:
181 | raise HTTPException(status_code=500, detail=str(e))
182 |
183 | # 启动Uvicorn服务器
184 | if __name__ == "__main__":
185 | models_root = './models'
186 |
187 | for d in ['.', '..', 'web_demo']:
188 | if os.path.isdir(f'{d}/models'):
189 | models_root = f'{d}/models'
190 | break
191 |
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument("--port", type=int, default=8888, help="port number")
194 | parser.add_argument("--addr", type=str,
195 | default="0.0.0.0", help="serve address")
196 |
197 | parser.add_argument("--asr-provider", type=str,
198 | default="cpu", help="asr provider, cpu or cuda")
199 | parser.add_argument("--tts-provider", type=str,
200 | default="cpu", help="tts provider, cpu or cuda")
201 |
202 | parser.add_argument("--threads", type=int, default=2,
203 | help="number of threads")
204 |
205 | parser.add_argument("--models-root", type=str, default=models_root,
206 | help="model root directory")
207 |
208 | parser.add_argument("--asr-model", type=str, default='zipformer-bilingual',
209 | help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en, whisper-medium")
210 |
211 | parser.add_argument("--asr-lang", type=str, default='zh',
212 | help="ASR language, zh, en, ja, ko, yue")
213 |
214 | parser.add_argument("--tts-model", type=str, default='sherpa-onnx-vits-zh-ll',
215 | help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en")
216 |
217 | args = parser.parse_args()
218 |
219 | if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda':
220 | print(
221 | "vits-melo-tts-zh_en does not support CUDA fallback to CPU")
222 | args.tts_provider = 'cpu'
223 |
224 | uvicorn.run(app, host=args.addr, port=args.port)
225 |
--------------------------------------------------------------------------------
/web_demo/static/DHLiveMini.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/DHLiveMini.wasm
--------------------------------------------------------------------------------
/web_demo/static/MiniLive.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | MiniLive
9 |
81 |
82 |
83 |
84 |
85 | MiniMates: loading...
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 | 加载中
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/web_demo/static/MiniLive_RealTime.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | MiniLive
9 |
129 |
130 |
131 |
132 |
136 |
137 |
138 |
145 |
146 |
147 |
148 | MiniMates: loading...
149 |
150 |
151 |
152 |
153 |
154 |
155 | 加载中
156 |
157 |
158 |
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/web_demo/static/MiniLive_new.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | MiniLive
8 |
128 |
129 |
130 |
131 |
137 |
138 |
139 |
145 |
146 |
147 |
148 | MiniMates: loading...
149 |
150 |
151 |
152 |
153 |
154 |
155 | 加载中
156 |
157 | -->
158 |
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/web_demo/static/assets/01.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets/01.mp4
--------------------------------------------------------------------------------
/web_demo/static/assets/combined_data.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets/combined_data.json.gz
--------------------------------------------------------------------------------
/web_demo/static/assets2/01.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets2/01.mp4
--------------------------------------------------------------------------------
/web_demo/static/assets2/combined_data.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/assets2/combined_data.json.gz
--------------------------------------------------------------------------------
/web_demo/static/common/bs_texture_halfFace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/bs_texture_halfFace.png
--------------------------------------------------------------------------------
/web_demo/static/common/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/favicon.ico
--------------------------------------------------------------------------------
/web_demo/static/common/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/common/test.wav
--------------------------------------------------------------------------------
/web_demo/static/css/material-icons.css:
--------------------------------------------------------------------------------
1 | /* 定义字体 */
2 | @font-face {
3 | font-family: 'Material Icons';
4 | font-style: normal;
5 | font-weight: 400;
6 | src: url('../fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2') format('woff2');
7 | }
8 |
9 | /* 定义图标样式 */
10 | .material-icons {
11 | font-family: 'Material Icons';
12 | font-weight: normal;
13 | font-style: normal;
14 | font-size: 24px;
15 | line-height: 1;
16 | letter-spacing: normal;
17 | text-transform: none;
18 | display: inline-block;
19 | white-space: nowrap;
20 | word-wrap: normal;
21 | direction: ltr;
22 | font-feature-settings: 'liga';
23 | -webkit-font-smoothing: antialiased;
24 | }
--------------------------------------------------------------------------------
/web_demo/static/dialog.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | AI聊天
7 |
8 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
149 |
150 |
151 |
--------------------------------------------------------------------------------
/web_demo/static/dialog_RealTime.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | AI聊天
7 |
8 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/web_demo/static/fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kleinlee/DH_live/01a3648d467af58290f07449c2e51bbfb69c10d5/web_demo/static/fonts/flUhRq6tzZclQEJ-Vdg-IuiaDsNcIhQ8tQ.woff2
--------------------------------------------------------------------------------
/web_demo/static/js/MiniMateLoader.js:
--------------------------------------------------------------------------------
1 | document.addEventListener('DOMContentLoaded', function () {
2 | init();
3 | });
4 |
5 | async function init()
6 | {
7 | const spinner = document.querySelector('#loadingSpinner');
8 | const screen = document.querySelector('#screen');
9 | const showUi = () => {
10 | spinner.style.display = 'none';
11 | screen.style.display = 'block';
12 | }
13 | const instance = await qtLoad({
14 | qt: {
15 | onLoaded: () => showUi(),
16 | entryFunction: window.createQtAppInstance,
17 | containerElements: [screen],
18 | }
19 | });
20 | await newVideoTask();
21 | document.getElementById('screen2').style.display = 'block';
22 | }
23 |
24 |
25 | async function qtLoad(config)
26 | {
27 | const throwIfEnvUsedButNotExported = (instance, config) =>
28 | {
29 | const environment = config.environment;
30 | if (!environment || Object.keys(environment).length === 0)
31 | return;
32 | const isEnvExported = typeof instance.ENV === 'object';
33 | if (!isEnvExported)
34 | throw new Error('ENV must be exported if environment variables are passed');
35 | };
36 |
37 | const throwIfFsUsedButNotExported = (instance, config) =>
38 | {
39 | const environment = config.environment;
40 | if (!environment || Object.keys(environment).length === 0)
41 | return;
42 | const isFsExported = typeof instance.FS === 'object';
43 | if (!isFsExported)
44 | throw new Error('FS must be exported if preload is used');
45 | };
46 |
47 | if (typeof config !== 'object')
48 | throw new Error('config is required, expected an object');
49 | if (typeof config.qt !== 'object')
50 | throw new Error('config.qt is required, expected an object');
51 | if (typeof config.qt.entryFunction !== 'function')
52 | config.qt.entryFunction = window.createQtAppInstance;
53 |
54 | config.qt.qtdir ??= 'qt';
55 | config.qt.preload ??= [];
56 |
57 | config.qtContainerElements = config.qt.containerElements;
58 | delete config.qt.containerElements;
59 | config.qtFontDpi = config.qt.fontDpi;
60 | delete config.qt.fontDpi;
61 |
62 | // Used for rejecting a failed load's promise where emscripten itself does not allow it,
63 | // like in instantiateWasm below. This allows us to throw in case of a load error instead of
64 | // hanging on a promise to entry function, which emscripten unfortunately does.
65 | let circuitBreakerReject;
66 | const circuitBreaker = new Promise((_, reject) => { circuitBreakerReject = reject; });
67 |
68 | // If module async getter is present, use it so that module reuse is possible.
69 | if (config.qt.module) {
70 | config.instantiateWasm = async (imports, successCallback) =>
71 | {
72 | try {
73 | const module = await config.qt.module;
74 | successCallback(
75 | await WebAssembly.instantiate(module, imports), module);
76 | } catch (e) {
77 | circuitBreakerReject(e);
78 | }
79 | }
80 | }
81 |
82 | const qtPreRun = (instance) => {
83 | // Copy qt.environment to instance.ENV
84 | throwIfEnvUsedButNotExported(instance, config);
85 | for (const [name, value] of Object.entries(config.qt.environment ?? {}))
86 | instance.ENV[name] = value;
87 |
88 | // Copy self.preloadData to MEMFS
89 | const makeDirs = (FS, filePath) => {
90 | const parts = filePath.split("/");
91 | let path = "/";
92 | for (let i = 0; i < parts.length - 1; ++i) {
93 | const part = parts[i];
94 | if (part == "")
95 | continue;
96 | path += part + "/";
97 | try {
98 | FS.mkdir(path);
99 | } catch (error) {
100 | const EEXIST = 20;
101 | if (error.errno != EEXIST)
102 | throw error;
103 | }
104 | }
105 | }
106 | throwIfFsUsedButNotExported(instance, config);
107 | for ({destination, data} of self.preloadData) {
108 | makeDirs(instance.FS, destination);
109 | instance.FS.writeFile(destination, new Uint8Array(data));
110 | }
111 | }
112 |
113 | if (!config.preRun)
114 | config.preRun = [];
115 | config.preRun.push(qtPreRun);
116 |
117 | config.onRuntimeInitialized = () => config.qt.onLoaded?.();
118 |
119 | const originalLocateFile = config.locateFile;
120 | config.locateFile = filename =>
121 | {
122 | const originalLocatedFilename = originalLocateFile ? originalLocateFile(filename) : filename;
123 | if (originalLocatedFilename.startsWith('libQt6'))
124 | return `${config.qt.qtdir}/lib/${originalLocatedFilename}`;
125 | return originalLocatedFilename;
126 | }
127 |
128 | const originalOnExit = config.onExit;
129 | config.onExit = code => {
130 | originalOnExit?.();
131 | config.qt.onExit?.({
132 | code,
133 | crashed: false
134 | });
135 | }
136 |
137 | const originalOnAbort = config.onAbort;
138 | config.onAbort = text =>
139 | {
140 | originalOnAbort?.();
141 |
142 | aborted = true;
143 | config.qt.onExit?.({
144 | text,
145 | crashed: true
146 | });
147 | };
148 |
149 | const fetchPreloadFiles = async () => {
150 | const fetchJson = async path => (await fetch(path)).json();
151 | const fetchArrayBuffer = async path => (await fetch(path)).arrayBuffer();
152 | const loadFiles = async (paths) => {
153 | const source = paths['source'].replace('$QTDIR', config.qt.qtdir);
154 | return {
155 | destination: paths['destination'],
156 | data: await fetchArrayBuffer(source)
157 | };
158 | }
159 | const fileList = (await Promise.all(config.qt.preload.map(fetchJson))).flat();
160 | self.preloadData = (await Promise.all(fileList.map(loadFiles))).flat();
161 | }
162 |
163 | await fetchPreloadFiles();
164 |
165 | // Call app/emscripten module entry function. It may either come from the emscripten
166 | // runtime script or be customized as needed.
167 | let instance;
168 | try {
169 | instance = await Promise.race(
170 | [circuitBreaker, config.qt.entryFunction(config)]);
171 | } catch (e) {
172 | config.qt.onExit?.({
173 | text: e.message,
174 | crashed: true
175 | });
176 | throw e;
177 | }
178 |
179 | return instance;
180 | }
181 |
182 | // Compatibility API. This API is deprecated,
183 | // and will be removed in a future version of Qt.
184 | function QtLoader(qtConfig) {
185 |
186 | const warning = 'Warning: The QtLoader API is deprecated and will be removed in ' +
187 | 'a future version of Qt. Please port to the new qtLoad() API.';
188 | console.warn(warning);
189 |
190 | let emscriptenConfig = qtConfig.moduleConfig || {}
191 | qtConfig.moduleConfig = undefined;
192 | const showLoader = qtConfig.showLoader;
193 | qtConfig.showLoader = undefined;
194 | const showError = qtConfig.showError;
195 | qtConfig.showError = undefined;
196 | const showExit = qtConfig.showExit;
197 | qtConfig.showExit = undefined;
198 | const showCanvas = qtConfig.showCanvas;
199 | qtConfig.showCanvas = undefined;
200 | if (qtConfig.canvasElements) {
201 | qtConfig.containerElements = qtConfig.canvasElements
202 | qtConfig.canvasElements = undefined;
203 | } else {
204 | qtConfig.containerElements = qtConfig.containerElements;
205 | qtConfig.containerElements = undefined;
206 | }
207 | emscriptenConfig.qt = qtConfig;
208 |
209 | let qtloader = {
210 | exitCode: undefined,
211 | exitText: "",
212 | loadEmscriptenModule: _name => {
213 | try {
214 | qtLoad(emscriptenConfig);
215 | } catch (e) {
216 | showError?.(e.message);
217 | }
218 | }
219 | }
220 |
221 | qtConfig.onLoaded = () => {
222 | showCanvas?.();
223 | }
224 |
225 | qtConfig.onExit = exit => {
226 | qtloader.exitCode = exit.code
227 | qtloader.exitText = exit.text;
228 | showExit?.();
229 | }
230 |
231 | showLoader?.("Loading");
232 |
233 | return qtloader;
234 | };
235 |
--------------------------------------------------------------------------------
/web_demo/static/js/audio_recorder.js:
--------------------------------------------------------------------------------
1 | // 将AudioWorklet处理逻辑转为字符串嵌入主文件
2 | const workletCode = `
3 | class PCMProcessor extends AudioWorkletProcessor {
4 | constructor() {
5 | super();
6 | this.port.onmessage = (event) => {
7 | if (event.data === 'stop') {
8 | this.port.postMessage('prepare to stop');
9 | this.isStopped = true;
10 | if (this.buffer.length > 0 && this.buffer.length > this.targetSampleCount) {
11 | this.port.postMessage(new Int16Array(this.buffer));
12 | this.port.postMessage({'event':'stopped'});
13 | this.buffer = [];
14 | }
15 | }
16 | };
17 | this.buffer = [];
18 | this.targetSampleCount = 1024;
19 | }
20 |
21 | process(inputs) {
22 | const input = inputs[0];
23 | if (input.length > 0) {
24 | const inputData = input[0];
25 | // 优化数据转换
26 | const samples = inputData.map(sample =>
27 | Math.max(-32768, Math.min(32767, Math.round(sample * 32767)))
28 | );
29 | this.buffer.push(...samples);
30 |
31 | while (this.buffer.length >= this.targetSampleCount) {
32 | const pcmData = this.buffer.splice(0, this.targetSampleCount);
33 | this.port.postMessage(new Int16Array(pcmData));
34 | this.port.postMessage({'event':'sending'});
35 | }
36 | }
37 | return true;
38 | }
39 | }
40 |
41 | registerProcessor('pcm-processor', PCMProcessor);
42 | `;
43 | class PCMAudioRecorder {
44 | constructor() {
45 | this.audioContext = null;
46 | this.stream = null;
47 | this.currentSource = null;
48 | this.audioCallback = null;
49 | }
50 |
51 | async connect(audioCallback) {
52 | this.audioCallback = audioCallback;
53 | if (!this.audioContext) {
54 | this.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
55 | }
56 | console.log('Current sample rate:', this.audioContext.sampleRate, 'Hz');
57 |
58 | // 生成动态worklet
59 | const blob = new Blob([workletCode], { type: 'application/javascript' });
60 | const url = URL.createObjectURL(blob);
61 |
62 | try {
63 | await this.audioContext.audioWorklet.addModule(url);
64 | URL.revokeObjectURL(url); // 清除内存
65 | } catch (e) {
66 | console.error('Error loading AudioWorklet:', e);
67 | return;
68 | }
69 |
70 | this.stream = await navigator.mediaDevices.getUserMedia({ audio: true });
71 | this.currentSource = this.audioContext.createMediaStreamSource(this.stream);
72 |
73 | this.processorNode = new AudioWorkletNode(this.audioContext, 'pcm-processor');
74 |
75 | this.processorNode.port.onmessage = (event) => {
76 | if (event.data instanceof Int16Array) {
77 | this.audioCallback?.(event.data);
78 | } else if (event.data?.event === 'stopped') {
79 | console.log('Recorder stopped.');
80 | }
81 | };
82 |
83 | this.currentSource.connect(this.processorNode);
84 | this.processorNode.connect(this.audioContext.destination);
85 | }
86 |
87 | stop() {
88 | if (this.processorNode) {
89 | this.processorNode.port.postMessage('stop');
90 | this.processorNode.disconnect();
91 | this.processorNode = null;
92 | }
93 |
94 | this.stream?.getTracks().forEach(track => track.stop());
95 | this.currentSource?.disconnect();
96 |
97 | if (this.audioContext) {
98 | this.audioContext.close().then(() => {
99 | this.audioContext = null;
100 | });
101 | }
102 | }
103 | }
104 |
105 | // 暴露到全局环境
106 | window.PCMAudioRecorder = PCMAudioRecorder;
--------------------------------------------------------------------------------
/web_demo/voiceapi/llm.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | # 豆包
3 | base_url = "https://ark.cn-beijing.volces.com/api/v3"
4 | api_key = ""
5 | model_name = "doubao-pro-32k-character-241215"
6 |
7 | # # DeepSeek
8 | # base_url = "https://api.deepseek.com"
9 | # api_key = ""
10 | # model_name = "deepseek-chat"
11 |
12 | assert api_key, "您必须配置自己的LLM API秘钥"
13 |
14 | llm_client = OpenAI(
15 | base_url=base_url,
16 | api_key=api_key,
17 | )
18 |
19 |
20 | def llm_stream(prompt):
21 | stream = llm_client.chat.completions.create(
22 | # 指定您创建的方舟推理接入点 ID,此处已帮您修改为您的推理接入点 ID
23 | model=model_name,
24 | messages=[
25 | {"role": "system", "content": "你是人工智能助手"},
26 | {"role": "user", "content": prompt},
27 | ],
28 | # 响应内容是否流式返回
29 | stream=True,
30 | )
31 | return stream
32 |
--------------------------------------------------------------------------------
/web_demo/voiceapi/tts.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | import os
3 | import time
4 | import sherpa_onnx
5 | import logging
6 | import numpy as np
7 | import asyncio
8 | import time
9 | import soundfile
10 | from scipy.signal import resample
11 | import io
12 | import re
13 | import threading
14 | import base64
15 | logger = logging.getLogger(__file__)
16 |
17 | splitter = re.compile(r'[,,。.!?!?;;、\n]')
18 | _tts_engines = {}
19 |
20 | tts_configs = {
21 | 'sherpa-onnx-vits-zh-ll': {
22 | 'model': 'model.onnx',
23 | 'lexicon': 'lexicon.txt',
24 | 'dict_dir': 'dict',
25 | 'tokens': 'tokens.txt',
26 | 'sample_rate': 16000,
27 | # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
28 | },
29 | 'vits-zh-hf-theresa': {
30 | 'model': 'theresa.onnx',
31 | 'lexicon': 'lexicon.txt',
32 | 'dict_dir': 'dict',
33 | 'tokens': 'tokens.txt',
34 | 'sample_rate': 22050,
35 | # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
36 | },
37 | 'vits-melo-tts-zh_en': {
38 | 'model': 'model.onnx',
39 | 'lexicon': 'lexicon.txt',
40 | 'dict_dir': 'dict',
41 | 'tokens': 'tokens.txt',
42 | 'sample_rate': 44100,
43 | 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
44 | },
45 | }
46 |
47 |
48 | def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig:
49 | cfg = tts_configs[name]
50 | fsts = []
51 | model_dir = os.path.join(model_root, name)
52 | for f in cfg.get('rule_fsts', ''):
53 | fsts.append(os.path.join(model_dir, f))
54 | tts_rule_fsts = ','.join(fsts) if fsts else ''
55 |
56 | model_config = sherpa_onnx.OfflineTtsModelConfig(
57 | vits=sherpa_onnx.OfflineTtsVitsModelConfig(
58 | model=os.path.join(model_dir, cfg['model']),
59 | lexicon=os.path.join(model_dir, cfg['lexicon']),
60 | dict_dir=os.path.join(model_dir, cfg['dict_dir']),
61 | tokens=os.path.join(model_dir, cfg['tokens']),
62 | ),
63 | provider=provider,
64 | debug=0,
65 | num_threads=num_threads,
66 | )
67 | tts_config = sherpa_onnx.OfflineTtsConfig(
68 | model=model_config,
69 | rule_fsts=tts_rule_fsts,
70 | max_num_sentences=max_num_sentences)
71 |
72 | if not tts_config.validate():
73 | raise ValueError("tts: invalid config")
74 |
75 | return tts_config
76 |
77 |
78 | def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]:
79 | sample_rate = tts_configs[args.tts_model]['sample_rate']
80 | cache_engine = _tts_engines.get(args.tts_model)
81 | if cache_engine:
82 | return cache_engine, sample_rate
83 | st = time.time()
84 | tts_config = load_tts_model(
85 | args.tts_model, args.models_root, args.tts_provider)
86 |
87 | cache_engine = sherpa_onnx.OfflineTts(tts_config)
88 | elapsed = time.time() - st
89 | logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s")
90 | _tts_engines[args.tts_model] = cache_engine
91 |
92 | return cache_engine, sample_rate
93 |
94 | # 1. 全局模型管理类
95 | class TTSEngineManager:
96 | _instance = None
97 | _lock = threading.Lock()
98 |
99 | def __new__(cls):
100 | with cls._lock:
101 | if not cls._instance:
102 | cls._instance = super().__new__(cls)
103 | cls._instance.engine = None
104 | return cls._instance
105 |
106 | @classmethod
107 | def initialize(cls, args):
108 | instance = cls()
109 | if instance.engine is None: # 安全访问属性
110 | instance.engine, instance.original_sample_rate = get_tts_engine(args)
111 |
112 | @classmethod
113 | def get_engine(cls):
114 | instance = cls() # 确保实例存在
115 | return instance.engine,instance.original_sample_rate # 安全访问属性
116 |
117 |
118 | async def get_audio(text, voice_speed=1.0, voice_id=0, target_sample_rate = 16000):
119 | print("run_tts", text, voice_speed, voice_id)
120 | # 获取全局共享的ASR引擎
121 | tts_engine,original_sample_rate = TTSEngineManager.get_engine()
122 |
123 | # 将同步方法放入线程池执行
124 | loop = asyncio.get_event_loop()
125 | audio = await loop.run_in_executor(
126 | None,
127 | lambda: tts_engine.generate(text, voice_id, voice_speed)
128 | )
129 | # audio = tts_engine.generate(text, voice_id, voice_speed)
130 | samples = audio.samples
131 | if target_sample_rate != original_sample_rate:
132 | num_samples = int(
133 | len(samples) * target_sample_rate / original_sample_rate)
134 | resampled_chunk = resample(samples, num_samples)
135 | audio.samples = resampled_chunk.astype(np.float32)
136 | audio.sample_rate = target_sample_rate
137 |
138 | output = io.BytesIO()
139 | # 使用 soundfile 写入 WAV 格式数据(自动生成头部)
140 | soundfile.write(
141 | output,
142 | audio.samples, # 音频数据(numpy 数组)
143 | samplerate=audio.sample_rate, # 采样率(如 16000)
144 | subtype="PCM_16", # 16-bit PCM 编码
145 | format="WAV" # WAV 容器格式
146 | )
147 |
148 | # 获取字节数据并 Base64 编码
149 | wav_data = output.getvalue()
150 | return base64.b64encode(wav_data).decode("utf-8")
151 |
152 | # import wave
153 | # import uuid
154 | # with wave.open('{}.wav'.format(uuid.uuid4()), 'w') as f:
155 | # f.setnchannels(1)
156 | # f.setsampwidth(2)
157 | # f.setframerate(16000)
158 | # f.writeframes(samples)
159 | # return base64.b64encode(samples).decode('utf-8')
160 |
161 |
--------------------------------------------------------------------------------