├── .clang-format
├── .flake8
├── .gitignore
├── .pylintrc
├── .style.yapf
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── README.md
├── athena
├── __init__.py
├── cmvn_main.py
├── data
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── language_set.py
│ │ ├── speech_recognition.py
│ │ ├── speech_recognition_test.py
│ │ └── speech_set.py
│ ├── feature_normalizer.py
│ └── text_featurizer.py
├── decode_main.py
├── horovod_main.py
├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── commons.py
│ ├── functional.py
│ └── transformer.py
├── loss.py
├── main.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── customized.py
│ ├── deep_speech.py
│ ├── masked_pc.py
│ ├── mtl_seq2seq.py
│ ├── rnn_lm.py
│ └── speech_transformer.py
├── solver.py
├── tools
│ ├── __init__.py
│ ├── beam_search.py
│ ├── ctc_scorer.py
│ ├── lm_scorer.py
│ └── sph2pipe
├── transform
│ ├── README.md
│ ├── __init__.py
│ ├── audio_featurizer.py
│ └── feats
│ │ ├── __init__.py
│ │ ├── base_frontend.py
│ │ ├── cmvn.py
│ │ ├── cmvn_test.py
│ │ ├── fbank.py
│ │ ├── fbank_pitch.py
│ │ ├── fbank_pitch_test.py
│ │ ├── fbank_test.py
│ │ ├── framepow.py
│ │ ├── framepow_test.py
│ │ ├── mfcc.py
│ │ ├── mfcc_test.py
│ │ ├── ops
│ │ ├── Makefile
│ │ ├── __init__.py
│ │ ├── kernels
│ │ │ ├── complex_defines.h
│ │ │ ├── delta_delta.cc
│ │ │ ├── delta_delta.h
│ │ │ ├── delta_delta_op.cc
│ │ │ ├── delta_delta_op_test.py
│ │ │ ├── fbank.cc
│ │ │ ├── fbank.h
│ │ │ ├── fbank_op.cc
│ │ │ ├── fbank_op_test.py
│ │ │ ├── framepow.cc
│ │ │ ├── framepow.h
│ │ │ ├── framepow_op.cc
│ │ │ ├── mfcc_dct.cc
│ │ │ ├── mfcc_dct.h
│ │ │ ├── mfcc_dct_op.cc
│ │ │ ├── mfcc_mel_filterbank.cc
│ │ │ ├── mfcc_mel_filterbank.h
│ │ │ ├── pitch.cc
│ │ │ ├── pitch.h
│ │ │ ├── pitch_op.cc
│ │ │ ├── resample.cc
│ │ │ ├── resample.h
│ │ │ ├── spectrum.cc
│ │ │ ├── spectrum.h
│ │ │ ├── spectrum_op.cc
│ │ │ ├── spectrum_op_test.py
│ │ │ ├── speed_op.cc
│ │ │ ├── support_functions.cc
│ │ │ ├── support_functions.h
│ │ │ └── x_ops.cc
│ │ └── py_x_ops.py
│ │ ├── pitch.py
│ │ ├── pitch_test.py
│ │ ├── read_wav.py
│ │ ├── read_wav_test.py
│ │ ├── spectrum.py
│ │ ├── spectrum_test.py
│ │ ├── write_wav.py
│ │ └── write_wav_test.py
└── utils
│ ├── __init__.py
│ ├── checkpoint.py
│ ├── data_queue.py
│ ├── hparam.py
│ ├── hparam_test.py
│ ├── learning_rate.py
│ ├── metric_check.py
│ ├── misc.py
│ └── vocabs
│ ├── ch-en.vocab
│ ├── ch.vocab
│ └── en.vocab
├── docs
├── README.md
├── TheTrainningEfficiency.md
├── development
│ └── contributing.md
├── transform
│ ├── img
│ │ ├── DCT.png
│ │ ├── DFT.png
│ │ ├── MFCC.png
│ │ ├── MelBank.png
│ │ ├── Mel_filter.png
│ │ ├── amplitude_spectrum.png
│ │ ├── audio_data.png
│ │ ├── delta.png
│ │ ├── fbank.png
│ │ ├── hamming.png
│ │ ├── logMel.png
│ │ ├── logpower_spectrum.png
│ │ ├── mel_freq.png
│ │ ├── melbanks.png
│ │ ├── phase_spectrum.png
│ │ ├── power_spectrum.png
│ │ ├── spectrum.png
│ │ ├── spectrum_emph.png
│ │ └── spectrum_orig.png
│ ├── speech_feature.md
│ └── user_manual.md
└── using_docker.md
├── examples
├── asr
│ ├── README.md
│ ├── aidatatang_200zh
│ │ └── local
│ │ │ └── prepare_data.py
│ ├── aishell
│ │ ├── data
│ │ │ └── vocab
│ │ ├── local
│ │ │ └── prepare_data.py
│ │ ├── mtl_transformer.json
│ │ ├── mtl_transformer_sp.json
│ │ └── rnnlm.json
│ ├── hkust
│ │ ├── README.md
│ │ ├── data
│ │ │ └── vocab
│ │ ├── deep_speech.json
│ │ ├── local
│ │ │ ├── prepare_data.py
│ │ │ └── segment_word.py
│ │ ├── mpc.json
│ │ ├── mtl_transformer.json
│ │ ├── mtl_transformer_sp.json
│ │ ├── rnnlm.json
│ │ ├── run.sh
│ │ └── transformer.json
│ ├── librispeech
│ │ ├── data
│ │ │ └── librispeech_unigram5000.model
│ │ ├── prepare_data.py
│ │ └── transformer.json
│ ├── magic_data
│ │ └── local
│ │ │ └── prepare_data.py
│ ├── primewords
│ │ └── local
│ │ │ └── prepare_data.py
│ └── switchboard_fisher
│ │ └── prepare_data.py
└── translate
│ └── spa-eng-example
│ └── prepare_data.py
├── requirements.txt
├── setup.py
└── tools
├── env.sh
├── install.sh
├── install_kenlm.sh
└── install_sph2pipe.sh
/.clang-format:
--------------------------------------------------------------------------------
1 | BasedOnStyle: Google
2 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | # .flake8
2 | #
3 | # DESCRIPTION
4 | # Configuration file for the python linter flake8.
5 | #
6 | # This configuration is based on the generic
7 | # configuration published on GitHub.
8 | #
9 | # AUTHOR krnd
10 | # VERSION v1.0
11 | #
12 | # SEE ALSO
13 | # http://flake8.pycqa.org/en/latest/user/options.html
14 | # http://flake8.pycqa.org/en/latest/user/error-codes.html
15 | # http://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes
16 | # http://gist.github.com/krnd
17 | #
18 |
19 | [flake8]
20 |
21 | # Specify the number of subprocesses that Flake8 will use to run checks in parallel.
22 | jobs = auto
23 |
24 | # Increase the verbosity of Flake8’s output.
25 | verbose = 0
26 |
27 | # Decrease the verbosity of Flake8’s output.
28 | quiet = 0
29 |
30 | # Select the formatter used to display errors to the user.
31 | format = default
32 |
33 | # Print the total number of errors.
34 | count = True
35 |
36 | # Print the source code generating the error/warning in question.
37 | show-source = True
38 |
39 | # Count the number of occurrences of each error/warning code and print a report.
40 | statistics = True
41 |
42 |
43 | # Redirect all output to the specified file.
44 | output-file = /tmp/flake8.log
45 |
46 | # Also print output to stdout if output-file has been configured.
47 | tee = True
48 |
49 | # Provide a comma-separated list of glob patterns to exclude from checks.
50 | exclude =
51 | # git folder
52 | .git,
53 | # python cache
54 | __pycache__,
55 | tools
56 |
57 | # Provide a comma-separate list of glob patterns to include for checks.
58 | filename = *.py
59 |
60 | # Provide a custom list of builtin functions, objects, names, etc.
61 | builtins =
62 |
63 | # Report all errors, even if it is on the same line as a `# NOQA` comment.
64 | disable-noqa = False
65 |
66 | # Set the maximum length that any line (with some exceptions) may be.
67 | max-line-length = 100
68 |
69 | # Set the maximum allowed McCabe complexity value for a block of code.
70 | max-complexity = 10
71 |
72 | # Toggle whether pycodestyle should enforce matching the indentation of the opening bracket’s line.
73 | # incluences E131 and E133
74 | hang-closing = True
75 |
76 |
77 | # ERROR CODES
78 | #
79 | # E/W - PEP8 errors/warnings (pycodestyle)
80 | # F - linting errors (pyflakes)
81 | # C - McCabe complexity error (mccabe)
82 | #
83 | # W503 - line break before binary operator
84 |
85 | # Specify a list of codes to ignore.
86 | ignore = W503
87 |
88 | # Specify the list of error codes you wish Flake8 to report.
89 | select = E9,F6,F81,F82,F83,F9
90 |
91 | # Enable off-by-default extensions.
92 | enable-extensions =
93 |
94 | # Enable PyFlakes syntax checking of doctests in docstrings.
95 | doctests = True
96 |
97 | # Specify which files are checked by PyFlakes for doctest syntax.
98 | include-in-doctest =
99 |
100 | # Specify which files are not to be checked by PyFlakes for doctest syntax.
101 | exclude-in-doctest =
102 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | *.pyc
4 | tags
5 | .vscode
6 |
7 | /tools/sph2pipe_v2.5.tar.gz
8 | /tools/sph2pipe_v2.5
9 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = chromium
3 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guideline
2 |
3 | Thanks for considering to contribute this project. All issues and pull requests are highly appreciated.
4 |
5 | ## Pull Requests
6 |
7 | Before sending pull request to this project, please read and follow guidelines below.
8 |
9 | 1. Branch: We accept pull request on `master` branch.
10 | 2. Coding style: Follow the coding style used in VirtualAPK.
11 | 3. Commit message: Use English and be aware of your spell.
12 | 4. Test: Make sure to test your code.
13 |
14 | Add device mode, API version, related log, screenshots and other related information in your pull request if possible.
15 |
16 | NOTE: We assume all your contribution can be licensed under the [Apache License 2.0](https://github.com/didichuxing/athena/tree/master/LICENSE).
17 |
18 | ## Issues
19 |
20 | We love clearly described issues. :)
21 |
22 | Following information can help us to resolve the issue faster.
23 |
24 | * Device mode and hardware information.
25 | * API version.
26 | * Logs.
27 | * Screenshots.
28 | * Steps to reproduce the issue.
29 |
30 | ## Coding Styles
31 |
32 | Please follow the coding styles [here](https://git.xiaojukeji.com/speech-am/athena/tree/master/docs/development/contributing.md)
33 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:2.1.0-gpu-py3
2 |
3 | ENV CUDNN_VERSION=7.6.5.32-1+cuda10.1
4 | ENV NCCL_VERSION=2.4.8-1+cuda10.1
5 |
6 | # Set default shell to /bin/bash
7 | SHELL ["/bin/bash", "-cu"]
8 |
9 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends wget \
10 | vim \
11 | git \
12 | libcudnn7=${CUDNN_VERSION} \
13 | libnccl2=${NCCL_VERSION} \
14 | libnccl-dev=${NCCL_VERSION}
15 |
16 | # Install Open MPI
17 | RUN mkdir /tmp/openmpi && \
18 | cd /tmp/openmpi && \
19 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \
20 | tar zxf openmpi-4.0.0.tar.gz && \
21 | cd openmpi-4.0.0 && \
22 | ./configure --enable-orterun-prefix-by-default && \
23 | make -j $(nproc) all && \
24 | make install && \
25 | ldconfig && \
26 | rm -rf /tmp/openmpi
27 |
28 | # Install Horovod, temporarily using CUDA stubs
29 | RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \
30 | HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_TENSORFLOW=1 && \
31 | pip --default-timeout=1000 install --no-cache-dir git+https://github.com/horovod/horovod && \
32 | ldconfig
33 |
34 | # Install OpenSSH for MPI to communicate between containers
35 | RUN apt-get install -y --no-install-recommends openssh-client openssh-server && \
36 | mkdir -p /var/run/sshd
37 |
38 | # Allow OpenSSH to talk to containers without asking for confirmation
39 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
40 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
41 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
42 |
43 | RUN pip --default-timeout=1000 install sox \
44 | absl-py \
45 | yapf \
46 | pylint \
47 | flake8 \
48 | tqdm \
49 | sentencepiece \
50 | librosa \
51 | kenlm \
52 | pandas \
53 | jieba
54 |
55 | # Install Athena
56 | Run git clone https://github.com/didi/athena.git /athena && \
57 | cd athena && python setup.py bdist_wheel && \
58 | python -m pip install --ignore-installed dist/athena-0.1.0*.whl
59 |
--------------------------------------------------------------------------------
/athena/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ module """
17 | # data
18 | from .data import SpeechRecognitionDatasetBuilder
19 | from .data import SpeechDatasetBuilder
20 | from .data import LanguageDatasetBuilder
21 | from .data import FeatureNormalizer
22 | from .data.text_featurizer import TextFeaturizer
23 |
24 | # layers
25 | from .layers.functional import make_positional_encoding
26 | from .layers.functional import collapse4d
27 | from .layers.functional import gelu
28 | from .layers.commons import PositionalEncoding
29 | from .layers.commons import Collapse4D
30 | from .layers.commons import TdnnLayer
31 | from .layers.commons import Gelu
32 | from .layers.attention import MultiHeadAttention
33 | from .layers.attention import BahdanauAttention
34 | from .layers.attention import HanAttention
35 | from .layers.attention import MatchAttention
36 | from .layers.transformer import Transformer
37 | from .layers.transformer import TransformerEncoder
38 | from .layers.transformer import TransformerDecoder
39 | from .layers.transformer import TransformerEncoderLayer
40 | from .layers.transformer import TransformerDecoderLayer
41 |
42 | # models
43 | from .models.base import BaseModel
44 | from .models.speech_transformer import SpeechTransformer, SpeechTransformer2
45 | from .models.masked_pc import MaskedPredictCoding
46 | from .models.deep_speech import DeepSpeechModel
47 | from .models.mtl_seq2seq import MtlTransformerCtc
48 | from .models.rnn_lm import RNNLM
49 |
50 | # solver & loss & accuracy
51 | from .solver import BaseSolver
52 | from .solver import HorovodSolver
53 | from .solver import DecoderSolver
54 | from .loss import CTCLoss
55 | from .loss import Seq2SeqSparseCategoricalCrossentropy
56 | from .metrics import CTCAccuracy
57 | from .metrics import Seq2SeqSparseCategoricalAccuracy
58 |
59 | # utils
60 | from .utils.checkpoint import Checkpoint
61 | from .utils.learning_rate import WarmUpLearningSchedule, WarmUpAdam
62 | from .utils.learning_rate import (
63 | ExponentialDecayLearningRateSchedule,
64 | ExponentialDecayAdam,
65 | )
66 | from .utils.hparam import HParams, register_and_parse_hparams
67 | from .utils.misc import generate_square_subsequent_mask
68 | from .utils.misc import get_wave_file_length
69 | from .utils.misc import set_default_summary_writer
70 |
71 | # tools
72 | from .tools.beam_search import BeamSearchDecoder
73 |
--------------------------------------------------------------------------------
/athena/cmvn_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # Only support tensorflow 2.0
18 | # pylint: disable=invalid-name, no-member
19 | r""" a sample implementation of LAS for HKUST """
20 | import sys
21 | import json
22 | import tensorflow as tf
23 | from absl import logging
24 | from athena.main import parse_config, SUPPORTED_DATASET_BUILDER
25 |
26 | if __name__ == "__main__":
27 | logging.set_verbosity(logging.INFO)
28 | if len(sys.argv) < 3:
29 | logging.warning('Usage: python {} config_json_file data_csv_file'.format(sys.argv[0]))
30 | sys.exit()
31 | tf.random.set_seed(1)
32 |
33 | jsonfile = sys.argv[1]
34 | with open(jsonfile) as file:
35 | config = json.load(file)
36 | p = parse_config(config)
37 | if "speed_permutation" in p.dataset_config:
38 | p.dataset_config['speed_permutation'] = [1.0]
39 | csv_file = sys.argv[2]
40 | dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.dataset_config)
41 | dataset_builder.load_csv(csv_file).compute_cmvn_if_necessary(True)
42 |
--------------------------------------------------------------------------------
/athena/data/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """ data """
18 | from .datasets.speech_recognition import SpeechRecognitionDatasetBuilder
19 | from .datasets.speech_set import SpeechDatasetBuilder
20 | from .datasets.language_set import LanguageDatasetBuilder
21 | from .feature_normalizer import FeatureNormalizer
22 | from .text_featurizer import TextFeaturizer, SentencePieceFeaturizer
23 |
--------------------------------------------------------------------------------
/athena/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ data.datasets """
17 |
--------------------------------------------------------------------------------
/athena/data/datasets/speech_recognition_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # pylint: disable=no-member, invalid-name
17 | """ audio dataset """
18 | import time
19 | import tqdm
20 | from absl import logging
21 | from athena import SpeechRecognitionDatasetBuilder
22 |
23 | def test():
24 | ''' test the speed of dataset '''
25 | data_csv = "/tmp-data/dataset/opensource/hkust/train.csv"
26 | dataset_builder = SpeechRecognitionDatasetBuilder(
27 | {
28 | "audio_config": {
29 | "type": "Fbank",
30 | "filterbank_channel_count": 40,
31 | "sample_rate": 8000,
32 | "local_cmvn": False,
33 | },
34 | "speed_permutation": [0.9, 1.0],
35 | "vocab_file": "examples/asr/hkust/data/vocab"
36 | }
37 | )
38 | dataset = dataset_builder.load_csv(data_csv).as_dataset(16, 4)
39 | start = time.time()
40 | for _ in tqdm.tqdm(dataset, total=len(dataset_builder)//16):
41 | pass
42 | logging.info(time.time() - start)
43 |
44 |
45 | if __name__ == '__main__':
46 | logging.set_verbosity(logging.INFO)
47 | test()
48 |
--------------------------------------------------------------------------------
/athena/decode_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # Only support tensorflow 2.0
18 | # pylint: disable=invalid-name, no-member
19 | r""" a sample implementation of LAS for HKUST """
20 | import sys
21 | import json
22 | import tensorflow as tf
23 | from absl import logging
24 | from athena import DecoderSolver
25 | from athena.main import (
26 | parse_config,
27 | build_model_from_jsonfile
28 | )
29 |
30 |
31 | def decode(jsonfile):
32 | """ entry point for model decoding, do some preparation work """
33 | p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile(jsonfile, 0)
34 | checkpointer.restore_from_best()
35 | solver = DecoderSolver(model, config=p.decode_config)
36 | dataset_builder = dataset_builder.load_csv(p.test_csv).compute_cmvn_if_necessary(True)
37 | solver.decode(dataset_builder.as_dataset(batch_size=1))
38 |
39 |
40 | if __name__ == "__main__":
41 | logging.set_verbosity(logging.INFO)
42 | tf.random.set_seed(1)
43 |
44 | JSON_FILE = sys.argv[1]
45 | CONFIG = None
46 | with open(JSON_FILE) as f:
47 | CONFIG = json.load(f)
48 | PARAMS = parse_config(CONFIG)
49 | DecoderSolver.initialize_devices(PARAMS.solver_gpu)
50 | decode(JSON_FILE)
51 |
--------------------------------------------------------------------------------
/athena/horovod_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # Only support tensorflow 2.0
18 | # pylint: disable=invalid-name, no-member
19 | r""" a sample implementation of LAS for HKUST """
20 | import sys
21 | import json
22 | import tensorflow as tf
23 | import horovod.tensorflow as hvd
24 | from absl import logging
25 | from athena import HorovodSolver
26 | from athena.main import parse_config, train
27 |
28 | if __name__ == "__main__":
29 | logging.set_verbosity(logging.INFO)
30 | tf.random.set_seed(1)
31 |
32 | JSON_FILE = sys.argv[1]
33 | CONFIG = None
34 | with open(JSON_FILE) as f:
35 | CONFIG = json.load(f)
36 | PARAMS = parse_config(CONFIG)
37 | HorovodSolver.initialize_devices()
38 | train(JSON_FILE, HorovodSolver, hvd.size(), hvd.local_rank())
39 |
--------------------------------------------------------------------------------
/athena/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ module """
17 |
--------------------------------------------------------------------------------
/athena/layers/commons.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # pylint: disable=too-few-public-methods, invalid-name
17 | # pylint: disable=no-self-use, missing-function-docstring
18 | """Utils for common layers."""
19 |
20 | import tensorflow as tf
21 | from athena.layers.functional import make_positional_encoding, collapse4d, gelu
22 |
23 | from athena.layers.functional import splice
24 |
25 |
26 | class PositionalEncoding(tf.keras.layers.Layer):
27 | """ positional encoding can be used in transformer """
28 |
29 | def __init__(self, d_model, max_position=800, scale=False):
30 | super().__init__()
31 | self.d_model = d_model
32 | self.scale = scale
33 | self.pos_encoding = make_positional_encoding(max_position, d_model)
34 |
35 | def call(self, x):
36 | """ call function """
37 | seq_len = tf.shape(x)[1]
38 | if self.scale:
39 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
40 | x += self.pos_encoding[:, :seq_len, :]
41 | return x
42 |
43 |
44 | class Collapse4D(tf.keras.layers.Layer):
45 | """ callapse4d can be used in cnn-lstm for speech processing
46 | reshape from [N T D C] -> [N T D*C]
47 | """
48 |
49 | def call(self, x):
50 | return collapse4d(x)
51 |
52 |
53 | class Gelu(tf.keras.layers.Layer):
54 | """Gaussian Error Linear Unit.
55 | This is a smoother version of the RELU.
56 | Original paper: https://arxiv.org/abs/1606.08415
57 | Args:
58 | x: float Tensor to perform activation.
59 | Returns:
60 | `x` with the GELU activation applied.
61 | """
62 |
63 | def call(self, x):
64 | return gelu(x)
65 |
66 |
67 | class TdnnLayer(tf.keras.layers.Layer):
68 | """ An implement of Tdnn Layer
69 | Args:
70 | context: a int of left and right context, or
71 | a list of context indexes, e.g. (-2, 0, 2).
72 | output_dim: the dim of the linear transform
73 | """
74 |
75 | def __init__(self, context, output_dim, use_bias=False, **kwargs):
76 | super().__init__(**kwargs)
77 |
78 | if hasattr(context, "__iter__"):
79 | self.context_size = len(context)
80 | self.context_list = context
81 | else:
82 | self.context_size = context * 2 + 1
83 | self.context_list = range(-context, context + 1)
84 |
85 | self.output_dim = output_dim
86 | self.linear = tf.keras.layers.Dense(output_dim, use_bias=use_bias)
87 |
88 | def call(self, x, training=None, mask=None):
89 | x = splice(x, self.context_list)
90 | x = self.linear(x, training=training, mask=mask)
91 | return x
92 |
93 |
94 | SUPPORTED_RNNS = {
95 | "lstm": tf.keras.layers.LSTMCell,
96 | "gru": tf.keras.layers.GRUCell,
97 | "cudnnlstm": tf.keras.layers.LSTMCell,
98 | "cudnngru": tf.keras.layers.GRUCell
99 | }
100 |
101 |
102 | ACTIVATIONS = {
103 | "relu": tf.nn.relu,
104 | "relu6": tf.nn.relu6,
105 | "elu": tf.nn.elu,
106 | "selu": tf.nn.selu,
107 | "gelu": gelu,
108 | "leaky_relu": tf.nn.leaky_relu,
109 | "sigmoid": tf.nn.sigmoid,
110 | "softplus": tf.nn.softplus,
111 | "softsign": tf.nn.softsign,
112 | "tanh": tf.nn.tanh,
113 | }
114 |
--------------------------------------------------------------------------------
/athena/layers/functional.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # pylint: disable=invalid-name
17 | """Utils for common layers."""
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from ..utils.misc import tensor_shape
22 | from tensorflow.python.framework import ops
23 |
24 |
25 | def make_positional_encoding(position, d_model):
26 | """ generate a postional encoding list """
27 |
28 | def get_angles(pos, i, d_model):
29 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
30 | return pos * angle_rates
31 |
32 | angle_rads = get_angles(
33 | np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model
34 | )
35 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
36 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
37 | pos_encoding = angle_rads[np.newaxis, ...]
38 | return tf.cast(pos_encoding, dtype=tf.float32)
39 |
40 |
41 | def collapse4d(x, name=None):
42 | """ reshape from [N T D C] -> [N T D*C]
43 | using tf.shape(x), which generate a tensor instead of x.shape
44 | """
45 | with ops.name_scope(name, "collapse4d") as name:
46 | shape = tensor_shape(x)
47 | N = shape[0]
48 | T = shape[1]
49 | D = shape[2]
50 | C = shape[3]
51 | DC = D * C
52 | out = tf.reshape(x, [N, T, DC])
53 | return out
54 |
55 |
56 | def splice(x, context):
57 | """
58 | Splice a tensor along the last dimension with context.
59 | e.g.:
60 | t = [[[1, 2, 3],
61 | [4, 5, 6],
62 | [7, 8, 9]]]
63 | splice_tensor(t, [0, 1]) =
64 | [[[1, 2, 3, 4, 5, 6],
65 | [4, 5, 6, 7, 8, 9],
66 | [7, 8, 9, 7, 8, 9]]]
67 |
68 | Args:
69 | tensor: a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W)
70 | context: a list of context offsets
71 |
72 | Returns:
73 | spliced tensor with shape (..., D * len(context))
74 | """
75 | input_shape = tf.shape(x)
76 | B, T = input_shape[0], input_shape[1]
77 | context_len = len(context)
78 | array = tf.TensorArray(x.dtype, size=context_len)
79 | for idx, offset in enumerate(context):
80 | begin = offset
81 | end = T + offset
82 | if begin < 0:
83 | begin = 0
84 | sliced = x[:, begin:end, :]
85 | tiled = tf.tile(x[:, 0:1, :], [1, abs(offset), 1])
86 | final = tf.concat((tiled, sliced), axis=1)
87 | else:
88 | end = T
89 | sliced = x[:, begin:end, :]
90 | tiled = tf.tile(x[:, -1:, :], [1, abs(offset), 1])
91 | final = tf.concat((sliced, tiled), axis=1)
92 | array = array.write(idx, final)
93 | spliced = array.stack()
94 | spliced = tf.transpose(spliced, (1, 2, 0, 3))
95 | spliced = tf.reshape(spliced, (B, T, -1))
96 | return spliced
97 |
98 |
99 | def gelu(x):
100 | """Gaussian Error Linear Unit.
101 | This is a smoother version of the RELU.
102 | Original paper: https://arxiv.org/abs/1606.08415
103 | Args:
104 | x: float Tensor to perform activation.
105 | Returns:
106 | `x` with the GELU activation applied.
107 | """
108 | cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
109 | return x * cdf
110 |
--------------------------------------------------------------------------------
/athena/loss.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # Only support eager mode and TF>=2.0.0
17 | # pylint: disable=too-few-public-methods
18 | """ some losses """
19 | import tensorflow as tf
20 | from .utils.misc import insert_eos_in_labels
21 |
22 |
23 | class CTCLoss(tf.keras.losses.Loss):
24 | """ CTC LOSS
25 | CTC LOSS implemented with Tensorflow
26 | """
27 |
28 | def __init__(self, logits_time_major=False, blank_index=-1, name="CTCLoss"):
29 | super().__init__(name=name)
30 | self.logits_time_major = logits_time_major
31 | self.blank_index = blank_index
32 | self.need_logit_length = True
33 |
34 | def __call__(self, logits, samples, logit_length=None):
35 | assert logit_length is not None
36 | # use v2 ctc_loss
37 | ctc_loss = tf.nn.ctc_loss(
38 | labels=samples["output"],
39 | logits=logits,
40 | logit_length=logit_length,
41 | label_length=samples["output_length"],
42 | logits_time_major=self.logits_time_major,
43 | blank_index=self.blank_index,
44 | )
45 | return tf.reduce_mean(ctc_loss)
46 |
47 |
48 | class Seq2SeqSparseCategoricalCrossentropy(tf.keras.losses.CategoricalCrossentropy):
49 | """ Seq2SeqSparseCategoricalCrossentropy LOSS
50 | CategoricalCrossentropy calculated at each character for each sequence in a batch
51 | """
52 |
53 | def __init__(self, num_classes, eos=-1, by_token=False, by_sequence=True,
54 | from_logits=True, label_smoothing=0.0):
55 | super().__init__(from_logits=from_logits, label_smoothing=label_smoothing, reduction="none")
56 | self.by_token = by_token
57 | self.by_sequence = by_sequence
58 | self.num_classes = num_classes
59 | self.eos = num_classes + eos if eos < 0 else eos
60 |
61 | def __call__(self, logits, samples, logit_length=None):
62 | labels = insert_eos_in_labels(samples["output"], self.eos, samples["output_length"])
63 | mask = tf.math.logical_not(tf.math.equal(labels, 0))
64 | labels = tf.one_hot(indices=labels, depth=self.num_classes)
65 | seq_len = tf.shape(labels)[1]
66 | logits = logits[:, :seq_len, :]
67 | loss = self.call(labels, logits)
68 | mask = tf.cast(mask, dtype=loss.dtype)
69 | loss *= mask
70 | if self.by_token:
71 | return tf.divide(tf.reduce_sum(loss), tf.reduce_sum(mask))
72 | if self.by_sequence:
73 | loss = tf.reduce_sum(loss, axis=-1)
74 | return tf.reduce_mean(loss)
75 |
76 |
77 | class MPCLoss(tf.keras.losses.Loss):
78 | """MPC LOSS
79 | L1 loss for each masked acoustic features in a batch
80 | """
81 |
82 | def __init__(self, name="MPCLoss"):
83 | super().__init__(name=name)
84 |
85 | def __call__(self, logits, samples, logit_length=None):
86 | target = samples["output"]
87 | shape = tf.shape(logits)
88 | target = tf.reshape(target, shape)
89 | loss = target - logits
90 | # mpc mask
91 | mask = tf.cast(tf.math.equal(tf.reshape(samples["input"], shape), 0), loss.dtype)
92 | loss *= mask
93 | # sequence length mask
94 | seq_mask = tf.sequence_mask(logit_length, shape[1], dtype=loss.dtype)
95 | seq_mask = tf.tile(seq_mask[:, :, tf.newaxis], [1, 1, shape[2]])
96 | loss *= seq_mask
97 | loss = tf.reduce_sum(tf.abs(loss, name="L1_loss"), 2)
98 | loss = tf.reduce_mean(loss)
99 | return loss
100 |
--------------------------------------------------------------------------------
/athena/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) ATHENA AUTHORS
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ module """
17 |
--------------------------------------------------------------------------------
/athena/models/base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # Only support eager mode
18 | # pylint: disable=useless-super-delegation, unused-argument, no-self-use
19 |
20 | """ base model for models """
21 | from absl import logging
22 | import tensorflow as tf
23 |
24 | class BaseModel(tf.keras.Model):
25 | """Base class for model."""
26 |
27 | def __init__(self, **kwargs):
28 | super().__init__(**kwargs)
29 | self.loss_function = None
30 | self.metric = None
31 |
32 | def call(self, samples, training=None):
33 | """ call model """
34 | raise NotImplementedError()
35 |
36 | #pylint: disable=not-callable
37 | def get_loss(self, logits, samples, training=None):
38 | """ get loss """
39 | if self.loss_function is None:
40 | loss = 0.0
41 | else:
42 | logit_length = self.compute_logit_length(samples)
43 | loss = self.loss_function(logits, samples, logit_length)
44 | if self.metric is None:
45 | metrics = {}
46 | else:
47 | self.metric(logits, samples, logit_length)
48 | metrics = {self.metric.name: self.metric.result()}
49 | return loss, metrics
50 |
51 | def compute_logit_length(self, samples):
52 | """ compute the logit length """
53 | return samples["input_length"]
54 |
55 | def reset_metrics(self):
56 | """ reset the metrics """
57 | if self.metric is not None:
58 | self.metric.reset_states()
59 |
60 | def prepare_samples(self, samples):
61 | """ for special data prepare
62 | carefully: do not change the shape of samples
63 | """
64 | return samples
65 |
66 | def restore_from_pretrained_model(self, pretrained_model, model_type=""):
67 | """ restore from pretrained model
68 | """
69 | logging.info("restore from pretrained model")
70 |
71 | def decode(self, samples, hparams):
72 | """ decode interface
73 | """
74 | logging.info("sorry, this model do not support decode")
75 |
--------------------------------------------------------------------------------
/athena/models/customized.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | # Only support eager mode and TF>=2.0.0
16 | # pylint: disable=no-member, invalid-name, relative-beyond-top-level
17 | """ a implementation of customized model """
18 |
19 | import tensorflow as tf
20 | from absl import logging
21 | from athena import layers as athena_layers
22 | from athena.utils.hparam import register_and_parse_hparams
23 | from .base import BaseModel
24 |
25 |
26 | def build_layer(input_data, layer_config):
27 | """
28 | Build one or multiple layers from layer configuration.
29 | args:
30 | layer_config: list of multiple layers, or dict of single layer parameters.
31 | returns: a Keras layer.
32 | """
33 | if isinstance(layer_config, list):
34 | # Recursively build each layer.
35 | output = input_data
36 | for one_layer_config in layer_config:
37 | logging.info(f"Layer conf: {one_layer_config}")
38 | output = build_layer(output, one_layer_config)
39 | return output
40 |
41 | # Build one layer.
42 | for layer_type in layer_config:
43 | layer_args = layer_config[layer_type]
44 | logging.info(f"Final layer config: {layer_type}: {layer_args}")
45 | layer_cls = getattr(athena_layers, layer_type)
46 | layer = layer_cls(**layer_args)
47 | return layer(input_data)
48 |
49 |
50 | class CustomizedModel(BaseModel):
51 | """ a simple customized model """
52 | default_config = {
53 | "topo": [{}]
54 | }
55 | def __init__(self, num_classes, sample_shape, config=None):
56 | super().__init__()
57 | self.hparams = register_and_parse_hparams(self.default_config, config)
58 |
59 | logging.info(f"Network topology config: {self.hparams.topo}")
60 | input_feature = tf.keras.layers.Input(shape=sample_shape["input"], dtype=tf.float32)
61 | inner = build_layer(input_feature, self.hparams.topo)
62 | inner = tf.keras.layers.Dense(num_classes)(inner)
63 | self.model = tf.keras.Model(inputs=input_feature, outputs=inner)
64 | logging.info(self.model.summary())
65 |
66 | def call(self, samples, training=None):
67 | return self.model(samples["input"], training=training)
68 |
--------------------------------------------------------------------------------
/athena/models/deep_speech.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | # Only support eager mode and TF>=2.0.0
16 | # pylint: disable=no-member, invalid-name, relative-beyond-top-level
17 | # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
18 | """ a implementation of deep speech 2 model can be used as a sample for ctc model """
19 |
20 | import tensorflow as tf
21 | from absl import logging
22 | from ..utils.hparam import register_and_parse_hparams
23 | from .base import BaseModel
24 | from ..loss import CTCLoss
25 | from ..metrics import CTCAccuracy
26 | from ..layers.commons import SUPPORTED_RNNS
27 |
28 |
29 | class DeepSpeechModel(BaseModel):
30 | """ a sample implementation of CTC model """
31 | default_config = {
32 | "conv_filters": 256,
33 | "rnn_hidden_size": 1024,
34 | "num_rnn_layers": 6,
35 | "rnn_type": "gru"
36 | }
37 | def __init__(self, num_classes, sample_shape, config=None):
38 | super().__init__()
39 | self.num_classes = num_classes + 1
40 | self.loss_function = CTCLoss(blank_index=-1)
41 | self.metric = CTCAccuracy()
42 | self.hparams = register_and_parse_hparams(self.default_config, config, cls=self.__class__)
43 |
44 | layers = tf.keras.layers
45 | input_feature = layers.Input(shape=sample_shape["input"], dtype=tf.float32)
46 | inner = layers.Conv2D(
47 | filters=self.hparams.conv_filters,
48 | kernel_size=(41, 11),
49 | strides=(2, 2),
50 | padding="same",
51 | use_bias=False,
52 | )(input_feature)
53 | inner = layers.BatchNormalization()(inner)
54 | inner = tf.nn.relu6(inner)
55 | inner = layers.Conv2D(
56 | filters=self.hparams.conv_filters,
57 | kernel_size=(21, 11),
58 | strides=(2, 1),
59 | padding="same",
60 | use_bias=False,
61 | )(inner)
62 | inner = layers.BatchNormalization()(inner)
63 | inner = tf.nn.relu6(inner)
64 | _, _, dim, channels = inner.get_shape().as_list()
65 | output_dim = dim * channels
66 | inner = layers.Reshape((-1, output_dim))(inner)
67 | rnn_type = self.hparams.rnn_type
68 | rnn_hidden_size = self.hparams.rnn_hidden_size
69 |
70 | for _ in range(self.hparams.num_rnn_layers):
71 | inner = tf.keras.layers.RNN(
72 | cell=[SUPPORTED_RNNS[rnn_type](rnn_hidden_size)],
73 | return_sequences=True
74 | )(inner)
75 | inner = layers.BatchNormalization()(inner)
76 | inner = layers.Dense(rnn_hidden_size, activation=tf.nn.relu6)(inner)
77 | inner = layers.Dense(self.num_classes)(inner)
78 | self.net = tf.keras.Model(inputs=input_feature, outputs=inner)
79 | logging.info(self.net.summary())
80 |
81 | def call(self, samples, training=None):
82 | """ call function """
83 | return self.net(samples["input"], training=training)
84 |
85 | def compute_logit_length(self, samples):
86 | """ used for get logit length """
87 | input_length = tf.cast(samples["input_length"], tf.float32)
88 | logit_length = tf.math.ceil(input_length / 2)
89 | logit_length = tf.math.ceil(logit_length / 2)
90 | logit_length = tf.cast(logit_length, tf.int32)
91 | return logit_length
92 |
--------------------------------------------------------------------------------
/athena/models/rnn_lm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 ATHENA AUTHORS; Xiangang Li; Xiaoning Lei
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # Only support eager mode
18 | # pylint: disable=no-member, invalid-name
19 | """ RNN language model implementation"""
20 |
21 | import tensorflow as tf
22 | from .base import BaseModel
23 | from ..utils.misc import insert_eos_in_labels, insert_sos_in_labels
24 | from ..utils.hparam import register_and_parse_hparams
25 | from ..layers.commons import SUPPORTED_RNNS
26 |
27 | class RNNLM(BaseModel):
28 | """Standard implementation of a RNNLM. Model mainly consists of embeding layer,
29 | rnn layers(with dropout), and the full connection layer, which are all incuded
30 | in self.model_for_rnn
31 | """
32 | default_config = {
33 | "d_model": 512, # the dim of model
34 | "rnn_type": 'lstm', # the supported rnn type
35 | "num_layer": 2, # the number of rnn layer
36 | "dropout_rate": 0.1, # dropout for model
37 | "sos": -1, # sos can be -1 or -2
38 | "eos": -1 # eos can be -1 or -2
39 | }
40 | def __init__(self, num_classes, sample_shape, config=None):
41 | """ config including the params for build lm """
42 | super(RNNLM, self).__init__()
43 | p = register_and_parse_hparams(self.default_config, config)
44 | self.num_classes = (
45 | num_classes + 1
46 | if p.sos == p.eos
47 | else num_classes + 2
48 | )
49 | self.sos = self.num_classes + p.sos
50 | self.eos = self.num_classes + p.eos
51 | self.metric = tf.keras.metrics.Mean(name="AverageLoss")
52 |
53 | layers = tf.keras.layers
54 | input_features = layers.Input(shape=sample_shape["output"], dtype=tf.int32)
55 | inner = tf.keras.layers.Embedding(self.num_classes, p.d_model)(input_features)
56 | for _ in range(p.num_layer):
57 | inner = tf.keras.layers.Dropout(p.dropout_rate)(inner)
58 | inner = tf.keras.layers.RNN(
59 | cell=[SUPPORTED_RNNS[p.rnn_type](p.d_model)],
60 | return_sequences=True
61 | )(inner)
62 | inner = tf.keras.layers.Dropout(p.dropout_rate)(inner)
63 | inner = tf.keras.layers.Dense(self.num_classes)(inner)
64 | self.rnnlm = tf.keras.Model(inputs=input_features, outputs=inner)
65 |
66 | def call(self, samples, training: bool = None):
67 | x = insert_sos_in_labels(samples['input'], self.sos)
68 | return self.rnnlm(x, training=training)
69 |
70 | def save_model(self, path):
71 | """
72 | for saving model and current weight, path is h5 file name, like 'my_model.h5'
73 | usage:
74 | new_model = tf.keras.models.load_model(path)
75 | """
76 | self.rnnlm.save(path)
77 |
78 | def get_loss(self, logits, samples, training=None):
79 | """ get loss """
80 | labels = samples['output']
81 | labels = insert_eos_in_labels(labels, self.eos, samples['output_length'])
82 | labels = tf.one_hot(labels, self.num_classes)
83 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
84 | n_token = tf.cast(tf.reduce_sum(samples['output_length'] + 1), tf.float32)
85 | self.metric.update_state(loss)
86 | metrics = {self.metric.name: self.metric.result()}
87 | return tf.reduce_sum(loss) / n_token, metrics
88 |
--------------------------------------------------------------------------------
/athena/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) ATHENA AUTHORS
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ tools """
17 |
--------------------------------------------------------------------------------
/athena/tools/lm_scorer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kenlm
3 |
4 |
5 | class NGramScorer(object):
6 | """
7 | KenLM language model
8 | """
9 |
10 | def __init__(self, lm_path, sos, eos, num_syms, lm_weight=0.1):
11 | """
12 | Basic params will be initialized, the kenlm model will be created from
13 | the lm_path
14 | Args:
15 | lm_path: the saved lm model path
16 | sos: start symbol
17 | eos: end symbol
18 | num_syms: number of classes
19 | lm_weight: the lm weight
20 | """
21 | self.lang_model = kenlm.Model(lm_path)
22 | self.state_index = 0
23 | self.sos = sos
24 | self.eos = eos
25 | self.num_syms = num_syms
26 | self.lm_weight = lm_weight
27 | kenlm_state = kenlm.State()
28 | self.lang_model.BeginSentenceWrite(kenlm_state)
29 | self.cand_kenlm_states = np.array([[kenlm_state] * num_syms])
30 |
31 | def score(self, candidate_holder, new_scores):
32 | """
33 | Call this function to compute the NGram score of the next prediction
34 | based on historical predictions, the scoring function shares a common interface
35 | Args:
36 | candidate_holder:
37 | Returns:
38 | score: the NGram weighted score
39 | cand_states:
40 | """
41 | cand_seqs = candidate_holder.cand_seqs
42 | cand_parents = candidate_holder.cand_parents
43 | cand_syms = cand_seqs[:, -1]
44 | score = self.get_score(cand_parents, cand_syms, self.lang_model)
45 | score = self.lm_weight * score
46 | return score, candidate_holder.cand_states
47 |
48 | def get_score(self, cand_parents, cand_syms, lang_model):
49 | """
50 | the saved lm model will be called here
51 | Args:
52 | cand_parents: last selected top candidates
53 | cand_syms: last selected top char index
54 | lang_model: the language model
55 | Return:
56 | scores: the lm scores
57 | """
58 | scale = 1.0 / np.log10(np.e) # convert log10 to ln
59 |
60 | num_cands = len(cand_syms)
61 | scores = np.zeros((num_cands, self.num_syms))
62 | new_states = np.zeros((num_cands, self.num_syms), dtype=object)
63 | chars = [str(x) for x in range(self.num_syms)]
64 | chars[self.sos] = ""
65 | chars[self.eos] = ""
66 | chars[0] = ""
67 |
68 | for i in range(num_cands):
69 | parent = cand_parents[i]
70 | kenlm_state_list = self.cand_kenlm_states[parent]
71 | kenlm_state = kenlm_state_list[cand_syms[i]]
72 | for sym in range(self.num_syms):
73 | char = chars[sym]
74 | out_state = kenlm.State()
75 | score = scale * lang_model.BaseScore(kenlm_state, char, out_state)
76 | scores[i, sym] = score
77 | new_states[i, sym] = out_state
78 | self.cand_kenlm_states = new_states
79 | return scores
80 |
--------------------------------------------------------------------------------
/athena/tools/sph2pipe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/athena/tools/sph2pipe
--------------------------------------------------------------------------------
/athena/transform/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from athena.transform import audio_featurizer
17 |
18 | from athena.transform.audio_featurizer import AudioFeaturizer
19 | from athena.transform.feats.cmvn import compute_cmvn
20 | from athena.transform.feats.read_wav import read_wav
21 |
--------------------------------------------------------------------------------
/athena/transform/audio_featurizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model provides a general interface for feature extraction."""
17 |
18 | import tensorflow as tf
19 | from athena.transform import feats
20 |
21 |
22 | class AudioFeaturizer:
23 | """
24 | Interface of audio Features extractions.
25 | """
26 | #pylint: disable=dangerous-default-value
27 | def __init__(self, config={"type": "Fbank"}):
28 | """init
29 | :param name Feature name, eg fbank, mfcc, plp ...
30 | :param config
31 | 'type': 'ReadWav', 'Fbank', 'Spectrum'
32 | The config for fbank
33 | 'sample_rate' 16000
34 | 'window_length' 0.025
35 | 'frame_length' 0.010
36 | 'upper_frequency_limit' 20
37 | 'filterbank_channel_count' 40
38 | The config for Spectrum
39 | 'sample_rate' 16000
40 | 'window_length' 0.025
41 | 'frame_length' 0.010
42 | 'output_type' 1
43 | """
44 |
45 | assert "type" in config
46 |
47 | self.name = config["type"]
48 | self.feat = getattr(feats, self.name).params(config).instantiate()
49 |
50 | if self.name != "ReadWav":
51 | self.read_wav = getattr(feats, "ReadWav").params(config).instantiate()
52 |
53 | #pylint:disable=invalid-name
54 | def __call__(self, audio=None, sr=None, speed=1.0):
55 | """extract feature from audo data
56 | :param audio data or audio file
57 | :sr sample rate
58 | :return feature
59 | """
60 |
61 | if audio is not None and not tf.is_tensor(audio):
62 | audio = tf.convert_to_tensor(audio)
63 | if sr is not None and not tf.is_tensor(sr):
64 | sr = tf.convert_to_tensor(sr)
65 |
66 | return self.__impl(audio, sr, speed)
67 |
68 | @tf.function
69 | def __impl(self, audio=None, sr=None, speed=1.0):
70 | """
71 | :param audio data or audio file, a tensor
72 | :sr sample rate, a tensor
73 | :return feature
74 | """
75 | if self.name == "ReadWav" or self.name == "CMVN":
76 | return self.feat(audio, speed)
77 | elif audio.dtype is tf.string:
78 | audio_data, sr = self.read_wav(audio, speed)
79 | return self.feat(audio_data, sr)
80 | else:
81 | return self.feat(audio, sr)
82 |
83 | @property
84 | def dim(self):
85 | """return the dimension of the feature
86 | if only ReadWav, return 1
87 | """
88 | return self.feat.dim()
89 |
90 | @property
91 | def num_channels(self):
92 | """return the channel of the feature"""
93 | return self.feat.num_channels()
94 |
--------------------------------------------------------------------------------
/athena/transform/feats/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from athena.transform.feats.read_wav import ReadWav
17 | from athena.transform.feats.spectrum import Spectrum
18 | from athena.transform.feats.framepow import Framepow
19 | from athena.transform.feats.pitch import Pitch
20 | from athena.transform.feats.mfcc import Mfcc
21 | from athena.transform.feats.write_wav import WriteWav
22 | from athena.transform.feats.fbank import Fbank
23 | from athena.transform.feats.cmvn import CMVN, compute_cmvn
24 | from athena.transform.feats.fbank_pitch import FbankPitch
25 |
--------------------------------------------------------------------------------
/athena/transform/feats/base_frontend.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ base interface of Frontend """
17 |
18 | import abc
19 | import tensorflow as tf
20 |
21 |
22 | class ABCFrontend(metaclass=abc.ABCMeta):
23 | """ abstract of Frontend """
24 |
25 | def __init__(self, config):
26 | raise NotImplementedError()
27 |
28 | @abc.abstractmethod
29 | def call(self, *args):
30 | """ implementation func """
31 | raise NotImplementedError()
32 |
33 |
34 | class BaseFrontend(ABCFrontend):
35 | """ wrapper of abstrcat Frontend"""
36 |
37 | def __init__(self, config: dict):
38 | self._config = config
39 |
40 | @property
41 | def config(self):
42 | """ config property """
43 | return self._config
44 |
45 | @classmethod
46 | def params(cls, config=None):
47 | """ set params """
48 | raise NotImplementedError()
49 |
50 | def __call__(self, *args):
51 | """ call """
52 | return self.call(*args)
53 |
54 | def dim(self):
55 | return 1
56 |
57 | def num_channels(self):
58 | return 1
59 |
--------------------------------------------------------------------------------
/athena/transform/feats/cmvn.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """This model doed CMVN on features."""
17 |
18 | import numpy as np
19 |
20 | import tensorflow as tf
21 |
22 | from athena.utils.hparam import HParams
23 | from athena.transform.feats.base_frontend import BaseFrontend
24 |
25 |
26 | class CMVN(BaseFrontend):
27 | """
28 | Do CMVN on features.
29 | """
30 | def __init__(self, config: dict):
31 | super().__init__(config)
32 |
33 | self.global_cmvn = False
34 | if len(config.global_mean) > 1:
35 | self.global_cmvn = True
36 |
37 | @classmethod
38 | def params(cls, config=None):
39 | """ set params """
40 |
41 | hparams = HParams(cls=cls)
42 | hparams.add_hparam("type", "CMVN")
43 | hparams.add_hparam("global_mean", [0.0])
44 | hparams.add_hparam("global_variance", [1.0])
45 | hparams.add_hparam("local_cmvn", False)
46 |
47 | if config is not None:
48 | hparams.parse(config, True)
49 |
50 | assert len(hparams.global_mean) == len(
51 | hparams.global_variance
52 | ), "Error, global_mean length {} is not equals to global_variance length {}".format(
53 | len(hparams.global_mean), len(hparams.global_variance)
54 | )
55 |
56 | hparams.global_variance = (np.sqrt(hparams.global_variance) + 1e-6).tolist()
57 | return hparams
58 |
59 | def call(self, audio_feature, speed=1.0):
60 | params = self.config
61 | if self.global_cmvn:
62 | audio_feature = (
63 | audio_feature - params.global_mean
64 | ) / params.global_variance
65 |
66 | if params.local_cmvn:
67 | mean, var = tf.compat.v1.nn.moments(audio_feature, axes=0)
68 | audio_feature = (audio_feature - mean) / (
69 | tf.compat.v1.math.sqrt(var) + 1e-6
70 | )
71 |
72 | return audio_feature
73 |
74 | def dim(self):
75 | params = self.config
76 | return len(params.global_mean)
77 |
78 |
79 | def compute_cmvn(audio_feature, mean=None, variance=None, local_cmvn=False):
80 | if mean is not None:
81 | assert variance is not None
82 | audio_feature = (audio_feature - mean) / variance
83 | if local_cmvn:
84 | mean, var = tf.compat.v1.nn.moments(audio_feature, axes=0)
85 | audio_feature = (audio_feature - mean) / (tf.compat.v1.math.sqrt(var) + 1e-6)
86 | return audio_feature
87 |
--------------------------------------------------------------------------------
/athena/transform/feats/cmvn_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests CMVN OP."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 | from athena.transform.feats.cmvn import CMVN
21 |
22 |
23 | class CMVNTest(tf.test.TestCase):
24 | """
25 | CMVN test.
26 | """
27 | def test_cmvn(self):
28 | dim = 40
29 | cmvn = CMVN.params(
30 | {"mean": np.zeros(dim).tolist(), "variance": np.ones(dim).tolist()}
31 | ).instantiate()
32 | audio_feature = tf.random_uniform(shape=[3, 40], dtype=tf.float32, maxval=1.0)
33 | print(audio_feature)
34 | normalized = cmvn(audio_feature)
35 | print("normalized = ", normalized)
36 | print("dim is ", cmvn.dim())
37 |
38 | cmvn = CMVN.params(
39 | {
40 | "mean": np.zeros(dim).tolist(),
41 | "variance": np.ones(dim).tolist(),
42 | "cmvn": False,
43 | }
44 | ).instantiate()
45 | normalized = cmvn(audio_feature)
46 | self.assertAllClose(audio_feature, normalized)
47 |
48 |
49 | if __name__ == "__main__":
50 | if tf.__version__ < "2.0.0":
51 | tf.compat.v1.enable_eager_execution()
52 | tf.test.main()
53 |
--------------------------------------------------------------------------------
/athena/transform/feats/fbank_pitch_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests Fbank&&Pitch FE."""
17 |
18 | import os
19 | from pathlib import Path
20 | import tensorflow as tf
21 | from tensorflow.python.framework.ops import disable_eager_execution
22 | from athena.transform.feats.read_wav import ReadWav
23 | from athena.transform.feats.fbank_pitch import FbankPitch
24 |
25 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
26 |
27 | class FbankPitchTest(tf.test.TestCase):
28 | """
29 | Fbank && Pitch extraction test.
30 | """
31 | def test_FbankPitch(self):
32 | wav_path = str(Path(os.environ['MAIN_ROOT']).joinpath('examples/sm1_cln.wav'))
33 |
34 | with self.session():
35 | read_wav = ReadWav.params().instantiate()
36 | input_data, sample_rate = read_wav(wav_path)
37 | config = {'window_length': 0.025, 'output_type': 1, 'frame_length': 0.010, 'dither': 0.0}
38 | fbank_pitch = FbankPitch.params(config).instantiate()
39 | fbank_pitch_test = fbank_pitch(input_data, sample_rate)
40 |
41 | if tf.executing_eagerly():
42 | self.assertEqual(tf.rank(fbank_pitch_test).numpy(), 3)
43 | print(fbank_pitch_test.numpy()[0:2, :, 0])
44 | else:
45 | self.assertEqual(tf.rank(fbank_pitch_test).eval(), 3)
46 | print(fbank_pitch_test.eval()[0:2, :, 0])
47 |
48 | if __name__ == '__main__':
49 |
50 | is_eager = True
51 | if not is_eager:
52 | disable_eager_execution()
53 | else:
54 | if tf.__version__ < '2.0.0':
55 | tf.compat.v1.enable_eager_execution()
56 | tf.test.main()
57 |
--------------------------------------------------------------------------------
/athena/transform/feats/fbank_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | import os
18 | from pathlib import Path
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow.python.framework.ops import disable_eager_execution
22 | from athena.transform.feats.read_wav import ReadWav
23 | from athena.transform.feats.fbank import Fbank
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
26 |
27 |
28 | class FbankTest(tf.test.TestCase):
29 | def test_fbank(self):
30 | # 16kHz && 8kHz test
31 | wav_path_16k = str(
32 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")
33 | )
34 | wav_path_8k = str(
35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav")
36 | )
37 |
38 | with self.session():
39 | # value test
40 | read_wav = ReadWav.params().instantiate()
41 | input_data, sample_rate = read_wav(wav_path_16k)
42 | fbank = Fbank.params({"delta_delta": False}).instantiate()
43 | fbank_test = fbank(input_data, sample_rate)
44 | real_fank_feats = np.array(
45 | [
46 | [3.768338, 4.946218, 6.289874, 6.330853, 6.761764, 6.884573],
47 | [3.803553, 5.450971, 6.547878, 5.796172, 6.397846, 7.242926],
48 | ]
49 | )
50 | # self.assertAllClose(np.squeeze(fbank_test.eval()[0:2, 0:6, 0]),
51 | # real_fank_feats, rtol=1e-05, atol=1e-05)
52 | if tf.executing_eagerly():
53 | print(fbank_test.numpy()[0:2, 0:6, 0])
54 | else:
55 | print(fbank_test.eval()[0:2, 0:6, 0])
56 | count = 1
57 |
58 | for wav_file in [wav_path_8k, wav_path_16k]:
59 |
60 | read_wav = ReadWav.params().instantiate()
61 | input_data, sample_rate = read_wav(wav_file)
62 | if tf.executing_eagerly():
63 | print(wav_file, sample_rate.numpy())
64 | else:
65 | print(wav_file, sample_rate.eval())
66 |
67 | conf = {
68 | "delta_delta": True,
69 | "lower_frequency_limit": 100,
70 | "upper_frequency_limit": 0,
71 | }
72 | fbank = Fbank.params(conf).instantiate()
73 | fbank_test = fbank(input_data, sample_rate)
74 | if tf.executing_eagerly():
75 | print(fbank_test.numpy())
76 | else:
77 | print(fbank_test.eval())
78 | print(fbank.num_channels())
79 |
80 | conf = {
81 | "delta_delta": False,
82 | "lower_frequency_limit": 100,
83 | "upper_frequency_limit": 0,
84 | }
85 | fbank = Fbank.params(conf).instantiate()
86 | fbank_test = fbank(input_data, sample_rate)
87 | print(fbank_test)
88 | print(fbank.num_channels())
89 | count += 1
90 | del read_wav
91 | del fbank
92 |
93 |
94 | if __name__ == "__main__":
95 |
96 | is_eager = True
97 | if not is_eager:
98 | disable_eager_execution()
99 | else:
100 | if tf.__version__ < "2.0.0":
101 | tf.compat.v1.enable_eager_execution()
102 | tf.test.main()
103 |
--------------------------------------------------------------------------------
/athena/transform/feats/framepow.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """"This model extracts framepow features per frame."""
17 |
18 | import tensorflow as tf
19 | from athena.utils.hparam import HParams
20 | from athena.transform.feats.ops import py_x_ops
21 | from athena.transform.feats.base_frontend import BaseFrontend
22 |
23 |
24 | class Framepow(BaseFrontend):
25 | """
26 | Compute power of every frame in speech. Return a float tensor with
27 | shape (1 * num_frames).
28 | """
29 | def __init__(self, config: dict):
30 | super().__init__(config)
31 |
32 | @classmethod
33 | def params(cls, config=None):
34 | """
35 | Set params.
36 | :param config: contains four optional parameters:
37 | --window_length : Window length in seconds. (float, default = 0.025)
38 | --frame_length : Hop length in seconds. (float, default = 0.010)
39 | --snip_edges : If True, the last frame (shorter than window_length)
40 | will be cutoff. If False, 1 // 2 frame_length data will
41 | be padded to data. (int, default = True)
42 | --remove_dc_offset : Subtract mean from waveform on each frame (bool, default = true)
43 | :return:An object of class HParams, which is a set of hyperparameters as name-value pairs.
44 | """
45 |
46 | window_length = 0.025
47 | frame_length = 0.010
48 | snip_edges = 1
49 | remove_dc_offset = True
50 |
51 | hparams = HParams(cls=cls)
52 | hparams.add_hparam("window_length", window_length)
53 | hparams.add_hparam("frame_length", frame_length)
54 | hparams.add_hparam("snip_edges", snip_edges)
55 | hparams.add_hparam("remove_dc_offset", remove_dc_offset)
56 |
57 | if config is not None:
58 | hparams.parse(config, True)
59 |
60 | return hparams
61 |
62 | def call(self, audio_data, sample_rate):
63 | """
64 | Caculate power of every frame in speech.
65 | :param audio_data: the audio signal from which to compute spectrum.
66 | Should be an (1, N) tensor.
67 | :param sample_rate: [option]the samplerate of the signal we working with,
68 | default is 16kHz.
69 | :return:A float tensor of size (1 * num_frames) containing power of every
70 | frame in speech.
71 | """
72 |
73 | p = self.config
74 | with tf.name_scope('framepow'):
75 | sample_rate = tf.cast(sample_rate, dtype=float)
76 | framepow = py_x_ops.frame_pow(
77 | audio_data,
78 | sample_rate,
79 | snip_edges=p.snip_edges,
80 | remove_dc_offset=p.remove_dc_offset,
81 | window_length=p.window_length,
82 | frame_length=p.frame_length)
83 |
84 | return tf.squeeze(framepow)
85 |
86 | def dim(self):
87 | return 1
--------------------------------------------------------------------------------
/athena/transform/feats/framepow_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests framepow FE."""
17 |
18 | import os
19 | from pathlib import Path
20 | import numpy as np
21 | import tensorflow as tf
22 | from tensorflow.python.framework.ops import disable_eager_execution
23 | from athena.transform.feats.read_wav import ReadWav
24 | from athena.transform.feats.framepow import Framepow
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
27 |
28 |
29 | class FramePowTest(tf.test.TestCase):
30 | """
31 | Framepow extraction test.
32 | """
33 | def test_framepow(self):
34 | wav_path_16k = str(
35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")
36 | )
37 |
38 | with self.session():
39 | read_wav = ReadWav.params().instantiate()
40 | input_data, sample_rate = read_wav(wav_path_16k)
41 | config = {"snip_edges": 1}
42 | framepow = Framepow.params(config).instantiate()
43 | framepow_test = framepow(input_data, sample_rate)
44 |
45 | real_framepow_feats = np.array(
46 | [9.819611, 9.328745, 9.247337, 9.26451, 9.266059]
47 | )
48 |
49 | if tf.executing_eagerly():
50 | self.assertAllClose(
51 | framepow_test.numpy()[0:5],
52 | real_framepow_feats,
53 | rtol=1e-05,
54 | atol=1e-05,
55 | )
56 | print(framepow_test.numpy()[0:5])
57 | else:
58 | self.assertAllClose(
59 | framepow_test.eval()[0:5],
60 | real_framepow_feats,
61 | rtol=1e-05,
62 | atol=1e-05,
63 | )
64 | print(framepow_test.eval()[0:5])
65 |
66 |
67 | if __name__ == "__main__":
68 | is_eager = True
69 | if not is_eager:
70 | disable_eager_execution()
71 | else:
72 | if tf.__version__ < "2.0.0":
73 | tf.compat.v1.enable_eager_execution()
74 | tf.test.main()
75 |
--------------------------------------------------------------------------------
/athena/transform/feats/mfcc_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests MFCC FE."""
17 |
18 | import os
19 | from pathlib import Path
20 | import numpy as np
21 | import tensorflow as tf
22 | from tensorflow.python.framework.ops import disable_eager_execution
23 | from athena.transform.feats.read_wav import ReadWav
24 | from athena.transform.feats.mfcc import Mfcc
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
27 |
28 |
29 | class MfccTest(tf.test.TestCase):
30 | """
31 | MFCC extraction test.
32 | """
33 | def test_mfcc(self):
34 | wav_path_16k = str(
35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")
36 | )
37 |
38 | with self.session():
39 | read_wav = ReadWav.params().instantiate()
40 | input_data, sample_rate = read_wav(wav_path_16k)
41 | config = {"use_energy": True}
42 | mfcc = Mfcc.params(config).instantiate()
43 | mfcc_test = mfcc(input_data, sample_rate)
44 |
45 | real_mfcc_feats = np.array(
46 | [
47 | [9.819611, -30.58736, -7.088838, -10.67966, -1.646479, -4.36086],
48 | [9.328745, -30.73371, -6.128432, -7.930599, 3.208357, -1.086456],
49 | ]
50 | )
51 |
52 | if tf.executing_eagerly():
53 | self.assertAllClose(
54 | mfcc_test.numpy()[0, 0:2, 0:6],
55 | real_mfcc_feats,
56 | rtol=1e-05,
57 | atol=1e-05,
58 | )
59 | else:
60 | self.assertAllClose(
61 | mfcc_test.eval()[0, 0:2, 0:6],
62 | real_mfcc_feats,
63 | rtol=1e-05,
64 | atol=1e-05,
65 | )
66 |
67 |
68 | if __name__ == "__main__":
69 |
70 | is_eager = True
71 | if not is_eager:
72 | disable_eager_execution()
73 | else:
74 | if tf.__version__ < "2.0.0":
75 | tf.compat.v1.enable_eager_execution()
76 | tf.test.main()
77 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/Makefile:
--------------------------------------------------------------------------------
1 | # Find where we're running from, so we can store generated files here.
2 |
3 | ifeq ($(origin MAKEFILE_DIR), undefined)
4 | MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
5 | MAIN_ROOT := $(realpath $(MAKEFILE_DIR)/../../)
6 | endif
7 |
8 | #$(info $(MAKEFILE_DIR))
9 | #$(info $(MAIN_ROOT))
10 |
11 | CXX := g++
12 | NVCC := nvcc
13 | PYTHON_BIN_PATH= python3
14 | CC :=
15 | AR :=
16 | CXXFLAGS :=
17 | LDFLAGS :=
18 | STDLIB :=
19 |
20 | # Try to figure out the host system
21 | HOST_OS :=
22 | ifeq ($(OS),Windows_NT)
23 | HOST_OS = windows
24 | else
25 | UNAME_S := $(shell uname -s)
26 | ifeq ($(UNAME_S),Linux)
27 | HOST_OS := linux
28 | endif
29 | ifeq ($(UNAME_S),Darwin)
30 | HOST_OS := ios
31 | endif
32 | endif
33 |
34 | #HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
35 | HOST_ARCH=x86_64
36 | TARGET := $(HOST_OS)
37 | TARGET_ARCH := $(HOST_ARCH)
38 |
39 | GENDIR := $(MAKEFILE_DIR)/gen/
40 | TGTDIR := $(GENDIR)$(TARGET)_$(TARGET_ARCH)/
41 | OBJDIR := $(TGTDIR)obj/
42 | BINDIR := $(TGTDIR)bin/
43 | LIBDIR := $(TGTDIR)lib/
44 |
45 |
46 | TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
47 | # Fix TF LDFLAGS issue on macOS.
48 | TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))' | sed "s/-l:libtensorflow_framework.1.dylib/-ltensorflow_framework.1/")
49 | #TF_INCLUDES := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
50 | TF_LIBS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
51 | CXXFLAGS += -fPIC -shared -O2 -std=c++11 -DFEATURE_VERSION=\"$(shell git rev-parse --short HEAD)\" $(TF_CFLAGS)
52 | INCLUDES := -I$(MAIN_ROOT) \
53 | -I$(MAIN_ROOT)/feats/ops \
54 |
55 | LDFLAGS += $(TF_LFLAGS)
56 |
57 | CORE_CC_EXCLUDE_SRCS := \
58 | $(wildcard kernels/*test.cc) \
59 | $(wildcard kernels/*test_util.cc)
60 |
61 | # src and tgts
62 | LIB_SRCS_ALL := $(wildcard kernels/*.cc)
63 | LIB_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(LIB_SRCS_ALL))
64 | LIB_OBJS := $(addprefix $(OBJDIR), $(patsubst %.cc, %.o, $(patsubst %.c, %.o, $(LIB_SRCS))))
65 |
66 | # lib
67 | SHARED_LIB := x_ops.so
68 |
69 | TEST_SRC := $(wildcard kernels/*_test.cc)
70 | TEST_OBJ := $(addprefix $(OBJDIR), $(patsubst %.cc, $(OBJS_DIR)%.o, $(TEST_SRC)))
71 | TEST_BIN := $(addprefix $(BINDIR), $(patsubst %.cc, $(OBJS_DIR)%.bin, $(TEST_SRC)))
72 | #TEST_BIN := $(BINDIR)test
73 |
74 | all: $(SHARED_LIB) $(TEST_BIN)
75 |
76 | $(OBJDIR)%.o: %.cc
77 | @mkdir -p $(dir $@)
78 | $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ $(LDFLAGS)
79 |
80 | $(SHARED_LIB): $(LIB_OBJS)
81 | @mkdir -p $(dir $@)
82 | $(CXX) -fPIC -shared -o $@ $^ $(STDLIB) $(LDFLAGS)
83 |
84 | $(STATIC_LIB): $(LIB_OBJS)
85 | @mkdir -p $(dir $@)
86 | $(AR) crsv $@ $^
87 |
88 | ${TEST_BIN}: $(TEST_OBJ) $(STATIC_LIB)
89 | @mkdir -p $(dir $@)
90 | $(CXX) $(LDFLAGS) $^ -o $@ $(STATIC_LIB)
91 |
92 | .PHONY: clean
93 | clean:
94 | -rm -rf $(GENDIR)
95 | -rm -f x_ops.so
96 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/delta_delta.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #include "kernels/delta_delta.h"
18 |
19 | #include
20 |
21 | namespace delta {
22 |
23 | const int kOrder = 2;
24 | const int kWindow = 2;
25 |
26 | DeltaDelta::DeltaDelta()
27 | : initialized_(false), order_(kOrder), window_(kWindow) {}
28 |
29 | DeltaDelta::~DeltaDelta() {
30 | for (int i = 0; i < scales_.size(); i++)
31 | std::vector().swap(scales_[i]);
32 | (scales_).clear();
33 | }
34 |
35 | bool DeltaDelta::Initialize(int order, int window) {
36 | if (order < 0 || order > 1000) {
37 | LOG(ERROR) << "order must be greater than zero and less than 1000.";
38 | return false;
39 | }
40 | if (window < 0 || window > 1000) {
41 | LOG(ERROR) << "window must be greater than zero and less than 1000.";
42 | return false;
43 | }
44 | order_ = order;
45 | window_ = window;
46 |
47 | scales_.resize(order_ + 1);
48 | scales_[0].resize(1);
49 | scales_[0][0] =
50 | 1.0; // trival window for 0th order delta [i.e. baseline feats]
51 |
52 | for (int i = 1; i <= order_; i++) {
53 | std::vector&prev_scales = scales_[i - 1], &cur_scales = scales_[i];
54 |
55 | int window = window_;
56 | if (window == 0) {
57 | LOG(ERROR) << "window must not be zero.";
58 | return false;
59 | }
60 | int prev_offset = (static_cast(prev_scales.size() - 1)) / 2,
61 | cur_offset = prev_offset + window;
62 | cur_scales.resize(prev_scales.size() + 2 * window);
63 |
64 | double normalizer = 0.0;
65 | for (int j = -window; j <= window; j++) {
66 | normalizer += j * j;
67 | for (int k = -prev_offset; k <= prev_offset; k++) {
68 | cur_scales[j + k + cur_offset] +=
69 | static_cast(j) * prev_scales[k + prev_offset];
70 | }
71 | }
72 |
73 | for (int i = 0; i < cur_scales.size(); i++) {
74 | cur_scales[i] *= (1.0 / normalizer);
75 | }
76 | }
77 |
78 | initialized_ = true;
79 | return initialized_;
80 | }
81 |
82 | // process one frame per time
83 | void DeltaDelta::Compute(const Tensor& input_feats, int frame,
84 | std::vector* output) const {
85 | if (!initialized_) {
86 | LOG(ERROR) << "DeltaDelta not initialized.";
87 | return;
88 | }
89 |
90 | int num_frames = input_feats.dim_size(0);
91 | int feat_dim = input_feats.dim_size(1);
92 | int output_dim = feat_dim * (order_ + 1);
93 |
94 | output->resize(output_dim);
95 | auto input = input_feats.matrix();
96 |
97 | for (int i = 0; i <= order_; i++) {
98 | const std::vector& scales = scales_[i];
99 | int max_offset = (scales.size() - 1) / 2;
100 | // e.g. max_offset=2, (-2, 2)
101 | for (int j = -max_offset; j <= max_offset; j++) {
102 | // if asked to read
103 | int offset_frame = frame + j;
104 | if (offset_frame < 0)
105 | offset_frame = 0;
106 | else if (offset_frame >= num_frames)
107 | offset_frame = num_frames - 1;
108 |
109 | // sacles[0] for `-max_offset` frame
110 | double scale = scales[j + max_offset];
111 | if (scale != 0.0) {
112 | for (int k = 0; k < feat_dim; k++) {
113 | (*output)[i + k * (order_ + 1)] += input(offset_frame, k) * scale;
114 | }
115 | }
116 | }
117 | }
118 | return;
119 | }
120 |
121 | } // namespace delta
122 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/delta_delta.h:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #ifndef DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_
18 | #define DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_
19 |
20 | #include
21 |
22 | #include "tensorflow/core/framework/op_kernel.h"
23 | #include "tensorflow/core/framework/tensor.h"
24 | #include "tensorflow/core/platform/logging.h"
25 |
26 | using namespace tensorflow; // NOLINT
27 |
28 | namespace delta {
29 |
30 | // This class provides a low-level function to compute delta features.
31 | // The function takes as input a matrix of features and a frame index
32 | // that it should compute the deltas on. It puts its output in an object
33 | // of type VectorBase, of size (original-feature-dimension) * (opts.order+1).
34 | // This is not the most efficient way to do the computation, but it's
35 | // state-free and thus easier to understand
36 | class DeltaDelta {
37 | public:
38 | DeltaDelta();
39 | ~DeltaDelta();
40 |
41 | bool Initialize(int order, int window);
42 |
43 | // Input is a single feature frame. Output is populated with
44 | // the feature, delta, delta-delta values.
45 | void Compute(const Tensor& input_feats, int frame,
46 | std::vector* output) const;
47 |
48 | void set_order(int order) {
49 | CHECK(!initialized_) << "Set order before calling Initialize.";
50 | order_ = order;
51 | }
52 |
53 | void set_window(int window) {
54 | CHECK(!initialized_) << "Set window before calling Initialize.";
55 | window_ = window;
56 | }
57 |
58 | private:
59 | int order_;
60 | // e.g. 2; controls window size (window size is 2*window + 1)
61 | // the behavior at the edges is to replicate the first or last frame.
62 | // this is not configurable.
63 | int window_;
64 |
65 | // a scaling window for each of the orders, including zero:
66 | // multiply the features for each dimension by this window.
67 | std::vector > scales_;
68 |
69 | bool initialized_;
70 | TF_DISALLOW_COPY_AND_ASSIGN(DeltaDelta);
71 | }; // class DeltaDelta
72 |
73 | } // namespace delta
74 |
75 | #endif // DELTA_LAYERS_OPS_KERNELS_DELTA_DELTA_H_
76 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/delta_delta_op.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | // See docs in ../ops/audio_ops.cc
18 | #include "kernels/delta_delta.h"
19 | #include "tensorflow/core/framework/op_kernel.h"
20 | #include "tensorflow/core/framework/register_types.h"
21 | #include "tensorflow/core/framework/tensor.h"
22 | #include "tensorflow/core/framework/tensor_shape.h"
23 | #include "tensorflow/core/framework/types.h"
24 | #include "tensorflow/core/lib/core/status.h"
25 |
26 | // https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md
27 |
28 | namespace delta {
29 |
30 | // add first and seoncd order derivatives
31 | class DeltaDeltaOp : public OpKernel {
32 | public:
33 | explicit DeltaDeltaOp(OpKernelConstruction* context) : OpKernel(context) {
34 | OP_REQUIRES_OK(context, context->GetAttr("order", &order_));
35 | OP_REQUIRES_OK(context, context->GetAttr("window", &window_));
36 | }
37 |
38 | void Compute(OpKernelContext* context) override {
39 | const Tensor& feats = context->input(0);
40 | OP_REQUIRES(context, feats.dims() == 2,
41 | errors::InvalidArgument("features must be 2-dimensional",
42 | feats.shape().DebugString()));
43 | // feats shape [time, feat dim]
44 | const int time = feats.dim_size(0); // num frames
45 | const int feat_dim = feats.dim_size(1);
46 | const int output_dim = feat_dim * (order_ + 1);
47 |
48 | DeltaDelta delta;
49 | OP_REQUIRES(
50 | context, delta.Initialize(order_, window_),
51 | errors::InvalidArgument("DeltaDelta initialization failed for order ",
52 | order_, " and window ", window_));
53 |
54 | Tensor* output_tensor = nullptr;
55 | OP_REQUIRES_OK(context,
56 | context->allocate_output(0, TensorShape({time, output_dim}),
57 | &output_tensor));
58 |
59 | // TType::Tensor feats_t = feats.tensor;
60 | float* output_flat = output_tensor->flat().data();
61 |
62 | for (int t = 0; t < time; t++) {
63 | float* row = output_flat + t * output_dim;
64 |
65 | // add delta-delta
66 | std::vector out;
67 | delta.Compute(feats, t, &out);
68 |
69 | // fill output buffer
70 | DCHECK_EQ(output_dim, out.size());
71 | for (int i = 0; i < output_dim; i++) {
72 | row[i] = static_cast(out[i]);
73 | }
74 | }
75 | }
76 |
77 | private:
78 | int order_;
79 | int window_;
80 | };
81 |
82 | REGISTER_KERNEL_BUILDER(Name("DeltaDelta").Device(DEVICE_CPU), DeltaDeltaOp);
83 |
84 | } // namespace delta
85 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/fbank.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #include "kernels/fbank.h"
18 |
19 | #include
20 |
21 | #include "tensorflow/core/platform/logging.h"
22 |
23 | namespace delta {
24 |
25 | const double kDefaultUpperFrequencyLimit = 4000;
26 | const double kDefaultLowerFrequencyLimit = 20;
27 | const double kFilterbankFloor = 1e-12;
28 | const int kDefaultFilterbankChannelCount = 40;
29 |
30 | Fbank::Fbank()
31 | : initialized_(false),
32 | lower_frequency_limit_(kDefaultLowerFrequencyLimit),
33 | upper_frequency_limit_(kDefaultUpperFrequencyLimit),
34 | filterbank_channel_count_(kDefaultFilterbankChannelCount) {}
35 |
36 | Fbank::~Fbank() {}
37 |
38 | bool Fbank::Initialize(int input_length, double input_sample_rate) {
39 | if (input_length < 1) {
40 | LOG(ERROR) << "Input length must be positive.";
41 | return false;
42 | }
43 | input_length_ = input_length;
44 |
45 | bool initialized = mel_filterbank_.Initialize(
46 | input_length, input_sample_rate, filterbank_channel_count_,
47 | lower_frequency_limit_, upper_frequency_limit_);
48 | initialized_ = initialized;
49 | return initialized;
50 | }
51 |
52 | void Fbank::Compute(const std::vector& spectrogram_frame,
53 | std::vector* output) const {
54 | if (!initialized_) {
55 | LOG(ERROR) << "Fbank not initialized.";
56 | return;
57 | }
58 |
59 | output->resize(filterbank_channel_count_);
60 | mel_filterbank_.Compute(spectrogram_frame, output);
61 | for (int i = 0; i < output->size(); ++i) {
62 | double val = (*output)[i];
63 | if (val < kFilterbankFloor) {
64 | val = kFilterbankFloor;
65 | }
66 | (*output)[i] = log(val);
67 | }
68 | }
69 |
70 | } // namespace delta
71 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/fbank.h:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #ifndef DELTA_LAYERS_OPS_KERNELS_FBANK_H_
18 | #define DELTA_LAYERS_OPS_KERNELS_FBANK_H_
19 |
20 | #include // NOLINT
21 |
22 | #include "kernels/mfcc_mel_filterbank.h"
23 |
24 | #include "tensorflow/core/framework/op_kernel.h"
25 | #include "tensorflow/core/platform/logging.h"
26 |
27 | using namespace tensorflow; // NOLINT
28 |
29 | namespace delta {
30 |
31 | // logfbank
32 | class Fbank {
33 | public:
34 | Fbank();
35 | ~Fbank();
36 | bool Initialize(int input_length, double input_sample_rate);
37 | // Input is a single squared-magnitude spectrogram frame. The input spectrum
38 | // is converted to linear magnitude and weighted into bands using a
39 | // triangular mel filterbank. Output is populated with the lowest
40 | // fbank_channel_count
41 | // of these values.
42 | void Compute(const std::vector& spectrogram_frame,
43 | std::vector* output) const;
44 |
45 | void set_upper_frequency_limit(double upper_frequency_limit) {
46 | CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
47 | upper_frequency_limit_ = upper_frequency_limit;
48 | }
49 |
50 | void set_lower_frequency_limit(double lower_frequency_limit) {
51 | CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
52 | lower_frequency_limit_ = lower_frequency_limit;
53 | }
54 |
55 | void set_filterbank_channel_count(int filterbank_channel_count) {
56 | CHECK(!initialized_) << "Set channel count before calling Initialize.";
57 | filterbank_channel_count_ = filterbank_channel_count;
58 | }
59 |
60 | private:
61 | MfccMelFilterbank mel_filterbank_;
62 | int input_length_;
63 | bool initialized_;
64 | double lower_frequency_limit_;
65 | double upper_frequency_limit_;
66 | int filterbank_channel_count_;
67 | TF_DISALLOW_COPY_AND_ASSIGN(Fbank);
68 | }; // class Fbank
69 |
70 | } // namespace delta
71 |
72 | #endif // DELTA_LAYERS_OPS_KERNELS_FBANK_H_
73 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/fbank_op_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ fbank op unittest"""
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from delta.layers.ops import py_x_ops
21 |
22 |
23 | class FbankOpTest(tf.test.TestCase):
24 | """ fbank op unittest"""
25 |
26 | def setUp(self):
27 | """ setup """
28 |
29 | def tearDown(self):
30 | """ tear donw """
31 |
32 | def test_fbank(self):
33 | """ test fbank op"""
34 | with self.session():
35 | data = np.arange(513)
36 | spectrogram = tf.constant(data[None, None, :], dtype=tf.float32)
37 | sample_rate = tf.constant(22050, tf.int32)
38 | output = py_x_ops.fbank(
39 | spectrogram, sample_rate, filterbank_channel_count=20
40 | )
41 |
42 | output_true = np.array(
43 | [
44 | 1.887894,
45 | 2.2693727,
46 | 2.576507,
47 | 2.8156495,
48 | 3.036504,
49 | 3.2296343,
50 | 3.4274294,
51 | 3.5987632,
52 | 3.771217,
53 | 3.937401,
54 | 4.0988584,
55 | 4.2570987,
56 | 4.4110703,
57 | 4.563661,
58 | 4.7140336,
59 | 4.8626432,
60 | 5.009346,
61 | 5.1539173,
62 | 5.2992935,
63 | 5.442024,
64 | ]
65 | )
66 | self.assertEqual(tf.rank(output).eval(), 3)
67 | self.assertEqual(output.shape, (1, 1, 20))
68 | self.assertAllClose(output.eval(), output_true[None, None, :])
69 |
70 |
71 | if __name__ == "__main__":
72 | tf.test.main()
73 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/framepow.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #include "kernels/framepow.h"
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | namespace delta {
24 | const float window_length_sec = 0.025;
25 | const float frame_length_sec = 0.010;
26 |
27 | FramePow::FramePow() {
28 | window_length_sec_ = window_length_sec;
29 | frame_length_sec_ = frame_length_sec;
30 | i_snip_edges = 1;
31 | i_remove_dc_offset = true;
32 | pf_FrmEng = NULL;
33 | }
34 |
35 | FramePow::~FramePow() { free(pf_FrmEng); }
36 |
37 | void FramePow::set_window_length_sec(float window_length_sec) {
38 | window_length_sec_ = window_length_sec;
39 | }
40 |
41 | void FramePow::set_frame_length_sec(float frame_length_sec) {
42 | frame_length_sec_ = frame_length_sec;
43 | }
44 |
45 | void FramePow::set_snip_edges(int snip_edges) { i_snip_edges = snip_edges; }
46 |
47 | void FramePow::set_remove_dc_offset(bool remove_dc_offset) {
48 | i_remove_dc_offset = remove_dc_offset;
49 | }
50 |
51 | int FramePow::init_eng(int input_size, float sample_rate) {
52 | f_SamRat = sample_rate;
53 | i_WinLen = static_cast(window_length_sec_ * f_SamRat);
54 | i_FrmLen = static_cast(frame_length_sec_ * f_SamRat);
55 | if (i_snip_edges == 1)
56 | i_NumFrm = (input_size - i_WinLen) / i_FrmLen + 1;
57 | else
58 | i_NumFrm = (input_size + i_FrmLen / 2) / i_FrmLen;
59 |
60 | pf_FrmEng = static_cast(malloc(sizeof(float) * i_NumFrm));
61 |
62 | return 1;
63 | }
64 |
65 | int FramePow::proc_eng(const float* mic_buf, int input_size) {
66 | int i, n, k;
67 | float* win = static_cast(malloc(sizeof(float) * i_WinLen));
68 |
69 | for (n = 0; n < i_NumFrm; n++) {
70 | pf_FrmEng[n] = 0.0;
71 | float sum = 0.0;
72 | float energy = 0.0;
73 | for (k = 0; k < i_WinLen; k++) {
74 | int index = n * i_FrmLen + k;
75 | if (index < input_size)
76 | win[k] = mic_buf[index];
77 | else
78 | win[k] = 0.0f;
79 | sum += win[k];
80 | }
81 |
82 | if (i_remove_dc_offset == true) {
83 | float mean = sum / i_WinLen;
84 | for (int l = 0; l < i_WinLen; l++) win[l] -= mean;
85 | }
86 |
87 | for (i = 0; i < i_WinLen; i++) {
88 | energy += win[i] * win[i];
89 | }
90 |
91 | pf_FrmEng[n] = log(energy);
92 |
93 | }
94 |
95 | free(win);
96 | return 1;
97 | }
98 |
99 | int FramePow::get_eng(float* output) {
100 | memcpy(output, pf_FrmEng, sizeof(float) * i_NumFrm);
101 |
102 | return 1;
103 | }
104 |
105 | int FramePow::write_eng() {
106 | FILE* fp;
107 | fp = fopen("frame_energy.txt", "w");
108 | int n;
109 | for (n = 0; n < i_NumFrm; n++) {
110 | fprintf(fp, "%4.6f\n", pf_FrmEng[n]);
111 | }
112 | fclose(fp);
113 | return 1;
114 | }
115 | } // namespace delta
116 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/framepow.h:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #ifndef DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_
18 | #define DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_
19 |
20 | #include "tensorflow/core/framework/op_kernel.h"
21 | #include "tensorflow/core/platform/logging.h"
22 |
23 | using namespace tensorflow; // NOLINT
24 |
25 | namespace delta {
26 | class FramePow {
27 | private:
28 | float window_length_sec_;
29 | float frame_length_sec_;
30 | int i_snip_edges;
31 | bool i_remove_dc_offset;
32 |
33 | float f_SamRat;
34 | int i_WinLen;
35 | int i_FrmLen;
36 | int i_NumFrm;
37 |
38 | float* pf_FrmEng;
39 |
40 | public:
41 | FramePow();
42 |
43 | ~FramePow();
44 |
45 | void set_window_length_sec(float window_length_sec);
46 |
47 | void set_frame_length_sec(float frame_length_sec);
48 |
49 | void set_snip_edges(int snip_edges);
50 |
51 | void set_remove_dc_offset(bool remove_dc_offset);
52 |
53 | int init_eng(int input_size, float sample_rate);
54 |
55 | int proc_eng(const float* mic_buf, int input_size);
56 |
57 | int get_eng(float* output);
58 |
59 | int write_eng();
60 |
61 | TF_DISALLOW_COPY_AND_ASSIGN(FramePow);
62 | };
63 | } // namespace delta
64 | #endif // DELTA_LAYERS_OPS_KERNELS_FRAMEPOW_H_
65 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/framepow_op.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #include "kernels/framepow.h"
18 |
19 | #include "tensorflow/core/framework/op_kernel.h"
20 | #include "tensorflow/core/framework/register_types.h"
21 | #include "tensorflow/core/framework/tensor.h"
22 | #include "tensorflow/core/framework/tensor_shape.h"
23 | #include "tensorflow/core/framework/types.h"
24 | #include "tensorflow/core/lib/core/status.h"
25 |
26 | namespace delta {
27 | class FramePowOp : public OpKernel {
28 | public:
29 | explicit FramePowOp(OpKernelConstruction* context) : OpKernel(context) {
30 | OP_REQUIRES_OK(context, context->GetAttr("window_length", &window_length_));
31 | OP_REQUIRES_OK(context, context->GetAttr("frame_length", &frame_length_));
32 | OP_REQUIRES_OK(context, context->GetAttr("snip_edges", &snip_edges_));
33 | OP_REQUIRES_OK(context,
34 | context->GetAttr("remove_dc_offset", &remove_dc_offset_));
35 | }
36 |
37 | void Compute(OpKernelContext* context) override {
38 | const Tensor& input_tensor = context->input(0);
39 | OP_REQUIRES(context, input_tensor.dims() == 1,
40 | errors::InvalidArgument("input signal must be 1-dimensional",
41 | input_tensor.shape().DebugString()));
42 |
43 | const Tensor& sample_rate_tensor = context->input(1);
44 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()),
45 | errors::InvalidArgument(
46 | "Input sample rate should be a scalar tensor, got ",
47 | sample_rate_tensor.shape().DebugString(), " instead."));
48 | const float sample_rate = sample_rate_tensor.scalar()();
49 |
50 | // shape
51 | const int L = input_tensor.dim_size(0);
52 | FramePow cls_eng;
53 | cls_eng.set_window_length_sec(window_length_);
54 | cls_eng.set_frame_length_sec(frame_length_);
55 | cls_eng.set_snip_edges(snip_edges_);
56 | cls_eng.set_remove_dc_offset(remove_dc_offset_);
57 | OP_REQUIRES(context, cls_eng.init_eng(L, sample_rate),
58 | errors::InvalidArgument(
59 | "framepow_class initialization failed for length ", L,
60 | " and sample rate ", sample_rate));
61 |
62 | Tensor* output_tensor = nullptr;
63 | int i_WinLen = static_cast(window_length_ * sample_rate);
64 | int i_FrmLen = static_cast(frame_length_ * sample_rate);
65 | int i_NumFrm = (L - i_WinLen) / i_FrmLen + 1;
66 | if (snip_edges_ == 2) i_NumFrm = (L + i_FrmLen / 2) / i_FrmLen;
67 | if (i_NumFrm < 1) i_NumFrm = 1;
68 | OP_REQUIRES_OK(context, context->allocate_output(
69 | 0, TensorShape({1, i_NumFrm}), &output_tensor));
70 |
71 | const float* input_flat = input_tensor.flat().data();
72 | float* output_flat = output_tensor->flat().data();
73 |
74 | int ret;
75 | ret = cls_eng.proc_eng(input_flat, L);
76 | ret = cls_eng.get_eng(output_flat);
77 | }
78 |
79 | private:
80 | float window_length_;
81 | float frame_length_;
82 | int snip_edges_;
83 | bool remove_dc_offset_;
84 | };
85 |
86 | REGISTER_KERNEL_BUILDER(Name("FramePow").Device(DEVICE_CPU), FramePowOp);
87 |
88 | } // namespace delta
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/mfcc_dct.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #include "kernels/mfcc_dct.h"
17 |
18 | #include
19 |
20 | #include "tensorflow/core/platform/logging.h"
21 |
22 | namespace delta {
23 |
24 | const float kDefaultCepstralLifter = 22;
25 | const int kDefaultCoefficientCount = 13;
26 |
27 | MfccDct::MfccDct()
28 | : initialized_(false),
29 | coefficient_count_(kDefaultCoefficientCount),
30 | cepstral_lifter_(kDefaultCepstralLifter) {}
31 |
32 | bool MfccDct::Initialize(int input_length, int coefficient_count) {
33 | coefficient_count_ = coefficient_count;
34 | input_length_ = input_length;
35 |
36 | if (coefficient_count_ < 1) {
37 | LOG(ERROR) << "Coefficient count must be positive.";
38 | return false;
39 | }
40 |
41 | if (input_length < 1) {
42 | LOG(ERROR) << "Input length must be positive.";
43 | return false;
44 | }
45 |
46 | if (coefficient_count_ > input_length_) {
47 | LOG(ERROR) << "Coefficient count must be less than or equal to "
48 | << "input length.";
49 | return false;
50 | }
51 |
52 | cosines_.resize(coefficient_count_);
53 | double fnorm = sqrt(2.0 / input_length_);
54 | // Some platforms don't have M_PI, so define a local constant here.
55 | const double pi = std::atan(1) * 4;
56 | double arg = pi / input_length_;
57 | for (int i = 0; i < coefficient_count_; ++i) {
58 | cosines_[i].resize(input_length_);
59 | for (int j = 0; j < input_length_; ++j) {
60 | cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5));
61 | }
62 | }
63 |
64 | lifter_coeffs_.resize(coefficient_count_);
65 | for (int j = 0; j < coefficient_count_; ++j)
66 | lifter_coeffs_[j] =
67 | 1.0 + 0.5 * cepstral_lifter_ * sin(PI * j / cepstral_lifter_);
68 |
69 | initialized_ = true;
70 | return true;
71 | }
72 |
73 | void MfccDct::set_coefficient_count(int coefficient_count) {
74 | coefficient_count_ = coefficient_count;
75 | }
76 |
77 | void MfccDct::set_cepstral_lifter(float cepstral_lifter) {
78 | cepstral_lifter_ = cepstral_lifter;
79 | }
80 |
81 | void MfccDct::Compute(const std::vector &input,
82 | std::vector *output) const {
83 | if (!initialized_) {
84 | LOG(ERROR) << "DCT not initialized.";
85 | return;
86 | }
87 |
88 | output->resize(coefficient_count_);
89 | int length = input.size();
90 | if (length > input_length_) {
91 | length = input_length_;
92 | }
93 |
94 | double res;
95 | for (int i = 0; i < coefficient_count_; ++i) {
96 | double sum = 0.0;
97 | for (int j = 0; j < length; ++j) {
98 | sum += cosines_[i][j] * input[j];
99 | }
100 | res = sum;
101 | if (cepstral_lifter_ != 0) res *= lifter_coeffs_[i];
102 | (*output)[i] = res;
103 | }
104 | }
105 |
106 | } // namespace delta
107 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/mfcc_dct.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | // Basic minimal DCT class for MFCC speech processing.
17 |
18 | #ifndef TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ // NOLINT
19 | #define TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ // NOLINT
20 |
21 | #include
22 |
23 | #include "tensorflow/core/framework/op_kernel.h"
24 | #include "tensorflow/core/platform/logging.h"
25 | #include "kernels/support_functions.h"
26 |
27 | using namespace tensorflow; // NOLINT
28 | #define PI (3.141592653589793)
29 |
30 | namespace delta {
31 |
32 | class MfccDct {
33 | public:
34 | MfccDct();
35 | bool Initialize(int input_length, int coefficient_count);
36 | void Compute(const std::vector& input,
37 | std::vector* output) const;
38 | void set_coefficient_count(int coefficient_count);
39 | void set_cepstral_lifter(float cepstral_lifter);
40 |
41 | private:
42 | bool initialized_;
43 | int coefficient_count_;
44 | float cepstral_lifter_;
45 | int input_length_;
46 | std::vector > cosines_;
47 | std::vector lifter_coeffs_;
48 | TF_DISALLOW_COPY_AND_ASSIGN(MfccDct);
49 | };
50 |
51 | } // namespace delta
52 |
53 | #endif // DELTA_LAYERS_OPS_KERNELS_MFCC_DCT_H_
54 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/mfcc_mel_filterbank.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | // Basic class for applying a mel-scale mapping to a power spectrum.
17 |
18 | #ifndef DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_
19 | #define DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_
20 |
21 | #include
22 |
23 | #include "tensorflow/core/framework/op_kernel.h"
24 |
25 | namespace tensorflow {
26 |
27 | class MfccMelFilterbank {
28 | public:
29 | MfccMelFilterbank();
30 | ~MfccMelFilterbank();
31 | bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1.
32 | double input_sample_rate, int output_channel_count,
33 | double lower_frequency_limit, double upper_frequency_limit);
34 |
35 | // Takes a squared-magnitude spectrogram slice as input, computes a
36 | // triangular-mel-weighted linear-magnitude filterbank, and places the result
37 | // in output.
38 | void Compute(const std::vector& input,
39 | std::vector* output) const;
40 |
41 | private:
42 | double FreqToMel(double freq) const;
43 | bool initialized_;
44 | int num_channels_;
45 | double sample_rate_;
46 | int input_length_;
47 | std::vector center_frequencies_; // In mel, for each mel channel.
48 |
49 | // Each FFT bin b contributes to two triangular mel channels, with
50 | // proportion weights_[b] going into mel channel band_mapper_[b], and
51 | // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1.
52 | // Thus, weights_ contains the weighting applied to each FFT bin for the
53 | // upper-half of the triangular band.
54 | std::vector weights_; // Right-side weight for this fft bin.
55 |
56 | // FFT bin i contributes to the upper side of mel channel band_mapper_[i]
57 | std::vector band_mapper_;
58 | int start_index_; // Lowest FFT bin used to calculate mel spectrum.
59 | int end_index_; // Highest FFT bin used to calculate mel spectrum.
60 |
61 | TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank);
62 | };
63 |
64 | } // namespace tensorflow
65 | #endif // DELTA_LAYERS_OPS_KERNELS_MFCC_MEL_FILTERBANK_H_
66 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/resample.h:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 |
18 | #ifndef DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_
19 | #define DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_
20 |
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 | #include "tensorflow/core/framework/op_kernel.h"
28 | #include "tensorflow/core/platform/logging.h"
29 | #include "kernels/support_functions.h"
30 |
31 | using namespace tensorflow; // NOLINT
32 | using namespace std;
33 |
34 | namespace delta {
35 |
36 | class ArbitraryResample {
37 | public:
38 | ArbitraryResample(int num_samples_in,
39 | BaseFloat samp_rate_hz,
40 | BaseFloat filter_cutoff_hz,
41 | const vector &sample_points_secs,
42 | int num_zeros);
43 |
44 | int NumSamplesIn() const { return num_samples_in_; }
45 |
46 | int NumSamplesOut() const { return weights_.size(); }
47 | void Resample(const std::vector > &input,
48 | std::vector > *output) const;
49 |
50 | void Resample(const vector &input,
51 | vector *output) const;
52 | private:
53 | void SetIndexes(const vector &sample_points);
54 |
55 | void SetWeights(const vector &sample_points);
56 |
57 | BaseFloat FilterFunc(BaseFloat t) const;
58 |
59 | int num_samples_in_;
60 | BaseFloat samp_rate_in_;
61 | BaseFloat filter_cutoff_;
62 | int num_zeros_;
63 |
64 | std::vector first_index_;
65 | std::vector > weights_;
66 | };
67 |
68 | class LinearResample {
69 | public:
70 | LinearResample(int samp_rate_in_hz,
71 | int samp_rate_out_hz,
72 | BaseFloat filter_cutoff_hz,
73 | int num_zeros);
74 |
75 | void Resample(const vector &input,
76 | bool flush,
77 | vector *output);
78 |
79 | void Reset();
80 |
81 | //// Return the input and output sampling rates (for checks, for example)
82 | inline int GetInputSamplingRate() { return samp_rate_in_; }
83 | inline int GetOutputSamplingRate() { return samp_rate_out_; }
84 | private:
85 | int GetNumOutputSamples(int input_num_samp, bool flush) const;
86 |
87 | inline void GetIndexes(int samp_out,
88 | int *first_samp_in,
89 | int *samp_out_wrapped) const;
90 |
91 | void SetRemainder(const vector &input);
92 |
93 | void SetIndexesAndWeights();
94 |
95 | BaseFloat FilterFunc(BaseFloat) const;
96 |
97 | // The following variables are provided by the user.
98 | int samp_rate_in_;
99 | int samp_rate_out_;
100 | BaseFloat filter_cutoff_;
101 | int num_zeros_;
102 |
103 | int input_samples_in_unit_;
104 | int output_samples_in_unit_;
105 |
106 | std::vector first_index_;
107 | std::vector > weights_;
108 |
109 | int input_sample_offset_;
110 | int output_sample_offset_;
111 | vector input_remainder_;
112 | };
113 |
114 | void ResampleWaveform(BaseFloat orig_freq, const vector &wave,
115 | BaseFloat new_freq, vector *new_wave);
116 |
117 | inline void DownsampleWaveForm(BaseFloat orig_freq, const vector &wave,
118 | BaseFloat new_freq, vector *new_wave) {
119 | ResampleWaveform(orig_freq, wave, new_freq, new_wave);
120 | }
121 |
122 | } // namespace delta
123 | #endif // DELTA_LAYERS_OPS_KERNELS_RESAMPLE_H_
124 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/spectrum.h:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #ifndef DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_
18 | #define DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_
19 |
20 | #include
21 | #include "tensorflow/core/framework/op_kernel.h"
22 | #include "tensorflow/core/platform/logging.h"
23 |
24 | #include "kernels/complex_defines.h"
25 | #include "kernels/support_functions.h"
26 |
27 | using namespace tensorflow; // NOLINT
28 |
29 | namespace delta {
30 | class Spectrum {
31 | private:
32 | float window_length_sec_;
33 | float frame_length_sec_;
34 |
35 | float f_SamRat;
36 | int i_WinLen;
37 | int i_FrmLen;
38 | int i_NumFrm;
39 | int i_NumFrq;
40 | int i_FFTSiz;
41 | float f_PreEph;
42 | char s_WinTyp[40];
43 | int i_OutTyp; // 1: PSD, 2:log(PSD)
44 | int i_snip_edges;
45 | int i_raw_energy;
46 | bool i_remove_dc_offset;
47 | bool i_is_fbank;
48 | float i_dither;
49 |
50 | float* pf_WINDOW;
51 | float* pf_SPC;
52 |
53 | xcomplex* win;
54 | float* win_buf;
55 | float* eph_buf;
56 | float* win_temp;
57 | xcomplex* fftwin;
58 | float* fft_buf;
59 |
60 | public:
61 | Spectrum();
62 |
63 | void set_window_length_sec(float window_length_sec);
64 |
65 | void set_frame_length_sec(float frame_length_sec);
66 |
67 | void set_output_type(int output_type);
68 |
69 | void set_snip_edges(int snip_edges);
70 |
71 | void set_raw_energy(int raw_energy);
72 |
73 | void set_preEph(float preEph);
74 |
75 | void set_window_type(char* window_type);
76 |
77 | void set_is_fbank(bool is_fbank);
78 |
79 | void set_remove_dc_offset(bool remove_dc_offset);
80 |
81 | void set_dither(float dither);
82 |
83 | int init_spc(int input_size, float sample_rate);
84 |
85 | int proc_spc(const float* mic_buf, int input_size);
86 |
87 | int get_spc(float* output);
88 |
89 | int write_spc();
90 |
91 | TF_DISALLOW_COPY_AND_ASSIGN(Spectrum);
92 | };
93 | } // namespace delta
94 | #endif // DELTA_LAYERS_OPS_KERNELS_SPECTRUM_H_
95 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/spectrum_op_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ spectrum Op unit-test """
17 | import os
18 | from pathlib import Path
19 |
20 | import numpy as np
21 | import tensorflow as tf
22 | from absl import logging
23 |
24 | from athena.transform.feats.ops import py_x_ops
25 |
26 |
27 | class SpecOpTest(tf.test.TestCase):
28 | """ spectrum op unittest"""
29 |
30 | def setUp(self):
31 | """set up"""
32 | self.wavpath = str(
33 | Path(os.environ["MAIN_ROOT"]).joinpath("delta/layers/ops/data/sm1_cln.wav")
34 | )
35 |
36 | def tearDown(self):
37 | """tear down"""
38 |
39 | def test_spectrum(self):
40 | """ test spectrum op"""
41 | with self.session(use_gpu=False, force_gpu=False):
42 | sample_rate, input_data = feat_lib.load_wav(self.wavpath, sr=16000)
43 | logging.info(
44 | f"input shape: {input_data.shape}, sample rate dtype: {sample_rate.dtype}"
45 | )
46 | self.assertEqual(sample_rate, 16000)
47 |
48 | output = py_x_ops.spectrum(input_data, sample_rate)
49 |
50 | # pylint: disable=bad-whitespace
51 | output_true = np.array(
52 | [
53 | [-16.863441, -16.910473, -17.077059, -16.371634, -16.845686],
54 | [-17.922068, -20.396345, -19.396944, -17.331493, -16.118851],
55 | [-17.017776, -17.551350, -20.332376, -17.403994, -16.617926],
56 | [-19.873854, -17.644503, -20.679525, -17.093716, -16.535091],
57 | [-17.074402, -17.295971, -16.896650, -15.995432, -16.560730],
58 | ]
59 | )
60 | # pylint: enable=bad-whitespace
61 |
62 | self.assertEqual(tf.rank(output).eval(), 2)
63 | logging.info("Shape of spectrum: {}".format(output.shape))
64 | self.assertAllClose(output.eval()[4:9, 4:9], output_true)
65 |
66 |
67 | if __name__ == "__main__":
68 | tf.test.main()
69 |
--------------------------------------------------------------------------------
/athena/transform/feats/ops/kernels/speed_op.cc:
--------------------------------------------------------------------------------
1 | /* Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | ==============================================================================*/
16 |
17 | #include "kernels/resample.h"
18 | #include "tensorflow/core/framework/op_kernel.h"
19 | #include "tensorflow/core/framework/register_types.h"
20 | #include "tensorflow/core/framework/tensor.h"
21 | #include "tensorflow/core/framework/tensor_shape.h"
22 | #include "tensorflow/core/framework/types.h"
23 | #include "tensorflow/core/lib/core/status.h"
24 |
25 | namespace delta {
26 |
27 | class SpeedOp : public OpKernel {
28 | public:
29 | explicit SpeedOp(OpKernelConstruction* context) : OpKernel(context) {
30 | OP_REQUIRES_OK(context,
31 | context->GetAttr("lowpass_filter_width", &lowpass_filter_width_));
32 | }
33 |
34 | void Compute(OpKernelContext* context) override {
35 | const Tensor& input_tensor = context->input(0);
36 | OP_REQUIRES(context, input_tensor.dims() == 1,
37 | errors::InvalidArgument("input signal must be 1-dimensional",
38 | input_tensor.shape().DebugString()));
39 | const Tensor& sample_rate_tensor = context->input(1);
40 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()),
41 | errors::InvalidArgument(
42 | "Input sample_rate should be a scalar tensor, got ",
43 | sample_rate_tensor.shape().DebugString(), " instead."));
44 | const Tensor& resample_rate_tensor = context->input(2);
45 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(resample_rate_tensor.shape()),
46 | errors::InvalidArgument(
47 | "Resample sample_rate should be a scalar tensor, got ",
48 | resample_rate_tensor.shape().DebugString(), " instead."));
49 | const int sample_rate = static_cast(sample_rate_tensor.scalar()());
50 | const int resample_freq = static_cast(resample_rate_tensor.scalar()());
51 | const float* input_flat = input_tensor.flat().data();
52 | const int L = input_tensor.dim_size(0);
53 |
54 | lowpass_cutoff_ = min(resample_freq / 2, sample_rate / 2);
55 | LinearResample cls_resample_(sample_rate, resample_freq,
56 | lowpass_cutoff_,
57 | lowpass_filter_width_);
58 | vector waveform(L);
59 | for (int i = 0; i < L; i++){
60 | waveform[i] = static_cast(input_flat[i]);
61 | }
62 | vector downsampled_wave;
63 | cls_resample_.Resample(waveform, false, &downsampled_wave);
64 | int output_length = downsampled_wave.size();
65 | Tensor* output_tensor = nullptr;
66 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({1, output_length}),
67 | &output_tensor));
68 | float* output_flat = output_tensor->flat().data();
69 | for (int j = 0; j < output_length; j++)
70 | output_flat[j] = downsampled_wave[j];
71 |
72 | std::vector().swap(downsampled_wave);
73 | std::vector().swap(waveform);
74 | cls_resample_.Reset();
75 | }
76 |
77 | private:
78 | float lowpass_cutoff_;
79 | int lowpass_filter_width_;
80 | };
81 |
82 | REGISTER_KERNEL_BUILDER(Name("Speed").Device(DEVICE_CPU), SpeedOp);
83 |
84 | } // namespace delta
--------------------------------------------------------------------------------
/athena/transform/feats/ops/py_x_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ python custom ops """
17 | import tensorflow.compat.v1 as tf
18 | from absl import logging
19 |
20 | so_lib_file = tf.io.gfile.glob(tf.resource_loader.get_data_files_path() + "/x_ops*.so")[
21 | 0
22 | ].split("/")[-1]
23 | path = tf.resource_loader.get_path_to_datafile(so_lib_file)
24 | logging.info("x_ops*.so path:{}".format(path))
25 |
26 | gen_x_ops = tf.load_op_library(tf.resource_loader.get_path_to_datafile(so_lib_file))
27 |
28 | spectrum = gen_x_ops.spectrum
29 | fbank = gen_x_ops.fbank
30 | delta_delta = gen_x_ops.delta_delta
31 | pitch = gen_x_ops.pitch
32 | mfcc = gen_x_ops.mfcc_dct
33 | frame_pow = gen_x_ops.frame_pow
34 | speed = gen_x_ops.speed
35 |
--------------------------------------------------------------------------------
/athena/transform/feats/pitch_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests pitch FE."""
17 |
18 | import os
19 | from pathlib import Path
20 | import tensorflow as tf
21 | from tensorflow.python.framework.ops import disable_eager_execution
22 | from athena.transform.feats.read_wav import ReadWav
23 | from athena.transform.feats.pitch import Pitch
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
26 |
27 |
28 | class SpectrumTest(tf.test.TestCase):
29 | """
30 | Pitch extraction test.
31 | """
32 | def test_spectrum(self):
33 | wav_path_16k = str(
34 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")
35 | )
36 | wav_path_8k = str(
37 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav")
38 | )
39 |
40 | with self.session():
41 | for wav_file in [wav_path_16k]:
42 | read_wav = ReadWav.params().instantiate()
43 | input_data, sample_rate = read_wav(wav_file)
44 |
45 | pitch = Pitch.params(
46 | {"window_length": 0.025, "soft_min_f0": 10.0}
47 | ).instantiate()
48 | pitch_test = pitch(input_data, sample_rate)
49 |
50 | if tf.executing_eagerly():
51 | self.assertEqual(tf.rank(pitch_test).numpy(), 2)
52 | else:
53 | self.assertEqual(tf.rank(pitch_test).eval(), 2)
54 |
55 | output_true = [
56 | [-0.1366025, 143.8855],
57 | [-0.0226383, 143.8855],
58 | [-0.08464742, 143.8855],
59 | [-0.08458386, 143.8855],
60 | [-0.1208689, 143.8855],
61 | ]
62 |
63 | if wav_file == wav_path_16k:
64 | if tf.executing_eagerly():
65 | print("Transform: ", pitch_test.numpy()[0:5, :])
66 | print("kaldi:", output_true)
67 | self.assertAllClose(
68 | pitch_test.numpy()[0:5, :],
69 | output_true,
70 | rtol=1e-05,
71 | atol=1e-05,
72 | )
73 | else:
74 | print("Transform: ", pitch_test.eval())
75 | print("kaldi:", output_true)
76 | self.assertAllClose(
77 | pitch_test.eval()[0:5, :],
78 | output_true,
79 | rtol=1e-05,
80 | atol=1e-05,
81 | )
82 |
83 |
84 | if __name__ == "__main__":
85 |
86 | is_eager = True
87 | if not is_eager:
88 | disable_eager_execution()
89 | else:
90 | if tf.__version__ < "2.0.0":
91 | tf.compat.v1.enable_eager_execution()
92 | tf.test.main()
93 |
--------------------------------------------------------------------------------
/athena/transform/feats/read_wav.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model reads audio sample from wav file."""
17 |
18 | import tensorflow as tf
19 | from athena.utils.hparam import HParams
20 | from athena.transform.feats.base_frontend import BaseFrontend
21 | from athena.transform.feats.ops import py_x_ops
22 |
23 |
24 | class ReadWav(BaseFrontend):
25 | """
26 | Read audio sample from wav file, return sample data and sample rate.
27 | """
28 | def __init__(self, config: dict):
29 | super().__init__(config)
30 |
31 | @classmethod
32 | def params(cls, config=None):
33 | """
34 | Set params.
35 | :param config: contains one optional parameters: audio_channels(int, default=1).
36 | :return: An object of class HParams, which is a set of hyperparameters as
37 | name-value pairs.
38 | """
39 | audio_channels = 1
40 |
41 | hparams = HParams(cls=cls)
42 | hparams.add_hparam('type', 'ReadWav')
43 | hparams.add_hparam('audio_channels', audio_channels)
44 |
45 | if config is not None:
46 | hparams.parse(config, True)
47 |
48 | return hparams
49 |
50 | def call(self, wavfile, speed=1.0):
51 | """
52 | Get audio data and sample rate from a wavfile.
53 | :param wavfile: filepath of wav
54 | speed: Speed of sample channels wanted (float, default=1.0)
55 | :return: 2 values. The first is a Tensor of audio data.
56 | The second return value isthe sample rate of the input wav
57 | file, which is a tensor with float dtype.
58 | """
59 | p = self.config
60 | contents = tf.io.read_file(wavfile)
61 | audio_data, sample_rate = tf.compat.v1.audio.decode_wav(
62 | contents, desired_channels=p.audio_channels)
63 | if (speed == 1.0):
64 | return tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32)
65 | else:
66 | resample_rate = tf.cast(sample_rate, dtype=tf.float32) * tf.cast(1.0 / speed, dtype=tf.float32)
67 | speed_data = py_x_ops.speed(tf.squeeze(audio_data * 32768, axis=-1),
68 | tf.cast(sample_rate, dtype=tf.int32),
69 | tf.cast(resample_rate, dtype=tf.int32),
70 | lowpass_filter_width=5)
71 | return tf.squeeze(speed_data), tf.cast(sample_rate, dtype=tf.int32)
72 |
73 |
74 | def read_wav(wavfile, audio_channels=1):
75 | """ read wav from file
76 | args: audio_channels = 1
77 | returns: tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32)
78 | """
79 | contents = tf.io.read_file(wavfile)
80 | audio_data, sample_rate = tf.compat.v1.audio.decode_wav(
81 | contents, desired_channels=audio_channels
82 | )
83 |
84 | return tf.squeeze(audio_data * 32768, axis=-1), tf.cast(sample_rate, dtype=tf.int32)
85 |
--------------------------------------------------------------------------------
/athena/transform/feats/read_wav_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests OP of read_wav """
17 |
18 | import os
19 | from pathlib import Path
20 | import librosa
21 | import tensorflow as tf
22 | from tensorflow.python.framework.ops import disable_eager_execution
23 | from athena.transform.feats.read_wav import ReadWav
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
26 |
27 | class ReadWavTest(tf.test.TestCase):
28 | """
29 | ReadWav OP test.
30 | """
31 | def test_read_wav(self):
32 | wav_path = str(Path(os.environ['MAIN_ROOT']).joinpath('examples/sm1_cln.wav'))
33 |
34 | with self.session():
35 | speed = 0.9
36 | read_wav = ReadWav.params().instantiate()
37 | input_data, sample_rate = read_wav(wav_path, speed)
38 |
39 | audio_data_true, sample_rate_true = librosa.load(wav_path, sr=16000)
40 | if (speed == 1.0):
41 | if tf.executing_eagerly():
42 | self.assertAllClose(input_data.numpy() / 32768, audio_data_true)
43 | self.assertAllClose(sample_rate.numpy(), sample_rate_true)
44 | else:
45 | self.assertAllClose(input_data.eval() / 32768, audio_data_true)
46 | self.assertAllClose(sample_rate.eval(), sample_rate_true)
47 |
48 |
49 | if __name__ == '__main__':
50 |
51 | is_eager = False
52 | if not is_eager:
53 | disable_eager_execution()
54 | else:
55 | if tf.__version__ < '2.0.0':
56 | tf.compat.v1.enable_eager_execution()
57 | tf.test.main()
58 |
--------------------------------------------------------------------------------
/athena/transform/feats/spectrum_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests spectrum FE."""
17 |
18 | import os
19 | from pathlib import Path
20 | import numpy as np
21 | import tensorflow as tf
22 | from tensorflow.python.framework.ops import disable_eager_execution
23 | from athena.transform.feats.read_wav import ReadWav
24 | from athena.transform.feats.spectrum import Spectrum
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
27 |
28 |
29 | class SpectrumTest(tf.test.TestCase):
30 | '''
31 | Spectum extraction test.
32 | '''
33 | def test_spectrum(self):
34 | wav_path_16k = str(
35 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav")
36 | )
37 | wav_path_8k = str(
38 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/english.wav")
39 | )
40 |
41 | with self.session():
42 | for wav_file in [wav_path_8k, wav_path_16k]:
43 | read_wav = ReadWav.params().instantiate()
44 | input_data, sample_rate = read_wav(wav_file)
45 |
46 | spectrum = Spectrum.params(
47 | {"window_length": 0.025, "dither": 0.0}
48 | ).instantiate()
49 | spectrum_test = spectrum(input_data, sample_rate)
50 |
51 | output_true = np.array(
52 | [
53 | [9.819611, 2.84503, 3.660894, 2.7779, 1.212233],
54 | [9.328745, 2.553949, 3.276319, 3.000918, 2.499342],
55 | ]
56 | )
57 | if tf.executing_eagerly():
58 | self.assertEqual(tf.rank(spectrum_test).numpy(), 2)
59 | else:
60 | self.assertEqual(tf.rank(spectrum_test).eval(), 2)
61 |
62 | if wav_file == wav_path_16k:
63 | if tf.executing_eagerly():
64 | self.assertAllClose(
65 | spectrum_test.numpy()[0:2, 0:5],
66 | output_true,
67 | rtol=1e-05,
68 | atol=1e-05,
69 | )
70 | else:
71 | self.assertAllClose(
72 | spectrum_test.eval()[0:2, 0:5],
73 | output_true,
74 | rtol=1e-05,
75 | atol=1e-05,
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 |
81 | is_eager = True
82 | if not is_eager:
83 | disable_eager_execution()
84 | else:
85 | if tf.__version__ < "2.0.0":
86 | tf.compat.v1.enable_eager_execution()
87 | tf.test.main()
88 |
--------------------------------------------------------------------------------
/athena/transform/feats/write_wav.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model write audio sample to wav file."""
17 |
18 | import tensorflow as tf
19 | from athena.utils.hparam import HParams
20 | from athena.transform.feats.base_frontend import BaseFrontend
21 |
22 |
23 | class WriteWav(BaseFrontend):
24 | """
25 | Encode audio data (input) using sample rate (input),
26 | return a write wav opration.
27 | """
28 |
29 | def __init__(self, config: dict):
30 | super().__init__(config)
31 |
32 | @classmethod
33 | def params(cls, config=None):
34 | """
35 | Set params.
36 | :param config: contains one optional parameters:sample_rate(int, default=16000).
37 | :return: An object of class HParams, which is a set of hyperparameters as
38 | name-value pairs.
39 | """
40 |
41 | sample_rate = 16000
42 |
43 | hparams = HParams(cls=cls)
44 | hparams.add_hparam('sample_rate', sample_rate)
45 |
46 | if config is not None:
47 | hparams.override_from_dict(config)
48 |
49 | return hparams
50 |
51 | def call(self, filename, audio_data, sample_rate):
52 | """
53 | Write wav using audio_data[tensor].
54 | :param filename: filepath of wav.
55 | :param audio_data: a tensor containing data of a wav.
56 | :param sample_rate: the samplerate of the signal we working with.
57 | :return: write wav opration.
58 | """
59 | filename = tf.constant(filename)
60 |
61 | with tf.name_scope('writewav'):
62 | audio_data = tf.cast(audio_data, dtype=tf.float32)
63 | contents = tf.audio.encode_wav(
64 | tf.expand_dims(audio_data, 1), tf.cast(sample_rate, dtype=tf.int32))
65 | w_op = tf.io.write_file(filename, contents)
66 |
67 | return w_op
68 |
--------------------------------------------------------------------------------
/athena/transform/feats/write_wav_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The model tests WriteWav OP."""
17 |
18 | import os
19 |
20 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
21 | from pathlib import Path
22 | import librosa
23 | import tensorflow as tf
24 | from tensorflow.python.framework.ops import disable_eager_execution
25 | from athena.transform.feats.read_wav import ReadWav
26 | from athena.transform.feats.write_wav import WriteWav
27 |
28 |
29 | class WriteWavTest(tf.test.TestCase):
30 | """
31 | WriteWav OP test.
32 | """
33 | def test_write_wav(self):
34 | wav_path = str(Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln.wav"))
35 |
36 | with self.cached_session() as sess:
37 | config = {"speed": 1.1}
38 | read_wav = ReadWav.params(config).instantiate()
39 | input_data, sample_rate = read_wav(wav_path)
40 | write_wav = WriteWav.params().instantiate()
41 | new_path = str(
42 | Path(os.environ["MAIN_ROOT"]).joinpath("examples/sm1_cln_resample.wav")
43 | )
44 | writewav_op = write_wav(new_path, input_data / 32768, sample_rate)
45 | sess.run(writewav_op)
46 |
47 |
48 | if __name__ == "__main__":
49 | is_eager = True
50 | if not is_eager:
51 | disable_eager_execution()
52 | else:
53 | if tf.__version__ < "2.0.0":
54 | tf.compat.v1.enable_eager_execution()
55 | tf.test.main()
56 |
--------------------------------------------------------------------------------
/athena/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) ATHENA AUTHORS
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ utils """
17 |
--------------------------------------------------------------------------------
/athena/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) ATHENA AUTHORS
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | # Only support eager mode
17 | # pylint: disable=invalid-name
18 | r""" check point manager """
19 | import os
20 | import tensorflow as tf
21 | from absl import logging
22 | import numpy as np
23 |
24 | class Checkpoint(tf.train.Checkpoint):
25 | """ A wrapper for Tensorflow checkpoint
26 |
27 | Args:
28 | checkpoint_directory: the directory for checkpoint
29 | summary_directory: the directory for summary used in Tensorboard
30 | __init__ provide the optimizer and model
31 | __call__ save the model
32 |
33 | Example:
34 | transformer = SpeechTransformer(target_vocab_size=dataset_builder.target_dim)
35 | optimizer = tf.keras.optimizers.Adam()
36 | ckpt = Checkpoint(checkpoint_directory='./train', summary_directory='./event',
37 | transformer=transformer, optimizer=optimizer)
38 | solver = BaseSolver(transformer)
39 | for epoch in dataset:
40 | ckpt()
41 | """
42 |
43 | def __init__(self, checkpoint_directory=None, **kwargs):
44 | super().__init__(**kwargs)
45 | self.best_loss = np.inf
46 | if checkpoint_directory is None:
47 | checkpoint_directory = os.path.join(os.path.expanduser("~"), ".athena")
48 | self.checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
49 | self.checkpoint_directory = checkpoint_directory
50 | logging.info("trying to restore from : %s" % checkpoint_directory)
51 | # load from checkpoint if previous models exist in checkpoint dir
52 | self.restore(tf.train.latest_checkpoint(checkpoint_directory))
53 |
54 | def _compare_and_save_best(self, loss, save_path):
55 | """ compare and save the best model in best_loss """
56 | if loss is None:
57 | return
58 | if loss < self.best_loss:
59 | self.best_loss = loss
60 | with open(os.path.join(self.checkpoint_directory, 'best_loss'), 'w') as wf:
61 | checkpoint = save_path.split('/')[-1]
62 | wf.write('model_checkpoint_path: "%s"' % checkpoint)
63 |
64 | def __call__(self, loss=None):
65 | logging.info("saving model in :%s" % self.checkpoint_prefix)
66 | save_path = self.save(file_prefix=self.checkpoint_prefix)
67 | self._compare_and_save_best(loss, save_path)
68 |
69 | def restore_from_best(self):
70 | """ restore from the best model """
71 | self.restore(
72 | tf.train.latest_checkpoint(
73 | self.checkpoint_directory,
74 | latest_filename='best_loss'
75 | )
76 | )
--------------------------------------------------------------------------------
/athena/utils/data_queue.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | # pylint: disable=missing-function-docstring, invalid-name
18 | """ data queue for multi thread """
19 | import time
20 | import threading
21 | import queue
22 |
23 |
24 | class DataQueue:
25 | """Queue for data prefetching
26 | args:
27 | generator(generator): instance of generator which feed data
28 | capacity(int): maximum data to prefetch
29 | num_threads(int): control concurrency, only take effect when do preprocessing
30 | wait_time(float): time to sleep when queue is full
31 | """
32 |
33 | def __init__(self, generator, capacity=20, num_threads=4, max_index=10000, wait_time=0.0001):
34 | self.generator = generator
35 | self.capacity = capacity
36 | self.wait_time = wait_time
37 | self.queue = queue.Queue()
38 | self.index = 0
39 | self.max_index = max_index
40 |
41 | self._stop = threading.Event()
42 | self._lock = threading.Lock()
43 |
44 | self.threads = [
45 | threading.Thread(target=self.generator_task) for _ in range(num_threads)
46 | ]
47 |
48 | for t in self.threads:
49 | t.setDaemon(True)
50 | t.start()
51 |
52 | def __del__(self):
53 | self.stop()
54 |
55 | def get(self):
56 | return self.queue.get()
57 |
58 | def stop(self):
59 | self._stop.set()
60 |
61 | def generator_task(self):
62 | """Enqueue batch data"""
63 | while not self._stop.is_set():
64 | try:
65 | if self.index >= self.max_index:
66 | continue
67 | batch = self.generator(self.index)
68 | self._lock.acquire()
69 | if self.queue.qsize() < self.capacity:
70 | try:
71 | self.index = self.index + 1
72 | except ValueError as e:
73 | print(e)
74 | self._lock.release()
75 | continue
76 | self.queue.put(batch)
77 | self._lock.release()
78 | else:
79 | self._lock.release()
80 | time.sleep(self.wait_time)
81 | except Exception as e:
82 | print(e)
83 | self._stop.set()
84 | raise
85 |
86 |
87 | def test():
88 | """
89 | Test data queue.
90 | Excpet return:
91 | epoch: %d, nb_batch: %d: finish.
92 | """
93 |
94 | def generator(i):
95 | return i
96 |
97 | train_queue = DataQueue(generator, capacity=8, num_threads=4)
98 | for _ in range(92):
99 | print(train_queue.get())
100 | train_queue.stop()
101 |
102 |
103 | if __name__ == "__main__":
104 | test()
105 |
--------------------------------------------------------------------------------
/athena/utils/metric_check.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """MetricChecker"""
18 | import time
19 | import tensorflow as tf
20 | import numpy as np
21 |
22 |
23 | class MetricChecker:
24 | """Hold and save best metric checkpoint
25 | Args:
26 | name: MetricChecker name
27 | maximum: more greater more better
28 | """
29 |
30 | def __init__(self, optimizer):
31 | self.best_loss = tf.constant(np.inf)
32 | self.optimizer = optimizer
33 | self.time_last_call = time.time()
34 | self.steps_last_call = 0
35 |
36 | def __call__(self, loss, metrics, evaluate_epoch=-1):
37 | """summary the basic metrics like loss, lr
38 | Args:
39 | loss:
40 | matrics: average loss of all previous steps in one epoch
41 | if training is False, it must be provided
42 | evaluate_epoch:
43 | if evaluate_epoch >= 0:
44 | if evaluate_epoch == -1:
45 | if evaluate_epoch < -1: (no tf.summary.write)
46 | Returns:
47 | logging_str: return average and best(if improved) loss if training is False
48 | """
49 | if evaluate_epoch is -1:
50 | return self.summary_train(loss, metrics)
51 | return self.summary_evaluate(loss, metrics, evaluate_epoch)
52 |
53 | def summary_train(self, loss, metrics):
54 | """ generate summary of learning_rate, loss, metrics, speed and write on Tensorboard
55 | """
56 | global_steps = tf.convert_to_tensor(self.optimizer.iterations)
57 | learning_rate = (
58 | self.optimizer.lr
59 | if isinstance(self.optimizer.lr, tf.Variable)
60 | else (self.optimizer.lr(global_steps))
61 | )
62 |
63 | tf.summary.scalar("loss", loss, step=global_steps)
64 | tf.summary.scalar("learning_rate", learning_rate, step=global_steps)
65 | for name in metrics:
66 | metric = metrics[name]
67 | tf.summary.scalar(name, metric, step=global_steps)
68 |
69 | reports = ""
70 | reports += "global_steps: %d\t" % (global_steps)
71 | reports += "learning_rate: %.4e\t" % (learning_rate)
72 | reports += "loss: %.4f\t" % (loss)
73 | for name in metrics:
74 | metric = metrics[name]
75 | reports += "%s: %.4f\t" % (name, metric)
76 | right_now = time.time()
77 | duration = right_now - self.time_last_call
78 | self.time_last_call = right_now
79 | if self.steps_last_call != 0:
80 | # average duration over log_interval steps
81 | sec_per_iter = duration / tf.cast(
82 | (global_steps - self.steps_last_call), tf.float32
83 | )
84 | reports += "sec/iter: %.4f" % (sec_per_iter)
85 | self.steps_last_call = global_steps
86 |
87 | return reports
88 |
89 | def summary_evaluate(self, loss, metrics, epoch=-1):
90 | """ If epoch > 0, return a summary of loss and metrics on dev set and write on Tensorboard
91 | Otherwise, just return evaluate loss and metrics
92 | """
93 | reports = ""
94 | if epoch >= 0:
95 | tf.summary.scalar("evaluate_loss", loss, step=epoch)
96 | for name in metrics:
97 | metric = metrics[name]
98 | tf.summary.scalar("evaluate_" + name, metric, step=epoch)
99 | reports += "epoch: %d\t" % (epoch)
100 | reports += "loss: %.4f\t" % (loss)
101 | for name in metrics:
102 | metric = metrics[name]
103 | reports += "%s: %.4f\t" % (name, metric)
104 | return reports
105 |
--------------------------------------------------------------------------------
/athena/utils/vocabs/en.vocab:
--------------------------------------------------------------------------------
1 | a 1
2 | b 2
3 | c 3
4 | d 4
5 | e 5
6 | f 6
7 | g 7
8 | h 8
9 | i 9
10 | j 10
11 | k 11
12 | l 12
13 | m 13
14 | n 14
15 | o 15
16 | p 16
17 | q 17
18 | r 18
19 | s 19
20 | t 20
21 | u 21
22 | v 22
23 | w 23
24 | x 24
25 | y 25
26 | z 26
27 | ' 27
28 | - 28
29 | 0
30 | 29
31 |
--------------------------------------------------------------------------------
/docs/TheTrainningEfficiency.md:
--------------------------------------------------------------------------------
1 | # The efficiency of GPU Using '``Horovod``+``Tensorflow``'
2 |
3 | ## Experimental
4 |
5 | The Training Environment: ``Athena``
6 |
7 |
8 | Traning Data: A subset was random selected 1000 samples from HKST training dataset.
9 |
10 |
11 |
12 | Newwork: ``LAS`` Model
13 |
14 | Primary Network Configuration: ``NUM_EPOCHS`` 1, ``BATCH_SIZE`` 10
15 |
16 |
17 |
18 | The training time is changed by deferent number of of server and GPU when using `Horovod`+`Tensorflow`. As the same time, the training data and network structure etc still keep same to train `one` `epoch`. These results of experiment as follow:
19 |
20 | ### The training time using ``Horovod``+``Tensorflow``(Character)
21 |
22 |
23 | Server and GPU number | 1S-1GPU | 1S-2GPUs | 1S-3GPUs | 1S-4GPUs | 2Ss-2GPUs | 2Ss-4GPUs | 2Ss-6GPUs | 2Ss-8GPUs |
24 | :-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:--------:|:--------:|
25 | training time(s/1 epoch) | 121.409 | 83.111 | 61.607 | 54.507 | 82.486 | 49.888 | 33.333 | 28.101 |
26 |
27 | ## The Reslut Analysis
28 | 1. As the character shown that the more GPUs are used and the training time is shorter. For example, we commpared their training time scale between using 1 server with 1 GPU and 1 server with 4 GPUs. Their training time scale is `1S-4GPUs:1S-1GPU=1:2.22`. Moreover,anoter set of data is recorded as `2Ss-8GPUs:1S-1GPU=1:4.3`. From them we can see, increasing the number of GPU when we train model can save training time and increase the efficiency.
29 |
30 | 2. The communication time is really short between difference server using `Horovod`. We have trained the same structure model respectively using 1 servers with 2 GPUs and using 2 servers with 1 GPU each and the training time scale is `1S-2GPUs:2Ss-2GPUs=1:1`.
31 |
32 |
--------------------------------------------------------------------------------
/docs/development/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing Guide
2 |
3 | ## License
4 |
5 | The source file should contain a license header. See the existing files as the example.
6 |
7 | ## Name style
8 |
9 | All name in python and cpp using [snake case style](https://en.wikipedia.org/wiki/Snake_case), except for `op` for `Tensorflow`.
10 | For Golang, using Camel-Case for `variable name` and `interface`.
11 |
12 | ## Python style
13 |
14 | Changes to Python code should conform the [Chromium Python Style Guide](https://chromium.googlesource.com/chromium/src/+/master/styleguide/python/python.md).
15 | You can use [yapf](https://github.com/google/yapf) to check the style.
16 | The style configuration is `.style.yapf`.
17 |
18 | ## C++ style
19 |
20 | Changes to C++ code should conform to [Google C++ Style Guide](https://github.com/google/styleguide).
21 | You can use [cpplint](https://github.com/google/styleguide/tree/gh-pages/cpplint) to check the style and use [clang-format](https://clang.llvm.org/docs/ClangFormat.html) to format the code.
22 | The style configuration is `.clang-format`.
23 |
24 | ## C++ macro
25 |
26 | C++ macros should start with `ATHENA_`, except for most common ones like `LOG` and `VLOG`.
27 |
28 | ## Golang style
29 |
30 | For Golang styple, please see docs below:
31 |
32 | * [How to Write Go Code](https://golang.org/doc/code.html)
33 | * [Effective Go](https://golang.org/doc/effective_go.html#interface-names)
34 | * [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments)
35 | * [Golang Style in Chinese](https://juejin.im/post/5c16f16c5188252dcb30ff42)
36 |
37 | Before commit golang code, plase using `go fmt` and `go vec` to format and lint code.
38 |
39 | ## Logging guideline
40 |
41 | For `python` using [abseil-py](https://github.com/abseil/abseil-py), [more info](https://abseil.io/docs/python/).
42 |
43 | For C++ using [abseil-cpp](https://github.com/abseil/abseil-cpp), [more info](https://abseil.io/docs/cpp/).
44 |
45 | For Golang using [glog](https://github.com/golang/glog).
46 |
47 | ## Unit test
48 |
49 | For `python` using `tf.test.TestCase`
50 |
51 | For C++ using [googletest](https://github.com/google/googletest)
52 |
53 | For Golang using `go test` for unittest.
54 |
--------------------------------------------------------------------------------
/docs/transform/img/DCT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/DCT.png
--------------------------------------------------------------------------------
/docs/transform/img/DFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/DFT.png
--------------------------------------------------------------------------------
/docs/transform/img/MFCC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/MFCC.png
--------------------------------------------------------------------------------
/docs/transform/img/MelBank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/MelBank.png
--------------------------------------------------------------------------------
/docs/transform/img/Mel_filter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/Mel_filter.png
--------------------------------------------------------------------------------
/docs/transform/img/amplitude_spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/amplitude_spectrum.png
--------------------------------------------------------------------------------
/docs/transform/img/audio_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/audio_data.png
--------------------------------------------------------------------------------
/docs/transform/img/delta.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/delta.png
--------------------------------------------------------------------------------
/docs/transform/img/fbank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/fbank.png
--------------------------------------------------------------------------------
/docs/transform/img/hamming.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/hamming.png
--------------------------------------------------------------------------------
/docs/transform/img/logMel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/logMel.png
--------------------------------------------------------------------------------
/docs/transform/img/logpower_spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/logpower_spectrum.png
--------------------------------------------------------------------------------
/docs/transform/img/mel_freq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/mel_freq.png
--------------------------------------------------------------------------------
/docs/transform/img/melbanks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/melbanks.png
--------------------------------------------------------------------------------
/docs/transform/img/phase_spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/phase_spectrum.png
--------------------------------------------------------------------------------
/docs/transform/img/power_spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/power_spectrum.png
--------------------------------------------------------------------------------
/docs/transform/img/spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum.png
--------------------------------------------------------------------------------
/docs/transform/img/spectrum_emph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum_emph.png
--------------------------------------------------------------------------------
/docs/transform/img/spectrum_orig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/docs/transform/img/spectrum_orig.png
--------------------------------------------------------------------------------
/docs/transform/speech_feature.md:
--------------------------------------------------------------------------------
1 | # 语音特征简介
2 |
3 | ## 语音数据
4 | 在进行特征处理前,我们需要从wav格式的音频文件获取音频数据和采样率。read_wav.py提供了ReadWav的接口,使用如下:
5 |
6 | config = {'sample_rate': 16000.0}
7 | wav_path = 'sm1_cln.wav''
8 | read_wav = ReadWav.params(config).instantiate()
9 | audio_data, sample_rate = read_wav(wav_path)
10 |
11 | 获得的音频数据时float类型的,范围为[-1, 1]。原始信号在时域上的表示如下图所示:
12 | 
13 |
14 | ## 短时傅里叶变换(STFT)
15 | 语音在时域上变化是快速且比较复杂的,我们常常需要采用傅里叶变换将其转换到频域上。对数字语音信号进行的傅里叶变换即为短时傅里叶变换(Short Time Fourier Transform, STFT)。STFT的一般分为预加重(可选)、分帧、加窗、傅里叶变换这几个步骤。
16 | ### 预加重
17 | 由于语音信号的平均功率谱受到声门激励和口鼻辐射的影响,高频段大约在800Hz以上以6dB/倍频程跌落,所以求语音信号频谱时,频率越高相应成分越小。预加重的目的在于提高高频成分,使信号的频谱变得平坦,保持在低频到高频的整个频带中能够用同样的信噪比求频谱,以便于频谱分析。一般通过一阶FIR高通滤波器来实现预加重,其传递函数为(a为预加重系数):
18 |
19 | y(n) = x(n) - ax(n-1), 0.9 < a < 1.0
20 | 语音信号经过预加重前后的频谱对比如下图所示(a=0.97)。
21 | 
22 | 
23 |
24 | ### 分帧
25 | 傅里叶变换要求信号是平稳的,但是语音在宏观上是不平稳的,所以我们首先需要将语音信号切割成较短固定长度的时间片段【即分帧】,此时语音信号可以看成是平稳的。帧长需要满足:①足够短保证帧内信号是平稳的,所以一帧的长度应小于一个音素的长度[50-200ms];②必须包含足够多的振动周期,而语音中的基频男声为100Hz(T=10ms),女声为200Hz(T=5ms),所以一般至少取20ms。综上,帧长一般为20-50ms,即frame_length=0.002-0.005(单位s)。kaldi默认帧长为0.025ms,帧移位0.010ms。
26 | ### 加窗
27 | 对时域信号进行截断,若非周期截断就会产生频谱泄露。为了降低频谱泄露误差,我们需要使用加权函数,即窗函数,使时域信号似乎更好地满足傅里叶变换的周期性要求。每帧信号与一个平滑的窗函数相乘,可以让帧两端平滑衰减到零,这样可以降低傅里叶变换后旁瓣的强度,取得更高质量的频谱。常见的窗函数有矩形窗、汉宁窗、汉明窗等。加窗的代价是一帧信号的两端部分被削弱了,为了避免信息丢失,我们需要对其进行弥补——不要背靠背地截取而是相互重叠一部分。相邻两帧的起始位置的时间差即为帧移,一般取帧长的0-0.5。即hop_length=(0-0.5) * frame_length。
28 | 常见的汉明窗对应的窗函数如下(N是语音总样本数,n是0-N的整数):
29 | 
30 |
31 | ### 快速傅里叶变换(FFT)
32 | 我们获取的是数字音频,因此我们需要进行离散傅里叶变换(DFT)。其数学公式如下:
33 |
34 | 
35 |
36 | 由于DFT的计算复杂度较高,我们通常使用的是快速傅里叶变换(FFT)。FFT的点数(Nfft)一般取2的整数次幂,且必须大于等于一帧信号的样本点数,如果帧长为400样本点,则Nfft必须大于等于512。帧长不等于Nfft时可以采用补零操作。频谱的分辨率精度=采样率/Nfft,第m个频点对应的频率f(m)=m × 频率分辨率。
37 | 对于单通道语音,我们获取到的是一个复数矩阵,维度为(num_Frequencies, num_frames),代表该语音信号是由num_Frequencies个不同相位的正弦波组合而成。对于每一个时频点,FFT的绝对值是该频率对应的正弦波的幅度,相位就是正弦波的初始相位,因此对应的频谱分别称为幅度谱和相位谱,如下图所示。
38 | 
39 | 
40 |
41 | ## 功率谱(Spectrum)
42 | 计算STFT后,我们得到了频域的信号。每个频带范围的能量大小不一,不同的音素的能量谱不一样。其计算公式为:
43 | 
44 |
45 | 常用的还有log谱,即logP = 10 × log(P)。功率谱和log功率谱如下图所示。
46 | 
47 | 
48 |
49 | ## FliterBank
50 | 人耳对声音频谱的响应是非线性的,人耳对于低频声音的分辨率要高于高频的声音。经验表明:如果我们能够设计一种前端处理算法,以类似于人耳的方式对音频进行处理,可以提高语音识别的性能。FilterBank就是这样的一种算法。在功率谱的基础上获取FBank特征,需要进行Mel滤波和取对数运算。
51 | ### Mel滤波
52 | 通过把频率转换成Mel频率,我们的特征就能够更好的匹配人类的听觉感知效果。频率f和Mel频率m的转换公式如下:
53 | 
54 | #### Mel滤波器组
55 | Mel滤波器组是一组大约数量为20-40(kaldi默认为23,MFCC为26)的三角滤波器,每一个三角窗滤波器覆盖的范围都近似于人耳的一个临界带宽。三角窗口可以覆盖从0到Nyquist的整个频率范围,但是通常我们会设定频率上限和下限,屏蔽掉某些不需要或者有噪声的频率范围。Mel滤波器有两种常见的形式:中心频率响应恒为1,三角形滤波器的面积随着带宽的变化而变化;随着宽的增加而改变高度,保证其面积不变。后一种的数学表达式为:
56 | 
57 | 式中m代表第m个滤波器;k代表横轴坐标,也就是自变量;f(m)代表第m个滤波器的中心点的横坐标值。其效果图如下:
58 | 
59 | Mel滤波器组有两个主要的作用:①对能量谱进行平滑化,并消除谐波的作用,突出语音的共振峰;②降低运算量。
60 | 采用Mel滤波器组对上一步得到的功率谱估计进行滤波,得到维数和Mel滤波器组三角形个数一致的特征向量,数学表达为:
61 | 
62 | ### 对数运算
63 | 这一步就是取上一步结果的对数,这样可以放大低能量处的能量差异。即,FBank特征为
64 | 
65 | FliterBank特征的效果图如下图所示(频率上限位8000Hz,下限为20Hz,特征维数为23):
66 | 
67 | ## 梅尔倒谱系数(MFCC)
68 | FBank特征已经很贴近人耳的响应特性,但是仍有一些不足:FBank特征相邻的特征高度相关(相邻滤波器组有重叠),因此当我们用HMM对音素建模的时候,几乎总需要首先进行离散余弦变换(discrete cosine transform,DCT),通过这样得到MFCC(Mel-scale FrequencyCepstral Coefficients)特征。DCT的实质是去除各维信号之间的相关性,将信号映射到低维空间。
69 | DCT的数学表达为:
70 | 
71 | N是FBank的特征维度,M 是 DCT(离散余弦变换)之后的特征维度。DCT对于一般的语音信号作用后,所得的结果的前几个系数特别大,后面的系数比较小,一次一般仅保留前12-20个,这样也进一步压缩了数据。MFCC的效果图如下:
72 | 
73 | FBank和MFCC的对比:①FBank特征相关性较高,而DNN/CNN可以更好的利用这些相关性,使用FBank特征可以更多地降低WER;②MFCC具有更好的判别度,而对于使用对角协方差矩阵的GMM由于忽略了不同特征维度的相关性,MFCC更适合用来做特征。
74 | ## 差分(delta)
75 | 标准的倒谱参数MFCC是针对一段语音信号进行特征提取,只反映了语音参数的静态特性,语音的动态特性可以用这些静态特征的差分谱来描述。实验证明:把动、静态特征结合起来才能有效提高系统的识别性能。差分参数的计算可以采用下面的公式(t是帧数,典型值为2):
76 | 
77 | ## 参考:
78 | https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
79 |
--------------------------------------------------------------------------------
/docs/transform/user_manual.md:
--------------------------------------------------------------------------------
1 | # User Manual
2 |
3 | ## Introduction
4 |
5 | Transform is a preprocess data toolkit.
6 |
7 | ## Usage
8 |
9 | Transform support speech feature extract:
10 |
11 | ### 1. Import module
12 |
13 | ```python
14 | from athena.transform import AudioFeaturizer
15 | ```
16 |
17 | ### 2. Init a feature extract object
18 |
19 | #### Read wav
20 |
21 | ```python
22 | conf = {'type':'ReadWav'}
23 | feature_ext = AudioFeaturizer()
24 |
25 | '''
26 | Other default args:
27 | 'audio_channels':1
28 |
29 | The shape of the output:
30 | [T, 1, 1]
31 | '''
32 | ```
33 |
34 | #### Spectrum
35 |
36 | ```python
37 | conf = {'type':'Spectrum'}
38 | feature_ext = AudioFeaturizer(conf)
39 |
40 | '''
41 | Other default args:
42 | 'sample_rate' : 16000
43 | 'window_length' : 0.025
44 | 'frame_length' : 0.010
45 | 'global_mean': 全局均值
46 | 'global_variance': 全局方差
47 | 'local_cmvn' : 默认是True, 做句子cmvn
48 |
49 | The shape of the output:
50 | [T, dim, 1]
51 | '''
52 | ```
53 |
54 | #### FliterBank
55 |
56 | ```python
57 | conf = {'type':'Fbank'}
58 | feature_ext = AudioFeaturizer(conf)
59 |
60 | '''
61 | Other default args:
62 | 'sample_rate' : 采样率 16000
63 | 'window_length' : 窗长 0.025秒
64 | 'frame_length' : 步长 0.010秒
65 | 'upper_frequency_limit' : 4000
66 | 'lower_frequency_limit': 20
67 | 'filterbank_channel_count' : 40
68 | 'delta_delta' : 是否做差分 False
69 | 'window' : 差分窗长 2
70 | 'order' : 差分阶数 2
71 | 'global_mean': 全局均值
72 | 'global_variance': 全局方差
73 | 'local_cmvn' : 默认是True, 做句子cmvn
74 |
75 | Returns:
76 | A tensor of shape [T, dim, num_channel].
77 | dim = 40
78 | num_channel = 1 if 'delta_delta' == False else 1 + 'order'
79 | '''
80 | ```
81 |
82 | #### CMVN
83 |
84 | ```python
85 | conf = {'type':'CMVN'}
86 | cmvn = AudioFeaturizer(conf)
87 |
88 | '''
89 | Other configuration
90 |
91 | 'global_mean': global cmvn
92 | 'global_variance': global variance
93 | 'local_cmvn' : default true
94 |
95 | 'global_mean'和'global_variance'如果设置则会做全局cmvn,否则不做全局cmvn。
96 | 'local_cmvn' 设置False不做句子cmvn
97 | '''
98 | ```
99 |
100 | ### 3. Feature extract
101 |
102 | ```python
103 | feat = feature_ext(audio)
104 | '''
105 | audio : Audio file or audio data.
106 | feat : A tensor containing speech feature.
107 | '''
108 | ```
109 |
110 | ### 4. Get feature dim and the number of channels
111 |
112 | ```python
113 | dim = feature_ext.dim
114 | # dim : A int scalar, it is feature dim.
115 | num_channels = feature_ext.num_channels
116 | # num_channels: A int scalar, it is the number of channels of the output feature
117 | # the shape of the output features: [None, dim, num_channels]
118 | ```
119 |
120 | ### 5. Example
121 |
122 | #### 5.1 Extract speech feature filterbank from audio file:
123 |
124 | ```python
125 | import tensorflow as tf
126 | from transform.audio_featurizer import AudioFeaturizer
127 |
128 | audio_file = 'englist.wav'
129 | conf = {'type':'Fbank'
130 | 'sample_rate':8000,
131 | 'delta_delta': True
132 | }
133 | feature_ext = AudioFeaturizer(conf)
134 | dim = feature_ext.dim
135 | print('Dim size is ', dim)
136 | num_channels = feature_ext.num_channels
137 | print('num_channels is ', num_channels)
138 |
139 | feat = feature_ext(audio_file)
140 | with tf.Session() as sess:
141 | fbank = sess.run(feat)
142 | print('Fbank shape is ', fbank.shape)
143 | ```
144 |
145 | Result:
146 |
147 | ```
148 | Dim is 40
149 | Fbank shape is (346, 40, 3) # [T, D, C]
150 | ```
151 |
152 | #### 5.2 CMVN usage:
153 |
154 | ```python
155 | import tensorflow as tf
156 | from transform.audio_featurizer import AudioFeaturizer
157 |
158 | dim = 23
159 | config = {'type': 'CMVN',
160 | 'global_mean': np.zeros(dim).tolist(),
161 | 'global_variance': np.ones(dim).tolist(),
162 | 'local_cmvn': True}
163 | cmvn = AudioFeaturizer(config)
164 | audio_feature = tf.compat.v1.random_uniform(shape=[10, dim], dtype=tf.float32, maxval=1.0)
165 | print('cmvn : ', cmvn(audio_feature))
166 | ```
167 |
--------------------------------------------------------------------------------
/docs/using_docker.md:
--------------------------------------------------------------------------------
1 | # Installation using Docker
2 |
3 | ## Install Docker
4 |
5 | Make sure `docker` has been installed. You can refer to the [official tutorial](https://docs.docker.com/install/).
6 |
7 | Install NVIDIA Container Toolkit if using GPU. You can refer to [nvidia-docker](https://github.com/NVIDIA/nvidia-docker)
8 |
9 | ## Docker image
10 |
11 | You can build image locally or using pre-build images, pre-built Docker images are available on DockerHub.
12 |
13 | ### Pull Docker Image
14 |
15 | You can directly pull the pre-build docker images for athena. We have created the following docker images:
16 |
17 | [athena-tf-2.1.0-gpu-py3](https://hub.docker.com/repository/docker/garygao99/athena/)
18 |
19 | Download the image as below:
20 | ```
21 | docker pull garygao99/athena:tf-2.1.0-gpu-py3
22 | ```
23 |
24 | ### Build images
25 |
26 | You can build image locally as below:
27 | ```
28 | docker build -t garygao99/athena:tf-2.1.0-gpu-py3 .
29 | ```
30 |
31 | ## Create Container
32 |
33 | After the image downloaded, create a container.
34 |
35 | ```bash
36 | docker run -it --gpus all garygao99/athena:tf-2.1.0-gpu-py3 /bin/bash
37 | ```
38 |
--------------------------------------------------------------------------------
/examples/asr/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Examples for HKUST
3 |
4 | ## 1 Transformer
5 |
6 | ```bash
7 | source env.sh
8 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust
9 | python athena/main.py examples/asr/hkust/transformer.json
10 | ```
11 |
12 | ## 2 MTL_Transformer_CTC
13 |
14 | ```bash
15 | source env.sh
16 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust
17 | python athena/main.py examples/asr/hkust/mtl_transformer.json
18 | ```
19 |
20 | # Examples for Librispeech
21 |
22 | TODO: need test
23 |
--------------------------------------------------------------------------------
/examples/asr/aishell/mtl_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":15,
4 | "sorta_epoch":1,
5 | "ckpt":"examples/asr/aishell/ckpts/mtl_transformer_ctc/",
6 | "summary_dir":"examples/asr/aishell/ckpts/mtl_transformer_ctc/event",
7 |
8 | "solver_gpu":[2],
9 | "solver_config":{
10 | "clip_norm":100,
11 | "log_interval":10,
12 | "enable_tf_function":true
13 | },
14 |
15 | "model":"mtl_transformer_ctc",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "model":"speech_transformer",
20 | "model_config":{
21 | "return_encoder_output":true,
22 | "num_filters":512,
23 | "d_model":512,
24 | "num_heads":8,
25 | "num_encoder_layers":12,
26 | "num_decoder_layers":6,
27 | "dff":1280,
28 | "rate":0.1,
29 | "label_smoothing_rate":0.0,
30 | "schedual_sampling_rate":0.9
31 | },
32 | "mtl_weight":0.5
33 | },
34 |
35 | "decode_config":{
36 | "beam_search":true,
37 | "beam_size":4,
38 | "ctc_weight":0.3,
39 | "lm_weight":0.1,
40 | "lm_path":"examples/asr/aishell/data/lm.bin"
41 | },
42 |
43 | "optimizer":"warmup_adam",
44 | "optimizer_config":{
45 | "d_model":512,
46 | "warmup_steps":8000,
47 | "k":0.5,
48 | "decay_steps": 50000,
49 | "decay_rate": 0.1
50 | },
51 |
52 | "dataset_builder": "speech_recognition_dataset",
53 | "dataset_config":{
54 | "audio_config":{
55 | "type":"Fbank",
56 | "filterbank_channel_count":40,
57 | "local_cmvn":false
58 | },
59 | "cmvn_file":"examples/asr/aishell/data/cmvn",
60 | "text_config": {
61 | "type":"vocab",
62 | "model":"examples/asr/aishell/data/vocab"
63 | },
64 | "input_length_range":[10, 8000]
65 | },
66 | "num_data_threads": 1,
67 | "train_csv":"examples/asr/aishell/data/train.csv",
68 | "dev_csv":"examples/asr/aishell/data/dev.csv",
69 | "test_csv":"examples/asr/aishell/data/test.csv"
70 | }
71 |
--------------------------------------------------------------------------------
/examples/asr/aishell/mtl_transformer_sp.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":8,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/aishell/ckpts/mtl_transformer_ctc/",
6 | "summary_dir":"examples/asr/aishell/ckpts/mtl_transformer_ctc/event",
7 |
8 | "solver_gpu":[2],
9 | "solver_config":{
10 | "clip_norm":100,
11 | "log_interval":10,
12 | "enable_tf_function":true
13 | },
14 |
15 | "model":"mtl_transformer_ctc",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "model":"speech_transformer",
20 | "model_config":{
21 | "return_encoder_output":true,
22 | "num_filters":512,
23 | "d_model":512,
24 | "num_heads":8,
25 | "num_encoder_layers":12,
26 | "num_decoder_layers":6,
27 | "dff":1280,
28 | "rate":0.1,
29 | "label_smoothing_rate":0.0,
30 | "schedual_sampling_rate":0.9
31 | },
32 | "mtl_weight":0.5
33 | },
34 |
35 | "decode_config":{
36 | "beam_search":true,
37 | "beam_size":4,
38 | "ctc_weight":0.3,
39 | "lm_weight":0.1,
40 | "lm_path":"examples/asr/aishell/data/lm.bin"
41 | },
42 |
43 | "optimizer":"warmup_adam",
44 | "optimizer_config":{
45 | "d_model":512,
46 | "warmup_steps":3500,
47 | "k":0.7,
48 | "decay_steps": 55050,
49 | "decay_rate": 0.1
50 | },
51 |
52 | "dataset_builder": "speech_recognition_dataset",
53 | "dataset_config":{
54 | "audio_config":{
55 | "type":"Fbank",
56 | "filterbank_channel_count":40,
57 | "local_cmvn":false
58 | },
59 | "cmvn_file":"examples/asr/aishell/data/cmvn",
60 | "text_config": {
61 | "type":"vocab",
62 | "model":"examples/asr/aishell/data/vocab"
63 | },
64 | "input_length_range":[10, 8000],
65 | "speed_permutation": [0.9, 1.0, 1.1]
66 | },
67 | "num_data_threads": 1,
68 | "train_csv":"examples/asr/aishell/data/train.csv",
69 | "dev_csv":"examples/asr/aishell/data/dev.csv",
70 | "test_csv":"examples/asr/aishell/data/test.csv"
71 | }
72 |
--------------------------------------------------------------------------------
/examples/asr/aishell/rnnlm.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":16,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/aishell/ckpts/rnnlm",
6 |
7 | "solver_gpu":[1],
8 | "solver_config":{
9 | "clip_norm":100,
10 | "log_interval":10,
11 | "enable_tf_function":true
12 | },
13 |
14 |
15 | "model":"rnnlm",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "d_model": 512,
20 | "rnn_type": "lstm",
21 | "num_layer": 4,
22 | "dropout_rate": 0.1,
23 | "sos": -1,
24 | "eos": -1
25 | },
26 |
27 | "optimizer":"warmup_adam",
28 | "optimizer_config":{
29 | "d_model":512,
30 | "warmup_steps":3500,
31 | "k":0.7,
32 | "decay_steps": 56296,
33 | "decay_rate": 0.1
34 | },
35 |
36 | "dataset_builder": "language_dataset",
37 | "dataset_config":{
38 | "input_text_config":{
39 | "type":"vocab",
40 | "model":"examples/asr/aishell/data/vocab"
41 | },
42 | "output_text_config":{
43 | "type":"vocab",
44 | "model":"examples/asr/aishell/data/vocab"
45 | }
46 | },
47 | "num_data_threads": 1,
48 | "train_csv":"examples/asr/aishell/data/train.trans.csv",
49 | "dev_csv":"examples/asr/aishell/data/dev.trans.csv",
50 | "test_csv":"examples/asr/aishell/data/test.trans.csv"
51 | }
52 |
53 |
--------------------------------------------------------------------------------
/examples/asr/hkust/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Examples for HKUST
3 |
4 | ## 1 Transformer
5 |
6 | ```bash
7 | source env.sh
8 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust
9 | python athena/main.py examples/asr/hkust/transformer.json
10 | ```
11 |
12 | ## 2 MTL_Transformer_CTC
13 |
14 | ```bash
15 | source env.sh
16 | python examples/asr/hkust/prepare_data.py /tmp-data/dataset/opensource/hkust
17 | python athena/main.py examples/asr/hkust/mtl_transformer.json
18 | ```
19 |
--------------------------------------------------------------------------------
/examples/asr/hkust/deep_speech.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":20,
4 | "sorta_epoch":1,
5 | "ckpt":"examples/asr/hkust/ckpts/deep_speech",
6 | "solver_gpu":[2],
7 | "solver_config":{
8 | "clip_norm":100.0,
9 | "log_interval":1
10 | },
11 |
12 | "model":"deep_speech",
13 | "num_classes": null,
14 | "pretrained_model": null,
15 | "model_config":{
16 | "conv_filters":64,
17 | "rnn_hidden_size":1680,
18 | "rnn_type":"cudnngru",
19 | "num_rnn_layers":6
20 | },
21 |
22 | "optimizer":"warmup_adam",
23 | "optimizer_config":{
24 | "d_model":512,
25 | "warmup_steps":8000,
26 | "k":0.5
27 | },
28 |
29 | "dataset_builder": "speech_recognition_dataset",
30 | "dataset_config":{
31 | "audio_config":{
32 | "type":"Fbank",
33 | "filterbank_channel_count":40,
34 | "local_cmvn":false
35 | },
36 | "cmvn_file":"examples/asr/hkust/data/cmvn",
37 | "text_config": {
38 | "type":"vocab",
39 | "model":"examples/asr/hkust/data/vocab"
40 | },
41 | "input_length_range":[10, 8000]
42 | },
43 | "num_data_threads": 1,
44 | "train_csv":"examples/asr/hkust/data/train.csv",
45 | "dev_csv":"examples/asr/hkust/data/dev.csv",
46 | "test_csv":"examples/asr/hkust/data/dev.csv"
47 | }
--------------------------------------------------------------------------------
/examples/asr/hkust/local/segment_word.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """ segment word """
18 |
19 | import codecs
20 | import sys
21 | import re
22 | from absl import logging
23 | import jieba
24 |
25 |
26 | def segment_trans(vocab_file, text_file):
27 | ''' segment transcripts according to vocab
28 | using Maximum Matching Algorithm
29 | Args:
30 | vocab_file: vocab file
31 | text_file: transcripts file
32 | Returns:
33 | seg_trans: segment words
34 | '''
35 | jieba.set_dictionary(vocab_file)
36 | with open(text_file, "r", encoding="utf-8") as text:
37 | lines = text.readlines()
38 | sents = ''
39 | for line in lines:
40 | seg_line = jieba.cut(line.strip(), HMM=False)
41 | seg_line = ' '.join(seg_line)
42 | sents += seg_line + '\n'
43 | return sents
44 |
45 |
46 | if __name__ == "__main__":
47 | logging.set_verbosity(logging.INFO)
48 | sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
49 | if len(sys.argv) < 3:
50 | logging.warning('Usage: python {} vocab_file text_file'.format(sys.argv[0]))
51 | sys.exit()
52 | print(segment_trans(sys.argv[1], sys.argv[2]))
53 |
--------------------------------------------------------------------------------
/examples/asr/hkust/mpc.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":64,
3 | "num_epochs":60,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/hkust/ckpts/mpc",
6 | "summary_dir":"examples/asr/hkust/ckpts/mpc/event",
7 |
8 | "solver_gpu":[0],
9 | "solver_config":{
10 | "clip_norm":100,
11 | "log_interval":10,
12 | "enable_tf_function":true
13 | },
14 |
15 | "model": "mpc",
16 | "num_classes": 40,
17 | "model_config":{
18 | "return_encoder_output":false,
19 | "num_filters":512,
20 | "d_model":512,
21 | "num_heads":8,
22 | "num_encoder_layers":12,
23 | "dff":1280,
24 | "rate":0.1,
25 | "chunk_size":1,
26 | "keep_probability":0.8
27 | },
28 |
29 | "optimizer": "warmup_adam",
30 | "optimizer_config":{
31 | "d_model":512,
32 | "warmup_steps":5000,
33 | "k":0.5
34 | },
35 |
36 | "dataset_builder":"speech_dataset",
37 | "dataset_config":{
38 | "audio_config":{
39 | "type":"Fbank",
40 | "filterbank_channel_count":40,
41 | "local_cmvn":false
42 | },
43 | "cmvn_file":"examples/asr/hkust/data/cmvn",
44 | "input_length_range":[10, 8000]
45 | },
46 | "num_data_threads": 1,
47 | "train_csv":"examples/asr/hkust/data/train.csv",
48 | "dev_csv":"examples/asr/hkust/data/dev.csv"
49 | }
50 |
--------------------------------------------------------------------------------
/examples/asr/hkust/mtl_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":8,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/hkust/ckpts/mtl_transformer_ctc/",
6 | "summary_dir":"examples/asr/hkust/ckpts/mtl_transformer_ctc/event",
7 |
8 | "solver_gpu":[0],
9 | "solver_config":{
10 | "clip_norm":100,
11 | "log_interval":10,
12 | "enable_tf_function":true
13 | },
14 |
15 | "model":"mtl_transformer_ctc",
16 | "num_classes": null,
17 | "pretrained_model": "examples/asr/hkust/mpc.json",
18 | "model_config":{
19 | "model":"speech_transformer",
20 | "model_config":{
21 | "return_encoder_output":true,
22 | "num_filters":512,
23 | "d_model":512,
24 | "num_heads":8,
25 | "num_encoder_layers":12,
26 | "num_decoder_layers":6,
27 | "dff":1280,
28 | "rate":0.1,
29 | "label_smoothing_rate":0.0,
30 | "schedual_sampling_rate":0.9
31 | },
32 | "mtl_weight":0.5
33 | },
34 |
35 | "decode_config":{
36 | "beam_search":true,
37 | "beam_size":4,
38 | "ctc_weight":0.3,
39 | "lm_weight":0.1,
40 | "lm_path":"examples/asr/hkust/data/4gram.arpa"
41 | },
42 |
43 | "optimizer":"warmup_adam",
44 | "optimizer_config":{
45 | "d_model":512,
46 | "warmup_steps":8000,
47 | "k":0.5,
48 | "decay_steps":22000,
49 | "decay_rate":0.1
50 | },
51 |
52 | "dataset_builder": "speech_recognition_dataset",
53 | "dataset_config":{
54 | "audio_config":{
55 | "type":"Fbank",
56 | "filterbank_channel_count":40,
57 | "local_cmvn":false
58 | },
59 | "cmvn_file":"examples/asr/hkust/data/cmvn",
60 | "text_config": {
61 | "type":"vocab",
62 | "model":"examples/asr/hkust/data/vocab"
63 | },
64 | "speed_permutation":[0.9, 1.0, 1.1],
65 | "input_length_range":[10, 8000]
66 | },
67 | "num_data_threads": 1,
68 | "train_csv":"examples/asr/hkust/data/train.csv",
69 | "dev_csv":"examples/asr/hkust/data/dev.csv",
70 | "test_csv":"examples/asr/hkust/data/dev.csv"
71 | }
72 |
--------------------------------------------------------------------------------
/examples/asr/hkust/mtl_transformer_sp.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":15,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/hkust/ckpts/mtl_transformer_ctc_sp/",
6 | "summary_dir":"examples/asr/hkust/ckpts/mtl_transformer_ctc_sp/event",
7 |
8 | "solver_gpu":[0],
9 | "solver_config":{
10 | "clip_norm":100,
11 | "log_interval":10,
12 | "enable_tf_function":true
13 | },
14 |
15 | "model":"mtl_transformer_ctc",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "model":"speech_transformer",
20 | "model_config":{
21 | "return_encoder_output":true,
22 | "num_filters":512,
23 | "d_model":512,
24 | "num_heads":8,
25 | "num_encoder_layers":12,
26 | "num_decoder_layers":6,
27 | "dff":1280,
28 | "rate":0.1,
29 | "label_smoothing_rate":0.0,
30 | "schedual_sampling_rate":0.9
31 | },
32 | "mtl_weight":0.5
33 | },
34 |
35 | "decode_config":{
36 | "beam_search":true,
37 | "beam_size":4,
38 | "ctc_weight":0.3,
39 | "lm_weight":0.1,
40 | "lm_path":"examples/asr/hkust/data/4gram.arpa"
41 | },
42 |
43 | "optimizer":"warmup_adam",
44 | "optimizer_config":{
45 | "d_model":512,
46 | "warmup_steps":8000,
47 | "k":0.5,
48 | "decay_steps": 140000,
49 | "decay_rate": 0.1
50 | },
51 |
52 | "dataset_builder": "speech_recognition_dataset",
53 | "dataset_config":{
54 | "audio_config":{
55 | "type":"Fbank",
56 | "filterbank_channel_count":40,
57 | "local_cmvn":false
58 | },
59 | "cmvn_file":"examples/asr/hkust/data/cmvn",
60 | "text_config": {
61 | "type":"vocab",
62 | "model":"examples/asr/hkust/data/vocab"
63 | },
64 | "input_length_range":[10, 8000],
65 | "speed_permutation": [0.9, 1.0, 1.1]
66 | },
67 | "num_data_threads": 1,
68 | "train_csv":"examples/asr/hkust/data/train.csv",
69 | "dev_csv":"examples/asr/hkust/data/dev.csv",
70 | "test_csv":"examples/asr/hkust/data/dev.csv"
71 | }
72 |
--------------------------------------------------------------------------------
/examples/asr/hkust/rnnlm.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":100,
4 | "sorta_epoch":0,
5 | "ckpt":"examples/asr/hkust/ckpts/rnnlm",
6 |
7 | "solver_gpu":[0],
8 | "solver_config":{
9 | "clip_norm":100,
10 | "log_interval":10,
11 | "enable_tf_function":true
12 | },
13 |
14 |
15 | "model":"rnnlm",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "d_model": 512,
20 | "rnn_type": "lstm",
21 | "num_layer": 4,
22 | "dropout_rate": 0.1,
23 | "sos": -1,
24 | "eos": -1
25 | },
26 |
27 | "optimizer":"warmup_adam",
28 | "optimizer_config":{
29 | "d_model":512,
30 | "warmup_steps":8000,
31 | "k":0.5
32 | },
33 |
34 | "dataset_builder": "language_dataset",
35 | "dataset_config":{
36 | "input_text_config":{
37 | "type":"vocab",
38 | "model":"examples/asr/hkust/data/vocab"
39 | },
40 | "output_text_config":{
41 | "type":"vocab",
42 | "model":"examples/asr/hkust/data/vocab"
43 | }
44 | },
45 | "num_data_threads": 1,
46 | "train_csv":"examples/asr/hkust/data/train.trans.csv",
47 | "dev_csv":"examples/asr/hkust/data/dev.trans.csv",
48 | "test_csv":"examples/asr/hkust/data/dev.trans.csv"
49 | }
--------------------------------------------------------------------------------
/examples/asr/hkust/run.sh:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) ATHENA AUTHORS
3 | # All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | if [ "athena" != $(basename "$PWD") ]; then
19 | echo "You should run this script in athena directory!!"
20 | exit 1
21 | fi
22 |
23 | source tools/env.sh
24 |
25 | stage=0
26 | stop_stage=100
27 |
28 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
29 | # prepare data
30 | echo "Preparing data"
31 | python examples/asr/hkust/local/prepare_data.py /nfs/project/datasets/opensource_data/hkust
32 | mkdir -p examples/asr/hkust/data
33 | cp /nfs/project/datasets/opensource_data/hkust/{train,dev}.csv examples/asr/hkust/data/
34 |
35 | # cal cmvn
36 | cat examples/asr/hkust/data/train.csv > examples/asr/hkust/data/all.csv
37 | tail -n +2 examples/asr/hkust/data/dev.csv >> examples/asr/hkust/data/all.csv
38 | python athena/cmvn_main.py examples/asr/hkust/mpc.json examples/asr/hkust/data/all.csv
39 | fi
40 |
41 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
42 | # pretrain stage
43 | echo "Pretraining"
44 | # we recommend training with multi-gpu, for single gpu, run "python athena/main.py examples/asr/hkust/mpc.json" instead
45 | horovodrun -np 4 -H localhost:4 python athena/horovod_main.py examples/asr/hkust/mpc.json
46 | fi
47 |
48 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
49 | # finetuning stage
50 | echo "Fine-tuning"
51 | # we recommend training with multi-gpu, for single gpu, run "python athena/main.py examples/asr/hkust/mtl_transformer.json" instead
52 | horovodrun -np 4 -H localhost:4 python athena/horovod_main.py examples/asr/hkust/mtl_transformer.json
53 | fi
54 |
55 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
56 | # decoding stage
57 | echo "Decoding"
58 | # prepare language model
59 | tail -n +2 examples/asr/hkust/data/train.csv | cut -f 3 > examples/asr/hkust/data/text
60 | python examples/asr/hkust/local/segment_word.py examples/asr/hkust/data/vocab \
61 | examples/asr/hkust/data/text > examples/asr/hkust/data/text.seg
62 | tools/kenlm/build/bin/lmplz -o 4 < examples/asr/hkust/data/text.seg > examples/asr/hkust/data/4gram.arpa
63 |
64 | python athena/decode_main.py examples/asr/hkust/mtl_transformer.json
65 | fi
66 |
67 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
68 | echo "training rnnlm"
69 | tail -n +2 examples/asr/hkust/data/train.csv | awk '{print $3"\t"$3}' > examples/asr/hkust/data/train.trans.csv
70 | tail -n +2 examples/asr/hkust/data/dev.csv | awk '{print $3"\t"$3}' > examples/asr/hkust/data/dev.trans.csv
71 | python athena/main.py examples/asr/hkust/rnnlm.json
72 | fi
73 |
74 |
--------------------------------------------------------------------------------
/examples/asr/hkust/transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":32,
3 | "num_epochs":20,
4 | "sorta_epoch":1,
5 | "ckpt":"examples/asr/hkust/ckpts/transformer",
6 |
7 | "solver_gpu":[0],
8 | "solver_config":{
9 | "clip_norm":100,
10 | "log_interval":10,
11 | "enable_tf_function":true
12 | },
13 |
14 |
15 | "model":"speech_transformer",
16 | "num_classes": null,
17 | "pretrained_model": null,
18 | "model_config":{
19 | "return_encoder_output":false,
20 | "num_filters":512,
21 | "d_model":512,
22 | "num_heads":8,
23 | "num_encoder_layers":12,
24 | "num_decoder_layers":6,
25 | "dff":1280,
26 | "rate":0.1,
27 | "label_smoothing_rate":0.0
28 | },
29 |
30 | "optimizer":"warmup_adam",
31 | "optimizer_config":{
32 | "d_model":512,
33 | "warmup_steps":8000,
34 | "k":0.5
35 | },
36 |
37 | "dataset_builder": "speech_recognition_dataset",
38 | "dataset_config":{
39 | "audio_config":{
40 | "type":"Fbank",
41 | "filterbank_channel_count":40,
42 | "local_cmvn":false
43 | },
44 | "text_config": {
45 | "type":"vocab",
46 | "model":"examples/asr/hkust/data/vocab"
47 | },
48 | "cmvn_file":"examples/asr/hkust/data/cmvn",
49 | "input_length_range":[10, 5000]
50 | },
51 | "num_data_threads": 1,
52 | "train_csv":"examples/asr/hkust/data/train.csv",
53 | "dev_csv":"examples/asr/hkust/data/dev.csv",
54 | "test_csv":"examples/asr/hkust/data/dev.csv"
55 | }
--------------------------------------------------------------------------------
/examples/asr/librispeech/data/librispeech_unigram5000.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/didi/athena/b11aea228b8f54430f0c43df7a20259c54691aee/examples/asr/librispeech/data/librispeech_unigram5000.model
--------------------------------------------------------------------------------
/examples/asr/librispeech/transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size":16,
3 | "num_epochs":50,
4 | "sorta_epoch":2,
5 | "ckpt":"examples/asr/librispeech/ckpts/debug",
6 | "solver_gpu":[0],
7 | "solver_config":{
8 | "clip_norm":100.0,
9 | "log_interval":1
10 | },
11 |
12 | "model":"speech_transformer",
13 | "model_config":{
14 | "return_encoder_output":false,
15 | "num_filters":512,
16 | "d_model":512,
17 | "num_heads":8,
18 | "num_encoder_layers":12,
19 | "num_decoder_layers":6,
20 | "dff":1280,
21 | "rate":0.1,
22 | "label_smoothing_rate":0.0,
23 | "schedual_sampling_rate":0.9
24 | },
25 |
26 | "optimizer":"warmup_adam",
27 | "optimizer_config":{
28 | "d_model":512,
29 | "warmup_steps":8000,
30 | "k":0.5
31 | },
32 |
33 | "trainset_config":{
34 | "data_csv":"/tmp-data/dataset/opensource/librispeech/wav/train-clean-100.csv",
35 | "audio_config":{
36 | "type":"Fbank",
37 | "filterbank_channel_count":40,
38 | "local_cmvn":false
39 | },
40 | "vocab_file":"examples/asr/hkust/data/librispeech_unigram5000.model",
41 | "subword":true,
42 | "audio_min_length":10,
43 | "audio_max_length":10000,
44 | "force_process":false
45 | },
46 | "cmvn_file":"examples/asr/librispeech/data/cmvn",
47 | "dev_csv":"/tmp-data/dataset/opensource/librispeech/wav/dev-clean.csv",
48 | "test_csv":"/tmp-data/dataset/opensource/librispeech/wav/test-clean.csv"
49 | }
50 |
--------------------------------------------------------------------------------
/examples/asr/magic_data/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ magic_data dataset """
17 |
18 | import os
19 | import sys
20 | import codecs
21 | import pandas
22 | from absl import logging
23 |
24 | import tensorflow as tf
25 | from athena import get_wave_file_length
26 |
27 | SUBSETS = ["train", "dev"]
28 |
29 | def convert_audio_and_split_transcript(directory, subset, out_csv_file):
30 |
31 | gfile = tf.compat.v1.gfile
32 | logging.info("Processing audio and transcript for {}".format(subset))
33 | audio_dir = os.path.join(directory, subset)
34 | trans_dir = os.path.join(directory, subset, "TRANS.txt")
35 |
36 | files = []
37 | with codecs.open(trans_dir,'r', encoding='utf-8') as f:
38 | lines = f.readlines()
39 | for line in lines[1:]:
40 | items = line.strip().split('\t')
41 | wav_filename = items[0]
42 | labels = items[2]
43 | speaker = items[1]
44 | files.append((wav_filename, speaker, labels))
45 | files_size_dict = {}
46 | for root, subdirs, _ in gfile.Walk(audio_dir):
47 | for subdir in subdirs:
48 | for filename in os.listdir(os.path.join(root, subdir)):
49 | files_size_dict[filename] = (
50 | get_wave_file_length(os.path.join(root, subdir, filename)),
51 | subdir,
52 | )
53 | content = []
54 | for wav_filename, speaker, trans in files:
55 | if wav_filename in files_size_dict:
56 | filesize, subdir = files_size_dict[wav_filename]
57 | abspath = os.path.join(audio_dir, subdir, wav_filename)
58 | content.append((abspath, filesize, trans, speaker))
59 |
60 | files = content
61 | df = pandas.DataFrame(
62 | data=files, columns=["wav_filename", "wav_length_ms", "transcript", "speaker"]
63 | )
64 | df.to_csv(out_csv_file, index=False, sep="\t")
65 | logging.info("Successfully generated csv file {}".format(out_csv_file))
66 |
67 | def processor(dircetory, subset, force_process):
68 | """ download and process """
69 | if subset not in SUBSETS:
70 | raise ValueError(subset, "is not in magic_data")
71 | if force_process:
72 | logging.info("force process is set to be true")
73 |
74 | subset_csv = os.path.join(dircetory, subset + ".csv")
75 | if not force_process and os.path.exists(subset_csv):
76 | logging.info("{} already exist".format(subset_csv))
77 | return subset_csv
78 | logging.info("Processing the magic_data subset {} in {}".format(subset, dircetory))
79 | convert_audio_and_split_transcript(dircetory, subset, subset_csv)
80 | logging.info("Finished processing magic_data subset {}".format(subset))
81 | return subset_csv
82 |
83 | if __name__ == "__main__":
84 | logging.set_verbosity(logging.INFO)
85 | DIR = sys.argv[1]
86 | for SUBSET in SUBSETS:
87 | processor(DIR, SUBSET, True)
88 |
89 |
--------------------------------------------------------------------------------
/examples/asr/primewords/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """ primewords dataset """
17 |
18 | import os
19 | import sys
20 | import codecs
21 | import pandas
22 | from absl import logging
23 |
24 | import tensorflow as tf
25 | # from athena import get_wave_file_length
26 |
27 | def convert_audio_and_split_transcript(directory, out_csv_file):
28 |
29 | gfile = tf.compat.v1.gfile
30 | logging.info("Processing audio and transcript for {}".format('primewords'))
31 | audio_dir = os.path.join(directory, "audio_files")
32 | trans_dir = os.path.join(directory, "set1_transcript.json")
33 |
34 | files = []
35 | with codecs.open(trans_dir, 'r', encoding='utf-8') as f:
36 | items = eval(f.readline())
37 | for item in items:
38 | wav_filename = item['file']
39 | # labels = item['text']
40 |
41 | labels = ''.join([x for x in item['text'].split()])
42 | # speaker = item['user_id']
43 | files.append((wav_filename, labels))
44 |
45 | files_size_dict = {}
46 | for subdir in os.listdir(audio_dir):
47 | for root, subsubdirs, _ in gfile.Walk(os.path.join(audio_dir, subdir)):
48 | for subsubdir in subsubdirs:
49 | for filename in os.listdir(os.path.join(root, subsubdir)):
50 | files_size_dict[filename] = (
51 | os.path.join(root, subsubdir, filename),
52 | root,
53 | subsubdir
54 | )
55 | # print(os.path.join(root,subsubdir,filename))
56 | content = []
57 | for wav_filename, trans in files:
58 | if wav_filename in files_size_dict:
59 | filesize, root, subsubdir = files_size_dict[wav_filename]
60 | abspath = os.path.join(root, subsubdir, wav_filename)
61 | content.append((abspath, filesize, trans, None))
62 | files = content
63 | df = pandas.DataFrame(
64 | data=files, columns=["wav_filename", "wav_length_ms", "transcript", "speaker"]
65 | )
66 |
67 | df.to_csv(out_csv_file, index=False, sep="\t")
68 | logging.info("Successfully generated csv file {}".format(out_csv_file))
69 |
70 | def processor(dircetory, force_process):
71 | if force_process:
72 | logging.info("force process is set to be true")
73 |
74 | subset_csv = os.path.join(dircetory, 'train' + ".csv")
75 | if not force_process and os.path.exists(subset_csv):
76 | logging.info("{} already exist".format(subset_csv))
77 | return subset_csv
78 | logging.info("Processing the primewords subset {} in {}".format('train', dircetory))
79 | convert_audio_and_split_transcript(dircetory, subset_csv)
80 | logging.info("Finished processing primewords subset {}".format('train'))
81 | return subset_csv
82 |
83 | if __name__ == '__main__':
84 |
85 | logging.set_verbosity(logging.INFO)
86 | DIR = sys.argv[1]
87 | processor(DIR,True)
88 |
89 |
--------------------------------------------------------------------------------
/examples/translate/spa-eng-example/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Only support eager mode
2 | # pylint: disable=invalid-name
3 |
4 | """ An example using translation dataset """
5 | import os
6 | import re
7 | import io
8 | import sys
9 | import unicodedata
10 | import pandas
11 | import tensorflow as tf
12 | from absl import logging
13 | from athena import LanguageDatasetBuilder
14 |
15 |
16 | def preprocess_sentence(w):
17 | """ preprocess_sentence """
18 | w = ''.join(c for c in unicodedata.normalize('NFD', w.lower().strip())
19 | if unicodedata.category(c) != 'Mn')
20 | w = re.sub(r"([?.!,?])", r" \1 ", w)
21 | w = re.sub(r'[" "]+', " ", w)
22 | w = re.sub(r"[^a-zA-Z?.!,?]+", " ", w)
23 | w = w.rstrip().strip()
24 | return w
25 |
26 | def create_dataset(csv_file="examples/translate/spa-eng-examples/data/train.csv"):
27 | """ create and store in csv file """
28 | path_to_zip = tf.keras.utils.get_file('spa-eng.zip',
29 | origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
30 | extract=True)
31 | path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"
32 | lines = io.open(path_to_file, encoding='UTF-8').read().strip().split('\n')
33 | word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines[:None]]
34 | df = pandas.DataFrame(
35 | data=word_pairs, columns=["input_transcripts", "output_transcripts"]
36 | )
37 | csv_dir = os.path.dirname(csv_file)
38 | if not os.path.exists(csv_dir):
39 | os.mkdir(csv_dir)
40 | df.to_csv(csv_file, index=False, sep="\t")
41 | logging.info("Successfully generated csv file {}".format(csv_file))
42 |
43 | if __name__ == "__main__":
44 | logging.set_verbosity(logging.INFO)
45 | CSV_FILE = sys.argv[1]
46 | create_dataset(CSV_FILE)
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.0.3
2 | sox
3 | absl-py
4 | yapf
5 | pylint
6 | flake8
7 | horovod
8 | tqdm
9 | sentencepiece
10 | librosa
11 | kenlm
12 | jieba
13 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) ATHENA AUTHORS
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | import sys
18 | from glob import glob
19 | import shutil
20 | import setuptools
21 | import tensorflow as tf
22 |
23 | TF_INCLUDE, TF_CFLAG = tf.sysconfig.get_compile_flags()
24 | TF_INCLUDE = TF_INCLUDE.split("-I")[1]
25 |
26 | TF_LIB_INC, TF_SO_LIB = tf.sysconfig.get_link_flags()
27 | TF_SO_LIB = TF_SO_LIB.replace(
28 | "-l:libtensorflow_framework.1.dylib", "-ltensorflow_framework.1"
29 | )
30 | TF_LIB_INC = TF_LIB_INC.split("-L")[1]
31 | TF_SO_LIB = TF_SO_LIB.split("-l")[1]
32 |
33 | complie_args = [TF_CFLAG, "-fPIC", "-shared", "-O2", "-std=c++11"]
34 | if sys.platform == "darwin": # Mac os X before Mavericks (10.9)
35 | complie_args.append("-stdlib=libc++")
36 |
37 | module = setuptools.Extension(
38 | "athena.transform.feats.ops.x_ops",
39 | sources=glob("athena/transform/feats/ops/kernels/*.cc"),
40 | depends=glob("athena/transform/feats/ops/kernels/*.h"),
41 | extra_compile_args=complie_args,
42 | include_dirs=["athena/transform/feats/ops", TF_INCLUDE],
43 | library_dirs=[TF_LIB_INC],
44 | libraries=[TF_SO_LIB],
45 | language="c++",
46 | )
47 |
48 |
49 | with open("README.md", "r") as fh:
50 | long_description = fh.read()
51 | setuptools.setup(
52 | name="athena",
53 | version="0.1.0",
54 | author="ATHENA AUTHORS",
55 | author_email="athena@gmail.com",
56 | description="for speech recognition",
57 | long_description=long_description,
58 | long_description_content_type="text/markdown",
59 | url="https://github.com/didichuxing/athena",
60 | packages=setuptools.find_packages(),
61 | package_data={"": ["x_ops*.so", "*.vocab", "sph2pipe"]},
62 | exclude_package_data={"feats": ["*_test*"]},
63 | ext_modules=[module],
64 | python_requires=">=3",
65 | )
66 | path = glob("build/lib.*/athena/transform/feats/ops/x_ops.*.so")
67 | shutil.copy(path[0], "athena/transform/feats/ops/x_ops.so")
68 |
--------------------------------------------------------------------------------
/tools/env.sh:
--------------------------------------------------------------------------------
1 | # check if we are executing or sourcing.
2 | if [[ "$0" == "$BASH_SOURCE" ]]
3 | then
4 | echo "You must source this script rather than executing it."
5 | exit -1
6 | fi
7 |
8 | # don't use readlink because macOS won't support it.
9 | if [[ "$BASH_SOURCE" == "/"* ]]
10 | then
11 | export MAIN_ROOT=$(dirname $BASH_SOURCE)
12 | else
13 | export MAIN_ROOT=$PWD/$(dirname $BASH_SOURCE)
14 | fi
15 |
16 | # pip bins
17 | export PATH=$PATH:~/.local/bin
18 |
19 | # athena
20 | export PYTHONPATH=${PYTHONPATH}:$MAIN_ROOT:
21 |
--------------------------------------------------------------------------------
/tools/install.sh:
--------------------------------------------------------------------------------
1 |
2 | rm -rf build/ athena.egg-info/ dist/
3 | python setup.py bdist_wheel sdist
4 | python -m pip install --ignore-installed dist/athena-0.1.0*.whl
5 |
--------------------------------------------------------------------------------
/tools/install_kenlm.sh:
--------------------------------------------------------------------------------
1 | # reference to https://github.com/kpu/kenlm
2 | cd tools
3 | wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
4 |
5 | mkdir -p kenlm/build
6 | cd kenlm/build
7 | cmake ..
8 | make -j 4
9 |
--------------------------------------------------------------------------------
/tools/install_sph2pipe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Tool to convert sph audio format to wav format.
3 | cd tools
4 | if [ ! -e sph2pipe_v2.5.tar.gz ]; then
5 | wget -T 10 -t 3 http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz || exit 1
6 | fi
7 |
8 | tar -zxvf sph2pipe_v2.5.tar.gz || exit 1
9 | cd sph2pipe_v2.5/
10 | gcc -o sph2pipe *.c -lm || exit 1
11 |
--------------------------------------------------------------------------------