├── wenet_asr_server ├── wenet │ ├── cli │ │ ├── __init__.py │ │ ├── transcribe.py │ │ ├── paraformer_model.py │ │ └── hub.py │ ├── k2 │ │ └── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ └── dataset.py │ ├── text │ │ ├── __init__.py │ │ ├── base_tokenizer.py │ │ ├── bpe_tokenizer.py │ │ ├── paraformer_tokenizer.py │ │ ├── hugging_face_tokenizer.py │ │ ├── tokenize_utils.py │ │ ├── char_tokenizer.py │ │ └── whisper_tokenizer.py │ ├── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── file_utils.py │ │ ├── init_tokenizer.py │ │ ├── cmvn.py │ │ ├── class_utils.py │ │ ├── checkpoint.py │ │ ├── ctc_utils.py │ │ ├── executor.py │ │ └── init_model.py │ ├── whisper │ │ ├── __init__.py │ │ └── whisper.py │ ├── branchformer │ │ └── __init__.py │ ├── paraformer │ │ ├── __init__.py │ │ ├── embedding.py │ │ └── subsampling.py │ ├── squeezeformer │ │ ├── __init__.py │ │ ├── conv2d.py │ │ ├── positionwise_feed_forward.py │ │ └── encoder_layer.py │ ├── transducer │ │ ├── __init__.py │ │ ├── search │ │ │ ├── greedy_search.py │ │ │ └── prefix_beam_search.py │ │ └── joint.py │ ├── transformer │ │ ├── __init__.py │ │ ├── norm.py │ │ ├── swish.py │ │ ├── cmvn.py │ │ ├── ctc.py │ │ ├── label_smoothing_loss.py │ │ ├── convolution.py │ │ ├── positionwise_feed_forward.py │ │ └── decoder_layer.py │ ├── efficient_conformer │ │ ├── __init__.py │ │ ├── subsampling.py │ │ └── convolution.py │ ├── __init__.py │ ├── README.md │ ├── bin │ │ ├── export_jit.py │ │ ├── export_ipex.py │ │ ├── average_model.py │ │ └── train.py │ └── ssl │ │ ├── wav2vec2 │ │ └── quantizer.py │ │ └── bestrq │ │ └── mask.py ├── lib │ ├── audio_read.py │ └── wenet_asr_pb2.py ├── README.md └── requirements.txt ├── wenet_asr_client ├── requirements.txt ├── example_single.py ├── cer.py ├── lib │ ├── utils.py │ └── wenet_asr_pb2.py ├── transcribe_to_txt.py ├── cal_cer.py ├── README.md ├── example_streaming.py └── recognizer.py ├── .gitignore ├── README.md └── protos └── wenet_asr.proto /wenet_asr_server/wenet/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/k2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/branchformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/paraformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/squeezeformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_client/requirements.txt: -------------------------------------------------------------------------------- 1 | grpcio 2 | grpcio-tools -------------------------------------------------------------------------------- /wenet_asr_server/wenet/efficient_conformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/__init__.py: -------------------------------------------------------------------------------- 1 | from wenet.cli.model import load_model # noqa 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wenet_asr_client/audio/ 3 | wenet_asr_server/models/ 4 | log/ 5 | transcript* 6 | test*.py 7 | cer*.csv 8 | baidu_api 9 | user_dict.txt 10 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/paraformer/embedding.py: -------------------------------------------------------------------------------- 1 | from wenet.transformer.embedding import WhisperPositionalEncoding 2 | 3 | 4 | class ParaformerPositinoalEncoding(WhisperPositionalEncoding): 5 | """ Sinusoids position encoding used in paraformer.encoder 6 | """ 7 | 8 | def __init__(self, 9 | depth: int, 10 | d_model: int, 11 | dropout_rate: float = 0.1, 12 | max_len: int = 1500): 13 | super().__init__(depth, dropout_rate, max_len) 14 | self.xscale = d_model**0.5 15 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RMSNorm(torch.nn.Module): 5 | """ https://arxiv.org/pdf/1910.07467.pdf 6 | """ 7 | 8 | def __init__( 9 | self, 10 | dim: int, 11 | eps: float = 1e-6, 12 | ): 13 | super().__init__() 14 | self.eps = eps 15 | self.weight = torch.nn.Parameter(torch.ones(dim)) 16 | 17 | def _norm(self, x): 18 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 19 | 20 | def forward(self, x): 21 | x = self._norm(x.float()).type_as(x) 22 | return x * self.weight 23 | -------------------------------------------------------------------------------- /wenet_asr_client/example_single.py: -------------------------------------------------------------------------------- 1 | import wave 2 | 3 | from recognizer import Recognizer 4 | 5 | 6 | IP_PORT = "localhost:50051" 7 | AUDIO_FILE = 'audio/zh.wav' 8 | 9 | 10 | def nonstreaming_test(): 11 | # 将待识别语音读取为bytes类型 12 | with wave.open(AUDIO_FILE, 'rb') as f: 13 | samprate = f.getframerate() 14 | sampwidth = f.getsampwidth() 15 | data = f.readframes(f.getnframes()) 16 | 17 | # 初始化一个Recognizer对象 18 | recognizer = Recognizer(sample_rate=samprate, 19 | bit_depth=8 * sampwidth) 20 | 21 | # 连接服务器 22 | print('connecting...') 23 | recognizer.connect(IP_PORT) 24 | 25 | # 请求执行单次语音识别 26 | print('recognizing...') 27 | result = recognizer.recognize(data, punctuation=True) 28 | 29 | # 打印单次识别结果 30 | print(f'result:{result}') 31 | 32 | recognizer.disconnect() 33 | 34 | 35 | if __name__ == '__main__': 36 | nonstreaming_test() 37 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Swish() activation function for Conformer.""" 17 | 18 | import torch 19 | 20 | 21 | class Swish(torch.nn.Module): 22 | """Construct an Swish object.""" 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | """Return Swish activation function.""" 26 | return x * torch.sigmoid(x) 27 | -------------------------------------------------------------------------------- /wenet_asr_client/cer.py: -------------------------------------------------------------------------------- 1 | def cer(s1, s2): 2 | """ 3 | 计算中文语音识别结果字错率(CER) 4 | 5 | 参数: 6 | s1: 语音识别结果字符串 7 | s2: 人工标注的标准答案字符串 8 | 9 | 返回值: 10 | 字错率数值 11 | """ 12 | # 将字符串转换为字符列表 13 | s1_chars = list(s1) 14 | s2_chars = list(s2) 15 | 16 | # 初始化二维数组,用于存储Levenshtein距离 17 | dp = [[0] * (len(s2_chars) + 1) for _ in range(len(s1_chars) + 1)] 18 | 19 | # 初始化第一行和第一列 20 | for i in range(len(s1_chars) + 1): 21 | dp[i][0] = i 22 | for j in range(len(s2_chars) + 1): 23 | dp[0][j] = j 24 | 25 | # 动态规划计算Levenshtein距离 26 | correct = 0 # 记录正确的字数 27 | for i in range(1, len(s1_chars) + 1): 28 | for j in range(1, len(s2_chars) + 1): 29 | cost = 0 if s1_chars[i - 1] == s2_chars[j - 1] else 1 30 | dp[i][j] = min(dp[i - 1][j] + 1, # 删除操作 31 | dp[i][j - 1] + 1, # 插入操作 32 | dp[i - 1][j - 1] + cost) # 替换操作 33 | if cost == 0: 34 | correct += 1 35 | 36 | # 计算字错率 37 | cer = dp[len(s1_chars)][len(s2_chars)] / len(s2_chars) 38 | 39 | # 返回字错率和正确字数数值 40 | return cer, correct 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeNet ASR Demo 2 | 3 | The project inplements an Automatic Speech Recognition (ASR) service based on WeNet model. The project consists of two subprojects, that are WeNet ASR Server project and WeNet ASR Client project. 4 | 5 | Here is the [official github page of WeNet project](https://github.com/wenet-e2e/wenet). 6 | 7 | ## WeNet ASR Server Project 8 | 9 | The server project implements a remote API service for WeNet ASR model call using gRPC framework. The project directory is *wenet_asr_server*. Please refer to *wenet_asr_server/README.md* for more information. 10 | 11 | ## WeNet ASR Client Project 12 | 13 | The client project provides python API implemented by a python class. The API encapsulates systematic usage of the remote API service provided by **WeNet ASR Server**. The API not only supports single-time model call, but also provides a real-time ASR calling method. The project directory is *wenet_asr_client*. Please refer to *wenet_asr_client/README.md* for more information. 14 | 15 | ## GRPC Message Prototype 16 | 17 | The *protos* directory defines the gRPC services and messages using Protocol Buffers. The *.proto* file is already compiled to generate python code and the code is adjusted to fit into the project, so you don't need to do it again by yourself. 18 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod, abstractproperty 2 | from typing import Dict, List, Tuple, Union 3 | 4 | T = Union[str, bytes] 5 | 6 | 7 | class BaseTokenizer(ABC): 8 | 9 | def tokenize(self, line: str) -> Tuple[List[T], List[int]]: 10 | tokens = self.text2tokens(line) 11 | ids = self.tokens2ids(tokens) 12 | return tokens, ids 13 | 14 | def detokenize(self, ids: List[int]) -> Tuple[str, List[T]]: 15 | tokens = self.ids2tokens(ids) 16 | text = self.tokens2text(tokens) 17 | return text, tokens 18 | 19 | @abstractmethod 20 | def text2tokens(self, line: str) -> List[T]: 21 | raise NotImplementedError("abstract method") 22 | 23 | @abstractmethod 24 | def tokens2text(self, tokens: List[T]) -> str: 25 | raise NotImplementedError("abstract method") 26 | 27 | @abstractmethod 28 | def tokens2ids(self, tokens: List[T]) -> List[int]: 29 | raise NotImplementedError("abstract method") 30 | 31 | @abstractmethod 32 | def ids2tokens(self, ids: List[int]) -> List[T]: 33 | raise NotImplementedError("abstract method") 34 | 35 | @abstractmethod 36 | def vocab_size(self) -> int: 37 | raise NotImplementedError("abstract method") 38 | 39 | @abstractproperty 40 | def symbol_table(self) -> Dict[T, int]: 41 | raise NotImplementedError("abstract method") 42 | -------------------------------------------------------------------------------- /protos/wenet_asr.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package wenetasr; 3 | 4 | service WenetASR { 5 | rpc Test (TextMessage) returns (TextMessage); 6 | rpc GetServerID (Empty) returns (TextMessage); 7 | rpc ReloadModel (ReloadModelRequest) returns (Empty); 8 | rpc Recognize (RecognizeRequest) returns (RecognizeResponse); 9 | rpc StreamingRecognize (stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse); 10 | rpc Punct (TextMessage) returns (TextMessage); 11 | } 12 | 13 | message Empty {} 14 | 15 | message TextMessage { 16 | string text = 1; 17 | } 18 | 19 | message ReloadModelRequest { 20 | bool asr = 1; 21 | string model = 2; 22 | string hotwords = 3; 23 | int32 context_score = 4; 24 | bool punctuation = 5; 25 | string punctuation_model = 6; 26 | } 27 | 28 | message RecognitionConfig { 29 | int32 sample_rate_hertz = 1; 30 | } 31 | 32 | message RecognizeRequest { 33 | RecognitionConfig config = 1; 34 | bytes data = 2; 35 | } 36 | 37 | // message RecognizeRequest { 38 | // int32 sample_rate = 1; 39 | // bytes data = 2; 40 | // } 41 | 42 | message RecognizeResponse { 43 | string transcript = 1; 44 | } 45 | 46 | message StreamingRecognizeRequest { 47 | oneof streaming_request { 48 | // RecognitionConfig config = 1; 49 | bytes data = 2; 50 | } 51 | } 52 | 53 | message StreamingRecognizeResponse { 54 | string transcript = 1; 55 | } -------------------------------------------------------------------------------- /wenet_asr_client/lib/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def rm_cn_punc(text): 4 | # 匹配中文标点的正则表达式 5 | chinese_punctuation_pattern = "[\u3000\u3001-\u3011\u2014\u2018-\u201D\u2022\u2026\u2030\u25EF,!?:;()——]+" 6 | # 使用正则表达式替换中文标点为空格 7 | cleaned_text = re.sub(chinese_punctuation_pattern, "", text) 8 | return cleaned_text 9 | 10 | 11 | def is_approximately_equal(char1, char2): 12 | # 检查两个字符是否完全相同 13 | if char1 == char2: 14 | return True 15 | 16 | # 检查两个字符是否为同一字母的不同大小写形式 17 | if char1.lower() == char2.lower(): 18 | return True 19 | 20 | # 定义数字的不同表达形式 21 | digit_forms = { 22 | '0': ['0', '0', '〇', '零'], 23 | '1': ['1', '1', '一'], 24 | '2': ['2', '2', '二'], 25 | '3': ['3', '3', '三'], 26 | '4': ['4', '4', '四'], 27 | '5': ['5', '5', '五'], 28 | '6': ['6', '6', '六'], 29 | '7': ['7', '7', '七'], 30 | '8': ['8', '8', '八'], 31 | '9': ['9', '9', '九'] 32 | } 33 | 34 | # 检查两个字符是否为同一数字的不同表达形式 35 | for digit, forms in digit_forms.items(): 36 | if char1 in forms and char2 in forms: 37 | return True 38 | 39 | # 若以上条件都不满足,则认为两个字符不近似相同 40 | return False 41 | 42 | 43 | if __name__ == '__main__': 44 | text = "这是一个示例,包含中文标点————:;,。!?《》……“”()" 45 | print(rm_cn_punc(text)) 46 | 47 | print(is_approximately_equal('A', 'a')) # True 48 | print(is_approximately_equal('1', '一')) # True 49 | print(is_approximately_equal('b', 'c')) # False -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Shaoshang Qi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | 17 | 18 | def override_config(configs, override_list): 19 | new_configs = copy.deepcopy(configs) 20 | for item in override_list: 21 | arr = item.split() 22 | if len(arr) != 2: 23 | print(f"the overrive {item} format not correct, skip it") 24 | continue 25 | keys = arr[0].split('.') 26 | s_configs = new_configs 27 | for i, key in enumerate(keys): 28 | if key not in s_configs: 29 | print(f"the overrive {item} format not correct, skip it") 30 | if i == len(keys) - 1: 31 | param_type = type(s_configs[key]) 32 | if param_type != bool: 33 | s_configs[key] = param_type(arr[1]) 34 | else: 35 | s_configs[key] = arr[1] in ['true', 'True'] 36 | print(f"override {arr[0]} with {arr[1]}") 37 | else: 38 | s_configs = s_configs[key] 39 | return new_configs 40 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class GlobalCMVN(torch.nn.Module): 19 | 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/paraformer/subsampling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | import torch 3 | from wenet.transformer.subsampling import BaseSubsampling 4 | 5 | 6 | class IdentitySubsampling(BaseSubsampling): 7 | """ Paraformer subsampling 8 | """ 9 | 10 | def __init__(self, idim: int, odim: int, dropout_rate: float, 11 | pos_enc_class: torch.nn.Module): 12 | super().__init__() 13 | _, _ = idim, odim 14 | self.right_context = 6 15 | self.subsampling_rate = 6 16 | self.pos_enc = pos_enc_class 17 | 18 | def forward( 19 | self, 20 | x: torch.Tensor, 21 | x_mask: torch.Tensor, 22 | offset: Union[torch.Tensor, int] = 0 23 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 24 | """Subsample x. 25 | 26 | Args: 27 | x (torch.Tensor): Input tensor (#batch, time, idim). 28 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 29 | 30 | Returns: 31 | torch.Tensor: Subsampled tensor (#batch, time', odim), 32 | where time' = time. 33 | torch.Tensor: Subsampled mask (#batch, 1, time'), 34 | where time' = time 35 | torch.Tensor: positional encoding 36 | 37 | """ 38 | # NOTE(Mddct): Paraformer starts from 1 39 | if isinstance(offset, torch.Tensor): 40 | offset = torch.add(offset, 1) 41 | else: 42 | offset = offset + 1 43 | x, pos_emb = self.pos_enc(x, offset) 44 | return x, pos_emb, x_mask 45 | 46 | def position_encoding(self, offset: Union[int, torch.Tensor], 47 | size: int) -> torch.Tensor: 48 | return self.pos_enc.position_encoding(offset + 1, size) 49 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Union 3 | from wenet.text.char_tokenizer import CharTokenizer 4 | from wenet.text.tokenize_utils import tokenize_by_bpe_model 5 | 6 | 7 | class BpeTokenizer(CharTokenizer): 8 | 9 | def __init__( 10 | self, 11 | bpe_model: Union[PathLike, str], 12 | symbol_table: Union[str, PathLike, Dict], 13 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 14 | split_with_space: bool = False, 15 | connect_symbol: str = '', 16 | unk='', 17 | ) -> None: 18 | super().__init__(symbol_table, non_lang_syms, split_with_space, 19 | connect_symbol, unk) 20 | self._model = bpe_model 21 | # NOTE(Mddct): multiprocessing.Process() issues 22 | # don't build sp here 23 | self.bpe_model = None 24 | 25 | def _build_sp(self): 26 | if self.bpe_model is None: 27 | import sentencepiece as spm 28 | self.bpe_model = spm.SentencePieceProcessor() 29 | self.bpe_model.load(self._model) 30 | 31 | def text2tokens(self, line: str) -> List[str]: 32 | self._build_sp() 33 | line = line.strip() 34 | if self.non_lang_syms_pattern is not None: 35 | parts = self.non_lang_syms_pattern.split(line.upper()) 36 | parts = [w for w in parts if len(w.strip()) > 0] 37 | else: 38 | parts = [line] 39 | 40 | tokens = [] 41 | for part in parts: 42 | if part in self.non_lang_syms: 43 | tokens.append(part) 44 | else: 45 | tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) 46 | return tokens 47 | 48 | def tokens2text(self, tokens: List[str]) -> str: 49 | self._build_sp() 50 | text = super().tokens2text(tokens) 51 | return text.replace("▁", ' ').strip() 52 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/README.md: -------------------------------------------------------------------------------- 1 | # Module Introduction 2 | 3 | Here is a brief introduction of each module(directory). 4 | 5 | * `bin`: training and recognition binaries 6 | * `dataset`: IO design 7 | * `utils`: common utils 8 | * `transformer`: the core of `WeNet`, in which the standard transformer/conformer is implemented. It contains the common blocks(backbone) of speech transformers. 9 | * transformer/attention.py: Standard multi head attention 10 | * transformer/embedding.py: Standard position encoding 11 | * transformer/positionwise_feed_forward.py: Standard feed forward in transformer 12 | * transformer/convolution.py: ConvolutionModule in Conformer model 13 | * transformer/subsampling.py: Subsampling implementation for speech task 14 | * `transducer`: transducer implementation 15 | * `squeezeformer`: squeezeformer implementation, please refer [paper](https://arxiv.org/pdf/2206.00888.pdf) 16 | * `efficient_conformer`: efficient conformer implementation, please refer [paper](https://arxiv.org/pdf/2109.01163.pdf) 17 | * `paraformer`: paraformer implementation, please refer [paper](https://arxiv.org/pdf/1905.11235.pdf) 18 | * `paraformer/cif.py`: Continuous Integrate-and-Fire implemented, please refer [paper](https://arxiv.org/pdf/1905.11235.pdf) 19 | * `branchformer`: branchformer implementation, please refer [paper](https://arxiv.org/abs/2207.02971) 20 | * `whisper`: whisper implementation, please refer [paper](https://arxiv.org/abs/2212.04356) 21 | * `ssl`: Self-supervised speech model implementation. e.g. wav2vec2, bestrq, w2vbert. 22 | * `ctl_model`: Enhancing the Unified Streaming and Non-streaming Model with with Contrastive Learning implementation [paper](https://arxiv.org/abs/2306.00755) 23 | 24 | `transducer`, `squeezeformer`, `efficient_conformer`, `branchformer` and `cif` are all based on `transformer`, 25 | they resue a lot of the common blocks of `tranformer`. 26 | 27 | **If you want to contribute your own x-former, please reuse the current code as much as possible**. 28 | 29 | 30 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/paraformer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Union 3 | from wenet.paraformer.search import paraformer_beautify_result 4 | from wenet.text.char_tokenizer import CharTokenizer 5 | from wenet.text.tokenize_utils import tokenize_by_seg_dict 6 | 7 | 8 | def read_seg_dict(path): 9 | seg_table = {} 10 | with open(path, 'r', encoding='utf8') as fin: 11 | for line in fin: 12 | arr = line.strip().split('\t') 13 | assert len(arr) == 2 14 | seg_table[arr[0]] = arr[1] 15 | return seg_table 16 | 17 | 18 | class ParaformerTokenizer(CharTokenizer): 19 | 20 | def __init__(self, 21 | symbol_table: Union[str, PathLike, Dict], 22 | seg_dict: Optional[Union[str, PathLike, Dict]] = None, 23 | split_with_space: bool = False, 24 | connect_symbol: str = '', 25 | unk='') -> None: 26 | super().__init__(symbol_table, None, split_with_space, connect_symbol, 27 | unk) 28 | self.seg_dict = seg_dict 29 | if seg_dict is not None and not isinstance(seg_dict, Dict): 30 | self.seg_dict = read_seg_dict(seg_dict) 31 | 32 | def text2tokens(self, line: str) -> List[str]: 33 | assert self.seg_dict is not None 34 | 35 | # TODO(Mddct): duplicated here, refine later 36 | line = line.strip() 37 | if self.non_lang_syms_pattern is not None: 38 | parts = self.non_lang_syms_pattern.split(line) 39 | parts = [w for w in parts if len(w.strip()) > 0] 40 | else: 41 | parts = [line] 42 | 43 | tokens = [] 44 | for part in parts: 45 | if part in self.non_lang_syms: 46 | tokens.append(part) 47 | else: 48 | tokens.extend(tokenize_by_seg_dict(self.seg_dict, part)) 49 | return tokens 50 | 51 | def tokens2text(self, tokens: List[str]) -> str: 52 | return paraformer_beautify_result(tokens) 53 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/hugging_face_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Union 3 | from wenet.text.base_tokenizer import BaseTokenizer, T as Type 4 | 5 | 6 | class HuggingFaceTokenizer(BaseTokenizer): 7 | 8 | def __init__(self, model: Union[str, PathLike], *args, **kwargs) -> None: 9 | # NOTE(Mddct): don't build here, pickle issues 10 | self.model = model 11 | self.tokenizer = None 12 | 13 | self.args = args 14 | self.kwargs = kwargs 15 | 16 | def __getstate__(self): 17 | state = self.__dict__.copy() 18 | del state['tokenizer'] 19 | return state 20 | 21 | def __setstate__(self, state): 22 | self.__dict__.update(state) 23 | recovery = {'tokenizer': None} 24 | self.__dict__.update(recovery) 25 | 26 | def _build_hugging_face(self): 27 | from transformers import AutoTokenizer 28 | if self.tokenizer is None: 29 | self.tokenizer = AutoTokenizer.from_pretrained( 30 | self.model, **self.kwargs) 31 | self.t2i = self.tokenizer.get_vocab() 32 | 33 | def text2tokens(self, line: str) -> List[Type]: 34 | self._build_hugging_face() 35 | return self.tokenizer.tokenize(line) 36 | 37 | def tokens2text(self, tokens: List[Type]) -> str: 38 | self._build_hugging_face() 39 | ids = self.tokens2ids(tokens) 40 | return self.tokenizer.decode(ids) 41 | 42 | def tokens2ids(self, tokens: List[Type]) -> List[int]: 43 | self._build_hugging_face() 44 | return self.tokenizer.convert_tokens_to_ids(tokens) 45 | 46 | def ids2tokens(self, ids: List[int]) -> List[Type]: 47 | self._build_hugging_face() 48 | return self.tokenizer.convert_ids_to_tokens(ids) 49 | 50 | def vocab_size(self) -> int: 51 | self._build_hugging_face() 52 | # TODO: we need special tokenize size in future 53 | return len(self.tokenizer) 54 | 55 | @property 56 | def symbol_table(self) -> Dict[Type, int]: 57 | self._build_hugging_face() 58 | return self.t2i 59 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transducer/search/greedy_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def basic_greedy_search( 7 | model: torch.nn.Module, 8 | encoder_out: torch.Tensor, 9 | encoder_out_lens: torch.Tensor, 10 | n_steps: int = 64, 11 | ) -> List[List[int]]: 12 | # fake padding 13 | padding = torch.zeros(1, 1).to(encoder_out.device) 14 | # sos 15 | pred_input_step = torch.tensor([model.blank]).reshape(1, 1) 16 | cache = model.predictor.init_state(1, 17 | method="zero", 18 | device=encoder_out.device) 19 | new_cache: List[torch.Tensor] = [] 20 | t = 0 21 | hyps = [] 22 | prev_out_nblk = True 23 | pred_out_step = None 24 | per_frame_max_noblk = n_steps 25 | per_frame_noblk = 0 26 | while t < encoder_out_lens: 27 | encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E] 28 | if prev_out_nblk: 29 | step_outs = model.predictor.forward_step(pred_input_step, padding, 30 | cache) # [1, 1, P] 31 | pred_out_step, new_cache = step_outs[0], step_outs[1] 32 | 33 | joint_out_step = model.joint(encoder_out_step, 34 | pred_out_step) # [1,1,v] 35 | joint_out_probs = joint_out_step.log_softmax(dim=-1) 36 | 37 | joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # [] 38 | if joint_out_max != model.blank: 39 | hyps.append(joint_out_max.item()) 40 | prev_out_nblk = True 41 | per_frame_noblk = per_frame_noblk + 1 42 | pred_input_step = joint_out_max.reshape(1, 1) 43 | # state_m, state_c = clstate_out_m, state_out_c 44 | cache = new_cache 45 | 46 | if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk: 47 | if joint_out_max == model.blank: 48 | prev_out_nblk = False 49 | # TODO(Mddct): make t in chunk for streamming 50 | # or t should't be too lang to predict none blank 51 | t = t + 1 52 | per_frame_noblk = 0 53 | 54 | return [hyps] 55 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def read_lists(list_file): 19 | lists = [] 20 | with open(list_file, 'r', encoding='utf8') as fin: 21 | for line in fin: 22 | lists.append(line.strip()) 23 | return lists 24 | 25 | 26 | def read_non_lang_symbols(non_lang_sym_path): 27 | """read non-linguistic symbol from file. 28 | 29 | The file format is like below: 30 | 31 | {NOISE}\n 32 | {BRK}\n 33 | ... 34 | 35 | 36 | Args: 37 | non_lang_sym_path: non-linguistic symbol file path, None means no any 38 | syms. 39 | 40 | """ 41 | if non_lang_sym_path is None: 42 | return [] 43 | else: 44 | syms = read_lists(non_lang_sym_path) 45 | non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") 46 | for sym in syms: 47 | if non_lang_syms_pattern.fullmatch(sym) is None: 48 | 49 | class BadSymbolFormat(Exception): 50 | pass 51 | 52 | raise BadSymbolFormat( 53 | "Non-linguistic symbols should be " 54 | "formatted in {xxx}//[xxx], consider" 55 | " modify '%s' to meet the requirment. " 56 | "More details can be found in discussions here : " 57 | "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) 58 | return syms 59 | 60 | 61 | def read_symbol_table(symbol_table_file): 62 | symbol_table = {} 63 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 64 | for line in fin: 65 | arr = line.strip().split() 66 | assert len(arr) == 2 67 | symbol_table[arr[0]] = int(arr[1]) 68 | return symbol_table 69 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/init_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) 2 | # (authors: Xingchen Song) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | 18 | from wenet.text.base_tokenizer import BaseTokenizer 19 | from wenet.text.bpe_tokenizer import BpeTokenizer 20 | from wenet.text.char_tokenizer import CharTokenizer 21 | from wenet.text.paraformer_tokenizer import ParaformerTokenizer 22 | from wenet.text.whisper_tokenizer import WhisperTokenizer 23 | 24 | 25 | def init_tokenizer(configs) -> BaseTokenizer: 26 | # TODO(xcsong): Forcefully read the 'tokenizer' attribute. 27 | tokenizer_type = configs.get("tokenizer", "char") 28 | if tokenizer_type == "whisper": 29 | tokenizer = WhisperTokenizer( 30 | multilingual=configs['tokenizer_conf']['is_multilingual'], 31 | num_languages=configs['tokenizer_conf']['num_languages']) 32 | elif tokenizer_type == "char": 33 | tokenizer = CharTokenizer( 34 | configs['tokenizer_conf']['symbol_table_path'], 35 | configs['tokenizer_conf']['non_lang_syms_path'], 36 | split_with_space=configs['tokenizer_conf'].get( 37 | 'split_with_space', False), 38 | connect_symbol=configs['tokenizer_conf'].get('connect_symbol', '')) 39 | elif tokenizer_type == "bpe": 40 | tokenizer = BpeTokenizer( 41 | configs['tokenizer_conf']['bpe_path'], 42 | configs['tokenizer_conf']['symbol_table_path'], 43 | configs['tokenizer_conf']['non_lang_syms_path'], 44 | split_with_space=configs['tokenizer_conf'].get( 45 | 'split_with_space', False)) 46 | elif tokenizer_type == 'paraformer': 47 | tokenizer = ParaformerTokenizer( 48 | symbol_table=configs['tokenizer_conf']['symbol_table_path'], 49 | seg_dict=configs['tokenizer_conf']['seg_dict_path']) 50 | else: 51 | raise NotImplementedError 52 | logging.info("use {} tokenizer".format(configs["tokenizer"])) 53 | 54 | return tokenizer 55 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/tokenize_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2023 Horizon Inc. (authors: Xingchen Song) 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def tokenize_by_bpe_model(sp, txt): 19 | return _tokenize_by_seg_dic_or_bpe_model(txt, sp=sp, upper=True) 20 | 21 | 22 | def tokenize_by_seg_dict(seg_dict, txt): 23 | return _tokenize_by_seg_dic_or_bpe_model(txt, 24 | seg_dict=seg_dict, 25 | upper=False) 26 | 27 | 28 | def _tokenize_by_seg_dic_or_bpe_model( 29 | txt, 30 | sp=None, 31 | seg_dict=None, 32 | upper=True, 33 | ): 34 | if sp is None: 35 | assert seg_dict is not None 36 | if seg_dict is None: 37 | assert sp is not None 38 | tokens = [] 39 | # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: 40 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 41 | pattern = re.compile(r'([\u4e00-\u9fff])') 42 | # Example: 43 | # txt = "你好 ITS'S OKAY 的" 44 | # chars = ["你", "好", " ITS'S OKAY ", "的"] 45 | chars = pattern.split(txt.upper() if upper else txt) 46 | mix_chars = [w for w in chars if len(w.strip()) > 0] 47 | for ch_or_w in mix_chars: 48 | # ch_or_w is a single CJK charater(i.e., "你"), do nothing. 49 | if pattern.fullmatch(ch_or_w) is not None: 50 | tokens.append(ch_or_w) 51 | # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), 52 | # encode ch_or_w using bpe_model. 53 | else: 54 | if sp is not None: 55 | for p in sp.encode_as_pieces(ch_or_w): 56 | tokens.append(p) 57 | else: 58 | for en_token in ch_or_w.split(): 59 | en_token = en_token.strip() 60 | if en_token in seg_dict: 61 | tokens.extend(seg_dict[en_token].split(' ')) 62 | else: 63 | tokens.append(en_token) 64 | 65 | return tokens 66 | -------------------------------------------------------------------------------- /wenet_asr_server/lib/audio_read.py: -------------------------------------------------------------------------------- 1 | import pyaudio 2 | import wave 3 | import time 4 | import numpy as np 5 | import torch 6 | # from typing import Union 7 | 8 | # CHUNK = 1024 # 每次读取的音频数据块大小 9 | # FORMAT = pyaudio.paInt16 # 音频格式 10 | # CHANNELS = 1 # 声道数 11 | 12 | def read_by_bytes(audio_file: str, length: int = None) -> tuple[torch.Tensor, int]: 13 | wav_file = wave.open(audio_file, 'rb') 14 | frame_rate = wav_file.getframerate() 15 | 16 | with open(audio_file, 'rb') as f: 17 | if length is None: 18 | data = f.read() 19 | else: 20 | data = f.read(length) 21 | 22 | return bytes2tensor(data), frame_rate 23 | 24 | def to_stream(audio_name: str, chunk=1024) -> tuple[pyaudio.Stream, int]: 25 | wf = wave.open(audio_name, 'rb') 26 | rate = wf.getframerate() 27 | 28 | # 初始化音频流 29 | p = pyaudio.PyAudio() 30 | stream = p.open(format=p.get_format_from_width(wf.getsampwidth()), 31 | channels=wf.getnchannels(), 32 | rate=rate, 33 | output=True, 34 | input=True) 35 | 36 | # # 实时发送音频数据流 37 | data = wf.readframes(chunk) 38 | while data: 39 | stream.write(data) # 播放音频 40 | data = wf.readframes(chunk) 41 | # time.sleep(len(data) / float(rate)) 42 | return stream, rate 43 | 44 | 45 | def bytes2tensor(data: bytes) -> torch.Tensor: 46 | waveform = np.frombuffer(data, dtype=np.int16) 47 | waveform = waveform.astype(np.float32) # 转换数据类型为32位浮点数 48 | waveform = torch.from_numpy(waveform) # 转换为torch Tensor 49 | return waveform.unsqueeze(0) # 添加batch维度,假设没有通道维度 50 | 51 | 52 | def send_audio_stream(): 53 | # 连接服务端 54 | server_address = ('localhost', 12345) # 服务端地址和端口 55 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 56 | client_socket.connect(server_address) 57 | 58 | # 打开音频文件 59 | wf = wave.open(WAVE_FILENAME, 'rb') 60 | 61 | # 初始化音频流 62 | p = pyaudio.PyAudio() 63 | stream = p.open(format=p.get_format_from_width(wf.getsampwidth()), 64 | channels=wf.getnchannels(), 65 | rate=wf.getframerate(), 66 | output=True) 67 | 68 | # 实时发送音频数据流 69 | data = wf.readframes(CHUNK) 70 | while data: 71 | client_socket.sendall(data) # 发送音频数据 72 | stream.write(data) # 播放音频 73 | data = wf.readframes(CHUNK) 74 | time.sleep(len(data) / float(RATE)) # 按音频播放速率延迟发送数据 75 | 76 | # 关闭连接 77 | client_socket.close() 78 | stream.stop_stream() 79 | stream.close() 80 | p.terminate() 81 | wf.close() 82 | 83 | if __name__ == "__main__": 84 | # send_audio_stream() 85 | pass -------------------------------------------------------------------------------- /wenet_asr_server/wenet/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | import os 20 | 21 | import torch 22 | import yaml 23 | 24 | from wenet.utils.init_model import init_model 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description='export your script model') 29 | parser.add_argument('--config', required=True, help='config file') 30 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 31 | parser.add_argument('--output_file', default=None, help='output file') 32 | parser.add_argument('--output_quant_file', 33 | default=None, 34 | help='output quantized model file') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | logging.basicConfig(level=logging.DEBUG, 42 | format='%(asctime)s %(levelname)s %(message)s') 43 | # No need gpu for model export 44 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 45 | 46 | with open(args.config, 'r') as fin: 47 | configs = yaml.load(fin, Loader=yaml.FullLoader) 48 | model, configs = init_model(args, configs) 49 | model.eval() 50 | print(model) 51 | # Export jit torch script model 52 | 53 | if args.output_file: 54 | script_model = torch.jit.script(model) 55 | script_model.save(args.output_file) 56 | print('Export model successfully, see {}'.format(args.output_file)) 57 | 58 | # Export quantized jit torch script model 59 | if args.output_quant_file: 60 | quantized_model = torch.quantization.quantize_dynamic( 61 | model, {torch.nn.Linear}, dtype=torch.qint8) 62 | print(quantized_model) 63 | script_quant_model = torch.jit.script(quantized_model) 64 | script_quant_model.save(args.output_quant_file) 65 | print('Export quantized model successfully, ' 66 | 'see {}'.format(args.output_quant_file)) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/squeezeformer/conv2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Conv2d Module with Valid Padding""" 15 | 16 | import torch.nn.functional as F 17 | from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional 18 | 19 | 20 | class Conv2dValid(_ConvNd): 21 | """ 22 | Conv2d operator for VALID mode padding. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | kernel_size: _size_2_t, 30 | stride: _size_2_t = 1, 31 | padding: Union[str, _size_2_t] = 0, 32 | dilation: _size_2_t = 1, 33 | groups: int = 1, 34 | bias: bool = True, 35 | padding_mode: str = 'zeros', # TODO: refine this type 36 | device=None, 37 | dtype=None, 38 | valid_trigx: bool = False, 39 | valid_trigy: bool = False) -> None: 40 | factory_kwargs = {'device': device, 'dtype': dtype} 41 | kernel_size_ = _pair(kernel_size) 42 | stride_ = _pair(stride) 43 | padding_ = padding if isinstance(padding, str) else _pair(padding) 44 | dilation_ = _pair(dilation) 45 | super(Conv2dValid, 46 | self).__init__(in_channels, out_channels, 47 | kernel_size_, stride_, padding_, dilation_, False, 48 | _pair(0), groups, bias, padding_mode, 49 | **factory_kwargs) 50 | self.valid_trigx = valid_trigx 51 | self.valid_trigy = valid_trigy 52 | 53 | def _conv_forward(self, input: Tensor, weight: Tensor, 54 | bias: Optional[Tensor]): 55 | validx, validy = 0, 0 56 | if self.valid_trigx: 57 | validx = (input.size(-2) * 58 | (self.stride[-2] - 1) - 1 + self.kernel_size[-2]) // 2 59 | if self.valid_trigy: 60 | validy = (input.size(-1) * 61 | (self.stride[-1] - 1) - 1 + self.kernel_size[-1]) // 2 62 | return F.conv2d(input, weight, bias, self.stride, (validx, validy), 63 | self.dilation, self.groups) 64 | 65 | def forward(self, input: Tensor) -> Tensor: 66 | return self._conv_forward(input, self.weight, self.bias) 67 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/efficient_conformer/subsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 2022 58.com(Wuba) Inc AI Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Subsampling layer definition.""" 17 | 18 | from typing import Tuple, Union 19 | 20 | import torch 21 | from wenet.transformer.subsampling import BaseSubsampling 22 | 23 | 24 | class Conv2dSubsampling2(BaseSubsampling): 25 | """Convolutional 2D subsampling (to 1/4 length). 26 | 27 | Args: 28 | idim (int): Input dimension. 29 | odim (int): Output dimension. 30 | dropout_rate (float): Dropout rate. 31 | 32 | """ 33 | 34 | def __init__(self, idim: int, odim: int, dropout_rate: float, 35 | pos_enc_class: torch.nn.Module): 36 | """Construct an Conv2dSubsampling4 object.""" 37 | super().__init__() 38 | self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), 39 | torch.nn.ReLU()) 40 | self.out = torch.nn.Sequential( 41 | torch.nn.Linear(odim * ((idim - 1) // 2), odim)) 42 | self.pos_enc = pos_enc_class 43 | # The right context for every conv layer is computed by: 44 | # (kernel_size - 1) * frame_rate_of_this_layer 45 | self.subsampling_rate = 2 46 | # 2 = (3 - 1) * 1 47 | self.right_context = 2 48 | 49 | def forward( 50 | self, 51 | x: torch.Tensor, 52 | x_mask: torch.Tensor, 53 | offset: Union[int, torch.Tensor] = 0 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 55 | """Subsample x. 56 | 57 | Args: 58 | x (torch.Tensor): Input tensor (#batch, time, idim). 59 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 60 | 61 | Returns: 62 | torch.Tensor: Subsampled tensor (#batch, time', odim), 63 | where time' = time // 2. 64 | torch.Tensor: Subsampled mask (#batch, 1, time'), 65 | where time' = time // 2. 66 | torch.Tensor: positional encoding 67 | 68 | """ 69 | x = x.unsqueeze(1) # (b, c=1, t, f) 70 | x = self.conv(x) 71 | b, c, t, f = x.size() 72 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 73 | x, pos_emb = self.pos_enc(x, offset) 74 | return x, pos_emb, x_mask[:, :, :-2:2] 75 | -------------------------------------------------------------------------------- /wenet_asr_client/transcribe_to_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wave 3 | import time 4 | from datetime import datetime 5 | 6 | from lib import wenet_asr_client 7 | 8 | IP_PORT = "192.168.81.10:50051" 9 | 10 | def cal_chunk(time_interval, framerate, sampwidth): 11 | # # 计算每个时间间隔内需要读取的样本数 12 | # samples_per_interval = time_interval * framerate 13 | 14 | # # 计算每个时间间隔内需要读取的字节数 15 | # bytes_per_interval = samples_per_interval * sampwidth 16 | 17 | # 返回每个时间间隔内需要读取的字节数 18 | return time_interval * framerate 19 | 20 | 21 | def read_long_audio(audio_dir, sep_dur=10): 22 | datastream = [] 23 | 24 | with wave.open(audio_dir, 'rb') as f: 25 | samprate = f.getframerate() 26 | sampwidth = f.getsampwidth() 27 | chunk = cal_chunk(sep_dur, samprate, sampwidth) 28 | chunk = 81920 29 | print(f'chunk: {chunk}') 30 | 31 | while True: 32 | dataframes = f.readframes(chunk) 33 | if not dataframes: 34 | break 35 | if len(dataframes) < chunk / 3: 36 | datastream[-1] += dataframes 37 | else: 38 | datastream.append(dataframes) 39 | 40 | return datastream, samprate, sampwidth 41 | 42 | 43 | def transcribe_to_txt(recognizer, audio_dir, filename): 44 | datastream, samprate, sampwidth = read_long_audio(audio_dir) 45 | print(len(datastream)) 46 | 47 | recognizer.sample_rate = samprate 48 | recognizer.bit_depth = sampwidth * 8 49 | 50 | print('recognizing...') 51 | text = '' 52 | for data in datastream: 53 | print('len', len(data)) 54 | result = recognizer.recognize(data) 55 | print(result) 56 | text += result 57 | print(f'result: {text}') 58 | 59 | try: 60 | with open(filename, 'w') as file: 61 | file.write(text) 62 | print(f"文件 '{filename}' 已成功创建并写入内容。") 63 | except Exception as e: 64 | print(f"写入文件 '{filename}' 时出错:{e}") 65 | 66 | 67 | if __name__ == '__main__': 68 | recognizer = wenet_asr_client.Recognizer( 69 | sample_rate=None, 70 | bit_depth=None, 71 | ) 72 | 73 | print('connecting...') 74 | recognizer.connect(IP_PORT) 75 | recognizer.reload_model( 76 | hotwords='default', 77 | context_score=6, 78 | ) 79 | 80 | # audio_dir = 'audio/samples3/1702257074_2632_70_0_1.wav' 81 | # transcribe_to_txt(recognizer, audio_dir, 'transcript/1702257074_2632_70_0_1.txt') 82 | 83 | samples = 'samples2' 84 | audio_path = f'audio/{samples}' 85 | txt_path = f'transcript_3/{samples}' 86 | # txt_path = 'transcript/samples3' 87 | if not os.path.exists(txt_path): 88 | os.makedirs(txt_path) 89 | for filename in os.listdir(audio_path): 90 | if not filename.endswith('.wav'): 91 | continue 92 | print(f'\ncurrent: {filename}') 93 | audio_dir = os.path.join(audio_path, filename) 94 | transcribe_to_txt(recognizer, audio_dir, os.path.join(txt_path, filename.strip('.wav') + '.txt')) 95 | -------------------------------------------------------------------------------- /wenet_asr_server/README.md: -------------------------------------------------------------------------------- 1 | # WeNet ASR Server 2 | 3 | The server project implements an API service for WeNet ASR model call using gRPC as the service framework. 4 | 5 | ## Environment Installation 6 | 7 | Run the following command to install python packages in a newly created conda environment: 8 | 9 | ```bash 10 | pip install -r requirements.txt # Note: run this command in wenet_asr_server directory. 11 | ``` 12 | 13 | The evironment consists of 3 parts:
14 | · evironment required by *WeNet*
15 | · evironment required by *PPASR* punctuation model
16 | · evironment required by *gRPC* framework 17 | 18 | You can refer to the following resources if having problems installing environment.
19 | [WeNet](https://github.com/wenet-e2e/wenet) - official github page
20 | [给语音识别文本加上标点符号](https://blog.csdn.net/qq_33200967/article/details/122474859) - CSDN blog
21 | [PPASR](https://github.com/yeyupiaoling/PPASR) - official github page
22 | [PaddleSpeech](https://github.com/PaddlePaddle/PaddleSpeech) - official github page 23 | 24 | 44 | 45 | ## Service Running 46 | 47 | ```bash 48 | python main.py \ 49 | --model \ 50 | --pun-model \ 51 | --hotwords \ 52 | --context-score \ 53 | --port 54 | ``` 55 | 56 | Command line arguments of *main.py*: 57 | 58 | >--model: If you have downloaded a WeNet model before, you can configure it with this argument. If not configured, the WeNet model will be auto-downloaded. 59 | > 60 | >--pun-model: If you have downloaded a PPASR punctuation model before, you can configure it with this argument. If not configured, the PPASR punctuation model will be auto-downloaded. You can download the model [here](https://download.csdn.net/download/qq_33200967/75664996) manually. 61 | > 62 | >--hotwords: You can configure a path to a *.txt* file in which each line records a hotword. The hotword is more probably recognized by WeNet model. See details of the hotword mechanism in the [official documents of WeNet.](https://wenet.org.cn/wenet/context.html) 63 | > 64 | >--context-score: Additional score of each character in a hotword when processing beam-search. 3.0 by default. See details of the hotword mechanism in the [official documents of WeNet.](https://wenet.org.cn/wenet/context.html) 65 | > 66 | >--port: Port of the service process. 50051 by default. 67 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/char_tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from os import PathLike 4 | from typing import Dict, List, Optional, Union 5 | from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols 6 | from wenet.text.base_tokenizer import BaseTokenizer 7 | 8 | 9 | class CharTokenizer(BaseTokenizer): 10 | 11 | def __init__( 12 | self, 13 | symbol_table: Union[str, PathLike, Dict], 14 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 15 | split_with_space: bool = False, 16 | connect_symbol: str = '', 17 | unk='', 18 | ) -> None: 19 | self.non_lang_syms_pattern = None 20 | if non_lang_syms is not None: 21 | self.non_lang_syms_pattern = re.compile( 22 | r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") 23 | if not isinstance(symbol_table, Dict): 24 | self._symbol_table = read_symbol_table(symbol_table) 25 | else: 26 | # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} 27 | self._symbol_table = symbol_table 28 | if not isinstance(non_lang_syms, List): 29 | self.non_lang_syms = read_non_lang_symbols(non_lang_syms) 30 | else: 31 | # non_lang_syms=["{NOISE}"] 32 | self.non_lang_syms = non_lang_syms 33 | self.char_dict = {v: k for k, v in self._symbol_table.items()} 34 | self.split_with_space = split_with_space 35 | self.connect_symbol = connect_symbol 36 | self.unk = unk 37 | 38 | def text2tokens(self, line: str) -> List[str]: 39 | line = line.strip() 40 | if self.non_lang_syms_pattern is not None: 41 | parts = self.non_lang_syms_pattern.split(line.upper()) 42 | parts = [w for w in parts if len(w.strip()) > 0] 43 | else: 44 | parts = [line] 45 | 46 | tokens = [] 47 | for part in parts: 48 | if part in self.non_lang_syms: 49 | tokens.append(part) 50 | else: 51 | if self.split_with_space: 52 | part = part.split(" ") 53 | for ch in part: 54 | if ch == ' ': 55 | ch = "▁" 56 | tokens.append(ch) 57 | return tokens 58 | 59 | def tokens2text(self, tokens: List[str]) -> str: 60 | return self.connect_symbol.join(tokens) 61 | 62 | def tokens2ids(self, tokens: List[str]) -> List[int]: 63 | ids = [] 64 | for ch in tokens: 65 | if ch in self._symbol_table: 66 | ids.append(self._symbol_table[ch]) 67 | elif self.unk in self._symbol_table: 68 | ids.append(self._symbol_table[self.unk]) 69 | return ids 70 | 71 | def ids2tokens(self, ids: List[int]) -> List[str]: 72 | content = [self.char_dict[w] for w in ids] 73 | return content 74 | 75 | def vocab_size(self) -> int: 76 | return len(self.char_dict) 77 | 78 | @property 79 | def symbol_table(self) -> Dict[str, int]: 80 | return self._symbol_table 81 | -------------------------------------------------------------------------------- /wenet_asr_client/cal_cer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | from lib.tester import cer_levenshtein 5 | 6 | name_map = { 7 | '075589939388-20230120105407-szBZ202301200044':'2_1', 8 | '075589939388-20230120144717-szBZ202301200064':'2_2', 9 | '075589939388-20230208174118-szBZ202302080210':'2_3', 10 | '075589939388-20230209142552-szBZ202302090102':'2_4', 11 | '075589939388-20230210115147-szBZ202302100143':'2_5', 12 | '1702257074_2632_70_0_1':'3_1', 13 | '1702257572_14763_56_0_1':'3_2', 14 | '1702262200_10509_67_0_1':'3_3', 15 | '1702434389_2340_152_0_1':'3_4', 16 | '1702435068_20113_149_0_1':'3_5', 17 | '1702435792_16578_159_0_1':'3_6', 18 | } 19 | 20 | def cal_folder(dir, df, hotwords=True, transcript_dir: str = None): 21 | if transcript_dir is None: 22 | transcript_dir = 'transcript' if hotwords else 'transcript_without_hotwords' 23 | for filename in os.listdir(os.path.join(transcript_dir, dir)): 24 | print(f'current: {filename}') 25 | with open(os.path.join(transcript_dir, dir, filename), 'r') as f: 26 | s1 = f.read() 27 | with open(os.path.join('audio', dir, filename), 'r') as f: 28 | s2 = f.read() 29 | 30 | filename = filename.strip('.txt') 31 | index = name_map[filename] 32 | df.loc[index, 'filename'] = filename 33 | df.loc[index, 'text_length'] = len(s2) 34 | if hotwords: 35 | df.loc[index, 'transcript_length'] = len(s1) 36 | _, df.loc[index, 'cer'], (insert, delete, replace) = cer_levenshtein(s1, s2) 37 | df.loc[index, 'correct_length'] = insert + replace 38 | else: 39 | df.loc[index, 'transcript_length_without_hotwords'] = len(s1) 40 | _, df.loc[index, 'cer_without_hotwords'], (insert, delete, replace) = cer_levenshtein(s1, s2) 41 | df.loc[index, 'correct_length_without_hotwords'] = len(s1) - delete - replace 42 | print(df) 43 | 44 | 45 | if __name__ == '__main__': 46 | save_file = 'cer_without_hotwords.csv' 47 | transcript_dir = 'transcript_without_hotwords' 48 | 49 | if not os.path.exists(save_file): 50 | # 创建索引 51 | index = [] 52 | for i in range(2, 4): 53 | for j in range(1, 7 if i == 3 else 6): 54 | index.append(f"{i}_{j}") 55 | row = len(index) 56 | data = { 57 | 'filename': [None] * row, 58 | 'text_length': [None] * row, 59 | 'transcript_length': [None] * row, 60 | 'correct_length': [None] * row, 61 | 'cer': [None] * row, 62 | # 'transcript_length_without_hotwords': [None] * row, 63 | # 'correct_length_without_hotwords': [None] * row, 64 | # 'cer_without_hotwords': [None] * row, 65 | } 66 | df = pd.DataFrame(data, index=index) 67 | else: 68 | df = pd.read_csv(save_file, index_col=0) 69 | 70 | cal_folder('samples2', df, transcript_dir=transcript_dir) 71 | cal_folder('samples3', df, transcript_dir=transcript_dir) 72 | # cal_folder('samples2', df, hotwords=False) 73 | # cal_folder('samples3', df, hotwords=False) 74 | 75 | df.to_csv(save_file) 76 | 77 | 78 | -------------------------------------------------------------------------------- /wenet_asr_client/lib/wenet_asr_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: wenet_asr.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fwenet_asr.proto\x12\x08wenetasr\"\x07\n\x05\x45mpty\"\x1b\n\x0bTextMessage\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x89\x01\n\x12ReloadModelRequest\x12\x0b\n\x03\x61sr\x18\x01 \x01(\x08\x12\r\n\x05model\x18\x02 \x01(\t\x12\x10\n\x08hotwords\x18\x03 \x01(\t\x12\x15\n\rcontext_score\x18\x04 \x01(\x05\x12\x13\n\x0bpunctuation\x18\x05 \x01(\x08\x12\x19\n\x11punctuation_model\x18\x06 \x01(\t\".\n\x11RecognitionConfig\x12\x19\n\x11sample_rate_hertz\x18\x01 \x01(\x05\"M\n\x10RecognizeRequest\x12+\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x1b.wenetasr.RecognitionConfig\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\'\n\x11RecognizeResponse\x12\x12\n\ntranscript\x18\x01 \x01(\t\"@\n\x19StreamingRecognizeRequest\x12\x0e\n\x04\x64\x61ta\x18\x02 \x01(\x0cH\x00\x42\x13\n\x11streaming_request\"0\n\x1aStreamingRecognizeResponse\x12\x12\n\ntranscript\x18\x01 \x01(\t2\x97\x03\n\x08WenetASR\x12\x34\n\x04Test\x12\x15.wenetasr.TextMessage\x1a\x15.wenetasr.TextMessage\x12\x35\n\x0bGetServerID\x12\x0f.wenetasr.Empty\x1a\x15.wenetasr.TextMessage\x12<\n\x0bReloadModel\x12\x1c.wenetasr.ReloadModelRequest\x1a\x0f.wenetasr.Empty\x12\x44\n\tRecognize\x12\x1a.wenetasr.RecognizeRequest\x1a\x1b.wenetasr.RecognizeResponse\x12\x63\n\x12StreamingRecognize\x12#.wenetasr.StreamingRecognizeRequest\x1a$.wenetasr.StreamingRecognizeResponse(\x01\x30\x01\x12\x35\n\x05Punct\x12\x15.wenetasr.TextMessage\x1a\x15.wenetasr.TextMessageb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'wenet_asr_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_EMPTY']._serialized_start=29 25 | _globals['_EMPTY']._serialized_end=36 26 | _globals['_TEXTMESSAGE']._serialized_start=38 27 | _globals['_TEXTMESSAGE']._serialized_end=65 28 | _globals['_RELOADMODELREQUEST']._serialized_start=68 29 | _globals['_RELOADMODELREQUEST']._serialized_end=205 30 | _globals['_RECOGNITIONCONFIG']._serialized_start=207 31 | _globals['_RECOGNITIONCONFIG']._serialized_end=253 32 | _globals['_RECOGNIZEREQUEST']._serialized_start=255 33 | _globals['_RECOGNIZEREQUEST']._serialized_end=332 34 | _globals['_RECOGNIZERESPONSE']._serialized_start=334 35 | _globals['_RECOGNIZERESPONSE']._serialized_end=373 36 | _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_start=375 37 | _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_end=439 38 | _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_start=441 39 | _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_end=489 40 | _globals['_WENETASR']._serialized_start=492 41 | _globals['_WENETASR']._serialized_end=899 42 | # @@protoc_insertion_point(module_scope) 43 | -------------------------------------------------------------------------------- /wenet_asr_server/lib/wenet_asr_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: wenet_asr.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fwenet_asr.proto\x12\x08wenetasr\"\x07\n\x05\x45mpty\"\x1b\n\x0bTextMessage\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x89\x01\n\x12ReloadModelRequest\x12\x0b\n\x03\x61sr\x18\x01 \x01(\x08\x12\r\n\x05model\x18\x02 \x01(\t\x12\x10\n\x08hotwords\x18\x03 \x01(\t\x12\x15\n\rcontext_score\x18\x04 \x01(\x05\x12\x13\n\x0bpunctuation\x18\x05 \x01(\x08\x12\x19\n\x11punctuation_model\x18\x06 \x01(\t\".\n\x11RecognitionConfig\x12\x19\n\x11sample_rate_hertz\x18\x01 \x01(\x05\"M\n\x10RecognizeRequest\x12+\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\x1b.wenetasr.RecognitionConfig\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\'\n\x11RecognizeResponse\x12\x12\n\ntranscript\x18\x01 \x01(\t\"@\n\x19StreamingRecognizeRequest\x12\x0e\n\x04\x64\x61ta\x18\x02 \x01(\x0cH\x00\x42\x13\n\x11streaming_request\"0\n\x1aStreamingRecognizeResponse\x12\x12\n\ntranscript\x18\x01 \x01(\t2\x97\x03\n\x08WenetASR\x12\x34\n\x04Test\x12\x15.wenetasr.TextMessage\x1a\x15.wenetasr.TextMessage\x12\x35\n\x0bGetServerID\x12\x0f.wenetasr.Empty\x1a\x15.wenetasr.TextMessage\x12<\n\x0bReloadModel\x12\x1c.wenetasr.ReloadModelRequest\x1a\x0f.wenetasr.Empty\x12\x44\n\tRecognize\x12\x1a.wenetasr.RecognizeRequest\x1a\x1b.wenetasr.RecognizeResponse\x12\x63\n\x12StreamingRecognize\x12#.wenetasr.StreamingRecognizeRequest\x1a$.wenetasr.StreamingRecognizeResponse(\x01\x30\x01\x12\x35\n\x05Punct\x12\x15.wenetasr.TextMessage\x1a\x15.wenetasr.TextMessageb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'wenet_asr_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_EMPTY']._serialized_start=29 25 | _globals['_EMPTY']._serialized_end=36 26 | _globals['_TEXTMESSAGE']._serialized_start=38 27 | _globals['_TEXTMESSAGE']._serialized_end=65 28 | _globals['_RELOADMODELREQUEST']._serialized_start=68 29 | _globals['_RELOADMODELREQUEST']._serialized_end=205 30 | _globals['_RECOGNITIONCONFIG']._serialized_start=207 31 | _globals['_RECOGNITIONCONFIG']._serialized_end=253 32 | _globals['_RECOGNIZEREQUEST']._serialized_start=255 33 | _globals['_RECOGNIZEREQUEST']._serialized_end=332 34 | _globals['_RECOGNIZERESPONSE']._serialized_start=334 35 | _globals['_RECOGNIZERESPONSE']._serialized_end=373 36 | _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_start=375 37 | _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_end=439 38 | _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_start=441 39 | _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_end=489 40 | _globals['_WENETASR']._serialized_start=492 41 | _globals['_WENETASR']._serialized_end=899 42 | # @@protoc_insertion_point(module_scope) 43 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | 18 | import numpy as np 19 | 20 | 21 | def _load_json_cmvn(json_cmvn_file): 22 | """ Load the json format cmvn stats file and calculate cmvn 23 | 24 | Args: 25 | json_cmvn_file: cmvn stats file in json format 26 | 27 | Returns: 28 | a numpy array of [means, vars] 29 | """ 30 | with open(json_cmvn_file) as f: 31 | cmvn_stats = json.load(f) 32 | 33 | means = cmvn_stats['mean_stat'] 34 | variance = cmvn_stats['var_stat'] 35 | count = cmvn_stats['frame_num'] 36 | for i in range(len(means)): 37 | means[i] /= count 38 | variance[i] = variance[i] / count - means[i] * means[i] 39 | if variance[i] < 1.0e-20: 40 | variance[i] = 1.0e-20 41 | variance[i] = 1.0 / math.sqrt(variance[i]) 42 | cmvn = np.array([means, variance]) 43 | return cmvn 44 | 45 | 46 | def _load_kaldi_cmvn(kaldi_cmvn_file): 47 | """ Load the kaldi format cmvn stats file and calculate cmvn 48 | 49 | Args: 50 | kaldi_cmvn_file: kaldi text style global cmvn file, which 51 | is generated by: 52 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 53 | 54 | Returns: 55 | a numpy array of [means, vars] 56 | """ 57 | means = [] 58 | variance = [] 59 | with open(kaldi_cmvn_file, 'r') as fid: 60 | # kaldi binary file start with '\0B' 61 | if fid.read(2) == '\0B': 62 | logging.error('kaldi cmvn binary file is not supported, please ' 63 | 'recompute it by: compute-cmvn-stats --binary=false ' 64 | ' scp:feats.scp global_cmvn') 65 | sys.exit(1) 66 | fid.seek(0) 67 | arr = fid.read().split() 68 | assert (arr[0] == '[') 69 | assert (arr[-2] == '0') 70 | assert (arr[-1] == ']') 71 | feat_dim = int((len(arr) - 2 - 2) / 2) 72 | for i in range(1, feat_dim + 1): 73 | means.append(float(arr[i])) 74 | count = float(arr[feat_dim + 1]) 75 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 76 | variance.append(float(arr[i])) 77 | 78 | for i in range(len(means)): 79 | means[i] /= count 80 | variance[i] = variance[i] / count - means[i] * means[i] 81 | if variance[i] < 1.0e-20: 82 | variance[i] = 1.0e-20 83 | variance[i] = 1.0 / math.sqrt(variance[i]) 84 | cmvn = np.array([means, variance]) 85 | return cmvn 86 | 87 | 88 | def load_cmvn(cmvn_file, is_json): 89 | if is_json: 90 | cmvn = _load_json_cmvn(cmvn_file) 91 | else: 92 | cmvn = _load_kaldi_cmvn(cmvn_file) 93 | return cmvn[0], cmvn[1] 94 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/squeezeformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 2022 Ximalaya Inc (Yuguang Yang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Positionwise feed forward layer definition.""" 17 | 18 | import torch 19 | 20 | 21 | class PositionwiseFeedForward(torch.nn.Module): 22 | """Positionwise feed forward layer. 23 | 24 | FeedForward are appied on each position of the sequence. 25 | The output dim is same with the input dim. 26 | 27 | Args: 28 | idim (int): Input dimenstion. 29 | hidden_units (int): The number of hidden units. 30 | dropout_rate (float): Dropout rate. 31 | activation (torch.nn.Module): Activation function 32 | """ 33 | 34 | def __init__(self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | adaptive_scale: bool = False, 40 | init_weights: bool = False): 41 | """Construct a PositionwiseFeedForward object.""" 42 | super(PositionwiseFeedForward, self).__init__() 43 | self.idim = idim 44 | self.hidden_units = hidden_units 45 | self.w_1 = torch.nn.Linear(idim, hidden_units) 46 | self.activation = activation 47 | self.dropout = torch.nn.Dropout(dropout_rate) 48 | self.w_2 = torch.nn.Linear(hidden_units, idim) 49 | self.ada_scale = None 50 | self.ada_bias = None 51 | self.adaptive_scale = adaptive_scale 52 | self.ada_scale = torch.nn.Parameter(torch.ones([1, 1, idim]), 53 | requires_grad=adaptive_scale) 54 | self.ada_bias = torch.nn.Parameter(torch.zeros([1, 1, idim]), 55 | requires_grad=adaptive_scale) 56 | if init_weights: 57 | self.init_weights() 58 | 59 | def init_weights(self): 60 | ffn1_max = self.idim**-0.5 61 | ffn2_max = self.hidden_units**-0.5 62 | torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) 63 | torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) 64 | torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) 65 | torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) 66 | 67 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 68 | """Forward function. 69 | 70 | Args: 71 | xs: input tensor (B, L, D) 72 | Returns: 73 | output tensor, (B, L, D) 74 | """ 75 | if self.adaptive_scale: 76 | xs = self.ada_scale * xs + self.ada_bias 77 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 78 | -------------------------------------------------------------------------------- /wenet_asr_client/README.md: -------------------------------------------------------------------------------- 1 | # WeNet ASR Client 2 | 3 | The client project provides python API implemented by a python class. Encapsulating systematic usage of the API service provided by **WeNet ASR Server**, the API not only supports single-time model call, but also provides a real-time ASR calling method. 4 | 5 | ## Environment Installation 6 | 7 | Run the following command to install python packages in a newly created conda environment: 8 | 9 | ```bash 10 | pip install -r requirements.txt # Note: run this command in wenet_asr_client directory. 11 | ``` 12 | 13 | ## API Usage 14 | 15 | > Before you start using *WeNet ASR Client*, you need to run the service provided by *WeNet ASR Server* on a server. (Please refer to *wenet_asr_server/README.md*.) After that, you can follow the instruction below to use ASR service through client API. 16 | 17 | Import the **Recognizer** class from *recognizer.py*, and instantiate a *recognizer* instance. 18 | 19 | ```python 20 | from recognizer import Recognizer 21 | recognizer = Recognizer(sample_rate=samprate, bit_depth=8 * sampwidth) 22 | ``` 23 | 24 | Class **Recognizer** has 4 initialization arguments. 25 | > **sample_rate**: *int*, the sample rate of the audio you need to transfer.
26 | > **bit_depth**: *int*, the bit depth (8 times sample width) of the audio you need to transfer.
27 | > **update_duration**: *int*, time intervals of text truncation and fix. Only takes effect in real-time call.
28 | > **puncting_len**: *int*, truncation length to fix the punctuation prediction. Only takes effect in real-time call. 29 | 30 | The properties can also be configured or modified later. Only mono audio is supported. If you need to transcribe multichannel audio, convert to mono before input. 31 | 32 | Next, connect to the server. 33 | ```python 34 | recognizer.connect(IP_PORT) 35 | ``` 36 | The format of the argument IP_PORT is "{ip}:{port}". Example: "192.168.81.10:50051" 37 | 38 | ### Single-time Call 39 | 40 | ```python 41 | recognizer.recognize(audio, punctuation=True) 42 | ``` 43 | 44 | *Recognizer.recognize()* method accepts two arguments. 45 | > **audio**: *bytes*. The audio data of type *bytes*.
46 | > **punctuation**: *bool*. If true, it will return texts with punctuation prediction. 47 | 48 | To read an audio file as type *bytes* and get audio information at the same time, you can use *wave* module of python. 49 | ```python 50 | import wave 51 | with wave.open(AUDIO_FILE, 'rb') as f: 52 | samprate = f.getframerate() # get sample rate of the audio 53 | sampwidth = f.getsampwidth() # get sample width of the audio 54 | audio = f.readframes(f.getnframes()) # read audio as bytes 55 | ``` 56 | 57 | *example_single.py* provides an example usage of single-time ASR call. You can refer to it if confusing about single-time call usage. 58 | 59 | ### Real-time Call 60 | 61 | ```python 62 | recognizer.start_streaming() # start streaming mode 63 | 64 | while True: 65 | chunk = wait_for_data_receive() # Waiting for audio stream receiving 66 | 67 | recognizer.input(chunk) # input the new data to the recognizer, and recognizer will autoexec real-time ASR logic. 68 | 69 | result = recognizer.result # get current result 70 | print(result["text"]) 71 | ``` 72 | 73 | *example_streaming.py* provides an example usage of real-time ASR call by simulating real-time audio stream using an audio file. You can refer to it if confusing about real-time call usage. 74 | 75 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/cli/transcribe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | from wenet.cli.paraformer_model import load_model as load_paraformer 18 | from wenet.cli.model import load_model 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser(description='') 23 | parser.add_argument('audio_file', help='audio file to transcribe') 24 | parser.add_argument('-l', 25 | '--language', 26 | choices=[ 27 | 'chinese', 28 | 'english', 29 | ], 30 | default='chinese', 31 | help='language type') 32 | parser.add_argument('-m', 33 | '--model_dir', 34 | default=None, 35 | help='specify your own model dir') 36 | parser.add_argument('-g', 37 | '--gpu', 38 | type=int, 39 | default='-1', 40 | help='gpu id to decode, default is cpu.') 41 | parser.add_argument('-t', 42 | '--show_tokens_info', 43 | action='store_true', 44 | help='whether to output token(word) level information' 45 | ', such times/confidence') 46 | parser.add_argument('--align', 47 | action='store_true', 48 | help='force align the input audio and transcript') 49 | parser.add_argument('--label', type=str, help='the input label to align') 50 | parser.add_argument('--paraformer', 51 | action='store_true', 52 | help='whether to use the best chinese model') 53 | parser.add_argument('--beam', type=int, default=5, help="beam size") 54 | parser.add_argument('--context_path', 55 | type=str, 56 | default=None, 57 | help='context list file') 58 | parser.add_argument('--context_score', 59 | type=float, 60 | default=6.0, 61 | help='context score') 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def main(): 67 | args = get_args() 68 | 69 | if args.paraformer: 70 | model = load_paraformer(args.model_dir, args.gpu) 71 | else: 72 | model = load_model(args.language, args.model_dir, args.gpu, args.beam, 73 | args.context_path, args.context_score) 74 | if args.align: 75 | result = model.align(args.audio_file, args.label) 76 | else: 77 | result = model.transcribe(args.audio_file, args.show_tokens_info) 78 | print(result) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/cli/paraformer_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchaudio 5 | import torchaudio.compliance.kaldi as kaldi 6 | 7 | from wenet.cli.hub import Hub 8 | from wenet.paraformer.search import (gen_timestamps_from_peak, 9 | paraformer_greedy_search) 10 | from wenet.text.paraformer_tokenizer import ParaformerTokenizer 11 | 12 | 13 | class Paraformer: 14 | 15 | def __init__(self, 16 | model_dir: str, 17 | device: int = -1, 18 | resample_rate: int = 16000) -> None: 19 | 20 | model_path = os.path.join(model_dir, 'final.zip') 21 | units_path = os.path.join(model_dir, 'units.txt') 22 | self.model = torch.jit.load(model_path) 23 | self.resample_rate = resample_rate 24 | if device >= 0: 25 | device = 'cuda:{}'.format(device) 26 | else: 27 | device = 'cpu' 28 | self.device = torch.device(device) 29 | self.model = self.model.to(self.device) 30 | self.tokenizer = ParaformerTokenizer(symbol_table=units_path) 31 | 32 | def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: 33 | waveform, sample_rate = torchaudio.load(audio_file, normalize=False) 34 | waveform = waveform.to(torch.float).to(self.device) 35 | if sample_rate != self.resample_rate: 36 | waveform = torchaudio.transforms.Resample( 37 | orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) 38 | feats = kaldi.fbank(waveform, 39 | num_mel_bins=80, 40 | frame_length=25, 41 | frame_shift=10, 42 | energy_floor=0.0, 43 | sample_frequency=self.resample_rate) 44 | feats = feats.unsqueeze(0) 45 | feats_lens = torch.tensor([feats.size(1)], 46 | dtype=torch.int64, 47 | device=feats.device) 48 | 49 | decoder_out, token_num, tp_alphas = self.model.forward_paraformer( 50 | feats, feats_lens) 51 | cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) 52 | res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0] 53 | result = {} 54 | result['confidence'] = res.confidence 55 | result['text'] = self.tokenizer.detokenize(res.tokens)[0] 56 | if tokens_info: 57 | tokens_info = [] 58 | times = gen_timestamps_from_peak(res.times, 59 | num_frames=tp_alphas.size(1), 60 | frame_rate=0.02) 61 | 62 | for i, x in enumerate(res.tokens): 63 | tokens_info.append({ 64 | 'token': self.tokenizer.char_dict[x], 65 | 'start': times[i][0], 66 | 'end': times[i][1], 67 | 'confidence': res.tokens_confidence[i] 68 | }) 69 | result['tokens'] = tokens_info 70 | 71 | return result 72 | 73 | def align(self, audio_file: str, label: str) -> dict: 74 | raise NotImplementedError("Align is currently not supported") 75 | 76 | 77 | def load_model(model_dir: str = None, gpu: int = -1) -> Paraformer: 78 | if model_dir is None: 79 | model_dir = Hub.get_model_by_lang('paraformer') 80 | return Paraformer(model_dir, gpu) 81 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright [2023-11-28] 4 | import torch 5 | from torch.nn import BatchNorm1d, LayerNorm 6 | from wenet.paraformer.embedding import ParaformerPositinoalEncoding 7 | from wenet.transformer.norm import RMSNorm 8 | from wenet.transformer.positionwise_feed_forward import ( 9 | GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) 10 | 11 | from wenet.transformer.swish import Swish 12 | from wenet.transformer.subsampling import ( 13 | LinearNoSubsampling, 14 | EmbedinigNoSubsampling, 15 | Conv1dSubsampling2, 16 | Conv2dSubsampling4, 17 | Conv2dSubsampling6, 18 | Conv2dSubsampling8, 19 | StackNFramesSubsampling, 20 | ) 21 | from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 22 | from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 23 | from wenet.transformer.embedding import (PositionalEncoding, 24 | RelPositionalEncoding, 25 | WhisperPositionalEncoding, 26 | LearnablePositionalEncoding, 27 | NoPositionalEncoding) 28 | from wenet.transformer.attention import (MultiHeadedAttention, 29 | MultiHeadedCrossAttention, 30 | RelPositionMultiHeadedAttention, 31 | ShawRelPositionMultiHeadedAttention) 32 | from wenet.efficient_conformer.attention import ( 33 | GroupedRelPositionMultiHeadedAttention) 34 | 35 | WENET_ACTIVATION_CLASSES = { 36 | "hardtanh": torch.nn.Hardtanh, 37 | "tanh": torch.nn.Tanh, 38 | "relu": torch.nn.ReLU, 39 | "selu": torch.nn.SELU, 40 | "swish": getattr(torch.nn, "SiLU", Swish), 41 | "gelu": torch.nn.GELU, 42 | } 43 | 44 | WENET_RNN_CLASSES = { 45 | "rnn": torch.nn.RNN, 46 | "lstm": torch.nn.LSTM, 47 | "gru": torch.nn.GRU, 48 | } 49 | 50 | WENET_SUBSAMPLE_CLASSES = { 51 | "linear": LinearNoSubsampling, 52 | "embed": EmbedinigNoSubsampling, 53 | "conv1d2": Conv1dSubsampling2, 54 | "conv2d2": Conv2dSubsampling2, 55 | "conv2d": Conv2dSubsampling4, 56 | "dwconv2d4": DepthwiseConv2dSubsampling4, 57 | "conv2d6": Conv2dSubsampling6, 58 | "conv2d8": Conv2dSubsampling8, 59 | 'paraformer_dummy': torch.nn.Identity, 60 | 'stack_n_frames': StackNFramesSubsampling, 61 | } 62 | 63 | WENET_EMB_CLASSES = { 64 | "embed": PositionalEncoding, 65 | "abs_pos": PositionalEncoding, 66 | "rel_pos": RelPositionalEncoding, 67 | "no_pos": NoPositionalEncoding, 68 | "abs_pos_whisper": WhisperPositionalEncoding, 69 | "embed_learnable_pe": LearnablePositionalEncoding, 70 | "abs_pos_paraformer": ParaformerPositinoalEncoding, 71 | } 72 | 73 | WENET_ATTENTION_CLASSES = { 74 | "selfattn": MultiHeadedAttention, 75 | "rel_selfattn": RelPositionMultiHeadedAttention, 76 | "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, 77 | "crossattn": MultiHeadedCrossAttention, 78 | 'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention 79 | } 80 | 81 | WENET_MLP_CLASSES = { 82 | 'position_wise_feed_forward': PositionwiseFeedForward, 83 | 'moe': MoEFFNLayer, 84 | 'gated': GatedVariantsMLP 85 | } 86 | 87 | WENET_NORM_CLASSES = { 88 | 'layer_norm': LayerNorm, 89 | 'batch_norm': BatchNorm1d, 90 | 'rms_norm': RMSNorm 91 | } 92 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/ctc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | 16 | from typing import Tuple 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | 22 | class CTC(torch.nn.Module): 23 | """CTC module""" 24 | 25 | def __init__( 26 | self, 27 | odim: int, 28 | encoder_output_size: int, 29 | dropout_rate: float = 0.0, 30 | reduce: bool = True, 31 | blank_id: int = 0, 32 | ): 33 | """ Construct CTC module 34 | Args: 35 | odim: dimension of outputs 36 | encoder_output_size: number of encoder projection units 37 | dropout_rate: dropout rate (0.0 ~ 1.0) 38 | reduce: reduce the CTC loss into a scalar 39 | blank_id: blank label. 40 | """ 41 | super().__init__() 42 | eprojs = encoder_output_size 43 | self.dropout_rate = dropout_rate 44 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 45 | 46 | reduction_type = "sum" if reduce else "none" 47 | self.ctc_loss = torch.nn.CTCLoss(blank=blank_id, 48 | reduction=reduction_type, 49 | zero_infinity=True) 50 | 51 | def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, 52 | ys_pad: torch.Tensor, 53 | ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 54 | """Calculate CTC loss. 55 | 56 | Args: 57 | hs_pad: batch of padded hidden state sequences (B, Tmax, D) 58 | hlens: batch of lengths of hidden state sequences (B) 59 | ys_pad: batch of padded character id sequence tensor (B, Lmax) 60 | ys_lens: batch of lengths of character sequence (B) 61 | """ 62 | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) 63 | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) 64 | # ys_hat: (B, L, D) -> (L, B, D) 65 | ys_hat = ys_hat.transpose(0, 1) 66 | ys_hat = ys_hat.log_softmax(2) 67 | loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) 68 | # Batch-size average 69 | loss = loss / ys_hat.size(1) 70 | ys_hat = ys_hat.transpose(0, 1) 71 | return loss, ys_hat 72 | 73 | def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 74 | """log_softmax of frame activations 75 | 76 | Args: 77 | Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 78 | Returns: 79 | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) 80 | """ 81 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 82 | 83 | def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 84 | """argmax of frame activations 85 | 86 | Args: 87 | torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 88 | Returns: 89 | torch.Tensor: argmax applied 2d tensor (B, Tmax) 90 | """ 91 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 92 | -------------------------------------------------------------------------------- /wenet_asr_client/example_streaming.py: -------------------------------------------------------------------------------- 1 | import wave 2 | import time 3 | from datetime import datetime 4 | 5 | from recognizer import Recognizer 6 | 7 | 8 | SERVER_IP = 'localhost' # 服务器IP地址 9 | SERVER_PORT = 50051 # 服务器端口号 10 | 11 | """ 12 | CHUNK: 模拟实时音频流时,每次读取的音频帧大小 13 | 用于控制一次接收的音频流长度,影响接收音频流间隔,从而控制识别结果实时更新速率 14 | CHUNK越小实时更新率越高,但应保证实时更新时间间隔大于一次性识别请求的时长 15 | """ 16 | CHUNK = 1024 * 16 17 | 18 | AUDIO_FILE = 'audio/082311171430575628101.wav' 19 | 20 | 21 | def audio_stream_simulate(file_path, chunk): 22 | ''' 23 | This function is used for simulating real-time audio stream from an audio file. 24 | ''' 25 | # 按指定CHUNK大小切割读取.wav文件,模拟实时数据流 26 | print('reading audio...') 27 | datastream = [] 28 | with wave.open(file_path, 'rb') as f: 29 | samprate = f.getframerate() # 使用.wav文件模拟数据流时,数据采样率为.wav文件的采样率;实际应用时应使用实际数据流的采样率(8k) 30 | channels = f.getnchannels() 31 | sampwidth = f.getsampwidth() 32 | duration = 0 33 | 34 | first = True 35 | while True: 36 | dataframes = f.readframes(chunk) 37 | if not dataframes: 38 | break 39 | 40 | if first: 41 | first = False 42 | num_frames = len(dataframes) / channels / sampwidth 43 | duration = num_frames / samprate # 每个CHUNK对应的音频时长(s) 44 | 45 | datastream.append(dataframes) 46 | print(f'Sample rate: {samprate}') 47 | print(f'Channels: {channels}') 48 | print(f'Sample width: {sampwidth}') 49 | print(f'CHUNK duration: {duration}(s)') 50 | return datastream, samprate, sampwidth, duration 51 | 52 | 53 | def streaming_test(): 54 | # 使用音频文件模拟实时音频流 55 | datastream, samprate, sampwidth, duration = audio_stream_simulate(AUDIO_FILE, CHUNK) 56 | 57 | # 初始化一个Recognizer对象,需指定待识别音频的采样率和位深度。仅支持单声道,多声道音频需提前转换为单声道。 58 | recognizer = Recognizer(sample_rate=samprate, 59 | bit_depth=8 * sampwidth, 60 | update_duration=8) 61 | 62 | # 连接服务器 63 | print('connecting...') 64 | recognizer.connect(f'{SERVER_IP}:{SERVER_PORT}') 65 | 66 | # recognizer.reload_model(context_score=6) 67 | 68 | # 开启流式识别模式 69 | recognizer.start_streaming() 70 | 71 | # 流式识别模式下,按CHUNK持续输入数据流至Recognizer对象,每次输入后自动执行流式识别逻辑 72 | print('recognizing...') 73 | ini_time = datetime.now() # 记录开始时间 74 | time.sleep(duration) # 模拟实时数据流,等待第一个CHUNK的时长 75 | 76 | for data in datastream: 77 | chunkStart = time.time() # 记录该CHUNK输入时的时间 78 | 79 | # 输入新音频数据,自动执行实时识别逻辑 80 | recognizer.input(data) # 输入一个CHUNK的数据,并自动执行流式识别逻辑,包括音频合并、转写、截断固定、标点预测等 81 | 82 | # 获取当前结果 83 | result = recognizer.result # 获取当前已获取部分音频流的总转写结果 84 | fixed, puncted, new = result['fixed'], result['puncted'], result['new'] 85 | print(f"[{datetime.now()-ini_time}] [{fixed}]{puncted}{new}") # 打印时间和结果(格式:[总时间] [标点固定部分]标点可变部分+文本可变部分) 86 | # print(f"[{datetime.now()-ini_time}] [{fixed}]{new}") # (无标点情况下)打印时间和结果(格式:[总时间] [文本固定部分]文本可变部分) 87 | 88 | # 模拟实时数据流,延迟一个CHUNK音频的实际时长 89 | run_delay = time.time() - chunkStart 90 | print('run_delay:', run_delay) 91 | if run_delay > duration: 92 | print(f'[WARNING]: Single recognition request duration ({run_delay}) is longer than CHUNK duration ({duration}). (RTF>1.) Try to set CHUNK larger to avoid this.') 93 | else: 94 | time.sleep(duration - run_delay) 95 | 96 | # 断开链接 97 | recognizer.disconnect() 98 | 99 | 100 | if __name__ == '__main__': 101 | streaming_test() 102 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/bin/export_ipex.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023 Intel Corporation 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import logging 8 | import os 9 | 10 | import torch 11 | import yaml 12 | 13 | from wenet.utils.init_model import init_model 14 | import intel_extension_for_pytorch as ipex 15 | from intel_extension_for_pytorch.quantization import prepare, convert 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser(description='export your script model') 20 | parser.add_argument('--config', required=True, help='config file') 21 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 22 | parser.add_argument('--output_file', default=None, help='output file') 23 | parser.add_argument('--dtype', 24 | default="fp32", 25 | help='choose the dtype to run:[fp32,bf16]') 26 | parser.add_argument('--output_quant_file', 27 | default=None, 28 | help='output quantized model file') 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def scripting(model): 34 | with torch.inference_mode(): 35 | script_model = torch.jit.script(model) 36 | script_model = torch.jit.freeze( 37 | script_model, 38 | preserved_attrs=[ 39 | "forward_encoder_chunk", "ctc_activation", 40 | "forward_attention_decoder", "subsampling_rate", 41 | "right_context", "sos_symbol", "eos_symbol", 42 | "is_bidirectional_decoder" 43 | ]) 44 | return script_model 45 | 46 | 47 | def main(): 48 | args = get_args() 49 | logging.basicConfig(level=logging.DEBUG, 50 | format='%(asctime)s %(levelname)s %(message)s') 51 | # No need gpu for model export 52 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 53 | 54 | with open(args.config, 'r') as fin: 55 | configs = yaml.load(fin, Loader=yaml.FullLoader) 56 | model, configs = init_model(args, configs) 57 | print(model) 58 | 59 | # Apply IPEX optimization 60 | model.eval() 61 | torch._C._jit_set_texpr_fuser_enabled(False) 62 | model.to(memory_format=torch.channels_last) 63 | if args.dtype == "fp32": 64 | ipex_model = ipex.optimize(model) 65 | elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) 66 | ipex_model = ipex.optimize(model, 67 | dtype=torch.bfloat16, 68 | weights_prepack=False) 69 | 70 | # Export jit torch script model 71 | if args.output_file: 72 | if args.dtype == "fp32": 73 | script_model = scripting(ipex_model) 74 | elif args.dtype == "bf16": 75 | torch._C._jit_set_autocast_mode(True) 76 | with torch.cpu.amp.autocast(): 77 | script_model = scripting(ipex_model) 78 | script_model.save(args.output_file) 79 | print('Export model successfully, see {}'.format(args.output_file)) 80 | 81 | # Export quantized jit torch script model 82 | if args.output_quant_file: 83 | dynamic_qconfig = ipex.quantization.default_dynamic_qconfig 84 | dummy_data = (torch.zeros(1, 67, 80), 16, -16, 85 | torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7)) 86 | model = prepare(model, dynamic_qconfig, dummy_data) 87 | model = convert(model) 88 | script_quant_model = scripting(model) 89 | script_quant_model.save(args.output_quant_file) 90 | print('Export quantized model successfully, ' 91 | 'see {}'.format(args.output_quant_file)) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/whisper/whisper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified from [Whisper](https://github.com/openai/whisper) 16 | 17 | import torch 18 | 19 | from typing import Tuple, Dict, List 20 | 21 | from wenet.transformer.asr_model import ASRModel 22 | from wenet.transformer.ctc import CTC 23 | from wenet.transformer.encoder import TransformerEncoder 24 | from wenet.transformer.decoder import TransformerDecoder 25 | from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy 26 | 27 | 28 | class Whisper(ASRModel): 29 | 30 | def __init__( 31 | self, 32 | vocab_size: int, 33 | encoder: TransformerEncoder, 34 | decoder: TransformerDecoder, 35 | ctc: CTC = None, 36 | ctc_weight: float = 0.5, 37 | ignore_id: int = IGNORE_ID, 38 | reverse_weight: float = 0.0, 39 | lsm_weight: float = 0.0, 40 | length_normalized_loss: bool = False, 41 | special_tokens: dict = None, 42 | ): 43 | super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, 44 | ignore_id, reverse_weight, lsm_weight, 45 | length_normalized_loss, special_tokens) 46 | assert reverse_weight == 0.0 47 | self.sos = special_tokens["sot"] 48 | self.eos = special_tokens["eot"] 49 | self.decode_maxlen = self.decoder.embed[1].max_len 50 | 51 | # TODO(xcsong): time align 52 | def set_alignment_heads(self, dump: bytes): 53 | raise NotImplementedError 54 | 55 | @property 56 | def is_multilingual(self): 57 | return self.vocab_size >= 51865 58 | 59 | @property 60 | def num_languages(self): 61 | return self.vocab_size - 51765 - int(self.is_multilingual) 62 | 63 | def _calc_att_loss( 64 | self, 65 | encoder_out: torch.Tensor, 66 | encoder_mask: torch.Tensor, 67 | ys_pad: torch.Tensor, 68 | ys_pad_lens: torch.Tensor, 69 | infos: Dict[str, List[str]], 70 | ) -> Tuple[torch.Tensor, float]: 71 | prev_len = ys_pad.size(1) 72 | ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens, 73 | ys_pad, 74 | self.ignore_id, 75 | tasks=infos['tasks'], 76 | no_timestamp=True, 77 | langs=infos['langs'], 78 | use_prev=False) 79 | cur_len = ys_in_pad.size(1) 80 | ys_in_lens = ys_pad_lens + cur_len - prev_len 81 | 82 | # 1. Forward decoder 83 | decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, 84 | ys_in_pad, ys_in_lens) 85 | 86 | # 2. Compute attention loss 87 | loss_att = self.criterion_att(decoder_out, ys_out_pad) 88 | acc_att = th_accuracy( 89 | decoder_out.view(-1, self.vocab_size), 90 | ys_out_pad, 91 | ignore_label=self.ignore_id, 92 | ) 93 | return loss_att, acc_att 94 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /wenet_asr_server/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | aistudio-sdk==0.1.7 5 | annotated-types==0.6.0 6 | anyio==4.3.0 7 | astor==0.8.1 8 | async-timeout==4.0.3 9 | attrs==23.2.0 10 | audioread==3.0.1 11 | av==12.0.0 12 | babel==2.14.0 13 | bce-python-sdk==0.9.6 14 | blinker==1.7.0 15 | certifi==2024.2.2 16 | cffi==1.16.0 17 | cfgv==3.4.0 18 | charset-normalizer==3.3.2 19 | clang-format==17.0.6 20 | click==8.1.7 21 | colorama==0.4.6 22 | coloredlogs==15.0.1 23 | colorlog==6.8.2 24 | contourpy==1.2.1 25 | cpplint==1.6.1 26 | cycler==0.12.1 27 | datasets==2.18.0 28 | deepspeed==0.12.6 29 | dill==0.3.4 30 | distlib==0.3.8 31 | fastapi==0.110.1 32 | filelock==3.13.3 33 | flake8==3.8.2 34 | flake8-bugbear==23.3.12 35 | flake8-comprehensions==3.14.0 36 | flake8-executable==2.1.3 37 | flake8-pyi==20.5.0 38 | flask==3.0.3 39 | flask-babel==4.0.0 40 | flatbuffers==24.3.25 41 | fonttools==4.51.0 42 | frozenlist==1.4.1 43 | fsspec==2024.2.0 44 | future==1.0.0 45 | grpcio 46 | grpcio-tools 47 | h11==0.14.0 48 | hjson==3.1.0 49 | httpcore==1.0.5 50 | httpx==0.27.0 51 | huggingface-hub==0.22.2 52 | humanfriendly==10.0 53 | identify==2.5.35 54 | idna==3.6 55 | ijson==3.2.3 56 | iniconfig==2.0.0 57 | itsdangerous==2.1.2 58 | jieba==0.42.1 59 | jinja2==3.1.3 60 | joblib==1.3.2 61 | kaldiio==2.18.0 62 | kiwisolver==1.4.5 63 | langid==1.1.6 64 | lazy_loader==0.3 65 | librosa==0.8.1 66 | llvmlite==0.42.0 67 | markdown==3.6 68 | markupsafe==2.1.5 69 | matplotlib==3.8.4 70 | mccabe==0.6.1 71 | mdurl==0.1.2 72 | more-itertools==10.2.0 73 | mpmath==1.3.0 74 | msgpack==1.0.8 75 | multidict==6.0.5 76 | multiprocess==0.70.12.2 77 | networkx==3.2.1 78 | ninja==1.11.1.1 79 | nodeenv==1.8.0 80 | numba==0.59.1 81 | numpy==1.26.4 82 | nvidia-cublas-cu12==12.1.3.1 83 | nvidia-cuda-cupti-cu12==12.1.105 84 | nvidia-cuda-nvrtc-cu12==12.1.105 85 | nvidia-cuda-runtime-cu12==12.1.105 86 | nvidia-cudnn-cu12==8.9.2.26 87 | nvidia-cufft-cu12==11.0.2.54 88 | nvidia-curand-cu12==10.3.2.106 89 | nvidia-cusolver-cu12==11.4.5.107 90 | nvidia-cusparse-cu12==12.1.0.106 91 | nvidia-nccl-cu12==2.19.3 92 | nvidia-nvjitlink-cu12==12.4.99 93 | nvidia-nvtx-cu12==12.1.105 94 | onnx==1.16.0 95 | onnxruntime==1.17.1 96 | openai-whisper==20231117 97 | opt-einsum==3.3.0 98 | paddle2onnx==1.1.0 99 | paddleaudio==1.1.0 100 | paddlefsl==1.1.0 101 | paddlenlp==2.7.2 102 | paddlepaddle==2.6.1 103 | pandas==2.2.1 104 | parameterized==0.9.0 105 | pathos==0.2.8 106 | pillow==10.2.0 107 | platformdirs==4.2.0 108 | pluggy==1.4.0 109 | pooch==1.8.1 110 | pox==0.3.4 111 | ppasr==2.4.6 112 | ppft==1.7.6.8 113 | pre-commit==3.5.0 114 | protobuf==4.25.3 115 | py-cpuinfo==9.0.0 116 | pyarrow==15.0.2 117 | pyarrow-hotfix==0.6 118 | pyaudio==0.2.11 119 | pybind11==2.12.0 120 | pycodestyle==2.6.0 121 | pycparser==2.21 122 | pycryptodome==3.20.0 123 | pydantic==2.6.4 124 | pydantic_core==2.16.3 125 | pyflakes==2.2.0 126 | pynvml==11.5.0 127 | pyparsing==3.1.2 128 | pytest==8.1.1 129 | python-Levenshtein==0.12.2 130 | pytz==2024.1 131 | pyyaml==6.0.1 132 | rarfile==4.2 133 | regex==2023.12.25 134 | requests==2.31.0 135 | resampy==0.4.3 136 | rich==13.7.1 137 | safetensors==0.4.2 138 | scikit-learn==1.4.1.post1 139 | scipy==1.12.0 140 | sentencepiece==0.2.0 141 | seqeval==1.2.2 142 | shellingham==1.5.4 143 | sniffio==1.3.1 144 | soundfile==0.12.1 145 | soxr==0.3.7 146 | starlette==0.37.2 147 | sympy==1.12 148 | tensorboard==2.16.2 149 | tensorboard-data-server==0.7.2 150 | tensorboardX==2.6.2.2 151 | termcolor==2.4.0 152 | textgrid==1.6.1 153 | threadpoolctl==3.4.0 154 | tiktoken==0.6.0 155 | tomli==2.0.1 156 | tool-helpers==0.1.1 157 | torch==2.2.2 158 | torchaudio==2.2.2 159 | tqdm==4.66.2 160 | triton==2.2.0 161 | typeguard==2.13.3 162 | typer==0.12.1 163 | typing_extensions==4.10.0 164 | tzdata==2024.1 165 | urllib3==2.2.1 166 | uvicorn==0.29.0 167 | virtualenv==20.25.1 168 | visualdl==2.5.3 169 | xxhash==3.4.1 170 | yarl==1.9.4 171 | zhconv==1.4.3 172 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/text/whisper_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Tuple, Union 3 | from wenet.text.base_tokenizer import BaseTokenizer 4 | 5 | from wenet.utils.file_utils import read_non_lang_symbols 6 | 7 | 8 | class WhisperTokenizer(BaseTokenizer): 9 | 10 | def __init__( 11 | self, 12 | multilingual: bool, 13 | num_languages: int = 99, 14 | language: Optional[str] = None, 15 | task: Optional[str] = None, 16 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 17 | *args, 18 | **kwargs, 19 | ) -> None: 20 | # NOTE(Mddct): don't build here, pickle issues 21 | self.tokenizer = None 22 | # TODO: we don't need this in future 23 | self.multilingual = multilingual 24 | self.num_languages = num_languages 25 | self.language = language 26 | self.task = task 27 | 28 | if not isinstance(non_lang_syms, List): 29 | self.non_lang_syms = read_non_lang_symbols(non_lang_syms) 30 | else: 31 | # non_lang_syms=["{NOISE}"] 32 | self.non_lang_syms = non_lang_syms 33 | # TODO(Mddct): add special tokens, like non_lang_syms 34 | del self.non_lang_syms 35 | 36 | def __getstate__(self): 37 | state = self.__dict__.copy() 38 | del state['tokenizer'] 39 | return state 40 | 41 | def __setstate__(self, state): 42 | self.__dict__.update(state) 43 | recovery = {'tokenizer': None} 44 | self.__dict__.update(recovery) 45 | 46 | def _build_tiktoken(self): 47 | if self.tokenizer is None: 48 | from whisper.tokenizer import get_tokenizer 49 | self.tokenizer = get_tokenizer(multilingual=self.multilingual, 50 | num_languages=self.num_languages, 51 | language=self.language, 52 | task=self.task) 53 | self.t2i = {} 54 | self.i2t = {} 55 | for i in range(self.tokenizer.encoding.n_vocab): 56 | unit = str( 57 | self.tokenizer.encoding.decode_single_token_bytes(i)) 58 | if len(unit) == 0: 59 | unit = str(i) 60 | unit = unit.replace(" ", "") 61 | # unit = bytes(unit, 'utf-8') 62 | self.t2i[unit] = i 63 | self.i2t[i] = unit 64 | assert len(self.t2i) == len(self.i2t) 65 | 66 | def tokenize(self, line: str) -> Tuple[List[str], List[int]]: 67 | self._build_tiktoken() 68 | ids = self.tokenizer.encoding.encode(line) 69 | text = [self.i2t[d] for d in ids] 70 | return text, ids 71 | 72 | def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: 73 | self._build_tiktoken() 74 | tokens = [self.i2t[d] for d in ids] 75 | text = self.tokenizer.encoding.decode(ids) 76 | return text, tokens 77 | 78 | def text2tokens(self, line: str) -> List[str]: 79 | self._build_tiktoken() 80 | return self.tokenize(line)[0] 81 | 82 | def tokens2text(self, tokens: List[str]) -> str: 83 | self._build_tiktoken() 84 | ids = [self.t2i[t] for t in tokens] 85 | return self.detokenize(ids)[0] 86 | 87 | def tokens2ids(self, tokens: List[str]) -> List[int]: 88 | self._build_tiktoken() 89 | ids = [self.t2i[t] for t in tokens] 90 | return ids 91 | 92 | def ids2tokens(self, ids: List[int]) -> List[str]: 93 | self._build_tiktoken() 94 | return [self.tokenizer.encoding.decode([id]) for id in ids] 95 | 96 | def vocab_size(self) -> int: 97 | self._build_tiktoken() 98 | return len(self.t2i) 99 | 100 | @property 101 | def symbol_table(self) -> Dict[str, int]: 102 | self._build_tiktoken() 103 | return self.t2i 104 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/cli/hub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Mddct(hamddct@gmail.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import requests 17 | import sys 18 | import tarfile 19 | from pathlib import Path 20 | from urllib.request import urlretrieve 21 | 22 | import tqdm 23 | 24 | 25 | def download(url: str, dest: str, only_child=True): 26 | """ download from url to dest 27 | """ 28 | assert os.path.exists(dest) 29 | print('Downloading {} to {}'.format(url, dest)) 30 | 31 | def progress_hook(t): 32 | last_b = [0] 33 | 34 | def update_to(b=1, bsize=1, tsize=None): 35 | if tsize not in (None, -1): 36 | t.total = tsize 37 | displayed = t.update((b - last_b[0]) * bsize) 38 | last_b[0] = b 39 | return displayed 40 | 41 | return update_to 42 | 43 | # *.tar.gz 44 | name = url.split('?')[0].split('/')[-1] 45 | tar_path = os.path.join(dest, name) 46 | with tqdm.tqdm(unit='B', 47 | unit_scale=True, 48 | unit_divisor=1024, 49 | miniters=1, 50 | desc=(name)) as t: 51 | urlretrieve(url, 52 | filename=tar_path, 53 | reporthook=progress_hook(t), 54 | data=None) 55 | t.total = t.n 56 | 57 | with tarfile.open(tar_path) as f: 58 | if not only_child: 59 | f.extractall(dest) 60 | else: 61 | for tarinfo in f: 62 | if "/" not in tarinfo.name: 63 | continue 64 | name = os.path.basename(tarinfo.name) 65 | fileobj = f.extractfile(tarinfo) 66 | with open(os.path.join(dest, name), "wb") as writer: 67 | writer.write(fileobj.read()) 68 | 69 | 70 | class Hub(object): 71 | """Hub for wenet pretrain runtime model 72 | """ 73 | # TODO(Mddct): make assets class to support other language 74 | Assets = { 75 | # wenetspeech 76 | "chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz", 77 | # gigaspeech 78 | "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", 79 | # paraformer 80 | "paraformer": "paraformer.tar.gz" 81 | } 82 | 83 | def __init__(self) -> None: 84 | pass 85 | 86 | @staticmethod 87 | def get_model_by_lang(lang: str) -> str: 88 | if lang not in Hub.Assets.keys(): 89 | print('ERROR: Unsupported language {} !!!'.format(lang)) 90 | sys.exit(1) 91 | 92 | # NOTE(Mddct): model_dir structure 93 | # Path.Home()/.wenet 94 | # - chs 95 | # - units.txt 96 | # - final.zip 97 | # - en 98 | # - units.txt 99 | # - final.zip 100 | model = Hub.Assets[lang] 101 | model_dir = os.path.join(Path.home(), ".wenet", lang) 102 | if not os.path.exists(model_dir): 103 | os.makedirs(model_dir) 104 | # TODO(Mddct): model metadata 105 | if set(["final.zip", 106 | "units.txt"]).issubset(set(os.listdir(model_dir))): 107 | return model_dir 108 | # If not exist, download 109 | response = requests.get( 110 | "https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa 111 | ) 112 | model_info = next(data for data in response.json()["Data"] 113 | if data["Key"] == model) 114 | model_url = model_info['Url'] 115 | download(model_url, model_dir, only_child=True) 116 | return model_dir 117 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import re 18 | 19 | import yaml 20 | import torch 21 | from collections import OrderedDict 22 | 23 | import datetime 24 | 25 | 26 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: 27 | logging.info('Checkpoint: loading from checkpoint %s' % path) 28 | checkpoint = torch.load(path, map_location='cpu') 29 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint, 30 | strict=False) 31 | for key in missing_keys: 32 | logging.info("missing tensor: {}".format(key)) 33 | for key in unexpected_keys: 34 | logging.info("unexpected tensor: {}".format(key)) 35 | info_path = re.sub('.pt$', '.yaml', path) 36 | configs = {} 37 | if os.path.exists(info_path): 38 | with open(info_path, 'r') as fin: 39 | configs = yaml.load(fin, Loader=yaml.FullLoader) 40 | return configs 41 | 42 | 43 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 44 | ''' 45 | Args: 46 | infos (dict or None): any info you want to save. 47 | ''' 48 | logging.info('Checkpoint: save to checkpoint %s' % path) 49 | if isinstance(model, torch.nn.DataParallel): 50 | state_dict = model.module.state_dict() 51 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 52 | state_dict = model.module.state_dict() 53 | else: 54 | state_dict = model.state_dict() 55 | torch.save(state_dict, path) 56 | info_path = re.sub('.pt$', '.yaml', path) 57 | if infos is None: 58 | infos = {} 59 | infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') 60 | with open(info_path, 'w') as fout: 61 | data = yaml.dump(infos) 62 | fout.write(data) 63 | 64 | 65 | def filter_modules(model_state_dict, modules): 66 | new_mods = [] 67 | incorrect_mods = [] 68 | mods_model = model_state_dict.keys() 69 | for mod in modules: 70 | if any(key.startswith(mod) for key in mods_model): 71 | new_mods += [mod] 72 | else: 73 | incorrect_mods += [mod] 74 | if incorrect_mods: 75 | logging.warning( 76 | "module(s) %s don't match or (partially match) " 77 | "available modules in model.", 78 | incorrect_mods, 79 | ) 80 | logging.warning("for information, the existing modules in model are:") 81 | logging.warning("%s", mods_model) 82 | 83 | return new_mods 84 | 85 | 86 | def load_trained_modules(model: torch.nn.Module, args: None): 87 | # Load encoder modules with pre-trained model(s). 88 | enc_model_path = args.enc_init 89 | enc_modules = args.enc_init_mods 90 | main_state_dict = model.state_dict() 91 | logging.warning("model(s) found for pre-initialization") 92 | if os.path.isfile(enc_model_path): 93 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % 94 | enc_model_path) 95 | model_state_dict = torch.load(enc_model_path, map_location='cpu') 96 | modules = filter_modules(model_state_dict, enc_modules) 97 | partial_state_dict = OrderedDict() 98 | for key, value in model_state_dict.items(): 99 | if any(key.startswith(m) for m in modules): 100 | partial_state_dict[key] = value 101 | main_state_dict.update(partial_state_dict) 102 | else: 103 | logging.warning("model was not found : %s", enc_model_path) 104 | 105 | model.load_state_dict(main_state_dict) 106 | configs = {} 107 | return configs 108 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transducer/joint.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES 6 | 7 | 8 | class TransducerJoint(torch.nn.Module): 9 | 10 | def __init__(self, 11 | vocab_size: int, 12 | enc_output_size: int, 13 | pred_output_size: int, 14 | join_dim: int, 15 | prejoin_linear: bool = True, 16 | postjoin_linear: bool = False, 17 | joint_mode: str = 'add', 18 | activation: str = "tanh", 19 | hat_joint: bool = False, 20 | dropout_rate: float = 0.1, 21 | hat_activation: str = 'tanh'): 22 | # TODO(Mddct): concat in future 23 | assert joint_mode in ['add'] 24 | super().__init__() 25 | 26 | self.activatoin = WENET_ACTIVATION_CLASSES[activation]() 27 | self.prejoin_linear = prejoin_linear 28 | self.postjoin_linear = postjoin_linear 29 | self.joint_mode = joint_mode 30 | 31 | if not self.prejoin_linear and not self.postjoin_linear: 32 | assert enc_output_size == pred_output_size == join_dim 33 | # torchscript compatibility 34 | self.enc_ffn: Optional[nn.Linear] = None 35 | self.pred_ffn: Optional[nn.Linear] = None 36 | if self.prejoin_linear: 37 | self.enc_ffn = nn.Linear(enc_output_size, join_dim) 38 | self.pred_ffn = nn.Linear(pred_output_size, join_dim) 39 | # torchscript compatibility 40 | self.post_ffn: Optional[nn.Linear] = None 41 | if self.postjoin_linear: 42 | self.post_ffn = nn.Linear(join_dim, join_dim) 43 | 44 | # NOTE: in vocab_size 45 | self.hat_joint = hat_joint 46 | self.vocab_size = vocab_size 47 | self.ffn_out: Optional[torch.nn.Linear] = None 48 | if not self.hat_joint: 49 | self.ffn_out = nn.Linear(join_dim, vocab_size) 50 | 51 | self.blank_pred: Optional[torch.nn.Module] = None 52 | self.token_pred: Optional[torch.nn.Module] = None 53 | if self.hat_joint: 54 | self.blank_pred = torch.nn.Sequential( 55 | torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), 56 | torch.nn.Linear(join_dim, 1), torch.nn.LogSigmoid()) 57 | self.token_pred = torch.nn.Sequential( 58 | WENET_ACTIVATION_CLASSES[hat_activation](), 59 | torch.nn.Dropout(dropout_rate), 60 | torch.nn.Linear(join_dim, self.vocab_size - 1)) 61 | 62 | def forward(self, 63 | enc_out: torch.Tensor, 64 | pred_out: torch.Tensor, 65 | pre_project: bool = True) -> torch.Tensor: 66 | """ 67 | Args: 68 | enc_out (torch.Tensor): [B, T, E] 69 | pred_out (torch.Tensor): [B, T, P] 70 | Return: 71 | [B,T,U,V] 72 | """ 73 | if (pre_project and self.prejoin_linear and self.enc_ffn is not None 74 | and self.pred_ffn is not None): 75 | enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,D] 76 | pred_out = self.pred_ffn(pred_out) 77 | if enc_out.ndim != 4: 78 | enc_out = enc_out.unsqueeze(2) # [B,T,D] -> [B,T,1,D] 79 | if pred_out.ndim != 4: 80 | pred_out = pred_out.unsqueeze(1) # [B,U,D] -> [B,1,U,D] 81 | 82 | # TODO(Mddct): concat joint 83 | _ = self.joint_mode 84 | out = enc_out + pred_out # [B,T,U,V] 85 | 86 | if self.postjoin_linear and self.post_ffn is not None: 87 | out = self.post_ffn(out) 88 | 89 | if not self.hat_joint and self.ffn_out is not None: 90 | out = self.activatoin(out) 91 | out = self.ffn_out(out) 92 | return out 93 | else: 94 | assert self.blank_pred is not None 95 | assert self.token_pred is not None 96 | blank_logp = self.blank_pred(out) # [B,T,U,1] 97 | 98 | # scale blank logp 99 | scale_logp = torch.clamp(1 - torch.exp(blank_logp), min=1e-6) 100 | label_logp = self.token_pred(out).log_softmax( 101 | dim=-1) # [B,T,U,vocab-1] 102 | # scale token logp 103 | label_logp = torch.log(scale_logp) + label_logp 104 | 105 | out = torch.cat((blank_logp, label_logp), dim=-1) # [B,T,U,vocab] 106 | return out 107 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/ssl/wav2vec2/quantizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | 5 | def gumbel(shape: torch.Size, dtype: torch.dtype, device: torch.device): 6 | """Sample Gumbel random values with given shape and float dtype. 7 | 8 | The values are distributed according to the probability density function: 9 | 10 | .. math:: 11 | f(x) = e^{-(x + e^{-x})} 12 | 13 | Args: 14 | shape (torch.Size): pdf shape 15 | dtype (torch.dtype): pdf value dtype 16 | 17 | Returns: 18 | A random array with the specified shape and dtype. 19 | """ 20 | # see https://www.cnblogs.com/initial-h/p/9468974.html for more details 21 | return -torch.log(-torch.log( 22 | torch.empty(shape, device=device).uniform_( 23 | torch.finfo(dtype).tiny, 1.))) 24 | 25 | 26 | class Wav2vecGumbelVectorQuantizer(torch.nn.Module): 27 | 28 | def __init__(self, 29 | features_dim: int = 256, 30 | num_codebooks: int = 2, 31 | num_embeddings: int = 8192, 32 | embedding_dim: int = 16, 33 | hard: bool = False) -> None: 34 | 35 | super().__init__() 36 | 37 | self.num_groups = num_codebooks 38 | self.num_codevectors_per_group = num_embeddings 39 | # codebooks 40 | # means [C, G, D] see quantize_vector in bestrq_model.py 41 | assert embedding_dim % num_codebooks == 0.0 42 | self.embeddings = torch.nn.parameter.Parameter( 43 | torch.empty(1, num_codebooks * num_embeddings, 44 | embedding_dim // num_codebooks), 45 | requires_grad=True, 46 | ) 47 | torch.nn.init.uniform_(self.embeddings) 48 | 49 | self.weight_proj = torch.nn.Linear(features_dim, 50 | num_codebooks * num_embeddings) 51 | # use gumbel softmax or argmax(non-differentiable) 52 | self.hard = hard 53 | 54 | @staticmethod 55 | def _compute_perplexity(probs, mask=None): 56 | if mask is not None: 57 | 58 | mask_extended = torch.broadcast_to(mask.flatten()[:, None, None], 59 | probs.shape) 60 | probs = torch.where(mask_extended.to(torch.bool), probs, 61 | torch.zeros_like(probs)) 62 | marginal_probs = probs.sum(dim=0) / mask.sum() 63 | else: 64 | marginal_probs = probs.mean(dim=0) 65 | 66 | perplexity = torch.exp(-torch.sum( 67 | marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() 68 | return perplexity 69 | 70 | def forward( 71 | self, 72 | input: torch.Tensor, 73 | input_mask: torch.Tensor, 74 | temperature: float = 1. 75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 76 | 77 | b, t, _ = input.size() 78 | 79 | hidden = self.weight_proj(input) 80 | hidden = hidden.reshape(b * t * self.num_groups, -1) 81 | if not self.hard: 82 | # sample code vector probs via gumbel in differentiateable way 83 | gumbels = gumbel(hidden.size(), hidden.dtype, hidden.device) 84 | codevector_probs = torch.nn.functional.softmax( 85 | (hidden + gumbels) / temperature, dim=-1) 86 | 87 | # compute perplexity 88 | codevector_soft_dist = torch.nn.functional.softmax( 89 | hidden.reshape(b * t, self.num_groups, -1), 90 | dim=-1, 91 | ) # [B*T, num_codebooks, num_embeddings] 92 | perplexity = self._compute_perplexity(codevector_soft_dist, 93 | input_mask) 94 | else: 95 | # take argmax in non-differentiable way 96 | # comptute hard codevector distribution (one hot) 97 | codevector_idx = hidden.argmax(axis=-1) 98 | codevector_probs = torch.nn.functional.one_hot( 99 | codevector_idx, hidden.shape[-1]) * 1.0 100 | codevector_probs = codevector_probs.reshape( 101 | b * t, self.num_groups, -1) 102 | perplexity = self._compute_perplexity(codevector_probs, input_mask) 103 | 104 | targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1) 105 | codevector_probs = codevector_probs.reshape(b * t, -1) 106 | # use probs to retrieve codevectors 107 | codevectors_per_group = codevector_probs.unsqueeze( 108 | -1) * self.embeddings 109 | codevectors = codevectors_per_group.reshape( 110 | b * t, self.num_groups, self.num_codevectors_per_group, -1) 111 | 112 | codevectors = codevectors.sum(-2).reshape(b, t, -1) 113 | return codevectors, perplexity, targets_idx 114 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import glob 18 | import sys 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser(description='average model') 26 | parser.add_argument('--dst_model', required=True, help='averaged model') 27 | parser.add_argument('--src_path', 28 | required=True, 29 | help='src model path for average') 30 | parser.add_argument('--val_best', 31 | action="store_true", 32 | help='averaged model') 33 | parser.add_argument('--num', 34 | default=5, 35 | type=int, 36 | help='nums for averaged model') 37 | parser.add_argument('--min_epoch', 38 | default=0, 39 | type=int, 40 | help='min epoch used for averaging model') 41 | parser.add_argument('--max_epoch', 42 | default=sys.maxsize, 43 | type=int, 44 | help='max epoch used for averaging model') 45 | parser.add_argument('--min_step', 46 | default=0, 47 | type=int, 48 | help='min step used for averaging model') 49 | parser.add_argument('--max_step', 50 | default=sys.maxsize, 51 | type=int, 52 | help='max step used for averaging model') 53 | parser.add_argument('--mode', 54 | default="hybrid", 55 | choices=["hybrid", "epoch", "step"], 56 | type=str, 57 | help='average mode') 58 | 59 | args = parser.parse_args() 60 | print(args) 61 | return args 62 | 63 | 64 | def main(): 65 | args = get_args() 66 | checkpoints = [] 67 | val_scores = [] 68 | if args.val_best: 69 | if args.mode == "hybrid": 70 | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) 71 | yamls = [ 72 | f for f in yamls 73 | if not (os.path.basename(f).startswith('train') 74 | or os.path.basename(f).startswith('init')) 75 | ] 76 | elif args.mode == "step": 77 | yamls = glob.glob('{}/step_*.yaml'.format(args.src_path)) 78 | else: 79 | yamls = glob.glob('{}/epoch_*.yaml'.format(args.src_path)) 80 | for y in yamls: 81 | with open(y, 'r') as f: 82 | dic_yaml = yaml.load(f, Loader=yaml.FullLoader) 83 | loss = dic_yaml['loss_dict']['loss'] 84 | epoch = dic_yaml['epoch'] 85 | step = dic_yaml['step'] 86 | tag = dic_yaml['tag'] 87 | if epoch >= args.min_epoch and epoch <= args.max_epoch \ 88 | and step >= args.min_step and step <= args.max_step: 89 | val_scores += [[epoch, step, loss, tag]] 90 | sorted_val_scores = sorted(val_scores, 91 | key=lambda x: x[2], 92 | reverse=False) 93 | print("best val (epoch, step, loss, tag) = " + 94 | str(sorted_val_scores[:args.num])) 95 | path_list = [ 96 | args.src_path + '/{}.pt'.format(score[-1]) 97 | for score in sorted_val_scores[:args.num] 98 | ] 99 | else: 100 | path_list = glob.glob('{}/[!init]*.pt'.format(args.src_path)) 101 | path_list = sorted(path_list, key=os.path.getmtime) 102 | path_list = path_list[-args.num:] 103 | print(path_list) 104 | avg = {} 105 | num = args.num 106 | assert num == len(path_list) 107 | for path in path_list: 108 | print('Processing {}'.format(path)) 109 | states = torch.load(path, map_location=torch.device('cpu')) 110 | for k in states.keys(): 111 | if k not in avg.keys(): 112 | avg[k] = states[k].clone() 113 | else: 114 | avg[k] += states[k] 115 | # average 116 | for k in avg.keys(): 117 | if avg[k] is not None: 118 | # pytorch 1.6 use true_divide instead of /= 119 | avg[k] = torch.true_divide(avg[k], num) 120 | print('Saving to {}'.format(args.dst_model)) 121 | torch.save(avg, args.dst_model) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/squeezeformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SqueezeformerEncoderLayer definition.""" 15 | 16 | import torch 17 | import torch.nn as nn 18 | from typing import Optional, Tuple 19 | 20 | 21 | class SqueezeformerEncoderLayer(nn.Module): 22 | """Encoder layer module. 23 | Args: 24 | size (int): Input dimension. 25 | self_attn (torch.nn.Module): Self-attention module instance. 26 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 27 | instance can be used as the argument. 28 | feed_forward1 (torch.nn.Module): Feed-forward module instance. 29 | `PositionwiseFeedForward` instance can be used as the argument. 30 | conv_module (torch.nn.Module): Convolution module instance. 31 | `ConvlutionModule` instance can be used as the argument. 32 | feed_forward2 (torch.nn.Module): Feed-forward module instance. 33 | `PositionwiseFeedForward` instance can be used as the argument. 34 | dropout_rate (float): Dropout rate. 35 | normalize_before (bool): 36 | True: use layer_norm before each sub-block. 37 | False: use layer_norm after each sub-block. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | size: int, 43 | self_attn: torch.nn.Module, 44 | feed_forward1: Optional[nn.Module] = None, 45 | conv_module: Optional[nn.Module] = None, 46 | feed_forward2: Optional[nn.Module] = None, 47 | normalize_before: bool = False, 48 | dropout_rate: float = 0.1, 49 | concat_after: bool = False, 50 | ): 51 | super(SqueezeformerEncoderLayer, self).__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.layer_norm1 = nn.LayerNorm(size) 55 | self.ffn1 = feed_forward1 56 | self.layer_norm2 = nn.LayerNorm(size) 57 | self.conv_module = conv_module 58 | self.layer_norm3 = nn.LayerNorm(size) 59 | self.ffn2 = feed_forward2 60 | self.layer_norm4 = nn.LayerNorm(size) 61 | self.normalize_before = normalize_before 62 | self.dropout = nn.Dropout(dropout_rate) 63 | self.concat_after = concat_after 64 | if concat_after: 65 | self.concat_linear = nn.Linear(size + size, size) 66 | else: 67 | self.concat_linear = nn.Identity() 68 | 69 | def forward( 70 | self, 71 | x: torch.Tensor, 72 | mask: torch.Tensor, 73 | pos_emb: torch.Tensor, 74 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 75 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 76 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 77 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 78 | # self attention module 79 | residual = x 80 | if self.normalize_before: 81 | x = self.layer_norm1(x) 82 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, 83 | att_cache) 84 | if self.concat_after: 85 | x_concat = torch.cat((x, x_att), dim=-1) 86 | x = residual + self.concat_linear(x_concat) 87 | else: 88 | x = residual + self.dropout(x_att) 89 | if not self.normalize_before: 90 | x = self.layer_norm1(x) 91 | 92 | # ffn module 93 | residual = x 94 | if self.normalize_before: 95 | x = self.layer_norm2(x) 96 | x = self.ffn1(x) 97 | x = residual + self.dropout(x) 98 | if not self.normalize_before: 99 | x = self.layer_norm2(x) 100 | 101 | # conv module 102 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 103 | residual = x 104 | if self.normalize_before: 105 | x = self.layer_norm3(x) 106 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 107 | x = residual + self.dropout(x) 108 | if not self.normalize_before: 109 | x = self.layer_norm3(x) 110 | 111 | # ffn module 112 | residual = x 113 | if self.normalize_before: 114 | x = self.layer_norm4(x) 115 | x = self.ffn2(x) 116 | # we do not use dropout here since it is inside feed forward function 117 | x = residual + self.dropout(x) 118 | if not self.normalize_before: 119 | x = self.layer_norm4(x) 120 | 121 | return x, mask, new_att_cache, new_cnn_cache 122 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | """ConvolutionModule definition.""" 16 | 17 | from typing import Tuple 18 | 19 | import torch 20 | from torch import nn 21 | 22 | from wenet.utils.class_utils import WENET_NORM_CLASSES 23 | 24 | 25 | class ConvolutionModule(nn.Module): 26 | """ConvolutionModule in Conformer model.""" 27 | 28 | def __init__( 29 | self, 30 | channels: int, 31 | kernel_size: int = 15, 32 | activation: nn.Module = nn.ReLU(), 33 | norm: str = "batch_norm", 34 | causal: bool = False, 35 | bias: bool = True, 36 | norm_eps: float = 1e-5, 37 | ): 38 | """Construct an ConvolutionModule object. 39 | Args: 40 | channels (int): The number of channels of conv layers. 41 | kernel_size (int): Kernel size of conv layers. 42 | causal (int): Whether use causal convolution or not 43 | """ 44 | super().__init__() 45 | 46 | self.pointwise_conv1 = nn.Conv1d( 47 | channels, 48 | 2 * channels, 49 | kernel_size=1, 50 | stride=1, 51 | padding=0, 52 | bias=bias, 53 | ) 54 | # self.lorder is used to distinguish if it's a causal convolution, 55 | # if self.lorder > 0: it's a causal convolution, the input will be 56 | # padded with self.lorder frames on the left in forward. 57 | # else: it's a symmetrical convolution 58 | if causal: 59 | padding = 0 60 | self.lorder = kernel_size - 1 61 | else: 62 | # kernel_size should be an odd number for none causal convolution 63 | assert (kernel_size - 1) % 2 == 0 64 | padding = (kernel_size - 1) // 2 65 | self.lorder = 0 66 | self.depthwise_conv = nn.Conv1d( 67 | channels, 68 | channels, 69 | kernel_size, 70 | stride=1, 71 | padding=padding, 72 | groups=channels, 73 | bias=bias, 74 | ) 75 | 76 | assert norm in ['batch_norm', 'layer_norm', 'rms_norm'] 77 | if norm == "batch_norm": 78 | self.use_layer_norm = False 79 | self.norm = WENET_NORM_CLASSES['batch_norm'](channels, 80 | eps=norm_eps) 81 | else: 82 | self.use_layer_norm = True 83 | self.norm = WENET_NORM_CLASSES[norm](channels, eps=norm_eps) 84 | 85 | self.pointwise_conv2 = nn.Conv1d( 86 | channels, 87 | channels, 88 | kernel_size=1, 89 | stride=1, 90 | padding=0, 91 | bias=bias, 92 | ) 93 | self.activation = activation 94 | 95 | def forward( 96 | self, 97 | x: torch.Tensor, 98 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 99 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 100 | ) -> Tuple[torch.Tensor, torch.Tensor]: 101 | """Compute convolution module. 102 | Args: 103 | x (torch.Tensor): Input tensor (#batch, time, channels). 104 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 105 | (0, 0, 0) means fake mask. 106 | cache (torch.Tensor): left context cache, it is only 107 | used in causal convolution (#batch, channels, cache_t), 108 | (0, 0, 0) meas fake cache. 109 | Returns: 110 | torch.Tensor: Output tensor (#batch, time, channels). 111 | """ 112 | # exchange the temporal dimension and the feature dimension 113 | x = x.transpose(1, 2) # (#batch, channels, time) 114 | 115 | # mask batch padding 116 | if mask_pad.size(2) > 0: # time > 0 117 | x.masked_fill_(~mask_pad, 0.0) 118 | 119 | if self.lorder > 0: 120 | if cache.size(2) == 0: # cache_t == 0 121 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 122 | else: 123 | assert cache.size(0) == x.size(0) # equal batch 124 | assert cache.size(1) == x.size(1) # equal channel 125 | x = torch.cat((cache, x), dim=2) 126 | assert (x.size(2) > self.lorder) 127 | new_cache = x[:, :, -self.lorder:] 128 | else: 129 | # It's better we just return None if no cache is required, 130 | # However, for JIT export, here we just fake one tensor instead of 131 | # None. 132 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 133 | 134 | # GLU mechanism 135 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 136 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 137 | 138 | # 1D Depthwise Conv 139 | x = self.depthwise_conv(x) 140 | if self.use_layer_norm: 141 | x = x.transpose(1, 2) 142 | x = self.activation(self.norm(x)) 143 | if self.use_layer_norm: 144 | x = x.transpose(1, 2) 145 | x = self.pointwise_conv2(x) 146 | # mask batch padding 147 | if mask_pad.size(2) > 0: # time > 0 148 | x.masked_fill_(~mask_pad, 0.0) 149 | 150 | return x.transpose(1, 2), new_cache 151 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | bias: bool = True, 40 | ): 41 | """Construct a PositionwiseFeedForward object.""" 42 | super(PositionwiseFeedForward, self).__init__() 43 | self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias) 44 | self.activation = activation 45 | self.dropout = torch.nn.Dropout(dropout_rate) 46 | self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) 47 | 48 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 49 | """Forward function. 50 | 51 | Args: 52 | xs: input tensor (B, L, D) 53 | Returns: 54 | output tensor, (B, L, D) 55 | """ 56 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 57 | 58 | 59 | class MoEFFNLayer(torch.nn.Module): 60 | """ 61 | Mixture of expert with Positionwise feed forward layer 62 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 63 | The output dim is same with the input dim. 64 | 65 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 66 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 67 | Args: 68 | n_expert: number of expert. 69 | n_expert_per_token: The actual number of experts used for each frame 70 | idim (int): Input dimenstion. 71 | hidden_units (int): The number of hidden units. 72 | dropout_rate (float): Dropout rate. 73 | activation (torch.nn.Module): Activation function 74 | """ 75 | 76 | def __init__( 77 | self, 78 | n_expert: int, 79 | n_expert_per_token: int, 80 | idim: int, 81 | hidden_units: int, 82 | dropout_rate: float, 83 | activation: torch.nn.Module = torch.nn.ReLU(), 84 | bias: bool = False, 85 | ): 86 | super(MoEFFNLayer, self).__init__() 87 | bias = False 88 | self.gate = torch.nn.Linear(idim, n_expert, bias=bias) 89 | self.experts = torch.nn.ModuleList( 90 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 91 | activation) for _ in range(n_expert)) 92 | self.n_expert_per_token = n_expert_per_token 93 | 94 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 95 | """Foward function. 96 | Args: 97 | xs: input tensor (B, L, D) 98 | Returns: 99 | output tensor, (B, L, D) 100 | 101 | """ 102 | B, L, D = xs.size( 103 | ) # batch size, sequence length, embedding dimension (idim) 104 | xs = xs.view(-1, D) # (B*L, D) 105 | router = self.gate(xs) # (B*L, n_expert) 106 | logits, indices = torch.topk( 107 | router, self.n_expert_per_token 108 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 109 | weights = torch.nn.functional.softmax( 110 | logits, dim=1, 111 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 112 | output = torch.zeros_like(xs) # (B*L, D) 113 | for i, expert in enumerate(self.experts): 114 | mask = indices == i 115 | batch_idx, ith_expert = torch.where(mask) 116 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 117 | xs[batch_idx]) 118 | return output.view(B, L, D) 119 | 120 | 121 | class GatedVariantsMLP(torch.nn.Module): 122 | """ https://arxiv.org/pdf/2002.05202.pdf 123 | """ 124 | 125 | def __init__( 126 | self, 127 | idim: int, 128 | hidden_units: int, 129 | dropout_rate: float, 130 | activation: torch.nn.Module = torch.nn.GELU(), 131 | bias: bool = True, 132 | ): 133 | """Construct a PositionwiseFeedForward object.""" 134 | super(GatedVariantsMLP, self).__init__() 135 | self.gate = torch.nn.Linear(idim, hidden_units, bias=False) 136 | self.activation = activation 137 | # w_1 as up proj 138 | self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias) 139 | self.dropout = torch.nn.Dropout(dropout_rate) 140 | # w_2 as down proj 141 | self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) 142 | 143 | def forward(self, x): 144 | """Foward function. 145 | Args: 146 | xs: input tensor (B, L, D) 147 | Returns: 148 | output tensor, (B, L, D) 149 | 150 | """ 151 | gate = self.activation(self.gate(x)) 152 | up = self.w_1(x) 153 | fuse = gate * up 154 | return self.w_2(self.dropout(fuse)) 155 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Dict, Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | from wenet.utils.class_utils import WENET_NORM_CLASSES 22 | 23 | 24 | class DecoderLayer(nn.Module): 25 | """Single decoder layer module. 26 | 27 | Args: 28 | size (int): Input dimension. 29 | self_attn (torch.nn.Module): Self-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | src_attn (torch.nn.Module): Inter-attention module instance. 32 | `MultiHeadedAttention` instance can be used as the argument. 33 | If `None` is passed, Inter-attention is not used, such as 34 | CIF, GPT, and other decoder only model. 35 | feed_forward (torch.nn.Module): Feed-forward module instance. 36 | `PositionwiseFeedForward` instance can be used as the argument. 37 | dropout_rate (float): Dropout rate. 38 | normalize_before (bool): 39 | True: use layer_norm before each sub-block. 40 | False: to use layer_norm after each sub-block. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | size: int, 46 | self_attn: nn.Module, 47 | src_attn: Optional[nn.Module], 48 | feed_forward: nn.Module, 49 | dropout_rate: float, 50 | normalize_before: bool = True, 51 | layer_norm_type: str = 'layer_norm', 52 | norm_eps: float = 1e-5, 53 | ): 54 | """Construct an DecoderLayer object.""" 55 | super().__init__() 56 | self.size = size 57 | self.self_attn = self_attn 58 | self.src_attn = src_attn 59 | self.feed_forward = feed_forward 60 | assert layer_norm_type in ['layer_norm', 'rms_norm'] 61 | self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) 62 | self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) 63 | self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) 64 | self.dropout = nn.Dropout(dropout_rate) 65 | self.normalize_before = normalize_before 66 | 67 | def forward( 68 | self, 69 | tgt: torch.Tensor, 70 | tgt_mask: torch.Tensor, 71 | memory: torch.Tensor, 72 | memory_mask: torch.Tensor, 73 | cache: Optional[Dict[str, Optional[torch.Tensor]]] = None 74 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 75 | """Compute decoded features. 76 | 77 | Args: 78 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 79 | tgt_mask (torch.Tensor): Mask for input tensor 80 | (#batch, maxlen_out). 81 | memory (torch.Tensor): Encoded memory 82 | (#batch, maxlen_in, size). 83 | memory_mask (torch.Tensor): Encoded memory mask 84 | (#batch, maxlen_in). 85 | cache (torch.Tensor): cached tensors. 86 | (#batch, maxlen_out - 1, size). 87 | 88 | Returns: 89 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 90 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 91 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 92 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 93 | 94 | """ 95 | if cache is not None: 96 | att_cache = cache['self_att_cache'] 97 | cross_att_cache = cache['cross_att_cache'] 98 | else: 99 | att_cache, cross_att_cache = None, None 100 | 101 | residual = tgt 102 | if self.normalize_before: 103 | tgt = self.norm1(tgt) 104 | 105 | if att_cache is None: 106 | tgt_q = tgt 107 | tgt_q_mask = tgt_mask 108 | att_cache = torch.empty(0, 0, 0, 0) 109 | else: 110 | tgt_q = tgt[:, -1:, :] 111 | residual = residual[:, -1:, :] 112 | tgt_q_mask = tgt_mask[:, -1:, :] 113 | 114 | x, new_att_cache = self.self_attn( 115 | tgt_q, 116 | tgt_q, 117 | tgt_q, 118 | tgt_q_mask, 119 | cache=att_cache, 120 | ) 121 | if cache is not None: 122 | cache['self_att_cache'] = new_att_cache 123 | x = residual + self.dropout(x) 124 | if not self.normalize_before: 125 | x = self.norm1(x) 126 | 127 | if self.src_attn is not None: 128 | residual = x 129 | if self.normalize_before: 130 | x = self.norm2(x) 131 | if cross_att_cache is None: 132 | cross_att_cache = torch.empty(0, 0, 0, 0) 133 | x, new_cross_cache = self.src_attn(x, 134 | memory, 135 | memory, 136 | memory_mask, 137 | cache=cross_att_cache) 138 | if cache is not None: 139 | cache['cross_att_cache'] = new_cross_cache 140 | x = residual + self.dropout(x) 141 | if not self.normalize_before: 142 | x = self.norm2(x) 143 | 144 | residual = x 145 | if self.normalize_before: 146 | x = self.norm3(x) 147 | x = residual + self.dropout(self.feed_forward(x)) 148 | if not self.normalize_before: 149 | x = self.norm3(x) 150 | 151 | return x, tgt_mask, memory, memory_mask 152 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/efficient_conformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2022 58.com(Wuba) Inc AI Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | from typing import Tuple 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class ConvolutionModule(nn.Module): 24 | """ConvolutionModule in Conformer model.""" 25 | 26 | def __init__(self, 27 | channels: int, 28 | kernel_size: int = 15, 29 | activation: nn.Module = nn.ReLU(), 30 | norm: str = "batch_norm", 31 | causal: bool = False, 32 | bias: bool = True, 33 | stride: int = 1): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | stride (int): Stride Convolution, for efficient Conformer 40 | """ 41 | super().__init__() 42 | 43 | self.pointwise_conv1 = nn.Conv1d( 44 | channels, 45 | 2 * channels, 46 | kernel_size=1, 47 | stride=1, 48 | padding=0, 49 | bias=bias, 50 | ) 51 | # self.lorder is used to distinguish if it's a causal convolution, 52 | # if self.lorder > 0: it's a causal convolution, the input will be 53 | # padded with self.lorder frames on the left in forward. 54 | # else: it's a symmetrical convolution 55 | if causal: 56 | padding = 0 57 | self.lorder = kernel_size - 1 58 | else: 59 | # kernel_size should be an odd number for none causal convolution 60 | assert (kernel_size - 1) % 2 == 0 61 | padding = (kernel_size - 1) // 2 62 | self.lorder = 0 63 | 64 | self.depthwise_conv = nn.Conv1d( 65 | channels, 66 | channels, 67 | kernel_size, 68 | stride=stride, # for depthwise_conv in StrideConv 69 | padding=padding, 70 | groups=channels, 71 | bias=bias, 72 | ) 73 | 74 | assert norm in ['batch_norm', 'layer_norm'] 75 | if norm == "batch_norm": 76 | self.use_layer_norm = False 77 | self.norm = nn.BatchNorm1d(channels) 78 | else: 79 | self.use_layer_norm = True 80 | self.norm = nn.LayerNorm(channels) 81 | 82 | self.pointwise_conv2 = nn.Conv1d( 83 | channels, 84 | channels, 85 | kernel_size=1, 86 | stride=1, 87 | padding=0, 88 | bias=bias, 89 | ) 90 | self.activation = activation 91 | self.stride = stride 92 | 93 | def forward( 94 | self, 95 | x: torch.Tensor, 96 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 97 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 98 | ) -> Tuple[torch.Tensor, torch.Tensor]: 99 | """Compute convolution module. 100 | Args: 101 | x (torch.Tensor): Input tensor (#batch, time, channels). 102 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 103 | (0, 0, 0) means fake mask. 104 | cache (torch.Tensor): left context cache, it is only 105 | used in causal convolution (#batch, channels, cache_t), 106 | (0, 0, 0) meas fake cache. 107 | Returns: 108 | torch.Tensor: Output tensor (#batch, time, channels). 109 | """ 110 | # exchange the temporal dimension and the feature dimension 111 | x = x.transpose(1, 2) # (#batch, channels, time) 112 | 113 | # mask batch padding 114 | if mask_pad.size(2) > 0: # time > 0 115 | x.masked_fill_(~mask_pad, 0.0) 116 | 117 | if self.lorder > 0: 118 | if cache.size(2) == 0: # cache_t == 0 119 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 120 | else: 121 | # When export ONNX,the first cache is not None but all-zero, 122 | # cause shape error in residual block, 123 | # eg. cache14 + x9 = 23, 23-7+1=17 != 9 124 | cache = cache[:, :, -self.lorder:] 125 | assert cache.size(0) == x.size(0) # equal batch 126 | assert cache.size(1) == x.size(1) # equal channel 127 | x = torch.cat((cache, x), dim=2) 128 | assert (x.size(2) > self.lorder) 129 | new_cache = x[:, :, -self.lorder:] 130 | else: 131 | # It's better we just return None if no cache is requried, 132 | # However, for JIT export, here we just fake one tensor instead of 133 | # None. 134 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 135 | 136 | # GLU mechanism 137 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 138 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 139 | 140 | # 1D Depthwise Conv 141 | x = self.depthwise_conv(x) 142 | if self.use_layer_norm: 143 | x = x.transpose(1, 2) 144 | x = self.activation(self.norm(x)) 145 | if self.use_layer_norm: 146 | x = x.transpose(1, 2) 147 | x = self.pointwise_conv2(x) 148 | # mask batch padding 149 | if mask_pad.size(2) > 0: # time > 0 150 | if mask_pad.size(2) != x.size(2): 151 | mask_pad = mask_pad[:, :, ::self.stride] 152 | x.masked_fill_(~mask_pad, 0.0) 153 | 154 | return x.transpose(1, 2), new_cache 155 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/ssl/bestrq/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def _sampler(pdf: torch.Tensor, num_samples: int, 6 | device=torch.device('cpu')) -> torch.Tensor: 7 | size = pdf.size() 8 | z = -torch.log(torch.rand(size, device=device)) 9 | _, indices = torch.topk(pdf + z, num_samples) 10 | return indices 11 | 12 | 13 | def compute_mask_indices( 14 | size: torch.Size, 15 | mask_prob: float, 16 | mask_length: int, 17 | min_masks: int = 0, 18 | device=torch.device('cpu'), 19 | ) -> torch.Tensor: 20 | 21 | assert len(size) == 2 22 | batch_size, seq_length = size 23 | 24 | # compute number of masked span in batch 25 | num_masked_spans = mask_prob * float(seq_length) / float( 26 | mask_length) + torch.rand(1)[0] 27 | num_masked_spans = int(num_masked_spans) 28 | num_masked_spans = max(num_masked_spans, min_masks) 29 | 30 | # num_masked <= seq_length 31 | if num_masked_spans * mask_length > seq_length: 32 | num_masked_spans = seq_length // mask_length 33 | 34 | pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) 35 | mask_idxs = _sampler(pdf, num_masked_spans, device=device) 36 | 37 | mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( 38 | batch_size, 39 | num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] 40 | 41 | offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( 42 | 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] 43 | offset = offset.view(1, num_masked_spans * mask_length) 44 | 45 | mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] 46 | 47 | ones = torch.ones(batch_size, 48 | seq_length, 49 | dtype=torch.bool, 50 | device=mask_idxs.device) 51 | # masks to fill 52 | full_mask = torch.zeros_like(ones, 53 | dtype=torch.bool, 54 | device=mask_idxs.device) 55 | return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones) 56 | 57 | 58 | def compute_mask_indices_v2( 59 | shape, 60 | padding_mask, 61 | mask_prob: float, 62 | mask_length: int, 63 | mask_type: str = 'static', 64 | mask_other: float = 0.0, 65 | min_masks: int = 2, 66 | no_overlap: bool = False, 67 | min_space: int = 1, 68 | device=torch.device('cpu'), 69 | ): 70 | bsz, all_sz = shape 71 | mask = np.full((bsz, all_sz), False) 72 | padding_mask = padding_mask.cpu().numpy() 73 | all_num_mask = int( 74 | # add a random number for probabilistic rounding 75 | mask_prob * all_sz / float(mask_length) + np.random.rand()) 76 | 77 | all_num_mask = max(min_masks, all_num_mask) 78 | 79 | mask_idcs = [] 80 | for i in range(bsz): 81 | if padding_mask is not None and not isinstance(padding_mask, bytes): 82 | sz = all_sz - padding_mask[i].sum() 83 | num_mask = int( 84 | # add a random number for probabilistic rounding 85 | mask_prob * sz / float(mask_length) + np.random.rand()) 86 | num_mask = max(min_masks, num_mask) 87 | else: 88 | sz = all_sz 89 | num_mask = all_num_mask 90 | 91 | if mask_type == 'static': 92 | lengths = np.full(num_mask, mask_length) 93 | elif mask_type == 'uniform': 94 | lengths = np.random.randint(mask_other, 95 | mask_length * 2 + 1, 96 | size=num_mask) 97 | elif mask_type == 'normal': 98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 99 | lengths = [max(1, int(round(x))) for x in lengths] 100 | elif mask_type == 'poisson': 101 | lengths = np.random.poisson(mask_length, size=num_mask) 102 | lengths = [int(round(x)) for x in lengths] 103 | else: 104 | raise Exception('unknown mask selection ' + mask_type) 105 | 106 | if sum(lengths) == 0: 107 | lengths[0] = min(mask_length, sz - 1) 108 | 109 | if no_overlap: 110 | mask_idc = [] 111 | 112 | def arrange(s, e, length, keep_length, mask_idc): 113 | span_start = np.random.randint(s, e - length) 114 | mask_idc.extend(span_start + i for i in range(length)) 115 | 116 | new_parts = [] 117 | if span_start - s - min_space >= keep_length: 118 | new_parts.append((s, span_start - min_space + 1)) 119 | if e - span_start - keep_length - min_space > keep_length: 120 | new_parts.append((span_start + length + min_space, e)) 121 | return new_parts 122 | 123 | parts = [(0, sz)] 124 | min_length = min(lengths) 125 | for length in sorted(lengths, reverse=True): 126 | lens = np.fromiter( 127 | (e - s if e - s >= length + min_space else 0 128 | for s, e in parts), 129 | np.int, 130 | ) 131 | l_sum = np.sum(lens) 132 | if l_sum == 0: 133 | break 134 | probs = lens / np.sum(lens) 135 | c = np.random.choice(len(parts), p=probs) 136 | s, e = parts.pop(c) 137 | parts.extend(arrange(s, e, length, min_length, mask_idc)) 138 | mask_idc = np.asarray(mask_idc) 139 | else: 140 | min_len = min(lengths) 141 | if sz - min_len <= num_mask: 142 | min_len = sz - num_mask - 1 143 | 144 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 145 | 146 | mask_idc = np.asarray([ 147 | mask_idc[j] + offset for j in range(len(mask_idc)) 148 | for offset in range(lengths[j]) 149 | ]) 150 | 151 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 152 | 153 | min_len = min([len(m) for m in mask_idcs]) 154 | for i, mask_idc in enumerate(mask_idcs): 155 | if len(mask_idc) > min_len: 156 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 157 | mask[i, mask_idc] = True 158 | 159 | mask = torch.from_numpy(mask).to(device) 160 | return mask 161 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/transducer/search/prefix_beam_search.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from wenet.utils.common import log_add 5 | 6 | 7 | class Sequence(): 8 | 9 | __slots__ = {'hyp', 'score', 'cache'} 10 | 11 | def __init__( 12 | self, 13 | hyp: List[torch.Tensor], 14 | score, 15 | cache: List[torch.Tensor], 16 | ): 17 | self.hyp = hyp 18 | self.score = score 19 | self.cache = cache 20 | 21 | 22 | class PrefixBeamSearch(): 23 | 24 | def __init__(self, encoder, predictor, joint, ctc, blank): 25 | self.encoder = encoder 26 | self.predictor = predictor 27 | self.joint = joint 28 | self.ctc = ctc 29 | self.blank = blank 30 | 31 | def forward_decoder_one_step( 32 | self, encoder_x: torch.Tensor, pre_t: torch.Tensor, 33 | cache: List[torch.Tensor] 34 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 35 | padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device) 36 | pre_t, new_cache = self.predictor.forward_step(pre_t.unsqueeze(-1), 37 | padding, cache) 38 | x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab] 39 | x = x.log_softmax(dim=-1) 40 | return x, new_cache 41 | 42 | def prefix_beam_search(self, 43 | speech: torch.Tensor, 44 | speech_lengths: torch.Tensor, 45 | decoding_chunk_size: int = -1, 46 | beam_size: int = 5, 47 | num_decoding_left_chunks: int = -1, 48 | simulate_streaming: bool = False, 49 | ctc_weight: float = 0.3, 50 | transducer_weight: float = 0.7): 51 | """prefix beam search 52 | also see wenet.transducer.transducer.beam_search 53 | """ 54 | assert speech.shape[0] == speech_lengths.shape[0] 55 | assert decoding_chunk_size != 0 56 | device = speech.device 57 | batch_size = speech.shape[0] 58 | assert batch_size == 1 59 | 60 | # 1. Encoder 61 | encoder_out, _ = self.encoder( 62 | speech, speech_lengths, decoding_chunk_size, 63 | num_decoding_left_chunks) # (B, maxlen, encoder_dim) 64 | maxlen = encoder_out.size(1) 65 | 66 | ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) 67 | beam_init: List[Sequence] = [] 68 | 69 | # 2. init beam using Sequence to save beam unit 70 | cache = self.predictor.init_state(1, method="zero", device=device) 71 | beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache)) 72 | # 3. start decoding (notice: we use breathwise first searching) 73 | # !!!! In this decoding method: one frame do not output multi units. !!!! 74 | # !!!! Experiments show that this strategy has little impact !!!! 75 | for i in range(maxlen): 76 | # 3.1 building input 77 | # decoder taking the last token to predict the next token 78 | input_hyp = [s.hyp[-1] for s in beam_init] 79 | input_hyp_tensor = torch.tensor(input_hyp, 80 | dtype=torch.int, 81 | device=device) 82 | # building statement from beam 83 | cache_batch = self.predictor.cache_to_batch( 84 | [s.cache for s in beam_init]) 85 | # build score tensor to do torch.add() function 86 | scores = torch.tensor([s.score for s in beam_init]).to(device) 87 | 88 | # 3.2 forward decoder 89 | logp, new_cache = self.forward_decoder_one_step( 90 | encoder_out[:, i, :].unsqueeze(1), 91 | input_hyp_tensor, 92 | cache_batch, 93 | ) # logp: (N, 1, 1, vocab_size) 94 | logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size) 95 | new_cache = self.predictor.batch_to_cache(new_cache) 96 | 97 | # 3.3 shallow fusion for transducer score 98 | # and ctc score where we can also add the LM score 99 | logp = torch.log( 100 | torch.add(transducer_weight * torch.exp(logp), 101 | ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)))) 102 | 103 | # 3.4 first beam prune 104 | top_k_logp, top_k_index = logp.topk(beam_size) # (N, N) 105 | scores = torch.add(scores.unsqueeze(1), top_k_logp) 106 | 107 | # 3.5 generate new beam (N*N) 108 | beam_A = [] 109 | for j in range(len(beam_init)): 110 | # update seq 111 | base_seq = beam_init[j] 112 | for t in range(beam_size): 113 | # blank: only update the score 114 | if top_k_index[j, t] == self.blank: 115 | new_seq = Sequence(hyp=base_seq.hyp.copy(), 116 | score=scores[j, t].item(), 117 | cache=base_seq.cache) 118 | 119 | beam_A.append(new_seq) 120 | # other unit: update hyp score statement and last 121 | else: 122 | hyp_new = base_seq.hyp.copy() 123 | hyp_new.append(top_k_index[j, t].item()) 124 | new_seq = Sequence(hyp=hyp_new, 125 | score=scores[j, t].item(), 126 | cache=new_cache[j]) 127 | beam_A.append(new_seq) 128 | 129 | # 3.6 prefix fusion 130 | fusion_A = [beam_A[0]] 131 | for j in range(1, len(beam_A)): 132 | s1 = beam_A[j] 133 | if_do_append = True 134 | for t in range(len(fusion_A)): 135 | # notice: A_ can not fusion with A 136 | if s1.hyp == fusion_A[t].hyp: 137 | fusion_A[t].score = log_add( 138 | [fusion_A[t].score, s1.score]) 139 | if_do_append = False 140 | break 141 | if if_do_append: 142 | fusion_A.append(s1) 143 | 144 | # 4. second pruned 145 | fusion_A.sort(key=lambda x: x.score, reverse=True) 146 | beam_init = fusion_A[:beam_size] 147 | 148 | return beam_init, encoder_out 149 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/ctc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Tuple 16 | 17 | import numpy as np 18 | 19 | import torch 20 | 21 | 22 | def remove_duplicates_and_blank(hyp: List[int], 23 | blank_id: int = 0) -> List[int]: 24 | new_hyp: List[int] = [] 25 | cur = 0 26 | while cur < len(hyp): 27 | if hyp[cur] != blank_id: 28 | new_hyp.append(hyp[cur]) 29 | prev = cur 30 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 31 | cur += 1 32 | return new_hyp 33 | 34 | 35 | def replace_duplicates_with_blank(hyp: List[int], 36 | blank_id: int = 0) -> List[int]: 37 | new_hyp: List[int] = [] 38 | cur = 0 39 | while cur < len(hyp): 40 | new_hyp.append(hyp[cur]) 41 | prev = cur 42 | cur += 1 43 | while cur < len( 44 | hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: 45 | new_hyp.append(blank_id) 46 | cur += 1 47 | return new_hyp 48 | 49 | 50 | def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]: 51 | times = [] 52 | cur = 0 53 | while cur < len(hyp): 54 | if hyp[cur] != blank_id: 55 | times.append(cur) 56 | prev = cur 57 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 58 | cur += 1 59 | return times 60 | 61 | 62 | def gen_timestamps_from_peak( 63 | peaks: List[int], 64 | max_duration: float, 65 | frame_rate: float = 0.04, 66 | max_token_duration: float = 1.0, 67 | ) -> List[Tuple[float, float]]: 68 | """ 69 | Args: 70 | peaks: ctc peaks time stamp 71 | max_duration: max_duration of the sentence 72 | frame_rate: frame rate of every time stamp, in seconds 73 | max_token_duration: max duration of the token, in seconds 74 | Returns: 75 | list(start, end) of each token 76 | """ 77 | times = [] 78 | half_max = max_token_duration / 2 79 | for i in range(len(peaks)): 80 | if i == 0: 81 | start = max(0, peaks[0] * frame_rate - half_max) 82 | else: 83 | start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate, 84 | peaks[i] * frame_rate - half_max) 85 | 86 | if i == len(peaks) - 1: 87 | end = min(max_duration, peaks[-1] * frame_rate + half_max) 88 | else: 89 | end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate, 90 | peaks[i] * frame_rate + half_max) 91 | times.append((start, end)) 92 | return times 93 | 94 | 95 | def insert_blank(label, blank_id=0): 96 | """Insert blank token between every two label token.""" 97 | label = np.expand_dims(label, 1) 98 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 99 | label = np.concatenate([blanks, label], axis=1) 100 | label = label.reshape(-1) 101 | label = np.append(label, label[0]) 102 | return label 103 | 104 | 105 | def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: 106 | """ctc forced alignment. 107 | 108 | Args: 109 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 110 | torch.Tensor y: id sequence tensor 1d tensor (L) 111 | int blank_id: blank symbol index 112 | Returns: 113 | torch.Tensor: alignment result 114 | """ 115 | ctc_probs = ctc_probs.cpu() 116 | y = y.cpu() 117 | y_insert_blank = insert_blank(y, blank_id) 118 | 119 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 120 | log_alpha = log_alpha - float('inf') # log of zero 121 | state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)), 122 | dtype=torch.int16) - 1 # state path 123 | 124 | # init start state 125 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 126 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 127 | 128 | for t in range(1, ctc_probs.size(0)): 129 | for s in range(len(y_insert_blank)): 130 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 131 | s] == y_insert_blank[s - 2]: 132 | candidates = torch.tensor( 133 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 134 | prev_state = [s, s - 1] 135 | else: 136 | candidates = torch.tensor([ 137 | log_alpha[t - 1, s], 138 | log_alpha[t - 1, s - 1], 139 | log_alpha[t - 1, s - 2], 140 | ]) 141 | prev_state = [s, s - 1, s - 2] 142 | log_alpha[ 143 | t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 144 | state_path[t, s] = prev_state[torch.argmax(candidates)] 145 | 146 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 147 | 148 | candidates = torch.tensor([ 149 | log_alpha[-1, len(y_insert_blank) - 1], 150 | log_alpha[-1, len(y_insert_blank) - 2] 151 | ]) 152 | final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 153 | state_seq[-1] = final_state[torch.argmax(candidates)] 154 | for t in range(ctc_probs.size(0) - 2, -1, -1): 155 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 156 | 157 | output_alignment = [] 158 | for t in range(0, ctc_probs.size(0)): 159 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 160 | 161 | return output_alignment 162 | 163 | 164 | def get_blank_id(configs, symbol_table): 165 | if 'ctc_conf' not in configs: 166 | configs['ctc_conf'] = {} 167 | 168 | if '' in symbol_table: 169 | if 'ctc_blank_id' in configs['ctc_conf']: 170 | assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[ 171 | ''] 172 | else: 173 | configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] 174 | else: 175 | assert 'ctc_blank_id' in configs[ 176 | 'ctc_conf'], "PLZ set ctc_blank_id in yaml" 177 | 178 | return configs, configs['ctc_conf']['ctc_blank_id'] 179 | -------------------------------------------------------------------------------- /wenet_asr_client/recognizer.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | 3 | from lib import wenet_asr_pb2 4 | from lib import wenet_asr_pb2_grpc 5 | 6 | 7 | class Recognizer: 8 | def __init__(self, sample_rate: int, 9 | bit_depth: int = 16, 10 | update_duration: int = 10, 11 | puncting_len: int = 40) -> None: 12 | ''' 13 | 初始化音频相关参数。仅支持单声道,故不设channels参数。 14 | 15 | Parameters: 16 | sample_rate (int): 音频数据采样率。 17 | bit_depth (int): 音频数据位深度。 18 | update_duration (int): 更新currentSession间隔。用于伪流式逻辑。 19 | puncting_len: (int): puncting的自动更新长度。用于流式识别时的标点预测。 20 | 21 | ''' 22 | 23 | # self.config = wenet_asr_pb2.RecognitionConfig( 24 | # sample_rate_hertz=sample_rate, 25 | # ) 26 | self.sample_rate = sample_rate 27 | self.sampwidth = bit_depth / 8 if bit_depth else None 28 | self.update_duration = update_duration 29 | 30 | self.connected = False 31 | self.channel = None 32 | self.stub: wenet_asr_pb2_grpc.WenetASRStub = None 33 | 34 | self.streaming = False 35 | self.datastream = [] 36 | self.currentSession = b'' 37 | self.fixed = '' 38 | self.new = '' 39 | 40 | self.punctuation = True 41 | self.puncting = '' 42 | self.puncted = '' 43 | self.puncted_temp = '' 44 | self.puncfixed = '' 45 | self.puncting_len = puncting_len 46 | 47 | @property 48 | def config(self): 49 | return wenet_asr_pb2.RecognitionConfig( 50 | sample_rate_hertz=self.sample_rate, 51 | ) 52 | 53 | @property 54 | def result(self): 55 | if self.punctuation: 56 | return { 57 | 'fixed': self.puncfixed, 58 | 'puncted': self.puncted, 59 | 'new': self.new, 60 | 'text': self.puncfixed + self.puncted + self.new, 61 | } 62 | return { 63 | 'fixed': self.fixed, 64 | 'puncted': self.puncfixed + self.puncted, 65 | 'new': self.new, 66 | 'text': self.fixed + self.new 67 | } 68 | 69 | # Build connection between client and server. 70 | def connect(self, server_port) -> None: 71 | try: 72 | self.channel = grpc.insecure_channel(server_port) 73 | self.stub = wenet_asr_pb2_grpc.WenetASRStub(self.channel) 74 | self.connected = True 75 | print(f'Server ID: {self.stub.GetServerID(wenet_asr_pb2.Empty()).text}') 76 | except Exception as e: 77 | raise Exception(f'Error connecting to server {server_port}: {e}') 78 | 79 | # Stop the connection between client and server. 80 | def disconnect(self) -> None: 81 | self.connected = False 82 | self.stub = None 83 | self.channel = None 84 | 85 | def start_streaming(self, punctuation=True): 86 | self.streaming = True 87 | self.punctuation = punctuation 88 | 89 | def stop_streaming(self): 90 | self.streaming = False 91 | 92 | def reload_model(self, 93 | asr: bool = True, 94 | model: str = 'chinese', 95 | hotwords: str = None, 96 | context_score: int = 3, 97 | punctuation: bool = False, 98 | punc_model: str = 'pun_models'): 99 | ''' 100 | 重新加载模型。 101 | 102 | 参数: 103 | asr (bool): 是否重新加载asr模型。 104 | model (str): 所要加载的asr模型。 105 | hotwords (str): asr模型所要加载的热词增强文件。 106 | context_score (int): 热词增强中每个字的分数。 107 | punctuation (bool): 是否重新加载标点符号预测模型。 108 | punc_model (str): 所要加载的标点符号预测模型。 109 | 110 | ''' 111 | self.stub.ReloadModel(wenet_asr_pb2.ReloadModelRequest(asr=asr, 112 | model=model, 113 | hotwords='None' if hotwords is None else hotwords, 114 | context_score=context_score, 115 | punctuation=punctuation, 116 | punctuation_model=punc_model)) 117 | 118 | # Upload a new frame of data to server (and trigger recognition logic if streaming). 119 | def input(self, data: bytes) -> None: 120 | ''' 121 | 将一段数据加入currentSession尾部。流式识别情况下,加入新数据后自动执行伪流式逻辑。 122 | 123 | Parameters: 124 | data (bytes): 待加入的音频数据。 125 | 126 | ''' 127 | self.datastream.append(data) 128 | self.currentSession += data 129 | 130 | # 伪流式逻辑 131 | if self.streaming: 132 | self.new = self.recognize() 133 | # 实时添加标点 134 | if self.punctuation and len(self.puncting + self.new): 135 | self.puncted = self.punct(self.puncting + self.new)[:-1] 136 | 137 | if len(self.currentSession) / self.sampwidth / self.sample_rate > self.update_duration: # / self.channels 138 | self.fixed += self.new 139 | self.puncting += self.new 140 | 141 | # 固定化标点 142 | if self.punctuation: 143 | # self.puncted = self.punct() 144 | if len(self.puncting) > self.puncting_len: 145 | self.puncfixed += self.punct(self.puncting)[:-1] 146 | self.puncted = '' 147 | self.puncting = '' 148 | 149 | self.new = '' 150 | self.currentSession = b'' 151 | 152 | 153 | # Conduct recognition logic. 154 | def recognize(self, data=None, punctuation=False) -> str: 155 | ''' 156 | 对输入音频执行单次语音识别。 157 | 158 | Parameters: 159 | data (bytes): 输入音频数据。如果不传入数据(如流式识别使用input函数情况下),则默认使用currentSession中的数据。 160 | punctuation (bool): 用于非流式识别时指定是否同时进行标点预测。 161 | 162 | Returns: 163 | str: 单次识别结果。 164 | 165 | ''' 166 | if data is None: 167 | data = self.currentSession 168 | if not len(data): 169 | return '' 170 | text = self.stub.Recognize(wenet_asr_pb2.RecognizeRequest(config=self.config, data=data)).transcript 171 | if punctuation: 172 | return self.stub.Punct(wenet_asr_pb2.TextMessage(text=text)).text 173 | return text 174 | 175 | def punct(self, text=None) -> str: 176 | ''' 177 | 对无标点文本执行标点预测。 178 | 179 | Parameters: 180 | text (string): 无标点文本。若不输入文本,则默认使用puncting文本。 181 | 182 | Returns: 183 | str: 标点预测后文本。 184 | 185 | ''' 186 | if text is None: 187 | text = self.puncting 188 | if not len(text): 189 | return '' 190 | return self.stub.Punct(wenet_asr_pb2.TextMessage(text=text)).text 191 | 192 | 193 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import datetime 17 | import logging 18 | import sys 19 | from contextlib import nullcontext 20 | 21 | # if your python version < 3.7 use the below one 22 | # from contextlib import suppress as nullcontext 23 | import torch 24 | from wenet.utils.common import StepTimer 25 | 26 | from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward, 27 | update_parameter_and_lr, log_per_step, 28 | save_model) 29 | 30 | 31 | class Executor: 32 | 33 | def __init__(self, global_step: int = 0): 34 | self.step = global_step 35 | self.train_step_timer = None 36 | self.cv_step_timer = None 37 | 38 | def train(self, model, optimizer, scheduler, train_data_loader, 39 | cv_data_loader, writer, configs, scaler, group_join): 40 | ''' Train one epoch 41 | ''' 42 | if self.train_step_timer is None: 43 | self.train_step_timer = StepTimer(self.step) 44 | model.train() 45 | info_dict = copy.deepcopy(configs) 46 | logging.info('using accumulate grad, new batch size is {} times' 47 | ' larger than before'.format(info_dict['accum_grad'])) 48 | # A context manager to be used in conjunction with an instance of 49 | # torch.nn.parallel.DistributedDataParallel to be able to train 50 | # with uneven inputs across participating processes. 51 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 52 | model_context = model.join 53 | else: 54 | model_context = nullcontext 55 | 56 | with model_context(): 57 | for batch_idx, batch_dict in enumerate(train_data_loader): 58 | info_dict["tag"] = "TRAIN" 59 | info_dict["step"] = self.step 60 | info_dict["batch_idx"] = batch_idx 61 | if wenet_join(group_join, info_dict): 62 | break 63 | 64 | if batch_dict["target_lengths"].size(0) == 0: 65 | continue 66 | 67 | context = None 68 | # Disable gradient synchronizations across DDP processes. 69 | # Within this context, gradients will be accumulated on module 70 | # variables, which will later be synchronized. 71 | if info_dict.get("train_engine", "torch_ddp") == "torch_ddp" and \ 72 | (batch_idx + 1) % info_dict["accum_grad"] != 0: 73 | context = model.no_sync 74 | # Used for single gpu training and DDP gradient synchronization 75 | # processes. 76 | else: 77 | context = nullcontext 78 | 79 | with context(): 80 | info_dict = batch_forward(model, batch_dict, scaler, 81 | info_dict) 82 | info_dict = batch_backward(model, scaler, info_dict) 83 | 84 | info_dict = update_parameter_and_lr(model, optimizer, 85 | scheduler, scaler, 86 | info_dict) 87 | save_interval = info_dict.get('save_interval', sys.maxsize) 88 | if self.step % save_interval == 0 and self.step != 0 \ 89 | and (batch_idx + 1) % info_dict["accum_grad"] == 0: 90 | loss_dict = self.cv(model, cv_data_loader, configs) 91 | model.train() 92 | info_dict.update({ 93 | "tag": 94 | "step_{}".format(self.step), 95 | "loss_dict": 96 | loss_dict, 97 | "save_time": 98 | datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), 99 | "lr": 100 | optimizer.param_groups[0]['lr'] 101 | }) 102 | save_model(model, info_dict) 103 | log_per_step(writer, info_dict, timer=self.train_step_timer) 104 | self.step += 1 if (batch_idx + 105 | 1) % info_dict["accum_grad"] == 0 else 0 106 | 107 | 108 | def cv(self, model, cv_data_loader, configs): 109 | ''' Cross validation on 110 | ''' 111 | if self.cv_step_timer is None: 112 | self.cv_step_timer = StepTimer(0.0) 113 | else: 114 | self.cv_step_timer.last_iteration = 0.0 115 | model.eval() 116 | info_dict = copy.deepcopy(configs) 117 | num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0 118 | with torch.no_grad(): 119 | for batch_idx, batch_dict in enumerate(cv_data_loader): 120 | info_dict["tag"] = "CV" 121 | info_dict["step"] = self.step 122 | info_dict["batch_idx"] = batch_idx 123 | info_dict["cv_step"] = batch_idx 124 | 125 | num_utts = batch_dict["target_lengths"].size(0) 126 | if num_utts == 0: 127 | continue 128 | 129 | info_dict = batch_forward(model, batch_dict, None, info_dict) 130 | _dict = info_dict["loss_dict"] 131 | 132 | num_seen_utts += num_utts 133 | total_acc.append(_dict['th_accuracy'].item( 134 | ) if _dict['th_accuracy'] is not None else 0.0) 135 | for loss_name, loss_value in _dict.items(): 136 | if loss_value is not None and "loss" in loss_name \ 137 | and torch.isfinite(loss_value): 138 | loss_value = loss_value.item() 139 | loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ 140 | loss_value * num_utts 141 | 142 | log_per_step(writer=None, 143 | info_dict=info_dict, 144 | timer=self.cv_step_timer) 145 | for loss_name, loss_value in loss_dict.items(): 146 | loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts 147 | loss_dict["acc"] = sum(total_acc) / len(total_acc) 148 | return loss_dict 149 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/bin/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import datetime 19 | import logging 20 | import os 21 | import torch 22 | import yaml 23 | 24 | import torch.distributed as dist 25 | 26 | from torch.distributed.elastic.multiprocessing.errors import record 27 | 28 | from wenet.utils.executor import Executor 29 | from wenet.utils.config import override_config 30 | from wenet.utils.init_model import init_model 31 | from wenet.utils.init_tokenizer import init_tokenizer 32 | from wenet.utils.train_utils import ( 33 | add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args, 34 | add_trace_args, init_distributed, init_dataset_and_dataloader, 35 | check_modify_and_save_config, init_optimizer_and_scheduler, 36 | trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model, 37 | log_per_epoch) 38 | 39 | 40 | def get_args(): 41 | parser = argparse.ArgumentParser(description='training your network') 42 | parser.add_argument('--train_engine', 43 | default='torch_ddp', 44 | choices=['torch_ddp', 'deepspeed'], 45 | help='Engine for paralleled training') 46 | parser = add_model_args(parser) 47 | parser = add_dataset_args(parser) 48 | parser = add_ddp_args(parser) 49 | parser = add_deepspeed_args(parser) 50 | parser = add_trace_args(parser) 51 | args = parser.parse_args() 52 | if args.train_engine == "deepspeed": 53 | args.deepspeed = True 54 | assert args.deepspeed_config is not None 55 | return args 56 | 57 | 58 | # NOTE(xcsong): On worker errors, this recod tool will summarize the 59 | # details of the error (e.g. time, rank, host, pid, traceback, etc). 60 | @record 61 | def main(): 62 | args = get_args() 63 | logging.basicConfig(level=logging.DEBUG, 64 | format='%(asctime)s %(levelname)s %(message)s') 65 | 66 | # Set random seed 67 | torch.manual_seed(777) 68 | 69 | # Read config 70 | with open(args.config, 'r') as fin: 71 | configs = yaml.load(fin, Loader=yaml.FullLoader) 72 | if len(args.override_config) > 0: 73 | configs = override_config(configs, args.override_config) 74 | 75 | # init tokenizer 76 | tokenizer = init_tokenizer(configs) 77 | 78 | # Init env for ddp OR deepspeed 79 | _, _, rank = init_distributed(args) 80 | 81 | # Get dataset & dataloader 82 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ 83 | init_dataset_and_dataloader(args, configs, tokenizer) 84 | 85 | # Do some sanity checks and save config to arsg.model_dir 86 | configs = check_modify_and_save_config(args, configs, 87 | tokenizer.symbol_table) 88 | 89 | # Init asr model from configs 90 | model, configs = init_model(args, configs) 91 | 92 | # Check model is jitable & print model archtectures 93 | trace_and_print_model(args, model) 94 | 95 | # Tensorboard summary 96 | writer = init_summarywriter(args) 97 | 98 | # Dispatch model from cpu to gpu 99 | model, device = wrap_cuda_model(args, model) 100 | 101 | # Get optimizer & scheduler 102 | model, optimizer, scheduler = init_optimizer_and_scheduler( 103 | args, configs, model) 104 | 105 | # Save checkpoints 106 | save_model(model, 107 | info_dict={ 108 | "save_time": 109 | datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), 110 | "tag": 111 | "init", 112 | **configs 113 | }) 114 | 115 | # Get executor 116 | tag = configs["init_infos"].get("tag", "init") 117 | executor = Executor(global_step=configs["init_infos"].get('step', -1) + 118 | int("step_" in tag)) 119 | 120 | # Init scaler, used for pytorch amp mixed precision training 121 | scaler = None 122 | if args.use_amp: 123 | scaler = torch.cuda.amp.GradScaler() 124 | 125 | # Start training loop 126 | start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) 127 | # if save_interval in configs, steps mode else epoch mode 128 | end_epoch = configs.get('max_epoch', 129 | 100) if "save_interval" not in configs else 1 130 | assert start_epoch <= end_epoch 131 | configs.pop("init_infos", None) 132 | final_epoch = None 133 | for epoch in range(start_epoch, end_epoch): 134 | configs['epoch'] = epoch 135 | 136 | lr = optimizer.param_groups[0]['lr'] 137 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format( 138 | epoch, lr, rank)) 139 | 140 | dist.barrier( 141 | ) # NOTE(xcsong): Ensure all ranks start Train at the same time. 142 | # NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join` 143 | group_join = dist.new_group( 144 | backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) 145 | executor.train(model, optimizer, scheduler, train_data_loader, 146 | cv_data_loader, writer, configs, scaler, group_join) 147 | dist.destroy_process_group(group_join) 148 | 149 | dist.barrier( 150 | ) # NOTE(xcsong): Ensure all ranks start CV at the same time. 151 | loss_dict = executor.cv(model, cv_data_loader, configs) 152 | 153 | lr = optimizer.param_groups[0]['lr'] 154 | logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format( 155 | epoch, lr, loss_dict["loss"], rank, loss_dict["acc"])) 156 | info_dict = { 157 | 'epoch': epoch, 158 | 'lr': lr, 159 | 'step': executor.step, 160 | 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), 161 | 'tag': "epoch_{}".format(epoch), 162 | 'loss_dict': loss_dict, 163 | **configs 164 | } 165 | log_per_epoch(writer, info_dict=info_dict) 166 | save_model(model, info_dict=info_dict) 167 | 168 | final_epoch = epoch 169 | 170 | if final_epoch is not None and rank == 0: 171 | final_model_path = os.path.join(args.model_dir, 'final.pt') 172 | os.remove(final_model_path) if os.path.exists( 173 | final_model_path) else None 174 | os.symlink('{}.pt'.format(final_epoch), final_model_path) 175 | writer.close() 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Wenet Community. (authors: Binbin Zhang) 2 | # 2023 Wenet Community. (authors: Dinghao Zhou) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | import sys 18 | from typing import Optional 19 | from wenet.dataset import processor 20 | from wenet.dataset.datapipes import (WenetRawDatasetSource, 21 | WenetTarShardDatasetSource) 22 | from wenet.text.base_tokenizer import BaseTokenizer 23 | from wenet.utils.file_utils import read_symbol_table 24 | 25 | 26 | def Dataset(data_type, 27 | data_list_file, 28 | tokenizer: Optional[BaseTokenizer] = None, 29 | conf=None, 30 | partition=True): 31 | """ Construct dataset from arguments 32 | 33 | We have two shuffle stage in the Dataset. The first is global 34 | shuffle at shards tar/raw file level. The second is global shuffle 35 | at training samples level. 36 | 37 | Args: 38 | data_type(str): raw/shard 39 | tokenizer (BaseTokenizer or None): tokenizer to tokenize 40 | partition(bool): whether to do data partition in terms of rank 41 | """ 42 | assert conf is not None 43 | assert data_type in ['raw', 'shard'] 44 | # cycle dataset 45 | cycle = conf.get('cycle', 1) 46 | # stage1 shuffle: source 47 | list_shuffle = conf.get('list_shuffle', True) 48 | list_shuffle_size = sys.maxsize 49 | if list_shuffle: 50 | list_shuffle_conf = conf.get('list_shuffle_conf', {}) 51 | list_shuffle_size = list_shuffle_conf.get('shuffle_size', 52 | list_shuffle_size) 53 | if data_type == 'raw': 54 | dataset = WenetRawDatasetSource(data_list_file, 55 | partition=partition, 56 | shuffle=list_shuffle, 57 | shuffle_size=list_shuffle_size, 58 | cycle=cycle) 59 | dataset = dataset.map(processor.parse_json) 60 | else: 61 | dataset = WenetTarShardDatasetSource(data_list_file, 62 | partition=partition, 63 | shuffle=list_shuffle, 64 | shuffle_size=list_shuffle_size, 65 | cycle=cycle) 66 | dataset = dataset.map_ignore_error(processor.decode_wav) 67 | 68 | singal_channel_conf = conf.get('singal_channel_conf', {}) 69 | dataset = dataset.map(partial(processor.singal_channel, **singal_channel_conf)) 70 | 71 | speaker_conf = conf.get('speaker_conf', None) 72 | if speaker_conf is not None: 73 | assert 'speaker_table_path' in speaker_conf 74 | speaker_table = read_symbol_table(speaker_conf['speaker_table_path']) 75 | dataset = dataset.map( 76 | partial(processor.parse_speaker, speaker_dict=speaker_table)) 77 | 78 | if tokenizer is not None: 79 | dataset = dataset.map(partial(processor.tokenize, tokenizer=tokenizer)) 80 | 81 | filter_conf = conf.get('filter_conf', {}) 82 | dataset = dataset.filter(partial(processor.filter, **filter_conf)) 83 | 84 | resample_conf = conf.get('resample_conf', {}) 85 | dataset = dataset.map(partial(processor.resample, **resample_conf)) 86 | 87 | speed_perturb = conf.get('speed_perturb', False) 88 | if speed_perturb: 89 | dataset = dataset.map(partial(processor.speed_perturb)) 90 | 91 | feats_type = conf.get('feats_type', 'fbank') 92 | assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] 93 | if feats_type == 'fbank': 94 | fbank_conf = conf.get('fbank_conf', {}) 95 | dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf)) 96 | elif feats_type == 'mfcc': 97 | mfcc_conf = conf.get('mfcc_conf', {}) 98 | dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf)) 99 | elif feats_type == 'log_mel_spectrogram': 100 | log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) 101 | dataset = dataset.map( 102 | partial(processor.compute_log_mel_spectrogram, 103 | **log_mel_spectrogram_conf)) 104 | spec_aug = conf.get('spec_aug', True) 105 | spec_sub = conf.get('spec_sub', False) 106 | spec_trim = conf.get('spec_trim', False) 107 | if spec_aug: 108 | spec_aug_conf = conf.get('spec_aug_conf', {}) 109 | dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf)) 110 | if spec_sub: 111 | spec_sub_conf = conf.get('spec_sub_conf', {}) 112 | dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf)) 113 | if spec_trim: 114 | spec_trim_conf = conf.get('spec_trim_conf', {}) 115 | dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf)) 116 | 117 | language_conf = conf.get('language_conf', {"limited_langs": ['zh', 'en']}) 118 | dataset = dataset.map(partial(processor.detect_language, **language_conf)) 119 | dataset = dataset.map(processor.detect_task) 120 | 121 | shuffle = conf.get('shuffle', True) 122 | if shuffle: 123 | shuffle_conf = conf.get('shuffle_conf', {}) 124 | dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) 125 | 126 | sort = conf.get('sort', True) 127 | if sort: 128 | sort_conf = conf.get('sort_conf', {}) 129 | dataset = dataset.sort(buffer_size=sort_conf['sort_size'], 130 | key_func=processor.sort_by_feats) 131 | 132 | batch_conf = conf.get('batch_conf', {}) 133 | batch_type = batch_conf.get('batch_type', 'static') 134 | assert batch_type in ['static', 'bucket', 'dynamic'] 135 | if batch_type == 'static': 136 | assert 'batch_size' in batch_conf 137 | batch_size = batch_conf.get('batch_size', 16) 138 | dataset = dataset.batch(batch_size, wrapper_class=processor.padding) 139 | elif batch_type == 'bucket': 140 | assert 'bucket_boundaries' in batch_conf 141 | assert 'bucket_batch_sizes' in batch_conf 142 | dataset = dataset.bucket_by_sequence_length( 143 | processor.feats_length_fn, 144 | batch_conf['bucket_boundaries'], 145 | batch_conf['bucket_batch_sizes'], 146 | wrapper_class=processor.padding) 147 | else: 148 | max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) 149 | dataset = dataset.dynamic_batch( 150 | processor.DynamicBatchWindow(max_frames_in_batch), 151 | wrapper_class=processor.padding, 152 | ) 153 | 154 | return dataset 155 | -------------------------------------------------------------------------------- /wenet_asr_server/wenet/utils/init_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from wenet.k2.model import K2Model 18 | from wenet.paraformer.cif import Cif 19 | from wenet.paraformer.layers import SanmDecoder, SanmEncoder 20 | from wenet.paraformer.paraformer import Paraformer, Predictor 21 | from wenet.transducer.joint import TransducerJoint 22 | from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, 23 | RNNPredictor) 24 | from wenet.transducer.transducer import Transducer 25 | from wenet.transformer.asr_model import ASRModel 26 | from wenet.transformer.cmvn import GlobalCMVN 27 | from wenet.transformer.ctc import CTC 28 | from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder 29 | from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder 30 | from wenet.branchformer.encoder import BranchformerEncoder 31 | from wenet.e_branchformer.encoder import EBranchformerEncoder 32 | from wenet.squeezeformer.encoder import SqueezeformerEncoder 33 | from wenet.efficient_conformer.encoder import EfficientConformerEncoder 34 | from wenet.ctl_model.encoder import DualTransformerEncoder, DualConformerEncoder 35 | from wenet.ctl_model.asr_model_ctl import CTLModel 36 | from wenet.whisper.whisper import Whisper 37 | from wenet.utils.cmvn import load_cmvn 38 | from wenet.utils.checkpoint import load_checkpoint, load_trained_modules 39 | 40 | WENET_ENCODER_CLASSES = { 41 | "transformer": TransformerEncoder, 42 | "conformer": ConformerEncoder, 43 | "squeezeformer": SqueezeformerEncoder, 44 | "efficientConformer": EfficientConformerEncoder, 45 | "branchformer": BranchformerEncoder, 46 | "e_branchformer": EBranchformerEncoder, 47 | "dual_transformer": DualTransformerEncoder, 48 | "dual_conformer": DualConformerEncoder, 49 | 'sanm_encoder': SanmEncoder, 50 | } 51 | 52 | WENET_DECODER_CLASSES = { 53 | "transformer": TransformerDecoder, 54 | "bitransformer": BiTransformerDecoder, 55 | "sanm_decoder": SanmDecoder, 56 | } 57 | 58 | WENET_CTC_CLASSES = { 59 | "ctc": CTC, 60 | } 61 | 62 | WENET_PREDICTOR_CLASSES = { 63 | "rnn": RNNPredictor, 64 | "embedding": EmbeddingPredictor, 65 | "conv": ConvPredictor, 66 | "cif_predictor": Cif, 67 | "paraformer_predictor": Predictor, 68 | } 69 | 70 | WENET_JOINT_CLASSES = { 71 | "transducer_joint": TransducerJoint, 72 | } 73 | 74 | WENET_MODEL_CLASSES = { 75 | "asr_model": ASRModel, 76 | "ctl_model": CTLModel, 77 | "whisper": Whisper, 78 | "k2_model": K2Model, 79 | "transducer": Transducer, 80 | 'paraformer': Paraformer, 81 | } 82 | 83 | 84 | def init_model(args, configs): 85 | 86 | # TODO(xcsong): Forcefully read the 'cmvn' attribute. 87 | if configs.get('cmvn', None) == 'global_cmvn': 88 | mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], 89 | configs['cmvn_conf']['is_json_cmvn']) 90 | global_cmvn = GlobalCMVN( 91 | torch.from_numpy(mean).float(), 92 | torch.from_numpy(istd).float()) 93 | else: 94 | global_cmvn = None 95 | 96 | input_dim = configs['input_dim'] 97 | vocab_size = configs['output_dim'] 98 | 99 | encoder_type = configs.get('encoder', 'conformer') 100 | decoder_type = configs.get('decoder', 'bitransformer') 101 | ctc_type = configs.get('ctc', 'ctc') 102 | 103 | encoder = WENET_ENCODER_CLASSES[encoder_type]( 104 | input_dim, 105 | global_cmvn=global_cmvn, 106 | **configs['encoder_conf'], 107 | **configs['encoder_conf']['efficient_conf'] 108 | if 'efficient_conf' in configs['encoder_conf'] else {}) 109 | 110 | decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, 111 | encoder.output_size(), 112 | **configs['decoder_conf']) 113 | 114 | ctc = WENET_CTC_CLASSES[ctc_type]( 115 | vocab_size, 116 | encoder.output_size(), 117 | blank_id=configs['ctc_conf']['ctc_blank_id'] 118 | if 'ctc_conf' in configs else 0) 119 | 120 | model_type = configs.get('model', 'asr_model') 121 | if model_type == "transducer": 122 | predictor_type = configs.get('predictor', 'rnn') 123 | joint_type = configs.get('joint', 'transducer_joint') 124 | predictor = WENET_PREDICTOR_CLASSES[predictor_type]( 125 | vocab_size, **configs['predictor_conf']) 126 | joint = WENET_JOINT_CLASSES[joint_type](vocab_size, 127 | **configs['joint_conf']) 128 | model = WENET_MODEL_CLASSES[model_type]( 129 | vocab_size=vocab_size, 130 | blank=0, 131 | predictor=predictor, 132 | encoder=encoder, 133 | attention_decoder=decoder, 134 | joint=joint, 135 | ctc=ctc, 136 | special_tokens=configs.get('tokenizer_conf', 137 | {}).get('special_tokens', None), 138 | **configs['model_conf']) 139 | elif model_type == 'paraformer': 140 | predictor_type = configs.get('predictor', 'cif') 141 | predictor = WENET_PREDICTOR_CLASSES[predictor_type]( 142 | **configs['predictor_conf']) 143 | model = WENET_MODEL_CLASSES[model_type]( 144 | vocab_size=vocab_size, 145 | encoder=encoder, 146 | decoder=decoder, 147 | predictor=predictor, 148 | ctc=ctc, 149 | **configs['model_conf'], 150 | special_tokens=configs.get('tokenizer_conf', 151 | {}).get('special_tokens', None), 152 | ) 153 | else: 154 | model = WENET_MODEL_CLASSES[model_type]( 155 | vocab_size=vocab_size, 156 | encoder=encoder, 157 | decoder=decoder, 158 | ctc=ctc, 159 | special_tokens=configs.get('tokenizer_conf', 160 | {}).get('special_tokens', None), 161 | **configs['model_conf']) 162 | 163 | # If specify checkpoint, load some info from checkpoint 164 | if hasattr(args, 'checkpoint') and args.checkpoint is not None: 165 | infos = load_checkpoint(model, args.checkpoint) 166 | elif hasattr(args, 'enc_init') and args.enc_init is not None: 167 | infos = load_trained_modules(model, args) 168 | else: 169 | infos = {} 170 | configs["init_infos"] = infos 171 | print(configs) 172 | 173 | # Tie emb.weight to decoder.output_layer.weight 174 | if model.decoder.tie_word_embedding: 175 | model.decoder.tie_or_clone_weights(jit_mode=args.jit) 176 | 177 | return model, configs 178 | --------------------------------------------------------------------------------