├── speechcatcher ├── __init__.py ├── decoder │ └── __init__.py ├── model │ ├── __pycache__ │ │ ├── ctc.cpython-312.pyc │ │ ├── __init__.cpython-312.pyc │ │ ├── checkpoint_loader.cpython-312.pyc │ │ └── espnet_asr_model.cpython-312.pyc │ ├── layers │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── convolution.cpython-312.pyc │ │ │ ├── feed_forward.cpython-312.pyc │ │ │ ├── normalization.cpython-312.pyc │ │ │ └── positional_encoding.cpython-312.pyc │ │ ├── __init__.py │ │ ├── normalization.py │ │ ├── feed_forward.py │ │ ├── convolution.py │ │ └── positional_encoding.py │ ├── decoder │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── decoder_layer.cpython-312.pyc │ │ │ └── transformer_decoder.cpython-312.pyc │ │ ├── __init__.py │ │ └── decoder_layer.py │ ├── encoder │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── subsampling.cpython-312.pyc │ │ │ ├── contextual_block_encoder_layer.cpython-312.pyc │ │ │ └── contextual_block_transformer_encoder.cpython-312.pyc │ │ ├── __init__.py │ │ └── subsampling.py │ ├── frontend │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ └── stft_frontend.cpython-312.pyc │ │ └── __init__.py │ ├── attention │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ └── multi_head_attention.cpython-312.pyc │ │ └── __init__.py │ └── __init__.py ├── beam_search │ ├── __pycache__ │ │ ├── __init__.cpython-312.pyc │ │ ├── scorers.cpython-312.pyc │ │ ├── beam_search.cpython-312.pyc │ │ ├── hypothesis.cpython-312.pyc │ │ └── ctc_prefix_score_full.cpython-312.pyc │ ├── __init__.py │ └── hypothesis.py ├── decode_kaldidir.py ├── vosk_test_client.py ├── compute_wer.py └── websocket_demo.html ├── speechcatcher_de_live.gif ├── docs ├── SLT2021_tsunoo │ ├── 2006.14941 │ └── contextual_block.eps ├── implementation │ ├── weight-loading.md │ └── root-cause-analysis.md ├── analysis │ ├── initial-comparison.md │ └── streaming-analysis.md └── debugging │ └── investigation.md ├── requirements.txt ├── .gitignore ├── speechcatcher_server.service ├── .github └── workflows │ └── python-package-test.yml ├── LICENSE ├── setup.py ├── tests ├── test_fp16_cuda.py ├── test_fp16_debug.py ├── test_beam_trace_chunks.py ├── test_single_chunk.py ├── test_bbd_debug.py ├── test_batch_mode.py ├── test_espnet_runtime_config.py ├── test_encoder_difference.py ├── test_espnet_decoding_start.py ├── test_espnet_streaming_beams.py ├── test_our_transcribe.py ├── test_full_transcription_debug.py ├── test_espnet_beam_first_step.py ├── test_encoder_buffer_debug.py ├── test_state_preservation.py ├── test_ctc_timing.py ├── test_streaming_chunks.py ├── test_final_chunk_debug.py ├── test_espnet_full.py ├── test_beam_search_trace.py ├── test_espnet_beam_search_trace.py ├── test_full_comparison.py ├── test_compare_beam_config.py ├── test_espnet_transcribe.py ├── test_espnet_final_chunk.py ├── test_bbd_state.py ├── test_token_scoring_chunk5.py ├── test_decoder_scores_debug.py ├── test_waveform_buffering.py ├── test_step_by_step_trace.py ├── test_model_mode.py ├── test_combined_scores.py ├── test_normalization.py ├── test_multi_chunk_comparison.py ├── test_normalization_only.py ├── model │ └── test_manual.py ├── test_espnet_vs_ours.py └── test_exact_score_comparison.py └── pyproject.toml /speechcatcher/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /speechcatcher/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speechcatcher_de_live.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher_de_live.gif -------------------------------------------------------------------------------- /docs/SLT2021_tsunoo/2006.14941: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/docs/SLT2021_tsunoo/2006.14941 -------------------------------------------------------------------------------- /docs/SLT2021_tsunoo/contextual_block.eps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/docs/SLT2021_tsunoo/contextual_block.eps -------------------------------------------------------------------------------- /speechcatcher/model/__pycache__/ctc.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/__pycache__/ctc.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/beam_search/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/beam_search/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/beam_search/__pycache__/scorers.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/beam_search/__pycache__/scorers.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/layers/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/layers/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/beam_search/__pycache__/beam_search.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/beam_search/__pycache__/beam_search.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/beam_search/__pycache__/hypothesis.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/beam_search/__pycache__/hypothesis.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/__pycache__/checkpoint_loader.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/__pycache__/checkpoint_loader.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/__pycache__/espnet_asr_model.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/__pycache__/espnet_asr_model.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/decoder/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/decoder/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/encoder/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/encoder/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/frontend/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/frontend/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/attention/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/attention/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/encoder/__pycache__/subsampling.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/encoder/__pycache__/subsampling.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/layers/__pycache__/convolution.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/layers/__pycache__/convolution.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/layers/__pycache__/feed_forward.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/layers/__pycache__/feed_forward.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/layers/__pycache__/normalization.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/layers/__pycache__/normalization.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/decoder/__pycache__/decoder_layer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/decoder/__pycache__/decoder_layer.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/frontend/__pycache__/stft_frontend.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/frontend/__pycache__/stft_frontend.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | """Frontend modules for audio feature extraction.""" 2 | 3 | from speechcatcher.model.frontend.stft_frontend import STFTFrontend 4 | 5 | __all__ = ["STFTFrontend"] 6 | -------------------------------------------------------------------------------- /speechcatcher/beam_search/__pycache__/ctc_prefix_score_full.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/beam_search/__pycache__/ctc_prefix_score_full.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/decoder/__pycache__/transformer_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/decoder/__pycache__/transformer_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/layers/__pycache__/positional_encoding.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/layers/__pycache__/positional_encoding.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/attention/__pycache__/multi_head_attention.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/attention/__pycache__/multi_head_attention.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/encoder/__pycache__/contextual_block_encoder_layer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/encoder/__pycache__/contextual_block_encoder_layer.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/encoder/__pycache__/contextual_block_transformer_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechcatcher-asr/speechcatcher/HEAD/speechcatcher/model/encoder/__pycache__/contextual_block_transformer_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /speechcatcher/model/__init__.py: -------------------------------------------------------------------------------- 1 | """Model modules for speechcatcher.""" 2 | 3 | from speechcatcher.model.ctc import CTC 4 | from speechcatcher.model.espnet_asr_model import ESPnetASRModel 5 | 6 | __all__ = [ 7 | "CTC", 8 | "ESPnetASRModel", 9 | ] 10 | -------------------------------------------------------------------------------- /speechcatcher/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | """Decoder modules for streaming ASR.""" 2 | 3 | from speechcatcher.model.decoder.decoder_layer import TransformerDecoderLayer 4 | from speechcatcher.model.decoder.transformer_decoder import TransformerDecoder 5 | 6 | __all__ = [ 7 | "TransformerDecoderLayer", 8 | "TransformerDecoder", 9 | ] 10 | -------------------------------------------------------------------------------- /speechcatcher/model/attention/__init__.py: -------------------------------------------------------------------------------- 1 | """Attention mechanisms for Transformer and Conformer models.""" 2 | 3 | from speechcatcher.model.attention.multi_head_attention import ( 4 | MultiHeadedAttention, 5 | RelPositionMultiHeadedAttention, 6 | ) 7 | 8 | __all__ = [ 9 | "MultiHeadedAttention", 10 | "RelPositionMultiHeadedAttention", 11 | ] 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | torch 3 | torchaudio 4 | ffmpeg-python 5 | espnet_streaming_decoder @ git+https://github.com/speechcatcher-asr/espnet_streaming_decoder 6 | espnet_model_zoo @ git+https://github.com/speechcatcher-asr/espnet_model_zoo 7 | soundfile 8 | six 9 | pyaudio 10 | python_speech_features 11 | scipy 12 | tqdm 13 | somajo 14 | websockets>=14.2 15 | # CRITICAL: sentencepiece 0.2.1+ has Python 3.13 wheels 16 | sentencepiece>=0.2.1 17 | -------------------------------------------------------------------------------- /speechcatcher/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for streaming ASR.""" 2 | 3 | from speechcatcher.model.encoder.contextual_block_encoder_layer import ContextualBlockEncoderLayer 4 | from speechcatcher.model.encoder.contextual_block_transformer_encoder import ContextualBlockTransformerEncoder 5 | from speechcatcher.model.encoder.subsampling import Conv2dSubsampling 6 | 7 | __all__ = [ 8 | "Conv2dSubsampling", 9 | "ContextualBlockEncoderLayer", 10 | "ContextualBlockTransformerEncoder", 11 | ] 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Large video files 2 | *.mp4 3 | 4 | # Test output files 5 | *.mp4.txt 6 | *.mp4.json 7 | *.wav 8 | 9 | # Python cache 10 | __pycache__/ 11 | *.pyc 12 | *.pyo 13 | *.pyd 14 | .Python 15 | *.so 16 | *.egg 17 | *.egg-info/ 18 | dist/ 19 | build/ 20 | *.whl 21 | 22 | # IDE 23 | .vscode/ 24 | .idea/ 25 | *.swp 26 | *.swo 27 | *~ 28 | 29 | # Jupyter 30 | .ipynb_checkpoints/ 31 | *.ipynb 32 | 33 | # Temporary files 34 | .tmp/ 35 | tmp/ 36 | temp/ 37 | 38 | # Model files (if large) 39 | *.pt 40 | *.pth 41 | *.bin 42 | *.onnx 43 | 44 | # Logs 45 | *.log 46 | -------------------------------------------------------------------------------- /speechcatcher_server.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=Speechcatcher Server 3 | After=network.target 4 | 5 | [Service] 6 | User=me 7 | Group=me 8 | WorkingDirectory=/home/me/speechcatcher_env 9 | Environment="VIRTUAL_ENV=/home/me/speechcatcher_env" 10 | Environment="PATH=/home/me/speechcatcher_env/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" 11 | ExecStart=/home/me/speechcatcher_env/bin/speechcatcher_server --host 127.0.0.1 --port 2700 --beamsize 3 --vosk-output-format --finalize-update-iters 5 --max_partial_iters 512 12 | Restart=always 13 | RestartSec=5 14 | 15 | [Install] 16 | WantedBy=multi-user.target 17 | -------------------------------------------------------------------------------- /speechcatcher/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural network layers for Transformer and Conformer models.""" 2 | 3 | from speechcatcher.model.layers.feed_forward import PositionwiseFeedForward 4 | from speechcatcher.model.layers.positional_encoding import ( 5 | PositionalEncoding, 6 | RelPositionalEncoding, 7 | StreamPositionalEncoding, 8 | ) 9 | from speechcatcher.model.layers.convolution import ConvolutionModule 10 | from speechcatcher.model.layers.normalization import LayerNorm 11 | 12 | __all__ = [ 13 | "PositionwiseFeedForward", 14 | "PositionalEncoding", 15 | "RelPositionalEncoding", 16 | "StreamPositionalEncoding", 17 | "ConvolutionModule", 18 | "LayerNorm", 19 | ] 20 | -------------------------------------------------------------------------------- /speechcatcher/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | """Beam search modules for streaming ASR.""" 2 | 3 | from speechcatcher.beam_search.beam_search import ( 4 | BeamSearch, 5 | BlockwiseSynchronousBeamSearch, 6 | create_beam_search, 7 | ) 8 | from speechcatcher.beam_search.hypothesis import ( 9 | BeamState, 10 | Hypothesis, 11 | create_initial_hypothesis, 12 | ) 13 | from speechcatcher.beam_search.scorers import ( 14 | CTCPrefixScorer, 15 | DecoderScorer, 16 | ScorerInterface, 17 | ) 18 | 19 | __all__ = [ 20 | "BeamSearch", 21 | "BlockwiseSynchronousBeamSearch", 22 | "create_beam_search", 23 | "BeamState", 24 | "Hypothesis", 25 | "create_initial_hypothesis", 26 | "CTCPrefixScorer", 27 | "DecoderScorer", 28 | "ScorerInterface", 29 | ] 30 | -------------------------------------------------------------------------------- /speechcatcher/decode_kaldidir.py: -------------------------------------------------------------------------------- 1 | from kaldiio import ReadHelper 2 | import os 3 | import sys 4 | import speechcatcher 5 | 6 | if __name__ == '__main__': 7 | testset_dir = "data/tuda_raw_test/" 8 | 9 | short_tag = 'de_streaming_transformer_m' 10 | speech2text = speechcatcher.load_model(speechcatcher.tags[short_tag]) 11 | 12 | with open(testset_dir + f'/{short_tag}_decoded', 'w') as testset_dir_decoded: 13 | with ReadHelper(f'scp,p:{os.path.join(testset_dir, "wav.scp")}') as reader: 14 | for key, (rate, speech) in reader: 15 | print(key, min(speech), max(speech), len(speech)) 16 | try: 17 | text = speechcatcher.recognize(speech2text, speech, rate, quiet=True, progress=False) 18 | print(f'{key} {text}') 19 | testset_dir_decoded.write(f'{key} {text}\n') 20 | except: 21 | print('Warning: couldnt decode:', key) 22 | -------------------------------------------------------------------------------- /speechcatcher/model/layers/normalization.py: -------------------------------------------------------------------------------- 1 | """Layer normalization variants.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LayerNorm(nn.LayerNorm): 8 | """Layer normalization with optional dimension parameter. 9 | 10 | This is a thin wrapper around torch.nn.LayerNorm that provides 11 | compatibility with the ESPnet interface and allows for easier 12 | pre-norm / post-norm switching. 13 | 14 | Args: 15 | dim: Normalization dimension 16 | eps: Epsilon for numerical stability (default: 1e-12) 17 | 18 | Shape: 19 | - Input: (*, dim) 20 | - Output: (*, dim) 21 | """ 22 | 23 | def __init__(self, dim: int, eps: float = 1e-12): 24 | super().__init__(dim, eps=eps) 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | """Forward pass. 28 | 29 | Args: 30 | x: Input tensor (*, dim) 31 | 32 | Returns: 33 | Normalized tensor (*, dim) 34 | """ 35 | return super().forward(x) 36 | -------------------------------------------------------------------------------- /.github/workflows/python-package-test.yml: -------------------------------------------------------------------------------- 1 | name: Python Package Test 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-24.04 8 | strategy: 9 | matrix: 10 | python-version: ['3.12'] 11 | include: 12 | - python-version: '3.12' 13 | python-dev-package: 'python3.12-dev' 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v2 17 | - name: Install system dependencies 18 | run: | 19 | sudo apt-get update 20 | sudo apt-get install -y portaudio19-dev ffmpeg ${{ matrix.python-dev-package }} 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install . 29 | - name: Test installation 30 | run: | 31 | which speechcatcher 32 | speechcatcher --help 33 | speechcatcher https://upload.wikimedia.org/wikipedia/commons/6/65/LibriVox_-_Fontane_Herr_von_Ribbeck.ogg 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 speechcatcher-asr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt') as f: 4 | requirements = f.read().splitlines() 5 | 6 | setup( 7 | name='speechcatcher', 8 | version='0.5.0', 9 | author="Benjamin Milde", 10 | author_email="bmilde@users.noreply.github.com", 11 | description="Speechcatcher is an open source toolbox for transcribing speech from media files (audio/video).", 12 | url="https://github.com/speechcatcher-asr/speechcatcher", 13 | packages=find_packages(), 14 | install_requires=requirements, 15 | entry_points={ 16 | 'console_scripts': [ 17 | 'speechcatcher=speechcatcher.speechcatcher:main', 18 | 'speechcatcher_compute_wer=speechcatcher.compute_wer:main', 19 | 'speechcatcher_server=speechcatcher.speechcatcher_server:main', 20 | 'speechcatcher_vosk_test_client=speechcatcher.vosk_test_client:main', 21 | 'speechcatcher_simple_endpointing=speechcatcher.simple_endpointing:main' 22 | ] 23 | }, 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /tests/test_fp16_cuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test FP16 on CUDA.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | print("Loading model with FP16 on CUDA...") 22 | s2t = load_model(tags['de_streaming_transformer_xl'], device='cuda', beam_size=5, quiet=False, fp16=True) 23 | 24 | print("\nModel loaded successfully!") 25 | print(f"Model device: {next(s2t.model.parameters()).device}") 26 | print(f"Model dtype: {next(s2t.model.parameters()).dtype}") 27 | 28 | print("\nProcessing first chunk...") 29 | chunk_size = 8000 30 | chunk = speech[0:chunk_size] 31 | 32 | try: 33 | result = s2t(chunk, is_final=False) 34 | print(f"Success! Result: {result}") 35 | except Exception as e: 36 | print(f"Error: {e}") 37 | import traceback 38 | traceback.print_exc() 39 | -------------------------------------------------------------------------------- /speechcatcher/model/layers/feed_forward.py: -------------------------------------------------------------------------------- 1 | """Position-wise feed-forward network for Transformer/Conformer.""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class PositionwiseFeedForward(nn.Module): 10 | """Positionwise feed-forward network. 11 | 12 | This implements the FFN layer from "Attention is All You Need": 13 | FFN(x) = max(0, xW1 + b1)W2 + b2 14 | 15 | Args: 16 | input_dim: Input dimension 17 | hidden_dim: Hidden dimension (typically 4x input_dim for Transformers) 18 | output_dim: Output dimension (typically same as input_dim) 19 | dropout_rate: Dropout rate 20 | activation: Activation function (default: ReLU) 21 | 22 | Shape: 23 | - Input: (batch, time, input_dim) 24 | - Output: (batch, time, output_dim) 25 | """ 26 | 27 | def __init__( 28 | self, 29 | input_dim: int, 30 | hidden_dim: int, 31 | output_dim: int, 32 | dropout_rate: float = 0.1, 33 | activation: nn.Module = nn.ReLU(), 34 | ): 35 | super().__init__() 36 | self.w_1 = nn.Linear(input_dim, hidden_dim) 37 | self.w_2 = nn.Linear(hidden_dim, output_dim) 38 | self.dropout = nn.Dropout(dropout_rate) 39 | self.activation = activation 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | """Forward pass. 43 | 44 | Args: 45 | x: Input tensor (batch, time, input_dim) 46 | 47 | Returns: 48 | Output tensor (batch, time, output_dim) 49 | """ 50 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 51 | -------------------------------------------------------------------------------- /tests/test_fp16_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Debug FP16 vs FP32 outputs.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | print("="*80) 22 | print("FP32 (baseline)") 23 | print("="*80) 24 | s2t_fp32 = load_model(tags['de_streaming_transformer_xl'], device='cuda', beam_size=5, quiet=True, fp16=False) 25 | s2t_fp32.reset() 26 | 27 | chunk_size = 8000 28 | for i in range(len(speech) // chunk_size + 1): 29 | chunk = speech[i*chunk_size:min((i+1)*chunk_size, len(speech))] 30 | if len(chunk) == 0: 31 | break 32 | result = s2t_fp32(chunk, is_final=(i == len(speech) // chunk_size)) 33 | if result and result[0][0]: 34 | print(f"Chunk {i+1}: '{result[0][0]}'") 35 | 36 | print("\n" + "="*80) 37 | print("FP16 (testing)") 38 | print("="*80) 39 | s2t_fp16 = load_model(tags['de_streaming_transformer_xl'], device='cuda', beam_size=5, quiet=True, fp16=True) 40 | s2t_fp16.reset() 41 | 42 | for i in range(len(speech) // chunk_size + 1): 43 | chunk = speech[i*chunk_size:min((i+1)*chunk_size, len(speech))] 44 | if len(chunk) == 0: 45 | break 46 | result = s2t_fp16(chunk, is_final=(i == len(speech) // chunk_size)) 47 | if result and result[0][0]: 48 | print(f"Chunk {i+1}: '{result[0][0]}'") 49 | if i == 0 and result[0][2]: # Show token IDs for first chunk 50 | print(f" Token IDs: {result[0][2][:20]}") 51 | -------------------------------------------------------------------------------- /tests/test_beam_trace_chunks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Trace beam state after each chunk for both implementations.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(message)s') 12 | 13 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 14 | 15 | os.makedirs('.tmp/', exist_ok=True) 16 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 17 | if not os.path.exists(wavfile_path): 18 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 19 | 20 | with wave.open(wavfile_path, 'rb') as f: 21 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 22 | speech = raw_audio.astype(np.float32) / 32768.0 23 | 24 | # Load our model 25 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 26 | our_s2t.reset() 27 | 28 | chunk_size = 8000 29 | 30 | print("="*80) 31 | print("OUR IMPLEMENTATION") 32 | print("="*80) 33 | 34 | for chunk_idx in range(len(speech) // chunk_size + 1): 35 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 36 | if len(chunk) == 0: 37 | break 38 | result = our_s2t(chunk, is_final=False) 39 | 40 | # Show beam state 41 | if hasattr(our_s2t, 'beam_state'): 42 | print(f"\nAfter Chunk {chunk_idx+1}:") 43 | if result: 44 | print(f" Output: '{result[0][0]}'") 45 | else: 46 | print(f" Output: (none)") 47 | print(f" Output index: {our_s2t.beam_state.output_index}") 48 | print(f" Beam size: {len(our_s2t.beam_state.hypotheses)}") 49 | for i, hyp in enumerate(our_s2t.beam_state.hypotheses[:3]): 50 | yseq = hyp.yseq.tolist()[:10] 51 | has_eos = hyp.yseq[-1].item() == 1023 52 | yseq_len = len(hyp.yseq) - 1 # Excluding SOS 53 | print(f" [{i+1}] len={yseq_len}, yseq={yseq}, score={hyp.score:.4f}, has_eos={has_eos}") 54 | -------------------------------------------------------------------------------- /tests/test_single_chunk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test with a single large chunk to isolate streaming issues.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("SINGLE CHUNK TEST") 12 | print("="*80) 13 | 14 | # Load audio 15 | print("\n[1] Loading audio...") 16 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 17 | 18 | os.makedirs('.tmp/', exist_ok=True) 19 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 20 | if not os.path.exists(wavfile_path): 21 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 22 | 23 | with wave.open(wavfile_path, 'rb') as wavfile_in: 24 | rate = wavfile_in.getframerate() 25 | buf = wavfile_in.readframes(-1) 26 | speech = np.frombuffer(buf, dtype='int16') 27 | 28 | # Normalize to [-1, 1] 29 | speech = speech.astype(np.float32) / 32768.0 30 | 31 | print(f"Audio: rate={rate} Hz, shape={speech.shape}, range=[{speech.min():.4f}, {speech.max():.4f}]") 32 | 33 | # Load model 34 | print("\n[2] Loading model...") 35 | speech2text = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 36 | print("✅ Model loaded") 37 | 38 | # Process as a SINGLE chunk with is_final=True 39 | print("\n[3] Processing as single chunk...") 40 | results = speech2text(speech=speech, is_final=True) 41 | 42 | print("\n" + "="*80) 43 | print("RESULTS") 44 | print("="*80) 45 | 46 | if results and len(results) > 0: 47 | text, tokens, token_ids = results[0] 48 | print(f"\n✅ Text: '{text}'") 49 | print(f"\n✅ Tokens ({len(tokens)}): {tokens[:20]}...") 50 | print(f"\n✅ Token IDs ({len(token_ids)}): {token_ids[:20]}...") 51 | 52 | # Check for Arabic token 53 | if 'م' in tokens: 54 | count = tokens.count('م') 55 | print(f"\n⚠️ Arabic 'م' appears {count} times") 56 | elif 1023 in token_ids: 57 | count = token_ids.count(1023) 58 | print(f"\n⚠️ Token ID 1023 appears {count} times") 59 | else: 60 | print(f"\n✅ No problematic tokens!") 61 | else: 62 | print("\n❌ No results!") 63 | 64 | print("\n" + "="*80) 65 | print("SINGLE CHUNK TEST COMPLETE") 66 | print("="*80) 67 | -------------------------------------------------------------------------------- /tests/test_bbd_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Debug BBD behavior.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable debug logging 12 | logging.basicConfig(level=logging.DEBUG, format='%(name)s - %(levelname)s - %(message)s') 13 | 14 | print("="*80) 15 | print("BBD DEBUG TEST") 16 | print("="*80) 17 | 18 | # Load audio 19 | print("\n[1] Loading audio...") 20 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 21 | 22 | os.makedirs('.tmp/', exist_ok=True) 23 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 24 | if not os.path.exists(wavfile_path): 25 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 26 | 27 | with wave.open(wavfile_path, 'rb') as wavfile_in: 28 | rate = wavfile_in.getframerate() 29 | buf = wavfile_in.readframes(-1) 30 | speech = np.frombuffer(buf, dtype='int16') 31 | 32 | # Normalize to [-1, 1] 33 | speech = speech.astype(np.float32) / 32768.0 34 | 35 | print(f"Audio: rate={rate} Hz, shape={speech.shape}, range=[{speech.min():.4f}, {speech.max():.4f}]") 36 | 37 | # Load model 38 | print("\n[2] Loading model with BBD enabled...") 39 | speech2text = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 40 | print("✅ Model loaded") 41 | print(f"BBD enabled: {speech2text.beam_search.use_bbd}") 42 | print(f"BBD conservative: {speech2text.beam_search.bbd_conservative}") 43 | 44 | # Process as a SINGLE chunk with is_final=True 45 | print("\n[3] Processing as single chunk...") 46 | results = speech2text(speech=speech, is_final=True) 47 | 48 | print("\n" + "="*80) 49 | print("RESULTS") 50 | print("="*80) 51 | 52 | if results and len(results) > 0: 53 | text, tokens, token_ids = results[0] 54 | print(f"\n✅ Text: '{text}'") 55 | print(f"\n✅ Token count: {len(tokens)}") 56 | 57 | # Check for Arabic token 58 | if 'م' in tokens: 59 | count = tokens.count('م') 60 | print(f"\n⚠️ Arabic 'م' appears {count} times") 61 | else: 62 | print(f"\n✅ No Arabic characters!") 63 | else: 64 | print("\n❌ No results!") 65 | 66 | print("\n" + "="*80) 67 | print("BBD DEBUG TEST COMPLETE") 68 | print("="*80) 69 | -------------------------------------------------------------------------------- /tests/test_batch_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test batch (non-streaming) mode.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("BATCH MODE TEST") 12 | print("="*80) 13 | 14 | # Load audio 15 | print("\n[1] Loading audio...") 16 | from speechcatcher.speechcatcher import convert_inputfile 17 | 18 | os.makedirs('.tmp/', exist_ok=True) 19 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 20 | if not os.path.exists(wavfile_path): 21 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 22 | 23 | with wave.open(wavfile_path, 'rb') as wavfile_in: 24 | rate = wavfile_in.getframerate() 25 | buf = wavfile_in.readframes(-1) 26 | speech = np.frombuffer(buf, dtype='int16') 27 | 28 | print(f"Audio: rate={rate} Hz, shape={speech.shape}") 29 | 30 | # Load ESPnet model in BATCH mode (no streaming) 31 | print("\n[2] Loading ESPnet model in BATCH mode...") 32 | from espnet2.bin.asr_inference import Speech2Text 33 | 34 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 35 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 36 | 37 | espnet_batch = Speech2Text( 38 | asr_train_config=config_path, 39 | asr_model_file=model_path, 40 | device="cpu", 41 | beam_size=5, 42 | ) 43 | print("✅ ESPnet batch model loaded") 44 | 45 | # Transcribe with batch mode 46 | print("\n[3] Running batch transcription...") 47 | 48 | # ESPnet expects float audio normalized to [-1, 1] 49 | speech_float = speech.astype(np.float32) / 32768.0 50 | print(f"Audio range: [{speech_float.min():.4f}, {speech_float.max():.4f}]") 51 | 52 | nbests = espnet_batch(speech_float) 53 | text = nbests[0][0] # Get best hypothesis text 54 | 55 | print("\n" + "="*80) 56 | print("RESULTS") 57 | print("="*80) 58 | 59 | print(f"\n✅ Transcription: '{text}'") 60 | 61 | print("\n" + "="*80) 62 | print("BATCH MODE TEST COMPLETE") 63 | print("="*80) 64 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "speechcatcher" 7 | version = "0.5.0" 8 | description = "Speechcatcher is an open source toolbox for transcribing speech from media files (audio/video)." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = {text = "MIT"} 12 | authors = [ 13 | {name = "Benjamin Milde", email = "bmilde@users.noreply.github.com"} 14 | ] 15 | keywords = ["asr", "speech-recognition", "streaming", "espnet"] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: 3.13", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | ] 27 | 28 | dependencies = [ 29 | "setuptools", 30 | "torch", 31 | "torchaudio", 32 | "ffmpeg-python", 33 | "espnet_streaming_decoder @ git+https://github.com/speechcatcher-asr/espnet_streaming_decoder", 34 | # Using local espnet_model_zoo fork that eliminates espnet dependency 35 | "espnet_model_zoo @ git+https://github.com/speechcatcher-asr/espnet_model_zoo", 36 | "soundfile", 37 | "six", 38 | "pyaudio", 39 | "python_speech_features", 40 | "scipy", 41 | "tqdm", 42 | "somajo", 43 | "websockets>=14.2", 44 | # CRITICAL: Override older sentencepiece pins from subdependencies 45 | # sentencepiece 0.2.1+ has Python 3.13 wheels and fixes build issues 46 | "sentencepiece>=0.2.1", 47 | ] 48 | 49 | [project.urls] 50 | Homepage = "https://github.com/speechcatcher-asr/speechcatcher" 51 | Repository = "https://github.com/speechcatcher-asr/speechcatcher.git" 52 | 53 | [project.scripts] 54 | speechcatcher = "speechcatcher.speechcatcher:main" 55 | speechcatcher_compute_wer = "speechcatcher.compute_wer:main" 56 | speechcatcher_server = "speechcatcher.speechcatcher_server:main" 57 | speechcatcher_vosk_test_client = "speechcatcher.vosk_test_client:main" 58 | speechcatcher_simple_endpointing = "speechcatcher.simple_endpointing:main" 59 | 60 | [tool.setuptools] 61 | packages = ["speechcatcher"] 62 | 63 | [tool.setuptools.package-data] 64 | speechcatcher = [] 65 | -------------------------------------------------------------------------------- /tests/test_espnet_runtime_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check ESPnet's actual runtime beam search configuration.""" 3 | 4 | import torch 5 | 6 | print("="*80) 7 | print("ESPnet RUNTIME CONFIGURATION") 8 | print("="*80) 9 | 10 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 11 | 12 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 13 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 14 | 15 | print("\n[Loading ESPnet with beam_size=5, ctc_weight=0.3...]") 16 | espnet_s2t = ESPnetStreaming( 17 | asr_train_config=config_path, 18 | asr_model_file=model_path, 19 | device="cpu", 20 | beam_size=5, 21 | ctc_weight=0.3, 22 | ) 23 | 24 | print("✅ Loaded\n") 25 | 26 | # Check beam search configuration 27 | if hasattr(espnet_s2t, 'beam_search'): 28 | bs = espnet_s2t.beam_search 29 | print("Beam Search Object:") 30 | print(f" Type: {type(bs).__name__}") 31 | print(f" beam_size: {getattr(bs, 'beam_size', 'N/A')}") 32 | print(f" sos: {getattr(bs, 'sos', 'N/A')}") 33 | print(f" eos: {getattr(bs, 'eos', 'N/A')}") 34 | 35 | if hasattr(bs, 'weights'): 36 | print(f"\n Scorer Weights:") 37 | for name, weight in bs.weights.items(): 38 | print(f" {name}: {weight}") 39 | 40 | if hasattr(bs, 'scorers'): 41 | print(f"\n Scorers:") 42 | for name, scorer in bs.scorers.items(): 43 | print(f" {name}: {type(scorer).__name__}") 44 | 45 | # Check if it's blockwise synchronous 46 | print("\nBlockwise Settings:") 47 | print(f" block_size: {getattr(bs, 'block_size', 'N/A')}") 48 | print(f" hop_size: {getattr(bs, 'hop_size', 'N/A')}") 49 | print(f" look_ahead: {getattr(bs, 'look_ahead', 'N/A')}") 50 | 51 | # Check BBD settings 52 | if hasattr(bs, 'use_bbd'): 53 | print(f"\nBBD Settings:") 54 | print(f" use_bbd: {bs.use_bbd}") 55 | if hasattr(bs, 'bbd_conservative'): 56 | print(f" bbd_conservative: {bs.bbd_conservative}") 57 | 58 | print("\n" + "="*80) 59 | -------------------------------------------------------------------------------- /tests/test_encoder_difference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check encoder output difference.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 22 | from speechcatcher.speechcatcher import load_model, tags 23 | 24 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 25 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 26 | 27 | espnet_s2t = ESPnetStreaming( 28 | asr_train_config=config_path, 29 | asr_model_file=model_path, 30 | device="cpu", 31 | beam_size=5, 32 | ctc_weight=0.3, 33 | ) 34 | espnet_s2t.reset() 35 | 36 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 37 | our_s2t.reset() 38 | 39 | chunk_size = 8000 40 | for chunk_idx in range(5): 41 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 42 | espnet_s2t(chunk, is_final=False) 43 | our_s2t(chunk, is_final=False) 44 | 45 | espnet_enc = espnet_s2t.beam_search.encbuffer[:40].unsqueeze(0) 46 | our_enc = our_s2t.beam_search.encoder_buffer[:, :40, :] 47 | 48 | print("Encoder output comparison (40 frames):") 49 | print(f"ESPnet shape: {espnet_enc.shape}") 50 | print(f"Ours shape: {our_enc.shape}") 51 | print(f"\nMax diff: {(espnet_enc - our_enc).abs().max().item():.10f}") 52 | print(f"Mean diff: {(espnet_enc - our_enc).abs().mean().item():.10f}") 53 | print(f"\nClose (atol=1e-5): {torch.allclose(espnet_enc, our_enc, atol=1e-5)}") 54 | print(f"Close (atol=1e-4): {torch.allclose(espnet_enc, our_enc, atol=1e-4)}") 55 | print(f"Close (atol=1e-3): {torch.allclose(espnet_enc, our_enc, atol=1e-3)}") 56 | -------------------------------------------------------------------------------- /tests/test_espnet_decoding_start.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check when ESPnet starts decoding.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 22 | 23 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 24 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 25 | 26 | espnet_s2t = ESPnetStreaming( 27 | asr_train_config=config_path, 28 | asr_model_file=model_path, 29 | device="cpu", 30 | beam_size=5, 31 | ctc_weight=0.3, 32 | ) 33 | espnet_s2t.reset() 34 | 35 | print("="*80) 36 | print("ESPnet Decoding Start Timing") 37 | print("="*80) 38 | 39 | chunk_size = 8000 40 | for chunk_idx in range(6): 41 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 42 | 43 | results = espnet_s2t(chunk, is_final=False) 44 | 45 | # Check beam state 46 | if hasattr(espnet_s2t, 'beam_search'): 47 | bs = espnet_s2t.beam_search 48 | if hasattr(bs, 'encbuffer') and bs.encbuffer is not None: 49 | print(f"\nChunk {chunk_idx+1}: encbuffer shape = {bs.encbuffer.shape}") 50 | else: 51 | print(f"\nChunk {chunk_idx+1}: encbuffer = None") 52 | 53 | if hasattr(bs, 'running') and bs.running: 54 | print(f" Running hypotheses: {len(bs.running)}") 55 | print(f" Top hypothesis: {bs.running[0].yseq[:5] if len(bs.running[0].yseq) < 5 else bs.running[0].yseq[:5]}") 56 | 57 | if results: 58 | print(f" OUTPUT: '{results[0][0]}'") 59 | else: 60 | print(f" OUTPUT: (none)") 61 | 62 | print("\n" + "="*80) 63 | -------------------------------------------------------------------------------- /tests/test_espnet_streaming_beams.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check ESPnet's beam states during streaming.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 22 | 23 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 24 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 25 | 26 | espnet_s2t = ESPnetStreaming( 27 | asr_train_config=config_path, 28 | asr_model_file=model_path, 29 | device="cpu", 30 | beam_size=5, 31 | ctc_weight=0.3, 32 | ) 33 | espnet_s2t.reset() 34 | 35 | chunk_size = 8000 36 | 37 | for chunk_idx in range(len(speech) // chunk_size): 38 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 39 | result = espnet_s2t(chunk, is_final=False) 40 | 41 | # Check beam state 42 | if hasattr(espnet_s2t, 'beam_search'): 43 | bs = espnet_s2t.beam_search 44 | if hasattr(bs, 'running_hyps') and bs.running_hyps: 45 | print(f"\n{'='*80}") 46 | print(f"After Chunk {chunk_idx+1}:") 47 | if result: 48 | print(f" Output: '{result[0][0]}'") 49 | else: 50 | print(f" Output: (none)") 51 | print(f" Beam hypotheses: {len(bs.running_hyps)}") 52 | for i, hyp in enumerate(bs.running_hyps[:3]): 53 | if torch.is_tensor(hyp): 54 | yseq = hyp.tolist()[:15] 55 | print(f" [{i+1}] yseq={yseq}") 56 | else: 57 | yseq = hyp.yseq.tolist()[:15] 58 | score = hyp.score if hasattr(hyp, 'score') else 0.0 59 | print(f" [{i+1}] score={score:.4f}, yseq={yseq}") 60 | -------------------------------------------------------------------------------- /tests/test_our_transcribe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test our full transcription pipeline.""" 3 | 4 | from speechcatcher import speechcatcher 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | if __name__ == '__main__': 11 | print("="*80) 12 | print("OUR FULL PIPELINE TEST") 13 | print("="*80) 14 | 15 | # Load model 16 | print("\n[1] Loading model...") 17 | short_tag = 'de_streaming_transformer_xl' 18 | speech2text = speechcatcher.load_model(speechcatcher.tags[short_tag], beam_size=5, quiet=True) 19 | print("✅ Model loaded") 20 | 21 | # Load audio 22 | print("\n[2] Loading audio...") 23 | os.makedirs('.tmp/', exist_ok=True) 24 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 25 | if not os.path.exists(wavfile_path): 26 | speechcatcher.convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 27 | 28 | with wave.open(wavfile_path, 'rb') as wavfile_in: 29 | rate = wavfile_in.getframerate() 30 | buf = wavfile_in.readframes(-1) 31 | speech = np.frombuffer(buf, dtype='int16') 32 | 33 | print(f"Audio: rate={rate} Hz, shape={speech.shape}") 34 | 35 | # Transcribe 36 | print("\n[3] Running transcription...") 37 | complete_text, paragraphs = speechcatcher.recognize(speech2text, speech, rate, quiet=False, progress=True) 38 | 39 | print("\n" + "="*80) 40 | print("RESULTS") 41 | print("="*80) 42 | 43 | print(f"\n✅ Complete text: '{complete_text}'") 44 | print(f"\n✅ Number of paragraphs: {len(paragraphs)}") 45 | 46 | if paragraphs: 47 | for i, para in enumerate(paragraphs): 48 | print(f"\nParagraph {i+1}:") 49 | print(f" Text: '{para.get('text', '')}'") 50 | print(f" Start: {para.get('start', 0):.2f}s, End: {para.get('end', 0):.2f}s") 51 | 52 | # Check token IDs if available 53 | if 'tokens' in para: 54 | tokens = para['tokens'] 55 | print(f" Tokens ({len(tokens)}): {tokens[:10]}...") 56 | 57 | # Check for problematic token 1023 58 | # Note: tokens might be strings not IDs in this output 59 | # Let's check the actual token representation 60 | if isinstance(tokens[0], str): 61 | # Tokens are strings (BPE pieces) 62 | if 'م' in tokens: 63 | count = tokens.count('م') 64 | print(f" ⚠️ Arabic 'م' appears {count} times in tokens") 65 | else: 66 | print(f" ✅ No Arabic characters in tokens") 67 | 68 | print("\n" + "="*80) 69 | print("OUR FULL PIPELINE TEST COMPLETE") 70 | print("="*80) 71 | -------------------------------------------------------------------------------- /tests/test_full_transcription_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test full transcription with debug output.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable debug logging 12 | logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s') 13 | 14 | print("="*80) 15 | print("FULL TRANSCRIPTION TEST") 16 | print("="*80) 17 | 18 | # Load audio 19 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 20 | 21 | os.makedirs('.tmp/', exist_ok=True) 22 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 23 | if not os.path.exists(wavfile_path): 24 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 25 | 26 | with wave.open(wavfile_path, 'rb') as wavfile_in: 27 | rate = wavfile_in.getframerate() 28 | buf = wavfile_in.readframes(-1) 29 | raw_audio = np.frombuffer(buf, dtype='int16') 30 | 31 | speech = raw_audio.astype(np.float32) / 32768.0 32 | print(f"Audio: {len(speech)} samples ({len(speech)/rate:.2f}s)") 33 | 34 | # Load model 35 | print("\nLoading model...") 36 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 37 | our_s2t.reset() 38 | print("✅ Loaded\n") 39 | 40 | # Process chunks 41 | chunk_size = 8000 42 | num_chunks = (len(speech) + chunk_size - 1) // chunk_size 43 | 44 | print(f"Processing {num_chunks} chunks of {chunk_size} samples each\n") 45 | 46 | for chunk_idx in range(min(10, num_chunks)): 47 | start = chunk_idx * chunk_size 48 | end = min((chunk_idx + 1) * chunk_size, len(speech)) 49 | chunk = speech[start:end] 50 | is_final = (chunk_idx == num_chunks - 1) 51 | 52 | print(f"{'='*60}") 53 | print(f"CHUNK {chunk_idx+1}/{num_chunks} (is_final={is_final})") 54 | print(f"{'='*60}") 55 | 56 | with torch.no_grad(): 57 | results = our_s2t(chunk, is_final=is_final) 58 | 59 | # Debug hypotheses 60 | if our_s2t.beam_state and our_s2t.beam_state.hypotheses: 61 | best_hyp = our_s2t.beam_state.hypotheses[0] 62 | print(f"Best hypothesis:") 63 | print(f" yseq: {best_hyp.yseq.tolist()}") 64 | print(f" score: {best_hyp.score:.4f}") 65 | 66 | # Check buffer 67 | if our_s2t.beam_search.encoder_buffer is not None: 68 | print(f"Encoder buffer: {our_s2t.beam_search.encoder_buffer.shape}") 69 | print(f"Processed blocks: {our_s2t.beam_search.processed_block}") 70 | 71 | # Results 72 | if results and len(results) > 0: 73 | text, tokens, token_ids = results[0] 74 | print(f"Result text: '{text}'") 75 | print(f"Token IDs: {token_ids}") 76 | else: 77 | print("Result: (empty)") 78 | 79 | print() 80 | 81 | print("="*80) 82 | print("TEST COMPLETE") 83 | print("="*80) 84 | -------------------------------------------------------------------------------- /tests/test_espnet_beam_first_step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Trace ESPnet's beam search at first decoding step.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable detailed logging 12 | logging.basicConfig(level=logging.DEBUG, format='%(name)s: %(message)s') 13 | 14 | from speechcatcher.speechcatcher import convert_inputfile 15 | 16 | os.makedirs('.tmp/', exist_ok=True) 17 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 18 | if not os.path.exists(wavfile_path): 19 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 20 | 21 | with wave.open(wavfile_path, 'rb') as f: 22 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 23 | speech = raw_audio.astype(np.float32) / 32768.0 24 | 25 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 26 | 27 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 28 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 29 | 30 | espnet_s2t = ESPnetStreaming( 31 | asr_train_config=config_path, 32 | asr_model_file=model_path, 33 | device="cpu", 34 | beam_size=5, 35 | ctc_weight=0.3, 36 | ) 37 | espnet_s2t.reset() 38 | 39 | chunk_size = 8000 40 | 41 | # Process chunks 1-4 (no decoding yet) 42 | for chunk_idx in range(4): 43 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 44 | espnet_s2t(chunk, is_final=False) 45 | 46 | print("="*80) 47 | print("Before chunk 5 (no decoding yet)") 48 | print("="*80) 49 | print(f"encbuffer: {espnet_s2t.beam_search.encbuffer.shape if espnet_s2t.beam_search.encbuffer is not None else None}") 50 | print(f"running_hyps: {len(espnet_s2t.beam_search.running_hyps) if espnet_s2t.beam_search.running_hyps else 0}") 51 | 52 | # Process chunk 5 (first decoding) 53 | print("\n" + "="*80) 54 | print("Processing chunk 5 (with DEBUG logging)...") 55 | print("="*80) 56 | 57 | chunk = speech[4*chunk_size : min(5*chunk_size, len(speech))] 58 | result = espnet_s2t(chunk, is_final=False) 59 | 60 | print("\n" + "="*80) 61 | print("After chunk 5") 62 | print("="*80) 63 | print(f"Output: {result}") 64 | 65 | if espnet_s2t.beam_search.running_hyps: 66 | print(f"\nTop 5 hypotheses:") 67 | for i, hyp in enumerate(espnet_s2t.beam_search.running_hyps[:5]): 68 | print(f" [{i+1}] score={hyp.score:.6f}, yseq={hyp.yseq.tolist()}") 69 | -------------------------------------------------------------------------------- /tests/test_encoder_buffer_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Debug encoder buffer accumulation.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable ALL debug logging 12 | logging.basicConfig(level=logging.DEBUG, format='%(name)s - %(levelname)s - %(message)s') 13 | 14 | print("="*80) 15 | print("ENCODER BUFFER DEBUG") 16 | print("="*80) 17 | 18 | # Load audio 19 | print("\n[1] Loading audio...") 20 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 21 | 22 | os.makedirs('.tmp/', exist_ok=True) 23 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 24 | if not os.path.exists(wavfile_path): 25 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 26 | 27 | with wave.open(wavfile_path, 'rb') as wavfile_in: 28 | rate = wavfile_in.getframerate() 29 | buf = wavfile_in.readframes(-1) 30 | raw_audio = np.frombuffer(buf, dtype='int16') 31 | 32 | speech = raw_audio.astype(np.float32) / 32768.0 33 | print(f"Audio: {len(speech)} samples ({len(speech)/rate:.2f}s)") 34 | 35 | # Load model 36 | print("\n[2] Loading model...") 37 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 38 | our_s2t.reset() 39 | print(f"✅ Loaded") 40 | 41 | # Process in chunks with detailed logging 42 | print("\n" + "="*80) 43 | print("PROCESSING CHUNKS WITH DEBUG") 44 | print("="*80) 45 | 46 | chunk_size = 8000 47 | chunks = [ 48 | speech[i*chunk_size : min((i+1)*chunk_size, len(speech))] 49 | for i in range((len(speech) + chunk_size - 1) // chunk_size) 50 | ] 51 | 52 | print(f"\nTotal chunks: {len(chunks)}") 53 | print(f"Chunk sizes: {[len(c) for c in chunks]}\n") 54 | 55 | for chunk_idx, chunk in enumerate(chunks[:5]): # Process first 5 chunks 56 | is_final = (chunk_idx == len(chunks) - 1) 57 | 58 | print(f"\n{'='*80}") 59 | print(f"CHUNK {chunk_idx+1}/{len(chunks)} (is_final={is_final})") 60 | print(f"{'='*80}") 61 | 62 | print(f"\n[Before] Encoder buffer: {our_s2t.beam_search.encoder_buffer.shape if our_s2t.beam_search.encoder_buffer is not None else 'None'}") 63 | print(f"[Before] Processed blocks: {our_s2t.beam_search.processed_block}") 64 | 65 | # Process chunk 66 | with torch.no_grad(): 67 | results = our_s2t(chunk, is_final=is_final) 68 | 69 | print(f"\n[After] Encoder buffer: {our_s2t.beam_search.encoder_buffer.shape if our_s2t.beam_search.encoder_buffer is not None else 'None'}") 70 | print(f"[After] Processed blocks: {our_s2t.beam_search.processed_block}") 71 | 72 | if results and len(results) > 0: 73 | text, tokens, token_ids = results[0] 74 | print(f"\n✅ Result: '{text}' ({len(tokens)} tokens)") 75 | else: 76 | print(f"\n⚠️ No results yet") 77 | 78 | print("\n" + "="*80) 79 | print("DEBUG COMPLETE") 80 | print("="*80) 81 | -------------------------------------------------------------------------------- /tests/test_state_preservation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test if encoder states are preserved between chunks.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable debug logging for speechcatcher modules 12 | logging.basicConfig(level=logging.DEBUG, format='%(name)s - %(levelname)s - %(message)s') 13 | logging.getLogger('speechcatcher').setLevel(logging.DEBUG) 14 | 15 | print("="*80) 16 | print("STATE PRESERVATION TEST") 17 | print("="*80) 18 | 19 | # Load audio 20 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 21 | 22 | os.makedirs('.tmp/', exist_ok=True) 23 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 24 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 25 | 26 | with wave.open(wavfile_path, 'rb') as wavfile_in: 27 | buf = wavfile_in.readframes(-1) 28 | raw_audio = np.frombuffer(buf, dtype='int16') 29 | 30 | speech = raw_audio.astype(np.float32) / 32768.0 31 | 32 | # Load model 33 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 34 | our_s2t.reset() 35 | 36 | # Process chunks 37 | chunk_size = 8000 38 | chunks = [speech[i*chunk_size : min((i+1)*chunk_size, len(speech))] for i in range(5)] 39 | 40 | print("\nProcessing first 5 chunks...\n") 41 | 42 | for chunk_idx, chunk in enumerate(chunks): 43 | is_final = False 44 | 45 | print(f"CHUNK {chunk_idx+1}:") 46 | print(f" Before call:") 47 | print(f" beam_state: {our_s2t.beam_state}") 48 | if our_s2t.beam_state: 49 | print(f" beam_state.encoder_states: {our_s2t.beam_state.encoder_states}") 50 | if our_s2t.beam_state.encoder_states: 51 | print(f" encoder_states keys: {list(our_s2t.beam_state.encoder_states.keys())[:3]}") 52 | print(f" beam_search.encoder_buffer: {our_s2t.beam_search.encoder_buffer.shape if our_s2t.beam_search.encoder_buffer is not None else 'None'}") 53 | 54 | # Process chunk through full pipeline 55 | with torch.no_grad(): 56 | results = our_s2t(chunk, is_final=is_final) 57 | 58 | print(f" After call:") 59 | print(f" beam_state: {our_s2t.beam_state}") 60 | if our_s2t.beam_state: 61 | print(f" beam_state.encoder_states: {our_s2t.beam_state.encoder_states}") 62 | if our_s2t.beam_state.encoder_states: 63 | print(f" encoder_states keys: {list(our_s2t.beam_state.encoder_states.keys())[:3]}") 64 | print(f" beam_search.encoder_buffer: {our_s2t.beam_search.encoder_buffer.shape if our_s2t.beam_search.encoder_buffer is not None else 'None'}") 65 | 66 | if results and len(results) > 0: 67 | text, _, _ = results[0] 68 | print(f" Result: '{text}'") 69 | else: 70 | print(f" Result: (empty)") 71 | print() 72 | 73 | print("="*80) 74 | print("STATE PRESERVATION TEST COMPLETE") 75 | print("="*80) 76 | -------------------------------------------------------------------------------- /tests/test_ctc_timing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test CTC timing with detailed logging.""" 3 | 4 | import logging 5 | import sys 6 | import time 7 | 8 | # Enable detailed logging 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 12 | stream=sys.stderr 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # Import after logging setup 18 | from speechcatcher.speech2text_streaming import Speech2TextStreaming 19 | 20 | logger.info("Loading model...") 21 | # Manually load with higher CTC weight 22 | from espnet_model_zoo.downloader import ModelDownloader 23 | from pathlib import Path 24 | from speechcatcher.speech2text_streaming import Speech2TextStreaming 25 | 26 | espnet_model_downloader = ModelDownloader("~/.cache/espnet") 27 | info = espnet_model_downloader.download_and_unpack( 28 | "speechcatcher/speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024", 29 | quiet=False 30 | ) 31 | 32 | model_dir = None 33 | for key in ['asr_model_file', 'asr_train_config', 'model_file', 'train_config']: 34 | if key in info and info[key]: 35 | model_dir = Path(info[key]).parent 36 | break 37 | 38 | logger.info(f"Model dir: {model_dir}") 39 | 40 | speech2text = Speech2TextStreaming( 41 | model_dir=model_dir, 42 | beam_size=10, 43 | ctc_weight=0.0, # DISABLE CTC to test decoder-only 44 | device="cpu", 45 | dtype="float32" 46 | ) 47 | 48 | logger.info("Model loaded successfully") 49 | 50 | # Load audio 51 | import torchaudio 52 | logger.info("Loading audio file...") 53 | waveform, sample_rate = torchaudio.load("Neujahrsansprache_5s.mp4") 54 | 55 | # Resample to 16kHz if needed 56 | if sample_rate != 16000: 57 | logger.info(f"Resampling from {sample_rate}Hz to 16000Hz") 58 | waveform = torchaudio.functional.resample(waveform, sample_rate, 16000) 59 | sample_rate = 16000 60 | 61 | # Convert to mono 62 | if waveform.shape[0] > 1: 63 | waveform = waveform.mean(dim=0, keepdim=True) 64 | 65 | logger.info(f"Audio loaded: shape={waveform.shape}, duration={waveform.shape[1]/sample_rate:.2f}s") 66 | 67 | # Check beam search configuration 68 | logger.info(f"Beam search scorers: {list(speech2text.beam_search.scorers.keys())}") 69 | logger.info(f"Beam search weights: {speech2text.beam_search.weights}") 70 | 71 | # Process audio 72 | start_time = time.time() 73 | logger.info("Starting transcription...") 74 | 75 | try: 76 | results = speech2text(speech=waveform.squeeze(0).numpy(), is_final=True) 77 | elapsed = time.time() - start_time 78 | logger.info(f"Transcription complete in {elapsed:.2f}s") 79 | 80 | print("\n=== RESULTS ===") 81 | print(results) 82 | 83 | except KeyboardInterrupt: 84 | elapsed = time.time() - start_time 85 | logger.error(f"Interrupted after {elapsed:.2f}s") 86 | sys.exit(1) 87 | except Exception as e: 88 | elapsed = time.time() - start_time 89 | logger.error(f"Error after {elapsed:.2f}s: {e}", exc_info=True) 90 | sys.exit(1) 91 | -------------------------------------------------------------------------------- /tests/test_streaming_chunks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test BBD with multiple streaming chunks.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | # Enable debug logging for beam search 12 | logging.basicConfig(level=logging.INFO, format='%(message)s') 13 | logging.getLogger('speechcatcher.beam_search.beam_search').setLevel(logging.DEBUG) 14 | 15 | print("="*80) 16 | print("STREAMING CHUNKS TEST (BBD)") 17 | print("="*80) 18 | 19 | # Load audio 20 | print("\n[1] Loading audio...") 21 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 22 | 23 | os.makedirs('.tmp/', exist_ok=True) 24 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 25 | if not os.path.exists(wavfile_path): 26 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 27 | 28 | with wave.open(wavfile_path, 'rb') as wavfile_in: 29 | rate = wavfile_in.getframerate() 30 | buf = wavfile_in.readframes(-1) 31 | speech = np.frombuffer(buf, dtype='int16') 32 | 33 | # Normalize to [-1, 1] 34 | speech = speech.astype(np.float32) / 32768.0 35 | 36 | print(f"Audio: rate={rate} Hz, shape={speech.shape}, range=[{speech.min():.4f}, {speech.max():.4f}]") 37 | 38 | # Load model 39 | print("\n[2] Loading model with BBD enabled...") 40 | speech2text = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 41 | print("✅ Model loaded") 42 | print(f"BBD enabled: {speech2text.beam_search.use_bbd}") 43 | print(f"BBD conservative: {speech2text.beam_search.bbd_conservative}") 44 | 45 | # Reset state 46 | speech2text.reset() 47 | 48 | # Process in CHUNKS (streaming mode) 49 | print("\n[3] Processing in streaming chunks...") 50 | 51 | chunk_size = 8000 # ~0.5s at 16kHz 52 | num_chunks = (len(speech) + chunk_size - 1) // chunk_size 53 | 54 | print(f"Total chunks: {num_chunks}") 55 | 56 | results_list = [] 57 | for i in range(num_chunks): 58 | start = i * chunk_size 59 | end = min((i + 1) * chunk_size, len(speech)) 60 | chunk = speech[start:end] 61 | is_final = (i == num_chunks - 1) 62 | 63 | print(f"\n--- Chunk {i+1}/{num_chunks} (is_final={is_final}) ---") 64 | results = speech2text(speech=chunk, is_final=is_final) 65 | 66 | if results and len(results) > 0: 67 | text, tokens, token_ids = results[0] 68 | print(f"Result: '{text[:50]}...' ({len(tokens)} tokens)") 69 | results_list.append((text, tokens, token_ids)) 70 | 71 | print("\n" + "="*80) 72 | print("RESULTS") 73 | print("="*80) 74 | 75 | if results_list: 76 | final_text, final_tokens, final_token_ids = results_list[-1] 77 | print(f"\n✅ Final text: '{final_text}'") 78 | print(f"\n✅ Token count: {len(final_tokens)}") 79 | 80 | # Check for Arabic token 81 | if 'م' in final_tokens: 82 | count = final_tokens.count('م') 83 | print(f"\n⚠️ Arabic 'م' appears {count} times") 84 | else: 85 | print(f"\n✅ No Arabic characters!") 86 | else: 87 | print("\n❌ No results!") 88 | 89 | print("\n" + "="*80) 90 | print("STREAMING CHUNKS TEST COMPLETE") 91 | print("="*80) 92 | -------------------------------------------------------------------------------- /tests/test_final_chunk_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Debug final chunk processing.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | # Load our model 22 | from speechcatcher.speechcatcher import load_model, tags 23 | import json 24 | 25 | with open("/tmp/espnet_token_list.json", "r") as f: 26 | token_list = json.load(f) 27 | 28 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 29 | our_s2t.reset() 30 | 31 | chunk_size = 8000 32 | 33 | print("="*80) 34 | print("PROCESSING CHUNKS") 35 | print("="*80) 36 | 37 | # Process all chunks except the last 38 | for chunk_idx in range(len(speech) // chunk_size): 39 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 40 | result = our_s2t(chunk, is_final=False) 41 | if result: 42 | print(f"Chunk {chunk_idx+1}: '{result[0][0]}'") 43 | 44 | # Now process final chunk 45 | final_chunk_idx = len(speech) // chunk_size 46 | chunk = speech[final_chunk_idx*chunk_size:] 47 | print(f"\n{'='*80}") 48 | print(f"FINAL CHUNK {final_chunk_idx+1} (is_final=True)") 49 | print(f"{'='*80}") 50 | print(f"Chunk length: {len(chunk)} samples") 51 | 52 | # Check beam state before final 53 | if hasattr(our_s2t, 'beam_state'): 54 | print(f"\nBeam state BEFORE final chunk:") 55 | print(f" Output index: {our_s2t.beam_state.output_index}") 56 | print(f" Hypotheses: {len(our_s2t.beam_state.hypotheses)}") 57 | for i, hyp in enumerate(our_s2t.beam_state.hypotheses[:3]): 58 | token_ids = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 59 | tokens = [token_list[tid] for tid in token_ids] 60 | print(f" [{i+1}] score={hyp.score:.4f}") 61 | print(f" yseq={token_ids[:20]}") 62 | print(f" tokens={' '.join(tokens[:20])}") 63 | 64 | result = our_s2t(chunk, is_final=True) 65 | 66 | # Check beam state after final 67 | if hasattr(our_s2t, 'beam_state'): 68 | print(f"\nBeam state AFTER final chunk:") 69 | print(f" Output index: {our_s2t.beam_state.output_index}") 70 | print(f" Hypotheses: {len(our_s2t.beam_state.hypotheses)}") 71 | for i, hyp in enumerate(our_s2t.beam_state.hypotheses[:3]): 72 | token_ids = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 73 | tokens = [token_list[tid] for tid in token_ids] 74 | print(f" [{i+1}] score={hyp.score:.4f}") 75 | print(f" yseq={token_ids[:20]}") 76 | print(f" tokens={' '.join(tokens[:20])}") 77 | 78 | if result: 79 | print(f"\nFinal output: '{result[0][0]}'") 80 | print(f"Token IDs: {result[0][2][:30]}") 81 | print(f"Tokens: {result[0][1][:30]}") 82 | 83 | print("\n" + "="*80) 84 | -------------------------------------------------------------------------------- /tests/test_espnet_full.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test ESPnet's full transcription pipeline.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("ESPNET FULL PIPELINE TEST") 12 | print("="*80) 13 | 14 | # Load audio 15 | print("\n[1] Loading audio...") 16 | from speechcatcher.speechcatcher import convert_inputfile 17 | 18 | os.makedirs('.tmp/', exist_ok=True) 19 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 20 | if not os.path.exists(wavfile_path): 21 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 22 | 23 | with wave.open(wavfile_path, 'rb') as wavfile_in: 24 | rate = wavfile_in.getframerate() 25 | buf = wavfile_in.readframes(-1) 26 | speech = np.frombuffer(buf, dtype='int16') 27 | 28 | print(f"Audio: rate={rate} Hz, shape={speech.shape}") 29 | 30 | # Load ESPnet model 31 | print("\n[2] Loading ESPnet model...") 32 | from espnet2.bin.asr_inference_streaming import Speech2TextStreaming as ESPnetS2T 33 | 34 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 35 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 36 | 37 | espnet_model = ESPnetS2T( 38 | asr_train_config=config_path, 39 | asr_model_file=model_path, 40 | device="cpu", 41 | beam_size=5, # Match our beam size 42 | ) 43 | print("✅ ESPnet model loaded") 44 | 45 | # Transcribe 46 | print("\n[3] Running ESPnet transcription...") 47 | 48 | # ESPnet expects float audio normalized to [-1, 1] 49 | speech_float = speech.astype(np.float32) 50 | 51 | # Try different chunk sizes to match our streaming 52 | chunk_size = 8000 # About 0.5s at 16kHz 53 | 54 | results = [] 55 | for i in range(0, len(speech_float), chunk_size): 56 | chunk = speech_float[i:i + chunk_size] 57 | is_final = (i + chunk_size >= len(speech_float)) 58 | 59 | result = espnet_model(chunk, is_final=is_final) 60 | if result: 61 | results.append(result) 62 | print(f"Chunk {i//chunk_size + 1}: '{result[0]}'") 63 | 64 | print("\n" + "="*80) 65 | print("RESULTS") 66 | print("="*80) 67 | 68 | if results: 69 | final_text = results[-1][0] if results else "" 70 | print(f"\n✅ Final text: '{final_text}'") 71 | 72 | # Compare with expected 73 | expected_path = "Neujahrsansprache_5s.mp4.txt.expected" 74 | if os.path.exists(expected_path): 75 | with open(expected_path, 'r') as f: 76 | expected = f.read().strip() 77 | print(f"\n📋 Expected: '{expected}'") 78 | 79 | if final_text.strip() == expected: 80 | print("\n✅ PERFECT MATCH!") 81 | else: 82 | print("\n❌ MISMATCH!") 83 | else: 84 | print("\n❌ No results produced!") 85 | 86 | print("\n" + "="*80) 87 | print("ESPNET FULL PIPELINE TEST COMPLETE") 88 | print("="*80) 89 | -------------------------------------------------------------------------------- /tests/test_beam_search_trace.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Trace beam search evolution block by block to see why token 738 disappears.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("BEAM SEARCH TRACE (Block-by-Block)") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as f: 23 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 24 | speech = raw_audio.astype(np.float32) / 32768.0 25 | 26 | # Load ESPnet token list 27 | import json 28 | with open("/tmp/espnet_token_list.json", "r") as f: 29 | espnet_token_list = json.load(f) 30 | 31 | # Load our model - disable DEBUG logging for cleaner output 32 | print("\n[1] Loading model...") 33 | from speechcatcher.speechcatcher import load_model, tags 34 | 35 | s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 36 | s2t.reset() 37 | print("✅ Model loaded") 38 | 39 | # Process chunk by chunk and print beam state after each 40 | print("\n[2] Processing chunks with beam search trace...") 41 | chunk_size = 8000 42 | 43 | for chunk_idx in range(6): 44 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 45 | is_final = False 46 | 47 | print(f"\n{'='*80}") 48 | print(f"CHUNK {chunk_idx + 1}") 49 | print('='*80) 50 | 51 | results = s2t(chunk, is_final=is_final) 52 | 53 | # Print beam search state 54 | if hasattr(s2t, 'beam_state') and s2t.beam_state is not None: 55 | print(f"\n--- BEAM STATE ---") 56 | print(f"Output index: {s2t.beam_state.output_index}") 57 | print(f"Processed frames: {s2t.beam_state.processed_frames}") 58 | print(f"Number of hypotheses: {len(s2t.beam_state.hypotheses)}") 59 | 60 | if hasattr(s2t, 'beam_search'): 61 | print(f"Processed blocks: {s2t.beam_search.processed_block}") 62 | if s2t.beam_search.encoder_buffer is not None: 63 | print(f"Encoder buffer shape: {s2t.beam_search.encoder_buffer.shape}") 64 | 65 | print(f"\nTop 3 hypotheses:") 66 | for i, hyp in enumerate(s2t.beam_state.hypotheses[:3]): 67 | # Decode tokens 68 | token_ids = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 69 | tokens = [espnet_token_list[tid] for tid in token_ids] 70 | tokens_str = " ".join(f"{tid}:{tok}" for tid, tok in zip(token_ids, tokens)) 71 | 72 | print(f" [{i+1}] score={hyp.score:.4f}") 73 | print(f" yseq={token_ids}") 74 | print(f" tokens={tokens_str}") 75 | 76 | # Print results 77 | if results and results[0]: 78 | text, tokens, token_ids = results[0] 79 | print(f"\n--- OUTPUT ---") 80 | print(f"Text: '{text}'") 81 | print(f"Token IDs: {token_ids}") 82 | print(f"Tokens: {tokens}") 83 | else: 84 | print(f"\n--- OUTPUT ---") 85 | print("(no output)") 86 | 87 | print("\n" + "="*80) 88 | print("TRACE COMPLETE") 89 | print("="*80) 90 | -------------------------------------------------------------------------------- /docs/implementation/weight-loading.md: -------------------------------------------------------------------------------- 1 | # Weight Loading Compatibility Notes 2 | 3 | ## Model Configuration (from config.yaml) 4 | 5 | **Model Type:** `de_streaming_transformer_xl` 6 | **Encoder:** `contextual_block_transformer` (NOT Conformer for this model!) 7 | **Decoder:** `transformer` 8 | 9 | ### Encoder Config 10 | ```yaml 11 | encoder: contextual_block_transformer 12 | encoder_conf: 13 | attention_dropout_rate: 0.0 14 | attention_heads: 8 15 | block_size: 40 16 | ctx_pos_enc: true 17 | dropout_rate: 0.1 18 | hop_size: 16 19 | init_average: true 20 | input_layer: conv2d 21 | linear_units: 2048 22 | look_ahead: 16 23 | normalize_before: true 24 | num_blocks: 30 25 | output_size: 256 26 | positional_dropout_rate: 0.1 27 | ``` 28 | 29 | ### Decoder Config 30 | ```yaml 31 | decoder: transformer 32 | decoder_conf: 33 | attention_heads: 8 34 | dropout_rate: 0.1 35 | linear_units: 2048 36 | num_blocks: 14 37 | positional_dropout_rate: 0.1 38 | self_attention_dropout_rate: 0.0 39 | src_attention_dropout_rate: 0.0 40 | ``` 41 | 42 | ### Frontend Config 43 | ```yaml 44 | frontend: default 45 | frontend_conf: 46 | fs: 16k 47 | hop_length: 160 # Different from our default 128! 48 | n_fft: 512 49 | win_length: 400 # Different from our default 512! 50 | ``` 51 | 52 | ### Model Config 53 | ```yaml 54 | model_conf: 55 | ctc_weight: 0.3 56 | length_normalized_loss: false 57 | lsm_weight: 0.1 58 | ``` 59 | 60 | ### Normalization 61 | ```yaml 62 | normalize: global_mvn 63 | normalize_conf: 64 | stats_file: .../feats_stats.npz 65 | ``` 66 | 67 | ### Token List 68 | - vocab_size: 1182 (including , , ) 69 | - token_type: bpe 70 | - bpemodel path available in config 71 | 72 | ## Critical Compatibility Requirements 73 | 74 | 1. **Frontend parameters must match exactly:** 75 | - n_fft: 512 ✓ 76 | - hop_length: **160** (not 128!) 77 | - win_length: **400** (not 512!) 78 | 79 | 2. **Encoder is Transformer, NOT Conformer** 80 | - Must implement `ContextualBlockTransformerEncoder` 81 | - No ConvolutionModule needed for this model 82 | 83 | 3. **Weight loading:** 84 | - Checkpoint path: `~/.cache/espnet/.../valid.acc.ave_6best.pth` 85 | - Must map layer names correctly from ESPnet to our implementation 86 | - Preserve exact parameter names for compatibility 87 | 88 | 4. **Normalization stats:** 89 | - Load `feats_stats.npz` for global mean/variance normalization 90 | - Apply before encoder 91 | 92 | ## Implementation Priority 93 | 94 | Given the model is **Transformer-based** (not Conformer): 95 | 96 | 1. ✅ Phase 0: Layers & Attention (DONE) 97 | 2. → **Phase 1a:** Implement ContextualBlockTransformerEncoder (PRIORITY) 98 | 3. → **Phase 1b:** Implement weight loading utilities 99 | 4. → Phase 2: Implement TransformerDecoder 100 | 5. → Phase 3: Implement ESPnetASRModel wrapper 101 | 6. → Phase 4: Test weight loading with real checkpoint 102 | 7. → Phase 5: Beam search & scorers 103 | 8. → Phase 6: Speech2TextStreaming API 104 | 9. → Phase 7: End-to-end test with Neujahrsansprache.mp4 105 | 106 | ## Weight Name Mapping (ESPnet → Our Implementation) 107 | 108 | Will need to verify and document exact mapping, e.g.: 109 | - `encoder.encoders.0.self_attn.linear_q.weight` → `encoder.layers[0].self_attn.linear_q.weight` 110 | - etc. 111 | -------------------------------------------------------------------------------- /tests/test_espnet_beam_search_trace.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Trace ESPnet's beam search to see what it does with token 1023.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("ESPNET BEAM SEARCH TRACE (Why does it output 'liebe'?)") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as wavfile_in: 23 | buf = wavfile_in.readframes(-1) 24 | raw_audio = np.frombuffer(buf, dtype='int16') 25 | 26 | speech = raw_audio.astype(np.float32) / 32768.0 27 | 28 | # Load ESPnet 29 | print("\n[1] Loading ESPnet...") 30 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 31 | 32 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 33 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 34 | 35 | espnet_s2t = ESPnetStreaming( 36 | asr_train_config=config_path, 37 | asr_model_file=model_path, 38 | device="cpu", 39 | beam_size=5, 40 | ctc_weight=0.3, 41 | ) 42 | espnet_s2t.reset() 43 | print("✅ ESPnet loaded") 44 | 45 | # Process chunks 1-5 46 | print("\n[2] Processing chunks 1-5...") 47 | chunk_size = 8000 48 | 49 | for chunk_idx in range(5): 50 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 51 | is_final = False 52 | 53 | results = espnet_s2t(chunk, is_final=is_final) 54 | 55 | if chunk_idx == 4: 56 | print(f"\nChunk {chunk_idx+1} results:") 57 | if results and len(results) > 0: 58 | print(f" Results: {results}") 59 | # ESPnet returns different format 60 | result = results[0] 61 | print(f" Result type: {type(result)}") 62 | print(f" Result: {result}") 63 | else: 64 | print(f" (no results yet)") 65 | 66 | # Check beam search state 67 | print(f"\nBeam search state:") 68 | print(f" processed_block: {espnet_s2t.beam_search.processed_block}") 69 | if hasattr(espnet_s2t.beam_search, 'encbuffer'): 70 | print(f" encbuffer shape: {espnet_s2t.beam_search.encbuffer.shape if espnet_s2t.beam_search.encbuffer is not None else 'None'}") 71 | if hasattr(espnet_s2t.beam_search, 'running_hyps'): 72 | print(f" running_hyps: {len(espnet_s2t.beam_search.running_hyps)} hypotheses") 73 | # Show hypotheses 74 | for i, hyp in enumerate(espnet_s2t.beam_search.running_hyps[:3]): 75 | print(f" Hyp {i}: yseq={hyp.yseq.tolist()[:15]}, score={hyp.score:.4f}") 76 | 77 | print("\n" + "="*80) 78 | print("TEST COMPLETE") 79 | print("="*80) 80 | -------------------------------------------------------------------------------- /tests/test_full_comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Full transcription comparison between ESPnet and ours.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | # Load ESPnet 22 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 23 | 24 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 25 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 26 | 27 | print("="*80) 28 | print("FULL TRANSCRIPTION COMPARISON") 29 | print("="*80) 30 | 31 | print("\n[ESPnet]") 32 | espnet_s2t = ESPnetStreaming( 33 | asr_train_config=config_path, 34 | asr_model_file=model_path, 35 | device="cpu", 36 | beam_size=5, 37 | ctc_weight=0.3, 38 | ) 39 | espnet_s2t.reset() 40 | 41 | chunk_size = 8000 42 | espnet_outputs = [] 43 | for chunk_idx in range(len(speech) // chunk_size + 1): 44 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 45 | if len(chunk) == 0: 46 | break 47 | 48 | is_final = (chunk_idx == len(speech) // chunk_size) 49 | result = espnet_s2t(chunk, is_final=is_final) 50 | if result: 51 | espnet_outputs.append((chunk_idx+1, result[0][0])) 52 | print(f" Chunk {chunk_idx+1}: '{result[0][0]}'") 53 | 54 | # Load ours 55 | print("\n[Ours]") 56 | from speechcatcher.speechcatcher import load_model, tags 57 | 58 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 59 | our_s2t.reset() 60 | 61 | our_outputs = [] 62 | for chunk_idx in range(len(speech) // chunk_size + 1): 63 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 64 | if len(chunk) == 0: 65 | break 66 | 67 | is_final = (chunk_idx == len(speech) // chunk_size) 68 | result = our_s2t(chunk, is_final=is_final) 69 | if result: 70 | our_outputs.append((chunk_idx+1, result[0][0])) 71 | print(f" Chunk {chunk_idx+1}: '{result[0][0]}'") 72 | 73 | print("\n" + "="*80) 74 | print("COMPARISON") 75 | print("="*80) 76 | 77 | print(f"\nESPnet: {' | '.join([f'C{c}: {t}' for c, t in espnet_outputs])}") 78 | print(f"Ours: {' | '.join([f'C{c}: {t}' for c, t in our_outputs])}") 79 | 80 | if len(espnet_outputs) > 0 and len(our_outputs) > 0: 81 | espnet_final = espnet_outputs[-1][1] 82 | our_final = our_outputs[-1][1] 83 | 84 | print(f"\nFinal outputs:") 85 | print(f" ESPnet: '{espnet_final}'") 86 | print(f" Ours: '{our_final}'") 87 | print(f" Match: {'✅' if espnet_final == our_final else '❌'}") 88 | 89 | print("\n" + "="*80) 90 | -------------------------------------------------------------------------------- /tests/test_compare_beam_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Compare beam search configuration between ESPnet and ours.""" 3 | 4 | import torch 5 | import yaml 6 | 7 | print("="*80) 8 | print("BEAM SEARCH CONFIGURATION COMPARISON") 9 | print("="*80) 10 | 11 | # Load ESPnet config 12 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 13 | 14 | with open(config_path, 'r') as f: 15 | config = yaml.safe_load(f) 16 | 17 | print("\n[ESPnet Config]") 18 | print("\nBeam Search Parameters:") 19 | for key in ['beam_size', 'ctc_weight', 'penalty', 'maxlenratio', 'minlenratio']: 20 | if key in config: 21 | print(f" {key}: {config[key]}") 22 | 23 | print("\nStreaming Parameters:") 24 | for key in ['streaming', 'block_size', 'hop_size', 'look_ahead']: 25 | if key in config.get('encoder_conf', {}): 26 | print(f" encoder.{key}: {config['encoder_conf'][key]}") 27 | 28 | print("\nDecoder Parameters:") 29 | decoder_conf = config.get('decoder_conf', {}) 30 | for key in ['attention_heads', 'linear_units', 'num_blocks', 'dropout_rate']: 31 | if key in decoder_conf: 32 | print(f" decoder.{key}: {decoder_conf[key]}") 33 | 34 | # Load our implementation 35 | print("\n" + "="*80) 36 | print("[Our Implementation]") 37 | print("="*80) 38 | 39 | from speechcatcher.speechcatcher import load_model, tags 40 | 41 | s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 42 | 43 | print("\nBeam Search Parameters:") 44 | if hasattr(s2t, 'beam_search'): 45 | bs = s2t.beam_search 46 | print(f" beam_size: {bs.beam_size}") 47 | print(f" vocab_size: {bs.vocab_size}") 48 | print(f" sos_id: {bs.sos_id}") 49 | print(f" eos_id: {bs.eos_id}") 50 | print(f" block_size: {bs.block_size}") 51 | print(f" hop_size: {bs.hop_size}") 52 | print(f" look_ahead: {bs.look_ahead}") 53 | print(f" use_bbd: {bs.use_bbd}") 54 | print(f" bbd_conservative: {bs.bbd_conservative}") 55 | 56 | print("\nScorer Weights:") 57 | for name, weight in bs.weights.items(): 58 | print(f" {name}: {weight}") 59 | 60 | print("\n" + "="*80) 61 | print("COMPARISON") 62 | print("="*80) 63 | 64 | # Compare key parameters 65 | espnet_block_size = config.get('encoder_conf', {}).get('block_size', None) 66 | espnet_hop_size = config.get('encoder_conf', {}).get('hop_size', None) 67 | espnet_look_ahead = config.get('encoder_conf', {}).get('look_ahead', None) 68 | 69 | print("\nBlock Processing:") 70 | print(f" ESPnet block_size: {espnet_block_size}") 71 | print(f" Ours block_size: {bs.block_size}") 72 | print(f" Match: {'✅' if espnet_block_size == bs.block_size else '❌'}") 73 | 74 | print(f"\n ESPnet hop_size: {espnet_hop_size}") 75 | print(f" Ours hop_size: {bs.hop_size}") 76 | print(f" Match: {'✅' if espnet_hop_size == bs.hop_size else '❌'}") 77 | 78 | print(f"\n ESPnet look_ahead: {espnet_look_ahead}") 79 | print(f" Ours look_ahead: {bs.look_ahead}") 80 | print(f" Match: {'✅' if espnet_look_ahead == bs.look_ahead else '❌'}") 81 | 82 | espnet_ctc_weight = config.get('ctc_weight', 0.0) 83 | our_ctc_weight = bs.weights.get('ctc', 0.0) 84 | 85 | print(f"\nCTC Weight:") 86 | print(f" ESPnet: {espnet_ctc_weight}") 87 | print(f" Ours: {our_ctc_weight}") 88 | print(f" Match: {'✅' if abs(espnet_ctc_weight - our_ctc_weight) < 0.001 else '❌'}") 89 | 90 | print("\n" + "="*80) 91 | -------------------------------------------------------------------------------- /speechcatcher/vosk_test_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import asyncio 4 | import websockets 5 | import argparse 6 | import wave 7 | import subprocess 8 | import sys 9 | import io 10 | 11 | # Function to convert audio using ffmpeg on the fly 12 | def convert_audio(input_file, sample_rate, channels, bit_depth): 13 | ffmpeg_command = [ 14 | 'ffmpeg', 15 | '-i', input_file, 16 | '-ar', str(sample_rate), 17 | '-ac', str(channels), 18 | '-f', 'wav', 19 | '-acodec', 'pcm_s16le', # 16-bit signed little-endian 20 | '-' 21 | ] 22 | process = subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 23 | return process.stdout 24 | 25 | # Function to check if wav file matches the desired format 26 | def is_wav_compatible(wav_file, sample_rate, channels, bit_depth): 27 | with wave.open(wav_file, 'rb') as wf: 28 | return (wf.getframerate() == sample_rate and 29 | wf.getnchannels() == channels and 30 | wf.getsampwidth() == bit_depth // 8) 31 | 32 | # Function to process and send audio to ASR websocket server 33 | async def process_audio(websocket, audio_stream, sample_rate, is_wave): 34 | await websocket.send(f'{{ "config" : {{ "sample_rate" : {sample_rate} }} }}') 35 | 36 | buffer_size = int(sample_rate * 0.2) # 0.2 seconds of audio 37 | while True: 38 | if is_wave: 39 | data = audio_stream.readframes(buffer_size) # For wave files 40 | else: 41 | data = audio_stream.read(buffer_size) # For ffmpeg stream 42 | 43 | if len(data) == 0: 44 | break 45 | await websocket.send(data) 46 | print(await websocket.recv()) 47 | 48 | await websocket.send('{"eof" : 1}') 49 | print(await websocket.recv()) 50 | 51 | async def run_test(uri, input_file, sample_rate, channels, bit_depth): 52 | async with websockets.connect(uri) as websocket: 53 | if input_file.endswith('.wav') and is_wav_compatible(input_file, sample_rate, channels, bit_depth): 54 | print(f"Sending {input_file} directly as it is already in the correct format.") 55 | with wave.open(input_file, 'rb') as wf: 56 | await process_audio(websocket, wf, sample_rate, is_wave=True) 57 | else: 58 | print(f"Converting {input_file} using ffmpeg before sending.") 59 | audio_stream = convert_audio(input_file, sample_rate, channels, bit_depth) 60 | await process_audio(websocket, io.BufferedReader(audio_stream), sample_rate, is_wave=False) 61 | 62 | def main(): 63 | parser = argparse.ArgumentParser(description="Speechcatcher's Vosk-API WebSocket ASR test client with audio conversion") 64 | parser.add_argument('input_file', help='Path to the input audio file') 65 | parser.add_argument('--port', type=int, default=2700, help='WebSocket server port (default: 2700)') 66 | parser.add_argument('--host', default='localhost', help='WebSocket server host (default: localhost)') 67 | parser.add_argument('--sample-rate', type=int, default=16000, help='Sample rate in Hz (default: 16000)') 68 | parser.add_argument('--channels', type=int, default=1, help='Number of audio channels (default: 1)') 69 | parser.add_argument('--bit-depth', type=int, default=16, help='Bit depth (default: 16)') 70 | 71 | args = parser.parse_args() 72 | 73 | uri = f'ws://{args.host}:{args.port}' 74 | asyncio.run(run_test(uri, args.input_file, args.sample_rate, args.channels, args.bit_depth)) 75 | 76 | if __name__ == "__main__": 77 | main() 78 | 79 | -------------------------------------------------------------------------------- /tests/test_espnet_transcribe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test what ESPnet actually transcribes for our sample audio.""" 3 | 4 | import logging 5 | import torch 6 | import numpy as np 7 | import wave 8 | import hashlib 9 | import os 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(message)s') 12 | logger = logging.getLogger(__name__) 13 | 14 | logger.info("="*80) 15 | logger.info("ESPnet Full Transcription Test") 16 | logger.info("="*80) 17 | 18 | # Load audio 19 | from speechcatcher.speechcatcher import convert_inputfile 20 | 21 | os.makedirs('.tmp/', exist_ok=True) 22 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 23 | if not os.path.exists(wavfile_path): 24 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 25 | 26 | with wave.open(wavfile_path, 'rb') as wavfile_in: 27 | buf = wavfile_in.readframes(-1) 28 | raw_audio = np.frombuffer(buf, dtype='int16') 29 | 30 | waveform = raw_audio.astype(np.float32) 31 | logger.info(f"Audio loaded: {waveform.shape}") 32 | 33 | # Load ESPnet model 34 | from espnet2.bin.asr_inference_streaming import Speech2TextStreaming as ESPnetS2T 35 | 36 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 37 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 38 | 39 | espnet_model = ESPnetS2T( 40 | asr_train_config=config_path, 41 | asr_model_file=model_path, 42 | device="cpu", 43 | ) 44 | 45 | logger.info("✅ ESPnet model loaded") 46 | 47 | # Transcribe 48 | logger.info("\n" + "="*80) 49 | logger.info("Running ESPnet transcription...") 50 | logger.info("="*80) 51 | 52 | results = espnet_model(waveform) 53 | 54 | logger.info(f"\nESPnet transcription result:") 55 | logger.info(f"Result type: {type(results)}") 56 | if isinstance(results, list) and len(results) > 0: 57 | for i, result in enumerate(results): 58 | logger.info(f"\nResult {i+1}:") 59 | logger.info(f" Type: {type(result)}") 60 | logger.info(f" Length: {len(result) if isinstance(result, (list, tuple)) else 'N/A'}") 61 | 62 | if isinstance(result, (list, tuple)): 63 | logger.info(f" Contents: {result}") 64 | # Try to extract text and token_ids 65 | if len(result) >= 2: 66 | text = result[0] if isinstance(result[0], str) else str(result[0]) 67 | token_ids = result[1] if isinstance(result[1], list) else [] 68 | logger.info(f" Text: '{text}'") 69 | if token_ids: 70 | logger.info(f" Token IDs ({len(token_ids)}): {token_ids[:20]}...") 71 | 72 | # Check if token 1023 appears 73 | if 1023 in token_ids: 74 | count = token_ids.count(1023) 75 | positions = [j for j, x in enumerate(token_ids) if x == 1023] 76 | logger.warning(f" ⚠️ Token 1023 appears {count} times at positions: {positions[:10]}") 77 | else: 78 | logger.info(f" ✅ Token 1023 does NOT appear in output") 79 | else: 80 | logger.info(f" {result}") 81 | else: 82 | logger.info(f" {results}") 83 | 84 | logger.info("\n" + "="*80) 85 | logger.info("ESPnet Transcription Test Complete") 86 | logger.info("="*80) 87 | -------------------------------------------------------------------------------- /tests/test_espnet_final_chunk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check what ESPnet does with the final chunk.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | from speechcatcher.speechcatcher import convert_inputfile 11 | 12 | os.makedirs('.tmp/', exist_ok=True) 13 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 14 | if not os.path.exists(wavfile_path): 15 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 16 | 17 | with wave.open(wavfile_path, 'rb') as f: 18 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 19 | speech = raw_audio.astype(np.float32) / 32768.0 20 | 21 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 22 | 23 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 24 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 25 | 26 | espnet_s2t = ESPnetStreaming( 27 | asr_train_config=config_path, 28 | asr_model_file=model_path, 29 | device="cpu", 30 | beam_size=5, 31 | ctc_weight=0.3, 32 | ) 33 | espnet_s2t.reset() 34 | 35 | chunk_size = 8000 36 | 37 | print("="*80) 38 | print("ESPnet Final Chunk Processing") 39 | print("="*80) 40 | 41 | # Process all chunks except the last 42 | for chunk_idx in range(len(speech) // chunk_size): 43 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 44 | result = espnet_s2t(chunk, is_final=False) 45 | if result: 46 | print(f"Chunk {chunk_idx+1}: '{result[0][0]}'") 47 | 48 | # Check beam state before final 49 | if hasattr(espnet_s2t, 'beam_search'): 50 | bs = espnet_s2t.beam_search 51 | print(f"\n{'='*80}") 52 | print("Beam state BEFORE final chunk:") 53 | print(f"{'='*80}") 54 | if hasattr(bs, 'running_hyps') and bs.running_hyps: 55 | print(f"Running hypotheses: {len(bs.running_hyps)}") 56 | for i, hyp in enumerate(bs.running_hyps[:3]): 57 | if torch.is_tensor(hyp): 58 | print(f" [{i+1}] yseq={hyp.tolist()[:10]}") 59 | else: 60 | print(f" [{i+1}] score={hyp.score:.4f}, yseq={hyp.yseq.tolist()[:10]}") 61 | else: 62 | print("No running hypotheses") 63 | 64 | # Process final chunk 65 | final_chunk_idx = len(speech) // chunk_size 66 | chunk = speech[final_chunk_idx*chunk_size:] 67 | print(f"\nProcessing final chunk {final_chunk_idx+1} (is_final=True)...") 68 | result = espnet_s2t(chunk, is_final=True) 69 | 70 | # Check beam state after final 71 | if hasattr(espnet_s2t, 'beam_search'): 72 | bs = espnet_s2t.beam_search 73 | print(f"\n{'='*80}") 74 | print("Beam state AFTER final chunk:") 75 | print(f"{'='*80}") 76 | if hasattr(bs, 'running_hyps') and bs.running_hyps: 77 | print(f"Running hypotheses: {len(bs.running_hyps)}") 78 | for i, hyp in enumerate(bs.running_hyps[:3]): 79 | if torch.is_tensor(hyp): 80 | print(f" [{i+1}] yseq={hyp.tolist()[:20]}") 81 | else: 82 | print(f" [{i+1}] score={hyp.score:.4f}, yseq={hyp.yseq.tolist()[:20]}") 83 | 84 | if result: 85 | print(f"\nFinal output: '{result[0][0]}'") 86 | print(f"Token IDs: {result[0][2]}") 87 | print(f"yseq from result: {result[0][4].yseq.tolist()}") 88 | 89 | print("\n" + "="*80) 90 | -------------------------------------------------------------------------------- /speechcatcher/compute_wer.py: -------------------------------------------------------------------------------- 1 | import jiwer 2 | from somajo import SoMaJo 3 | from jiwer import cer, wer, wil 4 | import argparse 5 | 6 | punctuation = '.,!?;:-_"\'' 7 | 8 | def to_word_list(tokenizer, paragraphs, remove_punctuation=False): 9 | 10 | paragraph_ids = list(sorted(paragraphs.keys())) 11 | paragraphs = [paragraphs[key] for key in paragraph_ids] 12 | 13 | if remove_punctuation: 14 | paragraph_tokens = [[token.text for token in sentence if token.text not in punctuation] for sentence in tokenizer.tokenize_text(paragraphs)] 15 | else: 16 | paragraph_tokens = [[token.text for token in sentence] for sentence in tokenizer.tokenize_text(paragraphs)] 17 | 18 | 19 | assert(len(paragraph_ids) == len(paragraph_tokens)) 20 | 21 | return dict(zip(paragraph_ids, paragraph_tokens)) 22 | 23 | def calculate_word_error_rate(kaldi_test_text_file, decoded_text_file, language="de_CMC", remove_punctuation=False, split_camel_case=True, lower_case=False): 24 | tokenizer = SoMaJo(language=language, split_camel_case=split_camel_case, split_sentences=False) 25 | 26 | test_utterances = {} 27 | decoded_utterances = {} 28 | 29 | with open(kaldi_test_text_file) as kaldi_test_text_file_in: 30 | for line in kaldi_test_text_file_in: 31 | utterance_id, transcription = line.strip().split(' ', 1) 32 | test_utterances[utterance_id] = transcription 33 | 34 | with open(decoded_text_file) as decoded_text_file_in: 35 | for line in decoded_text_file_in: 36 | utterance_id, transcription = line.strip().split(' ', 1) 37 | decoded_utterances[utterance_id] = transcription 38 | 39 | test_utterances = to_word_list(tokenizer, test_utterances, remove_punctuation) 40 | decoded_utterances = to_word_list(tokenizer, decoded_utterances, remove_punctuation) 41 | decoded_utterances_ids = list(sorted(decoded_utterances.keys())) 42 | 43 | if lower_case: 44 | test_utterances_list = [' '.join(test_utterances[utt_id]).lower() for utt_id in decoded_utterances_ids] 45 | decoded_utterances_list = [' '.join(decoded_utterances[utt_id]).lower() for utt_id in decoded_utterances_ids] 46 | else: 47 | test_utterances_list = [' '.join(test_utterances[utt_id]) for utt_id in decoded_utterances_ids] 48 | decoded_utterances_list = [' '.join(decoded_utterances[utt_id]) for utt_id in decoded_utterances_ids] 49 | 50 | 51 | print('CER:', cer(test_utterances_list, decoded_utterances_list)) 52 | print('WER:', wer(test_utterances_list, decoded_utterances_list)) 53 | print('WIL:', wil(test_utterances_list, decoded_utterances_list)) 54 | 55 | def main(): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--remove_punctuation', dest='remove_punctuation', 58 | help='Remove punctuation before calculating metrics', action='store_true') 59 | parser.add_argument('--lower_case', dest='lower_case', 60 | help='Lower case all word before calculating metrics', action='store_true') 61 | parser.add_argument('--language', dest='language', help='Language for the SoMaJo tokenizer. ' 62 | 'Necessary for WER computation. Default: de_CMC', 63 | default='de_CMC') 64 | 65 | 66 | parser.add_argument('testset_text', type=str, help='The test file containing the test set utterances') 67 | parser.add_argument('decoded_text', type=str, help='The decoded file containing the decoded utterances') 68 | 69 | args = parser.parse_args() 70 | 71 | calculate_word_error_rate(args.testset_text, args.decoded_text, language=args.language, remove_punctuation=args.remove_punctuation, lower_case=args.lower_case) 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /speechcatcher/model/layers/convolution.py: -------------------------------------------------------------------------------- 1 | """Convolution module for Conformer.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Swish(nn.Module): 9 | """Swish activation function: x * sigmoid(x).""" 10 | 11 | def forward(self, x: torch.Tensor) -> torch.Tensor: 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | class ConvolutionModule(nn.Module): 16 | """Convolution module for Conformer. 17 | 18 | This implements the convolution module from "Conformer: Convolution-augmented 19 | Transformer for Speech Recognition" (Gulati et al., 2020). 20 | 21 | Architecture: 22 | LayerNorm -> Pointwise Conv (expansion) -> GLU -> Depthwise Conv 23 | -> BatchNorm -> Swish -> Pointwise Conv (projection) -> Dropout 24 | 25 | Args: 26 | channels: Number of input/output channels 27 | kernel_size: Kernel size for depthwise convolution (default: 31) 28 | dropout_rate: Dropout rate (default: 0.1) 29 | bias: Whether to use bias in convolutions (default: True) 30 | 31 | Shape: 32 | - Input: (batch, time, channels) 33 | - Output: (batch, time, channels) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | channels: int, 39 | kernel_size: int = 31, 40 | dropout_rate: float = 0.1, 41 | bias: bool = True, 42 | ): 43 | super().__init__() 44 | assert kernel_size % 2 == 1, "Kernel size must be odd for 'same' padding" 45 | 46 | self.layernorm = nn.LayerNorm(channels) 47 | 48 | # Pointwise expansion (2x channels for GLU) 49 | self.pointwise_conv1 = nn.Conv1d( 50 | channels, 51 | 2 * channels, 52 | kernel_size=1, 53 | stride=1, 54 | padding=0, 55 | bias=bias, 56 | ) 57 | 58 | # Depthwise convolution 59 | self.depthwise_conv = nn.Conv1d( 60 | channels, 61 | channels, 62 | kernel_size=kernel_size, 63 | stride=1, 64 | padding=(kernel_size - 1) // 2, 65 | groups=channels, # Depthwise 66 | bias=bias, 67 | ) 68 | 69 | self.batch_norm = nn.BatchNorm1d(channels) 70 | self.activation = Swish() 71 | 72 | # Pointwise projection 73 | self.pointwise_conv2 = nn.Conv1d( 74 | channels, 75 | channels, 76 | kernel_size=1, 77 | stride=1, 78 | padding=0, 79 | bias=bias, 80 | ) 81 | 82 | self.dropout = nn.Dropout(dropout_rate) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | """Forward pass. 86 | 87 | Args: 88 | x: Input tensor (batch, time, channels) 89 | 90 | Returns: 91 | Output tensor (batch, time, channels) 92 | """ 93 | # LayerNorm 94 | x = self.layernorm(x) 95 | 96 | # Transpose to (batch, channels, time) for Conv1d 97 | x = x.transpose(1, 2) 98 | 99 | # Pointwise expansion 100 | x = self.pointwise_conv1(x) 101 | 102 | # GLU (Gated Linear Unit): split into two halves and apply gate 103 | x_a, x_b = x.chunk(2, dim=1) 104 | x = x_a * torch.sigmoid(x_b) 105 | 106 | # Depthwise convolution 107 | x = self.depthwise_conv(x) 108 | 109 | # BatchNorm + Swish 110 | x = self.batch_norm(x) 111 | x = self.activation(x) 112 | 113 | # Pointwise projection 114 | x = self.pointwise_conv2(x) 115 | 116 | # Transpose back to (batch, time, channels) 117 | x = x.transpose(1, 2) 118 | 119 | # Dropout 120 | return self.dropout(x) 121 | -------------------------------------------------------------------------------- /docs/analysis/initial-comparison.md: -------------------------------------------------------------------------------- 1 | # Segment-by-Segment Comparison: Native vs ESPnet Decoder 2 | 3 | ## Summary Statistics 4 | 5 | | Segment | Duration | Native Words | ESPnet Words | Difference | Native % | Status | 6 | |---------|----------|--------------|--------------|------------|----------|---------| 7 | | 0 | 60.0s | 70 | 105 | -35 | 66.7% | ⚠️ Moderate loss | 8 | | 1 | 60.0s | 37 | 92 | -55 | 40.2% | 🔴 Severe loss + repetition | 9 | | 2 | 60.0s | 4 | 96 | -92 | 4.2% | 🔴 CRITICAL failure | 10 | | 3 | 60.0s | 6 | 111 | -105 | 5.4% | 🔴 CRITICAL failure | 11 | | 4 | 60.0s | 5 | 5 | 0 | 100.0% | ✅ Identical | 12 | | 5 | 60.0s | 53 | 53 | 0 | 100.0% | ✅ Identical | 13 | | 6 | 60.0s | 78 | 78 | 0 | 100.0% | ✅ Identical | 14 | | 7 | 65.0s | 4 | 4 | 0 | 100.0% | ✅ Identical | 15 | | **TOTAL** | **485.0s** | **257** | **544** | **-287** | **47.2%** | - | 16 | 17 | ## Critical Findings 18 | 19 | ### 1. **Bimodal Performance Pattern** 20 | - **Segments 0-3** (first 4 minutes): Native achieves only 29.0% of ESPnet output (117 vs 404 words) 21 | - **Segments 4-7** (last 4 minutes): Native matches ESPnet 100% (140 vs 140 words) 22 | 23 | ### 2. **Worst Offenders** 24 | 25 | #### Segment 2 (4.2% accuracy) - WORST PERFORMANCE 26 | **ESPnet (96 words):** Full coherent transcription about university, administration, libraries, past vs future, flexibility, innovation, etc. 27 | 28 | **Native (4 words):** `und mussten sich, äh,` 29 | 30 | **Issue:** Decoder produced almost nothing, with special tokens appearing in output 31 | 32 | #### Segment 3 (5.4% accuracy) - SECOND WORST 33 | **ESPnet (111 words):** Full transcription about fear of the future, virus, vigilance vs anxiety, etc. 34 | 35 | **Native (6 words):** `andere. Ja, das ist so.` 36 | 37 | **Issue:** Decoder produced trivial output, completely missing the content 38 | 39 | #### Segment 1 (40.2% accuracy) - REPETITION FAILURE 40 | **ESPnet (92 words):** Full transcription about current situation, temporal phases, university past, etc. 41 | 42 | **Native (37 words):** Contains massive repetition: `Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Under. Und` 43 | 44 | **Issue:** BBD failed completely, allowing massive token repetition 45 | 46 | ### 3. **Perfect Segments (4-7)** 47 | All four segments in the second half produced **identical output** between native and ESPnet decoders, suggesting: 48 | - No BBD issues in second half 49 | - No state management issues 50 | - Identical decoder behavior 51 | 52 | ## Hypothesis 53 | 54 | The dramatic difference between first half (0-3) and second half (4-7) suggests: 55 | 56 | 1. **NOT a temporal embedding issue** - If it were temporal embeddings, we'd expect degradation after 180s (3 minutes), but problems START immediately and then IMPROVE after 4 minutes 57 | 58 | 2. **Possible audio content dependency** - The decoder may handle certain acoustic conditions or speech patterns differently 59 | 60 | 3. **BBD sensitivity to speech content** - The first 4 minutes may contain speech patterns that trigger BBD inappropriately, while the second half does not 61 | 62 | 4. **Encoder buffer issues** - First segments may have problematic encoder state handling that resolves by segment 4 63 | 64 | ## Next Steps 65 | 66 | **Priority 1:** Investigate segments 2 and 3 (worst offenders with <6% accuracy) 67 | **Priority 2:** Investigate segment 1 (BBD repetition failure) 68 | **Priority 3:** Understand why segments 4-7 work perfectly 69 | -------------------------------------------------------------------------------- /speechcatcher/model/encoder/subsampling.py: -------------------------------------------------------------------------------- 1 | """Subsampling layers for encoder.""" 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Conv2dSubsampling(nn.Module): 11 | """Convolutional 2D subsampling without positional encoding. 12 | 13 | This module reduces the time dimension by applying strided convolutions. 14 | Based on ESPnet's Conv2dSubsamplingWOPosEnc. 15 | 16 | Args: 17 | input_dim: Input feature dimension 18 | output_dim: Output dimension (model dimension) 19 | dropout_rate: Dropout rate 20 | kernels: List of kernel sizes for each conv layer (default: [3, 3]) 21 | strides: List of strides for each conv layer (default: [2, 2]) 22 | 23 | Shape: 24 | - Input: (batch, time, input_dim) 25 | - Output: (batch, time', output_dim) where time' = time // prod(strides) 26 | 27 | Example: 28 | With kernels=[3, 3], strides=[2, 2], subsample factor is 4x 29 | """ 30 | 31 | def __init__( 32 | self, 33 | input_dim: int, 34 | output_dim: int, 35 | dropout_rate: float = 0.0, 36 | kernels: list = None, 37 | strides: list = None, 38 | ): 39 | super().__init__() 40 | 41 | if kernels is None: 42 | kernels = [3, 3] 43 | if strides is None: 44 | strides = [2, 2] 45 | 46 | assert len(kernels) == len(strides), "kernels and strides must have same length" 47 | 48 | self.kernels = kernels 49 | self.strides = strides 50 | 51 | # Build conv layers 52 | conv_layers = [] 53 | for i, (kernel, stride) in enumerate(zip(kernels, strides)): 54 | in_channels = 1 if i == 0 else output_dim 55 | conv_layers.extend([ 56 | nn.Conv2d(in_channels, output_dim, kernel, stride), 57 | nn.ReLU(), 58 | ]) 59 | 60 | self.conv = nn.Sequential(*conv_layers) 61 | 62 | # Calculate output feature dimension after convolutions 63 | # Each conv reduces: out = (in - kernel) / stride + 1 64 | out_len = input_dim 65 | for kernel, stride in zip(kernels, strides): 66 | out_len = math.floor((out_len - kernel) / stride + 1) 67 | 68 | # Linear projection to output_dim 69 | self.out = nn.Linear(output_dim * out_len, output_dim) 70 | 71 | def forward( 72 | self, 73 | x: torch.Tensor, 74 | x_mask: Optional[torch.Tensor] = None 75 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 76 | """Forward pass. 77 | 78 | Args: 79 | x: Input tensor (batch, time, input_dim) 80 | x_mask: Optional mask tensor (batch, 1, time) 81 | 82 | Returns: 83 | Tuple of: 84 | - Subsampled output (batch, time', output_dim) 85 | - Subsampled mask (batch, 1, time') or None 86 | """ 87 | # Add channel dimension: (batch, time, input_dim) -> (batch, 1, time, input_dim) 88 | x = x.unsqueeze(1) 89 | 90 | # Apply convolutions: (batch, 1, time, input_dim) -> (batch, output_dim, time', feat') 91 | x = self.conv(x) 92 | 93 | # Reshape and project 94 | batch, channels, time, feat = x.size() 95 | # (batch, channels, time, feat) -> (batch, time, channels * feat) 96 | x = x.transpose(1, 2).contiguous().view(batch, time, channels * feat) 97 | # (batch, time, channels * feat) -> (batch, time, output_dim) 98 | x = self.out(x) 99 | 100 | # Subsample mask if provided 101 | if x_mask is not None: 102 | for kernel, stride in zip(self.kernels, self.strides): 103 | # Subsample mask: take every stride-th element, accounting for kernel 104 | x_mask = x_mask[:, :, :-kernel+1:stride] 105 | 106 | return x, x_mask 107 | -------------------------------------------------------------------------------- /tests/test_bbd_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check BBD state to see if rollback is happening correctly.""" 3 | 4 | import numpy as np 5 | import wave 6 | import hashlib 7 | import os 8 | import torch 9 | 10 | print("="*80) 11 | print("BBD STATE TRACE") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as f: 23 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 24 | speech = raw_audio.astype(np.float32) / 32768.0 25 | 26 | # Load our implementation 27 | print("\nLoading model...") 28 | from speechcatcher.speechcatcher import load_model, tags 29 | 30 | s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 31 | s2t.reset() 32 | print("✅ Model loaded") 33 | 34 | # Load BPE for decoding 35 | import sentencepiece as spm 36 | sp = spm.SentencePieceProcessor() 37 | sp.Load("/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/data/de_token_list/bpe_unigram1024/bpe.model") 38 | 39 | # Process chunks 40 | chunk_size = 8000 41 | print("\nProcessing chunks with BBD trace...") 42 | 43 | for chunk_idx in range(6): 44 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 45 | is_final = False 46 | 47 | print(f"\n{'='*80}") 48 | print(f"CHUNK {chunk_idx+1}") 49 | print(f"{'='*80}") 50 | 51 | # Check state before 52 | if hasattr(s2t, 'beam_search') and hasattr(s2t.beam_search, 'state'): 53 | bs = s2t.beam_search 54 | if bs.state and bs.state.hypotheses: 55 | print(f"\nBefore processing:") 56 | print(f" output_index: {bs.state.output_index}") 57 | print(f" processed_frames: {bs.state.processed_frames}") 58 | print(f" Top hypothesis:") 59 | hyp = bs.state.hypotheses[0] 60 | yseq = hyp.yseq.tolist() 61 | tokens = [sp.IdToPiece(t) for t in yseq] 62 | print(f" yseq: {yseq}") 63 | print(f" tokens: {tokens}") 64 | print(f" score: {hyp.score:.4f}") 65 | 66 | # Process 67 | results = s2t(chunk, is_final=is_final) 68 | 69 | # Check state after 70 | if hasattr(s2t, 'beam_search') and hasattr(s2t.beam_search, 'state'): 71 | bs = s2t.beam_search 72 | if bs.state and bs.state.hypotheses: 73 | print(f"\nAfter processing:") 74 | print(f" output_index: {bs.state.output_index}") 75 | print(f" processed_frames: {bs.state.processed_frames}") 76 | print(f" Top hypothesis:") 77 | hyp = bs.state.hypotheses[0] 78 | yseq = hyp.yseq.tolist() 79 | tokens = [sp.IdToPiece(t) for t in yseq] 80 | print(f" yseq: {yseq}") 81 | print(f" tokens: {tokens}") 82 | print(f" score: {hyp.score:.4f}") 83 | 84 | # Show results 85 | print(f"\nResults:") 86 | if results and results[0][0]: 87 | result = results[0] 88 | print(f" Text: \"{result[0]}\"") 89 | print(f" Tokens: {result[1]}") 90 | print(f" Token IDs: {result[2]}") 91 | 92 | # Extract just the NEW tokens output this chunk 93 | if chunk_idx > 0 and prev_output_index is not None: 94 | new_token_count = bs.state.output_index - prev_output_index 95 | print(f" New tokens this chunk: {new_token_count}") 96 | else: 97 | print(f" (no output)") 98 | 99 | # Save output_index for next iteration 100 | if hasattr(s2t, 'beam_search') and hasattr(s2t.beam_search, 'state'): 101 | prev_output_index = s2t.beam_search.state.output_index 102 | else: 103 | prev_output_index = None 104 | 105 | print("\n" + "="*80) 106 | print("TEST COMPLETE") 107 | print("="*80) 108 | -------------------------------------------------------------------------------- /tests/test_token_scoring_chunk5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Compare scores at chunk 5 when ESPnet has actually decoded.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("TOKEN SCORING AT CHUNK 5") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as f: 23 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 24 | speech = raw_audio.astype(np.float32) / 32768.0 25 | 26 | # Load ESPnet token list 27 | import json 28 | with open("/tmp/espnet_token_list.json", "r") as f: 29 | espnet_token_list = json.load(f) 30 | 31 | # Load ESPnet 32 | print("\n[1] Loading ESPnet...") 33 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 34 | 35 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 36 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 37 | 38 | espnet_s2t = ESPnetStreaming( 39 | asr_train_config=config_path, 40 | asr_model_file=model_path, 41 | device="cpu", 42 | beam_size=5, 43 | ctc_weight=0.3, 44 | ) 45 | espnet_s2t.reset() 46 | print("✅ ESPnet loaded") 47 | 48 | # Load ours 49 | print("\n[2] Loading our implementation...") 50 | from speechcatcher.speechcatcher import load_model, tags 51 | 52 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 53 | our_s2t.reset() 54 | print("✅ Ours loaded") 55 | 56 | # Process chunks 1-5 57 | print("\n[3] Processing chunks 1-5...") 58 | chunk_size = 8000 59 | 60 | for chunk_idx in range(5): 61 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 62 | 63 | # ESPnet 64 | espnet_results = espnet_s2t(chunk, is_final=False) 65 | if espnet_results: 66 | print(f" Chunk {chunk_idx+1} - ESPnet output: {espnet_results[0]}") 67 | 68 | # Ours 69 | our_results = our_s2t(chunk, is_final=False) 70 | if our_results: 71 | print(f" Chunk {chunk_idx+1} - Our output: {our_results[0]}") 72 | 73 | print("✅ Processed chunks 1-5") 74 | 75 | # Compare beam states 76 | print("\n" + "="*80) 77 | print("BEAM STATE AFTER CHUNK 5") 78 | print("="*80) 79 | 80 | # ESPnet beam state 81 | print("\n[ESPnet] Top 5 hypotheses:") 82 | if hasattr(espnet_s2t, 'beam_state') and espnet_s2t.beam_state is not None and espnet_s2t.beam_state.hypotheses: 83 | for i, hyp in enumerate(espnet_s2t.beam_state.hypotheses[:5]): 84 | token_ids = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 85 | tokens = [espnet_token_list[tid] for tid in token_ids] 86 | print(f" [{i+1}] score={hyp.score:.6f}") 87 | print(f" yseq={token_ids}") 88 | print(f" tokens={' '.join(tokens)}") 89 | else: 90 | print(" (no hypotheses)") 91 | 92 | # Our beam state 93 | print("\n[Ours] Top 5 hypotheses:") 94 | if hasattr(our_s2t, 'beam_state') and our_s2t.beam_state is not None and our_s2t.beam_state.hypotheses: 95 | for i, hyp in enumerate(our_s2t.beam_state.hypotheses[:5]): 96 | token_ids = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 97 | tokens = [espnet_token_list[tid] for tid in token_ids] 98 | print(f" [{i+1}] score={hyp.score:.6f}") 99 | print(f" yseq={token_ids}") 100 | print(f" tokens={' '.join(tokens)}") 101 | else: 102 | print(" (no hypotheses)") 103 | 104 | print("\n" + "="*80) 105 | -------------------------------------------------------------------------------- /tests/test_decoder_scores_debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Debug decoder scores to see why it predicts token 1023.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("DECODER SCORES DEBUG") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as wavfile_in: 23 | rate = wavfile_in.getframerate() 24 | buf = wavfile_in.readframes(-1) 25 | raw_audio = np.frombuffer(buf, dtype='int16') 26 | 27 | speech = raw_audio.astype(np.float32) / 32768.0 28 | 29 | # Load model 30 | print("\nLoading model...") 31 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 32 | our_s2t.reset() 33 | print("✅ Loaded\n") 34 | 35 | # Process chunks until we have encoder output 36 | chunk_size = 8000 37 | for chunk_idx in range(5): 38 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 39 | is_final = False 40 | 41 | with torch.no_grad(): 42 | results = our_s2t(chunk, is_final=is_final) 43 | 44 | if our_s2t.beam_search.encoder_buffer is not None and our_s2t.beam_search.encoder_buffer.size(1) >= 40: 45 | print(f"Chunk {chunk_idx+1}: Encoder buffer has {our_s2t.beam_search.encoder_buffer.size(1)} frames") 46 | 47 | # Get encoder output (first 40 frames) 48 | encoder_out = our_s2t.beam_search.encoder_buffer.narrow(1, 0, 40) 49 | encoder_out_lens = torch.tensor([40], dtype=torch.long) 50 | 51 | print(f"\nTesting decoder with {encoder_out.shape} encoder output\n") 52 | 53 | # Initialize with SOS 54 | from speechcatcher.beam_search.hypothesis import create_initial_hypothesis 55 | hyp = create_initial_hypothesis(sos_id=1, device="cpu") 56 | 57 | print("="*60) 58 | print("STEP 1: Score all tokens from SOS") 59 | print("="*60) 60 | 61 | # Use beam search's batch_score_hypotheses (handles CTC + decoder) 62 | from speechcatcher.beam_search.beam_search import BeamSearch 63 | 64 | temp_beam_search = BeamSearch( 65 | scorers=our_s2t.beam_search.scorers, 66 | weights=our_s2t.beam_search.weights, 67 | beam_size=5, 68 | vocab_size=1024, 69 | sos_id=1, 70 | eos_id=2, 71 | device="cpu", 72 | ) 73 | 74 | with torch.no_grad(): 75 | combined_scores, _ = temp_beam_search.batch_score_hypotheses([hyp], encoder_out) 76 | 77 | # Show top 20 tokens 78 | top_scores, top_tokens = torch.topk(combined_scores[0], 20) 79 | 80 | dec_weight = our_s2t.beam_search.weights["decoder"] 81 | ctc_weight = our_s2t.beam_search.weights["ctc"] 82 | print(f"\nTop 20 tokens (combined = {dec_weight:.1f}*decoder + {ctc_weight:.1f}*ctc):") 83 | print(f"{'Rank':<6} {'Token':<8} {'Score':<12}") 84 | print("-" * 30) 85 | for i, (score, token) in enumerate(zip(top_scores.tolist(), top_tokens.tolist())): 86 | print(f"{i+1:<6} {token:<8} {score:>11.4f}") 87 | 88 | # Check specific tokens 89 | print(f"\nSpecific tokens:") 90 | print(f" Token 1023 (م): {combined_scores[0, 1023].item():.4f}") 91 | 92 | # Load BPE to decode some tokens 93 | try: 94 | import sentencepiece as spm 95 | sp = spm.SentencePieceProcessor() 96 | sp.Load("/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/data/de_token_list/bpe_unigram1024/bpe.model") 97 | 98 | print(f"\nTop tokens decoded:") 99 | for i, token in enumerate(top_tokens[:10].tolist()): 100 | piece = sp.IdToPiece(token) 101 | print(f" {i+1}. Token {token}: '{piece}'") 102 | except: 103 | pass 104 | 105 | break 106 | 107 | print("\n" + "="*80) 108 | print("DEBUG COMPLETE") 109 | print("="*80) 110 | -------------------------------------------------------------------------------- /tests/test_waveform_buffering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test waveform buffering - does it wait for enough samples?""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("WAVEFORM BUFFERING TEST") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile, load_model, tags 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as wavfile_in: 23 | rate = wavfile_in.getframerate() 24 | buf = wavfile_in.readframes(-1) 25 | raw_audio = np.frombuffer(buf, dtype='int16') 26 | 27 | speech = raw_audio.astype(np.float32) / 32768.0 28 | print(f"Audio: {len(speech)} samples ({len(speech)/rate:.2f}s)") 29 | 30 | # Load model 31 | print("\nLoading model...") 32 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 33 | our_s2t.reset() 34 | 35 | print(f"Frontend params: win_length={our_s2t.win_length}, hop_length={our_s2t.hop_length}") 36 | print(f"Minimum samples needed: {our_s2t.win_length + 1}") 37 | 38 | # Test with VERY small chunks (smaller than win_length) 39 | print("\n" + "="*80) 40 | print("TEST 1: Chunks smaller than win_length") 41 | print("="*80) 42 | 43 | our_s2t.reset() 44 | chunk_size = 100 # Much smaller than win_length=400 45 | 46 | print(f"\nChunk size: {chunk_size} samples (< win_length={our_s2t.win_length})") 47 | print("Expected: Chunks 1-3 should buffer, chunk 4 should process\n") 48 | 49 | for chunk_idx in range(6): 50 | start = chunk_idx * chunk_size 51 | end = min((chunk_idx + 1) * chunk_size, len(speech)) 52 | chunk = speech[start:end] 53 | 54 | is_final = False 55 | 56 | print(f"Chunk {chunk_idx+1}: {len(chunk)} samples") 57 | 58 | # Check frontend state BEFORE 59 | buffered_before = 0 60 | if our_s2t.frontend_states and "waveform_buffer" in our_s2t.frontend_states: 61 | buffered_before = our_s2t.frontend_states["waveform_buffer"].size(0) 62 | 63 | with torch.no_grad(): 64 | results = our_s2t(chunk, is_final=is_final) 65 | 66 | # Check frontend state AFTER 67 | buffered_after = 0 68 | if our_s2t.frontend_states and "waveform_buffer" in our_s2t.frontend_states: 69 | buffered_after = our_s2t.frontend_states["waveform_buffer"].size(0) 70 | 71 | # Check encoder buffer 72 | enc_buf_size = 0 73 | if our_s2t.beam_search.encoder_buffer is not None: 74 | enc_buf_size = our_s2t.beam_search.encoder_buffer.size(1) 75 | 76 | print(f" Waveform buffer: {buffered_before} → {buffered_after} samples") 77 | print(f" Encoder buffer: {enc_buf_size} frames") 78 | print(f" Results: {'(empty)' if not results or len(results) == 0 or not results[0][0] else results[0][0]}") 79 | print() 80 | 81 | # Test with normal chunks 82 | print("="*80) 83 | print("TEST 2: Normal chunks (larger than win_length)") 84 | print("="*80) 85 | 86 | our_s2t.reset() 87 | chunk_size = 8000 88 | 89 | print(f"\nChunk size: {chunk_size} samples (>> win_length={our_s2t.win_length})") 90 | print("Expected: Every chunk should process immediately\n") 91 | 92 | for chunk_idx in range(3): 93 | start = chunk_idx * chunk_size 94 | end = min((chunk_idx + 1) * chunk_size, len(speech)) 95 | chunk = speech[start:end] 96 | 97 | is_final = False 98 | 99 | print(f"Chunk {chunk_idx+1}: {len(chunk)} samples") 100 | 101 | buffered_before = 0 102 | if our_s2t.frontend_states and "waveform_buffer" in our_s2t.frontend_states: 103 | buffered_before = our_s2t.frontend_states["waveform_buffer"].size(0) 104 | 105 | with torch.no_grad(): 106 | results = our_s2t(chunk, is_final=is_final) 107 | 108 | buffered_after = 0 109 | if our_s2t.frontend_states and "waveform_buffer" in our_s2t.frontend_states: 110 | buffered_after = our_s2t.frontend_states["waveform_buffer"].size(0) 111 | 112 | enc_buf_size = 0 113 | if our_s2t.beam_search.encoder_buffer is not None: 114 | enc_buf_size = our_s2t.beam_search.encoder_buffer.size(1) 115 | 116 | print(f" Waveform buffer: {buffered_before} → {buffered_after} samples") 117 | print(f" Encoder buffer: {enc_buf_size} frames") 118 | print() 119 | 120 | print("="*80) 121 | print("TEST COMPLETE") 122 | print("="*80) 123 | print("\nConclusion:") 124 | print("- Small chunks should accumulate in waveform buffer until > win_length") 125 | print("- Large chunks should process immediately with STFT overlap buffering") 126 | -------------------------------------------------------------------------------- /tests/test_step_by_step_trace.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Step-by-step trace to see where beam searches diverge.""" 3 | 4 | import numpy as np 5 | import wave 6 | import hashlib 7 | import os 8 | import torch 9 | 10 | print("="*80) 11 | print("STEP-BY-STEP BEAM SEARCH TRACE") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as f: 23 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 24 | speech = raw_audio.astype(np.float32) / 32768.0 25 | 26 | # Load BPE for decoding 27 | import sentencepiece as spm 28 | sp = spm.SentencePieceProcessor() 29 | sp.Load("/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/data/de_token_list/bpe_unigram1024/bpe.model") 30 | 31 | # Load ESPnet 32 | print("\n[1] Loading ESPnet...") 33 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 34 | 35 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 36 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 37 | 38 | espnet_s2t = ESPnetStreaming( 39 | asr_train_config=config_path, 40 | asr_model_file=model_path, 41 | device="cpu", 42 | beam_size=5, 43 | ctc_weight=0.3, 44 | ) 45 | espnet_s2t.reset() 46 | print("✅ ESPnet loaded") 47 | 48 | # Load ours 49 | print("\n[2] Loading our implementation...") 50 | from speechcatcher.speechcatcher import load_model, tags 51 | 52 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 53 | our_s2t.reset() 54 | print("✅ Ours loaded") 55 | 56 | # Process chunks 1-5 57 | print("\n[3] Processing chunks 1-5...") 58 | chunk_size = 8000 59 | 60 | for chunk_idx in range(5): 61 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 62 | 63 | espnet_results = espnet_s2t(chunk, is_final=False) 64 | our_results = our_s2t(chunk, is_final=False) 65 | 66 | print("\n" + "="*80) 67 | print("CHUNK 5 - HYPOTHESIS COMPARISON") 68 | print("="*80) 69 | 70 | # ESPnet hypotheses 71 | print("\n[ESPnet]") 72 | if hasattr(espnet_s2t.beam_search, 'running_hyps') and espnet_s2t.beam_search.running_hyps: 73 | print(f"Number of hypotheses: {len(espnet_s2t.beam_search.running_hyps)}") 74 | for i, hyp in enumerate(espnet_s2t.beam_search.running_hyps[:3]): 75 | if hasattr(hyp, 'yseq'): 76 | yseq = hyp.yseq.tolist() if torch.is_tensor(hyp.yseq) else hyp.yseq 77 | tokens = [sp.IdToPiece(t) for t in yseq] 78 | score = hyp.score.item() if torch.is_tensor(hyp.score) else hyp.score 79 | print(f" Hyp {i+1}:") 80 | print(f" yseq: {yseq}") 81 | print(f" tokens: {tokens}") 82 | print(f" score: {score:.4f}") 83 | 84 | # Our hypotheses 85 | print("\n[Ours]") 86 | if hasattr(our_s2t, 'beam_state') and our_s2t.beam_state and our_s2t.beam_state.hypotheses: 87 | print(f"Number of hypotheses: {len(our_s2t.beam_state.hypotheses)}") 88 | print(f"output_index: {our_s2t.beam_state.output_index}") 89 | for i, hyp in enumerate(our_s2t.beam_state.hypotheses[:3]): 90 | yseq = hyp.yseq.tolist() 91 | tokens = [sp.IdToPiece(t) for t in yseq] 92 | print(f" Hyp {i+1}:") 93 | print(f" yseq: {yseq}") 94 | print(f" tokens: {tokens}") 95 | print(f" score: {hyp.score:.4f}") 96 | 97 | # Compare outputs 98 | print("\n" + "="*80) 99 | print("OUTPUT COMPARISON") 100 | print("="*80) 101 | 102 | espnet_text = espnet_results[0][0] if espnet_results and espnet_results[0][0] else "(none)" 103 | our_text = our_results[0][0] if our_results and our_results[0][0] else "(none)" 104 | 105 | print(f"\nESPnet: \"{espnet_text}\"") 106 | print(f"Ours: \"{our_text}\"") 107 | 108 | if espnet_text == our_text: 109 | print("\n✅ OUTPUTS MATCH!") 110 | else: 111 | print(f"\n❌ OUTPUTS DIFFER!") 112 | 113 | print("\n" + "="*80) 114 | print("TEST COMPLETE") 115 | print("="*80) 116 | -------------------------------------------------------------------------------- /tests/test_model_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Check if models are in eval mode and test dropout behavior.""" 3 | 4 | import logging 5 | import torch 6 | 7 | logging.basicConfig(level=logging.INFO, format='%(message)s') 8 | logger = logging.getLogger(__name__) 9 | 10 | # ============================================================================ 11 | # Load both models 12 | # ============================================================================ 13 | logger.info("="*80) 14 | logger.info("MODEL MODE CHECK") 15 | logger.info("="*80) 16 | 17 | from speechcatcher.speechcatcher import load_model 18 | from espnet2.bin.asr_inference_streaming import Speech2TextStreaming as ESPnetS2T 19 | 20 | logger.info("\n[1] Loading models...") 21 | 22 | our_model = load_model( 23 | tag="speechcatcher/speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024", 24 | device="cpu", 25 | beam_size=5, 26 | quiet=True 27 | ) 28 | 29 | espnet_model = ESPnetS2T( 30 | asr_train_config="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml", 31 | asr_model_file="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth", 32 | device="cpu", 33 | ) 34 | 35 | logger.info("✅ Models loaded") 36 | 37 | # ============================================================================ 38 | # Check training mode 39 | # ============================================================================ 40 | logger.info("\n[2] Checking training mode...") 41 | 42 | logger.info(f"\nOur model training: {our_model.model.training}") 43 | logger.info(f"ESPnet model training: {espnet_model.asr_model.training}") 44 | 45 | # Check decoder specifically 46 | logger.info(f"\nOur decoder training: {our_model.model.decoder.training}") 47 | logger.info(f"ESPnet decoder training: {espnet_model.asr_model.decoder.training}") 48 | 49 | # Check for dropout layers 50 | logger.info("\n[3] Checking dropout layers...") 51 | 52 | def check_dropout_layers(model, name): 53 | dropout_layers = [] 54 | for module_name, module in model.named_modules(): 55 | if isinstance(module, torch.nn.Dropout): 56 | dropout_layers.append((module_name, module.p, module.training)) 57 | 58 | logger.info(f"\n{name} dropout layers:") 59 | if dropout_layers: 60 | for name, p, training in dropout_layers[:10]: # Show first 10 61 | logger.info(f" {name}: p={p:.2f}, training={training}") 62 | if len(dropout_layers) > 10: 63 | logger.info(f" ... and {len(dropout_layers) - 10} more") 64 | else: 65 | logger.info(" No dropout layers found") 66 | 67 | return dropout_layers 68 | 69 | our_dropouts = check_dropout_layers(our_model.model, "Our model") 70 | espnet_dropouts = check_dropout_layers(espnet_model.asr_model, "ESPnet model") 71 | 72 | # ============================================================================ 73 | # Set both to eval mode explicitly 74 | # ============================================================================ 75 | logger.info("\n[4] Setting both to eval mode...") 76 | 77 | our_model.model.eval() 78 | espnet_model.asr_model.eval() 79 | 80 | logger.info(f"Our model training: {our_model.model.training}") 81 | logger.info(f"ESPnet model training: {espnet_model.asr_model.training}") 82 | 83 | # ============================================================================ 84 | # Test with same random input to verify determinism 85 | # ============================================================================ 86 | logger.info("\n[5] Testing determinism with same random input...") 87 | 88 | torch.manual_seed(42) 89 | input1 = torch.randn(1, 124, 256) 90 | 91 | torch.manual_seed(42) 92 | input2 = torch.randn(1, 124, 256) 93 | 94 | logger.info(f"Inputs identical: {torch.allclose(input1, input2)}") 95 | 96 | # Test our decoder 97 | sos_id = 1 98 | yseq = torch.tensor([sos_id]) 99 | 100 | with torch.no_grad(): 101 | our_out1, _ = our_model.model.decoder.score(yseq, state=None, x=input1[0]) 102 | our_out2, _ = our_model.model.decoder.score(yseq, state=None, x=input2[0]) 103 | 104 | logger.info(f"\nOur decoder outputs identical: {torch.allclose(our_out1, our_out2)}") 105 | logger.info(f" Max diff: {(our_out1 - our_out2).abs().max():.6f}") 106 | 107 | # Test ESPnet decoder 108 | with torch.no_grad(): 109 | espnet_out1, _ = espnet_model.asr_model.decoder.score(yseq, state=None, x=input1[0]) 110 | espnet_out2, _ = espnet_model.asr_model.decoder.score(yseq, state=None, x=input2[0]) 111 | 112 | logger.info(f"\nESPnet decoder outputs identical: {torch.allclose(espnet_out1, espnet_out2)}") 113 | logger.info(f" Max diff: {(espnet_out1 - espnet_out2).abs().max():.6f}") 114 | 115 | logger.info("\n" + "="*80) 116 | logger.info("MODEL MODE CHECK COMPLETE") 117 | logger.info("="*80) 118 | -------------------------------------------------------------------------------- /speechcatcher/model/decoder/decoder_layer.py: -------------------------------------------------------------------------------- 1 | """Transformer decoder layer.""" 2 | 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from speechcatcher.model.layers import LayerNorm 9 | 10 | 11 | class TransformerDecoderLayer(nn.Module): 12 | """Single Transformer decoder layer. 13 | 14 | This layer implements a standard Transformer decoder block with: 15 | 1. Masked self-attention on target sequence 16 | 2. Cross-attention to encoder output 17 | 3. Position-wise feed-forward network 18 | 19 | Args: 20 | size: Model dimension 21 | self_attn: Self-attention module (MultiHeadedAttention) 22 | src_attn: Source attention module (MultiHeadedAttention) 23 | feed_forward: Feed-forward module (PositionwiseFeedForward) 24 | dropout_rate: Dropout rate 25 | normalize_before: Whether to apply layer norm before attention/FFN 26 | concat_after: Whether to concat attention input and output (adds linear layer) 27 | 28 | Shape: 29 | - Target input: (batch, target_len, size) 30 | - Memory input: (batch, source_len, size) 31 | - Output: (batch, target_len, size) 32 | """ 33 | 34 | def __init__( 35 | self, 36 | size: int, 37 | self_attn: nn.Module, 38 | src_attn: nn.Module, 39 | feed_forward: nn.Module, 40 | dropout_rate: float, 41 | normalize_before: bool = True, 42 | concat_after: bool = False, 43 | ): 44 | super().__init__() 45 | self.size = size 46 | self.self_attn = self_attn 47 | self.src_attn = src_attn 48 | self.feed_forward = feed_forward 49 | self.norm1 = LayerNorm(size) 50 | self.norm2 = LayerNorm(size) 51 | self.norm3 = LayerNorm(size) 52 | self.dropout = nn.Dropout(dropout_rate) 53 | self.normalize_before = normalize_before 54 | self.concat_after = concat_after 55 | 56 | if self.concat_after: 57 | self.concat_linear1 = nn.Linear(size + size, size) 58 | self.concat_linear2 = nn.Linear(size + size, size) 59 | 60 | def forward( 61 | self, 62 | tgt: torch.Tensor, 63 | tgt_mask: Optional[torch.Tensor], 64 | memory: torch.Tensor, 65 | memory_mask: Optional[torch.Tensor], 66 | cache: Optional[torch.Tensor] = None, 67 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: 68 | """Forward pass. 69 | 70 | Args: 71 | tgt: Target sequence (batch, target_len, size) 72 | tgt_mask: Target mask (batch, target_len, target_len) - causal mask 73 | memory: Encoder output (batch, source_len, size) 74 | memory_mask: Memory mask (batch, 1, source_len) or (batch, target_len, source_len) 75 | cache: Cached previous output for incremental decoding (batch, target_len-1, size) 76 | 77 | Returns: 78 | Tuple of (output, tgt_mask, memory, memory_mask) 79 | """ 80 | # Self-attention block 81 | residual = tgt 82 | if self.normalize_before: 83 | tgt = self.norm1(tgt) 84 | 85 | if cache is None: 86 | # Full sequence processing 87 | tgt_q = tgt 88 | tgt_q_mask = tgt_mask 89 | else: 90 | # Incremental decoding: compute only the last frame 91 | assert cache.shape == (tgt.shape[0], tgt.shape[1] - 1, self.size), \ 92 | f"Cache shape {cache.shape} != expected {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 93 | tgt_q = tgt[:, -1:, :] 94 | residual = residual[:, -1:, :] 95 | tgt_q_mask = None if tgt_mask is None else tgt_mask[:, -1:, :] 96 | 97 | if self.concat_after: 98 | tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) 99 | x = residual + self.concat_linear1(tgt_concat) 100 | else: 101 | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) 102 | 103 | if not self.normalize_before: 104 | x = self.norm1(x) 105 | 106 | # Cross-attention block 107 | residual = x 108 | if self.normalize_before: 109 | x = self.norm2(x) 110 | 111 | if self.concat_after: 112 | x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) 113 | x = residual + self.concat_linear2(x_concat) 114 | else: 115 | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) 116 | 117 | if not self.normalize_before: 118 | x = self.norm2(x) 119 | 120 | # Feed-forward block 121 | residual = x 122 | if self.normalize_before: 123 | x = self.norm3(x) 124 | x = residual + self.dropout(self.feed_forward(x)) 125 | if not self.normalize_before: 126 | x = self.norm3(x) 127 | 128 | # Concatenate cache for incremental decoding 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /tests/test_combined_scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test combined decoder + CTC scores to see which token wins.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | 10 | print("="*80) 11 | print("COMBINED SCORES TEST (Decoder + CTC)") 12 | print("="*80) 13 | 14 | # Load audio 15 | from speechcatcher.speechcatcher import convert_inputfile 16 | 17 | os.makedirs('.tmp/', exist_ok=True) 18 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 19 | if not os.path.exists(wavfile_path): 20 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 21 | 22 | with wave.open(wavfile_path, 'rb') as f: 23 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 24 | speech = raw_audio.astype(np.float32) / 32768.0 25 | 26 | # Load BPE 27 | import sentencepiece as spm 28 | sp = spm.SentencePieceProcessor() 29 | sp.Load("/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/data/de_token_list/bpe_unigram1024/bpe.model") 30 | 31 | # Load our implementation 32 | print("\n[1] Loading model...") 33 | from speechcatcher.speechcatcher import load_model, tags 34 | 35 | s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 36 | s2t.reset() 37 | print("✅ Model loaded") 38 | 39 | # Process chunks 1-5 to build encoder output 40 | print("\n[2] Building encoder output from chunks 1-5...") 41 | chunk_size = 8000 42 | 43 | buffer = [] 44 | encoder_states = None 45 | 46 | for chunk_idx in range(5): 47 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 48 | is_final = False 49 | speech_tensor = torch.from_numpy(chunk) 50 | 51 | feats, feats_len, frontend_states = s2t.apply_frontend( 52 | speech=speech_tensor, prev_states=s2t.frontend_states if chunk_idx == 0 else frontend_states, is_final=is_final 53 | ) 54 | if feats is not None: 55 | with torch.no_grad(): 56 | enc, enc_len, encoder_states = s2t.model.encoder( 57 | feats, feats_len, 58 | prev_states=encoder_states, is_final=is_final, infer_mode=True, 59 | ) 60 | if enc.size(1) > 0: 61 | buffer.append(enc) 62 | 63 | # Extract 40-frame block 64 | encoder_out = torch.cat(buffer, dim=1).narrow(1, 0, 40) 65 | print(f"Encoder output shape: {encoder_out.shape}") 66 | 67 | # Get decoder and CTC scores 68 | print("\n[3] Computing decoder and CTC scores...") 69 | 70 | # Decoder scores (starting from SOS=1023) 71 | ys = torch.tensor([[1023]], dtype=torch.long) 72 | with torch.no_grad(): 73 | decoder_out, _ = s2t.model.decoder.forward_one_step(ys, None, encoder_out, cache=None) 74 | decoder_log_probs = torch.log_softmax(decoder_out, dim=-1)[0] # (vocab_size,) 75 | 76 | # CTC scores (for initial hypothesis with just SOS) 77 | with torch.no_grad(): 78 | ctc_logits = s2t.model.ctc.ctc_lo(encoder_out) # (1, 40, vocab_size) 79 | ctc_log_probs = torch.log_softmax(ctc_logits, dim=-1)[0] # (40, vocab_size) 80 | 81 | # For CTC prefix scoring, we need to compute the score for each token 82 | # This is complex, but for a quick test, let's just look at the sum of CTC log probs 83 | # across all frames for each token 84 | ctc_sum_scores = ctc_log_probs.sum(dim=0) # (vocab_size,) 85 | 86 | print("\n" + "="*80) 87 | print("SCORE BREAKDOWN FOR KEY TOKENS") 88 | print("="*80) 89 | 90 | # Compare key tokens: 372 (▁dieses), 738 (trag), 1023 (م) 91 | key_tokens = [ 92 | (372, '▁dieses'), 93 | (738, 'trag'), 94 | (1023, 'م'), 95 | ] 96 | 97 | decoder_weight = 0.7 98 | ctc_weight = 0.3 99 | 100 | print(f"\nWeights: decoder={decoder_weight}, ctc={ctc_weight}") 101 | print(f"\nToken breakdown:\n") 102 | 103 | for token_id, token_text in key_tokens: 104 | dec_score = decoder_log_probs[token_id].item() 105 | ctc_score = ctc_sum_scores[token_id].item() 106 | combined = decoder_weight * dec_score + ctc_weight * ctc_score 107 | 108 | print(f"Token {token_id:4d} '{token_text}':") 109 | print(f" Decoder score: {dec_score:>10.4f} (weighted: {decoder_weight * dec_score:>10.4f})") 110 | print(f" CTC sum score: {ctc_score:>10.4f} (weighted: {ctc_weight * ctc_score:>10.4f})") 111 | print(f" Combined: {combined:>10.4f}") 112 | print() 113 | 114 | # Find top 10 by combined score 115 | combined_scores = decoder_weight * decoder_log_probs + ctc_weight * ctc_sum_scores 116 | top_scores, top_tokens = combined_scores.topk(10) 117 | 118 | print("="*80) 119 | print("TOP 10 BY COMBINED SCORE") 120 | print("="*80) 121 | print() 122 | 123 | for i, (score, token) in enumerate(zip(top_scores.tolist(), top_tokens.tolist())): 124 | piece = sp.IdToPiece(token) 125 | dec = decoder_log_probs[token].item() 126 | ctc = ctc_sum_scores[token].item() 127 | print(f"{i+1:2d}. Token {token:4d} '{piece:12s}': {score:>8.4f} (dec: {dec:>7.3f}, ctc: {ctc:>8.3f})") 128 | 129 | print("\n" + "="*80) 130 | print("NOTE: CTC scores here are just summed log probs, not true CTC prefix scores") 131 | print("Real CTC prefix scoring uses forward algorithm with blank handling") 132 | print("="*80) 133 | -------------------------------------------------------------------------------- /tests/test_normalization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test normalization step.""" 3 | 4 | import logging 5 | import sys 6 | import numpy as np 7 | import torch 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(message)s') 10 | logger = logging.getLogger(__name__) 11 | 12 | # Load audio 13 | from speechcatcher.speechcatcher import convert_inputfile 14 | import wave 15 | import hashlib 16 | import os 17 | 18 | logger.info("Loading audio...") 19 | os.makedirs('.tmp/', exist_ok=True) 20 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 21 | if not os.path.exists(wavfile_path): 22 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 23 | 24 | with wave.open(wavfile_path, 'rb') as wavfile_in: 25 | buf = wavfile_in.readframes(-1) 26 | raw_audio = np.frombuffer(buf, dtype='int16') 27 | 28 | waveform = torch.from_numpy(raw_audio).float() 29 | 30 | # Extract features 31 | from speechcatcher.model.frontend.stft_frontend import STFTFrontend 32 | 33 | frontend = STFTFrontend(n_fft=512, hop_length=160, win_length=400, n_mels=80) 34 | with torch.no_grad(): 35 | features, lengths = frontend(waveform) 36 | 37 | logger.info(f"Features (before norm): shape={features.shape}") 38 | logger.info(f"Features stats: min={features.min():.4f}, max={features.max():.4f}, mean={features.mean():.4f}, std={features.std():.4f}") 39 | 40 | # Load model to get normalization stats 41 | from speechcatcher.speechcatcher import load_model 42 | 43 | speech2text = load_model( 44 | tag="speechcatcher/speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024", 45 | device="cpu", 46 | beam_size=5, 47 | quiet=True 48 | ) 49 | 50 | # Check if model has normalization 51 | if hasattr(speech2text.model, 'normalize'): 52 | logger.info(f"\n✅ Model has normalize module: {speech2text.model.normalize}") 53 | logger.info(f"Normalize type: {type(speech2text.model.normalize)}") 54 | 55 | # Check normalization stats 56 | if hasattr(speech2text.model.normalize, 'mean'): 57 | logger.info(f"Norm mean shape: {speech2text.model.normalize.mean.shape}") 58 | logger.info(f"Norm mean stats: min={speech2text.model.normalize.mean.min():.4f}, max={speech2text.model.normalize.mean.max():.4f}") 59 | if hasattr(speech2text.model.normalize, 'std'): 60 | logger.info(f"Norm std shape: {speech2text.model.normalize.std.shape}") 61 | logger.info(f"Norm std stats: min={speech2text.model.normalize.std.min():.4f}, max={speech2text.model.normalize.std.max():.4f}") 62 | 63 | # Apply normalization manually 64 | logger.info("\n" + "="*60) 65 | logger.info("Applying normalization...") 66 | logger.info("="*60) 67 | 68 | with torch.no_grad(): 69 | normalized_features = speech2text.model.normalize(features) 70 | 71 | logger.info(f"Normalized features: shape={normalized_features.shape}") 72 | logger.info(f"Normalized stats: min={normalized_features.min():.4f}, max={normalized_features.max():.4f}, mean={normalized_features.mean():.4f}, std={normalized_features.std():.4f}") 73 | 74 | else: 75 | logger.warning("❌ Model does NOT have normalize module!") 76 | 77 | # Check if normalization is in the model's forward pass 78 | logger.info("\n" + "="*60) 79 | logger.info("Testing full model forward pass...") 80 | logger.info("="*60) 81 | 82 | # Run through the model's frontend + normalize 83 | with torch.no_grad(): 84 | # Frontend 85 | model_features, model_lengths = speech2text.model.frontend(waveform.unsqueeze(0)) 86 | logger.info(f"Model frontend output: shape={model_features.shape}") 87 | logger.info(f"Model frontend stats: min={model_features.min():.4f}, max={model_features.max():.4f}, mean={model_features.mean():.4f}") 88 | 89 | # Normalize (if exists) 90 | if hasattr(speech2text.model, 'normalize'): 91 | model_features = speech2text.model.normalize(model_features) 92 | logger.info(f"After normalize: shape={model_features.shape}") 93 | logger.info(f"After normalize stats: min={model_features.min():.4f}, max={model_features.max():.4f}, mean={model_features.mean():.4f}, std={model_features.std():.4f}") 94 | 95 | # Encoder 96 | logger.info("\nRunning encoder...") 97 | encoder_out, encoder_out_lens, encoder_states = speech2text.model.encoder( 98 | model_features, 99 | model_lengths, 100 | prev_states=None, 101 | is_final=True, 102 | infer_mode=True 103 | ) 104 | 105 | logger.info(f"Encoder output: shape={encoder_out.shape}") 106 | logger.info(f"Encoder stats: min={encoder_out.min():.4f}, max={encoder_out.max():.4f}, mean={encoder_out.mean():.4f}, std={encoder_out.std():.4f}") 107 | 108 | # Check for NaN/Inf 109 | if torch.isnan(encoder_out).any(): 110 | logger.error("❌ Encoder output contains NaN!") 111 | if torch.isinf(encoder_out).any(): 112 | logger.error("❌ Encoder output contains Inf!") 113 | if not torch.isnan(encoder_out).any() and not torch.isinf(encoder_out).any(): 114 | logger.info("✅ Encoder output is clean (no NaN/Inf)") 115 | 116 | logger.info("\n" + "="*60) 117 | logger.info("Normalization test complete!") 118 | logger.info("="*60) 119 | -------------------------------------------------------------------------------- /tests/test_multi_chunk_comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Multi-chunk comparison: ESPnet streaming vs our implementation.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import logging 10 | 11 | logging.basicConfig(level=logging.WARNING, format='%(message)s') 12 | 13 | print("="*80) 14 | print("MULTI-CHUNK STREAMING COMPARISON") 15 | print("="*80) 16 | 17 | # Load audio 18 | print("\n[1] Loading audio...") 19 | from speechcatcher.speechcatcher import convert_inputfile 20 | 21 | os.makedirs('.tmp/', exist_ok=True) 22 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 23 | if not os.path.exists(wavfile_path): 24 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 25 | 26 | with wave.open(wavfile_path, 'rb') as wavfile_in: 27 | rate = wavfile_in.getframerate() 28 | buf = wavfile_in.readframes(-1) 29 | raw_audio = np.frombuffer(buf, dtype='int16') 30 | 31 | speech = raw_audio.astype(np.float32) / 32768.0 32 | print(f"Audio: {len(speech)} samples ({len(speech)/rate:.2f}s)") 33 | 34 | # Load ESPnet 35 | print("\n[2] Loading ESPnet streaming decoder...") 36 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 37 | 38 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 39 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 40 | 41 | espnet_s2t = ESPnetStreaming( 42 | asr_train_config=config_path, 43 | asr_model_file=model_path, 44 | device="cpu", 45 | beam_size=5, 46 | ctc_weight=0.3, 47 | disable_repetition_detection=False, 48 | ) 49 | espnet_s2t.reset() 50 | print(f"✅ Loaded (block={espnet_s2t.beam_search.block_size}, hop={espnet_s2t.beam_search.hop_size}, lookahead={espnet_s2t.beam_search.look_ahead})") 51 | 52 | # Load ours 53 | print("\n[3] Loading our implementation...") 54 | from speechcatcher.speechcatcher import load_model, tags 55 | 56 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 57 | our_s2t.reset() 58 | print(f"✅ Loaded") 59 | 60 | # Process in chunks 61 | print("\n" + "="*80) 62 | print("PROCESSING CHUNKS") 63 | print("="*80) 64 | 65 | chunk_size = 8000 # Match typical streaming chunk size 66 | num_chunks = (len(speech) + chunk_size - 1) // chunk_size 67 | 68 | print(f"\nChunk size: {chunk_size} samples ({chunk_size/rate:.2f}s)") 69 | print(f"Total chunks: {num_chunks}\n") 70 | 71 | for chunk_idx in range(num_chunks): 72 | start = chunk_idx * chunk_size 73 | end = min((chunk_idx + 1) * chunk_size, len(speech)) 74 | chunk = speech[start:end] 75 | is_final = (chunk_idx == num_chunks - 1) 76 | 77 | print(f"--- Chunk {chunk_idx+1}/{num_chunks} (is_final={is_final}) ---") 78 | 79 | # ESPnet 80 | with torch.no_grad(): 81 | espnet_results = espnet_s2t(chunk, is_final=is_final) 82 | 83 | espnet_text = "" 84 | if espnet_results: 85 | espnet_text, _, _, _, _ = espnet_results[0] 86 | 87 | # Check encoder output size 88 | if hasattr(espnet_s2t, 'encoder_states') and espnet_s2t.encoder_states: 89 | enc_info = f"enc_states=present" 90 | else: 91 | enc_info = f"enc_states=None" 92 | 93 | # Check beam search state 94 | bs = espnet_s2t.beam_search 95 | bs_info = f"block={bs.processed_block}, idx={bs.process_idx}" 96 | if bs.encbuffer is not None: 97 | bs_info += f", encbuf={bs.encbuffer.shape[0]}" 98 | 99 | print(f" ESPnet: '{espnet_text[:50]}' ({bs_info}, {enc_info})") 100 | 101 | # Ours 102 | with torch.no_grad(): 103 | our_results = our_s2t(chunk, is_final=is_final) 104 | 105 | our_text = "" 106 | if our_results and len(our_results) > 0: 107 | our_text, _, _ = our_results[0] 108 | 109 | # Check beam state 110 | if our_s2t.beam_state: 111 | bs_info = f"output_idx={our_s2t.beam_state.output_index}, hyps={len(our_s2t.beam_state.hypotheses)}" 112 | if our_s2t.beam_state.encoder_out is not None: 113 | bs_info += f", enc_out={our_s2t.beam_state.encoder_out.shape[1]}" 114 | else: 115 | bs_info = "beam_state=None" 116 | 117 | print(f" Ours: '{our_text[:50]}' ({bs_info})") 118 | 119 | # Break if final 120 | if is_final: 121 | break 122 | 123 | print("\n" + "="*80) 124 | print("FINAL RESULTS") 125 | print("="*80) 126 | 127 | # Get final outputs 128 | espnet_s2t.reset() 129 | espnet_final = espnet_s2t(speech, is_final=True) 130 | espnet_final_text = "" 131 | if espnet_final: 132 | espnet_final_text, _, _, _, _ = espnet_final[0] 133 | 134 | our_s2t.reset() 135 | our_final = our_s2t(speech, is_final=True) 136 | our_final_text = "" 137 | if our_final and len(our_final) > 0: 138 | our_final_text, _, _ = our_final[0] 139 | 140 | print(f"\nESPnet (full audio): '{espnet_final_text}'") 141 | print(f"\nOurs (full audio): '{our_final_text}'") 142 | 143 | print("\n" + "="*80) 144 | print("COMPARISON COMPLETE") 145 | print("="*80) 146 | -------------------------------------------------------------------------------- /tests/test_normalization_only.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test ONLY the normalization step to isolate the difference.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | # Load same features from both pipelines BEFORE normalization 8 | from speechcatcher.speechcatcher import load_model 9 | from espnet2.bin.asr_inference_streaming import Speech2TextStreaming as ESPnetS2T 10 | import wave 11 | import hashlib 12 | import os 13 | 14 | print("Loading audio...") 15 | os.makedirs('.tmp/', exist_ok=True) 16 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 17 | with wave.open(wavfile_path, 'rb') as wf: 18 | buf = wf.readframes(-1) 19 | raw_audio = np.frombuffer(buf, dtype='int16') 20 | 21 | waveform = torch.from_numpy(raw_audio).float() 22 | 23 | # Load models 24 | print("Loading models...") 25 | our_model = load_model( 26 | tag="speechcatcher/speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024", 27 | device="cpu", 28 | beam_size=5, 29 | quiet=True 30 | ) 31 | 32 | espnet_model = ESPnetS2T( 33 | asr_train_config="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml", 34 | asr_model_file="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth", 35 | device="cpu", 36 | ) 37 | 38 | print("\nExtracting features from frontend (before normalization)...") 39 | 40 | # Get features BEFORE normalization using ESPnet's frontend 41 | speech_batch = waveform.unsqueeze(0) 42 | speech_lengths = torch.tensor([waveform.shape[0]]) 43 | 44 | with torch.no_grad(): 45 | raw_features, feat_lengths = espnet_model.asr_model.frontend(speech_batch, speech_lengths) 46 | 47 | print(f"Raw features: {raw_features.shape}") 48 | print(f"Raw features stats: min={raw_features.min():.4f}, max={raw_features.max():.4f}, mean={raw_features.mean():.4f}, std={raw_features.std():.4f}") 49 | 50 | # Now apply normalization TWO ways 51 | print("\n" + "="*80) 52 | print("NORMALIZATION COMPARISON") 53 | print("="*80) 54 | 55 | # Method 1: Our way (numpy) 56 | print("\n[1] Our normalization (numpy-based):") 57 | raw_np = raw_features.squeeze(0).cpu().numpy() 58 | our_normalized_np = (raw_np - our_model.mean) / our_model.std 59 | our_normalized = torch.from_numpy(our_normalized_np).unsqueeze(0) 60 | 61 | print(f" Result: min={our_normalized.min():.4f}, max={our_normalized.max():.4f}, mean={our_normalized.mean():.4f}, std={our_normalized.std():.4f}") 62 | 63 | # Method 2: ESPnet way (torch, with masking) 64 | print("\n[2] ESPnet normalization (torch-based with mask):") 65 | espnet_normalized, _ = espnet_model.asr_model.normalize(raw_features, feat_lengths) 66 | 67 | print(f" Result: min={espnet_normalized.min():.4f}, max={espnet_normalized.max():.4f}, mean={espnet_normalized.mean():.4f}, std={espnet_normalized.std():.4f}") 68 | 69 | # Method 3: Manual torch (no mask) 70 | print("\n[3] Manual torch normalization (no mask):") 71 | mean_torch = torch.from_numpy(our_model.mean).to(raw_features.dtype) 72 | std_torch = torch.from_numpy(our_model.std).to(raw_features.dtype) 73 | manual_normalized = (raw_features - mean_torch) / std_torch 74 | 75 | print(f" Result: min={manual_normalized.min():.4f}, max={manual_normalized.max():.4f}, mean={manual_normalized.mean():.4f}, std={manual_normalized.std():.4f}") 76 | 77 | # Compare 78 | print("\n" + "="*80) 79 | print("COMPARISON") 80 | print("="*80) 81 | 82 | diff_1_2 = (our_normalized - espnet_normalized).abs() 83 | print(f"\nOur vs ESPnet:") 84 | print(f" Max diff: {diff_1_2.max():.6f}") 85 | print(f" Mean diff: {diff_1_2.mean():.6f}") 86 | 87 | diff_3_2 = (manual_normalized - espnet_normalized).abs() 88 | print(f"\nManual torch vs ESPnet:") 89 | print(f" Max diff: {diff_3_2.max():.6f}") 90 | print(f" Mean diff: {diff_3_2.mean():.6f}") 91 | 92 | diff_1_3 = (our_normalized - manual_normalized).abs() 93 | print(f"\nOur vs Manual torch:") 94 | print(f" Max diff: {diff_1_3.max():.6f}") 95 | print(f" Mean diff: {diff_1_3.mean():.6f}") 96 | 97 | if torch.allclose(our_normalized, espnet_normalized, atol=1e-4): 98 | print("\n✅ All normalizations match!") 99 | else: 100 | print("\n❌ Normalizations differ!") 101 | 102 | # Check what ESPnet's mask does 103 | print("\n" + "="*80) 104 | print("CHECKING MASK EFFECT") 105 | print("="*80) 106 | from espnet_model_zoo.downloader import ModelDownloader 107 | from espnet2.torch_utils.set_all_random_seed import set_all_random_seed 108 | from espnet2.utils.types import str2bool 109 | from espnet_model_zoo.downloader import ModelDownloader 110 | import espnet2.tasks.asr 111 | from espnet2.asr.frontend.default import DefaultFrontend 112 | # from espnet2.utils.get_default_kwargs import get_default_kwargs 113 | 114 | # Check if there's a mask being applied 115 | print(f"Feature lengths: {feat_lengths}") 116 | print(f"Max length: {raw_features.size(1)}") 117 | if feat_lengths[0] < raw_features.size(1): 118 | print(f"⚠️ Padding detected! {raw_features.size(1) - feat_lengths[0]} frames are padded") 119 | else: 120 | print("No padding - full sequence used") 121 | -------------------------------------------------------------------------------- /speechcatcher/websocket_demo.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Speechcatcher ASR Demo 7 | 29 | 30 | 31 |

