├── asr ├── wenet │ ├── cli │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── transcribe.py │ │ ├── paraformer_model.py │ │ └── hub.py │ ├── espnet │ │ └── __init__.py │ ├── k2 │ │ └── __init__.py │ ├── text │ │ ├── __init__.py │ │ ├── base_tokenizer.py │ │ ├── bpe_tokenizer.py │ │ ├── paraformer_tokenizer.py │ │ ├── hugging_face_tokenizer.py │ │ ├── tokenize_utils.py │ │ ├── char_tokenizer.py │ │ ├── rev_bpe_tokenizer.py │ │ └── whisper_tokenizer.py │ ├── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── file_utils.py │ │ ├── class_utils.py │ │ ├── init_tokenizer.py │ │ ├── cmvn.py │ │ └── ctc_utils.py │ ├── dataset │ │ ├── __init__.py │ │ └── audio_utils.py │ ├── paraformer │ │ ├── __init__.py │ │ ├── embedding.py │ │ └── subsampling.py │ ├── transducer │ │ ├── __init__.py │ │ └── search │ │ │ └── greedy_search.py │ ├── transformer │ │ ├── __init__.py │ │ ├── functional.py │ │ ├── mish.py │ │ ├── swish.py │ │ ├── cmvn.py │ │ ├── label_smoothing_loss.py │ │ ├── positionwise_feed_forward.py │ │ ├── ctc.py │ │ ├── convolution.py │ │ └── context_adaptor.py │ ├── whisper │ │ ├── __init__.py │ │ └── whisper.py │ ├── branchformer │ │ └── __init__.py │ ├── squeezeformer │ │ ├── __init__.py │ │ ├── conv2d.py │ │ ├── positionwise_feed_forward.py │ │ └── encoder_layer.py │ ├── efficient_conformer │ │ ├── __init__.py │ │ ├── subsampling.py │ │ └── convolution.py │ ├── transducer_espnet │ │ ├── __init__.py │ │ ├── abs_decoder.py │ │ ├── utils.py │ │ ├── joint_network.py │ │ ├── bitransducer.py │ │ ├── transducer_decoder_interface.py │ │ └── scorer_interface.py │ ├── __init__.py │ ├── README.md │ ├── bin │ │ ├── average_model_fixed_list.py │ │ ├── export_ipex.py │ │ └── ctc_align.py │ ├── finetune │ │ └── lora │ │ │ ├── utils.py │ │ │ └── attention.py │ ├── onmt_translate │ │ └── penalties.py │ └── ssl │ │ ├── wav2vec2 │ │ └── quantizer.py │ │ └── bestrq │ │ └── mask.py ├── requirements.txt ├── wer_evaluation │ ├── RESULTS.md │ ├── README.md │ ├── aggregate_scoring.py │ └── scoring_commands.py ├── app.py └── README.md ├── diarization ├── requirements.txt ├── data │ └── database.yaml ├── infer_pyannote3.0.py ├── train_pyannote3.0.py ├── assign_words2speakers.py └── README.md ├── resources ├── logo_purple.png └── logo_white.png ├── pyproject.toml ├── Dockerfile ├── Dockerfile.arm64 ├── examples └── stream.py └── .gitignore /asr/wenet/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/espnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/k2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/text/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/paraformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/branchformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/squeezeformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/efficient_conformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diarization/requirements.txt: -------------------------------------------------------------------------------- 1 | pyannote.audio==3.3.1 2 | intervaltree==3.1.0 3 | -------------------------------------------------------------------------------- /resources/logo_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/revdotcom/reverb/HEAD/resources/logo_purple.png -------------------------------------------------------------------------------- /resources/logo_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/revdotcom/reverb/HEAD/resources/logo_white.png -------------------------------------------------------------------------------- /asr/requirements.txt: -------------------------------------------------------------------------------- 1 | torchaudio==2.2.2 2 | torch==2.2.2 3 | openai-whisper 4 | typeguard==2.* 5 | wandb 6 | pyyaml 7 | numpy<2 8 | GitPython 9 | -------------------------------------------------------------------------------- /asr/wenet/__init__.py: -------------------------------------------------------------------------------- 1 | from wenet.cli.reverb import ( 2 | download_model, 3 | get_available_models, 4 | load_model, 5 | ReverbASR 6 | ) 7 | -------------------------------------------------------------------------------- /asr/wenet/transformer/functional.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script provides functional interface for Mish activation function. 3 | ''' 4 | 5 | # import pytorch 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | @torch.jit.script 10 | def mish(input): 11 | ''' 12 | Applies the mish function element-wise: 13 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 14 | See additional documentation for mish class. 15 | ''' 16 | return input * torch.tanh(F.softplus(input)) 17 | 18 | -------------------------------------------------------------------------------- /diarization/data/database.yaml: -------------------------------------------------------------------------------- 1 | Databases: 2 | audiodb: 3 | - /path/to/audio/{uri}.wav 4 | 5 | Protocols: 6 | audiodb: 7 | SpeakerDiarization: 8 | train_protocol: 9 | scope: file 10 | train: 11 | uri: /path/to/train/all.uri 12 | annotation: /path/to/train/all.rttm 13 | annotated: /path/to/train/all.uem 14 | development: 15 | uri: /path/to/dev/all.uri 16 | annotation: /path/to/dev/all.rttm 17 | annotated: /path/to/dev/all.uem -------------------------------------------------------------------------------- /asr/wenet/paraformer/embedding.py: -------------------------------------------------------------------------------- 1 | from wenet.transformer.embedding import WhisperPositionalEncoding 2 | 3 | 4 | class ParaformerPositinoalEncoding(WhisperPositionalEncoding): 5 | """ Sinusoids position encoding used in paraformer.encoder 6 | """ 7 | 8 | def __init__(self, 9 | depth: int, 10 | d_model: int, 11 | dropout_rate: float = 0.1, 12 | max_len: int = 1500): 13 | super().__init__(depth, dropout_rate, max_len) 14 | self.xscale = d_model**0.5 15 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/abs_decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | from typing import Tuple 4 | 5 | import torch 6 | 7 | from wenet.transducer_espnet.scorer_interface import ScorerInterface 8 | 9 | 10 | class AbsDecoder(torch.nn.Module, ScorerInterface, ABC): 11 | @abstractmethod 12 | def forward( 13 | self, 14 | hs_pad: torch.Tensor, 15 | hlens: torch.Tensor, 16 | ys_in_pad: torch.Tensor, 17 | ys_in_lens: torch.Tensor, 18 | ) -> Tuple[torch.Tensor, torch.Tensor]: 19 | raise NotImplementedError 20 | -------------------------------------------------------------------------------- /asr/wenet/cli/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generator 2 | 3 | 4 | def hyps_to_ctm( 5 | audio_name: str, 6 | path: list[dict[str, Any]] 7 | ) -> Generator[str, None, None]: 8 | """Convert a given set of decode results for a single audio file into CTM lines.""" 9 | for line in path: 10 | start_seconds = line['start_time_ms'] / 1000 11 | duration_seconds = line['end_time_ms'] / 1000 - start_seconds 12 | ctm_line = f"{audio_name} 0 {start_seconds:.2f} {duration_seconds:.2f} {line['word']} {line['confidence']:.2f}" 13 | yield ctm_line 14 | 15 | 16 | def hyps_to_txt( 17 | path: list[dict[str, Any]] 18 | ) -> Generator[str, None, None]: 19 | """Convert a given set of decode results for a single audio file into CTM lines.""" 20 | for line in path: 21 | yield line['word'] 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rev-reverb" 7 | version = "0.1.0" 8 | dynamic = ["dependencies"] 9 | requires-python = ">=3.10" 10 | authors = [ 11 | {name = "Rev Speechteam", email = "speechteam@rev.com"}, 12 | ] 13 | maintainers = [ 14 | {name = "Rev Speechteam", email = "speechteam@rev.com"}, 15 | ] 16 | description = "A simplified python packge to interact with the reverb models" 17 | readme = "README.md" 18 | license = {file = "LICENSE"} 19 | keywords = ["asr", "reverb", "rev", "diarization"] 20 | classifiers = [ 21 | "Development Status :: 5 - Production/Stable", 22 | "Programming Language :: Python :: 3.10" 23 | ] 24 | 25 | [tool.setuptools.package-dir] 26 | wenet = "asr/wenet" 27 | 28 | [tool.setuptools.dynamic] 29 | dependencies = {file = ["asr/requirements.txt"]} 30 | 31 | [project.scripts] 32 | reverb = "wenet.bin.recognize_wav:main" 33 | -------------------------------------------------------------------------------- /asr/wenet/transformer/mish.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Applies the mish function element-wise: 3 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 4 | ''' 5 | 6 | # import pytorch 7 | from torch import nn 8 | 9 | # import activation functions 10 | import wenet.transformer.functional as Func 11 | 12 | class Mish(nn.Module): 13 | ''' 14 | Applies the mish function element-wise: 15 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 16 | 17 | Shape: 18 | - Input: (N, *) where * means, any number of additional 19 | dimensions 20 | - Output: (N, *), same shape as the input 21 | 22 | Examples: 23 | >>> m = Mish() 24 | >>> input = torch.randn(2) 25 | >>> output = m(input) 26 | 27 | ''' 28 | def __init__(self): 29 | ''' 30 | Init method. 31 | ''' 32 | super().__init__() 33 | 34 | def forward(self, input): 35 | ''' 36 | Forward pass of the function. 37 | ''' 38 | return Func.mish(input) 39 | -------------------------------------------------------------------------------- /asr/wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Swish() activation function for Conformer.""" 17 | 18 | import torch 19 | 20 | 21 | class Swish(torch.nn.Module): 22 | """Construct an Swish object.""" 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | """Return Swish activation function.""" 26 | return x * torch.sigmoid(x) 27 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime 2 | 3 | ENV PYTHONUNBUFFERED=1 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y --no-install-recommends \ 7 | git \ 8 | git-lfs \ 9 | locales && \ 10 | rm -rf /var/lib/apt/lists/* 11 | 12 | RUN echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && \ 13 | locale-gen en_US.UTF-8 && \ 14 | update-locale LANG=en_US.UTF-8 15 | ENV LANG en_US.UTF-8 16 | ENV LANGUAGE en_US:en 17 | ENV LC_ALL en_US.UTF-8 18 | 19 | WORKDIR /workspace 20 | COPY . /workspace/ 21 | 22 | ARG HUGGINGFACE_ACCESS_TOKEN 23 | ENV HUGGINGFACE_ACCESS_TOKEN=${HUGGINGFACE_ACCESS_TOKEN} 24 | 25 | # manually download ASR model 26 | # diarization will be download automatically when running the script due to HF integration 27 | RUN git lfs install 28 | RUN git clone https://${HUGGINGFACE_ACCESS_TOKEN}:${HUGGINGFACE_ACCESS_TOKEN}@huggingface.co/Revai/reverb-asr /root/.cache/reverb/reverb_asr_v1 29 | 30 | 31 | RUN pip3 install /workspace/ 32 | RUN pip3 install -r /workspace/diarization/requirements.txt 33 | 34 | ENV PYTHONPATH=/workspace/asr/:$PYTHONPATH 35 | 36 | RUN reverb --help 37 | RUN python3 /workspace/diarization/infer_pyannote3.0.py --help 38 | -------------------------------------------------------------------------------- /asr/wer_evaluation/RESULTS.md: -------------------------------------------------------------------------------- 1 | **Earnings21** 2 | 3 | | Model | WER | 4 | | -------- | ------- | 5 | | attention\_rescoring | 9.62% | 6 | | ctc\_greedy\_search | 9.66% | 7 | | ctc\_prefix\_beam\_search | 9.64% | 8 | 9 | **Earnings22** 10 | 11 | | Model | WER | 12 | | -------- | ------- | 13 | | attention\_rescoring | 13.32% | 14 | | ctc\_greedy\_search | 13.48% | 15 | | ctc\_prefix\_beam\_search | 13.45% | 16 | 17 | **Rev16** 18 | 19 | | Verbatimicity | Model | Verbatim reference WER | Nonverbatim reference WER | 20 | | -------- | -------- | ------- | -------| 21 | |1.0 | attention\_rescoring | 10.54% | 14.99% | 22 | |1.0 | ctc\_greedy\_search | 10.64% | 15.25% | 23 | |1.0 | ctc\_prefix\_beam\_search | 10.81% | 15.56% | 24 | |1.0 | joint\_decoding | 11.41% | n/a | 25 | |0.0 | attention\_rescoring | 13.97% | 9.08% | 26 | |0.0 | ctc\_greedy\_search | 13.65% | 9.17% | 27 | |0.0 | ctc\_prefix\_beam\_search | 13.57% | 9.16% 28 | 29 | **Gigaspeech (English, filtered)** 30 | 31 | | Model | WER | 32 | | -------- | ------- | 33 | | attention\_rescoring | 10.37% | 34 | | ctc\_greedy\_search | 10.67% | 35 | | ctc\_prefix\_beam\_search | 10.67% | 36 | -------------------------------------------------------------------------------- /asr/wenet/text/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod, abstractproperty 2 | from typing import Dict, List, Tuple, Union 3 | 4 | T = Union[str, bytes] 5 | 6 | 7 | class BaseTokenizer(ABC): 8 | 9 | def tokenize(self, line: str) -> Tuple[List[T], List[int]]: 10 | tokens = self.text2tokens(line) 11 | ids = self.tokens2ids(tokens) 12 | return tokens, ids 13 | 14 | def detokenize(self, ids: List[int]) -> Tuple[str, List[T]]: 15 | tokens = self.ids2tokens(ids) 16 | text = self.tokens2text(tokens) 17 | return text, tokens 18 | 19 | @abstractmethod 20 | def text2tokens(self, line: str) -> List[T]: 21 | raise NotImplementedError("abstract method") 22 | 23 | @abstractmethod 24 | def tokens2text(self, tokens: List[T]) -> str: 25 | raise NotImplementedError("abstract method") 26 | 27 | @abstractmethod 28 | def tokens2ids(self, tokens: List[T]) -> List[int]: 29 | raise NotImplementedError("abstract method") 30 | 31 | @abstractmethod 32 | def ids2tokens(self, ids: List[int]) -> List[T]: 33 | raise NotImplementedError("abstract method") 34 | 35 | @abstractmethod 36 | def vocab_size(self) -> int: 37 | raise NotImplementedError("abstract method") 38 | 39 | @abstractproperty 40 | def symbol_table(self) -> Dict[T, int]: 41 | raise NotImplementedError("abstract method") 42 | -------------------------------------------------------------------------------- /Dockerfile.arm64: -------------------------------------------------------------------------------- 1 | # Use an ARM64 base image 2 | FROM --platform=linux/arm64 python:3.11-slim 3 | 4 | # Install system dependencies 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | && apt-get clean 8 | 9 | # Install PyTorch 10 | RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 11 | 12 | ENV PYTHONUNBUFFERED=1 13 | 14 | RUN apt-get update && \ 15 | apt-get install -y --no-install-recommends \ 16 | git \ 17 | git-lfs \ 18 | libsndfile1-dev \ 19 | locales && \ 20 | rm -rf /var/lib/apt/lists/* 21 | 22 | RUN echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && \ 23 | locale-gen en_US.UTF-8 && \ 24 | update-locale LANG=en_US.UTF-8 25 | ENV LANG en_US.UTF-8 26 | ENV LANGUAGE en_US:en 27 | ENV LC_ALL en_US.UTF-8 28 | 29 | WORKDIR /workspace 30 | COPY . /workspace/ 31 | 32 | ARG HUGGINGFACE_ACCESS_TOKEN 33 | ENV HUGGINGFACE_ACCESS_TOKEN=${HUGGINGFACE_ACCESS_TOKEN} 34 | 35 | # manually download ASR model 36 | # diarization will be download automatically when running the script due to HF integration 37 | RUN git lfs install 38 | RUN git clone https://${HUGGINGFACE_ACCESS_TOKEN}:${HUGGINGFACE_ACCESS_TOKEN}@huggingface.co/Revai/reverb-asr /root/.cache/reverb/reverb_asr_v1 39 | 40 | 41 | RUN pip3 install /workspace/ 42 | RUN pip3 install -r /workspace/diarization/requirements.txt 43 | 44 | ENV PYTHONPATH=/workspace/asr/:$PYTHONPATH 45 | 46 | RUN reverb --help 47 | RUN python3 /workspace/diarization/infer_pyannote3.0.py --help 48 | -------------------------------------------------------------------------------- /asr/wenet/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Shaoshang Qi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | 17 | 18 | def override_config(configs, override_list): 19 | new_configs = copy.deepcopy(configs) 20 | for item in override_list: 21 | arr = item.split() 22 | if len(arr) != 2: 23 | print(f"the overrive {item} format not correct, skip it") 24 | continue 25 | keys = arr[0].split('.') 26 | s_configs = new_configs 27 | for i, key in enumerate(keys): 28 | if key not in s_configs: 29 | print(f"the overrive {item} format not correct, skip it") 30 | if i == len(keys) - 1: 31 | param_type = type(s_configs[key]) 32 | if param_type != bool: 33 | s_configs[key] = param_type(arr[1]) 34 | else: 35 | s_configs[key] = arr[1] in ['true', 'True'] 36 | print(f"override {arr[0]} with {arr[1]}") 37 | else: 38 | s_configs = s_configs[key] 39 | return new_configs 40 | -------------------------------------------------------------------------------- /asr/wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class GlobalCMVN(torch.nn.Module): 19 | 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /examples/stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tempfile 3 | import threading 4 | from pydub import AudioSegment 5 | from pydub.playback import play 6 | from wenet import load_model 7 | 8 | 9 | mdl = load_model('reverb_asr_v1') 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser( 14 | description="Simple example of how to run Reverb while streaming an audio file." 15 | ) 16 | parser.add_argument("audio_file", help="Audio to stream") 17 | parser.add_argument( 18 | "--chunk-size", 19 | type=int, 20 | default=10_000, 21 | help="Fixed size chunk to cut the audio in milliseconds.", 22 | ) 23 | return parser.parse_args() 24 | 25 | 26 | def stream_audio_chunks(file_path, chunk_size_ms=10000): 27 | audio = AudioSegment.from_file(file_path) 28 | 29 | for i in range(0, len(audio), chunk_size_ms): 30 | chunk = audio[i:i + chunk_size_ms] 31 | yield chunk 32 | 33 | if __name__ == '__main__': 34 | args = get_args() 35 | # Example usage 36 | for chunk in stream_audio_chunks(args.audio_file, chunk_size_ms=args.chunk_size): 37 | # Process the chunk 38 | with tempfile.NamedTemporaryFile() as tfile: 39 | def play_chunk(): 40 | play(chunk) 41 | def transcribe_chunk(): 42 | chunk.export(tfile.name, format="wav") 43 | print(mdl.transcribe(tfile.name)) 44 | 45 | thread1 = threading.Thread(target=play_chunk) 46 | thread2 = threading.Thread(target=transcribe_chunk) 47 | 48 | thread1.start() 49 | thread2.start() 50 | 51 | # Wait for threads to finish 52 | thread1.join() 53 | thread2.join() 54 | -------------------------------------------------------------------------------- /asr/wenet/paraformer/subsampling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | import torch 3 | from wenet.transformer.subsampling import BaseSubsampling 4 | 5 | 6 | class IdentitySubsampling(BaseSubsampling): 7 | """ Paraformer subsampling 8 | """ 9 | 10 | def __init__(self, idim: int, odim: int, dropout_rate: float, 11 | pos_enc_class: torch.nn.Module): 12 | super().__init__() 13 | _, _ = idim, odim 14 | self.right_context = 6 15 | self.subsampling_rate = 6 16 | self.pos_enc = pos_enc_class 17 | 18 | def forward( 19 | self, 20 | x: torch.Tensor, 21 | x_mask: torch.Tensor, 22 | offset: Union[torch.Tensor, int] = 0 23 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 24 | """Subsample x. 25 | 26 | Args: 27 | x (torch.Tensor): Input tensor (#batch, time, idim). 28 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 29 | 30 | Returns: 31 | torch.Tensor: Subsampled tensor (#batch, time', odim), 32 | where time' = time. 33 | torch.Tensor: Subsampled mask (#batch, 1, time'), 34 | where time' = time 35 | torch.Tensor: positional encoding 36 | 37 | """ 38 | # NOTE(Mddct): Paraformer starts from 1 39 | if isinstance(offset, torch.Tensor): 40 | offset = torch.add(offset, 1) 41 | else: 42 | offset = offset + 1 43 | x, pos_emb = self.pos_enc(x, offset) 44 | return x, pos_emb, x_mask 45 | 46 | def position_encoding(self, offset: Union[int, torch.Tensor], 47 | size: int) -> torch.Tensor: 48 | return self.pos_enc.position_encoding(offset + 1, size) 49 | -------------------------------------------------------------------------------- /diarization/infer_pyannote3.0.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2024 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import argparse 9 | import os 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | from pyannote.audio import Pipeline 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='Run inference on audio files') 19 | parser.add_argument('audios', nargs='+') 20 | parser.add_argument('--out-dir', type=Path, required=True) 21 | parser.add_argument('--hf-access-token', type=str, required=False, default=None) 22 | # we offer 2 models, reverb-diarization-v1 that is faster and a little bit less accurate, 23 | # and reverb-diarization-v2, the most accurate model but considerably slower 24 | parser.add_argument('--pipeline-model', type=str, required=False, 25 | choices=['Revai/reverb-diarization-v1', 26 | 'Revai/reverb-diarization-v2'], 27 | default='Revai/reverb-diarization-v1') 28 | 29 | args = parser.parse_args() 30 | os.makedirs(args.out_dir, exist_ok=True) 31 | 32 | hf_access_token = args.hf_access_token if args.hf_access_token else os.environ['HUGGINGFACE_ACCESS_TOKEN'] 33 | finetuned_pipeline = Pipeline.from_pretrained(args.pipeline_model, use_auth_token=hf_access_token) 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | finetuned_pipeline.to(torch.device(device)) 37 | 38 | for audio in args.audios: 39 | print('Processing', audio) 40 | annotation = finetuned_pipeline(audio) 41 | with open(args.out_dir / f'{os.path.splitext(os.path.basename(audio))[0]}.rttm', 'w') as f: 42 | annotation.write_rttm(f) 43 | -------------------------------------------------------------------------------- /asr/wenet/text/bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Union 3 | from wenet.text.char_tokenizer import CharTokenizer 4 | from wenet.text.tokenize_utils import tokenize_by_bpe_model 5 | 6 | 7 | class BpeTokenizer(CharTokenizer): 8 | 9 | def __init__( 10 | self, 11 | bpe_model: Union[PathLike, str], 12 | symbol_table: Union[str, PathLike, Dict], 13 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 14 | split_with_space: bool = False, 15 | connect_symbol: str = '', 16 | unk='', 17 | ) -> None: 18 | super().__init__(symbol_table, non_lang_syms, split_with_space, 19 | connect_symbol, unk) 20 | self._model = bpe_model 21 | # NOTE(Mddct): multiprocessing.Process() issues 22 | # don't build sp here 23 | self.bpe_model = None 24 | 25 | def _build_sp(self): 26 | if self.bpe_model is None: 27 | import sentencepiece as spm 28 | self.bpe_model = spm.SentencePieceProcessor() 29 | self.bpe_model.load(self._model) 30 | 31 | def text2tokens(self, line: str) -> List[str]: 32 | self._build_sp() 33 | line = line.strip() 34 | if self.non_lang_syms_pattern is not None: 35 | parts = self.non_lang_syms_pattern.split(line.upper()) 36 | parts = [w for w in parts if len(w.strip()) > 0] 37 | else: 38 | parts = [line] 39 | 40 | tokens = [] 41 | for part in parts: 42 | if part in self.non_lang_syms: 43 | tokens.append(part) 44 | else: 45 | tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) 46 | return tokens 47 | 48 | def tokens2text(self, tokens: List[str]) -> str: 49 | self._build_sp() 50 | text = super().tokens2text(tokens) 51 | return text.replace("▁", ' ').strip() 52 | -------------------------------------------------------------------------------- /asr/wenet/README.md: -------------------------------------------------------------------------------- 1 | # Module Introduction 2 | 3 | Here is a brief introduction of each module(directory). 4 | 5 | * `bin`: training and recognition binaries 6 | * `dataset`: IO design 7 | * `utils`: common utils 8 | * `transformer`: the core of `WeNet`, in which the standard transformer/conformer is implemented. It contains the common blocks(backbone) of speech transformers. 9 | * transformer/attention.py: Standard multi head attention 10 | * transformer/embedding.py: Standard position encoding 11 | * transformer/positionwise_feed_forward.py: Standard feed forward in transformer 12 | * transformer/convolution.py: ConvolutionModule in Conformer model 13 | * transformer/subsampling.py: Subsampling implementation for speech task 14 | * `transducer`: transducer implementation 15 | * `squeezeformer`: squeezeformer implementation, please refer [paper](https://arxiv.org/pdf/2206.00888.pdf) 16 | * `efficient_conformer`: efficient conformer implementation, please refer [paper](https://arxiv.org/pdf/2109.01163.pdf) 17 | * `paraformer`: paraformer implementation, please refer [paper](https://arxiv.org/pdf/1905.11235.pdf) 18 | * `paraformer/cif.py`: Continuous Integrate-and-Fire implemented, please refer [paper](https://arxiv.org/pdf/1905.11235.pdf) 19 | * `branchformer`: branchformer implementation, please refer [paper](https://arxiv.org/abs/2207.02971) 20 | * `whisper`: whisper implementation, please refer [paper](https://arxiv.org/abs/2212.04356) 21 | * `ssl`: Self-supervised speech model implementation. e.g. wav2vec2, bestrq, w2vbert. 22 | * `ctl_model`: Enhancing the Unified Streaming and Non-streaming Model with with Contrastive Learning implementation [paper](https://arxiv.org/abs/2306.00755) 23 | 24 | `transducer`, `squeezeformer`, `efficient_conformer`, `branchformer` and `cif` are all based on `transformer`, 25 | they resue a lot of the common blocks of `tranformer`. 26 | 27 | **If you want to contribute your own x-former, please reuse the current code as much as possible**. 28 | 29 | 30 | -------------------------------------------------------------------------------- /asr/wenet/text/paraformer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Union 3 | from wenet.paraformer.search import paraformer_beautify_result 4 | from wenet.text.char_tokenizer import CharTokenizer 5 | from wenet.text.tokenize_utils import tokenize_by_seg_dict 6 | 7 | 8 | def read_seg_dict(path): 9 | seg_table = {} 10 | with open(path, 'r', encoding='utf8') as fin: 11 | for line in fin: 12 | arr = line.strip().split('\t') 13 | assert len(arr) == 2 14 | seg_table[arr[0]] = arr[1] 15 | return seg_table 16 | 17 | 18 | class ParaformerTokenizer(CharTokenizer): 19 | 20 | def __init__(self, 21 | symbol_table: Union[str, PathLike, Dict], 22 | seg_dict: Optional[Union[str, PathLike, Dict]] = None, 23 | split_with_space: bool = False, 24 | connect_symbol: str = '', 25 | unk='') -> None: 26 | super().__init__(symbol_table, None, split_with_space, connect_symbol, 27 | unk) 28 | self.seg_dict = seg_dict 29 | if seg_dict is not None and not isinstance(seg_dict, Dict): 30 | self.seg_dict = read_seg_dict(seg_dict) 31 | 32 | def text2tokens(self, line: str) -> List[str]: 33 | assert self.seg_dict is not None 34 | 35 | # TODO(Mddct): duplicated here, refine later 36 | line = line.strip() 37 | if self.non_lang_syms_pattern is not None: 38 | parts = self.non_lang_syms_pattern.split(line) 39 | parts = [w for w in parts if len(w.strip()) > 0] 40 | else: 41 | parts = [line] 42 | 43 | tokens = [] 44 | for part in parts: 45 | if part in self.non_lang_syms: 46 | tokens.append(part) 47 | else: 48 | tokens.extend(tokenize_by_seg_dict(self.seg_dict, part)) 49 | return tokens 50 | 51 | def tokens2text(self, tokens: List[str]) -> str: 52 | return paraformer_beautify_result(tokens) 53 | -------------------------------------------------------------------------------- /asr/wenet/text/hugging_face_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Union 3 | from wenet.text.base_tokenizer import BaseTokenizer, T as Type 4 | 5 | 6 | class HuggingFaceTokenizer(BaseTokenizer): 7 | 8 | def __init__(self, model: Union[str, PathLike], *args, **kwargs) -> None: 9 | # NOTE(Mddct): don't build here, pickle issues 10 | self.model = model 11 | self.tokenizer = None 12 | 13 | self.args = args 14 | self.kwargs = kwargs 15 | 16 | def __getstate__(self): 17 | state = self.__dict__.copy() 18 | del state['tokenizer'] 19 | return state 20 | 21 | def __setstate__(self, state): 22 | self.__dict__.update(state) 23 | recovery = {'tokenizer': None} 24 | self.__dict__.update(recovery) 25 | 26 | def _build_hugging_face(self): 27 | from transformers import AutoTokenizer 28 | if self.tokenizer is None: 29 | self.tokenizer = AutoTokenizer.from_pretrained( 30 | self.model, **self.kwargs) 31 | self.t2i = self.tokenizer.get_vocab() 32 | 33 | def text2tokens(self, line: str) -> List[Type]: 34 | self._build_hugging_face() 35 | return self.tokenizer.tokenize(line) 36 | 37 | def tokens2text(self, tokens: List[Type]) -> str: 38 | self._build_hugging_face() 39 | ids = self.tokens2ids(tokens) 40 | return self.tokenizer.decode(ids) 41 | 42 | def tokens2ids(self, tokens: List[Type]) -> List[int]: 43 | self._build_hugging_face() 44 | return self.tokenizer.convert_tokens_to_ids(tokens) 45 | 46 | def ids2tokens(self, ids: List[int]) -> List[Type]: 47 | self._build_hugging_face() 48 | return self.tokenizer.convert_ids_to_tokens(ids) 49 | 50 | def vocab_size(self) -> int: 51 | self._build_hugging_face() 52 | # TODO: we need special tokenize size in future 53 | return len(self.tokenizer) 54 | 55 | @property 56 | def symbol_table(self) -> Dict[Type, int]: 57 | self._build_hugging_face() 58 | return self.t2i 59 | -------------------------------------------------------------------------------- /asr/wenet/bin/average_model_fixed_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. All Rights Reserved. 2 | # Author: di.wu@mobvoi.com (DI WU) 3 | import os 4 | import argparse 5 | import glob 6 | 7 | import yaml 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser(description='average model') 14 | parser.add_argument('--dst_model', required=True, help='averaged model') 15 | parser.add_argument('--src_path', help='src model path for average') 16 | parser.add_argument('--list', help='list of snapshots to merge') 17 | 18 | args = parser.parse_args() 19 | print(args) 20 | return args 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | chkpt_paths = [] 26 | 27 | if args.list: 28 | with open(args.list, "r") as reader: 29 | for line in reader: 30 | line=line.strip() 31 | if line[-3:] != ".pt": 32 | line = line + ".pt" 33 | 34 | if os.path.isabs(line): 35 | chkpt_paths.append(line) 36 | elif os.path.exists(line): 37 | chkpt_paths.append(line) 38 | elif args.src_path: 39 | chkpt_paths.append(args.src_path + "/" + line) 40 | 41 | num=len(chkpt_paths) 42 | print(f"num ({num}), len(chkpt_paths) = {len(chkpt_paths)}") 43 | avg = None 44 | for path in chkpt_paths: 45 | print('Processing {}'.format(path)) 46 | global_states = torch.load(path, map_location=torch.device('cpu')) 47 | if 'model0' in global_states: 48 | states = global_states['model0'] 49 | else: 50 | states = global_states 51 | 52 | if avg is None: 53 | avg = states 54 | else: 55 | for k in avg.keys(): 56 | avg[k] += states[k] 57 | # average 58 | for k in avg.keys(): 59 | if avg[k] is not None: 60 | # pytorch 1.6 use true_divide instead of /= 61 | avg[k] = torch.true_divide(avg[k], num) 62 | print('Saving to {}'.format(args.dst_model)) 63 | torch.save(avg, args.dst_model) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /asr/wenet/transducer/search/greedy_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def basic_greedy_search( 7 | model: torch.nn.Module, 8 | encoder_out: torch.Tensor, 9 | encoder_out_lens: torch.Tensor, 10 | n_steps: int = 64, 11 | ) -> List[List[int]]: 12 | # fake padding 13 | padding = torch.zeros(1, 1).to(encoder_out.device) 14 | # sos 15 | pred_input_step = torch.tensor([model.blank]).reshape(1, 1) 16 | cache = model.predictor.init_state(1, 17 | method="zero", 18 | device=encoder_out.device) 19 | new_cache: List[torch.Tensor] = [] 20 | t = 0 21 | hyps = [] 22 | prev_out_nblk = True 23 | pred_out_step = None 24 | per_frame_max_noblk = n_steps 25 | per_frame_noblk = 0 26 | while t < encoder_out_lens: 27 | encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E] 28 | if prev_out_nblk: 29 | step_outs = model.predictor.forward_step(pred_input_step, padding, 30 | cache) # [1, 1, P] 31 | pred_out_step, new_cache = step_outs[0], step_outs[1] 32 | 33 | joint_out_step = model.joint(encoder_out_step, 34 | pred_out_step) # [1,1,v] 35 | joint_out_probs = joint_out_step.log_softmax(dim=-1) 36 | 37 | joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # [] 38 | if joint_out_max != model.blank: 39 | hyps.append(joint_out_max.item()) 40 | prev_out_nblk = True 41 | per_frame_noblk = per_frame_noblk + 1 42 | pred_input_step = joint_out_max.reshape(1, 1) 43 | # state_m, state_c = clstate_out_m, state_out_c 44 | cache = new_cache 45 | 46 | if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk: 47 | if joint_out_max == model.blank: 48 | prev_out_nblk = False 49 | # TODO(Mddct): make t in chunk for streamming 50 | # or t should't be too lang to predict none blank 51 | t = t + 1 52 | per_frame_noblk = 0 53 | 54 | return [hyps] 55 | -------------------------------------------------------------------------------- /asr/app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Horizon Robotics. (authors: Binbin Zhang) 2 | # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import gradio as gr 17 | import wenet 18 | 19 | # TODO: add hotword 20 | chs_model = wenet.load_model('chinese') 21 | en_model = wenet.load_model('english') 22 | 23 | 24 | def recognition(audio, lang='CN'): 25 | if audio is None: 26 | return "Input Error! Please enter one audio!" 27 | # NOTE: model supports 16k sample_rate 28 | if lang == 'CN': 29 | ans = chs_model.transcribe(audio) 30 | elif lang == 'EN': 31 | ans = en_model.transcribe(audio) 32 | else: 33 | return "ERROR! Please select a language!" 34 | 35 | if ans is None: 36 | return "ERROR! No text output! Please try again!" 37 | txt = ans['text'] 38 | return txt 39 | 40 | 41 | # input 42 | inputs = [ 43 | gr.inputs.Audio(source="microphone", type="filepath", label='Input audio'), 44 | gr.Radio(['EN', 'CN'], label='Language') 45 | ] 46 | 47 | output = gr.outputs.Textbox(label="Output Text") 48 | 49 | text = "Speech Recognition in WeNet | 基于 WeNet 的语音识别" 50 | 51 | # description 52 | description = ( 53 | "Wenet Demo ! This is a speech recognition demo that supports Mandarin and English !" # noqa 54 | ) 55 | 56 | article = ( 57 | "

