├── raw └── wav_file_here ├── wavs └── slicer_files_here ├── whisper_model └── whisper_model_here ├── 查看cuda版本.bat ├── models_from_modelscope └── place_model_files_here ├── GPU诊断.bat ├── 1_Dataset.bat ├── common ├── __pycache__ │ ├── log.cpython-39.pyc │ ├── constants.cpython-39.pyc │ ├── stdout_wrapper.cpython-39.pyc │ └── subprocess_utils.cpython-39.pyc ├── log.py ├── constants.py ├── stdout_wrapper.py ├── subprocess_utils.py └── tts_model.py ├── bcut_asr ├── __pycache__ │ ├── orm.cpython-310.pyc │ ├── orm.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ └── __init__.cpython-39.pyc ├── orm.py ├── __main__.py └── __init__.py ├── becut_test.py ├── README.md ├── gpu_diagnostics.py ├── esd_whisper_large_v2.list ├── audio_slicer_pre.py ├── argparse_tools.py ├── trans_utils.py ├── short_audio_transcribe_bcut.py ├── requirements.txt ├── short_audio_transcribe_fwhisper.py ├── short_audio_transcribe_whisper.py ├── subtitle_utils.py ├── short_audio_transcribe_ali.py ├── slicer2.py ├── webui_dataset.py └── videoclipper.py /raw/wav_file_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /wavs/slicer_files_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /whisper_model/whisper_model_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /查看cuda版本.bat: -------------------------------------------------------------------------------- 1 | nvcc --version 2 | pause 3 | -------------------------------------------------------------------------------- /models_from_modelscope/place_model_files_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /GPU诊断.bat: -------------------------------------------------------------------------------- 1 | venv\python.exe gpu_diagnostics.py 2 | pause 3 | -------------------------------------------------------------------------------- /1_Dataset.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 3 | 4 | call venv\python.exe webui_dataset.py 5 | 6 | @echo 请按任意键继续 7 | call pause -------------------------------------------------------------------------------- /common/__pycache__/log.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/log.cpython-39.pyc -------------------------------------------------------------------------------- /bcut_asr/__pycache__/orm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/orm.cpython-310.pyc -------------------------------------------------------------------------------- /bcut_asr/__pycache__/orm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/orm.cpython-39.pyc -------------------------------------------------------------------------------- /bcut_asr/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bcut_asr/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /common/__pycache__/constants.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/constants.cpython-39.pyc -------------------------------------------------------------------------------- /common/__pycache__/stdout_wrapper.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/stdout_wrapper.cpython-39.pyc -------------------------------------------------------------------------------- /common/__pycache__/subprocess_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/subprocess_utils.cpython-39.pyc -------------------------------------------------------------------------------- /common/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | logger封装 3 | """ 4 | from loguru import logger 5 | 6 | from .stdout_wrapper import SAFE_STDOUT 7 | 8 | # 移除所有默认的处理器 9 | logger.remove() 10 | 11 | # 自定义格式并添加到标准输出 12 | log_format = ( 13 | "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}" 14 | ) 15 | 16 | logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True) 17 | -------------------------------------------------------------------------------- /becut_test.py: -------------------------------------------------------------------------------- 1 | from bcut_asr import BcutASR 2 | from bcut_asr.orm import ResultStateEnum 3 | 4 | asr = BcutASR("./wavs/Erwin_0.wav") 5 | asr.upload() # 上传文件 6 | asr.create_task() # 创建任务 7 | 8 | # 轮询检查结果 9 | while True: 10 | result = asr.result() 11 | # 判断识别成功 12 | if result.state == ResultStateEnum.COMPLETE: 13 | break 14 | 15 | # 解析字幕内容 16 | subtitle = result.parse() 17 | 18 | # 判断是否存在字幕 19 | if subtitle.has_data(): 20 | # 输出srt格式 21 | print(subtitle.to_txt()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 安装Python3 2 | 3 | ## 安装依赖 4 | 5 | ``` 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## 音频放入raw目录 10 | 11 | ``` 12 | 音频格式 角色名.wav 13 | ``` 14 | 15 | ## 运行webui 16 | 17 | ``` 18 | python3 webui_dataset.py 19 | ``` 20 | ![PixPin_2024-07-10_13-57-16](https://github.com/v3ucn/ASR_TOOLS_WebUI/assets/1288038/cc1f019a-647e-4076-895e-18dfba1ad2e5) 21 | 22 | 23 | ## 视频攻略 24 | 25 | https://www.bilibili.com/video/BV1da4y117Y6/ 26 | 27 | ## 必剪项目官方地址:SocialSisterYi/bcut-asr: 使用必剪API的语音字幕识别 (github.com) 28 | -------------------------------------------------------------------------------- /common/constants.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | DEFAULT_STYLE: str = "Neutral" 4 | DEFAULT_STYLE_WEIGHT: float = 26.0 5 | 6 | 7 | class Languages(str, enum.Enum): 8 | JP = "JP" 9 | EN = "EN" 10 | ZH = "ZH" 11 | 12 | 13 | DEFAULT_SDP_RATIO: float = 0.2 14 | DEFAULT_NOISE: float = 0.6 15 | DEFAULT_NOISEW: float = 0.8 16 | DEFAULT_LENGTH: float = 1.0 17 | DEFAULT_LINE_SPLIT: bool = True 18 | DEFAULT_SPLIT_INTERVAL: float = 0.5 19 | DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7 20 | DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0 21 | -------------------------------------------------------------------------------- /common/stdout_wrapper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tempfile 3 | 4 | 5 | class StdoutWrapper: 6 | def __init__(self): 7 | self.temp_file = tempfile.NamedTemporaryFile(mode="w+", delete=False) 8 | self.original_stdout = sys.stdout 9 | 10 | def write(self, message: str): 11 | self.temp_file.write(message) 12 | self.temp_file.flush() 13 | print(message, end="", file=self.original_stdout) 14 | 15 | def flush(self): 16 | self.temp_file.flush() 17 | 18 | def read(self): 19 | self.temp_file.seek(0) 20 | return self.temp_file.read() 21 | 22 | def close(self): 23 | self.temp_file.close() 24 | 25 | def fileno(self): 26 | return self.temp_file.fileno() 27 | 28 | 29 | try: 30 | import google.colab 31 | 32 | SAFE_STDOUT = StdoutWrapper() 33 | except ImportError: 34 | SAFE_STDOUT = sys.stdout 35 | -------------------------------------------------------------------------------- /gpu_diagnostics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gpu_diagnostics(): 4 | if torch.cuda.is_available(): 5 | print("GPU 诊断报告:") 6 | print("="*40) 7 | for i in range(torch.cuda.device_count()): 8 | props = torch.cuda.get_device_properties(i) 9 | total_memory = props.total_memory / (1024 * 1024) 10 | reserved_memory = torch.cuda.memory_reserved(i) / (1024 * 1024) 11 | allocated_memory = torch.cuda.memory_allocated(i) / (1024 * 1024) 12 | free_memory = total_memory - allocated_memory 13 | 14 | print(f"GPU {i}: {props.name}") 15 | print(f" 总显存 : {round(total_memory, 2)} MB") 16 | print(f" 已保留显存 : {round(reserved_memory, 2)} MB") 17 | print(f" 已分配显存 : {round(allocated_memory, 2)} MB") 18 | print(f" 空闲显存 : {round(free_memory, 2)} MB") 19 | print("="*40) 20 | else: 21 | print("未找到 GPU,使用 CPU") 22 | 23 | if __name__ == "__main__": 24 | gpu_diagnostics() 25 | -------------------------------------------------------------------------------- /common/subprocess_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | from .log import logger 5 | from .stdout_wrapper import SAFE_STDOUT 6 | 7 | 8 | #python = ".\\venv\\python.exe" 9 | python = "python3" 10 | 11 | print(python) 12 | 13 | 14 | def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]: 15 | logger.info(f"Running: {' '.join(cmd)}") 16 | print([python] + cmd) 17 | 18 | result = subprocess.run( 19 | [python] + cmd, 20 | stdout=SAFE_STDOUT, # type: ignore 21 | stderr=subprocess.PIPE, 22 | text=True, 23 | ) 24 | 25 | if result.returncode != 0: 26 | logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}") 27 | return False, result.stderr 28 | elif result.stderr and not ignore_warning: 29 | logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}") 30 | return True, result.stderr 31 | logger.success(f"Success: {' '.join(cmd)}") 32 | return True, "" 33 | 34 | 35 | def second_elem_of(original_function): 36 | def inner_function(*args, **kwargs): 37 | return original_function(*args, **kwargs)[1] 38 | 39 | return inner_function 40 | -------------------------------------------------------------------------------- /esd_whisper_large_v2.list: -------------------------------------------------------------------------------- 1 | Erwin_0.wav|Erwin|ZH|如果这个作战顺利 2 | Erwin_1.wav|Erwin|ZH|你也许可以趁此机会干掉兽之巨人 3 | Erwin_10.wav|Erwin|ZH|如果到时候我不冲在最前面 4 | Erwin_11.wav|Erwin|ZH|他们根本不会往前冲,然后我会第一个去死 5 | Erwin_12.wav|Erwin|ZH|地下室里到底有什么 6 | Erwin_13.wav|Erwin|ZH|也就无从知晓了,好想去地下室看一看,我之所以能撑着走到今天 7 | Erwin_14.wav|Erwin|ZH|就是因为相信这一天的到来 8 | Erwin_15.wav|Erwin|ZH|因为潜行者 9 | Erwin_16.wav|Erwin|ZH|我的猜想能够得到证实 10 | Erwin_17.wav|Erwin|ZH|我之前无数次的想过,要不然干脆死了算了 11 | Erwin_18.wav|Erwin|ZH|可即便如此,我还是想要实现父亲的梦想 12 | Erwin_19.wav|Erwin|ZH|然而现在 13 | Erwin_2.wav|Erwin|ZH|但得到所有新兵不管选择哪条路 14 | Erwin_20.wav|Erwin|ZH|他的答案就在我触手可及的地方 15 | Erwin_21.wav|Erwin|ZH|近在咫尺死去的同伴们也是如此吗 16 | Erwin_22.wav|Erwin|ZH|那些流血的牺牲都是没有意义的吗 17 | Erwin_23.wav|Erwin|ZH|不不对 18 | Erwin_24.wav|Erwin|ZH|那些死去士兵的意义将由我们来赋予 19 | Erwin_25.wav|Erwin|ZH|那些勇敢的死者,可怜的死者 20 | Erwin_26.wav|Erwin|ZH|是他们的牺牲换来了我们活着的今天 21 | Erwin_27.wav|Erwin|ZH|让我们能站在这里,而今天我们将会死去 22 | Erwin_28.wav|Erwin|ZH|将一一托付给下一个活着的人 23 | Erwin_29.wav|Erwin|ZH|这就是我们与这个残酷的世界 24 | Erwin_3.wav|Erwin|ZH|我们基本都会死吧,是的全灭的可能性相当的高 25 | Erwin_30.wav|Erwin|ZH|抗争的意义 26 | Erwin_4.wav|Erwin|ZH|但事到如今也只能做好玉石俱焚的觉悟 27 | Erwin_5.wav|Erwin|ZH|将一切堵在获胜渺茫的战术上 28 | Erwin_6.wav|Erwin|ZH|到了这一步 29 | Erwin_7.wav|Erwin|ZH|要让那些年轻人们去死 30 | Erwin_8.wav|Erwin|ZH|就必须像一个一流的诈骗犯一样 31 | Erwin_9.wav|Erwin|ZH|为他们花言巧语一番 32 | -------------------------------------------------------------------------------- /audio_slicer_pre.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import soundfile 3 | import os 4 | import argparse 5 | from slicer2 import Slicer 6 | 7 | # 设置命令行参数 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--min_sec", "-m", type=int, default=2000, help="Minimum seconds of a slice" 11 | ) 12 | parser.add_argument( 13 | "--max_sec", "-M", type=int, default=5000, help="Maximum seconds of a slice" 14 | ) 15 | parser.add_argument( 16 | "--dataset_path", 17 | type=str, 18 | default="inputs", 19 | help="Directory of input wav files", 20 | ) 21 | parser.add_argument( 22 | "--min_silence_dur_ms", 23 | "-s", 24 | type=int, 25 | default=700, 26 | help="Silence above this duration (ms) is considered as a split point.", 27 | ) 28 | args = parser.parse_args() 29 | 30 | # 清空输出目录 31 | folder_path = './wavs' 32 | if os.path.exists(folder_path): 33 | for filename in os.listdir(folder_path): 34 | file_path = os.path.join(folder_path, filename) 35 | if os.path.isfile(file_path): 36 | os.remove(file_path) 37 | else: 38 | os.makedirs(folder_path) 39 | 40 | # 遍历指定目录下的所有.wav文件 41 | audio_directory = f'{args.dataset_path}' 42 | for filename in os.listdir(audio_directory): 43 | file_path = os.path.join(audio_directory, filename) 44 | if os.path.isfile(file_path) and filename.endswith('.wav'): 45 | # 加载音频文件 46 | audio, sr = librosa.load(file_path, sr=None, mono=False) 47 | 48 | # 创建Slicer对象 49 | slicer = Slicer( 50 | sr=sr, 51 | threshold=-40, 52 | min_length=args.min_sec, 53 | min_interval=300, 54 | hop_size=10, 55 | max_sil_kept=args.min_silence_dur_ms 56 | ) 57 | 58 | # 切割音频 59 | chunks = slicer.slice(audio) 60 | for i, chunk in enumerate(chunks): 61 | if len(chunk.shape) > 1: 62 | chunk = chunk.T # Swap axes if the audio is stereo. 63 | soundfile.write(f'./wavs/{filename[:-4]}_{i}.wav', chunk, sr) 64 | -------------------------------------------------------------------------------- /argparse_tools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import yaml 5 | import sys 6 | 7 | 8 | class ArgumentParser(argparse.ArgumentParser): 9 | """Simple implementation of ArgumentParser supporting config file 10 | 11 | This class is originated from https://github.com/bw2/ConfigArgParse, 12 | but this class is lack of some features that it has. 13 | 14 | - Not supporting multiple config files 15 | - Automatically adding "--config" as an option. 16 | - Not supporting any formats other than yaml 17 | - Not checking argument type 18 | 19 | """ 20 | 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.add_argument("--config", help="Give config file in yaml format") 24 | 25 | def parse_known_args(self, args=None, namespace=None): 26 | # Once parsing for setting from "--config" 27 | _args, _ = super().parse_known_args(args, namespace) 28 | if _args.config is not None: 29 | if not Path(_args.config).exists(): 30 | self.error(f"No such file: {_args.config}") 31 | 32 | with open(_args.config, "r", encoding="utf-8") as f: 33 | d = yaml.safe_load(f) 34 | if not isinstance(d, dict): 35 | self.error("Config file has non dict value: {_args.config}") 36 | 37 | for key in d: 38 | for action in self._actions: 39 | if key == action.dest: 40 | break 41 | else: 42 | self.error(f"unrecognized arguments: {key} (from {_args.config})") 43 | 44 | # NOTE(kamo): Ignore "--config" from a config file 45 | # NOTE(kamo): Unlike "configargparse", this module doesn't check type. 46 | # i.e. We can set any type value regardless of argument type. 47 | self.set_defaults(**d) 48 | return super().parse_known_args(args, namespace) 49 | 50 | 51 | def get_commandline_args(): 52 | extra_chars = [ 53 | " ", 54 | ";", 55 | "&", 56 | "(", 57 | ")", 58 | "|", 59 | "^", 60 | "<", 61 | ">", 62 | "?", 63 | "*", 64 | "[", 65 | "]", 66 | "$", 67 | "`", 68 | '"', 69 | "\\", 70 | "!", 71 | "{", 72 | "}", 73 | ] 74 | 75 | # Escape the extra characters for shell 76 | argv = [ 77 | arg.replace("'", "'\\''") 78 | if all(char not in arg for char in extra_chars) 79 | else "'" + arg.replace("'", "'\\''") + "'" 80 | for arg in sys.argv 81 | ] 82 | 83 | return sys.executable + " " + " ".join(argv) -------------------------------------------------------------------------------- /trans_utils.py: -------------------------------------------------------------------------------- 1 | PUNC_LIST = [',', '。', '!', '?', '、'] 2 | 3 | 4 | def pre_proc(text): 5 | res = '' 6 | for i in range(len(text)): 7 | if text[i] in PUNC_LIST: 8 | continue 9 | if '\u4e00' <= text[i] <= '\u9fff': 10 | if len(res) and res[-1] != " ": 11 | res += ' ' + text[i]+' ' 12 | else: 13 | res += text[i]+' ' 14 | else: 15 | res += text[i] 16 | if res[-1] == ' ': 17 | res = res[:-1] 18 | return res 19 | 20 | def proc(raw_text, timestamp, dest_text): 21 | # simple matching 22 | ld = len(dest_text.split()) 23 | mi, ts = [], [] 24 | offset = 0 25 | while True: 26 | fi = raw_text.find(dest_text, offset, len(raw_text)) 27 | # import pdb; pdb.set_trace() 28 | ti = raw_text[:fi].count(' ') 29 | if fi == -1: 30 | break 31 | offset = fi + ld 32 | mi.append(fi) 33 | ts.append([timestamp[ti][0]*16, timestamp[ti+ld-1][1]*16]) 34 | # import pdb; pdb.set_trace() 35 | return ts 36 | 37 | def proc_spk(dest_spk, sd_sentences): 38 | ts = [] 39 | for d in sd_sentences: 40 | d_start = d['ts_list'][0][0] 41 | d_end = d['ts_list'][-1][1] 42 | spkid=dest_spk[3:] 43 | if str(d['spk']) == spkid and d_end-d_start>999: 44 | ts.append([d['start']*16, d['end']*16]) 45 | return ts 46 | 47 | def generate_vad_data(data, sd_sentences, sr=16000): 48 | assert len(data.shape) == 1 49 | vad_data = [] 50 | for d in sd_sentences: 51 | d_start = round(d['ts_list'][0][0]/1000, 2) 52 | d_end = round(d['ts_list'][-1][1]/1000, 2) 53 | vad_data.append([d_start, d_end, data[int(d_start * sr):int(d_end * sr)]]) 54 | return vad_data 55 | 56 | def write_state(output_dir, state): 57 | for key in ['/recog_res_raw', '/timestamp', '/sentences', '/sd_sentences']: 58 | with open(output_dir+key, 'w') as fout: 59 | fout.write(str(state[key[1:]])) 60 | if 'sd_sentences' in state: 61 | with open(output_dir+'/sd_sentences', 'w') as fout: 62 | fout.write(str(state['sd_sentences'])) 63 | 64 | import os 65 | def load_state(output_dir): 66 | state = {} 67 | with open(output_dir+'/recog_res_raw') as fin: 68 | line = fin.read() 69 | state['recog_res_raw'] = line 70 | with open(output_dir+'/timestamp') as fin: 71 | line = fin.read() 72 | state['timestamp'] = eval(line) 73 | with open(output_dir+'/sentences') as fin: 74 | line = fin.read() 75 | state['sentences'] = eval(line) 76 | if os.path.exists(output_dir+'/sd_sentences'): 77 | with open(output_dir+'/sd_sentences') as fin: 78 | line = fin.read() 79 | state['sd_sentences'] = eval(line) 80 | return state 81 | 82 | -------------------------------------------------------------------------------- /bcut_asr/orm.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pydantic import BaseModel 3 | 4 | class ASRDataSeg(BaseModel): 5 | '文字识别-断句' 6 | class ASRDataWords(BaseModel): 7 | '文字识别-逐字' 8 | label: str 9 | start_time: int 10 | end_time: int 11 | confidence: int 12 | start_time: int 13 | end_time: int 14 | transcript: str 15 | words: list[ASRDataWords] 16 | confidence: int 17 | 18 | def to_srt_ts(self) -> str: 19 | '转换为srt时间戳' 20 | def _conv(ms: int) -> tuple[int ,int, int, int]: 21 | return ms // 3600000, ms // 60000 % 60, ms // 1000 % 60, ms % 1000 22 | s_h, s_m, s_s, s_ms = _conv(self.start_time) 23 | e_h, e_m, e_s, e_ms = _conv(self.end_time) 24 | return f'{s_h:02d}:{s_m:02d}:{s_s:02d},{s_ms:03d} --> {e_h:02d}:{e_m:02d}:{e_s:02d},{e_ms:03d}' 25 | 26 | def to_lrc_ts(self) -> str: 27 | '转换为lrc时间戳' 28 | def _conv(ms: int) -> tuple[int ,int, int]: 29 | return ms // 60000, ms // 1000 % 60, ms % 1000 // 10 30 | s_m, s_s, s_ms = _conv(self.start_time) 31 | return f'[{s_m:02d}:{s_s:02d}.{s_ms:02d}]' 32 | 33 | 34 | 35 | class ASRData(BaseModel): 36 | '语音识别结果' 37 | utterances: list[ASRDataSeg] 38 | version: str 39 | 40 | def __iter__(self): 41 | 'iter穿透' 42 | return iter(self.utterances) 43 | 44 | def has_data(self) -> bool: 45 | '是否识别到数据' 46 | return len(self.utterances) > 0 47 | 48 | def to_txt(self) -> str: 49 | '转成txt格式字幕 (无时间标记)' 50 | return '\n'.join( 51 | seg.transcript 52 | for seg 53 | in self.utterances 54 | ) 55 | 56 | def to_srt(self) -> str: 57 | '转成srt格式字幕' 58 | return '\n'.join( 59 | f'{n}\n{seg.to_srt_ts()}\n{seg.transcript}\n' 60 | for n, seg 61 | in enumerate(self.utterances, 1) 62 | ) 63 | 64 | def to_lrc(self) -> str: 65 | '转成lrc格式字幕' 66 | return '\n'.join( 67 | f'{seg.to_lrc_ts()}{seg.transcript}' 68 | for seg 69 | in self.utterances 70 | ) 71 | 72 | def to_ass(self) -> str: 73 | ... 74 | 75 | 76 | class ResourceCreateRspSchema(BaseModel): 77 | '上传申请响应' 78 | resource_id: str 79 | title: str 80 | type: int 81 | in_boss_key: str 82 | size: int 83 | upload_urls: list[str] 84 | upload_id: str 85 | per_size: int 86 | 87 | class ResourceCompleteRspSchema(BaseModel): 88 | '上传提交响应' 89 | resource_id: str 90 | download_url: str 91 | 92 | class TaskCreateRspSchema(BaseModel): 93 | '任务创建响应' 94 | resource: str 95 | result: str 96 | task_id: str # 任务id 97 | 98 | class ResultStateEnum(Enum): 99 | '任务状态枚举' 100 | STOP = 0 # 未开始 101 | RUNING = 1 # 运行中 102 | ERROR = 3 # 错误 103 | COMPLETE = 4 # 完成 104 | 105 | class ResultRspSchema(BaseModel): 106 | '任务结果查询响应' 107 | task_id: str # 任务id 108 | result: str # 结果数据-json 109 | remark: str # 任务状态详情 110 | state: ResultStateEnum # 任务状态 111 | 112 | def parse(self) -> ASRData: 113 | '解析结果数据' 114 | return ASRData.parse_raw(self.result) 115 | -------------------------------------------------------------------------------- /short_audio_transcribe_bcut.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | from tqdm import tqdm 6 | import sys 7 | import os 8 | 9 | from common.constants import Languages 10 | from common.log import logger 11 | from common.stdout_wrapper import SAFE_STDOUT 12 | 13 | from bcut_asr import BcutASR 14 | from bcut_asr.orm import ResultStateEnum 15 | 16 | import whisper 17 | import torch 18 | 19 | import re 20 | 21 | 22 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 23 | 24 | 25 | model = whisper.load_model("medium",download_root="./whisper_model/") 26 | 27 | 28 | 29 | lang2token = { 30 | 'zh': "ZH|", 31 | 'ja': "JP|", 32 | "en": "EN|", 33 | } 34 | 35 | 36 | def transcribe_one(audio_path): 37 | 38 | audio = whisper.load_audio(audio_path) 39 | audio = whisper.pad_or_trim(audio) 40 | mel = whisper.log_mel_spectrogram(audio).to(model.device) 41 | _, probs = model.detect_language(mel) 42 | language = max(probs, key=probs.get) 43 | 44 | asr = BcutASR(audio_path) 45 | asr.upload() # 上传文件 46 | asr.create_task() # 创建任务 47 | 48 | # 轮询检查结果 49 | while True: 50 | result = asr.result() 51 | # 判断识别成功 52 | if result.state == ResultStateEnum.COMPLETE: 53 | break 54 | 55 | # 解析字幕内容 56 | subtitle = result.parse() 57 | 58 | # 判断是否存在字幕 59 | if subtitle.has_data(): 60 | 61 | 62 | 63 | text = subtitle.to_txt() 64 | text = repr(text) 65 | text = text.replace("'","") 66 | text = text.replace("\\n",",") 67 | text = text.replace("\\r",",") 68 | 69 | print(text) 70 | 71 | # 输出srt格式 72 | return text,language 73 | else: 74 | return "必剪无法识别",language 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | parser = argparse.ArgumentParser() 81 | 82 | parser.add_argument( 83 | "--language", type=str, default="ja", choices=["ja", "en", "zh"] 84 | ) 85 | parser.add_argument("--model_name", type=str, required=True) 86 | 87 | parser.add_argument("--input_file", type=str, default="./wavs/") 88 | 89 | parser.add_argument("--file_pos", type=str, default="") 90 | 91 | 92 | args = parser.parse_args() 93 | 94 | speaker_name = args.model_name 95 | 96 | language = args.language 97 | 98 | input_file = args.input_file 99 | 100 | if input_file == "": 101 | input_file = "./wavs/" 102 | 103 | file_pos = args.file_pos 104 | 105 | 106 | wav_files = [ 107 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav") 108 | ] 109 | 110 | 111 | with open("./esd.list", "w", encoding="utf-8") as f: 112 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT): 113 | file_name = os.path.basename(wav_file) 114 | 115 | # 使用正则表达式提取'deedee' 116 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file) 117 | if match: 118 | extracted_name = match.group(1) + match.group(2) 119 | else: 120 | print("No match found") 121 | extracted_name = "sample" 122 | 123 | text,lang = transcribe_one(f"{input_file}"+wav_file) 124 | 125 | if lang == "ja": 126 | language_id = "JA" 127 | elif lang == "en": 128 | language_id = "EN" 129 | elif lang == "zh": 130 | language_id = "ZH" 131 | 132 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n") 133 | 134 | f.flush() 135 | sys.exit(0) 136 | 137 | 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | aiohttp==3.9.1 5 | aiosignal==1.3.1 6 | aliyun-python-sdk-core==2.14.0 7 | aliyun-python-sdk-kms==2.16.2 8 | altair==5.2.0 9 | annotated-types==0.6.0 10 | anyio==4.2.0 11 | async-timeout==4.0.3 12 | attrs==23.2.0 13 | audioread==3.0.1 14 | cachetools==5.3.2 15 | certifi==2023.11.17 16 | cffi==1.16.0 17 | charset-normalizer==3.3.2 18 | click==8.1.7 19 | colorama==0.4.6 20 | contourpy==1.2.0 21 | crcmod==1.7 22 | cryptography==41.0.7 23 | cycler==0.12.1 24 | Cython==0.29.37 25 | datasets==2.16.1 26 | decorator==5.1.1 27 | dill==0.3.7 28 | editdistance==0.6.2 29 | einops==0.7.0 30 | exceptiongroup==1.2.0 31 | fastapi==0.109.0 32 | ffmpeg==1.4 33 | ffmpy==0.3.1 34 | filelock==3.13.1 35 | fonttools==4.47.2 36 | frozenlist==1.4.1 37 | fsspec==2023.10.0 38 | funasr 39 | gast==0.5.4 40 | google-auth==2.26.2 41 | google-auth-oauthlib==1.2.0 42 | gradio==4.14.0 43 | gradio_client==0.8.0 44 | grpcio==1.60.0 45 | h11==0.14.0 46 | hdbscan==0.8.33 47 | httpcore==1.0.2 48 | httpx==0.26.0 49 | huggingface-hub==0.20.2 50 | humanfriendly==10.0 51 | idna==3.6 52 | importlib-metadata==7.0.1 53 | importlib-resources==6.1.1 54 | jaconv==0.3.4 55 | jamo==0.4.1 56 | jieba==0.42.1 57 | Jinja2==3.1.3 58 | jmespath==0.10.0 59 | joblib==1.3.2 60 | jsonschema==4.20.0 61 | jsonschema-specifications==2023.12.1 62 | kaldiio==2.18.0 63 | kiwisolver==1.4.5 64 | lazy_loader==0.3 65 | librosa==0.10.1 66 | llvmlite==0.41.1 67 | loguru==0.7.2 68 | Markdown==3.5.2 69 | markdown-it-py==3.0.0 70 | MarkupSafe==2.1.3 71 | matplotlib==3.8.2 72 | mdurl==0.1.2 73 | mecab-python3==1.0.8 74 | modelscope==1.10.0 75 | more-itertools==10.2.0 76 | mpmath==1.3.0 77 | msgpack==1.0.7 78 | multidict==6.0.4 79 | multiprocess==0.70.15 80 | networkx==3.2.1 81 | numba==0.58.1 82 | numpy==1.26.3 83 | oauthlib==3.2.2 84 | openai-whisper @ git+https://github.com/openai/whisper.git@ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab 85 | orjson==3.9.10 86 | oss2==2.18.4 87 | packaging==23.2 88 | pandas==2.1.4 89 | pillow==10.2.0 90 | platformdirs==4.1.0 91 | pooch==1.8.0 92 | protobuf==4.23.4 93 | pyarrow==14.0.2 94 | pyarrow-hotfix==0.6 95 | pyasn1==0.5.1 96 | pyasn1-modules==0.3.0 97 | pycparser==2.21 98 | pycryptodome==3.20.0 99 | pydantic==2.5.3 100 | pydantic_core==2.14.6 101 | pydub==0.25.1 102 | Pygments==2.17.2 103 | pyparsing==3.1.1 104 | pypinyin==0.50.0 105 | pyreadline3==3.4.1 106 | python-dateutil==2.8.2 107 | python-multipart==0.0.6 108 | pytorch-wpe==0.0.1 109 | pytz==2023.3.post1 110 | PyYAML==6.0.1 111 | referencing==0.32.1 112 | regex==2023.12.25 113 | requests==2.31.0 114 | requests-oauthlib==1.3.1 115 | rich==13.7.0 116 | rpds-py==0.17.1 117 | rsa==4.9 118 | safetensors==0.4.1 119 | scikit-learn==1.3.2 120 | scipy==1.11.4 121 | semantic-version==2.10.0 122 | sentencepiece==0.1.99 123 | shellingham==1.5.4 124 | simplejson==3.19.2 125 | six==1.16.0 126 | sniffio==1.3.0 127 | sortedcontainers==2.4.0 128 | soundfile==0.12.1 129 | soxr==0.3.7 130 | starlette==0.35.1 131 | sympy==1.12 132 | tensorboard==2.15.1 133 | tensorboard-data-server==0.7.2 134 | threadpoolctl==3.2.0 135 | tiktoken==0.5.2 136 | tokenizers==0.15.0 137 | tomli==2.0.1 138 | tomlkit==0.12.0 139 | toolz==0.12.0 140 | torch==2.1.2+cu118 141 | torch-complex==0.4.3 142 | torchaudio==2.1.2 143 | torchvision==0.16.2+cu118 144 | tqdm==4.66.1 145 | transformers==4.36.2 146 | typer==0.9.0 147 | typing_extensions==4.9.0 148 | tzdata==2023.4 149 | umap==0.1.1 150 | urllib3==2.1.0 151 | uvicorn==0.25.0 152 | websockets==11.0.3 153 | Werkzeug==3.0.1 154 | win32-setctime==1.1.0 155 | xxhash==3.4.1 156 | yapf==0.40.2 157 | yarl==1.9.4 158 | zipp==3.17.0 159 | faster-whisper 160 | moviepy 161 | -------------------------------------------------------------------------------- /short_audio_transcribe_fwhisper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | from tqdm import tqdm 7 | import sys 8 | import os 9 | 10 | from common.constants import Languages 11 | from common.log import logger 12 | from common.stdout_wrapper import SAFE_STDOUT 13 | 14 | import re 15 | 16 | from transformers import pipeline 17 | 18 | from faster_whisper import WhisperModel 19 | 20 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 21 | 22 | model = None 23 | 24 | lang2token = { 25 | 'zh': "ZH|", 26 | 'ja': "JP|", 27 | "en": "EN|", 28 | } 29 | 30 | 31 | def transcribe_bela(audio_path): 32 | 33 | transcriber = pipeline( 34 | "automatic-speech-recognition", 35 | model="BELLE-2/Belle-whisper-large-v2-zh", 36 | device=device 37 | ) 38 | 39 | transcriber.model.config.forced_decoder_ids = ( 40 | transcriber.tokenizer.get_decoder_prompt_ids( 41 | language="zh", 42 | task="transcribe", 43 | ) 44 | ) 45 | 46 | transcription = transcriber(audio_path) 47 | 48 | print(transcription["text"]) 49 | return transcription["text"] 50 | 51 | 52 | def transcribe_one(audio_path,mytype): 53 | 54 | segments, info = model.transcribe(audio_path, beam_size=5,vad_filter=True,vad_parameters=dict(min_silence_duration_ms=500),) 55 | print("Detected language '%s' with probability %f" % (info.language, info.language_probability)) 56 | 57 | text_str = "" 58 | for segment in segments: 59 | text_str += f"{segment.text.lstrip()}," 60 | print(text_str.rstrip(",")) 61 | 62 | return text_str.rstrip(","),info.language 63 | 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument( 71 | "--language", type=str, default="ja", choices=["ja", "en", "zh"] 72 | ) 73 | 74 | parser.add_argument( 75 | "--mytype", type=str, default="medium" 76 | ) 77 | 78 | parser.add_argument("--model_name", type=str, required=True) 79 | 80 | parser.add_argument("--input_file", type=str, default="./wavs/") 81 | 82 | parser.add_argument("--file_pos", type=str, default="") 83 | 84 | 85 | args = parser.parse_args() 86 | 87 | speaker_name = args.model_name 88 | 89 | language = args.language 90 | 91 | mytype = args.mytype 92 | 93 | input_file = args.input_file 94 | 95 | if input_file == "": 96 | input_file = "./wavs/" 97 | 98 | file_pos = args.file_pos 99 | 100 | if device == "cuda": 101 | try: 102 | model = WhisperModel(mytype, device="cuda", compute_type="float16",download_root="./whisper_model",local_files_only=False) 103 | except Exception as e: 104 | model = WhisperModel(mytype, device="cuda", compute_type="int8_float16",download_root="./whisper_model",local_files_only=False) 105 | else: 106 | model = WhisperModel(mytype, device="cpu", compute_type="int8",download_root="./whisper_model",local_files_only=False) 107 | 108 | 109 | wav_files = [ 110 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav") 111 | ] 112 | 113 | 114 | 115 | with open("./esd.list", "w", encoding="utf-8") as f: 116 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT): 117 | file_name = os.path.basename(wav_file) 118 | 119 | if model: 120 | text,lang = transcribe_one(f"{input_file}"+wav_file,mytype) 121 | else: 122 | text,lang = transcribe_bela(f"{input_file}"+wav_file) 123 | 124 | # 使用正则表达式提取'deedee' 125 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file) 126 | if match: 127 | extracted_name = match.group(1) + match.group(2) 128 | else: 129 | print("No match found") 130 | extracted_name = "sample" 131 | 132 | if lang == "ja": 133 | language_id = "JA" 134 | elif lang == "en": 135 | language_id = "EN" 136 | elif lang == "zh": 137 | language_id = "ZH" 138 | 139 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n") 140 | 141 | f.flush() 142 | sys.exit(0) 143 | 144 | 145 | -------------------------------------------------------------------------------- /bcut_asr/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import time 4 | from argparse import ArgumentParser, FileType 5 | import ffmpeg 6 | from . import APIError, BcutASR, ResultStateEnum 7 | 8 | logging.basicConfig(format='%(asctime)s - [%(levelname)s] %(message)s', level=logging.INFO) 9 | 10 | INFILE_FMT = ('flac', 'aac', 'm4a', 'mp3', 'wav') 11 | OUTFILE_FMT = ('srt', 'json', 'lrc', 'txt') 12 | 13 | parser = ArgumentParser( 14 | prog='bcut-asr', 15 | description='必剪语音识别\n', 16 | epilog=f"支持输入音频格式: {', '.join(INFILE_FMT)} 支持自动调用ffmpeg提取视频伴音" 17 | ) 18 | parser.add_argument('-f', '--format', nargs='?', default='srt', choices=OUTFILE_FMT, help='输出字幕格式') 19 | parser.add_argument('input', type=FileType('rb'), help='输入媒体文件') 20 | parser.add_argument('output', nargs='?', type=FileType('w', encoding='utf8'), help='输出字幕文件, 可stdout') 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | def ffmpeg_render(media_file: str) -> bytes: 26 | '提取视频伴音并转码为aac格式' 27 | out, err = (ffmpeg 28 | .input(media_file, v='warning') 29 | .output('pipe:', ac=1, format='adts') 30 | .run(capture_stdout=True) 31 | ) 32 | return out 33 | 34 | 35 | def main(): 36 | # 处理输入文件情况 37 | infile = args.input 38 | infile_name = infile.name 39 | if infile_name == '': 40 | logging.error('输入文件错误') 41 | sys.exit(-1) 42 | suffix = infile_name.rsplit('.', 1)[-1] 43 | if suffix in INFILE_FMT: 44 | infile_fmt = suffix 45 | infile_data = infile.read() 46 | else: 47 | # ffmpeg分离视频伴音 48 | logging.info('非标准音频文件, 尝试调用ffmpeg转码') 49 | try: 50 | infile_data = ffmpeg_render(infile_name) 51 | except ffmpeg.Error: 52 | logging.error('ffmpeg转码失败') 53 | sys.exit(-1) 54 | else: 55 | logging.info('ffmpeg转码完成') 56 | infile_fmt = 'aac' 57 | 58 | # 处理输出文件情况 59 | outfile = args.output 60 | if outfile is None: 61 | # 未指定输出文件,默认为文件名同输入,可以 -t 传参,默认str格式 62 | if args.format is not None: 63 | outfile_fmt = args.format 64 | else: 65 | outfile_fmt = 'srt' 66 | else: 67 | # 指定输出文件 68 | outfile_name = outfile.name 69 | if outfile.name == '': 70 | # stdout情况,可以 -t 传参,默认str格式 71 | if args.format is not None: 72 | outfile_fmt = args.format 73 | else: 74 | outfile_fmt = 'srt' 75 | else: 76 | suffix = outfile_name.rsplit('.', 1)[-1] 77 | if suffix in OUTFILE_FMT: 78 | outfile_fmt = suffix 79 | else: 80 | logging.error('输出格式错误') 81 | sys.exit(-1) 82 | 83 | # 开始执行转换逻辑 84 | asr = BcutASR() 85 | asr.set_data(raw_data=infile_data, data_fmt=infile_fmt) 86 | try: 87 | # 上传文件 88 | asr.upload() 89 | # 创建任务 90 | task_id = asr.create_task() 91 | while True: 92 | # 轮询检查任务状态 93 | task_resp = asr.result() 94 | match task_resp.state: 95 | case ResultStateEnum.STOP: 96 | logging.info(f'等待识别开始') 97 | case ResultStateEnum.RUNING: 98 | logging.info(f'识别中-{task_resp.remark}') 99 | case ResultStateEnum.ERROR: 100 | logging.error(f'识别失败-{task_resp.remark}') 101 | sys.exit(-1) 102 | case ResultStateEnum.COMPLETE: 103 | logging.info(f'识别成功') 104 | outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}" 105 | outfile = open(outfile_name, 'w', encoding='utf8') 106 | # 识别成功, 回读字幕数据 107 | result = task_resp.parse() 108 | break 109 | time.sleep(1.0) 110 | if not result.has_data(): 111 | logging.error('未识别到语音') 112 | sys.exit(-1) 113 | match outfile_fmt: 114 | case 'srt': 115 | outfile.write(result.to_srt()) 116 | case 'lrc': 117 | outfile.write(result.to_lrc()) 118 | case 'json': 119 | outfile.write(result.json()) 120 | case 'txt': 121 | outfile.write(result.to_txt()) 122 | logging.info(f'转换成功: {outfile_name}') 123 | except APIError as err: 124 | logging.error(f'接口错误: {err.__str__()}') 125 | sys.exit(-1) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /short_audio_transcribe_whisper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import whisper 4 | import torch 5 | 6 | from tqdm import tqdm 7 | import sys 8 | import os 9 | 10 | from common.constants import Languages 11 | from common.log import logger 12 | from common.stdout_wrapper import SAFE_STDOUT 13 | 14 | import re 15 | 16 | from transformers import pipeline 17 | 18 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 19 | 20 | model = None 21 | 22 | 23 | lang2token = { 24 | 'zh': "ZH|", 25 | 'ja': "JP|", 26 | "en": "EN|", 27 | } 28 | 29 | 30 | def transcribe_bela(audio_path): 31 | 32 | transcriber = pipeline( 33 | "automatic-speech-recognition", 34 | model="BELLE-2/Belle-whisper-large-v2-zh", 35 | device=device 36 | ) 37 | 38 | transcriber.model.config.forced_decoder_ids = ( 39 | transcriber.tokenizer.get_decoder_prompt_ids( 40 | language="zh", 41 | task="transcribe", 42 | ) 43 | ) 44 | 45 | transcription = transcriber(audio_path) 46 | 47 | print(transcription["text"]) 48 | return transcription["text"] 49 | 50 | 51 | def transcribe_one(audio_path,mytype): 52 | # load audio and pad/trim it to fit 30 seconds 53 | audio = whisper.load_audio(audio_path) 54 | audio = whisper.pad_or_trim(audio) 55 | 56 | # make log-Mel spectrogram and move to the same device as the model 57 | 58 | if mytype == "large-v3": 59 | 60 | mel = whisper.log_mel_spectrogram(audio,n_mels=128).to(model.device) 61 | 62 | else: 63 | 64 | mel = whisper.log_mel_spectrogram(audio).to(model.device) 65 | 66 | 67 | # detect the spoken language 68 | _, probs = model.detect_language(mel) 69 | print(f"Detected language: {max(probs, key=probs.get)}") 70 | lang = max(probs, key=probs.get) 71 | # decode the audio 72 | 73 | 74 | if lang == "zh": 75 | 76 | 77 | if torch.cuda.is_available(): 78 | options = whisper.DecodingOptions(beam_size=5,prompt="生于忧患,死于欢乐。不亦快哉!") 79 | else: 80 | options = whisper.DecodingOptions(beam_size=5,fp16 = False,prompt="生于忧患,死于欢乐。不亦快哉!") 81 | 82 | else: 83 | 84 | 85 | 86 | if torch.cuda.is_available(): 87 | options = whisper.DecodingOptions(beam_size=5) 88 | else: 89 | options = whisper.DecodingOptions(beam_size=5,fp16 = False) 90 | 91 | 92 | 93 | 94 | 95 | 96 | result = whisper.decode(model, mel, options) 97 | 98 | # print the recognized text 99 | print(result.text) 100 | return result.text,max(probs, key=probs.get) 101 | 102 | 103 | if __name__ == "__main__": 104 | 105 | parser = argparse.ArgumentParser() 106 | 107 | parser.add_argument( 108 | "--language", type=str, default="ja", choices=["ja", "en", "zh"] 109 | ) 110 | 111 | parser.add_argument( 112 | "--mytype", type=str, default="medium" 113 | ) 114 | 115 | parser.add_argument("--model_name", type=str, required=True) 116 | 117 | parser.add_argument("--input_file", type=str, default="./wavs/") 118 | 119 | parser.add_argument("--file_pos", type=str, default="") 120 | 121 | 122 | args = parser.parse_args() 123 | 124 | speaker_name = args.model_name 125 | 126 | language = args.language 127 | 128 | mytype = args.mytype 129 | 130 | input_file = args.input_file 131 | 132 | if input_file == "": 133 | input_file = "./wavs/" 134 | 135 | file_pos = args.file_pos 136 | 137 | try: 138 | model = whisper.load_model(mytype,download_root="./whisper_model/") 139 | except Exception as e: 140 | 141 | print(str(e)) 142 | print("中文特化逻辑") 143 | 144 | 145 | wav_files = [ 146 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav") 147 | ] 148 | 149 | 150 | 151 | with open("./esd.list", "w", encoding="utf-8") as f: 152 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT): 153 | file_name = os.path.basename(wav_file) 154 | 155 | if model: 156 | text,lang = transcribe_one(f"{input_file}"+wav_file,mytype) 157 | else: 158 | text,lang = transcribe_bela(f"{input_file}"+wav_file) 159 | 160 | # 使用正则表达式提取'deedee' 161 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file) 162 | if match: 163 | extracted_name = match.group(1) + match.group(2) 164 | else: 165 | print("No match found") 166 | extracted_name = "sample" 167 | 168 | if lang == "ja": 169 | language_id = "JA" 170 | elif lang == "en": 171 | language_id = "EN" 172 | elif lang == "zh": 173 | language_id = "ZH" 174 | 175 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n") 176 | 177 | f.flush() 178 | sys.exit(0) 179 | 180 | 181 | -------------------------------------------------------------------------------- /subtitle_utils.py: -------------------------------------------------------------------------------- 1 | def time_convert(ms): 2 | ms = int(ms) 3 | tail = ms % 1000 4 | s = ms // 1000 5 | mi = s // 60 6 | s = s % 60 7 | h = mi // 60 8 | mi = mi % 60 9 | h = "00" if h == 0 else str(h) 10 | mi = "00" if mi == 0 else str(mi) 11 | s = "00" if s == 0 else str(s) 12 | tail = str(tail) 13 | if len(h) == 1: h = '0' + h 14 | if len(mi) == 1: mi = '0' + mi 15 | if len(s) == 1: s = '0' + s 16 | return "{}:{}:{},{}".format(h, mi, s, tail) 17 | 18 | 19 | class Text2SRT(): 20 | def __init__(self, text_seg, ts_list, offset=0): 21 | self.token_list = [i for i in text_seg.split() if len(i)] 22 | self.ts_list = ts_list 23 | start, end = ts_list[0][0] - offset, ts_list[-1][1] - offset 24 | self.start_sec, self.end_sec = start, end 25 | self.start_time = time_convert(start) 26 | self.end_time = time_convert(end) 27 | def text(self): 28 | res = "" 29 | for word in self.token_list: 30 | if '\u4e00' <= word <= '\u9fff': 31 | res += word 32 | else: 33 | res += " " + word 34 | return res 35 | def len(self): 36 | return len(self.token_list) 37 | def srt(self, acc_ost=0.0): 38 | return "{} --> {}\n{}\n".format( 39 | time_convert(self.start_sec+acc_ost*1000), 40 | time_convert(self.end_sec+acc_ost*1000), 41 | self.text()) 42 | def time(self, acc_ost=0.0): 43 | return (self.start_sec/1000+acc_ost, self.end_sec/1000+acc_ost) 44 | 45 | def distribute_spk(sentence_list, sd_time_list): 46 | sd_sentence_list = [] 47 | for d in sentence_list: 48 | sentence_start = d['ts_list'][0][0] 49 | sentence_end = d['ts_list'][-1][1] 50 | sentence_spk = 0 51 | max_overlap = 0 52 | for sd_time in sd_time_list: 53 | spk_st, spk_ed, spk = sd_time 54 | spk_st = spk_st*1000 55 | spk_ed = spk_ed*1000 56 | overlap = max( 57 | min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0) 58 | if overlap > max_overlap: 59 | max_overlap = overlap 60 | sentence_spk = spk 61 | d['spk'] = sentence_spk 62 | sd_sentence_list.append(d) 63 | return sd_sentence_list 64 | 65 | def generate_srt(sentence_list): 66 | srt_total = '' 67 | for i, d in enumerate(sentence_list): 68 | t2s = Text2SRT(d['text_seg'], d['ts_list']) 69 | if 'spk' in d: 70 | srt_total += "{} spk{}\n{}".format(i, d['spk'], t2s.srt()) 71 | else: 72 | srt_total += "{}\n{}".format(i, t2s.srt()) 73 | return srt_total 74 | 75 | def generate_srt_clip(sentence_list, start, end, begin_index=0, time_acc_ost=0.0): 76 | start, end = int(start * 1000), int(end * 1000) 77 | srt_total = '' 78 | cc = 1 + begin_index 79 | subs = [] 80 | for i, d in enumerate(sentence_list): 81 | if d['ts_list'][-1][1] <= start: 82 | continue 83 | if d['ts_list'][0][0] >= end: 84 | break 85 | # parts in between 86 | if (d['ts_list'][-1][1] <= end and d['ts_list'][0][0] > start) or (d['ts_list'][-1][1] == end and d['ts_list'][0][0] == start): 87 | t2s = Text2SRT(d['text_seg'], d['ts_list'], offset=start) 88 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost)) 89 | subs.append((t2s.time(time_acc_ost), t2s.text())) 90 | cc += 1 91 | continue 92 | if d['ts_list'][0][0] <= start: 93 | if not d['ts_list'][-1][1] > end: 94 | for j, ts in enumerate(d['ts_list']): 95 | if ts[1] > start: 96 | break 97 | _text = " ".join(d['text_seg'].split()[j:]) 98 | _ts = d['ts_list'][j:] 99 | else: 100 | for j, ts in enumerate(d['ts_list']): 101 | if ts[1] > start: 102 | _start = j 103 | break 104 | for j, ts in enumerate(d['ts_list']): 105 | if ts[1] > end: 106 | _end = j 107 | break 108 | _text = " ".join(d['text_seg'].split()[_start:_end]) 109 | _ts = d['ts_list'][_start:_end] 110 | if len(ts): 111 | t2s = Text2SRT(_text, _ts, offset=start) 112 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost)) 113 | subs.append((t2s.time(time_acc_ost), t2s.text())) 114 | cc += 1 115 | continue 116 | if d['ts_list'][-1][1] > end: 117 | for j, ts in enumerate(d['ts_list']): 118 | if ts[1] > end: 119 | break 120 | _text = " ".join(d['text_seg'].split()[:j]) 121 | _ts = d['ts_list'][:j] 122 | if len(_ts): 123 | t2s = Text2SRT(_text, _ts, offset=start) 124 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost)) 125 | subs.append( 126 | (t2s.time(time_acc_ost), t2s.text()) 127 | ) 128 | cc += 1 129 | continue 130 | return srt_total, subs, cc 131 | -------------------------------------------------------------------------------- /short_audio_transcribe_ali.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import whisper 4 | import torch 5 | 6 | from tqdm import tqdm 7 | import sys 8 | import os 9 | 10 | 11 | 12 | from common.constants import Languages 13 | from common.log import logger 14 | from common.stdout_wrapper import SAFE_STDOUT 15 | 16 | import re 17 | 18 | 19 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 20 | 21 | 22 | from funasr import AutoModel 23 | 24 | model_dir = "iic/SenseVoiceSmall" 25 | 26 | 27 | emo_dict = { 28 | "<|HAPPY|>": "😊", 29 | "<|SAD|>": "😔", 30 | "<|ANGRY|>": "😡", 31 | "<|NEUTRAL|>": "", 32 | "<|FEARFUL|>": "😰", 33 | "<|DISGUSTED|>": "🤢", 34 | "<|SURPRISED|>": "😮", 35 | } 36 | 37 | event_dict = { 38 | "<|BGM|>": "🎼", 39 | "<|Speech|>": "", 40 | "<|Applause|>": "👏", 41 | "<|Laughter|>": "😀", 42 | "<|Cry|>": "😭", 43 | "<|Sneeze|>": "🤧", 44 | "<|Breath|>": "", 45 | "<|Cough|>": "🤧", 46 | } 47 | 48 | emoji_dict = { 49 | "<|nospeech|><|Event_UNK|>": "❓", 50 | "<|zh|>": "", 51 | "<|en|>": "", 52 | "<|yue|>": "", 53 | "<|ja|>": "", 54 | "<|ko|>": "", 55 | "<|nospeech|>": "", 56 | "<|HAPPY|>": "😊", 57 | "<|SAD|>": "😔", 58 | "<|ANGRY|>": "😡", 59 | "<|NEUTRAL|>": "", 60 | "<|BGM|>": "🎼", 61 | "<|Speech|>": "", 62 | "<|Applause|>": "👏", 63 | "<|Laughter|>": "😀", 64 | "<|FEARFUL|>": "😰", 65 | "<|DISGUSTED|>": "🤢", 66 | "<|SURPRISED|>": "😮", 67 | "<|Cry|>": "😭", 68 | "<|EMO_UNKNOWN|>": "", 69 | "<|Sneeze|>": "🤧", 70 | "<|Breath|>": "", 71 | "<|Cough|>": "😷", 72 | "<|Sing|>": "", 73 | "<|Speech_Noise|>": "", 74 | "<|withitn|>": "", 75 | "<|woitn|>": "", 76 | "<|GBG|>": "", 77 | "<|Event_UNK|>": "", 78 | } 79 | 80 | lang_dict = { 81 | "<|zh|>": "<|lang|>", 82 | "<|en|>": "<|lang|>", 83 | "<|yue|>": "<|lang|>", 84 | "<|ja|>": "<|lang|>", 85 | "<|ko|>": "<|lang|>", 86 | "<|nospeech|>": "<|lang|>", 87 | } 88 | 89 | emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"} 90 | event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷",} 91 | 92 | lang2token = { 93 | 'zh': "ZH|", 94 | 'ja': "JP|", 95 | "en": "EN|", 96 | "ko": "KO|", 97 | "yue": "YUE|", 98 | } 99 | 100 | def format_str(s): 101 | for sptk in emoji_dict: 102 | s = s.replace(sptk, emoji_dict[sptk]) 103 | return s 104 | 105 | 106 | def format_str_v2(s): 107 | sptk_dict = {} 108 | for sptk in emoji_dict: 109 | sptk_dict[sptk] = s.count(sptk) 110 | s = s.replace(sptk, "") 111 | emo = "<|NEUTRAL|>" 112 | for e in emo_dict: 113 | if sptk_dict[e] > sptk_dict[emo]: 114 | emo = e 115 | for e in event_dict: 116 | if sptk_dict[e] > 0: 117 | s = event_dict[e] + s 118 | s = s + emo_dict[emo] 119 | 120 | for emoji in emo_set.union(event_set): 121 | s = s.replace(" " + emoji, emoji) 122 | s = s.replace(emoji + " ", emoji) 123 | return s.strip() 124 | 125 | def format_str_v3(s): 126 | def get_emo(s): 127 | return s[-1] if s[-1] in emo_set else None 128 | def get_event(s): 129 | return s[0] if s[0] in event_set else None 130 | 131 | s = s.replace("<|nospeech|><|Event_UNK|>", "❓") 132 | for lang in lang_dict: 133 | s = s.replace(lang, "<|lang|>") 134 | s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")] 135 | new_s = " " + s_list[0] 136 | cur_ent_event = get_event(new_s) 137 | for i in range(1, len(s_list)): 138 | if len(s_list[i]) == 0: 139 | continue 140 | if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None: 141 | s_list[i] = s_list[i][1:] 142 | #else: 143 | cur_ent_event = get_event(s_list[i]) 144 | if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s): 145 | new_s = new_s[:-1] 146 | new_s += s_list[i].strip().lstrip() 147 | new_s = new_s.replace("The.", " ") 148 | return new_s.strip() 149 | 150 | def transcribe_one(audio_path,language): 151 | 152 | model = AutoModel(model=model_dir, 153 | vad_model="fsmn-vad", 154 | vad_kwargs={"max_single_segment_time": 30000}, 155 | trust_remote_code=True, device="cuda:0") 156 | 157 | res = model.generate( 158 | input=audio_path, 159 | cache={}, 160 | language=language, # "zn", "en", "yue", "ja", "ko", "nospeech" 161 | use_itn=False, 162 | batch_size_s=0, 163 | ) 164 | 165 | try: 166 | 167 | text = res[0]["text"] 168 | text = format_str_v3(text) 169 | print(text) 170 | except Exception as e: 171 | print(e) 172 | text = "" 173 | 174 | 175 | return text,language 176 | 177 | 178 | if __name__ == "__main__": 179 | 180 | parser = argparse.ArgumentParser() 181 | 182 | parser.add_argument( 183 | "--language", type=str, default="ja", choices=["ja", "en", "zh","yue","ko"] 184 | ) 185 | parser.add_argument("--model_name", type=str, required=True) 186 | 187 | 188 | parser.add_argument("--input_file", type=str, default="./wavs/") 189 | 190 | parser.add_argument("--file_pos", type=str, default="") 191 | 192 | 193 | args = parser.parse_args() 194 | 195 | speaker_name = args.model_name 196 | 197 | language = args.language 198 | 199 | 200 | input_file = args.input_file 201 | 202 | if input_file == "": 203 | input_file = "./wavs/" 204 | 205 | file_pos = args.file_pos 206 | 207 | 208 | wav_files = [ 209 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav") 210 | ] 211 | 212 | 213 | with open("./esd.list", "w", encoding="utf-8") as f: 214 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT): 215 | file_name = os.path.basename(wav_file) 216 | 217 | text,lang = transcribe_one(f"{input_file}"+wav_file,language) 218 | 219 | # 使用正则表达式提取'deedee' 220 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file) 221 | if match: 222 | extracted_name = match.group(1) + match.group(2) 223 | else: 224 | print("No match found") 225 | extracted_name = "sample" 226 | 227 | if lang == "ja": 228 | language_id = "JA" 229 | elif lang == "en": 230 | language_id = "EN" 231 | elif lang == "zh": 232 | language_id = "ZH" 233 | elif lang == "yue": 234 | language_id = "YUE" 235 | elif lang == "ko": 236 | language_id = "KO" 237 | 238 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n") 239 | 240 | f.flush() 241 | sys.exit(0) 242 | 243 | 244 | -------------------------------------------------------------------------------- /slicer2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # This function is obtained from librosa. 5 | def get_rms( 6 | y, 7 | *, 8 | frame_length=2048, 9 | hop_length=512, 10 | pad_mode="constant", 11 | ): 12 | padding = (int(frame_length // 2), int(frame_length // 2)) 13 | y = np.pad(y, padding, mode=pad_mode) 14 | 15 | axis = -1 16 | # put our new within-frame axis at the end for now 17 | out_strides = y.strides + tuple([y.strides[axis]]) 18 | # Reduce the shape on the framing axis 19 | x_shape_trimmed = list(y.shape) 20 | x_shape_trimmed[axis] -= frame_length - 1 21 | out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) 22 | xw = np.lib.stride_tricks.as_strided( 23 | y, shape=out_shape, strides=out_strides 24 | ) 25 | if axis < 0: 26 | target_axis = axis - 1 27 | else: 28 | target_axis = axis + 1 29 | xw = np.moveaxis(xw, -1, target_axis) 30 | # Downsample along the target axis 31 | slices = [slice(None)] * xw.ndim 32 | slices[axis] = slice(0, None, hop_length) 33 | x = xw[tuple(slices)] 34 | 35 | # Calculate power 36 | power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) 37 | 38 | return np.sqrt(power) 39 | 40 | 41 | class Slicer: 42 | def __init__(self, 43 | sr: int, 44 | threshold: float = -40., 45 | min_length: int = 5000, 46 | min_interval: int = 300, 47 | hop_size: int = 20, 48 | max_sil_kept: int = 5000): 49 | if not min_length >= min_interval >= hop_size: 50 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') 51 | if not max_sil_kept >= hop_size: 52 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') 53 | min_interval = sr * min_interval / 1000 54 | self.threshold = 10 ** (threshold / 20.) 55 | self.hop_size = round(sr * hop_size / 1000) 56 | self.win_size = min(round(min_interval), 4 * self.hop_size) 57 | self.min_length = round(sr * min_length / 1000 / self.hop_size) 58 | self.min_interval = round(min_interval / self.hop_size) 59 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) 60 | 61 | def _apply_slice(self, waveform, begin, end): 62 | if len(waveform.shape) > 1: 63 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] 64 | else: 65 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] 66 | 67 | # @timeit 68 | def slice(self, waveform): 69 | if len(waveform.shape) > 1: 70 | samples = waveform.mean(axis=0) 71 | else: 72 | samples = waveform 73 | if (samples.shape[0] + self.hop_size - 1) // self.hop_size <= self.min_length: 74 | return [waveform] 75 | rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) 76 | sil_tags = [] 77 | silence_start = None 78 | clip_start = 0 79 | for i, rms in enumerate(rms_list): 80 | # Keep looping while frame is silent. 81 | if rms < self.threshold: 82 | # Record start of silent frames. 83 | if silence_start is None: 84 | silence_start = i 85 | continue 86 | # Keep looping while frame is not silent and silence start has not been recorded. 87 | if silence_start is None: 88 | continue 89 | # Clear recorded silence start if interval is not enough or clip is too short 90 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept 91 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length 92 | if not is_leading_silence and not need_slice_middle: 93 | silence_start = None 94 | continue 95 | # Need slicing. Record the range of silent frames to be removed. 96 | if i - silence_start <= self.max_sil_kept: 97 | pos = rms_list[silence_start: i + 1].argmin() + silence_start 98 | if silence_start == 0: 99 | sil_tags.append((0, pos)) 100 | else: 101 | sil_tags.append((pos, pos)) 102 | clip_start = pos 103 | elif i - silence_start <= self.max_sil_kept * 2: 104 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() 105 | pos += i - self.max_sil_kept 106 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 107 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 108 | if silence_start == 0: 109 | sil_tags.append((0, pos_r)) 110 | clip_start = pos_r 111 | else: 112 | sil_tags.append((min(pos_l, pos), max(pos_r, pos))) 113 | clip_start = max(pos_r, pos) 114 | else: 115 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 116 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 117 | if silence_start == 0: 118 | sil_tags.append((0, pos_r)) 119 | else: 120 | sil_tags.append((pos_l, pos_r)) 121 | clip_start = pos_r 122 | silence_start = None 123 | # Deal with trailing silence. 124 | total_frames = rms_list.shape[0] 125 | if silence_start is not None and total_frames - silence_start >= self.min_interval: 126 | silence_end = min(total_frames, silence_start + self.max_sil_kept) 127 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start 128 | sil_tags.append((pos, total_frames + 1)) 129 | # Apply and return slices. 130 | if len(sil_tags) == 0: 131 | return [waveform] 132 | else: 133 | chunks = [] 134 | if sil_tags[0][0] > 0: 135 | chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) 136 | for i in range(len(sil_tags) - 1): 137 | chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])) 138 | if sil_tags[-1][1] < total_frames: 139 | chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames)) 140 | return chunks 141 | 142 | 143 | def main(): 144 | import os.path 145 | from argparse import ArgumentParser 146 | 147 | import librosa 148 | import soundfile 149 | 150 | parser = ArgumentParser() 151 | parser.add_argument('audio', type=str, help='The audio to be sliced') 152 | parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips') 153 | parser.add_argument('--db_thresh', type=float, required=False, default=-40, 154 | help='The dB threshold for silence detection') 155 | parser.add_argument('--min_length', type=int, required=False, default=5000, 156 | help='The minimum milliseconds required for each sliced audio clip') 157 | parser.add_argument('--min_interval', type=int, required=False, default=300, 158 | help='The minimum milliseconds for a silence part to be sliced') 159 | parser.add_argument('--hop_size', type=int, required=False, default=10, 160 | help='Frame length in milliseconds') 161 | parser.add_argument('--max_sil_kept', type=int, required=False, default=500, 162 | help='The maximum silence length kept around the sliced clip, presented in milliseconds') 163 | args = parser.parse_args() 164 | out = args.out 165 | if out is None: 166 | out = os.path.dirname(os.path.abspath(args.audio)) 167 | audio, sr = librosa.load(args.audio, sr=None, mono=False) 168 | slicer = Slicer( 169 | sr=sr, 170 | threshold=args.db_thresh, 171 | min_length=args.min_length, 172 | min_interval=args.min_interval, 173 | hop_size=args.hop_size, 174 | max_sil_kept=args.max_sil_kept 175 | ) 176 | chunks = slicer.slice(audio) 177 | if not os.path.exists(out): 178 | os.makedirs(out) 179 | for i, chunk in enumerate(chunks): 180 | if len(chunk.shape) > 1: 181 | chunk = chunk.T 182 | soundfile.write(os.path.join(out, f'%s_%d.wav' % (os.path.basename(args.audio).rsplit('.', maxsplit=1)[0], i)), chunk, sr) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() -------------------------------------------------------------------------------- /webui_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import gradio as gr 5 | import yaml 6 | 7 | from common.log import logger 8 | from common.subprocess_utils import run_script_with_log 9 | 10 | from modelscope.pipelines import pipeline 11 | from modelscope.utils.constant import Tasks 12 | from videoclipper import VideoClipper 13 | import librosa 14 | import soundfile as sf 15 | import numpy as np 16 | import random 17 | 18 | dataset_root = ".\\raw\\" 19 | 20 | 21 | 22 | sd_pipeline = pipeline( 23 | task='speaker-diarization', 24 | model='damo/speech_campplus_speaker-diarization_common', 25 | model_revision='v1.0.0' 26 | ) 27 | 28 | def audio_change(audio): 29 | 30 | print(audio) 31 | 32 | sf.write('./output_44100.wav', audio[1], audio[0], 'PCM_24') 33 | 34 | y, sr = librosa.load('./output_44100.wav', sr=16000) 35 | 36 | # sf.write('./output_16000.wav', y, sr, 'PCM_24') 37 | 38 | # arr = np.array(y, dtype=np.int32) 39 | 40 | # y, sr = librosa.load('./output_16000.wav', sr=16000) 41 | 42 | audio_data = np.array(y) 43 | 44 | print(y, sr) 45 | 46 | return (16000,audio_data) 47 | 48 | def write_list(text,audio): 49 | 50 | random_number = random.randint(10000, 99999) 51 | 52 | wav_name = f'./wavs/sample_{random_number}.wav' 53 | 54 | sf.write(wav_name, audio[1], audio[0], 'PCM_24') 55 | 56 | text = text.replace("#",",") 57 | 58 | with open("./esd.list","a",encoding="utf-8")as f:f.write(f"\n{wav_name}|sample|en|{text}") 59 | 60 | 61 | 62 | 63 | def audio_recog(audio_input, sd_switch): 64 | print(audio_input) 65 | return audio_clipper.recog(audio_input, sd_switch) 66 | 67 | def audio_clip(dest_text, audio_spk_input, start_ost, end_ost, state): 68 | return audio_clipper.clip(dest_text, start_ost, end_ost, state, dest_spk=audio_spk_input) 69 | 70 | # 音频降噪 71 | 72 | def reset_tts_wav(audio): 73 | 74 | ans = pipeline( 75 | Tasks.acoustic_noise_suppression, 76 | model='damo/speech_frcrn_ans_cirm_16k') 77 | ans(audio,output_path='./output_ins.wav') 78 | 79 | return "./output_ins.wav","./output_ins.wav" 80 | 81 | 82 | def do_slice( 83 | dataset_path: str, 84 | min_sec: int, 85 | max_sec: int, 86 | min_silence_dur_ms: int, 87 | ): 88 | if dataset_path == "": 89 | return "Error: 数据集路径不能为空" 90 | logger.info("Start slicing...") 91 | output_dir = os.path.join(dataset_root, dataset_path, ".\\wavs") 92 | 93 | 94 | cmd = [ 95 | "audio_slicer_pre.py", 96 | "--dataset_path", 97 | dataset_path, 98 | "--min_sec", 99 | str(min_sec), 100 | "--max_sec", 101 | str(max_sec), 102 | "--min_silence_dur_ms", 103 | str(min_silence_dur_ms), 104 | ] 105 | 106 | 107 | success, message = run_script_with_log(cmd, ignore_warning=True) 108 | if not success: 109 | return f"Error: {message}" 110 | return "切分完毕" 111 | 112 | 113 | def do_transcribe_fwhisper( 114 | model_name,mytype,language,input_file,file_pos 115 | ): 116 | # if model_name == "": 117 | # return "Error: 角色名不能为空" 118 | 119 | 120 | cmd_py = "short_audio_transcribe_fwhisper.py" 121 | 122 | 123 | success, message = run_script_with_log( 124 | [ 125 | cmd_py, 126 | "--model_name", 127 | model_name, 128 | "--language", 129 | language, 130 | "--mytype", 131 | mytype,"--input_file", 132 | input_file, 133 | "--file_pos", 134 | file_pos, 135 | 136 | ] 137 | ) 138 | if not success: 139 | return f"Error: {message}" 140 | return "转写完毕" 141 | 142 | def do_transcribe_whisper( 143 | model_name,mytype,language,input_file,file_pos 144 | ): 145 | # if model_name == "": 146 | # return "Error: 角色名不能为空" 147 | 148 | 149 | cmd_py = "short_audio_transcribe_whisper.py" 150 | 151 | 152 | success, message = run_script_with_log( 153 | [ 154 | cmd_py, 155 | "--model_name", 156 | model_name, 157 | "--language", 158 | language, 159 | "--mytype", 160 | mytype,"--input_file", 161 | input_file, 162 | "--file_pos", 163 | file_pos, 164 | 165 | ] 166 | ) 167 | if not success: 168 | return f"Error: {message}" 169 | return "转写完毕" 170 | 171 | 172 | def do_transcribe_all( 173 | model_name,mytype,language,input_file,file_pos 174 | ): 175 | # if model_name == "": 176 | # return "Error: 角色名不能为空" 177 | 178 | 179 | cmd_py = "short_audio_transcribe_ali.py" 180 | 181 | 182 | if mytype == "bcut": 183 | 184 | cmd_py = "short_audio_transcribe_bcut.py" 185 | 186 | success, message = run_script_with_log( 187 | [ 188 | cmd_py, 189 | "--model_name", 190 | model_name, 191 | "--language", 192 | language, 193 | "--input_file", 194 | input_file, 195 | "--file_pos", 196 | file_pos, 197 | 198 | ] 199 | ) 200 | if not success: 201 | return f"Error: {message}" 202 | return "转写完毕" 203 | 204 | 205 | initial_md = """ 206 | 207 | 请把格式为 角色名.wav 的素材文件放入项目的raw目录 208 | 209 | 作者:刘悦的技术博客 https://space.bilibili.com/3031494 210 | 211 | """ 212 | 213 | with gr.Blocks(theme="NoCrypt/miku") as app: 214 | gr.Markdown(initial_md) 215 | model_name = gr.Textbox(label="角色名",placeholder="请输入角色名",visible=False) 216 | 217 | 218 | with gr.Accordion("干声抽离和降噪"): 219 | with gr.Row(): 220 | audio_inp_path = gr.Audio(label="请上传克隆对象音频", type="filepath") 221 | reset_inp_button = gr.Button("针对原始素材进行降噪", variant="primary",visible=True) 222 | reset_dataset_path = gr.Textbox(label="降噪后音频地址",placeholder="降噪后生成的音频地址") 223 | 224 | 225 | reset_inp_button.click(reset_tts_wav,[audio_inp_path],[audio_inp_path,reset_dataset_path]) 226 | 227 | with gr.Accordion("音频素材切割"): 228 | with gr.Row(): 229 | ##add by hyh 添加一个数据集路径的文本框 230 | dataset_path = gr.Textbox(label="音频素材所在路径,默认在项目的raw文件夹,支持批量角色切分",placeholder="设置音频素材所在路径",value="./raw/") 231 | with gr.Column(): 232 | 233 | min_sec = gr.Slider( 234 | minimum=0, maximum=7000, value=2500, step=100, label="最低几毫秒" 235 | ) 236 | max_sec = gr.Slider( 237 | minimum=0, maximum=15000, value=5000, step=100, label="最高几毫秒" 238 | ) 239 | min_silence_dur_ms = gr.Slider( 240 | minimum=500, 241 | maximum=5000, 242 | value=500, 243 | step=100, 244 | label="max_sil_kept长度", 245 | ) 246 | slice_button = gr.Button("开始切分") 247 | result1 = gr.Textbox(label="結果") 248 | 249 | 250 | 251 | 252 | with gr.Accordion("音频批量转写,转写文件存放在根目录的est.list"): 253 | with gr.Row(): 254 | with gr.Column(): 255 | 256 | language = gr.Dropdown(["ja", "en", "zh","ko","yue"], value="zh", label="选择转写的语言") 257 | 258 | mytype = gr.Dropdown(["small","medium","large-v3","large-v2"], value="medium", label="选择Whisper模型") 259 | 260 | input_file = gr.Textbox(label="切片所在目录",placeholder="不填默认为./wavs目录") 261 | 262 | file_pos = gr.Textbox(label="切片名称前缀",placeholder="不填只有切片文件名") 263 | 264 | transcribe_button_whisper = gr.Button("Whisper开始转写") 265 | 266 | transcribe_button_fwhisper = gr.Button("Faster-Whisper开始转写") 267 | 268 | transcribe_button_ali = gr.Button("阿里SenseVoice开始转写") 269 | 270 | transcribe_button_bcut = gr.Button("必剪ASR开始转写") 271 | 272 | 273 | result2 = gr.Textbox(label="結果") 274 | 275 | slice_button.click( 276 | do_slice, 277 | inputs=[dataset_path, min_sec, max_sec, min_silence_dur_ms], 278 | outputs=[result1], 279 | ) 280 | transcribe_button_whisper.click( 281 | do_transcribe_whisper, 282 | inputs=[ 283 | model_name, 284 | mytype, 285 | language,input_file,file_pos 286 | ], 287 | outputs=[result2],) 288 | 289 | 290 | transcribe_button_fwhisper.click( 291 | do_transcribe_fwhisper, 292 | inputs=[ 293 | model_name, 294 | mytype, 295 | language,input_file,file_pos 296 | ], 297 | outputs=[result2],) 298 | 299 | 300 | ali = gr.Text(value="ali",visible=False) 301 | 302 | bcut = gr.Text(value="bcut",visible=False) 303 | 304 | 305 | transcribe_button_ali.click( 306 | do_transcribe_all, 307 | inputs=[ 308 | model_name, 309 | ali, 310 | language,input_file,file_pos 311 | ], 312 | outputs=[result2], 313 | ) 314 | 315 | transcribe_button_bcut.click( 316 | do_transcribe_all, 317 | inputs=[ 318 | model_name, 319 | bcut, 320 | language,input_file,file_pos 321 | ], 322 | outputs=[result2], 323 | ) 324 | 325 | parser = argparse.ArgumentParser() 326 | parser.add_argument( 327 | "--server-name", 328 | type=str, 329 | default=None, 330 | help="Server name for Gradio app", 331 | ) 332 | parser.add_argument( 333 | "--no-autolaunch", 334 | action="store_true", 335 | default=False, 336 | help="Do not launch app automatically", 337 | ) 338 | args = parser.parse_args() 339 | 340 | app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name, server_port=7971) 341 | -------------------------------------------------------------------------------- /common/tts_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gradio as gr 3 | import torch 4 | import os 5 | import warnings 6 | from gradio.processing_utils import convert_to_16_bit_wav 7 | from typing import Dict, List, Optional 8 | 9 | import utils 10 | from infer import get_net_g, infer 11 | from models import SynthesizerTrn 12 | 13 | from .log import logger 14 | from .constants import ( 15 | DEFAULT_ASSIST_TEXT_WEIGHT, 16 | DEFAULT_LENGTH, 17 | DEFAULT_LINE_SPLIT, 18 | DEFAULT_NOISE, 19 | DEFAULT_NOISEW, 20 | DEFAULT_SDP_RATIO, 21 | DEFAULT_SPLIT_INTERVAL, 22 | DEFAULT_STYLE, 23 | DEFAULT_STYLE_WEIGHT, 24 | ) 25 | 26 | 27 | class Model: 28 | def __init__( 29 | self, model_path: str, config_path: str, style_vec_path: str, device: str 30 | ): 31 | self.model_path: str = model_path 32 | self.config_path: str = config_path 33 | self.device: str = device 34 | self.style_vec_path: str = style_vec_path 35 | self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path) 36 | self.spk2id: Dict[str, int] = self.hps.data.spk2id 37 | self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()} 38 | 39 | self.num_styles: int = self.hps.data.num_styles 40 | if hasattr(self.hps.data, "style2id"): 41 | self.style2id: Dict[str, int] = self.hps.data.style2id 42 | else: 43 | self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)} 44 | if len(self.style2id) != self.num_styles: 45 | raise ValueError( 46 | f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" 47 | ) 48 | 49 | self.style_vectors: np.ndarray = np.load(self.style_vec_path) 50 | if self.style_vectors.shape[0] != self.num_styles: 51 | raise ValueError( 52 | f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})" 53 | ) 54 | 55 | self.net_g: Optional[SynthesizerTrn] = None 56 | 57 | def load_net_g(self): 58 | self.net_g = get_net_g( 59 | model_path=self.model_path, 60 | version=self.hps.version, 61 | device=self.device, 62 | hps=self.hps, 63 | ) 64 | 65 | def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray: 66 | mean = self.style_vectors[0] 67 | style_vec = self.style_vectors[style_id] 68 | style_vec = mean + (style_vec - mean) * weight 69 | return style_vec 70 | 71 | def get_style_vector_from_audio( 72 | self, audio_path: str, weight: float = 1.0 73 | ) -> np.ndarray: 74 | from style_gen import get_style_vector 75 | 76 | xvec = get_style_vector(audio_path) 77 | mean = self.style_vectors[0] 78 | xvec = mean + (xvec - mean) * weight 79 | return xvec 80 | 81 | def infer( 82 | self, 83 | text: str, 84 | language: str = "JP", 85 | sid: int = 0, 86 | reference_audio_path: Optional[str] = None, 87 | sdp_ratio: float = DEFAULT_SDP_RATIO, 88 | noise: float = DEFAULT_NOISE, 89 | noisew: float = DEFAULT_NOISEW, 90 | length: float = DEFAULT_LENGTH, 91 | line_split: bool = DEFAULT_LINE_SPLIT, 92 | split_interval: float = DEFAULT_SPLIT_INTERVAL, 93 | assist_text: Optional[str] = None, 94 | assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT, 95 | use_assist_text: bool = False, 96 | style: str = DEFAULT_STYLE, 97 | style_weight: float = DEFAULT_STYLE_WEIGHT, 98 | given_tone: Optional[list[int]] = None, 99 | ) -> tuple[int, np.ndarray]: 100 | logger.info(f"Start generating audio data from text:\n{text}") 101 | if reference_audio_path == "": 102 | reference_audio_path = None 103 | if assist_text == "" or not use_assist_text: 104 | assist_text = None 105 | 106 | if self.net_g is None: 107 | self.load_net_g() 108 | if reference_audio_path is None: 109 | style_id = self.style2id[style] 110 | style_vector = self.get_style_vector(style_id, style_weight) 111 | else: 112 | style_vector = self.get_style_vector_from_audio( 113 | reference_audio_path, style_weight 114 | ) 115 | if not line_split: 116 | with torch.no_grad(): 117 | audio = infer( 118 | text=text, 119 | sdp_ratio=sdp_ratio, 120 | noise_scale=noise, 121 | noise_scale_w=noisew, 122 | length_scale=length, 123 | sid=sid, 124 | language=language, 125 | hps=self.hps, 126 | net_g=self.net_g, 127 | device=self.device, 128 | assist_text=assist_text, 129 | assist_text_weight=assist_text_weight, 130 | style_vec=style_vector, 131 | given_tone=given_tone, 132 | ) 133 | else: 134 | texts = text.split("\n") 135 | texts = [t for t in texts if t != ""] 136 | audios = [] 137 | with torch.no_grad(): 138 | for i, t in enumerate(texts): 139 | audios.append( 140 | infer( 141 | text=t, 142 | sdp_ratio=sdp_ratio, 143 | noise_scale=noise, 144 | noise_scale_w=noisew, 145 | length_scale=length, 146 | sid=sid, 147 | language=language, 148 | hps=self.hps, 149 | net_g=self.net_g, 150 | device=self.device, 151 | assist_text=assist_text, 152 | assist_text_weight=assist_text_weight, 153 | style_vec=style_vector, 154 | ) 155 | ) 156 | if i != len(texts) - 1: 157 | audios.append(np.zeros(int(44100 * split_interval))) 158 | audio = np.concatenate(audios) 159 | with warnings.catch_warnings(): 160 | warnings.simplefilter("ignore") 161 | audio = convert_to_16_bit_wav(audio) 162 | logger.info("Audio data generated successfully") 163 | return (self.hps.data.sampling_rate, audio) 164 | 165 | 166 | class ModelHolder: 167 | def __init__(self, root_dir: str, device: str): 168 | self.root_dir: str = root_dir 169 | self.device: str = device 170 | self.model_files_dict: Dict[str, List[str]] = {} 171 | self.current_model: Optional[Model] = None 172 | self.model_names: List[str] = [] 173 | self.models: List[Model] = [] 174 | self.refresh() 175 | 176 | def refresh(self): 177 | self.model_files_dict = {} 178 | self.model_names = [] 179 | self.current_model = None 180 | model_dirs = [ 181 | d 182 | for d in os.listdir(self.root_dir) 183 | if os.path.isdir(os.path.join(self.root_dir, d)) 184 | ] 185 | for model_name in model_dirs: 186 | model_dir = os.path.join(self.root_dir, model_name) 187 | model_files = [ 188 | os.path.join(model_dir, f) 189 | for f in os.listdir(model_dir) 190 | if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors") 191 | ] 192 | if len(model_files) == 0: 193 | logger.warning( 194 | f"No model files found in {self.root_dir}/{model_name}, so skip it" 195 | ) 196 | continue 197 | self.model_files_dict[model_name] = model_files 198 | self.model_names.append(model_name) 199 | 200 | def load_model_gr( 201 | self, model_name: str, model_path: str 202 | ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: 203 | if model_name not in self.model_files_dict: 204 | raise ValueError(f"Model `{model_name}` is not found") 205 | if model_path not in self.model_files_dict[model_name]: 206 | raise ValueError(f"Model file `{model_path}` is not found") 207 | self.current_model = Model( 208 | model_path=model_path, 209 | config_path=os.path.join(self.root_dir, model_name, "config.json"), 210 | style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"), 211 | device=self.device, 212 | ) 213 | speakers = list(self.current_model.spk2id.keys()) 214 | styles = list(self.current_model.style2id.keys()) 215 | return ( 216 | gr.Dropdown(choices=styles, value=styles[0]), 217 | gr.Button(interactive=True, value="音声合成"), 218 | gr.Dropdown(choices=speakers, value=speakers[0]), 219 | ) 220 | 221 | def update_model_files_gr(self, model_name: str) -> gr.Dropdown: 222 | model_files = self.model_files_dict[model_name] 223 | return gr.Dropdown(choices=model_files, value=model_files[0]) 224 | 225 | def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: 226 | self.refresh() 227 | initial_model_name = self.model_names[0] 228 | initial_model_files = self.model_files_dict[initial_model_name] 229 | return ( 230 | gr.Dropdown(choices=self.model_names, value=initial_model_name), 231 | gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), 232 | gr.Button(interactive=False), # For tts_button 233 | ) 234 | -------------------------------------------------------------------------------- /bcut_asr/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from os import PathLike 4 | from pathlib import Path 5 | from typing import Literal, Optional 6 | import requests 7 | import sys 8 | import ffmpeg 9 | from .orm import (ResourceCompleteRspSchema, ResourceCreateRspSchema, 10 | ResultRspSchema, ResultStateEnum, TaskCreateRspSchema) 11 | from typing import Optional, Union 12 | from os import PathLike 13 | 14 | __version__ = '0.0.2' 15 | 16 | API_REQ_UPLOAD = 'https://member.bilibili.com/x/bcut/rubick-interface/resource/create' # 申请上传 17 | API_COMMIT_UPLOAD = 'https://member.bilibili.com/x/bcut/rubick-interface/resource/create/complete' # 提交上传 18 | API_CREATE_TASK = 'https://member.bilibili.com/x/bcut/rubick-interface/task' # 创建任务 19 | API_QUERY_RESULT = 'https://member.bilibili.com/x/bcut/rubick-interface/task/result' # 查询结果 20 | 21 | SUPPORT_SOUND_FORMAT = Literal['flac', 'aac', 'm4a', 'mp3', 'wav'] 22 | 23 | INFILE_FMT = ('flac', 'aac', 'm4a', 'mp3', 'wav') 24 | OUTFILE_FMT = ('srt', 'json', 'lrc', 'txt') 25 | 26 | 27 | def ffmpeg_render(media_file: str) -> bytes: 28 | '提取视频伴音并转码为aac格式' 29 | out, err = (ffmpeg 30 | .input(media_file, v='warning') 31 | .output('pipe:', ac=1, format='adts') 32 | .run(capture_stdout=True) 33 | ) 34 | return out 35 | 36 | 37 | def run_everywhere(argg): 38 | logging.basicConfig(format='%(asctime)s - [%(levelname)s] %(message)s', level=logging.INFO) 39 | # 处理输入文件情况 40 | infile = argg.input 41 | infile_name = infile.name 42 | if infile_name == '': 43 | logging.error('输入文件错误') 44 | sys.exit(-1) 45 | suffix = infile_name.rsplit('.', 1)[-1] 46 | if suffix in INFILE_FMT: 47 | infile_fmt = suffix 48 | infile_data = infile.read() 49 | else: 50 | # ffmpeg分离视频伴音 51 | logging.info('非标准音频文件, 尝试调用ffmpeg转码') 52 | try: 53 | infile_data = ffmpeg_render(infile_name) 54 | except ffmpeg.Error: 55 | logging.error('ffmpeg转码失败') 56 | sys.exit(-1) 57 | else: 58 | logging.info('ffmpeg转码完成') 59 | infile_fmt = 'aac' 60 | 61 | # 处理输出文件情况 62 | outfile = argg.output 63 | if outfile is None: 64 | # 未指定输出文件,默认为文件名同输入,可以 -t 传参,默认str格式 65 | if argg.format is not None: 66 | outfile_fmt = argg.format 67 | else: 68 | outfile_fmt = 'srt' 69 | else: 70 | # 指定输出文件 71 | outfile_name = outfile.name 72 | if outfile.name == '': 73 | # stdout情况,可以 -t 传参,默认str格式 74 | if argg.format is not None: 75 | outfile_fmt = argg.format 76 | else: 77 | outfile_fmt = 'srt' 78 | else: 79 | suffix = outfile_name.rsplit('.', 1)[-1] 80 | if suffix in OUTFILE_FMT: 81 | outfile_fmt = suffix 82 | else: 83 | logging.error('输出格式错误') 84 | sys.exit(-1) 85 | 86 | # 开始执行转换逻辑 87 | asr = BcutASR() 88 | asr.set_data(raw_data=infile_data, data_fmt=infile_fmt) 89 | try: 90 | # 上传文件 91 | asr.upload() 92 | # 创建任务 93 | task_id = asr.create_task() 94 | while True: 95 | # 轮询检查任务状态 96 | task_resp = asr.result() 97 | # match task_resp.state: 98 | # case ResultStateEnum.STOP: 99 | # logging.info(f'等待识别开始') 100 | # case ResultStateEnum.RUNING: 101 | # logging.info(f'识别中-{task_resp.remark}') 102 | # case ResultStateEnum.ERROR: 103 | # logging.error(f'识别失败-{task_resp.remark}') 104 | # sys.exit(-1) 105 | # case ResultStateEnum.COMPLETE: 106 | # outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}" 107 | # outfile = open(outfile_name, 'w', encoding='utf8') 108 | # logging.info(f'识别成功') 109 | # # 识别成功, 回读字幕数据 110 | # result = task_resp.parse() 111 | # break 112 | 113 | if task_resp.state == ResultStateEnum.STOP: 114 | logging.info(f'等待识别开始') 115 | elif task_resp.state == ResultStateEnum.RUNING: 116 | logging.info(f'识别中-{task_resp.remark}') 117 | elif task_resp.state == ResultStateEnum.ERROR: 118 | logging.error(f'识别失败-{task_resp.remark}') 119 | sys.exit(-1) 120 | elif task_resp.state == ResultStateEnum.COMPLETE: 121 | outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}" 122 | outfile = open(outfile_name, 'w', encoding='utf8') 123 | logging.info(f'识别成功') 124 | # 识别成功, 回读字幕数据 125 | result = task_resp.parse() 126 | break 127 | 128 | 129 | time.sleep(300.0) 130 | if not result.has_data(): 131 | logging.error('未识别到语音') 132 | sys.exit(-1) 133 | # match outfile_fmt: 134 | # case 'srt': 135 | # outfile.write(result.to_srt()) 136 | # case 'lrc': 137 | # outfile.write(result.to_lrc()) 138 | # case 'json': 139 | # outfile.write(result.json()) 140 | # case 'txt': 141 | # outfile.write(result.to_txt()) 142 | if outfile_fmt == 'srt': 143 | outfile.write(result.to_srt()) 144 | elif outfile_fmt == 'lrc': 145 | outfile.write(result.to_lrc()) 146 | elif outfile_fmt == 'json': 147 | outfile.write(result.json()) 148 | elif outfile_fmt == 'txt': 149 | outfile.write(result.to_txt()) 150 | 151 | logging.info(f'转换成功: {outfile_name}') 152 | except APIError as err: 153 | logging.error(f'接口错误: {err.__str__()}') 154 | sys.exit(-1) 155 | 156 | class APIError(Exception): 157 | '接口调用错误' 158 | def __init__(self, code, msg) -> None: 159 | self.code = code 160 | self.msg = msg 161 | super().__init__() 162 | def __str__(self) -> str: 163 | return f'{self.code}:{self.msg}' 164 | 165 | class BcutASR: 166 | '必剪 语音识别接口' 167 | session: requests.Session 168 | sound_name: str 169 | sound_bin: bytes 170 | sound_fmt: SUPPORT_SOUND_FORMAT 171 | __in_boss_key: str 172 | __resource_id: str 173 | __upload_id: str 174 | __upload_urls: list[str] 175 | __per_size: int 176 | __clips: int 177 | __etags: list[str] 178 | __download_url: str 179 | task_id: str 180 | 181 | # def __init__(self, file: Optional[str | PathLike] = None) -> None: 182 | def __init__(self, file: Optional[Union[str, PathLike]] = None) -> None: 183 | self.session = requests.Session() 184 | self.task_id = None 185 | self.__etags = [] 186 | if file: 187 | self.set_data(file) 188 | 189 | def set_data(self, 190 | # file: Optional[str | PathLike] = None, 191 | file: Optional[Union[str, PathLike]] = None, 192 | raw_data: Optional[bytes] = None, 193 | data_fmt: Optional[SUPPORT_SOUND_FORMAT] = None 194 | ) -> None: 195 | '设置欲识别的数据' 196 | if file: 197 | if not isinstance(file, (str, PathLike)): 198 | raise TypeError('unknow file ptr') 199 | # 文件类 200 | file = Path(file) 201 | self.sound_bin = open(file, 'rb').read() 202 | suffix = data_fmt or file.suffix[1:] 203 | self.sound_name = file.name 204 | elif raw_data: 205 | # bytes类 206 | self.sound_bin = raw_data 207 | suffix = data_fmt 208 | self.sound_name = f'{int(time.time())}.{suffix}' 209 | else: 210 | raise ValueError('none set data') 211 | if suffix not in SUPPORT_SOUND_FORMAT.__args__: 212 | raise TypeError('format is not support') 213 | self.sound_fmt = suffix 214 | logging.info(f'加载文件成功: {self.sound_name}') 215 | 216 | def upload(self) -> None: 217 | '申请上传' 218 | if not self.sound_bin or not self.sound_fmt: 219 | raise ValueError('none set data') 220 | resp = self.session.post(API_REQ_UPLOAD, data={ 221 | 'type': 2, 222 | 'name': self.sound_name, 223 | 'size': len(self.sound_bin), 224 | 'resource_file_type': self.sound_fmt, 225 | 'model_id': 7 226 | }) 227 | resp.raise_for_status() 228 | resp = resp.json() 229 | code = resp['code'] 230 | if code: 231 | raise APIError(code, resp['message']) 232 | resp_data = ResourceCreateRspSchema.parse_obj(resp['data']) 233 | self.__in_boss_key = resp_data.in_boss_key 234 | self.__resource_id = resp_data.resource_id 235 | self.__upload_id = resp_data.upload_id 236 | self.__upload_urls = resp_data.upload_urls 237 | self.__per_size = resp_data.per_size 238 | self.__clips = len(resp_data.upload_urls) 239 | logging.info(f'申请上传成功, 总计大小{resp_data.size // 1024}KB, {self.__clips}分片, 分片大小{resp_data.per_size // 1024}KB: {self.__in_boss_key}') 240 | self.__upload_part() 241 | self.__commit_upload() 242 | 243 | def __upload_part(self) -> None: 244 | '上传音频数据' 245 | for clip in range(self.__clips): 246 | start_range = clip * self.__per_size 247 | end_range = (clip + 1) * self.__per_size 248 | logging.info(f'开始上传分片{clip}: {start_range}-{end_range}') 249 | resp = self.session.put(self.__upload_urls[clip], 250 | data=self.sound_bin[start_range:end_range], 251 | ) 252 | resp.raise_for_status() 253 | etag = resp.headers.get('Etag') 254 | self.__etags.append(etag) 255 | logging.info(f'分片{clip}上传成功: {etag}') 256 | 257 | def __commit_upload(self) -> None: 258 | '提交上传数据' 259 | resp = self.session.post(API_COMMIT_UPLOAD, data={ 260 | 'in_boss_key': self.__in_boss_key, 261 | 'resource_id': self.__resource_id, 262 | 'etags': ','.join(self.__etags), 263 | 'upload_id': self.__upload_id, 264 | 'model_id': 7 265 | }) 266 | resp.raise_for_status() 267 | resp = resp.json() 268 | code = resp['code'] 269 | if code: 270 | raise APIError(code, resp['message']) 271 | resp_data = ResourceCompleteRspSchema.parse_obj(resp['data']) 272 | self.__download_url = resp_data.download_url 273 | logging.info(f'提交成功') 274 | 275 | def create_task(self) -> str: 276 | '开始创建转换任务' 277 | resp = self.session.post(API_CREATE_TASK, json={ 278 | 'resource': self.__download_url, 279 | 'model_id': '7' 280 | }) 281 | resp.raise_for_status() 282 | resp = resp.json() 283 | code = resp['code'] 284 | if code: 285 | raise APIError(code, resp['message']) 286 | resp_data = TaskCreateRspSchema.parse_obj(resp['data']) 287 | self.task_id = resp_data.task_id 288 | logging.info(f'任务已创建: {self.task_id}') 289 | return self.task_id 290 | 291 | def result(self, task_id: Optional[str] = None) -> ResultRspSchema: 292 | '查询转换结果' 293 | resp = self.session.get(API_QUERY_RESULT, params={ 294 | 'model_id': 7, 295 | 'task_id': task_id or self.task_id 296 | }) 297 | resp.raise_for_status() 298 | resp = resp.json() 299 | code = resp['code'] 300 | if code: 301 | raise APIError(code, resp['message']) 302 | return ResultRspSchema.parse_obj(resp['data']) 303 | -------------------------------------------------------------------------------- /videoclipper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import librosa 4 | import logging 5 | import argparse 6 | import numpy as np 7 | import soundfile as sf 8 | import moviepy.editor as mpy 9 | # from modelscope.pipelines import pipeline 10 | # from modelscope.utils.constant import Tasks 11 | from subtitle_utils import generate_srt, generate_srt_clip, distribute_spk 12 | from trans_utils import pre_proc, proc, write_state, load_state, proc_spk, generate_vad_data 13 | from argparse_tools import ArgumentParser, get_commandline_args 14 | 15 | from moviepy.editor import * 16 | from moviepy.video.tools.subtitles import SubtitlesClip 17 | 18 | 19 | class VideoClipper(): 20 | def __init__(self, asr_pipeline, sd_pipeline=None): 21 | logging.warning("Initializing VideoClipper.") 22 | self.asr_pipeline = asr_pipeline 23 | self.sd_pipeline = sd_pipeline 24 | 25 | def recog(self, audio_input, sd_switch='no', state=None): 26 | if state is None: 27 | state = {} 28 | sr, data = audio_input 29 | assert sr == 16000, "16kHz sample rate required, {} given.".format(sr) 30 | if len(data.shape) == 2: # multi-channel wav input 31 | logging.warning("Input wav shape: {}, only first channel reserved.").format(data.shape) 32 | data = data[:,0] 33 | state['audio_input'] = (sr, data) 34 | data = data.astype(np.float64) 35 | rec_result = self.asr_pipeline(audio_in=data) 36 | if sd_switch == 'yes': 37 | vad_data = generate_vad_data(data.astype(np.float32), rec_result['sentences'], sr) 38 | sd_result = self.sd_pipeline(audio=vad_data, batch_size=1) 39 | rec_result['sd_sentences'] = distribute_spk(rec_result['sentences'], sd_result['text']) 40 | res_srt = generate_srt(rec_result['sd_sentences']) 41 | state['sd_sentences'] = rec_result['sd_sentences'] 42 | else: 43 | res_srt = generate_srt(rec_result['sentences']) 44 | state['recog_res_raw'] = rec_result['text_postprocessed'] 45 | state['timestamp'] = rec_result['time_stamp'] 46 | state['sentences'] = rec_result['sentences'] 47 | res_text = rec_result['text'] 48 | return res_text, res_srt, state 49 | 50 | def clip(self, dest_text, start_ost, end_ost, state, dest_spk=None): 51 | # get from state 52 | audio_input = state['audio_input'] 53 | recog_res_raw = state['recog_res_raw'] 54 | timestamp = state['timestamp'] 55 | sentences = state['sentences'] 56 | sr, data = audio_input 57 | data = data.astype(np.float64) 58 | 59 | all_ts = [] 60 | if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state: 61 | for _dest_text in dest_text.split('#'): 62 | _dest_text = pre_proc(_dest_text) 63 | ts = proc(recog_res_raw, timestamp, _dest_text) 64 | for _ts in ts: all_ts.append(_ts) 65 | else: 66 | for _dest_spk in dest_spk.split('#'): 67 | ts = proc_spk(_dest_spk, state['sd_sentences']) 68 | for _ts in ts: all_ts.append(_ts) 69 | ts = all_ts 70 | # ts.sort() 71 | srt_index = 0 72 | clip_srt = "" 73 | if len(ts): 74 | start, end = ts[0] 75 | start = min(max(0, start+start_ost*16), len(data)) 76 | end = min(max(0, end+end_ost*16), len(data)) 77 | res_audio = data[start:end] 78 | start_end_info = "from {} to {}".format(start/16000, end/16000) 79 | srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index) 80 | clip_srt += srt_clip 81 | for _ts in ts[1:]: # multiple sentence input or multiple output matched 82 | start, end = _ts 83 | start = min(max(0, start+start_ost*16), len(data)) 84 | end = min(max(0, end+end_ost*16), len(data)) 85 | start_end_info += ", from {} to {}".format(start, end) 86 | res_audio = np.concatenate([res_audio, data[start+start_ost*16:end+end_ost*16]], -1) 87 | srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index-1) 88 | clip_srt += srt_clip 89 | if len(ts): 90 | message = "{} periods found in the speech: ".format(len(ts)) + start_end_info 91 | else: 92 | message = "No period found in the speech, return raw speech. You may check the recognition result and try other destination text." 93 | res_audio = data 94 | return (sr, res_audio), message, clip_srt 95 | 96 | def video_recog(self, vedio_filename, sd_switch='no'): 97 | vedio_filename = vedio_filename 98 | clip_video_file = vedio_filename[:-4] + '_clip.mp4' 99 | video = mpy.VideoFileClip(vedio_filename) 100 | audio_file = vedio_filename[:-3] + 'wav' 101 | video.audio.write_audiofile(audio_file) 102 | wav = librosa.load(audio_file, sr=16000)[0] 103 | state = { 104 | 'vedio_filename': vedio_filename, 105 | 'clip_video_file': clip_video_file, 106 | 'video': video, 107 | } 108 | # res_text, res_srt = self.recog((16000, wav), state) 109 | return self.recog((16000, wav), sd_switch, state) 110 | 111 | def video_clip(self, dest_text, start_ost, end_ost, state, font_size=32, font_color='white', add_sub=False, dest_spk=None): 112 | # get from state 113 | recog_res_raw = state['recog_res_raw'] 114 | timestamp = state['timestamp'] 115 | sentences = state['sentences'] 116 | video = state['video'] 117 | clip_video_file = state['clip_video_file'] 118 | vedio_filename = state['vedio_filename'] 119 | 120 | all_ts = [] 121 | srt_index = 0 122 | if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state: 123 | for _dest_text in dest_text.split('#'): 124 | _dest_text = pre_proc(_dest_text) 125 | ts = proc(recog_res_raw, timestamp, _dest_text) 126 | for _ts in ts: all_ts.append(_ts) 127 | else: 128 | for _dest_spk in dest_spk.split('#'): 129 | ts = proc_spk(_dest_spk, state['sd_sentences']) 130 | for _ts in ts: all_ts.append(_ts) 131 | time_acc_ost = 0.0 132 | ts = all_ts 133 | # ts.sort() 134 | clip_srt = "" 135 | if len(ts): 136 | start, end = ts[0][0] / 16000, ts[0][1] / 16000 137 | srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index, time_acc_ost=time_acc_ost) 138 | start, end = start+start_ost/1000.0, end+end_ost/1000.0 139 | video_clip = video.subclip(start, end) 140 | start_end_info = "from {} to {}".format(start, end) 141 | clip_srt += srt_clip 142 | if add_sub: 143 | generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color) 144 | subtitles = SubtitlesClip(subs, generator) 145 | video_clip = CompositeVideoClip([video_clip, subtitles.set_pos(('center','bottom'))]) 146 | concate_clip = [video_clip] 147 | time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0) 148 | for _ts in ts[1:]: 149 | start, end = _ts[0] / 16000, _ts[1] / 16000 150 | srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index-1, time_acc_ost=time_acc_ost) 151 | chi_subs = [] 152 | sub_starts = subs[0][0][0] 153 | for sub in subs: 154 | chi_subs.append(((sub[0][0]-sub_starts, sub[0][1]-sub_starts), sub[1])) 155 | start, end = start+start_ost/1000.0, end+end_ost/1000.0 156 | _video_clip = video.subclip(start, end) 157 | start_end_info += ", from {} to {}".format(start, end) 158 | clip_srt += srt_clip 159 | if add_sub: 160 | generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color) 161 | subtitles = SubtitlesClip(chi_subs, generator) 162 | _video_clip = CompositeVideoClip([_video_clip, subtitles.set_pos(('center','bottom'))]) 163 | # _video_clip.write_videofile("debug.mp4", audio_codec="aac") 164 | concate_clip.append(copy.copy(_video_clip)) 165 | time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0) 166 | message = "{} periods found in the audio: ".format(len(ts)) + start_end_info 167 | logging.warning("Concating...") 168 | if len(concate_clip) > 1: 169 | video_clip = concatenate_videoclips(concate_clip) 170 | video_clip.write_videofile(clip_video_file, audio_codec="aac") 171 | else: 172 | clip_video_file = vedio_filename 173 | message = "No period found in the audio, return raw speech. You may check the recognition result and try other destination text." 174 | srt_clip = '' 175 | return clip_video_file, message, clip_srt 176 | 177 | 178 | def get_parser(): 179 | parser = ArgumentParser( 180 | description="ClipVideo Argument", 181 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 182 | ) 183 | parser.add_argument( 184 | "--stage", 185 | type=int, 186 | choices=(1, 2), 187 | help="Stage, 0 for recognizing and 1 for clipping", 188 | required=True 189 | ) 190 | parser.add_argument( 191 | "--file", 192 | type=str, 193 | default=None, 194 | help="Input file path", 195 | required=True 196 | ) 197 | parser.add_argument( 198 | "--sd_switch", 199 | type=str, 200 | choices=("no", "yes"), 201 | default="no", 202 | help="Trun on the speaker diarization or not", 203 | ) 204 | parser.add_argument( 205 | "--output_dir", 206 | type=str, 207 | default='./output', 208 | help="Output files path", 209 | ) 210 | parser.add_argument( 211 | "--dest_text", 212 | type=str, 213 | default=None, 214 | help="Destination text string for clipping", 215 | ) 216 | parser.add_argument( 217 | "--dest_spk", 218 | type=str, 219 | default=None, 220 | help="Destination spk id for clipping", 221 | ) 222 | parser.add_argument( 223 | "--start_ost", 224 | type=int, 225 | default=0, 226 | help="Offset time in ms at beginning for clipping" 227 | ) 228 | parser.add_argument( 229 | "--end_ost", 230 | type=int, 231 | default=0, 232 | help="Offset time in ms at ending for clipping" 233 | ) 234 | parser.add_argument( 235 | "--output_file", 236 | type=str, 237 | default=None, 238 | help="Output file path" 239 | ) 240 | return parser 241 | 242 | 243 | def runner(stage, file, sd_switch, output_dir, dest_text, dest_spk, start_ost, end_ost, output_file, config=None): 244 | audio_suffixs = ['wav'] 245 | video_suffixs = ['mp4'] 246 | if file[-3:] in audio_suffixs: 247 | mode = 'audio' 248 | elif file[-3:] in video_suffixs: 249 | mode = 'video' 250 | else: 251 | logging.error("Unsupported file format: {}".format(file)) 252 | while output_dir.endswith('/'): 253 | output_dir = output_dir[:-1] 254 | if stage == 1: 255 | from modelscope.pipelines import pipeline 256 | from modelscope.utils.constant import Tasks 257 | # initialize modelscope asr pipeline 258 | logging.warning("Initializing modelscope asr pipeline.") 259 | inference_pipeline = pipeline( 260 | task=Tasks.auto_speech_recognition, 261 | model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch', 262 | vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch', 263 | punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch', 264 | output_dir=output_dir, 265 | ) 266 | sd_pipeline = pipeline( 267 | task='speaker-diarization', 268 | model='damo/speech_campplus_speaker-diarization_common', 269 | model_revision='v1.0.0' 270 | ) 271 | audio_clipper = VideoClipper(inference_pipeline, sd_pipeline) 272 | if mode == 'audio': 273 | logging.warning("Recognizing audio file: {}".format(file)) 274 | wav, sr = librosa.load(file, sr=16000) 275 | res_text, res_srt, state = audio_clipper.recog((sr, wav), sd_switch) 276 | if mode == 'video': 277 | logging.warning("Recognizing video file: {}".format(file)) 278 | res_text, res_srt, state = audio_clipper.video_recog(file, sd_switch) 279 | total_srt_file = output_dir + '/total.srt' 280 | with open(total_srt_file, 'w') as fout: 281 | fout.write(res_srt) 282 | logging.warning("Write total subtitile to {}".format(total_srt_file)) 283 | write_state(output_dir, state) 284 | logging.warning("Recognition successed. You can copy the text segment from below and use stage 2.") 285 | print(res_text) 286 | if stage == 2: 287 | audio_clipper = VideoClipper(None) 288 | if mode == 'audio': 289 | state = load_state(output_dir) 290 | wav, sr = librosa.load(file, sr=16000) 291 | state['audio_input'] = (sr, wav) 292 | (sr, audio), message, srt_clip = audio_clipper.clip(dest_text, start_ost, end_ost, state, dest_spk=dest_spk) 293 | if output_file is None: 294 | output_file = output_dir + '/result.wav' 295 | clip_srt_file = output_file[:-3] + 'srt' 296 | logging.warning(message) 297 | sf.write(output_file, audio, 16000) 298 | assert output_file.endswith('.wav'), "output_file must ends with '.wav'" 299 | logging.warning("Save clipped wav file to {}".format(output_file)) 300 | with open(clip_srt_file, 'w') as fout: 301 | fout.write(srt_clip) 302 | logging.warning("Write clipped subtitile to {}".format(clip_srt_file)) 303 | if mode == 'video': 304 | state = load_state(output_dir) 305 | state['vedio_filename'] = file 306 | if output_file is None: 307 | state['clip_video_file'] = file[:-4] + '_clip.mp4' 308 | else: 309 | state['clip_video_file'] = output_file 310 | clip_srt_file = state['clip_video_file'][:-3] + 'srt' 311 | state['video'] = mpy.VideoFileClip(file) 312 | clip_video_file, message, srt_clip = audio_clipper.video_clip(dest_text, start_ost, end_ost, state, dest_spk=dest_spk) 313 | logging.warning("Clipping Log: {}".format(message)) 314 | logging.warning("Save clipped mp4 file to {}".format(clip_video_file)) 315 | with open(clip_srt_file, 'w') as fout: 316 | fout.write(srt_clip) 317 | logging.warning("Write clipped subtitile to {}".format(clip_srt_file)) 318 | 319 | 320 | def main(cmd=None): 321 | print(get_commandline_args(), file=sys.stderr) 322 | parser = get_parser() 323 | args = parser.parse_args(cmd) 324 | kwargs = vars(args) 325 | runner(**kwargs) 326 | 327 | 328 | if __name__ == '__main__': 329 | main() --------------------------------------------------------------------------------