├── raw
└── wav_file_here
├── wavs
└── slicer_files_here
├── whisper_model
└── whisper_model_here
├── 查看cuda版本.bat
├── models_from_modelscope
└── place_model_files_here
├── GPU诊断.bat
├── 1_Dataset.bat
├── common
├── __pycache__
│ ├── log.cpython-39.pyc
│ ├── constants.cpython-39.pyc
│ ├── stdout_wrapper.cpython-39.pyc
│ └── subprocess_utils.cpython-39.pyc
├── log.py
├── constants.py
├── stdout_wrapper.py
├── subprocess_utils.py
└── tts_model.py
├── bcut_asr
├── __pycache__
│ ├── orm.cpython-310.pyc
│ ├── orm.cpython-39.pyc
│ ├── __init__.cpython-310.pyc
│ └── __init__.cpython-39.pyc
├── orm.py
├── __main__.py
└── __init__.py
├── becut_test.py
├── README.md
├── gpu_diagnostics.py
├── esd_whisper_large_v2.list
├── audio_slicer_pre.py
├── argparse_tools.py
├── trans_utils.py
├── short_audio_transcribe_bcut.py
├── requirements.txt
├── short_audio_transcribe_fwhisper.py
├── short_audio_transcribe_whisper.py
├── subtitle_utils.py
├── short_audio_transcribe_ali.py
├── slicer2.py
├── webui_dataset.py
└── videoclipper.py
/raw/wav_file_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/wavs/slicer_files_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/whisper_model/whisper_model_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/查看cuda版本.bat:
--------------------------------------------------------------------------------
1 | nvcc --version
2 | pause
3 |
--------------------------------------------------------------------------------
/models_from_modelscope/place_model_files_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/GPU诊断.bat:
--------------------------------------------------------------------------------
1 | venv\python.exe gpu_diagnostics.py
2 | pause
3 |
--------------------------------------------------------------------------------
/1_Dataset.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | chcp 65001
3 |
4 | call venv\python.exe webui_dataset.py
5 |
6 | @echo 请按任意键继续
7 | call pause
--------------------------------------------------------------------------------
/common/__pycache__/log.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/log.cpython-39.pyc
--------------------------------------------------------------------------------
/bcut_asr/__pycache__/orm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/orm.cpython-310.pyc
--------------------------------------------------------------------------------
/bcut_asr/__pycache__/orm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/orm.cpython-39.pyc
--------------------------------------------------------------------------------
/bcut_asr/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/bcut_asr/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/bcut_asr/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/common/__pycache__/constants.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/constants.cpython-39.pyc
--------------------------------------------------------------------------------
/common/__pycache__/stdout_wrapper.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/stdout_wrapper.cpython-39.pyc
--------------------------------------------------------------------------------
/common/__pycache__/subprocess_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/v3ucn/ASR_TOOLS_SenseVoice_WebUI/HEAD/common/__pycache__/subprocess_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/common/log.py:
--------------------------------------------------------------------------------
1 | """
2 | logger封装
3 | """
4 | from loguru import logger
5 |
6 | from .stdout_wrapper import SAFE_STDOUT
7 |
8 | # 移除所有默认的处理器
9 | logger.remove()
10 |
11 | # 自定义格式并添加到标准输出
12 | log_format = (
13 | "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}"
14 | )
15 |
16 | logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True)
17 |
--------------------------------------------------------------------------------
/becut_test.py:
--------------------------------------------------------------------------------
1 | from bcut_asr import BcutASR
2 | from bcut_asr.orm import ResultStateEnum
3 |
4 | asr = BcutASR("./wavs/Erwin_0.wav")
5 | asr.upload() # 上传文件
6 | asr.create_task() # 创建任务
7 |
8 | # 轮询检查结果
9 | while True:
10 | result = asr.result()
11 | # 判断识别成功
12 | if result.state == ResultStateEnum.COMPLETE:
13 | break
14 |
15 | # 解析字幕内容
16 | subtitle = result.parse()
17 |
18 | # 判断是否存在字幕
19 | if subtitle.has_data():
20 | # 输出srt格式
21 | print(subtitle.to_txt())
22 |
23 |
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 安装Python3
2 |
3 | ## 安装依赖
4 |
5 | ```
6 | pip install -r requirements.txt
7 | ```
8 |
9 | ## 音频放入raw目录
10 |
11 | ```
12 | 音频格式 角色名.wav
13 | ```
14 |
15 | ## 运行webui
16 |
17 | ```
18 | python3 webui_dataset.py
19 | ```
20 | 
21 |
22 |
23 | ## 视频攻略
24 |
25 | https://www.bilibili.com/video/BV1da4y117Y6/
26 |
27 | ## 必剪项目官方地址:SocialSisterYi/bcut-asr: 使用必剪API的语音字幕识别 (github.com)
28 |
--------------------------------------------------------------------------------
/common/constants.py:
--------------------------------------------------------------------------------
1 | import enum
2 |
3 | DEFAULT_STYLE: str = "Neutral"
4 | DEFAULT_STYLE_WEIGHT: float = 26.0
5 |
6 |
7 | class Languages(str, enum.Enum):
8 | JP = "JP"
9 | EN = "EN"
10 | ZH = "ZH"
11 |
12 |
13 | DEFAULT_SDP_RATIO: float = 0.2
14 | DEFAULT_NOISE: float = 0.6
15 | DEFAULT_NOISEW: float = 0.8
16 | DEFAULT_LENGTH: float = 1.0
17 | DEFAULT_LINE_SPLIT: bool = True
18 | DEFAULT_SPLIT_INTERVAL: float = 0.5
19 | DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7
20 | DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0
21 |
--------------------------------------------------------------------------------
/common/stdout_wrapper.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import tempfile
3 |
4 |
5 | class StdoutWrapper:
6 | def __init__(self):
7 | self.temp_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
8 | self.original_stdout = sys.stdout
9 |
10 | def write(self, message: str):
11 | self.temp_file.write(message)
12 | self.temp_file.flush()
13 | print(message, end="", file=self.original_stdout)
14 |
15 | def flush(self):
16 | self.temp_file.flush()
17 |
18 | def read(self):
19 | self.temp_file.seek(0)
20 | return self.temp_file.read()
21 |
22 | def close(self):
23 | self.temp_file.close()
24 |
25 | def fileno(self):
26 | return self.temp_file.fileno()
27 |
28 |
29 | try:
30 | import google.colab
31 |
32 | SAFE_STDOUT = StdoutWrapper()
33 | except ImportError:
34 | SAFE_STDOUT = sys.stdout
35 |
--------------------------------------------------------------------------------
/gpu_diagnostics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def gpu_diagnostics():
4 | if torch.cuda.is_available():
5 | print("GPU 诊断报告:")
6 | print("="*40)
7 | for i in range(torch.cuda.device_count()):
8 | props = torch.cuda.get_device_properties(i)
9 | total_memory = props.total_memory / (1024 * 1024)
10 | reserved_memory = torch.cuda.memory_reserved(i) / (1024 * 1024)
11 | allocated_memory = torch.cuda.memory_allocated(i) / (1024 * 1024)
12 | free_memory = total_memory - allocated_memory
13 |
14 | print(f"GPU {i}: {props.name}")
15 | print(f" 总显存 : {round(total_memory, 2)} MB")
16 | print(f" 已保留显存 : {round(reserved_memory, 2)} MB")
17 | print(f" 已分配显存 : {round(allocated_memory, 2)} MB")
18 | print(f" 空闲显存 : {round(free_memory, 2)} MB")
19 | print("="*40)
20 | else:
21 | print("未找到 GPU,使用 CPU")
22 |
23 | if __name__ == "__main__":
24 | gpu_diagnostics()
25 |
--------------------------------------------------------------------------------
/common/subprocess_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 |
4 | from .log import logger
5 | from .stdout_wrapper import SAFE_STDOUT
6 |
7 |
8 | #python = ".\\venv\\python.exe"
9 | python = "python3"
10 |
11 | print(python)
12 |
13 |
14 | def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]:
15 | logger.info(f"Running: {' '.join(cmd)}")
16 | print([python] + cmd)
17 |
18 | result = subprocess.run(
19 | [python] + cmd,
20 | stdout=SAFE_STDOUT, # type: ignore
21 | stderr=subprocess.PIPE,
22 | text=True,
23 | )
24 |
25 | if result.returncode != 0:
26 | logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}")
27 | return False, result.stderr
28 | elif result.stderr and not ignore_warning:
29 | logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}")
30 | return True, result.stderr
31 | logger.success(f"Success: {' '.join(cmd)}")
32 | return True, ""
33 |
34 |
35 | def second_elem_of(original_function):
36 | def inner_function(*args, **kwargs):
37 | return original_function(*args, **kwargs)[1]
38 |
39 | return inner_function
40 |
--------------------------------------------------------------------------------
/esd_whisper_large_v2.list:
--------------------------------------------------------------------------------
1 | Erwin_0.wav|Erwin|ZH|如果这个作战顺利
2 | Erwin_1.wav|Erwin|ZH|你也许可以趁此机会干掉兽之巨人
3 | Erwin_10.wav|Erwin|ZH|如果到时候我不冲在最前面
4 | Erwin_11.wav|Erwin|ZH|他们根本不会往前冲,然后我会第一个去死
5 | Erwin_12.wav|Erwin|ZH|地下室里到底有什么
6 | Erwin_13.wav|Erwin|ZH|也就无从知晓了,好想去地下室看一看,我之所以能撑着走到今天
7 | Erwin_14.wav|Erwin|ZH|就是因为相信这一天的到来
8 | Erwin_15.wav|Erwin|ZH|因为潜行者
9 | Erwin_16.wav|Erwin|ZH|我的猜想能够得到证实
10 | Erwin_17.wav|Erwin|ZH|我之前无数次的想过,要不然干脆死了算了
11 | Erwin_18.wav|Erwin|ZH|可即便如此,我还是想要实现父亲的梦想
12 | Erwin_19.wav|Erwin|ZH|然而现在
13 | Erwin_2.wav|Erwin|ZH|但得到所有新兵不管选择哪条路
14 | Erwin_20.wav|Erwin|ZH|他的答案就在我触手可及的地方
15 | Erwin_21.wav|Erwin|ZH|近在咫尺死去的同伴们也是如此吗
16 | Erwin_22.wav|Erwin|ZH|那些流血的牺牲都是没有意义的吗
17 | Erwin_23.wav|Erwin|ZH|不不对
18 | Erwin_24.wav|Erwin|ZH|那些死去士兵的意义将由我们来赋予
19 | Erwin_25.wav|Erwin|ZH|那些勇敢的死者,可怜的死者
20 | Erwin_26.wav|Erwin|ZH|是他们的牺牲换来了我们活着的今天
21 | Erwin_27.wav|Erwin|ZH|让我们能站在这里,而今天我们将会死去
22 | Erwin_28.wav|Erwin|ZH|将一一托付给下一个活着的人
23 | Erwin_29.wav|Erwin|ZH|这就是我们与这个残酷的世界
24 | Erwin_3.wav|Erwin|ZH|我们基本都会死吧,是的全灭的可能性相当的高
25 | Erwin_30.wav|Erwin|ZH|抗争的意义
26 | Erwin_4.wav|Erwin|ZH|但事到如今也只能做好玉石俱焚的觉悟
27 | Erwin_5.wav|Erwin|ZH|将一切堵在获胜渺茫的战术上
28 | Erwin_6.wav|Erwin|ZH|到了这一步
29 | Erwin_7.wav|Erwin|ZH|要让那些年轻人们去死
30 | Erwin_8.wav|Erwin|ZH|就必须像一个一流的诈骗犯一样
31 | Erwin_9.wav|Erwin|ZH|为他们花言巧语一番
32 |
--------------------------------------------------------------------------------
/audio_slicer_pre.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import soundfile
3 | import os
4 | import argparse
5 | from slicer2 import Slicer
6 |
7 | # 设置命令行参数
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument(
10 | "--min_sec", "-m", type=int, default=2000, help="Minimum seconds of a slice"
11 | )
12 | parser.add_argument(
13 | "--max_sec", "-M", type=int, default=5000, help="Maximum seconds of a slice"
14 | )
15 | parser.add_argument(
16 | "--dataset_path",
17 | type=str,
18 | default="inputs",
19 | help="Directory of input wav files",
20 | )
21 | parser.add_argument(
22 | "--min_silence_dur_ms",
23 | "-s",
24 | type=int,
25 | default=700,
26 | help="Silence above this duration (ms) is considered as a split point.",
27 | )
28 | args = parser.parse_args()
29 |
30 | # 清空输出目录
31 | folder_path = './wavs'
32 | if os.path.exists(folder_path):
33 | for filename in os.listdir(folder_path):
34 | file_path = os.path.join(folder_path, filename)
35 | if os.path.isfile(file_path):
36 | os.remove(file_path)
37 | else:
38 | os.makedirs(folder_path)
39 |
40 | # 遍历指定目录下的所有.wav文件
41 | audio_directory = f'{args.dataset_path}'
42 | for filename in os.listdir(audio_directory):
43 | file_path = os.path.join(audio_directory, filename)
44 | if os.path.isfile(file_path) and filename.endswith('.wav'):
45 | # 加载音频文件
46 | audio, sr = librosa.load(file_path, sr=None, mono=False)
47 |
48 | # 创建Slicer对象
49 | slicer = Slicer(
50 | sr=sr,
51 | threshold=-40,
52 | min_length=args.min_sec,
53 | min_interval=300,
54 | hop_size=10,
55 | max_sil_kept=args.min_silence_dur_ms
56 | )
57 |
58 | # 切割音频
59 | chunks = slicer.slice(audio)
60 | for i, chunk in enumerate(chunks):
61 | if len(chunk.shape) > 1:
62 | chunk = chunk.T # Swap axes if the audio is stereo.
63 | soundfile.write(f'./wavs/{filename[:-4]}_{i}.wav', chunk, sr)
64 |
--------------------------------------------------------------------------------
/argparse_tools.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import yaml
5 | import sys
6 |
7 |
8 | class ArgumentParser(argparse.ArgumentParser):
9 | """Simple implementation of ArgumentParser supporting config file
10 |
11 | This class is originated from https://github.com/bw2/ConfigArgParse,
12 | but this class is lack of some features that it has.
13 |
14 | - Not supporting multiple config files
15 | - Automatically adding "--config" as an option.
16 | - Not supporting any formats other than yaml
17 | - Not checking argument type
18 |
19 | """
20 |
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.add_argument("--config", help="Give config file in yaml format")
24 |
25 | def parse_known_args(self, args=None, namespace=None):
26 | # Once parsing for setting from "--config"
27 | _args, _ = super().parse_known_args(args, namespace)
28 | if _args.config is not None:
29 | if not Path(_args.config).exists():
30 | self.error(f"No such file: {_args.config}")
31 |
32 | with open(_args.config, "r", encoding="utf-8") as f:
33 | d = yaml.safe_load(f)
34 | if not isinstance(d, dict):
35 | self.error("Config file has non dict value: {_args.config}")
36 |
37 | for key in d:
38 | for action in self._actions:
39 | if key == action.dest:
40 | break
41 | else:
42 | self.error(f"unrecognized arguments: {key} (from {_args.config})")
43 |
44 | # NOTE(kamo): Ignore "--config" from a config file
45 | # NOTE(kamo): Unlike "configargparse", this module doesn't check type.
46 | # i.e. We can set any type value regardless of argument type.
47 | self.set_defaults(**d)
48 | return super().parse_known_args(args, namespace)
49 |
50 |
51 | def get_commandline_args():
52 | extra_chars = [
53 | " ",
54 | ";",
55 | "&",
56 | "(",
57 | ")",
58 | "|",
59 | "^",
60 | "<",
61 | ">",
62 | "?",
63 | "*",
64 | "[",
65 | "]",
66 | "$",
67 | "`",
68 | '"',
69 | "\\",
70 | "!",
71 | "{",
72 | "}",
73 | ]
74 |
75 | # Escape the extra characters for shell
76 | argv = [
77 | arg.replace("'", "'\\''")
78 | if all(char not in arg for char in extra_chars)
79 | else "'" + arg.replace("'", "'\\''") + "'"
80 | for arg in sys.argv
81 | ]
82 |
83 | return sys.executable + " " + " ".join(argv)
--------------------------------------------------------------------------------
/trans_utils.py:
--------------------------------------------------------------------------------
1 | PUNC_LIST = [',', '。', '!', '?', '、']
2 |
3 |
4 | def pre_proc(text):
5 | res = ''
6 | for i in range(len(text)):
7 | if text[i] in PUNC_LIST:
8 | continue
9 | if '\u4e00' <= text[i] <= '\u9fff':
10 | if len(res) and res[-1] != " ":
11 | res += ' ' + text[i]+' '
12 | else:
13 | res += text[i]+' '
14 | else:
15 | res += text[i]
16 | if res[-1] == ' ':
17 | res = res[:-1]
18 | return res
19 |
20 | def proc(raw_text, timestamp, dest_text):
21 | # simple matching
22 | ld = len(dest_text.split())
23 | mi, ts = [], []
24 | offset = 0
25 | while True:
26 | fi = raw_text.find(dest_text, offset, len(raw_text))
27 | # import pdb; pdb.set_trace()
28 | ti = raw_text[:fi].count(' ')
29 | if fi == -1:
30 | break
31 | offset = fi + ld
32 | mi.append(fi)
33 | ts.append([timestamp[ti][0]*16, timestamp[ti+ld-1][1]*16])
34 | # import pdb; pdb.set_trace()
35 | return ts
36 |
37 | def proc_spk(dest_spk, sd_sentences):
38 | ts = []
39 | for d in sd_sentences:
40 | d_start = d['ts_list'][0][0]
41 | d_end = d['ts_list'][-1][1]
42 | spkid=dest_spk[3:]
43 | if str(d['spk']) == spkid and d_end-d_start>999:
44 | ts.append([d['start']*16, d['end']*16])
45 | return ts
46 |
47 | def generate_vad_data(data, sd_sentences, sr=16000):
48 | assert len(data.shape) == 1
49 | vad_data = []
50 | for d in sd_sentences:
51 | d_start = round(d['ts_list'][0][0]/1000, 2)
52 | d_end = round(d['ts_list'][-1][1]/1000, 2)
53 | vad_data.append([d_start, d_end, data[int(d_start * sr):int(d_end * sr)]])
54 | return vad_data
55 |
56 | def write_state(output_dir, state):
57 | for key in ['/recog_res_raw', '/timestamp', '/sentences', '/sd_sentences']:
58 | with open(output_dir+key, 'w') as fout:
59 | fout.write(str(state[key[1:]]))
60 | if 'sd_sentences' in state:
61 | with open(output_dir+'/sd_sentences', 'w') as fout:
62 | fout.write(str(state['sd_sentences']))
63 |
64 | import os
65 | def load_state(output_dir):
66 | state = {}
67 | with open(output_dir+'/recog_res_raw') as fin:
68 | line = fin.read()
69 | state['recog_res_raw'] = line
70 | with open(output_dir+'/timestamp') as fin:
71 | line = fin.read()
72 | state['timestamp'] = eval(line)
73 | with open(output_dir+'/sentences') as fin:
74 | line = fin.read()
75 | state['sentences'] = eval(line)
76 | if os.path.exists(output_dir+'/sd_sentences'):
77 | with open(output_dir+'/sd_sentences') as fin:
78 | line = fin.read()
79 | state['sd_sentences'] = eval(line)
80 | return state
81 |
82 |
--------------------------------------------------------------------------------
/bcut_asr/orm.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from pydantic import BaseModel
3 |
4 | class ASRDataSeg(BaseModel):
5 | '文字识别-断句'
6 | class ASRDataWords(BaseModel):
7 | '文字识别-逐字'
8 | label: str
9 | start_time: int
10 | end_time: int
11 | confidence: int
12 | start_time: int
13 | end_time: int
14 | transcript: str
15 | words: list[ASRDataWords]
16 | confidence: int
17 |
18 | def to_srt_ts(self) -> str:
19 | '转换为srt时间戳'
20 | def _conv(ms: int) -> tuple[int ,int, int, int]:
21 | return ms // 3600000, ms // 60000 % 60, ms // 1000 % 60, ms % 1000
22 | s_h, s_m, s_s, s_ms = _conv(self.start_time)
23 | e_h, e_m, e_s, e_ms = _conv(self.end_time)
24 | return f'{s_h:02d}:{s_m:02d}:{s_s:02d},{s_ms:03d} --> {e_h:02d}:{e_m:02d}:{e_s:02d},{e_ms:03d}'
25 |
26 | def to_lrc_ts(self) -> str:
27 | '转换为lrc时间戳'
28 | def _conv(ms: int) -> tuple[int ,int, int]:
29 | return ms // 60000, ms // 1000 % 60, ms % 1000 // 10
30 | s_m, s_s, s_ms = _conv(self.start_time)
31 | return f'[{s_m:02d}:{s_s:02d}.{s_ms:02d}]'
32 |
33 |
34 |
35 | class ASRData(BaseModel):
36 | '语音识别结果'
37 | utterances: list[ASRDataSeg]
38 | version: str
39 |
40 | def __iter__(self):
41 | 'iter穿透'
42 | return iter(self.utterances)
43 |
44 | def has_data(self) -> bool:
45 | '是否识别到数据'
46 | return len(self.utterances) > 0
47 |
48 | def to_txt(self) -> str:
49 | '转成txt格式字幕 (无时间标记)'
50 | return '\n'.join(
51 | seg.transcript
52 | for seg
53 | in self.utterances
54 | )
55 |
56 | def to_srt(self) -> str:
57 | '转成srt格式字幕'
58 | return '\n'.join(
59 | f'{n}\n{seg.to_srt_ts()}\n{seg.transcript}\n'
60 | for n, seg
61 | in enumerate(self.utterances, 1)
62 | )
63 |
64 | def to_lrc(self) -> str:
65 | '转成lrc格式字幕'
66 | return '\n'.join(
67 | f'{seg.to_lrc_ts()}{seg.transcript}'
68 | for seg
69 | in self.utterances
70 | )
71 |
72 | def to_ass(self) -> str:
73 | ...
74 |
75 |
76 | class ResourceCreateRspSchema(BaseModel):
77 | '上传申请响应'
78 | resource_id: str
79 | title: str
80 | type: int
81 | in_boss_key: str
82 | size: int
83 | upload_urls: list[str]
84 | upload_id: str
85 | per_size: int
86 |
87 | class ResourceCompleteRspSchema(BaseModel):
88 | '上传提交响应'
89 | resource_id: str
90 | download_url: str
91 |
92 | class TaskCreateRspSchema(BaseModel):
93 | '任务创建响应'
94 | resource: str
95 | result: str
96 | task_id: str # 任务id
97 |
98 | class ResultStateEnum(Enum):
99 | '任务状态枚举'
100 | STOP = 0 # 未开始
101 | RUNING = 1 # 运行中
102 | ERROR = 3 # 错误
103 | COMPLETE = 4 # 完成
104 |
105 | class ResultRspSchema(BaseModel):
106 | '任务结果查询响应'
107 | task_id: str # 任务id
108 | result: str # 结果数据-json
109 | remark: str # 任务状态详情
110 | state: ResultStateEnum # 任务状态
111 |
112 | def parse(self) -> ASRData:
113 | '解析结果数据'
114 | return ASRData.parse_raw(self.result)
115 |
--------------------------------------------------------------------------------
/short_audio_transcribe_bcut.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 |
5 | from tqdm import tqdm
6 | import sys
7 | import os
8 |
9 | from common.constants import Languages
10 | from common.log import logger
11 | from common.stdout_wrapper import SAFE_STDOUT
12 |
13 | from bcut_asr import BcutASR
14 | from bcut_asr.orm import ResultStateEnum
15 |
16 | import whisper
17 | import torch
18 |
19 | import re
20 |
21 |
22 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
23 |
24 |
25 | model = whisper.load_model("medium",download_root="./whisper_model/")
26 |
27 |
28 |
29 | lang2token = {
30 | 'zh': "ZH|",
31 | 'ja': "JP|",
32 | "en": "EN|",
33 | }
34 |
35 |
36 | def transcribe_one(audio_path):
37 |
38 | audio = whisper.load_audio(audio_path)
39 | audio = whisper.pad_or_trim(audio)
40 | mel = whisper.log_mel_spectrogram(audio).to(model.device)
41 | _, probs = model.detect_language(mel)
42 | language = max(probs, key=probs.get)
43 |
44 | asr = BcutASR(audio_path)
45 | asr.upload() # 上传文件
46 | asr.create_task() # 创建任务
47 |
48 | # 轮询检查结果
49 | while True:
50 | result = asr.result()
51 | # 判断识别成功
52 | if result.state == ResultStateEnum.COMPLETE:
53 | break
54 |
55 | # 解析字幕内容
56 | subtitle = result.parse()
57 |
58 | # 判断是否存在字幕
59 | if subtitle.has_data():
60 |
61 |
62 |
63 | text = subtitle.to_txt()
64 | text = repr(text)
65 | text = text.replace("'","")
66 | text = text.replace("\\n",",")
67 | text = text.replace("\\r",",")
68 |
69 | print(text)
70 |
71 | # 输出srt格式
72 | return text,language
73 | else:
74 | return "必剪无法识别",language
75 |
76 |
77 |
78 | if __name__ == "__main__":
79 |
80 | parser = argparse.ArgumentParser()
81 |
82 | parser.add_argument(
83 | "--language", type=str, default="ja", choices=["ja", "en", "zh"]
84 | )
85 | parser.add_argument("--model_name", type=str, required=True)
86 |
87 | parser.add_argument("--input_file", type=str, default="./wavs/")
88 |
89 | parser.add_argument("--file_pos", type=str, default="")
90 |
91 |
92 | args = parser.parse_args()
93 |
94 | speaker_name = args.model_name
95 |
96 | language = args.language
97 |
98 | input_file = args.input_file
99 |
100 | if input_file == "":
101 | input_file = "./wavs/"
102 |
103 | file_pos = args.file_pos
104 |
105 |
106 | wav_files = [
107 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav")
108 | ]
109 |
110 |
111 | with open("./esd.list", "w", encoding="utf-8") as f:
112 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
113 | file_name = os.path.basename(wav_file)
114 |
115 | # 使用正则表达式提取'deedee'
116 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
117 | if match:
118 | extracted_name = match.group(1) + match.group(2)
119 | else:
120 | print("No match found")
121 | extracted_name = "sample"
122 |
123 | text,lang = transcribe_one(f"{input_file}"+wav_file)
124 |
125 | if lang == "ja":
126 | language_id = "JA"
127 | elif lang == "en":
128 | language_id = "EN"
129 | elif lang == "zh":
130 | language_id = "ZH"
131 |
132 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
133 |
134 | f.flush()
135 | sys.exit(0)
136 |
137 |
138 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.0.0
2 | addict==2.4.0
3 | aiofiles==23.2.1
4 | aiohttp==3.9.1
5 | aiosignal==1.3.1
6 | aliyun-python-sdk-core==2.14.0
7 | aliyun-python-sdk-kms==2.16.2
8 | altair==5.2.0
9 | annotated-types==0.6.0
10 | anyio==4.2.0
11 | async-timeout==4.0.3
12 | attrs==23.2.0
13 | audioread==3.0.1
14 | cachetools==5.3.2
15 | certifi==2023.11.17
16 | cffi==1.16.0
17 | charset-normalizer==3.3.2
18 | click==8.1.7
19 | colorama==0.4.6
20 | contourpy==1.2.0
21 | crcmod==1.7
22 | cryptography==41.0.7
23 | cycler==0.12.1
24 | Cython==0.29.37
25 | datasets==2.16.1
26 | decorator==5.1.1
27 | dill==0.3.7
28 | editdistance==0.6.2
29 | einops==0.7.0
30 | exceptiongroup==1.2.0
31 | fastapi==0.109.0
32 | ffmpeg==1.4
33 | ffmpy==0.3.1
34 | filelock==3.13.1
35 | fonttools==4.47.2
36 | frozenlist==1.4.1
37 | fsspec==2023.10.0
38 | funasr
39 | gast==0.5.4
40 | google-auth==2.26.2
41 | google-auth-oauthlib==1.2.0
42 | gradio==4.14.0
43 | gradio_client==0.8.0
44 | grpcio==1.60.0
45 | h11==0.14.0
46 | hdbscan==0.8.33
47 | httpcore==1.0.2
48 | httpx==0.26.0
49 | huggingface-hub==0.20.2
50 | humanfriendly==10.0
51 | idna==3.6
52 | importlib-metadata==7.0.1
53 | importlib-resources==6.1.1
54 | jaconv==0.3.4
55 | jamo==0.4.1
56 | jieba==0.42.1
57 | Jinja2==3.1.3
58 | jmespath==0.10.0
59 | joblib==1.3.2
60 | jsonschema==4.20.0
61 | jsonschema-specifications==2023.12.1
62 | kaldiio==2.18.0
63 | kiwisolver==1.4.5
64 | lazy_loader==0.3
65 | librosa==0.10.1
66 | llvmlite==0.41.1
67 | loguru==0.7.2
68 | Markdown==3.5.2
69 | markdown-it-py==3.0.0
70 | MarkupSafe==2.1.3
71 | matplotlib==3.8.2
72 | mdurl==0.1.2
73 | mecab-python3==1.0.8
74 | modelscope==1.10.0
75 | more-itertools==10.2.0
76 | mpmath==1.3.0
77 | msgpack==1.0.7
78 | multidict==6.0.4
79 | multiprocess==0.70.15
80 | networkx==3.2.1
81 | numba==0.58.1
82 | numpy==1.26.3
83 | oauthlib==3.2.2
84 | openai-whisper @ git+https://github.com/openai/whisper.git@ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab
85 | orjson==3.9.10
86 | oss2==2.18.4
87 | packaging==23.2
88 | pandas==2.1.4
89 | pillow==10.2.0
90 | platformdirs==4.1.0
91 | pooch==1.8.0
92 | protobuf==4.23.4
93 | pyarrow==14.0.2
94 | pyarrow-hotfix==0.6
95 | pyasn1==0.5.1
96 | pyasn1-modules==0.3.0
97 | pycparser==2.21
98 | pycryptodome==3.20.0
99 | pydantic==2.5.3
100 | pydantic_core==2.14.6
101 | pydub==0.25.1
102 | Pygments==2.17.2
103 | pyparsing==3.1.1
104 | pypinyin==0.50.0
105 | pyreadline3==3.4.1
106 | python-dateutil==2.8.2
107 | python-multipart==0.0.6
108 | pytorch-wpe==0.0.1
109 | pytz==2023.3.post1
110 | PyYAML==6.0.1
111 | referencing==0.32.1
112 | regex==2023.12.25
113 | requests==2.31.0
114 | requests-oauthlib==1.3.1
115 | rich==13.7.0
116 | rpds-py==0.17.1
117 | rsa==4.9
118 | safetensors==0.4.1
119 | scikit-learn==1.3.2
120 | scipy==1.11.4
121 | semantic-version==2.10.0
122 | sentencepiece==0.1.99
123 | shellingham==1.5.4
124 | simplejson==3.19.2
125 | six==1.16.0
126 | sniffio==1.3.0
127 | sortedcontainers==2.4.0
128 | soundfile==0.12.1
129 | soxr==0.3.7
130 | starlette==0.35.1
131 | sympy==1.12
132 | tensorboard==2.15.1
133 | tensorboard-data-server==0.7.2
134 | threadpoolctl==3.2.0
135 | tiktoken==0.5.2
136 | tokenizers==0.15.0
137 | tomli==2.0.1
138 | tomlkit==0.12.0
139 | toolz==0.12.0
140 | torch==2.1.2+cu118
141 | torch-complex==0.4.3
142 | torchaudio==2.1.2
143 | torchvision==0.16.2+cu118
144 | tqdm==4.66.1
145 | transformers==4.36.2
146 | typer==0.9.0
147 | typing_extensions==4.9.0
148 | tzdata==2023.4
149 | umap==0.1.1
150 | urllib3==2.1.0
151 | uvicorn==0.25.0
152 | websockets==11.0.3
153 | Werkzeug==3.0.1
154 | win32-setctime==1.1.0
155 | xxhash==3.4.1
156 | yapf==0.40.2
157 | yarl==1.9.4
158 | zipp==3.17.0
159 | faster-whisper
160 | moviepy
161 |
--------------------------------------------------------------------------------
/short_audio_transcribe_fwhisper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | from tqdm import tqdm
7 | import sys
8 | import os
9 |
10 | from common.constants import Languages
11 | from common.log import logger
12 | from common.stdout_wrapper import SAFE_STDOUT
13 |
14 | import re
15 |
16 | from transformers import pipeline
17 |
18 | from faster_whisper import WhisperModel
19 |
20 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
21 |
22 | model = None
23 |
24 | lang2token = {
25 | 'zh': "ZH|",
26 | 'ja': "JP|",
27 | "en": "EN|",
28 | }
29 |
30 |
31 | def transcribe_bela(audio_path):
32 |
33 | transcriber = pipeline(
34 | "automatic-speech-recognition",
35 | model="BELLE-2/Belle-whisper-large-v2-zh",
36 | device=device
37 | )
38 |
39 | transcriber.model.config.forced_decoder_ids = (
40 | transcriber.tokenizer.get_decoder_prompt_ids(
41 | language="zh",
42 | task="transcribe",
43 | )
44 | )
45 |
46 | transcription = transcriber(audio_path)
47 |
48 | print(transcription["text"])
49 | return transcription["text"]
50 |
51 |
52 | def transcribe_one(audio_path,mytype):
53 |
54 | segments, info = model.transcribe(audio_path, beam_size=5,vad_filter=True,vad_parameters=dict(min_silence_duration_ms=500),)
55 | print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
56 |
57 | text_str = ""
58 | for segment in segments:
59 | text_str += f"{segment.text.lstrip()},"
60 | print(text_str.rstrip(","))
61 |
62 | return text_str.rstrip(","),info.language
63 |
64 |
65 |
66 | if __name__ == "__main__":
67 |
68 | parser = argparse.ArgumentParser()
69 |
70 | parser.add_argument(
71 | "--language", type=str, default="ja", choices=["ja", "en", "zh"]
72 | )
73 |
74 | parser.add_argument(
75 | "--mytype", type=str, default="medium"
76 | )
77 |
78 | parser.add_argument("--model_name", type=str, required=True)
79 |
80 | parser.add_argument("--input_file", type=str, default="./wavs/")
81 |
82 | parser.add_argument("--file_pos", type=str, default="")
83 |
84 |
85 | args = parser.parse_args()
86 |
87 | speaker_name = args.model_name
88 |
89 | language = args.language
90 |
91 | mytype = args.mytype
92 |
93 | input_file = args.input_file
94 |
95 | if input_file == "":
96 | input_file = "./wavs/"
97 |
98 | file_pos = args.file_pos
99 |
100 | if device == "cuda":
101 | try:
102 | model = WhisperModel(mytype, device="cuda", compute_type="float16",download_root="./whisper_model",local_files_only=False)
103 | except Exception as e:
104 | model = WhisperModel(mytype, device="cuda", compute_type="int8_float16",download_root="./whisper_model",local_files_only=False)
105 | else:
106 | model = WhisperModel(mytype, device="cpu", compute_type="int8",download_root="./whisper_model",local_files_only=False)
107 |
108 |
109 | wav_files = [
110 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav")
111 | ]
112 |
113 |
114 |
115 | with open("./esd.list", "w", encoding="utf-8") as f:
116 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
117 | file_name = os.path.basename(wav_file)
118 |
119 | if model:
120 | text,lang = transcribe_one(f"{input_file}"+wav_file,mytype)
121 | else:
122 | text,lang = transcribe_bela(f"{input_file}"+wav_file)
123 |
124 | # 使用正则表达式提取'deedee'
125 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
126 | if match:
127 | extracted_name = match.group(1) + match.group(2)
128 | else:
129 | print("No match found")
130 | extracted_name = "sample"
131 |
132 | if lang == "ja":
133 | language_id = "JA"
134 | elif lang == "en":
135 | language_id = "EN"
136 | elif lang == "zh":
137 | language_id = "ZH"
138 |
139 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
140 |
141 | f.flush()
142 | sys.exit(0)
143 |
144 |
145 |
--------------------------------------------------------------------------------
/bcut_asr/__main__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import time
4 | from argparse import ArgumentParser, FileType
5 | import ffmpeg
6 | from . import APIError, BcutASR, ResultStateEnum
7 |
8 | logging.basicConfig(format='%(asctime)s - [%(levelname)s] %(message)s', level=logging.INFO)
9 |
10 | INFILE_FMT = ('flac', 'aac', 'm4a', 'mp3', 'wav')
11 | OUTFILE_FMT = ('srt', 'json', 'lrc', 'txt')
12 |
13 | parser = ArgumentParser(
14 | prog='bcut-asr',
15 | description='必剪语音识别\n',
16 | epilog=f"支持输入音频格式: {', '.join(INFILE_FMT)} 支持自动调用ffmpeg提取视频伴音"
17 | )
18 | parser.add_argument('-f', '--format', nargs='?', default='srt', choices=OUTFILE_FMT, help='输出字幕格式')
19 | parser.add_argument('input', type=FileType('rb'), help='输入媒体文件')
20 | parser.add_argument('output', nargs='?', type=FileType('w', encoding='utf8'), help='输出字幕文件, 可stdout')
21 |
22 | args = parser.parse_args()
23 |
24 |
25 | def ffmpeg_render(media_file: str) -> bytes:
26 | '提取视频伴音并转码为aac格式'
27 | out, err = (ffmpeg
28 | .input(media_file, v='warning')
29 | .output('pipe:', ac=1, format='adts')
30 | .run(capture_stdout=True)
31 | )
32 | return out
33 |
34 |
35 | def main():
36 | # 处理输入文件情况
37 | infile = args.input
38 | infile_name = infile.name
39 | if infile_name == '':
40 | logging.error('输入文件错误')
41 | sys.exit(-1)
42 | suffix = infile_name.rsplit('.', 1)[-1]
43 | if suffix in INFILE_FMT:
44 | infile_fmt = suffix
45 | infile_data = infile.read()
46 | else:
47 | # ffmpeg分离视频伴音
48 | logging.info('非标准音频文件, 尝试调用ffmpeg转码')
49 | try:
50 | infile_data = ffmpeg_render(infile_name)
51 | except ffmpeg.Error:
52 | logging.error('ffmpeg转码失败')
53 | sys.exit(-1)
54 | else:
55 | logging.info('ffmpeg转码完成')
56 | infile_fmt = 'aac'
57 |
58 | # 处理输出文件情况
59 | outfile = args.output
60 | if outfile is None:
61 | # 未指定输出文件,默认为文件名同输入,可以 -t 传参,默认str格式
62 | if args.format is not None:
63 | outfile_fmt = args.format
64 | else:
65 | outfile_fmt = 'srt'
66 | else:
67 | # 指定输出文件
68 | outfile_name = outfile.name
69 | if outfile.name == '':
70 | # stdout情况,可以 -t 传参,默认str格式
71 | if args.format is not None:
72 | outfile_fmt = args.format
73 | else:
74 | outfile_fmt = 'srt'
75 | else:
76 | suffix = outfile_name.rsplit('.', 1)[-1]
77 | if suffix in OUTFILE_FMT:
78 | outfile_fmt = suffix
79 | else:
80 | logging.error('输出格式错误')
81 | sys.exit(-1)
82 |
83 | # 开始执行转换逻辑
84 | asr = BcutASR()
85 | asr.set_data(raw_data=infile_data, data_fmt=infile_fmt)
86 | try:
87 | # 上传文件
88 | asr.upload()
89 | # 创建任务
90 | task_id = asr.create_task()
91 | while True:
92 | # 轮询检查任务状态
93 | task_resp = asr.result()
94 | match task_resp.state:
95 | case ResultStateEnum.STOP:
96 | logging.info(f'等待识别开始')
97 | case ResultStateEnum.RUNING:
98 | logging.info(f'识别中-{task_resp.remark}')
99 | case ResultStateEnum.ERROR:
100 | logging.error(f'识别失败-{task_resp.remark}')
101 | sys.exit(-1)
102 | case ResultStateEnum.COMPLETE:
103 | logging.info(f'识别成功')
104 | outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}"
105 | outfile = open(outfile_name, 'w', encoding='utf8')
106 | # 识别成功, 回读字幕数据
107 | result = task_resp.parse()
108 | break
109 | time.sleep(1.0)
110 | if not result.has_data():
111 | logging.error('未识别到语音')
112 | sys.exit(-1)
113 | match outfile_fmt:
114 | case 'srt':
115 | outfile.write(result.to_srt())
116 | case 'lrc':
117 | outfile.write(result.to_lrc())
118 | case 'json':
119 | outfile.write(result.json())
120 | case 'txt':
121 | outfile.write(result.to_txt())
122 | logging.info(f'转换成功: {outfile_name}')
123 | except APIError as err:
124 | logging.error(f'接口错误: {err.__str__()}')
125 | sys.exit(-1)
126 |
127 |
128 | if __name__ == '__main__':
129 | main()
--------------------------------------------------------------------------------
/short_audio_transcribe_whisper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import whisper
4 | import torch
5 |
6 | from tqdm import tqdm
7 | import sys
8 | import os
9 |
10 | from common.constants import Languages
11 | from common.log import logger
12 | from common.stdout_wrapper import SAFE_STDOUT
13 |
14 | import re
15 |
16 | from transformers import pipeline
17 |
18 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
19 |
20 | model = None
21 |
22 |
23 | lang2token = {
24 | 'zh': "ZH|",
25 | 'ja': "JP|",
26 | "en": "EN|",
27 | }
28 |
29 |
30 | def transcribe_bela(audio_path):
31 |
32 | transcriber = pipeline(
33 | "automatic-speech-recognition",
34 | model="BELLE-2/Belle-whisper-large-v2-zh",
35 | device=device
36 | )
37 |
38 | transcriber.model.config.forced_decoder_ids = (
39 | transcriber.tokenizer.get_decoder_prompt_ids(
40 | language="zh",
41 | task="transcribe",
42 | )
43 | )
44 |
45 | transcription = transcriber(audio_path)
46 |
47 | print(transcription["text"])
48 | return transcription["text"]
49 |
50 |
51 | def transcribe_one(audio_path,mytype):
52 | # load audio and pad/trim it to fit 30 seconds
53 | audio = whisper.load_audio(audio_path)
54 | audio = whisper.pad_or_trim(audio)
55 |
56 | # make log-Mel spectrogram and move to the same device as the model
57 |
58 | if mytype == "large-v3":
59 |
60 | mel = whisper.log_mel_spectrogram(audio,n_mels=128).to(model.device)
61 |
62 | else:
63 |
64 | mel = whisper.log_mel_spectrogram(audio).to(model.device)
65 |
66 |
67 | # detect the spoken language
68 | _, probs = model.detect_language(mel)
69 | print(f"Detected language: {max(probs, key=probs.get)}")
70 | lang = max(probs, key=probs.get)
71 | # decode the audio
72 |
73 |
74 | if lang == "zh":
75 |
76 |
77 | if torch.cuda.is_available():
78 | options = whisper.DecodingOptions(beam_size=5,prompt="生于忧患,死于欢乐。不亦快哉!")
79 | else:
80 | options = whisper.DecodingOptions(beam_size=5,fp16 = False,prompt="生于忧患,死于欢乐。不亦快哉!")
81 |
82 | else:
83 |
84 |
85 |
86 | if torch.cuda.is_available():
87 | options = whisper.DecodingOptions(beam_size=5)
88 | else:
89 | options = whisper.DecodingOptions(beam_size=5,fp16 = False)
90 |
91 |
92 |
93 |
94 |
95 |
96 | result = whisper.decode(model, mel, options)
97 |
98 | # print the recognized text
99 | print(result.text)
100 | return result.text,max(probs, key=probs.get)
101 |
102 |
103 | if __name__ == "__main__":
104 |
105 | parser = argparse.ArgumentParser()
106 |
107 | parser.add_argument(
108 | "--language", type=str, default="ja", choices=["ja", "en", "zh"]
109 | )
110 |
111 | parser.add_argument(
112 | "--mytype", type=str, default="medium"
113 | )
114 |
115 | parser.add_argument("--model_name", type=str, required=True)
116 |
117 | parser.add_argument("--input_file", type=str, default="./wavs/")
118 |
119 | parser.add_argument("--file_pos", type=str, default="")
120 |
121 |
122 | args = parser.parse_args()
123 |
124 | speaker_name = args.model_name
125 |
126 | language = args.language
127 |
128 | mytype = args.mytype
129 |
130 | input_file = args.input_file
131 |
132 | if input_file == "":
133 | input_file = "./wavs/"
134 |
135 | file_pos = args.file_pos
136 |
137 | try:
138 | model = whisper.load_model(mytype,download_root="./whisper_model/")
139 | except Exception as e:
140 |
141 | print(str(e))
142 | print("中文特化逻辑")
143 |
144 |
145 | wav_files = [
146 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav")
147 | ]
148 |
149 |
150 |
151 | with open("./esd.list", "w", encoding="utf-8") as f:
152 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
153 | file_name = os.path.basename(wav_file)
154 |
155 | if model:
156 | text,lang = transcribe_one(f"{input_file}"+wav_file,mytype)
157 | else:
158 | text,lang = transcribe_bela(f"{input_file}"+wav_file)
159 |
160 | # 使用正则表达式提取'deedee'
161 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
162 | if match:
163 | extracted_name = match.group(1) + match.group(2)
164 | else:
165 | print("No match found")
166 | extracted_name = "sample"
167 |
168 | if lang == "ja":
169 | language_id = "JA"
170 | elif lang == "en":
171 | language_id = "EN"
172 | elif lang == "zh":
173 | language_id = "ZH"
174 |
175 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
176 |
177 | f.flush()
178 | sys.exit(0)
179 |
180 |
181 |
--------------------------------------------------------------------------------
/subtitle_utils.py:
--------------------------------------------------------------------------------
1 | def time_convert(ms):
2 | ms = int(ms)
3 | tail = ms % 1000
4 | s = ms // 1000
5 | mi = s // 60
6 | s = s % 60
7 | h = mi // 60
8 | mi = mi % 60
9 | h = "00" if h == 0 else str(h)
10 | mi = "00" if mi == 0 else str(mi)
11 | s = "00" if s == 0 else str(s)
12 | tail = str(tail)
13 | if len(h) == 1: h = '0' + h
14 | if len(mi) == 1: mi = '0' + mi
15 | if len(s) == 1: s = '0' + s
16 | return "{}:{}:{},{}".format(h, mi, s, tail)
17 |
18 |
19 | class Text2SRT():
20 | def __init__(self, text_seg, ts_list, offset=0):
21 | self.token_list = [i for i in text_seg.split() if len(i)]
22 | self.ts_list = ts_list
23 | start, end = ts_list[0][0] - offset, ts_list[-1][1] - offset
24 | self.start_sec, self.end_sec = start, end
25 | self.start_time = time_convert(start)
26 | self.end_time = time_convert(end)
27 | def text(self):
28 | res = ""
29 | for word in self.token_list:
30 | if '\u4e00' <= word <= '\u9fff':
31 | res += word
32 | else:
33 | res += " " + word
34 | return res
35 | def len(self):
36 | return len(self.token_list)
37 | def srt(self, acc_ost=0.0):
38 | return "{} --> {}\n{}\n".format(
39 | time_convert(self.start_sec+acc_ost*1000),
40 | time_convert(self.end_sec+acc_ost*1000),
41 | self.text())
42 | def time(self, acc_ost=0.0):
43 | return (self.start_sec/1000+acc_ost, self.end_sec/1000+acc_ost)
44 |
45 | def distribute_spk(sentence_list, sd_time_list):
46 | sd_sentence_list = []
47 | for d in sentence_list:
48 | sentence_start = d['ts_list'][0][0]
49 | sentence_end = d['ts_list'][-1][1]
50 | sentence_spk = 0
51 | max_overlap = 0
52 | for sd_time in sd_time_list:
53 | spk_st, spk_ed, spk = sd_time
54 | spk_st = spk_st*1000
55 | spk_ed = spk_ed*1000
56 | overlap = max(
57 | min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
58 | if overlap > max_overlap:
59 | max_overlap = overlap
60 | sentence_spk = spk
61 | d['spk'] = sentence_spk
62 | sd_sentence_list.append(d)
63 | return sd_sentence_list
64 |
65 | def generate_srt(sentence_list):
66 | srt_total = ''
67 | for i, d in enumerate(sentence_list):
68 | t2s = Text2SRT(d['text_seg'], d['ts_list'])
69 | if 'spk' in d:
70 | srt_total += "{} spk{}\n{}".format(i, d['spk'], t2s.srt())
71 | else:
72 | srt_total += "{}\n{}".format(i, t2s.srt())
73 | return srt_total
74 |
75 | def generate_srt_clip(sentence_list, start, end, begin_index=0, time_acc_ost=0.0):
76 | start, end = int(start * 1000), int(end * 1000)
77 | srt_total = ''
78 | cc = 1 + begin_index
79 | subs = []
80 | for i, d in enumerate(sentence_list):
81 | if d['ts_list'][-1][1] <= start:
82 | continue
83 | if d['ts_list'][0][0] >= end:
84 | break
85 | # parts in between
86 | if (d['ts_list'][-1][1] <= end and d['ts_list'][0][0] > start) or (d['ts_list'][-1][1] == end and d['ts_list'][0][0] == start):
87 | t2s = Text2SRT(d['text_seg'], d['ts_list'], offset=start)
88 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
89 | subs.append((t2s.time(time_acc_ost), t2s.text()))
90 | cc += 1
91 | continue
92 | if d['ts_list'][0][0] <= start:
93 | if not d['ts_list'][-1][1] > end:
94 | for j, ts in enumerate(d['ts_list']):
95 | if ts[1] > start:
96 | break
97 | _text = " ".join(d['text_seg'].split()[j:])
98 | _ts = d['ts_list'][j:]
99 | else:
100 | for j, ts in enumerate(d['ts_list']):
101 | if ts[1] > start:
102 | _start = j
103 | break
104 | for j, ts in enumerate(d['ts_list']):
105 | if ts[1] > end:
106 | _end = j
107 | break
108 | _text = " ".join(d['text_seg'].split()[_start:_end])
109 | _ts = d['ts_list'][_start:_end]
110 | if len(ts):
111 | t2s = Text2SRT(_text, _ts, offset=start)
112 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
113 | subs.append((t2s.time(time_acc_ost), t2s.text()))
114 | cc += 1
115 | continue
116 | if d['ts_list'][-1][1] > end:
117 | for j, ts in enumerate(d['ts_list']):
118 | if ts[1] > end:
119 | break
120 | _text = " ".join(d['text_seg'].split()[:j])
121 | _ts = d['ts_list'][:j]
122 | if len(_ts):
123 | t2s = Text2SRT(_text, _ts, offset=start)
124 | srt_total += "{}\n{}".format(cc, t2s.srt(time_acc_ost))
125 | subs.append(
126 | (t2s.time(time_acc_ost), t2s.text())
127 | )
128 | cc += 1
129 | continue
130 | return srt_total, subs, cc
131 |
--------------------------------------------------------------------------------
/short_audio_transcribe_ali.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import whisper
4 | import torch
5 |
6 | from tqdm import tqdm
7 | import sys
8 | import os
9 |
10 |
11 |
12 | from common.constants import Languages
13 | from common.log import logger
14 | from common.stdout_wrapper import SAFE_STDOUT
15 |
16 | import re
17 |
18 |
19 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
20 |
21 |
22 | from funasr import AutoModel
23 |
24 | model_dir = "iic/SenseVoiceSmall"
25 |
26 |
27 | emo_dict = {
28 | "<|HAPPY|>": "😊",
29 | "<|SAD|>": "😔",
30 | "<|ANGRY|>": "😡",
31 | "<|NEUTRAL|>": "",
32 | "<|FEARFUL|>": "😰",
33 | "<|DISGUSTED|>": "🤢",
34 | "<|SURPRISED|>": "😮",
35 | }
36 |
37 | event_dict = {
38 | "<|BGM|>": "🎼",
39 | "<|Speech|>": "",
40 | "<|Applause|>": "👏",
41 | "<|Laughter|>": "😀",
42 | "<|Cry|>": "😭",
43 | "<|Sneeze|>": "🤧",
44 | "<|Breath|>": "",
45 | "<|Cough|>": "🤧",
46 | }
47 |
48 | emoji_dict = {
49 | "<|nospeech|><|Event_UNK|>": "❓",
50 | "<|zh|>": "",
51 | "<|en|>": "",
52 | "<|yue|>": "",
53 | "<|ja|>": "",
54 | "<|ko|>": "",
55 | "<|nospeech|>": "",
56 | "<|HAPPY|>": "😊",
57 | "<|SAD|>": "😔",
58 | "<|ANGRY|>": "😡",
59 | "<|NEUTRAL|>": "",
60 | "<|BGM|>": "🎼",
61 | "<|Speech|>": "",
62 | "<|Applause|>": "👏",
63 | "<|Laughter|>": "😀",
64 | "<|FEARFUL|>": "😰",
65 | "<|DISGUSTED|>": "🤢",
66 | "<|SURPRISED|>": "😮",
67 | "<|Cry|>": "😭",
68 | "<|EMO_UNKNOWN|>": "",
69 | "<|Sneeze|>": "🤧",
70 | "<|Breath|>": "",
71 | "<|Cough|>": "😷",
72 | "<|Sing|>": "",
73 | "<|Speech_Noise|>": "",
74 | "<|withitn|>": "",
75 | "<|woitn|>": "",
76 | "<|GBG|>": "",
77 | "<|Event_UNK|>": "",
78 | }
79 |
80 | lang_dict = {
81 | "<|zh|>": "<|lang|>",
82 | "<|en|>": "<|lang|>",
83 | "<|yue|>": "<|lang|>",
84 | "<|ja|>": "<|lang|>",
85 | "<|ko|>": "<|lang|>",
86 | "<|nospeech|>": "<|lang|>",
87 | }
88 |
89 | emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
90 | event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷",}
91 |
92 | lang2token = {
93 | 'zh': "ZH|",
94 | 'ja': "JP|",
95 | "en": "EN|",
96 | "ko": "KO|",
97 | "yue": "YUE|",
98 | }
99 |
100 | def format_str(s):
101 | for sptk in emoji_dict:
102 | s = s.replace(sptk, emoji_dict[sptk])
103 | return s
104 |
105 |
106 | def format_str_v2(s):
107 | sptk_dict = {}
108 | for sptk in emoji_dict:
109 | sptk_dict[sptk] = s.count(sptk)
110 | s = s.replace(sptk, "")
111 | emo = "<|NEUTRAL|>"
112 | for e in emo_dict:
113 | if sptk_dict[e] > sptk_dict[emo]:
114 | emo = e
115 | for e in event_dict:
116 | if sptk_dict[e] > 0:
117 | s = event_dict[e] + s
118 | s = s + emo_dict[emo]
119 |
120 | for emoji in emo_set.union(event_set):
121 | s = s.replace(" " + emoji, emoji)
122 | s = s.replace(emoji + " ", emoji)
123 | return s.strip()
124 |
125 | def format_str_v3(s):
126 | def get_emo(s):
127 | return s[-1] if s[-1] in emo_set else None
128 | def get_event(s):
129 | return s[0] if s[0] in event_set else None
130 |
131 | s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
132 | for lang in lang_dict:
133 | s = s.replace(lang, "<|lang|>")
134 | s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
135 | new_s = " " + s_list[0]
136 | cur_ent_event = get_event(new_s)
137 | for i in range(1, len(s_list)):
138 | if len(s_list[i]) == 0:
139 | continue
140 | if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
141 | s_list[i] = s_list[i][1:]
142 | #else:
143 | cur_ent_event = get_event(s_list[i])
144 | if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
145 | new_s = new_s[:-1]
146 | new_s += s_list[i].strip().lstrip()
147 | new_s = new_s.replace("The.", " ")
148 | return new_s.strip()
149 |
150 | def transcribe_one(audio_path,language):
151 |
152 | model = AutoModel(model=model_dir,
153 | vad_model="fsmn-vad",
154 | vad_kwargs={"max_single_segment_time": 30000},
155 | trust_remote_code=True, device="cuda:0")
156 |
157 | res = model.generate(
158 | input=audio_path,
159 | cache={},
160 | language=language, # "zn", "en", "yue", "ja", "ko", "nospeech"
161 | use_itn=False,
162 | batch_size_s=0,
163 | )
164 |
165 | try:
166 |
167 | text = res[0]["text"]
168 | text = format_str_v3(text)
169 | print(text)
170 | except Exception as e:
171 | print(e)
172 | text = ""
173 |
174 |
175 | return text,language
176 |
177 |
178 | if __name__ == "__main__":
179 |
180 | parser = argparse.ArgumentParser()
181 |
182 | parser.add_argument(
183 | "--language", type=str, default="ja", choices=["ja", "en", "zh","yue","ko"]
184 | )
185 | parser.add_argument("--model_name", type=str, required=True)
186 |
187 |
188 | parser.add_argument("--input_file", type=str, default="./wavs/")
189 |
190 | parser.add_argument("--file_pos", type=str, default="")
191 |
192 |
193 | args = parser.parse_args()
194 |
195 | speaker_name = args.model_name
196 |
197 | language = args.language
198 |
199 |
200 | input_file = args.input_file
201 |
202 | if input_file == "":
203 | input_file = "./wavs/"
204 |
205 | file_pos = args.file_pos
206 |
207 |
208 | wav_files = [
209 | f for f in os.listdir(f"{input_file}") if f.endswith(".wav")
210 | ]
211 |
212 |
213 | with open("./esd.list", "w", encoding="utf-8") as f:
214 | for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
215 | file_name = os.path.basename(wav_file)
216 |
217 | text,lang = transcribe_one(f"{input_file}"+wav_file,language)
218 |
219 | # 使用正则表达式提取'deedee'
220 | match = re.search(r'(^.*?)_.*?(\..*?$)', wav_file)
221 | if match:
222 | extracted_name = match.group(1) + match.group(2)
223 | else:
224 | print("No match found")
225 | extracted_name = "sample"
226 |
227 | if lang == "ja":
228 | language_id = "JA"
229 | elif lang == "en":
230 | language_id = "EN"
231 | elif lang == "zh":
232 | language_id = "ZH"
233 | elif lang == "yue":
234 | language_id = "YUE"
235 | elif lang == "ko":
236 | language_id = "KO"
237 |
238 | f.write(file_pos+f"{file_name}|{extracted_name.replace('.wav','')}|{language_id}|{text}\n")
239 |
240 | f.flush()
241 | sys.exit(0)
242 |
243 |
244 |
--------------------------------------------------------------------------------
/slicer2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # This function is obtained from librosa.
5 | def get_rms(
6 | y,
7 | *,
8 | frame_length=2048,
9 | hop_length=512,
10 | pad_mode="constant",
11 | ):
12 | padding = (int(frame_length // 2), int(frame_length // 2))
13 | y = np.pad(y, padding, mode=pad_mode)
14 |
15 | axis = -1
16 | # put our new within-frame axis at the end for now
17 | out_strides = y.strides + tuple([y.strides[axis]])
18 | # Reduce the shape on the framing axis
19 | x_shape_trimmed = list(y.shape)
20 | x_shape_trimmed[axis] -= frame_length - 1
21 | out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
22 | xw = np.lib.stride_tricks.as_strided(
23 | y, shape=out_shape, strides=out_strides
24 | )
25 | if axis < 0:
26 | target_axis = axis - 1
27 | else:
28 | target_axis = axis + 1
29 | xw = np.moveaxis(xw, -1, target_axis)
30 | # Downsample along the target axis
31 | slices = [slice(None)] * xw.ndim
32 | slices[axis] = slice(0, None, hop_length)
33 | x = xw[tuple(slices)]
34 |
35 | # Calculate power
36 | power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
37 |
38 | return np.sqrt(power)
39 |
40 |
41 | class Slicer:
42 | def __init__(self,
43 | sr: int,
44 | threshold: float = -40.,
45 | min_length: int = 5000,
46 | min_interval: int = 300,
47 | hop_size: int = 20,
48 | max_sil_kept: int = 5000):
49 | if not min_length >= min_interval >= hop_size:
50 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
51 | if not max_sil_kept >= hop_size:
52 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
53 | min_interval = sr * min_interval / 1000
54 | self.threshold = 10 ** (threshold / 20.)
55 | self.hop_size = round(sr * hop_size / 1000)
56 | self.win_size = min(round(min_interval), 4 * self.hop_size)
57 | self.min_length = round(sr * min_length / 1000 / self.hop_size)
58 | self.min_interval = round(min_interval / self.hop_size)
59 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
60 |
61 | def _apply_slice(self, waveform, begin, end):
62 | if len(waveform.shape) > 1:
63 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
64 | else:
65 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
66 |
67 | # @timeit
68 | def slice(self, waveform):
69 | if len(waveform.shape) > 1:
70 | samples = waveform.mean(axis=0)
71 | else:
72 | samples = waveform
73 | if (samples.shape[0] + self.hop_size - 1) // self.hop_size <= self.min_length:
74 | return [waveform]
75 | rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
76 | sil_tags = []
77 | silence_start = None
78 | clip_start = 0
79 | for i, rms in enumerate(rms_list):
80 | # Keep looping while frame is silent.
81 | if rms < self.threshold:
82 | # Record start of silent frames.
83 | if silence_start is None:
84 | silence_start = i
85 | continue
86 | # Keep looping while frame is not silent and silence start has not been recorded.
87 | if silence_start is None:
88 | continue
89 | # Clear recorded silence start if interval is not enough or clip is too short
90 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept
91 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
92 | if not is_leading_silence and not need_slice_middle:
93 | silence_start = None
94 | continue
95 | # Need slicing. Record the range of silent frames to be removed.
96 | if i - silence_start <= self.max_sil_kept:
97 | pos = rms_list[silence_start: i + 1].argmin() + silence_start
98 | if silence_start == 0:
99 | sil_tags.append((0, pos))
100 | else:
101 | sil_tags.append((pos, pos))
102 | clip_start = pos
103 | elif i - silence_start <= self.max_sil_kept * 2:
104 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
105 | pos += i - self.max_sil_kept
106 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
107 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
108 | if silence_start == 0:
109 | sil_tags.append((0, pos_r))
110 | clip_start = pos_r
111 | else:
112 | sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
113 | clip_start = max(pos_r, pos)
114 | else:
115 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
116 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
117 | if silence_start == 0:
118 | sil_tags.append((0, pos_r))
119 | else:
120 | sil_tags.append((pos_l, pos_r))
121 | clip_start = pos_r
122 | silence_start = None
123 | # Deal with trailing silence.
124 | total_frames = rms_list.shape[0]
125 | if silence_start is not None and total_frames - silence_start >= self.min_interval:
126 | silence_end = min(total_frames, silence_start + self.max_sil_kept)
127 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
128 | sil_tags.append((pos, total_frames + 1))
129 | # Apply and return slices.
130 | if len(sil_tags) == 0:
131 | return [waveform]
132 | else:
133 | chunks = []
134 | if sil_tags[0][0] > 0:
135 | chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
136 | for i in range(len(sil_tags) - 1):
137 | chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
138 | if sil_tags[-1][1] < total_frames:
139 | chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
140 | return chunks
141 |
142 |
143 | def main():
144 | import os.path
145 | from argparse import ArgumentParser
146 |
147 | import librosa
148 | import soundfile
149 |
150 | parser = ArgumentParser()
151 | parser.add_argument('audio', type=str, help='The audio to be sliced')
152 | parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips')
153 | parser.add_argument('--db_thresh', type=float, required=False, default=-40,
154 | help='The dB threshold for silence detection')
155 | parser.add_argument('--min_length', type=int, required=False, default=5000,
156 | help='The minimum milliseconds required for each sliced audio clip')
157 | parser.add_argument('--min_interval', type=int, required=False, default=300,
158 | help='The minimum milliseconds for a silence part to be sliced')
159 | parser.add_argument('--hop_size', type=int, required=False, default=10,
160 | help='Frame length in milliseconds')
161 | parser.add_argument('--max_sil_kept', type=int, required=False, default=500,
162 | help='The maximum silence length kept around the sliced clip, presented in milliseconds')
163 | args = parser.parse_args()
164 | out = args.out
165 | if out is None:
166 | out = os.path.dirname(os.path.abspath(args.audio))
167 | audio, sr = librosa.load(args.audio, sr=None, mono=False)
168 | slicer = Slicer(
169 | sr=sr,
170 | threshold=args.db_thresh,
171 | min_length=args.min_length,
172 | min_interval=args.min_interval,
173 | hop_size=args.hop_size,
174 | max_sil_kept=args.max_sil_kept
175 | )
176 | chunks = slicer.slice(audio)
177 | if not os.path.exists(out):
178 | os.makedirs(out)
179 | for i, chunk in enumerate(chunks):
180 | if len(chunk.shape) > 1:
181 | chunk = chunk.T
182 | soundfile.write(os.path.join(out, f'%s_%d.wav' % (os.path.basename(args.audio).rsplit('.', maxsplit=1)[0], i)), chunk, sr)
183 |
184 |
185 | if __name__ == '__main__':
186 | main()
--------------------------------------------------------------------------------
/webui_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import gradio as gr
5 | import yaml
6 |
7 | from common.log import logger
8 | from common.subprocess_utils import run_script_with_log
9 |
10 | from modelscope.pipelines import pipeline
11 | from modelscope.utils.constant import Tasks
12 | from videoclipper import VideoClipper
13 | import librosa
14 | import soundfile as sf
15 | import numpy as np
16 | import random
17 |
18 | dataset_root = ".\\raw\\"
19 |
20 |
21 |
22 | sd_pipeline = pipeline(
23 | task='speaker-diarization',
24 | model='damo/speech_campplus_speaker-diarization_common',
25 | model_revision='v1.0.0'
26 | )
27 |
28 | def audio_change(audio):
29 |
30 | print(audio)
31 |
32 | sf.write('./output_44100.wav', audio[1], audio[0], 'PCM_24')
33 |
34 | y, sr = librosa.load('./output_44100.wav', sr=16000)
35 |
36 | # sf.write('./output_16000.wav', y, sr, 'PCM_24')
37 |
38 | # arr = np.array(y, dtype=np.int32)
39 |
40 | # y, sr = librosa.load('./output_16000.wav', sr=16000)
41 |
42 | audio_data = np.array(y)
43 |
44 | print(y, sr)
45 |
46 | return (16000,audio_data)
47 |
48 | def write_list(text,audio):
49 |
50 | random_number = random.randint(10000, 99999)
51 |
52 | wav_name = f'./wavs/sample_{random_number}.wav'
53 |
54 | sf.write(wav_name, audio[1], audio[0], 'PCM_24')
55 |
56 | text = text.replace("#",",")
57 |
58 | with open("./esd.list","a",encoding="utf-8")as f:f.write(f"\n{wav_name}|sample|en|{text}")
59 |
60 |
61 |
62 |
63 | def audio_recog(audio_input, sd_switch):
64 | print(audio_input)
65 | return audio_clipper.recog(audio_input, sd_switch)
66 |
67 | def audio_clip(dest_text, audio_spk_input, start_ost, end_ost, state):
68 | return audio_clipper.clip(dest_text, start_ost, end_ost, state, dest_spk=audio_spk_input)
69 |
70 | # 音频降噪
71 |
72 | def reset_tts_wav(audio):
73 |
74 | ans = pipeline(
75 | Tasks.acoustic_noise_suppression,
76 | model='damo/speech_frcrn_ans_cirm_16k')
77 | ans(audio,output_path='./output_ins.wav')
78 |
79 | return "./output_ins.wav","./output_ins.wav"
80 |
81 |
82 | def do_slice(
83 | dataset_path: str,
84 | min_sec: int,
85 | max_sec: int,
86 | min_silence_dur_ms: int,
87 | ):
88 | if dataset_path == "":
89 | return "Error: 数据集路径不能为空"
90 | logger.info("Start slicing...")
91 | output_dir = os.path.join(dataset_root, dataset_path, ".\\wavs")
92 |
93 |
94 | cmd = [
95 | "audio_slicer_pre.py",
96 | "--dataset_path",
97 | dataset_path,
98 | "--min_sec",
99 | str(min_sec),
100 | "--max_sec",
101 | str(max_sec),
102 | "--min_silence_dur_ms",
103 | str(min_silence_dur_ms),
104 | ]
105 |
106 |
107 | success, message = run_script_with_log(cmd, ignore_warning=True)
108 | if not success:
109 | return f"Error: {message}"
110 | return "切分完毕"
111 |
112 |
113 | def do_transcribe_fwhisper(
114 | model_name,mytype,language,input_file,file_pos
115 | ):
116 | # if model_name == "":
117 | # return "Error: 角色名不能为空"
118 |
119 |
120 | cmd_py = "short_audio_transcribe_fwhisper.py"
121 |
122 |
123 | success, message = run_script_with_log(
124 | [
125 | cmd_py,
126 | "--model_name",
127 | model_name,
128 | "--language",
129 | language,
130 | "--mytype",
131 | mytype,"--input_file",
132 | input_file,
133 | "--file_pos",
134 | file_pos,
135 |
136 | ]
137 | )
138 | if not success:
139 | return f"Error: {message}"
140 | return "转写完毕"
141 |
142 | def do_transcribe_whisper(
143 | model_name,mytype,language,input_file,file_pos
144 | ):
145 | # if model_name == "":
146 | # return "Error: 角色名不能为空"
147 |
148 |
149 | cmd_py = "short_audio_transcribe_whisper.py"
150 |
151 |
152 | success, message = run_script_with_log(
153 | [
154 | cmd_py,
155 | "--model_name",
156 | model_name,
157 | "--language",
158 | language,
159 | "--mytype",
160 | mytype,"--input_file",
161 | input_file,
162 | "--file_pos",
163 | file_pos,
164 |
165 | ]
166 | )
167 | if not success:
168 | return f"Error: {message}"
169 | return "转写完毕"
170 |
171 |
172 | def do_transcribe_all(
173 | model_name,mytype,language,input_file,file_pos
174 | ):
175 | # if model_name == "":
176 | # return "Error: 角色名不能为空"
177 |
178 |
179 | cmd_py = "short_audio_transcribe_ali.py"
180 |
181 |
182 | if mytype == "bcut":
183 |
184 | cmd_py = "short_audio_transcribe_bcut.py"
185 |
186 | success, message = run_script_with_log(
187 | [
188 | cmd_py,
189 | "--model_name",
190 | model_name,
191 | "--language",
192 | language,
193 | "--input_file",
194 | input_file,
195 | "--file_pos",
196 | file_pos,
197 |
198 | ]
199 | )
200 | if not success:
201 | return f"Error: {message}"
202 | return "转写完毕"
203 |
204 |
205 | initial_md = """
206 |
207 | 请把格式为 角色名.wav 的素材文件放入项目的raw目录
208 |
209 | 作者:刘悦的技术博客 https://space.bilibili.com/3031494
210 |
211 | """
212 |
213 | with gr.Blocks(theme="NoCrypt/miku") as app:
214 | gr.Markdown(initial_md)
215 | model_name = gr.Textbox(label="角色名",placeholder="请输入角色名",visible=False)
216 |
217 |
218 | with gr.Accordion("干声抽离和降噪"):
219 | with gr.Row():
220 | audio_inp_path = gr.Audio(label="请上传克隆对象音频", type="filepath")
221 | reset_inp_button = gr.Button("针对原始素材进行降噪", variant="primary",visible=True)
222 | reset_dataset_path = gr.Textbox(label="降噪后音频地址",placeholder="降噪后生成的音频地址")
223 |
224 |
225 | reset_inp_button.click(reset_tts_wav,[audio_inp_path],[audio_inp_path,reset_dataset_path])
226 |
227 | with gr.Accordion("音频素材切割"):
228 | with gr.Row():
229 | ##add by hyh 添加一个数据集路径的文本框
230 | dataset_path = gr.Textbox(label="音频素材所在路径,默认在项目的raw文件夹,支持批量角色切分",placeholder="设置音频素材所在路径",value="./raw/")
231 | with gr.Column():
232 |
233 | min_sec = gr.Slider(
234 | minimum=0, maximum=7000, value=2500, step=100, label="最低几毫秒"
235 | )
236 | max_sec = gr.Slider(
237 | minimum=0, maximum=15000, value=5000, step=100, label="最高几毫秒"
238 | )
239 | min_silence_dur_ms = gr.Slider(
240 | minimum=500,
241 | maximum=5000,
242 | value=500,
243 | step=100,
244 | label="max_sil_kept长度",
245 | )
246 | slice_button = gr.Button("开始切分")
247 | result1 = gr.Textbox(label="結果")
248 |
249 |
250 |
251 |
252 | with gr.Accordion("音频批量转写,转写文件存放在根目录的est.list"):
253 | with gr.Row():
254 | with gr.Column():
255 |
256 | language = gr.Dropdown(["ja", "en", "zh","ko","yue"], value="zh", label="选择转写的语言")
257 |
258 | mytype = gr.Dropdown(["small","medium","large-v3","large-v2"], value="medium", label="选择Whisper模型")
259 |
260 | input_file = gr.Textbox(label="切片所在目录",placeholder="不填默认为./wavs目录")
261 |
262 | file_pos = gr.Textbox(label="切片名称前缀",placeholder="不填只有切片文件名")
263 |
264 | transcribe_button_whisper = gr.Button("Whisper开始转写")
265 |
266 | transcribe_button_fwhisper = gr.Button("Faster-Whisper开始转写")
267 |
268 | transcribe_button_ali = gr.Button("阿里SenseVoice开始转写")
269 |
270 | transcribe_button_bcut = gr.Button("必剪ASR开始转写")
271 |
272 |
273 | result2 = gr.Textbox(label="結果")
274 |
275 | slice_button.click(
276 | do_slice,
277 | inputs=[dataset_path, min_sec, max_sec, min_silence_dur_ms],
278 | outputs=[result1],
279 | )
280 | transcribe_button_whisper.click(
281 | do_transcribe_whisper,
282 | inputs=[
283 | model_name,
284 | mytype,
285 | language,input_file,file_pos
286 | ],
287 | outputs=[result2],)
288 |
289 |
290 | transcribe_button_fwhisper.click(
291 | do_transcribe_fwhisper,
292 | inputs=[
293 | model_name,
294 | mytype,
295 | language,input_file,file_pos
296 | ],
297 | outputs=[result2],)
298 |
299 |
300 | ali = gr.Text(value="ali",visible=False)
301 |
302 | bcut = gr.Text(value="bcut",visible=False)
303 |
304 |
305 | transcribe_button_ali.click(
306 | do_transcribe_all,
307 | inputs=[
308 | model_name,
309 | ali,
310 | language,input_file,file_pos
311 | ],
312 | outputs=[result2],
313 | )
314 |
315 | transcribe_button_bcut.click(
316 | do_transcribe_all,
317 | inputs=[
318 | model_name,
319 | bcut,
320 | language,input_file,file_pos
321 | ],
322 | outputs=[result2],
323 | )
324 |
325 | parser = argparse.ArgumentParser()
326 | parser.add_argument(
327 | "--server-name",
328 | type=str,
329 | default=None,
330 | help="Server name for Gradio app",
331 | )
332 | parser.add_argument(
333 | "--no-autolaunch",
334 | action="store_true",
335 | default=False,
336 | help="Do not launch app automatically",
337 | )
338 | args = parser.parse_args()
339 |
340 | app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name, server_port=7971)
341 |
--------------------------------------------------------------------------------
/common/tts_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gradio as gr
3 | import torch
4 | import os
5 | import warnings
6 | from gradio.processing_utils import convert_to_16_bit_wav
7 | from typing import Dict, List, Optional
8 |
9 | import utils
10 | from infer import get_net_g, infer
11 | from models import SynthesizerTrn
12 |
13 | from .log import logger
14 | from .constants import (
15 | DEFAULT_ASSIST_TEXT_WEIGHT,
16 | DEFAULT_LENGTH,
17 | DEFAULT_LINE_SPLIT,
18 | DEFAULT_NOISE,
19 | DEFAULT_NOISEW,
20 | DEFAULT_SDP_RATIO,
21 | DEFAULT_SPLIT_INTERVAL,
22 | DEFAULT_STYLE,
23 | DEFAULT_STYLE_WEIGHT,
24 | )
25 |
26 |
27 | class Model:
28 | def __init__(
29 | self, model_path: str, config_path: str, style_vec_path: str, device: str
30 | ):
31 | self.model_path: str = model_path
32 | self.config_path: str = config_path
33 | self.device: str = device
34 | self.style_vec_path: str = style_vec_path
35 | self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
36 | self.spk2id: Dict[str, int] = self.hps.data.spk2id
37 | self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
38 |
39 | self.num_styles: int = self.hps.data.num_styles
40 | if hasattr(self.hps.data, "style2id"):
41 | self.style2id: Dict[str, int] = self.hps.data.style2id
42 | else:
43 | self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
44 | if len(self.style2id) != self.num_styles:
45 | raise ValueError(
46 | f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
47 | )
48 |
49 | self.style_vectors: np.ndarray = np.load(self.style_vec_path)
50 | if self.style_vectors.shape[0] != self.num_styles:
51 | raise ValueError(
52 | f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
53 | )
54 |
55 | self.net_g: Optional[SynthesizerTrn] = None
56 |
57 | def load_net_g(self):
58 | self.net_g = get_net_g(
59 | model_path=self.model_path,
60 | version=self.hps.version,
61 | device=self.device,
62 | hps=self.hps,
63 | )
64 |
65 | def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
66 | mean = self.style_vectors[0]
67 | style_vec = self.style_vectors[style_id]
68 | style_vec = mean + (style_vec - mean) * weight
69 | return style_vec
70 |
71 | def get_style_vector_from_audio(
72 | self, audio_path: str, weight: float = 1.0
73 | ) -> np.ndarray:
74 | from style_gen import get_style_vector
75 |
76 | xvec = get_style_vector(audio_path)
77 | mean = self.style_vectors[0]
78 | xvec = mean + (xvec - mean) * weight
79 | return xvec
80 |
81 | def infer(
82 | self,
83 | text: str,
84 | language: str = "JP",
85 | sid: int = 0,
86 | reference_audio_path: Optional[str] = None,
87 | sdp_ratio: float = DEFAULT_SDP_RATIO,
88 | noise: float = DEFAULT_NOISE,
89 | noisew: float = DEFAULT_NOISEW,
90 | length: float = DEFAULT_LENGTH,
91 | line_split: bool = DEFAULT_LINE_SPLIT,
92 | split_interval: float = DEFAULT_SPLIT_INTERVAL,
93 | assist_text: Optional[str] = None,
94 | assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
95 | use_assist_text: bool = False,
96 | style: str = DEFAULT_STYLE,
97 | style_weight: float = DEFAULT_STYLE_WEIGHT,
98 | given_tone: Optional[list[int]] = None,
99 | ) -> tuple[int, np.ndarray]:
100 | logger.info(f"Start generating audio data from text:\n{text}")
101 | if reference_audio_path == "":
102 | reference_audio_path = None
103 | if assist_text == "" or not use_assist_text:
104 | assist_text = None
105 |
106 | if self.net_g is None:
107 | self.load_net_g()
108 | if reference_audio_path is None:
109 | style_id = self.style2id[style]
110 | style_vector = self.get_style_vector(style_id, style_weight)
111 | else:
112 | style_vector = self.get_style_vector_from_audio(
113 | reference_audio_path, style_weight
114 | )
115 | if not line_split:
116 | with torch.no_grad():
117 | audio = infer(
118 | text=text,
119 | sdp_ratio=sdp_ratio,
120 | noise_scale=noise,
121 | noise_scale_w=noisew,
122 | length_scale=length,
123 | sid=sid,
124 | language=language,
125 | hps=self.hps,
126 | net_g=self.net_g,
127 | device=self.device,
128 | assist_text=assist_text,
129 | assist_text_weight=assist_text_weight,
130 | style_vec=style_vector,
131 | given_tone=given_tone,
132 | )
133 | else:
134 | texts = text.split("\n")
135 | texts = [t for t in texts if t != ""]
136 | audios = []
137 | with torch.no_grad():
138 | for i, t in enumerate(texts):
139 | audios.append(
140 | infer(
141 | text=t,
142 | sdp_ratio=sdp_ratio,
143 | noise_scale=noise,
144 | noise_scale_w=noisew,
145 | length_scale=length,
146 | sid=sid,
147 | language=language,
148 | hps=self.hps,
149 | net_g=self.net_g,
150 | device=self.device,
151 | assist_text=assist_text,
152 | assist_text_weight=assist_text_weight,
153 | style_vec=style_vector,
154 | )
155 | )
156 | if i != len(texts) - 1:
157 | audios.append(np.zeros(int(44100 * split_interval)))
158 | audio = np.concatenate(audios)
159 | with warnings.catch_warnings():
160 | warnings.simplefilter("ignore")
161 | audio = convert_to_16_bit_wav(audio)
162 | logger.info("Audio data generated successfully")
163 | return (self.hps.data.sampling_rate, audio)
164 |
165 |
166 | class ModelHolder:
167 | def __init__(self, root_dir: str, device: str):
168 | self.root_dir: str = root_dir
169 | self.device: str = device
170 | self.model_files_dict: Dict[str, List[str]] = {}
171 | self.current_model: Optional[Model] = None
172 | self.model_names: List[str] = []
173 | self.models: List[Model] = []
174 | self.refresh()
175 |
176 | def refresh(self):
177 | self.model_files_dict = {}
178 | self.model_names = []
179 | self.current_model = None
180 | model_dirs = [
181 | d
182 | for d in os.listdir(self.root_dir)
183 | if os.path.isdir(os.path.join(self.root_dir, d))
184 | ]
185 | for model_name in model_dirs:
186 | model_dir = os.path.join(self.root_dir, model_name)
187 | model_files = [
188 | os.path.join(model_dir, f)
189 | for f in os.listdir(model_dir)
190 | if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
191 | ]
192 | if len(model_files) == 0:
193 | logger.warning(
194 | f"No model files found in {self.root_dir}/{model_name}, so skip it"
195 | )
196 | continue
197 | self.model_files_dict[model_name] = model_files
198 | self.model_names.append(model_name)
199 |
200 | def load_model_gr(
201 | self, model_name: str, model_path: str
202 | ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
203 | if model_name not in self.model_files_dict:
204 | raise ValueError(f"Model `{model_name}` is not found")
205 | if model_path not in self.model_files_dict[model_name]:
206 | raise ValueError(f"Model file `{model_path}` is not found")
207 | self.current_model = Model(
208 | model_path=model_path,
209 | config_path=os.path.join(self.root_dir, model_name, "config.json"),
210 | style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
211 | device=self.device,
212 | )
213 | speakers = list(self.current_model.spk2id.keys())
214 | styles = list(self.current_model.style2id.keys())
215 | return (
216 | gr.Dropdown(choices=styles, value=styles[0]),
217 | gr.Button(interactive=True, value="音声合成"),
218 | gr.Dropdown(choices=speakers, value=speakers[0]),
219 | )
220 |
221 | def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
222 | model_files = self.model_files_dict[model_name]
223 | return gr.Dropdown(choices=model_files, value=model_files[0])
224 |
225 | def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
226 | self.refresh()
227 | initial_model_name = self.model_names[0]
228 | initial_model_files = self.model_files_dict[initial_model_name]
229 | return (
230 | gr.Dropdown(choices=self.model_names, value=initial_model_name),
231 | gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
232 | gr.Button(interactive=False), # For tts_button
233 | )
234 |
--------------------------------------------------------------------------------
/bcut_asr/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from os import PathLike
4 | from pathlib import Path
5 | from typing import Literal, Optional
6 | import requests
7 | import sys
8 | import ffmpeg
9 | from .orm import (ResourceCompleteRspSchema, ResourceCreateRspSchema,
10 | ResultRspSchema, ResultStateEnum, TaskCreateRspSchema)
11 | from typing import Optional, Union
12 | from os import PathLike
13 |
14 | __version__ = '0.0.2'
15 |
16 | API_REQ_UPLOAD = 'https://member.bilibili.com/x/bcut/rubick-interface/resource/create' # 申请上传
17 | API_COMMIT_UPLOAD = 'https://member.bilibili.com/x/bcut/rubick-interface/resource/create/complete' # 提交上传
18 | API_CREATE_TASK = 'https://member.bilibili.com/x/bcut/rubick-interface/task' # 创建任务
19 | API_QUERY_RESULT = 'https://member.bilibili.com/x/bcut/rubick-interface/task/result' # 查询结果
20 |
21 | SUPPORT_SOUND_FORMAT = Literal['flac', 'aac', 'm4a', 'mp3', 'wav']
22 |
23 | INFILE_FMT = ('flac', 'aac', 'm4a', 'mp3', 'wav')
24 | OUTFILE_FMT = ('srt', 'json', 'lrc', 'txt')
25 |
26 |
27 | def ffmpeg_render(media_file: str) -> bytes:
28 | '提取视频伴音并转码为aac格式'
29 | out, err = (ffmpeg
30 | .input(media_file, v='warning')
31 | .output('pipe:', ac=1, format='adts')
32 | .run(capture_stdout=True)
33 | )
34 | return out
35 |
36 |
37 | def run_everywhere(argg):
38 | logging.basicConfig(format='%(asctime)s - [%(levelname)s] %(message)s', level=logging.INFO)
39 | # 处理输入文件情况
40 | infile = argg.input
41 | infile_name = infile.name
42 | if infile_name == '':
43 | logging.error('输入文件错误')
44 | sys.exit(-1)
45 | suffix = infile_name.rsplit('.', 1)[-1]
46 | if suffix in INFILE_FMT:
47 | infile_fmt = suffix
48 | infile_data = infile.read()
49 | else:
50 | # ffmpeg分离视频伴音
51 | logging.info('非标准音频文件, 尝试调用ffmpeg转码')
52 | try:
53 | infile_data = ffmpeg_render(infile_name)
54 | except ffmpeg.Error:
55 | logging.error('ffmpeg转码失败')
56 | sys.exit(-1)
57 | else:
58 | logging.info('ffmpeg转码完成')
59 | infile_fmt = 'aac'
60 |
61 | # 处理输出文件情况
62 | outfile = argg.output
63 | if outfile is None:
64 | # 未指定输出文件,默认为文件名同输入,可以 -t 传参,默认str格式
65 | if argg.format is not None:
66 | outfile_fmt = argg.format
67 | else:
68 | outfile_fmt = 'srt'
69 | else:
70 | # 指定输出文件
71 | outfile_name = outfile.name
72 | if outfile.name == '':
73 | # stdout情况,可以 -t 传参,默认str格式
74 | if argg.format is not None:
75 | outfile_fmt = argg.format
76 | else:
77 | outfile_fmt = 'srt'
78 | else:
79 | suffix = outfile_name.rsplit('.', 1)[-1]
80 | if suffix in OUTFILE_FMT:
81 | outfile_fmt = suffix
82 | else:
83 | logging.error('输出格式错误')
84 | sys.exit(-1)
85 |
86 | # 开始执行转换逻辑
87 | asr = BcutASR()
88 | asr.set_data(raw_data=infile_data, data_fmt=infile_fmt)
89 | try:
90 | # 上传文件
91 | asr.upload()
92 | # 创建任务
93 | task_id = asr.create_task()
94 | while True:
95 | # 轮询检查任务状态
96 | task_resp = asr.result()
97 | # match task_resp.state:
98 | # case ResultStateEnum.STOP:
99 | # logging.info(f'等待识别开始')
100 | # case ResultStateEnum.RUNING:
101 | # logging.info(f'识别中-{task_resp.remark}')
102 | # case ResultStateEnum.ERROR:
103 | # logging.error(f'识别失败-{task_resp.remark}')
104 | # sys.exit(-1)
105 | # case ResultStateEnum.COMPLETE:
106 | # outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}"
107 | # outfile = open(outfile_name, 'w', encoding='utf8')
108 | # logging.info(f'识别成功')
109 | # # 识别成功, 回读字幕数据
110 | # result = task_resp.parse()
111 | # break
112 |
113 | if task_resp.state == ResultStateEnum.STOP:
114 | logging.info(f'等待识别开始')
115 | elif task_resp.state == ResultStateEnum.RUNING:
116 | logging.info(f'识别中-{task_resp.remark}')
117 | elif task_resp.state == ResultStateEnum.ERROR:
118 | logging.error(f'识别失败-{task_resp.remark}')
119 | sys.exit(-1)
120 | elif task_resp.state == ResultStateEnum.COMPLETE:
121 | outfile_name = f"{infile_name.rsplit('.', 1)[-2]}.{outfile_fmt}"
122 | outfile = open(outfile_name, 'w', encoding='utf8')
123 | logging.info(f'识别成功')
124 | # 识别成功, 回读字幕数据
125 | result = task_resp.parse()
126 | break
127 |
128 |
129 | time.sleep(300.0)
130 | if not result.has_data():
131 | logging.error('未识别到语音')
132 | sys.exit(-1)
133 | # match outfile_fmt:
134 | # case 'srt':
135 | # outfile.write(result.to_srt())
136 | # case 'lrc':
137 | # outfile.write(result.to_lrc())
138 | # case 'json':
139 | # outfile.write(result.json())
140 | # case 'txt':
141 | # outfile.write(result.to_txt())
142 | if outfile_fmt == 'srt':
143 | outfile.write(result.to_srt())
144 | elif outfile_fmt == 'lrc':
145 | outfile.write(result.to_lrc())
146 | elif outfile_fmt == 'json':
147 | outfile.write(result.json())
148 | elif outfile_fmt == 'txt':
149 | outfile.write(result.to_txt())
150 |
151 | logging.info(f'转换成功: {outfile_name}')
152 | except APIError as err:
153 | logging.error(f'接口错误: {err.__str__()}')
154 | sys.exit(-1)
155 |
156 | class APIError(Exception):
157 | '接口调用错误'
158 | def __init__(self, code, msg) -> None:
159 | self.code = code
160 | self.msg = msg
161 | super().__init__()
162 | def __str__(self) -> str:
163 | return f'{self.code}:{self.msg}'
164 |
165 | class BcutASR:
166 | '必剪 语音识别接口'
167 | session: requests.Session
168 | sound_name: str
169 | sound_bin: bytes
170 | sound_fmt: SUPPORT_SOUND_FORMAT
171 | __in_boss_key: str
172 | __resource_id: str
173 | __upload_id: str
174 | __upload_urls: list[str]
175 | __per_size: int
176 | __clips: int
177 | __etags: list[str]
178 | __download_url: str
179 | task_id: str
180 |
181 | # def __init__(self, file: Optional[str | PathLike] = None) -> None:
182 | def __init__(self, file: Optional[Union[str, PathLike]] = None) -> None:
183 | self.session = requests.Session()
184 | self.task_id = None
185 | self.__etags = []
186 | if file:
187 | self.set_data(file)
188 |
189 | def set_data(self,
190 | # file: Optional[str | PathLike] = None,
191 | file: Optional[Union[str, PathLike]] = None,
192 | raw_data: Optional[bytes] = None,
193 | data_fmt: Optional[SUPPORT_SOUND_FORMAT] = None
194 | ) -> None:
195 | '设置欲识别的数据'
196 | if file:
197 | if not isinstance(file, (str, PathLike)):
198 | raise TypeError('unknow file ptr')
199 | # 文件类
200 | file = Path(file)
201 | self.sound_bin = open(file, 'rb').read()
202 | suffix = data_fmt or file.suffix[1:]
203 | self.sound_name = file.name
204 | elif raw_data:
205 | # bytes类
206 | self.sound_bin = raw_data
207 | suffix = data_fmt
208 | self.sound_name = f'{int(time.time())}.{suffix}'
209 | else:
210 | raise ValueError('none set data')
211 | if suffix not in SUPPORT_SOUND_FORMAT.__args__:
212 | raise TypeError('format is not support')
213 | self.sound_fmt = suffix
214 | logging.info(f'加载文件成功: {self.sound_name}')
215 |
216 | def upload(self) -> None:
217 | '申请上传'
218 | if not self.sound_bin or not self.sound_fmt:
219 | raise ValueError('none set data')
220 | resp = self.session.post(API_REQ_UPLOAD, data={
221 | 'type': 2,
222 | 'name': self.sound_name,
223 | 'size': len(self.sound_bin),
224 | 'resource_file_type': self.sound_fmt,
225 | 'model_id': 7
226 | })
227 | resp.raise_for_status()
228 | resp = resp.json()
229 | code = resp['code']
230 | if code:
231 | raise APIError(code, resp['message'])
232 | resp_data = ResourceCreateRspSchema.parse_obj(resp['data'])
233 | self.__in_boss_key = resp_data.in_boss_key
234 | self.__resource_id = resp_data.resource_id
235 | self.__upload_id = resp_data.upload_id
236 | self.__upload_urls = resp_data.upload_urls
237 | self.__per_size = resp_data.per_size
238 | self.__clips = len(resp_data.upload_urls)
239 | logging.info(f'申请上传成功, 总计大小{resp_data.size // 1024}KB, {self.__clips}分片, 分片大小{resp_data.per_size // 1024}KB: {self.__in_boss_key}')
240 | self.__upload_part()
241 | self.__commit_upload()
242 |
243 | def __upload_part(self) -> None:
244 | '上传音频数据'
245 | for clip in range(self.__clips):
246 | start_range = clip * self.__per_size
247 | end_range = (clip + 1) * self.__per_size
248 | logging.info(f'开始上传分片{clip}: {start_range}-{end_range}')
249 | resp = self.session.put(self.__upload_urls[clip],
250 | data=self.sound_bin[start_range:end_range],
251 | )
252 | resp.raise_for_status()
253 | etag = resp.headers.get('Etag')
254 | self.__etags.append(etag)
255 | logging.info(f'分片{clip}上传成功: {etag}')
256 |
257 | def __commit_upload(self) -> None:
258 | '提交上传数据'
259 | resp = self.session.post(API_COMMIT_UPLOAD, data={
260 | 'in_boss_key': self.__in_boss_key,
261 | 'resource_id': self.__resource_id,
262 | 'etags': ','.join(self.__etags),
263 | 'upload_id': self.__upload_id,
264 | 'model_id': 7
265 | })
266 | resp.raise_for_status()
267 | resp = resp.json()
268 | code = resp['code']
269 | if code:
270 | raise APIError(code, resp['message'])
271 | resp_data = ResourceCompleteRspSchema.parse_obj(resp['data'])
272 | self.__download_url = resp_data.download_url
273 | logging.info(f'提交成功')
274 |
275 | def create_task(self) -> str:
276 | '开始创建转换任务'
277 | resp = self.session.post(API_CREATE_TASK, json={
278 | 'resource': self.__download_url,
279 | 'model_id': '7'
280 | })
281 | resp.raise_for_status()
282 | resp = resp.json()
283 | code = resp['code']
284 | if code:
285 | raise APIError(code, resp['message'])
286 | resp_data = TaskCreateRspSchema.parse_obj(resp['data'])
287 | self.task_id = resp_data.task_id
288 | logging.info(f'任务已创建: {self.task_id}')
289 | return self.task_id
290 |
291 | def result(self, task_id: Optional[str] = None) -> ResultRspSchema:
292 | '查询转换结果'
293 | resp = self.session.get(API_QUERY_RESULT, params={
294 | 'model_id': 7,
295 | 'task_id': task_id or self.task_id
296 | })
297 | resp.raise_for_status()
298 | resp = resp.json()
299 | code = resp['code']
300 | if code:
301 | raise APIError(code, resp['message'])
302 | return ResultRspSchema.parse_obj(resp['data'])
303 |
--------------------------------------------------------------------------------
/videoclipper.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import copy
3 | import librosa
4 | import logging
5 | import argparse
6 | import numpy as np
7 | import soundfile as sf
8 | import moviepy.editor as mpy
9 | # from modelscope.pipelines import pipeline
10 | # from modelscope.utils.constant import Tasks
11 | from subtitle_utils import generate_srt, generate_srt_clip, distribute_spk
12 | from trans_utils import pre_proc, proc, write_state, load_state, proc_spk, generate_vad_data
13 | from argparse_tools import ArgumentParser, get_commandline_args
14 |
15 | from moviepy.editor import *
16 | from moviepy.video.tools.subtitles import SubtitlesClip
17 |
18 |
19 | class VideoClipper():
20 | def __init__(self, asr_pipeline, sd_pipeline=None):
21 | logging.warning("Initializing VideoClipper.")
22 | self.asr_pipeline = asr_pipeline
23 | self.sd_pipeline = sd_pipeline
24 |
25 | def recog(self, audio_input, sd_switch='no', state=None):
26 | if state is None:
27 | state = {}
28 | sr, data = audio_input
29 | assert sr == 16000, "16kHz sample rate required, {} given.".format(sr)
30 | if len(data.shape) == 2: # multi-channel wav input
31 | logging.warning("Input wav shape: {}, only first channel reserved.").format(data.shape)
32 | data = data[:,0]
33 | state['audio_input'] = (sr, data)
34 | data = data.astype(np.float64)
35 | rec_result = self.asr_pipeline(audio_in=data)
36 | if sd_switch == 'yes':
37 | vad_data = generate_vad_data(data.astype(np.float32), rec_result['sentences'], sr)
38 | sd_result = self.sd_pipeline(audio=vad_data, batch_size=1)
39 | rec_result['sd_sentences'] = distribute_spk(rec_result['sentences'], sd_result['text'])
40 | res_srt = generate_srt(rec_result['sd_sentences'])
41 | state['sd_sentences'] = rec_result['sd_sentences']
42 | else:
43 | res_srt = generate_srt(rec_result['sentences'])
44 | state['recog_res_raw'] = rec_result['text_postprocessed']
45 | state['timestamp'] = rec_result['time_stamp']
46 | state['sentences'] = rec_result['sentences']
47 | res_text = rec_result['text']
48 | return res_text, res_srt, state
49 |
50 | def clip(self, dest_text, start_ost, end_ost, state, dest_spk=None):
51 | # get from state
52 | audio_input = state['audio_input']
53 | recog_res_raw = state['recog_res_raw']
54 | timestamp = state['timestamp']
55 | sentences = state['sentences']
56 | sr, data = audio_input
57 | data = data.astype(np.float64)
58 |
59 | all_ts = []
60 | if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
61 | for _dest_text in dest_text.split('#'):
62 | _dest_text = pre_proc(_dest_text)
63 | ts = proc(recog_res_raw, timestamp, _dest_text)
64 | for _ts in ts: all_ts.append(_ts)
65 | else:
66 | for _dest_spk in dest_spk.split('#'):
67 | ts = proc_spk(_dest_spk, state['sd_sentences'])
68 | for _ts in ts: all_ts.append(_ts)
69 | ts = all_ts
70 | # ts.sort()
71 | srt_index = 0
72 | clip_srt = ""
73 | if len(ts):
74 | start, end = ts[0]
75 | start = min(max(0, start+start_ost*16), len(data))
76 | end = min(max(0, end+end_ost*16), len(data))
77 | res_audio = data[start:end]
78 | start_end_info = "from {} to {}".format(start/16000, end/16000)
79 | srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index)
80 | clip_srt += srt_clip
81 | for _ts in ts[1:]: # multiple sentence input or multiple output matched
82 | start, end = _ts
83 | start = min(max(0, start+start_ost*16), len(data))
84 | end = min(max(0, end+end_ost*16), len(data))
85 | start_end_info += ", from {} to {}".format(start, end)
86 | res_audio = np.concatenate([res_audio, data[start+start_ost*16:end+end_ost*16]], -1)
87 | srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index-1)
88 | clip_srt += srt_clip
89 | if len(ts):
90 | message = "{} periods found in the speech: ".format(len(ts)) + start_end_info
91 | else:
92 | message = "No period found in the speech, return raw speech. You may check the recognition result and try other destination text."
93 | res_audio = data
94 | return (sr, res_audio), message, clip_srt
95 |
96 | def video_recog(self, vedio_filename, sd_switch='no'):
97 | vedio_filename = vedio_filename
98 | clip_video_file = vedio_filename[:-4] + '_clip.mp4'
99 | video = mpy.VideoFileClip(vedio_filename)
100 | audio_file = vedio_filename[:-3] + 'wav'
101 | video.audio.write_audiofile(audio_file)
102 | wav = librosa.load(audio_file, sr=16000)[0]
103 | state = {
104 | 'vedio_filename': vedio_filename,
105 | 'clip_video_file': clip_video_file,
106 | 'video': video,
107 | }
108 | # res_text, res_srt = self.recog((16000, wav), state)
109 | return self.recog((16000, wav), sd_switch, state)
110 |
111 | def video_clip(self, dest_text, start_ost, end_ost, state, font_size=32, font_color='white', add_sub=False, dest_spk=None):
112 | # get from state
113 | recog_res_raw = state['recog_res_raw']
114 | timestamp = state['timestamp']
115 | sentences = state['sentences']
116 | video = state['video']
117 | clip_video_file = state['clip_video_file']
118 | vedio_filename = state['vedio_filename']
119 |
120 | all_ts = []
121 | srt_index = 0
122 | if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
123 | for _dest_text in dest_text.split('#'):
124 | _dest_text = pre_proc(_dest_text)
125 | ts = proc(recog_res_raw, timestamp, _dest_text)
126 | for _ts in ts: all_ts.append(_ts)
127 | else:
128 | for _dest_spk in dest_spk.split('#'):
129 | ts = proc_spk(_dest_spk, state['sd_sentences'])
130 | for _ts in ts: all_ts.append(_ts)
131 | time_acc_ost = 0.0
132 | ts = all_ts
133 | # ts.sort()
134 | clip_srt = ""
135 | if len(ts):
136 | start, end = ts[0][0] / 16000, ts[0][1] / 16000
137 | srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index, time_acc_ost=time_acc_ost)
138 | start, end = start+start_ost/1000.0, end+end_ost/1000.0
139 | video_clip = video.subclip(start, end)
140 | start_end_info = "from {} to {}".format(start, end)
141 | clip_srt += srt_clip
142 | if add_sub:
143 | generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
144 | subtitles = SubtitlesClip(subs, generator)
145 | video_clip = CompositeVideoClip([video_clip, subtitles.set_pos(('center','bottom'))])
146 | concate_clip = [video_clip]
147 | time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
148 | for _ts in ts[1:]:
149 | start, end = _ts[0] / 16000, _ts[1] / 16000
150 | srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index-1, time_acc_ost=time_acc_ost)
151 | chi_subs = []
152 | sub_starts = subs[0][0][0]
153 | for sub in subs:
154 | chi_subs.append(((sub[0][0]-sub_starts, sub[0][1]-sub_starts), sub[1]))
155 | start, end = start+start_ost/1000.0, end+end_ost/1000.0
156 | _video_clip = video.subclip(start, end)
157 | start_end_info += ", from {} to {}".format(start, end)
158 | clip_srt += srt_clip
159 | if add_sub:
160 | generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
161 | subtitles = SubtitlesClip(chi_subs, generator)
162 | _video_clip = CompositeVideoClip([_video_clip, subtitles.set_pos(('center','bottom'))])
163 | # _video_clip.write_videofile("debug.mp4", audio_codec="aac")
164 | concate_clip.append(copy.copy(_video_clip))
165 | time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
166 | message = "{} periods found in the audio: ".format(len(ts)) + start_end_info
167 | logging.warning("Concating...")
168 | if len(concate_clip) > 1:
169 | video_clip = concatenate_videoclips(concate_clip)
170 | video_clip.write_videofile(clip_video_file, audio_codec="aac")
171 | else:
172 | clip_video_file = vedio_filename
173 | message = "No period found in the audio, return raw speech. You may check the recognition result and try other destination text."
174 | srt_clip = ''
175 | return clip_video_file, message, clip_srt
176 |
177 |
178 | def get_parser():
179 | parser = ArgumentParser(
180 | description="ClipVideo Argument",
181 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
182 | )
183 | parser.add_argument(
184 | "--stage",
185 | type=int,
186 | choices=(1, 2),
187 | help="Stage, 0 for recognizing and 1 for clipping",
188 | required=True
189 | )
190 | parser.add_argument(
191 | "--file",
192 | type=str,
193 | default=None,
194 | help="Input file path",
195 | required=True
196 | )
197 | parser.add_argument(
198 | "--sd_switch",
199 | type=str,
200 | choices=("no", "yes"),
201 | default="no",
202 | help="Trun on the speaker diarization or not",
203 | )
204 | parser.add_argument(
205 | "--output_dir",
206 | type=str,
207 | default='./output',
208 | help="Output files path",
209 | )
210 | parser.add_argument(
211 | "--dest_text",
212 | type=str,
213 | default=None,
214 | help="Destination text string for clipping",
215 | )
216 | parser.add_argument(
217 | "--dest_spk",
218 | type=str,
219 | default=None,
220 | help="Destination spk id for clipping",
221 | )
222 | parser.add_argument(
223 | "--start_ost",
224 | type=int,
225 | default=0,
226 | help="Offset time in ms at beginning for clipping"
227 | )
228 | parser.add_argument(
229 | "--end_ost",
230 | type=int,
231 | default=0,
232 | help="Offset time in ms at ending for clipping"
233 | )
234 | parser.add_argument(
235 | "--output_file",
236 | type=str,
237 | default=None,
238 | help="Output file path"
239 | )
240 | return parser
241 |
242 |
243 | def runner(stage, file, sd_switch, output_dir, dest_text, dest_spk, start_ost, end_ost, output_file, config=None):
244 | audio_suffixs = ['wav']
245 | video_suffixs = ['mp4']
246 | if file[-3:] in audio_suffixs:
247 | mode = 'audio'
248 | elif file[-3:] in video_suffixs:
249 | mode = 'video'
250 | else:
251 | logging.error("Unsupported file format: {}".format(file))
252 | while output_dir.endswith('/'):
253 | output_dir = output_dir[:-1]
254 | if stage == 1:
255 | from modelscope.pipelines import pipeline
256 | from modelscope.utils.constant import Tasks
257 | # initialize modelscope asr pipeline
258 | logging.warning("Initializing modelscope asr pipeline.")
259 | inference_pipeline = pipeline(
260 | task=Tasks.auto_speech_recognition,
261 | model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
262 | vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
263 | punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
264 | output_dir=output_dir,
265 | )
266 | sd_pipeline = pipeline(
267 | task='speaker-diarization',
268 | model='damo/speech_campplus_speaker-diarization_common',
269 | model_revision='v1.0.0'
270 | )
271 | audio_clipper = VideoClipper(inference_pipeline, sd_pipeline)
272 | if mode == 'audio':
273 | logging.warning("Recognizing audio file: {}".format(file))
274 | wav, sr = librosa.load(file, sr=16000)
275 | res_text, res_srt, state = audio_clipper.recog((sr, wav), sd_switch)
276 | if mode == 'video':
277 | logging.warning("Recognizing video file: {}".format(file))
278 | res_text, res_srt, state = audio_clipper.video_recog(file, sd_switch)
279 | total_srt_file = output_dir + '/total.srt'
280 | with open(total_srt_file, 'w') as fout:
281 | fout.write(res_srt)
282 | logging.warning("Write total subtitile to {}".format(total_srt_file))
283 | write_state(output_dir, state)
284 | logging.warning("Recognition successed. You can copy the text segment from below and use stage 2.")
285 | print(res_text)
286 | if stage == 2:
287 | audio_clipper = VideoClipper(None)
288 | if mode == 'audio':
289 | state = load_state(output_dir)
290 | wav, sr = librosa.load(file, sr=16000)
291 | state['audio_input'] = (sr, wav)
292 | (sr, audio), message, srt_clip = audio_clipper.clip(dest_text, start_ost, end_ost, state, dest_spk=dest_spk)
293 | if output_file is None:
294 | output_file = output_dir + '/result.wav'
295 | clip_srt_file = output_file[:-3] + 'srt'
296 | logging.warning(message)
297 | sf.write(output_file, audio, 16000)
298 | assert output_file.endswith('.wav'), "output_file must ends with '.wav'"
299 | logging.warning("Save clipped wav file to {}".format(output_file))
300 | with open(clip_srt_file, 'w') as fout:
301 | fout.write(srt_clip)
302 | logging.warning("Write clipped subtitile to {}".format(clip_srt_file))
303 | if mode == 'video':
304 | state = load_state(output_dir)
305 | state['vedio_filename'] = file
306 | if output_file is None:
307 | state['clip_video_file'] = file[:-4] + '_clip.mp4'
308 | else:
309 | state['clip_video_file'] = output_file
310 | clip_srt_file = state['clip_video_file'][:-3] + 'srt'
311 | state['video'] = mpy.VideoFileClip(file)
312 | clip_video_file, message, srt_clip = audio_clipper.video_clip(dest_text, start_ost, end_ost, state, dest_spk=dest_spk)
313 | logging.warning("Clipping Log: {}".format(message))
314 | logging.warning("Save clipped mp4 file to {}".format(clip_video_file))
315 | with open(clip_srt_file, 'w') as fout:
316 | fout.write(srt_clip)
317 | logging.warning("Write clipped subtitile to {}".format(clip_srt_file))
318 |
319 |
320 | def main(cmd=None):
321 | print(get_commandline_args(), file=sys.stderr)
322 | parser = get_parser()
323 | args = parser.parse_args(cmd)
324 | kwargs = vars(args)
325 | runner(**kwargs)
326 |
327 |
328 | if __name__ == '__main__':
329 | main()
--------------------------------------------------------------------------------