" 58 | "Github: Learn more about WeNet" # noqa 59 | "

") 60 | 61 | interface = gr.Interface( 62 | fn=recognition, 63 | inputs=inputs, 64 | outputs=output, 65 | title=text, 66 | description=description, 67 | article=article, 68 | theme='huggingface', 69 | ) 70 | 71 | interface.launch(enable_queue=True) 72 | -------------------------------------------------------------------------------- /asr/wer_evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluating ASR Systems 2 | We provide two python scripts to facilitate aligment and metric aggregation: 3 | 1. `scoring_commands.py` prints to STDOUT a series of commands that can be run to get the scoring between a reference and hypothesis transcript. _How_ to run these commands is up to the user, but we highly recommend using parallelization when possible. This script assumes a dataset from our [speech-datasets](https://github.com/revdotcom/speech-datasets) repository is being used; these have reference "nlp" transcripts along with reference normalizations that can improve the quality of scoring. This script also depends on having the [fstalign](https://github.com/revdotcom/fstalign/tree/develop) binary locally, follow the instructions on that repository to install. For more information on how to run this script, run `python3 scoring_commands.py --help`. 4 | 2. `aggregate_scoring.py` aggregates the output results from `scoring_commands.py` to produce the key metrics we use to evaluate ASR systems. For more information on how to run this script, run `python3 aggregate_scoring.py --help`. 5 | 6 | ## Example 7 | ```bash 8 | ~$ FSTALIGN_BINARY=fstalign/fstalign 9 | ~$ NLP_REFERENCE_DIRECTORY=speech-datasets/earnings21/transcripts/nlp_references/ 10 | ~$ ASR_HYPOTHESIS_DIRECTORY=asr_output 11 | ~$ FSTALIGN_OUTPUT_DIRECTORY=scoring_output 12 | ~$ NLP_NORMALIZATIONS_DIRECTORY=speech-datasets/earnings21/transcripts/normalizations/ 13 | ~$ FSTALIGN_SYNONYMS_FILE=fstalign/sample_data/synonyms.rules.txt 14 | 15 | ~$ python3 scoring_commands.py \ 16 | ${FSTALIGN_BINARY} \ 17 | ${NLP_REFERENCE_DIRECTORY} \ 18 | ${ASR_HYPOTHESIS_DIRECTORY} \ 19 | ${FSTALIGN_OUTPUT_DIRECTORY} \ 20 | --ref-norm ${NLP_NORMALIZATIONS_DIRECTORY} \ 21 | --synonyms-file ${FSTALIGN_SYNONYMS_FILE} > cmds.sh 22 | 23 | ~$ head -n 1 cmds.sh 24 | fstalign/fstalign wer --ref speech-datasets/earnings21/transcripts/nlp_references/4387865.nlp --hyp asr_output/4387865.ctm --json-log scoring_output/4387865.json --ref-json speech-datasets/earnings21/transcripts/normalizations/4387865.norm.json --syn fstalign/sample_data/synonyms.rules.txt 25 | 26 | ~$ ./cmds.sh 27 | ~$ python3 aggregate_scoring.py ${FSTALIGN_OUTPUT_DIRECTORY} 28 | TOTAL WER: 29172/374486 = 7.79% 29 | Insertion Rate: 6443/374486 = 1.72% 30 | Deletion Rate: 8405/374486 = 2.24% 31 | Substitution Rate: 14324/374486 = 3.82% 32 | ``` 33 | -------------------------------------------------------------------------------- /asr/wenet/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def read_lists(list_file): 19 | lists = [] 20 | with open(list_file, 'r', encoding='utf8') as fin: 21 | for line in fin: 22 | lists.append(line.strip()) 23 | return lists 24 | 25 | 26 | def read_non_lang_symbols(non_lang_sym_path): 27 | """read non-linguistic symbol from file. 28 | 29 | The file format is like below: 30 | 31 | {NOISE}\n 32 | {BRK}\n 33 | ... 34 | 35 | 36 | Args: 37 | non_lang_sym_path: non-linguistic symbol file path, None means no any 38 | syms. 39 | 40 | """ 41 | if non_lang_sym_path is None: 42 | return [] 43 | else: 44 | syms = read_lists(non_lang_sym_path) 45 | non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") 46 | for sym in syms: 47 | if non_lang_syms_pattern.fullmatch(sym) is None: 48 | 49 | class BadSymbolFormat(Exception): 50 | pass 51 | 52 | raise BadSymbolFormat( 53 | "Non-linguistic symbols should be " 54 | "formatted in {xxx}//[xxx], consider" 55 | " modify '%s' to meet the requirment. " 56 | "More details can be found in discussions here : " 57 | "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) 58 | return syms 59 | 60 | 61 | def read_symbol_table(symbol_table_file): 62 | symbol_table = {} 63 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 64 | for line in fin: 65 | arr = line.strip().split() 66 | assert len(arr) == 2 67 | symbol_table[arr[0]] = int(arr[1]) 68 | return symbol_table 69 | -------------------------------------------------------------------------------- /asr/wenet/finetune/lora/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 microsoft 2 | # 2023 Alan (alanfangemail@gmail.com) 3 | # ----------------------------------------------------------------------------- 4 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for 5 | # license information. 6 | # ----------------------------------------------------------------------------- 7 | 8 | import logging 9 | import torch 10 | import torch.nn as nn 11 | 12 | from typing import Dict 13 | 14 | from wenet.finetune.lora.attention import (LoRARelPositionMultiHeadedAttention, 15 | LoRAMultiHeadedAttention) 16 | from wenet.finetune.lora.layers import LoRALayer 17 | 18 | WENET_LORA_ATTENTION_CLASSES = { 19 | "selfattn": LoRAMultiHeadedAttention, 20 | "rel_selfattn": LoRARelPositionMultiHeadedAttention, 21 | } 22 | 23 | 24 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: 25 | logging.info('freezing all params except lora module.') 26 | for n, p in model.named_parameters(): 27 | if 'lora_' not in n: 28 | p.requires_grad = False 29 | if bias == 'none': 30 | return 31 | elif bias == 'all': 32 | for n, p in model.named_parameters(): 33 | if 'bias' in n: 34 | p.requires_grad = True 35 | elif bias == 'lora_only': 36 | for m in model.modules(): 37 | if isinstance(m, LoRALayer) and \ 38 | hasattr(m, 'bias') and \ 39 | m.bias is not None: 40 | m.bias.requires_grad = True 41 | else: 42 | raise NotImplementedError 43 | 44 | 45 | def lora_state_dict(model: nn.Module, 46 | bias: str = 'none') -> Dict[str, torch.Tensor]: 47 | my_state_dict = model.state_dict() 48 | if bias == 'none': 49 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} 50 | elif bias == 'all': 51 | return { 52 | k: my_state_dict[k] 53 | for k in my_state_dict if 'lora_' in k or 'bias' in k 54 | } 55 | elif bias == 'lora_only': 56 | to_return = {} 57 | for k in my_state_dict: 58 | if 'lora_' in k: 59 | to_return[k] = my_state_dict[k] 60 | bias_name = k.split('lora_')[0] + 'bias' 61 | if bias_name in my_state_dict: 62 | to_return[bias_name] = my_state_dict[bias_name] 63 | return to_return 64 | else: 65 | raise NotImplementedError 66 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for Transducer models.""" 2 | 3 | import torch 4 | 5 | # from wenet.transformer.transducer.nets_utils import pad_list 6 | 7 | def pad_list(xs, pad_value): 8 | """Perform padding for the list of tensors. 9 | 10 | Args: 11 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 12 | pad_value (float): Value for padding. 13 | 14 | Returns: 15 | Tensor: Padded tensor (B, Tmax, `*`). 16 | 17 | Examples: 18 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 19 | >>> x 20 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 21 | >>> pad_list(x, 0) 22 | tensor([[1., 1., 1., 1.], 23 | [1., 1., 0., 0.], 24 | [1., 0., 0., 0.]]) 25 | 26 | """ 27 | n_batch = len(xs) 28 | max_len = max(x.size(0) for x in xs) 29 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 30 | 31 | for i in range(n_batch): 32 | pad[i, : xs[i].size(0)] = xs[i] 33 | 34 | return pad 35 | 36 | def get_transducer_task_io( 37 | labels: torch.Tensor, 38 | encoder_out_lens: torch.Tensor, 39 | ignore_id: int = -1, 40 | blank_id: int = 0, 41 | ): 42 | """Get Transducer loss I/O. 43 | 44 | Args: 45 | labels: Label ID sequences. (B, L) 46 | encoder_out_lens: Encoder output lengths. (B,) 47 | ignore_id: Padding symbol ID. 48 | blank_id: Blank symbol ID. 49 | 50 | Return: 51 | decoder_in: Decoder inputs. (B, U) 52 | target: Target label ID sequences. (B, U) 53 | t_len: Time lengths. (B,) 54 | u_len: Label lengths. (B,) 55 | 56 | """ 57 | device = labels.device 58 | 59 | labels_unpad = [y[y != ignore_id] for y in labels] 60 | blank = labels[0].new([blank_id]) 61 | 62 | decoder_in = pad_list( 63 | [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id 64 | ).to(device) 65 | 66 | target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) 67 | 68 | if encoder_out_lens.dim() > 1: 69 | enc_mask = [m[m != 0] for m in encoder_out_lens] 70 | encoder_out_lens = list(map(int, [m.size(0) for m in enc_mask])) 71 | else: 72 | encoder_out_lens = list(map(int, encoder_out_lens)) 73 | 74 | t_len = torch.IntTensor(encoder_out_lens).to(device) 75 | u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) 76 | 77 | return decoder_in, target, t_len, u_len 78 | -------------------------------------------------------------------------------- /asr/wenet/dataset/audio_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torchaudio 5 | import torchaudio.backend.sox_io_backend as sox 6 | import subprocess 7 | import shlex 8 | import io 9 | import numpy as np 10 | from functools import lru_cache 11 | # for debugging purposes 12 | # import psutil 13 | 14 | def bytes_to_list_of_int16(bytes): 15 | return np.frombuffer(bytes, dtype=np.int16) 16 | 17 | # execute a shell command and read all stdout output to a string 18 | def read_cmd_output_old(cmd): 19 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) 20 | out = p.communicate()[0] 21 | ret = p.returncode 22 | p = None 23 | 24 | if(ret!= 0): 25 | return None 26 | return out 27 | 28 | def read_cmd_output(cmd): 29 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=False) 30 | out = p.stdout.read() 31 | # ret = p.returncode 32 | p = None 33 | 34 | # print(f"len out = {len(out)}, ret = {ret}") 35 | # if(ret!= 0): 36 | # return None 37 | return out 38 | 39 | # @lru_cache(maxsize=5) 40 | def get_wavdata_and_samplerate(wavscp_filedescriptor): 41 | """ 42 | Get wav data and sample rate from wavscp line 43 | """ 44 | wavscp_filedescriptor = wavscp_filedescriptor.strip() 45 | if wavscp_filedescriptor[-1] == '|': # this is a pipe command 46 | # here, we assume that the output of the pipe command will be a format torchaudio can load 47 | tmp_x = read_cmd_output(wavscp_filedescriptor.strip()[:-1]) 48 | ff = io.BytesIO(tmp_x) 49 | ff.seek(0) 50 | if ff.getbuffer().nbytes == 0: 51 | print(f"io.bytesio len {ff.getbuffer().nbytes}--> {wavscp_filedescriptor}") 52 | if tmp_x is None: 53 | return None, None 54 | 55 | try: 56 | wav_data, sample_rate = torchaudio.load(ff) 57 | # print(f"wav data shape {wav_data.shape}") 58 | except: 59 | print(f"Error loading wav file {wavscp_filedescriptor}") 60 | return None, None 61 | else: 62 | try: 63 | wav_data, sample_rate = torchaudio.load(wavscp_filedescriptor) 64 | # print(f"wav data shape {wav_data.shape}") 65 | except: 66 | print(f"Error loading wav file {wavscp_filedescriptor}") 67 | return None, None 68 | 69 | # useful for debugging purposes 70 | # print(f"Memory used: {psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2} MB", flush=True) 71 | return wav_data, sample_rate -------------------------------------------------------------------------------- /asr/wenet/text/tokenize_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2023 Horizon Inc. (authors: Xingchen Song) 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def tokenize_by_bpe_model(sp, txt): 19 | return _tokenize_by_seg_dic_or_bpe_model(txt, sp=sp, upper=True) 20 | 21 | 22 | def tokenize_by_seg_dict(seg_dict, txt): 23 | return _tokenize_by_seg_dic_or_bpe_model(txt, 24 | seg_dict=seg_dict, 25 | upper=False) 26 | 27 | 28 | def _tokenize_by_seg_dic_or_bpe_model( 29 | txt, 30 | sp=None, 31 | seg_dict=None, 32 | upper=True, 33 | ): 34 | if sp is None: 35 | assert seg_dict is not None 36 | if seg_dict is None: 37 | assert sp is not None 38 | tokens = [] 39 | # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: 40 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 41 | pattern = re.compile(r'([\u4e00-\u9fff])') 42 | # Example: 43 | # txt = "你好 ITS'S OKAY 的" 44 | # chars = ["你", "好", " ITS'S OKAY ", "的"] 45 | chars = pattern.split(txt.upper() if upper else txt) 46 | mix_chars = [w for w in chars if len(w.strip()) > 0] 47 | for ch_or_w in mix_chars: 48 | # ch_or_w is a single CJK charater(i.e., "你"), do nothing. 49 | if pattern.fullmatch(ch_or_w) is not None: 50 | tokens.append(ch_or_w) 51 | # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), 52 | # encode ch_or_w using bpe_model. 53 | else: 54 | if sp is not None: 55 | for p in sp.encode_as_pieces(ch_or_w): 56 | tokens.append(p) 57 | else: 58 | for en_token in ch_or_w.split(): 59 | en_token = en_token.strip() 60 | if en_token in seg_dict: 61 | tokens.extend(seg_dict[en_token].split(' ')) 62 | else: 63 | tokens.append(en_token) 64 | 65 | return tokens 66 | -------------------------------------------------------------------------------- /asr/wenet/squeezeformer/conv2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Conv2d Module with Valid Padding""" 15 | 16 | import torch.nn.functional as F 17 | from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional 18 | 19 | 20 | class Conv2dValid(_ConvNd): 21 | """ 22 | Conv2d operator for VALID mode padding. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | kernel_size: _size_2_t, 30 | stride: _size_2_t = 1, 31 | padding: Union[str, _size_2_t] = 0, 32 | dilation: _size_2_t = 1, 33 | groups: int = 1, 34 | bias: bool = True, 35 | padding_mode: str = 'zeros', # TODO: refine this type 36 | device=None, 37 | dtype=None, 38 | valid_trigx: bool = False, 39 | valid_trigy: bool = False) -> None: 40 | factory_kwargs = {'device': device, 'dtype': dtype} 41 | kernel_size_ = _pair(kernel_size) 42 | stride_ = _pair(stride) 43 | padding_ = padding if isinstance(padding, str) else _pair(padding) 44 | dilation_ = _pair(dilation) 45 | super(Conv2dValid, 46 | self).__init__(in_channels, out_channels, 47 | kernel_size_, stride_, padding_, dilation_, False, 48 | _pair(0), groups, bias, padding_mode, 49 | **factory_kwargs) 50 | self.valid_trigx = valid_trigx 51 | self.valid_trigy = valid_trigy 52 | 53 | def _conv_forward(self, input: Tensor, weight: Tensor, 54 | bias: Optional[Tensor]): 55 | validx, validy = 0, 0 56 | if self.valid_trigx: 57 | validx = (input.size(-2) * 58 | (self.stride[-2] - 1) - 1 + self.kernel_size[-2]) // 2 59 | if self.valid_trigy: 60 | validy = (input.size(-1) * 61 | (self.stride[-1] - 1) - 1 + self.kernel_size[-1]) // 2 62 | return F.conv2d(input, weight, bias, self.stride, (validx, validy), 63 | self.dilation, self.groups) 64 | 65 | def forward(self, input: Tensor) -> Tensor: 66 | return self._conv_forward(input, self.weight, self.bias) 67 | -------------------------------------------------------------------------------- /asr/wenet/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright [2023-11-28] 4 | import torch 5 | from torch.nn import BatchNorm1d, LayerNorm 6 | from wenet.paraformer.embedding import ParaformerPositinoalEncoding 7 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward 8 | 9 | from wenet.transformer.swish import Swish 10 | from wenet.transformer.subsampling import ( 11 | LinearNoSubsampling, 12 | EmbedinigNoSubsampling, 13 | Conv1dSubsampling2, 14 | Conv2dSubsampling4, 15 | Conv2dSubsampling6, 16 | Conv2dSubsampling8, 17 | ) 18 | from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 19 | from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 20 | from wenet.transformer.embedding import (PositionalEncoding, 21 | RelPositionalEncoding, 22 | WhisperPositionalEncoding, 23 | LearnablePositionalEncoding, 24 | NoPositionalEncoding) 25 | from wenet.transformer.attention import (MultiHeadedAttention, 26 | RelPositionMultiHeadedAttention) 27 | from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention 28 | 29 | WENET_ACTIVATION_CLASSES = { 30 | "hardtanh": torch.nn.Hardtanh, 31 | "tanh": torch.nn.Tanh, 32 | "relu": torch.nn.ReLU, 33 | "selu": torch.nn.SELU, 34 | "swish": getattr(torch.nn, "SiLU", Swish), 35 | "gelu": torch.nn.GELU, 36 | } 37 | 38 | WENET_RNN_CLASSES = { 39 | "rnn": torch.nn.RNN, 40 | "lstm": torch.nn.LSTM, 41 | "gru": torch.nn.GRU, 42 | } 43 | 44 | WENET_SUBSAMPLE_CLASSES = { 45 | "linear": LinearNoSubsampling, 46 | "embed": EmbedinigNoSubsampling, 47 | "conv1d2": Conv1dSubsampling2, 48 | "conv2d2": Conv2dSubsampling2, 49 | "conv2d": Conv2dSubsampling4, 50 | "dwconv2d4": DepthwiseConv2dSubsampling4, 51 | "conv2d6": Conv2dSubsampling6, 52 | "conv2d8": Conv2dSubsampling8, 53 | 'paraformer_dummy': torch.nn.Identity 54 | } 55 | 56 | WENET_EMB_CLASSES = { 57 | "embed": PositionalEncoding, 58 | "abs_pos": PositionalEncoding, 59 | "rel_pos": RelPositionalEncoding, 60 | "no_pos": NoPositionalEncoding, 61 | "abs_pos_whisper": WhisperPositionalEncoding, 62 | "embed_learnable_pe": LearnablePositionalEncoding, 63 | "abs_pos_paraformer": ParaformerPositinoalEncoding, 64 | } 65 | 66 | WENET_ATTENTION_CLASSES = { 67 | "selfattn": MultiHeadedAttention, 68 | "rel_selfattn": RelPositionMultiHeadedAttention, 69 | "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, 70 | } 71 | 72 | WENET_MLP_CLASSES = { 73 | 'position_wise_feed_forward': PositionwiseFeedForward, 74 | } 75 | 76 | WENET_NORM_CLASSES = { 77 | 'layer_norm': LayerNorm, 78 | 'batch_norm': BatchNorm1d, 79 | } -------------------------------------------------------------------------------- /diarization/train_pyannote3.0.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2024 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | import argparse 8 | import os 9 | 10 | import torch 11 | 12 | from pyannote.audio import Pipeline, Model 13 | from pyannote.database import FileFinder, registry 14 | from pyannote.audio.tasks import Segmentation 15 | from types import MethodType 16 | from torch.optim import Adam 17 | from pytorch_lightning.callbacks import ( 18 | EarlyStopping, 19 | ModelCheckpoint, 20 | RichProgressBar, 21 | ) 22 | from pytorch_lightning import Trainer 23 | 24 | torch.set_float32_matmul_precision('high') 25 | 26 | 27 | def configure_optimizers(self): 28 | return Adam(self.parameters(), lr=1e-4) 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser(description='Train Pyannote LSTM model') 33 | parser.add_argument('--database', type=str, required=True) 34 | parser.add_argument('--hf-access-token', type=str, required=False, default=None) 35 | 36 | args = parser.parse_args() 37 | registry.load_database(args.database) 38 | dataset = registry.get_protocol('audiodb.SpeakerDiarization.train_protocol', 39 | preprocessors={'audio': FileFinder()}) 40 | 41 | hf_access_token = args.hf_access_token if args.hf_access_token else os.environ['HUGGINGFACE_ACCESS_TOKEN'] 42 | model = Model.from_pretrained( 43 | "pyannote/segmentation-3.0", 44 | use_auth_token=args.hf_access_token) # start from pre-trained model 45 | print(model) 46 | 47 | task = Segmentation( 48 | dataset, 49 | duration=model.specifications.duration, 50 | max_num_speakers=len(model.specifications.classes), 51 | batch_size=64, 52 | num_workers=16, 53 | loss="bce", 54 | vad_loss="bce") 55 | 56 | model.configure_optimizers = MethodType(configure_optimizers, model) 57 | model.task = task 58 | model.setup(stage='fit') 59 | 60 | monitor, direction = task.val_monitor 61 | checkpoint = ModelCheckpoint( 62 | monitor=monitor, 63 | mode=direction, 64 | save_top_k=1, 65 | every_n_epochs=1, 66 | save_last=False, 67 | save_weights_only=False, 68 | filename="{epoch}", 69 | verbose=False, 70 | ) 71 | early_stopping = EarlyStopping( 72 | monitor=monitor, 73 | mode=direction, 74 | min_delta=0.0, 75 | patience=10, 76 | strict=True, 77 | verbose=False, 78 | ) 79 | 80 | callbacks = [RichProgressBar(), checkpoint, early_stopping] 81 | 82 | # we train for at most 20 epochs (might be shorter in case of early stopping) 83 | 84 | trainer = Trainer(accelerator="gpu", 85 | callbacks=callbacks, 86 | max_epochs=20, 87 | gradient_clip_val=0.5) 88 | trainer.fit(model) 89 | -------------------------------------------------------------------------------- /asr/wenet/efficient_conformer/subsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 2022 58.com(Wuba) Inc AI Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Subsampling layer definition.""" 17 | 18 | from typing import Tuple, Union 19 | 20 | import torch 21 | from wenet.transformer.subsampling import BaseSubsampling 22 | 23 | 24 | class Conv2dSubsampling2(BaseSubsampling): 25 | """Convolutional 2D subsampling (to 1/4 length). 26 | 27 | Args: 28 | idim (int): Input dimension. 29 | odim (int): Output dimension. 30 | dropout_rate (float): Dropout rate. 31 | 32 | """ 33 | 34 | def __init__(self, idim: int, odim: int, dropout_rate: float, 35 | pos_enc_class: torch.nn.Module): 36 | """Construct an Conv2dSubsampling4 object.""" 37 | super().__init__() 38 | self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), 39 | torch.nn.ReLU()) 40 | self.out = torch.nn.Sequential( 41 | torch.nn.Linear(odim * ((idim - 1) // 2), odim)) 42 | self.pos_enc = pos_enc_class 43 | # The right context for every conv layer is computed by: 44 | # (kernel_size - 1) * frame_rate_of_this_layer 45 | self.subsampling_rate = 2 46 | # 2 = (3 - 1) * 1 47 | self.right_context = 2 48 | 49 | def forward( 50 | self, 51 | x: torch.Tensor, 52 | x_mask: torch.Tensor, 53 | offset: Union[int, torch.Tensor] = 0 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 55 | """Subsample x. 56 | 57 | Args: 58 | x (torch.Tensor): Input tensor (#batch, time, idim). 59 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 60 | 61 | Returns: 62 | torch.Tensor: Subsampled tensor (#batch, time', odim), 63 | where time' = time // 2. 64 | torch.Tensor: Subsampled mask (#batch, 1, time'), 65 | where time' = time // 2. 66 | torch.Tensor: positional encoding 67 | 68 | """ 69 | x = x.unsqueeze(1) # (b, c=1, t, f) 70 | x = self.conv(x) 71 | b, c, t, f = x.size() 72 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 73 | x, pos_emb = self.pos_enc(x, offset) 74 | return x, pos_emb, x_mask[:, :, :-2:2] 75 | -------------------------------------------------------------------------------- /asr/wenet/text/char_tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from os import PathLike 5 | from typing import Dict, List, Optional, Union 6 | from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols 7 | from wenet.text.base_tokenizer import BaseTokenizer 8 | 9 | 10 | class CharTokenizer(BaseTokenizer): 11 | def __init__( 12 | self, 13 | symbol_table: Union[str, PathLike, Dict], 14 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 15 | split_with_space: bool = False, 16 | connect_symbol: str = '', 17 | unk='', 18 | ) -> None: 19 | self.non_lang_syms_pattern = None 20 | if non_lang_syms is not None: 21 | self.non_lang_syms_pattern = re.compile( 22 | r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") 23 | if not isinstance(symbol_table, Dict): 24 | self._symbol_table = read_symbol_table(symbol_table) 25 | else: 26 | # symbol_table = {"我": 1, "是": 2, "{NOISE}": 3} 27 | self._symbol_table = symbol_table 28 | if not isinstance(non_lang_syms, List): 29 | self.non_lang_syms = read_non_lang_symbols(non_lang_syms) 30 | else: 31 | # non_lang_syms=["{NOISE}"] 32 | self.non_lang_syms = non_lang_syms 33 | self.char_dict = {v: k for k, v in self._symbol_table.items()} 34 | self.split_with_space = split_with_space 35 | self.connect_symbol = connect_symbol 36 | self.unk = unk 37 | 38 | def text2tokens(self, line: str) -> List[str]: 39 | line = line.strip() 40 | if self.non_lang_syms_pattern is not None: 41 | parts = self.non_lang_syms_pattern.split(line.upper()) 42 | parts = [w for w in parts if len(w.strip()) > 0] 43 | else: 44 | parts = [line] 45 | 46 | tokens = [] 47 | for part in parts: 48 | if part in self.non_lang_syms: 49 | tokens.append(part) 50 | else: 51 | if self.split_with_space: 52 | part = part.split(" ") 53 | for ch in part: 54 | if ch == ' ': 55 | ch = "▁" 56 | tokens.append(ch) 57 | return tokens 58 | 59 | def tokens2text(self, tokens: List[str]) -> str: 60 | return self.connect_symbol.join(tokens) 61 | 62 | def tokens2ids(self, tokens: List[str]) -> List[int]: 63 | ids = [] 64 | for ch in tokens: 65 | if ch in self._symbol_table: 66 | ids.append(self._symbol_table[ch]) 67 | elif self.unk in self._symbol_table: 68 | ids.append(self._symbol_table[self.unk]) 69 | return ids 70 | 71 | def ids2tokens(self, ids: List[int]) -> List[str]: 72 | content = [self.char_dict[w] for w in ids] 73 | return content 74 | 75 | def vocab_size(self) -> int: 76 | return len(self.char_dict) 77 | 78 | @property 79 | def symbol_table(self) -> Dict[str, int]: 80 | return self._symbol_table 81 | -------------------------------------------------------------------------------- /asr/wenet/utils/init_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) 2 | # (authors: Xingchen Song) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | 18 | from wenet.text.base_tokenizer import BaseTokenizer 19 | from wenet.text.bpe_tokenizer import BpeTokenizer 20 | from wenet.text.rev_bpe_tokenizer import RevBpeTokenizer 21 | from wenet.text.char_tokenizer import CharTokenizer 22 | from wenet.text.paraformer_tokenizer import ParaformerTokenizer 23 | from wenet.text.whisper_tokenizer import WhisperTokenizer 24 | 25 | 26 | def init_tokenizer(configs) -> BaseTokenizer: 27 | # TODO(xcsong): Forcefully read the 'tokenizer' attribute. 28 | tokenizer_type = configs.get("tokenizer", "char") 29 | if tokenizer_type == "whisper": 30 | tokenizer = WhisperTokenizer( 31 | multilingual=configs['tokenizer_conf']['is_multilingual'], 32 | num_languages=configs['tokenizer_conf']['num_languages']) 33 | elif tokenizer_type == "char": 34 | tokenizer = CharTokenizer( 35 | configs['tokenizer_conf']['symbol_table_path'], 36 | configs['tokenizer_conf']['non_lang_syms_path'], 37 | split_with_space=configs['tokenizer_conf'].get( 38 | 'split_with_space', False), 39 | connect_symbol=configs['tokenizer_conf'].get('connect_symbol', '')) 40 | elif tokenizer_type == "bpe": 41 | tokenizer = BpeTokenizer( 42 | configs['tokenizer_conf']['bpe_path'], 43 | configs['tokenizer_conf']['symbol_table_path'], 44 | configs['tokenizer_conf']['non_lang_syms_path'], 45 | split_with_space=configs['tokenizer_conf'].get( 46 | 'split_with_space', False)) 47 | elif tokenizer_type == "rev_bpe": 48 | tokenizer = RevBpeTokenizer( 49 | configs['tokenizer_conf']['bpe_path'], 50 | configs['tokenizer_conf']['symbol_table_path'], 51 | configs['tokenizer_conf']['non_lang_syms_path'], 52 | split_with_space=configs['tokenizer_conf'].get( 53 | 'split_with_space', False), full_config=configs['tokenizer_conf']) 54 | elif tokenizer_type == 'paraformer': 55 | tokenizer = ParaformerTokenizer( 56 | symbol_table=configs['tokenizer_conf']['symbol_table_path'], 57 | seg_dict=configs['tokenizer_conf']['seg_dict_path']) 58 | else: 59 | raise NotImplementedError 60 | logging.info("use {} tokenizer".format(configs["tokenizer"])) 61 | 62 | return tokenizer 63 | -------------------------------------------------------------------------------- /asr/wenet/text/rev_bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os import PathLike 3 | from typing import Dict, List, Optional, Union 4 | from wenet.text.char_tokenizer import CharTokenizer 5 | from wenet.text.tokenize_utils import tokenize_by_bpe_model 6 | import torch 7 | # import numpy as np 8 | 9 | 10 | class RevBpeTokenizer(CharTokenizer): 11 | 12 | def __init__( 13 | self, 14 | bpe_model: Union[PathLike, str], 15 | symbol_table: Union[str, PathLike, Dict], 16 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 17 | split_with_space: bool = False, 18 | connect_symbol: str = '', 19 | unk='', 20 | full_config: Dict={} 21 | ) -> None: 22 | logging.debug(f"{bpe_model=}, {symbol_table=}, {non_lang_syms=}, {split_with_space=}, {connect_symbol=}, {unk=}, {full_config=}") 23 | super().__init__(symbol_table, non_lang_syms, split_with_space, 24 | connect_symbol, unk) 25 | self.remove_sw = full_config.get('remove_sw', True) 26 | self.replace_unk_as_unknown = full_config.get('replace_unk_as_unknown', True) 27 | self.connect_symbol = connect_symbol 28 | # JPR: TODO: other flags to implement: force_wb 29 | 30 | self._model = bpe_model 31 | # NOTE(Mddct): multiprocessing.Process() issues 32 | # don't build sp here 33 | self.bpe_model = None 34 | 35 | def _build_sp(self): 36 | if self.bpe_model is None: 37 | import sentencepiece as spm 38 | self.bpe_model = spm.SentencePieceProcessor() 39 | self.bpe_model.load(self._model) 40 | 41 | # overriding a lot of things here... 42 | def text2tokens(self, line: str) -> List[str]: 43 | self._build_sp() 44 | line = line.strip() 45 | 46 | if self.remove_sw: 47 | line = line.replace('', '').replace(' ',' ').strip() 48 | 49 | if self.replace_unk_as_unknown: 50 | line = line.replace("", "") 51 | 52 | # other things might be required here... 53 | # like removing trailing dashes, etc. 54 | 55 | tokens = self.bpe_model.encode(line, out_type=str) 56 | #tokens = torch.tensor([tk for tk in self.bpe_model.encode(line, out_type=int)]) 57 | #print(f"line = {line}, tokens = {tokens}") 58 | 59 | return tokens 60 | 61 | 62 | # if self.non_lang_syms_pattern is not None: 63 | # parts = self.non_lang_syms_pattern.split(line.upper()) 64 | # parts = [w for w in parts if len(w.strip()) > 0] 65 | # else: 66 | # parts = [line] 67 | 68 | # tokens = [] 69 | # for part in parts: 70 | # if part in self.non_lang_syms: 71 | # tokens.append(part) 72 | # else: 73 | # tokens.extend(tokenize_by_bpe_model(self.bpe_model, part)) 74 | # return tokens 75 | 76 | # from base class 77 | def tokens2text(self, tokens: List[str]) -> str: 78 | #self._build_sp() 79 | #text = super().tokens2text(tokens) 80 | text = self.connect_symbol.join(tokens) 81 | #return text 82 | return text.replace("▁", ' ').strip() 83 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/joint_network.py: -------------------------------------------------------------------------------- 1 | """Transducer joint network implementation.""" 2 | 3 | import torch 4 | 5 | from wenet.utils.common import get_activation 6 | 7 | 8 | class JointNetwork(torch.nn.Module): 9 | """Transducer joint network module. 10 | 11 | Args: 12 | joint_output_size: Joint network output dimension 13 | encoder_output_size: Encoder output dimension. 14 | decoder_output_size: Decoder output dimension. 15 | joint_space_size: Dimension of joint space. 16 | joint_activation_type: Type of activation for joint network. 17 | 18 | """ 19 | 20 | def __init__( 21 | self, 22 | joint_output_size: int, 23 | encoder_output_size: int, 24 | decoder_output_size: int, 25 | joint_space_size: int = 256, 26 | joint_activation_type: str = "tanh", 27 | ): 28 | """Joint network initializer.""" 29 | super().__init__() 30 | 31 | self.lin_enc = torch.nn.Linear(encoder_output_size, joint_space_size) 32 | self.lin_dec = torch.nn.Linear(decoder_output_size, joint_space_size) 33 | 34 | self.lin_out = torch.nn.Linear(joint_space_size, joint_output_size) 35 | 36 | self.joint_activation = get_activation(joint_activation_type) 37 | 38 | self.joint_space_size = joint_space_size 39 | 40 | def forward( 41 | self, 42 | enc_out: torch.Tensor, 43 | dec_out: torch.Tensor, 44 | ) -> torch.Tensor: 45 | """Joint computation of encoder and decoder hidden state sequences. 46 | 47 | Args: 48 | enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) 49 | dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) 50 | 51 | Returns: 52 | joint_out: Joint output state sequences. (B, T, U, D_out) 53 | 54 | """ 55 | joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out)) 56 | 57 | return self.lin_out(joint_out) 58 | 59 | # Optimized_transducer loss implementation expects other dimensions for joint_out 60 | def forward_optimized( 61 | self, 62 | enc_out: torch.Tensor, 63 | logit_lengths: torch.Tensor, 64 | dec_out: torch.Tensor, 65 | target_lengths: torch.Tensor 66 | ) -> torch.Tensor: 67 | """Joint computation of encoder and decoder hidden state sequences. 68 | 69 | Args: 70 | enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) 71 | dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) 72 | 73 | Returns: 74 | joint_out: Joint output state sequences. (B, T, U, D_out) 75 | 76 | """ 77 | 78 | B = enc_out.size(0) 79 | encoder_out_list = [enc_out[i, :logit_lengths[i], :, :] for i in range(B)] 80 | decoder_out_list = [dec_out[i, :, :target_lengths[i]+1, :] for i in range(B)] 81 | 82 | joint_out = [self.joint_activation(self.lin_enc(e) + self.lin_dec(d)) for e, d in zip(encoder_out_list, decoder_out_list)] 83 | joint_out = [p.reshape(-1, self.joint_space_size) for p in joint_out] 84 | joint_out = torch.cat(joint_out) 85 | 86 | return self.lin_out(joint_out) 87 | -------------------------------------------------------------------------------- /asr/wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | 18 | import numpy as np 19 | 20 | 21 | def _load_json_cmvn(json_cmvn_file): 22 | """ Load the json format cmvn stats file and calculate cmvn 23 | 24 | Args: 25 | json_cmvn_file: cmvn stats file in json format 26 | 27 | Returns: 28 | a numpy array of [means, vars] 29 | """ 30 | with open(json_cmvn_file) as f: 31 | cmvn_stats = json.load(f) 32 | 33 | means = cmvn_stats['mean_stat'] 34 | variance = cmvn_stats['var_stat'] 35 | count = cmvn_stats['frame_num'] 36 | for i in range(len(means)): 37 | means[i] /= count 38 | variance[i] = variance[i] / count - means[i] * means[i] 39 | if variance[i] < 1.0e-20: 40 | variance[i] = 1.0e-20 41 | variance[i] = 1.0 / math.sqrt(variance[i]) 42 | cmvn = np.array([means, variance]) 43 | return cmvn 44 | 45 | 46 | def _load_kaldi_cmvn(kaldi_cmvn_file): 47 | """ Load the kaldi format cmvn stats file and calculate cmvn 48 | 49 | Args: 50 | kaldi_cmvn_file: kaldi text style global cmvn file, which 51 | is generated by: 52 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 53 | 54 | Returns: 55 | a numpy array of [means, vars] 56 | """ 57 | means = [] 58 | variance = [] 59 | with open(kaldi_cmvn_file, 'r') as fid: 60 | # kaldi binary file start with '\0B' 61 | if fid.read(2) == '\0B': 62 | logging.error('kaldi cmvn binary file is not supported, please ' 63 | 'recompute it by: compute-cmvn-stats --binary=false ' 64 | ' scp:feats.scp global_cmvn') 65 | sys.exit(1) 66 | fid.seek(0) 67 | arr = fid.read().split() 68 | assert (arr[0] == '[') 69 | assert (arr[-2] == '0') 70 | assert (arr[-1] == ']') 71 | feat_dim = int((len(arr) - 2 - 2) / 2) 72 | for i in range(1, feat_dim + 1): 73 | means.append(float(arr[i])) 74 | count = float(arr[feat_dim + 1]) 75 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 76 | variance.append(float(arr[i])) 77 | 78 | for i in range(len(means)): 79 | means[i] /= count 80 | variance[i] = variance[i] / count - means[i] * means[i] 81 | if variance[i] < 1.0e-20: 82 | variance[i] = 1.0e-20 83 | variance[i] = 1.0 / math.sqrt(variance[i]) 84 | cmvn = np.array([means, variance]) 85 | return cmvn 86 | 87 | 88 | def load_cmvn(cmvn_file, is_json): 89 | if is_json: 90 | cmvn = _load_json_cmvn(cmvn_file) 91 | else: 92 | cmvn = _load_kaldi_cmvn(cmvn_file) 93 | return cmvn[0], cmvn[1] 94 | -------------------------------------------------------------------------------- /asr/wenet/squeezeformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 2022 Ximalaya Inc (Yuguang Yang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Positionwise feed forward layer definition.""" 17 | 18 | import torch 19 | 20 | 21 | class PositionwiseFeedForward(torch.nn.Module): 22 | """Positionwise feed forward layer. 23 | 24 | FeedForward are appied on each position of the sequence. 25 | The output dim is same with the input dim. 26 | 27 | Args: 28 | idim (int): Input dimenstion. 29 | hidden_units (int): The number of hidden units. 30 | dropout_rate (float): Dropout rate. 31 | activation (torch.nn.Module): Activation function 32 | """ 33 | 34 | def __init__(self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | adaptive_scale: bool = False, 40 | init_weights: bool = False): 41 | """Construct a PositionwiseFeedForward object.""" 42 | super(PositionwiseFeedForward, self).__init__() 43 | self.idim = idim 44 | self.hidden_units = hidden_units 45 | self.w_1 = torch.nn.Linear(idim, hidden_units) 46 | self.activation = activation 47 | self.dropout = torch.nn.Dropout(dropout_rate) 48 | self.w_2 = torch.nn.Linear(hidden_units, idim) 49 | self.ada_scale = None 50 | self.ada_bias = None 51 | self.adaptive_scale = adaptive_scale 52 | self.ada_scale = torch.nn.Parameter(torch.ones([1, 1, idim]), 53 | requires_grad=adaptive_scale) 54 | self.ada_bias = torch.nn.Parameter(torch.zeros([1, 1, idim]), 55 | requires_grad=adaptive_scale) 56 | if init_weights: 57 | self.init_weights() 58 | 59 | def init_weights(self): 60 | ffn1_max = self.idim**-0.5 61 | ffn2_max = self.hidden_units**-0.5 62 | torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) 63 | torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) 64 | torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) 65 | torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) 66 | 67 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 68 | """Forward function. 69 | 70 | Args: 71 | xs: input tensor (B, L, D) 72 | Returns: 73 | output tensor, (B, L, D) 74 | """ 75 | if self.adaptive_scale: 76 | xs = self.ada_scale * xs + self.ada_bias 77 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 78 | -------------------------------------------------------------------------------- /asr/wenet/cli/transcribe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | from wenet.cli.paraformer_model import load_model as load_paraformer 18 | from wenet.cli.model import load_model 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser(description='') 23 | parser.add_argument('audio_file', help='audio file to transcribe') 24 | parser.add_argument('-l', 25 | '--language', 26 | choices=[ 27 | 'chinese', 28 | 'english', 29 | ], 30 | default='chinese', 31 | help='language type') 32 | parser.add_argument('-m', 33 | '--model_dir', 34 | default=None, 35 | help='specify your own model dir') 36 | parser.add_argument('-g', 37 | '--gpu', 38 | type=int, 39 | default='-1', 40 | help='gpu id to decode, default is cpu.') 41 | parser.add_argument('-t', 42 | '--show_tokens_info', 43 | action='store_true', 44 | help='whether to output token(word) level information' 45 | ', such times/confidence') 46 | parser.add_argument('--align', 47 | action='store_true', 48 | help='force align the input audio and transcript') 49 | parser.add_argument('--label', type=str, help='the input label to align') 50 | parser.add_argument('--paraformer', 51 | action='store_true', 52 | help='whether to use the best chinese model') 53 | parser.add_argument('--beam', type=int, default=5, help="beam size") 54 | parser.add_argument('--context_path', 55 | type=str, 56 | default=None, 57 | help='context list file') 58 | parser.add_argument('--context_score', 59 | type=float, 60 | default=6.0, 61 | help='context score') 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def main(): 67 | args = get_args() 68 | 69 | if args.paraformer: 70 | model = load_paraformer(args.model_dir, args.gpu) 71 | else: 72 | model = load_model(args.language, args.model_dir, args.gpu, args.beam, 73 | args.context_path, args.context_score) 74 | if args.align: 75 | result = model.align(args.audio_file, args.label) 76 | else: 77 | result = model.transcribe(args.audio_file, args.show_tokens_info) 78 | print(result) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /asr/wenet/cli/paraformer_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchaudio 5 | import torchaudio.compliance.kaldi as kaldi 6 | 7 | from wenet.cli.hub import Hub 8 | from wenet.paraformer.search import (gen_timestamps_from_peak, 9 | paraformer_greedy_search) 10 | from wenet.text.paraformer_tokenizer import ParaformerTokenizer 11 | 12 | 13 | class Paraformer: 14 | 15 | def __init__(self, 16 | model_dir: str, 17 | device: int = -1, 18 | resample_rate: int = 16000) -> None: 19 | 20 | model_path = os.path.join(model_dir, 'final.zip') 21 | units_path = os.path.join(model_dir, 'units.txt') 22 | self.model = torch.jit.load(model_path) 23 | self.resample_rate = resample_rate 24 | if device >= 0: 25 | device = 'cuda:{}'.format(device) 26 | else: 27 | device = 'cpu' 28 | self.device = torch.device(device) 29 | self.model = self.model.to(self.device) 30 | self.tokenizer = ParaformerTokenizer(symbol_table=units_path) 31 | 32 | def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: 33 | waveform, sample_rate = torchaudio.load(audio_file, normalize=False) 34 | waveform = waveform.to(torch.float).to(self.device) 35 | if sample_rate != self.resample_rate: 36 | waveform = torchaudio.transforms.Resample( 37 | orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) 38 | feats = kaldi.fbank(waveform, 39 | num_mel_bins=80, 40 | frame_length=25, 41 | frame_shift=10, 42 | energy_floor=0.0, 43 | sample_frequency=self.resample_rate) 44 | feats = feats.unsqueeze(0) 45 | feats_lens = torch.tensor([feats.size(1)], 46 | dtype=torch.int64, 47 | device=feats.device) 48 | 49 | decoder_out, token_num, tp_alphas = self.model.forward_paraformer( 50 | feats, feats_lens) 51 | cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) 52 | res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0] 53 | result = {} 54 | result['confidence'] = res.confidence 55 | result['text'] = self.tokenizer.detokenize(res.tokens)[0] 56 | if tokens_info: 57 | tokens_info = [] 58 | times = gen_timestamps_from_peak(res.times, 59 | num_frames=tp_alphas.size(1), 60 | frame_rate=0.02) 61 | 62 | for i, x in enumerate(res.tokens): 63 | tokens_info.append({ 64 | 'token': self.tokenizer.char_dict[x], 65 | 'start': times[i][0], 66 | 'end': times[i][1], 67 | 'confidence': res.tokens_confidence[i] 68 | }) 69 | result['tokens'] = tokens_info 70 | 71 | return result 72 | 73 | def align(self, audio_file: str, label: str) -> dict: 74 | raise NotImplementedError("Align is currently not supported") 75 | 76 | 77 | def load_model(model_dir: str = None, gpu: int = -1) -> Paraformer: 78 | if model_dir is None: 79 | model_dir = Hub.get_model_by_lang('paraformer') 80 | return Paraformer(model_dir, gpu) 81 | -------------------------------------------------------------------------------- /diarization/assign_words2speakers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2024 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import argparse 9 | import csv 10 | from collections import defaultdict 11 | 12 | from intervaltree import IntervalTree, Interval 13 | 14 | from pyannote.database.util import load_rttm 15 | 16 | 17 | def read_ctm(ctm_path): 18 | with open(ctm_path, 'r') as f: 19 | csv_file = csv.reader(f, delimiter=' ') 20 | for row in csv_file: 21 | yield row 22 | 23 | 24 | def speaker_for_segment(start: float, 25 | dur: float, 26 | tree: IntervalTree) -> str: 27 | """Given a start and duration in seconds, and an interval tree representing 28 | speaker segments, return what speaker is speaking. 29 | 30 | If there are overlapping speakers, return the speaker who spoke most of the 31 | time. If there are no speakers, return the nearest one. 32 | 33 | The interval tree could represent reference or hypothesis. 34 | The data inside the interval tree should be the speaker label. 35 | """ 36 | intervals = tree[start:start + dur] 37 | 38 | # Easy case, only one possible interval 39 | if len(intervals) == 1: 40 | return intervals.pop().data 41 | 42 | # First special case, no match 43 | # so we need to find the nearest interval 44 | elif len(intervals) == 0: 45 | seg = Interval(start, start + dur) 46 | distances = {interval: seg.distance_to(interval) 47 | for interval in tree} 48 | if not distances: 49 | return "" 50 | return min(distances, key=distances.get).data 51 | 52 | # Second special case, overlapping speakers 53 | # so we return whichever speaker has majority 54 | else: 55 | seg = Interval(start, start + dur) 56 | overlap_sizes = defaultdict(int) 57 | for interval in intervals: 58 | i0 = max(seg[0], interval[0]) 59 | i1 = min(seg[1], interval[1]) 60 | overlap_sizes[interval.data] += i1 - i0 61 | return max(overlap_sizes, key=overlap_sizes.get) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser('Assign words to speakers based on a diarization rttm file and ctm transcription') 66 | parser.add_argument('diarization_rttm', help='diarization rttm file') 67 | parser.add_argument('ctm_transcription', help='ctm transcription file') 68 | """ Read more about stm format here (we can't store speaker identities in ctm) """ 69 | """ https://www.nist.gov/system/files/documents/2021/08/31/OpenASR21_EvalPlan_v1_3_1.pdf """ 70 | parser.add_argument('output_stm_transcription', help='output file in .stm format as described above') 71 | 72 | args = parser.parse_args() 73 | 74 | ctm = read_ctm(args.ctm_transcription) 75 | rttm = load_rttm(args.diarization_rttm) 76 | dict_key = list(rttm.keys()) 77 | assert len(dict_key) == 1, dict_key 78 | rttm = rttm[dict_key[0]] 79 | 80 | hypothesis_spkr_tree = IntervalTree(Interval(segment.start, segment.end, label) 81 | for segment, _, label in rttm.itertracks(yield_label=True)) 82 | 83 | with open(args.output_stm_transcription, 'w') as f: 84 | for _, channel, start, dur, token, _ in ctm: 85 | start, dur = float(start), float(dur) 86 | hyp_speaker = speaker_for_segment(float(start), float(dur), hypothesis_spkr_tree) 87 | f.write(f'{dict_key[0]} 1 {hyp_speaker} {start:.3f} {(start + dur):.3f} {token}\n') 88 | -------------------------------------------------------------------------------- /asr/wenet/whisper/whisper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modified from [Whisper](https://github.com/openai/whisper) 16 | 17 | import torch 18 | 19 | from typing import Tuple, Dict, List 20 | 21 | from wenet.transformer.asr_model import ASRModel 22 | from wenet.transformer.ctc import CTC 23 | from wenet.transformer.encoder import TransformerEncoder 24 | from wenet.transformer.decoder import TransformerDecoder 25 | from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy 26 | 27 | 28 | class Whisper(ASRModel): 29 | 30 | def __init__( 31 | self, 32 | vocab_size: int, 33 | encoder: TransformerEncoder, 34 | decoder: TransformerDecoder, 35 | ctc: CTC = None, 36 | ctc_weight: float = 0.5, 37 | ignore_id: int = IGNORE_ID, 38 | reverse_weight: float = 0.0, 39 | lsm_weight: float = 0.0, 40 | length_normalized_loss: bool = False, 41 | special_tokens: dict = None, 42 | ): 43 | super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, 44 | ignore_id, reverse_weight, lsm_weight, 45 | length_normalized_loss, special_tokens) 46 | assert reverse_weight == 0.0 47 | self.sos = special_tokens["sot"] 48 | self.eos = special_tokens["eot"] 49 | 50 | # TODO(xcsong): time align 51 | def set_alignment_heads(self, dump: bytes): 52 | raise NotImplementedError 53 | 54 | @property 55 | def is_multilingual(self): 56 | return self.vocab_size >= 51865 57 | 58 | @property 59 | def num_languages(self): 60 | return self.vocab_size - 51765 - int(self.is_multilingual) 61 | 62 | def _calc_att_loss( 63 | self, 64 | encoder_out: torch.Tensor, 65 | encoder_mask: torch.Tensor, 66 | ys_pad: torch.Tensor, 67 | ys_pad_lens: torch.Tensor, 68 | infos: Dict[str, List[str]], 69 | ) -> Tuple[torch.Tensor, float]: 70 | prev_len = ys_pad.size(1) 71 | ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens, 72 | ys_pad, 73 | self.ignore_id, 74 | tasks=infos['tasks'], 75 | no_timestamp=True, 76 | langs=infos['langs'], 77 | use_prev=False) 78 | cur_len = ys_in_pad.size(1) 79 | ys_in_lens = ys_pad_lens + cur_len - prev_len 80 | 81 | # 1. Forward decoder 82 | decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, 83 | ys_in_pad, ys_in_lens) 84 | 85 | # 2. Compute attention loss 86 | loss_att = self.criterion_att(decoder_out, ys_out_pad) 87 | acc_att = th_accuracy( 88 | decoder_out.view(-1, self.vocab_size), 89 | ys_out_pad, 90 | ignore_label=self.ignore_id, 91 | ) 92 | return loss_att, acc_att 93 | -------------------------------------------------------------------------------- /asr/wenet/bin/export_ipex.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023 Intel Corporation 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import logging 8 | import os 9 | 10 | import torch 11 | import yaml 12 | 13 | from wenet.utils.init_model import init_model 14 | import intel_extension_for_pytorch as ipex 15 | from intel_extension_for_pytorch.quantization import prepare, convert 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser(description='export your script model') 20 | parser.add_argument('--config', required=True, help='config file') 21 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 22 | parser.add_argument('--output_file', default=None, help='output file') 23 | parser.add_argument('--dtype', 24 | default="fp32", 25 | help='choose the dtype to run:[fp32,bf16]') 26 | parser.add_argument('--output_quant_file', 27 | default=None, 28 | help='output quantized model file') 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def scripting(model): 34 | with torch.inference_mode(): 35 | script_model = torch.jit.script(model) 36 | script_model = torch.jit.freeze( 37 | script_model, 38 | preserved_attrs=[ 39 | "forward_encoder_chunk", "ctc_activation", 40 | "forward_attention_decoder", "subsampling_rate", 41 | "right_context", "sos_symbol", "eos_symbol", 42 | "is_bidirectional_decoder" 43 | ]) 44 | return script_model 45 | 46 | 47 | def main(): 48 | args = get_args() 49 | logging.basicConfig(level=logging.DEBUG, 50 | format='%(asctime)s %(levelname)s %(message)s') 51 | # No need gpu for model export 52 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 53 | 54 | with open(args.config, 'r') as fin: 55 | configs = yaml.load(fin, Loader=yaml.FullLoader) 56 | model, configs = init_model(args, configs) 57 | print(model) 58 | 59 | # Apply IPEX optimization 60 | model.eval() 61 | torch._C._jit_set_texpr_fuser_enabled(False) 62 | model.to(memory_format=torch.channels_last) 63 | if args.dtype == "fp32": 64 | ipex_model = ipex.optimize(model) 65 | elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) 66 | ipex_model = ipex.optimize(model, 67 | dtype=torch.bfloat16, 68 | weights_prepack=False) 69 | 70 | # Export jit torch script model 71 | if args.output_file: 72 | if args.dtype == "fp32": 73 | script_model = scripting(ipex_model) 74 | elif args.dtype == "bf16": 75 | torch._C._jit_set_autocast_mode(True) 76 | with torch.cpu.amp.autocast(): 77 | script_model = scripting(ipex_model) 78 | script_model.save(args.output_file) 79 | print('Export model successfully, see {}'.format(args.output_file)) 80 | 81 | # Export quantized jit torch script model 82 | if args.output_quant_file: 83 | dynamic_qconfig = ipex.quantization.default_dynamic_qconfig 84 | dummy_data = (torch.zeros(1, 67, 80), 16, -16, 85 | torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7)) 86 | model = prepare(model, dynamic_qconfig, dummy_data) 87 | model = convert(model) 88 | script_quant_model = scripting(model) 89 | script_quant_model.save(args.output_quant_file) 90 | print('Export quantized model successfully, ' 91 | 'see {}'.format(args.output_quant_file)) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /asr/wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /asr/wenet/text/whisper_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Dict, List, Optional, Tuple, Union 3 | from wenet.text.base_tokenizer import BaseTokenizer 4 | 5 | from wenet.utils.file_utils import read_non_lang_symbols 6 | 7 | 8 | class WhisperTokenizer(BaseTokenizer): 9 | 10 | def __init__( 11 | self, 12 | multilingual: bool, 13 | num_languages: int = 99, 14 | language: Optional[str] = None, 15 | task: Optional[str] = None, 16 | non_lang_syms: Optional[Union[str, PathLike, List]] = None, 17 | *args, 18 | **kwargs, 19 | ) -> None: 20 | # NOTE(Mddct): don't build here, pickle issues 21 | self.tokenizer = None 22 | # TODO: we don't need this in future 23 | self.multilingual = multilingual 24 | self.num_languages = num_languages 25 | self.language = language 26 | self.task = task 27 | 28 | if not isinstance(non_lang_syms, List): 29 | self.non_lang_syms = read_non_lang_symbols(non_lang_syms) 30 | else: 31 | # non_lang_syms=["{NOISE}"] 32 | self.non_lang_syms = non_lang_syms 33 | # TODO(Mddct): add special tokens, like non_lang_syms 34 | del self.non_lang_syms 35 | 36 | def __getstate__(self): 37 | state = self.__dict__.copy() 38 | del state['tokenizer'] 39 | return state 40 | 41 | def __setstate__(self, state): 42 | self.__dict__.update(state) 43 | recovery = {'tokenizer': None} 44 | self.__dict__.update(recovery) 45 | 46 | def _build_tiktoken(self): 47 | if self.tokenizer is None: 48 | from whisper.tokenizer import get_tokenizer 49 | self.tokenizer = get_tokenizer(multilingual=self.multilingual, 50 | num_languages=self.num_languages, 51 | language=self.language, 52 | task=self.task) 53 | self.t2i = {} 54 | self.i2t = {} 55 | for i in range(self.tokenizer.encoding.n_vocab): 56 | unit = str( 57 | self.tokenizer.encoding.decode_single_token_bytes(i)) 58 | if len(unit) == 0: 59 | unit = str(i) 60 | unit = unit.replace(" ", "") 61 | # unit = bytes(unit, 'utf-8') 62 | self.t2i[unit] = i 63 | self.i2t[i] = unit 64 | assert len(self.t2i) == len(self.i2t) 65 | 66 | def tokenize(self, line: str) -> Tuple[List[str], List[int]]: 67 | self._build_tiktoken() 68 | ids = self.tokenizer.encoding.encode(line) 69 | text = [self.i2t[d] for d in ids] 70 | return text, ids 71 | 72 | def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]: 73 | self._build_tiktoken() 74 | tokens = [self.i2t[d] for d in ids] 75 | text = self.tokenizer.encoding.decode(ids) 76 | return text, tokens 77 | 78 | def text2tokens(self, line: str) -> List[str]: 79 | self._build_tiktoken() 80 | return self.tokenize(line)[0] 81 | 82 | def tokens2text(self, tokens: List[str]) -> str: 83 | self._build_tiktoken() 84 | ids = [self.t2i[t] for t in tokens] 85 | return self.detokenize(ids)[0] 86 | 87 | def tokens2ids(self, tokens: List[str]) -> List[int]: 88 | self._build_tiktoken() 89 | ids = [self.t2i[t] for t in tokens] 90 | return ids 91 | 92 | def ids2tokens(self, ids: List[int]) -> List[str]: 93 | self._build_tiktoken() 94 | return [self.tokenizer.encoding.decode([id]) for id in ids] 95 | 96 | def vocab_size(self) -> int: 97 | self._build_tiktoken() 98 | return len(self.t2i) 99 | 100 | @property 101 | def symbol_table(self) -> Dict[str, int]: 102 | self._build_tiktoken() 103 | return self.t2i 104 | -------------------------------------------------------------------------------- /asr/wenet/cli/hub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Mddct(hamddct@gmail.com) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import requests 17 | import sys 18 | import tarfile 19 | from pathlib import Path 20 | from urllib.request import urlretrieve 21 | 22 | import tqdm 23 | 24 | 25 | def download(url: str, dest: str, only_child=True): 26 | """ download from url to dest 27 | """ 28 | assert os.path.exists(dest) 29 | print('Downloading {} to {}'.format(url, dest)) 30 | 31 | def progress_hook(t): 32 | last_b = [0] 33 | 34 | def update_to(b=1, bsize=1, tsize=None): 35 | if tsize not in (None, -1): 36 | t.total = tsize 37 | displayed = t.update((b - last_b[0]) * bsize) 38 | last_b[0] = b 39 | return displayed 40 | 41 | return update_to 42 | 43 | # *.tar.gz 44 | name = url.split('?')[0].split('/')[-1] 45 | tar_path = os.path.join(dest, name) 46 | with tqdm.tqdm(unit='B', 47 | unit_scale=True, 48 | unit_divisor=1024, 49 | miniters=1, 50 | desc=(name)) as t: 51 | urlretrieve(url, 52 | filename=tar_path, 53 | reporthook=progress_hook(t), 54 | data=None) 55 | t.total = t.n 56 | 57 | with tarfile.open(tar_path) as f: 58 | if not only_child: 59 | f.extractall(dest) 60 | else: 61 | for tarinfo in f: 62 | if "/" not in tarinfo.name: 63 | continue 64 | name = os.path.basename(tarinfo.name) 65 | fileobj = f.extractfile(tarinfo) 66 | with open(os.path.join(dest, name), "wb") as writer: 67 | writer.write(fileobj.read()) 68 | 69 | 70 | class Hub(object): 71 | """Hub for wenet pretrain runtime model 72 | """ 73 | # TODO(Mddct): make assets class to support other language 74 | Assets = { 75 | # wenetspeech 76 | "chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz", 77 | # gigaspeech 78 | "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", 79 | # paraformer 80 | "paraformer": "paraformer.tar.gz" 81 | } 82 | 83 | def __init__(self) -> None: 84 | pass 85 | 86 | @staticmethod 87 | def get_model_by_lang(lang: str) -> str: 88 | if lang not in Hub.Assets.keys(): 89 | print('ERROR: Unsupported language {} !!!'.format(lang)) 90 | sys.exit(1) 91 | 92 | # NOTE(Mddct): model_dir structure 93 | # Path.Home()/.wenet 94 | # - chs 95 | # - units.txt 96 | # - final.zip 97 | # - en 98 | # - units.txt 99 | # - final.zip 100 | model = Hub.Assets[lang] 101 | model_dir = os.path.join(Path.home(), ".wenet", lang) 102 | if not os.path.exists(model_dir): 103 | os.makedirs(model_dir) 104 | # TODO(Mddct): model metadata 105 | if set(["final.zip", 106 | "units.txt"]).issubset(set(os.listdir(model_dir))): 107 | return model_dir 108 | # If not exist, download 109 | response = requests.get( 110 | "https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa 111 | ) 112 | model_info = next(data for data in response.json()["Data"] 113 | if data["Key"] == model) 114 | model_url = model_info['Url'] 115 | download(model_url, model_dir, only_child=True) 116 | return model_dir 117 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/bitransducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # from typeguard import check_argument_types 4 | 5 | # from wenet.transformer.transducer.utils import get_transducer_task_io 6 | from wenet.transducer_espnet.abs_decoder import AbsDecoder 7 | from wenet.transducer_espnet.transducer import Transducer 8 | 9 | from wenet.utils.common import reverse_pad_list 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | class BiTransducer(torch.nn.Module): 13 | """Bidirectional Transducer module""" 14 | def __init__( 15 | self, 16 | joint_network: torch.nn.Module, 17 | transducer_decoder: AbsDecoder, 18 | joint_network_r: torch.nn.Module, 19 | transducer_decoder_r: AbsDecoder, 20 | ignore_id: int, 21 | trans_type: str 22 | ): 23 | """ Construct CTC module 24 | Args: 25 | odim: dimension of outputs 26 | encoder_output_size: number of encoder projection units 27 | dropout_rate: dropout rate (0.0 ~ 1.0) 28 | reduce: reduce the CTC loss into a scalar 29 | """ 30 | # assert check_argument_types() 31 | super().__init__() 32 | 33 | # from warprnnt_pytorch import RNNTLoss 34 | 35 | self.transducer_decoder = transducer_decoder 36 | self.joint_network = joint_network 37 | self.transducer_decoder_r = transducer_decoder_r 38 | self.joint_network_r = joint_network_r 39 | 40 | self.blank_id = 0 41 | self.ignore_id = ignore_id 42 | self.trans_type = trans_type 43 | 44 | self.transducer = Transducer(joint_network, transducer_decoder, ignore_id, trans_type) 45 | self.transducer_r = Transducer(joint_network_r, transducer_decoder_r, ignore_id, trans_type) 46 | 47 | def reverse_features_pad_list(self, x_pad: torch.Tensor, 48 | x_lengths: torch.Tensor, 49 | pad_value: float = -1.0) -> torch.Tensor: 50 | """Reverse padding for the list of tensors. 51 | 52 | Args: 53 | ys_pad (tensor): The padded tensor (B, Tokenmax). 54 | ys_lens (tensor): The lens of token seqs (B) 55 | pad_value (int): Value for padding. 56 | 57 | Returns: 58 | Tensor: Padded tensor (B, Tokenmax). 59 | 60 | Examples: 61 | >>> x 62 | tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) 63 | >>> pad_list(x, 0) 64 | tensor([[4, 3, 2, 1], 65 | [7, 6, 5, 0], 66 | [9, 8, 0, 0]]) 67 | 68 | """ 69 | 70 | x_pad_reverse = pad_sequence([(torch.flip(x[:i], [0])) 71 | for x, i in zip(x_pad, x_lengths)], True, 72 | pad_value) 73 | return x_pad_reverse 74 | 75 | def forward( 76 | self, 77 | encoder_out: torch.Tensor, 78 | encoder_out_lens: torch.Tensor, 79 | labels: torch.Tensor, 80 | labels_lengths: torch.Tensor, 81 | ): 82 | """Compute Transducer loss. 83 | 84 | Args: 85 | encoder_out: Encoder output sequences. (B, T, D_enc) 86 | encoder_out_lens: Encoder output sequences lengths. (B,) 87 | labels: Label ID sequences. (B, L) 88 | 89 | Return: 90 | loss_transducer: Transducer loss value. 91 | 92 | """ 93 | 94 | encoder_out_r = self.reverse_features_pad_list(encoder_out, encoder_out_lens, 0.0) 95 | labels_r = reverse_pad_list(labels, labels_lengths, float(self.ignore_id)) 96 | # print(labels.size()) 97 | # print(labels[1]) 98 | # print(labels_r[1]) 99 | # print(encoder_out.size()) 100 | # print(encoder_out_lens) 101 | # print(encoder_out[0,394,:5]) 102 | # print(encoder_out[2,394,:5]) 103 | # print(encoder_out_r.size()) 104 | 105 | loss_transducer_l = self.transducer(encoder_out, encoder_out_lens, labels, labels_lengths) 106 | # reverse encoder and lables? 107 | loss_transducer_r = self.transducer_r(encoder_out_r, encoder_out_lens, labels_r, labels_lengths) 108 | loss_transducer = 0.7*loss_transducer_l + 0.3*loss_transducer_r 109 | 110 | return loss_transducer 111 | -------------------------------------------------------------------------------- /asr/wenet/onmt_translate/penalties.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam_search.py 2 | # MIT Licence 3 | import torch 4 | 5 | from typing import List, Optional, Tuple, Dict, Any 6 | 7 | class PenaltyBuilder(object): 8 | """Returns the Length and Coverage Penalty function for Beam Search. 9 | 10 | Args: 11 | length_pen (str): option name of length pen 12 | cov_pen (str): option name of cov pen 13 | 14 | Attributes: 15 | has_cov_pen (bool): Whether coverage penalty is None (applying it 16 | is a no-op). Note that the converse isn't true. Setting beta 17 | to 0 should force coverage length to be a no-op. 18 | has_len_pen (bool): Whether length penalty is None (applying it 19 | is a no-op). Note that the converse isn't true. Setting alpha 20 | to 1 should force length penalty to be a no-op. 21 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 22 | Calculates the coverage penalty. 23 | length_penalty (callable[[int, float], float]): Calculates 24 | the length penalty. 25 | """ 26 | 27 | def __init__(self, cov_pen: Optional[str], length_pen: Optional[str]): 28 | self.has_cov_pen = not self._pen_is_none(cov_pen) 29 | self.coverage_penalty = cov_pen #self._coverage_penalty(cov_pen) 30 | self.has_len_pen = not self._pen_is_none(length_pen) 31 | self.length_penalty = length_pen #self._length_penalty(length_pen) 32 | 33 | #@staticmethod 34 | def _pen_is_none(self, pen: Optional[str]): 35 | return pen is None or pen == "none" 36 | 37 | """ 38 | def _coverage_penalty(self, cov_pen: Optional[str]): 39 | if self._pen_is_none(cov_pen): 40 | return self.coverage_none 41 | elif cov_pen == "wu": 42 | return self.coverage_wu 43 | elif cov_pen == "summary": 44 | return self.coverage_summary 45 | else: 46 | raise NotImplementedError("No '{:s}' coverage penalty.".format(cov_pen)) 47 | 48 | def _length_penalty(self, length_pen: Optional[str]): 49 | if self._pen_is_none(length_pen): 50 | return self.length_none 51 | elif length_pen == "wu": 52 | return self.length_wu 53 | elif length_pen == "avg": 54 | return self.length_average 55 | else: 56 | raise NotImplementedError("No '{:s}' length penalty.".format(length_pen)) 57 | """ 58 | 59 | # Below are all the different penalty terms implemented so far. 60 | # Subtract coverage penalty from topk log probs. 61 | # Divide topk log probs by length penalty. 62 | 63 | def coverage_wu(self, cov, beta: float=0.0): 64 | """GNMT coverage re-ranking score. 65 | 66 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 67 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 68 | probably ``batch_size x beam_size`` but could be several 69 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 70 | then the ``seq_len`` axis probably sums to (almost) 1. 71 | """ 72 | 73 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 74 | return beta * penalty 75 | 76 | def coverage_summary(self, cov, beta:float=0.0): 77 | """Our summary penalty.""" 78 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 79 | penalty -= cov.size(-1) 80 | return beta * penalty 81 | 82 | def coverage_none(self, cov, beta:float=0.0): 83 | """Returns zero as penalty""" 84 | none = torch.zeros((1,), device=cov.device, dtype=torch.float) 85 | if cov.dim() == 3: 86 | none = torch.zeros((cov.shape[0], 1,), device=cov.device, dtype=torch.float) 87 | return none 88 | 89 | def length_wu(self, cur_len: int, alpha:float=0.0): 90 | """GNMT length re-ranking score. 91 | 92 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 93 | """ 94 | 95 | return ((5 + cur_len) / 6.0) ** alpha 96 | 97 | def length_average(self, cur_len: int, alpha:float=1.0): 98 | """Returns the current sequence length.""" 99 | return cur_len**alpha 100 | 101 | def length_none(self, cur_len: int, alpha:float=0.0): 102 | """Returns unmodified scores.""" 103 | return 1.0 104 | -------------------------------------------------------------------------------- /asr/wer_evaluation/aggregate_scoring.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from argparse import ArgumentParser 6 | from dataclasses import dataclass 7 | import json 8 | from pathlib import Path 9 | from typing import Dict 10 | 11 | 12 | def init_args(): 13 | parser = ArgumentParser( 14 | description="Takes in directory of fstalign outputs and calculates " 15 | " the aggregate WER metric over the full test suite." 16 | ) 17 | parser.add_argument( 18 | 'fstalign_out', type=Path, 19 | help="Path to an fstalign alignment output directory. This script" 20 | " relies specifically on the output from setting the --log-json" 21 | " flag in fstalign." 22 | ) 23 | return parser.parse_args() 24 | 25 | 26 | @dataclass 27 | class WERAggregator: 28 | insertion_count: int = 0 29 | deletion_count: int = 0 30 | substitution_count: int = 0 31 | correct_count: int = 0 32 | reference_count: int = 0 33 | 34 | def update(self, alignment_dict: Dict[str, float]): 35 | """Given a dictionary with alignment statistics from fstalign, 36 | update the corresponding counts. 37 | """ 38 | self.insertion_count += alignment_dict['insertions'] 39 | self.deletion_count += alignment_dict['deletions'] 40 | self.substitution_count += (alignment_dict['numErrors'] - alignment_dict['insertions'] - alignment_dict['deletions']) 41 | self.correct_count += (alignment_dict['numWordsInReference'] - alignment_dict['substitutions'] - alignment_dict['deletions']) 42 | self.reference_count += alignment_dict['numWordsInReference'] 43 | 44 | @property 45 | def num_errors(self): 46 | """Calculates the total number of errors of the aggregator. 47 | """ 48 | return self.insertion_count + self.deletion_count + self.substitution_count 49 | 50 | def check_state(self): 51 | """Ensures all assumptions are valid prior to calculations. Raises 52 | Raises an error if assumptions are broken. 53 | """ 54 | if self.reference_count == 0: 55 | raise RuntimeError("Something went wrong! Cannot compute a rate when `reference_count` is 0.") 56 | 57 | def insertion_rate(self) -> float: 58 | """Returns a float of the aggregator's insertion rate. 59 | """ 60 | self.check_state() 61 | return self.insertion_count / self.reference_count 62 | 63 | def deletion_rate(self) -> float: 64 | """Returns a float of the aggregator's deletion rate. 65 | """ 66 | self.check_state() 67 | return self.deletion_count / self.reference_count 68 | 69 | def substitution_rate(self) -> float: 70 | """Returns a float of the aggregator's substitution rate. 71 | """ 72 | self.check_state() 73 | return self.substitution_count / self.reference_count 74 | 75 | def wer(self) -> float: 76 | """Returns a float of the aggregator's WER (word error rate). 77 | """ 78 | self.check_state() 79 | return self.num_errors / self.reference_count 80 | 81 | def summary(self) -> str: 82 | """Provides a string summary of all aggregator's state in a formatted string. 83 | This includes: 84 | * WER 85 | * Insertion Rate 86 | * Deletion Rate 87 | * Substitution Rate 88 | """ 89 | def format_rate(title: str, numerator: int, rate: float) -> str: 90 | """Creates a string to represent a rate given its numerator 91 | and denominator. 92 | """ 93 | return f"{title}:\t{numerator}/{self.reference_count} = {rate:3.2%}" 94 | 95 | summary = [ 96 | format_rate("TOTAL WER", self.num_errors, self.wer()), 97 | format_rate("Insertion Rate", self.insertion_count, self.insertion_rate()), 98 | format_rate("Deletion Rate", self.deletion_count, self.deletion_rate()), 99 | format_rate("Substitution Rate", self.substitution_count, self.substitution_rate()), 100 | ] 101 | return '\n'.join(summary) 102 | 103 | 104 | if __name__ == '__main__': 105 | args = init_args() 106 | 107 | aggregator = WERAggregator() 108 | 109 | for json_path in args.fstalign_out.glob("*.json"): 110 | with json_path.open('r') as jfile: 111 | alignment_results = json.load(jfile) 112 | aggregator.update(alignment_results['wer']['bestWER']) 113 | 114 | print(aggregator.summary()) 115 | -------------------------------------------------------------------------------- /asr/wer_evaluation/scoring_commands.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | 10 | def init_args(): 11 | parser = ArgumentParser( 12 | description="Generates fstalign a list of commands that" 13 | " will perform adequate alignment between a test-suite" 14 | " and hypothesis directory. This script assumes the" 15 | " hypotheses are in CTM format and the references are" 16 | " in NLP format." 17 | ) 18 | parser.add_argument( 19 | 'fstalign', type=Path, 20 | help="Path to the fstalign binary." 21 | ) 22 | parser.add_argument( 23 | 'ref', type=Path, 24 | help="Path to test suite transcript file(s). Pass in either" 25 | " one file or a directory. Files *must* be in NLP format." 26 | ) 27 | parser.add_argument( 28 | 'hyp', type=Path, 29 | help="Path to ASR Hypothesis file(s). Pass in either one file" 30 | " or a directory. *Assumes files are in CTM format*." 31 | ) 32 | parser.add_argument( 33 | 'out', type=Path, 34 | help="Path to output directory to contain fstalign files." 35 | ) 36 | parser.add_argument( 37 | '--ref-norm', type=Path, 38 | default=None, 39 | help="Path to test suite normalization file(s). Pass in either" 40 | " one file or a directory. *In order for fstalign to use" 41 | " normalizations files, references in NLP format with normalization" 42 | " tags are required. Unexpected results may occur if used" 43 | " incorrectly.*" 44 | ) 45 | parser.add_argument( 46 | '--synonyms-file', type=Path, 47 | default=None, 48 | help="Path to fstalign, synonym file." 49 | ) 50 | return parser.parse_args() 51 | 52 | 53 | def prepare_IO( 54 | ref_path: Path, 55 | hyp_path: Path, 56 | out_path: Path, 57 | ref_norm_path: Optional[Path]=None, 58 | ): 59 | """Determines if the hyp_path provided is a directory or a file 60 | and appropriately identifies paths for fstalign files. Creates 61 | a Generator to return the files. 62 | 63 | out_path will always be made into a directory 64 | 65 | When hyp_path is a file: 66 | * ref_path and ref_norm_path are assumed to be files. 67 | * The resulting JSON from fstalign will be in out_path with the 68 | same name as the hyp_path. 69 | 70 | When hyp_path is a directory: 71 | * ref_path and ref_norm_path are assumed to be directories. 72 | For each hypothesis CTM in hyp_path, the equivalent reference 73 | transcripts and normalizations are assumed to have the same 74 | name as the hypothesis file. 75 | * The resulting JSONs from fstalign will be in out_path with the 76 | same name as the CTMs in hyp_path. 77 | """ 78 | out_path.mkdir(parents=True, exist_ok=True) 79 | if hyp_path.is_dir(): 80 | for hyp_file in hyp_path.glob("**/*.ctm"): 81 | hyp_name = hyp_file.stem 82 | ref_file = (ref_path / (hyp_name + ".nlp")).resolve() 83 | out_file = (out_path / (hyp_name + ".log.json")).resolve() 84 | ref_norm_file = None 85 | if ref_norm_path: 86 | ref_norm_file = (ref_norm_path / (hyp_name + '.norm.json')).resolve() 87 | yield ref_file, hyp_file.resolve(), out_file, ref_norm_file 88 | else: 89 | out_file = (out_path / (hyp_path.stem + ".log.json")).resolve() 90 | if ref_norm_path: 91 | ref_norm_path = ref_norm_path.resolve() 92 | yield ref_path.resolve(), hyp_path.resolve(), out_file, ref_norm_path 93 | 94 | 95 | if __name__ == '__main__': 96 | args = init_args() 97 | 98 | for ref_file, hyp_file, out_file, ref_norm_file in prepare_IO(args.ref, args.hyp, args.out, args.ref_norm): 99 | alignment_command = [ 100 | str(args.fstalign), 101 | "wer", 102 | "--ref", 103 | str(ref_file), 104 | "--hyp", 105 | str(hyp_file), 106 | "--json-log", 107 | str(out_file), 108 | ] 109 | if ref_norm_file: 110 | alignment_command.extend([ 111 | "--ref-json", 112 | str(ref_norm_file), 113 | ]) 114 | if args.synonyms_file: 115 | alignment_command.extend([ 116 | "--syn", 117 | str(args.synonyms_file), 118 | ]) 119 | 120 | print(' '.join(alignment_command)) 121 | -------------------------------------------------------------------------------- /diarization/README.md: -------------------------------------------------------------------------------- 1 | ## Rev's diarization models 2 | This repository contains 2 new speaker diarization models built upon the 3 | [PyAnnote](https://github.com/pyannote/pyannote-audio) framework. These models are trained and intended 4 | for the usage with ASR system (speaker attributed ASR). 5 | 6 | The smaller model - `Reverb Diarization V1` - provides a **16.5%** relative improvement in WDER (Word Diarization Error Rate) 7 | compared to the baseline pyannote3.0 model, 8 | evaluated on over 1,250,000 tokens across five different test suites. 9 | The larger model - `Reverb Diarization V2` - offers **22.25%** relative improvement over pyannote3.0 model. 10 | 11 | Both models can be found on HF https://huggingface.co/Revai and are integrated into HF via pyannote. 12 | 13 | ## Table of Contents 14 | - [Usage](#usage) 15 | - [Assigning words to speakers](#assigning-words-to-speakers) 16 | - [Running training script](#running-training-script) 17 | - [Results](#results) 18 | - [Acknowledgments](#Acknowledgments) 19 | 20 | ## Usage 21 | We recommend running on GPU. Dockerfile is CUDA ready and CUDA 12.4+ is required. 22 | 23 | ```python 24 | # taken from https://huggingface.co/pyannote/speaker-diarization-3.1 - see for more details 25 | # instantiate the pipeline 26 | from pyannote.audio import Pipeline 27 | pipeline = Pipeline.from_pretrained( 28 | "Revai/reverb-diarization-v1", 29 | use_auth_token="HUGGINGFACE_ACCESS_TOKEN_GOES_HERE") 30 | 31 | # run the pipeline on an audio file 32 | diarization = pipeline("audio.wav") 33 | 34 | # dump the diarization output to disk using RTTM format 35 | with open("audio.rttm", "w") as rttm: 36 | diarization.write_rttm(rttm) 37 | ``` 38 | 39 | Eventually, you can use a provided script `infer_pyannote3.0.py`. 40 | You can run diarization on a single audio file (or list of audio files). The same approach can be used for Docker. 41 | The output format is a standard RTTM stored in the output directory with `basename.rttm` format. 42 | ```bash 43 | python infer_pyannote3.0.py /path/to/audios --out-dir /path/to/outdir 44 | ``` 45 | You can specify the model you want to run via the `--pipeline-model` argument - 46 | `Revai/reverb-diarization-pipeline-v1` or `Revai/reverb-diarization-pipeline-v2`. 47 | 48 | 49 | ### Assigning words to speakers 50 | It is possible to assign words to speakers if ASR was previously run. 51 | The script `assign_words2speaker.py` takes a diarization segmentation and ASR transcription in 52 | CTM format to output speaker assignment to tokens (words). 53 | ```bash 54 | python assign_words2speaker.py speaker_segments.rttm words.ctm transcript.stm 55 | ``` 56 | Read more about stm format [here](https://www.nist.gov/system/files/documents/2021/08/31/OpenASR21_EvalPlan_v1_3_1.pdf) (we can't store speaker identities in ctm). 57 | 58 | ### Running training script 59 | We do provide the training script that was used to fine-tune original pyannote3.0 model. 60 | The training script is run as follows: 61 | ```bash 62 | python train_pyannote3.0.py --database data/database.yaml 63 | ``` 64 | The `--database` parameter points to yaml database file; we provide an example file that is 65 | easy to use. You need to specify .uri, .uem and .rttm files; for a more detailed 66 | description please refer to pyannote documentation. 67 | 68 | 69 | ## Results 70 | While DER is a valuable metric for assessing the technical performance of a diarization model 71 | in isolation, WDER is more crucial in the context of ASR because it reflects the combined 72 | effectiveness of both the diarization and ASR components in producing accurate, 73 | speaker-attributed text. In practical applications where the accuracy of both “who spoke” 74 | and “what was spoken” is essential, WDER provides a more meaningful and relevant measure 75 | for evaluating system performance and guiding improvements. 76 | For this reason we only report WDER metrics. We also plan to add WDER into `pyannote.metrics` 77 | codebase. 78 | 79 | ### Reverb Diarization V1 80 | | Test suite | WDER | 81 | |------------------------------------------------------------------------------------|-------| 82 | | [earnings21](https://github.com/revdotcom/speech-datasets/tree/rttm_v1/earnings21) | 0.047 | 83 | | rev16 | 0.077 | 84 | 85 | ### Reverb Diarization V2 86 | | Test suite | WDER | 87 | |------------------------------------------------------------------------------------|-------| 88 | | [earnings21](https://github.com/revdotcom/speech-datasets/tree/rttm_v1/earnings21) | 0.046 | 89 | | rev16 | 0.078 | 90 | 91 | ## Acknowledgments 92 | Special thanks to Hervé Bredin for developing and open-sourcing pyannote. 93 | -------------------------------------------------------------------------------- /asr/wenet/finetune/lora/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) 4 | # 2024 Alan (alanfangemail@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Multi-Head Attention layer definition with lora.""" 18 | 19 | from typing import Optional, List 20 | 21 | import torch 22 | from torch import nn 23 | 24 | from wenet.transformer.attention import (MultiHeadedAttention, 25 | RelPositionMultiHeadedAttention) 26 | import wenet.finetune.lora.layers as lora 27 | 28 | 29 | class LoRAMultiHeadedAttention(MultiHeadedAttention): 30 | """Multi-Head Attention layer with lora. 31 | 32 | Args: 33 | n_head (int): The number of heads. 34 | n_feat (int): The number of features. 35 | dropout_rate (float): Dropout rate. 36 | 37 | """ 38 | 39 | def __init__(self, 40 | n_head: int, 41 | n_feat: int, 42 | dropout_rate: float, 43 | key_bias: bool = True, 44 | lora_rank: int = 8, 45 | lora_alpha: int = 8, 46 | lora_dropout: float = 0.0, 47 | lora_list: Optional[List[str]] = None): 48 | """Construct an MultiHeadedAttention object.""" 49 | super().__init__(n_head, n_feat, dropout_rate, key_bias) 50 | assert n_feat % n_head == 0 51 | # We assume d_v always equals d_k 52 | self.d_k = n_feat // n_head 53 | self.h = n_head 54 | self.linear_out = lora.Linear( 55 | n_feat, 56 | n_feat, 57 | r=lora_rank, 58 | lora_alpha=lora_alpha, 59 | lora_dropout=lora_dropout 60 | ) if lora_list and "o" in lora_list else nn.Linear(n_feat, n_feat) 61 | 62 | lora_qkv_dict = { 63 | "q": lora_list and "q" in lora_list, 64 | "k": lora_list and "k" in lora_list, 65 | "v": lora_list and "v" in lora_list 66 | } 67 | 68 | for key, value in lora_qkv_dict.items(): 69 | setattr( 70 | self, f"linear_{key}", 71 | lora.Linear(n_feat, 72 | n_feat, 73 | r=lora_rank, 74 | lora_alpha=lora_alpha, 75 | lora_dropout=lora_dropout) if value else nn.Linear( 76 | n_feat, n_feat)) 77 | self.dropout = nn.Dropout(p=dropout_rate) 78 | 79 | 80 | class LoRARelPositionMultiHeadedAttention(LoRAMultiHeadedAttention, 81 | RelPositionMultiHeadedAttention): 82 | """Multi-Head Attention layer with relative position encoding. 83 | Paper: https://arxiv.org/abs/1901.02860 84 | Args: 85 | n_head (int): The number of heads. 86 | n_feat (int): The number of features. 87 | dropout_rate (float): Dropout rate. 88 | """ 89 | 90 | def __init__(self, 91 | n_head: int, 92 | n_feat: int, 93 | dropout_rate: float, 94 | key_bias: bool = True, 95 | lora_rank: int = 8, 96 | lora_alpha: int = 8, 97 | lora_dropout: float = 0.0, 98 | lora_list: Optional[List[str]] = None): 99 | """Construct an RelPositionMultiHeadedAttention object.""" 100 | super().__init__(n_head, n_feat, dropout_rate, key_bias, 101 | lora_rank, lora_alpha, 102 | lora_dropout, lora_list) 103 | # linear transformation for positional encoding 104 | self.linear_pos = nn.Linear(n_feat, n_feat) 105 | # these two learnable bias are used in matrix c and matrix d 106 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 107 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 108 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 109 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 110 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 111 | -------------------------------------------------------------------------------- /asr/wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /asr/wenet/ssl/wav2vec2/quantizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | 5 | def gumbel(shape: torch.Size, dtype: torch.dtype, device: torch.device): 6 | """Sample Gumbel random values with given shape and float dtype. 7 | 8 | The values are distributed according to the probability density function: 9 | 10 | .. math:: 11 | f(x) = e^{-(x + e^{-x})} 12 | 13 | Args: 14 | shape (torch.Size): pdf shape 15 | dtype (torch.dtype): pdf value dtype 16 | 17 | Returns: 18 | A random array with the specified shape and dtype. 19 | """ 20 | # see https://www.cnblogs.com/initial-h/p/9468974.html for more details 21 | return -torch.log(-torch.log( 22 | torch.empty(shape, device=device).uniform_( 23 | torch.finfo(dtype).tiny, 1.))) 24 | 25 | 26 | class Wav2vecGumbelVectorQuantizer(torch.nn.Module): 27 | 28 | def __init__(self, 29 | features_dim: int = 256, 30 | num_codebooks: int = 2, 31 | num_embeddings: int = 8192, 32 | embedding_dim: int = 16, 33 | hard: bool = False) -> None: 34 | 35 | super().__init__() 36 | 37 | self.num_groups = num_codebooks 38 | self.num_codevectors_per_group = num_embeddings 39 | # codebooks 40 | # means [C, G, D] see quantize_vector in bestrq_model.py 41 | assert embedding_dim % num_codebooks == 0.0 42 | self.embeddings = torch.nn.parameter.Parameter( 43 | torch.empty(1, num_codebooks * num_embeddings, 44 | embedding_dim // num_codebooks), 45 | requires_grad=True, 46 | ) 47 | torch.nn.init.uniform_(self.embeddings) 48 | 49 | self.weight_proj = torch.nn.Linear(features_dim, 50 | num_codebooks * num_embeddings) 51 | # use gumbel softmax or argmax(non-differentiable) 52 | self.hard = hard 53 | 54 | @staticmethod 55 | def _compute_perplexity(probs, mask=None): 56 | if mask is not None: 57 | 58 | mask_extended = torch.broadcast_to(mask.flatten()[:, None, None], 59 | probs.shape) 60 | probs = torch.where(mask_extended.to(torch.bool), probs, 61 | torch.zeros_like(probs)) 62 | marginal_probs = probs.sum(dim=0) / mask.sum() 63 | else: 64 | marginal_probs = probs.mean(dim=0) 65 | 66 | perplexity = torch.exp(-torch.sum( 67 | marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() 68 | return perplexity 69 | 70 | def forward( 71 | self, 72 | input: torch.Tensor, 73 | input_mask: torch.Tensor, 74 | temperature: float = 1. 75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 76 | 77 | b, t, _ = input.size() 78 | 79 | hidden = self.weight_proj(input) 80 | hidden = hidden.reshape(b * t * self.num_groups, -1) 81 | if not self.hard: 82 | # sample code vector probs via gumbel in differentiateable way 83 | gumbels = gumbel(hidden.size(), hidden.dtype, hidden.device) 84 | codevector_probs = torch.nn.functional.softmax( 85 | (hidden + gumbels) / temperature, dim=-1) 86 | 87 | # compute perplexity 88 | codevector_soft_dist = torch.nn.functional.softmax( 89 | hidden.reshape(b * t, self.num_groups, -1), 90 | dim=-1, 91 | ) # [B*T, num_codebooks, num_embeddings] 92 | perplexity = self._compute_perplexity(codevector_soft_dist, 93 | input_mask) 94 | else: 95 | # take argmax in non-differentiable way 96 | # comptute hard codevector distribution (one hot) 97 | codevector_idx = hidden.argmax(axis=-1) 98 | codevector_probs = torch.nn.functional.one_hot( 99 | codevector_idx, hidden.shape[-1]) * 1.0 100 | codevector_probs = codevector_probs.reshape( 101 | b * t, self.num_groups, -1) 102 | perplexity = self._compute_perplexity(codevector_probs, input_mask) 103 | 104 | targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1) 105 | codevector_probs = codevector_probs.reshape(b * t, -1) 106 | # use probs to retrieve codevectors 107 | codevectors_per_group = codevector_probs.unsqueeze( 108 | -1) * self.embeddings 109 | codevectors = codevectors_per_group.reshape( 110 | b * t, self.num_groups, self.num_codevectors_per_group, -1) 111 | 112 | codevectors = codevectors.sum(-2).reshape(b, t, -1) 113 | return codevectors, perplexity, targets_idx 114 | -------------------------------------------------------------------------------- /asr/wenet/transformer/ctc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | 16 | from typing import Tuple 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | 22 | class CTC(torch.nn.Module): 23 | """CTC module""" 24 | 25 | def __init__( 26 | self, 27 | odim: int, 28 | encoder_output_size: int, 29 | dropout_rate: float = 0.0, 30 | reduce: bool = True, 31 | blank_id: int = 0, 32 | do_focal_loss: bool = False, 33 | focal_alpha: float = 0.5, 34 | focal_gamma: float = 2, 35 | ): 36 | """ Construct CTC module 37 | Args: 38 | odim: dimension of outputs 39 | encoder_output_size: number of encoder projection units 40 | dropout_rate: dropout rate (0.0 ~ 1.0) 41 | reduce: reduce the CTC loss into a scalar 42 | blank_id: blank label. 43 | """ 44 | super().__init__() 45 | eprojs = encoder_output_size 46 | self.dropout_rate = dropout_rate 47 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 48 | 49 | self.do_focal_loss = do_focal_loss 50 | self.focal_alpha = focal_alpha 51 | self.focal_gamma = focal_gamma 52 | 53 | self.ctc_loss2 = torch.nn.CTCLoss(reduction='sum') 54 | self.ctc_loss3 = torch.nn.CTCLoss(reduction='mean') 55 | 56 | if do_focal_loss: 57 | reduction_type = "none" 58 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 59 | else: 60 | reduction_type = "sum" if reduce else "none" 61 | self.ctc_loss = torch.nn.CTCLoss(blank=blank_id, 62 | reduction=reduction_type, 63 | zero_infinity=True) 64 | 65 | def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, 66 | ys_pad: torch.Tensor, 67 | ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 68 | """Calculate CTC loss. 69 | 70 | Args: 71 | hs_pad: batch of padded hidden state sequences (B, Tmax, D) 72 | hlens: batch of lengths of hidden state sequences (B) 73 | ys_pad: batch of padded character id sequence tensor (B, Lmax) 74 | ys_lens: batch of lengths of character sequence (B) 75 | """ 76 | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) 77 | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) 78 | # ys_hat: (B, L, D) -> (L, B, D) 79 | ys_hat = ys_hat.transpose(0, 1) 80 | ys_hat = ys_hat.log_softmax(2) 81 | 82 | loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) 83 | if self.do_focal_loss: 84 | dbg=False 85 | # dbg = (torch.rand(1).item() < 0.05) 86 | p = torch.exp(-loss) 87 | if dbg: 88 | print(f"loss {loss}") 89 | print(f"log(p) {torch.log(p)}") 90 | 91 | loss = ((self.focal_alpha)*((1-p)**self.focal_gamma)*(loss)) 92 | if dbg: 93 | print(f"f-loss {loss}") 94 | loss = torch.mean(loss) 95 | if dbg: 96 | loss2 = self.ctc_loss2(ys_hat, ys_pad, hlens, ys_lens) 97 | loss3 = self.ctc_loss3(ys_hat, ys_pad, hlens, ys_lens) 98 | loss2 = loss2 / ys_hat.size(1) 99 | print(f"final loss = {loss}, loss2 {loss2}, loss3 {loss3}") 100 | else: 101 | # Batch-size average 102 | loss = loss / ys_hat.size(1) 103 | ys_hat = ys_hat.transpose(0, 1) 104 | return loss, ys_hat 105 | 106 | def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 107 | """log_softmax of frame activations 108 | 109 | Args: 110 | Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 111 | Returns: 112 | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) 113 | """ 114 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 115 | 116 | def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 117 | """argmax of frame activations 118 | 119 | Args: 120 | torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 121 | Returns: 122 | torch.Tensor: argmax applied 2d tensor (B, Tmax) 123 | """ 124 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 125 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/transducer_decoder_interface.py: -------------------------------------------------------------------------------- 1 | """Transducer decoder interface module.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | from typing import Dict 6 | from typing import List 7 | from typing import Optional 8 | from typing import Tuple 9 | from typing import Union 10 | 11 | import torch 12 | 13 | 14 | @dataclass 15 | class Hypothesis: 16 | """Default hypothesis definition for Transducer search algorithms.""" 17 | 18 | score: float 19 | yseq: List[int] 20 | dec_state: Union[ 21 | Tuple[torch.Tensor, Optional[torch.Tensor]], 22 | List[Optional[torch.Tensor]], 23 | torch.Tensor, 24 | ] 25 | lm_state: Union[Dict[str, Any], List[Any]] = None 26 | 27 | 28 | @dataclass 29 | class ExtendedHypothesis(Hypothesis): 30 | """Extended hypothesis definition for NSC beam search and mAES.""" 31 | 32 | dec_out: List[torch.Tensor] = None 33 | lm_scores: torch.Tensor = None 34 | 35 | 36 | class TransducerDecoderInterface: 37 | """Decoder interface for Transducer models.""" 38 | 39 | def init_state( 40 | self, 41 | batch_size: int, 42 | ) -> Union[ 43 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 44 | ]: 45 | """Initialize decoder states. 46 | 47 | Args: 48 | batch_size: Batch size. 49 | 50 | Returns: 51 | state: Initial decoder hidden states. 52 | 53 | """ 54 | raise NotImplementedError("init_state(...) is not implemented") 55 | 56 | def score( 57 | self, 58 | hyp: Hypothesis, 59 | cache: Dict[str, Any], 60 | ) -> Tuple[ 61 | torch.Tensor, 62 | Union[ 63 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 64 | ], 65 | torch.Tensor, 66 | ]: 67 | """One-step forward hypothesis. 68 | 69 | Args: 70 | hyp: Hypothesis. 71 | cache: Pairs of (dec_out, dec_state) for each token sequence. (key) 72 | 73 | Returns: 74 | dec_out: Decoder output sequence. 75 | new_state: Decoder hidden states. 76 | lm_tokens: Label ID for LM. 77 | 78 | """ 79 | raise NotImplementedError("score(...) is not implemented") 80 | 81 | def batch_score( 82 | self, 83 | hyps: Union[List[Hypothesis], List[ExtendedHypothesis]], 84 | dec_states: Union[ 85 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 86 | ], 87 | cache: Dict[str, Any], 88 | use_lm: bool, 89 | ) -> Tuple[ 90 | torch.Tensor, 91 | Union[ 92 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 93 | ], 94 | torch.Tensor, 95 | ]: 96 | """One-step forward hypotheses. 97 | 98 | Args: 99 | hyps: Hypotheses. 100 | dec_states: Decoder hidden states. 101 | cache: Pairs of (dec_out, dec_states) for each label sequence. (key) 102 | use_lm: Whether to compute label ID sequences for LM. 103 | 104 | Returns: 105 | dec_out: Decoder output sequences. 106 | dec_states: Decoder hidden states. 107 | lm_labels: Label ID sequences for LM. 108 | 109 | """ 110 | raise NotImplementedError("batch_score(...) is not implemented") 111 | 112 | def select_state( 113 | self, 114 | batch_states: Union[ 115 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor] 116 | ], 117 | idx: int, 118 | ) -> Union[ 119 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 120 | ]: 121 | """Get specified ID state from decoder hidden states. 122 | 123 | Args: 124 | batch_states: Decoder hidden states. 125 | idx: State ID to extract. 126 | 127 | Returns: 128 | state_idx: Decoder hidden state for given ID. 129 | 130 | """ 131 | raise NotImplementedError("select_state(...) is not implemented") 132 | 133 | def create_batch_states( 134 | self, 135 | states: Union[ 136 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 137 | ], 138 | new_states: List[ 139 | Union[ 140 | Tuple[torch.Tensor, Optional[torch.Tensor]], 141 | List[Optional[torch.Tensor]], 142 | ] 143 | ], 144 | l_tokens: List[List[int]], 145 | ) -> Union[ 146 | Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] 147 | ]: 148 | """Create decoder hidden states. 149 | 150 | Args: 151 | batch_states: Batch of decoder states 152 | l_states: List of decoder states 153 | l_tokens: List of token sequences for input batch 154 | 155 | Returns: 156 | batch_states: Batch of decoder states 157 | 158 | """ 159 | raise NotImplementedError("create_batch_states(...) is not implemented") 160 | -------------------------------------------------------------------------------- /asr/wenet/squeezeformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SqueezeformerEncoderLayer definition.""" 15 | 16 | import torch 17 | import torch.nn as nn 18 | from typing import Optional, Tuple 19 | 20 | 21 | class SqueezeformerEncoderLayer(nn.Module): 22 | """Encoder layer module. 23 | Args: 24 | size (int): Input dimension. 25 | self_attn (torch.nn.Module): Self-attention module instance. 26 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 27 | instance can be used as the argument. 28 | feed_forward1 (torch.nn.Module): Feed-forward module instance. 29 | `PositionwiseFeedForward` instance can be used as the argument. 30 | conv_module (torch.nn.Module): Convolution module instance. 31 | `ConvlutionModule` instance can be used as the argument. 32 | feed_forward2 (torch.nn.Module): Feed-forward module instance. 33 | `PositionwiseFeedForward` instance can be used as the argument. 34 | dropout_rate (float): Dropout rate. 35 | normalize_before (bool): 36 | True: use layer_norm before each sub-block. 37 | False: use layer_norm after each sub-block. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | size: int, 43 | self_attn: torch.nn.Module, 44 | feed_forward1: Optional[nn.Module] = None, 45 | conv_module: Optional[nn.Module] = None, 46 | feed_forward2: Optional[nn.Module] = None, 47 | normalize_before: bool = False, 48 | dropout_rate: float = 0.1, 49 | concat_after: bool = False, 50 | ): 51 | super(SqueezeformerEncoderLayer, self).__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.layer_norm1 = nn.LayerNorm(size) 55 | self.ffn1 = feed_forward1 56 | self.layer_norm2 = nn.LayerNorm(size) 57 | self.conv_module = conv_module 58 | self.layer_norm3 = nn.LayerNorm(size) 59 | self.ffn2 = feed_forward2 60 | self.layer_norm4 = nn.LayerNorm(size) 61 | self.normalize_before = normalize_before 62 | self.dropout = nn.Dropout(dropout_rate) 63 | self.concat_after = concat_after 64 | if concat_after: 65 | self.concat_linear = nn.Linear(size + size, size) 66 | else: 67 | self.concat_linear = nn.Identity() 68 | 69 | def forward( 70 | self, 71 | x: torch.Tensor, 72 | mask: torch.Tensor, 73 | pos_emb: torch.Tensor, 74 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 75 | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 76 | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), 77 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 78 | # self attention module 79 | residual = x 80 | if self.normalize_before: 81 | x = self.layer_norm1(x) 82 | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, 83 | att_cache) 84 | if self.concat_after: 85 | x_concat = torch.cat((x, x_att), dim=-1) 86 | x = residual + self.concat_linear(x_concat) 87 | else: 88 | x = residual + self.dropout(x_att) 89 | if not self.normalize_before: 90 | x = self.layer_norm1(x) 91 | 92 | # ffn module 93 | residual = x 94 | if self.normalize_before: 95 | x = self.layer_norm2(x) 96 | x = self.ffn1(x) 97 | x = residual + self.dropout(x) 98 | if not self.normalize_before: 99 | x = self.layer_norm2(x) 100 | 101 | # conv module 102 | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 103 | residual = x 104 | if self.normalize_before: 105 | x = self.layer_norm3(x) 106 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 107 | x = residual + self.dropout(x) 108 | if not self.normalize_before: 109 | x = self.layer_norm3(x) 110 | 111 | # ffn module 112 | residual = x 113 | if self.normalize_before: 114 | x = self.layer_norm4(x) 115 | x = self.ffn2(x) 116 | # we do not use dropout here since it is inside feed forward function 117 | x = residual + self.dropout(x) 118 | if not self.normalize_before: 119 | x = self.layer_norm4(x) 120 | 121 | return x, mask, new_att_cache, new_cnn_cache 122 | -------------------------------------------------------------------------------- /asr/wenet/bin/ctc_align.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | wenet_space_symbol = '▁' 4 | 5 | 6 | def id_to_token(tok, tokenizer): 7 | return tokenizer.detokenize([tok])[1][0] 8 | 9 | 10 | def is_special_token(word): 11 | open_bracket = word.find('<') 12 | close_bracket = word.find('>') 13 | return (open_bracket != -1) & (close_bracket != -1) & (open_bracket < close_bracket) 14 | 15 | 16 | def is_empty_word(word): 17 | return (word == "" or word == wenet_space_symbol) 18 | 19 | 20 | def is_start_of_word_token(word): 21 | return (word.find(wenet_space_symbol) != -1) 22 | 23 | 24 | def ctc_align(hypothesis, time_stamp, confidence_scores, tokenizer, 25 | frame_shift_ms, time_shift_ms): 26 | """ Convert tokens to words and assign timestamps based on frame indices from CTC output 27 | """ 28 | assert len(hypothesis) == len(time_stamp) 29 | word = "" 30 | unit_ids = [] 31 | start_ts_ms = -1 32 | unit_start = -1 33 | path = [] 34 | g_time_stamp_gap_ms = 100 35 | for i in range(len(hypothesis)): 36 | 37 | token = id_to_token(hypothesis[i], tokenizer) 38 | next_token = id_to_token(hypothesis[i + 1], tokenizer)if i + 1 < len(hypothesis) else wenet_space_symbol 39 | pos = token.find(wenet_space_symbol) 40 | 41 | # Trim starting _ if necessary 42 | if pos != -1: 43 | word += token[len(wenet_space_symbol):] 44 | else: 45 | word += token 46 | 47 | unit_ids.append(hypothesis[i]) 48 | 49 | if start_ts_ms == -1: 50 | # To ensure start is always greater than 0 51 | start_ts_ms = max(time_stamp[i] * frame_shift_ms - g_time_stamp_gap_ms, 0) 52 | if i > 0: 53 | start_ts_ms = ((time_stamp[i - 1] + time_stamp[i]) // 2 * frame_shift_ms 54 | if (time_stamp[i] - time_stamp[i - 1]) * frame_shift_ms < g_time_stamp_gap_ms 55 | else start_ts_ms) 56 | unit_start = i 57 | 58 | # Cutting a word if the word is a special token 59 | if not is_empty_word(word) and is_special_token(word): 60 | end_ts_ms = time_stamp[i] * frame_shift_ms 61 | if i < len(hypothesis) - 1: 62 | end_ts_ms = ((time_stamp[i + 1] + time_stamp[i]) // 2 * frame_shift_ms 63 | if (time_stamp[i + 1] - time_stamp[i]) * frame_shift_ms < g_time_stamp_gap_ms 64 | else end_ts_ms) 65 | 66 | if confidence_scores: 67 | confidence = max(c for c in confidence_scores[unit_start:i+1]) 68 | else: 69 | confidence = 0 70 | assert start_ts_ms < end_ts_ms 71 | assert len(unit_ids) == 1 72 | path.append({ 73 | 'word': word, 74 | 'unit_id': unit_ids[0], 75 | 'start_time_ms': start_ts_ms + time_shift_ms, 76 | 'end_time_ms': end_ts_ms + time_shift_ms, 77 | 'confidence': confidence, 78 | 'unit_ids': unit_ids 79 | }) 80 | 81 | start_ts_ms = -1 82 | unit_start = 0 83 | unit_ids = [] 84 | word = "" 85 | 86 | # Cutting a word if next token starts from _ or next token is a special token 87 | if is_start_of_word_token(next_token) or is_special_token(next_token): 88 | end_ts_ms = time_stamp[i] * frame_shift_ms 89 | if i < len(hypothesis) - 1: 90 | end_ts_ms = ((time_stamp[i + 1] + time_stamp[i]) // 2 * frame_shift_ms 91 | if (time_stamp[i + 1] - time_stamp[i]) * frame_shift_ms < g_time_stamp_gap_ms 92 | else end_ts_ms) 93 | if not is_empty_word(word): 94 | assert len(unit_ids) > 0 95 | if confidence_scores: 96 | confidence = max(c for c in confidence_scores[unit_start:i+1]) 97 | else: 98 | confidence = 0 99 | assert start_ts_ms <= end_ts_ms 100 | assert not is_special_token(word) 101 | path.append({ 102 | 'word': word, 103 | 'unit_id': -1, 104 | 'start_time_ms': start_ts_ms + time_shift_ms, 105 | 'end_time_ms': end_ts_ms + time_shift_ms, 106 | 'confidence': confidence, 107 | 'unit_ids': unit_ids 108 | }) 109 | start_ts_ms = -1 110 | unit_start = 0 111 | unit_ids = [] 112 | word = "" 113 | return path 114 | 115 | 116 | def adjust_model_time_offset(hypothesis, adjustment): 117 | if adjustment == 0: 118 | return 119 | 120 | adjusted_hyp = [] 121 | for i in range(len(hypothesis)): 122 | word = hypothesis[i] 123 | assert word['start_time_ms'] >= 0 124 | assert word['start_time_ms'] <= word['end_time_ms'] 125 | word_adjustment = 0 126 | if i == 0: 127 | word_adjustment = min(adjustment, word['start_time_ms']) 128 | else: 129 | prev_word = hypothesis[i-1] 130 | assert word['start_time_ms'] >= prev_word['end_time_ms'], f"ERROR! {word} >= {prev_word}" 131 | if word['start_time_ms'] >= prev_word['end_time_ms']: 132 | word_adjustment = min(adjustment, word['start_time_ms'] - prev_word['end_time_ms']) 133 | assert word_adjustment >= 0 134 | word['start_time_ms'] -= word_adjustment 135 | word['end_time_ms'] -= word_adjustment 136 | adjusted_hyp.append(word) 137 | 138 | return adjusted_hyp 139 | -------------------------------------------------------------------------------- /asr/wenet/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | """ConvolutionModule definition.""" 16 | 17 | from typing import Tuple 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class ConvolutionModule(nn.Module): 24 | """ConvolutionModule in Conformer model.""" 25 | 26 | def __init__(self, 27 | channels: int, 28 | kernel_size: int = 15, 29 | activation: nn.Module = nn.ReLU(), 30 | norm: str = "batch_norm", 31 | causal: bool = False, 32 | bias: bool = True): 33 | """Construct an ConvolutionModule object. 34 | Args: 35 | channels (int): The number of channels of conv layers. 36 | kernel_size (int): Kernel size of conv layers. 37 | causal (int): Whether use causal convolution or not 38 | """ 39 | super().__init__() 40 | 41 | self.pointwise_conv1 = nn.Conv1d( 42 | channels, 43 | 2 * channels, 44 | kernel_size=1, 45 | stride=1, 46 | padding=0, 47 | bias=bias, 48 | ) 49 | # self.lorder is used to distinguish if it's a causal convolution, 50 | # if self.lorder > 0: it's a causal convolution, the input will be 51 | # padded with self.lorder frames on the left in forward. 52 | # else: it's a symmetrical convolution 53 | if causal: 54 | padding = 0 55 | self.lorder = kernel_size - 1 56 | else: 57 | # kernel_size should be an odd number for none causal convolution 58 | assert (kernel_size - 1) % 2 == 0 59 | padding = (kernel_size - 1) // 2 60 | self.lorder = 0 61 | self.depthwise_conv = nn.Conv1d( 62 | channels, 63 | channels, 64 | kernel_size, 65 | stride=1, 66 | padding=padding, 67 | groups=channels, 68 | bias=bias, 69 | ) 70 | 71 | assert norm in ['batch_norm', 'layer_norm'] 72 | if norm == "batch_norm": 73 | self.use_layer_norm = False 74 | self.norm = nn.BatchNorm1d(channels) 75 | else: 76 | self.use_layer_norm = True 77 | self.norm = nn.LayerNorm(channels) 78 | 79 | self.pointwise_conv2 = nn.Conv1d( 80 | channels, 81 | channels, 82 | kernel_size=1, 83 | stride=1, 84 | padding=0, 85 | bias=bias, 86 | ) 87 | self.activation = activation 88 | 89 | def forward( 90 | self, 91 | x: torch.Tensor, 92 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 93 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | """Compute convolution module. 96 | Args: 97 | x (torch.Tensor): Input tensor (#batch, time, channels). 98 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 99 | (0, 0, 0) means fake mask. 100 | cache (torch.Tensor): left context cache, it is only 101 | used in causal convolution (#batch, channels, cache_t), 102 | (0, 0, 0) meas fake cache. 103 | Returns: 104 | torch.Tensor: Output tensor (#batch, time, channels). 105 | """ 106 | # exchange the temporal dimension and the feature dimension 107 | x = x.transpose(1, 2) # (#batch, channels, time) 108 | 109 | # mask batch padding 110 | if mask_pad.size(2) > 0: # time > 0 111 | x.masked_fill_(~mask_pad, 0.0) 112 | 113 | if self.lorder > 0: 114 | if cache.size(2) == 0: # cache_t == 0 115 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 116 | else: 117 | assert cache.size(0) == x.size(0) # equal batch 118 | assert cache.size(1) == x.size(1) # equal channel 119 | x = torch.cat((cache, x), dim=2) 120 | assert (x.size(2) > self.lorder) 121 | new_cache = x[:, :, -self.lorder:] 122 | else: 123 | # It's better we just return None if no cache is required, 124 | # However, for JIT export, here we just fake one tensor instead of 125 | # None. 126 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 127 | 128 | # GLU mechanism 129 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 130 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 131 | 132 | # 1D Depthwise Conv 133 | x = self.depthwise_conv(x) 134 | if self.use_layer_norm: 135 | x = x.transpose(1, 2) 136 | x = self.activation(self.norm(x)) 137 | if self.use_layer_norm: 138 | x = x.transpose(1, 2) 139 | x = self.pointwise_conv2(x) 140 | # mask batch padding 141 | if mask_pad.size(2) > 0: # time > 0 142 | x.masked_fill_(~mask_pad, 0.0) 143 | 144 | return x.transpose(1, 2), new_cache 145 | -------------------------------------------------------------------------------- /asr/README.md: -------------------------------------------------------------------------------- 1 | # Table of Contents 2 | - [Getting Started](#getting-started) 3 | - [About](#about) 4 | - [Code](#code) 5 | - [Features](#features) 6 | - [Benchmarking](#benchmarking) 7 | - [Acknowledgements](#acknowledgements) 8 | 9 | ## Docker Usage 10 | The `reverb` package is also installed within the docker image. You can run transcription using the `reverb` binary: 11 | ```bash 12 | reverb --config $config \ 13 | --checkpoint $checkpoint \ 14 | --audio_file $audio \ 15 | --modes ctc_prefix_beam_search attention_rescoring \ 16 | --gpu 0 \ 17 | --verbatimicity 1.0 \ 18 | --result_dir output 19 | ``` 20 | where `$config` points to the `config.yaml` file and `$checkpoint` points to the `reverb_asr_v1.pt` file. Or alternatively, you can simply name the model using `--model` to run our checkpoint. 21 | ```bash 22 | reverb --model reverb_asr_v1 \ 23 | --audio_file $audio \ 24 | --modes ctc_prefix_beam_search attention_rescoring \ 25 | --gpu 0 \ 26 | --verbatimicity 1.0 \ 27 | --result_dir output 28 | ``` 29 | 30 | If you are using the docker image, these paths will be: 31 | ```bash 32 | checkpoint="/root/.cache/reverb/reverb_asr_v1/reverb_asr_v1.pt" 33 | config="/root/.cache/reverb/reverb_asr_v1/config.yaml" 34 | ``` 35 | 36 | In place of `$audio`, pass in the wav file you want to run ASR on. 37 | 38 | Or check out our demo [on HuggingFace](https://huggingface.co/spaces/Revai/reverb-asr-demo). 39 | 40 | # About 41 | Reverb ASR was trained on 200,000 hours of English speech, all expertly transcribed by humans - the largest corpus of human transcribed audio ever used to train an open-source model. The quality of this data has produced the world’s most accurate English automatic speech recognition (ASR) system, using an efficient model architecture that can be run on either CPU or GPU. Additionally, Reverb ASR provides user control over the level of verbatimicity of the output transcript, making it ideal for both clean, readable transcription and use-cases like audio editing that require transcription of every spoken word including hesitations and re-wordings. Users can specify fully verbatim, fully non-verbatim, or anywhere in between for their transcription output. 42 | 43 | # Code 44 | The folder `wenet` is taken a fork of the [WeNet](https://github.com/wenet-e2e/wenet) repository, with some modifications made for Rev-specific architecture. 45 | 46 | The folder `wer_evaluation` contains instructions and code for running different benchmark utilities. These scripts are not specific to the Reverb architecture. 47 | 48 | # Features 49 | 50 | ## Transcription Style Options 51 | Reverb ASR was trained to produce transcriptions in either a verbatim style, in which every word is transcribed as spoken; or a non-verbatim style, in which disfluencies may be removed from the transcript. 52 | 53 | Users can specify Reverb ASR's output style with the `verbatimicity` parameter. 1 corresponds to a verbatim transcript that transcribes all spoken content and 0 corresponds to a non-verbatim transcript that removes unnecessary phrases to improve readability. Values between 0 and 1 are accepted and may correspond to a semi-non-verbatim style. The Rev team has found that halfway between verbatim and non-verbatim produces a reader-preferred style for captioning - capturing all content while reducing some hesitations and stutters to make captions fit better on screen. See our demo [here](https://huggingface.co/spaces/Revai/reverb-asr-demo) to test the `verbatimicity` parameter with your own audio. 54 | 55 | ## Decoding Options 56 | 57 | Reverb ASR uses the joint CTC/attention architecture described [here](https://arxiv.org/pdf/2102.01547) and [here](https://www.rev.com/blog/speech-to-text-technology/what-makes-revs-v2-best-in-class), and supports multiple modes of decoding. Users can specify one or more modes of decoding and separate output directories will be created for each decoding mode. 58 | 59 | Decoding options are: 60 | - `attention` 61 | - `ctc_greedy_search` 62 | - `ctc_prefix_beam_search` 63 | - `attention_rescoring` 64 | - `joint_decoding` 65 | 66 | # Benchmarking 67 | 68 | Unlike many ASR providers, Rev primarily uses long-form speech recognition corpora for benchmarking. We use each model to produce a transcript of an entire audio file, then use [fstalign](https://github.com/revdotcom/fstalign) to align and score the complete transcript. We report micro-average WER across all of the reference words in a given test suite. We have included our scoring scripts in this repository so that anyone can replicate our work, benchmark other models, or experiment with new long-form test suites. 69 | 70 | Here, we’ve benchmarked Reverb ASR model against the best performing open-source models currently available: OpenAI’s Whisper large-v3 and NVIDIA’s Canary-1B, both accessed through HuggingFace. Note that both of these models have significantly more parameters than Reverb ASR. We use simple chunking with no overlap - 30s chunks for Whisper and Canary, and 20s chunks for Reverb. These results use CTC prefix beam search with attention rescoring. For Whisper and Canary, we use NeMo to normalize the model outputs before scoring. 71 | 72 | For long-form ASR, we’ve used three corpora: Rev16 (podcasts), [Earnings21](https://github.com/revdotcom/speech-datasets/tree/main/earnings21) (earnings calls from US-based companies), and [Earnings22](https://github.com/revdotcom/speech-datasets/tree/main/earnings22) (earnings calls from global companies). 73 | 74 | | Model | Earnings21 | Earnings22 | Rev16 | 75 | |------------------|------------|------------|-------| 76 | | Reverb ASR | 9.68 | 13.68 | 10.30 | 77 | | Whisper large-v3 | 14.26 | 19.05 | 10.86 | 78 | | Canary-1B | 14.40 | 19.01 | 13.82 | 79 | 80 | See the wer_evaluation folder for benchmarking scripts and usage instructions. 81 | 82 | # Acknowledgments 83 | Special thanks to the Wenet team for their work and for making it available under an open-source license. 84 | -------------------------------------------------------------------------------- /asr/wenet/efficient_conformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2022 58.com(Wuba) Inc AI Lab. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | from typing import Tuple 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class ConvolutionModule(nn.Module): 24 | """ConvolutionModule in Conformer model.""" 25 | 26 | def __init__(self, 27 | channels: int, 28 | kernel_size: int = 15, 29 | activation: nn.Module = nn.ReLU(), 30 | norm: str = "batch_norm", 31 | causal: bool = False, 32 | bias: bool = True, 33 | stride: int = 1): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | stride (int): Stride Convolution, for efficient Conformer 40 | """ 41 | super().__init__() 42 | 43 | self.pointwise_conv1 = nn.Conv1d( 44 | channels, 45 | 2 * channels, 46 | kernel_size=1, 47 | stride=1, 48 | padding=0, 49 | bias=bias, 50 | ) 51 | # self.lorder is used to distinguish if it's a causal convolution, 52 | # if self.lorder > 0: it's a causal convolution, the input will be 53 | # padded with self.lorder frames on the left in forward. 54 | # else: it's a symmetrical convolution 55 | if causal: 56 | padding = 0 57 | self.lorder = kernel_size - 1 58 | else: 59 | # kernel_size should be an odd number for none causal convolution 60 | assert (kernel_size - 1) % 2 == 0 61 | padding = (kernel_size - 1) // 2 62 | self.lorder = 0 63 | 64 | self.depthwise_conv = nn.Conv1d( 65 | channels, 66 | channels, 67 | kernel_size, 68 | stride=stride, # for depthwise_conv in StrideConv 69 | padding=padding, 70 | groups=channels, 71 | bias=bias, 72 | ) 73 | 74 | assert norm in ['batch_norm', 'layer_norm'] 75 | if norm == "batch_norm": 76 | self.use_layer_norm = False 77 | self.norm = nn.BatchNorm1d(channels) 78 | else: 79 | self.use_layer_norm = True 80 | self.norm = nn.LayerNorm(channels) 81 | 82 | self.pointwise_conv2 = nn.Conv1d( 83 | channels, 84 | channels, 85 | kernel_size=1, 86 | stride=1, 87 | padding=0, 88 | bias=bias, 89 | ) 90 | self.activation = activation 91 | self.stride = stride 92 | 93 | def forward( 94 | self, 95 | x: torch.Tensor, 96 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 97 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 98 | ) -> Tuple[torch.Tensor, torch.Tensor]: 99 | """Compute convolution module. 100 | Args: 101 | x (torch.Tensor): Input tensor (#batch, time, channels). 102 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 103 | (0, 0, 0) means fake mask. 104 | cache (torch.Tensor): left context cache, it is only 105 | used in causal convolution (#batch, channels, cache_t), 106 | (0, 0, 0) meas fake cache. 107 | Returns: 108 | torch.Tensor: Output tensor (#batch, time, channels). 109 | """ 110 | # exchange the temporal dimension and the feature dimension 111 | x = x.transpose(1, 2) # (#batch, channels, time) 112 | 113 | # mask batch padding 114 | if mask_pad.size(2) > 0: # time > 0 115 | x.masked_fill_(~mask_pad, 0.0) 116 | 117 | if self.lorder > 0: 118 | if cache.size(2) == 0: # cache_t == 0 119 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 120 | else: 121 | # When export ONNX,the first cache is not None but all-zero, 122 | # cause shape error in residual block, 123 | # eg. cache14 + x9 = 23, 23-7+1=17 != 9 124 | cache = cache[:, :, -self.lorder:] 125 | assert cache.size(0) == x.size(0) # equal batch 126 | assert cache.size(1) == x.size(1) # equal channel 127 | x = torch.cat((cache, x), dim=2) 128 | assert (x.size(2) > self.lorder) 129 | new_cache = x[:, :, -self.lorder:] 130 | else: 131 | # It's better we just return None if no cache is requried, 132 | # However, for JIT export, here we just fake one tensor instead of 133 | # None. 134 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 135 | 136 | # GLU mechanism 137 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 138 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 139 | 140 | # 1D Depthwise Conv 141 | x = self.depthwise_conv(x) 142 | if self.use_layer_norm: 143 | x = x.transpose(1, 2) 144 | x = self.activation(self.norm(x)) 145 | if self.use_layer_norm: 146 | x = x.transpose(1, 2) 147 | x = self.pointwise_conv2(x) 148 | # mask batch padding 149 | if mask_pad.size(2) > 0: # time > 0 150 | if mask_pad.size(2) != x.size(2): 151 | mask_pad = mask_pad[:, :, ::self.stride] 152 | x.masked_fill_(~mask_pad, 0.0) 153 | 154 | return x.transpose(1, 2), new_cache 155 | -------------------------------------------------------------------------------- /asr/wenet/ssl/bestrq/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def _sampler(pdf: torch.Tensor, num_samples: int, 6 | device=torch.device('cpu')) -> torch.Tensor: 7 | size = pdf.size() 8 | z = -torch.log(torch.rand(size, device=device)) 9 | _, indices = torch.topk(pdf + z, num_samples) 10 | return indices 11 | 12 | 13 | def compute_mask_indices( 14 | size: torch.Size, 15 | mask_prob: float, 16 | mask_length: int, 17 | min_masks: int = 0, 18 | device=torch.device('cpu'), 19 | ) -> torch.Tensor: 20 | 21 | assert len(size) == 2 22 | batch_size, seq_length = size 23 | 24 | # compute number of masked span in batch 25 | num_masked_spans = mask_prob * float(seq_length) / float( 26 | mask_length) + torch.rand(1)[0] 27 | num_masked_spans = int(num_masked_spans) 28 | num_masked_spans = max(num_masked_spans, min_masks) 29 | 30 | # num_masked <= seq_length 31 | if num_masked_spans * mask_length > seq_length: 32 | num_masked_spans = seq_length // mask_length 33 | 34 | pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) 35 | mask_idxs = _sampler(pdf, num_masked_spans, device=device) 36 | 37 | mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( 38 | batch_size, 39 | num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] 40 | 41 | offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( 42 | 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] 43 | offset = offset.view(1, num_masked_spans * mask_length) 44 | 45 | mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] 46 | 47 | ones = torch.ones(batch_size, 48 | seq_length, 49 | dtype=torch.bool, 50 | device=mask_idxs.device) 51 | # masks to fill 52 | full_mask = torch.zeros_like(ones, 53 | dtype=torch.bool, 54 | device=mask_idxs.device) 55 | return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones) 56 | 57 | 58 | def compute_mask_indices_v2( 59 | shape, 60 | padding_mask, 61 | mask_prob: float, 62 | mask_length: int, 63 | mask_type: str = 'static', 64 | mask_other: float = 0.0, 65 | min_masks: int = 2, 66 | no_overlap: bool = False, 67 | min_space: int = 1, 68 | device=torch.device('cpu'), 69 | ): 70 | bsz, all_sz = shape 71 | mask = np.full((bsz, all_sz), False) 72 | padding_mask = padding_mask.cpu().numpy() 73 | all_num_mask = int( 74 | # add a random number for probabilistic rounding 75 | mask_prob * all_sz / float(mask_length) + np.random.rand()) 76 | 77 | all_num_mask = max(min_masks, all_num_mask) 78 | 79 | mask_idcs = [] 80 | for i in range(bsz): 81 | if padding_mask is not None and not isinstance(padding_mask, bytes): 82 | sz = all_sz - padding_mask[i].sum() 83 | num_mask = int( 84 | # add a random number for probabilistic rounding 85 | mask_prob * sz / float(mask_length) + np.random.rand()) 86 | num_mask = max(min_masks, num_mask) 87 | else: 88 | sz = all_sz 89 | num_mask = all_num_mask 90 | 91 | if mask_type == 'static': 92 | lengths = np.full(num_mask, mask_length) 93 | elif mask_type == 'uniform': 94 | lengths = np.random.randint(mask_other, 95 | mask_length * 2 + 1, 96 | size=num_mask) 97 | elif mask_type == 'normal': 98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 99 | lengths = [max(1, int(round(x))) for x in lengths] 100 | elif mask_type == 'poisson': 101 | lengths = np.random.poisson(mask_length, size=num_mask) 102 | lengths = [int(round(x)) for x in lengths] 103 | else: 104 | raise Exception('unknown mask selection ' + mask_type) 105 | 106 | if sum(lengths) == 0: 107 | lengths[0] = min(mask_length, sz - 1) 108 | 109 | if no_overlap: 110 | mask_idc = [] 111 | 112 | def arrange(s, e, length, keep_length, mask_idc): 113 | span_start = np.random.randint(s, e - length) 114 | mask_idc.extend(span_start + i for i in range(length)) 115 | 116 | new_parts = [] 117 | if span_start - s - min_space >= keep_length: 118 | new_parts.append((s, span_start - min_space + 1)) 119 | if e - span_start - keep_length - min_space > keep_length: 120 | new_parts.append((span_start + length + min_space, e)) 121 | return new_parts 122 | 123 | parts = [(0, sz)] 124 | min_length = min(lengths) 125 | for length in sorted(lengths, reverse=True): 126 | lens = np.fromiter( 127 | (e - s if e - s >= length + min_space else 0 128 | for s, e in parts), 129 | np.int, 130 | ) 131 | l_sum = np.sum(lens) 132 | if l_sum == 0: 133 | break 134 | probs = lens / np.sum(lens) 135 | c = np.random.choice(len(parts), p=probs) 136 | s, e = parts.pop(c) 137 | parts.extend(arrange(s, e, length, min_length, mask_idc)) 138 | mask_idc = np.asarray(mask_idc) 139 | else: 140 | min_len = min(lengths) 141 | if sz - min_len <= num_mask: 142 | min_len = sz - num_mask - 1 143 | 144 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 145 | 146 | mask_idc = np.asarray([ 147 | mask_idc[j] + offset for j in range(len(mask_idc)) 148 | for offset in range(lengths[j]) 149 | ]) 150 | 151 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 152 | 153 | min_len = min([len(m) for m in mask_idcs]) 154 | for i, mask_idc in enumerate(mask_idcs): 155 | if len(mask_idc) > min_len: 156 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 157 | mask[i, mask_idc] = True 158 | 159 | mask = torch.from_numpy(mask).to(device) 160 | return mask 161 | -------------------------------------------------------------------------------- /asr/wenet/transformer/context_adaptor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified from ESPnet(https://github.com/espnet/espnet) 15 | 16 | import logging 17 | from typing import Tuple, List, Optional 18 | 19 | import torch 20 | from typeguard import check_argument_types 21 | 22 | from wenet.utils.mask import (subsequent_mask, make_pad_mask) 23 | 24 | 25 | class ContextAdaptor(torch.nn.Module): 26 | """ContextAdaptor: https://assets.amazon.science/43/13/104c968c45ea9ed02cffaa1448e0/personalization-of-ctc-speech-recognition-models.pdf 27 | Args: 28 | vocab_size: subword vocab size 29 | embedding_dim: size of subword embeddings 30 | encoder_output_size: dimension of attention 31 | num_layers: the number of bilstm layers 32 | dropout_rate: dropout rate 33 | attention_heads: the number of heads of multi head attention 34 | attention_dropout_rate: dropout rate for attention 35 | """ 36 | def __init__( 37 | self, 38 | vocab_size: int, 39 | output_size: int = 512, 40 | embedding_dim: int = 128, 41 | num_layers: int = 2, 42 | dropout_rate: float = 0.1, 43 | attention_heads: int = 1, 44 | attention_dropout_rate: float = 0.0, 45 | ): 46 | assert check_argument_types() 47 | super().__init__() 48 | 49 | self.vocab_size = vocab_size 50 | # embedding layer (subword unit --> embedding) 51 | self.embed = torch.nn.Embedding(vocab_size+1, embedding_dim) 52 | 53 | # bidirectional LSTM -- output size will be doubled 54 | lstm_output_size = int(output_size/2) 55 | self.encoder = torch.nn.LSTM(embedding_dim, 56 | lstm_output_size, 57 | num_layers, 58 | batch_first = True, 59 | dropout = dropout_rate, 60 | bidirectional = True) 61 | 62 | # attention mechanism - ASR encoder outputs vs. context terms 63 | self.attention = torch.nn.MultiheadAttention(output_size, attention_heads, 64 | attention_dropout_rate, 65 | batch_first=True) 66 | 67 | def forward( 68 | self, 69 | encoder_layer_outs: List[torch.Tensor], 70 | cv_encoder_out: torch.Tensor, 71 | ) -> torch.Tensor: 72 | """Forward just attention piece of contextual adaptor 73 | Args: 74 | encoder_layer_outs: list of outputs of ASR encoder layers. each one is (batch, maxlen, output_size) 75 | cv_encoder_out: output of cv encoder (1, n_cv_terms, output_size) 76 | Returns: 77 | x: decoded token score before softmax (batch, maxlen, output_size) 78 | """ 79 | combined_encoder_layer_outs = self.combine_layers(encoder_layer_outs) 80 | cv_encoder_out = cv_encoder_out.expand(combined_encoder_layer_outs.shape[0], -1, -1) 81 | 82 | x, y = self.attention(combined_encoder_layer_outs, cv_encoder_out, cv_encoder_out) 83 | # x = batch x frame x embedding 84 | # y = batch x frame x CV term 85 | assert y is not None 86 | #mask = y[:, :, 0] > 0.5 87 | mask = torch.argmax(y, dim=2) == 0 88 | x[mask.unsqueeze(2).expand(-1, -1, x.shape[2])] = 0. 89 | # JDF: uncomment for CV detection during decoding 90 | #for i in range(y.shape[1]): 91 | # if y[0,i,0] <= 0.5: 92 | # logging.info(str(i) + " " + str(y[0, i, :])) 93 | return x 94 | 95 | def encode_cv( 96 | self, 97 | cv: torch.Tensor, 98 | lengths: torch.Tensor 99 | ) -> torch.Tensor: 100 | """Encode context terms - separated from main forward step so that it can be done just once per audio file at inference time 101 | Args: 102 | memory: encoded memory, float32 (batch, maxlen_in, feat) 103 | memory_mask: encoder memory mask, (batch, 1, maxlen_in) 104 | Returns: 105 | x: decoded token score before softmax (batch, maxlen_out, 106 | vocab_size) 107 | """ 108 | blank_token = torch.zeros(1, cv.shape[1], dtype=torch.int32) 109 | blank_token[0,0] = self.vocab_size 110 | if cv.get_device() >= 0: 111 | blank_token = blank_token.to(cv.get_device(), non_blocking=True) 112 | blank_length = torch.ones(1, dtype=torch.int32) 113 | 114 | if lengths.get_device() >= 0: 115 | blank_length = blank_length.to(lengths.get_device(), non_blocking=True) 116 | 117 | cv = torch.cat([blank_token, cv]) 118 | 119 | lengths = torch.cat([blank_length, lengths]) 120 | # pack_padded_sequence requires lengths to be on CPU 121 | lengths = lengths.to('cpu') 122 | 123 | # subwords --> embeddings 124 | x = self.embed(cv) # nTerms x maxlen x embdding_dim 125 | # padding 126 | x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 127 | 128 | # lstm on each CV term, pull out last hidden state 129 | _, (x,_) = self.encoder(x) # (nLayers x 2) x nTerms x output_dim 130 | x = x.view(-1, 2, x.shape[1], x.shape[2]) 131 | 132 | # concat forward and backward from last layer 133 | x = torch.cat([x[-1, 0, :, :], x[-1, 1, :, :]], dim=1).unsqueeze(0) # nTerms x 1 x output_dim*2 134 | 135 | return x 136 | 137 | def combine_layers( 138 | self, 139 | layer_outs: List[torch.Tensor] 140 | ) -> torch.Tensor: 141 | # in https://assets.amazon.science/43/13/104c968c45ea9ed02cffaa1448e0/personalization-of-ctc-speech-recognition-models.pdf 142 | # they use a weighted sum of the 6th, 12th, and 20th layers out of a 20-layer encoder 143 | # but they don't say what the weights are :/ 144 | return 0.5*layer_outs[-1] + 0.25*layer_outs[-9] + 0.25*layer_outs[-15] 145 | 146 | 147 | -------------------------------------------------------------------------------- /asr/wenet/utils/ctc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Tuple 16 | 17 | import numpy as np 18 | 19 | import torch 20 | 21 | 22 | def remove_duplicates_and_blank(hyp: List[int], 23 | blank_id: int = 0) -> List[int]: 24 | new_hyp: List[int] = [] 25 | cur = 0 26 | while cur < len(hyp): 27 | if hyp[cur] != blank_id: 28 | new_hyp.append(hyp[cur]) 29 | prev = cur 30 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 31 | cur += 1 32 | return new_hyp 33 | 34 | 35 | def replace_duplicates_with_blank(hyp: List[int], 36 | blank_id: int = 0) -> List[int]: 37 | new_hyp: List[int] = [] 38 | cur = 0 39 | while cur < len(hyp): 40 | new_hyp.append(hyp[cur]) 41 | prev = cur 42 | cur += 1 43 | while cur < len( 44 | hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: 45 | new_hyp.append(blank_id) 46 | cur += 1 47 | return new_hyp 48 | 49 | 50 | def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]: 51 | times = [] 52 | cur = 0 53 | while cur < len(hyp): 54 | if hyp[cur] != blank_id: 55 | times.append(cur) 56 | prev = cur 57 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 58 | cur += 1 59 | return times 60 | 61 | 62 | def gen_timestamps_from_peak( 63 | peaks: List[int], 64 | max_duration: float, 65 | frame_rate: float = 0.04, 66 | max_token_duration: float = 1.0, 67 | ) -> List[Tuple[float, float]]: 68 | """ 69 | Args: 70 | peaks: ctc peaks time stamp 71 | max_duration: max_duration of the sentence 72 | frame_rate: frame rate of every time stamp, in seconds 73 | max_token_duration: max duration of the token, in seconds 74 | Returns: 75 | list(start, end) of each token 76 | """ 77 | times = [] 78 | half_max = max_token_duration / 2 79 | for i in range(len(peaks)): 80 | if i == 0: 81 | start = max(0, peaks[0] * frame_rate - half_max) 82 | else: 83 | start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate, 84 | peaks[i] * frame_rate - half_max) 85 | 86 | if i == len(peaks) - 1: 87 | end = min(max_duration, peaks[-1] * frame_rate + half_max) 88 | else: 89 | end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate, 90 | peaks[i] * frame_rate + half_max) 91 | times.append((start, end)) 92 | return times 93 | 94 | 95 | def insert_blank(label, blank_id=0): 96 | """Insert blank token between every two label token.""" 97 | label = np.expand_dims(label, 1) 98 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 99 | label = np.concatenate([blanks, label], axis=1) 100 | label = label.reshape(-1) 101 | label = np.append(label, label[0]) 102 | return label 103 | 104 | 105 | def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: 106 | """ctc forced alignment. 107 | 108 | Args: 109 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 110 | torch.Tensor y: id sequence tensor 1d tensor (L) 111 | int blank_id: blank symbol index 112 | Returns: 113 | torch.Tensor: alignment result 114 | """ 115 | ctc_probs = ctc_probs.cpu() 116 | y = y.cpu() 117 | y_insert_blank = insert_blank(y, blank_id) 118 | 119 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 120 | log_alpha = log_alpha - float('inf') # log of zero 121 | state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)), 122 | dtype=torch.int16) - 1 # state path 123 | 124 | # init start state 125 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 126 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 127 | 128 | for t in range(1, ctc_probs.size(0)): 129 | for s in range(len(y_insert_blank)): 130 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 131 | s] == y_insert_blank[s - 2]: 132 | candidates = torch.tensor( 133 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 134 | prev_state = [s, s - 1] 135 | else: 136 | candidates = torch.tensor([ 137 | log_alpha[t - 1, s], 138 | log_alpha[t - 1, s - 1], 139 | log_alpha[t - 1, s - 2], 140 | ]) 141 | prev_state = [s, s - 1, s - 2] 142 | log_alpha[ 143 | t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 144 | state_path[t, s] = prev_state[torch.argmax(candidates)] 145 | 146 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 147 | 148 | candidates = torch.tensor([ 149 | log_alpha[-1, len(y_insert_blank) - 1], 150 | log_alpha[-1, len(y_insert_blank) - 2] 151 | ]) 152 | final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 153 | state_seq[-1] = final_state[torch.argmax(candidates)] 154 | for t in range(ctc_probs.size(0) - 2, -1, -1): 155 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 156 | 157 | output_alignment = [] 158 | for t in range(0, ctc_probs.size(0)): 159 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 160 | 161 | return output_alignment 162 | 163 | 164 | def get_blank_id(configs, symbol_table): 165 | if 'ctc_conf' not in configs: 166 | configs['ctc_conf'] = {} 167 | 168 | if '' in symbol_table: 169 | if 'ctc_blank_id' in configs['ctc_conf']: 170 | assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[ 171 | ''] 172 | else: 173 | configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] 174 | else: 175 | assert 'ctc_blank_id' in configs[ 176 | 'ctc_conf'], "PLZ set ctc_blank_id in yaml" 177 | 178 | return configs, configs['ctc_conf']['ctc_blank_id'] 179 | -------------------------------------------------------------------------------- /asr/wenet/transducer_espnet/scorer_interface.py: -------------------------------------------------------------------------------- 1 | """Scorer interface module.""" 2 | 3 | from typing import Any 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import torch 8 | import warnings 9 | 10 | 11 | class ScorerInterface: 12 | """Scorer interface for beam search. 13 | 14 | The scorer performs scoring of the all tokens in vocabulary. 15 | 16 | Examples: 17 | * Search heuristics 18 | * :class:`espnet.nets.scorers.length_bonus.LengthBonus` 19 | * Decoder networks of the sequence-to-sequence models 20 | * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` 21 | * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` 22 | * Neural language models 23 | * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` 24 | * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` 25 | * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` 26 | 27 | """ 28 | 29 | def init_state(self, x: torch.Tensor) -> Any: 30 | """Get an initial state for decoding (optional). 31 | 32 | Args: 33 | x (torch.Tensor): The encoded feature tensor 34 | 35 | Returns: initial state 36 | 37 | """ 38 | return None 39 | 40 | def select_state(self, state: Any, i: int, new_id: int = None) -> Any: 41 | """Select state with relative ids in the main beam search. 42 | 43 | Args: 44 | state: Decoder state for prefix tokens 45 | i (int): Index to select a state in the main beam search 46 | new_id (int): New label index to select a state if necessary 47 | 48 | Returns: 49 | state: pruned state 50 | 51 | """ 52 | return None if state is None else state[i] 53 | 54 | def score( 55 | self, y: torch.Tensor, state: Any, x: torch.Tensor 56 | ) -> Tuple[torch.Tensor, Any]: 57 | """Score new token (required). 58 | 59 | Args: 60 | y (torch.Tensor): 1D torch.int64 prefix tokens. 61 | state: Scorer state for prefix tokens 62 | x (torch.Tensor): The encoder feature that generates ys. 63 | 64 | Returns: 65 | tuple[torch.Tensor, Any]: Tuple of 66 | scores for next token that has a shape of `(n_vocab)` 67 | and next state for ys 68 | 69 | """ 70 | raise NotImplementedError 71 | 72 | def final_score(self, state: Any) -> float: 73 | """Score eos (optional). 74 | 75 | Args: 76 | state: Scorer state for prefix tokens 77 | 78 | Returns: 79 | float: final score 80 | 81 | """ 82 | return 0.0 83 | 84 | 85 | class BatchScorerInterface(ScorerInterface): 86 | """Batch scorer interface.""" 87 | 88 | def batch_init_state(self, x: torch.Tensor) -> Any: 89 | """Get an initial state for decoding (optional). 90 | 91 | Args: 92 | x (torch.Tensor): The encoded feature tensor 93 | 94 | Returns: initial state 95 | 96 | """ 97 | return self.init_state(x) 98 | 99 | def batch_score( 100 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 101 | ) -> Tuple[torch.Tensor, List[Any]]: 102 | """Score new token batch (required). 103 | 104 | Args: 105 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 106 | states (List[Any]): Scorer states for prefix tokens. 107 | xs (torch.Tensor): 108 | The encoder feature that generates ys (n_batch, xlen, n_feat). 109 | 110 | Returns: 111 | tuple[torch.Tensor, List[Any]]: Tuple of 112 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 113 | and next state list for ys. 114 | 115 | """ 116 | warnings.warn( 117 | "{} batch score is implemented through for loop not parallelized".format( 118 | self.__class__.__name__ 119 | ) 120 | ) 121 | scores = list() 122 | outstates = list() 123 | for i, (y, state, x) in enumerate(zip(ys, states, xs)): 124 | score, outstate = self.score(y, state, x) 125 | outstates.append(outstate) 126 | scores.append(score) 127 | scores = torch.cat(scores, 0).view(ys.shape[0], -1) 128 | return scores, outstates 129 | 130 | 131 | class PartialScorerInterface(ScorerInterface): 132 | """Partial scorer interface for beam search. 133 | 134 | The partial scorer performs scoring when non-partial scorer finished scoring, 135 | and receives pre-pruned next tokens to score because it is too heavy to score 136 | all the tokens. 137 | 138 | Examples: 139 | * Prefix search for connectionist-temporal-classification models 140 | * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` 141 | 142 | """ 143 | 144 | def score_partial( 145 | self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor 146 | ) -> Tuple[torch.Tensor, Any]: 147 | """Score new token (required). 148 | 149 | Args: 150 | y (torch.Tensor): 1D prefix token 151 | next_tokens (torch.Tensor): torch.int64 next token to score 152 | state: decoder state for prefix tokens 153 | x (torch.Tensor): The encoder feature that generates ys 154 | 155 | Returns: 156 | tuple[torch.Tensor, Any]: 157 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 158 | and next state for ys 159 | 160 | """ 161 | raise NotImplementedError 162 | 163 | 164 | class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): 165 | """Batch partial scorer interface for beam search.""" 166 | 167 | def batch_score_partial( 168 | self, 169 | ys: torch.Tensor, 170 | next_tokens: torch.Tensor, 171 | states: List[Any], 172 | xs: torch.Tensor, 173 | ) -> Tuple[torch.Tensor, Any]: 174 | """Score new token (required). 175 | 176 | Args: 177 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 178 | next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). 179 | states (List[Any]): Scorer states for prefix tokens. 180 | xs (torch.Tensor): 181 | The encoder feature that generates ys (n_batch, xlen, n_feat). 182 | 183 | Returns: 184 | tuple[torch.Tensor, Any]: 185 | Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` 186 | and next states for ys 187 | """ 188 | raise NotImplementedError 189 | --------------------------------------------------------------------------------