├── .clang-format ├── .flake8 ├── .gitignore ├── .pylintrc ├── .style.yapf ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── athena ├── __init__.py ├── cmvn_main.py ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── language_set.py │ │ ├── speech_recognition.py │ │ ├── speech_recognition_test.py │ │ └── speech_set.py │ ├── feature_normalizer.py │ └── text_featurizer.py ├── decode_main.py ├── horovod_main.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── commons.py │ ├── functional.py │ └── transformer.py ├── loss.py ├── main.py ├── metrics.py ├── models │ ├── __init__.py │ ├── base.py │ ├── customized.py │ ├── deep_speech.py │ ├── masked_pc.py │ ├── mtl_seq2seq.py │ ├── rnn_lm.py │ └── speech_transformer.py ├── solver.py ├── tools │ ├── __init__.py │ ├── beam_search.py │ ├── ctc_scorer.py │ ├── lm_scorer.py │ └── sph2pipe ├── transform │ ├── README.md │ ├── __init__.py │ ├── audio_featurizer.py │ └── feats │ │ ├── __init__.py │ │ ├── base_frontend.py │ │ ├── cmvn.py │ │ ├── cmvn_test.py │ │ ├── fbank.py │ │ ├── fbank_pitch.py │ │ ├── fbank_pitch_test.py │ │ ├── fbank_test.py │ │ ├── framepow.py │ │ ├── framepow_test.py │ │ ├── mfcc.py │ │ ├── mfcc_test.py │ │ ├── ops │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── kernels │ │ │ ├── complex_defines.h │ │ │ ├── delta_delta.cc │ │ │ ├── delta_delta.h │ │ │ ├── delta_delta_op.cc │ │ │ ├── delta_delta_op_test.py │ │ │ ├── fbank.cc │ │ │ ├── fbank.h │ │ │ ├── fbank_op.cc │ │ │ ├── fbank_op_test.py │ │ │ ├── framepow.cc │ │ │ ├── framepow.h │ │ │ ├── framepow_op.cc │ │ │ ├── mfcc_dct.cc │ │ │ ├── mfcc_dct.h │ │ │ ├── mfcc_dct_op.cc │ │ │ ├── mfcc_mel_filterbank.cc │ │ │ ├── mfcc_mel_filterbank.h │ │ │ ├── pitch.cc │ │ │ ├── pitch.h │ │ │ ├── pitch_op.cc │ │ │ ├── resample.cc │ │ │ ├── resample.h │ │ │ ├── spectrum.cc │ │ │ ├── spectrum.h │ │ │ ├── spectrum_op.cc │ │ │ ├── spectrum_op_test.py │ │ │ ├── speed_op.cc │ │ │ ├── support_functions.cc │ │ │ ├── support_functions.h │ │ │ └── x_ops.cc │ │ └── py_x_ops.py │ │ ├── pitch.py │ │ ├── pitch_test.py │ │ ├── read_wav.py │ │ ├── read_wav_test.py │ │ ├── spectrum.py │ │ ├── spectrum_test.py │ │ ├── write_wav.py │ │ └── write_wav_test.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── data_queue.py │ ├── hparam.py │ ├── hparam_test.py │ ├── learning_rate.py │ ├── metric_check.py │ ├── misc.py │ └── vocabs │ ├── ch-en.vocab │ ├── ch.vocab │ └── en.vocab ├── docs ├── README.md ├── TheTrainningEfficiency.md ├── development │ └── contributing.md ├── transform │ ├── img │ │ ├── DCT.png │ │ ├── DFT.png │ │ ├── MFCC.png │ │ ├── MelBank.png │ │ ├── Mel_filter.png │ │ ├── amplitude_spectrum.png │ │ ├── audio_data.png │ │ ├── delta.png │ │ ├── fbank.png │ │ ├── hamming.png │ │ ├── logMel.png │ │ ├── logpower_spectrum.png │ │ ├── mel_freq.png │ │ ├── melbanks.png │ │ ├── phase_spectrum.png │ │ ├── power_spectrum.png │ │ ├── spectrum.png │ │ ├── spectrum_emph.png │ │ └── spectrum_orig.png │ ├── speech_feature.md │ └── user_manual.md └── using_docker.md ├── examples ├── asr │ ├── README.md │ ├── aidatatang_200zh │ │ └── local │ │ │ └── prepare_data.py │ ├── aishell │ │ ├── data │ │ │ └── vocab │ │ ├── local │ │ │ └── prepare_data.py │ │ ├── mtl_transformer.json │ │ ├── mtl_transformer_sp.json │ │ └── rnnlm.json │ ├── hkust │ │ ├── README.md │ │ ├── data │ │ │ └── vocab │ │ ├── deep_speech.json │ │ ├── local │ │ │ ├── prepare_data.py │ │ │ └── segment_word.py │ │ ├── mpc.json │ │ ├── mtl_transformer.json │ │ ├── mtl_transformer_sp.json │ │ ├── rnnlm.json │ │ ├── run.sh │ │ └── transformer.json │ ├── librispeech │ │ ├── data │ │ │ └── librispeech_unigram5000.model │ │ ├── prepare_data.py │ │ └── transformer.json │ ├── magic_data │ │ └── local │ │ │ └── prepare_data.py │ ├── primewords │ │ └── local │ │ │ └── prepare_data.py │ └── switchboard_fisher │ │ └── prepare_data.py └── translate │ └── spa-eng-example │ └── prepare_data.py ├── requirements.txt ├── setup.py └── tools ├── env.sh ├── install.sh ├── install_kenlm.sh └── install_sph2pipe.sh /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # .flake8 2 | # 3 | # DESCRIPTION 4 | # Configuration file for the python linter flake8. 5 | # 6 | # This configuration is based on the generic 7 | # configuration published on GitHub. 8 | # 9 | # AUTHOR krnd 10 | # VERSION v1.0 11 | # 12 | # SEE ALSO 13 | # http://flake8.pycqa.org/en/latest/user/options.html 14 | # http://flake8.pycqa.org/en/latest/user/error-codes.html 15 | # http://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes 16 | # http://gist.github.com/krnd 17 | # 18 | 19 | [flake8] 20 | 21 | # Specify the number of subprocesses that Flake8 will use to run checks in parallel. 22 | jobs = auto 23 | 24 | # Increase the verbosity of Flake8’s output. 25 | verbose = 0 26 | 27 | # Decrease the verbosity of Flake8’s output. 28 | quiet = 0 29 | 30 | # Select the formatter used to display errors to the user. 31 | format = default 32 | 33 | # Print the total number of errors. 34 | count = True 35 | 36 | # Print the source code generating the error/warning in question. 37 | show-source = True 38 | 39 | # Count the number of occurrences of each error/warning code and print a report. 40 | statistics = True 41 | 42 | 43 | # Redirect all output to the specified file. 44 | output-file = /tmp/flake8.log 45 | 46 | # Also print output to stdout if output-file has been configured. 47 | tee = True 48 | 49 | # Provide a comma-separated list of glob patterns to exclude from checks. 50 | exclude = 51 | # git folder 52 | .git, 53 | # python cache 54 | __pycache__, 55 | tools 56 | 57 | # Provide a comma-separate list of glob patterns to include for checks. 58 | filename = *.py 59 | 60 | # Provide a custom list of builtin functions, objects, names, etc. 61 | builtins = 62 | 63 | # Report all errors, even if it is on the same line as a `# NOQA` comment. 64 | disable-noqa = False 65 | 66 | # Set the maximum length that any line (with some exceptions) may be. 67 | max-line-length = 100 68 | 69 | # Set the maximum allowed McCabe complexity value for a block of code. 70 | max-complexity = 10 71 | 72 | # Toggle whether pycodestyle should enforce matching the indentation of the opening bracket’s line. 73 | # incluences E131 and E133 74 | hang-closing = True 75 | 76 | 77 | # ERROR CODES 78 | # 79 | # E/W - PEP8 errors/warnings (pycodestyle) 80 | # F - linting errors (pyflakes) 81 | # C - McCabe complexity error (mccabe) 82 | # 83 | # W503 - line break before binary operator 84 | 85 | # Specify a list of codes to ignore. 86 | ignore = W503 87 | 88 | # Specify the list of error codes you wish Flake8 to report. 89 | select = E9,F6,F81,F82,F83,F9 90 | 91 | # Enable off-by-default extensions. 92 | enable-extensions = 93 | 94 | # Enable PyFlakes syntax checking of doctests in docstrings. 95 | doctests = True 96 | 97 | # Specify which files are checked by PyFlakes for doctest syntax. 98 | include-in-doctest = 99 | 100 | # Specify which files are not to be checked by PyFlakes for doctest syntax. 101 | exclude-in-doctest = 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | *.pyc 4 | tags 5 | .vscode 6 | 7 | /tools/sph2pipe_v2.5.tar.gz 8 | /tools/sph2pipe_v2.5 9 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = chromium 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guideline 2 | 3 | Thanks for considering to contribute this project. All issues and pull requests are highly appreciated. 4 | 5 | ## Pull Requests 6 | 7 | Before sending pull request to this project, please read and follow guidelines below. 8 | 9 | 1. Branch: We accept pull request on `master` branch. 10 | 2. Coding style: Follow the coding style used in VirtualAPK. 11 | 3. Commit message: Use English and be aware of your spell. 12 | 4. Test: Make sure to test your code. 13 | 14 | Add device mode, API version, related log, screenshots and other related information in your pull request if possible. 15 | 16 | NOTE: We assume all your contribution can be licensed under the [Apache License 2.0](https://github.com/didichuxing/athena/tree/master/LICENSE). 17 | 18 | ## Issues 19 | 20 | We love clearly described issues. :) 21 | 22 | Following information can help us to resolve the issue faster. 23 | 24 | * Device mode and hardware information. 25 | * API version. 26 | * Logs. 27 | * Screenshots. 28 | * Steps to reproduce the issue. 29 | 30 | ## Coding Styles 31 | 32 | Please follow the coding styles [here](https://git.xiaojukeji.com/speech-am/athena/tree/master/docs/development/contributing.md) 33 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.1.0-gpu-py3 2 | 3 | ENV CUDNN_VERSION=7.6.5.32-1+cuda10.1 4 | ENV NCCL_VERSION=2.4.8-1+cuda10.1 5 | 6 | # Set default shell to /bin/bash 7 | SHELL ["/bin/bash", "-cu"] 8 | 9 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends wget \ 10 | vim \ 11 | git \ 12 | libcudnn7=${CUDNN_VERSION} \ 13 | libnccl2=${NCCL_VERSION} \ 14 | libnccl-dev=${NCCL_VERSION} 15 | 16 | # Install Open MPI 17 | RUN mkdir /tmp/openmpi && \ 18 | cd /tmp/openmpi && \ 19 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ 20 | tar zxf openmpi-4.0.0.tar.gz && \ 21 | cd openmpi-4.0.0 && \ 22 | ./configure --enable-orterun-prefix-by-default && \ 23 | make -j $(nproc) all && \ 24 | make install && \ 25 | ldconfig && \ 26 | rm -rf /tmp/openmpi 27 | 28 | # Install Horovod, temporarily using CUDA stubs 29 | RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \ 30 | HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_TENSORFLOW=1 && \ 31 | pip --default-timeout=1000 install --no-cache-dir git+https://github.com/horovod/horovod && \ 32 | ldconfig 33 | 34 | # Install OpenSSH for MPI to communicate between containers 35 | RUN apt-get install -y --no-install-recommends openssh-client openssh-server && \ 36 | mkdir -p /var/run/sshd 37 | 38 | # Allow OpenSSH to talk to containers without asking for confirmation 39 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 40 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 41 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 42 | 43 | RUN pip --default-timeout=1000 install sox \ 44 | absl-py \ 45 | yapf \ 46 | pylint \ 47 | flake8 \ 48 | tqdm \ 49 | sentencepiece \ 50 | librosa \ 51 | kenlm \ 52 | pandas \ 53 | jieba 54 | 55 | # Install Athena 56 | Run git clone https://github.com/didi/athena.git /athena && \ 57 | cd athena && python setup.py bdist_wheel && \ 58 | python -m pip install --ignore-installed dist/athena-0.1.0*.whl 59 | -------------------------------------------------------------------------------- /athena/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ module """ 17 | # data 18 | from .data import SpeechRecognitionDatasetBuilder 19 | from .data import SpeechDatasetBuilder 20 | from .data import LanguageDatasetBuilder 21 | from .data import FeatureNormalizer 22 | from .data.text_featurizer import TextFeaturizer 23 | 24 | # layers 25 | from .layers.functional import make_positional_encoding 26 | from .layers.functional import collapse4d 27 | from .layers.functional import gelu 28 | from .layers.commons import PositionalEncoding 29 | from .layers.commons import Collapse4D 30 | from .layers.commons import TdnnLayer 31 | from .layers.commons import Gelu 32 | from .layers.attention import MultiHeadAttention 33 | from .layers.attention import BahdanauAttention 34 | from .layers.attention import HanAttention 35 | from .layers.attention import MatchAttention 36 | from .layers.transformer import Transformer 37 | from .layers.transformer import TransformerEncoder 38 | from .layers.transformer import TransformerDecoder 39 | from .layers.transformer import TransformerEncoderLayer 40 | from .layers.transformer import TransformerDecoderLayer 41 | 42 | # models 43 | from .models.base import BaseModel 44 | from .models.speech_transformer import SpeechTransformer, SpeechTransformer2 45 | from .models.masked_pc import MaskedPredictCoding 46 | from .models.deep_speech import DeepSpeechModel 47 | from .models.mtl_seq2seq import MtlTransformerCtc 48 | from .models.rnn_lm import RNNLM 49 | 50 | # solver & loss & accuracy 51 | from .solver import BaseSolver 52 | from .solver import HorovodSolver 53 | from .solver import DecoderSolver 54 | from .loss import CTCLoss 55 | from .loss import Seq2SeqSparseCategoricalCrossentropy 56 | from .metrics import CTCAccuracy 57 | from .metrics import Seq2SeqSparseCategoricalAccuracy 58 | 59 | # utils 60 | from .utils.checkpoint import Checkpoint 61 | from .utils.learning_rate import WarmUpLearningSchedule, WarmUpAdam 62 | from .utils.learning_rate import ( 63 | ExponentialDecayLearningRateSchedule, 64 | ExponentialDecayAdam, 65 | ) 66 | from .utils.hparam import HParams, register_and_parse_hparams 67 | from .utils.misc import generate_square_subsequent_mask 68 | from .utils.misc import get_wave_file_length 69 | from .utils.misc import set_default_summary_writer 70 | 71 | # tools 72 | from .tools.beam_search import BeamSearchDecoder 73 | -------------------------------------------------------------------------------- /athena/cmvn_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support tensorflow 2.0 18 | # pylint: disable=invalid-name, no-member 19 | r""" a sample implementation of LAS for HKUST """ 20 | import sys 21 | import json 22 | import tensorflow as tf 23 | from absl import logging 24 | from athena.main import parse_config, SUPPORTED_DATASET_BUILDER 25 | 26 | if __name__ == "__main__": 27 | logging.set_verbosity(logging.INFO) 28 | if len(sys.argv) < 3: 29 | logging.warning('Usage: python {} config_json_file data_csv_file'.format(sys.argv[0])) 30 | sys.exit() 31 | tf.random.set_seed(1) 32 | 33 | jsonfile = sys.argv[1] 34 | with open(jsonfile) as file: 35 | config = json.load(file) 36 | p = parse_config(config) 37 | if "speed_permutation" in p.dataset_config: 38 | p.dataset_config['speed_permutation'] = [1.0] 39 | csv_file = sys.argv[2] 40 | dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.dataset_config) 41 | dataset_builder.load_csv(csv_file).compute_cmvn_if_necessary(True) 42 | -------------------------------------------------------------------------------- /athena/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """ data """ 18 | from .datasets.speech_recognition import SpeechRecognitionDatasetBuilder 19 | from .datasets.speech_set import SpeechDatasetBuilder 20 | from .datasets.language_set import LanguageDatasetBuilder 21 | from .feature_normalizer import FeatureNormalizer 22 | from .text_featurizer import TextFeaturizer, SentencePieceFeaturizer 23 | -------------------------------------------------------------------------------- /athena/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ data.datasets """ 17 | -------------------------------------------------------------------------------- /athena/data/datasets/speech_recognition_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # pylint: disable=no-member, invalid-name 17 | """ audio dataset """ 18 | import time 19 | import tqdm 20 | from absl import logging 21 | from athena import SpeechRecognitionDatasetBuilder 22 | 23 | def test(): 24 | ''' test the speed of dataset ''' 25 | data_csv = "/tmp-data/dataset/opensource/hkust/train.csv" 26 | dataset_builder = SpeechRecognitionDatasetBuilder( 27 | { 28 | "audio_config": { 29 | "type": "Fbank", 30 | "filterbank_channel_count": 40, 31 | "sample_rate": 8000, 32 | "local_cmvn": False, 33 | }, 34 | "speed_permutation": [0.9, 1.0], 35 | "vocab_file": "examples/asr/hkust/data/vocab" 36 | } 37 | ) 38 | dataset = dataset_builder.load_csv(data_csv).as_dataset(16, 4) 39 | start = time.time() 40 | for _ in tqdm.tqdm(dataset, total=len(dataset_builder)//16): 41 | pass 42 | logging.info(time.time() - start) 43 | 44 | 45 | if __name__ == '__main__': 46 | logging.set_verbosity(logging.INFO) 47 | test() 48 | -------------------------------------------------------------------------------- /athena/decode_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support tensorflow 2.0 18 | # pylint: disable=invalid-name, no-member 19 | r""" a sample implementation of LAS for HKUST """ 20 | import sys 21 | import json 22 | import tensorflow as tf 23 | from absl import logging 24 | from athena import DecoderSolver 25 | from athena.main import ( 26 | parse_config, 27 | build_model_from_jsonfile 28 | ) 29 | 30 | 31 | def decode(jsonfile): 32 | """ entry point for model decoding, do some preparation work """ 33 | p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile(jsonfile, 0) 34 | checkpointer.restore_from_best() 35 | solver = DecoderSolver(model, config=p.decode_config) 36 | dataset_builder = dataset_builder.load_csv(p.test_csv).compute_cmvn_if_necessary(True) 37 | solver.decode(dataset_builder.as_dataset(batch_size=1)) 38 | 39 | 40 | if __name__ == "__main__": 41 | logging.set_verbosity(logging.INFO) 42 | tf.random.set_seed(1) 43 | 44 | JSON_FILE = sys.argv[1] 45 | CONFIG = None 46 | with open(JSON_FILE) as f: 47 | CONFIG = json.load(f) 48 | PARAMS = parse_config(CONFIG) 49 | DecoderSolver.initialize_devices(PARAMS.solver_gpu) 50 | decode(JSON_FILE) 51 | -------------------------------------------------------------------------------- /athena/horovod_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support tensorflow 2.0 18 | # pylint: disable=invalid-name, no-member 19 | r""" a sample implementation of LAS for HKUST """ 20 | import sys 21 | import json 22 | import tensorflow as tf 23 | import horovod.tensorflow as hvd 24 | from absl import logging 25 | from athena import HorovodSolver 26 | from athena.main import parse_config, train 27 | 28 | if __name__ == "__main__": 29 | logging.set_verbosity(logging.INFO) 30 | tf.random.set_seed(1) 31 | 32 | JSON_FILE = sys.argv[1] 33 | CONFIG = None 34 | with open(JSON_FILE) as f: 35 | CONFIG = json.load(f) 36 | PARAMS = parse_config(CONFIG) 37 | HorovodSolver.initialize_devices() 38 | train(JSON_FILE, HorovodSolver, hvd.size(), hvd.local_rank()) 39 | -------------------------------------------------------------------------------- /athena/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ module """ 17 | -------------------------------------------------------------------------------- /athena/layers/commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # pylint: disable=too-few-public-methods, invalid-name 17 | # pylint: disable=no-self-use, missing-function-docstring 18 | """Utils for common layers.""" 19 | 20 | import tensorflow as tf 21 | from athena.layers.functional import make_positional_encoding, collapse4d, gelu 22 | 23 | from athena.layers.functional import splice 24 | 25 | 26 | class PositionalEncoding(tf.keras.layers.Layer): 27 | """ positional encoding can be used in transformer """ 28 | 29 | def __init__(self, d_model, max_position=800, scale=False): 30 | super().__init__() 31 | self.d_model = d_model 32 | self.scale = scale 33 | self.pos_encoding = make_positional_encoding(max_position, d_model) 34 | 35 | def call(self, x): 36 | """ call function """ 37 | seq_len = tf.shape(x)[1] 38 | if self.scale: 39 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) 40 | x += self.pos_encoding[:, :seq_len, :] 41 | return x 42 | 43 | 44 | class Collapse4D(tf.keras.layers.Layer): 45 | """ callapse4d can be used in cnn-lstm for speech processing 46 | reshape from [N T D C] -> [N T D*C] 47 | """ 48 | 49 | def call(self, x): 50 | return collapse4d(x) 51 | 52 | 53 | class Gelu(tf.keras.layers.Layer): 54 | """Gaussian Error Linear Unit. 55 | This is a smoother version of the RELU. 56 | Original paper: https://arxiv.org/abs/1606.08415 57 | Args: 58 | x: float Tensor to perform activation. 59 | Returns: 60 | `x` with the GELU activation applied. 61 | """ 62 | 63 | def call(self, x): 64 | return gelu(x) 65 | 66 | 67 | class TdnnLayer(tf.keras.layers.Layer): 68 | """ An implement of Tdnn Layer 69 | Args: 70 | context: a int of left and right context, or 71 | a list of context indexes, e.g. (-2, 0, 2). 72 | output_dim: the dim of the linear transform 73 | """ 74 | 75 | def __init__(self, context, output_dim, use_bias=False, **kwargs): 76 | super().__init__(**kwargs) 77 | 78 | if hasattr(context, "__iter__"): 79 | self.context_size = len(context) 80 | self.context_list = context 81 | else: 82 | self.context_size = context * 2 + 1 83 | self.context_list = range(-context, context + 1) 84 | 85 | self.output_dim = output_dim 86 | self.linear = tf.keras.layers.Dense(output_dim, use_bias=use_bias) 87 | 88 | def call(self, x, training=None, mask=None): 89 | x = splice(x, self.context_list) 90 | x = self.linear(x, training=training, mask=mask) 91 | return x 92 | 93 | 94 | SUPPORTED_RNNS = { 95 | "lstm": tf.keras.layers.LSTMCell, 96 | "gru": tf.keras.layers.GRUCell, 97 | "cudnnlstm": tf.keras.layers.LSTMCell, 98 | "cudnngru": tf.keras.layers.GRUCell 99 | } 100 | 101 | 102 | ACTIVATIONS = { 103 | "relu": tf.nn.relu, 104 | "relu6": tf.nn.relu6, 105 | "elu": tf.nn.elu, 106 | "selu": tf.nn.selu, 107 | "gelu": gelu, 108 | "leaky_relu": tf.nn.leaky_relu, 109 | "sigmoid": tf.nn.sigmoid, 110 | "softplus": tf.nn.softplus, 111 | "softsign": tf.nn.softsign, 112 | "tanh": tf.nn.tanh, 113 | } 114 | -------------------------------------------------------------------------------- /athena/layers/functional.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # pylint: disable=invalid-name 17 | """Utils for common layers.""" 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from ..utils.misc import tensor_shape 22 | from tensorflow.python.framework import ops 23 | 24 | 25 | def make_positional_encoding(position, d_model): 26 | """ generate a postional encoding list """ 27 | 28 | def get_angles(pos, i, d_model): 29 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) 30 | return pos * angle_rates 31 | 32 | angle_rads = get_angles( 33 | np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model 34 | ) 35 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 36 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 37 | pos_encoding = angle_rads[np.newaxis, ...] 38 | return tf.cast(pos_encoding, dtype=tf.float32) 39 | 40 | 41 | def collapse4d(x, name=None): 42 | """ reshape from [N T D C] -> [N T D*C] 43 | using tf.shape(x), which generate a tensor instead of x.shape 44 | """ 45 | with ops.name_scope(name, "collapse4d") as name: 46 | shape = tensor_shape(x) 47 | N = shape[0] 48 | T = shape[1] 49 | D = shape[2] 50 | C = shape[3] 51 | DC = D * C 52 | out = tf.reshape(x, [N, T, DC]) 53 | return out 54 | 55 | 56 | def splice(x, context): 57 | """ 58 | Splice a tensor along the last dimension with context. 59 | e.g.: 60 | t = [[[1, 2, 3], 61 | [4, 5, 6], 62 | [7, 8, 9]]] 63 | splice_tensor(t, [0, 1]) = 64 | [[[1, 2, 3, 4, 5, 6], 65 | [4, 5, 6, 7, 8, 9], 66 | [7, 8, 9, 7, 8, 9]]] 67 | 68 | Args: 69 | tensor: a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W) 70 | context: a list of context offsets 71 | 72 | Returns: 73 | spliced tensor with shape (..., D * len(context)) 74 | """ 75 | input_shape = tf.shape(x) 76 | B, T = input_shape[0], input_shape[1] 77 | context_len = len(context) 78 | array = tf.TensorArray(x.dtype, size=context_len) 79 | for idx, offset in enumerate(context): 80 | begin = offset 81 | end = T + offset 82 | if begin < 0: 83 | begin = 0 84 | sliced = x[:, begin:end, :] 85 | tiled = tf.tile(x[:, 0:1, :], [1, abs(offset), 1]) 86 | final = tf.concat((tiled, sliced), axis=1) 87 | else: 88 | end = T 89 | sliced = x[:, begin:end, :] 90 | tiled = tf.tile(x[:, -1:, :], [1, abs(offset), 1]) 91 | final = tf.concat((sliced, tiled), axis=1) 92 | array = array.write(idx, final) 93 | spliced = array.stack() 94 | spliced = tf.transpose(spliced, (1, 2, 0, 3)) 95 | spliced = tf.reshape(spliced, (B, T, -1)) 96 | return spliced 97 | 98 | 99 | def gelu(x): 100 | """Gaussian Error Linear Unit. 101 | This is a smoother version of the RELU. 102 | Original paper: https://arxiv.org/abs/1606.08415 103 | Args: 104 | x: float Tensor to perform activation. 105 | Returns: 106 | `x` with the GELU activation applied. 107 | """ 108 | cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 109 | return x * cdf 110 | -------------------------------------------------------------------------------- /athena/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # Only support eager mode and TF>=2.0.0 17 | # pylint: disable=too-few-public-methods 18 | """ some losses """ 19 | import tensorflow as tf 20 | from .utils.misc import insert_eos_in_labels 21 | 22 | 23 | class CTCLoss(tf.keras.losses.Loss): 24 | """ CTC LOSS 25 | CTC LOSS implemented with Tensorflow 26 | """ 27 | 28 | def __init__(self, logits_time_major=False, blank_index=-1, name="CTCLoss"): 29 | super().__init__(name=name) 30 | self.logits_time_major = logits_time_major 31 | self.blank_index = blank_index 32 | self.need_logit_length = True 33 | 34 | def __call__(self, logits, samples, logit_length=None): 35 | assert logit_length is not None 36 | # use v2 ctc_loss 37 | ctc_loss = tf.nn.ctc_loss( 38 | labels=samples["output"], 39 | logits=logits, 40 | logit_length=logit_length, 41 | label_length=samples["output_length"], 42 | logits_time_major=self.logits_time_major, 43 | blank_index=self.blank_index, 44 | ) 45 | return tf.reduce_mean(ctc_loss) 46 | 47 | 48 | class Seq2SeqSparseCategoricalCrossentropy(tf.keras.losses.CategoricalCrossentropy): 49 | """ Seq2SeqSparseCategoricalCrossentropy LOSS 50 | CategoricalCrossentropy calculated at each character for each sequence in a batch 51 | """ 52 | 53 | def __init__(self, num_classes, eos=-1, by_token=False, by_sequence=True, 54 | from_logits=True, label_smoothing=0.0): 55 | super().__init__(from_logits=from_logits, label_smoothing=label_smoothing, reduction="none") 56 | self.by_token = by_token 57 | self.by_sequence = by_sequence 58 | self.num_classes = num_classes 59 | self.eos = num_classes + eos if eos < 0 else eos 60 | 61 | def __call__(self, logits, samples, logit_length=None): 62 | labels = insert_eos_in_labels(samples["output"], self.eos, samples["output_length"]) 63 | mask = tf.math.logical_not(tf.math.equal(labels, 0)) 64 | labels = tf.one_hot(indices=labels, depth=self.num_classes) 65 | seq_len = tf.shape(labels)[1] 66 | logits = logits[:, :seq_len, :] 67 | loss = self.call(labels, logits) 68 | mask = tf.cast(mask, dtype=loss.dtype) 69 | loss *= mask 70 | if self.by_token: 71 | return tf.divide(tf.reduce_sum(loss), tf.reduce_sum(mask)) 72 | if self.by_sequence: 73 | loss = tf.reduce_sum(loss, axis=-1) 74 | return tf.reduce_mean(loss) 75 | 76 | 77 | class MPCLoss(tf.keras.losses.Loss): 78 | """MPC LOSS 79 | L1 loss for each masked acoustic features in a batch 80 | """ 81 | 82 | def __init__(self, name="MPCLoss"): 83 | super().__init__(name=name) 84 | 85 | def __call__(self, logits, samples, logit_length=None): 86 | target = samples["output"] 87 | shape = tf.shape(logits) 88 | target = tf.reshape(target, shape) 89 | loss = target - logits 90 | # mpc mask 91 | mask = tf.cast(tf.math.equal(tf.reshape(samples["input"], shape), 0), loss.dtype) 92 | loss *= mask 93 | # sequence length mask 94 | seq_mask = tf.sequence_mask(logit_length, shape[1], dtype=loss.dtype) 95 | seq_mask = tf.tile(seq_mask[:, :, tf.newaxis], [1, 1, shape[2]]) 96 | loss *= seq_mask 97 | loss = tf.reduce_sum(tf.abs(loss, name="L1_loss"), 2) 98 | loss = tf.reduce_mean(loss) 99 | return loss 100 | -------------------------------------------------------------------------------- /athena/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) ATHENA AUTHORS 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ module """ 17 | -------------------------------------------------------------------------------- /athena/models/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support eager mode 18 | # pylint: disable=useless-super-delegation, unused-argument, no-self-use 19 | 20 | """ base model for models """ 21 | from absl import logging 22 | import tensorflow as tf 23 | 24 | class BaseModel(tf.keras.Model): 25 | """Base class for model.""" 26 | 27 | def __init__(self, **kwargs): 28 | super().__init__(**kwargs) 29 | self.loss_function = None 30 | self.metric = None 31 | 32 | def call(self, samples, training=None): 33 | """ call model """ 34 | raise NotImplementedError() 35 | 36 | #pylint: disable=not-callable 37 | def get_loss(self, logits, samples, training=None): 38 | """ get loss """ 39 | if self.loss_function is None: 40 | loss = 0.0 41 | else: 42 | logit_length = self.compute_logit_length(samples) 43 | loss = self.loss_function(logits, samples, logit_length) 44 | if self.metric is None: 45 | metrics = {} 46 | else: 47 | self.metric(logits, samples, logit_length) 48 | metrics = {self.metric.name: self.metric.result()} 49 | return loss, metrics 50 | 51 | def compute_logit_length(self, samples): 52 | """ compute the logit length """ 53 | return samples["input_length"] 54 | 55 | def reset_metrics(self): 56 | """ reset the metrics """ 57 | if self.metric is not None: 58 | self.metric.reset_states() 59 | 60 | def prepare_samples(self, samples): 61 | """ for special data prepare 62 | carefully: do not change the shape of samples 63 | """ 64 | return samples 65 | 66 | def restore_from_pretrained_model(self, pretrained_model, model_type=""): 67 | """ restore from pretrained model 68 | """ 69 | logging.info("restore from pretrained model") 70 | 71 | def decode(self, samples, hparams): 72 | """ decode interface 73 | """ 74 | logging.info("sorry, this model do not support decode") 75 | -------------------------------------------------------------------------------- /athena/models/customized.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Only support eager mode and TF>=2.0.0 16 | # pylint: disable=no-member, invalid-name, relative-beyond-top-level 17 | """ a implementation of customized model """ 18 | 19 | import tensorflow as tf 20 | from absl import logging 21 | from athena import layers as athena_layers 22 | from athena.utils.hparam import register_and_parse_hparams 23 | from .base import BaseModel 24 | 25 | 26 | def build_layer(input_data, layer_config): 27 | """ 28 | Build one or multiple layers from layer configuration. 29 | args: 30 | layer_config: list of multiple layers, or dict of single layer parameters. 31 | returns: a Keras layer. 32 | """ 33 | if isinstance(layer_config, list): 34 | # Recursively build each layer. 35 | output = input_data 36 | for one_layer_config in layer_config: 37 | logging.info(f"Layer conf: {one_layer_config}") 38 | output = build_layer(output, one_layer_config) 39 | return output 40 | 41 | # Build one layer. 42 | for layer_type in layer_config: 43 | layer_args = layer_config[layer_type] 44 | logging.info(f"Final layer config: {layer_type}: {layer_args}") 45 | layer_cls = getattr(athena_layers, layer_type) 46 | layer = layer_cls(**layer_args) 47 | return layer(input_data) 48 | 49 | 50 | class CustomizedModel(BaseModel): 51 | """ a simple customized model """ 52 | default_config = { 53 | "topo": [{}] 54 | } 55 | def __init__(self, num_classes, sample_shape, config=None): 56 | super().__init__() 57 | self.hparams = register_and_parse_hparams(self.default_config, config) 58 | 59 | logging.info(f"Network topology config: {self.hparams.topo}") 60 | input_feature = tf.keras.layers.Input(shape=sample_shape["input"], dtype=tf.float32) 61 | inner = build_layer(input_feature, self.hparams.topo) 62 | inner = tf.keras.layers.Dense(num_classes)(inner) 63 | self.model = tf.keras.Model(inputs=input_feature, outputs=inner) 64 | logging.info(self.model.summary()) 65 | 66 | def call(self, samples, training=None): 67 | return self.model(samples["input"], training=training) 68 | -------------------------------------------------------------------------------- /athena/models/deep_speech.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # Only support eager mode and TF>=2.0.0 16 | # pylint: disable=no-member, invalid-name, relative-beyond-top-level 17 | # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes 18 | """ a implementation of deep speech 2 model can be used as a sample for ctc model """ 19 | 20 | import tensorflow as tf 21 | from absl import logging 22 | from ..utils.hparam import register_and_parse_hparams 23 | from .base import BaseModel 24 | from ..loss import CTCLoss 25 | from ..metrics import CTCAccuracy 26 | from ..layers.commons import SUPPORTED_RNNS 27 | 28 | 29 | class DeepSpeechModel(BaseModel): 30 | """ a sample implementation of CTC model """ 31 | default_config = { 32 | "conv_filters": 256, 33 | "rnn_hidden_size": 1024, 34 | "num_rnn_layers": 6, 35 | "rnn_type": "gru" 36 | } 37 | def __init__(self, num_classes, sample_shape, config=None): 38 | super().__init__() 39 | self.num_classes = num_classes + 1 40 | self.loss_function = CTCLoss(blank_index=-1) 41 | self.metric = CTCAccuracy() 42 | self.hparams = register_and_parse_hparams(self.default_config, config, cls=self.__class__) 43 | 44 | layers = tf.keras.layers 45 | input_feature = layers.Input(shape=sample_shape["input"], dtype=tf.float32) 46 | inner = layers.Conv2D( 47 | filters=self.hparams.conv_filters, 48 | kernel_size=(41, 11), 49 | strides=(2, 2), 50 | padding="same", 51 | use_bias=False, 52 | )(input_feature) 53 | inner = layers.BatchNormalization()(inner) 54 | inner = tf.nn.relu6(inner) 55 | inner = layers.Conv2D( 56 | filters=self.hparams.conv_filters, 57 | kernel_size=(21, 11), 58 | strides=(2, 1), 59 | padding="same", 60 | use_bias=False, 61 | )(inner) 62 | inner = layers.BatchNormalization()(inner) 63 | inner = tf.nn.relu6(inner) 64 | _, _, dim, channels = inner.get_shape().as_list() 65 | output_dim = dim * channels 66 | inner = layers.Reshape((-1, output_dim))(inner) 67 | rnn_type = self.hparams.rnn_type 68 | rnn_hidden_size = self.hparams.rnn_hidden_size 69 | 70 | for _ in range(self.hparams.num_rnn_layers): 71 | inner = tf.keras.layers.RNN( 72 | cell=[SUPPORTED_RNNS[rnn_type](rnn_hidden_size)], 73 | return_sequences=True 74 | )(inner) 75 | inner = layers.BatchNormalization()(inner) 76 | inner = layers.Dense(rnn_hidden_size, activation=tf.nn.relu6)(inner) 77 | inner = layers.Dense(self.num_classes)(inner) 78 | self.net = tf.keras.Model(inputs=input_feature, outputs=inner) 79 | logging.info(self.net.summary()) 80 | 81 | def call(self, samples, training=None): 82 | """ call function """ 83 | return self.net(samples["input"], training=training) 84 | 85 | def compute_logit_length(self, samples): 86 | """ used for get logit length """ 87 | input_length = tf.cast(samples["input_length"], tf.float32) 88 | logit_length = tf.math.ceil(input_length / 2) 89 | logit_length = tf.math.ceil(logit_length / 2) 90 | logit_length = tf.cast(logit_length, tf.int32) 91 | return logit_length 92 | -------------------------------------------------------------------------------- /athena/models/rnn_lm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li; Xiaoning Lei 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support eager mode 18 | # pylint: disable=no-member, invalid-name 19 | """ RNN language model implementation""" 20 | 21 | import tensorflow as tf 22 | from .base import BaseModel 23 | from ..utils.misc import insert_eos_in_labels, insert_sos_in_labels 24 | from ..utils.hparam import register_and_parse_hparams 25 | from ..layers.commons import SUPPORTED_RNNS 26 | 27 | class RNNLM(BaseModel): 28 | """Standard implementation of a RNNLM. Model mainly consists of embeding layer, 29 | rnn layers(with dropout), and the full connection layer, which are all incuded 30 | in self.model_for_rnn 31 | """ 32 | default_config = { 33 | "d_model": 512, # the dim of model 34 | "rnn_type": 'lstm', # the supported rnn type 35 | "num_layer": 2, # the number of rnn layer 36 | "dropout_rate": 0.1, # dropout for model 37 | "sos": -1, # sos can be -1 or -2 38 | "eos": -1 # eos can be -1 or -2 39 | } 40 | def __init__(self, num_classes, sample_shape, config=None): 41 | """ config including the params for build lm """ 42 | super(RNNLM, self).__init__() 43 | p = register_and_parse_hparams(self.default_config, config) 44 | self.num_classes = ( 45 | num_classes + 1 46 | if p.sos == p.eos 47 | else num_classes + 2 48 | ) 49 | self.sos = self.num_classes + p.sos 50 | self.eos = self.num_classes + p.eos 51 | self.metric = tf.keras.metrics.Mean(name="AverageLoss") 52 | 53 | layers = tf.keras.layers 54 | input_features = layers.Input(shape=sample_shape["output"], dtype=tf.int32) 55 | inner = tf.keras.layers.Embedding(self.num_classes, p.d_model)(input_features) 56 | for _ in range(p.num_layer): 57 | inner = tf.keras.layers.Dropout(p.dropout_rate)(inner) 58 | inner = tf.keras.layers.RNN( 59 | cell=[SUPPORTED_RNNS[p.rnn_type](p.d_model)], 60 | return_sequences=True 61 | )(inner) 62 | inner = tf.keras.layers.Dropout(p.dropout_rate)(inner) 63 | inner = tf.keras.layers.Dense(self.num_classes)(inner) 64 | self.rnnlm = tf.keras.Model(inputs=input_features, outputs=inner) 65 | 66 | def call(self, samples, training: bool = None): 67 | x = insert_sos_in_labels(samples['input'], self.sos) 68 | return self.rnnlm(x, training=training) 69 | 70 | def save_model(self, path): 71 | """ 72 | for saving model and current weight, path is h5 file name, like 'my_model.h5' 73 | usage: 74 | new_model = tf.keras.models.load_model(path) 75 | """ 76 | self.rnnlm.save(path) 77 | 78 | def get_loss(self, logits, samples, training=None): 79 | """ get loss """ 80 | labels = samples['output'] 81 | labels = insert_eos_in_labels(labels, self.eos, samples['output_length']) 82 | labels = tf.one_hot(labels, self.num_classes) 83 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) 84 | n_token = tf.cast(tf.reduce_sum(samples['output_length'] + 1), tf.float32) 85 | self.metric.update_state(loss) 86 | metrics = {self.metric.name: self.metric.result()} 87 | return tf.reduce_sum(loss) / n_token, metrics 88 | -------------------------------------------------------------------------------- /athena/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) ATHENA AUTHORS 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ tools """ 17 | -------------------------------------------------------------------------------- /athena/tools/lm_scorer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kenlm 3 | 4 | 5 | class NGramScorer(object): 6 | """ 7 | KenLM language model 8 | """ 9 | 10 | def __init__(self, lm_path, sos, eos, num_syms, lm_weight=0.1): 11 | """ 12 | Basic params will be initialized, the kenlm model will be created from 13 | the lm_path 14 | Args: 15 | lm_path: the saved lm model path 16 | sos: start symbol 17 | eos: end symbol 18 | num_syms: number of classes 19 | lm_weight: the lm weight 20 | """ 21 | self.lang_model = kenlm.Model(lm_path) 22 | self.state_index = 0 23 | self.sos = sos 24 | self.eos = eos 25 | self.num_syms = num_syms 26 | self.lm_weight = lm_weight 27 | kenlm_state = kenlm.State() 28 | self.lang_model.BeginSentenceWrite(kenlm_state) 29 | self.cand_kenlm_states = np.array([[kenlm_state] * num_syms]) 30 | 31 | def score(self, candidate_holder, new_scores): 32 | """ 33 | Call this function to compute the NGram score of the next prediction 34 | based on historical predictions, the scoring function shares a common interface 35 | Args: 36 | candidate_holder: 37 | Returns: 38 | score: the NGram weighted score 39 | cand_states: 40 | """ 41 | cand_seqs = candidate_holder.cand_seqs 42 | cand_parents = candidate_holder.cand_parents 43 | cand_syms = cand_seqs[:, -1] 44 | score = self.get_score(cand_parents, cand_syms, self.lang_model) 45 | score = self.lm_weight * score 46 | return score, candidate_holder.cand_states 47 | 48 | def get_score(self, cand_parents, cand_syms, lang_model): 49 | """ 50 | the saved lm model will be called here 51 | Args: 52 | cand_parents: last selected top candidates 53 | cand_syms: last selected top char index 54 | lang_model: the language model 55 | Return: 56 | scores: the lm scores 57 | """ 58 | scale = 1.0 / np.log10(np.e) # convert log10 to ln 59 | 60 | num_cands = len(cand_syms) 61 | scores = np.zeros((num_cands, self.num_syms)) 62 | new_states = np.zeros((num_cands, self.num_syms), dtype=object) 63 | chars = [str(x) for x in range(self.num_syms)] 64 | chars[self.sos] = "" 65 | chars[self.eos] = "" 66 | chars[0] = "" 67 | 68 | for i in range(num_cands): 69 | parent = cand_parents[i] 70 | kenlm_state_list = self.cand_kenlm_states[parent] 71 | kenlm_state = kenlm_state_list[cand_syms[i]] 72 | for sym in range(self.num_syms): 73 | char = chars[sym] 74 | out_state = kenlm.State() 75 | score = scale * lang_model.BaseScore(kenlm_state, char, out_state) 76 | scores[i, sym] = score 77 | new_states[i, sym] = out_state 78 | self.cand_kenlm_states = new_states 79 | return scores 80 | -------------------------------------------------------------------------------- /athena/tools/sph2pipe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/athena/tools/sph2pipe -------------------------------------------------------------------------------- /athena/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from athena.transform import audio_featurizer 17 | 18 | from athena.transform.audio_featurizer import AudioFeaturizer 19 | from athena.transform.feats.cmvn import compute_cmvn 20 | from athena.transform.feats.read_wav import read_wav 21 | -------------------------------------------------------------------------------- /athena/transform/audio_featurizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model provides a general interface for feature extraction.""" 17 | 18 | import tensorflow as tf 19 | from athena.transform import feats 20 | 21 | 22 | class AudioFeaturizer: 23 | """ 24 | Interface of audio Features extractions. 25 | """ 26 | #pylint: disable=dangerous-default-value 27 | def __init__(self, config={"type": "Fbank"}): 28 | """init 29 | :param name Feature name, eg fbank, mfcc, plp ... 30 | :param config 31 | 'type': 'ReadWav', 'Fbank', 'Spectrum' 32 | The config for fbank 33 | 'sample_rate' 16000 34 | 'window_length' 0.025 35 | 'frame_length' 0.010 36 | 'upper_frequency_limit' 20 37 | 'filterbank_channel_count' 40 38 | The config for Spectrum 39 | 'sample_rate' 16000 40 | 'window_length' 0.025 41 | 'frame_length' 0.010 42 | 'output_type' 1 43 | """ 44 | 45 | assert "type" in config 46 | 47 | self.name = config["type"] 48 | self.feat = getattr(feats, self.name).params(config).instantiate() 49 | 50 | if self.name != "ReadWav": 51 | self.read_wav = getattr(feats, "ReadWav").params(config).instantiate() 52 | 53 | #pylint:disable=invalid-name 54 | def __call__(self, audio=None, sr=None, speed=1.0): 55 | """extract feature from audo data 56 | :param audio data or audio file 57 | :sr sample rate 58 | :return feature 59 | """ 60 | 61 | if audio is not None and not tf.is_tensor(audio): 62 | audio = tf.convert_to_tensor(audio) 63 | if sr is not None and not tf.is_tensor(sr): 64 | sr = tf.convert_to_tensor(sr) 65 | 66 | return self.__impl(audio, sr, speed) 67 | 68 | @tf.function 69 | def __impl(self, audio=None, sr=None, speed=1.0): 70 | """ 71 | :param audio data or audio file, a tensor 72 | :sr sample rate, a tensor 73 | :return feature 74 | """ 75 | if self.name == "ReadWav" or self.name == "CMVN": 76 | return self.feat(audio, speed) 77 | elif audio.dtype is tf.string: 78 | audio_data, sr = self.read_wav(audio, speed) 79 | return self.feat(audio_data, sr) 80 | else: 81 | return self.feat(audio, sr) 82 | 83 | @property 84 | def dim(self): 85 | """return the dimension of the feature 86 | if only ReadWav, return 1 87 | """ 88 | return self.feat.dim() 89 | 90 | @property 91 | def num_channels(self): 92 | """return the channel of the feature""" 93 | return self.feat.num_channels() 94 | -------------------------------------------------------------------------------- /athena/transform/feats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | from athena.transform.feats.read_wav import ReadWav 17 | from athena.transform.feats.spectrum import Spectrum 18 | from athena.transform.feats.framepow import Framepow 19 | from athena.transform.feats.pitch import Pitch 20 | from athena.transform.feats.mfcc import Mfcc 21 | from athena.transform.feats.write_wav import WriteWav 22 | from athena.transform.feats.fbank import Fbank 23 | from athena.transform.feats.cmvn import CMVN, compute_cmvn 24 | from athena.transform.feats.fbank_pitch import FbankPitch 25 | -------------------------------------------------------------------------------- /athena/transform/feats/base_frontend.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ base interface of Frontend """ 17 | 18 | import abc 19 | import tensorflow as tf 20 | 21 | 22 | class ABCFrontend(metaclass=abc.ABCMeta): 23 | """ abstract of Frontend """ 24 | 25 | def __init__(self, config): 26 | raise NotImplementedError() 27 | 28 | @abc.abstractmethod 29 | def call(self, *args): 30 | """ implementation func """ 31 | raise NotImplementedError() 32 | 33 | 34 | class BaseFrontend(ABCFrontend): 35 | """ wrapper of abstrcat Frontend""" 36 | 37 | def __init__(self, config: dict): 38 | self._config = config 39 | 40 | @property 41 | def config(self): 42 | """ config property """ 43 | return self._config 44 | 45 | @classmethod 46 | def params(cls, config=None): 47 | """ set params """ 48 | raise NotImplementedError() 49 | 50 | def __call__(self, *args): 51 | """ call """ 52 | return self.call(*args) 53 | 54 | def dim(self): 55 | return 1 56 | 57 | def num_channels(self): 58 | return 1 59 | -------------------------------------------------------------------------------- /athena/transform/feats/cmvn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """This model doed CMVN on features.""" 17 | 18 | import numpy as np 19 | 20 | import tensorflow as tf 21 | 22 | from athena.utils.hparam import HParams 23 | from athena.transform.feats.base_frontend import BaseFrontend 24 | 25 | 26 | class CMVN(BaseFrontend): 27 | """ 28 | Do CMVN on features. 29 | """ 30 | def __init__(self, config: dict): 31 | super().__init__(config) 32 | 33 | self.global_cmvn = False 34 | if len(config.global_mean) > 1: 35 | self.global_cmvn = True 36 | 37 | @classmethod 38 | def params(cls, config=None): 39 | """ set params """ 40 | 41 | hparams = HParams(cls=cls) 42 | hparams.add_hparam("type", "CMVN") 43 | hparams.add_hparam("global_mean", [0.0]) 44 | hparams.add_hparam("global_variance", [1.0]) 45 | hparams.add_hparam("local_cmvn", False) 46 | 47 | if config is not None: 48 | hparams.parse(config, True) 49 | 50 | assert len(hparams.global_mean) == len( 51 | hparams.global_variance 52 | ), "Error, global_mean length {} is not equals to global_variance length {}".format( 53 | len(hparams.global_mean), len(hparams.global_variance) 54 | ) 55 | 56 | hparams.global_variance = (np.sqrt(hparams.global_variance) + 1e-6).tolist() 57 | return hparams 58 | 59 | def call(self, audio_feature, speed=1.0): 60 | params = self.config 61 | if self.global_cmvn: 62 | audio_feature = ( 63 | audio_feature - params.global_mean 64 | ) / params.global_variance 65 | 66 | if params.local_cmvn: 67 | mean, var = tf.compat.v1.nn.moments(audio_feature, axes=0) 68 | audio_feature = (audio_feature - mean) / ( 69 | tf.compat.v1.math.sqrt(var) + 1e-6 70 | ) 71 | 72 | return audio_feature 73 | 74 | def dim(self): 75 | params = self.config 76 | return len(params.global_mean) 77 | 78 | 79 | def compute_cmvn(audio_feature, mean=None, variance=None, local_cmvn=False): 80 | if mean is not None: 81 | assert variance is not None 82 | audio_feature = (audio_feature - mean) / variance 83 | if local_cmvn: 84 | mean, var = tf.compat.v1.nn.moments(audio_feature, axes=0) 85 | audio_feature = (audio_feature - mean) / (tf.compat.v1.math.sqrt(var) + 1e-6) 86 | return audio_feature 87 | -------------------------------------------------------------------------------- /athena/transform/feats/cmvn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests CMVN OP.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from athena.transform.feats.cmvn import CMVN 21 | 22 | 23 | class CMVNTest(tf.test.TestCase): 24 | """ 25 | CMVN test. 26 | """ 27 | def test_cmvn(self): 28 | dim = 40 29 | cmvn = CMVN.params( 30 | {"mean": np.zeros(dim).tolist(), "variance": np.ones(dim).tolist()} 31 | ).instantiate() 32 | audio_feature = tf.random_uniform(shape=[3, 40], dtype=tf.float32, maxval=1.0) 33 | print(audio_feature) 34 | normalized = cmvn(audio_feature) 35 | print("normalized = ", normalized) 36 | print("dim is ", cmvn.dim()) 37 | 38 | cmvn = CMVN.params( 39 | { 40 | "mean": np.zeros(dim).tolist(), 41 | "variance": np.ones(dim).tolist(), 42 | "cmvn": False, 43 | } 44 | ).instantiate() 45 | normalized = cmvn(audio_feature) 46 | self.assertAllClose(audio_feature, normalized) 47 | 48 | 49 | if __name__ == "__main__": 50 | if tf.__version__ < "2.0.0": 51 | tf.compat.v1.enable_eager_execution() 52 | tf.test.main() 53 | -------------------------------------------------------------------------------- /athena/transform/feats/fbank_pitch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests Fbank&&Pitch FE.""" 17 | 18 | import os 19 | from pathlib import Path 20 | import tensorflow as tf 21 | from tensorflow.python.framework.ops import disable_eager_execution 22 | from athena.transform.feats.read_wav import ReadWav 23 | from athena.transform.feats.fbank_pitch import FbankPitch 24 | 25 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 26 | 27 | class FbankPitchTest(tf.test.TestCase): 28 | """ 29 | Fbank && Pitch extraction test. 30 | """ 31 | def test_FbankPitch(self): 32 | wav_path = str(Path(os.environ['MAIN_ROOT']).joinpath('examples/sm1_cln.wav')) 33 | 34 | with self.session(): 35 | read_wav = ReadWav.params().instantiate() 36 | input_data, sample_rate = read_wav(wav_path) 37 | config = {'window_length': 0.025, 'output_type': 1, 'frame_length': 0.010, 'dither': 0.0} 38 | fbank_pitch = FbankPitch.params(config).instantiate() 39 | fbank_pitch_test = fbank_pitch(input_data, sample_rate) 40 | 41 | if tf.executing_eagerly(): 42 | self.assertEqual(tf.rank(fbank_pitch_test).numpy(), 3) 43 | print(fbank_pitch_test.numpy()[0:2, :, 0]) 44 | else: 45 | self.assertEqual(tf.rank(fbank_pitch_test).eval(), 3) 46 | print(fbank_pitch_test.eval()[0:2, :, 0]) 47 | 48 | if __name__ == '__main__': 49 | 50 | is_eager = True 51 | if not is_eager: 52 | disable_eager_execution() 53 | else: 54 | if tf.__version__ < '2.0.0': 55 | tf.compat.v1.enable_eager_execution() 56 | tf.test.main() 57 | -------------------------------------------------------------------------------- /athena/transform/feats/fbank_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import os 18 | from pathlib import Path 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow.python.framework.ops import disable_eager_execution 22 | from athena.transform.feats.read_wav import ReadWav 23 | from athena.transform.feats.fbank import Fbank 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 26 | 27 | 28 | class FbankTest(tf.test.TestCase): 29 | def test_fbank(self): 30 | # 16kHz && 8kHz test 31 | wav_path_16k = str( 32 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav") 33 | ) 34 | wav_path_8k = str( 35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav") 36 | ) 37 | 38 | with self.session(): 39 | # value test 40 | read_wav = ReadWav.params().instantiate() 41 | input_data, sample_rate = read_wav(wav_path_16k) 42 | fbank = Fbank.params({"delta_delta": False}).instantiate() 43 | fbank_test = fbank(input_data, sample_rate) 44 | real_fank_feats = np.array( 45 | [ 46 | [3.768338, 4.946218, 6.289874, 6.330853, 6.761764, 6.884573], 47 | [3.803553, 5.450971, 6.547878, 5.796172, 6.397846, 7.242926], 48 | ] 49 | ) 50 | # self.assertAllClose(np.squeeze(fbank_test.eval()[0:2, 0:6, 0]), 51 | # real_fank_feats, rtol=1e-05, atol=1e-05) 52 | if tf.executing_eagerly(): 53 | print(fbank_test.numpy()[0:2, 0:6, 0]) 54 | else: 55 | print(fbank_test.eval()[0:2, 0:6, 0]) 56 | count = 1 57 | 58 | for wav_file in [wav_path_8k, wav_path_16k]: 59 | 60 | read_wav = ReadWav.params().instantiate() 61 | input_data, sample_rate = read_wav(wav_file) 62 | if tf.executing_eagerly(): 63 | print(wav_file, sample_rate.numpy()) 64 | else: 65 | print(wav_file, sample_rate.eval()) 66 | 67 | conf = { 68 | "delta_delta": True, 69 | "lower_frequency_limit": 100, 70 | "upper_frequency_limit": 0, 71 | } 72 | fbank = Fbank.params(conf).instantiate() 73 | fbank_test = fbank(input_data, sample_rate) 74 | if tf.executing_eagerly(): 75 | print(fbank_test.numpy()) 76 | else: 77 | print(fbank_test.eval()) 78 | print(fbank.num_channels()) 79 | 80 | conf = { 81 | "delta_delta": False, 82 | "lower_frequency_limit": 100, 83 | "upper_frequency_limit": 0, 84 | } 85 | fbank = Fbank.params(conf).instantiate() 86 | fbank_test = fbank(input_data, sample_rate) 87 | print(fbank_test) 88 | print(fbank.num_channels()) 89 | count += 1 90 | del read_wav 91 | del fbank 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | is_eager = True 97 | if not is_eager: 98 | disable_eager_execution() 99 | else: 100 | if tf.__version__ < "2.0.0": 101 | tf.compat.v1.enable_eager_execution() 102 | tf.test.main() 103 | -------------------------------------------------------------------------------- /athena/transform/feats/framepow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """"This model extracts framepow features per frame.""" 17 | 18 | import tensorflow as tf 19 | from athena.utils.hparam import HParams 20 | from athena.transform.feats.ops import py_x_ops 21 | from athena.transform.feats.base_frontend import BaseFrontend 22 | 23 | 24 | class Framepow(BaseFrontend): 25 | """ 26 | Compute power of every frame in speech. Return a float tensor with 27 | shape (1 * num_frames). 28 | """ 29 | def __init__(self, config: dict): 30 | super().__init__(config) 31 | 32 | @classmethod 33 | def params(cls, config=None): 34 | """ 35 | Set params. 36 | :param config: contains four optional parameters: 37 | --window_length : Window length in seconds. (float, default = 0.025) 38 | --frame_length : Hop length in seconds. (float, default = 0.010) 39 | --snip_edges : If True, the last frame (shorter than window_length) 40 | will be cutoff. If False, 1 // 2 frame_length data will 41 | be padded to data. (int, default = True) 42 | --remove_dc_offset : Subtract mean from waveform on each frame (bool, default = true) 43 | :return:An object of class HParams, which is a set of hyperparameters as name-value pairs. 44 | """ 45 | 46 | window_length = 0.025 47 | frame_length = 0.010 48 | snip_edges = 1 49 | remove_dc_offset = True 50 | 51 | hparams = HParams(cls=cls) 52 | hparams.add_hparam("window_length", window_length) 53 | hparams.add_hparam("frame_length", frame_length) 54 | hparams.add_hparam("snip_edges", snip_edges) 55 | hparams.add_hparam("remove_dc_offset", remove_dc_offset) 56 | 57 | if config is not None: 58 | hparams.parse(config, True) 59 | 60 | return hparams 61 | 62 | def call(self, audio_data, sample_rate): 63 | """ 64 | Caculate power of every frame in speech. 65 | :param audio_data: the audio signal from which to compute spectrum. 66 | Should be an (1, N) tensor. 67 | :param sample_rate: [option]the samplerate of the signal we working with, 68 | default is 16kHz. 69 | :return:A float tensor of size (1 * num_frames) containing power of every 70 | frame in speech. 71 | """ 72 | 73 | p = self.config 74 | with tf.name_scope('framepow'): 75 | sample_rate = tf.cast(sample_rate, dtype=float) 76 | framepow = py_x_ops.frame_pow( 77 | audio_data, 78 | sample_rate, 79 | snip_edges=p.snip_edges, 80 | remove_dc_offset=p.remove_dc_offset, 81 | window_length=p.window_length, 82 | frame_length=p.frame_length) 83 | 84 | return tf.squeeze(framepow) 85 | 86 | def dim(self): 87 | return 1 -------------------------------------------------------------------------------- /athena/transform/feats/framepow_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests framepow FE.""" 17 | 18 | import os 19 | from pathlib import Path 20 | import numpy as np 21 | import tensorflow as tf 22 | from tensorflow.python.framework.ops import disable_eager_execution 23 | from athena.transform.feats.read_wav import ReadWav 24 | from athena.transform.feats.framepow import Framepow 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 27 | 28 | 29 | class FramePowTest(tf.test.TestCase): 30 | """ 31 | Framepow extraction test. 32 | """ 33 | def test_framepow(self): 34 | wav_path_16k = str( 35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav") 36 | ) 37 | 38 | with self.session(): 39 | read_wav = ReadWav.params().instantiate() 40 | input_data, sample_rate = read_wav(wav_path_16k) 41 | config = {"snip_edges": 1} 42 | framepow = Framepow.params(config).instantiate() 43 | framepow_test = framepow(input_data, sample_rate) 44 | 45 | real_framepow_feats = np.array( 46 | [9.819611, 9.328745, 9.247337, 9.26451, 9.266059] 47 | ) 48 | 49 | if tf.executing_eagerly(): 50 | self.assertAllClose( 51 | framepow_test.numpy()[0:5], 52 | real_framepow_feats, 53 | rtol=1e-05, 54 | atol=1e-05, 55 | ) 56 | print(framepow_test.numpy()[0:5]) 57 | else: 58 | self.assertAllClose( 59 | framepow_test.eval()[0:5], 60 | real_framepow_feats, 61 | rtol=1e-05, 62 | atol=1e-05, 63 | ) 64 | print(framepow_test.eval()[0:5]) 65 | 66 | 67 | if __name__ == "__main__": 68 | is_eager = True 69 | if not is_eager: 70 | disable_eager_execution() 71 | else: 72 | if tf.__version__ < "2.0.0": 73 | tf.compat.v1.enable_eager_execution() 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /athena/transform/feats/mfcc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests MFCC FE.""" 17 | 18 | import os 19 | from pathlib import Path 20 | import numpy as np 21 | import tensorflow as tf 22 | from tensorflow.python.framework.ops import disable_eager_execution 23 | from athena.transform.feats.read_wav import ReadWav 24 | from athena.transform.feats.mfcc import Mfcc 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 27 | 28 | 29 | class MfccTest(tf.test.TestCase): 30 | """ 31 | MFCC extraction test. 32 | """ 33 | def test_mfcc(self): 34 | wav_path_16k = str( 35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav") 36 | ) 37 | 38 | with self.session(): 39 | read_wav = ReadWav.params().instantiate() 40 | input_data, sample_rate = read_wav(wav_path_16k) 41 | config = {"use_energy": True} 42 | mfcc = Mfcc.params(config).instantiate() 43 | mfcc_test = mfcc(input_data, sample_rate) 44 | 45 | real_mfcc_feats = np.array( 46 | [ 47 | [9.819611, -30.58736, -7.088838, -10.67966, -1.646479, -4.36086], 48 | [9.328745, -30.73371, -6.128432, -7.930599, 3.208357, -1.086456], 49 | ] 50 | ) 51 | 52 | if tf.executing_eagerly(): 53 | self.assertAllClose( 54 | mfcc_test.numpy()[0, 0:2, 0:6], 55 | real_mfcc_feats, 56 | rtol=1e-05, 57 | atol=1e-05, 58 | ) 59 | else: 60 | self.assertAllClose( 61 | mfcc_test.eval()[0, 0:2, 0:6], 62 | real_mfcc_feats, 63 | rtol=1e-05, 64 | atol=1e-05, 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | is_eager = True 71 | if not is_eager: 72 | disable_eager_execution() 73 | else: 74 | if tf.__version__ < "2.0.0": 75 | tf.compat.v1.enable_eager_execution() 76 | tf.test.main() 77 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/Makefile: -------------------------------------------------------------------------------- 1 | # Find where we're running from, so we can store generated files here. 2 | 3 | ifeq ($(origin MAKEFILE_DIR), undefined) 4 | MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) 5 | MAIN_ROOT := $(realpath $(MAKEFILE_DIR)/../../) 6 | endif 7 | 8 | #$(info $(MAKEFILE_DIR)) 9 | #$(info $(MAIN_ROOT)) 10 | 11 | CXX := g++ 12 | NVCC := nvcc 13 | PYTHON_BIN_PATH= python3 14 | CC := 15 | AR := 16 | CXXFLAGS := 17 | LDFLAGS := 18 | STDLIB := 19 | 20 | # Try to figure out the host system 21 | HOST_OS := 22 | ifeq ($(OS),Windows_NT) 23 | HOST_OS = windows 24 | else 25 | UNAME_S := $(shell uname -s) 26 | ifeq ($(UNAME_S),Linux) 27 | HOST_OS := linux 28 | endif 29 | ifeq ($(UNAME_S),Darwin) 30 | HOST_OS := ios 31 | endif 32 | endif 33 | 34 | #HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) 35 | HOST_ARCH=x86_64 36 | TARGET := $(HOST_OS) 37 | TARGET_ARCH := $(HOST_ARCH) 38 | 39 | GENDIR := $(MAKEFILE_DIR)/gen/ 40 | TGTDIR := $(GENDIR)$(TARGET)_$(TARGET_ARCH)/ 41 | OBJDIR := $(TGTDIR)obj/ 42 | BINDIR := $(TGTDIR)bin/ 43 | LIBDIR := $(TGTDIR)lib/ 44 | 45 | 46 | TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 47 | # Fix TF LDFLAGS issue on macOS. 48 | TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))' | sed "s/-l:libtensorflow_framework.1.dylib/-ltensorflow_framework.1/") 49 | #TF_INCLUDES := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 50 | TF_LIBS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 51 | CXXFLAGS += -fPIC -shared -O2 -std=c++11 -DFEATURE_VERSION=\"$(shell git rev-parse --short HEAD)\" $(TF_CFLAGS) 52 | INCLUDES := -I$(MAIN_ROOT) \ 53 | -I$(MAIN_ROOT)/feats/ops \ 54 | 55 | LDFLAGS += $(TF_LFLAGS) 56 | 57 | CORE_CC_EXCLUDE_SRCS := \ 58 | $(wildcard kernels/*test.cc) \ 59 | $(wildcard kernels/*test_util.cc) 60 | 61 | # src and tgts 62 | LIB_SRCS_ALL := $(wildcard kernels/*.cc) 63 | LIB_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(LIB_SRCS_ALL)) 64 | LIB_OBJS := $(addprefix $(OBJDIR), $(patsubst %.cc, %.o, $(patsubst %.c, %.o, $(LIB_SRCS)))) 65 | 66 | # lib 67 | SHARED_LIB := x_ops.so 68 | 69 | TEST_SRC := $(wildcard kernels/*_test.cc) 70 | TEST_OBJ := $(addprefix $(OBJDIR), $(patsubst %.cc, $(OBJS_DIR)%.o, $(TEST_SRC))) 71 | TEST_BIN := $(addprefix $(BINDIR), $(patsubst %.cc, $(OBJS_DIR)%.bin, $(TEST_SRC))) 72 | #TEST_BIN := $(BINDIR)test 73 | 74 | all: $(SHARED_LIB) $(TEST_BIN) 75 | 76 | $(OBJDIR)%.o: %.cc 77 | @mkdir -p $(dir $@) 78 | $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ $(LDFLAGS) 79 | 80 | $(SHARED_LIB): $(LIB_OBJS) 81 | @mkdir -p $(dir $@) 82 | $(CXX) -fPIC -shared -o $@ $^ $(STDLIB) $(LDFLAGS) 83 | 84 | $(STATIC_LIB): $(LIB_OBJS) 85 | @mkdir -p $(dir $@) 86 | $(AR) crsv $@ $^ 87 | 88 | ${TEST_BIN}: $(TEST_OBJ) $(STATIC_LIB) 89 | @mkdir -p $(dir $@) 90 | $(CXX) $(LDFLAGS) $^ -o $@ $(STATIC_LIB) 91 | 92 | .PHONY: clean 93 | clean: 94 | -rm -rf $(GENDIR) 95 | -rm -f x_ops.so 96 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/delta_delta.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "kernels/delta_delta.h" 18 | 19 | #include 20 | 21 | namespace delta { 22 | 23 | const int kOrder = 2; 24 | const int kWindow = 2; 25 | 26 | DeltaDelta::DeltaDelta() 27 | : initialized_(false), order_(kOrder), window_(kWindow) {} 28 | 29 | DeltaDelta::~DeltaDelta() { 30 | for (int i = 0; i < scales_.size(); i++) 31 | std::vector().swap(scales_[i]); 32 | (scales_).clear(); 33 | } 34 | 35 | bool DeltaDelta::Initialize(int order, int window) { 36 | if (order < 0 || order > 1000) { 37 | LOG(ERROR) << "order must be greater than zero and less than 1000."; 38 | return false; 39 | } 40 | if (window < 0 || window > 1000) { 41 | LOG(ERROR) << "window must be greater than zero and less than 1000."; 42 | return false; 43 | } 44 | order_ = order; 45 | window_ = window; 46 | 47 | scales_.resize(order_ + 1); 48 | scales_[0].resize(1); 49 | scales_[0][0] = 50 | 1.0; // trival window for 0th order delta [i.e. baseline feats] 51 | 52 | for (int i = 1; i <= order_; i++) { 53 | std::vector&prev_scales = scales_[i - 1], &cur_scales = scales_[i]; 54 | 55 | int window = window_; 56 | if (window == 0) { 57 | LOG(ERROR) << "window must not be zero."; 58 | return false; 59 | } 60 | int prev_offset = (static_cast(prev_scales.size() - 1)) / 2, 61 | cur_offset = prev_offset + window; 62 | cur_scales.resize(prev_scales.size() + 2 * window); 63 | 64 | double normalizer = 0.0; 65 | for (int j = -window; j <= window; j++) { 66 | normalizer += j * j; 67 | for (int k = -prev_offset; k <= prev_offset; k++) { 68 | cur_scales[j + k + cur_offset] += 69 | static_cast(j) * prev_scales[k + prev_offset]; 70 | } 71 | } 72 | 73 | for (int i = 0; i < cur_scales.size(); i++) { 74 | cur_scales[i] *= (1.0 / normalizer); 75 | } 76 | } 77 | 78 | initialized_ = true; 79 | return initialized_; 80 | } 81 | 82 | // process one frame per time 83 | void DeltaDelta::Compute(const Tensor& input_feats, int frame, 84 | std::vector* output) const { 85 | if (!initialized_) { 86 | LOG(ERROR) << "DeltaDelta not initialized."; 87 | return; 88 | } 89 | 90 | int num_frames = input_feats.dim_size(0); 91 | int feat_dim = input_feats.dim_size(1); 92 | int output_dim = feat_dim * (order_ + 1); 93 | 94 | output->resize(output_dim); 95 | auto input = input_feats.matrix(); 96 | 97 | for (int i = 0; i <= order_; i++) { 98 | const std::vector& scales = scales_[i]; 99 | int max_offset = (scales.size() - 1) / 2; 100 | // e.g. max_offset=2, (-2, 2) 101 | for (int j = -max_offset; j <= max_offset; j++) { 102 | // if asked to read 103 | int offset_frame = frame + j; 104 | if (offset_frame < 0) 105 | offset_frame = 0; 106 | else if (offset_frame >= num_frames) 107 | offset_frame = num_frames - 1; 108 | 109 | // sacles[0] for `-max_offset` frame 110 | double scale = scales[j + max_offset]; 111 | if (scale != 0.0) { 112 | for (int k = 0; k < feat_dim; k++) { 113 | (*output)[i + k * (order_ + 1)] += input(offset_frame, k) * scale; 114 | } 115 | } 116 | } 117 | } 118 | return; 119 | } 120 | 121 | } // namespace delta 122 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/delta_delta.h: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #ifndef DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_ 18 | #define DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_ 19 | 20 | #include 21 | 22 | #include "tensorflow/core/framework/op_kernel.h" 23 | #include "tensorflow/core/framework/tensor.h" 24 | #include "tensorflow/core/platform/logging.h" 25 | 26 | using namespace tensorflow; // NOLINT 27 | 28 | namespace delta { 29 | 30 | // This class provides a low-level function to compute delta features. 31 | // The function takes as input a matrix of features and a frame index 32 | // that it should compute the deltas on. It puts its output in an object 33 | // of type VectorBase, of size (original-feature-dimension) * (opts.order+1). 34 | // This is not the most efficient way to do the computation, but it's 35 | // state-free and thus easier to understand 36 | class DeltaDelta { 37 | public: 38 | DeltaDelta(); 39 | ~DeltaDelta(); 40 | 41 | bool Initialize(int order, int window); 42 | 43 | // Input is a single feature frame. Output is populated with 44 | // the feature, delta, delta-delta values. 45 | void Compute(const Tensor& input_feats, int frame, 46 | std::vector* output) const; 47 | 48 | void set_order(int order) { 49 | CHECK(!initialized_) << "Set order before calling Initialize."; 50 | order_ = order; 51 | } 52 | 53 | void set_window(int window) { 54 | CHECK(!initialized_) << "Set window before calling Initialize."; 55 | window_ = window; 56 | } 57 | 58 | private: 59 | int order_; 60 | // e.g. 2; controls window size (window size is 2*window + 1) 61 | // the behavior at the edges is to replicate the first or last frame. 62 | // this is not configurable. 63 | int window_; 64 | 65 | // a scaling window for each of the orders, including zero: 66 | // multiply the features for each dimension by this window. 67 | std::vector > scales_; 68 | 69 | bool initialized_; 70 | TF_DISALLOW_COPY_AND_ASSIGN(DeltaDelta); 71 | }; // class DeltaDelta 72 | 73 | } // namespace delta 74 | 75 | #endif // DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_ 76 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/delta_delta_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | // See docs in ../ops/audio_ops.cc 18 | #include "kernels/delta_delta.h" 19 | #include "tensorflow/core/framework/op_kernel.h" 20 | #include "tensorflow/core/framework/register_types.h" 21 | #include "tensorflow/core/framework/tensor.h" 22 | #include "tensorflow/core/framework/tensor_shape.h" 23 | #include "tensorflow/core/framework/types.h" 24 | #include "tensorflow/core/lib/core/status.h" 25 | 26 | // https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md 27 | 28 | namespace delta { 29 | 30 | // add first and seoncd order derivatives 31 | class DeltaDeltaOp : public OpKernel { 32 | public: 33 | explicit DeltaDeltaOp(OpKernelConstruction* context) : OpKernel(context) { 34 | OP_REQUIRES_OK(context, context->GetAttr("order", &order_)); 35 | OP_REQUIRES_OK(context, context->GetAttr("window", &window_)); 36 | } 37 | 38 | void Compute(OpKernelContext* context) override { 39 | const Tensor& feats = context->input(0); 40 | OP_REQUIRES(context, feats.dims() == 2, 41 | errors::InvalidArgument("features must be 2-dimensional", 42 | feats.shape().DebugString())); 43 | // feats shape [time, feat dim] 44 | const int time = feats.dim_size(0); // num frames 45 | const int feat_dim = feats.dim_size(1); 46 | const int output_dim = feat_dim * (order_ + 1); 47 | 48 | DeltaDelta delta; 49 | OP_REQUIRES( 50 | context, delta.Initialize(order_, window_), 51 | errors::InvalidArgument("DeltaDelta initialization failed for order ", 52 | order_, " and window ", window_)); 53 | 54 | Tensor* output_tensor = nullptr; 55 | OP_REQUIRES_OK(context, 56 | context->allocate_output(0, TensorShape({time, output_dim}), 57 | &output_tensor)); 58 | 59 | // TType::Tensor feats_t = feats.tensor; 60 | float* output_flat = output_tensor->flat().data(); 61 | 62 | for (int t = 0; t < time; t++) { 63 | float* row = output_flat + t * output_dim; 64 | 65 | // add delta-delta 66 | std::vector out; 67 | delta.Compute(feats, t, &out); 68 | 69 | // fill output buffer 70 | DCHECK_EQ(output_dim, out.size()); 71 | for (int i = 0; i < output_dim; i++) { 72 | row[i] = static_cast(out[i]); 73 | } 74 | } 75 | } 76 | 77 | private: 78 | int order_; 79 | int window_; 80 | }; 81 | 82 | REGISTER_KERNEL_BUILDER(Name("DeltaDelta").Device(DEVICE_CPU), DeltaDeltaOp); 83 | 84 | } // namespace delta 85 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/fbank.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "kernels/fbank.h" 18 | 19 | #include 20 | 21 | #include "tensorflow/core/platform/logging.h" 22 | 23 | namespace delta { 24 | 25 | const double kDefaultUpperFrequencyLimit = 4000; 26 | const double kDefaultLowerFrequencyLimit = 20; 27 | const double kFilterbankFloor = 1e-12; 28 | const int kDefaultFilterbankChannelCount = 40; 29 | 30 | Fbank::Fbank() 31 | : initialized_(false), 32 | lower_frequency_limit_(kDefaultLowerFrequencyLimit), 33 | upper_frequency_limit_(kDefaultUpperFrequencyLimit), 34 | filterbank_channel_count_(kDefaultFilterbankChannelCount) {} 35 | 36 | Fbank::~Fbank() {} 37 | 38 | bool Fbank::Initialize(int input_length, double input_sample_rate) { 39 | if (input_length < 1) { 40 | LOG(ERROR) << "Input length must be positive."; 41 | return false; 42 | } 43 | input_length_ = input_length; 44 | 45 | bool initialized = mel_filterbank_.Initialize( 46 | input_length, input_sample_rate, filterbank_channel_count_, 47 | lower_frequency_limit_, upper_frequency_limit_); 48 | initialized_ = initialized; 49 | return initialized; 50 | } 51 | 52 | void Fbank::Compute(const std::vector& spectrogram_frame, 53 | std::vector* output) const { 54 | if (!initialized_) { 55 | LOG(ERROR) << "Fbank not initialized."; 56 | return; 57 | } 58 | 59 | output->resize(filterbank_channel_count_); 60 | mel_filterbank_.Compute(spectrogram_frame, output); 61 | for (int i = 0; i < output->size(); ++i) { 62 | double val = (*output)[i]; 63 | if (val < kFilterbankFloor) { 64 | val = kFilterbankFloor; 65 | } 66 | (*output)[i] = log(val); 67 | } 68 | } 69 | 70 | } // namespace delta 71 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/fbank.h: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #ifndef DELTA_LAYERS_OPS_KERNELS_FBANK_H_ 18 | #define DELTA_LAYERS_OPS_KERNELS_FBANK_H_ 19 | 20 | #include // NOLINT 21 | 22 | #include "kernels/mfcc_mel_filterbank.h" 23 | 24 | #include "tensorflow/core/framework/op_kernel.h" 25 | #include "tensorflow/core/platform/logging.h" 26 | 27 | using namespace tensorflow; // NOLINT 28 | 29 | namespace delta { 30 | 31 | // logfbank 32 | class Fbank { 33 | public: 34 | Fbank(); 35 | ~Fbank(); 36 | bool Initialize(int input_length, double input_sample_rate); 37 | // Input is a single squared-magnitude spectrogram frame. The input spectrum 38 | // is converted to linear magnitude and weighted into bands using a 39 | // triangular mel filterbank. Output is populated with the lowest 40 | // fbank_channel_count 41 | // of these values. 42 | void Compute(const std::vector& spectrogram_frame, 43 | std::vector* output) const; 44 | 45 | void set_upper_frequency_limit(double upper_frequency_limit) { 46 | CHECK(!initialized_) << "Set frequency limits before calling Initialize."; 47 | upper_frequency_limit_ = upper_frequency_limit; 48 | } 49 | 50 | void set_lower_frequency_limit(double lower_frequency_limit) { 51 | CHECK(!initialized_) << "Set frequency limits before calling Initialize."; 52 | lower_frequency_limit_ = lower_frequency_limit; 53 | } 54 | 55 | void set_filterbank_channel_count(int filterbank_channel_count) { 56 | CHECK(!initialized_) << "Set channel count before calling Initialize."; 57 | filterbank_channel_count_ = filterbank_channel_count; 58 | } 59 | 60 | private: 61 | MfccMelFilterbank mel_filterbank_; 62 | int input_length_; 63 | bool initialized_; 64 | double lower_frequency_limit_; 65 | double upper_frequency_limit_; 66 | int filterbank_channel_count_; 67 | TF_DISALLOW_COPY_AND_ASSIGN(Fbank); 68 | }; // class Fbank 69 | 70 | } // namespace delta 71 | 72 | #endif // DELTA_LAYERS_OPS_KERNELS_FBANK_H_ 73 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/fbank_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ fbank op unittest""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from delta.layers.ops import py_x_ops 21 | 22 | 23 | class FbankOpTest(tf.test.TestCase): 24 | """ fbank op unittest""" 25 | 26 | def setUp(self): 27 | """ setup """ 28 | 29 | def tearDown(self): 30 | """ tear donw """ 31 | 32 | def test_fbank(self): 33 | """ test fbank op""" 34 | with self.session(): 35 | data = np.arange(513) 36 | spectrogram = tf.constant(data[None, None, :], dtype=tf.float32) 37 | sample_rate = tf.constant(22050, tf.int32) 38 | output = py_x_ops.fbank( 39 | spectrogram, sample_rate, filterbank_channel_count=20 40 | ) 41 | 42 | output_true = np.array( 43 | [ 44 | 1.887894, 45 | 2.2693727, 46 | 2.576507, 47 | 2.8156495, 48 | 3.036504, 49 | 3.2296343, 50 | 3.4274294, 51 | 3.5987632, 52 | 3.771217, 53 | 3.937401, 54 | 4.0988584, 55 | 4.2570987, 56 | 4.4110703, 57 | 4.563661, 58 | 4.7140336, 59 | 4.8626432, 60 | 5.009346, 61 | 5.1539173, 62 | 5.2992935, 63 | 5.442024, 64 | ] 65 | ) 66 | self.assertEqual(tf.rank(output).eval(), 3) 67 | self.assertEqual(output.shape, (1, 1, 20)) 68 | self.assertAllClose(output.eval(), output_true[None, None, :]) 69 | 70 | 71 | if __name__ == "__main__": 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/framepow.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "kernels/framepow.h" 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace delta { 24 | const float window_length_sec = 0.025; 25 | const float frame_length_sec = 0.010; 26 | 27 | FramePow::FramePow() { 28 | window_length_sec_ = window_length_sec; 29 | frame_length_sec_ = frame_length_sec; 30 | i_snip_edges = 1; 31 | i_remove_dc_offset = true; 32 | pf_FrmEng = NULL; 33 | } 34 | 35 | FramePow::~FramePow() { free(pf_FrmEng); } 36 | 37 | void FramePow::set_window_length_sec(float window_length_sec) { 38 | window_length_sec_ = window_length_sec; 39 | } 40 | 41 | void FramePow::set_frame_length_sec(float frame_length_sec) { 42 | frame_length_sec_ = frame_length_sec; 43 | } 44 | 45 | void FramePow::set_snip_edges(int snip_edges) { i_snip_edges = snip_edges; } 46 | 47 | void FramePow::set_remove_dc_offset(bool remove_dc_offset) { 48 | i_remove_dc_offset = remove_dc_offset; 49 | } 50 | 51 | int FramePow::init_eng(int input_size, float sample_rate) { 52 | f_SamRat = sample_rate; 53 | i_WinLen = static_cast(window_length_sec_ * f_SamRat); 54 | i_FrmLen = static_cast(frame_length_sec_ * f_SamRat); 55 | if (i_snip_edges == 1) 56 | i_NumFrm = (input_size - i_WinLen) / i_FrmLen + 1; 57 | else 58 | i_NumFrm = (input_size + i_FrmLen / 2) / i_FrmLen; 59 | 60 | pf_FrmEng = static_cast(malloc(sizeof(float) * i_NumFrm)); 61 | 62 | return 1; 63 | } 64 | 65 | int FramePow::proc_eng(const float* mic_buf, int input_size) { 66 | int i, n, k; 67 | float* win = static_cast(malloc(sizeof(float) * i_WinLen)); 68 | 69 | for (n = 0; n < i_NumFrm; n++) { 70 | pf_FrmEng[n] = 0.0; 71 | float sum = 0.0; 72 | float energy = 0.0; 73 | for (k = 0; k < i_WinLen; k++) { 74 | int index = n * i_FrmLen + k; 75 | if (index < input_size) 76 | win[k] = mic_buf[index]; 77 | else 78 | win[k] = 0.0f; 79 | sum += win[k]; 80 | } 81 | 82 | if (i_remove_dc_offset == true) { 83 | float mean = sum / i_WinLen; 84 | for (int l = 0; l < i_WinLen; l++) win[l] -= mean; 85 | } 86 | 87 | for (i = 0; i < i_WinLen; i++) { 88 | energy += win[i] * win[i]; 89 | } 90 | 91 | pf_FrmEng[n] = log(energy); 92 | 93 | } 94 | 95 | free(win); 96 | return 1; 97 | } 98 | 99 | int FramePow::get_eng(float* output) { 100 | memcpy(output, pf_FrmEng, sizeof(float) * i_NumFrm); 101 | 102 | return 1; 103 | } 104 | 105 | int FramePow::write_eng() { 106 | FILE* fp; 107 | fp = fopen("frame_energy.txt", "w"); 108 | int n; 109 | for (n = 0; n < i_NumFrm; n++) { 110 | fprintf(fp, "%4.6f\n", pf_FrmEng[n]); 111 | } 112 | fclose(fp); 113 | return 1; 114 | } 115 | } // namespace delta 116 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/framepow.h: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #ifndef DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_ 18 | #define DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_ 19 | 20 | #include "tensorflow/core/framework/op_kernel.h" 21 | #include "tensorflow/core/platform/logging.h" 22 | 23 | using namespace tensorflow; // NOLINT 24 | 25 | namespace delta { 26 | class FramePow { 27 | private: 28 | float window_length_sec_; 29 | float frame_length_sec_; 30 | int i_snip_edges; 31 | bool i_remove_dc_offset; 32 | 33 | float f_SamRat; 34 | int i_WinLen; 35 | int i_FrmLen; 36 | int i_NumFrm; 37 | 38 | float* pf_FrmEng; 39 | 40 | public: 41 | FramePow(); 42 | 43 | ~FramePow(); 44 | 45 | void set_window_length_sec(float window_length_sec); 46 | 47 | void set_frame_length_sec(float frame_length_sec); 48 | 49 | void set_snip_edges(int snip_edges); 50 | 51 | void set_remove_dc_offset(bool remove_dc_offset); 52 | 53 | int init_eng(int input_size, float sample_rate); 54 | 55 | int proc_eng(const float* mic_buf, int input_size); 56 | 57 | int get_eng(float* output); 58 | 59 | int write_eng(); 60 | 61 | TF_DISALLOW_COPY_AND_ASSIGN(FramePow); 62 | }; 63 | } // namespace delta 64 | #endif // DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_ 65 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/framepow_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "kernels/framepow.h" 18 | 19 | #include "tensorflow/core/framework/op_kernel.h" 20 | #include "tensorflow/core/framework/register_types.h" 21 | #include "tensorflow/core/framework/tensor.h" 22 | #include "tensorflow/core/framework/tensor_shape.h" 23 | #include "tensorflow/core/framework/types.h" 24 | #include "tensorflow/core/lib/core/status.h" 25 | 26 | namespace delta { 27 | class FramePowOp : public OpKernel { 28 | public: 29 | explicit FramePowOp(OpKernelConstruction* context) : OpKernel(context) { 30 | OP_REQUIRES_OK(context, context->GetAttr("window_length", &window_length_)); 31 | OP_REQUIRES_OK(context, context->GetAttr("frame_length", &frame_length_)); 32 | OP_REQUIRES_OK(context, context->GetAttr("snip_edges", &snip_edges_)); 33 | OP_REQUIRES_OK(context, 34 | context->GetAttr("remove_dc_offset", &remove_dc_offset_)); 35 | } 36 | 37 | void Compute(OpKernelContext* context) override { 38 | const Tensor& input_tensor = context->input(0); 39 | OP_REQUIRES(context, input_tensor.dims() == 1, 40 | errors::InvalidArgument("input signal must be 1-dimensional", 41 | input_tensor.shape().DebugString())); 42 | 43 | const Tensor& sample_rate_tensor = context->input(1); 44 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()), 45 | errors::InvalidArgument( 46 | "Input sample rate should be a scalar tensor, got ", 47 | sample_rate_tensor.shape().DebugString(), " instead.")); 48 | const float sample_rate = sample_rate_tensor.scalar()(); 49 | 50 | // shape 51 | const int L = input_tensor.dim_size(0); 52 | FramePow cls_eng; 53 | cls_eng.set_window_length_sec(window_length_); 54 | cls_eng.set_frame_length_sec(frame_length_); 55 | cls_eng.set_snip_edges(snip_edges_); 56 | cls_eng.set_remove_dc_offset(remove_dc_offset_); 57 | OP_REQUIRES(context, cls_eng.init_eng(L, sample_rate), 58 | errors::InvalidArgument( 59 | "framepow_class initialization failed for length ", L, 60 | " and sample rate ", sample_rate)); 61 | 62 | Tensor* output_tensor = nullptr; 63 | int i_WinLen = static_cast(window_length_ * sample_rate); 64 | int i_FrmLen = static_cast(frame_length_ * sample_rate); 65 | int i_NumFrm = (L - i_WinLen) / i_FrmLen + 1; 66 | if (snip_edges_ == 2) i_NumFrm = (L + i_FrmLen / 2) / i_FrmLen; 67 | if (i_NumFrm < 1) i_NumFrm = 1; 68 | OP_REQUIRES_OK(context, context->allocate_output( 69 | 0, TensorShape({1, i_NumFrm}), &output_tensor)); 70 | 71 | const float* input_flat = input_tensor.flat().data(); 72 | float* output_flat = output_tensor->flat().data(); 73 | 74 | int ret; 75 | ret = cls_eng.proc_eng(input_flat, L); 76 | ret = cls_eng.get_eng(output_flat); 77 | } 78 | 79 | private: 80 | float window_length_; 81 | float frame_length_; 82 | int snip_edges_; 83 | bool remove_dc_offset_; 84 | }; 85 | 86 | REGISTER_KERNEL_BUILDER(Name("FramePow").Device(DEVICE_CPU), FramePowOp); 87 | 88 | } // namespace delta -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/mfcc_dct.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "kernels/mfcc_dct.h" 17 | 18 | #include 19 | 20 | #include "tensorflow/core/platform/logging.h" 21 | 22 | namespace delta { 23 | 24 | const float kDefaultCepstralLifter = 22; 25 | const int kDefaultCoefficientCount = 13; 26 | 27 | MfccDct::MfccDct() 28 | : initialized_(false), 29 | coefficient_count_(kDefaultCoefficientCount), 30 | cepstral_lifter_(kDefaultCepstralLifter) {} 31 | 32 | bool MfccDct::Initialize(int input_length, int coefficient_count) { 33 | coefficient_count_ = coefficient_count; 34 | input_length_ = input_length; 35 | 36 | if (coefficient_count_ < 1) { 37 | LOG(ERROR) << "Coefficient count must be positive."; 38 | return false; 39 | } 40 | 41 | if (input_length < 1) { 42 | LOG(ERROR) << "Input length must be positive."; 43 | return false; 44 | } 45 | 46 | if (coefficient_count_ > input_length_) { 47 | LOG(ERROR) << "Coefficient count must be less than or equal to " 48 | << "input length."; 49 | return false; 50 | } 51 | 52 | cosines_.resize(coefficient_count_); 53 | double fnorm = sqrt(2.0 / input_length_); 54 | // Some platforms don't have M_PI, so define a local constant here. 55 | const double pi = std::atan(1) * 4; 56 | double arg = pi / input_length_; 57 | for (int i = 0; i < coefficient_count_; ++i) { 58 | cosines_[i].resize(input_length_); 59 | for (int j = 0; j < input_length_; ++j) { 60 | cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5)); 61 | } 62 | } 63 | 64 | lifter_coeffs_.resize(coefficient_count_); 65 | for (int j = 0; j < coefficient_count_; ++j) 66 | lifter_coeffs_[j] = 67 | 1.0 + 0.5 * cepstral_lifter_ * sin(PI * j / cepstral_lifter_); 68 | 69 | initialized_ = true; 70 | return true; 71 | } 72 | 73 | void MfccDct::set_coefficient_count(int coefficient_count) { 74 | coefficient_count_ = coefficient_count; 75 | } 76 | 77 | void MfccDct::set_cepstral_lifter(float cepstral_lifter) { 78 | cepstral_lifter_ = cepstral_lifter; 79 | } 80 | 81 | void MfccDct::Compute(const std::vector &input, 82 | std::vector *output) const { 83 | if (!initialized_) { 84 | LOG(ERROR) << "DCT not initialized."; 85 | return; 86 | } 87 | 88 | output->resize(coefficient_count_); 89 | int length = input.size(); 90 | if (length > input_length_) { 91 | length = input_length_; 92 | } 93 | 94 | double res; 95 | for (int i = 0; i < coefficient_count_; ++i) { 96 | double sum = 0.0; 97 | for (int j = 0; j < length; ++j) { 98 | sum += cosines_[i][j] * input[j]; 99 | } 100 | res = sum; 101 | if (cepstral_lifter_ != 0) res *= lifter_coeffs_[i]; 102 | (*output)[i] = res; 103 | } 104 | } 105 | 106 | } // namespace delta 107 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/mfcc_dct.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Basic minimal DCT class for MFCC speech processing. 17 | 18 | #ifndef TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ // NOLINT 19 | #define TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ // NOLINT 20 | 21 | #include 22 | 23 | #include "tensorflow/core/framework/op_kernel.h" 24 | #include "tensorflow/core/platform/logging.h" 25 | #include "kernels/support_functions.h" 26 | 27 | using namespace tensorflow; // NOLINT 28 | #define PI (3.141592653589793) 29 | 30 | namespace delta { 31 | 32 | class MfccDct { 33 | public: 34 | MfccDct(); 35 | bool Initialize(int input_length, int coefficient_count); 36 | void Compute(const std::vector& input, 37 | std::vector* output) const; 38 | void set_coefficient_count(int coefficient_count); 39 | void set_cepstral_lifter(float cepstral_lifter); 40 | 41 | private: 42 | bool initialized_; 43 | int coefficient_count_; 44 | float cepstral_lifter_; 45 | int input_length_; 46 | std::vector > cosines_; 47 | std::vector lifter_coeffs_; 48 | TF_DISALLOW_COPY_AND_ASSIGN(MfccDct); 49 | }; 50 | 51 | } // namespace delta 52 | 53 | #endif // DELTA_LAYERS_OPS_KERNELS_MFCC_DCT_H_ 54 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/mfcc_mel_filterbank.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Basic class for applying a mel-scale mapping to a power spectrum. 17 | 18 | #ifndef DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_ 19 | #define DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_ 20 | 21 | #include 22 | 23 | #include "tensorflow/core/framework/op_kernel.h" 24 | 25 | namespace tensorflow { 26 | 27 | class MfccMelFilterbank { 28 | public: 29 | MfccMelFilterbank(); 30 | ~MfccMelFilterbank(); 31 | bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1. 32 | double input_sample_rate, int output_channel_count, 33 | double lower_frequency_limit, double upper_frequency_limit); 34 | 35 | // Takes a squared-magnitude spectrogram slice as input, computes a 36 | // triangular-mel-weighted linear-magnitude filterbank, and places the result 37 | // in output. 38 | void Compute(const std::vector& input, 39 | std::vector* output) const; 40 | 41 | private: 42 | double FreqToMel(double freq) const; 43 | bool initialized_; 44 | int num_channels_; 45 | double sample_rate_; 46 | int input_length_; 47 | std::vector center_frequencies_; // In mel, for each mel channel. 48 | 49 | // Each FFT bin b contributes to two triangular mel channels, with 50 | // proportion weights_[b] going into mel channel band_mapper_[b], and 51 | // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1. 52 | // Thus, weights_ contains the weighting applied to each FFT bin for the 53 | // upper-half of the triangular band. 54 | std::vector weights_; // Right-side weight for this fft bin. 55 | 56 | // FFT bin i contributes to the upper side of mel channel band_mapper_[i] 57 | std::vector band_mapper_; 58 | int start_index_; // Lowest FFT bin used to calculate mel spectrum. 59 | int end_index_; // Highest FFT bin used to calculate mel spectrum. 60 | 61 | TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank); 62 | }; 63 | 64 | } // namespace tensorflow 65 | #endif // DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_ 66 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/resample.h: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | 18 | #ifndef DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_ 19 | #define DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "tensorflow/core/framework/op_kernel.h" 28 | #include "tensorflow/core/platform/logging.h" 29 | #include "kernels/support_functions.h" 30 | 31 | using namespace tensorflow; // NOLINT 32 | using namespace std; 33 | 34 | namespace delta { 35 | 36 | class ArbitraryResample { 37 | public: 38 | ArbitraryResample(int num_samples_in, 39 | BaseFloat samp_rate_hz, 40 | BaseFloat filter_cutoff_hz, 41 | const vector &sample_points_secs, 42 | int num_zeros); 43 | 44 | int NumSamplesIn() const { return num_samples_in_; } 45 | 46 | int NumSamplesOut() const { return weights_.size(); } 47 | void Resample(const std::vector > &input, 48 | std::vector > *output) const; 49 | 50 | void Resample(const vector &input, 51 | vector *output) const; 52 | private: 53 | void SetIndexes(const vector &sample_points); 54 | 55 | void SetWeights(const vector &sample_points); 56 | 57 | BaseFloat FilterFunc(BaseFloat t) const; 58 | 59 | int num_samples_in_; 60 | BaseFloat samp_rate_in_; 61 | BaseFloat filter_cutoff_; 62 | int num_zeros_; 63 | 64 | std::vector first_index_; 65 | std::vector > weights_; 66 | }; 67 | 68 | class LinearResample { 69 | public: 70 | LinearResample(int samp_rate_in_hz, 71 | int samp_rate_out_hz, 72 | BaseFloat filter_cutoff_hz, 73 | int num_zeros); 74 | 75 | void Resample(const vector &input, 76 | bool flush, 77 | vector *output); 78 | 79 | void Reset(); 80 | 81 | //// Return the input and output sampling rates (for checks, for example) 82 | inline int GetInputSamplingRate() { return samp_rate_in_; } 83 | inline int GetOutputSamplingRate() { return samp_rate_out_; } 84 | private: 85 | int GetNumOutputSamples(int input_num_samp, bool flush) const; 86 | 87 | inline void GetIndexes(int samp_out, 88 | int *first_samp_in, 89 | int *samp_out_wrapped) const; 90 | 91 | void SetRemainder(const vector &input); 92 | 93 | void SetIndexesAndWeights(); 94 | 95 | BaseFloat FilterFunc(BaseFloat) const; 96 | 97 | // The following variables are provided by the user. 98 | int samp_rate_in_; 99 | int samp_rate_out_; 100 | BaseFloat filter_cutoff_; 101 | int num_zeros_; 102 | 103 | int input_samples_in_unit_; 104 | int output_samples_in_unit_; 105 | 106 | std::vector first_index_; 107 | std::vector > weights_; 108 | 109 | int input_sample_offset_; 110 | int output_sample_offset_; 111 | vector input_remainder_; 112 | }; 113 | 114 | void ResampleWaveform(BaseFloat orig_freq, const vector &wave, 115 | BaseFloat new_freq, vector *new_wave); 116 | 117 | inline void DownsampleWaveForm(BaseFloat orig_freq, const vector &wave, 118 | BaseFloat new_freq, vector *new_wave) { 119 | ResampleWaveform(orig_freq, wave, new_freq, new_wave); 120 | } 121 | 122 | } // namespace delta 123 | #endif // DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_ 124 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/spectrum.h: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #ifndef DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_ 18 | #define DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_ 19 | 20 | #include 21 | #include "tensorflow/core/framework/op_kernel.h" 22 | #include "tensorflow/core/platform/logging.h" 23 | 24 | #include "kernels/complex_defines.h" 25 | #include "kernels/support_functions.h" 26 | 27 | using namespace tensorflow; // NOLINT 28 | 29 | namespace delta { 30 | class Spectrum { 31 | private: 32 | float window_length_sec_; 33 | float frame_length_sec_; 34 | 35 | float f_SamRat; 36 | int i_WinLen; 37 | int i_FrmLen; 38 | int i_NumFrm; 39 | int i_NumFrq; 40 | int i_FFTSiz; 41 | float f_PreEph; 42 | char s_WinTyp[40]; 43 | int i_OutTyp; // 1: PSD, 2:log(PSD) 44 | int i_snip_edges; 45 | int i_raw_energy; 46 | bool i_remove_dc_offset; 47 | bool i_is_fbank; 48 | float i_dither; 49 | 50 | float* pf_WINDOW; 51 | float* pf_SPC; 52 | 53 | xcomplex* win; 54 | float* win_buf; 55 | float* eph_buf; 56 | float* win_temp; 57 | xcomplex* fftwin; 58 | float* fft_buf; 59 | 60 | public: 61 | Spectrum(); 62 | 63 | void set_window_length_sec(float window_length_sec); 64 | 65 | void set_frame_length_sec(float frame_length_sec); 66 | 67 | void set_output_type(int output_type); 68 | 69 | void set_snip_edges(int snip_edges); 70 | 71 | void set_raw_energy(int raw_energy); 72 | 73 | void set_preEph(float preEph); 74 | 75 | void set_window_type(char* window_type); 76 | 77 | void set_is_fbank(bool is_fbank); 78 | 79 | void set_remove_dc_offset(bool remove_dc_offset); 80 | 81 | void set_dither(float dither); 82 | 83 | int init_spc(int input_size, float sample_rate); 84 | 85 | int proc_spc(const float* mic_buf, int input_size); 86 | 87 | int get_spc(float* output); 88 | 89 | int write_spc(); 90 | 91 | TF_DISALLOW_COPY_AND_ASSIGN(Spectrum); 92 | }; 93 | } // namespace delta 94 | #endif // DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_ 95 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/spectrum_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ spectrum Op unit-test """ 17 | import os 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | from absl import logging 23 | 24 | from athena.transform.feats.ops import py_x_ops 25 | 26 | 27 | class SpecOpTest(tf.test.TestCase): 28 | """ spectrum op unittest""" 29 | 30 | def setUp(self): 31 | """set up""" 32 | self.wavpath = str( 33 | Path(os.environ["MAIN_ROOT"]).joinpath("delta/layers/ops/data/sm1_cln.wav") 34 | ) 35 | 36 | def tearDown(self): 37 | """tear down""" 38 | 39 | def test_spectrum(self): 40 | """ test spectrum op""" 41 | with self.session(use_gpu=False, force_gpu=False): 42 | sample_rate, input_data = feat_lib.load_wav(self.wavpath, sr=16000) 43 | logging.info( 44 | f"input shape: {input_data.shape}, sample rate dtype: {sample_rate.dtype}" 45 | ) 46 | self.assertEqual(sample_rate, 16000) 47 | 48 | output = py_x_ops.spectrum(input_data, sample_rate) 49 | 50 | # pylint: disable=bad-whitespace 51 | output_true = np.array( 52 | [ 53 | [-16.863441, -16.910473, -17.077059, -16.371634, -16.845686], 54 | [-17.922068, -20.396345, -19.396944, -17.331493, -16.118851], 55 | [-17.017776, -17.551350, -20.332376, -17.403994, -16.617926], 56 | [-19.873854, -17.644503, -20.679525, -17.093716, -16.535091], 57 | [-17.074402, -17.295971, -16.896650, -15.995432, -16.560730], 58 | ] 59 | ) 60 | # pylint: enable=bad-whitespace 61 | 62 | self.assertEqual(tf.rank(output).eval(), 2) 63 | logging.info("Shape of spectrum: {}".format(output.shape)) 64 | self.assertAllClose(output.eval()[4:9, 4:9], output_true) 65 | 66 | 67 | if __name__ == "__main__": 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /athena/transform/feats/ops/kernels/speed_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "kernels/resample.h" 18 | #include "tensorflow/core/framework/op_kernel.h" 19 | #include "tensorflow/core/framework/register_types.h" 20 | #include "tensorflow/core/framework/tensor.h" 21 | #include "tensorflow/core/framework/tensor_shape.h" 22 | #include "tensorflow/core/framework/types.h" 23 | #include "tensorflow/core/lib/core/status.h" 24 | 25 | namespace delta { 26 | 27 | class SpeedOp : public OpKernel { 28 | public: 29 | explicit SpeedOp(OpKernelConstruction* context) : OpKernel(context) { 30 | OP_REQUIRES_OK(context, 31 | context->GetAttr("lowpass_filter_width", &lowpass_filter_width_)); 32 | } 33 | 34 | void Compute(OpKernelContext* context) override { 35 | const Tensor& input_tensor = context->input(0); 36 | OP_REQUIRES(context, input_tensor.dims() == 1, 37 | errors::InvalidArgument("input signal must be 1-dimensional", 38 | input_tensor.shape().DebugString())); 39 | const Tensor& sample_rate_tensor = context->input(1); 40 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()), 41 | errors::InvalidArgument( 42 | "Input sample_rate should be a scalar tensor, got ", 43 | sample_rate_tensor.shape().DebugString(), " instead.")); 44 | const Tensor& resample_rate_tensor = context->input(2); 45 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(resample_rate_tensor.shape()), 46 | errors::InvalidArgument( 47 | "Resample sample_rate should be a scalar tensor, got ", 48 | resample_rate_tensor.shape().DebugString(), " instead.")); 49 | const int sample_rate = static_cast(sample_rate_tensor.scalar()()); 50 | const int resample_freq = static_cast(resample_rate_tensor.scalar()()); 51 | const float* input_flat = input_tensor.flat().data(); 52 | const int L = input_tensor.dim_size(0); 53 | 54 | lowpass_cutoff_ = min(resample_freq / 2, sample_rate / 2); 55 | LinearResample cls_resample_(sample_rate, resample_freq, 56 | lowpass_cutoff_, 57 | lowpass_filter_width_); 58 | vector waveform(L); 59 | for (int i = 0; i < L; i++){ 60 | waveform[i] = static_cast(input_flat[i]); 61 | } 62 | vector downsampled_wave; 63 | cls_resample_.Resample(waveform, false, &downsampled_wave); 64 | int output_length = downsampled_wave.size(); 65 | Tensor* output_tensor = nullptr; 66 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({1, output_length}), 67 | &output_tensor)); 68 | float* output_flat = output_tensor->flat().data(); 69 | for (int j = 0; j < output_length; j++) 70 | output_flat[j] = downsampled_wave[j]; 71 | 72 | std::vector().swap(downsampled_wave); 73 | std::vector().swap(waveform); 74 | cls_resample_.Reset(); 75 | } 76 | 77 | private: 78 | float lowpass_cutoff_; 79 | int lowpass_filter_width_; 80 | }; 81 | 82 | REGISTER_KERNEL_BUILDER(Name("Speed").Device(DEVICE_CPU), SpeedOp); 83 | 84 | } // namespace delta -------------------------------------------------------------------------------- /athena/transform/feats/ops/py_x_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ python custom ops """ 17 | import tensorflow.compat.v1 as tf 18 | from absl import logging 19 | 20 | so_lib_file = tf.io.gfile.glob(tf.resource_loader.get_data_files_path() + "/x_ops*.so")[ 21 | 0 22 | ].split("/")[-1] 23 | path = tf.resource_loader.get_path_to_datafile(so_lib_file) 24 | logging.info("x_ops*.so path:{}".format(path)) 25 | 26 | gen_x_ops = tf.load_op_library(tf.resource_loader.get_path_to_datafile(so_lib_file)) 27 | 28 | spectrum = gen_x_ops.spectrum 29 | fbank = gen_x_ops.fbank 30 | delta_delta = gen_x_ops.delta_delta 31 | pitch = gen_x_ops.pitch 32 | mfcc = gen_x_ops.mfcc_dct 33 | frame_pow = gen_x_ops.frame_pow 34 | speed = gen_x_ops.speed 35 | -------------------------------------------------------------------------------- /athena/transform/feats/pitch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests pitch FE.""" 17 | 18 | import os 19 | from pathlib import Path 20 | import tensorflow as tf 21 | from tensorflow.python.framework.ops import disable_eager_execution 22 | from athena.transform.feats.read_wav import ReadWav 23 | from athena.transform.feats.pitch import Pitch 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 26 | 27 | 28 | class SpectrumTest(tf.test.TestCase): 29 | """ 30 | Pitch extraction test. 31 | """ 32 | def test_spectrum(self): 33 | wav_path_16k = str( 34 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav") 35 | ) 36 | wav_path_8k = str( 37 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav") 38 | ) 39 | 40 | with self.session(): 41 | for wav_file in [wav_path_16k]: 42 | read_wav = ReadWav.params().instantiate() 43 | input_data, sample_rate = read_wav(wav_file) 44 | 45 | pitch = Pitch.params( 46 | {"window_length": 0.025, "soft_min_f0": 10.0} 47 | ).instantiate() 48 | pitch_test = pitch(input_data, sample_rate) 49 | 50 | if tf.executing_eagerly(): 51 | self.assertEqual(tf.rank(pitch_test).numpy(), 2) 52 | else: 53 | self.assertEqual(tf.rank(pitch_test).eval(), 2) 54 | 55 | output_true = [ 56 | [-0.1366025, 143.8855], 57 | [-0.0226383, 143.8855], 58 | [-0.08464742, 143.8855], 59 | [-0.08458386, 143.8855], 60 | [-0.1208689, 143.8855], 61 | ] 62 | 63 | if wav_file == wav_path_16k: 64 | if tf.executing_eagerly(): 65 | print("Transform: ", pitch_test.numpy()[0:5, :]) 66 | print("kaldi:", output_true) 67 | self.assertAllClose( 68 | pitch_test.numpy()[0:5, :], 69 | output_true, 70 | rtol=1e-05, 71 | atol=1e-05, 72 | ) 73 | else: 74 | print("Transform: ", pitch_test.eval()) 75 | print("kaldi:", output_true) 76 | self.assertAllClose( 77 | pitch_test.eval()[0:5, :], 78 | output_true, 79 | rtol=1e-05, 80 | atol=1e-05, 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | 86 | is_eager = True 87 | if not is_eager: 88 | disable_eager_execution() 89 | else: 90 | if tf.__version__ < "2.0.0": 91 | tf.compat.v1.enable_eager_execution() 92 | tf.test.main() 93 | -------------------------------------------------------------------------------- /athena/transform/feats/read_wav.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model reads audio sample from wav file.""" 17 | 18 | import tensorflow as tf 19 | from athena.utils.hparam import HParams 20 | from athena.transform.feats.base_frontend import BaseFrontend 21 | from athena.transform.feats.ops import py_x_ops 22 | 23 | 24 | class ReadWav(BaseFrontend): 25 | """ 26 | Read audio sample from wav file, return sample data and sample rate. 27 | """ 28 | def __init__(self, config: dict): 29 | super().__init__(config) 30 | 31 | @classmethod 32 | def params(cls, config=None): 33 | """ 34 | Set params. 35 | :param config: contains one optional parameters: audio_channels(int, default=1). 36 | :return: An object of class HParams, which is a set of hyperparameters as 37 | name-value pairs. 38 | """ 39 | audio_channels = 1 40 | 41 | hparams = HParams(cls=cls) 42 | hparams.add_hparam('type', 'ReadWav') 43 | hparams.add_hparam('audio_channels', audio_channels) 44 | 45 | if config is not None: 46 | hparams.parse(config, True) 47 | 48 | return hparams 49 | 50 | def call(self, wavfile, speed=1.0): 51 | """ 52 | Get audio data and sample rate from a wavfile. 53 | :param wavfile: filepath of wav 54 | speed: Speed of sample channels wanted (float, default=1.0) 55 | :return: 2 values. The first is a Tensor of audio data. 56 | The second return value isthe sample rate of the input wav 57 | file, which is a tensor with float dtype. 58 | """ 59 | p = self.config 60 | contents = tf.io.read_file(wavfile) 61 | audio_data, sample_rate = tf.compat.v1.audio.decode_wav( 62 | contents, desired_channels=p.audio_channels) 63 | if (speed == 1.0): 64 | return tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32) 65 | else: 66 | resample_rate = tf.cast(sample_rate, dtype=tf.float32) * tf.cast(1.0 / speed, dtype=tf.float32) 67 | speed_data = py_x_ops.speed(tf.squeeze(audio_data * 32768, axis=-1), 68 | tf.cast(sample_rate, dtype=tf.int32), 69 | tf.cast(resample_rate, dtype=tf.int32), 70 | lowpass_filter_width=5) 71 | return tf.squeeze(speed_data), tf.cast(sample_rate, dtype=tf.int32) 72 | 73 | 74 | def read_wav(wavfile, audio_channels=1): 75 | """ read wav from file 76 | args: audio_channels = 1 77 | returns: tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32) 78 | """ 79 | contents = tf.io.read_file(wavfile) 80 | audio_data, sample_rate = tf.compat.v1.audio.decode_wav( 81 | contents, desired_channels=audio_channels 82 | ) 83 | 84 | return tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32) 85 | -------------------------------------------------------------------------------- /athena/transform/feats/read_wav_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests OP of read_wav """ 17 | 18 | import os 19 | from pathlib import Path 20 | import librosa 21 | import tensorflow as tf 22 | from tensorflow.python.framework.ops import disable_eager_execution 23 | from athena.transform.feats.read_wav import ReadWav 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 26 | 27 | class ReadWavTest(tf.test.TestCase): 28 | """ 29 | ReadWav OP test. 30 | """ 31 | def test_read_wav(self): 32 | wav_path = str(Path(os.environ['MAIN_ROOT']).joinpath('examples/sm1_cln.wav')) 33 | 34 | with self.session(): 35 | speed = 0.9 36 | read_wav = ReadWav.params().instantiate() 37 | input_data, sample_rate = read_wav(wav_path, speed) 38 | 39 | audio_data_true, sample_rate_true = librosa.load(wav_path, sr=16000) 40 | if (speed == 1.0): 41 | if tf.executing_eagerly(): 42 | self.assertAllClose(input_data.numpy() / 32768, audio_data_true) 43 | self.assertAllClose(sample_rate.numpy(), sample_rate_true) 44 | else: 45 | self.assertAllClose(input_data.eval() / 32768, audio_data_true) 46 | self.assertAllClose(sample_rate.eval(), sample_rate_true) 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | is_eager = False 52 | if not is_eager: 53 | disable_eager_execution() 54 | else: 55 | if tf.__version__ < '2.0.0': 56 | tf.compat.v1.enable_eager_execution() 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /athena/transform/feats/spectrum_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests spectrum FE.""" 17 | 18 | import os 19 | from pathlib import Path 20 | import numpy as np 21 | import tensorflow as tf 22 | from tensorflow.python.framework.ops import disable_eager_execution 23 | from athena.transform.feats.read_wav import ReadWav 24 | from athena.transform.feats.spectrum import Spectrum 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 27 | 28 | 29 | class SpectrumTest(tf.test.TestCase): 30 | ''' 31 | Spectum extraction test. 32 | ''' 33 | def test_spectrum(self): 34 | wav_path_16k = str( 35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav") 36 | ) 37 | wav_path_8k = str( 38 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav") 39 | ) 40 | 41 | with self.session(): 42 | for wav_file in [wav_path_8k, wav_path_16k]: 43 | read_wav = ReadWav.params().instantiate() 44 | input_data, sample_rate = read_wav(wav_file) 45 | 46 | spectrum = Spectrum.params( 47 | {"window_length": 0.025, "dither": 0.0} 48 | ).instantiate() 49 | spectrum_test = spectrum(input_data, sample_rate) 50 | 51 | output_true = np.array( 52 | [ 53 | [9.819611, 2.84503, 3.660894, 2.7779, 1.212233], 54 | [9.328745, 2.553949, 3.276319, 3.000918, 2.499342], 55 | ] 56 | ) 57 | if tf.executing_eagerly(): 58 | self.assertEqual(tf.rank(spectrum_test).numpy(), 2) 59 | else: 60 | self.assertEqual(tf.rank(spectrum_test).eval(), 2) 61 | 62 | if wav_file == wav_path_16k: 63 | if tf.executing_eagerly(): 64 | self.assertAllClose( 65 | spectrum_test.numpy()[0:2, 0:5], 66 | output_true, 67 | rtol=1e-05, 68 | atol=1e-05, 69 | ) 70 | else: 71 | self.assertAllClose( 72 | spectrum_test.eval()[0:2, 0:5], 73 | output_true, 74 | rtol=1e-05, 75 | atol=1e-05, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | is_eager = True 82 | if not is_eager: 83 | disable_eager_execution() 84 | else: 85 | if tf.__version__ < "2.0.0": 86 | tf.compat.v1.enable_eager_execution() 87 | tf.test.main() 88 | -------------------------------------------------------------------------------- /athena/transform/feats/write_wav.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model write audio sample to wav file.""" 17 | 18 | import tensorflow as tf 19 | from athena.utils.hparam import HParams 20 | from athena.transform.feats.base_frontend import BaseFrontend 21 | 22 | 23 | class WriteWav(BaseFrontend): 24 | """ 25 | Encode audio data (input) using sample rate (input), 26 | return a write wav opration. 27 | """ 28 | 29 | def __init__(self, config: dict): 30 | super().__init__(config) 31 | 32 | @classmethod 33 | def params(cls, config=None): 34 | """ 35 | Set params. 36 | :param config: contains one optional parameters:sample_rate(int, default=16000). 37 | :return: An object of class HParams, which is a set of hyperparameters as 38 | name-value pairs. 39 | """ 40 | 41 | sample_rate = 16000 42 | 43 | hparams = HParams(cls=cls) 44 | hparams.add_hparam('sample_rate', sample_rate) 45 | 46 | if config is not None: 47 | hparams.override_from_dict(config) 48 | 49 | return hparams 50 | 51 | def call(self, filename, audio_data, sample_rate): 52 | """ 53 | Write wav using audio_data[tensor]. 54 | :param filename: filepath of wav. 55 | :param audio_data: a tensor containing data of a wav. 56 | :param sample_rate: the samplerate of the signal we working with. 57 | :return: write wav opration. 58 | """ 59 | filename = tf.constant(filename) 60 | 61 | with tf.name_scope('writewav'): 62 | audio_data = tf.cast(audio_data, dtype=tf.float32) 63 | contents = tf.audio.encode_wav( 64 | tf.expand_dims(audio_data, 1), tf.cast(sample_rate, dtype=tf.int32)) 65 | w_op = tf.io.write_file(filename, contents) 66 | 67 | return w_op 68 | -------------------------------------------------------------------------------- /athena/transform/feats/write_wav_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The model tests WriteWav OP.""" 17 | 18 | import os 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 21 | from pathlib import Path 22 | import librosa 23 | import tensorflow as tf 24 | from tensorflow.python.framework.ops import disable_eager_execution 25 | from athena.transform.feats.read_wav import ReadWav 26 | from athena.transform.feats.write_wav import WriteWav 27 | 28 | 29 | class WriteWavTest(tf.test.TestCase): 30 | """ 31 | WriteWav OP test. 32 | """ 33 | def test_write_wav(self): 34 | wav_path = str(Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")) 35 | 36 | with self.cached_session() as sess: 37 | config = {"speed": 1.1} 38 | read_wav = ReadWav.params(config).instantiate() 39 | input_data, sample_rate = read_wav(wav_path) 40 | write_wav = WriteWav.params().instantiate() 41 | new_path = str( 42 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln_resample.wav") 43 | ) 44 | writewav_op = write_wav(new_path, input_data / 32768, sample_rate) 45 | sess.run(writewav_op) 46 | 47 | 48 | if __name__ == "__main__": 49 | is_eager = True 50 | if not is_eager: 51 | disable_eager_execution() 52 | else: 53 | if tf.__version__ < "2.0.0": 54 | tf.compat.v1.enable_eager_execution() 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /athena/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) ATHENA AUTHORS 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ utils """ 17 | -------------------------------------------------------------------------------- /athena/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) ATHENA AUTHORS 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # Only support eager mode 17 | # pylint: disable=invalid-name 18 | r""" check point manager """ 19 | import os 20 | import tensorflow as tf 21 | from absl import logging 22 | import numpy as np 23 | 24 | class Checkpoint(tf.train.Checkpoint): 25 | """ A wrapper for Tensorflow checkpoint 26 | 27 | Args: 28 | checkpoint_directory: the directory for checkpoint 29 | summary_directory: the directory for summary used in Tensorboard 30 | __init__ provide the optimizer and model 31 | __call__ save the model 32 | 33 | Example: 34 | transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim) 35 | optimizer = tf.keras.optimizers.Adam() 36 | ckpt = Checkpoint(checkpoint_directory='./train', summary_directory='./event', 37 | transformer=transformer, optimizer=optimizer) 38 | solver = BaseSolver(transformer) 39 | for epoch in dataset: 40 | ckpt() 41 | """ 42 | 43 | def __init__(self, checkpoint_directory=None, **kwargs): 44 | super().__init__(**kwargs) 45 | self.best_loss = np.inf 46 | if checkpoint_directory is None: 47 | checkpoint_directory = os.path.join(os.path.expanduser("~"), ".athena") 48 | self.checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 49 | self.checkpoint_directory = checkpoint_directory 50 | logging.info("trying to restore from : %s" % checkpoint_directory) 51 | # load from checkpoint if previous models exist in checkpoint dir 52 | self.restore(tf.train.latest_checkpoint(checkpoint_directory)) 53 | 54 | def _compare_and_save_best(self, loss, save_path): 55 | """ compare and save the best model in best_loss """ 56 | if loss is None: 57 | return 58 | if loss < self.best_loss: 59 | self.best_loss = loss 60 | with open(os.path.join(self.checkpoint_directory, 'best_loss'), 'w') as wf: 61 | checkpoint = save_path.split('/')[-1] 62 | wf.write('model_checkpoint_path: "%s"' % checkpoint) 63 | 64 | def __call__(self, loss=None): 65 | logging.info("saving model in :%s" % self.checkpoint_prefix) 66 | save_path = self.save(file_prefix=self.checkpoint_prefix) 67 | self._compare_and_save_best(loss, save_path) 68 | 69 | def restore_from_best(self): 70 | """ restore from the best model """ 71 | self.restore( 72 | tf.train.latest_checkpoint( 73 | self.checkpoint_directory, 74 | latest_filename='best_loss' 75 | ) 76 | ) -------------------------------------------------------------------------------- /athena/utils/data_queue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # pylint: disable=missing-function-docstring, invalid-name 18 | """ data queue for multi thread """ 19 | import time 20 | import threading 21 | import queue 22 | 23 | 24 | class DataQueue: 25 | """Queue for data prefetching 26 | args: 27 | generator(generator): instance of generator which feed data 28 | capacity(int): maximum data to prefetch 29 | num_threads(int): control concurrency, only take effect when do preprocessing 30 | wait_time(float): time to sleep when queue is full 31 | """ 32 | 33 | def __init__(self, generator, capacity=20, num_threads=4, max_index=10000, wait_time=0.0001): 34 | self.generator = generator 35 | self.capacity = capacity 36 | self.wait_time = wait_time 37 | self.queue = queue.Queue() 38 | self.index = 0 39 | self.max_index = max_index 40 | 41 | self._stop = threading.Event() 42 | self._lock = threading.Lock() 43 | 44 | self.threads = [ 45 | threading.Thread(target=self.generator_task) for _ in range(num_threads) 46 | ] 47 | 48 | for t in self.threads: 49 | t.setDaemon(True) 50 | t.start() 51 | 52 | def __del__(self): 53 | self.stop() 54 | 55 | def get(self): 56 | return self.queue.get() 57 | 58 | def stop(self): 59 | self._stop.set() 60 | 61 | def generator_task(self): 62 | """Enqueue batch data""" 63 | while not self._stop.is_set(): 64 | try: 65 | if self.index >= self.max_index: 66 | continue 67 | batch = self.generator(self.index) 68 | self._lock.acquire() 69 | if self.queue.qsize() < self.capacity: 70 | try: 71 | self.index = self.index + 1 72 | except ValueError as e: 73 | print(e) 74 | self._lock.release() 75 | continue 76 | self.queue.put(batch) 77 | self._lock.release() 78 | else: 79 | self._lock.release() 80 | time.sleep(self.wait_time) 81 | except Exception as e: 82 | print(e) 83 | self._stop.set() 84 | raise 85 | 86 | 87 | def test(): 88 | """ 89 | Test data queue. 90 | Excpet return: 91 | epoch: %d, nb_batch: %d: finish. 92 | """ 93 | 94 | def generator(i): 95 | return i 96 | 97 | train_queue = DataQueue(generator, capacity=8, num_threads=4) 98 | for _ in range(92): 99 | print(train_queue.get()) 100 | train_queue.stop() 101 | 102 | 103 | if __name__ == "__main__": 104 | test() 105 | -------------------------------------------------------------------------------- /athena/utils/metric_check.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """MetricChecker""" 18 | import time 19 | import tensorflow as tf 20 | import numpy as np 21 | 22 | 23 | class MetricChecker: 24 | """Hold and save best metric checkpoint 25 | Args: 26 | name: MetricChecker name 27 | maximum: more greater more better 28 | """ 29 | 30 | def __init__(self, optimizer): 31 | self.best_loss = tf.constant(np.inf) 32 | self.optimizer = optimizer 33 | self.time_last_call = time.time() 34 | self.steps_last_call = 0 35 | 36 | def __call__(self, loss, metrics, evaluate_epoch=-1): 37 | """summary the basic metrics like loss, lr 38 | Args: 39 | loss: 40 | matrics: average loss of all previous steps in one epoch 41 | if training is False, it must be provided 42 | evaluate_epoch: 43 | if evaluate_epoch >= 0: 44 | if evaluate_epoch == -1: 45 | if evaluate_epoch < -1: (no tf.summary.write) 46 | Returns: 47 | logging_str: return average and best(if improved) loss if training is False 48 | """ 49 | if evaluate_epoch is -1: 50 | return self.summary_train(loss, metrics) 51 | return self.summary_evaluate(loss, metrics, evaluate_epoch) 52 | 53 | def summary_train(self, loss, metrics): 54 | """ generate summary of learning_rate, loss, metrics, speed and write on Tensorboard 55 | """ 56 | global_steps = tf.convert_to_tensor(self.optimizer.iterations) 57 | learning_rate = ( 58 | self.optimizer.lr 59 | if isinstance(self.optimizer.lr, tf.Variable) 60 | else (self.optimizer.lr(global_steps)) 61 | ) 62 | 63 | tf.summary.scalar("loss", loss, step=global_steps) 64 | tf.summary.scalar("learning_rate", learning_rate, step=global_steps) 65 | for name in metrics: 66 | metric = metrics[name] 67 | tf.summary.scalar(name, metric, step=global_steps) 68 | 69 | reports = "" 70 | reports += "global_steps: %d\t" % (global_steps) 71 | reports += "learning_rate: %.4e\t" % (learning_rate) 72 | reports += "loss: %.4f\t" % (loss) 73 | for name in metrics: 74 | metric = metrics[name] 75 | reports += "%s: %.4f\t" % (name, metric) 76 | right_now = time.time() 77 | duration = right_now - self.time_last_call 78 | self.time_last_call = right_now 79 | if self.steps_last_call != 0: 80 | # average duration over log_interval steps 81 | sec_per_iter = duration / tf.cast( 82 | (global_steps - self.steps_last_call), tf.float32 83 | ) 84 | reports += "sec/iter: %.4f" % (sec_per_iter) 85 | self.steps_last_call = global_steps 86 | 87 | return reports 88 | 89 | def summary_evaluate(self, loss, metrics, epoch=-1): 90 | """ If epoch > 0, return a summary of loss and metrics on dev set and write on Tensorboard 91 | Otherwise, just return evaluate loss and metrics 92 | """ 93 | reports = "" 94 | if epoch >= 0: 95 | tf.summary.scalar("evaluate_loss", loss, step=epoch) 96 | for name in metrics: 97 | metric = metrics[name] 98 | tf.summary.scalar("evaluate_" + name, metric, step=epoch) 99 | reports += "epoch: %d\t" % (epoch) 100 | reports += "loss: %.4f\t" % (loss) 101 | for name in metrics: 102 | metric = metrics[name] 103 | reports += "%s: %.4f\t" % (name, metric) 104 | return reports 105 | -------------------------------------------------------------------------------- /athena/utils/vocabs/en.vocab: -------------------------------------------------------------------------------- 1 | a 1 2 | b 2 3 | c 3 4 | d 4 5 | e 5 6 | f 6 7 | g 7 8 | h 8 9 | i 9 10 | j 10 11 | k 11 12 | l 12 13 | m 13 14 | n 14 15 | o 15 16 | p 16 17 | q 17 18 | r 18 19 | s 19 20 | t 20 21 | u 21 22 | v 22 23 | w 23 24 | x 24 25 | y 25 26 | z 26 27 | ' 27 28 | - 28 29 | 0 30 | 29 31 | -------------------------------------------------------------------------------- /docs/TheTrainningEfficiency.md: -------------------------------------------------------------------------------- 1 | # The efficiency of GPU Using '``Horovod``+``Tensorflow``' 2 | 3 | ## Experimental 4 | 5 | The Training Environment: ``Athena`` 6 | 7 | 8 | Traning Data: A subset was random selected 1000 samples from HKST training dataset. 9 | 10 | 11 | 12 | Newwork: ``LAS`` Model 13 | 14 | Primary Network Configuration: ``NUM_EPOCHS`` 1, ``BATCH_SIZE`` 10 15 | 16 | 17 | 18 | The training time is changed by deferent number of of server and GPU when using `Horovod`+`Tensorflow`. As the same time, the training data and network structure etc still keep same to train `one` `epoch`. These results of experiment as follow: 19 | 20 | ### The training time using ``Horovod``+``Tensorflow``(Character) 21 | 22 | 23 | Server and GPU number | 1S-1GPU | 1S-2GPUs | 1S-3GPUs | 1S-4GPUs | 2Ss-2GPUs | 2Ss-4GPUs | 2Ss-6GPUs | 2Ss-8GPUs | 24 | :-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:--------:|:--------:| 25 | training time(s/1 epoch) | 121.409 | 83.111 | 61.607 | 54.507 | 82.486 | 49.888 | 33.333 | 28.101 | 26 | 27 | ## The Reslut Analysis 28 | 1. As the character shown that the more GPUs are used and the training time is shorter. For example, we commpared their training time scale between using 1 server with 1 GPU and 1 server with 4 GPUs. Their training time scale is `1S-4GPUs:1S-1GPU=1:2.22`. Moreover,anoter set of data is recorded as `2Ss-8GPUs:1S-1GPU=1:4.3`. From them we can see, increasing the number of GPU when we train model can save training time and increase the efficiency. 29 | 30 | 2. The communication time is really short between difference server using `Horovod`. We have trained the same structure model respectively using 1 servers with 2 GPUs and using 2 servers with 1 GPU each and the training time scale is `1S-2GPUs:2Ss-2GPUs=1:1`. 31 | 32 | -------------------------------------------------------------------------------- /docs/development/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | ## License 4 | 5 | The source file should contain a license header. See the existing files as the example. 6 | 7 | ## Name style 8 | 9 | All name in python and cpp using [snake case style](https://en.wikipedia.org/wiki/Snake_case), except for `op` for `Tensorflow`. 10 | For Golang, using Camel-Case for `variable name` and `interface`. 11 | 12 | ## Python style 13 | 14 | Changes to Python code should conform the [Chromium Python Style Guide](https://chromium.googlesource.com/chromium/src/+/master/styleguide/python/python.md). 15 | You can use [yapf](https://github.com/google/yapf) to check the style. 16 | The style configuration is `.style.yapf`. 17 | 18 | ## C++ style 19 | 20 | Changes to C++ code should conform to [Google C++ Style Guide](https://github.com/google/styleguide). 21 | You can use [cpplint](https://github.com/google/styleguide/tree/gh-pages/cpplint) to check the style and use [clang-format](https://clang.llvm.org/docs/ClangFormat.html) to format the code. 22 | The style configuration is `.clang-format`. 23 | 24 | ## C++ macro 25 | 26 | C++ macros should start with `ATHENA_`, except for most common ones like `LOG` and `VLOG`. 27 | 28 | ## Golang style 29 | 30 | For Golang styple, please see docs below: 31 | 32 | * [How to Write Go Code](https://golang.org/doc/code.html) 33 | * [Effective Go](https://golang.org/doc/effective_go.html#interface-names) 34 | * [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments) 35 | * [Golang Style in Chinese](https://juejin.im/post/5c16f16c5188252dcb30ff42) 36 | 37 | Before commit golang code, plase using `go fmt` and `go vec` to format and lint code. 38 | 39 | ## Logging guideline 40 | 41 | For `python` using [abseil-py](https://github.com/abseil/abseil-py), [more info](https://abseil.io/docs/python/). 42 | 43 | For C++ using [abseil-cpp](https://github.com/abseil/abseil-cpp), [more info](https://abseil.io/docs/cpp/). 44 | 45 | For Golang using [glog](https://github.com/golang/glog). 46 | 47 | ## Unit test 48 | 49 | For `python` using `tf.test.TestCase` 50 | 51 | For C++ using [googletest](https://github.com/google/googletest) 52 | 53 | For Golang using `go test` for unittest. 54 | -------------------------------------------------------------------------------- /docs/transform/img/DCT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/DCT.png -------------------------------------------------------------------------------- /docs/transform/img/DFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/DFT.png -------------------------------------------------------------------------------- /docs/transform/img/MFCC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/MFCC.png -------------------------------------------------------------------------------- /docs/transform/img/MelBank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/MelBank.png -------------------------------------------------------------------------------- /docs/transform/img/Mel_filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/Mel_filter.png -------------------------------------------------------------------------------- /docs/transform/img/amplitude_spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/amplitude_spectrum.png -------------------------------------------------------------------------------- /docs/transform/img/audio_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/audio_data.png -------------------------------------------------------------------------------- /docs/transform/img/delta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/delta.png -------------------------------------------------------------------------------- /docs/transform/img/fbank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/fbank.png -------------------------------------------------------------------------------- /docs/transform/img/hamming.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/hamming.png -------------------------------------------------------------------------------- /docs/transform/img/logMel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/logMel.png -------------------------------------------------------------------------------- /docs/transform/img/logpower_spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/logpower_spectrum.png -------------------------------------------------------------------------------- /docs/transform/img/mel_freq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/mel_freq.png -------------------------------------------------------------------------------- /docs/transform/img/melbanks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/melbanks.png -------------------------------------------------------------------------------- /docs/transform/img/phase_spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/phase_spectrum.png -------------------------------------------------------------------------------- /docs/transform/img/power_spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/power_spectrum.png -------------------------------------------------------------------------------- /docs/transform/img/spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum.png -------------------------------------------------------------------------------- /docs/transform/img/spectrum_emph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum_emph.png -------------------------------------------------------------------------------- /docs/transform/img/spectrum_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum_orig.png -------------------------------------------------------------------------------- /docs/transform/speech_feature.md: -------------------------------------------------------------------------------- 1 | # 语音特征简介 2 | 3 | ## 语音数据 4 | 在进行特征处理前,我们需要从wav格式的音频文件获取音频数据和采样率。read_wav.py提供了ReadWav的接口,使用如下: 5 | 6 | config = {'sample_rate': 16000.0} 7 | wav_path = 'sm1_cln.wav'' 8 | read_wav = ReadWav.params(config).instantiate() 9 | audio_data, sample_rate = read_wav(wav_path) 10 | 11 | 获得的音频数据时float类型的,范围为[-1, 1]。原始信号在时域上的表示如下图所示: 12 | ![audio data of wavfile](img/audio_data.png "Audio data") 13 | 14 | ## 短时傅里叶变换(STFT) 15 | 语音在时域上变化是快速且比较复杂的,我们常常需要采用傅里叶变换将其转换到频域上。对数字语音信号进行的傅里叶变换即为短时傅里叶变换(Short Time Fourier Transform, STFT)。STFT的一般分为预加重(可选)、分帧、加窗、傅里叶变换这几个步骤。 16 | ### 预加重 17 | 由于语音信号的平均功率谱受到声门激励和口鼻辐射的影响,高频段大约在800Hz以上以6dB/倍频程跌落,所以求语音信号频谱时,频率越高相应成分越小。预加重的目的在于提高高频成分,使信号的频谱变得平坦,保持在低频到高频的整个频带中能够用同样的信噪比求频谱,以便于频谱分析。一般通过一阶FIR高通滤波器来实现预加重,其传递函数为(a为预加重系数): 18 | 19 | y(n) = x(n) - ax(n-1), 0.9 < a < 1.0 20 | 语音信号经过预加重前后的频谱对比如下图所示(a=0.97)。 21 | ![Spectrum of orignal data](img/spectrum_orig.png "Orignal Spectrum") 22 | ![Spectrum of pre_emphasis data](img/spectrum_emph.png "Pre_emphasis Spectrum") 23 | 24 | ### 分帧 25 | 傅里叶变换要求信号是平稳的,但是语音在宏观上是不平稳的,所以我们首先需要将语音信号切割成较短固定长度的时间片段【即分帧】,此时语音信号可以看成是平稳的。帧长需要满足:①足够短保证帧内信号是平稳的,所以一帧的长度应小于一个音素的长度[50-200ms];②必须包含足够多的振动周期,而语音中的基频男声为100Hz(T=10ms),女声为200Hz(T=5ms),所以一般至少取20ms。综上,帧长一般为20-50ms,即frame_length=0.002-0.005(单位s)。kaldi默认帧长为0.025ms,帧移位0.010ms。 26 | ### 加窗 27 | 对时域信号进行截断,若非周期截断就会产生频谱泄露。为了降低频谱泄露误差,我们需要使用加权函数,即窗函数,使时域信号似乎更好地满足傅里叶变换的周期性要求。每帧信号与一个平滑的窗函数相乘,可以让帧两端平滑衰减到零,这样可以降低傅里叶变换后旁瓣的强度,取得更高质量的频谱。常见的窗函数有矩形窗、汉宁窗、汉明窗等。加窗的代价是一帧信号的两端部分被削弱了,为了避免信息丢失,我们需要对其进行弥补——不要背靠背地截取而是相互重叠一部分。相邻两帧的起始位置的时间差即为帧移,一般取帧长的0-0.5。即hop_length=(0-0.5) * frame_length。 28 | 常见的汉明窗对应的窗函数如下(N是语音总样本数,n是0-N的整数): 29 | ![Hamming](img/hamming.png "Hamming") 30 | 31 | ### 快速傅里叶变换(FFT) 32 | 我们获取的是数字音频,因此我们需要进行离散傅里叶变换(DFT)。其数学公式如下: 33 | 34 | ![DFT](img/DFT.png "DFT") 35 | 36 | 由于DFT的计算复杂度较高,我们通常使用的是快速傅里叶变换(FFT)。FFT的点数(Nfft)一般取2的整数次幂,且必须大于等于一帧信号的样本点数,如果帧长为400样本点,则Nfft必须大于等于512。帧长不等于Nfft时可以采用补零操作。频谱的分辨率精度=采样率/Nfft,第m个频点对应的频率f(m)=m × 频率分辨率。 37 | 对于单通道语音,我们获取到的是一个复数矩阵,维度为(num_Frequencies, num_frames),代表该语音信号是由num_Frequencies个不同相位的正弦波组合而成。对于每一个时频点,FFT的绝对值是该频率对应的正弦波的幅度,相位就是正弦波的初始相位,因此对应的频谱分别称为幅度谱和相位谱,如下图所示。 38 | ![Amplitude Spectrum](img/amplitude_spectrum.png "Amplitude Spectrum") 39 | ![Phase Spectrum](img/phase_spectrum.png "Phase Spectrum") 40 | 41 | ## 功率谱(Spectrum) 42 | 计算STFT后,我们得到了频域的信号。每个频带范围的能量大小不一,不同的音素的能量谱不一样。其计算公式为: 43 | ![Power_Spectrum](img/spectrum.png "Power_Spectrum") 44 | 45 | 常用的还有log谱,即logP = 10 × log(P)。功率谱和log功率谱如下图所示。 46 | ![Power Spectrum](img/power_spectrum.png "Power Spectrum") 47 | ![Log-power Spectrum](img/logpower_spectrum.png "Log-power Spectrum") 48 | 49 | ## FliterBank 50 | 人耳对声音频谱的响应是非线性的,人耳对于低频声音的分辨率要高于高频的声音。经验表明:如果我们能够设计一种前端处理算法,以类似于人耳的方式对音频进行处理,可以提高语音识别的性能。FilterBank就是这样的一种算法。在功率谱的基础上获取FBank特征,需要进行Mel滤波和取对数运算。 51 | ### Mel滤波 52 | 通过把频率转换成Mel频率,我们的特征就能够更好的匹配人类的听觉感知效果。频率f和Mel频率m的转换公式如下: 53 | ![Mel2Freq && Freq2Mel](img/mel_freq.png "Mel2Freq && Freq2Mel") 54 | #### Mel滤波器组 55 | Mel滤波器组是一组大约数量为20-40(kaldi默认为23,MFCC为26)的三角滤波器,每一个三角窗滤波器覆盖的范围都近似于人耳的一个临界带宽。三角窗口可以覆盖从0到Nyquist的整个频率范围,但是通常我们会设定频率上限和下限,屏蔽掉某些不需要或者有噪声的频率范围。Mel滤波器有两种常见的形式:中心频率响应恒为1,三角形滤波器的面积随着带宽的变化而变化;随着宽的增加而改变高度,保证其面积不变。后一种的数学表达式为: 56 | ![MelBanks_maths](img/MelBank.png "MelBanks_maths") 57 | 式中m代表第m个滤波器;k代表横轴坐标,也就是自变量;f(m)代表第m个滤波器的中心点的横坐标值。其效果图如下: 58 | ![MelBanks](img/melbanks.png "MelBanks") 59 | Mel滤波器组有两个主要的作用:①对能量谱进行平滑化,并消除谐波的作用,突出语音的共振峰;②降低运算量。 60 | 采用Mel滤波器组对上一步得到的功率谱估计进行滤波,得到维数和Mel滤波器组三角形个数一致的特征向量,数学表达为: 61 | ![MelFilter](img/Mel_filter.png "MelFilter") 62 | ### 对数运算 63 | 这一步就是取上一步结果的对数,这样可以放大低能量处的能量差异。即,FBank特征为 64 | ![FliterBank](img/logMel.png "FliterBank") 65 | FliterBank特征的效果图如下图所示(频率上限位8000Hz,下限为20Hz,特征维数为23): 66 | ![FliterBank Features](img/fbank.png "FliterBank Features") 67 | ## 梅尔倒谱系数(MFCC) 68 | FBank特征已经很贴近人耳的响应特性,但是仍有一些不足:FBank特征相邻的特征高度相关(相邻滤波器组有重叠),因此当我们用HMM对音素建模的时候,几乎总需要首先进行离散余弦变换(discrete cosine transform,DCT),通过这样得到MFCC(Mel-scale FrequencyCepstral Coefficients)特征。DCT的实质是去除各维信号之间的相关性,将信号映射到低维空间。 69 | DCT的数学表达为: 70 | ![DCT](img/DCT.png "DCT") 71 | N是FBank的特征维度,M 是 DCT(离散余弦变换)之后的特征维度。DCT对于一般的语音信号作用后,所得的结果的前几个系数特别大,后面的系数比较小,一次一般仅保留前12-20个,这样也进一步压缩了数据。MFCC的效果图如下: 72 | ![MFCC Features](img/MFCC.png "MFCC Features") 73 | FBank和MFCC的对比:①FBank特征相关性较高,而DNN/CNN可以更好的利用这些相关性,使用FBank特征可以更多地降低WER;②MFCC具有更好的判别度,而对于使用对角协方差矩阵的GMM由于忽略了不同特征维度的相关性,MFCC更适合用来做特征。 74 | ## 差分(delta) 75 | 标准的倒谱参数MFCC是针对一段语音信号进行特征提取,只反映了语音参数的静态特性,语音的动态特性可以用这些静态特征的差分谱来描述。实验证明:把动、静态特征结合起来才能有效提高系统的识别性能。差分参数的计算可以采用下面的公式(t是帧数,典型值为2): 76 | ![Delta](img/delta.png "Delta") 77 | ## 参考: 78 | https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html 79 | -------------------------------------------------------------------------------- /docs/transform/user_manual.md: -------------------------------------------------------------------------------- 1 | # User Manual 2 | 3 | ## Introduction 4 | 5 | Transform is a preprocess data toolkit. 6 | 7 | ## Usage 8 | 9 | Transform support speech feature extract: 10 | 11 | ### 1. Import module 12 | 13 | ```python 14 | from athena.transform import AudioFeaturizer 15 | ``` 16 | 17 | ### 2. Init a feature extract object 18 | 19 | #### Read wav 20 | 21 | ```python 22 | conf = {'type':'ReadWav'} 23 | feature_ext = AudioFeaturizer() 24 | 25 | ''' 26 | Other default args: 27 | 'audio_channels':1 28 | 29 | The shape of the output: 30 | [T, 1, 1] 31 | ''' 32 | ``` 33 | 34 | #### Spectrum 35 | 36 | ```python 37 | conf = {'type':'Spectrum'} 38 | feature_ext = AudioFeaturizer(conf) 39 | 40 | ''' 41 | Other default args: 42 | 'sample_rate' : 16000 43 | 'window_length' : 0.025 44 | 'frame_length' : 0.010 45 | 'global_mean': 全局均值 46 | 'global_variance': 全局方差 47 | 'local_cmvn' : 默认是True, 做句子cmvn 48 | 49 | The shape of the output: 50 | [T, dim, 1] 51 | ''' 52 | ``` 53 | 54 | #### FliterBank 55 | 56 | ```python 57 | conf = {'type':'Fbank'} 58 | feature_ext = AudioFeaturizer(conf) 59 | 60 | ''' 61 | Other default args: 62 | 'sample_rate' : 采样率 16000 63 | 'window_length' : 窗长 0.025秒 64 | 'frame_length' : 步长 0.010秒 65 | 'upper_frequency_limit' : 4000 66 | 'lower_frequency_limit': 20 67 | 'filterbank_channel_count' : 40 68 | 'delta_delta' : 是否做差分 False 69 | 'window' : 差分窗长 2 70 | 'order' : 差分阶数 2 71 | 'global_mean': 全局均值 72 | 'global_variance': 全局方差 73 | 'local_cmvn' : 默认是True, 做句子cmvn 74 | 75 | Returns: 76 | A tensor of shape [T, dim, num_channel]. 77 | dim = 40 78 | num_channel = 1 if 'delta_delta' == False else 1 + 'order' 79 | ''' 80 | ``` 81 | 82 | #### CMVN 83 | 84 | ```python 85 | conf = {'type':'CMVN'} 86 | cmvn = AudioFeaturizer(conf) 87 | 88 | ''' 89 | Other configuration 90 | 91 | 'global_mean': global cmvn 92 | 'global_variance': global variance 93 | 'local_cmvn' : default true 94 | 95 | 'global_mean'和'global_variance'如果设置则会做全局cmvn,否则不做全局cmvn。 96 | 'local_cmvn' 设置False不做句子cmvn 97 | ''' 98 | ``` 99 | 100 | ### 3. Feature extract 101 | 102 | ```python 103 | feat = feature_ext(audio) 104 | ''' 105 | audio : Audio file or audio data. 106 | feat : A tensor containing speech feature. 107 | ''' 108 | ``` 109 | 110 | ### 4. Get feature dim and the number of channels 111 | 112 | ```python 113 | dim = feature_ext.dim 114 | # dim : A int scalar, it is feature dim. 115 | num_channels = feature_ext.num_channels 116 | # num_channels: A int scalar, it is the number of channels of the output feature 117 | # the shape of the output features: [None, dim, num_channels] 118 | ``` 119 | 120 | ### 5. Example 121 | 122 | #### 5.1 Extract speech feature filterbank from audio file: 123 | 124 | ```python 125 | import tensorflow as tf 126 | from transform.audio_featurizer import AudioFeaturizer 127 | 128 | audio_file = 'englist.wav' 129 | conf = {'type':'Fbank' 130 | 'sample_rate':8000, 131 | 'delta_delta': True 132 | } 133 | feature_ext = AudioFeaturizer(conf) 134 | dim = feature_ext.dim 135 | print('Dim size is ', dim) 136 | num_channels = feature_ext.num_channels 137 | print('num_channels is ', num_channels) 138 | 139 | feat = feature_ext(audio_file) 140 | with tf.Session() as sess: 141 | fbank = sess.run(feat) 142 | print('Fbank shape is ', fbank.shape) 143 | ``` 144 | 145 | Result: 146 | 147 | ``` 148 | Dim is 40 149 | Fbank shape is (346, 40, 3) # [T, D, C] 150 | ``` 151 | 152 | #### 5.2 CMVN usage: 153 | 154 | ```python 155 | import tensorflow as tf 156 | from transform.audio_featurizer import AudioFeaturizer 157 | 158 | dim = 23 159 | config = {'type': 'CMVN', 160 | 'global_mean': np.zeros(dim).tolist(), 161 | 'global_variance': np.ones(dim).tolist(), 162 | 'local_cmvn': True} 163 | cmvn = AudioFeaturizer(config) 164 | audio_feature = tf.compat.v1.random_uniform(shape=[10, dim], dtype=tf.float32, maxval=1.0) 165 | print('cmvn : ', cmvn(audio_feature)) 166 | ``` 167 | -------------------------------------------------------------------------------- /docs/using_docker.md: -------------------------------------------------------------------------------- 1 | # Installation using Docker 2 | 3 | ## Install Docker 4 | 5 | Make sure `docker` has been installed. You can refer to the [official tutorial](https://docs.docker.com/install/). 6 | 7 | Install NVIDIA Container Toolkit if using GPU. You can refer to [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) 8 | 9 | ## Docker image 10 | 11 | You can build image locally or using pre-build images, pre-built Docker images are available on DockerHub. 12 | 13 | ### Pull Docker Image 14 | 15 | You can directly pull the pre-build docker images for athena. We have created the following docker images: 16 | 17 | [athena-tf-2.1.0-gpu-py3](https://hub.docker.com/repository/docker/garygao99/athena/) 18 | 19 | Download the image as below: 20 | ``` 21 | docker pull garygao99/athena:tf-2.1.0-gpu-py3 22 | ``` 23 | 24 | ### Build images 25 | 26 | You can build image locally as below: 27 | ``` 28 | docker build -t garygao99/athena:tf-2.1.0-gpu-py3 . 29 | ``` 30 | 31 | ## Create Container 32 | 33 | After the image downloaded, create a container. 34 | 35 | ```bash 36 | docker run -it --gpus all garygao99/athena:tf-2.1.0-gpu-py3 /bin/bash 37 | ``` 38 | -------------------------------------------------------------------------------- /examples/asr/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Examples for HKUST 3 | 4 | ## 1 Transformer 5 | 6 | ```bash 7 | source env.sh 8 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust 9 | python athena/main.py examples/asr/hkust/transformer.json 10 | ``` 11 | 12 | ## 2 MTL_Transformer_CTC 13 | 14 | ```bash 15 | source env.sh 16 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust 17 | python athena/main.py examples/asr/hkust/mtl_transformer.json 18 | ``` 19 | 20 | # Examples for Librispeech 21 | 22 | TODO: need test 23 | -------------------------------------------------------------------------------- /examples/asr/aishell/mtl_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":15, 4 | "sorta_epoch":1, 5 | "ckpt":"examples/asr/aishell/ckpts/mtl_transformer_ctc/", 6 | "summary_dir":"examples/asr/aishell/ckpts/mtl_transformer_ctc/event", 7 | 8 | "solver_gpu":[2], 9 | "solver_config":{ 10 | "clip_norm":100, 11 | "log_interval":10, 12 | "enable_tf_function":true 13 | }, 14 | 15 | "model":"mtl_transformer_ctc", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "model":"speech_transformer", 20 | "model_config":{ 21 | "return_encoder_output":true, 22 | "num_filters":512, 23 | "d_model":512, 24 | "num_heads":8, 25 | "num_encoder_layers":12, 26 | "num_decoder_layers":6, 27 | "dff":1280, 28 | "rate":0.1, 29 | "label_smoothing_rate":0.0, 30 | "schedual_sampling_rate":0.9 31 | }, 32 | "mtl_weight":0.5 33 | }, 34 | 35 | "decode_config":{ 36 | "beam_search":true, 37 | "beam_size":4, 38 | "ctc_weight":0.3, 39 | "lm_weight":0.1, 40 | "lm_path":"examples/asr/aishell/data/lm.bin" 41 | }, 42 | 43 | "optimizer":"warmup_adam", 44 | "optimizer_config":{ 45 | "d_model":512, 46 | "warmup_steps":8000, 47 | "k":0.5, 48 | "decay_steps": 50000, 49 | "decay_rate": 0.1 50 | }, 51 | 52 | "dataset_builder": "speech_recognition_dataset", 53 | "dataset_config":{ 54 | "audio_config":{ 55 | "type":"Fbank", 56 | "filterbank_channel_count":40, 57 | "local_cmvn":false 58 | }, 59 | "cmvn_file":"examples/asr/aishell/data/cmvn", 60 | "text_config": { 61 | "type":"vocab", 62 | "model":"examples/asr/aishell/data/vocab" 63 | }, 64 | "input_length_range":[10, 8000] 65 | }, 66 | "num_data_threads": 1, 67 | "train_csv":"examples/asr/aishell/data/train.csv", 68 | "dev_csv":"examples/asr/aishell/data/dev.csv", 69 | "test_csv":"examples/asr/aishell/data/test.csv" 70 | } 71 | -------------------------------------------------------------------------------- /examples/asr/aishell/mtl_transformer_sp.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":8, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/aishell/ckpts/mtl_transformer_ctc/", 6 | "summary_dir":"examples/asr/aishell/ckpts/mtl_transformer_ctc/event", 7 | 8 | "solver_gpu":[2], 9 | "solver_config":{ 10 | "clip_norm":100, 11 | "log_interval":10, 12 | "enable_tf_function":true 13 | }, 14 | 15 | "model":"mtl_transformer_ctc", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "model":"speech_transformer", 20 | "model_config":{ 21 | "return_encoder_output":true, 22 | "num_filters":512, 23 | "d_model":512, 24 | "num_heads":8, 25 | "num_encoder_layers":12, 26 | "num_decoder_layers":6, 27 | "dff":1280, 28 | "rate":0.1, 29 | "label_smoothing_rate":0.0, 30 | "schedual_sampling_rate":0.9 31 | }, 32 | "mtl_weight":0.5 33 | }, 34 | 35 | "decode_config":{ 36 | "beam_search":true, 37 | "beam_size":4, 38 | "ctc_weight":0.3, 39 | "lm_weight":0.1, 40 | "lm_path":"examples/asr/aishell/data/lm.bin" 41 | }, 42 | 43 | "optimizer":"warmup_adam", 44 | "optimizer_config":{ 45 | "d_model":512, 46 | "warmup_steps":3500, 47 | "k":0.7, 48 | "decay_steps": 55050, 49 | "decay_rate": 0.1 50 | }, 51 | 52 | "dataset_builder": "speech_recognition_dataset", 53 | "dataset_config":{ 54 | "audio_config":{ 55 | "type":"Fbank", 56 | "filterbank_channel_count":40, 57 | "local_cmvn":false 58 | }, 59 | "cmvn_file":"examples/asr/aishell/data/cmvn", 60 | "text_config": { 61 | "type":"vocab", 62 | "model":"examples/asr/aishell/data/vocab" 63 | }, 64 | "input_length_range":[10, 8000], 65 | "speed_permutation": [0.9, 1.0, 1.1] 66 | }, 67 | "num_data_threads": 1, 68 | "train_csv":"examples/asr/aishell/data/train.csv", 69 | "dev_csv":"examples/asr/aishell/data/dev.csv", 70 | "test_csv":"examples/asr/aishell/data/test.csv" 71 | } 72 | -------------------------------------------------------------------------------- /examples/asr/aishell/rnnlm.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":16, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/aishell/ckpts/rnnlm", 6 | 7 | "solver_gpu":[1], 8 | "solver_config":{ 9 | "clip_norm":100, 10 | "log_interval":10, 11 | "enable_tf_function":true 12 | }, 13 | 14 | 15 | "model":"rnnlm", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "d_model": 512, 20 | "rnn_type": "lstm", 21 | "num_layer": 4, 22 | "dropout_rate": 0.1, 23 | "sos": -1, 24 | "eos": -1 25 | }, 26 | 27 | "optimizer":"warmup_adam", 28 | "optimizer_config":{ 29 | "d_model":512, 30 | "warmup_steps":3500, 31 | "k":0.7, 32 | "decay_steps": 56296, 33 | "decay_rate": 0.1 34 | }, 35 | 36 | "dataset_builder": "language_dataset", 37 | "dataset_config":{ 38 | "input_text_config":{ 39 | "type":"vocab", 40 | "model":"examples/asr/aishell/data/vocab" 41 | }, 42 | "output_text_config":{ 43 | "type":"vocab", 44 | "model":"examples/asr/aishell/data/vocab" 45 | } 46 | }, 47 | "num_data_threads": 1, 48 | "train_csv":"examples/asr/aishell/data/train.trans.csv", 49 | "dev_csv":"examples/asr/aishell/data/dev.trans.csv", 50 | "test_csv":"examples/asr/aishell/data/test.trans.csv" 51 | } 52 | 53 | -------------------------------------------------------------------------------- /examples/asr/hkust/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Examples for HKUST 3 | 4 | ## 1 Transformer 5 | 6 | ```bash 7 | source env.sh 8 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust 9 | python athena/main.py examples/asr/hkust/transformer.json 10 | ``` 11 | 12 | ## 2 MTL_Transformer_CTC 13 | 14 | ```bash 15 | source env.sh 16 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust 17 | python athena/main.py examples/asr/hkust/mtl_transformer.json 18 | ``` 19 | -------------------------------------------------------------------------------- /examples/asr/hkust/deep_speech.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":20, 4 | "sorta_epoch":1, 5 | "ckpt":"examples/asr/hkust/ckpts/deep_speech", 6 | "solver_gpu":[2], 7 | "solver_config":{ 8 | "clip_norm":100.0, 9 | "log_interval":1 10 | }, 11 | 12 | "model":"deep_speech", 13 | "num_classes": null, 14 | "pretrained_model": null, 15 | "model_config":{ 16 | "conv_filters":64, 17 | "rnn_hidden_size":1680, 18 | "rnn_type":"cudnngru", 19 | "num_rnn_layers":6 20 | }, 21 | 22 | "optimizer":"warmup_adam", 23 | "optimizer_config":{ 24 | "d_model":512, 25 | "warmup_steps":8000, 26 | "k":0.5 27 | }, 28 | 29 | "dataset_builder": "speech_recognition_dataset", 30 | "dataset_config":{ 31 | "audio_config":{ 32 | "type":"Fbank", 33 | "filterbank_channel_count":40, 34 | "local_cmvn":false 35 | }, 36 | "cmvn_file":"examples/asr/hkust/data/cmvn", 37 | "text_config": { 38 | "type":"vocab", 39 | "model":"examples/asr/hkust/data/vocab" 40 | }, 41 | "input_length_range":[10, 8000] 42 | }, 43 | "num_data_threads": 1, 44 | "train_csv":"examples/asr/hkust/data/train.csv", 45 | "dev_csv":"examples/asr/hkust/data/dev.csv", 46 | "test_csv":"examples/asr/hkust/data/dev.csv" 47 | } -------------------------------------------------------------------------------- /examples/asr/hkust/local/segment_word.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """ segment word """ 18 | 19 | import codecs 20 | import sys 21 | import re 22 | from absl import logging 23 | import jieba 24 | 25 | 26 | def segment_trans(vocab_file, text_file): 27 | ''' segment transcripts according to vocab 28 | using Maximum Matching Algorithm 29 | Args: 30 | vocab_file: vocab file 31 | text_file: transcripts file 32 | Returns: 33 | seg_trans: segment words 34 | ''' 35 | jieba.set_dictionary(vocab_file) 36 | with open(text_file, "r", encoding="utf-8") as text: 37 | lines = text.readlines() 38 | sents = '' 39 | for line in lines: 40 | seg_line = jieba.cut(line.strip(), HMM=False) 41 | seg_line = ' '.join(seg_line) 42 | sents += seg_line + '\n' 43 | return sents 44 | 45 | 46 | if __name__ == "__main__": 47 | logging.set_verbosity(logging.INFO) 48 | sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach()) 49 | if len(sys.argv) < 3: 50 | logging.warning('Usage: python {} vocab_file text_file'.format(sys.argv[0])) 51 | sys.exit() 52 | print(segment_trans(sys.argv[1], sys.argv[2])) 53 | -------------------------------------------------------------------------------- /examples/asr/hkust/mpc.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":64, 3 | "num_epochs":60, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/hkust/ckpts/mpc", 6 | "summary_dir":"examples/asr/hkust/ckpts/mpc/event", 7 | 8 | "solver_gpu":[0], 9 | "solver_config":{ 10 | "clip_norm":100, 11 | "log_interval":10, 12 | "enable_tf_function":true 13 | }, 14 | 15 | "model": "mpc", 16 | "num_classes": 40, 17 | "model_config":{ 18 | "return_encoder_output":false, 19 | "num_filters":512, 20 | "d_model":512, 21 | "num_heads":8, 22 | "num_encoder_layers":12, 23 | "dff":1280, 24 | "rate":0.1, 25 | "chunk_size":1, 26 | "keep_probability":0.8 27 | }, 28 | 29 | "optimizer": "warmup_adam", 30 | "optimizer_config":{ 31 | "d_model":512, 32 | "warmup_steps":5000, 33 | "k":0.5 34 | }, 35 | 36 | "dataset_builder":"speech_dataset", 37 | "dataset_config":{ 38 | "audio_config":{ 39 | "type":"Fbank", 40 | "filterbank_channel_count":40, 41 | "local_cmvn":false 42 | }, 43 | "cmvn_file":"examples/asr/hkust/data/cmvn", 44 | "input_length_range":[10, 8000] 45 | }, 46 | "num_data_threads": 1, 47 | "train_csv":"examples/asr/hkust/data/train.csv", 48 | "dev_csv":"examples/asr/hkust/data/dev.csv" 49 | } 50 | -------------------------------------------------------------------------------- /examples/asr/hkust/mtl_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":8, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/hkust/ckpts/mtl_transformer_ctc/", 6 | "summary_dir":"examples/asr/hkust/ckpts/mtl_transformer_ctc/event", 7 | 8 | "solver_gpu":[0], 9 | "solver_config":{ 10 | "clip_norm":100, 11 | "log_interval":10, 12 | "enable_tf_function":true 13 | }, 14 | 15 | "model":"mtl_transformer_ctc", 16 | "num_classes": null, 17 | "pretrained_model": "examples/asr/hkust/mpc.json", 18 | "model_config":{ 19 | "model":"speech_transformer", 20 | "model_config":{ 21 | "return_encoder_output":true, 22 | "num_filters":512, 23 | "d_model":512, 24 | "num_heads":8, 25 | "num_encoder_layers":12, 26 | "num_decoder_layers":6, 27 | "dff":1280, 28 | "rate":0.1, 29 | "label_smoothing_rate":0.0, 30 | "schedual_sampling_rate":0.9 31 | }, 32 | "mtl_weight":0.5 33 | }, 34 | 35 | "decode_config":{ 36 | "beam_search":true, 37 | "beam_size":4, 38 | "ctc_weight":0.3, 39 | "lm_weight":0.1, 40 | "lm_path":"examples/asr/hkust/data/4gram.arpa" 41 | }, 42 | 43 | "optimizer":"warmup_adam", 44 | "optimizer_config":{ 45 | "d_model":512, 46 | "warmup_steps":8000, 47 | "k":0.5, 48 | "decay_steps":22000, 49 | "decay_rate":0.1 50 | }, 51 | 52 | "dataset_builder": "speech_recognition_dataset", 53 | "dataset_config":{ 54 | "audio_config":{ 55 | "type":"Fbank", 56 | "filterbank_channel_count":40, 57 | "local_cmvn":false 58 | }, 59 | "cmvn_file":"examples/asr/hkust/data/cmvn", 60 | "text_config": { 61 | "type":"vocab", 62 | "model":"examples/asr/hkust/data/vocab" 63 | }, 64 | "speed_permutation":[0.9, 1.0, 1.1], 65 | "input_length_range":[10, 8000] 66 | }, 67 | "num_data_threads": 1, 68 | "train_csv":"examples/asr/hkust/data/train.csv", 69 | "dev_csv":"examples/asr/hkust/data/dev.csv", 70 | "test_csv":"examples/asr/hkust/data/dev.csv" 71 | } 72 | -------------------------------------------------------------------------------- /examples/asr/hkust/mtl_transformer_sp.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":15, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/hkust/ckpts/mtl_transformer_ctc_sp/", 6 | "summary_dir":"examples/asr/hkust/ckpts/mtl_transformer_ctc_sp/event", 7 | 8 | "solver_gpu":[0], 9 | "solver_config":{ 10 | "clip_norm":100, 11 | "log_interval":10, 12 | "enable_tf_function":true 13 | }, 14 | 15 | "model":"mtl_transformer_ctc", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "model":"speech_transformer", 20 | "model_config":{ 21 | "return_encoder_output":true, 22 | "num_filters":512, 23 | "d_model":512, 24 | "num_heads":8, 25 | "num_encoder_layers":12, 26 | "num_decoder_layers":6, 27 | "dff":1280, 28 | "rate":0.1, 29 | "label_smoothing_rate":0.0, 30 | "schedual_sampling_rate":0.9 31 | }, 32 | "mtl_weight":0.5 33 | }, 34 | 35 | "decode_config":{ 36 | "beam_search":true, 37 | "beam_size":4, 38 | "ctc_weight":0.3, 39 | "lm_weight":0.1, 40 | "lm_path":"examples/asr/hkust/data/4gram.arpa" 41 | }, 42 | 43 | "optimizer":"warmup_adam", 44 | "optimizer_config":{ 45 | "d_model":512, 46 | "warmup_steps":8000, 47 | "k":0.5, 48 | "decay_steps": 140000, 49 | "decay_rate": 0.1 50 | }, 51 | 52 | "dataset_builder": "speech_recognition_dataset", 53 | "dataset_config":{ 54 | "audio_config":{ 55 | "type":"Fbank", 56 | "filterbank_channel_count":40, 57 | "local_cmvn":false 58 | }, 59 | "cmvn_file":"examples/asr/hkust/data/cmvn", 60 | "text_config": { 61 | "type":"vocab", 62 | "model":"examples/asr/hkust/data/vocab" 63 | }, 64 | "input_length_range":[10, 8000], 65 | "speed_permutation": [0.9, 1.0, 1.1] 66 | }, 67 | "num_data_threads": 1, 68 | "train_csv":"examples/asr/hkust/data/train.csv", 69 | "dev_csv":"examples/asr/hkust/data/dev.csv", 70 | "test_csv":"examples/asr/hkust/data/dev.csv" 71 | } 72 | -------------------------------------------------------------------------------- /examples/asr/hkust/rnnlm.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":100, 4 | "sorta_epoch":0, 5 | "ckpt":"examples/asr/hkust/ckpts/rnnlm", 6 | 7 | "solver_gpu":[0], 8 | "solver_config":{ 9 | "clip_norm":100, 10 | "log_interval":10, 11 | "enable_tf_function":true 12 | }, 13 | 14 | 15 | "model":"rnnlm", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "d_model": 512, 20 | "rnn_type": "lstm", 21 | "num_layer": 4, 22 | "dropout_rate": 0.1, 23 | "sos": -1, 24 | "eos": -1 25 | }, 26 | 27 | "optimizer":"warmup_adam", 28 | "optimizer_config":{ 29 | "d_model":512, 30 | "warmup_steps":8000, 31 | "k":0.5 32 | }, 33 | 34 | "dataset_builder": "language_dataset", 35 | "dataset_config":{ 36 | "input_text_config":{ 37 | "type":"vocab", 38 | "model":"examples/asr/hkust/data/vocab" 39 | }, 40 | "output_text_config":{ 41 | "type":"vocab", 42 | "model":"examples/asr/hkust/data/vocab" 43 | } 44 | }, 45 | "num_data_threads": 1, 46 | "train_csv":"examples/asr/hkust/data/train.trans.csv", 47 | "dev_csv":"examples/asr/hkust/data/dev.trans.csv", 48 | "test_csv":"examples/asr/hkust/data/dev.trans.csv" 49 | } -------------------------------------------------------------------------------- /examples/asr/hkust/run.sh: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) ATHENA AUTHORS 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | if [ "athena" != $(basename "$PWD") ]; then 19 | echo "You should run this script in athena directory!!" 20 | exit 1 21 | fi 22 | 23 | source tools/env.sh 24 | 25 | stage=0 26 | stop_stage=100 27 | 28 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 29 | # prepare data 30 | echo "Preparing data" 31 | python examples/asr/hkust/local/prepare_data.py /nfs/project/datasets/opensource_data/hkust 32 | mkdir -p examples/asr/hkust/data 33 | cp /nfs/project/datasets/opensource_data/hkust/{train,dev}.csv examples/asr/hkust/data/ 34 | 35 | # cal cmvn 36 | cat examples/asr/hkust/data/train.csv > examples/asr/hkust/data/all.csv 37 | tail -n +2 examples/asr/hkust/data/dev.csv >> examples/asr/hkust/data/all.csv 38 | python athena/cmvn_main.py examples/asr/hkust/mpc.json examples/asr/hkust/data/all.csv 39 | fi 40 | 41 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 42 | # pretrain stage 43 | echo "Pretraining" 44 | # we recommend training with multi-gpu, for single gpu, run "python athena/main.py examples/asr/hkust/mpc.json" instead 45 | horovodrun -np 4 -H localhost:4 python athena/horovod_main.py examples/asr/hkust/mpc.json 46 | fi 47 | 48 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 49 | # finetuning stage 50 | echo "Fine-tuning" 51 | # we recommend training with multi-gpu, for single gpu, run "python athena/main.py examples/asr/hkust/mtl_transformer.json" instead 52 | horovodrun -np 4 -H localhost:4 python athena/horovod_main.py examples/asr/hkust/mtl_transformer.json 53 | fi 54 | 55 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 56 | # decoding stage 57 | echo "Decoding" 58 | # prepare language model 59 | tail -n +2 examples/asr/hkust/data/train.csv | cut -f 3 > examples/asr/hkust/data/text 60 | python examples/asr/hkust/local/segment_word.py examples/asr/hkust/data/vocab \ 61 | examples/asr/hkust/data/text > examples/asr/hkust/data/text.seg 62 | tools/kenlm/build/bin/lmplz -o 4 < examples/asr/hkust/data/text.seg > examples/asr/hkust/data/4gram.arpa 63 | 64 | python athena/decode_main.py examples/asr/hkust/mtl_transformer.json 65 | fi 66 | 67 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 68 | echo "training rnnlm" 69 | tail -n +2 examples/asr/hkust/data/train.csv | awk '{print $3"\t"$3}' > examples/asr/hkust/data/train.trans.csv 70 | tail -n +2 examples/asr/hkust/data/dev.csv | awk '{print $3"\t"$3}' > examples/asr/hkust/data/dev.trans.csv 71 | python athena/main.py examples/asr/hkust/rnnlm.json 72 | fi 73 | 74 | -------------------------------------------------------------------------------- /examples/asr/hkust/transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":32, 3 | "num_epochs":20, 4 | "sorta_epoch":1, 5 | "ckpt":"examples/asr/hkust/ckpts/transformer", 6 | 7 | "solver_gpu":[0], 8 | "solver_config":{ 9 | "clip_norm":100, 10 | "log_interval":10, 11 | "enable_tf_function":true 12 | }, 13 | 14 | 15 | "model":"speech_transformer", 16 | "num_classes": null, 17 | "pretrained_model": null, 18 | "model_config":{ 19 | "return_encoder_output":false, 20 | "num_filters":512, 21 | "d_model":512, 22 | "num_heads":8, 23 | "num_encoder_layers":12, 24 | "num_decoder_layers":6, 25 | "dff":1280, 26 | "rate":0.1, 27 | "label_smoothing_rate":0.0 28 | }, 29 | 30 | "optimizer":"warmup_adam", 31 | "optimizer_config":{ 32 | "d_model":512, 33 | "warmup_steps":8000, 34 | "k":0.5 35 | }, 36 | 37 | "dataset_builder": "speech_recognition_dataset", 38 | "dataset_config":{ 39 | "audio_config":{ 40 | "type":"Fbank", 41 | "filterbank_channel_count":40, 42 | "local_cmvn":false 43 | }, 44 | "text_config": { 45 | "type":"vocab", 46 | "model":"examples/asr/hkust/data/vocab" 47 | }, 48 | "cmvn_file":"examples/asr/hkust/data/cmvn", 49 | "input_length_range":[10, 5000] 50 | }, 51 | "num_data_threads": 1, 52 | "train_csv":"examples/asr/hkust/data/train.csv", 53 | "dev_csv":"examples/asr/hkust/data/dev.csv", 54 | "test_csv":"examples/asr/hkust/data/dev.csv" 55 | } -------------------------------------------------------------------------------- /examples/asr/librispeech/data/librispeech_unigram5000.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/examples/asr/librispeech/data/librispeech_unigram5000.model -------------------------------------------------------------------------------- /examples/asr/librispeech/transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":16, 3 | "num_epochs":50, 4 | "sorta_epoch":2, 5 | "ckpt":"examples/asr/librispeech/ckpts/debug", 6 | "solver_gpu":[0], 7 | "solver_config":{ 8 | "clip_norm":100.0, 9 | "log_interval":1 10 | }, 11 | 12 | "model":"speech_transformer", 13 | "model_config":{ 14 | "return_encoder_output":false, 15 | "num_filters":512, 16 | "d_model":512, 17 | "num_heads":8, 18 | "num_encoder_layers":12, 19 | "num_decoder_layers":6, 20 | "dff":1280, 21 | "rate":0.1, 22 | "label_smoothing_rate":0.0, 23 | "schedual_sampling_rate":0.9 24 | }, 25 | 26 | "optimizer":"warmup_adam", 27 | "optimizer_config":{ 28 | "d_model":512, 29 | "warmup_steps":8000, 30 | "k":0.5 31 | }, 32 | 33 | "trainset_config":{ 34 | "data_csv":"/tmp-data/dataset/opensource/librispeech/wav/train-clean-100.csv", 35 | "audio_config":{ 36 | "type":"Fbank", 37 | "filterbank_channel_count":40, 38 | "local_cmvn":false 39 | }, 40 | "vocab_file":"examples/asr/hkust/data/librispeech_unigram5000.model", 41 | "subword":true, 42 | "audio_min_length":10, 43 | "audio_max_length":10000, 44 | "force_process":false 45 | }, 46 | "cmvn_file":"examples/asr/librispeech/data/cmvn", 47 | "dev_csv":"/tmp-data/dataset/opensource/librispeech/wav/dev-clean.csv", 48 | "test_csv":"/tmp-data/dataset/opensource/librispeech/wav/test-clean.csv" 49 | } 50 | -------------------------------------------------------------------------------- /examples/asr/magic_data/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ magic_data dataset """ 17 | 18 | import os 19 | import sys 20 | import codecs 21 | import pandas 22 | from absl import logging 23 | 24 | import tensorflow as tf 25 | from athena import get_wave_file_length 26 | 27 | SUBSETS = ["train", "dev"] 28 | 29 | def convert_audio_and_split_transcript(directory, subset, out_csv_file): 30 | 31 | gfile = tf.compat.v1.gfile 32 | logging.info("Processing audio and transcript for {}".format(subset)) 33 | audio_dir = os.path.join(directory, subset) 34 | trans_dir = os.path.join(directory, subset, "TRANS.txt") 35 | 36 | files = [] 37 | with codecs.open(trans_dir,'r', encoding='utf-8') as f: 38 | lines = f.readlines() 39 | for line in lines[1:]: 40 | items = line.strip().split('\t') 41 | wav_filename = items[0] 42 | labels = items[2] 43 | speaker = items[1] 44 | files.append((wav_filename, speaker, labels)) 45 | files_size_dict = {} 46 | for root, subdirs, _ in gfile.Walk(audio_dir): 47 | for subdir in subdirs: 48 | for filename in os.listdir(os.path.join(root, subdir)): 49 | files_size_dict[filename] = ( 50 | get_wave_file_length(os.path.join(root, subdir, filename)), 51 | subdir, 52 | ) 53 | content = [] 54 | for wav_filename, speaker, trans in files: 55 | if wav_filename in files_size_dict: 56 | filesize, subdir = files_size_dict[wav_filename] 57 | abspath = os.path.join(audio_dir, subdir, wav_filename) 58 | content.append((abspath, filesize, trans, speaker)) 59 | 60 | files = content 61 | df = pandas.DataFrame( 62 | data=files, columns=["wav_filename", "wav_length_ms", "transcript", "speaker"] 63 | ) 64 | df.to_csv(out_csv_file, index=False, sep="\t") 65 | logging.info("Successfully generated csv file {}".format(out_csv_file)) 66 | 67 | def processor(dircetory, subset, force_process): 68 | """ download and process """ 69 | if subset not in SUBSETS: 70 | raise ValueError(subset, "is not in magic_data") 71 | if force_process: 72 | logging.info("force process is set to be true") 73 | 74 | subset_csv = os.path.join(dircetory, subset + ".csv") 75 | if not force_process and os.path.exists(subset_csv): 76 | logging.info("{} already exist".format(subset_csv)) 77 | return subset_csv 78 | logging.info("Processing the magic_data subset {} in {}".format(subset, dircetory)) 79 | convert_audio_and_split_transcript(dircetory, subset, subset_csv) 80 | logging.info("Finished processing magic_data subset {}".format(subset)) 81 | return subset_csv 82 | 83 | if __name__ == "__main__": 84 | logging.set_verbosity(logging.INFO) 85 | DIR = sys.argv[1] 86 | for SUBSET in SUBSETS: 87 | processor(DIR, SUBSET, True) 88 | 89 | -------------------------------------------------------------------------------- /examples/asr/primewords/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """ primewords dataset """ 17 | 18 | import os 19 | import sys 20 | import codecs 21 | import pandas 22 | from absl import logging 23 | 24 | import tensorflow as tf 25 | # from athena import get_wave_file_length 26 | 27 | def convert_audio_and_split_transcript(directory, out_csv_file): 28 | 29 | gfile = tf.compat.v1.gfile 30 | logging.info("Processing audio and transcript for {}".format('primewords')) 31 | audio_dir = os.path.join(directory, "audio_files") 32 | trans_dir = os.path.join(directory, "set1_transcript.json") 33 | 34 | files = [] 35 | with codecs.open(trans_dir, 'r', encoding='utf-8') as f: 36 | items = eval(f.readline()) 37 | for item in items: 38 | wav_filename = item['file'] 39 | # labels = item['text'] 40 | 41 | labels = ''.join([x for x in item['text'].split()]) 42 | # speaker = item['user_id'] 43 | files.append((wav_filename, labels)) 44 | 45 | files_size_dict = {} 46 | for subdir in os.listdir(audio_dir): 47 | for root, subsubdirs, _ in gfile.Walk(os.path.join(audio_dir, subdir)): 48 | for subsubdir in subsubdirs: 49 | for filename in os.listdir(os.path.join(root, subsubdir)): 50 | files_size_dict[filename] = ( 51 | os.path.join(root, subsubdir, filename), 52 | root, 53 | subsubdir 54 | ) 55 | # print(os.path.join(root,subsubdir,filename)) 56 | content = [] 57 | for wav_filename, trans in files: 58 | if wav_filename in files_size_dict: 59 | filesize, root, subsubdir = files_size_dict[wav_filename] 60 | abspath = os.path.join(root, subsubdir, wav_filename) 61 | content.append((abspath, filesize, trans, None)) 62 | files = content 63 | df = pandas.DataFrame( 64 | data=files, columns=["wav_filename", "wav_length_ms", "transcript", "speaker"] 65 | ) 66 | 67 | df.to_csv(out_csv_file, index=False, sep="\t") 68 | logging.info("Successfully generated csv file {}".format(out_csv_file)) 69 | 70 | def processor(dircetory, force_process): 71 | if force_process: 72 | logging.info("force process is set to be true") 73 | 74 | subset_csv = os.path.join(dircetory, 'train' + ".csv") 75 | if not force_process and os.path.exists(subset_csv): 76 | logging.info("{} already exist".format(subset_csv)) 77 | return subset_csv 78 | logging.info("Processing the primewords subset {} in {}".format('train', dircetory)) 79 | convert_audio_and_split_transcript(dircetory, subset_csv) 80 | logging.info("Finished processing primewords subset {}".format('train')) 81 | return subset_csv 82 | 83 | if __name__ == '__main__': 84 | 85 | logging.set_verbosity(logging.INFO) 86 | DIR = sys.argv[1] 87 | processor(DIR,True) 88 | 89 | -------------------------------------------------------------------------------- /examples/translate/spa-eng-example/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Only support eager mode 2 | # pylint: disable=invalid-name 3 | 4 | """ An example using translation dataset """ 5 | import os 6 | import re 7 | import io 8 | import sys 9 | import unicodedata 10 | import pandas 11 | import tensorflow as tf 12 | from absl import logging 13 | from athena import LanguageDatasetBuilder 14 | 15 | 16 | def preprocess_sentence(w): 17 | """ preprocess_sentence """ 18 | w = ''.join(c for c in unicodedata.normalize('NFD', w.lower().strip()) 19 | if unicodedata.category(c) != 'Mn') 20 | w = re.sub(r"([?.!,?])", r" \1 ", w) 21 | w = re.sub(r'[" "]+', " ", w) 22 | w = re.sub(r"[^a-zA-Z?.!,?]+", " ", w) 23 | w = w.rstrip().strip() 24 | return w 25 | 26 | def create_dataset(csv_file="examples/translate/spa-eng-examples/data/train.csv"): 27 | """ create and store in csv file """ 28 | path_to_zip = tf.keras.utils.get_file('spa-eng.zip', 29 | origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip', 30 | extract=True) 31 | path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt" 32 | lines = io.open(path_to_file, encoding='UTF-8').read().strip().split('\n') 33 | word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines[:None]] 34 | df = pandas.DataFrame( 35 | data=word_pairs, columns=["input_transcripts", "output_transcripts"] 36 | ) 37 | csv_dir = os.path.dirname(csv_file) 38 | if not os.path.exists(csv_dir): 39 | os.mkdir(csv_dir) 40 | df.to_csv(csv_file, index=False, sep="\t") 41 | logging.info("Successfully generated csv file {}".format(csv_file)) 42 | 43 | if __name__ == "__main__": 44 | logging.set_verbosity(logging.INFO) 45 | CSV_FILE = sys.argv[1] 46 | create_dataset(CSV_FILE) 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.0.3 2 | sox 3 | absl-py 4 | yapf 5 | pylint 6 | flake8 7 | horovod 8 | tqdm 9 | sentencepiece 10 | librosa 11 | kenlm 12 | jieba 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) ATHENA AUTHORS 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import sys 18 | from glob import glob 19 | import shutil 20 | import setuptools 21 | import tensorflow as tf 22 | 23 | TF_INCLUDE, TF_CFLAG = tf.sysconfig.get_compile_flags() 24 | TF_INCLUDE = TF_INCLUDE.split("-I")[1] 25 | 26 | TF_LIB_INC, TF_SO_LIB = tf.sysconfig.get_link_flags() 27 | TF_SO_LIB = TF_SO_LIB.replace( 28 | "-l:libtensorflow_framework.1.dylib", "-ltensorflow_framework.1" 29 | ) 30 | TF_LIB_INC = TF_LIB_INC.split("-L")[1] 31 | TF_SO_LIB = TF_SO_LIB.split("-l")[1] 32 | 33 | complie_args = [TF_CFLAG, "-fPIC", "-shared", "-O2", "-std=c++11"] 34 | if sys.platform == "darwin": # Mac os X before Mavericks (10.9) 35 | complie_args.append("-stdlib=libc++") 36 | 37 | module = setuptools.Extension( 38 | "athena.transform.feats.ops.x_ops", 39 | sources=glob("athena/transform/feats/ops/kernels/*.cc"), 40 | depends=glob("athena/transform/feats/ops/kernels/*.h"), 41 | extra_compile_args=complie_args, 42 | include_dirs=["athena/transform/feats/ops", TF_INCLUDE], 43 | library_dirs=[TF_LIB_INC], 44 | libraries=[TF_SO_LIB], 45 | language="c++", 46 | ) 47 | 48 | 49 | with open("README.md", "r") as fh: 50 | long_description = fh.read() 51 | setuptools.setup( 52 | name="athena", 53 | version="0.1.0", 54 | author="ATHENA AUTHORS", 55 | author_email="athena@gmail.com", 56 | description="for speech recognition", 57 | long_description=long_description, 58 | long_description_content_type="text/markdown", 59 | url="https://github.com/didichuxing/athena", 60 | packages=setuptools.find_packages(), 61 | package_data={"": ["x_ops*.so", "*.vocab", "sph2pipe"]}, 62 | exclude_package_data={"feats": ["*_test*"]}, 63 | ext_modules=[module], 64 | python_requires=">=3", 65 | ) 66 | path = glob("build/lib.*/athena/transform/feats/ops/x_ops.*.so") 67 | shutil.copy(path[0], "athena/transform/feats/ops/x_ops.so") 68 | -------------------------------------------------------------------------------- /tools/env.sh: -------------------------------------------------------------------------------- 1 | # check if we are executing or sourcing. 2 | if [[ "$0" == "$BASH_SOURCE" ]] 3 | then 4 | echo "You must source this script rather than executing it." 5 | exit -1 6 | fi 7 | 8 | # don't use readlink because macOS won't support it. 9 | if [[ "$BASH_SOURCE" == "/"* ]] 10 | then 11 | export MAIN_ROOT=$(dirname $BASH_SOURCE) 12 | else 13 | export MAIN_ROOT=$PWD/$(dirname $BASH_SOURCE) 14 | fi 15 | 16 | # pip bins 17 | export PATH=$PATH:~/.local/bin 18 | 19 | # athena 20 | export PYTHONPATH=${PYTHONPATH}:$MAIN_ROOT: 21 | -------------------------------------------------------------------------------- /tools/install.sh: -------------------------------------------------------------------------------- 1 | 2 | rm -rf build/ athena.egg-info/ dist/ 3 | python setup.py bdist_wheel sdist 4 | python -m pip install --ignore-installed dist/athena-0.1.0*.whl 5 | -------------------------------------------------------------------------------- /tools/install_kenlm.sh: -------------------------------------------------------------------------------- 1 | # reference to https://github.com/kpu/kenlm 2 | cd tools 3 | wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz 4 | 5 | mkdir -p kenlm/build 6 | cd kenlm/build 7 | cmake .. 8 | make -j 4 9 | -------------------------------------------------------------------------------- /tools/install_sph2pipe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Tool to convert sph audio format to wav format. 3 | cd tools 4 | if [ ! -e sph2pipe_v2.5.tar.gz ]; then 5 | wget -T 10 -t 3 http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz || exit 1 6 | fi 7 | 8 | tar -zxvf sph2pipe_v2.5.tar.gz || exit 1 9 | cd sph2pipe_v2.5/ 10 | gcc -o sph2pipe *.c -lm || exit 1 11 | --------------------------------------------------------------------------------