├── .gitignore ├── LICENSE ├── README.md ├── ctc-asr-chunked-inference ├── README.md ├── ctc_asr_chunked_inference │ ├── __init__.py │ ├── asr_chunk_infer_glue_pipeline.py │ └── asr_infer_decode.py ├── requirements.txt ├── setup.py └── tests │ ├── __init__.py │ ├── conftest.py │ ├── resources │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011.opus │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt │ └── lm.arpa │ ├── test_aschinglupi.py │ └── test_asr_infer_decode.py ├── ctc-decoding ├── README.md ├── ctc_decoding │ ├── __init__.py │ ├── ctc_decoding.py │ ├── huggingface_ctc_decoding.py │ ├── lm_model_for_pyctcdecode.py │ ├── logit_aligned_transcript.py │ └── pyctc_decoder.py ├── requirements.txt ├── setup.py └── tests │ ├── conftest.py │ ├── resources │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011_logits.npy │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt │ ├── lm.arpa │ └── test_corpus.txt │ ├── test_greedy_decoding.py │ └── test_pyctc_decoding.py ├── fastapi-asr-service ├── app │ ├── __init__.py │ ├── fastapi_asr_service_utils.py │ └── main.py ├── build_model_in_docker.py ├── docker │ └── fastapi_cpu │ │ └── Dockerfile ├── readme.md ├── requirements.txt └── tests │ ├── resources │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011.opus │ └── LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt │ └── test_fastapi_asr_service.py ├── huggingface_wav2vec2_finetuning ├── __init__.py ├── base_model_for_finetuning.py ├── ctc_data_collator.py ├── ctc_trainer.py ├── data_loading_commons.py ├── ds_config_zero3.json ├── hf_finetune_utils.py ├── huggingface_wav2vec2_finetuner.py ├── requirements.txt ├── run_finetuning_directly.py ├── run_speech_recognition_ctc_bnb_original.py └── stream_ftdataset.py ├── ml4audio ├── __init__.py ├── asr_inference │ ├── __init__.py │ ├── faster_whisper_inferencer.py │ ├── inference.py │ ├── logits_inferencer │ │ ├── __init__.py │ │ ├── asr_logits_inferencer.py │ │ ├── hfwav2vec2_logits_inferencer.py │ │ ├── huggingface_checkpoints.py │ │ └── nemo_asr_logits_inferencer.py │ ├── openai_whisper_inferencer.py │ ├── pytorch_to_onnx_for_wav2vec.py │ ├── transcript_glueing.py │ ├── transcript_gluer.py │ └── whisper_inference.py ├── audio_data │ ├── __init__.py │ ├── common_voice_datasets.py │ ├── hf_speech_iterable_dataset.py │ ├── mls_corpora.py │ ├── nemo_perturbation.py │ ├── sox_signal_augmentation.py │ └── targz_asr_dataset.py ├── audio_utils │ ├── __init__.py │ ├── aligned_transcript.py │ ├── audio_data_models.py │ ├── audio_io.py │ ├── audio_segmentation_utils.py │ ├── convert_video_to_mp3.py │ ├── nemo_utils.py │ ├── overlap_array_chunker.py │ ├── pyaudio_streaming.py │ ├── subtitle_utils.py │ ├── test_utils.py │ └── torchaudio_utils.py ├── service_utils │ ├── __init__.py │ └── fastapi_utils.py └── text_processing │ ├── __init__.py │ ├── asr_metrics.py │ ├── asr_text_cleaning.py │ ├── character_mappings │ ├── __init__.py │ ├── cyrillic_character_maps.py │ ├── latin_character_maps.py │ ├── not_str_translatable_maps.py │ └── text_cleaning.py │ ├── kenlm_arpa.py │ ├── pretty_diff.py │ ├── smith_waterman_alignment.py │ └── word_based_text_corpus.py ├── nemo_language_classification ├── __init__.py ├── benchmark_lang_clf.py ├── conf │ └── titanet-finetune.yaml ├── finetune_lang_clf.py ├── language_classification_data.py ├── nemo_lang_clf.py ├── prepare_lang_clf_splits.py ├── readme.md └── requirements.txt ├── nemo_punctuation_capitalization ├── README.md ├── __init__.py ├── punctcap_service │ ├── Dockerfile │ ├── __init__.py │ ├── build_model.py │ ├── debug_punctcap_service.py │ ├── punctcap_fastapi_server.py │ ├── readme.md │ └── requirements.txt └── punctcap_training │ ├── __init__.py │ ├── conf │ └── punctuation_capitalization_config.yaml │ ├── lenta_data.py │ ├── nemo_punctcap_traindata.py │ ├── punctcap_inference.py │ ├── punctuation_capitalization_train_evaluate.py │ ├── punctuation_tatoeba_data.py │ ├── readme.md │ ├── requirements.txt │ └── run_punctuation_training.py ├── nemo_vad ├── __init__.py ├── images │ └── vad_demo.png ├── nemo_offline_vad.py ├── nemo_streaming_vad.py ├── readme.md ├── requirements.txt ├── scripts │ └── visualize_segmentation.py ├── streaming_vad_segmentation.py ├── tests │ ├── __init__.py │ ├── resources │ │ └── VAD_demo.wav │ ├── test_nemo_offline_vad.py │ ├── test_vad.py │ └── vad_infer_almost_original.py └── vad_inference_postprocessing.yaml ├── requirements.txt ├── setup.py ├── speaker-diarization ├── images │ ├── dw_africa_queen_elizabeth_prediction_umap_cluster.png │ ├── dw_africa_queen_elizabeth_speaker_segments.png │ ├── dw_africa_queen_elizabeth_youtube.jpg │ ├── dw_queen_elizabeth_africa_speakers.png │ └── speaker_visualization.png ├── nemo_diarization_tutorial.py ├── readme.md ├── requirements.txt ├── scripts │ ├── audio_segmentation_via_asr.py │ └── debug_speaker_clustering_service.py ├── setup.py ├── speaker_diarization │ ├── __init__.py │ ├── diarization │ │ ├── __init__.py │ │ ├── nemo_diarizers.py │ │ ├── offline_diarization.yaml │ │ ├── pyannote_diarizers.py │ │ ├── speaker_diarization_inferencer.py │ │ └── umaspeclu_diarizers.py │ ├── nemo_speaker_embedder.py │ ├── speaker_clusterer.py │ ├── speaker_embedding_utils.py │ └── speechbrain_der.py └── tests │ ├── resources │ ├── oLnl1D6owYA.opus │ └── oLnl1D6owYA_ref.rttm │ └── test_speaker_clusterer.py ├── speaker_clustering_service ├── app │ ├── __init__.py │ └── main.py ├── build_model_in_docker.py ├── docker │ └── fastapi_cpu │ │ └── Dockerfile ├── readme.md └── requirements.txt ├── tests ├── conftest.py ├── resources │ ├── aligned_transcript.json │ └── test_corpus.txt ├── test_arpa_from_corpus.py └── test_array_chunking.py ├── urdu_asr ├── readme.md └── some_urdu.png └── whisper-streaming ├── README.md ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── conftest.py ├── resources │ ├── LibriSpeech_dev-other_116_288046_116-288046-0011.opus │ └── LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt └── test_whisper_streaming.py └── whisper_streaming ├── __init__.py └── whisper_streaming.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | *.env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SELMA-project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml4audio 2 | audio, NLP, ML with huggingface, nvidia/nemo, speechbrain 3 | 4 | [![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.rtfd.io) 5 | 6 | -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/README.md: -------------------------------------------------------------------------------- 1 | # streaming inference for ctc-based asr 2 | * currently only huggingface's wav2vec2 -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/ctc_asr_chunked_inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/ctc-asr-chunked-inference/ctc_asr_chunked_inference/__init__.py -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/ctc_asr_chunked_inference/asr_infer_decode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from dataclasses import dataclass 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | from beartype import beartype 9 | from ml4audio.asr_inference.logits_inferencer.asr_logits_inferencer import ( 10 | ASRLogitsInferencer, 11 | ) 12 | from ml4audio.audio_utils.audio_io import MAX_16_BIT_PCM 13 | from ml4audio.audio_utils.torchaudio_utils import torchaudio_resample 14 | from transformers import set_seed 15 | 16 | from ctc_decoding.ctc_decoding import BaseCTCDecoder 17 | from ctc_decoding.logit_aligned_transcript import LogitAlignedTranscript 18 | from misc_utils.beartypes import ( 19 | TorchTensor2D, 20 | NeNpFloatDim1, 21 | NeNpFloatDim1, 22 | NeNpInt16Dim1, 23 | ) 24 | from misc_utils.buildable import Buildable 25 | from misc_utils.dataclass_utils import UNDEFINED, _UNDEFINED 26 | from ml4audio.audio_utils.aligned_transcript import ( 27 | TimestampedLetters, 28 | ) 29 | 30 | DEBUG = False 31 | counter = 0 32 | if DEBUG: 33 | # debug_name= "16kHz" 34 | debug_name = "8kHz" 35 | debug_wav_dir = f"/tmp/debug_wav_{debug_name}" 36 | shutil.rmtree(debug_wav_dir, ignore_errors=True) 37 | os.makedirs(debug_wav_dir, exist_ok=True) 38 | 39 | set_seed(42) 40 | 41 | 42 | NumpyFloatORInt16_1DArray = Union[NeNpFloatDim1, NeNpInt16Dim1] 43 | 44 | 45 | @beartype 46 | def convert_and_resample( 47 | audio: NumpyFloatORInt16_1DArray, input_sample_rate: int, target_sample_rate: int 48 | ) -> NeNpFloatDim1: 49 | if audio.dtype == np.int16: 50 | audio = audio.astype(np.float32) / MAX_16_BIT_PCM 51 | if input_sample_rate != target_sample_rate: 52 | audio = torchaudio_resample( 53 | signal=torch.from_numpy(audio.astype(np.float32)), 54 | sample_rate=input_sample_rate, 55 | target_sample_rate=target_sample_rate, 56 | ).numpy() 57 | return audio 58 | 59 | 60 | @dataclass 61 | class ASRInferDecoder(Buildable): 62 | """ 63 | does asr-inference WITH decoding greedy/lm-based 64 | TODO: 65 | also does preprocessing of the audio-array (conversion+resampling)! 66 | split into logits-inferencer and decoder 67 | well seems huggingface's "src/transformers/pipelines/automatic_speech_recognition.py" cannot yet do streaming! just "long audio-files" 68 | 69 | """ 70 | 71 | input_sample_rate: int = 16000 72 | logits_inferencer: ASRLogitsInferencer = UNDEFINED # order matters! first the logits_inferencer is build which builds the transcript_normalizer which is needed by decoder! 73 | decoder: Union[_UNDEFINED, BaseCTCDecoder] = UNDEFINED 74 | 75 | @property 76 | def vocab(self) -> list[str]: 77 | return self.logits_inferencer.vocab 78 | 79 | @beartype 80 | def transcribe_audio_array(self, audio_array: NeNpFloatDim1) -> TimestampedLetters: 81 | audio_array = convert_and_resample( 82 | audio_array, 83 | self.input_sample_rate, 84 | self.logits_inferencer.asr_model_sample_rate, 85 | ) 86 | logits = self.logits_inferencer.calc_logits(audio_array) 87 | return self.__aligned_decode(logits, len(audio_array)) 88 | 89 | @beartype 90 | def __aligned_decode( 91 | self, logits: TorchTensor2D, audio_array_seq_len: int 92 | ) -> TimestampedLetters: 93 | """ 94 | letters aligned to audio-frames 95 | 96 | """ 97 | dec_out: LogitAlignedTranscript = self.decoder.ctc_decode(logits.numpy())[0] 98 | 99 | logits_seq_len = logits.size()[0] 100 | audio_to_logits_ratio = audio_array_seq_len / logits_seq_len 101 | timestamps = [ 102 | audio_to_logits_ratio * i / self.input_sample_rate 103 | for i in dec_out.logit_ids 104 | ] 105 | 106 | return TimestampedLetters( 107 | dec_out.text, np.array(timestamps) 108 | ) # ,dec_out.logits_score,dec_out.lm_score 109 | -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/requirements.txt: -------------------------------------------------------------------------------- 1 | pyctcdecode==0.5.0 -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | 6 | def req_file(filename, folder="./"): 7 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 8 | content = f.readlines() 9 | # you may also want to remove whitespace characters 10 | # Example: `\n` at the end of each line 11 | return [x.strip() for x in content] 12 | 13 | 14 | install_requires = req_file("requirements.txt") 15 | 16 | with open("README.md") as f: 17 | readme = f.read() 18 | 19 | 20 | setup( 21 | name="ctc-asr-chunked-inference", 22 | version="0.1", 23 | author="Tilo Himmelsbach", 24 | author_email="dertilo@gmail.com", 25 | packages=find_packages(include=["ctc_asr_chunked_inference*"]), 26 | license="MIT License", 27 | long_description=readme, 28 | install_requires=install_requires, 29 | python_requires=">=3.9", 30 | ) 31 | -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/ctc-asr-chunked-inference/tests/__init__.py -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from warnings import filterwarnings 3 | 4 | from beartype import beartype 5 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 6 | 7 | from ctc_asr_chunked_inference.asr_infer_decode import ASRInferDecoder 8 | from ctc_decoding.huggingface_ctc_decoding import ( 9 | HFCTCGreedyDecoder, 10 | ) 11 | from ctc_decoding.lm_model_for_pyctcdecode import GzippedArpaAndUnigramsForPyCTCDecode 12 | from ctc_decoding.pyctc_decoder import PyCTCKenLMDecoder 13 | from ml4audio.asr_inference.logits_inferencer.asr_logits_inferencer import ( 14 | ASRLogitsInferencer, 15 | determine_casing, 16 | ) 17 | from ml4audio.asr_inference.logits_inferencer.hfwav2vec2_logits_inferencer import ( 18 | HFWav2Vec2LogitsInferencer, 19 | ) 20 | from ml4audio.asr_inference.logits_inferencer.huggingface_checkpoints import ( 21 | HfModelFromCheckpoint, 22 | ) 23 | from ml4audio.asr_inference.logits_inferencer.nemo_asr_logits_inferencer import ( 24 | NemoASRLogitsInferencer, 25 | ) 26 | from ml4audio.audio_utils.test_utils import ( 27 | get_test_vocab, 28 | get_test_cache_base, 29 | TEST_RESOURCES, 30 | ) 31 | from ml4audio.text_processing.asr_text_cleaning import ( 32 | VocabCasingAwareTextCleaner, 33 | Letters, 34 | ) 35 | from ml4audio.text_processing.kenlm_arpa import AnArpaFile 36 | 37 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 38 | 39 | from data_io.readwrite_files import read_lines 40 | import pytest 41 | 42 | cache_base = get_test_cache_base() 43 | 44 | TEST_MODEL_NAME = "facebook/wav2vec2-base-960h" 45 | 46 | 47 | @dataclass(frozen=True) 48 | class TestParams: 49 | input_sample_rate: int = 16000 50 | inferencer_name: str = "hf-wav2vec2" 51 | decoder_name: str = "greedy" 52 | lm_weight: float = 1.0 53 | 54 | 55 | @pytest.fixture 56 | def vocab(): 57 | return get_test_vocab() 58 | 59 | 60 | Words = list[str] 61 | 62 | 63 | @beartype 64 | def build_decoder(tp: TestParams, vocab: Words, letter_vocab: Letters): 65 | NAME2DECODER = { 66 | "greedy": HFCTCGreedyDecoder(tokenizer_name_or_path=TEST_MODEL_NAME), 67 | "beamsearch": PyCTCKenLMDecoder( 68 | vocab=vocab, 69 | lm_weight=tp.lm_weight, 70 | beta=0.5, 71 | beam_size=100, 72 | ngram_lm_model=GzippedArpaAndUnigramsForPyCTCDecode( 73 | cache_base=cache_base, 74 | raw_arpa=AnArpaFile(arpa_filepath=f"{TEST_RESOURCES}/lm.arpa"), 75 | transcript_cleaner=VocabCasingAwareTextCleaner( 76 | casing=determine_casing(letter_vocab), 77 | text_cleaner_name="en", 78 | letter_vocab=letter_vocab, 79 | ), 80 | ), 81 | ), 82 | } 83 | return NAME2DECODER[tp.decoder_name] 84 | 85 | 86 | def build_logits_inferencer(name: str) -> ASRLogitsInferencer: 87 | # SMALL_CTC_CONFORMER = "nvidia/stt_en_conformer_ctc_small" 88 | SMALL_CTC_CONFORMER = "stt_en_conformer_ctc_small" 89 | NAME2INFERENCER = { 90 | "hf-wav2vec2": HFWav2Vec2LogitsInferencer( 91 | checkpoint=HfModelFromCheckpoint( 92 | name=TEST_MODEL_NAME, 93 | model_name_or_path=TEST_MODEL_NAME, 94 | hf_model_type="Wav2Vec2ForCTC", 95 | base_dir=cache_base, 96 | ), 97 | ), 98 | "nemo-conformer": NemoASRLogitsInferencer(SMALL_CTC_CONFORMER), 99 | } 100 | return NAME2INFERENCER[name].build() 101 | 102 | 103 | @pytest.fixture 104 | def asr_infer_decoder(request): 105 | 106 | if not hasattr(request, "param"): 107 | tp = TestParams() 108 | else: 109 | tp: TestParams = request.param 110 | 111 | inferencer = build_logits_inferencer(tp.inferencer_name) 112 | asr = ASRInferDecoder( 113 | logits_inferencer=inferencer, 114 | decoder=build_decoder(tp, inferencer.vocab, inferencer.letter_vocab), 115 | input_sample_rate=tp.input_sample_rate, 116 | ) 117 | asr.build() 118 | return asr 119 | 120 | 121 | @pytest.fixture 122 | def librispeech_ref(): 123 | ref_txt = ( 124 | f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt" 125 | ) 126 | raw_ref = next(iter(read_lines(ref_txt))) 127 | return raw_ref 128 | 129 | 130 | @pytest.fixture 131 | def librispeech_audio_file(): 132 | return f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011.opus" 133 | -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/ctc-asr-chunked-inference/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt: -------------------------------------------------------------------------------- 1 | NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AND BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WETTING THE OPEN PAGE BEFORE HIM WITH HIS TEARS PUSHING INTO THE WE HOURS OF THE NIGHT HIS QUEST ANIMATED BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/test_aschinglupi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from conftest import get_test_cache_base 8 | from ctc_asr_chunked_inference.asr_chunk_infer_glue_pipeline import Aschinglupi 9 | from ctc_asr_chunked_inference.asr_infer_decode import ASRInferDecoder 10 | from misc_utils.prefix_suffix import BASE_PATHES 11 | from ml4audio.asr_inference.transcript_glueing import ( 12 | accumulate_transcript_suffixes, 13 | ) 14 | from ml4audio.asr_inference.transcript_gluer import ( 15 | TranscriptGluer, 16 | ASRStreamInferenceOutput, 17 | ) 18 | from ml4audio.audio_utils.overlap_array_chunker import ( 19 | OverlapArrayChunker, 20 | ) 21 | from ml4audio.audio_utils.audio_io import audio_messages_from_file 22 | from ml4audio.text_processing.asr_metrics import calc_cer 23 | from ml4audio.text_processing.asr_text_cleaning import ( 24 | clean_and_filter_text, 25 | Casing, 26 | ) 27 | from ml4audio.text_processing.pretty_diff import smithwaterman_aligned_icdiff 28 | 29 | BASE_PATHES["asr_inference"] = get_test_cache_base() 30 | os.environ["DEBUG_GLUER"] = "True" 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "step_dur,window_dur,max_step_dur,chunk_dur,max_CER,num_responses", 35 | [ 36 | # fmt: off 37 | (1.0, 2.0, None,0.1, 0.073, 25), 38 | # got worse due to using opus instead of wav 39 | (1.5, 3.0, None,0.1, 0.016, 17), 40 | (1.0, 4.0, None,0.1, 0.008, 25), 41 | 42 | (2.0, 4.0, None, 0.1, 0.008, 13), 43 | (1.0, 4.0, 2.0, 2.0, 0.008, 13), # same as above cause max_step_dur == chunk_dur == 2.0, the min_step_dur is kind of ignored, cause chunk_dur is fixed 44 | 45 | (4.0, 8.0, None,0.1, 0.0027, 7), 46 | (1.0, 8.0, None,0.1, 0.0, 25), 47 | # fmt: on 48 | ], 49 | ) 50 | def test_ASRStreamInferencer( 51 | asr_infer_decoder: ASRInferDecoder, 52 | librispeech_audio_file, 53 | librispeech_ref, 54 | step_dur: float, 55 | window_dur: float, 56 | max_step_dur: Optional[float], 57 | chunk_dur: float, 58 | max_CER: float, 59 | num_responses: int, 60 | ): 61 | 62 | SR = expected_sample_rate = asr_infer_decoder.input_sample_rate 63 | asr_input = list( 64 | audio_messages_from_file( 65 | librispeech_audio_file, expected_sample_rate, chunk_duration=chunk_dur 66 | ) 67 | ) 68 | assert asr_input[-1].end_of_signal 69 | audio_signal = np.concatenate([ac.array for ac in asr_input]) 70 | wav_length = 393920 71 | opus_is_alittle_longer = 70 72 | assert audio_signal.shape[0] == wav_length + opus_is_alittle_longer 73 | # audio_duration = audio_signal.shape[0] / SR 74 | 75 | streaming_asr: Aschinglupi = Aschinglupi( 76 | hf_asr_decoding_inferencer=asr_infer_decoder, 77 | transcript_gluer=TranscriptGluer(), 78 | audio_bufferer=OverlapArrayChunker( 79 | chunk_size=int(window_dur * SR), 80 | minimum_chunk_size=int(1 * SR), # one second! 81 | min_step_size=int(step_dur * SR), 82 | max_step_size=int(max_step_dur * SR) if max_step_dur is not None else None, 83 | ), 84 | ).build() 85 | 86 | streaming_asr.reset() 87 | 88 | outputs: list[ASRStreamInferenceOutput] = [ 89 | t for inpt in asr_input for t in streaming_asr.handle_inference_input(inpt) 90 | ] 91 | assert len(outputs) == num_responses 92 | assert outputs[-1].end_of_message 93 | 94 | suffixes_g = (tr.aligned_transcript for tr in outputs) 95 | transcript = accumulate_transcript_suffixes(suffixes_g) 96 | hyp = transcript.letters.strip(" ") 97 | 98 | # print(f"{audio_duration,prefix.timestamps[-1]}") 99 | ref = clean_and_filter_text( 100 | librispeech_ref, 101 | asr_infer_decoder.logits_inferencer.letter_vocab, 102 | text_cleaner="en", 103 | casing=Casing.upper, 104 | ) 105 | # print(smithwaterman_aligned_icdiff(ref, hyp)) 106 | cer = calc_cer([ref], [hyp]) 107 | print(f"{step_dur=},{window_dur=},{cer=}") 108 | 109 | assert cer <= max_CER 110 | -------------------------------------------------------------------------------- /ctc-asr-chunked-inference/tests/test_asr_infer_decode.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import pytest 4 | 5 | from ctc_asr_chunked_inference.asr_infer_decode import ASRInferDecoder 6 | from ml4audio.audio_utils.torchaudio_utils import load_resample_with_torch 7 | from ml4audio.text_processing.asr_metrics import calc_cer 8 | from tests.conftest import TestParams 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "asr_infer_decoder,max_CER", 13 | [ 14 | (TestParams(), 0.0), # WTF! this model reaches 0% CER! overfitted? 15 | (TestParams(4_000), 0.079), 16 | (TestParams(4_000, decoder_name="beamsearch"), 0.021), 17 | ( 18 | TestParams( 19 | input_sample_rate=4000, 20 | inferencer_name="nemo-conformer", 21 | decoder_name="beamsearch", 22 | lm_weight=0.0, 23 | ), 24 | 0.029, 25 | ), 26 | ], 27 | indirect=["asr_infer_decoder"], 28 | ) 29 | def test_ASRInferDecoder( 30 | asr_infer_decoder: ASRInferDecoder, 31 | librispeech_audio_file, 32 | librispeech_ref, 33 | max_CER, 34 | ): 35 | 36 | expected_sample_rate = asr_infer_decoder.input_sample_rate 37 | print(f"{asr_infer_decoder.logits_inferencer.letter_vocab=}") 38 | audio_array = load_resample_with_torch( 39 | librispeech_audio_file, target_sample_rate=expected_sample_rate 40 | ).numpy() 41 | 42 | start_time = time() 43 | transcript = asr_infer_decoder.transcribe_audio_array(audio_array.squeeze()) 44 | inference_duration = time() - start_time 45 | hyp = transcript.letters.upper() 46 | # print(smithwaterman_aligned_icdiff(librispeech_ref, hyp)) 47 | 48 | cer = calc_cer([librispeech_ref], [hyp]) 49 | decoder_name = asr_infer_decoder.decoder.__class__.__name__ 50 | print( 51 | f"{asr_infer_decoder.logits_inferencer.name},{decoder_name}\t{cer=}, inference took: {inference_duration} seconds" 52 | ) 53 | assert cer <= max_CER 54 | -------------------------------------------------------------------------------- /ctc-decoding/README.md: -------------------------------------------------------------------------------- 1 | # mostly pyctcdecode based ctc-decoding -------------------------------------------------------------------------------- /ctc-decoding/ctc_decoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/ctc-decoding/ctc_decoding/__init__.py -------------------------------------------------------------------------------- /ctc-decoding/ctc_decoding/ctc_decoding.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | 4 | from ctc_decoding.logit_aligned_transcript import LogitAlignedTranscript 5 | from misc_utils.beartypes import NumpyFloat2DArray 6 | 7 | NoneType = type(None) 8 | 9 | AlignedBeams = list[LogitAlignedTranscript] 10 | BatchOfAlignedBeams = list[AlignedBeams] 11 | 12 | 13 | @dataclass 14 | class BaseCTCDecoder: 15 | @abstractmethod 16 | def ctc_decode(self, logits: NumpyFloat2DArray) -> AlignedBeams: 17 | raise NotImplementedError 18 | 19 | # @property # TODO: needed? 20 | # def vocab(self): 21 | # return list(self._tokenizer.get_vocab().keys()) 22 | -------------------------------------------------------------------------------- /ctc-decoding/ctc_decoding/huggingface_ctc_decoding.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Union, Any 4 | 5 | import torch 6 | from beartype import beartype 7 | 8 | from transformers import PreTrainedTokenizer, AutoTokenizer 9 | from transformers.models.wav2vec2.tokenization_wav2vec2 import ( 10 | Wav2Vec2CTCTokenizerOutput, 11 | ) 12 | 13 | from ctc_decoding.ctc_decoding import BaseCTCDecoder, AlignedBeams 14 | from ctc_decoding.logit_aligned_transcript import LogitAlignedTranscript 15 | from misc_utils.beartypes import NumpyFloat2DArray, NeList 16 | from misc_utils.buildable import Buildable 17 | from misc_utils.prefix_suffix import PrefixSuffix 18 | from ml4audio.audio_utils.overlap_array_chunker import MessageChunk 19 | 20 | 21 | @dataclass 22 | class VocabFromHFTokenizer(Buildable, list[str]): 23 | tokenizer_name_or_path: Union[str, PrefixSuffix] 24 | 25 | @beartype 26 | def _build_self(self) -> NeList[str]: 27 | assert len(self) == 0 28 | self._tokenizer = AutoTokenizer.from_pretrained( 29 | str(self.tokenizer_name_or_path) 30 | ) 31 | vocab = list(self._tokenizer.get_vocab().keys()) 32 | 33 | self.extend(vocab) 34 | return vocab 35 | 36 | 37 | @dataclass 38 | class HFCTCDecoder(BaseCTCDecoder, Buildable): 39 | # TODO: remove this? 40 | vocab: NeList[str] 41 | # tokenizer_name_or_path: Union[str, PrefixSuffix] 42 | # _tokenizer: PreTrainedTokenizer = field(init=False) # default=UNDEFINED ? 43 | 44 | # def _build_self(self) -> Any: 45 | # self._tokenizer = AutoTokenizer.from_pretrained( 46 | # str(self.tokenizer_name_or_path) 47 | # ) 48 | # return self 49 | # 50 | # @property 51 | # def vocab(self): 52 | # return list(self._tokenizer.get_vocab().keys()) 53 | 54 | 55 | @dataclass 56 | class HFCTCGreedyDecoder(BaseCTCDecoder, Buildable): 57 | """ 58 | huggingface does not have a "proper" greedy decoder, but does argmax somewhere in the asr-pipeline 59 | see: https://github.com/huggingface/transformers/blob/7999ec125fc31428ed6879bf01bb013483daf704/src/transformers/pipelines/automatic_speech_recognition.py#L323 60 | 61 | method called: convert_tokens_to_string in tokenization_wav2vec2 62 | see: https://github.com/huggingface/transformers/blob/7999ec125fc31428ed6879bf01bb013483daf704/src/transformers/models/wav2vec2/tokenization_wav2vec2.py#L254 63 | does ctc to text conversion (collapsing the sequence) 64 | """ 65 | 66 | tokenizer_name_or_path: Union[str, PrefixSuffix] 67 | _tokenizer: PreTrainedTokenizer = field(init=False) # default=UNDEFINED ? 68 | 69 | def _build_self(self) -> Any: 70 | self._tokenizer = AutoTokenizer.from_pretrained( 71 | str(self.tokenizer_name_or_path) 72 | ) 73 | 74 | @property 75 | def vocab(self): 76 | return list(self._tokenizer.get_vocab().keys()) 77 | 78 | @beartype 79 | def ctc_decode(self, logits: NumpyFloat2DArray) -> AlignedBeams: 80 | 81 | greedy_path = torch.argmax(torch.from_numpy(logits), dim=-1).squeeze() 82 | out: Wav2Vec2CTCTokenizerOutput = self._tokenizer.decode( # noqa 83 | token_ids=greedy_path, 84 | output_char_offsets=True, 85 | skip_special_tokens=False, # for ctc (see huggingface/transformers) 86 | ) 87 | char_offsets: list[dict] = out.char_offsets 88 | vocab_space = [" "] + self.vocab 89 | vocab_space = [ 90 | c for c in vocab_space if c not in ["", "", "", "", "|"] 91 | ] 92 | 93 | char_offsets = list(filter(lambda d: d["char"] in vocab_space, char_offsets)) 94 | if len(char_offsets) == 0: 95 | char_offsets = [{"char": " ", "start_offset": 0}] 96 | 97 | return [ 98 | LogitAlignedTranscript( 99 | text="".join([d["char"] for d in char_offsets]), 100 | logit_ids=[int(d["start_offset"]) for d in char_offsets], 101 | ) 102 | ] 103 | -------------------------------------------------------------------------------- /ctc-decoding/ctc_decoding/logit_aligned_transcript.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from beartype import beartype 5 | 6 | from misc_utils.beartypes import NeStr, NeList 7 | 8 | TokenSpans = list[tuple[str, tuple[int, int]]] 9 | 10 | 11 | @dataclass 12 | class LogitAlignedTranscript: 13 | """ 14 | Text is character-wise aligned to logits, no time-stamps here. 15 | logits == ctc-matrix 16 | """ 17 | 18 | text: NeStr 19 | logit_ids: NeList[int] # TODO: not too strict? 20 | 21 | logits_score: Optional[float] = None 22 | lm_score: Optional[float] = None 23 | 24 | def __post_init__(self) -> None: 25 | """Validate data.""" 26 | have_same_len = len(self.text) == len(self.logit_ids) 27 | assert have_same_len, ( 28 | f"{self.text=} and {self.logit_ids=} have different length! " 29 | + f"{len(self.text)=}!={len(self.logit_ids)=}" 30 | ) 31 | 32 | @staticmethod 33 | def create_from_token_spans( 34 | token_spans: TokenSpans, 35 | lm_score: float, 36 | logits_score: float, 37 | ): 38 | text = " ".join([tok for tok, _ in token_spans]) 39 | return LogitAlignedTranscript( 40 | text=text, 41 | logit_ids=charwise_idx_for_tokenspans_via_linear_interpolation(token_spans), 42 | lm_score=lm_score, 43 | logits_score=logits_score, 44 | ) 45 | 46 | 47 | @beartype 48 | def charwise_idx_for_tokenspans_via_linear_interpolation( 49 | token_spans: TokenSpans, 50 | ) -> list[int]: 51 | seq_idx = [ 52 | round(start + (end - start) * k / len(word)) # interpolate 53 | for word, (start, end) in token_spans 54 | for k in range(len(word) + 1) 55 | ] 56 | return seq_idx[:-1] # all but the last one, which is a space 57 | -------------------------------------------------------------------------------- /ctc-decoding/requirements.txt: -------------------------------------------------------------------------------- 1 | kenlm@git+https://github.com/kpu/kenlm.git@master#egg=kenlm@996d7a6454b001337e9b8ea3d2ac1532f13c8e44 -------------------------------------------------------------------------------- /ctc-decoding/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | 6 | def req_file(filename, folder="./"): 7 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 8 | content = f.readlines() 9 | # you may also want to remove whitespace characters 10 | # Example: `\n` at the end of each line 11 | return [x.strip() for x in content] 12 | 13 | 14 | install_requires = req_file("requirements.txt") 15 | 16 | with open("README.md") as f: 17 | readme = f.read() 18 | 19 | 20 | setup( 21 | name="ctc-decoding", 22 | version="0.1", 23 | author="Tilo Himmelsbach", 24 | author_email="dertilo@gmail.com", 25 | packages=find_packages(include=["ctc_decoding*"]), 26 | license="MIT License", 27 | long_description=readme, 28 | install_requires=install_requires, 29 | python_requires=">=3.9", 30 | ) 31 | -------------------------------------------------------------------------------- /ctc-decoding/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from ml4audio.audio_utils.test_utils import TEST_RESOURCES 5 | 6 | sys.path.append(os.path.dirname(__file__)) # TODO: WTF! this is a hack! 7 | 8 | from data_io.readwrite_files import read_lines 9 | 10 | from warnings import filterwarnings 11 | 12 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 13 | from transformers import Wav2Vec2CTCTokenizer 14 | 15 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 16 | 17 | import pytest 18 | 19 | 20 | @pytest.fixture 21 | def hfwav2vec2_base_tokenizer(): 22 | return load_hfwav2vec2_base_tokenizer() 23 | 24 | 25 | def load_hfwav2vec2_base_tokenizer(): 26 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h") 27 | return tokenizer 28 | 29 | 30 | @pytest.fixture 31 | def librispeech_logtis_file(): 32 | return ( 33 | f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011_logits.npy" 34 | ) 35 | 36 | 37 | @pytest.fixture 38 | def librispeech_ref(): 39 | ref_txt = ( 40 | f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt" 41 | ) 42 | raw_ref = next(iter(read_lines(ref_txt))) 43 | return raw_ref 44 | -------------------------------------------------------------------------------- /ctc-decoding/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_logits.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/ctc-decoding/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_logits.npy -------------------------------------------------------------------------------- /ctc-decoding/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt: -------------------------------------------------------------------------------- 1 | NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AND BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WETTING THE OPEN PAGE BEFORE HIM WITH HIS TEARS PUSHING INTO THE WE HOURS OF THE NIGHT HIS QUEST ANIMATED BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES -------------------------------------------------------------------------------- /ctc-decoding/tests/test_greedy_decoding.py: -------------------------------------------------------------------------------- 1 | import icdiff 2 | import numpy as np 3 | 4 | from ctc_decoding.huggingface_ctc_decoding import ( 5 | HFCTCGreedyDecoder, 6 | ) 7 | from ml4audio.text_processing.asr_metrics import calc_cer 8 | 9 | TARGET_SAMPLE_RATE = 16000 10 | 11 | 12 | def test_GreedyDecoder( 13 | hfwav2vec2_base_tokenizer, 14 | librispeech_logtis_file, 15 | librispeech_ref, 16 | ): 17 | logits = np.load(librispeech_logtis_file, allow_pickle=True) 18 | decoder = HFCTCGreedyDecoder( 19 | tokenizer_name_or_path="facebook/wav2vec2-base-960h", 20 | ).build() 21 | transcript = decoder.ctc_decode(logits.squeeze())[0] 22 | hyp = transcript.text 23 | 24 | cer = calc_cer([librispeech_ref], [hyp]) 25 | assert cer == 0.0 26 | -------------------------------------------------------------------------------- /fastapi-asr-service/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/fastapi-asr-service/app/__init__.py -------------------------------------------------------------------------------- /fastapi-asr-service/app/fastapi_asr_service_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from omegaconf import OmegaConf 5 | 6 | from data_io.readwrite_files import read_json 7 | from misc_utils.dataclass_utils import decode_dataclass 8 | from misc_utils.prefix_suffix import BASE_PATHES, PrefixSuffix 9 | from nemo_vad.nemo_offline_vad import NemoOfflineVAD 10 | 11 | 12 | def load_asr_inferencer(): 13 | cache_root_in_container = os.environ["CACHE_ROOT"] 14 | cache_root = os.environ.get("cache_root", cache_root_in_container) 15 | BASE_PATHES["base_path"] = "/" 16 | BASE_PATHES["cache_root"] = cache_root 17 | BASE_PATHES["asr_inference"] = PrefixSuffix("cache_root", "ASR_INFERENCE") 18 | BASE_PATHES["am_models"] = PrefixSuffix("cache_root", "AM_MODELS") 19 | p = next( 20 | Path(cache_root).rglob("Aschinglupi*/dataclass.json") 21 | ) # TODO(tilo): hard-coded the class-name here!! 22 | jzon = read_json(str(p)) 23 | inferencer = decode_dataclass(jzon) 24 | inferencer.build() 25 | return inferencer 26 | 27 | 28 | # for parameters see: https://github.com/NVIDIA/NeMo/blob/aff169747378bcbcec3fc224748242b36205413f/examples/asr/conf/vad/vad_inference_postprocessing.yaml 29 | 30 | DEFAULT_NEMO_VAD_CONFIG = { 31 | "name": "vad_inference_postprocessing", 32 | "dataset": None, 33 | "num_workers": 0, 34 | "sample_rate": 16000, 35 | "gen_seg_table": True, 36 | "write_to_manifest": True, 37 | "prepare_manifest": {"auto_split": True, "split_duration": 400}, 38 | "vad": { 39 | "model_path": "app/vad_multilingual_marblenet.nemo", 40 | "parameters": { 41 | "normalize_audio": False, 42 | "window_length_in_sec": 0.15, 43 | "shift_length_in_sec": 0.01, 44 | "smoothing": "median", 45 | "overlap": 0.875, 46 | "postprocessing": { 47 | "onset": 0.3, 48 | "offset": 0.2, 49 | "pad_onset": 0.1, 50 | "pad_offset": 0.1, 51 | "min_duration_on": 0.5, 52 | "min_duration_off": 1.0, 53 | "filter_speech_first": True, 54 | }, 55 | }, 56 | }, 57 | "prepared_manifest_vad_input": None, 58 | "frame_out_dir": "vad_frame", 59 | "smoothing_out_dir": None, 60 | "table_out_dir": None, 61 | "out_manifest_filepath": None, 62 | } 63 | 64 | 65 | def load_vad_inferencer() -> NemoOfflineVAD: 66 | cfg = OmegaConf.create(DEFAULT_NEMO_VAD_CONFIG) 67 | cfg.vad.parameters.window_length_in_sec = 0.15 68 | cfg.vad.parameters.postprocessing.onset = 0.1 69 | cfg.vad.parameters.postprocessing.offset = 0.05 70 | cfg.vad.parameters.postprocessing.min_duration_on = 0.1 71 | cfg.vad.parameters.postprocessing.min_duration_off = 3.0 72 | cfg.vad.parameters.smoothing = "median" 73 | cfg.vad.parameters.overlap = 0.875 74 | vad = NemoOfflineVAD(cfg) 75 | vad.build() 76 | return vad 77 | -------------------------------------------------------------------------------- /fastapi-asr-service/build_model_in_docker.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from fastapi_asr_service.app.fastapi_asr_service_utils import load_asr_inferencer, \ 4 | load_vad_inferencer 5 | from misc_utils.dataclass_utils import ( 6 | to_dict, 7 | ) 8 | 9 | if __name__ == "__main__": 10 | """ 11 | maybe it acts as kind of sanity/integration test?? 12 | """ 13 | 14 | asr_inferencer = load_asr_inferencer() 15 | vad = load_vad_inferencer() 16 | pprint(to_dict(asr_inferencer)) 17 | pprint(to_dict(vad)) 18 | -------------------------------------------------------------------------------- /fastapi-asr-service/docker/fastapi_cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # see: https://github.com/docker/for-mac/issues/2155 2 | # global arg -> must declare it before every FROM or it will be local 3 | ARG MODEL_IMAGE 4 | FROM ${MODEL_IMAGE} as model_image 5 | 6 | FROM python:3.9-bullseye AS dependencies 7 | WORKDIR /code 8 | ENV APT_INSTALL="apt-get install -y --no-install-recommends" 9 | 10 | RUN apt-get update && \ 11 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 12 | build-essential \ 13 | ca-certificates \ 14 | wget \ 15 | git \ 16 | g++ \ 17 | cmake \ 18 | vim \ 19 | # for testing \ 20 | # libsndfile 21 | libsndfile1-dev \ 22 | # portaudio 23 | portaudio19-dev python3-pyaudio \ 24 | # ffmpeg 25 | ffmpeg libavcodec-extra \ 26 | # sox \ 27 | sox libsox-dev && \ 28 | apt-get clean && \ 29 | apt-get -y autoremove && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | 33 | ENV PATH="/venv/bin:$PATH" 34 | ENV PIP_INSTALL="/venv/bin/pip install --no-cache-dir --upgrade" 35 | 36 | RUN apt-get update && apt-get install -y python3-venv 37 | RUN python3 -m venv /venv && $PIP_INSTALL pip packaging setuptools 38 | RUN $PIP_INSTALL torchaudio@https://download.pytorch.org/whl/cpu/torchaudio-0.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl 39 | RUN $PIP_INSTALL install Cython 40 | 41 | # to trigger re-run of following, "disable" caching, see: https://stackoverflow.com/questions/35134713/disable-cache-for-specific-run-commands 42 | # use with: --build-arg CACHEBUST=$(date +%s) 43 | 44 | ARG CACHEBUST=1 45 | RUN echo "$CACHEBUST" 46 | 47 | COPY requirements.txt requirements.txt 48 | RUN $PIP_INSTALL -r requirements.txt 49 | 50 | # pruning venv 51 | RUN rm -rf /venv/lib/python3.9/site-packages/sklearn/ensemble 52 | RUN rm -rf /venv/lib/python3.9/site-packages/pynini.libs 53 | 54 | # ================================================================== 55 | # BUILD MODELS - stage 56 | # ------------------------------------------------------------------ 57 | 58 | FROM dependencies AS build_models 59 | ENV CACHE_ROOT="/model" 60 | COPY --from=model_image . /model 61 | COPY build_model_in_docker.py /code/build_model_in_docker.py 62 | COPY app/fastapi_asr_service_utils.py /code/fastapi_asr_service/app/fastapi_asr_service_utils.py 63 | COPY app/vad_inference_postprocessing.yaml /code/app/vad_inference_postprocessing.yaml 64 | COPY app/vad_multilingual_marblenet.nemo /code/app/vad_multilingual_marblenet.nemo 65 | RUN python /code/build_model_in_docker.py 66 | 67 | # pruning .cache 68 | RUN rm -rf /root/.cache/pip 69 | RUN rm -rf /root/.cache/matplotlib 70 | 71 | # ================================================================== 72 | # PRODUCTION - stage 73 | # ------------------------------------------------------------------ 74 | FROM python:3.9.13-slim-buster AS production 75 | LABEL maintainer="Tilo Himmelsbach" 76 | WORKDIR /code 77 | ENV PATH="/venv/bin:$PATH" 78 | ENV APT_INSTALL="apt-get install -y --no-install-recommends" 79 | 80 | RUN apt-get update && \ 81 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 82 | # libsndfile TODO: currently asr_logits_inferencer uses librosa to resample!! 83 | libsndfile1-dev \ 84 | # portaudio 85 | portaudio19-dev python3-pyaudio \ 86 | # ffmpeg 87 | ffmpeg libavcodec-extra \ 88 | # sox \ 89 | sox libsox-dev && \ 90 | 91 | apt-get clean && \ 92 | apt-get -y autoremove && \ 93 | rm -rf /var/lib/apt/lists/* 94 | 95 | # maybe for better docker-caching copy from model-image here, this only works if build_models-stage does not modify the models!! does it? well it could! 96 | #COPY --from=model_image . /model 97 | COPY --from=build_models /model /model 98 | COPY --from=build_models /venv /venv 99 | COPY --from=build_models /root/.cache /root/.cache 100 | 101 | ENV HF_DATASETS_OFFLINE=1 102 | ENV TRANSFORMERS_OFFLINE=1 103 | 104 | # PYTHONFAULTHANDLER TODO: wasdatdenn? 105 | ENV PYTHONFAULTHANDLER=1 106 | ENV CACHE_ROOT="/model" 107 | # ENV JINA_MP_START_METHOD=spawn 108 | 109 | COPY app /code/app 110 | 111 | CMD ["/bin/bash", "-c", "source /venv/bin/activate && \ 112 | uvicorn app.main:app --host 0.0.0.0 --port 8000"] -------------------------------------------------------------------------------- /fastapi-asr-service/requirements.txt: -------------------------------------------------------------------------------- 1 | ml4audio@git+https://github.com/SELMA-project/ml4audio@main#egg=ml4audio 2 | # if pushed changes to ml4audio do cache-bust here! -> docker (buildkit) does not reinstall if this file is not changing! 3 | # datasets # why? 4 | python-levenshtein 5 | beartype 6 | numba==0.53.1 7 | numpy==1.21.6 # why? 8 | librosa 9 | kenlm@git+https://github.com/kpu/kenlm.git@master#egg=kenlm 10 | pyctcdecode 11 | # pytest 12 | 13 | fastapi #==0.78.0 14 | Flask #==2.1.2 15 | icdiff 16 | jina==3.6.6 17 | jiwer 18 | # nemo-toolkit[nlp]==1.9.0 # hopefully this is not needed! 19 | torchaudio@https://download.pytorch.org/whl/cpu/torchaudio-0.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl 20 | tqdm 21 | transformers==4.22.1 22 | resampy==0.2.2 # newer version here somehow deadlocked jina-ai executors! super strange! 23 | python-multipart 24 | uvicorn[standard] 25 | omegaconf 26 | nemo_toolkit[asr]==1.11.0 27 | wandb # WTF!! nemo wants it!! -------------------------------------------------------------------------------- /fastapi-asr-service/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/fastapi-asr-service/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus -------------------------------------------------------------------------------- /fastapi-asr-service/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt: -------------------------------------------------------------------------------- 1 | NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AND BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WETTING THE OPEN PAGE BEFORE HIM WITH HIS TEARS PUSHING INTO THE WE HOURS OF THE NIGHT HIS QUEST ANIMATED BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES -------------------------------------------------------------------------------- /fastapi-asr-service/tests/test_fastapi_asr_service.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import icdiff 4 | import pytest 5 | from fastapi import status 6 | from fastapi.testclient import TestClient 7 | 8 | from app.main import app as webapp 9 | from data_io.readwrite_files import read_file 10 | from ml4audio.text_processing.asr_metrics import calc_cer 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def test_client() -> TestClient: 15 | with TestClient(webapp) as tc: 16 | yield tc 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def audio_file() -> str: 21 | return "tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus" 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def transcript_reference() -> str: 26 | return read_file( 27 | "tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt" 28 | ) 29 | 30 | 31 | def test_transcripe_endpoint(test_client, audio_file, transcript_reference): 32 | assert os.path.isdir( 33 | os.environ.get("CACHE_ROOT", "no-dir") 34 | ), "CACHE_ROOT where Aschinglupi got exported must be set as env-variable!" 35 | max_CER = 0.02 36 | 37 | f = open(audio_file, "rb") 38 | files = {"file": (f.name, f, "multipart/form-data")} 39 | 40 | resp = test_client.post( 41 | "/transcribe", 42 | files=files, 43 | ) 44 | 45 | assert resp.status_code == status.HTTP_200_OK 46 | hyp = resp.json()["text"] 47 | ref = transcript_reference.upper() 48 | # cd = icdiff.ConsoleDiff(cols=120) 49 | # diff_line = "\n".join( 50 | # cd.make_table( 51 | # [ref], 52 | # [hyp], 53 | # "ref", 54 | # "hyp", 55 | # ) 56 | # ) 57 | # print(diff_line) 58 | 59 | cer = calc_cer([(hyp, transcript_reference)]) 60 | assert cer <= max_CER 61 | -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/ctc_data_collator.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from dataclasses import dataclass 3 | from pprint import pprint 4 | from typing import Union, Optional, Any 5 | 6 | import torch 7 | from beartype import beartype 8 | from transformers import Wav2Vec2Processor, BatchFeature 9 | 10 | from misc_utils.beartypes import NeList 11 | 12 | 13 | @dataclass 14 | class DataCollatorCTCWithPadding: 15 | """ 16 | based on: https://github.com/huggingface/transformers/blob/b9bb417324c0d9013c505dc39c016ab9ca0e23c8/examples/research_projects/wav2vec2/run_common_voice.py#L143 17 | 18 | Data collator that will dynamically pad the inputs received. 19 | Args: 20 | processor (:class:`~transformers.Wav2Vec2Processor`) 21 | The processor used for proccessing the data. 22 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 23 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 24 | among: 25 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 26 | sequence if provided). 27 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 28 | maximum acceptable input length for the model if that argument is not provided. 29 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 30 | different lengths). 31 | max_length (:obj:`int`, `optional`): 32 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 33 | max_length_labels (:obj:`int`, `optional`): 34 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 35 | pad_to_multiple_of (:obj:`int`, `optional`): 36 | If set will pad the sequence to a multiple of the provided value. 37 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 38 | 7.5 (Volta). 39 | """ 40 | 41 | processor: Wav2Vec2Processor 42 | padding: Union[bool, str] = "longest" 43 | pad_to_multiple_of: Optional[int] = None 44 | pad_to_multiple_of_labels: Optional[int] = None 45 | 46 | some_batch: Optional[Any] = None 47 | 48 | @beartype 49 | def _process_pad(self, features: NeList) -> BatchFeature: 50 | # split inputs and labels since they have to be of different lenghts and need 51 | # different padding methods 52 | input_features = [ 53 | {"input_values": feature["input_values"]} for feature in features 54 | ] 55 | label_features = [{"input_ids": feature["labels"]} for feature in features] 56 | 57 | batch = self.processor.feature_extractor.pad( 58 | input_features, 59 | padding=self.padding, 60 | pad_to_multiple_of=self.pad_to_multiple_of, 61 | return_tensors="pt", 62 | ) 63 | # with self.processor.as_target_processor(): # tilo does not like this implicit processor switching 64 | labels_batch = self.processor.tokenizer.pad( # explicitly use tokenizier here 65 | label_features, 66 | padding=self.padding, 67 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 68 | return_tensors="pt", 69 | ) 70 | 71 | # replace padding with -100 to ignore loss correctly 72 | labels = labels_batch["input_ids"].masked_fill( 73 | labels_batch.attention_mask.ne(1), -100 74 | ) 75 | batch["labels"] = labels 76 | return batch 77 | 78 | @beartype 79 | def __call__( 80 | self, features: NeList[dict[str, Union[list[int], torch.Tensor]]] 81 | ) -> BatchFeature: 82 | """ 83 | TODO(tilo): why did I want a try-except here? 84 | """ 85 | 86 | try: 87 | batch = self._process_pad(features) 88 | if self.some_batch is None: 89 | self.some_batch = batch 90 | 91 | except Exception as e: 92 | print(e) 93 | traceback.print_exc() 94 | print("WARNING: Collator failed!!!") 95 | pprint(f"input: {features}") 96 | batch = self.some_batch 97 | 98 | return batch 99 | -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/ctc_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Any 2 | 3 | import torch 4 | from torch import nn 5 | from torch.utils.data.dataloader import DataLoader 6 | from transformers import Trainer 7 | 8 | from huggingface_wav2vec2_finetuning.ctc_data_collator import DataCollatorCTCWithPadding 9 | from huggingface_wav2vec2_finetuning.hf_finetune_utils import ( 10 | ReduceLROnPlateauWithWarmup, 11 | ) 12 | 13 | 14 | def dummpy_step(**kwargs): 15 | pass 16 | 17 | 18 | class CTCTrainer(Trainer): 19 | """ 20 | in dryrun-mode does only one single forward-pass 21 | """ 22 | 23 | def __init__(self, **kwargs): 24 | super().__init__(**kwargs) 25 | 26 | def training_step( 27 | self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] 28 | ) -> torch.Tensor: 29 | try: 30 | if ( 31 | isinstance(self.lr_scheduler, ReduceLROnPlateauWithWarmup) 32 | and self.state.global_step % self.args.eval_steps == 0 33 | ): 34 | self.lr_scheduler.step(metrics=self.state.best_metric) 35 | loss_d = super().training_step(model, inputs) 36 | except Exception as e: 37 | err = "CUDA out of memory" if "CUDA out of memory" in str(e) else e 38 | print(f"train-step failed with: {err}") 39 | model.zero_grad() 40 | loss_d = torch.tensor(torch.nan) 41 | return loss_d 42 | 43 | def get_train_dataloader(self) -> DataLoader: 44 | """ 45 | Returns the training :class:`~torch.utils.data.DataLoader`. 46 | 47 | Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted 48 | to distributed training if necessary) otherwise. 49 | 50 | Subclass and override this method if you want to inject some custom behavior. 51 | """ 52 | 53 | assert isinstance( 54 | self.data_collator, DataCollatorCTCWithPadding 55 | ), f"{type(self.data_collator)=}" 56 | 57 | """ 58 | # https://pytorch.org/docs/stable/data.html 59 | 60 | # loading from an iterable-style dataset is roughly equivalent with: 61 | 62 | for data in iter(dataset): 63 | yield collate_fn(data) 64 | 65 | """ 66 | return DataLoader( 67 | self.train_dataset, 68 | batch_size=self.args.train_batch_size, 69 | collate_fn=self.data_collator, 70 | num_workers=self.args.dataloader_num_workers, 71 | pin_memory=self.args.dataloader_pin_memory, 72 | ) 73 | -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "stage3_gather_16bit_weights_on_model_save": true 49 | }, 50 | 51 | "gradient_accumulation_steps": "auto", 52 | "gradient_clipping": "auto", 53 | "steps_per_print": 2000, 54 | "train_batch_size": "auto", 55 | "train_micro_batch_size_per_gpu": "auto", 56 | "wall_clock_breakdown": false 57 | } -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | nemo-toolkit[asr] 2 | datasets 3 | wandb 4 | bitsandbytes-cuda113 # maybe adapt to your coda-version -------------------------------------------------------------------------------- /huggingface_wav2vec2_finetuning/stream_ftdataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from dataclasses import dataclass 3 | from typing import Optional, Iterator, Union 4 | 5 | import math 6 | import torch 7 | from beartype import beartype 8 | 9 | from huggingface_wav2vec2_finetuning.data_loading_commons import IterableDatasetBase 10 | from misc_utils.buildable import Buildable 11 | from misc_utils.dataclass_utils import _UNDEFINED, UNDEFINED 12 | from misc_utils.utils import buffer_shuffle 13 | 14 | from ml4audio.audio_utils.audio_data_models import AudioTextData, ArrayText 15 | 16 | 17 | @beartype 18 | def calc_this_workers_start_end(start: int, end: int) -> tuple[int, int]: 19 | """ 20 | see: https://github.com/pytorch/pytorch/blob/f2582a59d0835323ebf143726ea79ba52e7cceff/torch/utils/data/dataset.py#L128 21 | 22 | TODO: actually this is a stupid idea! would be better if kth-worker would "not-skip" every k-th sample 23 | thereby no need to eat large portions of the entire input-iterable! which can be very expensive! 24 | """ 25 | worker_info = torch.utils.data.get_worker_info() 26 | if worker_info is None: # single-process data loading, return the full iterator 27 | iter_start = start 28 | iter_end = end 29 | else: # in a worker process 30 | per_worker = int(math.ceil((end - start) / float(worker_info.num_workers))) 31 | worker_id = worker_info.id 32 | iter_start = start + worker_id * per_worker 33 | iter_end = min(iter_start + per_worker, end) 34 | print(f"{worker_id=}: {iter_start=}, {iter_end=}") 35 | return iter_start, iter_end 36 | 37 | 38 | @dataclass 39 | class IterableSlicingDataset(IterableDatasetBase, Buildable): 40 | """ 41 | multiple data-loaders reading from corpus need to start-end at different "points" in the iterable 42 | """ 43 | 44 | array_texts: Union[_UNDEFINED, AudioTextData] = UNDEFINED 45 | limit: Optional[int] = None 46 | shufflebuffer_size: Optional[int] = None 47 | 48 | def __len__(self): 49 | return self.limit 50 | 51 | @beartype 52 | def _generate_array_texts(self) -> Iterator[ArrayText]: 53 | iter_start, iter_end = calc_this_workers_start_end(0, self.limit) 54 | g = ( 55 | (a, t) 56 | # for corpus in self.corpus 57 | for a, t in self.array_texts 58 | ) 59 | array_text_g = itertools.islice(g, iter_start, iter_end) 60 | if self.shufflebuffer_size is not None: 61 | g = buffer_shuffle(array_text_g, buffer_size=self.shufflebuffer_size) 62 | else: 63 | g = array_text_g 64 | return iter(g) 65 | -------------------------------------------------------------------------------- /ml4audio/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/__init__.py: -------------------------------------------------------------------------------- 1 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 2 | from warnings import filterwarnings 3 | 4 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 5 | 6 | from misc_utils.beartyped_dataclass_patch import ( 7 | beartype_all_dataclasses_of_this_files_parent, 8 | ) 9 | 10 | beartype_all_dataclasses_of_this_files_parent(__file__) 11 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/inference.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Annotated 4 | 5 | from beartype.vale import Is 6 | 7 | from misc_utils.beartypes import NpFloatDim1 8 | from ml4audio.audio_utils.audio_segmentation_utils import ( 9 | StartEndArray, 10 | StartEndText, 11 | is_non_overlapping, 12 | ) 13 | 14 | 15 | class SetupTearDown: 16 | @abstractmethod 17 | def __enter__(self): 18 | """ 19 | use to load the model into memory, prepare things 20 | """ 21 | raise NotImplementedError 22 | 23 | @abstractmethod 24 | def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): 25 | """ 26 | use as tear-down, to free memory, unload model 27 | """ 28 | raise NotImplementedError 29 | 30 | 31 | StartEndTextsNonOverlap = Annotated[ 32 | list[StartEndText], 33 | Is[is_non_overlapping], 34 | ] 35 | 36 | 37 | @dataclass 38 | class AudioArray2SegmentedTranscripts(SetupTearDown): 39 | """ 40 | TODO: AA2ST = Audio Array 2 Segmented Transcripts 41 | """ 42 | 43 | @property 44 | @abstractmethod 45 | def name(self) -> str: 46 | raise NotImplementedError 47 | 48 | @property 49 | def sample_rate(self) -> int: 50 | # rename to input_sample_rate? 51 | return 16000 52 | 53 | @abstractmethod 54 | def audio_to_segmented_transcripts( 55 | self, audio_array: NpFloatDim1 56 | ) -> StartEndTextsNonOverlap: 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/logits_inferencer/__init__.py: -------------------------------------------------------------------------------- 1 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 2 | from warnings import filterwarnings 3 | 4 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 5 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/logits_inferencer/asr_logits_inferencer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from typing import ClassVar 4 | 5 | import torch 6 | from beartype import beartype 7 | from transformers import ( 8 | set_seed, 9 | ) 10 | 11 | from misc_utils.beartypes import ( 12 | NeList, 13 | NeStr, 14 | TorchTensor2D, 15 | NeNpFloatDim1, 16 | NeNpFloatDim1, 17 | ) 18 | from misc_utils.buildable import Buildable 19 | from ml4audio.text_processing.asr_text_cleaning import Casing, Letters 20 | 21 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | set_seed(42) 23 | 24 | 25 | def determine_casing(letter_vocab: Letters) -> Casing: 26 | more_than_half_is_upper = ( 27 | sum([1 if c.upper() == c else 0 for c in letter_vocab]) > len(letter_vocab) / 2 28 | ) 29 | casing = Casing.upper if more_than_half_is_upper else Casing.lower 30 | return casing 31 | 32 | 33 | @dataclass 34 | class ASRLogitsInferencer(Buildable): 35 | """ 36 | Asr Connectionis temporal classification (CTC) Logits Inference 37 | 38 | ────────────────────────────────────────────── 39 | ──────│─────│───────│─────│───────│────────│── 40 | ──────│─────│───────│─────│───────│────────│── 41 | ──────│──┌───┬────┬───┐──┌┐───────│┌┐──────│── 42 | ──────│──│┌─┐│┌┐┌┐│┌─┐│──││───────┌┘└┐─────│── 43 | ──────│──││─└┴┘││└┤││└┘──││┌──┬──┬┼┐┌┼──┐──│── 44 | ──────│──││─┌┐─││─│││┌┬──┤││┌┐│┌┐├┤│││──┤──│── 45 | ──────│──│└─┘│─││─│└─┘├──┤└┤└┘│└┘│││└┼──│──│── 46 | ──────│──└───┘─└┘─└───┘──└─┴──┴─┐├┘└─┴──┘──│── 47 | ──────│─────│───────│─────│───┌─┘││────────│── 48 | ──────│─────│───────│─────│───└──┘│────────│── 49 | ──────│─────│───────│─────│───────│────────│── 50 | ──────│─────│───────│─────│───────│────────│── 51 | 52 | """ 53 | 54 | asr_model_sample_rate: ClassVar[int] = 16000 55 | 56 | @property 57 | @beartype 58 | def name(self) -> NeStr: 59 | raise NotImplementedError 60 | 61 | @property 62 | @abstractmethod 63 | def vocab(self) -> NeList[str]: 64 | raise NotImplementedError 65 | 66 | @property 67 | @abstractmethod 68 | def letter_vocab(self) -> Letters: 69 | raise NotImplementedError 70 | 71 | @abstractmethod 72 | @beartype 73 | def calc_logits(self, audio: NeNpFloatDim1) -> TorchTensor2D: 74 | raise NotImplementedError 75 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/logits_inferencer/nemo_asr_logits_inferencer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import nemo.collections.asr as nemo_asr 5 | import torch 6 | from beartype import beartype 7 | from nemo.collections.asr.models import EncDecCTCModel 8 | 9 | from misc_utils.beartypes import ( 10 | NeNpFloatDim1, 11 | NeList, 12 | NeStr, 13 | TorchTensor2D, 14 | ) 15 | from misc_utils.dataclass_utils import UNDEFINED 16 | from misc_utils.utils import slugify_with_underscores 17 | from ml4audio.asr_inference.logits_inferencer.asr_logits_inferencer import ( 18 | ASRLogitsInferencer, 19 | ) 20 | from ml4audio.text_processing.asr_text_cleaning import Letters 21 | 22 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | @dataclass 26 | class NemoASRLogitsInferencer(ASRLogitsInferencer): 27 | model_name_or_path: str = UNDEFINED 28 | _model: EncDecCTCModel = field(init=False) 29 | 30 | @property 31 | def name(self) -> NeStr: 32 | return slugify_with_underscores(self.model_name_or_path) 33 | 34 | def _build_self(self): 35 | # see: tools/ctc_segmentation/scripts/run_ctc_segmentation.py in nemo-code 36 | model_name = self.model_name_or_path 37 | if os.path.exists(model_name): 38 | raise NotImplementedError 39 | self._model = nemo_asr.models.EncDecCTCModel.restore_from(model_name) 40 | else: 41 | self._model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained( 42 | model_name, strict=False 43 | ) 44 | # else: 45 | # raise ValueError( 46 | # f"{model_name} not a valid model name or path. Provide path to the pre-trained checkpoint " 47 | # f"or choose from {nemo_asr.models.EncDecCTCModelBPE.list_available_models()}" 48 | # ) 49 | self._model.eval() 50 | self._model = self._model.to(DEVICE) 51 | return self 52 | 53 | @property 54 | @beartype 55 | def vocab(self) -> NeList[str]: 56 | vocabulary = self._model.cfg.decoder.vocabulary 57 | vocabulary = list(vocabulary) 58 | return vocabulary 59 | 60 | @property 61 | @beartype 62 | def letter_vocab(self) -> Letters: 63 | bad_letters = ["<", ">", "▁"] 64 | return [l for l in dict.fromkeys("".join(self.vocab)) if l not in bad_letters] 65 | 66 | @beartype 67 | def calc_logits(self, audio: NeNpFloatDim1) -> TorchTensor2D: 68 | device = next(self._model.parameters()).device 69 | audio_signal = torch.as_tensor(audio.reshape(1, -1), dtype=torch.float32) 70 | audio_signal_len = torch.as_tensor([audio.size], dtype=torch.int64) 71 | 72 | with torch.no_grad(): 73 | log_probs, _encoded_len, _greedy_predictions = self._model( 74 | input_signal=audio_signal.to(device), 75 | input_signal_length=audio_signal_len.to(device), 76 | ) 77 | log_probs = log_probs.cpu().squeeze() 78 | 79 | return log_probs 80 | 81 | 82 | # TODO: what about these? 83 | # 84 | # @beartype 85 | # def calc_logsoftmaxed_logits(self, audio: NpFloatDim1) -> NumpyFloat2DArray: 86 | # device = next(self._model.parameters()).device 87 | # audio_signal = torch.as_tensor(audio.reshape(1, -1), dtype=torch.float32) 88 | # audio_signal_len = torch.as_tensor([audio.size], dtype=torch.int64) 89 | # 90 | # with torch.no_grad(): 91 | # log_probs, encoded_len, greedy_predictions = self._model( 92 | # input_signal=audio_signal.to(device), 93 | # input_signal_length=audio_signal_len.to(device), 94 | # ) 95 | # log_probs = log_probs.cpu().squeeze() 96 | # 97 | # log_probs = self._post_process_for_ctc_alignment(log_probs) 98 | # assert log_probs.shape[1] == len( 99 | # self.vocab 100 | # ), f"{log_probs.shape=},{len(self.vocab)}" 101 | # return log_probs 102 | # 103 | # @beartype 104 | # def _post_process_for_ctc_alignment( 105 | # self, log_probs: NumpyFloat2DArray 106 | # ) -> NumpyFloat2DArray: 107 | # """ 108 | # see:nvidia-nemo-code: tools/ctc_segmentation/scripts/run_ctc_segmentation.py 109 | # """ 110 | # blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1)) 111 | # log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1) 112 | # return log_probs 113 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/openai_whisper_inferencer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field, asdict 3 | from typing import Any, Optional 4 | 5 | import whisper as whisper_module 6 | from beartype import beartype 7 | from whisper import Whisper, DecodingOptions 8 | 9 | from misc_utils.beartypes import NeNpFloatDim1, NpFloatDim1 10 | from misc_utils.dataclass_utils import UNDEFINED, FillUndefined 11 | from misc_utils.prefix_suffix import PrefixSuffix, BASE_PATHES 12 | from ml4audio.asr_inference.inference import ( 13 | StartEndTextsNonOverlap, 14 | ) 15 | from ml4audio.asr_inference.whisper_inference import ( 16 | WhisperInferencer, 17 | fix_whisper_segments, 18 | WhisperArgs, 19 | ) 20 | from whisper.utils import exact_div 21 | 22 | 23 | @dataclass(frozen=True) 24 | class OpenAiWhisperArgs(WhisperArgs, DecodingOptions): 25 | """ 26 | for defaults see transcribe-method 27 | """ 28 | 29 | compression_ratio_threshold: Optional[float] = 2.4 30 | logprob_threshold: Optional[float] = -1.0 31 | no_speech_threshold: Optional[float] = 0.6 32 | condition_on_previous_text: bool = True 33 | initial_prompt: Optional[str] = None 34 | word_timestamps: bool = False 35 | prepend_punctuations: str = "\"'“¿([{-" 36 | append_punctuations: str = "\"'.。,,!!??::”)]}、" 37 | 38 | 39 | @dataclass 40 | class OpenAIWhisperASRSegmentInferencer(WhisperInferencer): 41 | """ 42 | https://github.com/saharmor/whisper-playground 43 | """ 44 | 45 | model_name: str = "base" 46 | whisper_args: Optional[OpenAiWhisperArgs] = None 47 | _model: Whisper = field(init=False, repr=False) 48 | base_dir: PrefixSuffix = field( 49 | default_factory=lambda: PrefixSuffix("cache_root", "MODELS/WHISPER_MODELS"), 50 | init=False, 51 | ) 52 | 53 | def __post_init__(self): 54 | if self.model_name.startswith("openai/whisper-"): 55 | self.model_name = self.model_name.replace("openai/whisper-", "") 56 | 57 | @property 58 | def name(self) -> str: 59 | return f"whisper-{self.model_name}" 60 | 61 | @property 62 | def sample_rate(self) -> int: 63 | return whisper_module.audio.SAMPLE_RATE 64 | 65 | @property 66 | def _is_data_valid(self) -> bool: 67 | return os.path.isfile(self._checkpoint_file) 68 | 69 | @property 70 | def _checkpoint_file(self) -> str: 71 | """ 72 | see: whisper/__init__.py _download method 73 | """ 74 | return f"{self.data_dir}/{os.path.basename(whisper_module._MODELS[self.model_name])}" 75 | 76 | def _build_data(self) -> Any: 77 | checkpoint_file = whisper_module._download( 78 | whisper_module._MODELS[self.model_name], self.data_dir, in_memory=False 79 | ) 80 | assert checkpoint_file == self._checkpoint_file 81 | 82 | def __enter__(self): 83 | self._model = whisper_module.load_model(self._checkpoint_file) 84 | 85 | def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): 86 | del self._model 87 | 88 | @beartype 89 | def predict_transcribed_with_whisper_args( 90 | self, audio_array: NpFloatDim1, whisper_args: OpenAiWhisperArgs 91 | ) -> StartEndTextsNonOverlap: 92 | from whisper import audio 93 | 94 | if hasattr(whisper_args, "chunk_length"): 95 | audio.CHUNK_LENGTH = whisper_args.chunk_length 96 | audio.N_SAMPLES = audio.CHUNK_LENGTH * audio.SAMPLE_RATE 97 | audio.N_FRAMES = exact_div(audio.N_SAMPLES, audio.HOP_LENGTH) 98 | 99 | audio_dur = float(len(audio_array) / self.sample_rate) 100 | resp = self._model.transcribe(audio=audio_array, **asdict(whisper_args)) 101 | 102 | # resp["text"].strip(" ") # somehow this sometimes repeats the transcribt twice 103 | whisper_segments = resp["segments"] 104 | if len(whisper_segments) > 0: 105 | raw_whisper_segments = [ 106 | (seg["start"], seg["end"], seg["text"]) for seg in whisper_segments 107 | ] 108 | start_end_text = fix_whisper_segments( 109 | raw_whisper_segments, 110 | audio_dur, 111 | ) 112 | else: 113 | start_end_text = [] 114 | return start_end_text 115 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/pytorch_to_onnx_for_wav2vec.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from beartype import beartype 4 | from beartype.vale import Is 5 | from transformers import Wav2Vec2ForCTC 6 | import torch 7 | import argparse 8 | 9 | 10 | @beartype 11 | def convert_to_onnx(model_id_or_path: str, onnx_model_name): 12 | # based on : https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py 13 | print(f"Converting {model_id_or_path} to onnx") 14 | # using: "torch_dtype=torch.float16" leads to "weight_norm_kernel" not implemented for 'Half' 15 | model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path) 16 | audio_len = 250000 17 | 18 | x = torch.randn(1, audio_len, requires_grad=True) 19 | 20 | torch.onnx.export( 21 | model, # model being run 22 | x, # model input (or a tuple for multiple inputs) 23 | onnx_model_name, # where to save the model (can be a file or file-like object) 24 | export_params=True, # store the trained parameter weights inside the model file 25 | opset_version=11, # the ONNX version to export the model to 26 | do_constant_folding=True, # whether to execute constant folding for optimization 27 | input_names=["input"], # the model's input names 28 | output_names=["output"], # the model's output names 29 | dynamic_axes={ 30 | "input": {1: "audio_len"}, # variable length axes 31 | "output": {1: "audio_len"}, 32 | }, 33 | ) 34 | 35 | 36 | # WeightTypeName = Annotated[str, Is[lambda s: s in ONNX_QUANT_WEIGHT_TYPES.keys()]] 37 | WeightTypeName = Annotated[ 38 | str, Is[lambda s: s in ["QUInt8", "QInt8"]] 39 | ] # TODO: I was too lazy to rebuild docker-image with onnxruntime! 40 | 41 | 42 | @beartype 43 | def quantize_onnx_model( 44 | model_id_or_path: str, 45 | onnx_model_path: str, 46 | quantized_model_path: str, 47 | weight_type_name="QUInt8", 48 | ): 49 | """ 50 | TODO: 51 | use_external_data_format create extra file containing weights, this files absolute path on file system seems to be hard-coded in the onnx-file! 52 | so one cannot really copy it! 53 | """ 54 | 55 | # see: https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1150608315 56 | model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path) 57 | names = [name for name, _ in model.named_children()] 58 | 59 | prefix = ["MatMul", "Add", "Relu"] 60 | linear_names = [v for v in names if v.split("_")[0] in prefix] 61 | 62 | print("Starting quantization...") 63 | 64 | from onnxruntime.quantization import quantize_dynamic, QuantType 65 | 66 | ONNX_QUANT_WEIGHT_TYPES = { 67 | "QUInt8": QuantType.QUInt8, 68 | "QInt8": QuantType.QInt8, 69 | } 70 | 71 | quantize_dynamic( 72 | onnx_model_path, 73 | quantized_model_path, 74 | weight_type=ONNX_QUANT_WEIGHT_TYPES[ 75 | weight_type_name 76 | ], # better stay with default: QInt8 77 | use_external_data_format=True, # to support big models (>2GB) 78 | nodes_to_quantize=linear_names, 79 | extra_options={"MatMulConstBOnly": True}, 80 | ) 81 | 82 | print(f"Quantized model saved to: {quantized_model_path}") 83 | 84 | 85 | if __name__ == "__main__": 86 | """ 87 | # seems to work! 88 | python ml4audio/asr_inference/pytorch_to_onnx_for_wav2vec.py --model jonatasgrosman/wav2vec2-large-xlsr-53-english 89 | 90 | # big model 91 | python ml4audio/asr_inference/pytorch_to_onnx_for_wav2vec.py --model jonatasgrosman/wav2vec2-xls-r-1b-english 92 | # leads to 93 | ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 3850629924 94 | 95 | """ 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | "--model", 99 | type=str, 100 | default="ccoreilly/wav2vec2-large-100k-voxpopuli-catala", 101 | help="Model HuggingFace ID or path that will converted to ONNX", 102 | ) 103 | parser.add_argument( 104 | "--quantize", 105 | action="store_true", 106 | help="Whether to use also quantize the model or not", 107 | ) 108 | args = parser.parse_args() 109 | 110 | model_id_or_path = args.model 111 | onnx_model_name = model_id_or_path.split("/")[-1] + ".onnx" 112 | convert_to_onnx(model_id_or_path, onnx_model_name) 113 | if args.quantize: 114 | quantized_model_path = model_id_or_path.split("/")[-1] + ".quant.onnx" 115 | onnx_model_name = quantize_onnx_model( 116 | model_id_or_path, onnx_model_name, quantized_model_path 117 | ) 118 | 119 | import onnx 120 | 121 | onnx_model = onnx.load(onnx_model_name) 122 | onnx.checker.check_model(onnx_model) 123 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/transcript_gluer.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import os 3 | from dataclasses import field, dataclass 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | from beartype import beartype 8 | 9 | from misc_utils.buildable import Buildable 10 | from misc_utils.utils import just_try 11 | from ml4audio.asr_inference.transcript_glueing import ( 12 | calc_new_suffix, 13 | NO_NEW_SUFFIX, 14 | _NO_NEW_SUFFIX, 15 | ) 16 | from ml4audio.audio_utils.aligned_transcript import ( 17 | TimestampedLetters, 18 | ) 19 | 20 | DEBUG = os.environ.get("DEBUG", "False").lower() != "false" 21 | 22 | 23 | @dataclass 24 | class ASRStreamInferenceOutput: 25 | id: str 26 | aligned_transcript: TimestampedLetters # TODO: rename 27 | end_of_message: bool = False 28 | 29 | 30 | @dataclass 31 | class TranscriptGluer(Buildable): 32 | """ 33 | ───▄▄▄ 34 | ─▄▀░▄░▀▄ 35 | ─█░█▄▀░█ 36 | ─█░▀▄▄▀█▄█▄▀ 37 | ▄▄█▄▄▄▄███▀ 38 | 39 | """ 40 | 41 | _prefix: Optional[TimestampedLetters] = field(init=False, repr=False, default=None) 42 | seqmatcher: Optional[difflib.SequenceMatcher] = field( 43 | init=False, repr=False, default=None 44 | ) 45 | 46 | def __enter__(self): 47 | return self.build() 48 | 49 | def __exit__(self, exc_type, exc_val, exc_tb): 50 | pass 51 | 52 | def reset(self) -> None: 53 | self._prefix: Optional[TimestampedLetters] = None 54 | 55 | def _build_self(self): 56 | self.reset() 57 | self.seqmatcher = difflib.SequenceMatcher() 58 | 59 | @beartype 60 | def calc_transcript_suffix( 61 | self, inp: TimestampedLetters 62 | ) -> Union[TimestampedLetters, _NO_NEW_SUFFIX]: 63 | 64 | if self._prefix is None: 65 | self._prefix, new_suffix = inp, inp 66 | else: 67 | self._prefix, new_suffix = self._calc_glued_and_suffix(self._prefix, inp) 68 | 69 | return new_suffix 70 | 71 | @beartype 72 | def _calc_glued_and_suffix( 73 | self, prefix: TimestampedLetters, inp: TimestampedLetters 74 | ) -> tuple[TimestampedLetters, Union[TimestampedLetters, _NO_NEW_SUFFIX]]: 75 | new_suffix = just_try( 76 | lambda: calc_new_suffix(left=prefix, right=inp, sm=self.seqmatcher), 77 | default=NO_NEW_SUFFIX, 78 | # a failed glue does not add anything! In the hope that overlap is big enough so that it can be recovered by next glue! 79 | verbose=DEBUG, 80 | print_stacktrace=True, 81 | reraise=False, 82 | ) 83 | if new_suffix is not NO_NEW_SUFFIX: 84 | glued_trimmed = self._glue_and_trim(prefix, new_suffix) 85 | else: 86 | glued_trimmed = prefix 87 | return glued_trimmed, new_suffix 88 | 89 | def _glue_and_trim(self, prefix, new_suffix): 90 | KEEP_DURATION = 100 # was not working with 40 91 | prefix_to_keep = prefix.slice( 92 | np.argwhere(prefix.timestamps < new_suffix.timestamps[0]) 93 | ) 94 | glued = TimestampedLetters( 95 | prefix_to_keep.letters + new_suffix.letters, 96 | np.concatenate([prefix_to_keep.timestamps, new_suffix.timestamps]), 97 | ) 98 | glued_trimmed = glued.slice( 99 | np.argwhere(glued.timestamps > glued.timestamps[-1] - KEEP_DURATION) 100 | ) 101 | return glued_trimmed 102 | -------------------------------------------------------------------------------- /ml4audio/asr_inference/whisper_inference.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Annotated, Optional, Union 3 | 4 | from beartype import beartype 5 | from beartype.vale import Is 6 | 7 | from misc_utils.beartypes import NpFloatDim1, NeList 8 | from misc_utils.buildable_data import BuildableData 9 | from ml4audio.asr_inference.inference import ( 10 | AudioArray2SegmentedTranscripts, 11 | StartEndTextsNonOverlap, 12 | ) 13 | from ml4audio.audio_utils.audio_segmentation_utils import ( 14 | StartEnd, 15 | fix_segments_to_non_overlapping, 16 | ) 17 | 18 | MINIMAL_SEGMENT_LENGHTS = 0.08 # TODO: this is super arbitray! 19 | 20 | 21 | @beartype 22 | def fix_start_end(start_end: tuple[float, float], audio_dur: float) -> StartEnd: 23 | """ 24 | TODO 25 | 13.06. I was changing audio_dur to "Optional[float]=None" -> why? 26 | somehow I had an issue fix unfixable whisper-segments, multiple in a row being empty or something 27 | """ 28 | start, end = start_end 29 | if start < 0: 30 | print(f"WTF! whisper gave {start=}") 31 | start = 0.0 32 | 33 | # if end > audio_dur: # TODO: trying to correct "after-audio-hallucinations" like this is not working! 34 | # print(f"WTF! whisper gave {end=} that is after {audio_dur=} -> {audio_dur=}") 35 | # # if end-audio_dur>10.0: 36 | # # raise AssertionError(f"thats too much! cannot fix it!") 37 | # end = audio_dur 38 | 39 | if end - start < MINIMAL_SEGMENT_LENGHTS: 40 | print(f"WTF! whisper gave {(start,end)=}") 41 | start = end - MINIMAL_SEGMENT_LENGHTS 42 | # end = min(audio_dur, start + 0.04) 43 | 44 | return (start, end) 45 | 46 | 47 | @beartype 48 | def fix_whisper_segments( 49 | whisper_segments: NeList[tuple[float, float, str]], audio_dur: float 50 | ) -> StartEndTextsNonOverlap: 51 | 52 | start_end = [ 53 | fix_start_end((s, e), audio_dur) 54 | for s, e, text in whisper_segments 55 | if len(text) > 0 56 | # if s < audio_dur # TODO: one could filter for potentially hallucinated like this 57 | ] 58 | # start_ends_merged=[] 59 | # for s,e,t in zip(start_end,whisper_segments): 60 | 61 | start_end = fix_segments_to_non_overlapping(start_end) 62 | return [ 63 | (start, end, text) 64 | for (_, _, text), (start, end) in zip(whisper_segments, start_end) 65 | ] 66 | 67 | 68 | @dataclass(frozen=True) 69 | class WhisperArgs: 70 | task: Annotated[str, Is[lambda s: s in WHISPER_TASKS]] 71 | language: str = "de" 72 | temperature: Optional[Union[float, tuple[float, ...], list[float]]] = ( 73 | 0.0, 74 | 0.2, 75 | 0.4, 76 | 0.6, 77 | 0.8, 78 | 1.0, 79 | ) # this is default in whisper code 80 | # don't mess with the temperatures! they are needed for fallback if beam-search fails! 81 | beam_size: Optional[int] = None # default=5 see whisper code 82 | 83 | 84 | WHISPER_TASKS = {"transcribe", "translate"} 85 | 86 | 87 | @dataclass 88 | class WhisperInferencer(BuildableData, AudioArray2SegmentedTranscripts): 89 | whisper_args: Optional[WhisperArgs] = None 90 | 91 | @beartype 92 | def audio_to_segmented_transcripts( 93 | self, audio_array: NpFloatDim1 94 | ) -> StartEndTextsNonOverlap: 95 | return self.predict_transcribed_with_whisper_args( 96 | audio_array, self.whisper_args 97 | ) 98 | 99 | @beartype 100 | def predict_transcribed_with_whisper_args( 101 | self, audio_array: NpFloatDim1, whisper_args: WhisperArgs 102 | ) -> StartEndTextsNonOverlap: 103 | raise NotImplementedError 104 | -------------------------------------------------------------------------------- /ml4audio/audio_data/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/audio_data/hf_speech_iterable_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Tuple 6 | 7 | from beartype import beartype 8 | 9 | from data_io.readwrite_files import ( 10 | read_lines, 11 | ) 12 | from misc_utils.cached_data import CachedData 13 | from misc_utils.dataclass_utils import UNDEFINED 14 | from misc_utils.prefix_suffix import BASE_PATHES, PrefixSuffix 15 | from ml4audio.audio_data.targz_asr_dataset import ( 16 | TarGzASRCorpus, 17 | TarGzTranscripts, 18 | TarGzArrayTextWithSize, 19 | ) 20 | 21 | HF_DATASETS = "huggingface_cache/datasets" 22 | 23 | 24 | from nemo.utils import logging 25 | 26 | logging.disabled = True 27 | logging._logger = None # TODO: WTF!! nemo is logging warnings at error level!! 28 | 29 | 30 | @dataclass 31 | class HFTarGzTranscripts(TarGzTranscripts): 32 | """ 33 | TODO: rename to CommonVoiceTarGzTranscripts 34 | """ 35 | 36 | def contains_transcript(self, member: tarfile.TarInfo) -> bool: 37 | file_name = Path(member.name).stem 38 | return file_name in self.split_names 39 | 40 | @beartype 41 | def build_id_transcripts( 42 | self, split_name: str, transcript_files: list[str] 43 | ) -> list[Tuple[str, str]]: 44 | tsv_file = next( 45 | filter(lambda s: s.endswith(f"{split_name}.tsv"), transcript_files) 46 | ) 47 | lines_g = read_lines(tsv_file) 48 | header = next(lines_g).split("\t") 49 | data = [{k: v for k, v in zip(header, l.split("\t"))} for l in lines_g] 50 | return [(d["path"], d["sentence"]) for d in data] 51 | 52 | def build_transcript_file_name(self, member_name: str) -> str: 53 | s = member_name.split("huggingface/datasets/downloads/extracted")[-1] 54 | return s.replace("/", "__") 55 | 56 | 57 | @dataclass 58 | class HFIterableDataset(TarGzASRCorpus): 59 | def is_audiofile(self, member_name: str) -> bool: 60 | return member_name.endswith(".mp3") 61 | 62 | def audiofile_to_id(self, member_name: str) -> str: 63 | return Path(member_name).name 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | base_path = os.environ["BASE_PATH"] 69 | cache_root = f"{base_path}/data/cache" 70 | BASE_PATHES["base_path"] = base_path 71 | BASE_PATHES["cache_root"] = cache_root 72 | BASE_PATHES["raw_data"] = PrefixSuffix("cache_root", "RAW_DATA") 73 | 74 | corpus = TarGzArrayTextWithSize( 75 | corpus=HFIterableDataset( 76 | targztranscripts=HFTarGzTranscripts( 77 | targz_file=str( 78 | PrefixSuffix( 79 | "base_path", 80 | "/data/ASR_DATA/COMMON_VOICE/cv-corpus-10.0-2022-07-04-es.tar.gz", 81 | ) 82 | ), 83 | ), 84 | split="dev", 85 | ), 86 | sample_rate=16000 87 | # limit=10 88 | ).build() 89 | -------------------------------------------------------------------------------- /ml4audio/audio_data/mls_corpora.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | from beartype import beartype 6 | 7 | from data_io.readwrite_files import read_lines 8 | from misc_utils.beartypes import NeStr 9 | from ml4audio.audio_data.targz_asr_dataset import TarGzASRCorpus, TarGzTranscripts 10 | 11 | 12 | @dataclass 13 | class MLSTarGzTranscripts(TarGzTranscripts): 14 | def contains_transcript(self, member: tarfile.TarInfo) -> bool: 15 | return member.name.endswith("transcripts.txt") 16 | 17 | @beartype 18 | def build_id_transcripts( 19 | self, split_name: str, transcript_files: list[str] 20 | ) -> list[tuple[str, NeStr]]: # NeStr too strict? 21 | t_file = next( 22 | filter( 23 | lambda s: s.endswith(f"{split_name}/transcripts.txt"), transcript_files 24 | ) 25 | ) 26 | 27 | @beartype 28 | def parse_line(l: str) -> tuple[str, NeStr]: 29 | eid, transcript = l.split("\t") 30 | return eid, transcript 31 | 32 | return [parse_line(l) for l in read_lines(t_file)] 33 | 34 | def build_transcript_file_name(self, member_name: str) -> str: 35 | return member_name 36 | 37 | 38 | @dataclass 39 | class MLSIterableDataset(TarGzASRCorpus): 40 | def audiofile_to_id(self, member_name: str) -> str: 41 | return Path(member_name).name.replace(".flac", "") 42 | 43 | def is_audiofile(self, member_name: str) -> bool: 44 | return member_name.endswith(".flac") 45 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/aligned_transcript.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | from beartype import beartype 5 | from numpy.typing import NDArray 6 | from misc_utils.beartypes import NeNpFloatDim1 7 | 8 | 9 | @dataclass 10 | class TimestampedLetters: 11 | letters: str 12 | timestamps: NeNpFloatDim1 13 | 14 | def __post_init__(self): 15 | self.validate_data() 16 | 17 | def validate_data(self): 18 | strictly_increasing = np.all(np.diff(self.timestamps) >= 0) 19 | assert ( 20 | strictly_increasing 21 | ), f"{self.timestamps=}\n{np.argwhere(np.diff(self.timestamps)<=0)}" 22 | assert len(self.letters) == len(self.timestamps) 23 | 24 | def __len__(self): 25 | return len(self.letters) 26 | 27 | @beartype 28 | def slice(self, those: NDArray[int]): 29 | those = those.squeeze(1) 30 | sliced = TimestampedLetters( 31 | "".join([self.letters[i] for i in those]), self.timestamps[those] 32 | ) 33 | return sliced 34 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/convert_video_to_mp3.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Iterable, Union 4 | 5 | from tqdm import tqdm 6 | 7 | from ml4audio.audio_utils.audio_io import extract_streams_from_video_file 8 | from misc_utils.cached_data import CachedData 9 | from misc_utils.dataclass_utils import _UNDEFINED, UNDEFINED 10 | from misc_utils.prefix_suffix import PrefixSuffix 11 | 12 | 13 | @dataclass 14 | class ExtractedMp3s(CachedData, Iterable[str]): 15 | name: Union[_UNDEFINED, str] = UNDEFINED 16 | path: Union[_UNDEFINED, str] = UNDEFINED 17 | audio_files: list[str] = field(init=False, default_factory=lambda: []) 18 | cache_base: PrefixSuffix = field( 19 | default_factory=lambda: PrefixSuffix("processed_data", "extracted_mp3s") 20 | ) 21 | 22 | def build_audio_cmd(self, af, k, vf): 23 | # cmd = f'ffmpeg -i "{vf}" -y -filter_complex "[0:a:{k}]channelsplit=channel_layout=stereo[left][right]" -map "[left]" -c:a libopus -ar 16000 -ac 1 {af}_left.opus.ogg -map "[right]" -c:a libopus -ar 16000 -ac 1 {af}_right.opus.ogg' 24 | cmd = f'ffmpeg -i "{vf}" -y -map 0:a:{k} -q:a 0 -ac 1 -ar 16000 "{af}.mp3"' 25 | return cmd 26 | 27 | def _build_cache(self): 28 | for p in tqdm(Path(self.path).rglob("*.mp4")): 29 | audio_files = extract_streams_from_video_file( 30 | str(p), 31 | audio_file_target_folder=self.prefix_cache_dir("mp3s"), 32 | build_cmd_fun=self.build_audio_cmd, 33 | ) 34 | self.audio_files.extend(audio_files) 35 | 36 | def __iter__(self): 37 | for f in self.audio_files: 38 | yield f"{f}.mp3" 39 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/nemo_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from beartype import beartype 3 | 4 | from misc_utils.beartypes import NpFloatDim1 5 | from nemo.collections.asr.models import EncDecSpeakerLabelModel 6 | from nemo.utils import logging 7 | from nemo_vad.nemo_offline_vad import NemoOfflineVAD 8 | 9 | 10 | @beartype 11 | def load_EncDecSpeakerLabelModel(pretrained_model: str) -> EncDecSpeakerLabelModel: 12 | """ 13 | based on: https://github.com/NVIDIA/NeMo/blob/ddd87197e94ca23ae54e641dc7784e64c00a43d6/examples/speaker_tasks/recognition/speaker_reco_finetune.py#L63 14 | """ 15 | if pretrained_model.endswith(".nemo"): 16 | logging.info(f"Using local speaker model from {pretrained_model}") 17 | model = EncDecSpeakerLabelModel.restore_from(restore_path=pretrained_model) 18 | elif pretrained_model.endswith(".ckpt"): 19 | logging.info(f"Using local speaker model from checkpoint {pretrained_model}") 20 | model = EncDecSpeakerLabelModel.load_from_checkpoint( 21 | checkpoint_path=pretrained_model 22 | ) 23 | else: 24 | logging.info("Using pretrained speaker recognition model from NGC") 25 | model = EncDecSpeakerLabelModel.from_pretrained(model_name=pretrained_model) 26 | return model 27 | 28 | 29 | @beartype 30 | def nemo_offline_vad_to_cut_away_noise( 31 | vad: NemoOfflineVAD, array: NpFloatDim1, SR: int = 16_000 32 | ) -> NpFloatDim1: 33 | start_ends, probas = vad.predict(array) 34 | if len(start_ends) == 0: 35 | # assuming that VAD fugedup so fallback to no-vad 36 | noise_free_array = array 37 | else: 38 | noise_free_array = np.concatenate( 39 | [array[round(s * SR) : round(e * SR)] for s, e in start_ends], axis=0 40 | ) 41 | return noise_free_array 42 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/pyaudio_streaming.py: -------------------------------------------------------------------------------- 1 | import wave 2 | from typing import Iterator 3 | 4 | CHUNK_SIZE = 1000 5 | 6 | 7 | def build_pyaudio_stream( 8 | rate=16000, num_seconds_to_record=4, chunk_len=1.0 9 | ) -> Iterator[bytes]: 10 | import pyaudio 11 | 12 | chunk_size = 2 * round(chunk_len * rate) # 16bit need 2 bytes 13 | pyaudio_object = pyaudio.PyAudio() 14 | stream: pyaudio.Stream = pyaudio_object.open( 15 | channels=1, 16 | format=pyaudio.paInt16, # TODO 17 | rate=rate, 18 | input=True, 19 | frames_per_buffer=chunk_size, 20 | ) 21 | 22 | try: 23 | for _ in range(int(num_seconds_to_record * rate / chunk_size)): 24 | yield stream.read(chunk_size) 25 | finally: 26 | stream.close() 27 | pyaudio_object.terminate() 28 | 29 | 30 | def pyaudio_play_stream_from_file(audio_file): 31 | import pyaudio 32 | 33 | with wave.open(audio_file, "rb") as wf: 34 | 35 | p = pyaudio.PyAudio() 36 | formatt = p.get_format_from_width(wf.getsampwidth()) 37 | stream = p.open( 38 | format=formatt, 39 | channels=wf.getnchannels(), 40 | rate=wf.getframerate(), 41 | output=True, 42 | ) 43 | 44 | data = wf.readframes(CHUNK_SIZE) 45 | 46 | # play stream (3) 47 | while len(data) > 0: 48 | stream.write(data) 49 | data = wf.readframes(CHUNK_SIZE) 50 | 51 | # stop stream (4) 52 | stream.stop_stream() 53 | stream.close() 54 | 55 | # close PyAudio (5) 56 | p.terminate() 57 | -------------------------------------------------------------------------------- /ml4audio/audio_utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from misc_utils.prefix_suffix import BASE_PATHES, PrefixSuffix 5 | 6 | TEST_RESOURCES = "tests/resources" 7 | BASE_PATHES["test_resources"] = TEST_RESOURCES 8 | 9 | 10 | def get_test_cache_base(): 11 | cache_base = PrefixSuffix("test_resources", "cache") 12 | if ( 13 | os.path.isdir(str(cache_base)) 14 | and not os.environ.get("DONT_REMOVE_TEST_CACHE", "False") != "False" 15 | ): 16 | shutil.rmtree(str(cache_base)) 17 | os.makedirs(str(cache_base), exist_ok=True) 18 | return cache_base 19 | 20 | 21 | def get_test_vocab(): 22 | return f""" 23 | 24 | 25 | 26 | | 27 | E 28 | T 29 | A 30 | O 31 | N 32 | I 33 | H 34 | S 35 | R 36 | D 37 | L 38 | U 39 | M 40 | W 41 | C 42 | F 43 | G 44 | Y 45 | P 46 | B 47 | V 48 | K 49 | ' 50 | X 51 | J 52 | Q 53 | Z""".split( 54 | "\n" 55 | ) 56 | -------------------------------------------------------------------------------- /ml4audio/service_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/service_utils/fastapi_utils.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from tempfile import NamedTemporaryFile 4 | from typing import Union 5 | 6 | import numpy as np 7 | from beartype import beartype 8 | from fastapi import UploadFile, HTTPException 9 | from starlette.datastructures import UploadFile as starlette_UploadFile 10 | 11 | from misc_utils.beartypes import NpFloatDim1, NumpyFloat32_1D, Dataclass 12 | from misc_utils.buildable import Buildable 13 | from misc_utils.dataclass_utils import encode_dataclass, decode_dataclass 14 | 15 | 16 | _UploadFile = Union[UploadFile, starlette_UploadFile] 17 | 18 | 19 | @dataclass 20 | class DictPredictor: 21 | @abstractmethod 22 | def predict(self, data: dict) -> dict: 23 | raise NotImplementedError 24 | 25 | 26 | @dataclass 27 | class DataclassPredictor(Buildable): 28 | @abstractmethod 29 | def predict(self, data: Dataclass) -> Dataclass: 30 | raise NotImplementedError 31 | 32 | 33 | @dataclass 34 | class DataclassEncoderDecoderPredictorWrapper(Buildable, DictPredictor): 35 | predictor: DataclassPredictor 36 | 37 | def predict(self, data: dict) -> dict: 38 | return encode_dataclass(self.predictor.predict(decode_dataclass(data))) 39 | 40 | 41 | @beartype 42 | async def read_uploaded_audio_file( 43 | file: _UploadFile, SR: int = 16000 44 | ) -> NumpyFloat32_1D: 45 | # TODO: cannot typehint from fastapi import UploadFile cause it hands in UploadFile from starlette! 46 | from ml4audio.audio_utils.audio_io import ffmpeg_torch_load 47 | 48 | if not file: 49 | raise HTTPException(status_code=400, detail="Audio bytes expected") 50 | 51 | def save_file(filename, data): 52 | with open(filename, "wb") as f: 53 | f.write(data) 54 | 55 | with NamedTemporaryFile(delete=False, suffix=".wav") as tmp_original: 56 | # data_bytes = file.file.read() # if in synchronous context otherwise just file 57 | data_bytes = await file.read() # if in Asynchronous context 58 | save_file(tmp_original.name, data_bytes) 59 | 60 | raw_audio = ffmpeg_torch_load(tmp_original.name, target_sample_rate=SR).numpy() 61 | audio = raw_audio.astype(np.float32) 62 | return audio 63 | 64 | 65 | def get_full_model_config(asr_inferencer): 66 | return encode_dataclass( 67 | asr_inferencer, 68 | skip_keys=[ 69 | "_id_", 70 | "_target_", 71 | "cache_base", 72 | "cache_dir", 73 | "prefix", 74 | "use_hash_suffix", 75 | ], 76 | ) 77 | -------------------------------------------------------------------------------- /ml4audio/text_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/text_processing/asr_metrics.py: -------------------------------------------------------------------------------- 1 | import jiwer as jiwer 2 | from beartype import beartype 3 | 4 | from misc_utils.beartypes import NeList, NeStr 5 | 6 | 7 | @beartype 8 | def character_error_rates(refs: NeList[str], hyps: NeList[str]) -> dict[str, float]: 9 | cho = jiwer.process_characters(refs, hyps) 10 | num_chars = sum([len(r) for r in cho.references]) 11 | return { 12 | "cer": cho.cer, 13 | "insr": cho.insertions / num_chars, 14 | "delr": cho.deletions / num_chars, 15 | "subr": cho.substitutions / num_chars, 16 | # "hit": cho.hits / num_chars, # who cares about a hit-rate? 17 | } 18 | 19 | 20 | @beartype 21 | def word_error_rates(refs: NeList[str], hyps: NeList[str]) -> dict[str, float]: 22 | who = jiwer.process_words(refs, hyps) 23 | num_words = sum(len(r) for r in who.references) 24 | assert num_words == who.hits + who.deletions + who.substitutions 25 | return { 26 | "wer": who.wer, 27 | "insr": who.insertions / num_words, 28 | "delr": who.deletions / num_words, 29 | "subr": who.substitutions / num_words, 30 | # "hit": who.hits / num_words, 31 | } 32 | 33 | 34 | @beartype 35 | def calc_cer(refs: NeList[NeStr], hyps: NeList[str]) -> float: 36 | return character_error_rates(refs, hyps)["cer"] 37 | 38 | 39 | @beartype 40 | def calc_wer(refs: NeList[NeStr], hyps: NeList[str]) -> float: 41 | return word_error_rates(refs, hyps)["wer"] 42 | 43 | 44 | @beartype 45 | def micro_avg_asr_scores( 46 | refs_hyps: NeList[tuple[NeStr, str]] 47 | ) -> dict[str, dict[str, float]]: 48 | refs, hyps = [list(x) for x in zip(*refs_hyps)] 49 | return { 50 | "word": word_error_rates(refs, hyps), 51 | "char": character_error_rates(refs, hyps), 52 | } 53 | -------------------------------------------------------------------------------- /ml4audio/text_processing/character_mappings/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /ml4audio/text_processing/character_mappings/cyrillic_character_maps.py: -------------------------------------------------------------------------------- 1 | # tilo very pragmatically just wrote down some easily-mappable letters, no need to be perfect, just used as backup 2 | RECOVER_CYRILLIC = { 3 | # "jo": "ё", 4 | "a": "а", 5 | "b": "б", 6 | "v": "в", 7 | "g": "г", 8 | "d": "д", 9 | "e": "е", 10 | # "sh": "ж", 11 | # "s": "з", 12 | "i": "и", 13 | "j": "й", 14 | "k": "к", 15 | "l": "л", 16 | "m": "м", 17 | "n": "н", 18 | "o": "о", 19 | "p": "п", 20 | "r": "р", 21 | "s": "с", 22 | "t": "т", 23 | "u": "у", 24 | "f": "ф", 25 | "h": "х", 26 | "z": "ц", 27 | } # "ъ", "ы", "ь", "э", 28 | 29 | NO_JO = { 30 | "ё": "е", 31 | "ë": "е", 32 | } 33 | 34 | if __name__ == "__main__": 35 | 36 | for k, v in NO_JO.items(): 37 | print(f"{k}: {k.encode('utf-8')} -> {v}: {v.encode('utf-8')}") 38 | 39 | for k, v in RECOVER_CYRILLIC.items(): 40 | print(f"{k}: {k.encode('utf-8')} -> {v}: {v.encode('utf-8')}") 41 | -------------------------------------------------------------------------------- /ml4audio/text_processing/character_mappings/latin_character_maps.py: -------------------------------------------------------------------------------- 1 | """ 2 | see: https://www.businessballs.com/glossaries-and-terminology/accents-and-diacritical-marks/ 3 | é - accent acute 4 | è - accent grave 5 | ê - circumflex 6 | ë - umlaut or diaerisis 7 | ç - cedilla 8 | ñ - tilde 9 | ø - streg 10 | ð - eth (capital form Ð) 11 | å - bolle 12 | æ - ligature 13 | œ - ligature 14 | ē - macron 15 | č - háček 16 | ŭ - crescent 17 | 18 | TODO: what about upper-cased letters!! 19 | """ 20 | 21 | # assuming that this backward accent is just typo 22 | import string 23 | 24 | remove_backward_accent = { 25 | "à": "a", 26 | "è": "e", 27 | "ì": "i", 28 | "ò": "o", 29 | "ù": "u", 30 | } 31 | # TODO!! 32 | map_to_a = {k: "a" for k in ["ã", "ǎ", "á", "ă", "â", "å", "ā", "á", "à", "ą"]} 33 | map_to_A = {k: "A" for k in ["Á", "Ǎ", "À", "Å", "Â", "Ā"]} 34 | 35 | # hats, circumflex 36 | remove_hats = { 37 | "â": "a", 38 | "ê": "e", 39 | "ô": "o", 40 | "î": "i", 41 | "û": "u", 42 | } 43 | 44 | remove_tilde = { 45 | "ã": "a", 46 | "ñ": "n", 47 | } 48 | remove_flat = { 49 | "ō": "o", 50 | "ē": "e", 51 | } 52 | 53 | 54 | remove_accent = { 55 | # accent acute 56 | "ń": "n", 57 | "é": "e", # wtf didn't have this! 58 | } 59 | 60 | remove_diaeresis = { 61 | "ä": "a", 62 | "ë": "e", 63 | "ï": "i", 64 | "ö": "o", 65 | "ü": "u", 66 | } 67 | map_ligature = { 68 | "æ": "a", 69 | "œ": "o", 70 | } 71 | 72 | remove_reverse_hat = { 73 | "č": "c", 74 | "ŭ": "u", 75 | } 76 | 77 | strange_stuff = { 78 | # circle, bolle 79 | "å": "a", 80 | "ø": "o", 81 | "ç": "c", 82 | "ß": "s", # TODO: one or two s? 83 | } 84 | 85 | all_kinds_of_apostrophes = "'’‘`´ʹʻʼʽʿˈ" # also map itself, -> identity mapping to overwrite potentially removing mappings that came before 86 | NORMALIZE_APOSTROPHES = {c: "'" for c in all_kinds_of_apostrophes} 87 | NORMALIZE_DASH = {"–": "-", "-": "-"} # utf8: b'\xe2\x80\x93' 88 | 89 | REMOVE_EVERYTHING = ( 90 | remove_backward_accent 91 | | remove_hats 92 | | remove_tilde 93 | | remove_flat 94 | | remove_accent 95 | | remove_diaeresis 96 | | map_ligature 97 | | remove_reverse_hat 98 | | strange_stuff 99 | ) 100 | 101 | not_apostrophes_what_to_call_them = "„“”" 102 | string_punctuation = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation 103 | PUNCTUATION = string_punctuation + not_apostrophes_what_to_call_them 104 | REPLACE_ALL_PUNCT_WITH_SPACE = {key: " " for key in PUNCTUATION} 105 | -------------------------------------------------------------------------------- /ml4audio/text_processing/character_mappings/not_str_translatable_maps.py: -------------------------------------------------------------------------------- 1 | SAME_SAME_BUT_DIFFERENT = { 2 | # same-same but different, strangely they have len of 2, so cannot simply map them 3 | "ä": "ä", 4 | "ü": "ü", 5 | "ö": "ö", 6 | "Ä": "Ä", 7 | "Ü": "Ü", 8 | "Ö": "Ö", 9 | } 10 | -------------------------------------------------------------------------------- /ml4audio/text_processing/kenlm_arpa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from abc import abstractmethod 5 | from dataclasses import dataclass, field 6 | from pathlib import Path 7 | from typing import Union, Optional, Annotated 8 | 9 | import sys 10 | from beartype import beartype 11 | from beartype.vale import Is 12 | 13 | from data_io.readwrite_files import read_lines 14 | from misc_utils.beartypes import NeStr 15 | from misc_utils.cached_data import CachedData 16 | 17 | # based on: https://github.com/mozilla/DeepSpeech/blob/master/data/lm/generate_lm.py 18 | from misc_utils.dataclass_utils import ( 19 | UNDEFINED, 20 | _UNDEFINED, 21 | ) 22 | from misc_utils.prefix_suffix import BASE_PATHES, PrefixSuffix 23 | from ml4audio.text_processing.word_based_text_corpus import WordBasedLMCorpus 24 | 25 | # TODO: move this to its own package? cause it depends on kenlm 26 | 27 | 28 | @dataclass 29 | class ArpaArgs: 30 | order: int = 3 31 | max_memory: str = "80%" 32 | prune: str = "0|8|9" 33 | kenlm_bin: str = "/opt/kenlm/bin" 34 | vocab_size: Optional[int] = None 35 | 36 | 37 | arpa_suffixes = [".arpa.gz", ".arpa", ".gz"] # TODO: WTF! who calls a arpa "lm.gz"? 38 | ArpaFile = Annotated[ 39 | str, 40 | Is[lambda s: any(s.endswith(suffix) for suffix in arpa_suffixes)], 41 | ] 42 | 43 | 44 | class GotArpaFile: 45 | name: NeStr 46 | arpa_filepath: ArpaFile = field(init=False) 47 | 48 | 49 | @dataclass 50 | class AnArpaFile(GotArpaFile): 51 | arpa_filepath: ArpaFile = field(init=True) 52 | 53 | def __post_init__(self): 54 | self.name = Path(self.arpa_filepath).name 55 | 56 | 57 | @dataclass 58 | class ArpaBuilder(CachedData, GotArpaFile): 59 | arpa_args: Union[_UNDEFINED, ArpaArgs] = UNDEFINED 60 | corpus: Union[_UNDEFINED, WordBasedLMCorpus] = UNDEFINED 61 | cache_base: PrefixSuffix = field(default_factory=lambda: BASE_PATHES["lm_models"]) 62 | 63 | @property 64 | def name(self): 65 | return f"arpa-{self.corpus.name}" 66 | 67 | @property 68 | def arpa_filepath(self) -> str: 69 | return self.prefix_cache_dir("lm.arpa") 70 | 71 | def _build_cache(self): 72 | corpus_file, word_counts_file = ( 73 | self.corpus.corpus_filepath, 74 | self.corpus.word_counts_filepath, 75 | ) 76 | if word_counts_file is not None: 77 | vocab_str = "\n".join( 78 | l.split("\t")[0] 79 | for l in read_lines(word_counts_file, limit=self.arpa_args.vocab_size) 80 | ) 81 | else: 82 | vocab_str = None 83 | 84 | build_kenlm_arpa( 85 | self.arpa_args, 86 | str(self.cache_dir), 87 | self.arpa_filepath, 88 | corpus_file, 89 | vocab_str, 90 | ) 91 | assert os.path.isfile(self.arpa_filepath), f"could build {self.arpa_filepath=}" 92 | 93 | 94 | @beartype 95 | def build_kenlm_arpa( 96 | args: ArpaArgs, 97 | output_dir: str, 98 | arpa_file: str, 99 | text_file: str, 100 | vocab_str: Optional[str] = None, 101 | ): 102 | print("\nCreating ARPA file ...") 103 | os.makedirs(output_dir, exist_ok=True) 104 | subargs = [ 105 | os.path.join(args.kenlm_bin, "lmplz"), 106 | "--order", 107 | str(args.order), 108 | "--temp_prefix", 109 | output_dir, 110 | "--memory", 111 | args.max_memory, 112 | "--text", 113 | text_file, 114 | "--arpa", 115 | arpa_file, 116 | "--prune", 117 | *args.prune.split("|"), 118 | "--skip_symbols", 119 | "--discount_fallback", 120 | ] 121 | subprocess.check_call(subargs, stdout=sys.stdout, stderr=sys.stdout) 122 | 123 | if vocab_str is not None: 124 | # Filter LM using vocabulary of top-k words 125 | print("\nFiltering ARPA file using vocabulary of top-k words ...") 126 | arpa_file_unfiltered = f"{output_dir}/lm_unfiltered.arpa" 127 | shutil.copy(arpa_file, arpa_file_unfiltered) 128 | 129 | subprocess.run( 130 | [ 131 | os.path.join(args.kenlm_bin, "filter"), 132 | "single", 133 | f"model:{arpa_file_unfiltered}", 134 | arpa_file, 135 | ], 136 | input=vocab_str.encode("utf-8"), 137 | check=True, 138 | ) 139 | -------------------------------------------------------------------------------- /ml4audio/text_processing/pretty_diff.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from beartype import beartype 4 | 5 | from ml4audio.text_processing.smith_waterman_alignment import align_split 6 | 7 | 8 | @beartype 9 | def smithwaterman_aligned_icdiff( 10 | ref: str, 11 | hyp: str, 12 | split_len_a=70, 13 | ref_header: Optional[str] = "ref", 14 | hyp_header: Optional[str] = "hyp", 15 | ) -> str: 16 | import icdiff 17 | 18 | refs, hyps = align_split(ref, hyp, split_len_a=split_len_a, debug=False) 19 | cd = icdiff.ConsoleDiff(cols=2 * split_len_a + 20) 20 | 21 | diff_line = "\n".join( 22 | cd.make_table( 23 | refs, 24 | hyps, 25 | ref_header, 26 | hyp_header, 27 | ) 28 | ) 29 | return diff_line 30 | 31 | 32 | if __name__ == "__main__": 33 | ref = "NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AND BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WETTING THE OPEN PAGE BEFORE HIM WITH HIS TEARS PUSHING INTO THE WE HOURS OF THE NIGHT HIS QUEST ANIMATED BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES" 34 | hyp = "NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AN BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WET IN THE OPEN PAGE BAFORE HIM WITH HIS TEARS PUSHING INTO THE WEE HOURS OF THE NIGHT HIS QUEST AND BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES" 35 | 36 | print(smithwaterman_aligned_icdiff(ref, hyp, ref_header=None, hyp_header=None)) 37 | -------------------------------------------------------------------------------- /nemo_language_classification/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /nemo_language_classification/benchmark_lang_clf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from beartype import beartype 4 | 5 | from data_io.readwrite_files import read_jsonl 6 | from misc_utils.beartypes import NeList 7 | from ml4audio.audio_utils.audio_io import read_audio_chunks_from_file 8 | from functools import partial 9 | from itertools import groupby 10 | from itertools import islice 11 | from typing import Dict, List, Tuple 12 | 13 | from nemo_language_classification.nemo_lang_clf import NemoLangClf 14 | 15 | from pprint import pprint 16 | 17 | from sklearn import metrics 18 | from tqdm import tqdm 19 | 20 | 21 | def get_data( 22 | val_manifest=f"{os.environ['BASE_PATH']}/data/lang_clf_data/validation_manifest.jsonl", 23 | ): 24 | data = list(read_jsonl(val_manifest)) 25 | lang2data = { 26 | k: list(islice(g, 100)) 27 | for k, g in groupby( 28 | sorted(data, key=lambda x: x["label"]), lambda x: x["label"] 29 | ) 30 | } 31 | lang2data = {"en": lang2data["en"], "de": lang2data["de"]} 32 | data = [d for g in lang2data.values() for d in g] 33 | return data 34 | 35 | 36 | def get_max(o: Dict[str, float]): 37 | label, value = max(o.items(), key=lambda x: x[1]) 38 | return label 39 | 40 | 41 | @beartype 42 | def benchmark_lang_clf( 43 | mdl: NemoLangClf, wavfile_label: NeList[tuple[str, str]] 44 | ) -> dict: 45 | sr = 16000 46 | chunk_dur = 8 47 | id_pred_targets = [ 48 | (f"{wav_file}-{k}", get_max(mdl.predict(chunk)), target) 49 | for wav_file, target in tqdm(wavfile_label) 50 | for k, chunk in enumerate( 51 | read_audio_chunks_from_file(wav_file, sr, chunk_duration=chunk_dur) 52 | ) 53 | if len(chunk) > (chunk_dur / 2) * sr 54 | ] 55 | eids, preds, targets = (list(x) for x in zip(*id_pred_targets)) 56 | clf_report = metrics.classification_report( 57 | y_true=targets, 58 | y_pred=preds, 59 | # labels=target_names, 60 | digits=3, 61 | output_dict=True, 62 | ) 63 | return clf_report 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | # val_manifest=f"{os.environ['BASE_PATH']}/data/lang_clf_data/train_manifest.jsonl" 69 | # model_file="{BASE_PATH}/.../end2end-asr/nemo_experiments/SpeakerNet/2021-07-22_12-01-52/checkpoints/SpeakerNet--val_loss=8.97-epoch=0-last.ckpt" 70 | # "{os.environ['BASE_PATH']}/results/TRAINING/LANG_CLF/debug/SpeakerNet/2021-07-23_10-14-04/checkpoints/SpeakerNet--val_loss=6.84-epoch=1-last.ckpt" 71 | 72 | data = get_data( 73 | val_manifest=f"{os.environ['BASE_PATH']}/data/AUDIO_DATA/lang_clf_data_7lanuages/test_manifest.jsonl" 74 | ) 75 | input_output = [ 76 | ( 77 | d["audio_filepath"].replace("data/huggingface/", "huggingface_cache/"), 78 | d["label"], 79 | ) 80 | for d in data 81 | ] 82 | pprint( 83 | benchmark_lang_clf( 84 | NemoLangClf( 85 | model_file=f"{os.environ['BASE_PATH']}/iais_code/ml4audio/nemo_experiments/titanet-finetune-lang-clf/2022-09-25_13-12-59/checkpoints/titanet-finetune-lang-clf.nemo" 86 | ).build(), 87 | wavfile_label=input_output, 88 | ) 89 | ) 90 | # pprint(benchmark_fun(Wav2vecPyctcLangClf())) 91 | -------------------------------------------------------------------------------- /nemo_language_classification/finetune_lang_clf.py: -------------------------------------------------------------------------------- 1 | # based on: https://github.com/NVIDIA/NeMo/blob/aff169747378bcbcec3fc224748242b36205413f/examples/speaker_tasks/recognition/speaker_reco_finetune.py 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from omegaconf import OmegaConf 6 | from pytorch_lightning import seed_everything 7 | 8 | from nemo.collections.asr.models import EncDecSpeakerLabelModel 9 | from nemo.core.config import hydra_runner 10 | from nemo.utils import logging 11 | from nemo.utils.exp_manager import exp_manager 12 | from nemo_language_classification.prepare_lang_clf_splits import create_subset_manifest 13 | 14 | seed_everything(42) 15 | 16 | 17 | @hydra_runner(config_path="conf", config_name="titanet-finetune.yaml") 18 | def main(cfg): 19 | labels = ["en", "de", "es", "ru", "pt", "fr", "it"] 20 | cfg.model.train_ds.manifest_filepath = f"{os.environ['BASE_PATH']}/data/AUDIO_DATA/lang_clf_data_7lanuages/train_manifest.jsonl" 21 | cfg.model.validation_ds.manifest_filepath = f"{os.environ['BASE_PATH']}/data/AUDIO_DATA/lang_clf_data_7lanuages/validation_manifest.jsonl" 22 | # cfg.model.test_ds.manifest_filepath = f"{os.environ['BASE_PATH']}/data/AUDIO_DATA/lang_clf_data_7lanuages/test_manifest.jsonl" 23 | cfg.model.decoder.num_classes = len(labels) 24 | 25 | logging.info(f"Hydra config: {OmegaConf.to_yaml(cfg)}") 26 | trainer = pl.Trainer(**cfg.trainer) 27 | exp_man_cfg = cfg.get("exp_manager", None) 28 | _ = exp_manager(trainer, exp_man_cfg) 29 | mdl = EncDecSpeakerLabelModel(cfg=cfg.model, trainer=trainer) 30 | mdl.maybe_init_from_pretrained_checkpoint(cfg) 31 | mdl.cfg.train_ds.labels = labels # TODO(tilo):WTF! had to manually stick labels in there, so that I have "vocabulary" at inference-time 32 | 33 | trainer.fit(mdl) 34 | 35 | # if ( 36 | # hasattr(cfg.model, "test_ds") 37 | # and cfg.model.test_ds.manifest_filepath is not None 38 | # ): 39 | # trainer = pl.Trainer(devices=1, accelerator=cfg.trainer.accelerator) 40 | # if mdl.prepare_test(trainer): 41 | # trainer.test(mdl) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /nemo_language_classification/language_classification_data.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | # pylint: disable-all 3 | import json 4 | from typing import Optional 5 | from urllib.request import urlopen 6 | 7 | import sys 8 | 9 | from data_io.readwrite_files import write_jsonl 10 | from ml4audio.audio_utils.torchaudio_utils import torchaudio_info 11 | 12 | sys.path.append("") 13 | 14 | from itertools import islice 15 | from random import shuffle 16 | 17 | import os 18 | import shutil 19 | 20 | import datasets 21 | from tqdm import tqdm 22 | 23 | TARGET_SAMPLE_RATE = 16000 24 | 25 | 26 | def lang_clf_nemo_datum(d) -> Optional[dict]: 27 | file = d["path"] 28 | if os.path.isfile(file): 29 | 30 | # x, _sr = librosa.load(file, sr=TARGET_SAMPLE_RATE) 31 | # duration=librosa.get_duration(x, sr=_sr) 32 | num_frames, sample_rate, duration = torchaudio_info(file) 33 | return { 34 | "audio_filepath": file, 35 | "duration": duration, 36 | "label": d["locale"], 37 | "text": "_", # d["sentence"] 38 | "offset": 0.0, 39 | } 40 | else: 41 | return None 42 | 43 | 44 | # def copy_data(d): 45 | # shutil.copy(d["audio_filepath"], corpus_dir) 46 | 47 | 48 | def rm_mkdir(dirr): 49 | if os.path.isdir(dirr): 50 | shutil.rmtree(dirr) 51 | os.makedirs(dirr) 52 | 53 | 54 | assert ( 55 | "HF_DATASETS_CACHE" in os.environ 56 | ), f'do: export HF_DATASETS_CACHE="/path/to/another/directory"' 57 | assert "HF_HOME" in os.environ, f'do export HF_HOME="/somewhere/path/huggingface_cache"' 58 | 59 | # export HF_DATASETS_CACHE={os.environ['BASE_PATH']}/huggingface_cache/datasets 60 | # export HF_HOME={os.environ['BASE_PATH']}/huggingface_cache 61 | 62 | if __name__ == "__main__": 63 | cv_info_json = urlopen( 64 | "https://huggingface.co/datasets/common_voice/raw/main/dataset_infos.json" 65 | ).read() 66 | cv_info = json.loads(cv_info_json) 67 | cv_languages = cv_info.keys() 68 | print(f"{cv_languages=}") 69 | # assert False 70 | 71 | num_samples_per_lang = 10_000 72 | manifests_dir = f"{os.environ['BASE_PATH']}/data/AUDIO_DATA/lang_clf_data_7lanuages" 73 | rm_mkdir(manifests_dir) 74 | 75 | # it = iter(_LANGUAGES.keys()) 76 | languages_of_interest = ["en", "de", "es", "ru", "pt", "fr", "it"] 77 | 78 | for lang in languages_of_interest: 79 | 80 | for split_name in ["train", "validation", "test"]: 81 | try: 82 | ttm = num_samples_per_lang * 10 # so ttm == ten times more 83 | ds_ttm = datasets.load_dataset( 84 | "common_voice", 85 | lang, 86 | keep_in_memory=True, 87 | split=f"{split_name}[:{ttm}]", 88 | ) 89 | except: 90 | ds_ttm = datasets.load_dataset( 91 | "common_voice", lang, keep_in_memory=True, split=f"{split_name}" 92 | ) 93 | 94 | data_ttm = list( 95 | filter( 96 | lambda x: x is not None, (lang_clf_nemo_datum(d) for d in ds_ttm) 97 | ) 98 | ) 99 | shuffle( 100 | data_ttm 101 | ) # not sure whether common-voice data is already shuffled? I don't want only few speakers! 102 | data = list(islice(data_ttm, num_samples_per_lang)) 103 | # for d in data: 104 | # copy_data(d) 105 | 106 | write_jsonl( 107 | f"/{manifests_dir}/{split_name}_manifest.jsonl", 108 | tqdm(data, desc=f"{lang}: {split_name}"), 109 | mode="ab", 110 | ) 111 | -------------------------------------------------------------------------------- /nemo_language_classification/nemo_lang_clf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from beartype import beartype 4 | 5 | from misc_utils.beartypes import ( 6 | TorchTensorFloat2D, 7 | TorchTensorInt, 8 | NeDict, 9 | NumpyInt16Dim1, 10 | ) 11 | 12 | from dataclasses import dataclass 13 | from typing import Any 14 | 15 | import torch 16 | 17 | from misc_utils.buildable import Buildable 18 | import numpy as np 19 | from tqdm import tqdm 20 | 21 | from ml4audio.audio_utils.audio_io import MAX_16_BIT_PCM, read_audio_chunks_from_file 22 | from ml4audio.audio_utils.nemo_utils import load_EncDecSpeakerLabelModel 23 | 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | 27 | @dataclass 28 | class NemoLangClf(Buildable): 29 | model_file: str 30 | 31 | def _build_self(self) -> Any: 32 | self.model = load_EncDecSpeakerLabelModel(self.model_file) 33 | self.model.eval() 34 | self.model.to(device) 35 | 36 | @beartype 37 | @torch.no_grad() 38 | def predict(self, audio_array: NumpyInt16Dim1) -> NeDict[str, float]: 39 | labels = list(self.model.cfg.train_ds.labels) 40 | sig, sig_len = prepare_audio_signal(audio_array) 41 | logits, _ = self.model.forward( 42 | input_signal=sig.to(device), input_signal_length=sig_len.to(device) 43 | ) 44 | 45 | probs = torch.softmax(logits, dim=-1).squeeze() 46 | label2proba = {k: float(p.cpu().numpy()) for k, p in zip(labels, probs)} 47 | # class_idx = np.argmax(probs) 48 | # class_label = labels[class_idx] 49 | return label2proba 50 | 51 | 52 | @beartype 53 | def prepare_audio_signal( 54 | signal: NumpyInt16Dim1, 55 | ) -> tuple[TorchTensorFloat2D, TorchTensorInt]: 56 | signal = signal.squeeze() 57 | assert signal.dtype == np.int16 58 | signal = signal.astype(np.float32) / MAX_16_BIT_PCM 59 | size_tensor = torch.as_tensor([signal.size], dtype=torch.int64) 60 | return ( 61 | torch.as_tensor(signal, dtype=torch.float32).unsqueeze(0), 62 | size_tensor, 63 | ) 64 | 65 | 66 | if __name__ == "__main__": 67 | mdl = NemoLangClf( 68 | model_file=f"{os.environ['BASE_PATH']}/results/TRAINING/LANG_CLF/debug/SpeakerNet/2021-07-23_10-14-04/checkpoints/SpeakerNet--val_loss=6.84-epoch=1-last.ckpt" 69 | ) 70 | mdl.build() 71 | input_sample_rate = 16000 72 | frame_duration = 4.0 73 | 74 | wav_file = "tests/resources/tuda_2015-02-03-13-51-36_Realtek.wav" 75 | for chunk in tqdm( 76 | read_audio_chunks_from_file( 77 | wav_file, input_sample_rate, chunk_duration=frame_duration 78 | ) 79 | ): 80 | print(mdl.predict(chunk)) 81 | -------------------------------------------------------------------------------- /nemo_language_classification/prepare_lang_clf_splits.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | from itertools import groupby 4 | from random import shuffle 5 | from typing import List 6 | 7 | from beartype import beartype 8 | 9 | from data_io.readwrite_files import write_jsonl, read_jsonl 10 | 11 | OTHER = "OTHER" 12 | 13 | 14 | def fix_path(d): 15 | original_filepath = d["audio_filepath"] 16 | 17 | if "huggingface/datasets" in original_filepath: 18 | _, audio_filepath_tail = original_filepath.split("huggingface/datasets") 19 | else: 20 | _, audio_filepath_tail = original_filepath.split("huggingface_cache/datasets") 21 | d["audio_filepath"] = f"{os.environ['HF_DATASETS_CACHE']}/{audio_filepath_tail}" 22 | return d 23 | 24 | 25 | @beartype 26 | def create_subset_manifest( 27 | base_manifest: str, 28 | only_these_labels: list[str], 29 | subset_manifest_file: str, 30 | num_per_label: int = 100, 31 | ) -> None: 32 | """ 33 | base_manifest contains all data with many labels 34 | eventuall creates OTHER-class by randomly selecting from class-labels not contained in labels 35 | """ 36 | data = list(read_jsonl(base_manifest)) 37 | gr = ( 38 | (k, list(g)) 39 | for k, g in groupby( 40 | sorted(data, key=lambda x: x["label"]), lambda x: x["label"] 41 | ) 42 | ) 43 | lang2data = {k: g[:num_per_label] for k, g in gr} 44 | data = [d for k in only_these_labels if k != OTHER for d in lang2data.pop(k)] 45 | 46 | if OTHER in only_these_labels: 47 | other_data = [d for g in lang2data.values() for d in g] 48 | shuffle(other_data) 49 | for d in other_data: 50 | d["label"] = OTHER 51 | data += other_data[: len(only_these_labels) * num_per_label] 52 | 53 | print(subset_manifest_file) 54 | print(Counter(d["label"] for d in data)) 55 | shuffle(data) 56 | write_jsonl(subset_manifest_file, map(fix_path, data)) 57 | -------------------------------------------------------------------------------- /nemo_language_classification/readme.md: -------------------------------------------------------------------------------- 1 | # language classification 2 | * setup 3 | ```shell 4 | pip install -r nemo_language_classification/requirements.txt 5 | ``` 6 | * train 7 | ```shell 8 | export BASE_PATH= 9 | export PYTHONPATH=${PWD} 10 | export HF_DATASETS_CACHE=${BASE_PATH}/data/huggingface_cache/datasets 11 | source secrets.env 12 | python nemo_language_classification/finetune_lang_clf.py 13 | 14 | 15 | ``` 16 | -------------------------------------------------------------------------------- /nemo_language_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | nemo_toolkit[asr] 2 | jiwer 3 | soundfile 4 | librosa 5 | # samplerate 6 | scipy 7 | # audiomentations 8 | # torch-audiomentations 9 | # pyloudnorm 10 | # homoglyphs 11 | # pandas 12 | matplotlib 13 | wandb -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-bullseye AS dependencies 2 | WORKDIR /code 3 | ENV APT_INSTALL="apt-get install -y --no-install-recommends" 4 | 5 | RUN apt-get update && \ 6 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 7 | build-essential \ 8 | ca-certificates \ 9 | wget \ 10 | git \ 11 | g++ \ 12 | cmake \ 13 | vim && \ 14 | # ================================================================== 15 | # clean up everything 16 | # ------------------------------------------------------------------ 17 | apt-get clean && \ 18 | apt-get -y autoremove && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | RUN python3 -m venv /venv && /venv/bin/pip install --no-cache-dir -U \ 22 | pip packaging setuptools tqdm beartype Cython 23 | 24 | COPY ./requirements.txt /code/requirements.txt 25 | RUN /venv/bin/pip install --no-cache-dir --upgrade -r /code/requirements.txt 26 | 27 | RUN rm -rf /venv/lib/python3.9/site-packages/sklearn/ensemble && \ 28 | rm -rf /venv/lib/python3.9/site-packages/grpc 29 | 30 | FROM dependencies AS build_models 31 | ENV PATH="/venv/bin:$PATH" 32 | COPY . /code/ 33 | RUN python /code/build_model.py 34 | # TODO: cannot remove the following! 35 | # rm -rf /venv/lib/python3.9/site-packages/tensorboard && \ -> pytorch-pytorch_lightning wants it! 36 | # rm -rf /venv/lib/python3.9/site-packages/onnx -> nemo wants this! ->malparido! 37 | # rm -rf /venv/lib/python3.9/site-packages/matplotlib -> nemo wants this! ->malparido! 38 | 39 | FROM python:3.9-slim-bullseye AS production 40 | LABEL maintainer="Tilo Himmelsbach" 41 | ENV PATH="/venv/bin:$PATH" 42 | WORKDIR /code 43 | COPY --from=build_models /venv /venv 44 | COPY --from=build_models /root/.cache /root/.cache 45 | 46 | ENV PYTHONFAULTHANDLER=1 47 | ENV HF_DATASETS_OFFLINE=1 48 | ENV TRANSFORMERS_OFFLINE=1 49 | 50 | EXPOSE 8000:8000 51 | COPY . /code/ 52 | 53 | CMD ["/bin/bash", "-c", "source /venv/bin/activate && \ 54 | uvicorn punctcap_fastapi_server:app --host 0.0.0.0 --port 8000"] 55 | 56 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/nemo_punctuation_capitalization/punctcap_service/__init__.py -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/build_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from nemo.collections.nlp.models import PunctuationCapitalizationModel 4 | 5 | if __name__ == "__main__": 6 | model_files = [str(p) for p in Path("/code").rglob("model.nemo")] 7 | nemo_model = model_files[0] 8 | inferencer = PunctuationCapitalizationModel.restore_from(nemo_model) 9 | default_query = "deutsche welle sometimes abbreviated to dw is a german public state-owned international broadcaster funded by the german federal tax budget the service is available in 32 languages dws satellite" 10 | result = inferencer.add_punctuation_capitalization([default_query]) 11 | print(f"{inferencer.cfg=}") 12 | print(f"{default_query=}") 13 | print(f"{result=}") 14 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/debug_punctcap_service.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | default_query = ( 4 | "nur leere drohungen oder ein realistisches szenario " 5 | "wirtschaftsminister robert habeck hält inzwischen nichts mehr für " 6 | "ausgeschlossen was würde es bedeuten wenn wladimir putin " 7 | "beschließt deutschland das gas über diese pipeline abzudrehen " 8 | "die wichtigsten antworten im überblick" 9 | ) 10 | 11 | if __name__ == "__main__": 12 | file = "{BASE_PATH}/data/cache/PROCESSED_DATA/NEMO_MODELS/NemoTrainedPunctuationCapitalizationModel-deu-1421ee5c9e895d0334f3d3c8a93d21eda0de2c61/nemo_exp_dir/model.nemo" 13 | f = open(file, "rb") 14 | port = 8000 15 | files = {"file": (f.name, f, "multipart/form-data")} 16 | requests.post(url=f"http://127.0.0.1:{port}/upload_modelfile", files=files) 17 | 18 | r = requests.post( 19 | url=f"http://127.0.0.1:{port}/predict", json={"text": default_query} 20 | ) 21 | print(r.text) 22 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/punctcap_fastapi_server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import Any, Dict, Optional 6 | 7 | import uvicorn 8 | from fastapi import FastAPI, UploadFile 9 | from nemo.collections.nlp.models import PunctuationCapitalizationModel 10 | from pydantic import BaseModel 11 | 12 | from misc_utils.dataclass_utils import ( 13 | encode_dataclass, 14 | ) 15 | from misc_utils.utils import just_try 16 | 17 | DEBUG = os.environ.get("DEBUG", "False").lower() != "false" 18 | if DEBUG: 19 | print("DEBUGGING MODE") 20 | 21 | app = FastAPI(debug=DEBUG) 22 | inferencer: Optional[PunctuationCapitalizationModel] = None 23 | 24 | 25 | @app.get("/get_inferencer_dataclass") 26 | def get_inferencer_dataclass() -> Dict[str, Any]: 27 | global inferencer 28 | if inferencer is not None: 29 | d = encode_dataclass(inferencer) 30 | else: 31 | d = {"response": "no model loaded yet!"} 32 | return d 33 | 34 | 35 | @app.post("/upload_modelfile") 36 | async def upload_modelfile(file: UploadFile): 37 | global inferencer 38 | 39 | def save_file(filename, data): 40 | with open(filename, "wb") as f: 41 | f.write(data) 42 | 43 | nemo_model_file = "model.nemo" 44 | save_file(nemo_model_file, await file.read()) 45 | 46 | just_try( 47 | lambda: load_nemo_model(nemo_model_file), default=None, reraise=True 48 | ) # TODO: here some 400er error if model-file does not pass the sanity_check 49 | return {"filename": file.filename} 50 | 51 | 52 | default_query = ( 53 | "nur leere drohungen oder ein realistisches szenario " 54 | "wirtschaftsminister robert habeck hält inzwischen nichts mehr für " 55 | "ausgeschlossen was würde es bedeuten wenn wladimir putin " 56 | "beschließt deutschland das gas über diese pipeline abzudrehen " 57 | "die wichtigsten antworten im überblick" 58 | ) 59 | 60 | 61 | class PunctuationCapitalizationRequest(BaseModel): 62 | # TODO(tilo): do I need pydantic at all if just simple str is sent? 63 | text: str 64 | 65 | class Config: 66 | schema_extra = { 67 | "example": { 68 | "text": default_query, 69 | } 70 | } 71 | 72 | 73 | @app.post("/predict") # TODO: response_model=SomeResponsePydanticDataModel 74 | async def predict(req: PunctuationCapitalizationRequest): 75 | global inferencer 76 | 77 | result = inferencer.add_punctuation_capitalization([req.text]) 78 | 79 | return {"text": result} 80 | 81 | 82 | def load_nemo_model(nemo_model="model.nemo"): 83 | """ 84 | Nur leere Drohungen oder ein realistisches Szenario. Wirtschaftsminister Robert 85 | Habeck hält inzwischen nichts mehr für ausgeschlossen. Was würde es bedeuten, 86 | wenn Wladimir Putin beschließt, Deutschland das Gas über diese Pipeline 87 | abzudrehen? Die wichtigsten Antworten im Überblick. 88 | """ 89 | global inferencer 90 | inferencer = PunctuationCapitalizationModel.restore_from(nemo_model) 91 | result = inferencer.add_punctuation_capitalization([default_query]) 92 | print(f"{inferencer.cfg=}") 93 | # print(f"DE-Result : {result}") 94 | 95 | 96 | @app.on_event("startup") 97 | async def startup_event(): 98 | model_files = [str(p) for p in Path("/code").rglob("model.nemo")] 99 | 100 | if len(model_files) > 0: 101 | load_nemo_model(model_files[0]) 102 | else: 103 | print(f"no model found in container, use /upload_modelfile") 104 | 105 | 106 | if __name__ == "__main__": 107 | """ 108 | #TODO: why is that necessary? 109 | export PYTHONPATH=${PYTHONPATH}:{BASE_PATH}/iais_code/NeMo 110 | """ 111 | uvicorn.run( 112 | "punctcap_fastapi_server:app", 113 | host="127.0.0.1", 114 | port=2700, 115 | # log_level="debug" 116 | ) 117 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/readme.md: -------------------------------------------------------------------------------- 1 | # fastapi+NeMo-based punctuation&capitalization-service 2 | ```commandline 3 | IMAGE=selmaproject/iais-punctcap:en 4 | DOCKER_BUILDKIT=1 docker build -t $IMAGE . 5 | 6 | docker run --rm -p 8000:8000 selmaproject/iais-punctcap:en 7 | 8 | curl -H "Content-Type: application/json;charset=UTF-8" -X POST -d '{"text":"deutsche welle sometimes abbreviated to dw is a german public state-owned international broadcaster funded by the german federal tax budget the service is available in 32 languages dws satellite"}' http://localhost:8000/predict 9 | 10 | text='Deutsche Welle, sometimes abbreviated to DW, is a German public, state-owned international broadcaster funded by the German federal tax budget. The service is available in 32 languages. DWs satellite' 11 | pred="Deutsche Welle, sometimes abbreviated to Dw, is a German public state-owned international broadcaster funded by the German federal tax budget the service is available in 32 languages Dws satellite" 12 | 13 | export MODEL=$HOME/data/cache/NEMO_PUNCTCAP_MODELS/NemoTrainedPunctuationCapitalizationModel-wiki-deu-dc4c2436c1d3f07b6a969ce1c28f43e19dc3221b/nemo_exp_dir/model.nemo 14 | 15 | curl -F file=@"${MODEL}" localhost:8000/upload_modelfile 16 | ``` 17 | 18 | * debugging 19 | ```commandline 20 | docker run -it --rm -p 8000:8000 -v ${PWD}:/code punctcap:latest bash 21 | uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload 22 | 23 | { echo "FROM scratch" ; echo "COPY . ."; CMD ["fake"] } > Dockerfile && \ 24 | export IMAGE=some_test:bla && \ 25 | docker build -t $IMAGE . 26 | 27 | ``` -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_service/requirements.txt: -------------------------------------------------------------------------------- 1 | misc-utils@git+https://github.com/dertilo/misc-utils.git#egg=misc-utils 2 | python-multipart 3 | nemo_toolkit[nlp] 4 | uvicorn[standard] 5 | fastapi -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/lenta_data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os.path 3 | from dataclasses import dataclass, field 4 | from typing import Iterable, Iterator, Any 5 | 6 | import pandas 7 | from data_io.download_extract_files import wget_file 8 | from misc_utils.buildable_data import BuildableData 9 | from misc_utils.prefix_suffix import PrefixSuffix, BASE_PATHES 10 | from pandas import Series 11 | 12 | """ 13 | this code is completely unused! 14 | def clean_russian_text_for_punctcap_training(text: str, punct_marks: str = ",?."): 15 | copypasted from tugtekins "russian_normalization_train"-method /mturan/russian-punct-casing/-/blob/master/local/get_lenta_data.py#L13 16 | # TODO(tilo): @tugtekin -> please explanation for: NFKC, re.sub, regex.sub, ! 17 | we need pytests! 18 | written2spoken before this cleaning method, OR adapt cleaning method as to keep numbers! 19 | Normalization and deRomanization of Russian data for training purpose 20 | Args: 21 | text: input text 22 | punct_marks: supported punctuation marks 23 | # \xa0 (i.e. chr(160)) creates problems in Lenta dataset 24 | # this is a non-breaking space in Latin1 (ISO 8859-1) 25 | # for other languages (with Latin alphabet) better to use 'NFKD' instead of 'NFKC' 26 | unicoded = unicodedata.normalize("NFKC", text) 27 | 28 | # remove links if exist 29 | no_URL = re.sub(r"http\S+", "", unicoded) 30 | 31 | # delete alphanumeric Latin letters except for Cyrillic (no transliteration) 32 | # also, remove all the punctuations except defined in 'punct_marks' 33 | match = "[^\s\p{IsCyrillic}" + punct_marks + "]" 34 | only_cyrillic_and_punctuations = regex.sub(match, "", no_URL) 35 | 36 | # remove repetitive whitespace 37 | normalized = " ".join(only_cyrillic_and_punctuations.split()) 38 | 39 | # replace punctuations with extra spaces 40 | # e.g. "hey . you are , okay ?" --> "hey. you are, okay?" 41 | remove_extra_space = {" ,": ",", " .": ".", " ?": "?"} 42 | for key, value in remove_extra_space.items(): 43 | normalized = normalized.replace(key, value) 44 | 45 | return normalized 46 | """ 47 | 48 | 49 | def got_text(d) -> bool: 50 | return ( 51 | isinstance(d, tuple) and isinstance(d[1], Series) and isinstance(d[1].text, str) 52 | ) 53 | 54 | 55 | @dataclass 56 | class LentaData(BuildableData, Iterable[str]): 57 | 58 | """ 59 | is simply downloading the lenta-ru-news.csv.bz2 file into the "russian_text_data"-folder 60 | """ 61 | 62 | punct_marks: str = ",?." # TODO: ! 63 | 64 | base_dir: PrefixSuffix = field( 65 | default_factory=lambda: BASE_PATHES["russian_text_data"] 66 | ) 67 | 68 | _bz2_file_url: str = field( 69 | init=False, 70 | default="https://github.com/yutkin/Lenta.Ru-News-Dataset/releases/download/v1.1/lenta-ru-news.csv.bz2", 71 | ) 72 | 73 | @property 74 | def name(self) -> str: 75 | return "russian-lenta-data" 76 | 77 | @property 78 | def _is_data_valid(self) -> bool: 79 | return os.path.isfile(self.raw_file) 80 | 81 | @property 82 | def raw_file(self): 83 | return f'{self.data_dir}/{self._bz2_file_url.split("/")[-1]}' 84 | 85 | def _build_data(self) -> Any: 86 | os.makedirs(str(self.data_dir), exist_ok=True) 87 | if not os.path.isfile(self.raw_file): 88 | wget_file(self._bz2_file_url, str(self.data_dir)) 89 | return self 90 | 91 | def __iter__(self) -> Iterator[str]: 92 | chunksize = 1000 93 | 94 | with pandas.read_csv( 95 | self.raw_file, usecols=["text"], index_col=False, chunksize=chunksize 96 | ) as reader: 97 | rows = (chunk for chunk in reader for d in chunk.iterrows()) 98 | good_rows = (d for d in rows if got_text(d)) 99 | for d in good_rows: 100 | original = d[1].text.replace("\n", " ").replace("\r", "") 101 | yield original 102 | 103 | 104 | if __name__ == "__main__": 105 | 106 | processed_corproa_dir = f"{os.environ['corpora']}/processed_corpora" 107 | BASE_PATHES["processed_corproa_dir"] = processed_corproa_dir 108 | BASE_PATHES["russian_text_data"] = PrefixSuffix( 109 | "processed_corproa_dir", "RUSSIAN_TEXT_DATA" 110 | ) 111 | 112 | corpus = LentaData().build() 113 | 114 | for t in itertools.islice(corpus, 0, 10): 115 | print(f"{t=}") 116 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/punctcap_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from nemo.collections.nlp.models import PunctuationCapitalizationModel 4 | 5 | if __name__ == "__main__": 6 | """ 7 | Nur leere Drohungen oder ein realistisches Szenario. Wirtschaftsminister Robert 8 | Habeck hält inzwischen nichts mehr für ausgeschlossen. Was würde es bedeuten, 9 | wenn Wladimir Putin beschließt, Deutschland das Gas über diese Pipeline 10 | abzudrehen? Die wichtigsten Antworten im Überblick. 11 | """ 12 | query = ( 13 | "nur leere drohungen oder ein realistisches szenario " 14 | "wirtschaftsminister robert habeck hält inzwischen nichts mehr für " 15 | "ausgeschlossen was würde es bedeuten wenn wladimir putin " 16 | "beschließt deutschland das gas über diese pipeline abzudrehen " 17 | "die wichtigsten antworten im überblick" 18 | ) 19 | 20 | nemo_model = f"{os.environ['BASE_PATH']}/data/cache/PROCESSED_DATA/NEMO_MODELS/NemoTrainedPunctuationCapitalizationModel-deu-1421ee5c9e895d0334f3d3c8a93d21eda0de2c61/nemo_exp_dir/Punctuation_and_Capitalization/2022-03-21_18-05-06/checkpoints/Punctuation_and_Capitalization.nemo" 21 | print(f"{nemo_model=}") 22 | DE_model = PunctuationCapitalizationModel.restore_from(nemo_model) 23 | result = DE_model.add_punctuation_capitalization([query]) 24 | print(f"DE-Query : {query}") 25 | print(f"DE-Result : {result}") 26 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/punctuation_capitalization_train_evaluate.py: -------------------------------------------------------------------------------- 1 | # based on: https://github.com/NVIDIA/NeMo/blob/3d0c29a317b89b20c93757010db80271eeea6816/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py 2 | 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from nemo.collections.nlp.models import PunctuationCapitalizationModel 10 | from nemo.collections.nlp.models.token_classification.punctuation_capitalization_config import ( 11 | PunctuationCapitalizationConfig, 12 | ) 13 | from nemo.core.config import hydra_runner 14 | from nemo.utils import logging 15 | from nemo.utils.exp_manager import exp_manager 16 | 17 | 18 | @hydra_runner(config_path="conf", config_name="punctuation_capitalization_config") 19 | def main(cfg: DictConfig) -> None: 20 | torch.manual_seed(42) 21 | cfg = OmegaConf.merge(OmegaConf.structured(PunctuationCapitalizationConfig()), cfg) 22 | trainer = pl.Trainer(**cfg.trainer) 23 | exp_manager(trainer, cfg.get("exp_manager", None)) 24 | if not cfg.do_training and not cfg.do_testing: 25 | raise ValueError( 26 | "At least one of config parameters `do_training` and `do_testing` has to `true`." 27 | ) 28 | if cfg.do_training: 29 | if cfg.model.get("train_ds") is None: 30 | raise ValueError( 31 | "`model.train_ds` config section is required if `do_training` config item is `True`." 32 | ) 33 | if cfg.do_testing: 34 | if cfg.model.get("test_ds") is None: 35 | raise ValueError( 36 | "`model.test_ds` config section is required if `do_testing` config item is `True`." 37 | ) 38 | 39 | if not isinstance(cfg.pretrained_model, str) or cfg.pretrained_model in [ 40 | "null", 41 | "None", 42 | ]: 43 | logging.info(f"Config: {OmegaConf.to_yaml(cfg)}") 44 | model = PunctuationCapitalizationModel(cfg.model, trainer=trainer) 45 | else: 46 | if os.path.exists(cfg.pretrained_model): 47 | model = PunctuationCapitalizationModel.restore_from(cfg.pretrained_model) 48 | elif ( 49 | cfg.pretrained_model 50 | in PunctuationCapitalizationModel.get_available_model_names() 51 | ): 52 | model = PunctuationCapitalizationModel.from_pretrained(cfg.pretrained_model) 53 | else: 54 | raise ValueError( 55 | f"Provide path to the pre-trained .nemo file or choose from " 56 | f"{PunctuationCapitalizationModel.list_available_models()}" 57 | ) 58 | 59 | if cfg.do_training: 60 | model.update_config_after_restoring_from_checkpoint( 61 | class_labels=cfg.model.class_labels, 62 | common_dataset_parameters=cfg.model.common_dataset_parameters, 63 | train_ds=cfg.model.get("train_ds") if cfg.do_training else None, 64 | validation_ds=cfg.model.get("validation_ds") 65 | if cfg.do_training 66 | else None, 67 | test_ds=cfg.model.get("test_ds") if cfg.do_testing else None, 68 | optim=cfg.model.get("optim") if cfg.do_training else None, 69 | ) 70 | model.set_trainer(trainer) 71 | 72 | model.setup_training_data() 73 | model.setup_validation_data() 74 | model.setup_optimization() 75 | else: 76 | model.setup_test_data(cfg.model.get("test_ds")) 77 | if cfg.do_training: 78 | trainer.fit(model) 79 | if cfg.do_testing: 80 | trainer.test(model) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/readme.md: -------------------------------------------------------------------------------- 1 | # punctuation via sequence tagging 2 | ### [NeMo tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/nlp/Punctuation_and_Capitalization.ipynb) 3 | * [NeMo punctuation_capitalization_train_evaluate](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py) 4 | 5 | ### datasets 6 | * [Helsinki-NLP/Tatoeba-Challenge](https://github.com/Helsinki-NLP/Tatoeba-Challenge/blob/master/data/MonolingualData.md) 7 | 8 | * wikibooks.txt.gz contains phrases like this: `Wassergekühlter 6-Zylinder turboaufgeladener ladeluftgekühlter Dieselmotor mit 8.277 ccm Hubraum und Bohrung x Hub 114 x 135,1 mm vom Typ C6T-830.` 9 | -> we might need proper written2spoken-text-formatting! -------------------------------------------------------------------------------- /nemo_punctuation_capitalization/punctcap_training/requirements.txt: -------------------------------------------------------------------------------- 1 | beartype 2 | misc_utils@git+https://github.com/dertilo/misc-utils@main#egg=misc_utils 3 | hydra-core==1.1.0 # see nemo 4 | -------------------------------------------------------------------------------- /nemo_vad/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /nemo_vad/images/vad_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/nemo_vad/images/vad_demo.png -------------------------------------------------------------------------------- /nemo_vad/readme.md: -------------------------------------------------------------------------------- 1 | # Voice Activity Detection with Nvidia NeMo 2 | * based on this [notebook](https://github.com/NVIDIA/NeMo/blob/v1.0.0/tutorials/asr/07_Online_Offline_Microphone_VAD_Demo.ipynb) 3 | 4 | * nvidia says: `It is **not a recommended** way to do inference in production workflows. If you are interested in 5 | production-level inference using NeMo ASR models, please sign-up to Jarvis early access program: https://developer.nvidia.com/nvidia-jarvis` 6 | 7 | * setup 8 | ```shell 9 | pip install -r requirements.txt 10 | ``` 11 | * [visualize_segmentation.py](scripts/visualize_segmentation.py) 12 | ![image](images/vad_demo.png) 13 | 14 | ### [training VAD with nemo](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Voice_Activity_Detection.ipynb) 15 | * TODO -------------------------------------------------------------------------------- /nemo_vad/requirements.txt: -------------------------------------------------------------------------------- 1 | librosa 2 | nemo_toolkit[asr]==1.11.0 -------------------------------------------------------------------------------- /nemo_vad/scripts/visualize_segmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from ml4audio.audio_utils.audio_io import ( 7 | break_array_into_chunks, 8 | convert_to_16bit_array, 9 | load_resample_with_nemo, 10 | ) 11 | from ml4audio.audio_utils.torchaudio_utils import torchaudio_info 12 | from nemo_vad.nemo_streaming_vad import NeMoVAD 13 | 14 | 15 | def offline_inference(vad: NeMoVAD, signal_chunks): 16 | preds = [] 17 | proba_b = [] 18 | proba_s = [] 19 | 20 | for signal in signal_chunks: 21 | result = vad.predict(signal) 22 | 23 | preds.append(result.label_id) 24 | proba_b.append(result.probs_background) 25 | proba_s.append(result.probs_speech) 26 | 27 | vad.reset() 28 | 29 | return preds, proba_b, proba_s 30 | 31 | 32 | def visualize(results, audio, sample_rate, threshold, dur): 33 | """ 34 | copypasted from: https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb 35 | """ 36 | import librosa.display 37 | 38 | plt.figure(figsize=[20, 10]) 39 | num = len(results) 40 | for i, (FRAME_LEN, buffer_size, _, _, proba_s) in enumerate(results): 41 | len_pred = len(results[i][2]) 42 | ax1 = plt.subplot(num + 1, 1, i + 1) 43 | 44 | ax1.plot(np.arange(audio.size) / sample_rate, audio, "b") 45 | ax1.set_xlim([-0.01, int(dur) + 1]) 46 | ax1.tick_params(axis="y", labelcolor="b") 47 | ax1.set_ylabel("Signal") 48 | ax1.set_ylim([-1, 1]) 49 | 50 | pred = [1 if p > threshold else 0 for p in proba_s] 51 | ax2 = ax1.twinx() 52 | ax2.plot( 53 | np.arange(len_pred) / (1 / FRAME_LEN), np.array(pred), "r", label="pred" 54 | ) 55 | ax2.plot( 56 | np.arange(len_pred) / (1 / FRAME_LEN), 57 | np.array(proba_s), 58 | "g--", 59 | label="speech prob", 60 | ) 61 | ax2.tick_params(axis="y", labelcolor="r") 62 | legend = ax2.legend(loc="lower right", shadow=True) 63 | ax1.set_ylabel("prediction") 64 | 65 | ax2.set_title(f"step {FRAME_LEN}s, buffer size {buffer_size}s") 66 | ax2.set_ylabel("Preds and Probas") 67 | ax = plt.subplot(num + 1, 1, i + 2) 68 | S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=64, fmax=8000) 69 | S_dB = librosa.power_to_db(S, ref=np.max) 70 | librosa.display.specshow( 71 | S_dB, x_axis="time", y_axis="mel", sr=sample_rate, fmax=8000 72 | ) 73 | ax.set_title("Mel-frequency spectrogram") 74 | ax.grid() 75 | plt.savefig("vad.png") 76 | 77 | 78 | def main(): 79 | file = "nemo_vad/tests/resources/VAD_demo.wav" 80 | # if not os.path.exists(file): 81 | # os.system( 82 | # 'wget "https://dldata-public.s3.us-east-2.amazonaws.com/VAD_demo.wav" ' 83 | # ) 84 | 85 | sr = 16_000 86 | audio = load_resample_with_nemo(file, sr) 87 | speech_array = convert_to_16bit_array(audio) 88 | 89 | num_frames, sample_rate, dur = torchaudio_info(file) 90 | # audio, sample_rate = librosa.load(file, sr=sample_rate) 91 | # dur = librosa.get_duration(audio) 92 | # print(dur) 93 | 94 | threshold = 0.2 95 | STEP_LIST = [0.01, 0.01] 96 | WINDOW_SIZE_LIST = [0.31, 0.15] 97 | 98 | results = [] 99 | for STEP, WINDOW_SIZE in zip( 100 | STEP_LIST, 101 | WINDOW_SIZE_LIST, 102 | ): 103 | 104 | arrays = list(break_array_into_chunks(speech_array, int(sr * STEP))) 105 | 106 | vad = NeMoVAD( 107 | threshold=threshold, 108 | frame_duration=STEP, 109 | window_len_in_secs=WINDOW_SIZE, 110 | input_sample_rate=sr, 111 | ).build() 112 | 113 | print(f"====== STEP is {STEP}s, WINDOW_SIZE is {WINDOW_SIZE}s ====== ") 114 | preds, proba_b, proba_s = offline_inference( 115 | vad, 116 | arrays, 117 | ) 118 | results.append([STEP, WINDOW_SIZE, preds, proba_b, proba_s]) 119 | 120 | visualize(results, audio, sample_rate, threshold, dur) 121 | 122 | 123 | if __name__ == "__main__": 124 | 125 | main() 126 | -------------------------------------------------------------------------------- /nemo_vad/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/nemo_vad/tests/__init__.py -------------------------------------------------------------------------------- /nemo_vad/tests/resources/VAD_demo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/nemo_vad/tests/resources/VAD_demo.wav -------------------------------------------------------------------------------- /nemo_vad/tests/test_nemo_offline_vad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | import numpy as np 6 | from beartype import beartype 7 | from omegaconf import OmegaConf 8 | 9 | from misc_utils.prefix_suffix import BASE_PATHES 10 | from ml4audio.audio_utils.audio_io import load_resample_with_soundfile, ffmpeg_load_trim 11 | from nemo_vad.nemo_offline_vad import NemoOfflineVAD 12 | from nemo_vad.tests.vad_infer_almost_original import ( 13 | nemo_offline_vad_infer_main_original, 14 | ) 15 | import logging 16 | 17 | from conftest import get_test_cache_base 18 | 19 | logging.getLogger("nemo_logger").setLevel(logging.ERROR) 20 | 21 | # fmt: off 22 | # used vad_infer_almost_original.py to create this expected 23 | expected = [(0.31, 2.93), (3.27, 6.109999999999999), (6.81, 9.83), (10.69, 13.149999999999999), (13.69, 16.35), (17.21, 19.23), (19.54, 20.45), (21.37, 24.37)] 24 | # fmt: on 25 | 26 | BASE_PATHES["cache_root"] = get_test_cache_base() 27 | 28 | 29 | @beartype 30 | def vad_assertions(start_ends: list[tuple[float, float]]): 31 | assert len(start_ends) == len(expected), f"{len(start_ends)=},{len(expected)=}" 32 | starts, ends = [np.asarray(x) for x in zip(*start_ends)] 33 | starts_exp, ends_exp = [np.asarray(x) for x in zip(*expected)] 34 | print(f"{starts=},{ends=}") 35 | print(f"{starts_exp=},{ends_exp=}") 36 | # assert start_ends==expected 37 | assert np.allclose(starts, starts_exp, atol=1e-2) 38 | assert np.allclose(ends, ends_exp, atol=1e-2) 39 | 40 | 41 | # for parameters see: https://github.com/NVIDIA/NeMo/blob/aff169747378bcbcec3fc224748242b36205413f/examples/asr/conf/vad/vad_inference_postprocessing.yaml 42 | default_vad_config = { 43 | "name": "vad_inference_postprocessing", 44 | "dataset": None, 45 | "num_workers": 0, 46 | "sample_rate": 16000, 47 | "gen_seg_table": True, 48 | "write_to_manifest": True, 49 | "prepare_manifest": {"auto_split": True, "split_duration": 400}, 50 | "vad": { 51 | "model_path": "vad_marblenet", 52 | "parameters": { 53 | "normalize_audio": False, 54 | "window_length_in_sec": 0.15, 55 | "shift_length_in_sec": 0.01, 56 | "smoothing": "median", 57 | "overlap": 0.875, 58 | "postprocessing": { 59 | "onset": 0.4, 60 | "offset": 0.7, # TODO(tilo): makes no sense to me 61 | "pad_onset": 0.05, 62 | "pad_offset": -0.1, 63 | "min_duration_on": 0.2, 64 | "min_duration_off": 0.2, 65 | "filter_speech_first": True, 66 | }, 67 | }, 68 | }, 69 | "prepared_manifest_vad_input": None, 70 | "frame_out_dir": "vad_frame", 71 | "smoothing_out_dir": None, 72 | "table_out_dir": None, 73 | "out_manifest_filepath": None, 74 | } 75 | 76 | # TODO: the test is broken! 77 | def test_nemo_offline_vad( 78 | librispeech_audio_file, 79 | ): 80 | audio = ffmpeg_load_trim(librispeech_audio_file) 81 | 82 | vad = NemoOfflineVAD(name="test-vad", cfg=default_vad_config) 83 | vad.build() 84 | with vad: 85 | start_ends, _ = vad.predict(audio) 86 | vad_assertions(start_ends) 87 | 88 | 89 | # def test_nemo_original_vad( 90 | # librispeech_audio_file, 91 | # config_yaml="tests/resources/vad_inference_postprocessing_original.yaml", 92 | # ): 93 | # 94 | # with tempfile.TemporaryDirectory() as tmpdir: 95 | # start_ends = nemo_offline_vad_infer_main_original( 96 | # audio_file=librispeech_audio_file, 97 | # config_yaml=config_yaml, 98 | # data_dir=tmpdir, 99 | # ) 100 | # shutil.rmtree("vad_frame_for_testing", ignore_errors=True) 101 | # vad_assertions(start_ends) 102 | -------------------------------------------------------------------------------- /nemo_vad/tests/test_vad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.testing import assert_allclose 4 | from tqdm import tqdm 5 | 6 | from ml4audio.audio_utils.audio_io import ( 7 | load_and_resample_16bit_PCM, 8 | break_array_into_chunks, 9 | ) 10 | from nemo_vad.nemo_streaming_vad import NeMoVAD 11 | from nemo_vad.streaming_vad_segmentation import StreamingSignalSegmentor 12 | 13 | raw_audio_chunks_dur = 0.01 14 | 15 | file = "nemo_vad/tests/resources/VAD_demo.wav" 16 | SR = 16_000 17 | 18 | 19 | @pytest.fixture() 20 | def audio_arrays(): 21 | speech_array = load_and_resample_16bit_PCM(file, SR) 22 | arrays = list(break_array_into_chunks(speech_array, int(SR * raw_audio_chunks_dur))) 23 | return arrays 24 | 25 | 26 | # fmt: off 27 | expected_speech_probas=[0.124, 0.035, 0.644, 0.324, 0.455, 0.111, 0.648, 0.662, 0.954, 0.985, 0.963, 0.681, 0.293, 0.635, 0.99, 0.975, 1.0, 0.987, 0.978, 0.941, 1.0, 1.0, 1.0, 0.999, 0.999, 0.987, 0.989, 0.998, 0.999, 0.996] 28 | # fmt: on 29 | 30 | 31 | def test_nemo_vad(): 32 | speech_array = load_and_resample_16bit_PCM(file, SR) 33 | arrays = list(break_array_into_chunks(speech_array, int(SR * 0.1))) 34 | 35 | vad = NeMoVAD( 36 | threshold=0.3, 37 | frame_duration=0.1, 38 | window_len_in_secs=4 * 0.1, 39 | input_sample_rate=SR, 40 | ).build() 41 | probas = np.array([vad.predict(signal).probs_speech for signal in arrays]) 42 | assert_allclose(np.array(expected_speech_probas), probas, atol=0.01) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "params", 47 | [ 48 | (0.3, 0.05, 42840, 3), 49 | (0.5, 0.1, 46040, 2), 50 | (0.8, 0.1, 36440, 3), 51 | ], 52 | ) 53 | def test_vad_segmentation(params, audio_arrays): 54 | threshold, frame_duration, exp_voice_dur, num_segs = params 55 | # assert frame_duration>raw_audio_chunks_dur 56 | vad = NeMoVAD( 57 | threshold=threshold, 58 | frame_duration=frame_duration, 59 | window_len_in_secs=4 * frame_duration, 60 | input_sample_rate=SR, 61 | ) 62 | segmenter = StreamingSignalSegmentor(vad=vad).build() 63 | 64 | g = (segmenter.handle_audio_array(a) for a in audio_arrays) 65 | voice_segs = [vs for vs in g if vs is not None and vs.is_final()] 66 | last_seg = segmenter.flush() 67 | if last_seg is not None: 68 | voice_segs = voice_segs + [last_seg] 69 | 70 | voice_dur = sum([len(vs.array) for vs in voice_segs]) 71 | audio_dur = sum(len(a) for a in audio_arrays) 72 | assert audio_dur == 47640 73 | assert voice_dur == exp_voice_dur 74 | assert len(voice_segs) == num_segs 75 | -------------------------------------------------------------------------------- /nemo_vad/vad_inference_postprocessing.yaml: -------------------------------------------------------------------------------- 1 | name: &name "vad_inference_postprocessing" 2 | 3 | dataset: null # Path of json file of evaluation data. Audio files should have unique names 4 | num_workers: 1 5 | sample_rate: 16000 6 | 7 | # functionality 8 | gen_seg_table: True # whether to converting frame level prediction to speech/no-speech segment in start and end times format 9 | write_to_manifest: True # whether to writing above segments to a single manifest json file. 10 | 11 | prepare_manifest: 12 | auto_split: True # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue. 13 | split_duration: 400 # try smaller number if you still have CUDA memory issue 14 | 15 | vad: 16 | # model_path: "vad_multilingual_marblenet" # loading this from ngc-hub seem NOT to be working! 17 | model_path: null # provide a ".nemo" model path 18 | parameters: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) 19 | normalize_audio: False 20 | window_length_in_sec: 0.15 # window length in sec for VAD context input 21 | shift_length_in_sec: 0.01 # shift length in sec for generate frame level VAD prediction 22 | smoothing: "median" # false or type of smoothing method (eg: median) 23 | overlap: 0.875 # overlap ratio for overlapped mean/median smoothing filter 24 | postprocessing: 25 | onset: 0.3 # onset threshold for detecting the beginning and end of a speech 26 | # choosing offset this difference lead to high `miss_speaker` score, which only indicates that I was to lazy too manually label short audio segments -> the reference is not too accurate! 20 | * at the very beginning if you "zoom in" you can actually see one of the rare "true" diarization erros (`SER`=="speaker_confusion") 21 | * there is a very short green segment where `big_colored_neglaces` speaks which got not recognized/clustered properly! 22 | ![img_1.png](images/dw_africa_queen_elizabeth_speaker_segments.png) -------------------------------------------------------------------------------- /speaker-diarization/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/speaker-diarization/requirements.txt -------------------------------------------------------------------------------- /speaker-diarization/scripts/audio_segmentation_via_asr.py: -------------------------------------------------------------------------------- 1 | from beartype import beartype 2 | 3 | from misc_utils.beartypes import NeNpFloatDim1 4 | from ml4audio.asr_inference.asr_chunk_infer_glue_pipeline import ( 5 | Aschinglupi, 6 | calc_final_transcript, 7 | CompleteMessage, 8 | ) 9 | from ml4audio.asr_inference.hfwav2vec2_asr_decode_inferencer import ( 10 | HFASRDecodeInferencer, 11 | ) 12 | from ml4audio.asr_inference.logits_inferencer.asr_logits_inferencer import ( 13 | HfCheckpoint, 14 | ) 15 | from ml4audio.asr_inference.logits_inferencer.hfwav2vec2_logits_inferencer import ( 16 | HFWav2Vec2LogitsInferencer, 17 | ) 18 | from ml4audio.asr_inference.transcript_gluer import ( 19 | TranscriptGluer, 20 | ) 21 | from ml4audio.audio_utils.aligned_transcript import AlignedTranscript 22 | from ml4audio.audio_utils.audio_io import ( 23 | convert_to_16bit_array, 24 | break_array_into_chunks, 25 | audio_messages_from_file, 26 | audio_messages_from_chunks, 27 | ) 28 | from ml4audio.audio_utils.overlap_array_chunker import ( 29 | OverlapArrayChunker, 30 | ) 31 | from ml4audio.text_processing.ctc_decoding import GreedyDecoder 32 | 33 | 34 | def wav2vec2_decode_inferencer(model="jonatasgrosman/wav2vec2-large-xlsr-53-german"): 35 | 36 | # TODO(tilo): WTF! I copypasted this from a test! 37 | # if not hasattr(request, "param"): 38 | expected_sample_rate = 16000 39 | # else: 40 | # expected_sample_rate = request.param 41 | 42 | # model = "facebook/wav2vec2-base-960h" 43 | logits_inferencer = HFWav2Vec2LogitsInferencer( 44 | checkpoint=HfCheckpoint( 45 | name=model, 46 | model_name_or_path=model, 47 | ), 48 | input_sample_rate=expected_sample_rate, 49 | ) 50 | asr = HFASRDecodeInferencer( 51 | logits_inferencer=logits_inferencer, 52 | decoder=GreedyDecoder(tokenizer_name_or_path=model), 53 | ) 54 | asr.build() 55 | return asr 56 | 57 | 58 | @beartype 59 | def aschinglupi_infer_file( 60 | wav_file: str, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-german" 61 | ) -> AlignedTranscript: 62 | """ 63 | only for debugging -> loads asr-model with every call! 64 | """ 65 | SR = 16000 66 | asr_input = list(audio_messages_from_file(wav_file, SR)) 67 | return aschinglupi_infer_messages(asr_input, model_name, SR) 68 | 69 | 70 | @beartype 71 | def aschinglupi_infer_array( 72 | array: NeNpFloatDim1, 73 | model_name, 74 | target_sample_rate: int, 75 | chunk_duration: float = 0.1, 76 | ) -> AlignedTranscript: 77 | """ 78 | only for debugging -> loads asr-model with every call! 79 | # TODO: heavy code duplication! see: ml4audio/asr_inference/asr_chunk_infer_glue_pipeline.py 80 | """ 81 | a = convert_to_16bit_array(array) 82 | chunks_g = break_array_into_chunks(a, int(target_sample_rate * chunk_duration)) 83 | messages_g = list(audio_messages_from_chunks("foobar", chunks_g)) 84 | return aschinglupi_infer_messages(messages_g, model_name, target_sample_rate) 85 | 86 | 87 | @beartype 88 | def aschinglupi_infer_messages( 89 | asr_input: CompleteMessage, model_name: str, SR 90 | ) -> AlignedTranscript: 91 | """ 92 | only for debugging -> loads asr-model with every call! 93 | """ 94 | window_dur = 8.0 95 | step_dur = 4.0 96 | streaming_asr = Aschinglupi( 97 | hf_asr_decoding_inferencer=wav2vec2_decode_inferencer(model_name), 98 | transcript_gluer=TranscriptGluer(), 99 | audio_bufferer=OverlapArrayChunker( 100 | chunk_size=int(window_dur * SR), 101 | # minimum_chunk_size=int(1 * SR), # one second! 102 | min_step_size=int(step_dur * SR), 103 | ), 104 | ).build() 105 | at = calc_final_transcript(streaming_asr, asr_input) 106 | return at 107 | -------------------------------------------------------------------------------- /speaker-diarization/scripts/debug_speaker_clustering_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from tempfile import NamedTemporaryFile 4 | 5 | import requests 6 | from data_io.readwrite_files import write_lines 7 | 8 | from ml4audio.speaker_tasks.speaker_embedding_utils import ( 9 | read_sel_from_rttm, 10 | format_rttm_lines, 11 | ) 12 | from ml4audio.speaker_tasks.speechbrain_der import speechbrain_DER 13 | 14 | if __name__ == "__main__": 15 | 16 | for endpoint in ["predict", "predict_unsegmented"]: 17 | audio_file = "nemo_diarization/tests/resources/oLnl1D6owYA.opus" 18 | rttm_ref = "nemo_diarization/tests/resources/oLnl1D6owYA_ref.rttm" 19 | 20 | SR = 16_000 21 | start_end_speaker = read_sel_from_rttm(rttm_ref) 22 | 23 | f = open(audio_file, "rb") 24 | files = { 25 | "file": (f.name, f, "multipart/form-data"), 26 | } 27 | if endpoint == "predict": 28 | files["segments"] = ( 29 | None, 30 | json.dumps([(s, e) for s, e, _ in start_end_speaker]), 31 | "application/json", 32 | ) 33 | port = 8001 34 | r = requests.post(f"http://localhost:{port}/{endpoint}", files=files) 35 | response = r.json() 36 | print(f"{response}") 37 | s_e_labels = [ 38 | (d["start"], d["end"], d["label"]) for d in response["labeled_segments"] 39 | ] 40 | 41 | file_id = Path(audio_file).stem 42 | 43 | with NamedTemporaryFile(suffix=".rttm") as tmp_file: 44 | rttm_pred_file = tmp_file.name 45 | write_lines(rttm_pred_file, format_rttm_lines(s_e_labels, file_id=file_id)) 46 | miss_speaker, fa_speaker, sers, ders = speechbrain_DER( 47 | rttm_ref, 48 | rttm_pred_file, 49 | ignore_overlap=True, 50 | collar=0.25, 51 | individual_file_scores=True, 52 | ) 53 | print(f"{(miss_speaker, fa_speaker, sers, ders)=}") 54 | -------------------------------------------------------------------------------- /speaker-diarization/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | 6 | def req_file(filename, folder="./"): 7 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 8 | content = f.readlines() 9 | # you may also want to remove whitespace characters 10 | # Example: `\n` at the end of each line 11 | return [x.strip() for x in content] 12 | 13 | 14 | install_requires = req_file("requirements.txt") 15 | 16 | with open("README.md") as f: 17 | readme = f.read() 18 | 19 | 20 | setup( 21 | name="speaker_diarization", 22 | version="0.1", 23 | author="Tilo Himmelsbach", 24 | author_email="dertilo@gmail.com", 25 | packages=find_packages(include=["speaker_diarization*"]), 26 | license="MIT License", 27 | long_description=readme, 28 | install_requires=install_requires, 29 | python_requires=">=3.9", 30 | ) 31 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/diarization/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/diarization/nemo_diarizers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | from typing import Any 5 | 6 | import soundfile 7 | from beartype import beartype 8 | from omegaconf import OmegaConf 9 | 10 | from data_io.readwrite_files import write_lines 11 | from ml4audio.audio_utils.audio_segmentation_utils import ( 12 | StartEndArraysNonOverlap, 13 | NonOverlSegs, 14 | StartEndLabels, 15 | ) 16 | from ml4audio.speaker_tasks.diarization.speaker_diarization_inferencer import ( 17 | SpeakerDiarizationInferencer, 18 | ) 19 | from ml4audio.speaker_tasks.speaker_embedding_utils import ( 20 | format_rttm_lines, 21 | read_sel_from_rttm, 22 | ) 23 | from nemo.collections.asr.models import ClusteringDiarizer 24 | from nemo_vad.nemo_offline_vad import NemoOfflineVAD, create_manifest 25 | 26 | 27 | @dataclass 28 | class NemoVadDiarizer(SpeakerDiarizationInferencer): 29 | nemo_vad: NemoOfflineVAD 30 | # cfg:DictConfig 31 | 32 | def _build_self(self) -> Any: 33 | self.cfg = OmegaConf.load( 34 | "ml4audio/speaker_tasks/diarization/offline_diarization.yaml" 35 | ) 36 | 37 | @beartype 38 | def _run_vad(self, s_e_a: StartEndArraysNonOverlap) -> NonOverlSegs: 39 | assert len(s_e_a) == 1 40 | SR = self.nemo_vad.sample_rate 41 | s, e, a = s_e_a[0] 42 | return [ 43 | (s + vad_s, s + vad_e) for vad_s, vad_e in self.nemo_vad.predict(audio=a)[0] 44 | ] 45 | 46 | @beartype 47 | def predict(self, s_e_a: StartEndArraysNonOverlap) -> StartEndLabels: 48 | segments = self._run_vad(s_e_a) 49 | _, _, array = s_e_a[0] 50 | 51 | # tmpdir = f"{os.getcwd()}/nemo_diar" 52 | # os.makedirs(tmpdir) 53 | with TemporaryDirectory(prefix="/tmp/nemo_tmp_dir") as tmpdir: 54 | fileid = "audio" 55 | audio_file = f"{tmpdir}/{fileid}.wav" 56 | manifest_file = f"{tmpdir}/manifest.json" 57 | rttm_file = f"{tmpdir}/vad.rttm" 58 | 59 | soundfile.write(audio_file, array, samplerate=16000) 60 | # manifest_file = f"speaker_tasks/tests/resources/input_manifest.json" 61 | write_lines( 62 | rttm_file, 63 | format_rttm_lines( 64 | [(s, e, "NOSPEAKER") for s, e in segments], file_id=fileid 65 | ), 66 | ) 67 | 68 | create_manifest(manifest_file, audio_file, rttm_file) 69 | self.cfg.diarizer.manifest_filepath = manifest_file 70 | self.cfg.diarizer.vad.model_path = None 71 | self.cfg.diarizer.speaker_embeddings.model_path = ( 72 | "titanet-large" # TODO: wtf hardcoded! 73 | ) 74 | self.cfg.diarizer.oracle_vad = True 75 | # self.cfg.diarizer.vad.parameters.window_length_in_sec = params["window"] 76 | # self.cfg.diarizer.vad.parameters.shift_length_in_sec = params["step_dur"] 77 | self.cfg.diarizer.out_dir = tmpdir 78 | self.cfg.device = "cpu" 79 | sd_model = ClusteringDiarizer(cfg=self.cfg) 80 | sd_model.diarize() 81 | 82 | rttm_file = next(iter(Path(f"{tmpdir}/pred_rttms").glob("*.rttm"))) 83 | return read_sel_from_rttm(str(rttm_file)) 84 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/diarization/offline_diarization.yaml: -------------------------------------------------------------------------------- 1 | name: &name "ClusterDiarizer" 2 | 3 | num_workers: 1 4 | sample_rate: 16000 5 | batch_size: 64 6 | 7 | diarizer: 8 | manifest_filepath: ??? # should be provided during the experiment 9 | out_dir: ??? # should be provided during the experiment 10 | oracle_vad: False # use speech activity (VAD) model for timestamps 11 | collar: 0.25 # no-score zone around reference segment boundaries 12 | ignore_overlap: True # ignore overlap segments while scoring 13 | 14 | vad: 15 | # model_path: "vad_multilingual_marblenet" # loading this from ngc-hub seem NOT to be working! 16 | model_path: "speaker_tasks/vad_multilingual_marblenet.nemo" # provide a ".nemo" model path 17 | external_vad_manifest: null # whether to use an external VAD 18 | 19 | parameters: 20 | window_length_in_sec: 0.15 # Window length in sec for VAD context input 21 | shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction 22 | smoothing: "median" # False or type of smoothing method (eg: median) 23 | overlap: 0.15 # Overlap ratio for overlapped mean/median smoothing filter 24 | onset: 0.3 # Onset threshold for detecting the beginning and end of a speech 25 | offset: 0.2 # Offset threshold for detecting the end of a speech 26 | pad_onset: 0.1 # Adding durations before each speech segment 27 | pad_offset: 0.1 # Adding durations after each speech segment 28 | min_duration_on: 0.7 # Threshold for small non_speech deletion 29 | min_duration_off: 1.0 # Threshold for short speech segment deletion 30 | filter_speech_first: True 31 | 32 | speaker_embeddings: 33 | model_path: ??? # should be provided during the experiment 34 | parameters: 35 | window_length_in_sec: 1.5 # Window length(s) in sec (floating-point number). Either a number or a list. Ex) 1.5 or [1.5,1.0,0.5] 36 | shift_length_in_sec: 0.75 # Shift length(s) in sec (floating-point number). Either a number or a list. Ex) 0.75 or [0.75,0.5,0.25] 37 | multiscale_weights: null 38 | # window_length_in_sec: [1.5,1.0,0.5] # window length in sec 39 | # shift_length_in_sec: [0.75,0.5,0.25] # shift length in sec 40 | # multiscale_weights: [0.33,0.33,0.33] # do multi-scaling 41 | save_embeddings: False # don't save extracted embeddings 42 | 43 | clustering: 44 | parameters: 45 | oracle_num_speakers: False # use a non-oracle setting 46 | max_num_speakers: 50 # max number of speakers for each recording 47 | sparse_search_volume: 100 # values that will be examined with time 48 | enhanced_count_thres: 90 # tilo: this is a crutch/workaround for a not properly working clustering algorithm! 49 | max_rp_threshold: 0.25 # range of p-value search: 0 < p <= threshold 50 | sparse_search: False # use all estimated speakers not sparse list 51 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/diarization/pyannote_diarizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from tempfile import TemporaryDirectory 4 | from typing import Any 5 | 6 | import soundfile 7 | from beartype import beartype 8 | 9 | from ml4audio.audio_utils.audio_segmentation_utils import ( 10 | StartEndArraysNonOverlap, 11 | StartEndLabels, 12 | ) 13 | from ml4audio.speaker_tasks.diarization.speaker_diarization_inferencer import ( 14 | SpeakerDiarizationInferencer, 15 | ) 16 | from ml4audio.speaker_tasks.speaker_embedding_utils import read_sel_from_rttm 17 | from pyannote.audio import Pipeline 18 | 19 | 20 | @dataclass 21 | class PyannoteDiarizer(SpeakerDiarizationInferencer): 22 | pipeline: Pipeline = field(init=False, repr=False) 23 | 24 | def _build_self(self) -> Any: 25 | self.pipeline = Pipeline.from_pretrained( 26 | "pyannote/speaker-diarization@2.1", 27 | use_auth_token=os.environ["PYANNOTE_TOKEN"], 28 | ) 29 | 30 | @beartype 31 | def predict(self, s_e_a: StartEndArraysNonOverlap) -> StartEndLabels: 32 | assert len(s_e_a) == 1 33 | _, _, array = s_e_a[0] 34 | 35 | with TemporaryDirectory(prefix="/tmp/nemo_tmp_dir") as tmpdir: 36 | fileid = "audio" 37 | audio_file = f"{tmpdir}/{fileid}.wav" 38 | rttm_file = f"{tmpdir}/vad.rttm" 39 | 40 | soundfile.write(audio_file, array, samplerate=16000) 41 | diarization = self.pipeline(audio_file) 42 | with open(rttm_file, "w") as f: 43 | diarization.write_rttm(f) 44 | 45 | return read_sel_from_rttm(str(rttm_file)) 46 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/diarization/speaker_diarization_inferencer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from beartype import beartype 4 | from misc_utils.buildable import Buildable 5 | 6 | from ml4audio.audio_utils.audio_segmentation_utils import ( 7 | StartEndArraysNonOverlap, 8 | StartEndLabels, 9 | ) 10 | 11 | 12 | @dataclass 13 | class SpeakerDiarizationInferencer(Buildable): 14 | @beartype 15 | def predict(self, s_e_a: StartEndArraysNonOverlap) -> StartEndLabels: 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /speaker-diarization/speaker_diarization/nemo_speaker_embedder.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from dataclasses import field, dataclass 3 | from typing import Any 4 | 5 | import torch 6 | from beartype import beartype 7 | from tqdm import tqdm 8 | 9 | from misc_utils.beartypes import ( 10 | NeList, 11 | NpFloatDim1, 12 | ) 13 | from misc_utils.buildable import Buildable 14 | from misc_utils.buildable_data import BuildableData, SlugStr 15 | from misc_utils.dataclass_utils import UNDEFINED 16 | from misc_utils.prefix_suffix import PrefixSuffix, BASE_PATHES 17 | from misc_utils.processing_utils import iterable_to_batches 18 | from ml4audio.audio_utils.nemo_utils import load_EncDecSpeakerLabelModel 19 | from ml4audio.speaker_tasks.speaker_embedding_utils import SignalEmbedder 20 | from nemo.collections.asr.models import EncDecSpeakerLabelModel 21 | 22 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 23 | 24 | 25 | @beartype 26 | def embed_audio_chunks_with_nemo( 27 | speaker_model: EncDecSpeakerLabelModel, 28 | overlapping_chunks: NeList[NpFloatDim1], 29 | batch_size: int, 30 | ) -> NeList[NpFloatDim1]: 31 | """ 32 | based on: https://github.com/NVIDIA/NeMo/blob/aff169747378bcbcec3fc224748242b36205413f/examples/speaker_tasks/recognition/extract_speaker_embeddings.py 33 | 34 | based on: https://github.com/NVIDIA/NeMo/blob/aff169747378bcbcec3fc224748242b36205413f/nemo/collections/asr/models/clustering_diarizer.py#L329 35 | """ 36 | if batch_size != 1: 37 | raise NotImplementedError("only batch size 1 is supported, don't ask me why!") 38 | speaker_model = speaker_model.to(DEVICE) 39 | speaker_model.eval() 40 | 41 | all_embs = [] 42 | for test_batch in tqdm( 43 | iterable_to_batches(overlapping_chunks, batch_size=batch_size), 44 | desc="embedding with nemo", 45 | ): 46 | audio_tensors = [torch.from_numpy(x).to(DEVICE) for x in test_batch] 47 | audio_signal_len = torch.as_tensor([len(a) for a in audio_tensors]).to(DEVICE) 48 | no_need_for_padding_cause_all_have_same_len = ( 49 | len(set([len(a) for a in test_batch])) == 1 50 | ) 51 | assert no_need_for_padding_cause_all_have_same_len, set( 52 | [len(a) for a in test_batch] 53 | ) 54 | audio_tensor = torch.concat([x.unsqueeze(0) for x in audio_tensors], dim=0) 55 | # probably based on: https://github.com/NVIDIA/NeMo/blob/4f06f3458b3d4d5e8ed3f5174d84e255a526321a/nemo/collections/asr/models/clustering_diarizer.py#L351 56 | with torch.no_grad(): 57 | _, embs = speaker_model.forward( 58 | input_signal=audio_tensor, input_signal_length=audio_signal_len 59 | ) 60 | emb_shape = embs.shape[-1] 61 | embs = embs.view(-1, emb_shape) 62 | all_embs.extend(embs.cpu().detach().numpy()) 63 | 64 | return all_embs 65 | 66 | 67 | @dataclass 68 | class NemoAudioEmbedder(SignalEmbedder): 69 | model_name: str = UNDEFINED 70 | _speaker_model: EncDecSpeakerLabelModel = field(init=False, repr=False) 71 | 72 | base_dir: PrefixSuffix = field( 73 | default_factory=lambda: PrefixSuffix("cache_root", "MODELS/NEMO_MODELS") 74 | ) 75 | 76 | @property 77 | def name(self) -> SlugStr: 78 | return f"{self.model_name}" 79 | 80 | @property 81 | def _is_data_valid(self) -> bool: 82 | return os.path.isfile(self.model_file) 83 | 84 | @property 85 | def model_file(self): 86 | return f"{self.data_dir}/model.nemo" 87 | 88 | def _build_data(self) -> Any: 89 | model = load_EncDecSpeakerLabelModel(self.model_name) 90 | model.save_to(self.model_file) 91 | 92 | def __enter__(self): 93 | self._speaker_model = EncDecSpeakerLabelModel.restore_from( 94 | restore_path=self.model_file 95 | ) 96 | 97 | def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): 98 | del self._speaker_model 99 | 100 | @beartype 101 | def predict(self, arrays: NeList[NpFloatDim1]) -> NeList[NpFloatDim1]: 102 | return embed_audio_chunks_with_nemo(self._speaker_model, arrays, batch_size=1) 103 | 104 | 105 | if __name__ == "__main__": 106 | BASE_PATHES["cache_root"] = "/tmp/cache_root" 107 | embedder = NemoAudioEmbedder( 108 | model_name="titanet_large", 109 | ).build() 110 | -------------------------------------------------------------------------------- /speaker-diarization/tests/resources/oLnl1D6owYA.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/speaker-diarization/tests/resources/oLnl1D6owYA.opus -------------------------------------------------------------------------------- /speaker-diarization/tests/resources/oLnl1D6owYA_ref.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER oLnl1D6owYA 1 0.662 9.082 eddy_micah_jr 2 | SPEAKER oLnl1D6owYA 1 11.248 0.401 big_colored_necklaces 3 | SPEAKER oLnl1D6owYA 1 12.371 0.641 guy_on_the_street 4 | SPEAKER oLnl1D6owYA 1 13.715 5.502 white_dude_in_shoppingmall 5 | SPEAKER oLnl1D6owYA 1 19.217 6.287 big_dude_braids 6 | SPEAKER oLnl1D6owYA 1 27.268 25.854 eddy_micah_jr 7 | SPEAKER oLnl1D6owYA 1 55.208 10.275 translator 8 | SPEAKER oLnl1D6owYA 1 66.988 9.012 women_with_headscarf 9 | SPEAKER oLnl1D6owYA 1 87.068 21.202 eddy_micah_jr 10 | SPEAKER oLnl1D6owYA 1 109.133 10.897 dude_in_street_blue_shirt 11 | SPEAKER oLnl1D6owYA 1 121.173 9.946 woman_purple_shirt 12 | SPEAKER oLnl1D6owYA 1 131.119 17.392 woman_orange_sweater 13 | SPEAKER oLnl1D6owYA 1 149.474 13.874 big_colored_necklaces 14 | SPEAKER oLnl1D6owYA 1 164.070 10.192 women_in_street_afro_cut 15 | SPEAKER oLnl1D6owYA 1 174.262 9.905 big_dude_braids 16 | SPEAKER oLnl1D6owYA 1 184.167 12.384 women_in_street_blue_jacket 17 | SPEAKER oLnl1D6owYA 1 198.095 10.875 eddy_micah_jr 18 | SPEAKER oLnl1D6owYA 1 211.220 16.128 thiago_melo 19 | SPEAKER oLnl1D6owYA 1 229.213 6.857 speaker_in_ancient_footage 20 | SPEAKER oLnl1D6owYA 1 237.514 2.275 thiago_melo 21 | SPEAKER oLnl1D6owYA 1 240.852 8.912 another_speaker_in_ancient_footage 22 | SPEAKER oLnl1D6owYA 1 250.386 1.423 another_speaker_in_ancient_footage 23 | SPEAKER oLnl1D6owYA 1 253.975 67.207 thiago_melo 24 | SPEAKER oLnl1D6owYA 1 326.216 16.711 eddy_micah_jr 25 | SPEAKER oLnl1D6owYA 1 344.070 61.179 alima_bawah 26 | SPEAKER oLnl1D6owYA 1 405.250 18.379 eddy_micah_jr 27 | SPEAKER oLnl1D6owYA 1 424.431 46.857 macharia_munene 28 | SPEAKER oLnl1D6owYA 1 472.892 12.130 eddy_micah_jr 29 | SPEAKER oLnl1D6owYA 1 486.827 35.649 macharia_munene 30 | SPEAKER oLnl1D6owYA 1 523.559 13.624 eddy_micah_jr 31 | -------------------------------------------------------------------------------- /speaker_clustering_service/app/__init__.py: -------------------------------------------------------------------------------- 1 | from misc_utils.beartyped_dataclass_patch import ( 2 | beartype_all_dataclasses_of_this_files_parent, 3 | ) 4 | 5 | beartype_all_dataclasses_of_this_files_parent(__file__) 6 | -------------------------------------------------------------------------------- /speaker_clustering_service/app/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Optional, Dict 4 | 5 | import uvicorn 6 | from beartype.door import is_bearable 7 | from fastapi import FastAPI, UploadFile, Form 8 | from misc_utils.dataclass_utils import ( 9 | encode_dataclass, 10 | ) 11 | 12 | from ml4audio.audio_utils.audio_segmentation_utils import ( 13 | expand_merge_segments, 14 | merge_short_segments, 15 | ) 16 | from ml4audio.speaker_tasks.speaker_clusterer import UmascanSpeakerClusterer 17 | from ml4audio.service_utils.fastapi_utils import ( 18 | read_uploaded_audio_file, 19 | get_full_model_config, 20 | ) 21 | 22 | DEBUG = os.environ.get("DEBUG", "False").lower() != "false" 23 | if DEBUG: 24 | print("DEBUGGING MODE") 25 | 26 | 27 | app = FastAPI(debug=DEBUG) 28 | 29 | inferencer: Optional[UmascanSpeakerClusterer] = None 30 | 31 | 32 | SR = 16_000 33 | 34 | 35 | def _form_response(file, s_e_labels): 36 | return { 37 | "filename": file.filename, 38 | "labeled_segments": [ 39 | {"start": s, "end": e, "label": l} for s, e, l in s_e_labels 40 | ], 41 | } 42 | 43 | 44 | @app.post("/predict") 45 | async def upload_and_process_audio_file(file: UploadFile, segments: str = Form()): 46 | """ 47 | TODO(tilo): cannot go with normal sync def method, cause: 48 | fastapi wants to run things in multiprocessing-processes -> therefore needs to pickle stuff 49 | some parts of nemo cannot be pickled: "_pickle.PicklingError: Can't pickle " 50 | 51 | # use like this 52 | f = open(audio_file, "rb") 53 | files = { 54 | "file": (f.name, f, "multipart/form-data"), 55 | "segments": ( 56 | None, 57 | json.dumps([(s, e) for s, e, _ in start_end_speaker]), 58 | "application/json", 59 | ), 60 | } 61 | port = 8001 62 | r = requests.post(f"http://localhost:{port}/predict", files=files) 63 | """ 64 | segments = json.loads(segments) 65 | is_bearable(segments, list[list[float]]) # TODO: does not type-narrow mypy! 66 | segments: list[tuple[float, float]] = [(s, e) for s, e in segments] 67 | global inferencer 68 | 69 | audio = await read_uploaded_audio_file(file) 70 | 71 | s_e_times = expand_merge_segments(segments, min_gap_dur=0.7, expand_by=0.1) 72 | s_e_times = merge_short_segments(s_e_times, min_dur=1.5) 73 | s_e_audio = [(s, e, audio[round(s * SR) : round(e * SR)]) for s, e in s_e_times] 74 | assert all((len(a) > 1000 for (s, e), a in s_e_audio)) 75 | 76 | s_e_labels, _ = inferencer.predict(s_e_audio) 77 | 78 | return _form_response(file, s_e_labels) 79 | 80 | 81 | @app.post("/predict_unsegmented") 82 | async def upload_and_process_audio_file_unsegmented(file: UploadFile): 83 | """""" 84 | global inferencer 85 | assert isinstance(inferencer, UmascanSpeakerClusterer) 86 | audio = await read_uploaded_audio_file(file) 87 | dur = float(len(audio)) / SR 88 | s_e_labels, _ = inferencer.predict([((0.0, dur), audio)]) 89 | 90 | return _form_response(file, s_e_labels) 91 | 92 | 93 | @app.get("/get_inferencer_dataclass") 94 | def get_inferencer_dataclass() -> Dict[str, Any]: 95 | global inferencer 96 | if inferencer is not None: 97 | d = encode_dataclass(inferencer) 98 | else: 99 | d = {"response": "no model loaded yet!"} 100 | return d 101 | 102 | 103 | @app.get("/inferencer_config") 104 | def get_model_config() -> Dict[str, Any]: 105 | global inferencer 106 | if inferencer is not None: 107 | d = get_full_model_config(inferencer) 108 | 109 | else: 110 | d = {"response": "no model loaded yet!"} 111 | return d 112 | 113 | 114 | @app.on_event("startup") 115 | def startup_event(): 116 | global inferencer 117 | model_name = "ecapa_tdnn" # TODO(tilo): try out titanet! 118 | inferencer = UmascanSpeakerClusterer(model_name=model_name, metric="cosine").build() 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | uvicorn.run( 124 | "app.main:app", 125 | host="127.0.0.1", 126 | port=2700, 127 | reload=True if DEBUG else False 128 | # log_level="debug" 129 | ) 130 | -------------------------------------------------------------------------------- /speaker_clustering_service/build_model_in_docker.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from misc_utils.dataclass_utils import ( 4 | to_dict, 5 | ) 6 | 7 | from ml4audio.speaker_tasks.speaker_clusterer import UmascanSpeakerClusterer 8 | 9 | if __name__ == "__main__": 10 | """ 11 | maybe it acts as kind of sanity/integration test?? 12 | it also downloads the model, some where to .cache folder 13 | """ 14 | 15 | inferencer = UmascanSpeakerClusterer( 16 | model_name="ecapa_tdnn", metric="cosine" 17 | ).build() 18 | pprint(to_dict(inferencer)) 19 | -------------------------------------------------------------------------------- /speaker_clustering_service/docker/fastapi_cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # see: https://github.com/docker/for-mac/issues/2155 2 | # global arg -> must declare it before every FROM or it will be local 3 | ARG MODEL_IMAGE 4 | # FROM ${MODEL_IMAGE} as model_image 5 | 6 | FROM python:3.9-bullseye AS dependencies 7 | WORKDIR /code 8 | ENV APT_INSTALL="apt-get install -y --no-install-recommends" 9 | 10 | RUN apt-get update && \ 11 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 12 | build-essential \ 13 | ca-certificates \ 14 | wget \ 15 | git \ 16 | g++ \ 17 | cmake \ 18 | vim \ 19 | # for testing \ 20 | # libsndfile 21 | libsndfile1-dev \ 22 | # portaudio 23 | portaudio19-dev python3-pyaudio \ 24 | # ffmpeg 25 | ffmpeg libavcodec-extra \ 26 | # sox \ 27 | sox libsox-dev && \ 28 | apt-get clean && \ 29 | apt-get -y autoremove && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | 33 | ENV PATH="/venv/bin:$PATH" 34 | ENV PIP_INSTALL="/venv/bin/pip install --no-cache-dir --upgrade" 35 | 36 | RUN apt-get update && apt-get install -y python3-venv 37 | RUN python3 -m venv /venv && $PIP_INSTALL pip packaging setuptools 38 | RUN $PIP_INSTALL torchaudio@https://download.pytorch.org/whl/cpu/torchaudio-0.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl 39 | RUN $PIP_INSTALL install Cython 40 | 41 | # to trigger re-run of following, "disable" caching, see: https://stackoverflow.com/questions/35134713/disable-cache-for-specific-run-commands 42 | # use with: --build-arg CACHEBUST=$(date +%s) 43 | 44 | ARG CACHEBUST=-1 45 | RUN echo "$CACHEBUST" 46 | 47 | COPY requirements.txt requirements.txt 48 | RUN $PIP_INSTALL -r requirements.txt 49 | 50 | # ================================================================== 51 | # BUILD MODELS - stage 52 | # ------------------------------------------------------------------ 53 | 54 | FROM dependencies AS build_models 55 | ENV CACHE_ROOT="/model" 56 | # COPY --from=model_image . /model 57 | COPY build_model_in_docker.py /code/build_model_in_docker.py 58 | RUN python /code/build_model_in_docker.py 59 | # 60 | RUN rm -rf /venv/lib/python3.9/site-packages/sklearn/ensemble 61 | RUN rm -rf /venv/lib/python3.9/site-packages/pynini.libs 62 | RUN rm -rf /root/.cache/pip 63 | RUN rm -rf /root/.cache/matplotlib 64 | 65 | # ================================================================== 66 | # PRODUCTION - stage 67 | # ------------------------------------------------------------------ 68 | FROM python:3.9.13-slim-buster AS production 69 | LABEL maintainer="Tilo Himmelsbach" 70 | WORKDIR /code 71 | ENV PATH="/venv/bin:$PATH" 72 | ENV APT_INSTALL="apt-get install -y --no-install-recommends" 73 | 74 | RUN apt-get update && \ 75 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 76 | # libsndfile TODO: currently asr_logits_inferencer uses librosa to resample!! 77 | libsndfile1-dev \ 78 | # portaudio 79 | portaudio19-dev python3-pyaudio \ 80 | # ffmpeg 81 | ffmpeg libavcodec-extra \ 82 | # sox \ 83 | sox libsox-dev && \ 84 | 85 | apt-get clean && \ 86 | apt-get -y autoremove && \ 87 | rm -rf /var/lib/apt/lists/* 88 | 89 | # maybe for better docker-caching copy from model-image here, this only works if build_models-stage does not modify the models!! does it? well it could! 90 | #COPY --from=model_image . /model 91 | # COPY --from=build_models /model /model 92 | COPY --from=build_models /venv /venv 93 | COPY --from=build_models /root/.cache /root/.cache 94 | 95 | ENV HF_DATASETS_OFFLINE=1 96 | ENV TRANSFORMERS_OFFLINE=1 97 | 98 | # PYTHONFAULTHANDLER TODO: wasdatdenn? 99 | ENV PYTHONFAULTHANDLER=1 100 | ENV CACHE_ROOT="/model" 101 | # ENV JINA_MP_START_METHOD=spawn 102 | 103 | COPY app /code/app 104 | 105 | CMD ["/bin/bash", "-c", "source /venv/bin/activate && \ 106 | uvicorn app.main:app --host 0.0.0.0 --port 8000"] -------------------------------------------------------------------------------- /speaker_clustering_service/readme.md: -------------------------------------------------------------------------------- 1 | # speaker clustering service 2 | ### manually build docker-image 3 | ```commandline 4 | IMAGE=selmaproject/iais-speaker-clustering-services:latest 5 | DOCKER_BUILDKIT=1 docker build -f docker/fastapi_cpu/Dockerfile -t $IMAGE . 6 | docker run -it --rm --shm-size 8G -p 8000:8000 $IMAGE bash 7 | 8 | # for debugging 9 | docker run -it -v ${PWD}:/code -v $CODE_DIR/misc-utils:/code/misc-utils -v $CODE_DIR/ml4audio:/code/ml4audio -v $CODE_DIR/misc-utils:/code/misc-utils -p 8001:8000 --rm $IMAGE bash 10 | 11 | export PYTHONPATH=/code:/code/ml4audio && uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload 12 | 13 | curl -F ‘file=@path/to/local/file’ localhost:8000/transcribe 14 | # "production" 15 | docker run -p 8001:8000 --rm $IMAGE 16 | ``` 17 | -------------------------------------------------------------------------------- /speaker_clustering_service/requirements.txt: -------------------------------------------------------------------------------- 1 | ml4audio@git+https://github.com/SELMA-project/ml4audio@main#egg=ml4audio 2 | # if pushed changes to ml4audio do cache-bust here! -> docker (buildkit) does not reinstall if this file is not changing! 3 | # datasets # why? 4 | python-levenshtein 5 | beartype==0.11.0 6 | numba==0.53.1 7 | librosa 8 | 9 | fastapi #==0.78.0 10 | Flask #==2.1.2 11 | icdiff 12 | torchaudio@https://download.pytorch.org/whl/cpu/torchaudio-0.11.0%2Bcpu-cp39-cp39-linux_x86_64.whl 13 | tqdm 14 | transformers==4.22.1 15 | python-multipart 16 | uvicorn[standard] 17 | omegaconf 18 | nemo_toolkit[asr]==1.11.0 19 | # wandb # WTF!! nemo wants it!! 20 | numpy==1.22.0 # see: https://github.com/scikit-learn-contrib/hdbscan/issues/457 21 | hdbscan==0.8.28 22 | joblib==1.1.0 23 | umap_learn==0.5.3 -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # sys.path.append(os.path.dirname(__file__)) # TODO: WTF! this is a hack! 2 | from warnings import filterwarnings 3 | 4 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 5 | 6 | from ml4audio.audio_utils.test_utils import get_test_vocab, TEST_RESOURCES 7 | 8 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 9 | 10 | import pytest 11 | 12 | 13 | @pytest.fixture 14 | def vocab(): 15 | return get_test_vocab() 16 | 17 | 18 | @pytest.fixture 19 | def arpa_file(): 20 | return f"{TEST_RESOURCES}/lm.arpa" 21 | -------------------------------------------------------------------------------- /tests/test_arpa_from_corpus.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from data_io.readwrite_files import read_lines 4 | from misc_utils.buildable import BuildableList 5 | from misc_utils.prefix_suffix import PrefixSuffix, BASE_PATHES 6 | from ml4audio.text_processing.asr_text_cleaning import ( 7 | VocabCasingAwareTextCleaner, 8 | Casing, 9 | ) 10 | from ml4audio.text_processing.kenlm_arpa import ArpaArgs, ArpaBuilder 11 | from ml4audio.text_processing.word_based_text_corpus import ( 12 | WordBasedLMCorpus, 13 | RglobRawCorpus, 14 | ) 15 | from conftest import TEST_RESOURCES 16 | 17 | 18 | def test_arpa_from_corpus(vocab): 19 | test_corpus_dir = TEST_RESOURCES 20 | normalizer = VocabCasingAwareTextCleaner( 21 | casing=Casing.upper, text_cleaner_name="en", letter_vocab=vocab 22 | ) 23 | 24 | with tempfile.TemporaryDirectory() as cache_base: 25 | BASE_PATHES["tmp"] = cache_base 26 | cache_base = PrefixSuffix("tmp", "") 27 | 28 | arpa_args = ArpaArgs( 29 | order=5, 30 | prune="|".join(str(k) for k in [0, 8, 16]), 31 | ) 32 | 33 | lm_corpus = WordBasedLMCorpus( 34 | name="test", 35 | cache_base=cache_base, 36 | raw_corpora=BuildableList[RglobRawCorpus]( 37 | [ 38 | RglobRawCorpus( 39 | cache_base=cache_base, 40 | corpus_dir=test_corpus_dir, 41 | file_pattern="*corpus.txt", 42 | ) 43 | ] 44 | ), 45 | transcript_cleaner=normalizer, 46 | ) 47 | 48 | arpa_builder = ArpaBuilder( 49 | cache_base=cache_base, 50 | arpa_args=arpa_args, 51 | corpus=lm_corpus, 52 | ) 53 | 54 | arpa_builder.build() 55 | some_lines = list(read_lines(arpa_builder.arpa_filepath, limit=20)) 56 | # fmt: off 57 | expected_lines=['', '\\data\\', 'ngram 1=669', 'ngram 2=2', 'ngram 3=0', 'ngram 4=0', 'ngram 5=0', '', '\\1-grams:', '-3.1839507\t\t0', '0\t\t-0.03075325', '-1.26055\t\t0', '-1.8272647\tA\t0', '-2.8700492\tMYTH\t0', '-1.8468318\tIS\t0', '-3.0925605\tFANCIFUL\t0', '-3.0925605\tEXPLANATION\t0', '-1.4614682\tOF\t-0.07636787', '-2.8700492\tGIVEN\t0', '-3.0925605\tPHENOMENON\t0'] 58 | # fmt: on 59 | assert ( 60 | expected_lines == some_lines 61 | ) # TODO: not sure how to check arpa-file for validity 62 | -------------------------------------------------------------------------------- /urdu_asr/readme.md: -------------------------------------------------------------------------------- 1 | # evaluating some Urdu-ASR-models from huggingface-hub 2 | ## TL;DR 3 | * this one looks good: https://huggingface.co/anuragshas/wav2vec2-xls-r-300m-ur-cv9-with-lm 4 | * model was trained on `cv-corpus-9.0` but for fun I also evaluated it on `cv-corpus-11.0` 5 | 6 | ## evaluated on first 1k samples of common-voice test set 7 | 8 | corpora \ service | anuragshas_wav2vec2-xls-r-300m-ur-cv9-with-lm-greedy 9 | --- | --- 10 | cv-corpus-11.0-2022-09-21-ur-train-clean | 31.6% 11 | cv-corpus-11.0-2022-09-21-ur-test-clean | 37.5% 12 | cv-corpus-9.0-2022-04-27-ur-train-clean | 15.0% 13 | cv-corpus-9.0-2022-04-27-ur-test-clean | 35.7% 14 | 15 | ## evaluated on first 1k samples of common-voice test set 16 | ### service - corpora, cased wer (same reference within rows&cols) 17 | service \ corpora | cv-corpus-11.0-2022-09-21-ur-test-clean 18 | --- | --- 19 | Maniac_wav2vec2-xls-r-urdu-greedy | 66.0% 20 | Maniac_wav2vec2-xls-r-60-urdu-greedy | 79.0% 21 | kingabzpro_wav2vec2-urdu-greedy | 50.8% 22 | kingabzpro_wav2vec2-large-xls-r-300m-Urdu-greedy | 51.6% 23 | kingabzpro_wav2vec2-60-urdu-greedy | 53.0% 24 | kingabzpro_wav2vec2-60-Urdu-V8-greedy | 50.7% 25 | anuragshas_wav2vec2-large-xls-r-300m-ur-cv8-greedy | 51.7% 26 | anuragshas_wav2vec2-xls-r-300m-ur-cv9-with-lm-greedy | 34.4% 27 | 28 | 29 | ## common-voice urdu datasets 30 | * version 9: `4.2`/`3.4` hours train/test, thats `3674` samples in train-set 31 | * version 11: `4.9`/`3.8` hours train/test 32 | 33 | ## looks like this 34 | ![image](some_urdu.png) -------------------------------------------------------------------------------- /urdu_asr/some_urdu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/urdu_asr/some_urdu.png -------------------------------------------------------------------------------- /whisper-streaming/README.md: -------------------------------------------------------------------------------- 1 | # live (streaming capable) ASR via Whisper 2 | * https://github.com/collabora/WhisperLive 3 | * https://github.com/davabase/whisper_real_time 4 | * https://github.com/JonathanFly/faster-whisper-livestream-translator 5 | * https://github.com/ufal/whisper_streaming 6 | ```commandline 7 | 8 | conda activate py39_torch2 9 | ENV_NAME=whisper_live 10 | python -m venv ${ENVS_PATH}/${ENV_NAME} --system-site-packages 11 | source ${ENVS_PATH}/$ENV_NAME/bin/activate 12 | 13 | ``` -------------------------------------------------------------------------------- /whisper-streaming/requirements.txt: -------------------------------------------------------------------------------- 1 | faster-whisper -------------------------------------------------------------------------------- /whisper-streaming/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | 6 | def req_file(filename, folder="./"): 7 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 8 | content = f.readlines() 9 | # you may also want to remove whitespace characters 10 | # Example: `\n` at the end of each line 11 | return [x.strip() for x in content] 12 | 13 | 14 | install_requires = req_file("requirements.txt") 15 | 16 | with open("README.md") as f: 17 | readme = f.read() 18 | 19 | 20 | setup( 21 | name="whisper-streaming", 22 | version="0.1", 23 | author="Tilo Himmelsbach", 24 | author_email="dertilo@gmail.com", 25 | packages=find_packages(include=["whisper_streaming*"]), 26 | license="MIT License", 27 | long_description=readme, 28 | install_requires=install_requires, 29 | python_requires=">=3.9", 30 | ) 31 | -------------------------------------------------------------------------------- /whisper-streaming/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/whisper-streaming/tests/__init__.py -------------------------------------------------------------------------------- /whisper-streaming/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from warnings import filterwarnings 2 | 3 | from beartype.roar import BeartypeDecorHintPep585DeprecationWarning 4 | 5 | from misc_utils.prefix_suffix import BASE_PATHES 6 | from ml4audio.audio_utils.test_utils import ( 7 | get_test_cache_base, 8 | TEST_RESOURCES, 9 | ) 10 | 11 | filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) 12 | 13 | from data_io.readwrite_files import read_lines 14 | import pytest 15 | 16 | cache_base = get_test_cache_base() 17 | BASE_PATHES["cache_root"] = cache_base 18 | 19 | 20 | @pytest.fixture 21 | def librispeech_ref(): 22 | ref_txt = ( 23 | f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt" 24 | ) 25 | raw_ref = next(iter(read_lines(ref_txt))) 26 | return raw_ref 27 | 28 | 29 | @pytest.fixture 30 | def librispeech_audio_file(): 31 | return f"{TEST_RESOURCES}/LibriSpeech_dev-other_116_288046_116-288046-0011.opus" 32 | -------------------------------------------------------------------------------- /whisper-streaming/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/whisper-streaming/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011.opus -------------------------------------------------------------------------------- /whisper-streaming/tests/resources/LibriSpeech_dev-other_116_288046_116-288046-0011_ref.txt: -------------------------------------------------------------------------------- 1 | NOT HAVING THE COURAGE OR THE INDUSTRY OF OUR NEIGHBOUR WHO WORKS LIKE A BUSY BEE IN THE WORLD OF MEN AND BOOKS SEARCHING WITH THE SWEAT OF HIS BROW FOR THE REAL BREAD OF LIFE WETTING THE OPEN PAGE BEFORE HIM WITH HIS TEARS PUSHING INTO THE WE HOURS OF THE NIGHT HIS QUEST ANIMATED BY THE FAIREST OF ALL LOVES THE LOVE OF TRUTH WE EASE OUR OWN INDOLENT CONSCIENCE BY CALLING HIM NAMES -------------------------------------------------------------------------------- /whisper-streaming/tests/test_whisper_streaming.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from beartype import beartype 4 | 5 | from ml4audio.asr_inference.faster_whisper_inferencer import ( 6 | FasterWhisperArray2SegmentedTranscripts, 7 | FasterWhisperArgs, 8 | ) 9 | from ml4audio.asr_inference.inference import StartEndTextsNonOverlap 10 | from ml4audio.audio_utils.overlap_array_chunker import ( 11 | OverlapArrayChunker, 12 | ) 13 | from ml4audio.audio_utils.audio_io import audio_messages_from_file 14 | from ml4audio.text_processing.asr_metrics import calc_cer 15 | from ml4audio.text_processing.asr_text_cleaning import ( 16 | VocabCasingAwareTextCleaner, 17 | Casing, 18 | ) 19 | from ml4audio.text_processing.pretty_diff import smithwaterman_aligned_icdiff 20 | from whisper_streaming.whisper_streaming import ( 21 | WhisperStreamer, 22 | concat_transcript, 23 | OverlappingSegment, 24 | accumulate_transcript, 25 | ) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "name,step_dur,window_dur,max_CER,num_responses_expected", 30 | [ 31 | # fmt: off 32 | # one might be tempted to interpret those CER-values, but I think there is not pattern here, everything around 5% or lower is acceptable, 3% is NOT better than 5%! its just a very-small+very instable "base-whisper"-model! 33 | ("non-overlapping",4.0, 4.0, 0.045, 7), 34 | ("good-overlap",2.0, 4.0, 0.042, 12), 35 | ("big-overlap",1.0, 4.0, 0.05, 22), 36 | # fmt: on 37 | ], 38 | ) 39 | def test_whisper_streaming( 40 | name: str, 41 | librispeech_audio_file: str, 42 | librispeech_ref: str, 43 | step_dur: float, 44 | window_dur: float, 45 | max_CER: float, 46 | num_responses_expected: int, 47 | ): 48 | inferencer = FasterWhisperArray2SegmentedTranscripts( 49 | model_name="base", whisper_args=FasterWhisperArgs(language="en") 50 | ) 51 | inferencer.build() 52 | SR = inferencer.sample_rate 53 | asr_input = list( 54 | audio_messages_from_file(librispeech_audio_file, SR, chunk_duration=0.1) 55 | ) 56 | assert asr_input[-1].end_of_signal 57 | audio_signal = np.concatenate([ac.array for ac in asr_input]) 58 | wav_length = 393920 59 | opus_is_alittle_longer = 70 60 | audio_len = audio_signal.shape[0] 61 | print(f"audio-dur: {audio_len/16000}") 62 | assert audio_len == wav_length + opus_is_alittle_longer 63 | 64 | streaming_asr: WhisperStreamer = WhisperStreamer( 65 | asr_inferencer=inferencer, 66 | audio_bufferer=OverlapArrayChunker( 67 | chunk_size=int(window_dur * SR), 68 | # minimum_chunk_size=int(1 * SR), # one second! 69 | min_step_size=int(step_dur * SR), 70 | # max_step_size=int(max_step_dur * SR) if max_step_dur is not None else None, 71 | ), 72 | overwrite_last_k_words=3, 73 | ) 74 | streaming_asr.build() 75 | transcript: str = "" 76 | num_responses = 0 77 | with streaming_asr: 78 | for inpt in asr_input: 79 | for overlap_segment, new_segments in streaming_asr.handle_inference_input( 80 | inpt 81 | ): 82 | num_responses += 1 83 | transcript = accumulate_transcript( 84 | overlap_segment, new_segments, transcript 85 | ) 86 | print(f"{overlap_segment=}###{new_segments=}") 87 | assert " " not in transcript 88 | hyp = transcript 89 | 90 | cleaner = VocabCasingAwareTextCleaner( 91 | casing=Casing.upper, 92 | text_cleaner_name="en", 93 | letter_vocab=list(set(librispeech_ref)), 94 | ) 95 | hyp = cleaner(hyp) 96 | ref = librispeech_ref 97 | print(smithwaterman_aligned_icdiff(ref, hyp)) 98 | cer = calc_cer([ref], [hyp]) 99 | print(f"{name}: {step_dur=},{window_dur=},{cer=}") 100 | assert cer <= max_CER 101 | assert num_responses_expected == num_responses, f"{num_responses=}" 102 | -------------------------------------------------------------------------------- /whisper-streaming/whisper_streaming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SELMA-project/ml4audio/4d771c3267f5ac9c8b3f59b9435f260e1f946083/whisper-streaming/whisper_streaming/__init__.py --------------------------------------------------------------------------------