Speechcatcher ASR Demo

32 |
Press "Start" to begin German live transcription.
33 | 34 | 35 |
36 | 37 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /docs/analysis/streaming-analysis.md: -------------------------------------------------------------------------------- 1 | # Streaming Implementation Comparison 2 | 3 | ## ESPnet Streaming Decoder vs Our Implementation 4 | 5 | ### Key Parameters from Paper (Tsunoo et al., 2021) 6 | 7 | Paper configuration: `{N_l, N_c, N_r} = {16, 16, 8}` frames (after 4x subsampling) 8 | - **block_size** = 40 frames (total context per block) 9 | - **hop_size** = 16 frames (how much we advance per block) 10 | - **look_ahead** = 16 frames (future context) 11 | 12 | ### ESPnet Streaming Decoder Architecture 13 | 14 | **File**: `espnet_streaming_decoder/asr_inference_streaming.py` 15 | 16 | ```python 17 | # Line 325-331 18 | enc, _, self.encoder_states = self.asr_model.encoder( 19 | feats, 20 | feats_lengths, 21 | self.encoder_states, 22 | is_final=is_final, 23 | infer_mode=True, # ← USES infer_mode=True! 24 | ) 25 | ``` 26 | 27 | **File**: `espnet_streaming_decoder/batch_beam_search_online.py` 28 | 29 | ```python 30 | # Lines 133-135: Block calculation 31 | cur_end_frame = ( 32 | self.block_size - self.look_ahead + self.hop_size * self.processed_block 33 | ) 34 | 35 | # Block 0: 40 - 16 + 16*0 = 24 frames 36 | # Block 1: 40 - 16 + 16*1 = 40 frames 37 | # Block 2: 40 - 16 + 16*2 = 56 frames 38 | ``` 39 | 40 | **BBD Implementation** (Lines 209-218): 41 | ```python 42 | # Simple repetition detection 43 | elif ( 44 | not self.disable_repetition_detection 45 | and not prev_repeat 46 | and best.yseq[i, -1] in best.yseq[i, :-1] # Is last token in history? 47 | and not is_final 48 | ): 49 | prev_repeat = True 50 | ``` 51 | 52 | **Rollback** (Lines 258-261): 53 | ```python 54 | # NON-conservative rollback (1 step) 55 | if self.process_idx > 1 and len(self.prev_hyps) > 0: 56 | self.running_hyps = self.prev_hyps 57 | self.process_idx -= 1 # Go back ONE step 58 | self.prev_hyps = [] 59 | ``` 60 | 61 | **EOS Detection** (Lines 230-232): 62 | ```python 63 | if len(local_ended_hyps) > 0 and not is_final: 64 | logging.info("Detected hyp(s) reaching EOS in this block.") 65 | break 66 | ``` 67 | 68 | ### Our Implementation 69 | 70 | **File**: `speechcatcher/beam_search/beam_search.py` 71 | 72 | ```python 73 | # Line 372-378 74 | encoder_out, encoder_out_lens, encoder_states = self.encoder( 75 | features, 76 | feature_lens, 77 | prev_states=prev_state.encoder_states, 78 | is_final=is_final, 79 | infer_mode=False, # ← WE USE infer_mode=False! 80 | ) 81 | ``` 82 | 83 | **BBD Implementation** (Lines 347-419): 84 | ```python 85 | # Complex reliability score calculation (Equations 12-13) 86 | r_score = max score among repetitions 87 | s_score = alpha_next - r_score 88 | 89 | if s_score <= 0: 90 | # Unreliable hypothesis 91 | ``` 92 | 93 | **Rollback** (Lines 507-513): 94 | ```python 95 | # CONSERVATIVE rollback (2 steps) by default 96 | if self.bbd_conservative and len(prev_step_hypotheses) > 0: 97 | new_state.hypotheses = prev_step_hypotheses 98 | new_state.output_index -= 2 # Go back TWO steps 99 | ``` 100 | 101 | ### Critical Differences 102 | 103 | | Aspect | ESPnet Streaming | Our Implementation | Impact | 104 | |--------|------------------|-------------------|---------| 105 | | **Encoder mode** | `infer_mode=True` | `infer_mode=False` | ❌ WRONG MODE | 106 | | **Block size** | 40 frames (encoder output) | Not used | ❌ NO BLOCKING | 107 | | **Hop size** | 16 frames | Not used | ❌ NO HOPPING | 108 | | **Look ahead** | 16 frames | Not used | ❌ NO LOOKAHEAD | 109 | | **BBD detection** | Simple: `last_token in prev_tokens` | Complex: Eq 12-13 | ⚠️ Overly complex | 110 | | **Rollback** | 1 step (non-conservative) | 2 steps (conservative) | ⚠️ Too conservative | 111 | | **Block boundary** | Explicit block-by-block | Per audio chunk | ❌ WRONG GRANULARITY | 112 | 113 | ### Audio Chunk Size Analysis 114 | 115 | **Current chunk_length**: 8192 samples (0.512s at 16kHz) 116 | 117 | **After frontend processing**: 118 | - STFT with hop_length=160: 8192 / 160 = **51 STFT frames** 119 | - Subsampling by 4x: 51 / 4 = **~12 encoder frames** 120 | 121 | **Required for first block**: 122 | - block_size - look_ahead = 40 - 16 = **24 encoder frames minimum** 123 | 124 | **Problem**: Our chunks give us ~12 frames, but we need at least 24! 125 | 126 | ### Required Chunk Size Calculation 127 | 128 | To get 24 encoder frames: 129 | - 24 encoder frames × 4 (subsampling) × 160 (hop_length) = **15,360 samples** 130 | 131 | To get 40 encoder frames (full block): 132 | - 40 encoder frames × 4 × 160 = **25,600 samples** 133 | 134 | **Recommended**: `chunk_length = 25600` (1.6 seconds at 16kHz) 135 | 136 | ### Action Items 137 | 138 | 1. ✅ Change encoder to `infer_mode=True` 139 | 2. ❌ Implement proper block-based processing 140 | - Track `cur_end_frame` based on block_size, hop_size, look_ahead 141 | - Process encoder output in blocks, not raw audio chunks 142 | 3. ❌ Simplify BBD to match ESPnet 143 | - Replace complex reliability score with simple repetition check 144 | - Change rollback from 2 steps to 1 step 145 | 4. ⚠️ Increase chunk_length from 8192 to 25600 146 | - OR process multiple audio chunks before decoding 147 | 148 | ### Why Batch Mode Works But Streaming Doesn't 149 | 150 | **Batch mode** (`infer_mode=False`): 151 | - Sees ENTIRE utterance at once 152 | - No need for block boundaries 153 | - Can look at all future context 154 | 155 | **Streaming mode** (`infer_mode=True`): 156 | - Processes in LIMITED blocks 157 | - Needs explicit block boundaries 158 | - Has to stop when context insufficient 159 | 160 | **Our bug**: Using `infer_mode=False` (batch mode) while trying to stream! 161 | -------------------------------------------------------------------------------- /tests/model/test_manual.py: -------------------------------------------------------------------------------- 1 | """Manual test to verify layers work correctly.""" 2 | 3 | import sys 4 | import torch 5 | 6 | # Add parent directory to path 7 | sys.path.insert(0, '/home/ben/speechcatcher') 8 | 9 | from speechcatcher.model.layers import ( 10 | PositionwiseFeedForward, 11 | PositionalEncoding, 12 | RelPositionalEncoding, 13 | StreamPositionalEncoding, 14 | ConvolutionModule, 15 | LayerNorm, 16 | ) 17 | from speechcatcher.model.attention import ( 18 | MultiHeadedAttention, 19 | RelPositionMultiHeadedAttention, 20 | ) 21 | from speechcatcher.model.frontend import STFTFrontend 22 | 23 | 24 | def test_feed_forward(): 25 | """Test PositionwiseFeedForward.""" 26 | print("Testing PositionwiseFeedForward...") 27 | layer = PositionwiseFeedForward(256, 1024, 256) 28 | x = torch.randn(2, 10, 256) 29 | out = layer(x) 30 | assert out.shape == (2, 10, 256), f"Expected (2, 10, 256), got {out.shape}" 31 | print("✓ PositionwiseFeedForward passed") 32 | 33 | 34 | def test_positional_encoding(): 35 | """Test positional encoding variants.""" 36 | print("\nTesting PositionalEncoding...") 37 | pe = PositionalEncoding(256) 38 | x = torch.randn(2, 100, 256) 39 | out = pe(x) 40 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 41 | print("✓ PositionalEncoding passed") 42 | 43 | print("Testing RelPositionalEncoding...") 44 | rel_pe = RelPositionalEncoding(256) 45 | out, pos_emb = rel_pe(x) 46 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 47 | assert pos_emb.shape == (1, 100, 256), f"Expected (1, 100, 256), got {pos_emb.shape}" 48 | print("✓ RelPositionalEncoding passed") 49 | 50 | print("Testing StreamPositionalEncoding...") 51 | spe = StreamPositionalEncoding(256) 52 | x1 = torch.randn(1, 10, 256) 53 | x2 = torch.randn(1, 10, 256) 54 | out1 = spe(x1) 55 | out2 = spe(x2) 56 | assert out1.shape == x1.shape 57 | assert out2.shape == x2.shape 58 | spe.reset() 59 | print("✓ StreamPositionalEncoding passed") 60 | 61 | 62 | def test_convolution(): 63 | """Test ConvolutionModule.""" 64 | print("\nTesting ConvolutionModule...") 65 | conv = ConvolutionModule(256, kernel_size=31) 66 | x = torch.randn(2, 100, 256) 67 | out = conv(x) 68 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 69 | print("✓ ConvolutionModule passed") 70 | 71 | 72 | def test_layer_norm(): 73 | """Test LayerNorm.""" 74 | print("\nTesting LayerNorm...") 75 | ln = LayerNorm(256) 76 | x = torch.randn(2, 100, 256) 77 | out = ln(x) 78 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 79 | print("✓ LayerNorm passed") 80 | 81 | 82 | def test_multi_head_attention(): 83 | """Test MultiHeadedAttention.""" 84 | print("\nTesting MultiHeadedAttention...") 85 | attn = MultiHeadedAttention(n_head=4, n_feat=256) 86 | x = torch.randn(2, 10, 256) 87 | out = attn(x, x, x) 88 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 89 | print("✓ MultiHeadedAttention passed") 90 | 91 | print("Testing MultiHeadedAttention with cache...") 92 | q = torch.randn(1, 1, 256) 93 | k = torch.randn(1, 1, 256) 94 | v = torch.randn(1, 1, 256) 95 | out, cache = attn.forward_with_cache(q, k, v) 96 | assert out.shape == (1, 1, 256) 97 | assert cache[0].shape[2] == 1 # K cache time dimension 98 | assert cache[1].shape[2] == 1 # V cache time dimension 99 | print("✓ MultiHeadedAttention with cache passed") 100 | 101 | 102 | def test_rel_position_attention(): 103 | """Test RelPositionMultiHeadedAttention.""" 104 | print("\nTesting RelPositionMultiHeadedAttention...") 105 | attn = RelPositionMultiHeadedAttention(n_head=4, n_feat=256) 106 | x = torch.randn(2, 10, 256) 107 | pos_emb = torch.randn(1, 10, 256) 108 | out = attn(x, x, x, pos_emb) 109 | assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" 110 | print("✓ RelPositionMultiHeadedAttention passed") 111 | 112 | 113 | def test_stft_frontend(): 114 | """Test STFTFrontend.""" 115 | print("\nTesting STFTFrontend...") 116 | frontend = STFTFrontend(n_fft=512, hop_length=128, n_mels=80) 117 | 118 | # Generate 1 second of audio at 16kHz 119 | waveform = torch.randn(2, 16000) 120 | features, lengths = frontend(waveform) 121 | 122 | print(f" Input shape: {waveform.shape}") 123 | print(f" Output shape: {features.shape}") 124 | print(f" Lengths: {lengths}") 125 | 126 | assert features.shape[0] == 2 # Batch size 127 | assert features.shape[2] == 80 # n_mels 128 | assert lengths.shape == (2,) 129 | print("✓ STFTFrontend passed") 130 | 131 | 132 | def main(): 133 | """Run all tests.""" 134 | print("="*60) 135 | print("Running manual tests for model layers") 136 | print("="*60) 137 | 138 | try: 139 | test_feed_forward() 140 | test_positional_encoding() 141 | test_convolution() 142 | test_layer_norm() 143 | test_multi_head_attention() 144 | test_rel_position_attention() 145 | test_stft_frontend() 146 | 147 | print("\n" + "="*60) 148 | print("✓ All tests passed!") 149 | print("="*60) 150 | return 0 151 | except Exception as e: 152 | print(f"\n✗ Test failed with error: {e}") 153 | import traceback 154 | traceback.print_exc() 155 | return 1 156 | 157 | 158 | if __name__ == "__main__": 159 | sys.exit(main()) 160 | -------------------------------------------------------------------------------- /docs/debugging/investigation.md: -------------------------------------------------------------------------------- 1 | # Investigation Summary: Native vs ESPnet Decoder Comparison 2 | 3 | ## Objective 4 | Investigate why native decoder produces only ~35% of ESPnet output (238 vs 544 words on full video). 5 | 6 | ## Methodology 7 | Step-by-step investigation of segment_1 (worst offender: 37 native words vs 96 ESPnet words). 8 | 9 | ## Key Findings 10 | 11 | ### 1. BBD Implementation is Correct ✓ 12 | - **Both decoders show identical BBD triggering patterns** 13 | - Step 7 → Step 1 → Step 0 repeatedly 14 | - BBD logic matches ESPnet perfectly 15 | - **Conclusion**: BBD is NOT the problem 16 | 17 | ### 2. Encoder Outputs are Identical ✓ 18 | - Added debug logging to capture encoder statistics 19 | - Native decoder block 0: mean=0.056936, std=0.493884, shape=[1, 16, 256] 20 | - Encoder is deterministic and produces identical outputs 21 | - **Conclusion**: Encoder is NOT the problem 22 | 23 | ### 3. Root Cause: Premature EOS Prediction ✗ 24 | Added hypothesis selection debugging, discovered: 25 | 26 | **Block 0, Step 5**: 27 | ``` 28 | best_yseq=[1023, 72, 104, 260, 66, 3, 1023] 29 | ^EOS at position 6! 30 | ``` 31 | 32 | **Block 2, Step 2**: 33 | ``` 34 | best_yseq=[..., 41, 1023, 1023] 35 | ^TWO EOS tokens back-to-back! 36 | ``` 37 | 38 | **Why this is catastrophic**: 39 | 1. Decoder assigns high probability to EOS (token ID 1023) prematurely 40 | 2. Hypotheses contain multiple EOS tokens in middle of sequence 41 | 3. BBD detects token 1023 repetition and triggers immediately 42 | 4. Block decoding stops before meaningful output is generated 43 | 5. Only 5-37 words produced instead of 96+ 44 | 45 | ### 4. Secondary Issue: Special Token Leaking ✓ FIXED 46 | **Before fix**: Native output contained `` tokens in text: 47 | ``` 48 | Forderung, die... unsere... 49 | ``` 50 | 51 | **Fix applied**: Filter tokens 0, 1, 1023 from output (`speech2text_streaming.py:511`) 52 | ```python 53 | # Before: 54 | token_ids_filtered = [tid for tid in token_ids if tid != 0] 55 | 56 | # After: 57 | token_ids_filtered = [tid for tid in token_ids if tid not in [0, 1, 1023]] 58 | ``` 59 | 60 | **Result**: Output is now clean, but still very short (5 words vs 96) 61 | 62 | ## Why ESPnet Doesn't Have This Problem 63 | 64 | The original ESPnet decoder handles EOS differently, likely through: 65 | 1. **Different hypothesis pruning**: Discards premature EOS hypotheses 66 | 2. **Different EOS scoring**: Penalizes premature EOS predictions 67 | 3. **Different state management**: Prevents EOS prediction loops 68 | 4. **Different beam selection**: Prefers hypotheses that haven't reached EOS yet 69 | 70 | ## Impact Summary 71 | 72 | ### Segments 0-3 (First 4 minutes) 73 | | Segment | Native | ESPnet | Native % | Issue | 74 | |---------|--------|--------|----------|-------| 75 | | 0 | 70 | 105 | 66.7% | Premature EOS | 76 | | 1 | 5* | 96 | 5.2% | Severe premature EOS | 77 | | 2 | 4 | 96 | 4.2% | Critical premature EOS | 78 | | 3 | 6 | 111 | 5.4% | Critical premature EOS | 79 | 80 | *After special token filtering fix 81 | 82 | ### Segments 4-7 (Last 4 minutes) 83 | All four segments produce **identical output** (100% match), suggesting: 84 | - No premature EOS in these segments 85 | - Audio content dependency 86 | - Possible acoustic differences that trigger/prevent EOS 87 | 88 | ## Files Modified 89 | 90 | 1. **`speechcatcher/speech2text_streaming.py:511`** 91 | - Fixed special token filtering 92 | - Now removes tokens 0, 1, 1023 from output 93 | 94 | 2. **`speechcatcher/beam_search/beam_search.py:475-489`** 95 | - Added encoder output debugging 96 | - Logs mean/std/min/max statistics 97 | - Saves first block to `/tmp/encoder_debug/block0_native.pt` 98 | 99 | 3. **`speechcatcher/beam_search/beam_search.py:666-671`** 100 | - Added hypothesis selection debugging 101 | - Logs best hypothesis score and token sequence after each step 102 | 103 | 4. **`speechcatcher/beam_search/beam_search.py:758-769`** 104 | - Added `decoder_name` parameter to `create_beam_search()` 105 | - Enables distinguishing ESPnet vs native in debug logs 106 | 107 | ## Next Steps 108 | 109 | ### Option 1: Match ESPnet's EOS Handling (Recommended) 110 | Investigate how ESPnet's original decoder prevents premature EOS: 111 | 1. Check ESPnet's `batch_beam_search_online.py` for EOS handling 112 | 2. Look for hypothesis filtering that discards premature EOS 113 | 3. Check if ESPnet uses different EOS scoring/penalties 114 | 4. Implement similar logic in native decoder 115 | 116 | ### Option 2: Post-Filter EOS Tokens 117 | Add logic to prevent decoder from predicting EOS until minimum length reached: 118 | - Block EOS token in beam search until min_length tokens generated 119 | - Similar to length penalty in standard beam search 120 | - Risk: May not match ESPnet's exact behavior 121 | 122 | ### Option 3: Investigate Acoustic Dependency 123 | Understand why segments 4-7 work perfectly but 0-3 fail: 124 | - Analyze audio characteristics (volume, speech rate, silence) 125 | - Check if encoder outputs differ in problematic segments 126 | - May reveal model limitations or training data bias 127 | 128 | ## Recommendation 129 | 130 | **Start with Option 1**: Study ESPnet's EOS handling and match it exactly. This is the most principled approach and will ensure our decoder behaves identically to ESPnet. 131 | 132 | The investigation infrastructure is now in place: 133 | - Encoder debugging 134 | - Hypothesis selection logging 135 | - Special token filtering 136 | - All tools needed to compare behaviors in detail 137 | -------------------------------------------------------------------------------- /speechcatcher/model/layers/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """Positional encoding modules for Transformer/Conformer.""" 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | """Absolute positional encoding. 12 | 13 | PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) 14 | PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) 15 | 16 | Args: 17 | d_model: Embedding dimension 18 | dropout_rate: Dropout rate 19 | max_len: Maximum sequence length (default: 5000) 20 | reverse: Whether to reverse the positional encoding (default: False) 21 | 22 | Shape: 23 | - Input: (batch, time, d_model) 24 | - Output: (batch, time, d_model) 25 | """ 26 | 27 | def __init__( 28 | self, 29 | d_model: int, 30 | dropout_rate: float = 0.1, 31 | max_len: int = 5000, 32 | reverse: bool = False, 33 | ): 34 | super().__init__() 35 | self.d_model = d_model 36 | self.reverse = reverse 37 | self.dropout = nn.Dropout(p=dropout_rate) 38 | 39 | # Compute positional encoding 40 | pe = torch.zeros(max_len, d_model) 41 | position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) 42 | div_term = torch.exp( 43 | torch.arange(0, d_model, 2, dtype=torch.float32) 44 | * -(math.log(10000.0) / d_model) 45 | ) 46 | pe[:, 0::2] = torch.sin(position * div_term) 47 | pe[:, 1::2] = torch.cos(position * div_term) 48 | pe = pe.unsqueeze(0) # (1, max_len, d_model) 49 | self.register_buffer("pe", pe) 50 | 51 | def forward( 52 | self, x: torch.Tensor, offset: int = 0 53 | ) -> torch.Tensor: 54 | """Forward pass. 55 | 56 | Args: 57 | x: Input tensor (batch, time, d_model) 58 | offset: Positional offset for streaming (default: 0) 59 | 60 | Returns: 61 | Output with positional encoding (batch, time, d_model) 62 | """ 63 | # Scale input by sqrt(d_model) as in original Transformer 64 | x = x * math.sqrt(self.d_model) 65 | 66 | # Add positional encoding 67 | seq_len = x.size(1) 68 | if self.reverse: 69 | # For right-to-left processing 70 | pe = self.pe[:, offset : offset + seq_len].flip(1) 71 | else: 72 | pe = self.pe[:, offset : offset + seq_len] 73 | 74 | x = x + pe 75 | return self.dropout(x) 76 | 77 | 78 | class RelPositionalEncoding(PositionalEncoding): 79 | """Relative positional encoding for Conformer. 80 | 81 | This extends absolute positional encoding by providing both 82 | the positionally encoded input and the positional encoding itself 83 | for use in relative positional multi-head attention. 84 | 85 | Args: 86 | d_model: Embedding dimension 87 | dropout_rate: Dropout rate 88 | max_len: Maximum sequence length (default: 5000) 89 | 90 | Shape: 91 | - Input: (batch, time, d_model) 92 | - Output: Tuple of 93 | - (batch, time, d_model): Positionally encoded input 94 | - (1, time, d_model): Positional encoding for relative attention 95 | """ 96 | 97 | def forward( 98 | self, x: torch.Tensor, offset: int = 0 99 | ) -> Tuple[torch.Tensor, torch.Tensor]: 100 | """Forward pass. 101 | 102 | Args: 103 | x: Input tensor (batch, time, d_model) 104 | offset: Positional offset for streaming (default: 0) 105 | 106 | Returns: 107 | Tuple of: 108 | - Output with positional encoding (batch, time, d_model) 109 | - Positional encoding tensor (1, time, d_model) 110 | """ 111 | # Scale input by sqrt(d_model) 112 | x = x * math.sqrt(self.d_model) 113 | 114 | # Get positional encoding 115 | seq_len = x.size(1) 116 | pe = self.pe[:, offset : offset + seq_len] 117 | 118 | # Add positional encoding and apply dropout 119 | x = self.dropout(x + pe) 120 | 121 | # Return both the encoded input and the positional encoding 122 | return x, pe 123 | 124 | 125 | class StreamPositionalEncoding(PositionalEncoding): 126 | """Streaming positional encoding with state management. 127 | 128 | This variant maintains an internal position counter for streaming 129 | applications, allowing proper positional encoding across chunks. 130 | 131 | Args: 132 | d_model: Embedding dimension 133 | dropout_rate: Dropout rate 134 | max_len: Maximum sequence length (default: 5000) 135 | 136 | Shape: 137 | - Input: (batch, time, d_model) 138 | - Output: (batch, time, d_model) 139 | """ 140 | 141 | def __init__( 142 | self, 143 | d_model: int, 144 | dropout_rate: float = 0.1, 145 | max_len: int = 5000, 146 | ): 147 | super().__init__(d_model, dropout_rate, max_len) 148 | self.register_buffer("_current_position", torch.tensor(0, dtype=torch.long)) 149 | 150 | def forward( 151 | self, x: torch.Tensor, offset: Optional[int] = None 152 | ) -> torch.Tensor: 153 | """Forward pass with automatic position tracking. 154 | 155 | Args: 156 | x: Input tensor (batch, time, d_model) 157 | offset: Manual positional offset (if None, uses internal counter) 158 | 159 | Returns: 160 | Output with positional encoding (batch, time, d_model) 161 | """ 162 | if offset is None: 163 | offset = self._current_position.item() 164 | self._current_position += x.size(1) 165 | 166 | return super().forward(x, offset) 167 | 168 | def reset(self): 169 | """Reset the internal position counter.""" 170 | self._current_position.zero_() 171 | -------------------------------------------------------------------------------- /docs/implementation/root-cause-analysis.md: -------------------------------------------------------------------------------- 1 | # ROOT CAUSE: Missing Encoder Buffer in Beam Search 2 | 3 | ## Summary 4 | 5 | **We found the root cause!** Our implementation is missing the critical encoder output buffering mechanism that ESPnet uses. 6 | 7 | ## The Problem 8 | 9 | ### What ESPnet Does (CORRECT): 10 | 11 | ```python 12 | # batch_beam_search_online.py:118-144 13 | def forward(self, x, is_final=True): 14 | # 1. ACCUMULATE encoder outputs in buffer 15 | if self.encbuffer is None: 16 | self.encbuffer = x 17 | else: 18 | self.encbuffer = torch.cat([self.encbuffer, x], axis=0) 19 | 20 | # 2. EXTRACT blocks from buffer 21 | while True: 22 | cur_end_frame = ( 23 | self.block_size - self.look_ahead + 24 | self.hop_size * self.processed_block 25 | ) 26 | if cur_end_frame < self.encbuffer.shape[0]: 27 | h = self.encbuffer.narrow(0, 0, cur_end_frame) # Extract block 28 | self.process_one_block(h, ...) 29 | self.processed_block += 1 30 | else: 31 | break 32 | ``` 33 | 34 | **Block extraction formula:** 35 | - Block 0: frames [0, 24) (block_size - look_ahead = 40 - 16) 36 | - Block 1: frames [0, 40) (24 + 16*1) 37 | - Block 2: frames [0, 56) (24 + 16*2) 38 | - Block 3: frames [0, 72) (24 + 16*3) 39 | 40 | ### What We Do (WRONG): 41 | 42 | ```python 43 | # beam_search.py:process_block() 44 | def process_block(self, features, is_final): 45 | # 1. Directly encode the current chunk 46 | encoder_out, encoder_out_lens, encoder_states = self.encoder( 47 | features, ..., prev_states=prev_state.encoder_states 48 | ) 49 | 50 | # 2. Directly use encoder output (NO BUFFERING!) 51 | # This gives us different sizes each time: 0, 24, 16, 0, 16, ... 52 | scores, new_states_dict = self.beam_search.batch_score_hypotheses( 53 | hypotheses, encoder_out # ← WRONG! Not extracted from buffer! 54 | ) 55 | ``` 56 | 57 | ## Evidence from Test Output 58 | 59 | ``` 60 | Chunk 4/11: 61 | ESPnet: encbuf=24 ← First encoder output 62 | Ours: enc_out=24 ← Same 63 | 64 | Chunk 5/11: 65 | ESPnet: encbuf=40 ← Buffer grew! (24 + 16 hop) 66 | Ours: enc_out=16 ← Wrong! Encoder output reset 67 | 68 | ESPnet: "liebe" (correct German!) 69 | Ours: "م" (Arabic garbage) 70 | ``` 71 | 72 | **ESPnet's encbuffer grows**: 0 → 24 → 40 → 56 → 72 → 88 73 | **Our enc_out fluctuates**: 0 → 24 → 16 → 0 → 16 → 0 74 | 75 | ## Why This Causes Token 1023 (Arabic 'م') 76 | 77 | 1. **Insufficient context**: Decoder sees only 16 frames instead of 40 78 | 2. **Truncated attention**: Can't attend to full context window 79 | 3. **Poor predictions**: Model trained on 40-frame blocks, gets 16-frame blocks 80 | 4. **Fallback behavior**: Predicts high-frequency token (1023) when confused 81 | 82 | ## The Fix 83 | 84 | We need to implement encoder output buffering in `BlockwiseSynchronousBeamSearch`: 85 | 86 | ```python 87 | class BlockwiseSynchronousBeamSearch: 88 | def __init__(self, ...): 89 | self.encoder_buffer = None # Accumulated encoder outputs 90 | self.processed_block = 0 91 | 92 | def process_block(self, features, is_final): 93 | # 1. Encode features 94 | encoder_out = self.encoder(features, ...) 95 | 96 | # 2. ACCUMULATE in buffer 97 | if self.encoder_buffer is None: 98 | self.encoder_buffer = encoder_out 99 | else: 100 | self.encoder_buffer = torch.cat([self.encoder_buffer, encoder_out], dim=1) 101 | 102 | # 3. EXTRACT block(s) from buffer 103 | while True: 104 | cur_end_frame = ( 105 | self.block_size - self.look_ahead + 106 | self.hop_size * self.processed_block 107 | ) 108 | 109 | if cur_end_frame <= self.encoder_buffer.shape[1]: 110 | # Extract block: [0, cur_end_frame) 111 | block = self.encoder_buffer[:, :cur_end_frame, :] 112 | 113 | # Decode this block 114 | self.decode_block(block, is_final=False) 115 | self.processed_block += 1 116 | else: 117 | break 118 | 119 | # 4. If final, decode remaining buffer 120 | if is_final and self.encoder_buffer.shape[1] > 0: 121 | self.decode_block(self.encoder_buffer, is_final=True) 122 | ``` 123 | 124 | ## Architecture Comparison 125 | 126 | ### ESPnet Architecture (CORRECT): 127 | 128 | ``` 129 | Audio Chunks → Encoder → [BUFFER] → Block Extractor → Decoder 130 | ↓ ↓ ↓ ↓ ↓ 131 | 8000 samples Varies Grows Fixed blocks Good output 132 | (0.5s) 0→24→16 0→24→40 24,40,56... "Liebe..." 133 | ``` 134 | 135 | ### Our Architecture (WRONG): 136 | 137 | ``` 138 | Audio Chunks → Encoder → [NO BUFFER] → Decoder 139 | ↓ ↓ ↓ 140 | 25600 samples Varies Bad output 141 | (1.6s) 0→24→16 "م..." 142 | ``` 143 | 144 | ## Next Steps 145 | 146 | 1. ✅ Identified root cause: Missing encoder buffer 147 | 2. ⏳ Implement encoder buffering in `BlockwiseSynchronousBeamSearch` 148 | 3. ⏳ Implement block extraction logic 149 | 4. ⏳ Test with multi-chunk streaming 150 | 5. ⏳ Verify output matches ESPnet: "Liebe Mitglieder..." 151 | 152 | ## Key Parameters 153 | 154 | From config.yaml and test output: 155 | - `block_size`: 40 frames (encoder output frames) 156 | - `hop_size`: 16 frames (how much to advance per block) 157 | - `look_ahead`: 16 frames (future context) 158 | - First block: frames [0, 24) 159 | - Second block: frames [0, 40) 160 | - Hop between blocks: 16 frames 161 | 162 | ## Impact 163 | 164 | This explains EVERYTHING: 165 | - ✅ Why batch mode works (sees full context) 166 | - ✅ Why streaming mode fails (insufficient context) 167 | - ✅ Why token 1023 appears (model confused by short context) 168 | - ✅ Why BBD detects repetitions (model trying to end early) 169 | - ✅ Why ESPnet streaming works (proper buffering) 170 | -------------------------------------------------------------------------------- /tests/test_espnet_vs_ours.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Compare our decoder output vs ESPnet's decoder output with same inputs.""" 3 | 4 | import logging 5 | import torch 6 | import numpy as np 7 | from pathlib import Path 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(message)s') 10 | logger = logging.getLogger(__name__) 11 | 12 | logger.info("="*80) 13 | logger.info("DECODER OUTPUT COMPARISON: Our Implementation vs ESPnet") 14 | logger.info("="*80) 15 | 16 | # ============================================================================ 17 | # Load our model 18 | # ============================================================================ 19 | logger.info("\n[1] Loading OUR model...") 20 | from speechcatcher.speechcatcher import load_model 21 | 22 | our_model = load_model( 23 | tag="speechcatcher/speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024", 24 | device="cpu", 25 | beam_size=5, 26 | quiet=True 27 | ) 28 | 29 | logger.info("✅ Our model loaded") 30 | 31 | # ============================================================================ 32 | # Load ESPnet model directly 33 | # ============================================================================ 34 | logger.info("\n[2] Loading ESPnet model...") 35 | 36 | from espnet2.bin.asr_inference_streaming import Speech2TextStreaming as ESPnetS2T 37 | 38 | espnet_model = ESPnetS2T( 39 | asr_train_config="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml", 40 | asr_model_file="/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth", 41 | device="cpu", 42 | ) 43 | 44 | logger.info("✅ ESPnet model loaded") 45 | 46 | # ============================================================================ 47 | # Create test inputs 48 | # ============================================================================ 49 | logger.info("\n[3] Creating test inputs...") 50 | 51 | # Use SAME encoder output for both 52 | encoder_out = torch.randn(1, 124, 256) # Fixed random encoder output 53 | sos_id = 1 54 | yseq = torch.tensor([sos_id]) 55 | 56 | logger.info(f"Encoder output: {encoder_out.shape}") 57 | logger.info(f"Input sequence: {yseq}") 58 | 59 | # ============================================================================ 60 | # Test our decoder 61 | # ============================================================================ 62 | logger.info("\n[4] Testing OUR decoder...") 63 | 64 | with torch.no_grad(): 65 | our_decoder_out, our_state = our_model.model.decoder.score( 66 | yseq, 67 | state=None, 68 | x=encoder_out[0] # (124, 256) 69 | ) 70 | 71 | logger.info(f"Our decoder output: {our_decoder_out.shape}") 72 | logger.info(f"Our decoder stats: min={our_decoder_out.min():.4f}, max={our_decoder_out.max():.4f}, mean={our_decoder_out.mean():.4f}") 73 | 74 | top_k = 5 75 | top_scores, top_tokens = torch.topk(our_decoder_out, k=top_k) 76 | logger.info(f"\nOur top {top_k} predictions:") 77 | for i, (score, token) in enumerate(zip(top_scores.tolist(), top_tokens.tolist())): 78 | if our_model.tokenizer: 79 | token_text = our_model.tokenizer.id_to_piece(int(token)) 80 | logger.info(f" {i+1}. Token {token:4d} ({token_text:20s}): {score:.4f}") 81 | 82 | # ============================================================================ 83 | # Test ESPnet decoder 84 | # ============================================================================ 85 | logger.info(f"\n[5] Testing ESPnet decoder...") 86 | 87 | with torch.no_grad(): 88 | # ESPnet decoder.score() has same signature 89 | espnet_decoder_out, espnet_state = espnet_model.asr_model.decoder.score( 90 | yseq, 91 | state=None, 92 | x=encoder_out[0] # (124, 256) 93 | ) 94 | 95 | logger.info(f"ESPnet decoder output: {espnet_decoder_out.shape}") 96 | logger.info(f"ESPnet decoder stats: min={espnet_decoder_out.min():.4f}, max={espnet_decoder_out.max():.4f}, mean={espnet_decoder_out.mean():.4f}") 97 | 98 | top_scores, top_tokens = torch.topk(espnet_decoder_out, k=top_k) 99 | logger.info(f"\nESPnet top {top_k} predictions:") 100 | for i, (score, token) in enumerate(zip(top_scores.tolist(), top_tokens.tolist())): 101 | if our_model.tokenizer: 102 | token_text = our_model.tokenizer.id_to_piece(int(token)) 103 | logger.info(f" {i+1}. Token {token:4d} ({token_text:20s}): {score:.4f}") 104 | 105 | # ============================================================================ 106 | # Compare outputs 107 | # ============================================================================ 108 | logger.info(f"\n[6] Comparing outputs...") 109 | 110 | diff = (our_decoder_out - espnet_decoder_out).abs() 111 | logger.info(f"\nAbsolute difference:") 112 | logger.info(f" Max: {diff.max():.6f}") 113 | logger.info(f" Mean: {diff.mean():.6f}") 114 | logger.info(f" Min: {diff.min():.6f}") 115 | 116 | if torch.allclose(our_decoder_out, espnet_decoder_out, atol=1e-4): 117 | logger.info("\n✅ Outputs are IDENTICAL (within tolerance)!") 118 | else: 119 | logger.error("\n❌ Outputs are DIFFERENT!") 120 | 121 | # Find which tokens differ most 122 | _, our_top_token = our_decoder_out.topk(1) 123 | _, espnet_top_token = espnet_decoder_out.topk(1) 124 | 125 | logger.info(f"\nTop prediction:") 126 | logger.info(f" Ours: token {our_top_token.item()}") 127 | logger.info(f" ESPnet: token {espnet_top_token.item()}") 128 | 129 | if our_top_token.item() == 1023: 130 | logger.error(f" ❌ Our model predicts token 1023 (wrong!)") 131 | if espnet_top_token.item() != our_top_token.item(): 132 | logger.error(f" ❌ Different top predictions!") 133 | 134 | logger.info("\n" + "="*80) 135 | logger.info("COMPARISON COMPLETE") 136 | logger.info("="*80) 137 | -------------------------------------------------------------------------------- /speechcatcher/beam_search/hypothesis.py: -------------------------------------------------------------------------------- 1 | """Hypothesis and beam state classes for beam search.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Any, Dict, List, Optional 5 | 6 | import torch 7 | 8 | 9 | @dataclass 10 | class Hypothesis: 11 | """Single hypothesis in beam search. 12 | 13 | Matches ESPnet's Hypothesis structure for compatibility. 14 | 15 | Attributes: 16 | yseq: Token sequence (torch.Tensor of token IDs) 17 | score: Total log probability score 18 | scores: Score breakdown by component (e.g., 'decoder', 'ctc', 'lm') 19 | states: Per-scorer states dict {scorer_name: state} 20 | xpos: Encoder frame positions for each token (torch.Tensor) 21 | """ 22 | 23 | yseq: torch.Tensor = field(default_factory=lambda: torch.tensor([], dtype=torch.long)) 24 | score: float = 0.0 25 | scores: Dict[str, float] = field(default_factory=dict) 26 | states: Dict[str, Any] = field(default_factory=dict) 27 | xpos: torch.Tensor = field(default_factory=lambda: torch.tensor([], dtype=torch.long)) 28 | 29 | def asdict(self) -> Dict[str, Any]: 30 | """Convert to dictionary (for compatibility with ESPnet).""" 31 | return { 32 | "yseq": self.yseq, 33 | "score": self.score, 34 | "scores": self.scores, 35 | } 36 | 37 | def __repr__(self) -> str: 38 | yseq_list = self.yseq.tolist() if len(self.yseq) > 0 else [] 39 | yseq_str = str(yseq_list[:10]) + ('...' if len(yseq_list) > 10 else '') 40 | return f"Hypothesis(yseq={yseq_str}, score={self.score:.2f})" 41 | 42 | 43 | @dataclass 44 | class BeamState: 45 | """Beam state for blockwise synchronous beam search. 46 | 47 | Attributes: 48 | hypotheses: List of active hypotheses 49 | encoder_states: Encoder streaming states 50 | encoder_out: Current encoder output (batch, time, dim) 51 | encoder_out_lens: Encoder output lengths 52 | processed_frames: Number of frames processed so far 53 | is_final: Whether this is the final block 54 | output_index: Current output token index (for BBD) 55 | evaluated_hyps: Set of hypothesis sequences already evaluated (Ω_R for BBD) 56 | """ 57 | 58 | hypotheses: List[Hypothesis] = field(default_factory=list) 59 | encoder_states: Optional[Dict] = None 60 | encoder_out: Optional[torch.Tensor] = None 61 | encoder_out_lens: Optional[torch.Tensor] = None 62 | processed_frames: int = 0 63 | is_final: bool = False 64 | output_index: int = 0 65 | evaluated_hyps: set = field(default_factory=set) 66 | 67 | def __repr__(self) -> str: 68 | return ( 69 | f"BeamState(n_hyps={len(self.hypotheses)}, " 70 | f"processed_frames={self.processed_frames}, " 71 | f"is_final={self.is_final})" 72 | ) 73 | 74 | 75 | def create_initial_hypothesis(sos_id: int = 1023, device: str = "cpu") -> Hypothesis: 76 | """Create initial hypothesis with SOS token. 77 | 78 | Args: 79 | sos_id: Start-of-sentence token ID 80 | device: Device to place tensors on 81 | 82 | Returns: 83 | Initial hypothesis 84 | """ 85 | return Hypothesis( 86 | yseq=torch.tensor([sos_id], dtype=torch.long, device=device), 87 | score=0.0, 88 | scores={}, 89 | states={}, 90 | xpos=torch.tensor([0], dtype=torch.long, device=device), 91 | ) 92 | 93 | 94 | def batch_hypotheses(hypotheses: List[Hypothesis], device: str = "cpu") -> Dict[str, torch.Tensor]: 95 | """Batch multiple hypotheses for parallel scoring. 96 | 97 | Args: 98 | hypotheses: List of hypotheses to batch 99 | device: Device to place tensors on 100 | 101 | Returns: 102 | Dictionary with batched tensors 103 | """ 104 | if not hypotheses: 105 | return {} 106 | 107 | # Batch yseq (already torch.Tensor in each hypothesis) 108 | max_len = max(len(h.yseq) for h in hypotheses) 109 | yseq_batch = torch.zeros(len(hypotheses), max_len, dtype=torch.long, device=device) 110 | 111 | for i, h in enumerate(hypotheses): 112 | yseq_batch[i, : len(h.yseq)] = h.yseq.to(device) 113 | 114 | # Batch states - now dict-based 115 | # States is Dict[str, Any] where each scorer has its own state 116 | # We need to reorganize this for batch processing 117 | states_batch = {} 118 | if hypotheses[0].states: 119 | # Get all scorer names 120 | scorer_names = hypotheses[0].states.keys() 121 | for scorer_name in scorer_names: 122 | # Each scorer's state is a list of layer states 123 | states_batch[scorer_name] = [h.states[scorer_name] for h in hypotheses] 124 | 125 | return { 126 | "yseq": yseq_batch, 127 | "states": states_batch, 128 | "scores": [h.score for h in hypotheses], 129 | } 130 | 131 | 132 | def top_k_hypotheses(hypotheses: List[Hypothesis], k: int) -> List[Hypothesis]: 133 | """Select top-k hypotheses by score. 134 | 135 | Args: 136 | hypotheses: List of hypotheses 137 | k: Number of hypotheses to keep 138 | 139 | Returns: 140 | Top-k hypotheses sorted by score (descending) 141 | """ 142 | return sorted(hypotheses, key=lambda h: h.score, reverse=True)[:k] 143 | 144 | 145 | def append_token(tensor: torch.Tensor, token_id: int) -> torch.Tensor: 146 | """Append a token to a tensor sequence. 147 | 148 | Args: 149 | tensor: Original sequence tensor 150 | token_id: Token ID to append 151 | 152 | Returns: 153 | New tensor with token appended 154 | """ 155 | return torch.cat([tensor, torch.tensor([token_id], dtype=torch.long, device=tensor.device)]) 156 | 157 | 158 | def append_position(xpos: torch.Tensor, position: int) -> torch.Tensor: 159 | """Append an encoder position to xpos tensor. 160 | 161 | Args: 162 | xpos: Original position tensor 163 | position: Encoder frame position to append 164 | 165 | Returns: 166 | New tensor with position appended 167 | """ 168 | return torch.cat([xpos, torch.tensor([position], dtype=torch.long, device=xpos.device)]) 169 | -------------------------------------------------------------------------------- /tests/test_exact_score_comparison.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Compare exact decoder + CTC scores for tokens 372 and 738.""" 3 | 4 | import torch 5 | import numpy as np 6 | import wave 7 | import hashlib 8 | import os 9 | import json 10 | 11 | print("="*80) 12 | print("EXACT SCORE COMPARISON (Tokens 372 vs 738)") 13 | print("="*80) 14 | 15 | # Load audio 16 | from speechcatcher.speechcatcher import convert_inputfile 17 | 18 | os.makedirs('.tmp/', exist_ok=True) 19 | wavfile_path = '.tmp/' + hashlib.sha1("Neujahrsansprache_5s.mp4".encode("utf-8")).hexdigest() + '.wav' 20 | if not os.path.exists(wavfile_path): 21 | convert_inputfile("Neujahrsansprache_5s.mp4", wavfile_path) 22 | 23 | with wave.open(wavfile_path, 'rb') as f: 24 | raw_audio = np.frombuffer(f.readframes(-1), dtype='int16') 25 | speech = raw_audio.astype(np.float32) / 32768.0 26 | 27 | # Load token list 28 | with open("/tmp/espnet_token_list.json", "r") as f: 29 | token_list = json.load(f) 30 | 31 | # Load ESPnet 32 | print("\n[1] Loading ESPnet...") 33 | from espnet_streaming_decoder.asr_inference_streaming import Speech2TextStreaming as ESPnetStreaming 34 | 35 | config_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/config.yaml" 36 | model_path = "/home/ben/.cache/espnet/models--speechcatcher--speechcatcher_german_espnet_streaming_transformer_26k_train_size_xl_raw_de_bpe1024/snapshots/469c3474c28025d77cd7c1e1671638b56de53c2d/exp/asr_train_asr_streaming_transformer_size_xl_raw_de_bpe1024/valid.acc.ave_6best.pth" 37 | 38 | espnet_s2t = ESPnetStreaming( 39 | asr_train_config=config_path, 40 | asr_model_file=model_path, 41 | device="cpu", 42 | beam_size=5, 43 | ctc_weight=0.3, 44 | ) 45 | espnet_s2t.reset() 46 | print("✅ ESPnet loaded") 47 | 48 | # Load ours 49 | print("\n[2] Loading our implementation...") 50 | from speechcatcher.speechcatcher import load_model, tags 51 | 52 | our_s2t = load_model(tags['de_streaming_transformer_xl'], beam_size=5, quiet=True) 53 | our_s2t.reset() 54 | print("✅ Ours loaded") 55 | 56 | # Process chunks 1-5 to get to first decoding point 57 | print("\n[3] Processing chunks 1-5...") 58 | chunk_size = 8000 59 | 60 | for chunk_idx in range(5): 61 | chunk = speech[chunk_idx*chunk_size : min((chunk_idx+1)*chunk_size, len(speech))] 62 | espnet_s2t(chunk, is_final=False) 63 | our_s2t(chunk, is_final=False) 64 | 65 | print("✅ Processed chunks 1-5") 66 | 67 | # Now we should have 40-frame encoder buffer and be ready to decode block 0 68 | print("\n[4] Extracting 40-frame encoder output...") 69 | 70 | # Get SAME 40 frames from both 71 | espnet_enc = espnet_s2t.beam_search.encbuffer[:40].unsqueeze(0) # (1, 40, 256) 72 | our_enc = our_s2t.beam_search.encoder_buffer[:, :40, :] # (1, 40, 256) 73 | 74 | print(f"ESPnet encoder: {espnet_enc.shape}") 75 | print(f"Our encoder: {our_enc.shape}") 76 | print(f"Match: {torch.allclose(espnet_enc, our_enc, atol=1e-6)}") 77 | 78 | # Use SAME encoder output for both 79 | enc = espnet_enc 80 | 81 | print("\n" + "="*80) 82 | print("DECODER SCORES (from SOS=1023)") 83 | print("="*80) 84 | 85 | # Initial hypothesis: [SOS] 86 | ys = torch.tensor([[1023]], dtype=torch.long) 87 | 88 | # ESPnet decoder 89 | with torch.no_grad(): 90 | espnet_dec_out, _ = espnet_s2t.asr_model.decoder.forward_one_step(ys, None, enc, cache=None) 91 | espnet_dec_logprobs = torch.log_softmax(espnet_dec_out, dim=-1)[0] 92 | 93 | print(f"\nToken 372 (▁Li): {espnet_dec_logprobs[372].item():.6f}") 94 | print(f"Token 738 (▁liebe): {espnet_dec_logprobs[738].item():.6f}") 95 | print(f"Difference: {(espnet_dec_logprobs[372] - espnet_dec_logprobs[738]).item():.6f}") 96 | 97 | # Our decoder (should be identical) 98 | with torch.no_grad(): 99 | our_dec_out, _ = our_s2t.model.decoder.forward_one_step(ys, None, enc, cache=None) 100 | our_dec_logprobs = torch.log_softmax(our_dec_out, dim=-1)[0] 101 | 102 | print(f"\n[Verification - should match ESPnet]") 103 | print(f"Token 372 (▁Li): {our_dec_logprobs[372].item():.6f}") 104 | print(f"Token 738 (▁liebe): {our_dec_logprobs[738].item():.6f}") 105 | 106 | print("\n" + "="*80) 107 | print("CTC SCORES") 108 | print("="*80) 109 | 110 | # Get CTC logits for 40 frames 111 | with torch.no_grad(): 112 | espnet_ctc_logits = espnet_s2t.asr_model.ctc.ctc_lo(enc) # (1, 40, 1024) 113 | espnet_ctc_logprobs = torch.log_softmax(espnet_ctc_logits, dim=-1)[0] # (40, 1024) 114 | 115 | # Simple sum across frames (not true CTC prefix scoring, but gives intuition) 116 | espnet_ctc_sum_372 = espnet_ctc_logprobs[:, 372].sum().item() 117 | espnet_ctc_sum_738 = espnet_ctc_logprobs[:, 738].sum().item() 118 | 119 | print(f"\nToken 372 (▁Li) sum: {espnet_ctc_sum_372:.6f}") 120 | print(f"Token 738 (▁liebe) sum: {espnet_ctc_sum_738:.6f}") 121 | print(f"Difference: {(espnet_ctc_sum_372 - espnet_ctc_sum_738):.6f}") 122 | 123 | print("\n" + "="*80) 124 | print("COMBINED SCORES (decoder=0.7, ctc=0.3)") 125 | print("="*80) 126 | 127 | # NOTE: This is simplified - real CTC prefix scoring is more complex 128 | dec_372 = espnet_dec_logprobs[372].item() 129 | dec_738 = espnet_dec_logprobs[738].item() 130 | 131 | # Simplified combined score 132 | combined_372 = 0.7 * dec_372 + 0.3 * (espnet_ctc_sum_372 / 40) 133 | combined_738 = 0.7 * dec_738 + 0.3 * (espnet_ctc_sum_738 / 40) 134 | 135 | print(f"\nToken 372 (▁Li):") 136 | print(f" Decoder (0.7x): {0.7 * dec_372:.6f}") 137 | print(f" CTC (0.3x): {0.3 * (espnet_ctc_sum_372 / 40):.6f}") 138 | print(f" Combined: {combined_372:.6f}") 139 | 140 | print(f"\nToken 738 (▁liebe):") 141 | print(f" Decoder (0.7x): {0.7 * dec_738:.6f}") 142 | print(f" CTC (0.3x): {0.3 * (espnet_ctc_sum_738 / 40):.6f}") 143 | print(f" Combined: {combined_738:.6f}") 144 | 145 | print(f"\nWinner: Token {'372 (▁Li)' if combined_372 > combined_738 else '738 (▁liebe)'}") 146 | print(f"Margin: {abs(combined_372 - combined_738):.6f}") 147 | 148 | print("\n" + "="*80) 149 | print("NOTE: CTC scoring here is simplified (sum/frames). Real CTC prefix") 150 | print("scoring uses forward algorithm with proper blank handling.") 151 | print("="*80) 152 | --------------------------------------------------------------------------------