├── .dockerignore ├── .github └── workflows │ ├── python-publish.yml │ └── wiki-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── docker-compose.yml ├── docs ├── features.md ├── figs │ ├── log_gammatone_spectrogram.png │ ├── log_mel_spectrogram.png │ ├── mfcc.png │ └── spectrogram.png ├── tokenizers.md └── tutorials │ ├── testing.md │ ├── tflite.md │ └── training.md ├── examples ├── datasets │ ├── librispeech │ │ ├── characters │ │ │ ├── char.yml.j2 │ │ │ ├── english.metadata.json │ │ │ └── english.vocab │ │ ├── config.yml.j2 │ │ ├── prepare_transcript.py │ │ ├── sentencepiece │ │ │ ├── sp.256.yml.j2 │ │ │ ├── sp.yml.j2 │ │ │ ├── train_8000&960.model │ │ │ ├── train_bpe_1000.metadata.json │ │ │ ├── train_bpe_1000.model │ │ │ ├── train_bpe_1000.vocab │ │ │ ├── train_bpe_256.metadata.json │ │ │ ├── train_bpe_256.model │ │ │ └── train_bpe_256.vocab │ │ ├── subwords │ │ │ ├── train_1030_4.metadata.json │ │ │ └── train_1030_4.subwords │ │ └── wordpiece │ │ │ ├── train_1000.metadata.json │ │ │ ├── train_1000.vocab │ │ │ ├── train_1000_whitespace.metadata.json │ │ │ ├── train_1000_whitespace.vocab │ │ │ ├── wp.yml.j2 │ │ │ └── wp_whitespace.yml.j2 │ ├── vietbud500 │ │ ├── config.yml.j2 │ │ ├── download.py │ │ └── sentencepiece │ │ │ ├── sp.256.yml.j2 │ │ │ ├── sp.yml.j2 │ │ │ ├── train_bpe_1000.metadata.json │ │ │ ├── train_bpe_1000.model │ │ │ ├── train_bpe_1000.vocab │ │ │ ├── train_bpe_256.metadata.json │ │ │ ├── train_bpe_256.model │ │ │ └── train_bpe_256.vocab │ └── vivos │ │ └── vietnamese.characters ├── inferences │ ├── README.md │ ├── main.py │ ├── rnn_transducer.py │ ├── streaming_tflite_conformer.py │ ├── tflite.py │ └── wavs │ │ ├── 1089-134691-0000.flac │ │ └── 2033-164915-0001.flac └── models │ ├── ctc │ ├── conformer │ │ ├── results │ │ │ └── sentencepiece │ │ │ │ ├── README.md │ │ │ │ └── figs │ │ │ │ ├── librispeech-small-streaming-batch-loss.jpg │ │ │ │ ├── librispeech-small-streaming-epoch-loss.jpg │ │ │ │ └── librispeech-small-streaming-lr.jpg │ │ ├── small-streaming.yml.j2 │ │ └── small.yml.j2 │ ├── deepspeech2 │ │ ├── base.yml.j2 │ │ └── uni.yml.j2 │ ├── jasper │ │ └── base.yml.j2 │ └── transformer │ │ ├── README.md │ │ ├── base-streaming.yml.j2 │ │ └── base.yml.j2 │ └── transducer │ ├── conformer │ ├── README.md │ ├── inference │ │ ├── gen_saved_model.py │ │ └── run_saved_model.py │ ├── results │ │ ├── sentencepiece │ │ │ ├── README.md │ │ │ └── figs │ │ │ │ ├── vietbud500-small-streaming-batch-loss.jpg │ │ │ │ ├── vietbud500-small-streaming-epoch-loss.jpg │ │ │ │ └── vietbud500-small-streaming-lr.jpg │ │ └── subword - deprecated │ │ │ └── figs │ │ │ ├── conformer.svg │ │ │ └── subword_conformer_loss.svg │ ├── small-streaming.yml.j2 │ └── small.yml.j2 │ ├── contextnet │ ├── README.md │ ├── results │ │ └── wordpiece │ │ │ ├── README.md │ │ │ └── figs │ │ │ ├── contextnet-small-wp1k-whitespace-batch-loss.svg │ │ │ ├── contextnet-small-wp1k-whitespace-epoch-loss.svg │ │ │ └── contextnet-small-wp1k-whitespace-lr.svg │ └── small.yml.j2 │ └── rnnt │ ├── README.md │ ├── results │ ├── sentencepiece │ │ ├── README.md │ │ └── figs │ │ │ ├── rnnt-tiny-sp256-batch-loss.svg │ │ │ └── rnnt-tiny-sp256-epoch-loss.svg │ └── subword - deprecated │ │ ├── README.md │ │ └── figs │ │ ├── epoch_learning_rate.svg │ │ └── subword_rnnt_loss.svg │ └── small.yml.j2 ├── pyproject.toml ├── requirements.apple.txt ├── requirements.cpu.txt ├── requirements.dev.txt ├── requirements.gpu.txt ├── requirements.text.txt ├── requirements.tpu.txt ├── requirements.txt ├── scripts ├── install_ctc_decoders.sh ├── install_ctc_loss.sh └── install_rnnt_loss.sh ├── setup.cfg ├── setup.py ├── setup.sh ├── tensorflow_asr ├── __init__.py ├── abstracts.py ├── augmentations │ ├── README.md │ ├── __init__.py │ ├── augmentation.py │ └── methods │ │ ├── __init__.py │ │ ├── base_method.py │ │ ├── gaussnoise.py │ │ └── specaugment.py ├── callbacks.py ├── configs.py ├── datasets.py ├── features │ ├── README.md │ ├── __init__.py │ └── gammatone.py ├── losses │ ├── __init__.py │ ├── base_loss.py │ ├── ctc_loss.py │ ├── impl │ │ ├── __init__.py │ │ ├── ctc_tpu.py │ │ └── rnnt.py │ └── rnnt_loss.py ├── metrics │ ├── __init__.py │ └── error_rates.py ├── models │ ├── __init__.py │ ├── activations │ │ ├── __init__.py │ │ └── glu.py │ ├── base_layer.py │ ├── base_model.py │ ├── ctc │ │ ├── __init__.py │ │ ├── base_ctc.py │ │ ├── conformer.py │ │ ├── deepspeech2.py │ │ ├── jasper.py │ │ └── transformer.py │ ├── decoders │ │ └── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ ├── conformer.py │ │ ├── contextnet.py │ │ ├── deepspeech2.py │ │ ├── jasper.py │ │ ├── rnnt.py │ │ └── transformer.py │ ├── layers │ │ ├── __init__.py │ │ ├── blurpool.py │ │ ├── convolution.py │ │ ├── embedding.py │ │ ├── feature_extraction.py │ │ ├── general.py │ │ ├── memory.py │ │ ├── multihead_attention.py │ │ ├── norm.py │ │ ├── positional_encoding.py │ │ ├── residual.py │ │ ├── sequence_wise_bn.py │ │ └── subsampling.py │ └── transducer │ │ ├── __init__.py │ │ ├── base_transducer.py │ │ ├── conformer.py │ │ ├── contextnet.py │ │ ├── rnnt.py │ │ └── transformer.py ├── optimizers │ ├── __init__.py │ ├── accumulation.py │ ├── regularizers.py │ └── schedules.py ├── schemas.py ├── scripts │ ├── __init__.py │ ├── save.py │ ├── test.py │ ├── tflite.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── create_datasets_metadata.py │ │ ├── create_mls_trans.py │ │ └── create_tfrecords.py ├── tokenizers.py └── utils │ ├── __init__.py │ ├── app_util.py │ ├── cli_util.py │ ├── data_util.py │ ├── env_util.py │ ├── feature_util.py │ ├── file_util.py │ ├── keras_util.py │ ├── layer_util.py │ ├── math_util.py │ ├── metric_util.py │ ├── plot_util.py │ ├── shape_util.py │ └── tf_util.py └── tests ├── __init__.py ├── conftest.py ├── featurizer ├── test_sentencepiece.py ├── test_speech_featurizer.py └── transcripts_librispeech_train_clean_100.tsv ├── test.flac ├── test_bug.py ├── test_callbacks.py ├── test_layers.py ├── test_mask.py ├── test_relpe.py ├── test_rnnt_loss.py ├── test_schedules.py ├── test_tokenizers.py └── test_utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | LibriSpeech 2 | Models 3 | .venv* 4 | venv* 5 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.10.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.github/workflows/wiki-publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Wiki Pages 2 | on: 3 | push: 4 | branches: [main] 5 | concurrency: 6 | group: publish-wiki 7 | cancel-in-progress: true 8 | permissions: 9 | contents: write 10 | jobs: 11 | publish-wiki: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4.1.4 15 | - uses: nglehuy/github-wiki-action@master 16 | with: 17 | token: ${{ secrets.TOKEN }} 18 | path: docs 19 | preprocess: true 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.egg-info 4 | tensorflow 5 | externals 6 | .vim 7 | .session.vim 8 | Session.vim 9 | .idea 10 | __pycache__ 11 | .pytest* 12 | venv* 13 | .venv* 14 | my_train 15 | .DS_Store 16 | models/* 17 | !models/README.md -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: black-formatter-fix 5 | name: black-formatter-fix 6 | entry: bash -c "for f in $@; do black --verbose $f; done" 7 | language: system 8 | types: [python] 9 | stages: [pre-commit] 10 | fail_fast: true 11 | verbose: true 12 | - id: isort-fix 13 | name: isort-fix 14 | entry: bash -c "for f in $@; do echo -e \"Organize import for file $f\" && isort $f; done" 15 | language: system 16 | types: [python] 17 | stages: [pre-commit] 18 | fail_fast: true 19 | verbose: true 20 | - id: pylint-check 21 | name: pylint-check 22 | entry: bash -c "for f in $@; do pylint --rcfile=.pylintrc -rn -sn $f; done" 23 | language: system 24 | types: [python] 25 | stages: [pre-commit] 26 | fail_fast: true 27 | require_serial: true 28 | verbose: true 29 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.isort", 4 | "ms-python.black-formatter", 5 | "ms-python.pylint", 6 | "ms-python.vscode-pylance", 7 | "ms-python.python" 8 | ] 9 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Test RNNT Loss", 9 | "type": "python", 10 | "request": "launch", 11 | "module": "pytest", 12 | "justMyCode": true, 13 | "args": [ 14 | "-s", 15 | "./tests/test_rnnt_loss.py" 16 | ] 17 | }, 18 | { 19 | "name": "Test Prediction", 20 | "type": "python", 21 | "request": "launch", 22 | "justMyCode": true, 23 | "program": "./examples/inferences/main.py", 24 | "args": [ 25 | "--file-path", 26 | "/Users/nglehuy/Data/Persona/MachineLearning/Datasets/LibriSpeech/test-clean/61/70970/61-70970-0030.flac", 27 | "--config-path", 28 | "~/Data/Persona/Projects/TensorFlowASR/examples/models/transducer/contextnet/small.yml.j2", 29 | "--h5", 30 | "~/Data/Persona/MachineLearning/Models/transducer/sp1k-contextnet/small/28.h5" 31 | ] 32 | } 33 | ] 34 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "ms-python.black-formatter", 4 | "editor.tabSize": 4 5 | }, 6 | "[markdown]": { 7 | "editor.tabSize": 2, 8 | "editor.indentSize": 2, 9 | "editor.detectIndentation": false 10 | }, 11 | "[json]": { 12 | "editor.tabSize": 2 13 | }, 14 | "[yaml]": { 15 | "editor.tabSize": 2 16 | }, 17 | "autoDocstring.docstringFormat": "numpy", 18 | "black-formatter.args": ["--config", "${workspaceFolder}/pyproject.toml"], 19 | "black-formatter.path": ["${interpreter}", "-m", "black"], 20 | "editor.codeActionsOnSave": { 21 | "source.fixAll": "explicit", 22 | "source.organizeImports": "explicit" 23 | }, 24 | "editor.formatOnSave": true, 25 | "isort.args": ["--settings-file", "${workspaceFolder}/pyproject.toml"], 26 | "pylint.args": ["--rcfile=${workspaceFolder}/.pylintrc"], 27 | "pylint.path": ["${interpreter}", "-m", "pylint"], 28 | "python.analysis.fixAll": ["source.unusedImports", "source.convertImportFormat"], 29 | "python.analysis.importFormat": "absolute", 30 | "markdown.extension.list.indentationSize": "inherit" 31 | } 32 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.18.0-gpu 2 | 3 | RUN apt-get update \ 4 | && apt-get upgrade -y \ 5 | && apt-get install -y \ 6 | && apt-get -y install apt-utils gcc libpq-dev libsndfile-dev git build-essential cmake screen 7 | 8 | # Clear cache 9 | RUN apt clean && apt-get clean 10 | 11 | # Install dependencies 12 | COPY requirements*.txt / 13 | RUN pip --no-cache-dir install -r /requirements.txt -r /requirements.cuda.txt 14 | 15 | # Install rnnt_loss 16 | COPY scripts /scripts 17 | ARG install_rnnt_loss=true 18 | ARG using_gpu=true 19 | RUN if [ "$install_rnnt_loss" = "true" ] ; \ 20 | then if [ "$using_gpu" = "true" ] ; then export CUDA_HOME=/usr/local/cuda ; else echo 'Using CPU' ; fi \ 21 | && ./scripts/install_rnnt_loss.sh \ 22 | else echo 'Using pure TensorFlow'; fi 23 | 24 | RUN echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | 3 | services: 4 | tensorflow_asr: 5 | build: 6 | context: . 7 | args: 8 | using_gpu: "true" 9 | install_rnnt_loss: "true" 10 | tty: true 11 | runtime: nvidia 12 | environment: 13 | - NVIDIA_VISIBLE_DEVICES=all 14 | - NVIDIA_DRIVER_CAPABILITIES=all 15 | ipc: "host" 16 | ports: 17 | - 6006:6006 18 | working_dir: /app 19 | volumes: 20 | - ./:/app 21 | -------------------------------------------------------------------------------- /docs/features.md: -------------------------------------------------------------------------------- 1 | # Speech Features Extraction 2 | 3 | See [feature_extraction.py](../tensorflow_asr/models/layers/feature_extraction.py) for more detail 4 | 5 | **Speech features** are extracted from the **Signal** with `sample_rate`, `frame_ms`, `stride_ms` and `num_feature_bins`. 6 | 7 | Speech features has the shape `(B, T, num_feature_bins, num_channels)` and it contains from 1-4 channels: 8 | 9 | 1. Spectrogram, Log Mel Spectrogram, Log Gammatone Spectrogram or MFCCs 10 | 2. TODO: Delta features: like `librosa.feature.delta` from the features extracted on channel 1. 11 | 3. TODO: Delta deltas features: like `librosa.feature.delta` with `order=2` from the features extracted on channel 1. 12 | 4. TODO: Pitch features: like `librosa.core.piptrack` from the signal 13 | 14 | Implementation in tensorflow keras [layer](../tensorflow_asr/models/layers/feature_extraction.py) 15 | 16 | ![Spectrogram](./figs/spectrogram.png) 17 | 18 | ![Log Mel Spectrogram](./figs/log_mel_spectrogram.png) 19 | 20 | ![MFCCs](./figs/mfcc.png) 21 | 22 | ![Log Gammatone Spectrogram](./figs/log_gammatone_spectrogram.png) 23 | -------------------------------------------------------------------------------- /docs/figs/log_gammatone_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/docs/figs/log_gammatone_spectrogram.png -------------------------------------------------------------------------------- /docs/figs/log_mel_spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/docs/figs/log_mel_spectrogram.png -------------------------------------------------------------------------------- /docs/figs/mfcc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/docs/figs/mfcc.png -------------------------------------------------------------------------------- /docs/figs/spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/docs/figs/spectrogram.png -------------------------------------------------------------------------------- /docs/tokenizers.md: -------------------------------------------------------------------------------- 1 | - [Tokenizers](#tokenizers) 2 | - [1. Character Tokenizer](#1-character-tokenizer) 3 | - [2. Wordpiece Tokenizer](#2-wordpiece-tokenizer) 4 | - [3. Sentencepiece Tokenizer](#3-sentencepiece-tokenizer) 5 | 6 | # Tokenizers 7 | 8 | ## 1. Character Tokenizer 9 | 10 | See [librespeech config](../examples/datasets/librispeech/characters/char.yml.j2) 11 | 12 | This splits the text into characters and then maps each character to an index. The index starts from 1 and 0 is reserved for blank token. This tokenizer only used for languages that have a small number of characters and each character is not a combination of other characters. For example, English, Vietnamese, etc. 13 | 14 | ## 2. Wordpiece Tokenizer 15 | 16 | See [librespeech config](../examples/datasets/librispeech/wordpiece/wp.yml.j2) for wordpiece splitted by whitespace 17 | 18 | See [librespeech config](../examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2) for wordpiece that whitespace is a separate token 19 | 20 | This splits the text into words and then splits each word into subwords. The subwords are then mapped to indices. Blank token can be set to as index 0. This tokenizer is used for languages that have a large number of words and each word can be a combination of other words, therefore it can be applied to any language. 21 | 22 | ## 3. Sentencepiece Tokenizer 23 | 24 | See [librespeech config](../examples/datasets/librispeech/sentencepiece/sp.yml.j2) 25 | 26 | This splits the whole sentence into subwords and then maps each subword to an index. Blank token can be set to as index 0. This tokenizer is used for languages that have a large number of words and each word can be a combination of other words, therefore it can be applied to any language. -------------------------------------------------------------------------------- /docs/tutorials/testing.md: -------------------------------------------------------------------------------- 1 | - [Testing Tutorial](#testing-tutorial) 2 | - [1. Installation](#1-installation) 3 | - [2. Prepare transcripts files](#2-prepare-transcripts-files) 4 | - [3. Prepare config file](#3-prepare-config-file) 5 | - [4. Run testing](#4-run-testing) 6 | 7 | 8 | # Testing Tutorial 9 | 10 | These commands are example for librispeech dataset, but we can apply similar to other datasets 11 | 12 | ## 1. Installation 13 | 14 | ```bash 15 | ./setup.sh [tpu|gpu|cpu] install 16 | ``` 17 | 18 | ## 2. Prepare transcripts files 19 | 20 | This is the example for preparing transcript files for librispeech data corpus 21 | 22 | ```bash 23 | python examples/datasets/librispeech/prepare_transcript.py \ 24 | --directory=/path/to/dataset/test-clean \ 25 | --output=/path/to/dataset/test-clean/transcripts.tsv 26 | ``` 27 | 28 | Do the same thing with `test-clean`, `test-other` 29 | 30 | For other datasets, please make your own script to prepare the transcript files, take a look at the [`prepare_transcript.py`](../../examples/datasets/librispeech/prepare_transcript.py) file for more reference 31 | 32 | ## 3. Prepare config file 33 | 34 | The config file is under format `config.yml.j2` which is jinja2 format with yaml content 35 | 36 | Please take a look in some examples for config files in `examples/*/*.yml.j2` 37 | 38 | The config file is the same as the config used for training 39 | 40 | The inputs, outputs and other options of vocabulary are defined in the config file 41 | 42 | For example: 43 | 44 | ```jinja2 45 | {% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} 46 | {{decoder_config}} 47 | 48 | {% import "examples/models/transducer/conformer/small.yml.j2" as config with context %} 49 | {{config}} 50 | ``` 51 | 52 | ## 4. Run testing 53 | 54 | ```bash 55 | tensorflow_asr test \ 56 | --config-path /path/to/config.yml.j2 \ 57 | --dataset_type slice \ 58 | --datadir /path/to/datadir \ 59 | --outputdir /path/to/modeldir/tests \ 60 | --h5 /path/to/modeldir/weights.h5 61 | ## See others params 62 | tensorflow_asr test --help 63 | ``` -------------------------------------------------------------------------------- /docs/tutorials/tflite.md: -------------------------------------------------------------------------------- 1 | - [TFLite Tutorial](#tflite-tutorial) 2 | - [Conversion](#conversion) 3 | - [Inference](#inference) 4 | - [1. Input](#1-input) 5 | - [2. Output](#2-output) 6 | - [3. Example script](#3-example-script) 7 | 8 | 9 | # TFLite Tutorial 10 | 11 | ## Conversion 12 | 13 | ```bash 14 | tensorflow_asr tflite \ 15 | --config-path=/path/to/config.yml.j2 \ 16 | --h5=/path/to/weight.h5 \ 17 | --bs=1 \ # Batch size 18 | --beam-width=0 \ # Beam width, set >0 to enable beam search 19 | --output=/path/to/output.tflite 20 | ## See others params 21 | tensorflow_asr tflite --help 22 | ``` 23 | 24 | ## Inference 25 | 26 | ### 1. Input 27 | 28 | Input of each tflite depends on the models' parameters and configs. 29 | 30 | The `inputs`, `inputs_length` and `previous_tokens` are still the same as bellow for all models. 31 | 32 | ```python 33 | schemas.PredictInput( 34 | inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32), 35 | inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32), 36 | previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)), 37 | previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)), 38 | previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)), 39 | ) 40 | ``` 41 | 42 | For models that don't have encoder states or decoder states, the default values are `tf.zeros([], dtype=self.dtype)` tensors for `previous_encoder_states` and `previous_decoder_states`. This is just for tflite conversion because tflite does not allow `None` value in `input_signature`. However, the output `next_encoder_states` and `next_decoder_states` are still `None`, so we can simply ignore those outputs. 43 | 44 | ### 2. Output 45 | 46 | ```python 47 | schemas.PredictOutputWithTranscript( 48 | transcript=self.tokenizer.detokenize(outputs.tokens), 49 | tokens=outputs.tokens, 50 | next_tokens=outputs.next_tokens, 51 | next_encoder_states=outputs.next_encoder_states, 52 | next_decoder_states=outputs.next_decoder_states, 53 | ) 54 | ``` 55 | 56 | This is for supporting streaming inference. 57 | 58 | Each output corresponds to the input = each chunk of audio signal. 59 | 60 | Then we can overwrite `previous_tokens`, `previous_encoder_states` and `previous_decoder_states` with `next_tokens`, `next_encoder_states` and `next_decoder_states` for the next chunk of audio signal. 61 | 62 | And continue until the end of the audio signal. 63 | 64 | ### 3. Example script 65 | 66 | See [examples/inferences/tflite.py](../../examples/inferences/tflite.py) for more details. -------------------------------------------------------------------------------- /docs/tutorials/training.md: -------------------------------------------------------------------------------- 1 | - [Training Tutorial](#training-tutorial) 2 | - [1. Install packages](#1-install-packages) 3 | - [2. Prepare transcripts files](#2-prepare-transcripts-files) 4 | - [3. Prepare config file](#3-prepare-config-file) 5 | - [4. \[Optional\]\[Required if using TPUs\] Create tfrecords](#4-optionalrequired-if-using-tpus-create-tfrecords) 6 | - [5. Generate vocabulary and metadata](#5-generate-vocabulary-and-metadata) 7 | - [6. Run training](#6-run-training) 8 | 9 | 10 | # Training Tutorial 11 | 12 | These commands are example for librispeech dataset, but we can apply similar to other datasets 13 | 14 | ## 1. Installation 15 | 16 | ```bash 17 | ./setup.sh [tpu|gpu|cpu] install 18 | ``` 19 | 20 | ## 2. Prepare transcripts files 21 | 22 | This is the example for preparing transcript files for librispeech data corpus 23 | 24 | ```bash 25 | python examples/datasets/librispeech/prepare_transcript.py \ 26 | --directory=/path/to/dataset/train-clean-100 \ 27 | --output=/path/to/dataset/train-clean-100/transcripts.tsv 28 | ``` 29 | 30 | Do the same thing with `train-clean-360`, `train-other-500`, `dev-clean`, `dev-other`, `test-clean`, `test-other` 31 | 32 | For other datasets, please make your own script to prepare the transcript files, take a look at the [`prepare_transcript.py`](../../examples/datasets/librispeech/prepare_transcript.py) file for more reference 33 | 34 | ## 3. Prepare config file 35 | 36 | The config file is under format `config.yml.j2` which is jinja2 format with yaml content 37 | 38 | Please take a look in some examples for config files in `examples/*/*.yml.j2` 39 | 40 | For example: 41 | 42 | ```jinja2 43 | {% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} 44 | {{decoder_config}} 45 | 46 | {% import "examples/models/transducer/conformer/small.yml.j2" as config with context %} 47 | {{config}} 48 | ``` 49 | 50 | ## 4. [Optional] Create tfrecords 51 | 52 | If you want to train with tfrecords 53 | 54 | ```bash 55 | tensorflow_asr utils create_tfrecords \ 56 | --config-path=/path/to/config.yml.j2 \ 57 | --mode=\["train","eval","test"\] \ 58 | --datadir=/path/to/datadir 59 | ``` 60 | 61 | You can reduce the flag `--modes` to `--modes=\["train","eval"\]` to only create train and eval datasets 62 | 63 | ## 5. Generate vocabulary and metadata 64 | 65 | This step requires defining path to vocabulary file and other options for generating vocabulary in config file. 66 | 67 | ```bash 68 | tensorflow_asr utils create_datasets_metadata \ 69 | --config-path=/path/to/config.yml.j2 \ 70 | --datadir=/path/to/datadir \ 71 | --dataset-type="slice" 72 | ``` 73 | 74 | The inputs, outputs and other options of vocabulary are defined in the config file 75 | 76 | ## 6. Run training 77 | 78 | ```bash 79 | tensorflow_asr train \ 80 | --config-path=/path/to/config.yml.j2 \ 81 | --modeldir=/path/to/modeldir \ 82 | --datadir=/path/to/datadir \ 83 | --dataset-type=tfrecord \ # or "generator" or "slice" \ 84 | --dataset-cache \ 85 | --mxp=strict \ 86 | --bs=4 \ 87 | --ga-steps=8 \ 88 | --verbose=1 \ 89 | --jit-compile \ 90 | --device-type=tpu \ 91 | --tpu-address=local 92 | ## See others params 93 | tensorflow_asr train --help 94 | ``` -------------------------------------------------------------------------------- /examples/datasets/librispeech/characters/char.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 29 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/librispeech/characters/english" %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: characters 7 | blank_index: 0 8 | beam_width: 0 9 | norm_score: True 10 | lm_config: null 11 | vocabulary: {{vocabprefix}}.vocab 12 | vocab_size: {{vocabsize}} 13 | 14 | {% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} 15 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/librispeech/characters/english.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 475760, 4 | "max_label_length": 524, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 562480, 9 | "max_label_length": 516, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/characters/english.vocab: -------------------------------------------------------------------------------- 1 | # List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which 2 | # will be ignored by the parser. 3 | # begin of vocabulary 4 | 5 | 6 | a 7 | b 8 | c 9 | d 10 | e 11 | f 12 | g 13 | h 14 | i 15 | j 16 | k 17 | l 18 | m 19 | n 20 | o 21 | p 22 | q 23 | r 24 | s 25 | t 26 | u 27 | v 28 | w 29 | x 30 | y 31 | z 32 | ' 33 | # end of vocabulary 34 | -------------------------------------------------------------------------------- /examples/datasets/librispeech/config.yml.j2: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_dataset_config: 3 | enabled: True 4 | sample_rate: 16000 5 | data_paths: 6 | - {{datadir}}/train-clean-100/transcripts.tsv 7 | - {{datadir}}/train-clean-360/transcripts.tsv 8 | - {{datadir}}/train-other-500/transcripts.tsv 9 | tfrecords_dir: {{datadir}}/tfrecords 10 | tfrecords_shards: 32 11 | shuffle: True 12 | cache: False 13 | buffer_size: 1024 14 | drop_remainder: True 15 | stage: train 16 | metadata: {{metadata}} 17 | indefinite: True 18 | 19 | eval_dataset_config: 20 | enabled: True 21 | sample_rate: 16000 22 | data_paths: 23 | - {{datadir}}/dev-clean/transcripts.tsv 24 | - {{datadir}}/dev-other/transcripts.tsv 25 | tfrecords_dir: {{datadir}}/tfrecords 26 | buffer_size: 1024 27 | tfrecords_shards: 2 28 | shuffle: True 29 | cache: False 30 | drop_remainder: True 31 | stage: eval 32 | metadata: {{metadata}} 33 | indefinite: True 34 | 35 | test_dataset_configs: 36 | - name: test-clean 37 | enabled: True 38 | sample_rate: 16000 39 | data_paths: 40 | - {{datadir}}/test-clean/transcripts.tsv 41 | tfrecords_dir: {{datadir}}/tfrecords 42 | shuffle: False 43 | cache: False 44 | buffer_size: null 45 | drop_remainder: False 46 | stage: test 47 | indefinite: False 48 | - name: test-other 49 | enabled: True 50 | sample_rate: 16000 51 | data_paths: 52 | - {{datadir}}/test-other/transcripts.tsv 53 | tfrecords_dir: {{datadir}}/tfrecords 54 | shuffle: False 55 | cache: False 56 | buffer_size: null 57 | drop_remainder: False 58 | stage: test 59 | indefinite: False -------------------------------------------------------------------------------- /examples/datasets/librispeech/prepare_transcript.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import unicodedata 18 | 19 | import librosa 20 | 21 | from tensorflow_asr.utils import cli_util, file_util 22 | 23 | 24 | def main( 25 | directory: str, 26 | output: str, 27 | ): 28 | directory = file_util.preprocess_paths(directory, isdir=True) 29 | output = file_util.preprocess_paths(output) 30 | 31 | transcripts = [] 32 | 33 | text_files = glob.glob(os.path.join(directory, "**", "*.txt"), recursive=True) 34 | 35 | from tqdm.auto import tqdm 36 | 37 | for text_file in tqdm(text_files, desc="[Loading]", disable=False): 38 | current_dir = os.path.dirname(text_file) 39 | with open(text_file, "r", encoding="utf-8") as txt: 40 | lines = txt.read().splitlines() 41 | for line in lines: 42 | line = line.split(" ", maxsplit=1) 43 | audio_file = os.path.join(current_dir, line[0] + ".flac") 44 | y, sr = librosa.load(audio_file, sr=None) 45 | duration = librosa.get_duration(y=y, sr=sr) 46 | text = unicodedata.normalize("NFKC", line[1]) 47 | transcripts.append(f"{audio_file}\t{duration}\t{text.lower()}\n") 48 | 49 | with open(output, "w", encoding="utf-8") as out: 50 | out.write("PATH\tDURATION\tTRANSCRIPT\n") 51 | for line in tqdm(transcripts, desc="[Writing]"): 52 | out.write(line) 53 | 54 | 55 | if __name__ == "__main__": 56 | cli_util.run(main) 57 | -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/sp.256.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 256 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/librispeech/sentencepiece/train_bpe_" ~ vocabsize %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: sentencepiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | pad_token: "" 11 | pad_index: -1 12 | bos_token: "" 13 | bos_index: -1 14 | eos_token: "" 15 | eos_index: -1 16 | beam_width: 0 17 | norm_score: True 18 | lm_config: null 19 | model_type: bpe 20 | vocabulary: {{vocabprefix}}.model 21 | vocab_size: {{vocabsize}} 22 | reserved_tokens: null 23 | normalization_form: NFKC 24 | max_sentencepiece_length: 16 25 | max_sentence_length: 1048576 26 | character_coverage: 1.0 27 | keep_whitespace: False 28 | 29 | {% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} 30 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/sp.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 1000 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/librispeech/sentencepiece/train_bpe_" ~ vocabsize %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: sentencepiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | pad_token: "" 11 | pad_index: -1 12 | bos_token: "" 13 | bos_index: -1 14 | eos_token: "" 15 | eos_index: -1 16 | beam_width: 0 17 | norm_score: True 18 | lm_config: null 19 | model_type: bpe 20 | vocabulary: {{vocabprefix}}.model 21 | vocab_size: {{vocabsize}} 22 | reserved_tokens: null 23 | normalization_form: NFKC 24 | max_sentencepiece_length: 16 25 | max_sentence_length: 1048576 26 | character_coverage: 1.0 27 | keep_whitespace: False 28 | 29 | {% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} 30 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_8000&960.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/datasets/librispeech/sentencepiece/train_8000&960.model -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_bpe_1000.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 475760, 4 | "max_label_length": 230, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 562480, 9 | "max_label_length": 225, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_bpe_1000.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/datasets/librispeech/sentencepiece/train_bpe_1000.model -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_bpe_256.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 475760, 4 | "max_label_length": 270, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 562480, 9 | "max_label_length": 260, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_bpe_256.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/datasets/librispeech/sentencepiece/train_bpe_256.model -------------------------------------------------------------------------------- /examples/datasets/librispeech/sentencepiece/train_bpe_256.vocab: -------------------------------------------------------------------------------- 1 | 0 2 | ▁t -0 3 | he -1 4 | ▁a -2 5 | ▁the -3 6 | in -4 7 | ▁s -5 8 | ▁w -6 9 | ▁o -7 10 | re -8 11 | nd -9 12 | ▁b -10 13 | ▁h -11 14 | er -12 15 | ▁m -13 16 | ▁i -14 17 | ou -15 18 | ▁c -16 19 | ▁f -17 20 | at -18 21 | ed -19 22 | ▁and -20 23 | en -21 24 | ▁to -22 25 | ▁of -23 26 | on -24 27 | is -25 28 | ▁d -26 29 | ing -27 30 | ▁th -28 31 | ▁p -29 32 | ▁he -30 33 | or -31 34 | ▁l -32 35 | es -33 36 | ▁in -34 37 | ll -35 38 | it -36 39 | ar -37 40 | as -38 41 | an -39 42 | ▁n -40 43 | ▁g -41 44 | om -42 45 | ▁be -43 46 | ▁ha -44 47 | ▁e -45 48 | le -46 49 | ot -47 50 | ▁y -48 51 | ut -49 52 | ow -50 53 | ic -51 54 | ▁wh -52 55 | ▁it -53 56 | ld -54 57 | ve -55 58 | ▁that -56 59 | ly -57 60 | ▁was -58 61 | id -59 62 | se -60 63 | st -61 64 | ▁on -62 65 | gh -63 66 | ent -64 67 | ▁re -65 68 | ▁you -66 69 | im -67 70 | ce -68 71 | ▁u -69 72 | ver -70 73 | ion -71 74 | ▁as -72 75 | et -73 76 | ▁for -74 77 | ay -75 78 | ▁we -76 79 | ▁his -77 80 | ith -78 81 | al -79 82 | ir -80 83 | ▁r -81 84 | ▁with -82 85 | ▁st -83 86 | ad -84 87 | ur -85 88 | ght -86 89 | ▁an -87 90 | ▁her -88 91 | ▁not -89 92 | ▁had -90 93 | ▁is -91 94 | ter -92 95 | her -93 96 | ac -94 97 | am -95 98 | ▁at -96 99 | oo -97 100 | ▁but -98 101 | ould -99 102 | ▁she -100 103 | ▁k -101 104 | ▁se -102 105 | ▁sa -103 106 | ▁sh -104 107 | ▁fr -105 108 | ▁him -106 109 | ▁so -107 110 | ill -108 111 | ▁me -109 112 | ain -110 113 | ▁su -111 114 | ight -112 115 | ch -113 116 | red -114 117 | ct -115 118 | all -116 119 | ro -117 120 | ke -118 121 | ess -119 122 | il -120 123 | ore -121 124 | ▁de -122 125 | ▁they -123 126 | ▁my -124 127 | ▁whe -125 128 | ▁all -126 129 | ich -127 130 | ▁ne -128 131 | ri -129 132 | ▁by -130 133 | ▁have -131 134 | ome -132 135 | pp -133 136 | ▁this -134 137 | ▁li -135 138 | ▁do -136 139 | ▁con -137 140 | us -138 141 | ▁which -139 142 | ▁ch -140 143 | ul -141 144 | qu -142 145 | ▁j -143 146 | ▁up -144 147 | ▁said -145 148 | ▁from -146 149 | ard -147 150 | ge -148 151 | ▁or -149 152 | ▁v -150 153 | ▁one -151 154 | th -152 155 | ▁no -153 156 | ▁ex -154 157 | ▁were -155 158 | ▁there -156 159 | pe -157 160 | and -158 161 | est -159 162 | ▁man -160 163 | ▁who -161 164 | ble -162 165 | ant -163 166 | ie -164 167 | ▁al -165 168 | res -166 169 | ous -167 170 | ust -168 171 | very -169 172 | ation -170 173 | ▁fe -171 174 | ▁them -172 175 | lf -173 176 | ▁when -174 177 | ind -175 178 | nt -176 179 | ame -177 180 | ra -178 181 | ▁go -179 182 | ers -180 183 | ast -181 184 | fe -182 185 | ood -183 186 | ▁kn -184 187 | ▁int -185 188 | ist -186 189 | art -187 190 | ▁are -188 191 | out -189 192 | ▁would -190 193 | ▁le -191 194 | os -192 195 | ▁their -193 196 | ong -194 197 | ▁what -195 198 | our -196 199 | ▁if -197 200 | ound -198 201 | ▁com -199 202 | ▁ab -200 203 | ▁out -201 204 | ▁wor -202 205 | em -203 206 | ▁will -204 207 | ak -205 208 | ▁mis -206 209 | ate -207 210 | ol -208 211 | um -209 212 | un -210 213 | itt -211 214 | ough -212 215 | ked -213 216 | ap -214 217 | ig -215 218 | one -216 219 | ▁been -217 220 | own -218 221 | ive -219 222 | ▁then -220 223 | ▁br -221 224 | ven -222 225 | if -223 226 | ▁ar -224 227 | ▁tr -225 228 | self -226 229 | ▁ -227 230 | e -228 231 | t -229 232 | a -230 233 | o -231 234 | n -232 235 | i -233 236 | h -234 237 | s -235 238 | r -236 239 | d -237 240 | l -238 241 | u -239 242 | m -240 243 | c -241 244 | w -242 245 | f -243 246 | g -244 247 | y -245 248 | p -246 249 | b -247 250 | v -248 251 | k -249 252 | ' -250 253 | x -251 254 | j -252 255 | q -253 256 | z -254 257 | -------------------------------------------------------------------------------- /examples/datasets/librispeech/subwords/train_1030_4.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 2974, 4 | "max_label_length": 207, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 3516, 9 | "max_label_length": 194, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/wordpiece/train_1000.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 475760, 4 | "max_label_length": 202, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 562480, 9 | "max_label_length": 190, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/wordpiece/train_1000_whitespace.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 475760, 4 | "max_label_length": 281, 5 | "num_entries": 281241 6 | }, 7 | "eval": { 8 | "max_input_length": 562480, 9 | "max_label_length": 279, 10 | "num_entries": 5567 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/librispeech/wordpiece/wp.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 1000 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/librispeech/wordpiece/train_" ~ vocabsize %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: wordpiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | beam_width: 0 11 | norm_score: True 12 | lm_config: null 13 | vocabulary: {{vocabprefix}}.vocab 14 | keep_whitespace: False 15 | vocab_size: {{vocabsize}} 16 | max_token_length: 50 17 | max_unique_chars: 1000 18 | reserved_tokens: 19 | - "" 20 | normalization_form: NFKC 21 | num_iterations: 4 22 | 23 | {% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} 24 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 1000 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/librispeech/wordpiece/train_" ~ vocabsize ~ "_whitespace" %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: wordpiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | beam_width: 0 11 | norm_score: True 12 | lm_config: null 13 | vocabulary: {{vocabprefix}}.vocab 14 | keep_whitespace: True 15 | vocab_size: {{vocabsize}} 16 | max_token_length: 50 17 | max_unique_chars: 1000 18 | reserved_tokens: 19 | - "" 20 | normalization_form: NFKC 21 | num_iterations: 4 22 | 23 | {% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} 24 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/vietbud500/config.yml.j2: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_dataset_config: 3 | enabled: True 4 | sample_rate: 16000 5 | data_paths: 6 | - {{datadir}}/train/transcripts.tsv 7 | tfrecords_dir: {{datadir}}/tfrecords 8 | tfrecords_shards: 32 9 | shuffle: True 10 | cache: False 11 | buffer_size: 1024 12 | drop_remainder: True 13 | stage: train 14 | metadata: {{metadata}} 15 | indefinite: True 16 | 17 | eval_dataset_config: 18 | enabled: True 19 | sample_rate: 16000 20 | data_paths: 21 | - {{datadir}}/validation/transcripts.tsv 22 | tfrecords_dir: {{datadir}}/tfrecords 23 | buffer_size: 1024 24 | tfrecords_shards: 2 25 | shuffle: True 26 | cache: False 27 | drop_remainder: True 28 | stage: eval 29 | metadata: {{metadata}} 30 | indefinite: True 31 | 32 | test_dataset_configs: 33 | - name: test 34 | enabled: True 35 | sample_rate: 16000 36 | data_paths: 37 | - {{datadir}}/test/transcripts.tsv 38 | tfrecords_dir: {{datadir}}/tfrecords 39 | shuffle: False 40 | cache: False 41 | buffer_size: null 42 | drop_remainder: False 43 | stage: test 44 | indefinite: False -------------------------------------------------------------------------------- /examples/datasets/vietbud500/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import datasets 4 | import librosa 5 | import soundfile 6 | from tqdm import tqdm 7 | 8 | from tensorflow_asr.utils import cli_util, data_util 9 | 10 | MAPPING = { 11 | "audio.array": "audio", 12 | "audio.sampling_rate": "sample_rate", 13 | "transcription": "transcript", 14 | } 15 | 16 | 17 | def load_item_from_mapping(item): 18 | data = {} 19 | for path, key in MAPPING.items(): 20 | data[key] = data_util.get(item, path) 21 | if not all(x in data for x in ["audio", "transcript"]): 22 | return None 23 | return data["audio"], int(data["sample_rate"]), str(data["transcript"]) 24 | 25 | 26 | def main( 27 | directory: str, 28 | token: str, 29 | ): 30 | dataset_list = datasets.load_dataset("linhtran92/viet_bud500", token=token, streaming=True, keep_in_memory=False) 31 | for stage in dataset_list.keys(): 32 | print(f"[Loading {stage}]") 33 | output = os.path.realpath(os.path.join(directory, stage, "audio")) 34 | tsv_output = os.path.realpath(os.path.join(directory, stage, "transcripts.tsv")) 35 | os.makedirs(output, exist_ok=True) 36 | with open(tsv_output, "w", encoding="utf-8") as out: 37 | out.write("PATH\tDURATION\tTRANSCRIPT\n") 38 | index = 1 39 | for item in tqdm(dataset_list[stage], desc=f"[Loading to {output}]", disable=False): 40 | data = load_item_from_mapping(item) 41 | if data is None: 42 | continue 43 | audio, sample_rate, transcript = data 44 | path = os.path.join(output, f"{index}.wav") 45 | soundfile.write(path, audio, sample_rate) 46 | duration = librosa.get_duration(y=audio, sr=sample_rate) 47 | out.write(f"{path}\t{duration}\t{transcript}\n") 48 | index += 1 49 | 50 | 51 | if __name__ == "__main__": 52 | cli_util.run(main) 53 | -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/sp.256.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 256 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/vietbud500/sentencepiece/train_bpe_" ~ vocabsize %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: sentencepiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | pad_token: "" 11 | pad_index: -1 12 | bos_token: "" 13 | bos_index: -1 14 | eos_token: "" 15 | eos_index: -1 16 | beam_width: 0 17 | norm_score: True 18 | lm_config: null 19 | model_type: bpe 20 | vocabulary: {{vocabprefix}}.model 21 | vocab_size: {{vocabsize}} 22 | reserved_tokens: null 23 | normalization_form: NFKC 24 | max_sentencepiece_length: 16 25 | max_sentence_length: 1048576 26 | character_coverage: 1.0 27 | keep_whitespace: False 28 | 29 | {% import "examples/datasets/vietbud500/config.yml.j2" as data_config with context %} 30 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/sp.yml.j2: -------------------------------------------------------------------------------- 1 | {% set vocabsize = 1000 %} 2 | {% set vocabprefix = repodir ~ "/examples/datasets/vietbud500/sentencepiece/train_bpe_" ~ vocabsize %} 3 | {% set metadata = vocabprefix ~ ".metadata.json" %} 4 | 5 | decoder_config: 6 | type: sentencepiece 7 | blank_index: 0 8 | unknown_token: "" 9 | unknown_index: 0 10 | pad_token: "" 11 | pad_index: -1 12 | bos_token: "" 13 | bos_index: -1 14 | eos_token: "" 15 | eos_index: -1 16 | beam_width: 0 17 | norm_score: True 18 | lm_config: null 19 | model_type: bpe 20 | vocabulary: {{vocabprefix}}.model 21 | vocab_size: {{vocabsize}} 22 | reserved_tokens: null 23 | normalization_form: NFKC 24 | max_sentencepiece_length: 16 25 | max_sentence_length: 1048576 26 | character_coverage: 1.0 27 | keep_whitespace: False 28 | 29 | {% import "examples/datasets/vietbud500/config.yml.j2" as data_config with context %} 30 | {{data_config}} -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/train_bpe_1000.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 256498, 4 | "max_label_length": 75, 5 | "num_entries": 634158 6 | }, 7 | "eval": { 8 | "max_input_length": 117571, 9 | "max_label_length": 42, 10 | "num_entries": 7500 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/train_bpe_1000.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.model -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/train_bpe_256.metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "max_input_length": 256498, 4 | "max_label_length": 100, 5 | "num_entries": 634158 6 | }, 7 | "eval": { 8 | "max_input_length": 117571, 9 | "max_label_length": 57, 10 | "num_entries": 7500 11 | } 12 | } -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/train_bpe_256.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/datasets/vietbud500/sentencepiece/train_bpe_256.model -------------------------------------------------------------------------------- /examples/datasets/vietbud500/sentencepiece/train_bpe_256.vocab: -------------------------------------------------------------------------------- 1 | 0 2 | ▁c -0 3 | ng -1 4 | ▁t -2 5 | nh -3 6 | ▁đ -4 7 | ▁m -5 8 | ▁l -6 9 | ▁th -7 10 | ▁v -8 11 | ▁ch -9 12 | ▁b -10 13 | ▁nh -11 14 | ▁k -12 15 | ▁n -13 16 | ▁h -14 17 | ▁kh -15 18 | ▁ng -16 19 | ▁s -17 20 | ▁g -18 21 | ▁là -19 22 | ông -20 23 | ▁tr -21 24 | ▁r -22 25 | ▁không -23 26 | ời -24 27 | ▁p -25 28 | ▁ph -26 29 | ▁cá -27 30 | ▁có -28 31 | ên -29 32 | ▁d -30 33 | ôi -31 34 | ình -32 35 | ▁gi -33 36 | anh -34 37 | qu -35 38 | ▁qu -36 39 | ▁và -37 40 | ột -38 41 | ới -39 42 | ▁củ -40 43 | ▁của -41 44 | iế -42 45 | ười -43 46 | ▁như -44 47 | ▁một -45 48 | ▁tôi -46 49 | ▁nó -47 50 | ▁mà -48 51 | ▁người -49 52 | iệ -50 53 | ▁x -51 54 | ▁anh -52 55 | ▁đư -53 56 | ại -54 57 | ất -55 58 | ấy -56 59 | ▁nà -57 60 | ▁mình -58 61 | ▁đi -59 62 | ▁thì -60 63 | ▁cái -61 64 | ợc -62 65 | em -63 66 | ▁được -64 67 | ay -65 68 | ▁cũ -66 69 | uy -67 70 | ▁co -68 71 | ▁cũng -69 72 | ững -70 73 | ong -71 74 | ▁những -72 75 | ▁cho -73 76 | ▁con -74 77 | ai -75 78 | ải -76 79 | ▁em -77 80 | ▁ngh -78 81 | ▁cả -79 82 | ều -80 83 | ▁đó -81 84 | ▁cô -82 85 | ồi -83 86 | ▁lại -84 87 | ▁với -85 88 | ch -86 89 | ao -87 90 | ân -88 91 | ▁này -89 92 | ▁đã -90 93 | ▁trong -91 94 | ần -92 95 | uố -93 96 | ▁để -94 97 | ▁làm -95 98 | ▁nói -96 99 | ▁ta -97 100 | ạn -98 101 | ▁phải -99 102 | ▁ra -100 103 | ây -101 104 | ▁chú -102 105 | ▁nhưng -103 106 | ướ -104 107 | ang -105 108 | au -106 109 | ▁rồi -107 110 | ▁sẽ -108 111 | âu -109 112 | ến -110 113 | ▁về -111 114 | ▁nhi -112 115 | iết -113 116 | an -114 117 | ác -115 118 | ▁khi -116 119 | òn -117 120 | ▁ti -118 121 | ▁gì -119 122 | ▁thế -120 123 | ▁bạn -121 124 | ước -122 125 | ▁ở -123 126 | ▁họ -124 127 | ▁đến -125 128 | ▁còn -126 129 | ▁thể -127 130 | ▁các -128 131 | ết -129 132 | ▁mẹ -130 133 | ▁việ -131 134 | ươ -132 135 | ật -133 136 | ▁ông -134 137 | ▁biết -135 138 | úc -136 139 | ▁nhà -137 140 | ▁chúng -138 141 | ương -139 142 | ận -140 143 | oà -141 144 | ▁chị -142 145 | ành -143 146 | ▁bà -144 147 | ơn -145 148 | ▁thấy -146 149 | ▁từ -147 150 | ầu -148 151 | ậy -149 152 | ▁chuy -150 153 | ▁nào -151 154 | ăn -152 155 | ▁chỉ -153 156 | ờng -154 157 | ữa -155 158 | ▁rất -156 159 | ▁sự -157 160 | ồng -158 161 | ▁nhiều -159 162 | ùng -160 163 | ▁ -161 164 | n -162 165 | h -163 166 | c -164 167 | i -165 168 | t -166 169 | g -167 170 | m -168 171 | a -169 172 | đ -170 173 | à -171 174 | u -172 175 | l -173 176 | o -174 177 | ư -175 178 | y -176 179 | ô -177 180 | v -178 181 | r -179 182 | b -180 183 | k -181 184 | á -182 185 | ó -183 186 | ì -184 187 | s -185 188 | ế -186 189 | p -187 190 | ờ -188 191 | ấ -189 192 | ạ -190 193 | ả -191 194 | ê -192 195 | ộ -193 196 | ớ -194 197 | â -195 198 | ố -196 199 | ệ -197 200 | ề -198 201 | ủ -199 202 | d -200 203 | ậ -201 204 | ể -202 205 | e -203 206 | ợ -204 207 | ú -205 208 | q -206 209 | ữ -207 210 | ơ -208 211 | ồ -209 212 | ọ -210 213 | ầ -211 214 | ị -212 215 | ứ -213 216 | x -214 217 | ắ -215 218 | ã -216 219 | ở -217 220 | ũ -218 221 | ự -219 222 | í -220 223 | ò -221 224 | ă -222 225 | ừ -223 226 | ặ -224 227 | ẽ -225 228 | ẹ -226 229 | ù -227 230 | ỏ -228 231 | ụ -229 232 | ổ -230 233 | ỉ -231 234 | ĩ -232 235 | ằ -233 236 | ẫ -234 237 | ý -235 238 | é -236 239 | ử -237 240 | ỗ -238 241 | ẻ -239 242 | ẳ -240 243 | ẩ -241 244 | ễ -242 245 | è -243 246 | ỡ -244 247 | õ -245 248 | ỳ -246 249 | ỹ -247 250 | ỷ -248 251 | ẵ -249 252 | ỵ -250 253 | w -251 254 | f -252 255 | j -253 256 | z -254 257 | -------------------------------------------------------------------------------- /examples/datasets/vivos/vietnamese.characters: -------------------------------------------------------------------------------- 1 | # List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which 2 | # will be ignored by the parser. 3 | # begin of vocabulary 4 | 5 | 6 | a 7 | b 8 | c 9 | d 10 | e 11 | f 12 | g 13 | h 14 | i 15 | j 16 | k 17 | l 18 | m 19 | n 20 | o 21 | p 22 | q 23 | r 24 | s 25 | t 26 | u 27 | v 28 | w 29 | x 30 | y 31 | z 32 | á 33 | à 34 | ạ 35 | ã 36 | ả 37 | ă 38 | ắ 39 | ằ 40 | ặ 41 | ẵ 42 | ẳ 43 | â 44 | ấ 45 | ầ 46 | ậ 47 | ẫ 48 | ẩ 49 | đ 50 | é 51 | è 52 | ẹ 53 | ẽ 54 | ẻ 55 | ê 56 | ế 57 | ề 58 | ệ 59 | ễ 60 | ể 61 | í 62 | ì 63 | ị 64 | ĩ 65 | ỉ 66 | ó 67 | ò 68 | ọ 69 | õ 70 | ỏ 71 | ơ 72 | ớ 73 | ờ 74 | ợ 75 | ỡ 76 | ở 77 | ô 78 | ố 79 | ồ 80 | ộ 81 | ỗ 82 | ổ 83 | ú 84 | ù 85 | ụ 86 | ũ 87 | ủ 88 | ư 89 | ứ 90 | ừ 91 | ự 92 | ữ 93 | ử 94 | ý 95 | ỳ 96 | ỵ 97 | ỹ 98 | ỷ 99 | ' 100 | # end of vocabulary 101 | -------------------------------------------------------------------------------- /examples/inferences/README.md: -------------------------------------------------------------------------------- 1 | # TFLite Demonstrations 2 | 3 | ## Streaming 4 | 5 | We must install this dependencies for running streaming mode 6 | 7 | - [python-sounddevice](https://python-sounddevice.readthedocs.io/en/0.4.1/installation.html) 8 | 9 | To run, please read the code in this example. 10 | 11 | ## Non Streaming 12 | 13 | Please read the code in this example. 14 | 15 | ## Wave files 16 | 17 | Wave files are chosen from LibriSpeech test-clean and test-other. -------------------------------------------------------------------------------- /examples/inferences/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from tensorflow_asr import keras, schemas, tf, tokenizers 18 | from tensorflow_asr.configs import Config 19 | from tensorflow_asr.models import base_model 20 | from tensorflow_asr.utils import cli_util, data_util, env_util, file_util 21 | 22 | logger = tf.get_logger() 23 | 24 | 25 | def main( 26 | file_path: str, 27 | config_path: str, 28 | h5: str, 29 | repodir: str = os.getcwd(), 30 | ): 31 | env_util.setup_seed() 32 | file_path = file_util.preprocess_paths(file_path) 33 | 34 | config = Config(config_path, training=False, repodir=repodir) 35 | tokenizer = tokenizers.get(config) 36 | 37 | model: base_model.BaseModel = keras.Model.from_config(config.model_config) 38 | model.make(batch_size=1) 39 | model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False) 40 | model.summary() 41 | 42 | signal = data_util.read_raw_audio(data_util.load_and_convert_to_wav(file_path)) 43 | signal = tf.reshape(signal, [1, -1]) 44 | signal_length = tf.reshape(tf.shape(signal)[1], [1]) 45 | 46 | outputs = model.recognize( 47 | schemas.PredictInput( 48 | inputs=signal, 49 | inputs_length=signal_length, 50 | previous_tokens=model.get_initial_tokens(), 51 | previous_encoder_states=model.get_initial_encoder_states(), 52 | previous_decoder_states=model.get_initial_decoder_states(), 53 | ) 54 | ) 55 | transcript = tokenizer.detokenize(outputs.tokens)[0].numpy().decode("utf-8") 56 | logger.info(f"Transcript: {transcript}") 57 | 58 | 59 | if __name__ == "__main__": 60 | cli_util.run(main) 61 | -------------------------------------------------------------------------------- /examples/inferences/rnn_transducer.py: -------------------------------------------------------------------------------- 1 | # # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # # 3 | # # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # # you may not use this file except in compliance with the License. 5 | # # You may obtain a copy of the License at 6 | # # 7 | # # http://www.apache.org/licenses/LICENSE-2.0 8 | # # 9 | # # Unless required by applicable law or agreed to in writing, software 10 | # # distributed under the License is distributed on an "AS IS" BASIS, 11 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # # See the License for the specific language governing permissions and 13 | # # limitations under the License. 14 | 15 | # import argparse 16 | 17 | # from tensorflow_asr.utils import data_util, env_util, math_util 18 | 19 | # logger = env_util.setup_environment() 20 | # import tensorflow as tf 21 | 22 | # parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming") 23 | 24 | # parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") 25 | 26 | # parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml") 27 | 28 | # parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights") 29 | 30 | # parser.add_argument("--beam_width", type=int, default=0, help="Beam width") 31 | 32 | # parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") 33 | 34 | # parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") 35 | 36 | # parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") 37 | 38 | # parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords") 39 | 40 | # parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") 41 | 42 | # args = parser.parse_args() 43 | 44 | # env_util.setup_devices([args.device], cpu=args.cpu) 45 | 46 | # from tensorflow_asr.configs import Config 47 | # from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio 48 | # from tensorflow_asr.models.transducer.rnnt import RnnTransducer 49 | # from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer 50 | 51 | # config = Config(args.config) 52 | # speech_featurizer = SpeechFeaturizer(config.speech_config) 53 | # if args.sentence_piece: 54 | # logger.info("Loading SentencePiece model ...") 55 | # text_featurizer = SentencePieceTokenizer(config.decoder_config) 56 | # elif args.subwords: 57 | # logger.info("Loading subwords ...") 58 | # text_featurizer = SubwordFeaturizer(config.decoder_config) 59 | # else: 60 | # text_featurizer = CharTokenizer(config.decoder_config) 61 | # text_featurizer.decoder_config.beam_width = args.beam_width 62 | 63 | # # build model 64 | # rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes) 65 | # rnnt.make(speech_featurizer.shape) 66 | # rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) 67 | # rnnt.summary() 68 | # rnnt.add_featurizers(speech_featurizer, text_featurizer) 69 | 70 | # signal = read_raw_audio(args.filename) 71 | # features = speech_featurizer.tf_extract(signal) 72 | # input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) 73 | 74 | # if args.beam_width: 75 | # transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) 76 | # logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) 77 | # elif args.timestamp: 78 | # transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( 79 | # signal=signal, 80 | # predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), 81 | # encoder_states=rnnt.encoder.get_initial_state(), 82 | # prediction_states=rnnt.predict_net.get_initial_state(), 83 | # ) 84 | # logger.info("Transcript:", transcript) 85 | # logger.info("Start time:", stime) 86 | # logger.info("End time:", etime) 87 | # else: 88 | # transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) 89 | # logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) 90 | -------------------------------------------------------------------------------- /examples/inferences/tflite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import tensorflow_text as tft 18 | from tensorflow.lite.python import interpreter 19 | 20 | from tensorflow_asr import tf 21 | from tensorflow_asr.utils import cli_util, data_util 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main( 27 | audio_file_path: str, 28 | tflite: str, 29 | sample_rate: int = 16000, 30 | blank: int = 0, 31 | ): 32 | wav = data_util.load_and_convert_to_wav(audio_file_path, sample_rate=sample_rate) 33 | signal = data_util.read_raw_audio(wav) 34 | signal = tf.reshape(signal, [1, -1]) 35 | signal_length = tf.reshape(tf.shape(signal)[1], [1]) 36 | 37 | tflitemodel = interpreter.InterpreterWithCustomOps(model_path=tflite, custom_op_registerers=tft.tflite_registrar.SELECT_TFTEXT_OPS) 38 | input_details = tflitemodel.get_input_details() 39 | output_details = tflitemodel.get_output_details() 40 | 41 | tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape, strict=True) 42 | tflitemodel.allocate_tensors() 43 | tflitemodel.set_tensor(input_details[0]["index"], signal) 44 | tflitemodel.set_tensor(input_details[1]["index"], signal_length) 45 | tflitemodel.set_tensor(input_details[2]["index"], tf.ones(input_details[2]["shape"], dtype=input_details[2]["dtype"]) * blank) 46 | tflitemodel.set_tensor(input_details[3]["index"], tf.zeros(input_details[3]["shape"], dtype=input_details[3]["dtype"])) 47 | tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(input_details[4]["shape"], dtype=input_details[4]["dtype"])) 48 | 49 | tflitemodel.invoke() 50 | 51 | transcript = tflitemodel.get_tensor(output_details[0]["index"]) 52 | tokens = tflitemodel.get_tensor(output_details[1]["index"]) 53 | next_tokens = tflitemodel.get_tensor(output_details[2]["index"]) 54 | if len(output_details) > 4: 55 | next_encoder_states = tflitemodel.get_tensor(output_details[3]["index"]) 56 | next_decoder_states = tflitemodel.get_tensor(output_details[4]["index"]) 57 | elif len(output_details) > 3: 58 | next_encoder_states = None 59 | next_decoder_states = tflitemodel.get_tensor(output_details[3]["index"]) 60 | else: 61 | next_encoder_states = None 62 | next_decoder_states = None 63 | 64 | logger.info(f"Transcript: {transcript}") 65 | logger.info(f"Tokens: {tokens}") 66 | logger.info(f"Next tokens: {next_tokens}") 67 | logger.info(f"Next encoder states: {None if next_encoder_states is None else next_encoder_states.shape}") 68 | logger.info(f"Next decoder states: {None if next_decoder_states is None else next_decoder_states.shape}") 69 | 70 | 71 | if __name__ == "__main__": 72 | cli_util.run(main) 73 | -------------------------------------------------------------------------------- /examples/inferences/wavs/1089-134691-0000.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/inferences/wavs/1089-134691-0000.flac -------------------------------------------------------------------------------- /examples/inferences/wavs/2033-164915-0001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/inferences/wavs/2033-164915-0001.flac -------------------------------------------------------------------------------- /examples/models/ctc/conformer/results/sentencepiece/README.md: -------------------------------------------------------------------------------- 1 | - [\[English\] LibriSpeech](#english-librispeech) 2 | - [I. Small + SentencePiece 256](#i-small--sentencepiece-256) 3 | - [II. Small + Streaming + SentencePiece 256](#ii-small--streaming--sentencepiece-256) 4 | 5 | # [English] LibriSpeech 6 | 7 | ## I. Small + SentencePiece 256 8 | 9 | | Category | Description | 10 | | :---------------- | :--------------------------------------------------------------------------------------- | 11 | | Config | [small.yml.j2](../../small.yml.j2) | 12 | | Tensorflow | **2.18.0** | 13 | | Device | Google Cloud TPUs v4-8 | 14 | | Mixed Precision | strict | 15 | | Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) | 16 | | Max Epochs | 450 | 17 | | Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small) | 18 | 19 | **Config:** 20 | 21 | ```jinja2 22 | {% import "examples/datasets/librispeech/sentencepiece/sp.256.yml.j2" as decoder_config with context %} 23 | {{decoder_config}} 24 | {% import "examples/models/ctc/conformer/small.yml.j2" as config with context %} 25 | {{config}} 26 | ``` 27 | 28 | **Results:** 29 | 30 | | Epoch | Dataset | decoding | wer | cer | mer | wil | wip | 31 | | :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------- | :------- | 32 | | 170 | test-clean | greedy | 0.0967171 | 0.031954 | 0.0958403 | 0.168307 | 0.831693 | 33 | | 170 | test-other | greedy | 0.201612 | 0.0812955 | 0.197415 | 0.330207 | 0.669793 | 34 | 35 | 36 | ## II. Small + Streaming + SentencePiece 256 37 | 38 | | Category | Description | 39 | | :---------------- | :------------------------------------------------------------------------------------------------- | 40 | | Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) | 41 | | Tensorflow | **2.18.0** | 42 | | Device | Google Cloud TPUs v4-8 | 43 | | Mixed Precision | strict | 44 | | Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) | 45 | | Max Epochs | 450 | 46 | | Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small-streaming) | 47 | 48 | **Config:** 49 | 50 | ```jinja2 51 | {% import "examples/datasets/librispeech/sentencepiece/sp.256.yml.j2" as decoder_config with context %} 52 | {{decoder_config}} 53 | {% import "examples/models/ctc/conformer/small-streaming.yml.j2" as config with context %} 54 | {{config}} 55 | ``` 56 | 57 | **Tensorboard:** 58 | 59 | 60 | 61 | 65 | 69 | 73 | 74 |
62 |
63 | Epoch Loss 64 |
66 |
67 | Batch Loss 68 |
70 |
71 | Learning Rate 72 |
75 | 76 | **Results:** 77 | 78 | | Epoch | Dataset | decoding | wer | cer | mer | wil | wip | 79 | | :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------ | :------ | 80 | | 60 | test-clean | greedy | 0.0848106 | 0.0286257 | 0.0841686 | 0.14896 | 0.85104 | 81 | | 60 | test-other | greedy | 0.217221 | 0.0913044 | 0.213409 | 0.3555 | 0.6445 | -------------------------------------------------------------------------------- /examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-batch-loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-batch-loss.jpg -------------------------------------------------------------------------------- /examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-epoch-loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-epoch-loss.jpg -------------------------------------------------------------------------------- /examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-lr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-lr.jpg -------------------------------------------------------------------------------- /examples/models/ctc/conformer/small-streaming.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.conformer>Conformer 3 | config: 4 | name: conformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: null 13 | encoder_subsampling: 14 | class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling 15 | config: 16 | filters: [176, 176] 17 | kernels: [3, 3] 18 | strides: [2, 2] 19 | paddings: ["causal", "causal"] 20 | norms: ["layer", "layer"] 21 | activations: ["swish", "swish"] 22 | encoder_ffm_residual_factor: 0.5 23 | encoder_mhsam_residual_factor: 1.0 24 | encoder_convm_residual_factor: 1.0 25 | encoder_dmodel: 176 26 | encoder_num_blocks: 16 27 | encoder_head_size: 44 # == dmodel // num_heads 28 | encoder_num_heads: 4 29 | encoder_mha_type: relmha 30 | encoder_interleave_relpe: True 31 | encoder_use_attention_causal_mask: False 32 | encoder_use_attention_auto_mask: True 33 | encoder_mhsam_use_attention_bias: True 34 | encoder_convm_dw_norm_type: layer 35 | encoder_kernel_size: 31 36 | encoder_dropout: 0.1 37 | encoder_padding: causal 38 | encoder_memory_length: null 39 | encoder_history_size: 64 # frames = 4 * chunk_size 40 | encoder_chunk_size: 16 # frames 41 | blank: 0 42 | vocab_size: {{decoder_config.vocabsize}} 43 | kernel_regularizer: 44 | class_name: l2 45 | config: 46 | l2: 1e-6 47 | 48 | learning_config: 49 | optimizer_config: 50 | class_name: Adam 51 | config: 52 | learning_rate: 53 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 54 | config: 55 | dmodel: 176 56 | warmup_steps: 10000 57 | max_lr: 0.05/(176**0.5) 58 | min_lr: null 59 | scale: 2.0 60 | beta_1: 0.9 61 | beta_2: 0.98 62 | epsilon: 1e-9 63 | weight_decay: 1e-6 64 | 65 | gwn_config: null 66 | 67 | gradn_config: null 68 | 69 | batch_size: 8 70 | ga_steps: 4 71 | num_epochs: 450 72 | 73 | callbacks: 74 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 75 | config: {} 76 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 77 | config: 78 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 79 | save_best_only: False 80 | save_weights_only: True 81 | save_freq: epoch 82 | - class_name: tensorflow_asr.callbacks>TensorBoard 83 | config: 84 | log_dir: {{modeldir}}/tensorboard 85 | histogram_freq: 0 86 | write_graph: False 87 | write_images: False 88 | write_steps_per_second: False 89 | update_freq: batch 90 | profile_batch: 0 91 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 92 | config: 93 | model_handle: {{kaggle_model_handle}} 94 | model_dir: {{modeldir}} 95 | save_freq: epoch -------------------------------------------------------------------------------- /examples/models/ctc/conformer/small.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.conformer>Conformer 3 | config: 4 | name: conformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 0.5 16 | num_masks: 5 17 | mask_factor: -1 # whole utterance 18 | p_upperbound: 0.05 19 | mask_value: 0 20 | freq_masking: 21 | prob: 0.5 22 | num_masks: 2 23 | mask_factor: 27 24 | mask_value: 0 25 | encoder_subsampling: 26 | class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling 27 | config: 28 | filters: [176, 176] 29 | kernels: [3, 3] 30 | strides: [2, 2] 31 | paddings: ["causal", "causal"] 32 | norms: ["batch", "batch"] 33 | activations: ["swish", "swish"] 34 | encoder_ffm_residual_factor: 0.5 35 | encoder_mhsam_residual_factor: 1.0 36 | encoder_convm_residual_factor: 1.0 37 | encoder_dmodel: 176 38 | encoder_num_blocks: 16 39 | encoder_head_size: 44 # == dmodel // num_heads 40 | encoder_num_heads: 4 41 | encoder_mha_type: relmha 42 | encoder_interleave_relpe: True 43 | encoder_use_attention_causal_mask: False 44 | encoder_use_attention_auto_mask: True 45 | encoder_mhsam_use_attention_bias: True 46 | encoder_kernel_size: 31 47 | encoder_dropout: 0.1 48 | encoder_padding: causal 49 | encoder_memory_length: null 50 | blank: 0 51 | vocab_size: {{decoder_config.vocabsize}} 52 | kernel_regularizer: 53 | class_name: l2 54 | config: 55 | l2: 1e-6 56 | 57 | learning_config: 58 | optimizer_config: 59 | class_name: Adam 60 | config: 61 | learning_rate: 62 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 63 | config: 64 | dmodel: 176 65 | warmup_steps: 10000 66 | max_lr: 0.05/(176**0.5) 67 | min_lr: null 68 | scale: 2.0 69 | beta_1: 0.9 70 | beta_2: 0.98 71 | epsilon: 1e-9 72 | weight_decay: 1e-6 73 | 74 | gwn_config: null 75 | 76 | gradn_config: null 77 | 78 | batch_size: 8 79 | ga_steps: 4 80 | num_epochs: 450 81 | 82 | callbacks: 83 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 84 | config: {} 85 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 86 | config: 87 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 88 | save_best_only: False 89 | save_weights_only: True 90 | save_freq: epoch 91 | - class_name: tensorflow_asr.callbacks>TensorBoard 92 | config: 93 | log_dir: {{modeldir}}/tensorboard 94 | histogram_freq: 0 95 | write_graph: False 96 | write_images: False 97 | write_steps_per_second: False 98 | update_freq: batch 99 | profile_batch: 0 100 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 101 | config: 102 | model_handle: {{kaggle_model_handle}} 103 | model_dir: {{modeldir}} 104 | save_freq: epoch -------------------------------------------------------------------------------- /examples/models/ctc/deepspeech2/base.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.deepspeech2>DeepSpeech2 3 | config: 4 | name: deepspeech2 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 160 11 | feature_type: spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 5 17 | mask_factor: -1 # whole utterance 18 | p_upperbound: 0.05 19 | mask_value: 0 20 | freq_masking: 21 | prob: 1.0 22 | num_masks: 1 23 | mask_factor: 27 24 | mask_value: 0 25 | conv_type: conv2d 26 | conv_kernels: [ [ 11, 41 ], [ 11, 21 ] ] 27 | conv_strides: [ [ 2, 2 ], [ 1, 2 ] ] 28 | conv_filters: [ 32, 32 ] 29 | conv_activation: relu 30 | conv_padding: same 31 | conv_initializer: he_uniform 32 | rnn_nlayers: 5 33 | rnn_type: lstm 34 | rnn_units: 512 35 | rnn_bidirectional: True 36 | rnn_unroll: False 37 | rnn_rowconv: 0 38 | rnn_rowconv_activation: relu 39 | rnn_dropout: 0.5 40 | fc_nlayers: 1 41 | fc_units: 1024 42 | fc_activation: relu 43 | fc_dropout: 0.5 44 | fc_initializer: he_uniform 45 | blank: 0 46 | vocab_size: {{decoder_config.vocabsize}} 47 | kernel_regularizer: 48 | class_name: l2 49 | config: 50 | l2: 0.0005 51 | bias_regularizer: 52 | class_name: l2 53 | config: 54 | l2: 0.0005 55 | 56 | learning_config: 57 | optimizer_config: 58 | class_name: Adam 59 | config: 60 | learning_rate: 61 | class_name: ExponentialDecay 62 | module: keras.src.optimizers.schedules.learning_rate_schedule 63 | config: 64 | initial_learning_rate: 0.0001 65 | decay_steps: 5000 66 | decay_rate: 0.9 67 | staircase: True 68 | 69 | gwn_config: null 70 | 71 | gradn_config: null 72 | 73 | batch_size: 16 74 | ga_steps: 4 75 | num_epochs: 450 76 | 77 | callbacks: 78 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 79 | config: {} 80 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 81 | config: 82 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 83 | save_best_only: False 84 | save_weights_only: True 85 | save_freq: epoch 86 | - class_name: tensorflow_asr.callbacks>TensorBoard 87 | config: 88 | log_dir: {{modeldir}}/tensorboard 89 | histogram_freq: 0 90 | write_graph: False 91 | write_images: False 92 | write_steps_per_second: False 93 | update_freq: batch 94 | profile_batch: 0 95 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 96 | config: 97 | model_handle: {{kaggle_model_handle}} 98 | model_dir: {{modeldir}} 99 | save_freq: epoch 100 | 101 | -------------------------------------------------------------------------------- /examples/models/ctc/deepspeech2/uni.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.deepspeech2>DeepSpeech2 3 | config: 4 | name: deepspeech2 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 160 11 | feature_type: spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 5 17 | mask_factor: -1 # whole utterance 18 | p_upperbound: 0.05 19 | mask_value: 0 20 | freq_masking: 21 | prob: 1.0 22 | num_masks: 1 23 | mask_factor: 27 24 | mask_value: 0 25 | conv_type: conv2d 26 | conv_kernels: [ [ 11, 41 ], [ 11, 21 ] ] 27 | conv_strides: [ [ 2, 2 ], [ 1, 2 ] ] 28 | conv_filters: [ 32, 32 ] 29 | conv_activation: relu 30 | conv_padding: causal 31 | conv_initializer: he_uniform 32 | rnn_nlayers: 5 33 | rnn_type: lstm 34 | rnn_units: 512 35 | rnn_bidirectional: False 36 | rnn_unroll: False 37 | rnn_rowconv: 3 38 | rnn_rowconv_activation: relu 39 | rnn_dropout: 0.1 40 | fc_nlayers: 1 41 | fc_units: 1024 42 | fc_activation: relu 43 | fc_dropout: 0.1 44 | fc_initializer: he_uniform 45 | blank: 0 46 | vocab_size: {{decoder_config.vocabsize}} 47 | kernel_regularizer: 48 | class_name: l2 49 | config: 50 | l2: 0.0005 51 | bias_regularizer: 52 | class_name: l2 53 | config: 54 | l2: 0.0005 55 | 56 | learning_config: 57 | optimizer_config: 58 | class_name: Adam 59 | config: 60 | learning_rate: 61 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 62 | config: 63 | dmodel: 512 64 | warmup_steps: 10000 65 | min_lr: 1e-6 66 | scale: 2.0 67 | beta_1: 0.9 68 | beta_2: 0.98 69 | epsilon: 1e-9 70 | weight_decay: 1e-6 71 | 72 | gwn_config: null 73 | 74 | gradn_config: null 75 | 76 | batch_size: 16 77 | ga_steps: 4 78 | num_epochs: 450 79 | 80 | callbacks: 81 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 82 | config: {} 83 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 84 | config: 85 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 86 | save_best_only: False 87 | save_weights_only: True 88 | save_freq: epoch 89 | - class_name: tensorflow_asr.callbacks>TensorBoard 90 | config: 91 | log_dir: {{modeldir}}/tensorboard 92 | histogram_freq: 0 93 | write_graph: False 94 | write_images: False 95 | write_steps_per_second: False 96 | update_freq: batch 97 | profile_batch: 0 98 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 99 | config: 100 | model_handle: {{kaggle_model_handle}} 101 | model_dir: {{modeldir}} 102 | save_freq: epoch 103 | 104 | -------------------------------------------------------------------------------- /examples/models/ctc/jasper/base.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.jasper>Jasper 3 | config: 4 | name: jasper 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | log_base: "10" 13 | dense: True 14 | first_additional_block_channels: 256 15 | first_additional_block_kernels: 11 16 | first_additional_block_strides: 2 17 | first_additional_block_dilation: 1 18 | first_additional_block_dropout: 0.2 19 | nsubblocks: 3 20 | block_channels: [ 256, 384, 512, 640, 768 ] 21 | block_kernels: [ 11, 13, 17, 21, 25 ] 22 | block_dropout: [ 0.2, 0.2, 0.2, 0.3, 0.3 ] 23 | second_additional_block_channels: 896 24 | second_additional_block_kernels: 1 25 | second_additional_block_strides: 1 26 | second_additional_block_dilation: 2 27 | second_additional_block_dropout: 0.4 28 | third_additional_block_channels: 1024 29 | third_additional_block_kernels: 1 30 | third_additional_block_strides: 1 31 | third_additional_block_dilation: 1 32 | third_additional_block_dropout: 0.4 33 | blank: 0 34 | vocab_size: {{decoder_config.vocabsize}} 35 | kernel_regularizer: 36 | class_name: l2 37 | config: 38 | l2: 1e-6 39 | 40 | learning_config: 41 | optimizer_config: 42 | class_name: Adam 43 | config: 44 | learning_rate: 0.001 45 | beta_1: 0.9 46 | beta_2: 0.98 47 | epsilon: 1e-9 48 | 49 | gwn_config: null 50 | 51 | gradn_config: null 52 | 53 | batch_size: 16 54 | ga_steps: 4 55 | num_epochs: 450 56 | 57 | callbacks: 58 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 59 | config: {} 60 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 61 | config: 62 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 63 | save_best_only: False 64 | save_weights_only: True 65 | save_freq: epoch 66 | - class_name: tensorflow_asr.callbacks>TensorBoard 67 | config: 68 | log_dir: {{modeldir}}/tensorboard 69 | histogram_freq: 0 70 | write_graph: False 71 | write_images: False 72 | write_steps_per_second: False 73 | update_freq: batch 74 | profile_batch: 0 75 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 76 | config: 77 | model_handle: {{kaggle_model_handle}} 78 | model_dir: {{modeldir}} 79 | save_freq: epoch 80 | 81 | -------------------------------------------------------------------------------- /examples/models/ctc/transformer/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/ctc/transformer/README.md -------------------------------------------------------------------------------- /examples/models/ctc/transformer/base-streaming.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.transformer>Transformer 3 | config: 4 | name: transformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 5 17 | mask_factor: -1 18 | p_upperbound: 0.05 19 | freq_masking: 20 | prob: 1.0 21 | num_masks: 2 22 | mask_factor: 27 23 | encoder_subsampling: 24 | type: conv2d 25 | filters: [512, 512] 26 | kernels: [3, 3] 27 | strides: [2, 2] 28 | paddings: ["causal", "causal"] 29 | norms: ["batch", "batch"] 30 | activations: ["relu", "relu"] 31 | encoder_dropout: 0.1 32 | encoder_residual_factor: 1.0 33 | encoder_norm_position: post 34 | encoder_dmodel: 512 35 | encoder_dff: 1024 36 | encoder_num_blocks: 6 37 | encoder_head_size: 128 38 | encoder_num_heads: 4 39 | encoder_mha_type: mha 40 | encoder_interleave_relpe: True 41 | encoder_use_attention_causal_mask: False 42 | encoder_use_attention_auto_mask: True 43 | encoder_pwffn_activation: relu 44 | encoder_memory_length: null 45 | encoder_history_size: 64 # frames = 4 * chunk_size 46 | encoder_chunk_size: 16 # frames 47 | blank: 0 48 | vocab_size: {{decoder_config.vocabsize}} 49 | kernel_regularizer: 50 | class_name: l2 51 | config: 52 | l2: 1e-6 53 | 54 | learning_config: 55 | optimizer_config: 56 | class_name: Adam 57 | config: 58 | learning_rate: 59 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 60 | config: 61 | dmodel: 512 62 | warmup_steps: 10000 63 | max_lr: null 64 | min_lr: null 65 | beta_1: 0.9 66 | beta_2: 0.98 67 | epsilon: 1e-9 68 | 69 | gwn_config: 70 | predict_net_step: 0 71 | predict_net_stddev: 0.075 72 | 73 | gradn_config: null 74 | 75 | batch_size: 8 76 | ga_steps: 4 77 | num_epochs: 450 78 | 79 | callbacks: 80 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 81 | config: {} 82 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 83 | config: 84 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 85 | save_best_only: False 86 | save_weights_only: True 87 | save_freq: epoch 88 | - class_name: tensorflow_asr.callbacks>TensorBoard 89 | config: 90 | log_dir: {{modeldir}}/tensorboard 91 | histogram_freq: 0 92 | write_graph: False 93 | write_images: False 94 | write_steps_per_second: False 95 | update_freq: batch 96 | profile_batch: 0 97 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 98 | config: 99 | model_handle: {{kaggle_model_handle}} 100 | model_dir: {{modeldir}} 101 | save_freq: epoch 102 | -------------------------------------------------------------------------------- /examples/models/ctc/transformer/base.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.ctc.transformer>Transformer 3 | config: 4 | name: transformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 5 17 | mask_factor: -1 18 | p_upperbound: 0.05 19 | freq_masking: 20 | prob: 1.0 21 | num_masks: 2 22 | mask_factor: 27 23 | encoder_subsampling: 24 | type: conv2d 25 | filters: [512, 512] 26 | kernels: [3, 3] 27 | strides: [2, 2] 28 | paddings: ["causal", "causal"] 29 | norms: ["batch", "batch"] 30 | activations: ["relu", "relu"] 31 | encoder_dropout: 0.1 32 | encoder_residual_factor: 1.0 33 | encoder_norm_position: post 34 | encoder_dmodel: 512 35 | encoder_dff: 1024 36 | encoder_num_blocks: 6 37 | encoder_head_size: 128 38 | encoder_num_heads: 4 39 | encoder_mha_type: mha 40 | encoder_interleave_relpe: True 41 | encoder_use_attention_causal_mask: False 42 | encoder_use_attention_auto_mask: True 43 | encoder_pwffn_activation: relu 44 | encoder_memory_length: null 45 | blank: 0 46 | vocab_size: {{decoder_config.vocabsize}} 47 | kernel_regularizer: 48 | class_name: l2 49 | config: 50 | l2: 1e-6 51 | 52 | learning_config: 53 | optimizer_config: 54 | class_name: Adam 55 | config: 56 | learning_rate: 57 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 58 | config: 59 | dmodel: 512 60 | warmup_steps: 10000 61 | max_lr: null 62 | min_lr: null 63 | beta_1: 0.9 64 | beta_2: 0.98 65 | epsilon: 1e-9 66 | 67 | gwn_config: null 68 | 69 | gradn_config: null 70 | 71 | batch_size: 8 72 | ga_steps: 4 73 | num_epochs: 450 74 | 75 | callbacks: 76 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 77 | config: {} 78 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 79 | config: 80 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 81 | save_best_only: False 82 | save_weights_only: True 83 | save_freq: epoch 84 | - class_name: tensorflow_asr.callbacks>TensorBoard 85 | config: 86 | log_dir: {{modeldir}}/tensorboard 87 | histogram_freq: 0 88 | write_graph: False 89 | write_images: False 90 | write_steps_per_second: False 91 | update_freq: batch 92 | profile_batch: 0 93 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 94 | config: 95 | model_handle: {{kaggle_model_handle}} 96 | model_dir: {{modeldir}} 97 | save_freq: epoch 98 | -------------------------------------------------------------------------------- /examples/models/transducer/conformer/README.md: -------------------------------------------------------------------------------- 1 | # Conformer Transducer 2 | 3 | ## Results 4 | 5 | See [results](./results) for more details. -------------------------------------------------------------------------------- /examples/models/transducer/conformer/inference/gen_saved_model.py: -------------------------------------------------------------------------------- 1 | # # pylint: disable=no-member 2 | # # Copyright 2020 Huy Le Nguyen (@nglehuy) 3 | # # 4 | # # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # # you may not use this file except in compliance with the License. 6 | # # You may obtain a copy of the License at 7 | # # 8 | # # http://www.apache.org/licenses/LICENSE-2.0 9 | # # 10 | # # Unless required by applicable law or agreed to in writing, software 11 | # # distributed under the License is distributed on an "AS IS" BASIS, 12 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # # See the License for the specific language governing permissions and 14 | # # limitations under the License. 15 | 16 | # import os 17 | 18 | # import fire 19 | # from tensorflow_asr import tf, keras 20 | 21 | # from tensorflow_asr.configs import Config 22 | # from tensorflow_asr.helpers import featurizer_helpers 23 | # from tensorflow_asr.models.transducer.conformer import Conformer 24 | # from tensorflow_asr.utils import env_util 25 | 26 | # logger = env_util.setup_environment() 27 | 28 | # DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") 29 | 30 | 31 | # def main( 32 | # config_path: str = DEFAULT_YAML, 33 | # saved: str = None, 34 | # output_dir: str = None, 35 | # ): 36 | # assert saved and output_dir 37 | # tf.random.set_seed(0) 38 | # keras.backend.clear_session() 39 | 40 | # logger.info("Load config and featurizers ...") 41 | # config = Config(config_path) 42 | # speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) 43 | 44 | # logger.info("Build and load model ...") 45 | # conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) 46 | # conformer.make(speech_featurizer.shape) 47 | # conformer.add_featurizers(speech_featurizer, text_featurizer) 48 | # conformer.load_weights(saved, by_name=True) 49 | # conformer.summary() 50 | 51 | # logger.info("Save model ...") 52 | # tf.saved_model.save(conformer, export_dir=output_dir, signatures=conformer.recognize_from_signal.get_concrete_function()) 53 | 54 | 55 | # if __name__ == "__main__": 56 | # fire.Fire(main) 57 | -------------------------------------------------------------------------------- /examples/models/transducer/conformer/inference/run_saved_model.py: -------------------------------------------------------------------------------- 1 | # # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # # 3 | # # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # # you may not use this file except in compliance with the License. 5 | # # You may obtain a copy of the License at 6 | # # 7 | # # http://www.apache.org/licenses/LICENSE-2.0 8 | # # 9 | # # Unless required by applicable law or agreed to in writing, software 10 | # # distributed under the License is distributed on an "AS IS" BASIS, 11 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # # See the License for the specific language governing permissions and 13 | # # limitations under the License. 14 | 15 | # import os 16 | 17 | # import fire 18 | # from tensorflow_asr import tf, keras 19 | 20 | # from tensorflow_asr.features.speech_featurizers import read_raw_audio 21 | # from tensorflow_asr.utils import env_util 22 | 23 | # logger = env_util.setup_environment() 24 | 25 | # DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") 26 | 27 | 28 | # def main( 29 | # saved_model: str = None, 30 | # filename: str = None, 31 | # ): 32 | # keras.backend.clear_session() 33 | 34 | # module = tf.saved_model.load(export_dir=saved_model) 35 | 36 | # signal = read_raw_audio(filename) 37 | # transcript = module.pred(signal) 38 | 39 | # print("Transcript: ", "".join([chr(u) for u in transcript])) 40 | 41 | 42 | # if __name__ == "__main__": 43 | # fire.Fire(main) 44 | -------------------------------------------------------------------------------- /examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-batch-loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-batch-loss.jpg -------------------------------------------------------------------------------- /examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-epoch-loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-epoch-loss.jpg -------------------------------------------------------------------------------- /examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-lr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-lr.jpg -------------------------------------------------------------------------------- /examples/models/transducer/conformer/small-streaming.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.transducer.conformer>Conformer 3 | config: 4 | name: conformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | encoder_subsampling: 13 | class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling 14 | config: 15 | filters: [144, 144] 16 | kernels: [3, 3] 17 | strides: [2, 2] 18 | paddings: ["causal", "causal"] 19 | norms: ["layer", "layer"] 20 | activations: ["swish", "swish"] 21 | encoder_ffm_residual_factor: 0.5 22 | encoder_mhsam_residual_factor: 1.0 23 | encoder_convm_residual_factor: 1.0 24 | encoder_dmodel: 144 25 | encoder_num_blocks: 16 26 | encoder_head_size: 36 # == dmodel // num_heads 27 | encoder_num_heads: 4 28 | encoder_mha_type: relmha 29 | encoder_interleave_relpe: True 30 | encoder_use_attention_causal_mask: False 31 | encoder_use_attention_auto_mask: True 32 | encoder_mhsam_use_attention_bias: False 33 | encoder_convm_dw_norm_type: layer 34 | encoder_kernel_size: 31 35 | encoder_dropout: 0.1 36 | encoder_padding: causal 37 | encoder_memory_length: null 38 | encoder_history_size: 64 # frames = 4 * chunk_size 39 | encoder_chunk_size: 16 # frames 40 | prediction_label_encode_mode: embedding 41 | prediction_embed_dim: 320 42 | prediction_num_rnns: 1 43 | prediction_rnn_units: 320 44 | prediction_rnn_type: lstm 45 | prediction_rnn_implementation: 2 46 | prediction_rnn_unroll: False 47 | prediction_layer_norm: True 48 | prediction_projection_units: 0 49 | joint_dim: 320 50 | prejoint_encoder_linear: True 51 | prejoint_prediction_linear: True 52 | postjoint_linear: False 53 | joint_activation: tanh 54 | joint_mode: add 55 | blank: 0 56 | vocab_size: {{decoder_config.vocabsize}} 57 | kernel_regularizer: 58 | class_name: l2 59 | config: 60 | l2: 1e-6 61 | 62 | learning_config: 63 | optimizer_config: 64 | class_name: Adam 65 | config: 66 | learning_rate: 67 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 68 | config: 69 | dmodel: 144 70 | warmup_steps: 10000 71 | max_lr: 0.05/(144**0.5) 72 | min_lr: null 73 | scale: 2.0 74 | beta_1: 0.9 75 | beta_2: 0.98 76 | epsilon: 1e-9 77 | weight_decay: 1e-6 78 | 79 | gwn_config: null 80 | 81 | gradn_config: null 82 | 83 | batch_size: 8 84 | ga_steps: 4 85 | num_epochs: 300 86 | 87 | callbacks: 88 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 89 | config: {} 90 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 91 | config: 92 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 93 | save_best_only: False 94 | save_weights_only: True 95 | save_freq: epoch 96 | - class_name: tensorflow_asr.callbacks>TensorBoard 97 | config: 98 | log_dir: {{modeldir}}/tensorboard 99 | histogram_freq: 0 100 | write_graph: False 101 | write_images: False 102 | write_steps_per_second: False 103 | update_freq: batch 104 | profile_batch: 0 105 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 106 | config: 107 | model_handle: {{kaggle_model_handle}} 108 | model_dir: {{modeldir}} 109 | save_freq: epoch -------------------------------------------------------------------------------- /examples/models/transducer/conformer/small.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.transducer.conformer>Conformer 3 | config: 4 | name: conformer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | nfft: 512 10 | num_feature_bins: 80 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 10 17 | mask_factor: -1 18 | p_upperbound: 0.05 19 | mask_value: 0 20 | freq_masking: 21 | prob: 1.0 22 | num_masks: 1 23 | mask_factor: 27 24 | mask_value: 0 25 | encoder_subsampling: 26 | class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling 27 | config: 28 | filters: [144, 144] 29 | kernels: [3, 3] 30 | strides: [2, 2] 31 | paddings: ["causal", "causal"] 32 | norms: ["batch", "batch"] 33 | activations: ["swish", "swish"] 34 | encoder_ffm_residual_factor: 0.5 35 | encoder_mhsam_residual_factor: 1.0 36 | encoder_convm_residual_factor: 1.0 37 | encoder_dmodel: 144 38 | encoder_num_blocks: 16 39 | encoder_head_size: 36 # == dmodel // num_heads 40 | encoder_num_heads: 4 41 | encoder_mha_type: relmha 42 | encoder_interleave_relpe: True 43 | encoder_use_attention_causal_mask: False 44 | encoder_use_attention_auto_mask: True 45 | encoder_mhsam_use_attention_bias: False 46 | encoder_kernel_size: 31 47 | encoder_dropout: 0.1 48 | encoder_padding: causal 49 | encoder_memory_length: null 50 | prediction_label_encode_mode: embedding 51 | prediction_embed_dim: 320 52 | prediction_num_rnns: 1 53 | prediction_rnn_units: 320 54 | prediction_rnn_type: lstm 55 | prediction_rnn_implementation: 2 56 | prediction_rnn_unroll: False 57 | prediction_layer_norm: True 58 | prediction_projection_units: 0 59 | joint_dim: 320 60 | prejoint_encoder_linear: True 61 | prejoint_prediction_linear: True 62 | postjoint_linear: False 63 | joint_activation: tanh 64 | joint_mode: add 65 | blank: 0 66 | vocab_size: {{decoder_config.vocabsize}} 67 | kernel_regularizer: 68 | class_name: l2 69 | config: 70 | l2: 1e-6 71 | 72 | learning_config: 73 | optimizer_config: 74 | class_name: Adam 75 | config: 76 | learning_rate: 77 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 78 | config: 79 | dmodel: 144 80 | warmup_steps: 10000 81 | max_lr: 0.05/(144**0.5) 82 | min_lr: null 83 | scale: 2.0 84 | beta_1: 0.9 85 | beta_2: 0.98 86 | epsilon: 1e-9 87 | weight_decay: 1e-6 88 | 89 | gwn_config: null 90 | 91 | gradn_config: null 92 | 93 | batch_size: 2 94 | ga_steps: 16 95 | num_epochs: 300 96 | 97 | callbacks: 98 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 99 | config: {} 100 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 101 | config: 102 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 103 | save_best_only: False 104 | save_weights_only: True 105 | save_freq: epoch 106 | - class_name: tensorflow_asr.callbacks>TensorBoard 107 | config: 108 | log_dir: {{modeldir}}/tensorboard 109 | histogram_freq: 0 110 | write_graph: False 111 | write_images: False 112 | write_steps_per_second: False 113 | update_freq: batch 114 | profile_batch: 0 115 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 116 | config: 117 | model_handle: {{kaggle_model_handle}} 118 | model_dir: {{modeldir}} 119 | save_freq: epoch -------------------------------------------------------------------------------- /examples/models/transducer/contextnet/README.md: -------------------------------------------------------------------------------- 1 | # ContextNet Transducer 2 | 3 | ## Results 4 | 5 | See [results](./results) for more details. -------------------------------------------------------------------------------- /examples/models/transducer/contextnet/results/wordpiece/README.md: -------------------------------------------------------------------------------- 1 | **Table of Contents** 2 | - [WordPiece 1k With Whitespace + Small + LibriSpeech](#wordpiece-1k-with-whitespace--small--librispeech) 3 | - [Epoch Loss](#epoch-loss) 4 | - [Batch Loss](#batch-loss) 5 | - [Training Learning Rate](#training-learning-rate) 6 | - [Results](#results) 7 | 8 | # WordPiece 1k With Whitespace + Small + LibriSpeech 9 | 10 | 11 | | Category | Description | 12 | | :---------------- | :--------------------------------- | 13 | | Config | [small.yml.j2](../../small.yml.j2) | 14 | | Tensorflow | **2.13.x** | 15 | | Device | Google Colab TPUs | 16 | | Global Batch Size | 2 * 16 * 8 = 256 (as 8 TPUs) | 17 | 18 | 19 | ### Epoch Loss 20 | 21 | ![Epoch Loss](./figs/contextnet-small-wp1k-whitespace-epoch-loss.svg) 22 | 23 | ### Batch Loss 24 | 25 | ![Batch Loss](./figs/contextnet-small-wp1k-whitespace-batch-loss.svg) 26 | 27 | ### Training Learning Rate 28 | 29 | ![Learning Rate](./figs/contextnet-small-wp1k-whitespace-lr.svg) 30 | 31 | ### Results 32 | 33 | Pretrain Model here: [link](https://drive.google.com/drive/folders/1xT3j_L5q4oSBeUiLArnBPliZ0g9k-N7O?usp=drive_link) 34 | 35 | ```json 36 | [ 37 | { 38 | "epoch": 273, 39 | "test-clean": { 40 | "greedy": { 41 | "wer": 0.07923767498478393, 42 | "cer": 0.0336269669307001, 43 | "mer": 0.07840111410128536, 44 | "wil": 0.13531145375649656, 45 | "wip": 0.8646885462435034 46 | } 47 | }, 48 | "test-other": { 49 | "greedy": { 50 | "wer": 0.19121945627877654, 51 | "cer": 0.09776798480704507, 52 | "mer": 0.1870526453493805, 53 | "wil": 0.3107931720744128, 54 | "wip": 0.6892068279255872 55 | } 56 | } 57 | } 58 | ] 59 | ``` -------------------------------------------------------------------------------- /examples/models/transducer/rnnt/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/examples/models/transducer/rnnt/README.md -------------------------------------------------------------------------------- /examples/models/transducer/rnnt/results/sentencepiece/README.md: -------------------------------------------------------------------------------- 1 | - [SentencePiece 256 + Tiny + LibriSpeech](#sentencepiece-256--tiny--librispeech) 2 | - [Training Loss](#training-loss) 3 | - [1. Epoch Loss](#1-epoch-loss) 4 | - [2. Batch Loss](#2-batch-loss) 5 | - [Results](#results) 6 | 7 | 8 | # SentencePiece 256 + Tiny + LibriSpeech 9 | 10 | | Category | Description | 11 | | :---------------- | :------------------------------- | 12 | | Config | [tiny.yml.j2](../../tiny.yml.j2) | 13 | | Tensorflow | **2.15.x** | 14 | | Device | NVIDIA GeForce GTX 1650 | 15 | | Global Batch Size | 3 | 16 | | Max Epochs | 300 | 17 | 18 | 19 | ### Training Loss 20 | 21 | #### 1. Epoch Loss 22 | 23 | ![Epoch Loss](./figs/rnnt-tiny-sp256-epoch-loss.svg) 24 | 25 | #### 2. Batch Loss 26 | 27 | ![Batch Loss](./figs/rnnt-tiny-sp256-batch-loss.svg) 28 | 29 | 30 | ### Results 31 | 32 | Pretrain Model here: [link](https://drive.google.com/drive/folders/1h0BrCzZo8JTz_MUU5bJPJ3UBqroBnsuv?usp=sharing) 33 | 34 | ```json 35 | [ 36 | { 37 | "epoch": 136, 38 | "test-clean": { 39 | "greedy": { 40 | "wer": 0.15853241022519782, 41 | "cer": 0.07179696657549817, 42 | "mer": 0.15537908021549876, 43 | "wil": 0.2587056704145151, 44 | "wip": 0.7412943295854849 45 | } 46 | }, 47 | "test-other": { 48 | "greedy": { 49 | "wer": 0.3457577899623636, 50 | "cer": 0.18733822655980759, 51 | "mer": 0.33391759995571874, 52 | "wil": 0.5185365485613327, 53 | "wip": 0.48146345143866726 54 | } 55 | } 56 | }, 57 | ] -------------------------------------------------------------------------------- /examples/models/transducer/rnnt/results/subword - deprecated/README.md: -------------------------------------------------------------------------------- 1 | # RNN Transducer Subwords 2 | 3 | - [RNN Transducer Subwords](#rnn-transducer-subwords) 4 | - [v1.0.x](#v10x) 5 | 6 | 7 | ## v1.0.x 8 | 9 | **Summary** 10 | 11 | - Number of subwords: 1008 12 | - Maximum length of a subword: 10 13 | - Subwords corpus: all training sets 14 | - Number of parameters: 54,914,480 15 | - Number of epochs: 21 16 | - Train on: 8 Google Colab TPUs 17 | - Train hours: 10.5 days uncontinuous (each day I trained 2 epoch because colab only allows 12 hours/day and 1 epoch required 4.5 hours) => 94.5 hours continuous (3.9375 days) 18 | 19 | **Pretrained and Config**, go to [drive](https://drive.google.com/drive/folders/1rYpiYF0F9JIsAKN2DCFFtEdfNzVbBLHe?usp=sharing) 20 | 21 | **Epoch Transducer Loss** 22 | 23 | subword_rnnt_loss 24 | 25 | **Epoch Learning Rate** 26 | 27 | epoch_learning_rate 28 | 29 | **Error Rates** 30 | 31 | | **Test-clean** | Test batch size | Epoch | WER (%) | CER (%) | 32 | | :------------: | :-------------: | :---: | :---------------: | :---------------: | 33 | | _Greedy_ | 8 | 21 | 13.13907504081726 | 6.023869663476944 | 34 | | _Greedy_ | 8 | 25 | 12.79481202363968 | 5.671864375472069 | -------------------------------------------------------------------------------- /examples/models/transducer/rnnt/small.yml.j2: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: tensorflow_asr.models.transducer.rnnt>RnnTransducer 3 | config: 4 | name: rnn_transducer 5 | speech_config: 6 | sample_rate: 16000 7 | frame_ms: 25 8 | stride_ms: 10 9 | num_feature_bins: 80 10 | nfft: 512 11 | feature_type: log_mel_spectrogram 12 | augmentation_config: 13 | feature_augment: 14 | time_masking: 15 | prob: 1.0 16 | num_masks: 5 17 | mask_factor: -1 18 | p_upperbound: 0.05 19 | freq_masking: 20 | prob: 1.0 21 | num_masks: 1 22 | mask_factor: 27 23 | encoder_reduction_positions: [ post, post, post, post ] 24 | encoder_reduction_factors: [ 3, 0, 2, 0 ] # downsampled to 30ms and add 2 reduction after second layer 25 | encoder_dmodel: 320 26 | encoder_rnn_type: lstm 27 | encoder_rnn_units: 1024 28 | encoder_nlayers: 4 29 | encoder_layer_norm: True 30 | prediction_label_encode_mode: embedding 31 | prediction_embed_dim: 512 32 | prediction_num_rnns: 1 33 | prediction_rnn_units: 1024 34 | prediction_rnn_type: lstm 35 | prediction_rnn_unroll: False 36 | prediction_layer_norm: True 37 | prediction_projection_units: 0 38 | joint_dim: 320 39 | prejoint_encoder_linear: True 40 | prejoint_prediction_linear: True 41 | postjoint_linear: False 42 | joint_activation: tanh 43 | joint_mode: add 44 | blank: 0 45 | vocab_size: {{decoder_config.vocabsize}} 46 | kernel_regularizer: 47 | class_name: l2 48 | config: 49 | l2: 1e-6 50 | 51 | learning_config: 52 | optimizer_config: 53 | class_name: Adam 54 | config: 55 | learning_rate: 56 | class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule 57 | config: 58 | dmodel: 320 59 | warmup_steps: 10000 60 | max_lr: null 61 | min_lr: 1e-6 62 | scale: 2.0 63 | beta_1: 0.9 64 | beta_2: 0.98 65 | epsilon: 1e-9 66 | weight_decay: 1e-6 67 | 68 | gwn_config: 69 | predict_net_step: 20000 70 | predict_net_stddev: 0.075 71 | 72 | gradn_config: null 73 | 74 | batch_size: 4 75 | ga_steps: 8 76 | num_epochs: 300 77 | 78 | callbacks: 79 | - class_name: tensorflow_asr.callbacks>TerminateOnNaN 80 | config: {} 81 | - class_name: tensorflow_asr.callbacks>ModelCheckpoint 82 | config: 83 | filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 84 | save_best_only: False 85 | save_weights_only: True 86 | save_freq: epoch 87 | - class_name: tensorflow_asr.callbacks>TensorBoard 88 | config: 89 | log_dir: {{modeldir}}/tensorboard 90 | histogram_freq: 0 91 | write_graph: False 92 | write_images: False 93 | write_steps_per_second: False 94 | update_freq: batch 95 | profile_batch: 0 96 | - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore 97 | config: 98 | model_handle: {{kaggle_model_handle}} 99 | model_dir: {{modeldir}} 100 | save_freq: epoch 101 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 150 3 | 4 | [tool.isort] 5 | profile = "black" 6 | line_length = 150 7 | 8 | [tool.pytest.ini_options] 9 | minversion = "6.0" 10 | log_cli = true 11 | log_cli_level = "WARNING" 12 | log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 13 | log_cli_date_format = "%Y-%m-%d %H:%M:%S" 14 | testpaths = "tests" 15 | python_files = "test_*.py" 16 | addopts = "-s --durations=0" 17 | filterwarnings = ["error", "ignore::UserWarning", "ignore::DeprecationWarning"] 18 | asyncio_mode = "auto" 19 | asyncio_default_fixture_loop_scope = "session" 20 | -------------------------------------------------------------------------------- /requirements.apple.txt: -------------------------------------------------------------------------------- 1 | tensorflow~=2.18.0 2 | tensorflow-text @ https://github.com/sun1638650145/Libraries-and-Extensions-for-TensorFlow-for-Apple-Silicon/releases/download/v2.18/tensorflow_text-2.18.1-cp312-cp312-macosx_11_0_arm64.whl -------------------------------------------------------------------------------- /requirements.cpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow~=2.18.0 2 | tensorflow-text~=2.18.0 -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=7.4.1 2 | black>=24.3.0 3 | pylint>=3.2.4 4 | matplotlib>=3.7.2 5 | pydot-ng>=2.0.0 6 | graphviz>=0.20.1 7 | pre-commit>=3.7.0 8 | tf2onnx>=1.16.1 9 | netron>=8.0.3 -------------------------------------------------------------------------------- /requirements.gpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow[and-cuda]~=2.18.0 2 | tensorflow-text~=2.18.0 -------------------------------------------------------------------------------- /requirements.text.txt: -------------------------------------------------------------------------------- 1 | tensorflow-text~=2.18.0 -------------------------------------------------------------------------------- /requirements.tpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow-tpu~=2.18.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SoundFile~=0.12.1 2 | nltk>=3.9.0 3 | sentencepiece~=0.2.0 4 | tqdm>=4.67.1 5 | librosa~=0.10.1 6 | PyYAML~=6.0.1 7 | sounddevice~=0.4.6 8 | jinja2~=3.1.3 9 | fire>=0.7.0 10 | jiwer~=3.0.3 11 | keras-nightly~=3.9.0.dev # https://github.com/keras-team/keras/issues/20568#issuecomment-2510432421 12 | cached_property~=2.0.1 13 | ipywidgets~=8.1.5 14 | ipython<9.0.0 15 | kagglehub~=0.3.6 16 | datasets~=3.5.1 17 | tabulate~=0.9.0 -------------------------------------------------------------------------------- /scripts/install_ctc_decoders.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PROJECT_DIR=$(realpath "$(dirname $0)/..") 4 | 5 | mkdir -p $PROJECT_DIR/externals 6 | cd $PROJECT_DIR/externals || exit 7 | 8 | # Install baidu's beamsearch_with_lm 9 | if [ ! -d ctc_decoders ]; then 10 | git clone --depth 1 https://github.com/nglehuy/ctc_decoders.git 11 | cd ./ctc_decoders || exit 12 | chmod a+x setup.sh 13 | ./setup.sh 14 | fi 15 | 16 | cd $PROJECT_DIR || exit 17 | -------------------------------------------------------------------------------- /scripts/install_ctc_loss.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PROJECT_DIR=$(realpath "$(dirname $0)/..") 4 | cd "$PROJECT_DIR" || exit 5 | 6 | mkdir -p $PROJECT_DIR/externals 7 | cd $PROJECT_DIR/externals || exit 8 | 9 | TF_VERSION=$(python3 -c "import tensorflow as tf; print(tf.__version__)") 10 | 11 | # Install rnnt_loss 12 | if [ ! -d warp-ctc ]; then 13 | git clone --depth 1 https://github.com/nglehuy/warp-ctc.git 14 | cd $PROJECT_DIR/externals/warp-ctc/tensorflow_binding 15 | if [ ! -d tensorflow ]; then 16 | git clone --depth 1 --branch v$TF_VERSION https://github.com/tensorflow/tensorflow.git 17 | fi 18 | cd ../../ 19 | fi 20 | 21 | export TENSORFLOW_SRC_PATH="$PROJECT_DIR/externals/warp-ctc/tensorflow_binding/tensorflow" 22 | 23 | rm -rf $PROJECT_DIR/externals/warp-ctc/build 24 | mkdir -p $PROJECT_DIR/externals/warp-ctc/build 25 | cd $PROJECT_DIR/externals/warp-ctc/build || exit 26 | 27 | if [ "$CUDA_HOME" ]; then 28 | cmake \ 29 | -DWITH_GPU=ON \ 30 | -DCUDA_TOOLKIT_ROOT_DIR="$CUDA_HOME" .. 31 | else 32 | cmake \ 33 | -DWITH_GPU=OFF \ 34 | .. 35 | fi 36 | 37 | make -j $(nproc) 38 | 39 | cd $PROJECT_DIR/externals/warp-ctc/tensorflow_binding || exit 40 | 41 | if [ "$CUDA_HOME" ]; then 42 | CUDA="$CUDA_HOME" python3 setup.py install 43 | else 44 | python3 setup.py install 45 | fi 46 | 47 | cd $PROJECT_DIR || exit -------------------------------------------------------------------------------- /scripts/install_rnnt_loss.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PROJECT_DIR=$(realpath "$(dirname $0)/..") 4 | cd "$PROJECT_DIR" || exit 5 | 6 | mkdir -p $PROJECT_DIR/externals 7 | cd $PROJECT_DIR/externals || exit 8 | 9 | TF_VERSION=$(python3 -c "import tensorflow as tf; print(tf.__version__)") 10 | 11 | # Install rnnt_loss 12 | if [ ! -d warp-transducer ]; then 13 | git clone --depth 1 https://github.com/nglehuy/warp-transducer.git 14 | cd $PROJECT_DIR/externals/warp-transducer/tensorflow_binding 15 | if [ ! -d tensorflow ]; then 16 | git clone --depth 1 --branch v$TF_VERSION https://github.com/tensorflow/tensorflow.git 17 | fi 18 | cd ../../ 19 | fi 20 | 21 | export TENSORFLOW_SRC_PATH="$PROJECT_DIR/externals/warp-transducer/tensorflow_binding/tensorflow" 22 | 23 | rm -rf $PROJECT_DIR/externals/warp-transducer/build 24 | mkdir -p $PROJECT_DIR/externals/warp-transducer/build 25 | cd $PROJECT_DIR/externals/warp-transducer/build || exit 26 | 27 | if [ "$CUDA_HOME" ]; then 28 | cmake \ 29 | -DUSE_NAIVE_KERNEL=OFF \ 30 | -DWITH_GPU=ON \ 31 | -DCMAKE_C_COMPILER_LAUNCHER="$(which gcc)" \ 32 | -DCMAKE_CXX_COMPILER_LAUNCHER="$(which g++)" \ 33 | -DCUDA_TOOLKIT_ROOT_DIR="$CUDA_HOME" .. 34 | else 35 | cmake \ 36 | -DUSE_NAIVE_KERNEL=OFF \ 37 | -DWITH_GPU=OFF \ 38 | -DCMAKE_C_COMPILER_LAUNCHER="$(which gcc)" \ 39 | -DCMAKE_CXX_COMPILER_LAUNCHER="$(which g++)" .. 40 | fi 41 | 42 | make -j $(nproc) 43 | 44 | cd $PROJECT_DIR/externals/warp-transducer/tensorflow_binding || exit 45 | 46 | if [ "$CUDA_HOME" ]; then 47 | CUDA="$CUDA_HOME" python3 setup.py install 48 | else 49 | python3 setup.py install 50 | fi 51 | 52 | cd $PROJECT_DIR || exit -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E402,E701,E702,E704,E251,E203,W503,W504,C901,E501 3 | max-line-length = 150 4 | 5 | [pep8] 6 | ignore = E402,E701,E702,E704,E251,E203,W503,W504,C901,E501 7 | max-line-length = 150 8 | indent-size = 4 9 | 10 | [options.entry_points] 11 | console_scripts = 12 | tensorflow_asr = tensorflow_asr.scripts:main -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | 18 | from setuptools import find_packages, setup 19 | 20 | install_requires = [] 21 | extras_requires = {} 22 | 23 | for req_file in glob.glob("requirements*.txt", recursive=False): 24 | name = os.path.basename(req_file).split(".") 25 | extra = name[1] if len(name) > 2 else None 26 | with open(req_file, "r", encoding="utf-8") as fr: 27 | if not extra: 28 | install_requires = fr.readlines() 29 | else: 30 | extras_requires[extra] = fr.readlines() 31 | 32 | with open("README.md", "r", encoding="utf-8") as fh: 33 | long_description = fh.read() 34 | 35 | setup( 36 | name="TensorFlowASR", 37 | version="3.0.0", 38 | author="Huy Le Nguyen", 39 | author_email="nlhuy.cs.16@gmail.com", 40 | description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", 41 | long_description=long_description, 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/TensorSpeech/TensorFlowASR", 44 | packages=find_packages(include=("tensorflow_asr", "tensorflow_asr.*")), 45 | install_requires=install_requires, 46 | extras_require=extras_requires, 47 | classifiers=[ 48 | "Programming Language :: Python :: 3.6", 49 | "Programming Language :: Python :: 3.7", 50 | "Programming Language :: Python :: 3.8", 51 | "Programming Language :: Python :: 3.9", 52 | "Intended Audience :: Science/Research", 53 | "Operating System :: POSIX :: Linux", 54 | "License :: OSI Approved :: Apache Software License", 55 | "Topic :: Software Development :: Libraries :: Python Modules", 56 | ], 57 | python_requires=">=3.8, <4", 58 | ) 59 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 -m pip install -r requirements.text.txt 4 | 5 | case "$1" in 6 | tpu) 7 | python3 -m pip uninstall -y tensorflow 8 | python3 -m pip install -r requirements.tpu.txt -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force 9 | ;; 10 | gpu) 11 | python3 -m pip install -r requirements.gpu.txt 12 | ;; 13 | cpu) 14 | python3 -m pip install -r requirements.cpu.txt 15 | ;; 16 | apple) 17 | python3 -m pip install -r requirements.apple.txt 18 | ;; 19 | *) echo -e "Usage: $0 " 20 | esac 21 | 22 | python3 -m pip uninstall -y keras # use keras-nightly 23 | python3 -m pip install -r requirements.txt --force 24 | 25 | case "$2" in 26 | dev) 27 | python3 -m pip install -r requirements.dev.txt 28 | python3 -m pip install -e . 29 | ;; 30 | install) 31 | python3 -m pip install -e . 32 | ;; 33 | esac -------------------------------------------------------------------------------- /tensorflow_asr/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access 2 | import os 3 | 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = os.environ.get("TF_CPP_MIN_LOG_LEVEL") or "3" 5 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = os.environ.get("TF_FORCE_GPU_ALLOW_GROWTH", "true") 6 | 7 | # import submodules to register keras objects 8 | import glob 9 | from os.path import basename, dirname, isdir, isfile, join 10 | 11 | import keras 12 | import tensorflow as tf # for reference 13 | 14 | from tensorflow_asr.utils import env_util # import here fist to apply logging 15 | 16 | for fd in glob.glob(join(dirname(__file__), "*")): 17 | if not isfile(fd) and not isdir(fd): 18 | continue 19 | if isfile(fd) and not fd.endswith(".py"): 20 | continue 21 | fd = fd if isdir(fd) else fd[:-3] 22 | fd = basename(fd) 23 | if fd.startswith("__"): 24 | continue 25 | __import__(f"{__name__}.{fd}") 26 | 27 | 28 | __all__ = ["keras", "tf", "env_util"] 29 | -------------------------------------------------------------------------------- /tensorflow_asr/abstracts.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | 4 | from tensorflow_asr import tf 5 | 6 | 7 | class AbstractTokenizer(ABC): 8 | initialized: bool 9 | 10 | @abstractmethod 11 | def make(self): 12 | pass 13 | 14 | @abstractmethod 15 | def tokenize(self, text: str) -> tf.Tensor: 16 | pass 17 | 18 | @abstractmethod 19 | def detokenize(self, indices: tf.Tensor) -> tf.Tensor: 20 | pass 21 | 22 | @abstractmethod 23 | def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: 24 | pass 25 | 26 | 27 | class AbstractDataset(ABC): 28 | name: str 29 | num_entries: int 30 | 31 | @abstractmethod 32 | def read_entries(self): 33 | pass 34 | 35 | @abstractmethod 36 | def generator(self) -> typing.Generator: 37 | pass 38 | 39 | @abstractmethod 40 | def vocab_generator(self) -> typing.Generator: 41 | pass 42 | -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/README.md: -------------------------------------------------------------------------------- 1 | # Augmentations 2 | 3 | ```yaml 4 | augmentations: 5 | prob: 0.5 # a number between 0.0 and 1.0, this number indicates the randomness for signal_augment and feature_augment 6 | signal_augment: ... # augmentation on signal 7 | feature_augment: ... # augmentation on feature extracted from signal 8 | ``` 9 | 10 | ## Methods 11 | 12 | See [methods](./methods) 13 | 14 | Currently we have: 15 | - SpecAugment: Time Masking and Frequency Masking 16 | 17 | Custom augmentation methods is inherited from class `AugmentationMethod` with the function `augment` must be defined. -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/augmentations/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from tensorflow_asr import tf 18 | from tensorflow_asr.augmentations.methods import gaussnoise, specaugment 19 | from tensorflow_asr.augmentations.methods.base_method import AugmentationMethod 20 | 21 | AUGMENTATIONS = { 22 | "gauss_noise": gaussnoise.GaussNoise, 23 | "freq_masking": specaugment.FreqMasking, 24 | "time_masking": specaugment.TimeMasking, 25 | } 26 | 27 | 28 | class Augmentation: 29 | def __init__(self, config: dict = None): 30 | _config = config or {} 31 | self.signal_augmentations = self.parse(_config.pop("signal_augment", {})) 32 | self.feature_augmentations = self.parse(_config.pop("feature_augment", {})) 33 | 34 | def _augment(self, inputs, augmentations: List[AugmentationMethod]): 35 | outputs = inputs 36 | for au in augmentations: 37 | outputs = au.augment(outputs) 38 | # p = tf.random.uniform(shape=[], dtype=tf.float32) 39 | # outputs = tf.cond(tf.less(p, au.prob), lambda: au.augment(outputs), lambda: outputs) 40 | return outputs 41 | 42 | def signal_augment(self, inputs, inputs_length): 43 | """ 44 | Augment audio signals 45 | 46 | Parameters 47 | ---------- 48 | inputs : tf.Tensor, shape [B, None] 49 | Original audio signals 50 | inputs_length : tf.Tensor, shape [B] 51 | Original audio signals length 52 | 53 | Returns 54 | ------- 55 | tf.Tensor, shape [B, None] 56 | Augmented audio signals 57 | """ 58 | return tf.map_fn( 59 | fn=lambda x: self._augment(x, self.signal_augmentations), 60 | elems=(inputs, inputs_length), 61 | fn_output_signature=( 62 | tf.TensorSpec.from_tensor(inputs[0]), 63 | tf.TensorSpec.from_tensor(inputs_length[0]), 64 | ), 65 | ) 66 | 67 | def feature_augment(self, inputs, inputs_length): 68 | """ 69 | Augment audio features 70 | 71 | Parameters 72 | ---------- 73 | inputs : tf.Tensor, shape [B, T, F] 74 | Original audio features 75 | inputs_length : tf.Tensor, shape [B] 76 | Original audio features length 77 | 78 | Returns 79 | ------- 80 | tf.Tensor, shape [B, T, F] 81 | Augmented audio features 82 | """ 83 | return tf.map_fn( 84 | fn=lambda x: self._augment(x, self.feature_augmentations), 85 | elems=(inputs, inputs_length), 86 | fn_output_signature=( 87 | tf.TensorSpec.from_tensor(inputs[0]), 88 | tf.TensorSpec.from_tensor(inputs_length[0]), 89 | ), 90 | ) 91 | 92 | @staticmethod 93 | def parse(config: dict) -> list: 94 | augmentations = [] 95 | for key, value in sorted(config.items(), key=lambda x: x[0]): 96 | au = AUGMENTATIONS.get(key, None) 97 | if au is None: 98 | raise KeyError(f"No tf augmentation named: {key}\n" f"Available tf augmentations: {AUGMENTATIONS.keys()}") 99 | aug = au(**value) if value is not None else au() 100 | augmentations.append(aug) 101 | return augmentations 102 | -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/augmentations/methods/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/methods/base_method.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class AugmentationMethod: 17 | def __init__(self, prob: float = 0.5): 18 | self.prob = prob 19 | 20 | def augment(self, *args, **kwargs): 21 | raise NotImplementedError() 22 | -------------------------------------------------------------------------------- /tensorflow_asr/augmentations/methods/gaussnoise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import tf 16 | from tensorflow_asr.augmentations.methods.base_method import AugmentationMethod 17 | 18 | 19 | class GaussNoise(AugmentationMethod): 20 | def __init__( 21 | self, 22 | mean: float = 0.0, 23 | stddev: float = 0.075, 24 | prob: float = 0.5, 25 | ): 26 | super().__init__(prob=prob) 27 | self.mean = mean 28 | self.stddev = stddev 29 | 30 | def augment(self, args): 31 | inputs, inputs_length = args 32 | prob = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=tf.float32) 33 | do_apply = tf.where(tf.less_equal(prob, self.prob), tf.constant(1, inputs.dtype), tf.constant(0, inputs.dtype)) 34 | noise = tf.random.normal(shape=tf.shape(inputs), mean=self.mean, stddev=self.stddev, dtype=inputs.dtype) 35 | noise *= tf.sequence_mask(inputs_length, inputs.shape[1], dtype=inputs.dtype) 36 | noise *= do_apply 37 | return tf.add(inputs, noise), inputs_length 38 | -------------------------------------------------------------------------------- /tensorflow_asr/features/README.md: -------------------------------------------------------------------------------- 1 | # Features Extraction 2 | 3 | ## Speech Features 4 | 5 | **Speech features** are extracted from the **Signal** with `sample_rate`, `frame_ms`, `stride_ms` and `num_feature_bins`. 6 | 7 | Speech features has the shape `(B, T, num_feature_bins, num_channels)` and it contains from 1-4 channels: 8 | 9 | 1. Spectrogram, Log Mel Spectrogram, Log Gammatone Spectrogram or MFCCs 10 | 2. Delta features: `librosa.feature.delta` from the features extracted on channel 1. 11 | 3. Delta deltas features: `librosa.feature.delta` with `order=2` from the features extracted on channel 1. 12 | 4. Pitch features: `librosa.core.piptrack` from the signal 13 | 14 | There are 2 classes for Speech Features Extraction: `NumpySpeechFeaturizer` (uses `librosa`) and `TFSpeechFeaturizer` (uses `tf.signal`). The tf based class does not support `delta, delta_delta, pitch` features yet. 15 | 16 | _Note_: `TFSpeechFeaturizer` class **should be used** if you want to deploy `tflite`. 17 | 18 | ![Spectrogram](./figs/spectrogram.png) 19 | 20 | ![Log Mel Spectrogram](./figs/log_mel_spectrogram.png) 21 | 22 | ![MFCCs](./figs/mfcc.png) 23 | 24 | ![Log Gammatone Spectrogram](./figs/log_gammatone_spectrogram.png) 25 | 26 | ## Text Features 27 | 28 | **Text features** are read as index from the file like the default `tensorflow_asr.featurizers.english.txt` plus 1 for the blank index. 29 | 30 | The **blank** index is either `0` or `num_classes - 1` where `num_classes` is number of characters in your language (exclude blank). 31 | 32 | Class `TextFeaturizer` is initialized with the `decoder_config` as follows: 33 | 34 | ```yaml 35 | decoder_config: 36 | vocabulary: path_to_vocab_txt 37 | blank_at_zero: bool 38 | beam_width: int, 39 | lm_config: ... 40 | ``` 41 | 42 | ## TODO 43 | 44 | - Implement `TFSpeechFeaturizer` to extract `delta, delta_delta, pitch` features 45 | -------------------------------------------------------------------------------- /tensorflow_asr/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/features/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/losses/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/losses/base_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import keras, schemas, tf 16 | from tensorflow_asr.utils import env_util 17 | 18 | logger = tf.get_logger() 19 | 20 | 21 | class BaseLoss(keras.losses.Loss): 22 | def __init__(self, blank=0, reduction="sum_over_batch_size", name=None): 23 | super().__init__(reduction=reduction, name=name) 24 | assert blank == 0, "Only support blank=0" 25 | self.blank = blank 26 | self.use_tpu = env_util.has_devices("TPU") 27 | 28 | def call( 29 | self, 30 | y_true: schemas.TrainLabel, 31 | y_pred: schemas.TrainOutput, 32 | ): 33 | logit_length = tf.cast(y_pred.logits_length, tf.int32) 34 | labels = tf.cast(y_true.labels, tf.int32) 35 | label_length = tf.cast(y_true.labels_length, tf.int32) 36 | logit_length = tf.where(tf.less(logit_length, label_length), label_length, logit_length) # pad logit_length to label_length 37 | return y_pred.logits, logit_length, labels, label_length 38 | 39 | def get_config(self): 40 | config = super().get_config() 41 | config.update({"blank": self.blank}) 42 | return config 43 | -------------------------------------------------------------------------------- /tensorflow_asr/losses/ctc_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 Alexey Tochin 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | # ============================================================================== 29 | 30 | import logging 31 | import os 32 | 33 | from tensorflow_asr import tf 34 | from tensorflow_asr.losses.base_loss import BaseLoss 35 | from tensorflow_asr.losses.impl.ctc_tpu import ctc_loss_tpu 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | TFASR_USE_TF_CTC = os.getenv("TFASR_USE_TF_CTC", "False") in ("true", "True", "1") 40 | 41 | 42 | class CtcLoss(BaseLoss): 43 | def __init__(self, blank=0, reduction="sum_over_batch_size", name=None): 44 | super().__init__(blank=blank, reduction=reduction, name=name) 45 | logger.info("Use CTC loss TPU implementation" if self.use_tpu and not TFASR_USE_TF_CTC else "Use CTC loss") 46 | 47 | def call(self, y_true, y_pred): 48 | logits, logit_length, labels, label_length = super().call(y_true, y_pred) 49 | if self.use_tpu and not TFASR_USE_TF_CTC: 50 | return ctc_loss_tpu( 51 | labels=labels, 52 | logits=logits, 53 | label_length=label_length, 54 | logit_length=logit_length, 55 | blank_index=self.blank, 56 | ) 57 | return tf.nn.ctc_loss( 58 | logits=logits, 59 | logit_length=logit_length, 60 | labels=labels, 61 | label_length=label_length, 62 | logits_time_major=False, 63 | unique=tf.nn.ctc_unique_labels(labels) if self.use_tpu else None, 64 | blank_index=self.blank, 65 | name=self.name, 66 | ) 67 | -------------------------------------------------------------------------------- /tensorflow_asr/losses/impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/losses/impl/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/losses/rnnt_loss.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-name-in-module,unexpected-keyword-arg,no-value-for-parameter 2 | # Copyright 2020 Huy Le Nguyen (@nglehuy) and M. Yusuf Sarıgöz (@monatis) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # RNNT loss implementation in pure TensorFlow is borrowed from [iamjanvijay's repo](https://github.com/iamjanvijay/rnnt) 16 | 17 | 18 | import logging 19 | import os 20 | 21 | from tensorflow_asr.losses.base_loss import BaseLoss 22 | from tensorflow_asr.losses.impl.rnnt import rnnt_loss, warp_rnnt_loss 23 | from tensorflow_asr.utils import env_util 24 | 25 | TFASR_USE_CPU_LOSS = os.getenv("TFASR_USE_CPU_LOSS", "False") in ("true", "True", "1") 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class RnntLoss(BaseLoss): 31 | def __init__( 32 | self, 33 | blank, 34 | reduction="sum_over_batch_size", 35 | output_shapes=None, 36 | name=None, 37 | ): 38 | super().__init__(blank=blank, reduction=reduction, name=name) 39 | self.use_cpu = TFASR_USE_CPU_LOSS or (not env_util.has_devices("GPU") and not env_util.has_devices("TPU")) 40 | self.output_shapes = output_shapes 41 | # fmt: off 42 | logger.info(f"[RNNT loss] Use {'CPU' if self.use_cpu else 'GPU/TPU'} implementation in {'Tensorflow' if warp_rnnt_loss is None else 'WarpRNNT'}") # pylint: disable=line-too-long 43 | # fmt: on 44 | if self.output_shapes: 45 | logger.info(f"[RNNT loss] Use model's output shapes: {self.output_shapes}") 46 | if not all(self.output_shapes): 47 | logger.info("[RNNT loss] Detected dynamic shape") 48 | self.output_shapes = None 49 | 50 | def call(self, y_true, y_pred): 51 | logits, logit_length, labels, label_length = super().call(y_true, y_pred) 52 | return rnnt_loss( 53 | logits=logits, 54 | logits_length=logit_length, 55 | labels=labels, 56 | labels_length=label_length, 57 | blank=self.blank, 58 | name=self.name, 59 | use_cpu=self.use_cpu, 60 | output_shapes=self.output_shapes, 61 | ) 62 | 63 | def get_config(self): 64 | conf = super().get_config() 65 | conf.update({"output_shapes": self.output_shapes}) 66 | return conf 67 | -------------------------------------------------------------------------------- /tensorflow_asr/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/metrics/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/metrics/error_rates.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import keras, tf 16 | 17 | 18 | class ErrorRate(keras.metrics.Metric): 19 | """Metric for WER or CER""" 20 | 21 | def __init__(self, name="error_rate", **kwargs): 22 | super().__init__(name=name, **kwargs) 23 | self.numerator = self.add_weight(name="numerator", initializer="zeros") 24 | self.denominator = self.add_weight(name="denominator", initializer="zeros") 25 | 26 | def update_state(self, data): 27 | numer, denom = data 28 | self.numerator.assign_add(tf.reduce_sum(numer)) 29 | self.denominator.assign_add(tf.reduce_sum(denom)) 30 | 31 | def result(self): 32 | return tf.math.divide(self.numerator, self.denominator) 33 | -------------------------------------------------------------------------------- /tensorflow_asr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/activations/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/activations/glu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import keras, tf 16 | from tensorflow_asr.models.base_layer import Layer 17 | 18 | 19 | @keras.utils.register_keras_serializable(package=__name__) 20 | class GLU(Layer): 21 | def __init__(self, axis=-1, name="glu", **kwargs): 22 | super().__init__(name=name, **kwargs) 23 | self.axis = axis 24 | 25 | def call(self, inputs): 26 | a, b = tf.split(inputs, 2, axis=self.axis) 27 | b = tf.nn.sigmoid(b) 28 | return tf.multiply(a, b) 29 | 30 | def compute_output_shape(self, input_shape): 31 | B, T, V = input_shape 32 | return (B, T, V // 2) 33 | -------------------------------------------------------------------------------- /tensorflow_asr/models/base_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import keras 16 | from tensorflow_asr.utils import math_util 17 | 18 | 19 | @keras.utils.register_keras_serializable(package=__name__) 20 | class Layer(keras.layers.Layer): 21 | def __init__( 22 | self, 23 | trainable=True, 24 | name=None, 25 | dtype=None, 26 | **kwargs, 27 | ): 28 | super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) 29 | self.supports_masking = True 30 | 31 | def compute_output_shape(self, input_shape): 32 | return input_shape 33 | 34 | 35 | @keras.utils.register_keras_serializable(package=__name__) 36 | class Reshape(Layer): 37 | def call(self, inputs): 38 | outputs, outputs_length = inputs 39 | outputs = math_util.merge_two_last_dims(outputs) 40 | return outputs, outputs_length 41 | 42 | def compute_output_shape(self, input_shape): 43 | output_shape, output_length_shape = input_shape 44 | output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],) 45 | return output_shape, output_length_shape 46 | -------------------------------------------------------------------------------- /tensorflow_asr/models/ctc/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import keras, tf 16 | from tensorflow_asr.models.base_layer import Layer 17 | 18 | 19 | @keras.utils.register_keras_serializable(package=__name__) 20 | class Embedding(keras.layers.Embedding): 21 | def __init__( 22 | self, 23 | vocab_size, 24 | embed_dim, 25 | initializer="uniform", 26 | regularizer=None, 27 | contraint=None, 28 | **kwargs, 29 | ): 30 | super().__init__( 31 | input_dim=vocab_size, 32 | output_dim=embed_dim, 33 | embeddings_initializer=initializer, 34 | embeddings_regularizer=regularizer, 35 | embeddings_constraint=contraint, 36 | mask_zero=False, 37 | **kwargs, 38 | ) 39 | self.supports_masking = True 40 | 41 | def call(self, inputs): 42 | outputs, outputs_length = inputs 43 | outputs = super().call(outputs) 44 | return outputs, outputs_length 45 | 46 | def call_next(self, inputs): 47 | outputs = tf.cast(tf.expand_dims(inputs, axis=-1), dtype=tf.int32) 48 | return tf.gather_nd(self.embeddings, outputs) # https://github.com/tensorflow/tensorflow/issues/42410 49 | 50 | def compute_mask(self, inputs, mask=None): 51 | outputs, outputs_length = inputs 52 | mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool) 53 | return mask, None 54 | 55 | def compute_output_shape(self, input_shape): 56 | output_shape, output_length_shape = input_shape 57 | output_shape = super().compute_output_shape(output_shape) 58 | return output_shape, output_length_shape 59 | 60 | 61 | @keras.utils.register_keras_serializable(package=__name__) 62 | class OneHotBlank(Layer): 63 | """ 64 | https://arxiv.org/pdf/1211.3711.pdf 65 | The inputs are encoded as one-hot vectors; 66 | that is, if Y consists of K labels and yu = k, then y^u is a length K vector whose elements are all zero 67 | except the k-th, which is one. ∅ is encoded as a length K vector of zeros 68 | """ 69 | 70 | def __init__(self, blank, depth, name="one_hot_blank", **kwargs): 71 | super().__init__(name=name, **kwargs) 72 | self.blank = blank 73 | self.depth = depth 74 | 75 | def call(self, inputs): 76 | outputs, outputs_length = inputs 77 | minus_one_at_blank = tf.where(tf.equal(outputs, self.blank), -1, outputs) 78 | outputs = tf.one_hot(minus_one_at_blank, depth=self.depth, dtype=self.dtype) 79 | return outputs, outputs_length 80 | 81 | def call_next(self, inputs): 82 | outputs, _ = self.call((inputs, None)) 83 | return outputs 84 | 85 | def compute_mask(self, inputs, mask=None): 86 | outputs, outputs_length = inputs 87 | mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool) 88 | return mask, None 89 | 90 | def compute_output_shape(self, input_shape): 91 | output_shape, output_length_shape = input_shape 92 | output_shape = output_shape + (self.depth,) 93 | return output_shape, output_length_shape 94 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/general.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.src import activations, backend 3 | 4 | from tensorflow_asr.utils import math_util 5 | 6 | 7 | class Dropout(keras.layers.Dropout): 8 | def __init__(self, rate, noise_shape=None, seed=None, **kwargs): 9 | super().__init__(rate, noise_shape, seed, **kwargs) 10 | self.built = False 11 | 12 | 13 | class Identity(keras.layers.Identity): 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | self.built = False 17 | 18 | 19 | class Activation(keras.layers.Activation): 20 | def __init__(self, activation, **kwargs): 21 | super().__init__(activation, **kwargs) 22 | self.built = False 23 | 24 | 25 | class Softmax(keras.layers.Softmax): 26 | """ 27 | Softmax activation layer with better numerical stability to avoid Inf or NaN 28 | """ 29 | 30 | def call(self, inputs, mask=None): 31 | if mask is not None: 32 | inputs = math_util.masked_fill( 33 | inputs, 34 | mask=mask, 35 | value=math_util.large_compatible_negative_number(self.dtype), 36 | ) 37 | if isinstance(self.axis, (tuple, list)): 38 | if len(self.axis) > 1: 39 | return backend.numpy.exp(inputs - backend.math.logsumexp(inputs, axis=self.axis, keepdims=True)) 40 | return activations.softmax(inputs, axis=self.axis[0]) 41 | return activations.softmax(inputs, axis=self.axis) 42 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from keras.src import backend 16 | 17 | from tensorflow_asr import keras, tf 18 | from tensorflow_asr.models.base_layer import Layer 19 | from tensorflow_asr.utils import math_util 20 | 21 | 22 | def _create_num_masked(tensor_mask): 23 | return tf.vectorized_map(lambda x: math_util.count(tf.cast(x, tf.int32), value=0), elems=tensor_mask, warn=False) 24 | 25 | 26 | def _shift(tensor, shift): 27 | shifted_tensor, _ = tf.vectorized_map(lambda x: (tf.roll(x[0], shift=x[1], axis=0), x[1]), elems=(tensor, shift), warn=False) 28 | return shifted_tensor 29 | 30 | 31 | @keras.utils.register_keras_serializable(package=__name__) 32 | class Memory(Layer): 33 | """ 34 | Memory Layer 35 | This layer `call` method will do 2 things: 36 | 1. prepend memory hidden states to inputs -> new_inputs 37 | 2. concatenating memory and inputs, then slice to memory length -> new_memory 38 | """ 39 | 40 | def __init__(self, memory_length, dmodel, **kwargs): 41 | super().__init__(trainable=False, **kwargs) 42 | assert memory_length > 0, "memory_length must be integer" 43 | self.memory_length = memory_length 44 | self.dmodel = dmodel 45 | 46 | def _get_inputs(self, inputs, default_mask_value=1): 47 | inputs_mask = backend.get_keras_mask(inputs) 48 | if inputs_mask is None: 49 | batch_size, max_length, *_ = tf.shape(inputs) 50 | inputs_mask = tf.cast(tf.ones((batch_size, max_length), dtype=tf.int32) * default_mask_value, dtype=tf.bool) 51 | return inputs, inputs_mask 52 | 53 | def get_initial_state(self, batch_size: int): 54 | memory = tf.zeros(shape=(batch_size, self.memory_length, self.dmodel), dtype=self.dtype) 55 | backend.set_keras_mask(memory, tf.zeros(shape=(batch_size, self.memory_length), dtype=tf.bool)) 56 | return memory 57 | 58 | def call(self, inputs, memories=None, training=False): 59 | if memories is None: 60 | return None 61 | inputs, inputs_mask = self._get_inputs(inputs) 62 | memory, memory_mask = self._get_inputs(memories) 63 | # create new_inputs by prepending memory to inputs 64 | if training: 65 | memory = tf.stop_gradient(memory) 66 | memory_mask = tf.stop_gradient(memory_mask) 67 | new_inputs = tf.concat([memory, inputs], 1) # prepend memory and inputs 68 | new_inputs_mask = tf.concat([memory_mask, inputs_mask], 1) 69 | new_inputs._keras_mask = new_inputs_mask # pylint: disable=protected-access 70 | # create new_memory by slicing new_inputs to memory length 71 | new_memory = tf.slice( 72 | new_inputs, 73 | begin=[0, tf.shape(new_inputs)[1] - self.memory_length, 0], 74 | size=[-1, self.memory_length, -1], 75 | ) 76 | new_memory_mask = tf.slice( 77 | new_inputs_mask, 78 | begin=[0, tf.shape(new_inputs_mask)[1] - self.memory_length], 79 | size=[-1, self.memory_length], 80 | ) 81 | new_memory._keras_mask = new_memory_mask # pylint: disable=protected-access 82 | return new_inputs, new_memory 83 | 84 | def compute_output_shape(self, input_shape): 85 | return input_shape, (input_shape[0], self.memory_length, self.dmodel) 86 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/residual.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from tensorflow_asr import keras, tf 18 | from tensorflow_asr.models.base_layer import Layer 19 | 20 | 21 | @keras.utils.register_keras_serializable(package=__name__) 22 | class Residual(Layer): 23 | """Applying residual addition to layers 24 | - Normal addition with constant factor 25 | - Rezero: which improves convergence speed. This implements the paper: 26 | ReZero is All You Need: Fast Convergence at Large Depth. 27 | (https://arxiv.org/pdf/2003.04887.pdf). 28 | """ 29 | 30 | def __init__( 31 | self, 32 | factor="rezero", 33 | initializer: keras.initializers.Initializer = "zeros", 34 | regularizer: Optional[keras.regularizers.Regularizer] = None, 35 | name="residual", 36 | **kwargs, 37 | ): 38 | super().__init__(name=name, trainable=False, **kwargs) 39 | self._factor = factor 40 | self._initializer = initializer 41 | self._regularizer = regularizer 42 | 43 | def build(self, input_shape): 44 | if self._factor == "rezero": 45 | self._alpha = self.add_weight( 46 | name="alpha", 47 | shape=[], 48 | initializer=self._initializer, 49 | regularizer=self._regularizer, 50 | trainable=True, 51 | dtype=self.variable_dtype, 52 | ) 53 | else: 54 | assert isinstance(self._factor, (int, float)) 55 | self._alpha = self._factor 56 | return super().build(input_shape) 57 | 58 | def call(self, inputs): 59 | x, residual_x = inputs 60 | alpha = tf.cast(tf.convert_to_tensor(self._alpha, dtype=self.dtype), residual_x.dtype) 61 | x = x + alpha * residual_x 62 | return x 63 | 64 | def compute_output_shape(self, input_shape): 65 | return input_shape[0] 66 | -------------------------------------------------------------------------------- /tensorflow_asr/models/layers/sequence_wise_bn.py: -------------------------------------------------------------------------------- 1 | # pylint:disable=attribute-defined-outside-init 2 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from tensorflow_asr import keras, tf 17 | 18 | 19 | # https://arxiv.org/abs/1510.01378 20 | class SequenceBatchNorm(keras.layers.Layer): 21 | def __init__(self, name, time_major=False, gamma_regularizer=None, beta_regularizer=None, **kwargs): 22 | super().__init__(name=name, **kwargs) 23 | self.time_major = time_major 24 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) 25 | self.beta_regularizer = keras.regularizers.get(beta_regularizer) 26 | 27 | def build( 28 | self, 29 | input_shape, 30 | ): 31 | self.beta = self.add_weight( 32 | shape=[input_shape[-1]], 33 | name="beta", 34 | initializer="zeros", 35 | regularizer=self.beta_regularizer, 36 | constraint=None, 37 | trainable=True, 38 | dtype=self.variable_dtype, 39 | ) 40 | self.gamma = self.add_weight( 41 | shape=[input_shape[-1]], 42 | name="gamma", 43 | initializer="ones", 44 | regularizer=self.gamma_regularizer, 45 | constraint=None, 46 | trainable=True, 47 | dtype=self.variable_dtype, 48 | ) 49 | 50 | def call( 51 | self, 52 | inputs, 53 | **kwargs, 54 | ): 55 | mean, variance = tf.nn.moments(inputs, axes=[0, 1], keepdims=False) 56 | if self.time_major: 57 | total_padded_frames = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) 58 | batch_size = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) 59 | else: 60 | total_padded_frames = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) 61 | batch_size = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) 62 | total_unpadded_frames_batch = tf.math.count_nonzero(inputs, axis=[0, 1], keepdims=False, dtype=keras.backend.dtype(mean)) 63 | mean = (mean * total_padded_frames * batch_size) / total_unpadded_frames_batch 64 | variance = (variance * total_padded_frames * batch_size) / total_unpadded_frames_batch 65 | return tf.nn.batch_normalization( 66 | inputs, 67 | mean=mean, 68 | variance=variance, 69 | offset=self.beta, 70 | scale=self.gamma, 71 | variance_epsilon=keras.backend.epsilon(), 72 | ) 73 | -------------------------------------------------------------------------------- /tensorflow_asr/models/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/models/transducer/contextnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from tensorflow_asr import keras 18 | from tensorflow_asr.models.encoders.contextnet import L2, ContextNetEncoder 19 | from tensorflow_asr.models.transducer.base_transducer import Transducer 20 | 21 | 22 | @keras.utils.register_keras_serializable(package=__name__) 23 | class ContextNet(Transducer): 24 | def __init__( 25 | self, 26 | blank: int, 27 | vocab_size: int, 28 | speech_config: dict, 29 | encoder_blocks: List[dict], 30 | encoder_alpha: float = 0.5, 31 | encoder_trainable: bool = True, 32 | prediction_label_encode_mode: str = "embedding", 33 | prediction_embed_dim: int = 512, 34 | prediction_num_rnns: int = 1, 35 | prediction_rnn_units: int = 320, 36 | prediction_rnn_type: str = "lstm", 37 | prediction_rnn_implementation: int = 2, 38 | prediction_rnn_unroll: bool = False, 39 | prediction_layer_norm: bool = True, 40 | prediction_projection_units: int = 0, 41 | prediction_trainable: bool = True, 42 | joint_dim: int = 1024, 43 | joint_activation: str = "tanh", 44 | prejoint_encoder_linear: bool = True, 45 | prejoint_prediction_linear: bool = True, 46 | postjoint_linear: bool = False, 47 | joint_mode: str = "add", 48 | joint_trainable: bool = True, 49 | kernel_regularizer=L2, 50 | bias_regularizer=None, 51 | name: str = "contextnet", 52 | **kwargs, 53 | ): 54 | super().__init__( 55 | speech_config=speech_config, 56 | encoder=ContextNetEncoder( 57 | blocks=encoder_blocks, 58 | alpha=encoder_alpha, 59 | kernel_regularizer=kernel_regularizer, 60 | bias_regularizer=bias_regularizer, 61 | trainable=encoder_trainable, 62 | name="encoder", 63 | ), 64 | blank=blank, 65 | vocab_size=vocab_size, 66 | prediction_label_encoder_mode=prediction_label_encode_mode, 67 | prediction_embed_dim=prediction_embed_dim, 68 | prediction_num_rnns=prediction_num_rnns, 69 | prediction_rnn_units=prediction_rnn_units, 70 | prediction_rnn_type=prediction_rnn_type, 71 | prediction_rnn_implementation=prediction_rnn_implementation, 72 | prediction_rnn_unroll=prediction_rnn_unroll, 73 | prediction_layer_norm=prediction_layer_norm, 74 | prediction_trainable=prediction_trainable, 75 | prediction_projection_units=prediction_projection_units, 76 | joint_dim=joint_dim, 77 | joint_activation=joint_activation, 78 | prejoint_encoder_linear=prejoint_encoder_linear, 79 | prejoint_prediction_linear=prejoint_prediction_linear, 80 | postjoint_linear=postjoint_linear, 81 | joint_mode=joint_mode, 82 | joint_trainable=joint_trainable, 83 | kernel_regularizer=kernel_regularizer, 84 | bias_regularizer=bias_regularizer, 85 | name=name, 86 | **kwargs, 87 | ) 88 | self.dmodel = self.encoder.blocks[-1].dmodel 89 | self.time_reduction_factor = 1 90 | for block in self.encoder.blocks: 91 | self.time_reduction_factor *= block.time_reduction_factor 92 | -------------------------------------------------------------------------------- /tensorflow_asr/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import basename, dirname, isdir, isfile, join 3 | 4 | for fd in glob.glob(join(dirname(__file__), "*")): 5 | if not isfile(fd) and not isdir(fd): 6 | continue 7 | if isfile(fd) and not fd.endswith(".py"): 8 | continue 9 | fd = fd if isdir(fd) else fd[:-3] 10 | fd = basename(fd) 11 | if fd.startswith("__"): 12 | continue 13 | __import__(f"{__name__}.{fd}") 14 | -------------------------------------------------------------------------------- /tensorflow_asr/optimizers/accumulation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gradient Accummulation for training TF2 custom training loop. 3 | """ 4 | 5 | from keras.src.optimizers.base_optimizer import BaseOptimizer 6 | 7 | from tensorflow_asr import tf 8 | 9 | 10 | class GradientAccumulator: 11 | # We use the ON_READ synchronization policy so that no synchronization is 12 | # performed on assignment. To get the value, we call .value() which returns the 13 | # value on the current replica without synchronization. 14 | 15 | def __init__(self, ga_steps, optimizer: BaseOptimizer, name="ga"): 16 | self.name = name 17 | if ga_steps is None: 18 | raise ValueError("ga_steps must be defined") 19 | self._ga_steps = ga_steps 20 | self._optimizer = optimizer 21 | self._accumulated_gradients = [] 22 | self.built = False 23 | 24 | def build(self, variables): 25 | if not self._optimizer.built: 26 | self._optimizer.build(variables) 27 | for i, variable in enumerate(variables): 28 | self._accumulated_gradients.append( 29 | self._optimizer.add_variable_from_reference( 30 | variable, 31 | name="gradient_accumulator", 32 | ) 33 | ) 34 | self.built = True 35 | 36 | @property 37 | def total_steps(self): 38 | return self._ga_steps 39 | 40 | # def is_apply_step(self, step): 41 | # return tf.math.equal(step % self._ga_steps, 0) 42 | 43 | def reset(self): 44 | for g_acc in self._accumulated_gradients: 45 | g_acc.assign(tf.zeros(g_acc.shape, dtype=g_acc.dtype)) 46 | 47 | def _get_acc_grads(self, trainable_variables): 48 | # `trainable_variables` might have been filtered in previous 49 | # processing steps, so we need to ensure the correct mapping between 50 | # `self._accumulated_gradients` and `trainable_variables` 51 | acc_grads = [self._accumulated_gradients[self._optimizer._get_variable_index(v)] for v in trainable_variables] 52 | return acc_grads 53 | 54 | def accumulate(self, grads, trainable_variables): 55 | """Accumulates :obj:`gradients` on the current replica.""" 56 | if not self.built: 57 | self.build(trainable_variables) 58 | # return [None if x is None else x if y is None else x + y for x, y in zip(gradients, per_ga_gradients)] 59 | acc_grads = self._get_acc_grads(trainable_variables) 60 | new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)] 61 | for n_g_acc, g_acc in zip(new_g_accs, acc_grads): 62 | g_acc.assign(n_g_acc) 63 | 64 | def gradients(self, grads, trainable_variables): 65 | """Gets the gradients for the apply step.""" 66 | if not self.built: 67 | self.build(trainable_variables) 68 | acc_grads = self._get_acc_grads(trainable_variables) 69 | grads = [(g + acc_g) / self._ga_steps for g, acc_g in zip(grads, acc_grads)] 70 | return grads 71 | -------------------------------------------------------------------------------- /tensorflow_asr/optimizers/regularizers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tensorflow_asr import keras, tf 4 | 5 | 6 | @keras.utils.register_keras_serializable(package=__name__) 7 | class TimeDependentGaussianGradientNoise(keras.regularizers.Regularizer): 8 | """ 9 | Reference: https://openreview.net/pdf/ZY9xxQDMMu5Pk8ELfEz4.pdf 10 | """ 11 | 12 | def __init__( 13 | self, 14 | mean: float = 0.0, 15 | eta: float = 1.0, # {0.01, 0.3, 1.0} 16 | gamma: float = 0.55, 17 | ): 18 | self.mean = mean 19 | self.eta = eta 20 | self.gamma = gamma 21 | super().__init__() 22 | 23 | def noise(self, step: tf.Tensor, gradient: tf.Tensor): 24 | sigma_squared = self.eta / ((1 + tf.cast(step, dtype=gradient.dtype)) ** self.gamma) 25 | return tf.random.normal(mean=self.mean, stddev=tf.math.sqrt(sigma_squared), shape=tf.shape(gradient), dtype=gradient.dtype) 26 | 27 | def __call__(self, step: tf.Tensor, gradients: List[tf.Tensor]): 28 | """ 29 | Apply gaussian noise with time dependent to gradients 30 | 31 | Parameters 32 | ---------- 33 | step : tf.Tensor 34 | Training step 35 | gradients : List[tf.Tensor] 36 | Gradients calculated from optimizer 37 | 38 | Returns 39 | ------- 40 | List[tf.Tensor] 41 | Noise added gradients 42 | """ 43 | return list(tf.add(gradient, self.noise(step, gradient=gradient)) for gradient in gradients) 44 | 45 | def get_config(self): 46 | return { 47 | "mean": self.mean, 48 | "eta": self.eta, 49 | "gamma": self.gamma, 50 | } 51 | -------------------------------------------------------------------------------- /tensorflow_asr/schemas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing 16 | 17 | import tensorflow as tf 18 | 19 | 20 | class TrainInput(typing.NamedTuple): 21 | inputs: tf.Tensor 22 | inputs_length: tf.Tensor 23 | predictions: tf.Tensor 24 | predictions_length: tf.Tensor 25 | 26 | 27 | class TrainOutput(typing.NamedTuple): 28 | logits: tf.Tensor 29 | logits_length: tf.Tensor 30 | 31 | 32 | class TrainLabel(typing.NamedTuple): 33 | labels: tf.Tensor 34 | labels_length: tf.Tensor 35 | 36 | 37 | class TrainData(typing.NamedTuple): 38 | inputs: TrainInput 39 | labels: TrainLabel 40 | 41 | 42 | class PredictInput(typing.NamedTuple): 43 | inputs: tf.Tensor 44 | inputs_length: tf.Tensor 45 | previous_tokens: typing.Optional[tf.Tensor] = None 46 | previous_encoder_states: typing.Optional[tf.Tensor] = None 47 | previous_decoder_states: typing.Optional[tf.Tensor] = None 48 | 49 | 50 | class PredictOutput(typing.NamedTuple): 51 | tokens: tf.Tensor 52 | next_tokens: tf.Tensor 53 | next_encoder_states: typing.Optional[tf.Tensor] = None 54 | next_decoder_states: typing.Optional[tf.Tensor] = None 55 | 56 | 57 | class PredictOutputWithTranscript(typing.NamedTuple): 58 | transcript: tf.Tensor 59 | tokens: tf.Tensor 60 | next_tokens: tf.Tensor 61 | next_encoder_states: typing.Optional[tf.Tensor] = None 62 | next_decoder_states: typing.Optional[tf.Tensor] = None 63 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_asr.scripts import save, test, tflite, train 2 | from tensorflow_asr.scripts.utils import create_datasets_metadata, create_mls_trans, create_tfrecords 3 | from tensorflow_asr.utils import cli_util 4 | 5 | 6 | def main(): 7 | cli_util.run( 8 | { 9 | "train": train.main, 10 | "test": test.main, 11 | "tflite": tflite.main, 12 | "save": save.main, 13 | "utils": { 14 | "create_mls_trans": create_mls_trans.main, 15 | "create_tfrecords": create_tfrecords.main, 16 | "create_datasets_metadata": create_datasets_metadata.main, 17 | }, 18 | } 19 | ) 20 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/save.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | 18 | from tensorflow_asr import keras, tf, tokenizers 19 | from tensorflow_asr.configs import Config 20 | from tensorflow_asr.models.base_model import BaseModel 21 | from tensorflow_asr.utils import cli_util, env_util, keras_util 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main( 27 | config_path: str, 28 | output: str, 29 | h5: str = None, 30 | bs: int = 2, 31 | save_format: str = "h5", 32 | repodir: str = os.getcwd(), 33 | ): 34 | assert output 35 | keras.backend.clear_session() 36 | env_util.setup_seed() 37 | 38 | config = Config(config_path, training=False, repodir=repodir) 39 | tokenizer = tokenizers.get(config) 40 | tokenizer.make() 41 | 42 | logger.info(f"Configs: {str(config)}") 43 | 44 | model: BaseModel = keras_util.model_from_config(config.model_config) 45 | model.tokenizer = tokenizer 46 | model.make(batch_size=bs) 47 | if h5 and tf.io.gfile.exists(h5): 48 | model.load_weights(h5, skip_mismatch=False) 49 | model.summary() 50 | 51 | model.save(output, save_format=save_format) 52 | loaded_model: BaseModel = keras.models.load_model(output) 53 | logger.info(loaded_model.to_json()) 54 | loaded_model.summary() 55 | 56 | 57 | if __name__ == "__main__": 58 | cli_util.run(main) 59 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | import os 18 | 19 | from tensorflow_asr import datasets, tf, tokenizers # import to aid logging messages 20 | from tensorflow_asr.callbacks import PredictLogger 21 | from tensorflow_asr.configs import Config 22 | from tensorflow_asr.models.base_model import BaseModel 23 | from tensorflow_asr.utils import app_util, cli_util, env_util, file_util, keras_util 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main( 29 | config_path: str, 30 | dataset_type: str, 31 | datadir: str, 32 | outputdir: str, 33 | h5: str = None, 34 | mxp: str = "none", 35 | bs: int = 1, 36 | jit_compile: bool = False, 37 | repodir: str = os.getcwd(), 38 | ): 39 | 40 | outputdir = file_util.preprocess_paths(outputdir, isdir=True) 41 | checkpoint_name = os.path.splitext(os.path.basename(h5))[0] 42 | 43 | env_util.setup_seed() 44 | env_util.setup_mxp(mxp=mxp) 45 | 46 | config = Config(config_path, training=False, repodir=repodir, datadir=datadir) 47 | batch_size = bs 48 | 49 | tokenizer = tokenizers.get(config) 50 | tokenizer.make() 51 | 52 | logger.info(f"Configs: {str(config)}") 53 | 54 | model: BaseModel = keras_util.model_from_config(config.model_config) 55 | model.tokenizer = tokenizer 56 | model.make(batch_size=batch_size) 57 | model.load_weights(h5, skip_mismatch=False) 58 | model.jit_compile = jit_compile 59 | model.summary() 60 | 61 | for test_data_config in config.data_config.test_dataset_configs: 62 | if not test_data_config.name: 63 | raise ValueError("Test dataset name must be provided") 64 | logger.info(f"Testing dataset: {test_data_config.name}") 65 | 66 | output = os.path.join(outputdir, f"{test_data_config.name}-{checkpoint_name}.tsv") 67 | 68 | test_dataset = datasets.get(tokenizer=tokenizer, dataset_config=test_data_config, dataset_type=dataset_type) 69 | test_data_loader = test_dataset.create(batch_size) 70 | 71 | overwrite = True 72 | if tf.io.gfile.exists(output): 73 | while overwrite not in ["yes", "no"]: 74 | overwrite = input(f"File {output} exists, overwrite? (yes/no): ").lower() 75 | overwrite = overwrite == "yes" 76 | 77 | if overwrite: 78 | with file_util.save_file(output) as output_file_path: 79 | model.predict( 80 | test_data_loader, 81 | verbose=1, 82 | callbacks=[ 83 | PredictLogger(test_dataset=test_dataset, output_file_path=output_file_path), 84 | ], 85 | ) 86 | 87 | evaluation_outputs = app_util.evaluate_hypotheses(output) 88 | logger.info(f"Results:\n{evaluation_outputs.to_markdown()}") 89 | 90 | 91 | if __name__ == "__main__": 92 | cli_util.run(main) 93 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/tflite.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | 18 | from tensorflow_asr import keras, tf, tokenizers # import to aid logging messages 19 | from tensorflow_asr.configs import Config 20 | from tensorflow_asr.models.base_model import BaseModel 21 | from tensorflow_asr.utils import app_util, cli_util, env_util, keras_util 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main( 27 | config_path: str, 28 | output: str, 29 | h5: str = None, 30 | bs: int = 1, 31 | beam_width: int = 0, 32 | repodir: str = os.getcwd(), 33 | ): 34 | assert output 35 | keras.backend.clear_session() 36 | env_util.setup_seed() 37 | 38 | config = Config(config_path, training=False, repodir=repodir) 39 | tokenizer = tokenizers.get(config) 40 | tokenizer.make() 41 | 42 | logger.info(f"Configs: {str(config)}") 43 | 44 | model: BaseModel = keras_util.model_from_config(config.model_config) 45 | model.tokenizer = tokenizer 46 | model.make(batch_size=bs) 47 | if h5 and tf.io.gfile.exists(h5): 48 | model.load_weights(h5, skip_mismatch=False) 49 | model.summary() 50 | 51 | app_util.convert_tflite(model=model, output=output, batch_size=bs, beam_width=beam_width) 52 | 53 | 54 | if __name__ == "__main__": 55 | cli_util.run(main) 56 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/scripts/utils/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/scripts/utils/create_datasets_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | import os 18 | 19 | from tensorflow_asr import datasets, tokenizers 20 | from tensorflow_asr.configs import Config 21 | from tensorflow_asr.utils import cli_util 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main( 27 | config_path: str, 28 | datadir: str, 29 | dataset_type: str, 30 | repodir: str = os.getcwd(), 31 | ): 32 | config = Config(config_path, repodir=repodir, datadir=datadir) 33 | if not config.decoder_config.vocabulary: 34 | raise ValueError("decoder_config.vocabulary must be defined") 35 | 36 | tokenizer = tokenizers.get(config) 37 | 38 | logger.info("Preparing train metadata ...") 39 | config.data_config.train_dataset_config.drop_remainder = False 40 | config.data_config.train_dataset_config.shuffle = False 41 | train_dataset = datasets.get( 42 | tokenizer=tokenizer, 43 | dataset_config=config.data_config.train_dataset_config, 44 | dataset_type=dataset_type, 45 | ) 46 | tokenizer.build(train_dataset) 47 | tokenizer.make() 48 | train_dataset.update_metadata() 49 | 50 | logger.info("Preparing eval metadata ...") 51 | config.data_config.eval_dataset_config.drop_remainder = False 52 | config.data_config.eval_dataset_config.shuffle = False 53 | eval_dataset = datasets.get( 54 | tokenizer=tokenizer, 55 | dataset_config=config.data_config.eval_dataset_config, 56 | dataset_type=dataset_type, 57 | ) 58 | eval_dataset.update_metadata() 59 | 60 | 61 | if __name__ == "__main__": 62 | cli_util.run(main) 63 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/utils/create_mls_trans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 M. Yusuf Sarıgöz (@monatis) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import librosa 19 | 20 | from tensorflow_asr import keras 21 | 22 | # example usage: python create_mls_trans.py -dataset-home /mnt/datasets/mls --language polish --opus 23 | 24 | base_url = "https://dl.fbaipublicfiles.com/mls/" 25 | 26 | langs = ["dutch", "english", "german", "french", "italian", "portuguese", "polish", "spanish"] 27 | 28 | splits = ["dev", "test", "train"] 29 | 30 | chars = set() 31 | 32 | 33 | def prepare_split(dataset_dir, split, opus=False): 34 | # Setup necessary paths 35 | split_home = os.path.join(dataset_dir, split) 36 | transcripts_infile = os.path.join(split_home, "transcripts.txt") 37 | transcripts_outfile = os.path.join(split_home, "transcripts_tfasr.tsv") 38 | audio_home = os.path.join(split_home, "audio") 39 | extension = ".opus" if opus else ".flac" 40 | transcripts = [] 41 | 42 | from tqdm.auto import tqdm 43 | 44 | # Make paths absolute, get durations and read chars to form alphabet later on 45 | with open(transcripts_infile, "r", encoding="utf8") as infile: 46 | for line in tqdm(infile.readlines(), desc=f"Reading from {transcripts_infile}...", disable=False): 47 | file_id, transcript = line.strip().split("\t") 48 | speaker_id, book_id, _ = file_id.split("_") 49 | audio_path = os.path.join(audio_home, speaker_id, book_id, f"{file_id}{extension}") 50 | y, sr = librosa.load(audio_path, sr=None) 51 | duration = librosa.get_duration(y, sr) 52 | transcripts.append(f"{audio_path}\t{duration}\t{transcript}\n") 53 | for char in transcript: 54 | chars.add(char) 55 | 56 | # Write transcripts to file 57 | with open(transcripts_outfile, "w", encoding="utf8") as outfile: 58 | outfile.write("PATH\tDURATION\tTRANSCRIPT\n") 59 | for t in tqdm(transcripts, desc=f"Writing to {transcripts_outfile}", disable=False): 60 | outfile.write(t) 61 | 62 | 63 | def make_alphabet_file(filepath, chars_list, lang): 64 | print(f"Writing alphabet to {filepath}...") 65 | with open(filepath, "w", encoding="utf8") as outfile: 66 | outfile.write(f"# Alphabet file for language {lang}\n") 67 | outfile.write("Automatically generated. Do not edit\n#\n") 68 | for char in sorted(list(chars_list)): 69 | outfile.write(f"{char}\n") 70 | 71 | outfile.write("# end of file") 72 | 73 | 74 | def main(): 75 | ap = argparse.ArgumentParser(description="Download and prepare MLS dataset in a given language") 76 | ap.add_argument( 77 | "--dataset-home", "-d", default=None, required=False, help="Path to home directory to download and prepare dataset. Default to ~/.keras" 78 | ) 79 | ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, help="Any name of language included in MLS") 80 | ap.add_argument("--opus", default=False, action="store_true", help="Whether to use dataset in opus format or not") 81 | 82 | args = ap.parse_args() 83 | fname = "mls_{}{}.tar.gz".format(args.language, "_opus" if args.opus else "") 84 | subdir = fname[:-7] 85 | dataset_home = os.path.abspath(args.dataset_home) 86 | dataset_dir = os.path.join(dataset_home, subdir) 87 | full_url = base_url + fname 88 | 89 | downloaded_file = keras.utils.get_file(fname, full_url, cache_subdir=dataset_home, extract=True) 90 | 91 | print(f"Dataset extracted to {dataset_dir}. Preparing...") 92 | 93 | for split in splits: 94 | prepare_split(dataset_dir=dataset_dir, split=split, opus=args.opus) 95 | 96 | make_alphabet_file(os.path.join(dataset_dir, "alphabet.txt"), chars, args.language) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /tensorflow_asr/scripts/utils/create_tfrecords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from typing import List 17 | 18 | from tensorflow_asr import datasets, tokenizers 19 | from tensorflow_asr.configs import Config 20 | from tensorflow_asr.utils import cli_util 21 | 22 | 23 | def main( 24 | config_path: str, 25 | datadir: str, 26 | modes: List[str], 27 | repodir: str = os.getcwd(), 28 | dataset_type: str = "tfrecord", 29 | ): 30 | config = Config(config_path, repodir=repodir, datadir=datadir) 31 | tokenizer = tokenizers.get(config=config) 32 | tokenizer.make() 33 | for mode in modes: 34 | dat = datasets.get( 35 | tokenizer=tokenizer, 36 | dataset_config=getattr(config.data_config, f"{mode}_dataset_config"), 37 | dataset_type=dataset_type, 38 | ) 39 | dat.create_tfrecords() 40 | 41 | 42 | if __name__ == "__main__": 43 | cli_util.run(main) 44 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tensorflow_asr/utils/__init__.py -------------------------------------------------------------------------------- /tensorflow_asr/utils/cli_util.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | 4 | def run( 5 | component, 6 | command=None, 7 | name=None, 8 | ): 9 | """ 10 | Run a component with cli, the helps is printed in stdout 11 | as https://github.com/google/python-fire/issues/188#issuecomment-791972163 12 | 13 | Args: 14 | component: functions or class 15 | command (optional): any. Defaults to None. 16 | name (str, optional):. Defaults to None. 17 | """ 18 | fire.core.Display = lambda lines, out: print(*lines, file=out) 19 | fire.Fire(component, command=command, name=name) 20 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/data_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # tf.data.Dataset does not work well for namedtuple so we are using dict 16 | 17 | import os 18 | from functools import reduce 19 | from typing import Any 20 | 21 | import librosa 22 | import tensorflow as tf 23 | 24 | 25 | def load_and_convert_to_wav( 26 | path: str, 27 | sample_rate: int = None, 28 | ): 29 | wave, rate = librosa.load(os.path.realpath(os.path.expanduser(path)), sr=sample_rate, mono=True) 30 | return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate) 31 | 32 | 33 | def read_raw_audio(audio: tf.Tensor): 34 | wave, _ = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1) 35 | return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal 36 | 37 | 38 | def get( 39 | obj: dict, 40 | path: str, 41 | default: Any = None, 42 | ): 43 | path = str(path) 44 | 45 | def _reduce_fn(d, key): 46 | if isinstance(d, dict): 47 | return d.get(key, default) 48 | if isinstance(d, list): 49 | try: 50 | return d[int(key)] 51 | except (IndexError, ValueError): 52 | return default 53 | return default 54 | 55 | return reduce(_reduce_fn, path.split("."), obj) 56 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/feature_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import tf 16 | 17 | 18 | def float_feature( 19 | list_of_floats, 20 | ): 21 | return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats)) 22 | 23 | 24 | def int64_feature( 25 | list_of_ints, 26 | ): 27 | return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints)) 28 | 29 | 30 | def bytestring_feature( 31 | list_of_bytestrings, 32 | ): 33 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings)) 34 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/keras_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras.src.saving import serialization_lib 3 | 4 | 5 | def model_from_config(model_config: dict, custom_objects=None): 6 | return serialization_lib.deserialize_keras_object(model_config, custom_objects=custom_objects) 7 | 8 | 9 | def reduce_per_replica(values, strategy, reduction): 10 | if reduction == "auto": 11 | if isinstance(strategy, tf.distribute.TPUStrategy): 12 | reduction = "first" 13 | else: 14 | reduction = "mean" 15 | 16 | def _reduce(v): 17 | """Reduce a single `PerReplica` object.""" 18 | if reduction == "first": 19 | return strategy.experimental_local_results(v)[0] 20 | if reduction == "sum": 21 | return strategy.reduce("SUM", v, axis=None) 22 | if reduction == "mean": 23 | return strategy.reduce("MEAN", v, axis=None) 24 | raise ValueError("`reduction` must be one of " '"first", "mean", "sum", or "auto". ' f"Received: reduction={reduction}.") 25 | 26 | return tf.nest.map_structure(_reduce, values) 27 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/layer_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | 17 | from tensorflow_asr import keras, tf 18 | from tensorflow_asr.models.layers.convolution import Conv1D, Conv2D 19 | 20 | 21 | def get_rnn( 22 | rnn_type: str, 23 | ): 24 | assert rnn_type in ["lstm", "gru", "rnn"] 25 | if rnn_type == "lstm": 26 | return keras.layers.LSTM 27 | if rnn_type == "gru": 28 | return keras.layers.GRU 29 | return keras.layers.SimpleRNN 30 | 31 | 32 | def get_conv( 33 | conv_type: str, 34 | ): 35 | assert conv_type in ["conv1d", "conv2d"] 36 | if conv_type == "conv1d": 37 | return Conv1D 38 | return Conv2D 39 | 40 | 41 | def add_gwn( 42 | trainable_weights: List[tf.Variable], 43 | stddev: float = 1.0, 44 | ): 45 | original_weights = [] 46 | for weight in trainable_weights: 47 | noise = tf.stop_gradient(tf.random.normal(mean=0.0, stddev=stddev, shape=weight.shape, dtype=weight.dtype)) 48 | original_weights.append(weight) 49 | weight.assign_add(noise) 50 | return original_weights 51 | 52 | 53 | def sub_gwn( 54 | original_weights: list, 55 | trainable_weights: list, 56 | ): 57 | for i, weight in enumerate(trainable_weights): 58 | weight.assign(original_weights[i]) 59 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plotmesh(data, title="data", scale_ysize=4, invert_yaxis=True): 7 | xsize = data.shape[1] 8 | ysize = data.shape[0] 9 | gcd = math.gcd(xsize, ysize) 10 | xsize /= gcd 11 | ysize /= gcd 12 | xsize = (xsize * scale_ysize) / ysize 13 | ysize = scale_ysize 14 | figsize = [xsize, ysize] 15 | fig, ax = plt.subplots(figsize=figsize) 16 | ax.set_title(title, fontweight="bold") 17 | ax.minorticks_on() 18 | if invert_yaxis: 19 | ax.invert_yaxis() 20 | img = ax.pcolormesh(data, cmap="viridis") 21 | cbar = fig.colorbar(img, ax=ax, format="%.2f", pad=0.01) 22 | cbar.minorticks_on() 23 | fig.tight_layout() 24 | plt.show() 25 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/shape_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huy Le Nguyen (@nglehuy) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tensorflow_asr import tf 16 | 17 | 18 | def shape_list(x, out_type=tf.int32): 19 | """Deal with dynamic shape in tensorflow cleanly.""" 20 | static = x.shape.as_list() 21 | dynamic = tf.shape(x, out_type=out_type) 22 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 23 | 24 | 25 | def shape_list_per_replica(x, per_replica_batch_size): 26 | _, *rest_shape = x.shape 27 | shapes = (int(per_replica_batch_size),) + tuple(rest_shape) 28 | return shapes 29 | 30 | 31 | def get_shape_invariants(tensor): 32 | shapes = shape_list(tensor) 33 | return tf.TensorShape([i if isinstance(i, int) else None for i in shapes]) 34 | 35 | 36 | def get_float_spec(tensor): 37 | shape = get_shape_invariants(tensor) 38 | return tf.TensorSpec(shape, dtype=tf.float32) 39 | 40 | 41 | def get_dim(tensor, i): 42 | """Get value of tensor shape[i] preferring static value if available.""" 43 | return tf.compat.dimension_value(tensor.shape[i]) or tf.shape(tensor)[i] 44 | -------------------------------------------------------------------------------- /tensorflow_asr/utils/tf_util.py: -------------------------------------------------------------------------------- 1 | # # import importlib 2 | 3 | # import tensorflow as tf 4 | # from keras.src.utils import tf_utils 5 | 6 | # from tensorflow_asr.utils.env_util import KERAS_SRC 7 | 8 | # # tf_utils = importlib.import_module(f"{KERAS_SRC}.utils.tf_utils") 9 | 10 | 11 | # def convert_shapes(input_shape, to_tuples=True): 12 | # if input_shape is None: 13 | # return None 14 | 15 | # def _is_shape_component(value): 16 | # return value is None or isinstance(value, (int, tf.compat.v1.Dimension)) 17 | 18 | # def _is_atomic_shape(input_shape): 19 | # # Ex: TensorShape or (None, 10, 32) or 5 or `None` 20 | # if _is_shape_component(input_shape): 21 | # return True 22 | # if isinstance(input_shape, tf.TensorShape): 23 | # return True 24 | # if isinstance(input_shape, (tuple, list)) and all(_is_shape_component(ele) for ele in input_shape): 25 | # return True 26 | # return False 27 | 28 | # def _convert_shape(input_shape): 29 | # if input_shape is None: 30 | # return None 31 | # input_shape = tf.TensorShape(input_shape) 32 | # if to_tuples: 33 | # input_shape = tuple(input_shape.as_list()) 34 | # return input_shape 35 | 36 | # return tf_utils.map_structure_with_atomic(_is_atomic_shape, _convert_shape, input_shape) 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tests/conftest.py -------------------------------------------------------------------------------- /tests/featurizer/test_speech_featurizer.py: -------------------------------------------------------------------------------- 1 | # # %% 2 | # import librosa 3 | # import librosa.display 4 | # import matplotlib.pyplot as plt 5 | # import numpy as np 6 | 7 | # from tensorflow_asr import tf 8 | # from tensorflow_asr.augmentations.methods import specaugment 9 | # from tensorflow_asr.configs import SpeechConfig 10 | # from tensorflow_asr.features import speech_featurizers 11 | 12 | # speech_conf = SpeechConfig( 13 | # { 14 | # "sample_rate": 16000, 15 | # "frame_ms": 25, 16 | # "stride_ms": 10, 17 | # "feature_type": "log_mel_spectrogram", 18 | # "num_feature_bins": 80, 19 | # # "compute_energy": True, 20 | # # "use_natural_log": False, 21 | # # "use_librosa_like_stft": True, 22 | # # "fft_overdrive": False, 23 | # # "normalize_feature": False, 24 | # } 25 | # ) 26 | # signal = speech_featurizers.read_raw_audio("./test.flac", speech_conf.sample_rate) 27 | 28 | # print(f"signal length: {len(signal)}") 29 | # sf = speech_featurizers.SpeechFeaturizer(speech_conf) 30 | # ft = sf.extract(signal) 31 | # freq_mask = specaugment.FreqMasking(prob=1, mask_value="min") 32 | # ft = freq_mask.augment(ft) 33 | # time_mask = specaugment.TimeMasking(prob=1, p_upperbound=0.05) 34 | # ft = time_mask.augment(ft) 35 | # ft = tf.squeeze(ft, axis=-1) 36 | # ft = ft.numpy().T 37 | # print(ft.shape) 38 | 39 | # plt.figure(figsize=(24, 5)) 40 | # ax = plt.gca() 41 | # ax.set_title("log_mel_spectrogram", fontweight="bold") 42 | # librosa.display.specshow(ft, cmap="viridis") 43 | # v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) 44 | # plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) 45 | # plt.tight_layout() 46 | # plt.show() 47 | 48 | # sf.speech_config.normalize_per_frame = True 49 | # ft = sf.extract(signal) 50 | # ft = tf.squeeze(ft, axis=-1) 51 | # ft = ft.numpy().T 52 | # print(ft.shape) 53 | 54 | # plt.figure(figsize=(24, 5)) 55 | # ax = plt.gca() 56 | # ax.set_title("log_mel_spectrogram", fontweight="bold") 57 | # librosa.display.specshow(ft, cmap="viridis") 58 | # v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) 59 | # plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) 60 | # plt.tight_layout() 61 | # plt.show() 62 | 63 | # print(np.std(ft)) 64 | # print(np.mean(ft)) 65 | 66 | # nframes = 5 67 | # chunk_size = (nframes - 1) * sf.speech_config.frame_step + sf.speech_config.frame_length 68 | # stride = nframes * sf.speech_config.frame_step 69 | # print(f"With chunk size: {chunk_size} and nfft: {sf.nfft}") 70 | # signal_length = len(signal) 71 | # all_ft = None 72 | # for i in range(int(np.ceil((signal_length - chunk_size) / stride))): # this ensure the fft shape of chunked signal is the same with whole signal 73 | # chunk = signal[i * stride : i * stride + chunk_size] 74 | # # cft = sf.power_to_db(sf.stft(chunk)) 75 | # cft = sf.extract(chunk) 76 | # cft = tf.squeeze(cft, axis=-1) 77 | # cft = cft.numpy() 78 | # if all_ft is None: 79 | # all_ft = cft 80 | # else: 81 | # all_ft = np.concatenate([all_ft, cft], axis=0) 82 | # all_ft = all_ft.T 83 | # all_ft = np.pad(all_ft, [[0, 0], [0, ft.shape[-1] - all_ft.shape[-1]]]) 84 | # print(all_ft.shape) 85 | 86 | # plt.figure(figsize=(24, 5)) 87 | # ax = plt.gca() 88 | # ax.set_title(f"chunked log_mel_spectrogram", fontweight="bold") 89 | # librosa.display.specshow(all_ft, cmap="viridis") 90 | # v1 = np.linspace(all_ft.min(), all_ft.max(), 8, endpoint=True) 91 | # plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) 92 | # plt.tight_layout() 93 | # plt.show() 94 | 95 | # dft = all_ft - ft 96 | 97 | # plt.figure(figsize=(24, 5)) 98 | # ax = plt.gca() 99 | # ax.set_title(f"diff of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") 100 | # librosa.display.specshow(dft, cmap="viridis") 101 | # v1 = np.linspace(dft.min(), dft.max(), 8, endpoint=True) 102 | # plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) 103 | # plt.tight_layout() 104 | # plt.show() 105 | 106 | # plt.figure(figsize=(24, 5)) 107 | # ax = plt.gca() 108 | # ax.set_title(f"RMSE of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") 109 | # plt.plot(np.sqrt(np.mean(dft**2, axis=0))) 110 | # plt.tight_layout() 111 | # plt.show() 112 | 113 | # # %% 114 | -------------------------------------------------------------------------------- /tests/test.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowASR/f092d39743d74308cb7959001c4657b98d01e7be/tests/test.flac -------------------------------------------------------------------------------- /tests/test_bug.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | 4 | class Model(keras.Model): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | self.dense = keras.layers.Dense(10) 8 | self.mha = keras.layers.MultiHeadAttention(10, 10, output_shape=(100,)) 9 | 10 | def call(self, inputs): 11 | x = self.dense(inputs) 12 | return self.mha(x, x, x) 13 | 14 | 15 | model = Model() 16 | model(keras.Input(shape=(10, 10))) 17 | model.summary() 18 | -------------------------------------------------------------------------------- /tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from tensorflow_asr.callbacks import KaggleModelBackupAndRestore 5 | 6 | 7 | def test_kaggle_model_backup_and_restore(): 8 | model_handle = os.getenv("TEST_MODEL_HANDLE") 9 | if not model_handle: 10 | return 11 | with tempfile.TemporaryDirectory() as temp_dir: 12 | os.environ["KAGGLEHUB_CACHE"] = os.path.join(temp_dir, "cache") 13 | os.makedirs(os.environ["KAGGLEHUB_CACHE"], exist_ok=True) 14 | model_dir = os.path.join(temp_dir, "model") 15 | os.makedirs(model_dir, exist_ok=True) 16 | with open(os.path.join(model_dir, "model.h5"), "w", encoding="utf-8") as f: 17 | f.write("dummy model data") 18 | callback = KaggleModelBackupAndRestore( 19 | model_handle=model_handle, 20 | model_dir=model_dir, 21 | save_freq=1, 22 | ) 23 | callback._backup_kaggle(logs={}, notes="Backed up model at batch") 24 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=line-too-long 2 | import os 3 | 4 | import librosa 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from tensorflow_asr import tf 9 | from tensorflow_asr.augmentations.augmentation import Augmentation 10 | from tensorflow_asr.models.layers.feature_extraction import FeatureExtraction 11 | from tensorflow_asr.utils import data_util, file_util 12 | 13 | # config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "log_mel_spectrogram.yml.j2") 14 | # config = file_util.load_yaml(config_path) 15 | 16 | audio_file_path = os.path.join(os.path.dirname(__file__), "test.flac") 17 | 18 | 19 | def plot_specs(ft, title): 20 | ft = ft.numpy() if isinstance(ft, tf.Tensor) else ft 21 | ft = np.squeeze(ft) 22 | ft = ft.T 23 | plt.figure(figsize=(24, 5)) 24 | ax = plt.gca() 25 | ax.set_title(title, fontweight="bold") 26 | librosa.display.specshow(ft, cmap="viridis") 27 | v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) 28 | plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) 29 | plt.tight_layout() 30 | plt.show() 31 | 32 | 33 | def test_feature_extraction(): 34 | signal = data_util.load_and_convert_to_wav(audio_file_path) 35 | signal = tf.expand_dims(data_util.read_raw_audio(signal), axis=0) 36 | signal_length = tf.expand_dims(tf.shape(signal)[1], axis=0) 37 | signal = tf.pad(signal, paddings=[[0, 0], [0, 16000]], mode="CONSTANT", constant_values=0.0) 38 | 39 | feature_extraction_layer = FeatureExtraction() 40 | 41 | for ftype in ("spectrogram", "log_mel_spectrogram", "log_gammatone_spectrogram", "mfcc"): 42 | feature_extraction_layer.feature_type = ftype 43 | ft, _ = feature_extraction_layer((signal, signal_length)) 44 | plot_specs(ft, feature_extraction_layer.feature_type) 45 | 46 | mask, _ = feature_extraction_layer.compute_mask((signal, signal_length)) 47 | print(mask) 48 | 49 | feature_extraction_layer.feature_type = "log_mel_spectrogram" 50 | feature_extraction_layer.preemphasis = 0.0 51 | ft1, _ = feature_extraction_layer((signal, signal_length)) 52 | feature_extraction_layer.preemphasis = 0.97 53 | ft2, _ = feature_extraction_layer((signal, signal_length)) 54 | ft = ft1 - ft2 55 | plot_specs(ft, feature_extraction_layer.feature_type) 56 | 57 | feature_extraction_layer.augmentations = Augmentation( 58 | { 59 | "feature_augment": { 60 | "freq_masking": { 61 | "num_masks": 2, 62 | "mask_factor": 27, 63 | "prob": 0.0, 64 | "mask_value": 0, 65 | }, 66 | "time_masking": { 67 | "num_masks": 2, 68 | "mask_factor": -1, 69 | "prob": 0.0, 70 | "mask_value": 0, 71 | "p_upperbound": 0.05, 72 | }, 73 | } 74 | } 75 | ) 76 | feature_extraction_layer.preemphasis = 0.0 77 | ft1, _ = feature_extraction_layer((signal, signal_length), training=True) 78 | plot_specs(ft1, feature_extraction_layer.feature_type) 79 | -------------------------------------------------------------------------------- /tests/test_mask.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow_asr.models.layers.multihead_attention import compute_streaming_mask 4 | 5 | 6 | def test_mha_streaming_mask(): 7 | mask = compute_streaming_mask(2, 2, tf.zeros([5, 8, 8])) 8 | print(mask) 9 | assert tf.reduce_all( 10 | tf.equal( 11 | mask, 12 | tf.constant( 13 | [ 14 | [ 15 | [True, True, False, False, False, False, False, False], 16 | [True, True, False, False, False, False, False, False], 17 | [True, True, True, True, False, False, False, False], 18 | [True, True, True, True, False, False, False, False], 19 | [False, False, True, True, True, True, False, False], 20 | [False, False, True, True, True, True, False, False], 21 | [False, False, False, False, True, True, True, True], 22 | [False, False, False, False, True, True, True, True], 23 | ] 24 | ] 25 | ), 26 | ) 27 | ).numpy() 28 | 29 | mask = compute_streaming_mask(3, 3, tf.zeros([5, 14, 14])) 30 | print(mask) 31 | assert tf.reduce_all( 32 | tf.equal( 33 | mask, 34 | tf.constant( 35 | [ 36 | [ 37 | [True, True, True, False, False, False, False, False, False, False, False, False, False, False], 38 | [True, True, True, False, False, False, False, False, False, False, False, False, False, False], 39 | [True, True, True, False, False, False, False, False, False, False, False, False, False, False], 40 | [True, True, True, True, True, True, False, False, False, False, False, False, False, False], 41 | [True, True, True, True, True, True, False, False, False, False, False, False, False, False], 42 | [True, True, True, True, True, True, False, False, False, False, False, False, False, False], 43 | [False, False, False, True, True, True, True, True, True, False, False, False, False, False], 44 | [False, False, False, True, True, True, True, True, True, False, False, False, False, False], 45 | [False, False, False, True, True, True, True, True, True, False, False, False, False, False], 46 | [False, False, False, False, False, False, True, True, True, True, True, True, False, False], 47 | [False, False, False, False, False, False, True, True, True, True, True, True, False, False], 48 | [False, False, False, False, False, False, True, True, True, True, True, True, False, False], 49 | [False, False, False, False, False, False, False, False, False, True, True, True, True, True], 50 | [False, False, False, False, False, False, False, False, False, True, True, True, True, True], 51 | ] 52 | ] 53 | ), 54 | ) 55 | ).numpy() 56 | -------------------------------------------------------------------------------- /tests/test_relpe.py: -------------------------------------------------------------------------------- 1 | from tensorflow_asr import tf 2 | from tensorflow_asr.models.layers.multihead_attention import rel_left_shift 3 | from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding 4 | from tensorflow_asr.utils import plot_util 5 | 6 | 7 | def test(): 8 | batch_size, input_length, max_length, dmodel = 2, 300, 500, 144 9 | causal = False 10 | layer = RelativeSinusoidalPositionalEncoding(interleave=True, memory_length=input_length, causal=causal) 11 | _, pe = layer((tf.random.normal([batch_size, max_length, dmodel]), tf.convert_to_tensor([input_length, input_length + 10])), training=False) 12 | shift = tf.einsum("brd,btd->btr", pe, tf.ones([batch_size, max_length, dmodel])) 13 | shift = rel_left_shift(shift[0][None, None, ...], causal=causal) 14 | pe = tf.transpose(pe[0], perm=[1, 0]) 15 | pe = pe.numpy() 16 | print(pe.shape) 17 | shift = shift[0][0] 18 | shift = shift.numpy() 19 | print(shift.shape) 20 | plot_util.plotmesh(pe, title="sinusoid position encoding", invert_yaxis=False) 21 | plot_util.plotmesh(shift, title="relshift") 22 | 23 | 24 | def test_relshift(): 25 | a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 26 | print(a) 27 | a = a[None, ...] 28 | a = a[None, ...] 29 | b = rel_left_shift(a, causal=True) 30 | b = tf.squeeze(b, 0) 31 | b = tf.squeeze(b, 0) 32 | print(b) 33 | -------------------------------------------------------------------------------- /tests/test_rnnt_loss.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from tensorflow_asr import tf 4 | from tensorflow_asr.losses.rnnt_loss import compute_rnnt_loss_and_grad_helper 5 | 6 | B = 1 7 | T = 743 8 | U = 200 9 | V = 1000 10 | blank = 0 11 | 12 | 13 | # @tf.function 14 | def run(): 15 | logits = tf.random.normal([B, T, U + 1, V], dtype=tf.float32) 16 | labels = tf.repeat(tf.range(U, dtype=tf.int32)[None, :], B, 0) 17 | logit_length = tf.repeat(tf.convert_to_tensor([T], dtype=tf.int32), B, 0) 18 | label_length = tf.repeat(tf.convert_to_tensor([U], dtype=tf.int32), B, 0) 19 | 20 | t0 = time.time() 21 | loss, grad = compute_rnnt_loss_and_grad_helper( 22 | logits=logits, 23 | labels=labels, 24 | label_length=label_length, 25 | logit_length=logit_length, 26 | ) 27 | t1 = time.time() 28 | tf.print(loss) 29 | print("Took", t1 - t0) 30 | 31 | 32 | def test(): 33 | tf.config.run_functions_eagerly(False) 34 | run() 35 | -------------------------------------------------------------------------------- /tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from tensorflow_asr.optimizers.schedules import CyclicTransformerSchedule, TransformerSchedule 4 | 5 | 6 | def test_transformer_schedule(): 7 | sched = TransformerSchedule(dmodel=176, scale=10.0, warmup_steps=10000, max_lr="0.05/(176**0.5)", min_lr=None) 8 | sched2 = CyclicTransformerSchedule(dmodel=320, step_size=10000, warmup_steps=15000, max_lr=0.0025) 9 | lrs = [sched(i).numpy() for i in range(100000)] 10 | print(lrs[:100]) 11 | plt.plot(lrs) 12 | plt.show() 13 | lrs = [sched2(i).numpy() for i in range(100000)] 14 | print(lrs[:100]) 15 | plt.plot(lrs) 16 | plt.show() 17 | -------------------------------------------------------------------------------- /tests/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=line-too-long 2 | import os 3 | 4 | from tensorflow_asr import tf 5 | from tensorflow_asr.configs import DecoderConfig 6 | from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, WordPieceTokenizer 7 | from tensorflow_asr.utils import file_util 8 | 9 | file_util.ENABLE_PATH_PREPROCESS = False 10 | 11 | repodir = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) 12 | 13 | 14 | text = "i'm good but it would have broken down after ten miles of that hard trail dawn came while they wound over the crest of the range and with the sun in their faces they took the downgrade it was well into the morning before nash reached logan" 15 | # text = "a b" 16 | 17 | 18 | def test_char(): 19 | config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "characters", "char.yml.j2") 20 | config = file_util.load_yaml(config_path, repodir=repodir) 21 | decoder_config = DecoderConfig(config["decoder_config"]) 22 | featurizer = CharTokenizer(decoder_config=decoder_config) 23 | print(featurizer.num_classes) 24 | print(text) 25 | indices = featurizer.tokenize(text) 26 | print(indices.numpy()) 27 | batch_indices = tf.stack([indices, indices], axis=0) 28 | reversed_text = featurizer.detokenize(batch_indices) 29 | print(reversed_text.numpy()) 30 | upoints = featurizer.detokenize_unicode_points(indices) 31 | print(upoints.numpy()) 32 | 33 | 34 | def test_wp(): 35 | config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "wordpiece", "wp.yml.j2") 36 | config = file_util.load_yaml(config_path, repodir=repodir) 37 | decoder_config = DecoderConfig(config["decoder_config"]) 38 | featurizer = WordPieceTokenizer(decoder_config=decoder_config) 39 | print(featurizer.num_classes) 40 | print(text) 41 | indices = featurizer.tokenize(text) 42 | print(indices.numpy()) 43 | batch_indices = tf.stack([indices, indices], axis=0) 44 | reversed_text = featurizer.detokenize(batch_indices) 45 | print(reversed_text.numpy()) 46 | upoints = featurizer.detokenize_unicode_points(indices) 47 | print(upoints.numpy()) 48 | 49 | 50 | def test_sp(): 51 | config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "sentencepiece", "sp.yml.j2") 52 | config = file_util.load_yaml(config_path, repodir=repodir) 53 | decoder_config = DecoderConfig(config["decoder_config"]) 54 | featurizer = SentencePieceTokenizer(decoder_config=decoder_config) 55 | print(featurizer.num_classes) 56 | print(text) 57 | indices = featurizer.tokenize(text) 58 | print(indices) 59 | indices = list(indices.numpy()) 60 | indices += [0, 0] 61 | batch_indices = tf.stack([indices, indices], axis=0) 62 | reversed_text = featurizer.detokenize(batch_indices) 63 | print(reversed_text.numpy()) 64 | upoints = featurizer.detokenize_unicode_points(indices) 65 | print(upoints.numpy()) 66 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tensorflow_asr import tf 4 | from tensorflow_asr.utils import file_util, math_util 5 | 6 | 7 | def test_load_yaml(): 8 | a = file_util.load_yaml(f"{os.path.dirname(__file__)}/../examples/conformer/config_wp.yml") 9 | print(a) 10 | 11 | 12 | def test_mask_fill(): 13 | a = math_util.masked_fill( 14 | tf.convert_to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], tf.float32), 15 | [[True, True, True], [True, False, True], [False, True, True]], 16 | value=-1e9, 17 | ) 18 | print(a.numpy()) 19 | 20 | 21 | def test_dataset(): 22 | a = [1, 2, 3, 4, 5, 6, 7] 23 | batch = 2 24 | ds = tf.data.Dataset.from_tensor_slices(a) 25 | ds = ds.cache() 26 | ds = ds.shuffle(3) 27 | ds = ds.repeat(3) 28 | ds = ds.batch(batch, drop_remainder=True) 29 | print(list(ds.as_numpy_iterator())) 30 | 31 | 32 | def test_split_batch(): 33 | a = tf.ones((12, 2, 4), tf.float32) 34 | b = math_util.split_tensor_by_ga(a, 4, 3) 35 | print(b) 36 | --------------------------------------------------------------------------------