├── .gitattributes ├── .github ├── stale.yml └── workflows │ └── ci.yaml ├── .gitignore ├── LICENSE ├── README.md ├── docker-compose.yml ├── dockerfile ├── examples ├── android │ ├── .gitignore │ ├── README.md │ ├── app │ │ ├── .gitignore │ │ ├── build.gradle │ │ ├── proguard-rules.pro │ │ └── src │ │ │ ├── androidTest │ │ │ └── java │ │ │ │ └── com │ │ │ │ └── tensorspeech │ │ │ │ └── tensorflowtts │ │ │ │ └── ExampleInstrumentedTest.java │ │ │ ├── main │ │ │ ├── AndroidManifest.xml │ │ │ ├── assets │ │ │ │ ├── fastspeech2_quant.tflite │ │ │ │ └── mbmelgan.tflite │ │ │ ├── java │ │ │ │ └── com │ │ │ │ │ └── tensorspeech │ │ │ │ │ └── tensorflowtts │ │ │ │ │ ├── MainActivity.java │ │ │ │ │ ├── dispatcher │ │ │ │ │ ├── OnTtsStateListener.java │ │ │ │ │ └── TtsStateDispatcher.java │ │ │ │ │ ├── module │ │ │ │ │ ├── AbstractModule.java │ │ │ │ │ ├── FastSpeech2.java │ │ │ │ │ └── MBMelGan.java │ │ │ │ │ ├── tts │ │ │ │ │ ├── InputWorker.java │ │ │ │ │ ├── TtsManager.java │ │ │ │ │ └── TtsPlayer.java │ │ │ │ │ └── utils │ │ │ │ │ ├── NumberNorm.java │ │ │ │ │ ├── Processor.java │ │ │ │ │ └── ThreadPoolManager.java │ │ │ └── res │ │ │ │ ├── drawable-v24 │ │ │ │ └── ic_launcher_foreground.xml │ │ │ │ ├── drawable │ │ │ │ └── ic_launcher_background.xml │ │ │ │ ├── layout │ │ │ │ └── activity_main.xml │ │ │ │ ├── mipmap-anydpi-v26 │ │ │ │ ├── ic_launcher.xml │ │ │ │ └── ic_launcher_round.xml │ │ │ │ ├── mipmap-hdpi │ │ │ │ ├── ic_launcher.png │ │ │ │ └── ic_launcher_round.png │ │ │ │ ├── mipmap-mdpi │ │ │ │ ├── ic_launcher.png │ │ │ │ └── ic_launcher_round.png │ │ │ │ ├── mipmap-xhdpi │ │ │ │ ├── ic_launcher.png │ │ │ │ └── ic_launcher_round.png │ │ │ │ ├── mipmap-xxhdpi │ │ │ │ ├── ic_launcher.png │ │ │ │ └── ic_launcher_round.png │ │ │ │ ├── mipmap-xxxhdpi │ │ │ │ ├── ic_launcher.png │ │ │ │ └── ic_launcher_round.png │ │ │ │ └── values │ │ │ │ ├── colors.xml │ │ │ │ ├── strings.xml │ │ │ │ └── styles.xml │ │ │ └── test │ │ │ └── java │ │ │ └── com │ │ │ └── tensorspeech │ │ │ └── tensorflowtts │ │ │ └── ExampleUnitTest.java │ ├── build.gradle │ ├── gradle.properties │ ├── gradle │ │ └── wrapper │ │ │ ├── gradle-wrapper.jar │ │ │ └── gradle-wrapper.properties │ ├── gradlew │ ├── gradlew.bat │ └── settings.gradle ├── cpptflite │ ├── .gitignore │ ├── CMakeLists.txt │ ├── README.md │ ├── demo │ │ ├── main.cpp │ │ └── text2ids.py │ ├── results │ │ ├── lj_ori_mel.png │ │ ├── lj_tflite_mel.png │ │ ├── tflite_mel.png │ │ └── tflite_mel2.png │ └── src │ │ ├── AudioFile.h │ │ ├── MelGenerateTF.cpp │ │ ├── MelGenerateTF.h │ │ ├── TTSBackend.cpp │ │ ├── TTSBackend.h │ │ ├── TTSFrontend.cpp │ │ ├── TTSFrontend.h │ │ ├── TfliteBase.cpp │ │ ├── TfliteBase.h │ │ ├── VocoderTF.cpp │ │ ├── VocoderTF.h │ │ ├── VoxCommon.cpp │ │ └── VoxCommon.h ├── cppwin │ ├── .gitattributes │ ├── .gitignore │ ├── README.md │ ├── TensorflowTTSCppInference.pro │ ├── TensorflowTTSCppInference.sln │ └── TensorflowTTSCppInference │ │ ├── EnglishPhoneticProcessor.cpp │ │ ├── EnglishPhoneticProcessor.h │ │ ├── FastSpeech2.cpp │ │ ├── FastSpeech2.h │ │ ├── MultiBandMelGAN.cpp │ │ ├── MultiBandMelGAN.h │ │ ├── TensorflowTTSCppInference.cpp │ │ ├── TensorflowTTSCppInference.vcxproj │ │ ├── TensorflowTTSCppInference.vcxproj.filters │ │ ├── TextTokenizer.cpp │ │ ├── TextTokenizer.h │ │ ├── Voice.cpp │ │ ├── Voice.h │ │ ├── VoxCommon.cpp │ │ ├── VoxCommon.hpp │ │ ├── ext │ │ ├── AudioFile.hpp │ │ ├── CppFlow │ │ │ ├── include │ │ │ │ ├── Model.h │ │ │ │ └── Tensor.h │ │ │ └── src │ │ │ │ ├── Model.cpp │ │ │ │ └── Tensor.cpp │ │ ├── ZCharScanner.cpp │ │ ├── ZCharScanner.h │ │ ├── cxxopts.hpp │ │ └── json.hpp │ │ ├── phonemizer.cpp │ │ ├── phonemizer.h │ │ ├── tfg2p.cpp │ │ └── tfg2p.h ├── fastspeech │ ├── README.md │ ├── conf │ │ ├── fastspeech.v1.yaml │ │ └── fastspeech.v3.yaml │ ├── decode_fastspeech.py │ ├── fastspeech_dataset.py │ ├── fig │ │ └── fastspeech.v1.png │ └── train_fastspeech.py ├── fastspeech2 │ ├── README.md │ ├── conf │ │ ├── fastspeech2.baker.v2.yaml │ │ ├── fastspeech2.jsut.v1.yaml │ │ ├── fastspeech2.kss.v1.yaml │ │ ├── fastspeech2.kss.v2.yaml │ │ ├── fastspeech2.v1.yaml │ │ └── fastspeech2.v2.yaml │ ├── decode_fastspeech2.py │ ├── extractfs_postnets.py │ ├── fastspeech2_dataset.py │ └── train_fastspeech2.py ├── fastspeech2_libritts │ ├── README.md │ ├── conf │ │ └── fastspeech2libritts.yaml │ ├── fastspeech2_dataset.py │ ├── libri_experiment │ │ └── prepare_libri.ipynb │ ├── scripts │ │ ├── build.sh │ │ ├── docker │ │ │ └── Dockerfile │ │ ├── interactive.sh │ │ └── train_libri.sh │ └── train_fastspeech2.py ├── hifigan │ ├── README.md │ ├── conf │ │ ├── hifigan.v1.yaml │ │ └── hifigan.v2.yaml │ └── train_hifigan.py ├── ios │ ├── .gitignore │ ├── Podfile │ ├── Podfile.lock │ ├── README.md │ ├── TF_TTS_Demo.xcodeproj │ │ └── project.pbxproj │ └── TF_TTS_Demo │ │ ├── Assets.xcassets │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ ├── AppIcon.appiconset │ │ │ └── Contents.json │ │ └── Contents.json │ │ ├── ContentView.swift │ │ ├── FastSpeech2.swift │ │ ├── Info.plist │ │ ├── MBMelGAN.swift │ │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ │ ├── TF_TTS_DemoApp.swift │ │ └── TTS.swift ├── melgan │ ├── README.md │ ├── audio_mel_dataset.py │ ├── conf │ │ └── melgan.v1.yaml │ ├── decode_melgan.py │ ├── fig │ │ └── melgan.v1.png │ └── train_melgan.py ├── melgan_stft │ ├── README.md │ ├── conf │ │ └── melgan_stft.v1.yaml │ ├── fig │ │ ├── melgan.stft.v1.eval.png │ │ └── melgan.stft.v1.train.png │ └── train_melgan_stft.py ├── mfa_extraction │ ├── README.md │ ├── fix_mismatch.py │ ├── requirements.txt │ ├── run_mfa.py │ ├── scripts │ │ └── prepare_mfa.sh │ └── txt_grid_parser.py ├── multiband_melgan │ ├── README.md │ ├── conf │ │ ├── multiband_melgan.baker.v1.yaml │ │ ├── multiband_melgan.synpaflex.v1.yaml │ │ └── multiband_melgan.v1.yaml │ ├── decode_mb_melgan.py │ ├── fig │ │ ├── eval.png │ │ └── train.png │ └── train_multiband_melgan.py ├── multiband_melgan_hf │ ├── README.md │ ├── conf │ │ ├── multiband_melgan_hf.lju.v1.yml │ │ └── multiband_melgan_hf.lju.v1ft.yml │ ├── decode_mb_melgan.py │ ├── fig │ │ ├── eval.png │ │ └── train.png │ └── train_multiband_melgan_hf.py ├── parallel_wavegan │ ├── README.md │ ├── conf │ │ └── parallel_wavegan.v1.yaml │ ├── convert_pwgan_from_pytorch_to_tensorflow.ipynb │ ├── decode_parallel_wavegan.py │ └── train_parallel_wavegan.py └── tacotron2 │ ├── README.md │ ├── conf │ ├── tacotron2.baker.v1.yaml │ ├── tacotron2.jsut.v1.yaml │ ├── tacotron2.kss.v1.yaml │ ├── tacotron2.lju.v1.yaml │ ├── tacotron2.synpaflex.v1.yaml │ └── tacotron2.v1.yaml │ ├── decode_tacotron2.py │ ├── export_align.py │ ├── extract_duration.py │ ├── extract_postnets.py │ ├── fig │ ├── alignment.gif │ └── tensorboard.png │ ├── tacotron_dataset.py │ └── train_tacotron2.py ├── notebooks ├── Parallel_WaveGAN_TFLite.ipynb ├── TensorFlowTTS_FastSpeech_with_TFLite.ipynb ├── TensorFlowTTS_Tacotron2_with_TFLite.ipynb ├── fastspeech2_inference.ipynb ├── fastspeech_inference.ipynb ├── griffin_lim_tensorflow.ipynb ├── multiband_melgan_inference.ipynb ├── prepare_synpaflex.ipynb └── tacotron2_inference.ipynb ├── preprocess ├── baker_preprocess.yaml ├── jsut_preprocess.yaml ├── kss_preprocess.yaml ├── libritts_preprocess.yaml ├── ljspeech_preprocess.yaml ├── ljspeechu_preprocess.yaml ├── synpaflex_preprocess.yaml └── thorsten_preprocess.yaml ├── setup.cfg ├── setup.py ├── tensorflow_tts ├── __init__.py ├── bin │ ├── __init__.py │ └── preprocess.py ├── configs │ ├── __init__.py │ ├── base_config.py │ ├── fastspeech.py │ ├── fastspeech2.py │ ├── hifigan.py │ ├── mb_melgan.py │ ├── melgan.py │ ├── parallel_wavegan.py │ └── tacotron2.py ├── datasets │ ├── __init__.py │ ├── abstract_dataset.py │ ├── audio_dataset.py │ └── mel_dataset.py ├── inference │ ├── __init__.py │ ├── auto_config.py │ ├── auto_model.py │ ├── auto_processor.py │ └── savable_models.py ├── losses │ ├── __init__.py │ ├── spectrogram.py │ └── stft.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── fastspeech.py │ ├── fastspeech2.py │ ├── hifigan.py │ ├── mb_melgan.py │ ├── melgan.py │ ├── parallel_wavegan.py │ └── tacotron2.py ├── optimizers │ ├── __init__.py │ ├── adamweightdecay.py │ └── gradient_accumulate.py ├── processor │ ├── __init__.py │ ├── baker.py │ ├── base_processor.py │ ├── jsut.py │ ├── kss.py │ ├── libritts.py │ ├── ljspeech.py │ ├── ljspeechu.py │ ├── pretrained │ │ ├── baker_mapper.json │ │ ├── jsut_mapper.json │ │ ├── kss_mapper.json │ │ ├── libritts_mapper.json │ │ ├── ljspeech_mapper.json │ │ ├── ljspeechu_mapper.json │ │ ├── synpaflex_mapper.json │ │ └── thorsten_mapper.json │ ├── synpaflex.py │ └── thorsten.py ├── trainers │ ├── __init__.py │ └── base_trainer.py └── utils │ ├── __init__.py │ ├── cleaners.py │ ├── decoder.py │ ├── griffin_lim.py │ ├── group_conv.py │ ├── korean.py │ ├── number_norm.py │ ├── outliers.py │ ├── strategy.py │ ├── utils.py │ └── weight_norm.py └── test ├── files ├── baker_mapper.json ├── kss_mapper.json ├── libritts_mapper.json ├── ljspeech_mapper.json ├── mapper.json └── train.txt ├── test_auto.py ├── test_base_processor.py ├── test_fastspeech.py ├── test_fastspeech2.py ├── test_hifigan.py ├── test_mb_melgan.py ├── test_melgan.py ├── test_melgan_layers.py ├── test_parallel_wavegan.py └── test_tacotron2.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. 15 | # Comment to post when closing a stale issue. Set to `false` to disable 16 | closeComment: false 17 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | schedule: 11 | - cron: 0 0 * * 1 12 | 13 | jobs: 14 | linter_and_test: 15 | runs-on: ubuntu-18.04 16 | strategy: 17 | max-parallel: 10 18 | matrix: 19 | python-version: [3.7] 20 | tensorflow-version: [2.7.0] 21 | steps: 22 | - uses: actions/checkout@master 23 | - uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | architecture: 'x64' 27 | - uses: actions/cache@v1 28 | with: 29 | path: ~/.cache/pip 30 | key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-pip-${{ hashFiles('**/setup.py') }} 31 | restore-keys: | 32 | ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.tensorflow-version }}-pip- 33 | - name: Install dependencies 34 | run: | 35 | # install python modules 36 | python -m pip install --upgrade pip 37 | pip install -q -U numpy 38 | pip install git+https://github.com/repodiac/german_transliterate.git#egg=german_transliterate 39 | pip install -q tensorflow-gpu==${{ matrix.tensorflow-version }} 40 | pip install -q -e . 41 | pip install -q -e .[test] 42 | pip install typing_extensions 43 | sudo apt-get install libsndfile1-dev 44 | python -m pip install black 45 | - name: black 46 | run: | 47 | python -m black . 48 | - name: Pytest 49 | run: | 50 | pytest test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # general 3 | *~ 4 | *.pyc 5 | \#*\# 6 | .\#* 7 | *DS_Store 8 | out.txt 9 | TensorFlowTTS.egg-info/ 10 | doc/_build 11 | slurm-*.out 12 | tmp* 13 | .eggs/ 14 | .hypothesis/ 15 | .idea 16 | .backup/ 17 | .pytest_cache/ 18 | __pycache__/ 19 | .coverage* 20 | coverage.xml* 21 | .vscode* 22 | .nfs* 23 | .ipynb_checkpoints 24 | ljspeech 25 | *.h5 26 | *.npy 27 | ./*.wav 28 | !docker-compose.yml 29 | /Pipfile 30 | /Pipfile.lock 31 | /datasets 32 | /examples/tacotron2/exp/ 33 | /temp/ 34 | LibriTTS/ 35 | dataset/ 36 | mfa/ 37 | kss/ 38 | baker/ 39 | libritts/ 40 | dump_baker/ 41 | dump_ljspeech/ 42 | dump_kss/ 43 | dump_libritts/ 44 | /notebooks/test_saved/ 45 | build/ 46 | dist/ -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.6' 2 | services: 3 | tensorflowtts: 4 | build: . 5 | volumes: 6 | - .:/workspace 7 | runtime: nvidia 8 | tty: true 9 | command: /bin/bash 10 | environment: 11 | - CUDA_VISIBLE_DEVICES 12 | -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.6.0-gpu 2 | RUN apt-get update 3 | RUN apt-get install -y zsh tmux wget git libsndfile1 4 | RUN pip install ipython && \ 5 | pip install git+https://github.com/TensorSpeech/TensorflowTTS.git && \ 6 | pip install git+https://github.com/repodiac/german_transliterate.git#egg=german_transliterate 7 | RUN mkdir /workspace 8 | WORKDIR /workspace 9 | -------------------------------------------------------------------------------- /examples/android/.gitignore: -------------------------------------------------------------------------------- 1 | # Android Studio 2 | *.iml 3 | .gradle 4 | /local.properties 5 | /.idea 6 | .DS_Store 7 | /build 8 | /captures 9 | 10 | # Built application files 11 | *.apk 12 | !prebuiltapps/*.apk 13 | *.ap_ 14 | 15 | # Files for the Dalvik VM 16 | *.dex 17 | 18 | # Java class files 19 | *.class 20 | 21 | # Generated files 22 | bin/ 23 | gen/ 24 | 25 | # Gradle files 26 | .gradle/ 27 | build/ 28 | */build/ 29 | 30 | # Local configuration file (sdk path, etc) 31 | local.properties 32 | 33 | # Proguard folder generated by Eclipse 34 | proguard/ 35 | 36 | # Log Files 37 | *.log 38 | 39 | # project 40 | project.properties 41 | .classpath 42 | .project 43 | .settings/ 44 | 45 | # Intellij project files 46 | *.ipr 47 | *.iws 48 | .idea/ 49 | app/.gradle/ 50 | .idea/libraries 51 | .idea/workspace.xml 52 | .idea/vcs.xml 53 | .idea/scopes/scope_setting.xml 54 | .idea/moudles.xml 55 | .idea/misc.xml 56 | .idea/inspectionProfiles/Project_Default.xml 57 | .idea/inspectionProfiles/profiles_setting.xml 58 | .idea/encodings.xml 59 | .idea/.name 60 | -------------------------------------------------------------------------------- /examples/android/README.md: -------------------------------------------------------------------------------- 1 | ### Android Demo 2 | 3 | This is a simple Android demo which will load converted FastSpeech2 and Multi-Band MelGAN modules to synthesize audio. 4 | In order to optimize the synthesize speed, two LinkedBlockingQueues have been implemented. 5 | 6 | 7 | ### HOW-TO 8 | 1. Import this project into Android Studio. 9 | 2. Run the app! 10 | 11 | ### LICENSE 12 | The license use for this code is [CC BY-NC 3.0](https://creativecommons.org/licenses/by-nc/3.0/). Please read the license carefully before you use it. 13 | 14 | ### Contributors 15 | [Xuefeng Ding](https://github.com/mapledxf) 16 | -------------------------------------------------------------------------------- /examples/android/app/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | -------------------------------------------------------------------------------- /examples/android/app/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.application' 2 | 3 | android { 4 | compileSdkVersion 29 5 | buildToolsVersion "29.0.2" 6 | defaultConfig { 7 | applicationId "com.tensorspeech.tensorflowtts" 8 | minSdkVersion 21 9 | targetSdkVersion 29 10 | versionCode 1 11 | versionName "1.0" 12 | } 13 | buildTypes { 14 | release { 15 | minifyEnabled false 16 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 17 | } 18 | } 19 | aaptOptions { 20 | noCompress "tflite" 21 | } 22 | compileOptions { 23 | sourceCompatibility = '1.8' 24 | targetCompatibility = '1.8' 25 | } 26 | lintOptions { 27 | abortOnError false 28 | } 29 | } 30 | 31 | dependencies { 32 | implementation fileTree(dir: 'libs', include: ['*.jar']) 33 | implementation 'androidx.appcompat:appcompat:1.1.0' 34 | implementation 'androidx.constraintlayout:constraintlayout:1.1.3' 35 | 36 | implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' 37 | implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly' 38 | implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly' 39 | } 40 | -------------------------------------------------------------------------------- /examples/android/app/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile 22 | -------------------------------------------------------------------------------- /examples/android/app/src/androidTest/java/com/tensorspeech/tensorflowtts/ExampleInstrumentedTest.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts; 2 | 3 | import android.content.Context; 4 | 5 | import androidx.test.platform.app.InstrumentationRegistry; 6 | import androidx.test.ext.junit.runners.AndroidJUnit4; 7 | 8 | import org.junit.Test; 9 | import org.junit.runner.RunWith; 10 | 11 | import static org.junit.Assert.*; 12 | 13 | /** 14 | * Instrumented test, which will execute on an Android device. 15 | * 16 | * @see Testing documentation 17 | */ 18 | @RunWith(AndroidJUnit4.class) 19 | public class ExampleInstrumentedTest { 20 | @Test 21 | public void useAppContext() { 22 | // Context of the app under test. 23 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); 24 | 25 | assertEquals("com.tensorspeech.tensorflowtts", appContext.getPackageName()); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /examples/android/app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /examples/android/app/src/main/assets/fastspeech2_quant.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/assets/fastspeech2_quant.tflite -------------------------------------------------------------------------------- /examples/android/app/src/main/assets/mbmelgan.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/assets/mbmelgan.tflite -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/MainActivity.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts; 2 | 3 | import android.os.Bundle; 4 | import android.text.TextUtils; 5 | import android.view.View; 6 | import android.widget.EditText; 7 | import android.widget.RadioGroup; 8 | 9 | import androidx.appcompat.app.AppCompatActivity; 10 | 11 | import com.tensorspeech.tensorflowtts.dispatcher.OnTtsStateListener; 12 | import com.tensorspeech.tensorflowtts.dispatcher.TtsStateDispatcher; 13 | import com.tensorspeech.tensorflowtts.tts.TtsManager; 14 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager; 15 | 16 | /** 17 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 18 | * Created 2020-07-20 17:25 19 | */ 20 | public class MainActivity extends AppCompatActivity { 21 | private static final String DEFAULT_INPUT_TEXT = "Unless you work on a ship, it's unlikely that you use the word boatswain in everyday conversation, so it's understandably a tricky one. The word - which refers to a petty officer in charge of hull maintenance is not pronounced boats-wain Rather, it's bo-sun to reflect the salty pronunciation of sailors, as The Free Dictionary explains./Blue opinion poll conducted for the National Post."; 22 | 23 | private View speakBtn; 24 | private RadioGroup speedGroup; 25 | 26 | @Override 27 | protected void onCreate(Bundle savedInstanceState) { 28 | super.onCreate(savedInstanceState); 29 | setContentView(R.layout.activity_main); 30 | 31 | TtsManager.getInstance().init(this); 32 | 33 | TtsStateDispatcher.getInstance().addListener(new OnTtsStateListener() { 34 | @Override 35 | public void onTtsReady() { 36 | speakBtn.setEnabled(true); 37 | } 38 | 39 | @Override 40 | public void onTtsStart(String text) { 41 | } 42 | 43 | @Override 44 | public void onTtsStop() { 45 | } 46 | }); 47 | 48 | EditText input = findViewById(R.id.input); 49 | input.setHint(DEFAULT_INPUT_TEXT); 50 | 51 | speedGroup = findViewById(R.id.speed_chooser); 52 | speedGroup.check(R.id.normal); 53 | 54 | speakBtn = findViewById(R.id.start); 55 | speakBtn.setEnabled(false); 56 | speakBtn.setOnClickListener(v -> 57 | ThreadPoolManager.getInstance().execute(() -> { 58 | float speed ; 59 | switch (speedGroup.getCheckedRadioButtonId()) { 60 | case R.id.fast: 61 | speed = 0.8F; 62 | break; 63 | case R.id.slow: 64 | speed = 1.2F; 65 | break; 66 | case R.id.normal: 67 | default: 68 | speed = 1.0F; 69 | break; 70 | } 71 | 72 | String inputText = input.getText().toString(); 73 | if (TextUtils.isEmpty(inputText)) { 74 | inputText = DEFAULT_INPUT_TEXT; 75 | } 76 | TtsManager.getInstance().speak(inputText, speed, true); 77 | })); 78 | 79 | findViewById(R.id.stop).setOnClickListener(v -> 80 | TtsManager.getInstance().stopTts()); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/dispatcher/OnTtsStateListener.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.dispatcher; 2 | 3 | /** 4 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 5 | * Created 2020-07-28 14:25 6 | */ 7 | public interface OnTtsStateListener { 8 | public void onTtsReady(); 9 | 10 | public void onTtsStart(String text); 11 | 12 | public void onTtsStop(); 13 | } -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/dispatcher/TtsStateDispatcher.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.dispatcher; 2 | 3 | import android.os.Handler; 4 | import android.os.Looper; 5 | import android.util.Log; 6 | 7 | import java.util.concurrent.CopyOnWriteArrayList; 8 | 9 | /** 10 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 11 | * Created 2020-07-28 14:25 12 | */ 13 | public class TtsStateDispatcher { 14 | private static final String TAG = "TtsStateDispatcher"; 15 | private static volatile TtsStateDispatcher instance; 16 | private static final Object INSTANCE_WRITE_LOCK = new Object(); 17 | 18 | public static TtsStateDispatcher getInstance() { 19 | if (instance == null) { 20 | synchronized (INSTANCE_WRITE_LOCK) { 21 | if (instance == null) { 22 | instance = new TtsStateDispatcher(); 23 | } 24 | } 25 | } 26 | return instance; 27 | } 28 | 29 | private final Handler handler = new Handler(Looper.getMainLooper()); 30 | 31 | private CopyOnWriteArrayList mListeners = new CopyOnWriteArrayList<>(); 32 | 33 | public void release() { 34 | Log.d(TAG, "release: "); 35 | mListeners.clear(); 36 | } 37 | 38 | public void addListener(OnTtsStateListener listener) { 39 | if (mListeners.contains(listener)) { 40 | return; 41 | } 42 | Log.d(TAG, "addListener: " + listener.getClass()); 43 | mListeners.add(listener); 44 | } 45 | 46 | public void removeListener(OnTtsStateListener listener) { 47 | if (mListeners.contains(listener)) { 48 | Log.d(TAG, "removeListener: " + listener.getClass()); 49 | mListeners.remove(listener); 50 | } 51 | } 52 | 53 | public void onTtsStart(String text){ 54 | Log.d(TAG, "onTtsStart: "); 55 | if (!mListeners.isEmpty()) { 56 | for (OnTtsStateListener listener : mListeners) { 57 | handler.post(() -> listener.onTtsStart(text)); 58 | } 59 | } 60 | } 61 | 62 | public void onTtsStop(){ 63 | Log.d(TAG, "onTtsStop: "); 64 | if (!mListeners.isEmpty()) { 65 | for (OnTtsStateListener listener : mListeners) { 66 | handler.post(listener::onTtsStop); 67 | } 68 | } 69 | } 70 | 71 | public void onTtsReady(){ 72 | Log.d(TAG, "onTtsReady: "); 73 | if (!mListeners.isEmpty()) { 74 | for (OnTtsStateListener listener : mListeners) { 75 | handler.post(listener::onTtsReady); 76 | } 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/AbstractModule.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.module; 2 | 3 | import org.tensorflow.lite.Interpreter; 4 | 5 | /** 6 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 7 | * Created 2020-07-20 17:25 8 | * 9 | */ 10 | abstract class AbstractModule { 11 | 12 | Interpreter.Options getOption() { 13 | Interpreter.Options options = new Interpreter.Options(); 14 | options.setNumThreads(5); 15 | return options; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/FastSpeech2.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.module; 2 | 3 | import android.annotation.SuppressLint; 4 | import android.util.Log; 5 | 6 | import org.tensorflow.lite.DataType; 7 | import org.tensorflow.lite.Interpreter; 8 | import org.tensorflow.lite.Tensor; 9 | import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; 10 | 11 | import java.io.File; 12 | import java.nio.FloatBuffer; 13 | import java.util.Arrays; 14 | import java.util.HashMap; 15 | import java.util.Map; 16 | 17 | /** 18 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 19 | * Created 2020-07-20 17:26 20 | * 21 | */ 22 | public class FastSpeech2 extends AbstractModule { 23 | private static final String TAG = "FastSpeech2"; 24 | private Interpreter mModule; 25 | 26 | public FastSpeech2(String modulePath) { 27 | try { 28 | mModule = new Interpreter(new File(modulePath), getOption()); 29 | int input = mModule.getInputTensorCount(); 30 | for (int i = 0; i < input; i++) { 31 | Tensor inputTensor = mModule.getInputTensor(i); 32 | Log.d(TAG, "input:" + i + 33 | " name:" + inputTensor.name() + 34 | " shape:" + Arrays.toString(inputTensor.shape()) + 35 | " dtype:" + inputTensor.dataType()); 36 | } 37 | 38 | int output = mModule.getOutputTensorCount(); 39 | for (int i = 0; i < output; i++) { 40 | Tensor outputTensor = mModule.getOutputTensor(i); 41 | Log.d(TAG, "output:" + i + 42 | " name:" + outputTensor.name() + 43 | " shape:" + Arrays.toString(outputTensor.shape()) + 44 | " dtype:" + outputTensor.dataType()); 45 | } 46 | Log.d(TAG, "successfully init"); 47 | } catch (Exception e) { 48 | e.printStackTrace(); 49 | } 50 | } 51 | 52 | public TensorBuffer getMelSpectrogram(int[] inputIds, float speed) { 53 | Log.d(TAG, "input id length: " + inputIds.length); 54 | mModule.resizeInput(0, new int[]{1, inputIds.length}); 55 | mModule.allocateTensors(); 56 | 57 | @SuppressLint("UseSparseArrays") 58 | Map outputMap = new HashMap<>(); 59 | 60 | FloatBuffer outputBuffer = FloatBuffer.allocate(350000); 61 | outputMap.put(0, outputBuffer); 62 | 63 | int[][] inputs = new int[1][inputIds.length]; 64 | inputs[0] = inputIds; 65 | 66 | long time = System.currentTimeMillis(); 67 | mModule.runForMultipleInputsOutputs( 68 | new Object[]{inputs, new int[1][1], new int[]{0}, new float[]{speed}, new float[]{1F}, new float[]{1F}}, 69 | outputMap); 70 | Log.d(TAG, "time cost: " + (System.currentTimeMillis() - time)); 71 | 72 | int size = mModule.getOutputTensor(0).shape()[2]; 73 | int[] shape = {1, outputBuffer.position() / size, size}; 74 | TensorBuffer spectrogram = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); 75 | float[] outputArray = new float[outputBuffer.position()]; 76 | outputBuffer.rewind(); 77 | outputBuffer.get(outputArray); 78 | spectrogram.loadArray(outputArray); 79 | 80 | return spectrogram; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/MBMelGan.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.module; 2 | 3 | import android.util.Log; 4 | 5 | import org.tensorflow.lite.Interpreter; 6 | import org.tensorflow.lite.Tensor; 7 | import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; 8 | 9 | import java.io.File; 10 | import java.nio.FloatBuffer; 11 | import java.util.Arrays; 12 | 13 | /** 14 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 15 | * Created 2020-07-20 17:26 16 | * 17 | */ 18 | public class MBMelGan extends AbstractModule { 19 | private static final String TAG = "MBMelGan"; 20 | private Interpreter mModule; 21 | 22 | public MBMelGan(String modulePath) { 23 | try { 24 | mModule = new Interpreter(new File(modulePath), getOption()); 25 | int input = mModule.getInputTensorCount(); 26 | for (int i = 0; i < input; i++) { 27 | Tensor inputTensor = mModule.getInputTensor(i); 28 | Log.d(TAG, "input:" + i 29 | + " name:" + inputTensor.name() 30 | + " shape:" + Arrays.toString(inputTensor.shape()) + 31 | " dtype:" + inputTensor.dataType()); 32 | } 33 | 34 | int output = mModule.getOutputTensorCount(); 35 | for (int i = 0; i < output; i++) { 36 | Tensor outputTensor = mModule.getOutputTensor(i); 37 | Log.d(TAG, "output:" + i 38 | + " name:" + outputTensor.name() 39 | + " shape:" + Arrays.toString(outputTensor.shape()) 40 | + " dtype:" + outputTensor.dataType()); 41 | } 42 | Log.d(TAG, "successfully init"); 43 | } catch (Exception e) { 44 | e.printStackTrace(); 45 | } 46 | } 47 | 48 | 49 | public float[] getAudio(TensorBuffer input) { 50 | mModule.resizeInput(0, input.getShape()); 51 | mModule.allocateTensors(); 52 | 53 | FloatBuffer outputBuffer = FloatBuffer.allocate(350000); 54 | 55 | long time = System.currentTimeMillis(); 56 | mModule.run(input.getBuffer(), outputBuffer); 57 | Log.d(TAG, "time cost: " + (System.currentTimeMillis() - time)); 58 | 59 | float[] audioArray = new float[outputBuffer.position()]; 60 | outputBuffer.rewind(); 61 | outputBuffer.get(audioArray); 62 | return audioArray; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/tts/TtsManager.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.tts; 2 | 3 | import android.content.Context; 4 | import android.util.Log; 5 | 6 | import com.tensorspeech.tensorflowtts.dispatcher.TtsStateDispatcher; 7 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager; 8 | 9 | import java.io.File; 10 | import java.io.FileOutputStream; 11 | import java.io.InputStream; 12 | import java.io.OutputStream; 13 | 14 | /** 15 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 16 | * Created 2020-07-28 14:25 17 | */ 18 | public class TtsManager { 19 | private static final String TAG = "TtsManager"; 20 | 21 | private static final Object INSTANCE_WRITE_LOCK = new Object(); 22 | 23 | private static volatile TtsManager instance; 24 | 25 | public static TtsManager getInstance() { 26 | if (instance == null) { 27 | synchronized (INSTANCE_WRITE_LOCK) { 28 | if (instance == null) { 29 | instance = new TtsManager(); 30 | } 31 | } 32 | } 33 | return instance; 34 | } 35 | 36 | private InputWorker mWorker; 37 | 38 | private final static String FASTSPEECH2_MODULE = "fastspeech2_quant.tflite"; 39 | private final static String MELGAN_MODULE = "mbmelgan.tflite"; 40 | 41 | public void init(Context context) { 42 | ThreadPoolManager.getInstance().getSingleExecutor("init").execute(() -> { 43 | try { 44 | String fastspeech = copyFile(context, FASTSPEECH2_MODULE); 45 | String vocoder = copyFile(context, MELGAN_MODULE); 46 | mWorker = new InputWorker(fastspeech, vocoder); 47 | } catch (Exception e) { 48 | Log.e(TAG, "mWorker init failed", e); 49 | } 50 | 51 | TtsStateDispatcher.getInstance().onTtsReady(); 52 | }); 53 | } 54 | 55 | private String copyFile(Context context, String strOutFileName) { 56 | Log.d(TAG, "start copy file " + strOutFileName); 57 | File file = context.getFilesDir(); 58 | 59 | String tmpFile = file.getAbsolutePath() + "/" + strOutFileName; 60 | File f = new File(tmpFile); 61 | if (f.exists()) { 62 | Log.d(TAG, "file exists " + strOutFileName); 63 | return f.getAbsolutePath(); 64 | } 65 | 66 | try (OutputStream myOutput = new FileOutputStream(f); 67 | InputStream myInput = context.getAssets().open(strOutFileName)) { 68 | byte[] buffer = new byte[1024]; 69 | int length = myInput.read(buffer); 70 | while (length > 0) { 71 | myOutput.write(buffer, 0, length); 72 | length = myInput.read(buffer); 73 | } 74 | myOutput.flush(); 75 | Log.d(TAG, "Copy task successful"); 76 | } catch (Exception e) { 77 | Log.e(TAG, "copyFile: Failed to copy", e); 78 | } finally { 79 | Log.d(TAG, "end copy file " + strOutFileName); 80 | } 81 | return f.getAbsolutePath(); 82 | } 83 | 84 | public void stopTts() { 85 | mWorker.interrupt(); 86 | } 87 | 88 | public void speak(String inputText, float speed, boolean interrupt) { 89 | if (interrupt) { 90 | stopTts(); 91 | } 92 | 93 | ThreadPoolManager.getInstance().execute(() -> 94 | mWorker.processInput(inputText, speed)); 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/tts/TtsPlayer.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts.tts; 2 | 3 | import android.media.AudioAttributes; 4 | import android.media.AudioFormat; 5 | import android.media.AudioManager; 6 | import android.media.AudioTrack; 7 | import android.util.Log; 8 | 9 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager; 10 | 11 | import java.util.concurrent.LinkedBlockingQueue; 12 | 13 | /** 14 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"} 15 | * Created 2020-07-20 18:22 16 | */ 17 | class TtsPlayer { 18 | private static final String TAG = "TtsPlayer"; 19 | 20 | private final AudioTrack mAudioTrack; 21 | 22 | private final static int FORMAT = AudioFormat.ENCODING_PCM_FLOAT; 23 | private final static int SAMPLERATE = 22050; 24 | private final static int CHANNEL = AudioFormat.CHANNEL_OUT_MONO; 25 | private final static int BUFFER_SIZE = AudioTrack.getMinBufferSize(SAMPLERATE, CHANNEL, FORMAT); 26 | private LinkedBlockingQueue mAudioQueue = new LinkedBlockingQueue<>(); 27 | private AudioData mCurrentAudioData; 28 | 29 | TtsPlayer() { 30 | mAudioTrack = new AudioTrack( 31 | new AudioAttributes.Builder() 32 | .setUsage(AudioAttributes.USAGE_MEDIA) 33 | .setContentType(AudioAttributes.CONTENT_TYPE_MUSIC) 34 | .build(), 35 | new AudioFormat.Builder() 36 | .setSampleRate(22050) 37 | .setEncoding(FORMAT) 38 | .setChannelMask(CHANNEL) 39 | .build(), 40 | BUFFER_SIZE, 41 | AudioTrack.MODE_STREAM, AudioManager.AUDIO_SESSION_ID_GENERATE 42 | ); 43 | mAudioTrack.play(); 44 | 45 | ThreadPoolManager.getInstance().getSingleExecutor("audio").execute(() -> { 46 | //noinspection InfiniteLoopStatement 47 | while (true) { 48 | try { 49 | mCurrentAudioData = mAudioQueue.take(); 50 | Log.d(TAG, "playing: " + mCurrentAudioData.text); 51 | int index = 0; 52 | while (index < mCurrentAudioData.audio.length && !mCurrentAudioData.isInterrupt) { 53 | int buffer = Math.min(BUFFER_SIZE, mCurrentAudioData.audio.length - index); 54 | mAudioTrack.write(mCurrentAudioData.audio, index, buffer, AudioTrack.WRITE_BLOCKING); 55 | index += BUFFER_SIZE; 56 | } 57 | } catch (Exception e) { 58 | Log.e(TAG, "Exception: ", e); 59 | } 60 | } 61 | }); 62 | } 63 | 64 | void play(AudioData audioData) { 65 | Log.d(TAG, "add audio data to queue: " + audioData.text); 66 | mAudioQueue.offer(audioData); 67 | } 68 | 69 | void interrupt() { 70 | mAudioQueue.clear(); 71 | if (mCurrentAudioData != null) { 72 | mCurrentAudioData.interrupt(); 73 | } 74 | } 75 | 76 | static class AudioData { 77 | private String text; 78 | private float[] audio; 79 | private boolean isInterrupt; 80 | 81 | AudioData(String text, float[] audio) { 82 | this.text = text; 83 | this.audio = audio; 84 | } 85 | 86 | private void interrupt() { 87 | isInterrupt = true; 88 | } 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml: -------------------------------------------------------------------------------- 1 | 7 | 12 | 13 | 19 | 22 | 25 | 26 | 27 | 28 | 34 | 35 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/layout/activity_main.xml: -------------------------------------------------------------------------------- 1 | 2 | 9 | 10 | 17 | 18 | 23 | 24 | 28 | 29 | 34 | 35 | 40 | 41 | 46 | 47 | 52 | 53 | 54 | 55 | 56 | 61 | 62 | 68 | 69 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-hdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-mdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /examples/android/app/src/main/res/values/colors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | #008577 4 | #00574B 5 | #D81B60 6 | 7 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/values/strings.xml: -------------------------------------------------------------------------------- 1 | 2 | TensorflowTTS 3 | 4 | -------------------------------------------------------------------------------- /examples/android/app/src/main/res/values/styles.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /examples/android/app/src/test/java/com/tensorspeech/tensorflowtts/ExampleUnitTest.java: -------------------------------------------------------------------------------- 1 | package com.tensorspeech.tensorflowtts; 2 | 3 | import org.junit.Test; 4 | 5 | import static org.junit.Assert.*; 6 | 7 | /** 8 | * Example local unit test, which will execute on the development machine (host). 9 | * 10 | * @see Testing documentation 11 | */ 12 | public class ExampleUnitTest { 13 | @Test 14 | public void addition_isCorrect() { 15 | assertEquals(4, 2 + 2); 16 | } 17 | } -------------------------------------------------------------------------------- /examples/android/build.gradle: -------------------------------------------------------------------------------- 1 | // Top-level build file where you can add configuration options common to all sub-projects/modules. 2 | 3 | buildscript { 4 | repositories { 5 | google() 6 | jcenter() 7 | 8 | } 9 | dependencies { 10 | classpath 'com.android.tools.build:gradle:3.5.2' 11 | 12 | // NOTE: Do not place your application dependencies here; they belong 13 | // in the individual module build.gradle files 14 | } 15 | } 16 | 17 | allprojects { 18 | repositories { 19 | google() 20 | jcenter() 21 | 22 | } 23 | } 24 | 25 | task clean(type: Delete) { 26 | delete rootProject.buildDir 27 | } 28 | -------------------------------------------------------------------------------- /examples/android/gradle.properties: -------------------------------------------------------------------------------- 1 | # Project-wide Gradle settings. 2 | # IDE (e.g. Android Studio) users: 3 | # Gradle settings configured through the IDE *will override* 4 | # any settings specified in this file. 5 | # For more details on how to configure your build environment visit 6 | # http://www.gradle.org/docs/current/userguide/build_environment.html 7 | # Specifies the JVM arguments used for the daemon process. 8 | # The setting is particularly useful for tweaking memory settings. 9 | org.gradle.jvmargs=-Xmx1536m 10 | # When configured, Gradle will run in incubating parallel mode. 11 | # This option should only be used with decoupled projects. More details, visit 12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects 13 | # org.gradle.parallel=true 14 | # AndroidX package structure to make it clearer which packages are bundled with the 15 | # Android operating system, and which are packaged with your app's APK 16 | # https://developer.android.com/topic/libraries/support-library/androidx-rn 17 | android.useAndroidX=true 18 | # Automatically convert third-party libraries to use AndroidX 19 | android.enableJetifier=true 20 | 21 | -------------------------------------------------------------------------------- /examples/android/gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /examples/android/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Mon Jul 20 11:21:10 CST 2020 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip 7 | -------------------------------------------------------------------------------- /examples/android/gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /examples/android/settings.gradle: -------------------------------------------------------------------------------- 1 | include ':app' 2 | rootProject.name='TensorflowTTS' 3 | -------------------------------------------------------------------------------- /examples/cpptflite/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | /build 3 | /models 4 | /lib 5 | lib.zip 6 | models.zip 7 | models_ljspeech.zip -------------------------------------------------------------------------------- /examples/cpptflite/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.6) 2 | PROJECT(TfliteTTS) 3 | 4 | option(MAPPER "Processor select (supported BAKER or LJSPEECH)") 5 | if (${MAPPER} STREQUAL "LJSPEECH") 6 | add_definitions(-DLJSPEECH) 7 | elseif (${MAPPER} STREQUAL "BAKER") 8 | add_definitions(-DBAKER) 9 | else () 10 | message(FATAL_ERROR "MAPPER is only supported BAKER or LJSPEECH") 11 | endif() 12 | 13 | message(STATUS "MAPPER is selected: "${MAPPER}) 14 | 15 | include_directories(lib) 16 | include_directories(lib/flatbuffers/include) 17 | include_directories(src) 18 | 19 | aux_source_directory(src DIR_SRCS) 20 | 21 | SET(CMAKE_CXX_COMPILER "g++") 22 | 23 | SET(CMAKE_CXX_FLAGS "-O3 -DNDEBUG -Wl,--no-as-needed -ldl -pthread -fpermissive") 24 | 25 | add_executable(demo demo/main.cpp ${DIR_SRCS}) 26 | 27 | find_library(tflite_LIB tensorflow-lite lib) 28 | 29 | target_link_libraries(demo ${tflite_LIB}) -------------------------------------------------------------------------------- /examples/cpptflite/demo/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "VoxCommon.h" 5 | #include "TTSFrontend.h" 6 | #include "TTSBackend.h" 7 | 8 | typedef struct 9 | { 10 | const char* mapperJson; 11 | unsigned int sampleRate; 12 | } Processor; 13 | 14 | int main(int argc, char* argv[]) 15 | { 16 | if (argc != 3) 17 | { 18 | fprintf(stderr, "demo text wavfile\n"); 19 | return 1; 20 | } 21 | 22 | const char* cmd = "python3 ../demo/text2ids.py"; 23 | 24 | Processor proc; 25 | #if LJSPEECH 26 | proc.mapperJson = "../../../tensorflow_tts/processor/pretrained/ljspeech_mapper.json"; 27 | proc.sampleRate = 22050; 28 | #elif BAKER 29 | proc.mapperJson = "../../../tensorflow_tts/processor/pretrained/baker_mapper.json"; 30 | proc.sampleRate = 24000; 31 | #endif 32 | 33 | const char* melgenfile = "../models/fastspeech2_quan.tflite"; 34 | const char* vocoderfile = "../models/mb_melgan.tflite"; 35 | 36 | // Init 37 | TTSFrontend ttsfrontend(proc.mapperJson, cmd); 38 | TTSBackend ttsbackend(melgenfile, vocoderfile); 39 | 40 | // Process 41 | ttsfrontend.text2ids(argv[1]); 42 | std::vector phonesIds = ttsfrontend.getPhoneIds(); 43 | 44 | ttsbackend.inference(phonesIds); 45 | MelGenData mel = ttsbackend.getMel(); 46 | std::vector audio = ttsbackend.getAudio(); 47 | 48 | std::cout << "********* Phones' ID *********" << std::endl; 49 | 50 | for (auto iter: phonesIds) 51 | { 52 | std::cout << iter << " "; 53 | } 54 | std::cout << std::endl; 55 | 56 | std::cout << "********* MEL SHAPE **********" << std::endl; 57 | for (auto index : mel.melShape) 58 | { 59 | std::cout << index << " "; 60 | } 61 | std::cout << std::endl; 62 | 63 | std::cout << "********* AUDIO LEN **********" << std::endl; 64 | std::cout << audio.size() << std::endl; 65 | 66 | VoxUtil::ExportWAV(argv[2], audio, proc.sampleRate); 67 | std::cout << "Wavfile: " << argv[2] << " creats." << std::endl; 68 | 69 | return 0; 70 | } -------------------------------------------------------------------------------- /examples/cpptflite/demo/text2ids.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | 4 | eng_pat = re.compile("[a-zA-Z]+") 5 | 6 | if __name__ == "__main__": 7 | argvs = sys.argv 8 | 9 | if (len(argvs) != 3): 10 | print("usage: python3 {} mapper.json text".format(argvs[0])) 11 | else: 12 | from tensorflow_tts.inference import AutoProcessor 13 | mapper_json = argvs[1] 14 | processor = AutoProcessor.from_pretrained(pretrained_path=mapper_json) 15 | 16 | input_text = argvs[2] 17 | 18 | if eng_pat.match(input_text): 19 | input_ids = processor.text_to_sequence(input_text) 20 | else: 21 | input_ids = processor.text_to_sequence(input_text, inference=True) 22 | 23 | print(" ".join(str(i) for i in input_ids)) -------------------------------------------------------------------------------- /examples/cpptflite/results/lj_ori_mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/lj_ori_mel.png -------------------------------------------------------------------------------- /examples/cpptflite/results/lj_tflite_mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/lj_tflite_mel.png -------------------------------------------------------------------------------- /examples/cpptflite/results/tflite_mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/tflite_mel.png -------------------------------------------------------------------------------- /examples/cpptflite/results/tflite_mel2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/tflite_mel2.png -------------------------------------------------------------------------------- /examples/cpptflite/src/MelGenerateTF.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "MelGenerateTF.h" 3 | 4 | MelGenData MelGenerateTF::infer(const std::vector inputIds) 5 | { 6 | 7 | MelGenData output; 8 | 9 | int32_t idsLen = inputIds.size(); 10 | 11 | std::vector> inputIndexsShape{ {1, idsLen}, {1}, {1}, {1}, {1} }; 12 | 13 | int32_t shapeI = 0; 14 | for (auto index : inputIndexs) 15 | { 16 | interpreter->ResizeInputTensor(index, inputIndexsShape[shapeI]); 17 | shapeI++; 18 | } 19 | 20 | TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); 21 | 22 | int32_t* input_ids_ptr = interpreter->typed_tensor(inputIndexs[0]); 23 | memcpy(input_ids_ptr, inputIds.data(), int_size * idsLen); 24 | 25 | int32_t* speaker_ids_ptr = interpreter->typed_tensor(inputIndexs[1]); 26 | memcpy(speaker_ids_ptr, _speakerId.data(), int_size); 27 | 28 | float* speed_ratios_ptr = interpreter->typed_tensor(inputIndexs[2]); 29 | memcpy(speed_ratios_ptr, _speedRatio.data(), float_size); 30 | 31 | float* speed_ratios2_ptr = interpreter->typed_tensor(inputIndexs[3]); 32 | memcpy(speed_ratios2_ptr, _f0Ratio.data(), float_size); 33 | 34 | float* speed_ratios3_ptr = interpreter->typed_tensor(inputIndexs[4]); 35 | memcpy(speed_ratios3_ptr, _enegyRatio.data(), float_size); 36 | 37 | TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); 38 | 39 | TfLiteTensor* melGenTensor = interpreter->tensor(ouptIndex); 40 | 41 | for (int i=0; idims->size; i++) 42 | { 43 | output.melShape.push_back(melGenTensor->dims->data[i]); 44 | } 45 | 46 | output.bytes = melGenTensor->bytes; 47 | 48 | output.melData = interpreter->typed_tensor(ouptIndex); 49 | 50 | return output; 51 | } -------------------------------------------------------------------------------- /examples/cpptflite/src/MelGenerateTF.h: -------------------------------------------------------------------------------- 1 | #ifndef MELGENERATETF_H 2 | #define MELGENERATETF_H 3 | 4 | #include "TfliteBase.h" 5 | 6 | class MelGenerateTF : public TfliteBase 7 | { 8 | public: 9 | 10 | MelGenerateTF(const char* modelFilename):TfliteBase(modelFilename), 11 | inputIndexs(interpreter->inputs()), 12 | ouptIndex(interpreter->outputs()[1]) {}; 13 | 14 | MelGenData infer(const std::vector inputIds); 15 | 16 | private: 17 | std::vector _speakerId{0}; 18 | std::vector _speedRatio{1.0}; 19 | std::vector _f0Ratio{1.0}; 20 | std::vector _enegyRatio{1.0}; 21 | 22 | const std::vector inputIndexs; 23 | const int32_t ouptIndex; 24 | 25 | }; 26 | 27 | #endif // MELGENERATETF_H -------------------------------------------------------------------------------- /examples/cpptflite/src/TTSBackend.cpp: -------------------------------------------------------------------------------- 1 | #include "TTSBackend.h" 2 | 3 | void TTSBackend::inference(std::vector phonesIds) 4 | { 5 | _mel = MelGen.infer(phonesIds); 6 | _audio = Vocoder.infer(_mel); 7 | } -------------------------------------------------------------------------------- /examples/cpptflite/src/TTSBackend.h: -------------------------------------------------------------------------------- 1 | #ifndef TTSBACKEND_H 2 | #define TTSBACKEND_H 3 | 4 | #include 5 | #include 6 | #include "MelGenerateTF.h" 7 | #include "VocoderTF.h" 8 | 9 | class TTSBackend 10 | { 11 | public: 12 | TTSBackend(const char* melgenfile, const char* vocoderfile): 13 | MelGen(melgenfile), Vocoder(vocoderfile) 14 | { 15 | std::cout << "TTSBackend Init" << std::endl; 16 | std::cout << melgenfile << std::endl; 17 | std::cout << vocoderfile << std::endl; 18 | }; 19 | 20 | void inference(std::vector phonesIds); 21 | 22 | MelGenData getMel() const {return _mel;} 23 | std::vector getAudio() const {return _audio;} 24 | 25 | private: 26 | MelGenerateTF MelGen; 27 | VocoderTF Vocoder; 28 | 29 | MelGenData _mel; 30 | std::vector _audio; 31 | }; 32 | 33 | #endif // TTSBACKEND_H -------------------------------------------------------------------------------- /examples/cpptflite/src/TTSFrontend.cpp: -------------------------------------------------------------------------------- 1 | #include "TTSFrontend.h" 2 | 3 | void TTSFrontend::text2ids(const std::string &text) 4 | { 5 | _phonesIds = strSplit(getCmdResult(text)); 6 | } 7 | 8 | std::string TTSFrontend::getCmdResult(const std::string &text) 9 | { 10 | char buf[1000] = {0}; 11 | FILE *pf = NULL; 12 | 13 | if( (pf = popen((_strCmd + " " + _mapperJson + " \"" + text + "\"").c_str(), "r")) == NULL ) 14 | { 15 | return ""; 16 | } 17 | 18 | while(fgets(buf, sizeof(buf), pf)) 19 | { 20 | continue; 21 | } 22 | 23 | std::string strResult(buf); 24 | pclose(pf); 25 | 26 | return strResult; 27 | } 28 | 29 | std::vector TTSFrontend::strSplit(const std::string &idStr) 30 | { 31 | std::vector idsVector; 32 | 33 | std::regex rgx ("\\s+"); 34 | std::sregex_token_iterator iter(idStr.begin(), idStr.end(), rgx, -1); 35 | std::sregex_token_iterator end; 36 | 37 | while (iter != end) { 38 | idsVector.push_back(stoi(*iter)); 39 | ++iter; 40 | } 41 | 42 | return idsVector; 43 | } -------------------------------------------------------------------------------- /examples/cpptflite/src/TTSFrontend.h: -------------------------------------------------------------------------------- 1 | #ifndef TTSFRONTEND_H 2 | #define TTSFRONTEND_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | class TTSFrontend 11 | { 12 | public: 13 | 14 | /** 15 | * Converting text to phoneIDs. 16 | * A tmporary method using command to process text in this demo, 17 | * which should to be replaced by a pronunciation processing module. 18 | *@param strCmd Command to call the method of processor.text_to_sequence() 19 | */ 20 | TTSFrontend(const std::string &mapperJson, 21 | const std::string &strCmd): 22 | _mapperJson(mapperJson), 23 | _strCmd(strCmd) 24 | { 25 | std::cout << "TTSFrontend Init" << std::endl; 26 | std::cout << _mapperJson << std::endl; 27 | std::cout << _strCmd << std::endl; 28 | }; 29 | 30 | void text2ids(const std::string &text); 31 | 32 | std::vector getPhoneIds() const {return _phonesIds;} 33 | private: 34 | 35 | const std::string _mapperJson; 36 | const std::string _strCmd; 37 | 38 | std::vector _phonesIds; 39 | 40 | std::string getCmdResult(const std::string &text); 41 | std::vector strSplit(const std::string &idStr); 42 | }; 43 | 44 | #endif // TTSFRONTEND_H -------------------------------------------------------------------------------- /examples/cpptflite/src/TfliteBase.cpp: -------------------------------------------------------------------------------- 1 | #include "TfliteBase.h" 2 | 3 | TfliteBase::TfliteBase(const char* modelFilename) 4 | { 5 | interpreterBuild(modelFilename); 6 | } 7 | 8 | TfliteBase::~TfliteBase() 9 | { 10 | ; 11 | } 12 | 13 | void TfliteBase::interpreterBuild(const char* modelFilename) 14 | { 15 | model = tflite::FlatBufferModel::BuildFromFile(modelFilename); 16 | 17 | TFLITE_MINIMAL_CHECK(model != nullptr); 18 | 19 | tflite::InterpreterBuilder builder(*model, resolver); 20 | 21 | builder(&interpreter); 22 | 23 | TFLITE_MINIMAL_CHECK(interpreter != nullptr); 24 | } 25 | -------------------------------------------------------------------------------- /examples/cpptflite/src/TfliteBase.h: -------------------------------------------------------------------------------- 1 | #ifndef TFLITEBASE_H 2 | #define TFLITEBASE_H 3 | 4 | #include "tensorflow/lite/interpreter.h" 5 | #include "tensorflow/lite/kernels/register.h" 6 | #include "tensorflow/lite/model.h" 7 | #include "tensorflow/lite/optional_debug_tools.h" 8 | 9 | #define TFLITE_MINIMAL_CHECK(x) \ 10 | if (!(x)) { \ 11 | fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ 12 | exit(1); \ 13 | } 14 | 15 | typedef struct 16 | { 17 | float *melData; 18 | std::vector melShape; 19 | int32_t bytes; 20 | } MelGenData; 21 | 22 | class TfliteBase 23 | { 24 | public: 25 | uint32_t int_size = sizeof(int32_t); 26 | uint32_t float_size = sizeof(float); 27 | 28 | std::unique_ptr interpreter; 29 | 30 | TfliteBase(const char* modelFilename); 31 | ~TfliteBase(); 32 | 33 | private: 34 | std::unique_ptr model; 35 | tflite::ops::builtin::BuiltinOpResolver resolver; 36 | 37 | void interpreterBuild(const char* modelFilename); 38 | }; 39 | 40 | #endif // TFLITEBASE_H -------------------------------------------------------------------------------- /examples/cpptflite/src/VocoderTF.cpp: -------------------------------------------------------------------------------- 1 | #include "VocoderTF.h" 2 | 3 | std::vector VocoderTF::infer(const MelGenData mel) 4 | { 5 | std::vector audio; 6 | 7 | interpreter->ResizeInputTensor(inputIndex, mel.melShape); 8 | TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); 9 | 10 | float* melDataPtr = interpreter->typed_input_tensor(inputIndex); 11 | memcpy(melDataPtr, mel.melData, mel.bytes); 12 | 13 | TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); 14 | 15 | TfLiteTensor* audioTensor = interpreter->tensor(outputIndex); 16 | 17 | float* outputPtr = interpreter->typed_output_tensor(0); 18 | 19 | int32_t audio_len = audioTensor->bytes / float_size; 20 | 21 | for (int i=0; iinputs()[0]), 12 | outputIndex(interpreter->outputs()[0]) {}; 13 | 14 | std::vector infer(const MelGenData mel); 15 | 16 | private: 17 | 18 | const int32_t inputIndex; 19 | const int32_t outputIndex; 20 | }; 21 | 22 | #endif // VOCODERTF_H -------------------------------------------------------------------------------- /examples/cpptflite/src/VoxCommon.cpp: -------------------------------------------------------------------------------- 1 | #include "VoxCommon.h" 2 | 3 | void VoxUtil::ExportWAV(const std::string & Filename, const std::vector& Data, unsigned SampleRate) { 4 | AudioFile::AudioBuffer Buffer; 5 | Buffer.resize(1); 6 | 7 | 8 | Buffer[0] = Data; 9 | size_t BufSz = Data.size(); 10 | 11 | 12 | AudioFile File; 13 | File.setAudioBuffer(Buffer); 14 | File.setAudioBufferSize(1, (int)BufSz); 15 | File.setNumSamplesPerChannel((int)BufSz); 16 | File.setNumChannels(1); 17 | File.setBitDepth(32); 18 | File.setSampleRate(SampleRate); 19 | 20 | File.save(Filename, AudioFileFormat::Wave); 21 | } 22 | -------------------------------------------------------------------------------- /examples/cpptflite/src/VoxCommon.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | /* 3 | VoxCommon.hpp : Defines common data structures and constants to be used with TensorVox 4 | */ 5 | #include 6 | #include 7 | #include "AudioFile.h" 8 | // #include "ext/CppFlow/include/Tensor.h" 9 | // #include 10 | 11 | #define IF_RETURN(cond,ret) if (cond){return ret;} 12 | #define VX_IF_EXCEPT(cond,ex) if (cond){throw std::invalid_argument(ex);} 13 | 14 | 15 | template 16 | struct TFTensor { 17 | std::vector Data; 18 | std::vector Shape; 19 | size_t TotalSize; 20 | }; 21 | 22 | namespace VoxUtil { 23 | 24 | void ExportWAV(const std::string& Filename, const std::vector& Data, unsigned SampleRate); 25 | } 26 | -------------------------------------------------------------------------------- /examples/cppwin/.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference.pro: -------------------------------------------------------------------------------- 1 | TEMPLATE = app 2 | CONFIG += console c++14 3 | CONFIG -= app_bundle 4 | CONFIG -= qt 5 | TARGET = TFTTSCppInfer 6 | 7 | HEADERS += \ 8 | TensorflowTTSCppInference/EnglishPhoneticProcessor.h \ 9 | TensorflowTTSCppInference/FastSpeech2.h \ 10 | TensorflowTTSCppInference/MultiBandMelGAN.h \ 11 | TensorflowTTSCppInference/TextTokenizer.h \ 12 | TensorflowTTSCppInference/Voice.h \ 13 | TensorflowTTSCppInference/VoxCommon.hpp \ 14 | TensorflowTTSCppInference/ext/AudioFile.hpp \ 15 | TensorflowTTSCppInference/ext/CppFlow/include/Model.h \ 16 | TensorflowTTSCppInference/ext/CppFlow/include/Tensor.h \ 17 | TensorflowTTSCppInference/ext/ZCharScanner.h \ 18 | TensorflowTTSCppInference/phonemizer.h \ 19 | TensorflowTTSCppInference/tfg2p.h \ 20 | 21 | SOURCES += \ 22 | TensorflowTTSCppInference/EnglishPhoneticProcessor.cpp \ 23 | TensorflowTTSCppInference/FastSpeech2.cpp \ 24 | TensorflowTTSCppInference/MultiBandMelGAN.cpp \ 25 | TensorflowTTSCppInference/TensorflowTTSCppInference.cpp \ 26 | TensorflowTTSCppInference/TextTokenizer.cpp \ 27 | TensorflowTTSCppInference/Voice.cpp \ 28 | TensorflowTTSCppInference/VoxCommon.cpp \ 29 | TensorflowTTSCppInference/ext/CppFlow/src/Model.cpp \ 30 | TensorflowTTSCppInference/ext/CppFlow/src/Tensor.cpp \ 31 | TensorflowTTSCppInference/phonemizer.cpp \ 32 | TensorflowTTSCppInference/tfg2p.cpp \ 33 | TensorflowTTSCppInference/ext/ZCharScanner.cpp \ 34 | 35 | INCLUDEPATH += $$PWD/deps/include 36 | LIBS += -L$$PWD/deps/lib -ltensorflow 37 | 38 | # GCC shits itself on memcp in AudioFile.hpp (l-1186) unless we add this 39 | QMAKE_CXXFLAGS += -fpermissive 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.28307.136 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TensorflowTTSCppInference", "TensorflowTTSCppInference\TensorflowTTSCppInference.vcxproj", "{67C98279-9BA3-49F7-9FE4-2C0DF77A2875}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x64.ActiveCfg = Debug|x64 17 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x64.Build.0 = Debug|x64 18 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x86.ActiveCfg = Debug|Win32 19 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x86.Build.0 = Debug|Win32 20 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x64.ActiveCfg = Release|x64 21 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x64.Build.0 = Release|x64 22 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x86.ActiveCfg = Release|Win32 23 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | GlobalSection(ExtensibilityGlobals) = postSolution 29 | SolutionGuid = {08E7CCCB-028D-4BFC-9CDC-E8957E50F8EA} 30 | EndGlobalSection 31 | EndGlobal 32 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/EnglishPhoneticProcessor.cpp: -------------------------------------------------------------------------------- 1 | #include "EnglishPhoneticProcessor.h" 2 | #include "VoxCommon.hpp" 3 | 4 | using namespace std; 5 | 6 | bool EnglishPhoneticProcessor::Initialize(Phonemizer* InPhn) 7 | { 8 | 9 | 10 | Phoner = InPhn; 11 | Tokenizer.SetAllowedChars(Phoner->GetGraphemeChars()); 12 | 13 | 14 | 15 | return true; 16 | } 17 | 18 | std::string EnglishPhoneticProcessor::ProcessTextPhonetic(const std::string& InText, const std::vector &InPhonemes,ETTSLanguage::Enum InLanguage) 19 | { 20 | if (!Phoner) 21 | return "ERROR"; 22 | 23 | 24 | 25 | vector Words = Tokenizer.Tokenize(InText,InLanguage); 26 | 27 | string Assemble = ""; 28 | // Make a copy of the dict passed. 29 | 30 | for (size_t w = 0; w < Words.size();w++) 31 | { 32 | const string& Word = Words[w]; 33 | 34 | if (Word.find("@") != std::string::npos){ 35 | std::string AddPh = Word.substr(1); // Remove the @ 36 | size_t OutId = 0; 37 | if (VoxUtil::FindInVec(AddPh,InPhonemes,OutId)) 38 | { 39 | Assemble.append(InPhonemes[OutId]); 40 | Assemble.append(" "); 41 | 42 | 43 | } 44 | 45 | continue; 46 | 47 | } 48 | 49 | 50 | 51 | 52 | size_t OverrideIdx = 0; 53 | 54 | 55 | 56 | std::string Res = Phoner->ProcessWord(Word,0.001f); 57 | 58 | // Cache the word in the override dict so next time we don't have to research it 59 | 60 | Assemble.append(Res); 61 | Assemble.append(" "); 62 | 63 | 64 | 65 | 66 | 67 | } 68 | 69 | 70 | // Delete last space if there is 71 | 72 | 73 | if (Assemble[Assemble.size() - 1] == ' ') 74 | Assemble.pop_back(); 75 | 76 | 77 | return Assemble; 78 | } 79 | 80 | EnglishPhoneticProcessor::EnglishPhoneticProcessor() 81 | { 82 | Phoner = nullptr; 83 | } 84 | 85 | EnglishPhoneticProcessor::EnglishPhoneticProcessor(Phonemizer *InPhn) 86 | { 87 | Initialize(InPhn); 88 | 89 | } 90 | 91 | 92 | 93 | EnglishPhoneticProcessor::~EnglishPhoneticProcessor() 94 | { 95 | if (Phoner) 96 | delete Phoner; 97 | } 98 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/EnglishPhoneticProcessor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "TextTokenizer.h" 3 | #include "phonemizer.h" 4 | 5 | class EnglishPhoneticProcessor 6 | { 7 | private: 8 | TextTokenizer Tokenizer; 9 | Phonemizer* Phoner; 10 | 11 | inline bool FileExists(const std::string& name) { 12 | std::ifstream f(name.c_str()); 13 | return f.good(); 14 | } 15 | 16 | public: 17 | bool Initialize(Phonemizer *InPhn); 18 | std::string ProcessTextPhonetic(const std::string& InText, const std::vector &InPhonemes,ETTSLanguage::Enum InLanguage); 19 | EnglishPhoneticProcessor(); 20 | EnglishPhoneticProcessor(Phonemizer *InPhn); 21 | ~EnglishPhoneticProcessor(); 22 | }; 23 | 24 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/FastSpeech2.cpp: -------------------------------------------------------------------------------- 1 | #include "FastSpeech2.h" 2 | #include 3 | 4 | 5 | FastSpeech2::FastSpeech2() 6 | { 7 | FastSpeech = nullptr; 8 | } 9 | 10 | FastSpeech2::FastSpeech2(const std::string & SavedModelFolder) 11 | { 12 | Initialize(SavedModelFolder); 13 | } 14 | 15 | 16 | bool FastSpeech2::Initialize(const std::string & SavedModelFolder) 17 | { 18 | try { 19 | FastSpeech = new Model(SavedModelFolder); 20 | } 21 | catch (...) { 22 | FastSpeech = nullptr; 23 | return false; 24 | 25 | } 26 | return true; 27 | } 28 | 29 | TFTensor FastSpeech2::DoInference(const std::vector& InputIDs, int32_t SpeakerID, float Speed, float Energy, float F0, int32_t EmotionID) 30 | { 31 | if (!FastSpeech) 32 | throw std::invalid_argument("Tried to do inference on unloaded or invalid model!"); 33 | 34 | // Convenience reference so that we don't have to constantly derefer pointers. 35 | Model& Mdl = *FastSpeech; 36 | 37 | // Define the tensors 38 | Tensor input_ids{ Mdl,"serving_default_input_ids" }; 39 | Tensor energy_ratios{ Mdl,"serving_default_energy_ratios" }; 40 | Tensor f0_ratios{ Mdl,"serving_default_f0_ratios" }; 41 | Tensor speaker_ids{ Mdl,"serving_default_speaker_ids" }; 42 | Tensor speed_ratios{ Mdl,"serving_default_speed_ratios" }; 43 | Tensor* emotion_ids = nullptr; 44 | 45 | // This is a multi-emotion model 46 | if (EmotionID != -1) 47 | { 48 | emotion_ids = new Tensor{Mdl,"serving_default_emotion_ids"}; 49 | emotion_ids->set_data(std::vector{EmotionID}); 50 | 51 | } 52 | 53 | 54 | // This is the shape of the input IDs, our equivalent to tf.expand_dims. 55 | std::vector InputIDShape = { 1, (int64_t)InputIDs.size() }; 56 | 57 | input_ids.set_data(InputIDs, InputIDShape); 58 | energy_ratios.set_data(std::vector{ Energy }); 59 | f0_ratios.set_data(std::vector{F0}); 60 | speaker_ids.set_data(std::vector{SpeakerID}); 61 | speed_ratios.set_data(std::vector{Speed}); 62 | 63 | // Define output tensor 64 | Tensor output{ Mdl,"StatefulPartitionedCall" }; 65 | 66 | 67 | // Vector of input tensors 68 | std::vector inputs = { &input_ids,&speaker_ids,&speed_ratios,&f0_ratios,&energy_ratios }; 69 | 70 | if (EmotionID != -1) 71 | inputs.push_back(emotion_ids); 72 | 73 | 74 | // Do inference 75 | FastSpeech->run(inputs, output); 76 | 77 | // Define output and return it 78 | TFTensor Output = VoxUtil::CopyTensor(output); 79 | 80 | // We allocated the emotion_ids tensor dynamically, delete it 81 | if (emotion_ids) 82 | delete emotion_ids; 83 | 84 | // We could just straight out define it in the return statement, but I like it more this way 85 | 86 | return Output; 87 | } 88 | 89 | FastSpeech2::~FastSpeech2() 90 | { 91 | if (FastSpeech) 92 | delete FastSpeech; 93 | } 94 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/FastSpeech2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ext/CppFlow/include/Model.h" 4 | #include "VoxCommon.hpp" 5 | class FastSpeech2 6 | { 7 | private: 8 | Model* FastSpeech; 9 | 10 | public: 11 | FastSpeech2(); 12 | FastSpeech2(const std::string& SavedModelFolder); 13 | 14 | /* 15 | Initialize and load the model 16 | 17 | -> SavedModelFolder: Folder where the .pb, variables, and other characteristics of the exported SavedModel 18 | <- Returns: (bool)Success 19 | */ 20 | bool Initialize(const std::string& SavedModelFolder); 21 | 22 | /* 23 | Do inference on a FastSpeech2 model. 24 | 25 | -> InputIDs: Input IDs of tokens for inference 26 | -> SpeakerID: ID of the speaker in the model to do inference on. If single speaker, always leave at 0. If multispeaker, refer to your model. 27 | -> Speed, Energy, F0: Parameters for FS2 inference. Leave at 1.f for defaults 28 | 29 | <- Returns: TFTensor with shape {1,,80} containing contents of mel spectrogram. 30 | */ 31 | TFTensor DoInference(const std::vector& InputIDs, int32_t SpeakerID = 0, float Speed = 1.f, float Energy = 1.f, float F0 = 1.f,int32_t EmotionID = -1); 32 | 33 | 34 | 35 | ~FastSpeech2(); 36 | }; 37 | 38 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/MultiBandMelGAN.cpp: -------------------------------------------------------------------------------- 1 | #include "MultiBandMelGAN.h" 2 | #include 3 | #define IF_EXCEPT(cond,ex) if (cond){throw std::invalid_argument(ex);} 4 | 5 | 6 | 7 | bool MultiBandMelGAN::Initialize(const std::string & VocoderPath) 8 | { 9 | try { 10 | MelGAN = new Model(VocoderPath); 11 | } 12 | catch (...) { 13 | MelGAN = nullptr; 14 | return false; 15 | 16 | } 17 | return true; 18 | 19 | 20 | } 21 | 22 | TFTensor MultiBandMelGAN::DoInference(const TFTensor& InMel) 23 | { 24 | IF_EXCEPT(!MelGAN, "Tried to infer MB-MelGAN on uninitialized model!!!!") 25 | 26 | // Convenience reference so that we don't have to constantly derefer pointers. 27 | Model& Mdl = *MelGAN; 28 | 29 | Tensor input_mels{ Mdl,"serving_default_mels" }; 30 | input_mels.set_data(InMel.Data, InMel.Shape); 31 | 32 | Tensor out_audio{ Mdl,"StatefulPartitionedCall" }; 33 | 34 | MelGAN->run(input_mels, out_audio); 35 | 36 | TFTensor RetTensor = VoxUtil::CopyTensor(out_audio); 37 | 38 | return RetTensor; 39 | 40 | 41 | } 42 | 43 | MultiBandMelGAN::MultiBandMelGAN() 44 | { 45 | MelGAN = nullptr; 46 | } 47 | 48 | 49 | MultiBandMelGAN::~MultiBandMelGAN() 50 | { 51 | if (MelGAN) 52 | delete MelGAN; 53 | 54 | } 55 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/MultiBandMelGAN.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ext/CppFlow/include/Model.h" 4 | #include "VoxCommon.hpp" 5 | class MultiBandMelGAN 6 | { 7 | private: 8 | Model* MelGAN; 9 | 10 | 11 | public: 12 | bool Initialize(const std::string& VocoderPath); 13 | 14 | 15 | // Do MultiBand MelGAN inference including PQMF 16 | // -> InMel: Mel spectrogram (shape [1, xx, 80]) 17 | // <- Returns: Tensor data [4, xx, 1] 18 | TFTensor DoInference(const TFTensor& InMel); 19 | 20 | MultiBandMelGAN(); 21 | ~MultiBandMelGAN(); 22 | }; 23 | 24 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/TextTokenizer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "VoxCommon.hpp" 5 | 6 | class TextTokenizer 7 | { 8 | private: 9 | std::string AllowedChars; 10 | std::string IntToStr(int number); 11 | 12 | std::vector ExpandNumbers(const std::vector& SpaceTokens); 13 | public: 14 | TextTokenizer(); 15 | ~TextTokenizer(); 16 | 17 | std::vector Tokenize(const std::string& InTxt,ETTSLanguage::Enum Language = ETTSLanguage::English); 18 | void SetAllowedChars(const std::string &value); 19 | }; 20 | 21 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/Voice.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "FastSpeech2.h" 4 | #include "MultiBandMelGAN.h" 5 | #include "EnglishPhoneticProcessor.h" 6 | 7 | 8 | class Voice 9 | { 10 | private: 11 | FastSpeech2 MelPredictor; 12 | MultiBandMelGAN Vocoder; 13 | EnglishPhoneticProcessor Processor; 14 | VoiceInfo VoxInfo; 15 | 16 | 17 | 18 | std::vector Phonemes; 19 | std::vector PhonemeIDs; 20 | 21 | 22 | 23 | std::vector PhonemesToID(const std::string& InTxt); 24 | 25 | std::vector Speakers; 26 | std::vector Emotions; 27 | 28 | void ReadPhonemes(const std::string& PhonemePath); 29 | 30 | void ReadSpeakers(const std::string& SpeakerPath); 31 | 32 | void ReadEmotions(const std::string& EmotionPath); 33 | 34 | 35 | void ReadModelInfo(const std::string& ModelInfoPath); 36 | 37 | std::vector GetLinedFile(const std::string& Path); 38 | 39 | 40 | std::string ModelInfo; 41 | 42 | public: 43 | /* Voice constructor, arguments obligatory. 44 | -> VoxPath: Path of folder where models are contained. 45 | -- Must be a folder without an ending slash with UNIX slashes, can be relative or absolute (eg: MyVoices/Karen) 46 | -- The folder must contain the following elements: 47 | --- melgen: Folder generated where a FastSpeech2 model was saved as SavedModel, with .pb, variables, etc 48 | --- vocoder: Folder where a Multi-Band MelGAN model was saved as SavedModel. 49 | --- info.json: Model information 50 | --- phonemes.txt: Tab delimited file containing PHONEME \t ID, for inputting to the FS2 model. 51 | 52 | --- If multispeaker, a lined .txt file called speakers.txt 53 | --- If multi-emotion, a lined .txt file called emotions.txt 54 | 55 | */ 56 | Voice(const std::string& VoxPath, const std::string& inName,Phonemizer* InPhn); 57 | 58 | void AddPhonemizer(Phonemizer* InPhn); 59 | 60 | 61 | std::vector Vocalize(const std::string& Prompt, float Speed = 1.f, int32_t SpeakerID = 0, float Energy = 1.f, float F0 = 1.f,int32_t EmotionID = -1); 62 | 63 | std::string Name; 64 | inline const VoiceInfo& GetInfo(){return VoxInfo;} 65 | 66 | inline const std::vector& GetSpeakers(){return Speakers;} 67 | inline const std::vector& GetEmotions(){return Emotions;} 68 | 69 | inline const std::string& GetModelInfo(){return ModelInfo;} 70 | 71 | ~Voice(); 72 | }; 73 | 74 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/VoxCommon.cpp: -------------------------------------------------------------------------------- 1 | #include "VoxCommon.hpp" 2 | #include "ext/json.hpp" 3 | using namespace nlohmann; 4 | 5 | const std::vector Text2MelNames = {"FastSpeech2","Tacotron2"}; 6 | const std::vector VocoderNames = {"Multi-Band MelGAN"}; 7 | const std::vector RepoNames = {"TensorflowTTS"}; 8 | 9 | const std::vector LanguageNames = {"English","Spanish"}; 10 | 11 | 12 | void VoxUtil::ExportWAV(const std::string & Filename, const std::vector& Data, unsigned SampleRate) { 13 | AudioFile::AudioBuffer Buffer; 14 | Buffer.resize(1); 15 | 16 | 17 | Buffer[0] = Data; 18 | size_t BufSz = Data.size(); 19 | 20 | 21 | AudioFile File; 22 | File.setAudioBuffer(Buffer); 23 | File.setAudioBufferSize(1, (int)BufSz); 24 | File.setNumSamplesPerChannel((int)BufSz); 25 | File.setNumChannels(1); 26 | File.setBitDepth(32); 27 | File.setSampleRate(SampleRate); 28 | 29 | File.save(Filename, AudioFileFormat::Wave); 30 | 31 | 32 | 33 | } 34 | 35 | VoiceInfo VoxUtil::ReadModelJSON(const std::string &InfoFilename) 36 | { 37 | const size_t MaxNoteSize = 80; 38 | 39 | std::ifstream JFile(InfoFilename); 40 | json JS; 41 | 42 | JFile >> JS; 43 | 44 | 45 | JFile.close(); 46 | 47 | auto Arch = JS["architecture"]; 48 | 49 | ArchitectureInfo CuArch; 50 | CuArch.Repo = Arch["repo"].get(); 51 | CuArch.Text2Mel = Arch["text2mel"].get(); 52 | CuArch.Vocoder = Arch["vocoder"].get(); 53 | 54 | // Now fill the strings 55 | CuArch.s_Repo = RepoNames[CuArch.Repo]; 56 | CuArch.s_Text2Mel = Text2MelNames[CuArch.Text2Mel]; 57 | CuArch.s_Vocoder = VocoderNames[CuArch.Vocoder]; 58 | 59 | 60 | uint32_t Lang = JS["language"].get(); 61 | VoiceInfo Inf{JS["name"].get(), 62 | JS["author"].get(), 63 | JS["version"].get(), 64 | JS["description"].get(), 65 | CuArch, 66 | JS["note"].get(), 67 | JS["sarate"].get(), 68 | Lang, 69 | LanguageNames[Lang], 70 | " " + JS["pad"].get()}; // Add a space for separation since we directly append the value to the prompt 71 | 72 | if (Inf.Note.size() > MaxNoteSize) 73 | Inf.Note = Inf.Note.substr(0,MaxNoteSize); 74 | 75 | return Inf; 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | } 84 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/VoxCommon.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | /* 3 | VoxCommon.hpp : Defines common data structures and constants to be used with TensorVox 4 | */ 5 | #include 6 | #include 7 | #include "ext/AudioFile.hpp" 8 | #include "ext/CppFlow/include/Tensor.h" 9 | 10 | #define IF_RETURN(cond,ret) if (cond){return ret;} 11 | 12 | 13 | 14 | template 15 | struct TFTensor { 16 | std::vector Data; 17 | std::vector Shape; 18 | size_t TotalSize; 19 | 20 | }; 21 | 22 | 23 | namespace ETTSRepo { 24 | enum Enum{ 25 | TensorflowTTS = 0, 26 | MozillaTTS // not implemented yet 27 | }; 28 | 29 | } 30 | namespace EText2MelModel { 31 | enum Enum{ 32 | FastSpeech2 = 0, 33 | Tacotron2 // not implemented yet 34 | }; 35 | 36 | } 37 | 38 | namespace EVocoderModel{ 39 | enum Enum{ 40 | MultiBandMelGAN = 0 41 | }; 42 | } 43 | 44 | namespace ETTSLanguage{ 45 | enum Enum{ 46 | English = 0, 47 | Spanish 48 | }; 49 | 50 | } 51 | 52 | 53 | 54 | struct ArchitectureInfo{ 55 | int Repo; 56 | int Text2Mel; 57 | int Vocoder; 58 | 59 | // String versions of the info, for displaying. 60 | // We want boilerplate int index to str conversion code to be low. 61 | std::string s_Repo; 62 | std::string s_Text2Mel; 63 | std::string s_Vocoder; 64 | 65 | }; 66 | struct VoiceInfo{ 67 | std::string Name; 68 | std::string Author; 69 | int32_t Version; 70 | std::string Description; 71 | ArchitectureInfo Architecture; 72 | std::string Note; 73 | 74 | uint32_t SampleRate; 75 | 76 | uint32_t Language; 77 | std::string s_Language; 78 | 79 | std::string EndPadding; 80 | 81 | 82 | 83 | }; 84 | 85 | namespace VoxUtil { 86 | 87 | VoiceInfo ReadModelJSON(const std::string& InfoFilename); 88 | 89 | 90 | template 91 | TFTensor CopyTensor(Tensor& InTens) 92 | { 93 | std::vector Data = InTens.get_data(); 94 | std::vector Shape = InTens.get_shape(); 95 | size_t TotalSize = 1; 96 | for (const int64_t& Dim : Shape) 97 | TotalSize *= Dim; 98 | 99 | return TFTensor{Data, Shape, TotalSize}; 100 | 101 | 102 | } 103 | 104 | template 105 | bool FindInVec(V In, const std::vector& Vec, size_t& OutIdx, size_t start = 0) { 106 | for (size_t xx = start;xx < Vec.size();xx++) 107 | { 108 | if (Vec[xx] == In) { 109 | OutIdx = xx; 110 | return true; 111 | 112 | } 113 | 114 | } 115 | 116 | 117 | return false; 118 | 119 | } 120 | template 121 | bool FindInVec2(V In, const std::vector& Vec, size_t& OutIdx, size_t start = 0) { 122 | for (size_t xx = start;xx < Vec.size();xx++) 123 | { 124 | if (Vec[xx] == In) { 125 | OutIdx = xx; 126 | return true; 127 | 128 | } 129 | 130 | } 131 | 132 | 133 | return false; 134 | 135 | } 136 | 137 | void ExportWAV(const std::string& Filename, const std::vector& Data, unsigned SampleRate); 138 | } 139 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/include/Model.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by sergio on 12/05/19. 3 | // 4 | 5 | #ifndef CPPFLOW_MODEL_H 6 | #define CPPFLOW_MODEL_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #pragma warning(push, 0) 16 | #include 17 | #include "Tensor.h" 18 | #pragma warning(pop) 19 | class Tensor; 20 | 21 | class Model { 22 | public: 23 | // Pass a path to the model file and optional Tensorflow config options. See examples/load_model/main.cpp. 24 | explicit Model(const std::string& model_filename, const std::vector& config_options = {}); 25 | 26 | // Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy 27 | // the contents of the pointer (i guess dereferencing won't do a deep copy) 28 | Model(const Model &model) = delete; 29 | Model(Model &&model) = default; 30 | Model& operator=(const Model &model) = delete; 31 | Model& operator=(Model &&model) = default; 32 | 33 | ~Model(); 34 | 35 | void init(); 36 | void restore(const std::string& ckpt); 37 | void save(const std::string& ckpt); 38 | void restore_savedmodel(const std::string& savedmdl); 39 | std::vector get_operations() const; 40 | 41 | // Original Run 42 | void run(const std::vector& inputs, const std::vector& outputs); 43 | 44 | // Run with references 45 | void run(Tensor& input, const std::vector& outputs); 46 | void run(const std::vector& inputs, Tensor& output); 47 | void run(Tensor& input, Tensor& output); 48 | 49 | // Run with pointers 50 | void run(Tensor* input, const std::vector& outputs); 51 | void run(const std::vector& inputs, Tensor* output); 52 | void run(Tensor* input, Tensor* output); 53 | 54 | private: 55 | TF_Graph* graph; 56 | TF_Session* session; 57 | TF_Status* status; 58 | 59 | // Read a file from a string 60 | static TF_Buffer* read(const std::string&); 61 | 62 | bool status_check(bool throw_exc) const; 63 | void error_check(bool condition, const std::string &error) const; 64 | 65 | public: 66 | friend class Tensor; 67 | }; 68 | 69 | 70 | #endif //CPPFLOW_MODEL_H 71 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/include/Tensor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by sergio on 13/05/19. 3 | // 4 | 5 | #ifndef CPPFLOW_TENSOR_H 6 | #define CPPFLOW_TENSOR_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | // Prevent warnings from Tensorflow C API headers 16 | 17 | #pragma warning(push, 0) 18 | #include 19 | #include "Model.h" 20 | #pragma warning(pop) 21 | 22 | class Model; 23 | 24 | class Tensor { 25 | public: 26 | Tensor(const Model& model, const std::string& operation); 27 | 28 | // Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy 29 | // the contents of the pointer (i guess dereferencing won't do a deep copy) 30 | Tensor(const Tensor &tensor) = delete; 31 | Tensor(Tensor &&tensor) = default; 32 | Tensor& operator=(const Tensor &tensor) = delete; 33 | Tensor& operator=(Tensor &&tensor) = default; 34 | 35 | ~Tensor(); 36 | 37 | void clean(); 38 | 39 | template 40 | void set_data(std::vector new_data); 41 | 42 | template 43 | void set_data(std::vector new_data, const std::vector& new_shape); 44 | 45 | void set_data(const std::string & new_data, Model & inmodel); 46 | template 47 | std::vector get_data(); 48 | 49 | std::vector get_shape(); 50 | 51 | private: 52 | TF_Tensor* val; 53 | TF_Output op; 54 | TF_DataType type; 55 | std::vector shape; 56 | std::unique_ptr> actual_shape; 57 | void* data; 58 | int flag; 59 | 60 | // Aux functions 61 | void error_check(bool condition, const std::string& error); 62 | 63 | 64 | 65 | 66 | template 67 | static TF_DataType deduce_type(); 68 | 69 | void deduce_shape(); 70 | 71 | public: 72 | friend class Model; 73 | }; 74 | 75 | #endif //CPPFLOW_TENSOR_H 76 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/ext/ZCharScanner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define GBasicCharScanner ZStringDelimiter 4 | 5 | #include 6 | #include 7 | 8 | #define ZSDEL_USE_STD_STRING 9 | #ifndef ZSDEL_USE_STD_STRING 10 | #include "golem_string.h" 11 | #else 12 | #define GString std::string 13 | #endif 14 | 15 | typedef std::vector::const_iterator TokenIterator; 16 | 17 | // ZStringDelimiter 18 | // ============== 19 | // Simple class to delimit and split strings. 20 | // You can use operator[] to access them 21 | // Or you can use the itBegin() and itEnd() to get some iterators 22 | // ================= 23 | class ZStringDelimiter 24 | { 25 | private: 26 | int key_search(const GString & s, const GString & key); 27 | void UpdateTokens(); 28 | std::vector m_vTokens; 29 | std::vector m_vDelimiters; 30 | 31 | GString m_sString; 32 | 33 | void DelimStr(const GString& s, const GString& delimiter, const bool& removeEmptyEntries = false); 34 | void BarRange(const int& min, const int& max); 35 | void Bar(const int& pos); 36 | size_t tokenIndex; 37 | public: 38 | ZStringDelimiter(); 39 | bool PgBar; 40 | 41 | #ifdef _AFX_ALL_WARNINGS 42 | CProgressCtrl* m_pBar; 43 | #endif 44 | 45 | ZStringDelimiter(const GString& in_iStr) { 46 | m_sString = in_iStr; 47 | PgBar = false; 48 | 49 | } 50 | 51 | bool GetFirstToken(GString& in_out); 52 | bool GetNextToken(GString& in_sOut); 53 | 54 | // std::String alts 55 | 56 | size_t szTokens() { return m_vTokens.size(); } 57 | GString operator[](const size_t& in_index); 58 | 59 | GString Reassemble(const GString & delim, const int & nelem = -1); 60 | 61 | // Override to reassemble provided tokens. 62 | GString Reassemble(const GString & delim, const std::vector& Strs,int nelem = -1); 63 | 64 | // Get a const reference to the tokens 65 | const std::vector& GetTokens() { return m_vTokens; } 66 | 67 | TokenIterator itBegin() { return m_vTokens.begin(); } 68 | TokenIterator itEnd() { return m_vTokens.end(); } 69 | 70 | void SetText(const GString& in_Txt) { 71 | m_sString = in_Txt; 72 | if (m_vDelimiters.size()) 73 | UpdateTokens(); 74 | } 75 | void AddDelimiter(const GString& in_Delim); 76 | 77 | ~ZStringDelimiter(); 78 | }; 79 | 80 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/phonemizer.h: -------------------------------------------------------------------------------- 1 | #ifndef PHONEMIZER_H 2 | #define PHONEMIZER_H 3 | #include "tfg2p.h" 4 | #include 5 | #include 6 | #include 7 | 8 | struct IdStr{ 9 | int32_t ID; 10 | std::string STR; 11 | }; 12 | 13 | 14 | struct StrStr{ 15 | std::string Word; 16 | std::string Phn; 17 | }; 18 | 19 | 20 | class Phonemizer 21 | { 22 | private: 23 | TFG2P G2pModel; 24 | 25 | std::vector CharId; 26 | std::vector PhnId; 27 | 28 | 29 | 30 | 31 | 32 | 33 | std::vector GetDelimitedFile(const std::string& InFname); 34 | 35 | 36 | // Sorry, can't use set, unordered_map or any other types. (I tried) 37 | std::vector Dictionary; 38 | 39 | void LoadDictionary(const std::string& InDictFn); 40 | 41 | std::string DictLookup(const std::string& InWord); 42 | 43 | 44 | 45 | std::string PhnLanguage; 46 | public: 47 | Phonemizer(); 48 | /* 49 | * Initialize a phonemizer 50 | * Expects: 51 | * - Two files consisting in TOKEN \t ID: 52 | * -- char2id.txt: Translation from input character to ID the model can accept 53 | * -- phn2id.txt: Translation from output ID from the model to phoneme 54 | * - A model/ folder where a G2P-Tensorflow model was saved as SavedModel 55 | * - dict.txt: Phonetic dictionary. First it searches the word there and if it can't be found then it uses the model. 56 | 57 | */ 58 | bool Initialize(const std::string InPath); 59 | std::string ProcessWord(const std::string& InWord, float Temperature = 0.1f); 60 | std::string GetPhnLanguage() const; 61 | void SetPhnLanguage(const std::string &value); 62 | 63 | std::string GetGraphemeChars(); 64 | 65 | }; 66 | 67 | 68 | bool operator<(const StrStr& right,const StrStr& left); 69 | #endif // PHONEMIZER_H 70 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/tfg2p.cpp: -------------------------------------------------------------------------------- 1 | #include "tfg2p.h" 2 | #include 3 | TFG2P::TFG2P() 4 | { 5 | G2P = nullptr; 6 | 7 | } 8 | 9 | TFG2P::TFG2P(const std::string &SavedModelFolder) 10 | { 11 | G2P = nullptr; 12 | 13 | Initialize(SavedModelFolder); 14 | } 15 | 16 | bool TFG2P::Initialize(const std::string &SavedModelFolder) 17 | { 18 | try { 19 | 20 | G2P = new Model(SavedModelFolder); 21 | 22 | } 23 | catch (...) { 24 | G2P = nullptr; 25 | return false; 26 | 27 | } 28 | return true; 29 | } 30 | 31 | TFTensor TFG2P::DoInference(const std::vector &InputIDs, float Temperature) 32 | { 33 | if (!G2P) 34 | throw std::invalid_argument("Tried to do inference on unloaded or invalid model!"); 35 | 36 | // Convenience reference so that we don't have to constantly derefer pointers. 37 | Model& Mdl = *G2P; 38 | 39 | 40 | // Convenience reference so that we don't have to constantly derefer pointers. 41 | 42 | Tensor input_ids{ Mdl,"serving_default_input_ids" }; 43 | Tensor input_len{Mdl,"serving_default_input_len"}; 44 | Tensor input_temp{Mdl,"serving_default_input_temperature"}; 45 | 46 | input_ids.set_data(InputIDs, std::vector{(int64_t)InputIDs.size()}); 47 | input_len.set_data(std::vector{(int32_t)InputIDs.size()}); 48 | input_temp.set_data(std::vector{Temperature}); 49 | 50 | 51 | 52 | std::vector Inputs {&input_ids,&input_len,&input_temp}; 53 | Tensor out_ids{ Mdl,"StatefulPartitionedCall" }; 54 | 55 | Mdl.run(Inputs, out_ids); 56 | 57 | TFTensor RetTensor = VoxUtil::CopyTensor(out_ids); 58 | 59 | return RetTensor; 60 | 61 | 62 | } 63 | 64 | TFG2P::~TFG2P() 65 | { 66 | if (G2P) 67 | delete G2P; 68 | 69 | } 70 | -------------------------------------------------------------------------------- /examples/cppwin/TensorflowTTSCppInference/tfg2p.h: -------------------------------------------------------------------------------- 1 | #ifndef TFG2P_H 2 | #define TFG2P_H 3 | #include "ext/CppFlow/include/Model.h" 4 | #include "VoxCommon.hpp" 5 | 6 | 7 | class TFG2P 8 | { 9 | private: 10 | Model* G2P; 11 | 12 | public: 13 | TFG2P(); 14 | TFG2P(const std::string& SavedModelFolder); 15 | 16 | /* 17 | Initialize and load the model 18 | 19 | -> SavedModelFolder: Folder where the .pb, variables, and other characteristics of the exported SavedModel 20 | <- Returns: (bool)Success 21 | */ 22 | bool Initialize(const std::string& SavedModelFolder); 23 | 24 | /* 25 | Do inference on a G2P-TF-RNN model. 26 | 27 | -> InputIDs: Input IDs of tokens for inference 28 | -> Temperature: Temperature of the RNN, values higher than 0.1 cause instability. 29 | 30 | <- Returns: TFTensor containing phoneme IDs 31 | */ 32 | TFTensor DoInference(const std::vector& InputIDs, float Temperature = 0.1f); 33 | 34 | ~TFG2P(); 35 | 36 | }; 37 | 38 | #endif // TFG2P_H 39 | -------------------------------------------------------------------------------- /examples/fastspeech/fig/fastspeech.v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/fastspeech/fig/fastspeech.v1.png -------------------------------------------------------------------------------- /examples/fastspeech2_libritts/README.md: -------------------------------------------------------------------------------- 1 | # Fast speech 2 multi-speaker english lang based 2 | 3 | ## Prepare 4 | Everything is done from main repo folder so TensorflowTTS/ 5 | 6 | 0. Optional* [Download](http://www.openslr.org/60/) and prepare libritts (helper to prepare libri in examples/fastspeech2_libritts/libri_experiment/prepare_libri.ipynb) 7 | - Dataset structure after finish this step: 8 | ``` 9 | |- TensorFlowTTS/ 10 | | |- LibriTTS/ 11 | | |- |- train-clean-100/ 12 | | |- |- SPEAKERS.txt 13 | | |- |- ... 14 | | |- libritts/ 15 | | |- |- 200/ 16 | | |- |- |- 200_124139_000001_000000.txt 17 | | |- |- |- 200_124139_000001_000000.wav 18 | | |- |- |- ... 19 | | |- |- 250/ 20 | | |- |- ... 21 | | |- tensorflow_tts/ 22 | | |- models/ 23 | | |- ... 24 | ``` 25 | 1. Extract Duration (use examples/mfa_extraction or pretrained tacotron2) 26 | 2. Optional* build docker 27 | - ``` 28 | bash examples/fastspeech2_libritts/scripts/build.sh 29 | ``` 30 | 3. Optional* run docker 31 | - ``` 32 | bash examples/fastspeech2_libritts/scripts/interactive.sh 33 | ``` 34 | 4. Preprocessing: 35 | - ``` 36 | tensorflow-tts-preprocess --rootdir ./libritts \ 37 | --outdir ./dump_libritts \ 38 | --config preprocess/libritts_preprocess.yaml \ 39 | --dataset libritts 40 | ``` 41 | 42 | 5. Normalization: 43 | - ``` 44 | tensorflow-tts-normalize --rootdir ./dump_libritts \ 45 | --outdir ./dump_libritts \ 46 | --config preprocess/libritts_preprocess.yaml \ 47 | --dataset libritts 48 | ``` 49 | 50 | 6. Change CharactorDurationF0EnergyMelDataset speaker mapper in fastspeech2_dataset to match your dataset (if you use libri with mfa_extraction you didnt need to change anything) 51 | 7. Change train_libri.sh to match your dataset and run: 52 | - ``` 53 | bash examples/fastspeech2_libritts/scripts/train_libri.sh 54 | ``` 55 | 8. Optional* If u have problems with tensor sizes mismatch check step 5 in `examples/mfa_extraction` directory 56 | 57 | ## Comments 58 | 59 | This version is using popular train.txt '|' split used in other repos. Training files should looks like this => 60 | 61 | Wav Path | Text | Speaker Name 62 | 63 | Wav Path2 | Text | Speaker Name 64 | 65 | -------------------------------------------------------------------------------- /examples/fastspeech2_libritts/scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker build --rm -t tftts -f examples/fastspeech2_libritts/scripts/docker/Dockerfile . 3 | -------------------------------------------------------------------------------- /examples/fastspeech2_libritts/scripts/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.2.0-gpu 2 | RUN apt-get update 3 | RUN apt-get install -y zsh tmux wget git libsndfile1 4 | ADD . /workspace/tts 5 | WORKDIR /workspace/tts 6 | RUN pip install . 7 | 8 | -------------------------------------------------------------------------------- /examples/fastspeech2_libritts/scripts/interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker run --gpus all --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -it --rm --ipc=host -p 8888:8888 -v $PWD:/workspace/tts/ tftts bash 3 | -------------------------------------------------------------------------------- /examples/fastspeech2_libritts/scripts/train_libri.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python examples/fastspeech2_libritts/train_fastspeech2.py \ 2 | --train-dir ./dump/train/ \ 3 | --dev-dir ./dump/valid/ \ 4 | --outdir ./examples/fastspeech2_libritts/outdir_libri/ \ 5 | --config ./examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml \ 6 | --use-norm 1 \ 7 | --f0-stat ./dump/stats_f0.npy \ 8 | --energy-stat ./dump/stats_energy.npy \ 9 | --mixed_precision 1 \ 10 | --dataset_config preprocess/libritts_preprocess.yaml \ 11 | --dataset_stats dump/stats.npy -------------------------------------------------------------------------------- /examples/ios/.gitignore: -------------------------------------------------------------------------------- 1 | Pods 2 | *.xcworkspace 3 | xcuserdata 4 | -------------------------------------------------------------------------------- /examples/ios/Podfile: -------------------------------------------------------------------------------- 1 | platform :ios, '14.0' 2 | 3 | target 'TF_TTS_Demo' do 4 | pod 'TensorFlowLiteSwift' 5 | pod 'TensorFlowLiteSelectTfOps' 6 | end 7 | -------------------------------------------------------------------------------- /examples/ios/Podfile.lock: -------------------------------------------------------------------------------- 1 | PODS: 2 | - TensorFlowLiteC (2.4.0): 3 | - TensorFlowLiteC/Core (= 2.4.0) 4 | - TensorFlowLiteC/Core (2.4.0) 5 | - TensorFlowLiteSelectTfOps (2.4.0) 6 | - TensorFlowLiteSwift (2.4.0): 7 | - TensorFlowLiteSwift/Core (= 2.4.0) 8 | - TensorFlowLiteSwift/Core (2.4.0): 9 | - TensorFlowLiteC (= 2.4.0) 10 | 11 | DEPENDENCIES: 12 | - TensorFlowLiteSelectTfOps 13 | - TensorFlowLiteSwift 14 | 15 | SPEC REPOS: 16 | trunk: 17 | - TensorFlowLiteC 18 | - TensorFlowLiteSelectTfOps 19 | - TensorFlowLiteSwift 20 | 21 | SPEC CHECKSUMS: 22 | TensorFlowLiteC: 09f8ac75a76caeadb19bcfa694e97454cc1ecf87 23 | TensorFlowLiteSelectTfOps: f8053d3ec72032887b832d2d060015d8b7031144 24 | TensorFlowLiteSwift: f062dc1178120100d825d7799fd9f115b4a37fee 25 | 26 | PODFILE CHECKSUM: 12da12fb22671b6cdc578320043f5d310043aded 27 | 28 | COCOAPODS: 1.10.1 29 | -------------------------------------------------------------------------------- /examples/ios/README.md: -------------------------------------------------------------------------------- 1 | # iOS Demo 2 | 3 | This app demonstrates using FastSpeech2 and MB MelGAN models on iOS. 4 | 5 | ## How to build 6 | 7 | Download LJ Speech TFLite models from https://github.com/luan78zaoha/TTS_tflite_cpp/releases/tag/0.1.0 and unpack into TF_TTS_Demo directory containing Swift files. 8 | 9 | It uses [CocoaPods](https://cocoapods.org) to link with TensorFlowSwift. 10 | 11 | ``` 12 | pod install 13 | open TF_TTS_Demo.xcworkspace 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "iphone", 5 | "scale" : "2x", 6 | "size" : "20x20" 7 | }, 8 | { 9 | "idiom" : "iphone", 10 | "scale" : "3x", 11 | "size" : "20x20" 12 | }, 13 | { 14 | "idiom" : "iphone", 15 | "scale" : "2x", 16 | "size" : "29x29" 17 | }, 18 | { 19 | "idiom" : "iphone", 20 | "scale" : "3x", 21 | "size" : "29x29" 22 | }, 23 | { 24 | "idiom" : "iphone", 25 | "scale" : "2x", 26 | "size" : "40x40" 27 | }, 28 | { 29 | "idiom" : "iphone", 30 | "scale" : "3x", 31 | "size" : "40x40" 32 | }, 33 | { 34 | "idiom" : "iphone", 35 | "scale" : "2x", 36 | "size" : "60x60" 37 | }, 38 | { 39 | "idiom" : "iphone", 40 | "scale" : "3x", 41 | "size" : "60x60" 42 | }, 43 | { 44 | "idiom" : "ipad", 45 | "scale" : "1x", 46 | "size" : "20x20" 47 | }, 48 | { 49 | "idiom" : "ipad", 50 | "scale" : "2x", 51 | "size" : "20x20" 52 | }, 53 | { 54 | "idiom" : "ipad", 55 | "scale" : "1x", 56 | "size" : "29x29" 57 | }, 58 | { 59 | "idiom" : "ipad", 60 | "scale" : "2x", 61 | "size" : "29x29" 62 | }, 63 | { 64 | "idiom" : "ipad", 65 | "scale" : "1x", 66 | "size" : "40x40" 67 | }, 68 | { 69 | "idiom" : "ipad", 70 | "scale" : "2x", 71 | "size" : "40x40" 72 | }, 73 | { 74 | "idiom" : "ipad", 75 | "scale" : "1x", 76 | "size" : "76x76" 77 | }, 78 | { 79 | "idiom" : "ipad", 80 | "scale" : "2x", 81 | "size" : "76x76" 82 | }, 83 | { 84 | "idiom" : "ipad", 85 | "scale" : "2x", 86 | "size" : "83.5x83.5" 87 | }, 88 | { 89 | "idiom" : "ios-marketing", 90 | "scale" : "1x", 91 | "size" : "1024x1024" 92 | } 93 | ], 94 | "info" : { 95 | "author" : "xcode", 96 | "version" : 1 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // TF TTS Demo 4 | // 5 | // Created by 안창범 on 2021/03/16. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ContentView: View { 11 | @StateObject var tts = TTS() 12 | 13 | @State var text = "The Rhodes Must Fall campaigners said the announcement was hopeful, but warned they would remain cautious until the college had actually carried out the removal." 14 | 15 | var body: some View { 16 | VStack { 17 | TextEditor(text: $text) 18 | Button { 19 | tts.speak(string: text) 20 | } label: { 21 | Label("Speak", systemImage: "speaker.1") 22 | } 23 | } 24 | .padding() 25 | } 26 | } 27 | 28 | struct ContentView_Previews: PreviewProvider { 29 | static var previews: some View { 30 | ContentView() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/FastSpeech2.swift: -------------------------------------------------------------------------------- 1 | // 2 | // FastSpeech2.swift 3 | // HelloTensorFlowTTS 4 | // 5 | // Created by 안창범 on 2021/03/09. 6 | // 7 | 8 | import Foundation 9 | import TensorFlowLite 10 | 11 | class FastSpeech2 { 12 | let interpreter: Interpreter 13 | 14 | var speakerId: Int32 = 0 15 | 16 | var f0Ratio: Float = 1 17 | 18 | var energyRatio: Float = 1 19 | 20 | init(url: URL) throws { 21 | var options = Interpreter.Options() 22 | options.threadCount = 5 23 | interpreter = try Interpreter(modelPath: url.path, options: options) 24 | } 25 | 26 | func getMelSpectrogram(inputIds: [Int32], speedRatio: Float) throws -> Tensor { 27 | try interpreter.resizeInput(at: 0, to: [1, inputIds.count]) 28 | try interpreter.allocateTensors() 29 | 30 | let data = inputIds.withUnsafeBufferPointer(Data.init) 31 | try interpreter.copy(data, toInputAt: 0) 32 | try interpreter.copy(Data(bytes: &speakerId, count: 4), toInputAt: 1) 33 | var speedRatio = speedRatio 34 | try interpreter.copy(Data(bytes: &speedRatio, count: 4), toInputAt: 2) 35 | try interpreter.copy(Data(bytes: &f0Ratio, count: 4), toInputAt: 3) 36 | try interpreter.copy(Data(bytes: &energyRatio, count: 4), toInputAt: 4) 37 | 38 | let t0 = Date() 39 | try interpreter.invoke() 40 | print("fastspeech2: \(Date().timeIntervalSince(t0))s") 41 | 42 | return try interpreter.output(at: 1) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CFBundleDevelopmentRegion 6 | $(DEVELOPMENT_LANGUAGE) 7 | CFBundleExecutable 8 | $(EXECUTABLE_NAME) 9 | CFBundleIdentifier 10 | $(PRODUCT_BUNDLE_IDENTIFIER) 11 | CFBundleInfoDictionaryVersion 12 | 6.0 13 | CFBundleName 14 | $(PRODUCT_NAME) 15 | CFBundlePackageType 16 | $(PRODUCT_BUNDLE_PACKAGE_TYPE) 17 | CFBundleShortVersionString 18 | 1.0 19 | CFBundleVersion 20 | 1 21 | LSRequiresIPhoneOS 22 | 23 | UIApplicationSceneManifest 24 | 25 | UIApplicationSupportsMultipleScenes 26 | 27 | 28 | UIApplicationSupportsIndirectInputEvents 29 | 30 | UILaunchScreen 31 | 32 | UIRequiredDeviceCapabilities 33 | 34 | armv7 35 | 36 | UISupportedInterfaceOrientations 37 | 38 | UIInterfaceOrientationPortrait 39 | UIInterfaceOrientationLandscapeLeft 40 | UIInterfaceOrientationLandscapeRight 41 | 42 | UISupportedInterfaceOrientations~ipad 43 | 44 | UIInterfaceOrientationPortrait 45 | UIInterfaceOrientationPortraitUpsideDown 46 | UIInterfaceOrientationLandscapeLeft 47 | UIInterfaceOrientationLandscapeRight 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/MBMelGAN.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MBMelGAN.swift 3 | // HelloTensorFlowTTS 4 | // 5 | // Created by 안창범 on 2021/03/09. 6 | // 7 | 8 | import Foundation 9 | import TensorFlowLite 10 | 11 | class MBMelGan { 12 | let interpreter: Interpreter 13 | 14 | init(url: URL) throws { 15 | var options = Interpreter.Options() 16 | options.threadCount = 5 17 | interpreter = try Interpreter(modelPath: url.path, options: options) 18 | } 19 | 20 | func getAudio(input: Tensor) throws -> Data { 21 | try interpreter.resizeInput(at: 0, to: input.shape) 22 | try interpreter.allocateTensors() 23 | 24 | try interpreter.copy(input.data, toInputAt: 0) 25 | 26 | let t0 = Date() 27 | try interpreter.invoke() 28 | print("mbmelgan: \(Date().timeIntervalSince(t0))s") 29 | 30 | return try interpreter.output(at: 0).data 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /examples/ios/TF_TTS_Demo/TF_TTS_DemoApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // TF_TTS_DemoApp.swift 3 | // TF TTS Demo 4 | // 5 | // Created by 안창범 on 2021/03/16. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct TF_TTS_DemoApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | ContentView() 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/melgan/fig/melgan.v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan/fig/melgan.v1.png -------------------------------------------------------------------------------- /examples/melgan_stft/fig/melgan.stft.v1.eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan_stft/fig/melgan.stft.v1.eval.png -------------------------------------------------------------------------------- /examples/melgan_stft/fig/melgan.stft.v1.train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan_stft/fig/melgan.stft.v1.train.png -------------------------------------------------------------------------------- /examples/mfa_extraction/README.md: -------------------------------------------------------------------------------- 1 | # MFA based extraction for FastSpeech 2 | 3 | ## Prepare 4 | Everything is done from main repo folder so TensorflowTTS/ 5 | 6 | 0. Optional* Modify MFA scripts to work with your language (https://montreal-forced-aligner.readthedocs.io/en/latest/pretrained_models.html) 7 | 8 | 1. Download pretrained mfa, lexicon and run extract textgrids: 9 | 10 | - ``` 11 | bash examples/mfa_extraction/scripts/prepare_mfa.sh 12 | ``` 13 | 14 | - ``` 15 | python examples/mfa_extraction/run_mfa.py \ 16 | --corpus_directory ./libritts \ 17 | --output_directory ./mfa/parsed \ 18 | --jobs 8 19 | ``` 20 | 21 | After this step, the TextGrids is allocated at `./mfa/parsed`. 22 | 23 | 2. Extract duration from textgrid files: 24 | - ``` 25 | python examples/mfa_extraction/txt_grid_parser.py \ 26 | --yaml_path examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml \ 27 | --dataset_path ./libritts \ 28 | --text_grid_path ./mfa/parsed \ 29 | --output_durations_path ./libritts/durations \ 30 | --sample_rate 24000 31 | ``` 32 | 33 | - Dataset structure after finish this step: 34 | ``` 35 | |- TensorFlowTTS/ 36 | | |- LibriTTS/ 37 | | |- |- train-clean-100/ 38 | | |- |- SPEAKERS.txt 39 | | |- |- ... 40 | | |- dataset/ 41 | | |- |- 200/ 42 | | |- |- |- 200_124139_000001_000000.txt 43 | | |- |- |- 200_124139_000001_000000.wav 44 | | |- |- |- ... 45 | | |- |- 250/ 46 | | |- |- ... 47 | | |- |- durations/ 48 | | |- |- train.txt 49 | | |- tensorflow_tts/ 50 | | |- models/ 51 | | |- ... 52 | ``` 53 | 3. Optional* add your own dataset parser based on tensorflow_tts/processor/experiment/example_dataset.py ( If base processor dataset didnt match yours ) 54 | 55 | 4. Run preprocess and normalization (Step 4,5 in `examples/fastspeech2_libritts/README.MD`) 56 | 57 | 5. Run fix mismatch to fix few frames difference in audio and duration files: 58 | 59 | - ``` 60 | python examples/mfa_extraction/fix_mismatch.py \ 61 | --base_path ./dump \ 62 | --trimmed_dur_path ./dataset/trimmed-durations \ 63 | --dur_path ./dataset/durations 64 | ``` 65 | 66 | ## Problems with MFA extraction 67 | Looks like MFA have problems with trimmed files it works better (in my experiments) with ~100ms of silence at start and end 68 | 69 | Short files can get a lot of false positive like only silence extraction (LibriTTS example) so i would get only samples >2s 70 | -------------------------------------------------------------------------------- /examples/mfa_extraction/requirements.txt: -------------------------------------------------------------------------------- 1 | textgrid 2 | click 3 | g2p_en -------------------------------------------------------------------------------- /examples/mfa_extraction/run_mfa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Runing mfa to extract textgrids.""" 16 | 17 | from subprocess import call 18 | from pathlib import Path 19 | 20 | import click 21 | import os 22 | 23 | 24 | @click.command() 25 | @click.option("--mfa_path", default=os.path.join('mfa', 'montreal-forced-aligner', 'bin', 'mfa_align')) 26 | @click.option("--corpus_directory", default="libritts") 27 | @click.option("--lexicon", default=os.path.join('mfa', 'lexicon', 'librispeech-lexicon.txt')) 28 | @click.option("--acoustic_model_path", default=os.path.join('mfa', 'montreal-forced-aligner', 'pretrained_models', 'english.zip')) 29 | @click.option("--output_directory", default=os.path.join('mfa', 'parsed')) 30 | @click.option("--jobs", default="8") 31 | def run_mfa( 32 | mfa_path: str, 33 | corpus_directory: str, 34 | lexicon: str, 35 | acoustic_model_path: str, 36 | output_directory: str, 37 | jobs: str, 38 | ): 39 | Path(output_directory).mkdir(parents=True, exist_ok=True) 40 | call( 41 | [ 42 | f".{os.path.sep}{mfa_path}", 43 | corpus_directory, 44 | lexicon, 45 | acoustic_model_path, 46 | output_directory, 47 | f"-j {jobs}" 48 | ] 49 | ) 50 | 51 | 52 | if __name__ == "__main__": 53 | run_mfa() 54 | -------------------------------------------------------------------------------- /examples/mfa_extraction/scripts/prepare_mfa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir mfa 3 | cd mfa 4 | wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.1.0-beta.2/montreal-forced-aligner_linux.tar.gz 5 | tar -zxvf montreal-forced-aligner_linux.tar.gz 6 | cd mfa 7 | mkdir lexicon 8 | cd lexicon 9 | wget http://www.openslr.org/resources/11/librispeech-lexicon.txt -------------------------------------------------------------------------------- /examples/multiband_melgan/fig/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan/fig/eval.png -------------------------------------------------------------------------------- /examples/multiband_melgan/fig/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan/fig/train.png -------------------------------------------------------------------------------- /examples/multiband_melgan_hf/fig/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan_hf/fig/eval.png -------------------------------------------------------------------------------- /examples/multiband_melgan_hf/fig/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan_hf/fig/train.png -------------------------------------------------------------------------------- /examples/tacotron2/fig/alignment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/tacotron2/fig/alignment.gif -------------------------------------------------------------------------------- /examples/tacotron2/fig/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/tacotron2/fig/tensorboard.png -------------------------------------------------------------------------------- /notebooks/multiband_melgan_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import yaml\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "import tensorflow as tf\n", 14 | "\n", 15 | "from tensorflow_tts.inference import AutoConfig\n", 16 | "from tensorflow_tts.inference import TFAutoModel" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "mb_melgan = TFAutoModel.from_pretrained(\"tensorspeech/tts-mb_melgan-ljspeech-en\")" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "# Save to Pb" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "tf.saved_model.save(mb_melgan, \"./mb_melgan\", signatures=mb_melgan.inference)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "# Load and Inference" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "mb_melgan = tf.saved_model.load(\"./mb_melgan\")" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "mels = np.load(\"../dump/valid/norm-feats/LJ001-0009-norm-feats.npy\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "audios = mb_melgan.inference(mels[None, ...])" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "plt.plot(audios[0, :, 0])" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.7.7" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 4 116 | } 117 | -------------------------------------------------------------------------------- /preprocess/baker_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 24000 # Sampling rate. 5 | fft_size: 2048 # FFT size. 6 | hop_size: 300 # Hop size. (fixed value, don't change) 7 | win_length: 1200 # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /preprocess/jsut_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 24000 # Sampling rate. 5 | fft_size: 2048 # FFT size. 6 | hop_size: 300 # Hop size. (fixed value, don't change) 7 | win_length: 1200 # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /preprocess/kss_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 22050 # Sampling rate. 5 | fft_size: 1024 # FFT size. 6 | hop_size: 256 # Hop size. (fixed value, don't change) 7 | win_length: null # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 30 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /preprocess/libritts_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ###########################################################base_preprocess 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 24000 # Sampling rate. 5 | fft_size: 1024 # FFT size. 6 | hop_size: 300 # Hop size. (fixed value, don't change) 7 | win_length: null # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | trim_mfa: true 20 | 21 | -------------------------------------------------------------------------------- /preprocess/ljspeech_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 22050 # Sampling rate. 5 | fft_size: 1024 # FFT size. 6 | hop_size: 256 # Hop size. (fixed value, don't change) 7 | win_length: null # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /preprocess/ljspeechu_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 44100 # Sampling rate. 5 | fft_size: 2048 # FFT size. 6 | hop_size: 512 # Hop size. (fixed value, don't change) 7 | win_length: 2048 # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 20 # Minimum freq in mel basis calculation. 12 | fmax: 11025 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: false # Whether to trim the start and end of silence 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | trim_mfa: false -------------------------------------------------------------------------------- /preprocess/synpaflex_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 22050 # Sampling rate. 5 | fft_size: 1024 # FFT size. 6 | hop_size: 256 # Hop size. (fixed value, don't change) 7 | win_length: null # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 20 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /preprocess/thorsten_preprocess.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 22050 # Sampling rate. 5 | fft_size: 1024 # FFT size. 6 | hop_size: 256 # Hop size. (fixed value, don't change) 7 | win_length: null # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 80 # Minimum freq in mel basis calculation. 12 | fmax: 7600 # Maximum frequency in mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: true # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming. 18 | format: "npy" # Feature file format. Only "npy" is supported. 19 | 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose --durations=0 6 | testpaths = test 7 | 8 | [flake8] 9 | ignore = H102,W504,H238,D104,H306,H405,D205 10 | # 120 is a workaround, 79 is good 11 | max-line-length = 120 12 | -------------------------------------------------------------------------------- /tensorflow_tts/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0" 2 | -------------------------------------------------------------------------------- /tensorflow_tts/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/tensorflow_tts/bin/__init__.py -------------------------------------------------------------------------------- /tensorflow_tts/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.configs.base_config import BaseConfig 2 | from tensorflow_tts.configs.fastspeech import FastSpeechConfig 3 | from tensorflow_tts.configs.fastspeech2 import FastSpeech2Config 4 | from tensorflow_tts.configs.melgan import ( 5 | MelGANDiscriminatorConfig, 6 | MelGANGeneratorConfig, 7 | ) 8 | from tensorflow_tts.configs.mb_melgan import ( 9 | MultiBandMelGANDiscriminatorConfig, 10 | MultiBandMelGANGeneratorConfig, 11 | ) 12 | from tensorflow_tts.configs.hifigan import ( 13 | HifiGANGeneratorConfig, 14 | HifiGANDiscriminatorConfig, 15 | ) 16 | from tensorflow_tts.configs.tacotron2 import Tacotron2Config 17 | from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANGeneratorConfig 18 | from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANDiscriminatorConfig 19 | -------------------------------------------------------------------------------- /tensorflow_tts/configs/base_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Base Config for all config.""" 16 | 17 | import abc 18 | import yaml 19 | import os 20 | 21 | from tensorflow_tts.utils.utils import CONFIG_FILE_NAME 22 | 23 | 24 | class BaseConfig(abc.ABC): 25 | def set_config_params(self, config_params): 26 | self.config_params = config_params 27 | 28 | def save_pretrained(self, saved_path): 29 | """Save config to file""" 30 | os.makedirs(saved_path, exist_ok=True) 31 | with open(os.path.join(saved_path, CONFIG_FILE_NAME), "w") as file: 32 | yaml.dump(self.config_params, file, Dumper=yaml.Dumper) 33 | -------------------------------------------------------------------------------- /tensorflow_tts/configs/fastspeech2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """FastSpeech2 Config object.""" 16 | 17 | 18 | from tensorflow_tts.configs import FastSpeechConfig 19 | 20 | 21 | class FastSpeech2Config(FastSpeechConfig): 22 | """Initialize FastSpeech2 Config.""" 23 | 24 | def __init__( 25 | self, 26 | variant_prediction_num_conv_layers=2, 27 | variant_kernel_size=9, 28 | variant_dropout_rate=0.5, 29 | variant_predictor_filter=256, 30 | variant_predictor_kernel_size=3, 31 | variant_predictor_dropout_rate=0.5, 32 | **kwargs 33 | ): 34 | super().__init__(**kwargs) 35 | self.variant_prediction_num_conv_layers = variant_prediction_num_conv_layers 36 | self.variant_predictor_kernel_size = variant_predictor_kernel_size 37 | self.variant_predictor_dropout_rate = variant_predictor_dropout_rate 38 | self.variant_predictor_filter = variant_predictor_filter 39 | -------------------------------------------------------------------------------- /tensorflow_tts/configs/hifigan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorflowTTS Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """HifiGAN Config object.""" 16 | 17 | 18 | from tensorflow_tts.configs import BaseConfig 19 | 20 | 21 | class HifiGANGeneratorConfig(BaseConfig): 22 | """Initialize HifiGAN Generator Config.""" 23 | 24 | def __init__( 25 | self, 26 | out_channels=1, 27 | kernel_size=7, 28 | filters=128, 29 | use_bias=True, 30 | upsample_scales=[8, 8, 2, 2], 31 | stacks=3, 32 | stack_kernel_size=[3, 7, 11], 33 | stack_dilation_rate=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 34 | nonlinear_activation="LeakyReLU", 35 | nonlinear_activation_params={"alpha": 0.2}, 36 | padding_type="REFLECT", 37 | use_final_nolinear_activation=True, 38 | is_weight_norm=True, 39 | initializer_seed=42, 40 | **kwargs 41 | ): 42 | """Init parameters for HifiGAN Generator model.""" 43 | self.out_channels = out_channels 44 | self.kernel_size = kernel_size 45 | self.filters = filters 46 | self.use_bias = use_bias 47 | self.upsample_scales = upsample_scales 48 | self.stacks = stacks 49 | self.stack_kernel_size = stack_kernel_size 50 | self.stack_dilation_rate = stack_dilation_rate 51 | self.nonlinear_activation = nonlinear_activation 52 | self.nonlinear_activation_params = nonlinear_activation_params 53 | self.padding_type = padding_type 54 | self.use_final_nolinear_activation = use_final_nolinear_activation 55 | self.is_weight_norm = is_weight_norm 56 | self.initializer_seed = initializer_seed 57 | 58 | 59 | class HifiGANDiscriminatorConfig(object): 60 | """Initialize HifiGAN Discriminator Config.""" 61 | 62 | def __init__( 63 | self, 64 | out_channels=1, 65 | period_scales=[2, 3, 5, 7, 11], 66 | n_layers=5, 67 | kernel_size=5, 68 | strides=3, 69 | filters=8, 70 | filter_scales=4, 71 | max_filters=1024, 72 | nonlinear_activation="LeakyReLU", 73 | nonlinear_activation_params={"alpha": 0.2}, 74 | is_weight_norm=True, 75 | initializer_seed=42, 76 | **kwargs 77 | ): 78 | """Init parameters for MelGAN Discriminator model.""" 79 | self.out_channels = out_channels 80 | self.period_scales = period_scales 81 | self.n_layers = n_layers 82 | self.kernel_size = kernel_size 83 | self.strides = strides 84 | self.filters = filters 85 | self.filter_scales = filter_scales 86 | self.max_filters = max_filters 87 | self.nonlinear_activation = nonlinear_activation 88 | self.nonlinear_activation_params = nonlinear_activation_params 89 | self.is_weight_norm = is_weight_norm 90 | self.initializer_seed = initializer_seed 91 | -------------------------------------------------------------------------------- /tensorflow_tts/configs/mb_melgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Multi-band MelGAN Config object.""" 16 | 17 | from tensorflow_tts.configs import MelGANDiscriminatorConfig, MelGANGeneratorConfig 18 | 19 | 20 | class MultiBandMelGANGeneratorConfig(MelGANGeneratorConfig): 21 | """Initialize Multi-band MelGAN Generator Config.""" 22 | 23 | def __init__(self, **kwargs): 24 | super().__init__(**kwargs) 25 | self.subbands = kwargs.pop("subbands", 4) 26 | self.taps = kwargs.pop("taps", 62) 27 | self.cutoff_ratio = kwargs.pop("cutoff_ratio", 0.142) 28 | self.beta = kwargs.pop("beta", 9.0) 29 | 30 | 31 | class MultiBandMelGANDiscriminatorConfig(MelGANDiscriminatorConfig): 32 | """Initialize Multi-band MelGAN Discriminator Config.""" 33 | 34 | def __init__(self, **kwargs): 35 | super().__init__(**kwargs) 36 | -------------------------------------------------------------------------------- /tensorflow_tts/configs/parallel_wavegan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ParallelWaveGAN Config object.""" 16 | 17 | 18 | from tensorflow_tts.configs import BaseConfig 19 | 20 | 21 | class ParallelWaveGANGeneratorConfig(BaseConfig): 22 | """Initialize ParallelWaveGAN Generator Config.""" 23 | 24 | def __init__( 25 | self, 26 | out_channels=1, 27 | kernel_size=3, 28 | n_layers=30, 29 | stacks=3, 30 | residual_channels=64, 31 | gate_channels=128, 32 | skip_channels=64, 33 | aux_channels=80, 34 | aux_context_window=2, 35 | dropout_rate=0.0, 36 | use_bias=True, 37 | use_causal_conv=False, 38 | upsample_conditional_features=True, 39 | upsample_params={"upsample_scales": [4, 4, 4, 4]}, 40 | initializer_seed=42, 41 | **kwargs, 42 | ): 43 | """Init parameters for ParallelWaveGAN Generator model.""" 44 | self.out_channels = out_channels 45 | self.kernel_size = kernel_size 46 | self.n_layers = n_layers 47 | self.stacks = stacks 48 | self.residual_channels = residual_channels 49 | self.gate_channels = gate_channels 50 | self.skip_channels = skip_channels 51 | self.aux_channels = aux_channels 52 | self.aux_context_window = aux_context_window 53 | self.dropout_rate = dropout_rate 54 | self.use_bias = use_bias 55 | self.use_causal_conv = use_causal_conv 56 | self.upsample_conditional_features = upsample_conditional_features 57 | self.upsample_params = upsample_params 58 | self.initializer_seed = initializer_seed 59 | 60 | 61 | class ParallelWaveGANDiscriminatorConfig(object): 62 | """Initialize ParallelWaveGAN Discriminator Config.""" 63 | 64 | def __init__( 65 | self, 66 | out_channels=1, 67 | kernel_size=3, 68 | n_layers=10, 69 | conv_channels=64, 70 | use_bias=True, 71 | dilation_factor=1, 72 | nonlinear_activation="LeakyReLU", 73 | nonlinear_activation_params={"alpha": 0.2}, 74 | initializer_seed=42, 75 | apply_sigmoid_at_last=False, 76 | **kwargs, 77 | ): 78 | "Init parameters for ParallelWaveGAN Discriminator model." 79 | self.out_channels = out_channels 80 | self.kernel_size = kernel_size 81 | self.n_layers = n_layers 82 | self.conv_channels = conv_channels 83 | self.use_bias = use_bias 84 | self.dilation_factor = dilation_factor 85 | self.nonlinear_activation = nonlinear_activation 86 | self.nonlinear_activation_params = nonlinear_activation_params 87 | self.initializer_seed = initializer_seed 88 | self.apply_sigmoid_at_last = apply_sigmoid_at_last 89 | -------------------------------------------------------------------------------- /tensorflow_tts/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.datasets.abstract_dataset import AbstractDataset 2 | from tensorflow_tts.datasets.audio_dataset import AudioDataset 3 | from tensorflow_tts.datasets.mel_dataset import MelDataset 4 | -------------------------------------------------------------------------------- /tensorflow_tts/datasets/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Abstract Dataset modules.""" 16 | 17 | import abc 18 | 19 | import tensorflow as tf 20 | 21 | 22 | class AbstractDataset(metaclass=abc.ABCMeta): 23 | """Abstract Dataset module for Dataset Loader.""" 24 | 25 | @abc.abstractmethod 26 | def get_args(self): 27 | """Return args for generator function.""" 28 | pass 29 | 30 | @abc.abstractmethod 31 | def generator(self): 32 | """Generator function, should have args from get_args function.""" 33 | pass 34 | 35 | @abc.abstractmethod 36 | def get_output_dtypes(self): 37 | """Return output dtypes for each element from generator.""" 38 | pass 39 | 40 | @abc.abstractmethod 41 | def get_len_dataset(self): 42 | """Return number of samples on dataset.""" 43 | pass 44 | 45 | def create( 46 | self, 47 | allow_cache=False, 48 | batch_size=1, 49 | is_shuffle=False, 50 | map_fn=None, 51 | reshuffle_each_iteration=True, 52 | ): 53 | """Create tf.dataset function.""" 54 | output_types = self.get_output_dtypes() 55 | datasets = tf.data.Dataset.from_generator( 56 | self.generator, output_types=output_types, args=(self.get_args()) 57 | ) 58 | 59 | if allow_cache: 60 | datasets = datasets.cache() 61 | 62 | if is_shuffle: 63 | datasets = datasets.shuffle( 64 | self.get_len_dataset(), 65 | reshuffle_each_iteration=reshuffle_each_iteration, 66 | ) 67 | 68 | if batch_size > 1 and map_fn is None: 69 | raise ValueError("map function must define when batch_size > 1.") 70 | 71 | if map_fn is not None: 72 | datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE) 73 | 74 | datasets = datasets.batch(batch_size) 75 | datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE) 76 | 77 | return datasets 78 | -------------------------------------------------------------------------------- /tensorflow_tts/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.inference.auto_model import TFAutoModel 2 | from tensorflow_tts.inference.auto_config import AutoConfig 3 | from tensorflow_tts.inference.auto_processor import AutoProcessor 4 | -------------------------------------------------------------------------------- /tensorflow_tts/inference/auto_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tensorflow Auto Config modules.""" 16 | 17 | import logging 18 | import yaml 19 | import os 20 | from collections import OrderedDict 21 | 22 | from tensorflow_tts.configs import ( 23 | FastSpeechConfig, 24 | FastSpeech2Config, 25 | MelGANGeneratorConfig, 26 | MultiBandMelGANGeneratorConfig, 27 | HifiGANGeneratorConfig, 28 | Tacotron2Config, 29 | ParallelWaveGANGeneratorConfig, 30 | ) 31 | 32 | from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME 33 | from tensorflow_tts import __version__ as VERSION 34 | from huggingface_hub import hf_hub_url, cached_download 35 | 36 | CONFIG_MAPPING = OrderedDict( 37 | [ 38 | ("fastspeech", FastSpeechConfig), 39 | ("fastspeech2", FastSpeech2Config), 40 | ("multiband_melgan_generator", MultiBandMelGANGeneratorConfig), 41 | ("melgan_generator", MelGANGeneratorConfig), 42 | ("hifigan_generator", HifiGANGeneratorConfig), 43 | ("tacotron2", Tacotron2Config), 44 | ("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig), 45 | ] 46 | ) 47 | 48 | 49 | class AutoConfig: 50 | def __init__(self): 51 | raise EnvironmentError( 52 | "AutoConfig is designed to be instantiated " 53 | "using the `AutoConfig.from_pretrained(pretrained_path)` method." 54 | ) 55 | 56 | @classmethod 57 | def from_pretrained(cls, pretrained_path, **kwargs): 58 | # load weights from hf hub 59 | if not os.path.isfile(pretrained_path): 60 | # retrieve correct hub url 61 | download_url = hf_hub_url( 62 | repo_id=pretrained_path, filename=CONFIG_FILE_NAME 63 | ) 64 | 65 | pretrained_path = str( 66 | cached_download( 67 | url=download_url, 68 | library_name=LIBRARY_NAME, 69 | library_version=VERSION, 70 | cache_dir=CACHE_DIRECTORY, 71 | ) 72 | ) 73 | 74 | with open(pretrained_path) as f: 75 | config = yaml.load(f, Loader=yaml.Loader) 76 | 77 | try: 78 | model_type = config["model_type"] 79 | config_class = CONFIG_MAPPING[model_type] 80 | config_class = config_class(**config[model_type + "_params"], **kwargs) 81 | config_class.set_config_params(config) 82 | return config_class 83 | except Exception: 84 | raise ValueError( 85 | "Unrecognized config in {}. " 86 | "Should have a `model_type` key in its config.yaml, or contain one of the following strings " 87 | "in its name: {}".format( 88 | pretrained_path, ", ".join(CONFIG_MAPPING.keys()) 89 | ) 90 | ) 91 | -------------------------------------------------------------------------------- /tensorflow_tts/inference/auto_processor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 The TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tensorflow Auto Processor modules.""" 16 | 17 | import logging 18 | import json 19 | import os 20 | from collections import OrderedDict 21 | 22 | from tensorflow_tts.processor import ( 23 | LJSpeechProcessor, 24 | KSSProcessor, 25 | BakerProcessor, 26 | LibriTTSProcessor, 27 | ThorstenProcessor, 28 | LJSpeechUltimateProcessor, 29 | SynpaflexProcessor, 30 | JSUTProcessor, 31 | ) 32 | 33 | from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME 34 | from tensorflow_tts import __version__ as VERSION 35 | from huggingface_hub import hf_hub_url, cached_download 36 | 37 | CONFIG_MAPPING = OrderedDict( 38 | [ 39 | ("LJSpeechProcessor", LJSpeechProcessor), 40 | ("KSSProcessor", KSSProcessor), 41 | ("BakerProcessor", BakerProcessor), 42 | ("LibriTTSProcessor", LibriTTSProcessor), 43 | ("ThorstenProcessor", ThorstenProcessor), 44 | ("LJSpeechUltimateProcessor", LJSpeechUltimateProcessor), 45 | ("SynpaflexProcessor", SynpaflexProcessor), 46 | ("JSUTProcessor", JSUTProcessor), 47 | ] 48 | ) 49 | 50 | 51 | class AutoProcessor: 52 | def __init__(self): 53 | raise EnvironmentError( 54 | "AutoProcessor is designed to be instantiated " 55 | "using the `AutoProcessor.from_pretrained(pretrained_path)` method." 56 | ) 57 | 58 | @classmethod 59 | def from_pretrained(cls, pretrained_path, **kwargs): 60 | # load weights from hf hub 61 | if not os.path.isfile(pretrained_path): 62 | # retrieve correct hub url 63 | download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME) 64 | 65 | pretrained_path = str( 66 | cached_download( 67 | url=download_url, 68 | library_name=LIBRARY_NAME, 69 | library_version=VERSION, 70 | cache_dir=CACHE_DIRECTORY, 71 | ) 72 | ) 73 | with open(pretrained_path, "r") as f: 74 | config = json.load(f) 75 | 76 | try: 77 | processor_name = config["processor_name"] 78 | processor_class = CONFIG_MAPPING[processor_name] 79 | processor_class = processor_class( 80 | data_dir=None, loaded_mapper_path=pretrained_path 81 | ) 82 | return processor_class 83 | except Exception: 84 | raise ValueError( 85 | "Unrecognized processor in {}. " 86 | "Should have a `processor_name` key in its config.json, or contain one of the following strings " 87 | "in its name: {}".format( 88 | pretrained_path, ", ".join(CONFIG_MAPPING.keys()) 89 | ) 90 | ) 91 | -------------------------------------------------------------------------------- /tensorflow_tts/inference/savable_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tensorflow Savable Model modules.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from tensorflow_tts.models import ( 21 | TFFastSpeech, 22 | TFFastSpeech2, 23 | TFMelGANGenerator, 24 | TFMBMelGANGenerator, 25 | TFHifiGANGenerator, 26 | TFTacotron2, 27 | TFParallelWaveGANGenerator, 28 | ) 29 | 30 | 31 | class SavableTFTacotron2(TFTacotron2): 32 | def __init__(self, config, **kwargs): 33 | super().__init__(config, **kwargs) 34 | 35 | def call(self, inputs, training=False): 36 | input_ids, input_lengths, speaker_ids = inputs 37 | return super().inference(input_ids, input_lengths, speaker_ids) 38 | 39 | def _build(self): 40 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=tf.int32) 41 | input_lengths = tf.convert_to_tensor([9], dtype=tf.int32) 42 | speaker_ids = tf.convert_to_tensor([0], dtype=tf.int32) 43 | self([input_ids, input_lengths, speaker_ids]) 44 | 45 | 46 | class SavableTFFastSpeech(TFFastSpeech): 47 | def __init__(self, config, **kwargs): 48 | super().__init__(config, **kwargs) 49 | 50 | def call(self, inputs, training=False): 51 | input_ids, speaker_ids, speed_ratios = inputs 52 | return super()._inference(input_ids, speaker_ids, speed_ratios) 53 | 54 | def _build(self): 55 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32) 56 | speaker_ids = tf.convert_to_tensor([0], tf.int32) 57 | speed_ratios = tf.convert_to_tensor([1.0], tf.float32) 58 | self([input_ids, speaker_ids, speed_ratios]) 59 | 60 | 61 | class SavableTFFastSpeech2(TFFastSpeech2): 62 | def __init__(self, config, **kwargs): 63 | super().__init__(config, **kwargs) 64 | 65 | def call(self, inputs, training=False): 66 | input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios = inputs 67 | return super()._inference( 68 | input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios 69 | ) 70 | 71 | def _build(self): 72 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32) 73 | speaker_ids = tf.convert_to_tensor([0], tf.int32) 74 | speed_ratios = tf.convert_to_tensor([1.0], tf.float32) 75 | f0_ratios = tf.convert_to_tensor([1.0], tf.float32) 76 | energy_ratios = tf.convert_to_tensor([1.0], tf.float32) 77 | self([input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios]) 78 | -------------------------------------------------------------------------------- /tensorflow_tts/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.losses.spectrogram import TFMelSpectrogram 2 | from tensorflow_tts.losses.stft import TFMultiResolutionSTFT 3 | -------------------------------------------------------------------------------- /tensorflow_tts/losses/spectrogram.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Spectrogram-based loss modules.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | class TFMelSpectrogram(tf.keras.layers.Layer): 21 | """Mel Spectrogram loss.""" 22 | 23 | def __init__( 24 | self, 25 | n_mels=80, 26 | f_min=80.0, 27 | f_max=7600, 28 | frame_length=1024, 29 | frame_step=256, 30 | fft_length=1024, 31 | sample_rate=16000, 32 | **kwargs 33 | ): 34 | """Initialize.""" 35 | super().__init__(**kwargs) 36 | self.frame_length = frame_length 37 | self.frame_step = frame_step 38 | self.fft_length = fft_length 39 | 40 | self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( 41 | n_mels, fft_length // 2 + 1, sample_rate, f_min, f_max 42 | ) 43 | 44 | def _calculate_log_mels_spectrogram(self, signals): 45 | """Calculate forward propagation. 46 | Args: 47 | signals (Tensor): signal (B, T). 48 | Returns: 49 | Tensor: Mel spectrogram (B, T', 80) 50 | """ 51 | stfts = tf.signal.stft( 52 | signals, 53 | frame_length=self.frame_length, 54 | frame_step=self.frame_step, 55 | fft_length=self.fft_length, 56 | ) 57 | linear_spectrograms = tf.abs(stfts) 58 | mel_spectrograms = tf.tensordot( 59 | linear_spectrograms, self.linear_to_mel_weight_matrix, 1 60 | ) 61 | mel_spectrograms.set_shape( 62 | linear_spectrograms.shape[:-1].concatenate( 63 | self.linear_to_mel_weight_matrix.shape[-1:] 64 | ) 65 | ) 66 | log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) # prevent nan. 67 | return log_mel_spectrograms 68 | 69 | def call(self, y, x): 70 | """Calculate forward propagation. 71 | Args: 72 | y (Tensor): Groundtruth signal (B, T). 73 | x (Tensor): Predicted signal (B, T). 74 | Returns: 75 | Tensor: Mean absolute Error Spectrogram Loss. 76 | """ 77 | y_mels = self._calculate_log_mels_spectrogram(y) 78 | x_mels = self._calculate_log_mels_spectrogram(x) 79 | return tf.reduce_mean( 80 | tf.abs(y_mels - x_mels), axis=list(range(1, len(x_mels.shape))) 81 | ) 82 | -------------------------------------------------------------------------------- /tensorflow_tts/models/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.models.base_model import BaseModel 2 | from tensorflow_tts.models.fastspeech import TFFastSpeech 3 | from tensorflow_tts.models.fastspeech2 import TFFastSpeech2 4 | from tensorflow_tts.models.melgan import ( 5 | TFMelGANDiscriminator, 6 | TFMelGANGenerator, 7 | TFMelGANMultiScaleDiscriminator, 8 | ) 9 | from tensorflow_tts.models.mb_melgan import TFPQMF 10 | from tensorflow_tts.models.mb_melgan import TFMBMelGANGenerator 11 | from tensorflow_tts.models.hifigan import ( 12 | TFHifiGANGenerator, 13 | TFHifiGANMultiPeriodDiscriminator, 14 | TFHifiGANPeriodDiscriminator 15 | ) 16 | from tensorflow_tts.models.tacotron2 import TFTacotron2 17 | from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANGenerator 18 | from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANDiscriminator 19 | -------------------------------------------------------------------------------- /tensorflow_tts/models/base_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Base Model for all model.""" 16 | 17 | import tensorflow as tf 18 | import yaml 19 | import os 20 | import numpy as np 21 | 22 | from tensorflow_tts.utils.utils import MODEL_FILE_NAME, CONFIG_FILE_NAME 23 | 24 | 25 | class BaseModel(tf.keras.Model): 26 | def set_config(self, config): 27 | self.config = config 28 | 29 | def save_pretrained(self, saved_path): 30 | """Save config and weights to file""" 31 | os.makedirs(saved_path, exist_ok=True) 32 | self.config.save_pretrained(saved_path) 33 | self.save_weights(os.path.join(saved_path, MODEL_FILE_NAME)) 34 | -------------------------------------------------------------------------------- /tensorflow_tts/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.optimizers.adamweightdecay import AdamWeightDecay, WarmUp 2 | from tensorflow_tts.optimizers.gradient_accumulate import GradientAccumulator 3 | -------------------------------------------------------------------------------- /tensorflow_tts/processor/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.processor.base_processor import BaseProcessor 2 | 3 | from tensorflow_tts.processor.ljspeech import LJSpeechProcessor 4 | from tensorflow_tts.processor.baker import BakerProcessor 5 | from tensorflow_tts.processor.kss import KSSProcessor 6 | from tensorflow_tts.processor.libritts import LibriTTSProcessor 7 | from tensorflow_tts.processor.thorsten import ThorstenProcessor 8 | from tensorflow_tts.processor.ljspeechu import LJSpeechUltimateProcessor 9 | from tensorflow_tts.processor.synpaflex import SynpaflexProcessor 10 | from tensorflow_tts.processor.jsut import JSUTProcessor 11 | -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/jsut_mapper.json: -------------------------------------------------------------------------------- 1 | { 2 | "symbol_to_id": { 3 | "pad": 0, 4 | "sil": 1, 5 | "N": 2, 6 | "a": 3, 7 | "b": 4, 8 | "by": 5, 9 | "ch": 6, 10 | "cl": 7, 11 | "d": 8, 12 | "dy": 9, 13 | "e": 10, 14 | "f": 11, 15 | "g": 12, 16 | "gy": 13, 17 | "h": 14, 18 | "hy": 15, 19 | "i": 16, 20 | "j": 17, 21 | "k": 18, 22 | "ky": 19, 23 | "m": 20, 24 | "my": 21, 25 | "n": 22, 26 | "ny": 23, 27 | "o": 24, 28 | "p": 25, 29 | "pau": 26, 30 | "py": 27, 31 | "r": 28, 32 | "ry": 29, 33 | "s": 30, 34 | "sh": 31, 35 | "t": 32, 36 | "ts": 33, 37 | "u": 34, 38 | "v": 35, 39 | "w": 36, 40 | "y": 37, 41 | "z": 38, 42 | "eos": 39 43 | }, 44 | "id_to_symbol": { 45 | "0": "pad", 46 | "1": "sil", 47 | "2": "N", 48 | "3": "a", 49 | "4": "b", 50 | "5": "by", 51 | "6": "ch", 52 | "7": "cl", 53 | "8": "d", 54 | "9": "dy", 55 | "10": "e", 56 | "11": "f", 57 | "12": "g", 58 | "13": "gy", 59 | "14": "h", 60 | "15": "hy", 61 | "16": "i", 62 | "17": "j", 63 | "18": "k", 64 | "19": "ky", 65 | "20": "m", 66 | "21": "my", 67 | "22": "n", 68 | "23": "ny", 69 | "24": "o", 70 | "25": "p", 71 | "26": "pau", 72 | "27": "py", 73 | "28": "r", 74 | "29": "ry", 75 | "30": "s", 76 | "31": "sh", 77 | "32": "t", 78 | "33": "ts", 79 | "34": "u", 80 | "35": "v", 81 | "36": "w", 82 | "37": "y", 83 | "38": "z", 84 | "39": "eos" 85 | }, 86 | "speakers_map": { 87 | "jsut": 0 88 | }, 89 | "processor_name": "JSUTProcessor" 90 | } -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/kss_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "-": 7, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "\u1100": 13, "\u1101": 14, "\u1102": 15, "\u1103": 16, "\u1104": 17, "\u1105": 18, "\u1106": 19, "\u1107": 20, "\u1108": 21, "\u1109": 22, "\u110a": 23, "\u110b": 24, "\u110c": 25, "\u110d": 26, "\u110e": 27, "\u110f": 28, "\u1110": 29, "\u1111": 30, "\u1112": 31, "\u1161": 32, "\u1162": 33, "\u1163": 34, "\u1164": 35, "\u1165": 36, "\u1166": 37, "\u1167": 38, "\u1168": 39, "\u1169": 40, "\u116a": 41, "\u116b": 42, "\u116c": 43, "\u116d": 44, "\u116e": 45, "\u116f": 46, "\u1170": 47, "\u1171": 48, "\u1172": 49, "\u1173": 50, "\u1174": 51, "\u1175": 52, "\u11a8": 53, "\u11a9": 54, "\u11aa": 55, "\u11ab": 56, "\u11ac": 57, "\u11ad": 58, "\u11ae": 59, "\u11af": 60, "\u11b0": 61, "\u11b1": 62, "\u11b2": 63, "\u11b3": 64, "\u11b4": 65, "\u11b5": 66, "\u11b6": 67, "\u11b7": 68, "\u11b8": 69, "\u11b9": 70, "\u11ba": 71, "\u11bb": 72, "\u11bc": 73, "\u11bd": 74, "\u11be": 75, "\u11bf": 76, "\u11c0": 77, "\u11c1": 78, "\u11c2": 79, "eos": 80}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "\u1100", "14": "\u1101", "15": "\u1102", "16": "\u1103", "17": "\u1104", "18": "\u1105", "19": "\u1106", "20": "\u1107", "21": "\u1108", "22": "\u1109", "23": "\u110a", "24": "\u110b", "25": "\u110c", "26": "\u110d", "27": "\u110e", "28": "\u110f", "29": "\u1110", "30": "\u1111", "31": "\u1112", "32": "\u1161", "33": "\u1162", "34": "\u1163", "35": "\u1164", "36": "\u1165", "37": "\u1166", "38": "\u1167", "39": "\u1168", "40": "\u1169", "41": "\u116a", "42": "\u116b", "43": "\u116c", "44": "\u116d", "45": "\u116e", "46": "\u116f", "47": "\u1170", "48": "\u1171", "49": "\u1172", "50": "\u1173", "51": "\u1174", "52": "\u1175", "53": "\u11a8", "54": "\u11a9", "55": "\u11aa", "56": "\u11ab", "57": "\u11ac", "58": "\u11ad", "59": "\u11ae", "60": "\u11af", "61": "\u11b0", "62": "\u11b1", "63": "\u11b2", "64": "\u11b3", "65": "\u11b4", "66": "\u11b5", "67": "\u11b6", "68": "\u11b7", "69": "\u11b8", "70": "\u11b9", "71": "\u11ba", "72": "\u11bb", "73": "\u11bc", "74": "\u11bd", "75": "\u11be", "76": "\u11bf", "77": "\u11c0", "78": "\u11c1", "79": "\u11c2", "80": "eos"}, "speakers_map": {"kss": 0}, "processor_name": "KSSProcessor"} -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/libritts_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"@": 0, "@": 1, "@": 2, "@": 3, "@AA0": 4, "@AA1": 5, "@AA2": 6, "@AE0": 7, "@AE1": 8, "@AE2": 9, "@AH0": 10, "@AH1": 11, "@AH2": 12, "@AO0": 13, "@AO1": 14, "@AO2": 15, "@AW0": 16, "@AW1": 17, "@AW2": 18, "@AY0": 19, "@AY1": 20, "@AY2": 21, "@B": 22, "@CH": 23, "@D": 24, "@DH": 25, "@EH0": 26, "@EH1": 27, "@EH2": 28, "@ER0": 29, "@ER1": 30, "@ER2": 31, "@EY0": 32, "@EY1": 33, "@EY2": 34, "@F": 35, "@G": 36, "@HH": 37, "@IH0": 38, "@IH1": 39, "@IH2": 40, "@IY0": 41, "@IY1": 42, "@IY2": 43, "@JH": 44, "@K": 45, "@L": 46, "@M": 47, "@N": 48, "@NG": 49, "@OW0": 50, "@OW1": 51, "@OW2": 52, "@OY0": 53, "@OY1": 54, "@OY2": 55, "@P": 56, "@R": 57, "@S": 58, "@SH": 59, "@T": 60, "@TH": 61, "@UH0": 62, "@UH1": 63, "@UH2": 64, "@UW": 65, "@UW0": 66, "@UW1": 67, "@UW2": 68, "@V": 69, "@W": 70, "@Y": 71, "@Z": 72, "@ZH": 73, "@SIL": 74, "@END": 75, "!": 76, "'": 77, "(": 78, ")": 79, ",": 80, ".": 81, ":": 82, ";": 83, "?": 84, " ": 85}, "id_to_symbol": {"0": "@", "1": "@", "2": "@", "3": "@", "4": "@AA0", "5": "@AA1", "6": "@AA2", "7": "@AE0", "8": "@AE1", "9": "@AE2", "10": "@AH0", "11": "@AH1", "12": "@AH2", "13": "@AO0", "14": "@AO1", "15": "@AO2", "16": "@AW0", "17": "@AW1", "18": "@AW2", "19": "@AY0", "20": "@AY1", "21": "@AY2", "22": "@B", "23": "@CH", "24": "@D", "25": "@DH", "26": "@EH0", "27": "@EH1", "28": "@EH2", "29": "@ER0", "30": "@ER1", "31": "@ER2", "32": "@EY0", "33": "@EY1", "34": "@EY2", "35": "@F", "36": "@G", "37": "@HH", "38": "@IH0", "39": "@IH1", "40": "@IH2", "41": "@IY0", "42": "@IY1", "43": "@IY2", "44": "@JH", "45": "@K", "46": "@L", "47": "@M", "48": "@N", "49": "@NG", "50": "@OW0", "51": "@OW1", "52": "@OW2", "53": "@OY0", "54": "@OY1", "55": "@OY2", "56": "@P", "57": "@R", "58": "@S", "59": "@SH", "60": "@T", "61": "@TH", "62": "@UH0", "63": "@UH1", "64": "@UH2", "65": "@UW", "66": "@UW0", "67": "@UW1", "68": "@UW2", "69": "@V", "70": "@W", "71": "@Y", "72": "@Z", "73": "@ZH", "74": "@SIL", "75": "@END", "76": "!", "77": "'", "78": "(", "79": ")", "80": ",", "81": ".", "82": ":", "83": ";", "84": "?", "85": " "}, "speakers_map": {"200": 0, "1841": 1, "3664": 2, "6454": 3, "8108": 4, "2416": 5, "4680": 6, "6147": 7, "412": 8, "2952": 9, "8838": 10, "2836": 11, "1263": 12, "5322": 13, "3830": 14, "7447": 15, "1116": 16, "8312": 17, "8123": 18, "250": 19}, "processor_name": "LibriTTSProcessor"} -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/ljspeech_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, ":": 8, ";": 9, "?": 10, " ": 11, "A": 12, "B": 13, "C": 14, "D": 15, "E": 16, "F": 17, "G": 18, "H": 19, "I": 20, "J": 21, "K": 22, "L": 23, "M": 24, "N": 25, "O": 26, "P": 27, "Q": 28, "R": 29, "S": 30, "T": 31, "U": 32, "V": 33, "W": 34, "X": 35, "Y": 36, "Z": 37, "a": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63, "@AA": 64, "@AA0": 65, "@AA1": 66, "@AA2": 67, "@AE": 68, "@AE0": 69, "@AE1": 70, "@AE2": 71, "@AH": 72, "@AH0": 73, "@AH1": 74, "@AH2": 75, "@AO": 76, "@AO0": 77, "@AO1": 78, "@AO2": 79, "@AW": 80, "@AW0": 81, "@AW1": 82, "@AW2": 83, "@AY": 84, "@AY0": 85, "@AY1": 86, "@AY2": 87, "@B": 88, "@CH": 89, "@D": 90, "@DH": 91, "@EH": 92, "@EH0": 93, "@EH1": 94, "@EH2": 95, "@ER": 96, "@ER0": 97, "@ER1": 98, "@ER2": 99, "@EY": 100, "@EY0": 101, "@EY1": 102, "@EY2": 103, "@F": 104, "@G": 105, "@HH": 106, "@IH": 107, "@IH0": 108, "@IH1": 109, "@IH2": 110, "@IY": 111, "@IY0": 112, "@IY1": 113, "@IY2": 114, "@JH": 115, "@K": 116, "@L": 117, "@M": 118, "@N": 119, "@NG": 120, "@OW": 121, "@OW0": 122, "@OW1": 123, "@OW2": 124, "@OY": 125, "@OY0": 126, "@OY1": 127, "@OY2": 128, "@P": 129, "@R": 130, "@S": 131, "@SH": 132, "@T": 133, "@TH": 134, "@UH": 135, "@UH0": 136, "@UH1": 137, "@UH2": 138, "@UW": 139, "@UW0": 140, "@UW1": 141, "@UW2": 142, "@V": 143, "@W": 144, "@Y": 145, "@Z": 146, "@ZH": 147, "eos": 148}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": " ", "12": "A", "13": "B", "14": "C", "15": "D", "16": "E", "17": "F", "18": "G", "19": "H", "20": "I", "21": "J", "22": "K", "23": "L", "24": "M", "25": "N", "26": "O", "27": "P", "28": "Q", "29": "R", "30": "S", "31": "T", "32": "U", "33": "V", "34": "W", "35": "X", "36": "Y", "37": "Z", "38": "a", "39": "b", "40": "c", "41": "d", "42": "e", "43": "f", "44": "g", "45": "h", "46": "i", "47": "j", "48": "k", "49": "l", "50": "m", "51": "n", "52": "o", "53": "p", "54": "q", "55": "r", "56": "s", "57": "t", "58": "u", "59": "v", "60": "w", "61": "x", "62": "y", "63": "z", "64": "@AA", "65": "@AA0", "66": "@AA1", "67": "@AA2", "68": "@AE", "69": "@AE0", "70": "@AE1", "71": "@AE2", "72": "@AH", "73": "@AH0", "74": "@AH1", "75": "@AH2", "76": "@AO", "77": "@AO0", "78": "@AO1", "79": "@AO2", "80": "@AW", "81": "@AW0", "82": "@AW1", "83": "@AW2", "84": "@AY", "85": "@AY0", "86": "@AY1", "87": "@AY2", "88": "@B", "89": "@CH", "90": "@D", "91": "@DH", "92": "@EH", "93": "@EH0", "94": "@EH1", "95": "@EH2", "96": "@ER", "97": "@ER0", "98": "@ER1", "99": "@ER2", "100": "@EY", "101": "@EY0", "102": "@EY1", "103": "@EY2", "104": "@F", "105": "@G", "106": "@HH", "107": "@IH", "108": "@IH0", "109": "@IH1", "110": "@IH2", "111": "@IY", "112": "@IY0", "113": "@IY1", "114": "@IY2", "115": "@JH", "116": "@K", "117": "@L", "118": "@M", "119": "@N", "120": "@NG", "121": "@OW", "122": "@OW0", "123": "@OW1", "124": "@OW2", "125": "@OY", "126": "@OY0", "127": "@OY1", "128": "@OY2", "129": "@P", "130": "@R", "131": "@S", "132": "@SH", "133": "@T", "134": "@TH", "135": "@UH", "136": "@UH0", "137": "@UH1", "138": "@UH2", "139": "@UW", "140": "@UW0", "141": "@UW1", "142": "@UW2", "143": "@V", "144": "@W", "145": "@Y", "146": "@Z", "147": "@ZH", "148": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechProcessor"} -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/ljspeechu_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"padeos": 95}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": "@AA", "12": "@AA0", "13": "@AA1", "14": "@AA2", "15": "@AE", "16": "@AE0", "17": "@AE1", "18": "@AE2", "19": "@AH", "20": "@AH0", "21": "@AH1", "22": "@AH2", "23": "@AO", "24": "@AO0", "25": "@AO1", "26": "@AO2", "27": "@AW", "28": "@AW0", "29": "@AW1", "30": "@AW2", "31": "@AY", "32": "@AY0", "33": "@AY1", "34": "@AY2", "35": "@B", "36": "@CH", "37": "@D", "38": "@DH", "39": "@EH", "40": "@EH0", "41": "@EH1", "42": "@EH2", "43": "@ER", "44": "@ER0", "45": "@ER1", "46": "@ER2", "47": "@EY", "48": "@EY0", "49": "@EY1", "50": "@EY2", "51": "@F", "52": "@G", "53": "@HH", "54": "@IH", "55": "@IH0", "56": "@IH1", "57": "@IH2", "58": "@IY", "59": "@IY0", "60": "@IY1", "61": "@IY2", "62": "@JH", "63": "@K", "64": "@L", "65": "@M", "66": "@N", "67": "@NG", "68": "@OW", "69": "@OW0", "70": "@OW1", "71": "@OW2", "72": "@OY", "73": "@OY0", "74": "@OY1", "75": "@OY2", "76": "@P", "77": "@R", "78": "@S", "79": "@SH", "80": "@T", "81": "@TH", "82": "@UH", "83": "@UH0", "84": "@UH1", "85": "@UH2", "86": "@UW", "87": "@UW0", "88": "@UW1", "89": "@UW2", "90": "@V", "91": "@W", "92": "@Y", "93": "@Z", "94": "@ZH", "95": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechUltimateProcessor"} -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/synpaflex_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "!": 1, "/": 2, "'": 3, "(": 4, ")": 5, ",": 6, "-": 7, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "A": 13, "B": 14, "C": 15, "D": 16, "E": 17, "F": 18, "G": 19, "H": 20, "I": 21, "J": 22, "K": 23, "L": 24, "M": 25, "N": 26, "O": 27, "P": 28, "Q": 29, "R": 30, "S": 31, "T": 32, "U": 33, "V": 34, "W": 35, "X": 36, "Y": 37, "Z": 38, "a": 39, "b": 40, "c": 41, "d": 42, "e": 43, "f": 44, "g": 45, "h": 46, "i": 47, "j": 48, "k": 49, "l": 50, "m": 51, "n": 52, "o": 53, "p": 54, "q": 55, "r": 56, "s": 57, "t": 58, "u": 59, "v": 60, "w": 61, "x": 62, "y": 63, "z": 64, "\u00e9": 65, "\u00e8": 66, "\u00e0": 67, "\u00f9": 68, "\u00e2": 69, "\u00ea": 70, "\u00ee": 71, "\u00f4": 72, "\u00fb": 73, "\u00e7": 74, "\u00e4": 75, "\u00eb": 76, "\u00ef": 77, "\u00f6": 78, "\u00fc": 79, "\u00ff": 80, "\u0153": 81, "\u00e6": 82, "eos": 83}, "id_to_symbol": {"0": "pad", "1": "!", "2": "/", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "A", "14": "B", "15": "C", "16": "D", "17": "E", "18": "F", "19": "G", "20": "H", "21": "I", "22": "J", "23": "K", "24": "L", "25": "M", "26": "N", "27": "O", "28": "P", "29": "Q", "30": "R", "31": "S", "32": "T", "33": "U", "34": "V", "35": "W", "36": "X", "37": "Y", "38": "Z", "39": "a", "40": "b", "41": "c", "42": "d", "43": "e", "44": "f", "45": "g", "46": "h", "47": "i", "48": "j", "49": "k", "50": "l", "51": "m", "52": "n", "53": "o", "54": "p", "55": "q", "56": "r", "57": "s", "58": "t", "59": "u", "60": "v", "61": "w", "62": "x", "63": "y", "64": "z", "65": "\u00e9", "66": "\u00e8", "67": "\u00e0", "68": "\u00f9", "69": "\u00e2", "70": "\u00ea", "71": "\u00ee", "72": "\u00f4", "73": "\u00fb", "74": "\u00e7", "75": "\u00e4", "76": "\u00eb", "77": "\u00ef", "78": "\u00f6", "79": "\u00fc", "80": "\u00ff", "81": "\u0153", "82": "\u00e6", "83": "eos"}, "speakers_map": {"synpaflex": 0}, "processor_name": "SynpaflexProcessor"} 2 | -------------------------------------------------------------------------------- /tensorflow_tts/processor/pretrained/thorsten_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, "?": 8, " ": 9, "A": 10, "B": 11, "C": 12, "D": 13, "E": 14, "F": 15, "G": 16, "H": 17, "I": 18, "J": 19, "K": 20, "L": 21, "M": 22, "N": 23, "O": 24, "P": 25, "Q": 26, "R": 27, "S": 28, "T": 29, "U": 30, "V": 31, "W": 32, "X": 33, "Y": 34, "Z": 35, "a": 36, "b": 37, "c": 38, "d": 39, "e": 40, "f": 41, "g": 42, "h": 43, "i": 44, "j": 45, "k": 46, "l": 47, "m": 48, "n": 49, "o": 50, "p": 51, "q": 52, "r": 53, "s": 54, "t": 55, "u": 56, "v": 57, "w": 58, "x": 59, "y": 60, "z": 61, "eos": 62}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": "?", "9": " ", "10": "A", "11": "B", "12": "C", "13": "D", "14": "E", "15": "F", "16": "G", "17": "H", "18": "I", "19": "J", "20": "K", "21": "L", "22": "M", "23": "N", "24": "O", "25": "P", "26": "Q", "27": "R", "28": "S", "29": "T", "30": "U", "31": "V", "32": "W", "33": "X", "34": "Y", "35": "Z", "36": "a", "37": "b", "38": "c", "39": "d", "40": "e", "41": "f", "42": "g", "43": "h", "44": "i", "45": "j", "46": "k", "47": "l", "48": "m", "49": "n", "50": "o", "51": "p", "52": "q", "53": "r", "54": "s", "55": "t", "56": "u", "57": "v", "58": "w", "59": "x", "60": "y", "61": "z", "62": "eos"}, "speakers_map": {"thorsten": 0}, "processor_name": "ThorstenProcessor"} -------------------------------------------------------------------------------- /tensorflow_tts/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.trainers.base_trainer import GanBasedTrainer, Seq2SeqBasedTrainer 2 | -------------------------------------------------------------------------------- /tensorflow_tts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow_tts.utils.cleaners import ( 2 | basic_cleaners, 3 | collapse_whitespace, 4 | convert_to_ascii, 5 | english_cleaners, 6 | expand_abbreviations, 7 | expand_numbers, 8 | lowercase, 9 | transliteration_cleaners, 10 | ) 11 | from tensorflow_tts.utils.decoder import dynamic_decode 12 | from tensorflow_tts.utils.griffin_lim import TFGriffinLim, griffin_lim_lb 13 | from tensorflow_tts.utils.group_conv import GroupConv1D 14 | from tensorflow_tts.utils.number_norm import normalize_numbers 15 | from tensorflow_tts.utils.outliers import remove_outlier 16 | from tensorflow_tts.utils.strategy import ( 17 | calculate_2d_loss, 18 | calculate_3d_loss, 19 | return_strategy, 20 | ) 21 | from tensorflow_tts.utils.utils import find_files, MODEL_FILE_NAME, CONFIG_FILE_NAME, PROCESSOR_FILE_NAME, CACHE_DIRECTORY, LIBRARY_NAME 22 | from tensorflow_tts.utils.weight_norm import WeightNormalization 23 | -------------------------------------------------------------------------------- /tensorflow_tts/utils/outliers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Outliers detection and remove.""" 16 | import numpy as np 17 | 18 | 19 | def is_outlier(x, p25, p75): 20 | """Check if value is an outlier.""" 21 | lower = p25 - 1.5 * (p75 - p25) 22 | upper = p75 + 1.5 * (p75 - p25) 23 | return x <= lower or x >= upper 24 | 25 | 26 | def remove_outlier(x, p_bottom: int = 25, p_top: int = 75): 27 | """Remove outlier from x.""" 28 | p_bottom = np.percentile(x, p_bottom) 29 | p_top = np.percentile(x, p_top) 30 | 31 | indices_of_outliers = [] 32 | for ind, value in enumerate(x): 33 | if is_outlier(value, p_bottom, p_top): 34 | indices_of_outliers.append(ind) 35 | 36 | x[indices_of_outliers] = 0.0 37 | 38 | # replace by mean f0. 39 | x[indices_of_outliers] = np.max(x) 40 | return x 41 | -------------------------------------------------------------------------------- /tensorflow_tts/utils/strategy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Strategy util functions""" 16 | import tensorflow as tf 17 | 18 | 19 | def return_strategy(): 20 | physical_devices = tf.config.list_physical_devices("GPU") 21 | if len(physical_devices) == 0: 22 | return tf.distribute.OneDeviceStrategy(device="/cpu:0") 23 | elif len(physical_devices) == 1: 24 | return tf.distribute.OneDeviceStrategy(device="/gpu:0") 25 | else: 26 | return tf.distribute.MirroredStrategy() 27 | 28 | 29 | def calculate_3d_loss(y_gt, y_pred, loss_fn): 30 | """Calculate 3d loss, normally it's mel-spectrogram loss.""" 31 | y_gt_T = tf.shape(y_gt)[1] 32 | y_pred_T = tf.shape(y_pred)[1] 33 | 34 | # there is a mismath length when training multiple GPU. 35 | # we need slice the longer tensor to make sure the loss 36 | # calculated correctly. 37 | if y_gt_T > y_pred_T: 38 | y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1]) 39 | elif y_pred_T > y_gt_T: 40 | y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1]) 41 | 42 | loss = loss_fn(y_gt, y_pred) 43 | if isinstance(loss, tuple) is False: 44 | loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B] 45 | else: 46 | loss = list(loss) 47 | for i in range(len(loss)): 48 | loss[i] = tf.reduce_mean( 49 | loss[i], list(range(1, len(loss[i].shape))) 50 | ) # shape = [B] 51 | return loss 52 | 53 | 54 | def calculate_2d_loss(y_gt, y_pred, loss_fn): 55 | """Calculate 2d loss, normally it's durrations/f0s/energys loss.""" 56 | y_gt_T = tf.shape(y_gt)[1] 57 | y_pred_T = tf.shape(y_pred)[1] 58 | 59 | # there is a mismath length when training multiple GPU. 60 | # we need slice the longer tensor to make sure the loss 61 | # calculated correctly. 62 | if y_gt_T > y_pred_T: 63 | y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T]) 64 | elif y_pred_T > y_gt_T: 65 | y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T]) 66 | 67 | loss = loss_fn(y_gt, y_pred) 68 | if isinstance(loss, tuple) is False: 69 | loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B] 70 | else: 71 | loss = list(loss) 72 | for i in range(len(loss)): 73 | loss[i] = tf.reduce_mean( 74 | loss[i], list(range(1, len(loss[i].shape))) 75 | ) # shape = [B] 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /tensorflow_tts/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | """Utility functions.""" 6 | 7 | import fnmatch 8 | import os 9 | import re 10 | import tempfile 11 | from pathlib import Path 12 | 13 | import tensorflow as tf 14 | 15 | MODEL_FILE_NAME = "model.h5" 16 | CONFIG_FILE_NAME = "config.yml" 17 | PROCESSOR_FILE_NAME = "processor.json" 18 | LIBRARY_NAME = "tensorflow_tts" 19 | CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME) 20 | 21 | 22 | def find_files(root_dir, query="*.wav", include_root_dir=True): 23 | """Find files recursively. 24 | Args: 25 | root_dir (str): Root root_dir to find. 26 | query (str): Query to find. 27 | include_root_dir (bool): If False, root_dir name is not included. 28 | Returns: 29 | list: List of found filenames. 30 | """ 31 | files = [] 32 | for root, _, filenames in os.walk(root_dir, followlinks=True): 33 | for filename in fnmatch.filter(filenames, query): 34 | files.append(os.path.join(root, filename)) 35 | if not include_root_dir: 36 | files = [file_.replace(root_dir + "/", "") for file_ in files] 37 | 38 | return files 39 | 40 | 41 | def _path_requires_gfile(filepath): 42 | """Checks if the given path requires use of GFile API. 43 | 44 | Args: 45 | filepath (str): Path to check. 46 | Returns: 47 | bool: True if the given path needs GFile API to access, such as 48 | "s3://some/path" and "gs://some/path". 49 | """ 50 | # If the filepath contains a protocol (e.g. "gs://"), it should be handled 51 | # using TensorFlow GFile API. 52 | return bool(re.match(r"^[a-z]+://", filepath)) 53 | 54 | 55 | def save_weights(model, filepath): 56 | """Save model weights. 57 | 58 | Same as model.save_weights(filepath), but supports saving to S3 or GCS 59 | buckets using TensorFlow GFile API. 60 | 61 | Args: 62 | model (tf.keras.Model): Model to save. 63 | filepath (str): Path to save the model weights to. 64 | """ 65 | if not _path_requires_gfile(filepath): 66 | model.save_weights(filepath) 67 | return 68 | 69 | # Save to a local temp file and copy to the desired path using GFile API. 70 | _, ext = os.path.splitext(filepath) 71 | with tempfile.NamedTemporaryFile(suffix=ext) as temp_file: 72 | model.save_weights(temp_file.name) 73 | # To preserve the original semantics, we need to overwrite the target 74 | # file. 75 | tf.io.gfile.copy(temp_file.name, filepath, overwrite=True) 76 | 77 | 78 | def load_weights(model, filepath): 79 | """Load model weights. 80 | 81 | Same as model.load_weights(filepath), but supports loading from S3 or GCS 82 | buckets using TensorFlow GFile API. 83 | 84 | Args: 85 | model (tf.keras.Model): Model to load weights to. 86 | filepath (str): Path to the weights file. 87 | """ 88 | if not _path_requires_gfile(filepath): 89 | model.load_weights(filepath) 90 | return 91 | 92 | # Make a local copy and load it. 93 | _, ext = os.path.splitext(filepath) 94 | with tempfile.NamedTemporaryFile(suffix=ext) as temp_file: 95 | # The target temp_file should be created above, so we need to overwrite. 96 | tf.io.gfile.copy(filepath, temp_file.name, overwrite=True) 97 | model.load_weights(temp_file.name) 98 | -------------------------------------------------------------------------------- /test/files/kss_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "-": 7, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "\u1100": 13, "\u1101": 14, "\u1102": 15, "\u1103": 16, "\u1104": 17, "\u1105": 18, "\u1106": 19, "\u1107": 20, "\u1108": 21, "\u1109": 22, "\u110a": 23, "\u110b": 24, "\u110c": 25, "\u110d": 26, "\u110e": 27, "\u110f": 28, "\u1110": 29, "\u1111": 30, "\u1112": 31, "\u1161": 32, "\u1162": 33, "\u1163": 34, "\u1164": 35, "\u1165": 36, "\u1166": 37, "\u1167": 38, "\u1168": 39, "\u1169": 40, "\u116a": 41, "\u116b": 42, "\u116c": 43, "\u116d": 44, "\u116e": 45, "\u116f": 46, "\u1170": 47, "\u1171": 48, "\u1172": 49, "\u1173": 50, "\u1174": 51, "\u1175": 52, "\u11a8": 53, "\u11a9": 54, "\u11aa": 55, "\u11ab": 56, "\u11ac": 57, "\u11ad": 58, "\u11ae": 59, "\u11af": 60, "\u11b0": 61, "\u11b1": 62, "\u11b2": 63, "\u11b3": 64, "\u11b4": 65, "\u11b5": 66, "\u11b6": 67, "\u11b7": 68, "\u11b8": 69, "\u11b9": 70, "\u11ba": 71, "\u11bb": 72, "\u11bc": 73, "\u11bd": 74, "\u11be": 75, "\u11bf": 76, "\u11c0": 77, "\u11c1": 78, "\u11c2": 79, "eos": 80}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "\u1100", "14": "\u1101", "15": "\u1102", "16": "\u1103", "17": "\u1104", "18": "\u1105", "19": "\u1106", "20": "\u1107", "21": "\u1108", "22": "\u1109", "23": "\u110a", "24": "\u110b", "25": "\u110c", "26": "\u110d", "27": "\u110e", "28": "\u110f", "29": "\u1110", "30": "\u1111", "31": "\u1112", "32": "\u1161", "33": "\u1162", "34": "\u1163", "35": "\u1164", "36": "\u1165", "37": "\u1166", "38": "\u1167", "39": "\u1168", "40": "\u1169", "41": "\u116a", "42": "\u116b", "43": "\u116c", "44": "\u116d", "45": "\u116e", "46": "\u116f", "47": "\u1170", "48": "\u1171", "49": "\u1172", "50": "\u1173", "51": "\u1174", "52": "\u1175", "53": "\u11a8", "54": "\u11a9", "55": "\u11aa", "56": "\u11ab", "57": "\u11ac", "58": "\u11ad", "59": "\u11ae", "60": "\u11af", "61": "\u11b0", "62": "\u11b1", "63": "\u11b2", "64": "\u11b3", "65": "\u11b4", "66": "\u11b5", "67": "\u11b6", "68": "\u11b7", "69": "\u11b8", "70": "\u11b9", "71": "\u11ba", "72": "\u11bb", "73": "\u11bc", "74": "\u11bd", "75": "\u11be", "76": "\u11bf", "77": "\u11c0", "78": "\u11c1", "79": "\u11c2", "80": "eos"}, "speakers_map": {"kss": 0}, "processor_name": "KSSProcessor"} -------------------------------------------------------------------------------- /test/files/libritts_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"@": 0, "@": 1, "@": 2, "@": 3, "@AA0": 4, "@AA1": 5, "@AA2": 6, "@AE0": 7, "@AE1": 8, "@AE2": 9, "@AH0": 10, "@AH1": 11, "@AH2": 12, "@AO0": 13, "@AO1": 14, "@AO2": 15, "@AW0": 16, "@AW1": 17, "@AW2": 18, "@AY0": 19, "@AY1": 20, "@AY2": 21, "@B": 22, "@CH": 23, "@D": 24, "@DH": 25, "@EH0": 26, "@EH1": 27, "@EH2": 28, "@ER0": 29, "@ER1": 30, "@ER2": 31, "@EY0": 32, "@EY1": 33, "@EY2": 34, "@F": 35, "@G": 36, "@HH": 37, "@IH0": 38, "@IH1": 39, "@IH2": 40, "@IY0": 41, "@IY1": 42, "@IY2": 43, "@JH": 44, "@K": 45, "@L": 46, "@M": 47, "@N": 48, "@NG": 49, "@OW0": 50, "@OW1": 51, "@OW2": 52, "@OY0": 53, "@OY1": 54, "@OY2": 55, "@P": 56, "@R": 57, "@S": 58, "@SH": 59, "@T": 60, "@TH": 61, "@UH0": 62, "@UH1": 63, "@UH2": 64, "@UW": 65, "@UW0": 66, "@UW1": 67, "@UW2": 68, "@V": 69, "@W": 70, "@Y": 71, "@Z": 72, "@ZH": 73, "@SIL": 74, "@END": 75, "!": 76, "'": 77, "(": 78, ")": 79, ",": 80, ".": 81, ":": 82, ";": 83, "?": 84, " ": 85}, "id_to_symbol": {"0": "@", "1": "@", "2": "@", "3": "@", "4": "@AA0", "5": "@AA1", "6": "@AA2", "7": "@AE0", "8": "@AE1", "9": "@AE2", "10": "@AH0", "11": "@AH1", "12": "@AH2", "13": "@AO0", "14": "@AO1", "15": "@AO2", "16": "@AW0", "17": "@AW1", "18": "@AW2", "19": "@AY0", "20": "@AY1", "21": "@AY2", "22": "@B", "23": "@CH", "24": "@D", "25": "@DH", "26": "@EH0", "27": "@EH1", "28": "@EH2", "29": "@ER0", "30": "@ER1", "31": "@ER2", "32": "@EY0", "33": "@EY1", "34": "@EY2", "35": "@F", "36": "@G", "37": "@HH", "38": "@IH0", "39": "@IH1", "40": "@IH2", "41": "@IY0", "42": "@IY1", "43": "@IY2", "44": "@JH", "45": "@K", "46": "@L", "47": "@M", "48": "@N", "49": "@NG", "50": "@OW0", "51": "@OW1", "52": "@OW2", "53": "@OY0", "54": "@OY1", "55": "@OY2", "56": "@P", "57": "@R", "58": "@S", "59": "@SH", "60": "@T", "61": "@TH", "62": "@UH0", "63": "@UH1", "64": "@UH2", "65": "@UW", "66": "@UW0", "67": "@UW1", "68": "@UW2", "69": "@V", "70": "@W", "71": "@Y", "72": "@Z", "73": "@ZH", "74": "@SIL", "75": "@END", "76": "!", "77": "'", "78": "(", "79": ")", "80": ",", "81": ".", "82": ":", "83": ";", "84": "?", "85": " "}, "speakers_map": {"200": 0, "1841": 1, "3664": 2, "6454": 3, "8108": 4, "2416": 5, "4680": 6, "6147": 7, "412": 8, "2952": 9, "8838": 10, "2836": 11, "1263": 12, "5322": 13, "3830": 14, "7447": 15, "1116": 16, "8312": 17, "8123": 18, "250": 19}, "processor_name": "LibriTTSProcessor"} -------------------------------------------------------------------------------- /test/files/ljspeech_mapper.json: -------------------------------------------------------------------------------- 1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, ":": 8, ";": 9, "?": 10, " ": 11, "A": 12, "B": 13, "C": 14, "D": 15, "E": 16, "F": 17, "G": 18, "H": 19, "I": 20, "J": 21, "K": 22, "L": 23, "M": 24, "N": 25, "O": 26, "P": 27, "Q": 28, "R": 29, "S": 30, "T": 31, "U": 32, "V": 33, "W": 34, "X": 35, "Y": 36, "Z": 37, "a": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63, "@AA": 64, "@AA0": 65, "@AA1": 66, "@AA2": 67, "@AE": 68, "@AE0": 69, "@AE1": 70, "@AE2": 71, "@AH": 72, "@AH0": 73, "@AH1": 74, "@AH2": 75, "@AO": 76, "@AO0": 77, "@AO1": 78, "@AO2": 79, "@AW": 80, "@AW0": 81, "@AW1": 82, "@AW2": 83, "@AY": 84, "@AY0": 85, "@AY1": 86, "@AY2": 87, "@B": 88, "@CH": 89, "@D": 90, "@DH": 91, "@EH": 92, "@EH0": 93, "@EH1": 94, "@EH2": 95, "@ER": 96, "@ER0": 97, "@ER1": 98, "@ER2": 99, "@EY": 100, "@EY0": 101, "@EY1": 102, "@EY2": 103, "@F": 104, "@G": 105, "@HH": 106, "@IH": 107, "@IH0": 108, "@IH1": 109, "@IH2": 110, "@IY": 111, "@IY0": 112, "@IY1": 113, "@IY2": 114, "@JH": 115, "@K": 116, "@L": 117, "@M": 118, "@N": 119, "@NG": 120, "@OW": 121, "@OW0": 122, "@OW1": 123, "@OW2": 124, "@OY": 125, "@OY0": 126, "@OY1": 127, "@OY2": 128, "@P": 129, "@R": 130, "@S": 131, "@SH": 132, "@T": 133, "@TH": 134, "@UH": 135, "@UH0": 136, "@UH1": 137, "@UH2": 138, "@UW": 139, "@UW0": 140, "@UW1": 141, "@UW2": 142, "@V": 143, "@W": 144, "@Y": 145, "@Z": 146, "@ZH": 147, "eos": 148}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": " ", "12": "A", "13": "B", "14": "C", "15": "D", "16": "E", "17": "F", "18": "G", "19": "H", "20": "I", "21": "J", "22": "K", "23": "L", "24": "M", "25": "N", "26": "O", "27": "P", "28": "Q", "29": "R", "30": "S", "31": "T", "32": "U", "33": "V", "34": "W", "35": "X", "36": "Y", "37": "Z", "38": "a", "39": "b", "40": "c", "41": "d", "42": "e", "43": "f", "44": "g", "45": "h", "46": "i", "47": "j", "48": "k", "49": "l", "50": "m", "51": "n", "52": "o", "53": "p", "54": "q", "55": "r", "56": "s", "57": "t", "58": "u", "59": "v", "60": "w", "61": "x", "62": "y", "63": "z", "64": "@AA", "65": "@AA0", "66": "@AA1", "67": "@AA2", "68": "@AE", "69": "@AE0", "70": "@AE1", "71": "@AE2", "72": "@AH", "73": "@AH0", "74": "@AH1", "75": "@AH2", "76": "@AO", "77": "@AO0", "78": "@AO1", "79": "@AO2", "80": "@AW", "81": "@AW0", "82": "@AW1", "83": "@AW2", "84": "@AY", "85": "@AY0", "86": "@AY1", "87": "@AY2", "88": "@B", "89": "@CH", "90": "@D", "91": "@DH", "92": "@EH", "93": "@EH0", "94": "@EH1", "95": "@EH2", "96": "@ER", "97": "@ER0", "98": "@ER1", "99": "@ER2", "100": "@EY", "101": "@EY0", "102": "@EY1", "103": "@EY2", "104": "@F", "105": "@G", "106": "@HH", "107": "@IH", "108": "@IH0", "109": "@IH1", "110": "@IH2", "111": "@IY", "112": "@IY0", "113": "@IY1", "114": "@IY2", "115": "@JH", "116": "@K", "117": "@L", "118": "@M", "119": "@N", "120": "@NG", "121": "@OW", "122": "@OW0", "123": "@OW1", "124": "@OW2", "125": "@OY", "126": "@OY0", "127": "@OY1", "128": "@OY2", "129": "@P", "130": "@R", "131": "@S", "132": "@SH", "133": "@T", "134": "@TH", "135": "@UH", "136": "@UH0", "137": "@UH1", "138": "@UH2", "139": "@UW", "140": "@UW0", "141": "@UW1", "142": "@UW2", "143": "@V", "144": "@W", "145": "@Y", "146": "@Z", "147": "@ZH", "148": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechProcessor"} -------------------------------------------------------------------------------- /test/files/mapper.json: -------------------------------------------------------------------------------- 1 | { 2 | "speakers_map": { 3 | "test_one": 0, 4 | "test_two": 1 5 | }, 6 | "symbol_to_id": { 7 | "a": 0, 8 | "b": 1, 9 | "@ph": 2 10 | }, 11 | "id_to_symbol": { 12 | "0": "a", 13 | "1": "b", 14 | "2": "@ph" 15 | }, 16 | "processor_name": "KSSProcessor" 17 | } -------------------------------------------------------------------------------- /test/files/train.txt: -------------------------------------------------------------------------------- 1 | speaker1/libri1.wav|in fact its just a test.|One 2 | speaker2/libri2|in fact its just a speaker number one.|Two 3 | -------------------------------------------------------------------------------- /test/test_auto.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_tts.inference import AutoConfig 23 | from tensorflow_tts.inference import AutoProcessor 24 | from tensorflow_tts.inference import TFAutoModel 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 27 | 28 | logging.basicConfig( 29 | level=logging.DEBUG, 30 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 31 | ) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "mapper_path", 36 | [ 37 | "./test/files/baker_mapper.json", 38 | "./test/files/kss_mapper.json", 39 | "./test/files/libritts_mapper.json", 40 | "./test/files/ljspeech_mapper.json", 41 | ] 42 | ) 43 | def test_auto_processor(mapper_path): 44 | processor = AutoProcessor.from_pretrained(pretrained_path=mapper_path) 45 | processor.save_pretrained("./test_saved") 46 | processor = AutoProcessor.from_pretrained("./test_saved/processor.json") 47 | 48 | 49 | @pytest.mark.parametrize( 50 | "config_path", 51 | [ 52 | "./examples/fastspeech/conf/fastspeech.v1.yaml", 53 | "./examples/fastspeech/conf/fastspeech.v3.yaml", 54 | "./examples/fastspeech2/conf/fastspeech2.v1.yaml", 55 | "./examples/fastspeech2/conf/fastspeech2.v2.yaml", 56 | "./examples/fastspeech2/conf/fastspeech2.kss.v1.yaml", 57 | "./examples/fastspeech2/conf/fastspeech2.kss.v2.yaml", 58 | "./examples/melgan/conf/melgan.v1.yaml", 59 | "./examples/melgan_stft/conf/melgan_stft.v1.yaml", 60 | "./examples/multiband_melgan/conf/multiband_melgan.v1.yaml", 61 | "./examples/tacotron2/conf/tacotron2.v1.yaml", 62 | "./examples/tacotron2/conf/tacotron2.kss.v1.yaml", 63 | "./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml", 64 | "./examples/hifigan/conf/hifigan.v1.yaml", 65 | "./examples/hifigan/conf/hifigan.v2.yaml", 66 | ] 67 | ) 68 | def test_auto_model(config_path): 69 | config = AutoConfig.from_pretrained(pretrained_path=config_path) 70 | model = TFAutoModel.from_pretrained(pretrained_path=None, config=config) 71 | 72 | # test save_pretrained 73 | config.save_pretrained("./test_saved") 74 | model.save_pretrained("./test_saved") 75 | 76 | # test from_pretrained 77 | config = AutoConfig.from_pretrained("./test_saved/config.yml") 78 | model = TFAutoModel.from_pretrained("./test_saved/model.h5", config=config) 79 | -------------------------------------------------------------------------------- /test/test_fastspeech.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_tts.configs import FastSpeechConfig 23 | from tensorflow_tts.models import TFFastSpeech 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 26 | 27 | logging.basicConfig( 28 | level=logging.DEBUG, 29 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 30 | ) 31 | 32 | 33 | @pytest.mark.parametrize("new_size", [100, 200, 300]) 34 | def test_fastspeech_resize_positional_embeddings(new_size): 35 | config = FastSpeechConfig() 36 | fastspeech = TFFastSpeech(config, name="fastspeech") 37 | fastspeech._build() 38 | fastspeech.save_weights("./test.h5") 39 | fastspeech.resize_positional_embeddings(new_size) 40 | fastspeech.load_weights("./test.h5", by_name=True, skip_mismatch=True) 41 | 42 | 43 | @pytest.mark.parametrize("num_hidden_layers,n_speakers", [(2, 1), (3, 2), (4, 3)]) 44 | def test_fastspeech_trainable(num_hidden_layers, n_speakers): 45 | config = FastSpeechConfig( 46 | encoder_num_hidden_layers=num_hidden_layers, 47 | decoder_num_hidden_layers=num_hidden_layers + 1, 48 | n_speakers=n_speakers, 49 | ) 50 | 51 | fastspeech = TFFastSpeech(config, name="fastspeech") 52 | optimizer = tf.keras.optimizers.Adam(lr=0.001) 53 | 54 | # fake inputs 55 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32) 56 | attention_mask = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32) 57 | speaker_ids = tf.convert_to_tensor([0], tf.int32) 58 | duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32) 59 | 60 | mel_gts = tf.random.uniform(shape=[1, 10, 80], dtype=tf.float32) 61 | 62 | @tf.function 63 | def one_step_training(): 64 | with tf.GradientTape() as tape: 65 | mel_outputs_before, _, duration_outputs = fastspeech( 66 | input_ids, speaker_ids, duration_gts, training=True 67 | ) 68 | duration_loss = tf.keras.losses.MeanSquaredError()( 69 | duration_gts, duration_outputs 70 | ) 71 | mel_loss = tf.keras.losses.MeanSquaredError()(mel_gts, mel_outputs_before) 72 | loss = duration_loss + mel_loss 73 | gradients = tape.gradient(loss, fastspeech.trainable_variables) 74 | optimizer.apply_gradients(zip(gradients, fastspeech.trainable_variables)) 75 | 76 | tf.print(loss) 77 | 78 | import time 79 | 80 | for i in range(2): 81 | if i == 1: 82 | start = time.time() 83 | one_step_training() 84 | print(time.time() - start) 85 | -------------------------------------------------------------------------------- /test/test_mb_melgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensorflow as tf 17 | 18 | import logging 19 | import os 20 | 21 | import numpy as np 22 | import pytest 23 | 24 | from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig 25 | from tensorflow_tts.models import TFPQMF, TFMelGANGenerator 26 | 27 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 28 | 29 | logging.basicConfig( 30 | level=logging.DEBUG, 31 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 32 | ) 33 | 34 | 35 | def make_multi_band_melgan_generator_args(**kwargs): 36 | defaults = dict( 37 | out_channels=1, 38 | kernel_size=7, 39 | filters=512, 40 | use_bias=True, 41 | upsample_scales=[8, 8, 2, 2], 42 | stack_kernel_size=3, 43 | stacks=3, 44 | nonlinear_activation="LeakyReLU", 45 | nonlinear_activation_params={"alpha": 0.2}, 46 | padding_type="REFLECT", 47 | subbands=4, 48 | tabs=62, 49 | cutoff_ratio=0.15, 50 | beta=9.0, 51 | ) 52 | defaults.update(kwargs) 53 | return defaults 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "dict_g", 58 | [ 59 | {"subbands": 4, "upsample_scales": [2, 4, 8], "stacks": 4, "out_channels": 4}, 60 | {"subbands": 4, "upsample_scales": [4, 4, 4], "stacks": 5, "out_channels": 4}, 61 | ], 62 | ) 63 | def test_multi_band_melgan(dict_g): 64 | args_g = make_multi_band_melgan_generator_args(**dict_g) 65 | args_g = MultiBandMelGANGeneratorConfig(**args_g) 66 | generator = TFMelGANGenerator(args_g, name="multi_band_melgan") 67 | generator._build() 68 | 69 | pqmf = TFPQMF(args_g, name="pqmf") 70 | 71 | fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32) 72 | fake_y = tf.random.uniform(shape=[1, 100 * 256, 1], dtype=tf.float32) 73 | y_hat_subbands = generator(fake_mels) 74 | 75 | y_hat = pqmf.synthesis(y_hat_subbands) 76 | y_subbands = pqmf.analysis(fake_y) 77 | 78 | assert np.shape(y_subbands) == np.shape(y_hat_subbands) 79 | assert np.shape(fake_y) == np.shape(y_hat) 80 | -------------------------------------------------------------------------------- /test/test_melgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_tts.configs import MelGANDiscriminatorConfig, MelGANGeneratorConfig 23 | from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 26 | 27 | logging.basicConfig( 28 | level=logging.DEBUG, 29 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 30 | ) 31 | 32 | 33 | def make_melgan_generator_args(**kwargs): 34 | defaults = dict( 35 | out_channels=1, 36 | kernel_size=7, 37 | filters=512, 38 | use_bias=True, 39 | upsample_scales=[8, 8, 2, 2], 40 | stack_kernel_size=3, 41 | stacks=3, 42 | nonlinear_activation="LeakyReLU", 43 | nonlinear_activation_params={"alpha": 0.2}, 44 | padding_type="REFLECT", 45 | ) 46 | defaults.update(kwargs) 47 | return defaults 48 | 49 | 50 | def make_melgan_discriminator_args(**kwargs): 51 | defaults = dict( 52 | out_channels=1, 53 | scales=3, 54 | downsample_pooling="AveragePooling1D", 55 | downsample_pooling_params={"pool_size": 4, "strides": 2,}, 56 | kernel_sizes=[5, 3], 57 | filters=16, 58 | max_downsample_filters=1024, 59 | use_bias=True, 60 | downsample_scales=[4, 4, 4, 4], 61 | nonlinear_activation="LeakyReLU", 62 | nonlinear_activation_params={"alpha": 0.2}, 63 | padding_type="REFLECT", 64 | ) 65 | defaults.update(kwargs) 66 | return defaults 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "dict_g, dict_d, dict_loss", 71 | [ 72 | ({}, {}, {}), 73 | ({"kernel_size": 3}, {}, {}), 74 | ({"filters": 1024}, {}, {}), 75 | ({"stack_kernel_size": 5}, {}, {}), 76 | ({"stack_kernel_size": 5, "stacks": 2}, {}, {}), 77 | ({"upsample_scales": [4, 4, 4, 4]}, {}, {}), 78 | ({"upsample_scales": [8, 8, 2, 2]}, {}, {}), 79 | ({"filters": 1024, "upsample_scales": [8, 8, 2, 2]}, {}, {}), 80 | ], 81 | ) 82 | def test_melgan_trainable(dict_g, dict_d, dict_loss): 83 | batch_size = 4 84 | batch_length = 4096 85 | args_g = make_melgan_generator_args(**dict_g) 86 | args_d = make_melgan_discriminator_args(**dict_d) 87 | 88 | args_g = MelGANGeneratorConfig(**args_g) 89 | args_d = MelGANDiscriminatorConfig(**args_d) 90 | 91 | generator = TFMelGANGenerator(args_g) 92 | discriminator = TFMelGANMultiScaleDiscriminator(args_d) 93 | -------------------------------------------------------------------------------- /test/test_melgan_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 Minh Nguyen (@dathudeptrai) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | 19 | import numpy as np 20 | import pytest 21 | import tensorflow as tf 22 | 23 | from tensorflow_tts.models.melgan import ( 24 | TFConvTranspose1d, 25 | TFReflectionPad1d, 26 | TFResidualStack, 27 | ) 28 | 29 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 30 | 31 | logging.basicConfig( 32 | level=logging.DEBUG, 33 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 34 | ) 35 | 36 | 37 | @pytest.mark.parametrize("padding_size", [(3), (5)]) 38 | def test_padding(padding_size): 39 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32) 40 | out = TFReflectionPad1d(padding_size=padding_size)(fake_input_1d) 41 | assert np.array_equal( 42 | tf.keras.backend.int_shape(out), [4, 8000 + 2 * padding_size, 256] 43 | ) 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "filters,kernel_size,strides,padding,is_weight_norm", 48 | [(512, 40, 8, "same", False), (768, 15, 8, "same", True)], 49 | ) 50 | def test_convtranpose1d(filters, kernel_size, strides, padding, is_weight_norm): 51 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32) 52 | conv1d_transpose = TFConvTranspose1d( 53 | filters=filters, 54 | kernel_size=kernel_size, 55 | strides=strides, 56 | padding=padding, 57 | is_weight_norm=is_weight_norm, 58 | initializer_seed=42, 59 | ) 60 | out = conv1d_transpose(fake_input_1d) 61 | assert np.array_equal(tf.keras.backend.int_shape(out), [4, 8000 * strides, filters]) 62 | 63 | 64 | @pytest.mark.parametrize( 65 | "kernel_size,filters,dilation_rate,use_bias,nonlinear_activation,nonlinear_activation_params,is_weight_norm", 66 | [ 67 | (3, 256, 1, True, "LeakyReLU", {"alpha": 0.3}, True), 68 | (3, 256, 3, True, "ReLU", {}, False), 69 | ], 70 | ) 71 | def test_residualblock( 72 | kernel_size, 73 | filters, 74 | dilation_rate, 75 | use_bias, 76 | nonlinear_activation, 77 | nonlinear_activation_params, 78 | is_weight_norm, 79 | ): 80 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32) 81 | residual_block = TFResidualStack( 82 | kernel_size=kernel_size, 83 | filters=filters, 84 | dilation_rate=dilation_rate, 85 | use_bias=use_bias, 86 | nonlinear_activation=nonlinear_activation, 87 | nonlinear_activation_params=nonlinear_activation_params, 88 | is_weight_norm=is_weight_norm, 89 | initializer_seed=42, 90 | ) 91 | out = residual_block(fake_input_1d) 92 | assert np.array_equal(tf.keras.backend.int_shape(out), [4, 8000, filters]) 93 | -------------------------------------------------------------------------------- /test/test_parallel_wavegan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2020 TensorFlowTTS Team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | 19 | import pytest 20 | import tensorflow as tf 21 | 22 | from tensorflow_tts.configs import ( 23 | ParallelWaveGANGeneratorConfig, 24 | ParallelWaveGANDiscriminatorConfig, 25 | ) 26 | from tensorflow_tts.models import ( 27 | TFParallelWaveGANGenerator, 28 | TFParallelWaveGANDiscriminator, 29 | ) 30 | 31 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 32 | 33 | logging.basicConfig( 34 | level=logging.DEBUG, 35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 36 | ) 37 | 38 | 39 | def make_pwgan_generator_args(**kwargs): 40 | defaults = dict( 41 | out_channels=1, 42 | kernel_size=3, 43 | n_layers=30, 44 | stacks=3, 45 | residual_channels=64, 46 | gate_channels=128, 47 | skip_channels=64, 48 | aux_channels=80, 49 | aux_context_window=2, 50 | dropout_rate=0.0, 51 | use_bias=True, 52 | use_causal_conv=False, 53 | upsample_conditional_features=True, 54 | upsample_params={"upsample_scales": [4, 4, 4, 4]}, 55 | initializer_seed=42, 56 | ) 57 | defaults.update(kwargs) 58 | return defaults 59 | 60 | 61 | def make_pwgan_discriminator_args(**kwargs): 62 | defaults = dict( 63 | out_channels=1, 64 | kernel_size=3, 65 | n_layers=10, 66 | conv_channels=64, 67 | use_bias=True, 68 | dilation_factor=1, 69 | nonlinear_activation="LeakyReLU", 70 | nonlinear_activation_params={"alpha": 0.2}, 71 | initializer_seed=42, 72 | apply_sigmoid_at_last=False, 73 | ) 74 | defaults.update(kwargs) 75 | return defaults 76 | 77 | 78 | @pytest.mark.parametrize( 79 | "dict_g, dict_d", 80 | [ 81 | ({}, {}), 82 | ( 83 | {"kernel_size": 3, "aux_context_window": 5, "residual_channels": 128}, 84 | {"dilation_factor": 2}, 85 | ), 86 | ({"stacks": 4, "n_layers": 40}, {"conv_channels": 128}), 87 | ], 88 | ) 89 | def test_melgan_trainable(dict_g, dict_d): 90 | random_c = tf.random.uniform(shape=[4, 32, 80], dtype=tf.float32) 91 | 92 | args_g = make_pwgan_generator_args(**dict_g) 93 | args_d = make_pwgan_discriminator_args(**dict_d) 94 | 95 | args_g = ParallelWaveGANGeneratorConfig(**args_g) 96 | args_d = ParallelWaveGANDiscriminatorConfig(**args_d) 97 | 98 | generator = TFParallelWaveGANGenerator(args_g) 99 | generator._build() 100 | discriminator = TFParallelWaveGANDiscriminator(args_d) 101 | discriminator._build() 102 | 103 | generated_audios = generator(random_c, training=True) 104 | discriminator(generated_audios) 105 | 106 | generator.summary() 107 | discriminator.summary() 108 | --------------------------------------------------------------------------------