├── .gitattributes
├── .github
├── stale.yml
└── workflows
│ └── ci.yaml
├── .gitignore
├── LICENSE
├── README.md
├── docker-compose.yml
├── dockerfile
├── examples
├── android
│ ├── .gitignore
│ ├── README.md
│ ├── app
│ │ ├── .gitignore
│ │ ├── build.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ ├── androidTest
│ │ │ └── java
│ │ │ │ └── com
│ │ │ │ └── tensorspeech
│ │ │ │ └── tensorflowtts
│ │ │ │ └── ExampleInstrumentedTest.java
│ │ │ ├── main
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── assets
│ │ │ │ ├── fastspeech2_quant.tflite
│ │ │ │ └── mbmelgan.tflite
│ │ │ ├── java
│ │ │ │ └── com
│ │ │ │ │ └── tensorspeech
│ │ │ │ │ └── tensorflowtts
│ │ │ │ │ ├── MainActivity.java
│ │ │ │ │ ├── dispatcher
│ │ │ │ │ ├── OnTtsStateListener.java
│ │ │ │ │ └── TtsStateDispatcher.java
│ │ │ │ │ ├── module
│ │ │ │ │ ├── AbstractModule.java
│ │ │ │ │ ├── FastSpeech2.java
│ │ │ │ │ └── MBMelGan.java
│ │ │ │ │ ├── tts
│ │ │ │ │ ├── InputWorker.java
│ │ │ │ │ ├── TtsManager.java
│ │ │ │ │ └── TtsPlayer.java
│ │ │ │ │ └── utils
│ │ │ │ │ ├── NumberNorm.java
│ │ │ │ │ ├── Processor.java
│ │ │ │ │ └── ThreadPoolManager.java
│ │ │ └── res
│ │ │ │ ├── drawable-v24
│ │ │ │ └── ic_launcher_foreground.xml
│ │ │ │ ├── drawable
│ │ │ │ └── ic_launcher_background.xml
│ │ │ │ ├── layout
│ │ │ │ └── activity_main.xml
│ │ │ │ ├── mipmap-anydpi-v26
│ │ │ │ ├── ic_launcher.xml
│ │ │ │ └── ic_launcher_round.xml
│ │ │ │ ├── mipmap-hdpi
│ │ │ │ ├── ic_launcher.png
│ │ │ │ └── ic_launcher_round.png
│ │ │ │ ├── mipmap-mdpi
│ │ │ │ ├── ic_launcher.png
│ │ │ │ └── ic_launcher_round.png
│ │ │ │ ├── mipmap-xhdpi
│ │ │ │ ├── ic_launcher.png
│ │ │ │ └── ic_launcher_round.png
│ │ │ │ ├── mipmap-xxhdpi
│ │ │ │ ├── ic_launcher.png
│ │ │ │ └── ic_launcher_round.png
│ │ │ │ ├── mipmap-xxxhdpi
│ │ │ │ ├── ic_launcher.png
│ │ │ │ └── ic_launcher_round.png
│ │ │ │ └── values
│ │ │ │ ├── colors.xml
│ │ │ │ ├── strings.xml
│ │ │ │ └── styles.xml
│ │ │ └── test
│ │ │ └── java
│ │ │ └── com
│ │ │ └── tensorspeech
│ │ │ └── tensorflowtts
│ │ │ └── ExampleUnitTest.java
│ ├── build.gradle
│ ├── gradle.properties
│ ├── gradle
│ │ └── wrapper
│ │ │ ├── gradle-wrapper.jar
│ │ │ └── gradle-wrapper.properties
│ ├── gradlew
│ ├── gradlew.bat
│ └── settings.gradle
├── cpptflite
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── demo
│ │ ├── main.cpp
│ │ └── text2ids.py
│ ├── results
│ │ ├── lj_ori_mel.png
│ │ ├── lj_tflite_mel.png
│ │ ├── tflite_mel.png
│ │ └── tflite_mel2.png
│ └── src
│ │ ├── AudioFile.h
│ │ ├── MelGenerateTF.cpp
│ │ ├── MelGenerateTF.h
│ │ ├── TTSBackend.cpp
│ │ ├── TTSBackend.h
│ │ ├── TTSFrontend.cpp
│ │ ├── TTSFrontend.h
│ │ ├── TfliteBase.cpp
│ │ ├── TfliteBase.h
│ │ ├── VocoderTF.cpp
│ │ ├── VocoderTF.h
│ │ ├── VoxCommon.cpp
│ │ └── VoxCommon.h
├── cppwin
│ ├── .gitattributes
│ ├── .gitignore
│ ├── README.md
│ ├── TensorflowTTSCppInference.pro
│ ├── TensorflowTTSCppInference.sln
│ └── TensorflowTTSCppInference
│ │ ├── EnglishPhoneticProcessor.cpp
│ │ ├── EnglishPhoneticProcessor.h
│ │ ├── FastSpeech2.cpp
│ │ ├── FastSpeech2.h
│ │ ├── MultiBandMelGAN.cpp
│ │ ├── MultiBandMelGAN.h
│ │ ├── TensorflowTTSCppInference.cpp
│ │ ├── TensorflowTTSCppInference.vcxproj
│ │ ├── TensorflowTTSCppInference.vcxproj.filters
│ │ ├── TextTokenizer.cpp
│ │ ├── TextTokenizer.h
│ │ ├── Voice.cpp
│ │ ├── Voice.h
│ │ ├── VoxCommon.cpp
│ │ ├── VoxCommon.hpp
│ │ ├── ext
│ │ ├── AudioFile.hpp
│ │ ├── CppFlow
│ │ │ ├── include
│ │ │ │ ├── Model.h
│ │ │ │ └── Tensor.h
│ │ │ └── src
│ │ │ │ ├── Model.cpp
│ │ │ │ └── Tensor.cpp
│ │ ├── ZCharScanner.cpp
│ │ ├── ZCharScanner.h
│ │ ├── cxxopts.hpp
│ │ └── json.hpp
│ │ ├── phonemizer.cpp
│ │ ├── phonemizer.h
│ │ ├── tfg2p.cpp
│ │ └── tfg2p.h
├── fastspeech
│ ├── README.md
│ ├── conf
│ │ ├── fastspeech.v1.yaml
│ │ └── fastspeech.v3.yaml
│ ├── decode_fastspeech.py
│ ├── fastspeech_dataset.py
│ ├── fig
│ │ └── fastspeech.v1.png
│ └── train_fastspeech.py
├── fastspeech2
│ ├── README.md
│ ├── conf
│ │ ├── fastspeech2.baker.v2.yaml
│ │ ├── fastspeech2.jsut.v1.yaml
│ │ ├── fastspeech2.kss.v1.yaml
│ │ ├── fastspeech2.kss.v2.yaml
│ │ ├── fastspeech2.v1.yaml
│ │ └── fastspeech2.v2.yaml
│ ├── decode_fastspeech2.py
│ ├── extractfs_postnets.py
│ ├── fastspeech2_dataset.py
│ └── train_fastspeech2.py
├── fastspeech2_libritts
│ ├── README.md
│ ├── conf
│ │ └── fastspeech2libritts.yaml
│ ├── fastspeech2_dataset.py
│ ├── libri_experiment
│ │ └── prepare_libri.ipynb
│ ├── scripts
│ │ ├── build.sh
│ │ ├── docker
│ │ │ └── Dockerfile
│ │ ├── interactive.sh
│ │ └── train_libri.sh
│ └── train_fastspeech2.py
├── hifigan
│ ├── README.md
│ ├── conf
│ │ ├── hifigan.v1.yaml
│ │ └── hifigan.v2.yaml
│ └── train_hifigan.py
├── ios
│ ├── .gitignore
│ ├── Podfile
│ ├── Podfile.lock
│ ├── README.md
│ ├── TF_TTS_Demo.xcodeproj
│ │ └── project.pbxproj
│ └── TF_TTS_Demo
│ │ ├── Assets.xcassets
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ ├── AppIcon.appiconset
│ │ │ └── Contents.json
│ │ └── Contents.json
│ │ ├── ContentView.swift
│ │ ├── FastSpeech2.swift
│ │ ├── Info.plist
│ │ ├── MBMelGAN.swift
│ │ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ │ ├── TF_TTS_DemoApp.swift
│ │ └── TTS.swift
├── melgan
│ ├── README.md
│ ├── audio_mel_dataset.py
│ ├── conf
│ │ └── melgan.v1.yaml
│ ├── decode_melgan.py
│ ├── fig
│ │ └── melgan.v1.png
│ └── train_melgan.py
├── melgan_stft
│ ├── README.md
│ ├── conf
│ │ └── melgan_stft.v1.yaml
│ ├── fig
│ │ ├── melgan.stft.v1.eval.png
│ │ └── melgan.stft.v1.train.png
│ └── train_melgan_stft.py
├── mfa_extraction
│ ├── README.md
│ ├── fix_mismatch.py
│ ├── requirements.txt
│ ├── run_mfa.py
│ ├── scripts
│ │ └── prepare_mfa.sh
│ └── txt_grid_parser.py
├── multiband_melgan
│ ├── README.md
│ ├── conf
│ │ ├── multiband_melgan.baker.v1.yaml
│ │ ├── multiband_melgan.synpaflex.v1.yaml
│ │ └── multiband_melgan.v1.yaml
│ ├── decode_mb_melgan.py
│ ├── fig
│ │ ├── eval.png
│ │ └── train.png
│ └── train_multiband_melgan.py
├── multiband_melgan_hf
│ ├── README.md
│ ├── conf
│ │ ├── multiband_melgan_hf.lju.v1.yml
│ │ └── multiband_melgan_hf.lju.v1ft.yml
│ ├── decode_mb_melgan.py
│ ├── fig
│ │ ├── eval.png
│ │ └── train.png
│ └── train_multiband_melgan_hf.py
├── parallel_wavegan
│ ├── README.md
│ ├── conf
│ │ └── parallel_wavegan.v1.yaml
│ ├── convert_pwgan_from_pytorch_to_tensorflow.ipynb
│ ├── decode_parallel_wavegan.py
│ └── train_parallel_wavegan.py
└── tacotron2
│ ├── README.md
│ ├── conf
│ ├── tacotron2.baker.v1.yaml
│ ├── tacotron2.jsut.v1.yaml
│ ├── tacotron2.kss.v1.yaml
│ ├── tacotron2.lju.v1.yaml
│ ├── tacotron2.synpaflex.v1.yaml
│ └── tacotron2.v1.yaml
│ ├── decode_tacotron2.py
│ ├── export_align.py
│ ├── extract_duration.py
│ ├── extract_postnets.py
│ ├── fig
│ ├── alignment.gif
│ └── tensorboard.png
│ ├── tacotron_dataset.py
│ └── train_tacotron2.py
├── notebooks
├── Parallel_WaveGAN_TFLite.ipynb
├── TensorFlowTTS_FastSpeech_with_TFLite.ipynb
├── TensorFlowTTS_Tacotron2_with_TFLite.ipynb
├── fastspeech2_inference.ipynb
├── fastspeech_inference.ipynb
├── griffin_lim_tensorflow.ipynb
├── multiband_melgan_inference.ipynb
├── prepare_synpaflex.ipynb
└── tacotron2_inference.ipynb
├── preprocess
├── baker_preprocess.yaml
├── jsut_preprocess.yaml
├── kss_preprocess.yaml
├── libritts_preprocess.yaml
├── ljspeech_preprocess.yaml
├── ljspeechu_preprocess.yaml
├── synpaflex_preprocess.yaml
└── thorsten_preprocess.yaml
├── setup.cfg
├── setup.py
├── tensorflow_tts
├── __init__.py
├── bin
│ ├── __init__.py
│ └── preprocess.py
├── configs
│ ├── __init__.py
│ ├── base_config.py
│ ├── fastspeech.py
│ ├── fastspeech2.py
│ ├── hifigan.py
│ ├── mb_melgan.py
│ ├── melgan.py
│ ├── parallel_wavegan.py
│ └── tacotron2.py
├── datasets
│ ├── __init__.py
│ ├── abstract_dataset.py
│ ├── audio_dataset.py
│ └── mel_dataset.py
├── inference
│ ├── __init__.py
│ ├── auto_config.py
│ ├── auto_model.py
│ ├── auto_processor.py
│ └── savable_models.py
├── losses
│ ├── __init__.py
│ ├── spectrogram.py
│ └── stft.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── fastspeech.py
│ ├── fastspeech2.py
│ ├── hifigan.py
│ ├── mb_melgan.py
│ ├── melgan.py
│ ├── parallel_wavegan.py
│ └── tacotron2.py
├── optimizers
│ ├── __init__.py
│ ├── adamweightdecay.py
│ └── gradient_accumulate.py
├── processor
│ ├── __init__.py
│ ├── baker.py
│ ├── base_processor.py
│ ├── jsut.py
│ ├── kss.py
│ ├── libritts.py
│ ├── ljspeech.py
│ ├── ljspeechu.py
│ ├── pretrained
│ │ ├── baker_mapper.json
│ │ ├── jsut_mapper.json
│ │ ├── kss_mapper.json
│ │ ├── libritts_mapper.json
│ │ ├── ljspeech_mapper.json
│ │ ├── ljspeechu_mapper.json
│ │ ├── synpaflex_mapper.json
│ │ └── thorsten_mapper.json
│ ├── synpaflex.py
│ └── thorsten.py
├── trainers
│ ├── __init__.py
│ └── base_trainer.py
└── utils
│ ├── __init__.py
│ ├── cleaners.py
│ ├── decoder.py
│ ├── griffin_lim.py
│ ├── group_conv.py
│ ├── korean.py
│ ├── number_norm.py
│ ├── outliers.py
│ ├── strategy.py
│ ├── utils.py
│ └── weight_norm.py
└── test
├── files
├── baker_mapper.json
├── kss_mapper.json
├── libritts_mapper.json
├── ljspeech_mapper.json
├── mapper.json
└── train.txt
├── test_auto.py
├── test_base_processor.py
├── test_fastspeech.py
├── test_fastspeech2.py
├── test_hifigan.py
├── test_mb_melgan.py
├── test_melgan.py
├── test_melgan_layers.py
├── test_parallel_wavegan.py
└── test_tacotron2.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 60
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 7
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs.
15 | # Comment to post when closing a stale issue. Set to `false` to disable
16 | closeComment: false
17 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | schedule:
11 | - cron: 0 0 * * 1
12 |
13 | jobs:
14 | linter_and_test:
15 | runs-on: ubuntu-18.04
16 | strategy:
17 | max-parallel: 10
18 | matrix:
19 | python-version: [3.7]
20 | tensorflow-version: [2.7.0]
21 | steps:
22 | - uses: actions/checkout@master
23 | - uses: actions/setup-python@v1
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | architecture: 'x64'
27 | - uses: actions/cache@v1
28 | with:
29 | path: ~/.cache/pip
30 | key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-pip-${{ hashFiles('**/setup.py') }}
31 | restore-keys: |
32 | ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.tensorflow-version }}-pip-
33 | - name: Install dependencies
34 | run: |
35 | # install python modules
36 | python -m pip install --upgrade pip
37 | pip install -q -U numpy
38 | pip install git+https://github.com/repodiac/german_transliterate.git#egg=german_transliterate
39 | pip install -q tensorflow-gpu==${{ matrix.tensorflow-version }}
40 | pip install -q -e .
41 | pip install -q -e .[test]
42 | pip install typing_extensions
43 | sudo apt-get install libsndfile1-dev
44 | python -m pip install black
45 | - name: black
46 | run: |
47 | python -m black .
48 | - name: Pytest
49 | run: |
50 | pytest test
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # general
3 | *~
4 | *.pyc
5 | \#*\#
6 | .\#*
7 | *DS_Store
8 | out.txt
9 | TensorFlowTTS.egg-info/
10 | doc/_build
11 | slurm-*.out
12 | tmp*
13 | .eggs/
14 | .hypothesis/
15 | .idea
16 | .backup/
17 | .pytest_cache/
18 | __pycache__/
19 | .coverage*
20 | coverage.xml*
21 | .vscode*
22 | .nfs*
23 | .ipynb_checkpoints
24 | ljspeech
25 | *.h5
26 | *.npy
27 | ./*.wav
28 | !docker-compose.yml
29 | /Pipfile
30 | /Pipfile.lock
31 | /datasets
32 | /examples/tacotron2/exp/
33 | /temp/
34 | LibriTTS/
35 | dataset/
36 | mfa/
37 | kss/
38 | baker/
39 | libritts/
40 | dump_baker/
41 | dump_ljspeech/
42 | dump_kss/
43 | dump_libritts/
44 | /notebooks/test_saved/
45 | build/
46 | dist/
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '2.6'
2 | services:
3 | tensorflowtts:
4 | build: .
5 | volumes:
6 | - .:/workspace
7 | runtime: nvidia
8 | tty: true
9 | command: /bin/bash
10 | environment:
11 | - CUDA_VISIBLE_DEVICES
12 |
--------------------------------------------------------------------------------
/dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:2.6.0-gpu
2 | RUN apt-get update
3 | RUN apt-get install -y zsh tmux wget git libsndfile1
4 | RUN pip install ipython && \
5 | pip install git+https://github.com/TensorSpeech/TensorflowTTS.git && \
6 | pip install git+https://github.com/repodiac/german_transliterate.git#egg=german_transliterate
7 | RUN mkdir /workspace
8 | WORKDIR /workspace
9 |
--------------------------------------------------------------------------------
/examples/android/.gitignore:
--------------------------------------------------------------------------------
1 | # Android Studio
2 | *.iml
3 | .gradle
4 | /local.properties
5 | /.idea
6 | .DS_Store
7 | /build
8 | /captures
9 |
10 | # Built application files
11 | *.apk
12 | !prebuiltapps/*.apk
13 | *.ap_
14 |
15 | # Files for the Dalvik VM
16 | *.dex
17 |
18 | # Java class files
19 | *.class
20 |
21 | # Generated files
22 | bin/
23 | gen/
24 |
25 | # Gradle files
26 | .gradle/
27 | build/
28 | */build/
29 |
30 | # Local configuration file (sdk path, etc)
31 | local.properties
32 |
33 | # Proguard folder generated by Eclipse
34 | proguard/
35 |
36 | # Log Files
37 | *.log
38 |
39 | # project
40 | project.properties
41 | .classpath
42 | .project
43 | .settings/
44 |
45 | # Intellij project files
46 | *.ipr
47 | *.iws
48 | .idea/
49 | app/.gradle/
50 | .idea/libraries
51 | .idea/workspace.xml
52 | .idea/vcs.xml
53 | .idea/scopes/scope_setting.xml
54 | .idea/moudles.xml
55 | .idea/misc.xml
56 | .idea/inspectionProfiles/Project_Default.xml
57 | .idea/inspectionProfiles/profiles_setting.xml
58 | .idea/encodings.xml
59 | .idea/.name
60 |
--------------------------------------------------------------------------------
/examples/android/README.md:
--------------------------------------------------------------------------------
1 | ### Android Demo
2 |
3 | This is a simple Android demo which will load converted FastSpeech2 and Multi-Band MelGAN modules to synthesize audio.
4 | In order to optimize the synthesize speed, two LinkedBlockingQueues have been implemented.
5 |
6 |
7 | ### HOW-TO
8 | 1. Import this project into Android Studio.
9 | 2. Run the app!
10 |
11 | ### LICENSE
12 | The license use for this code is [CC BY-NC 3.0](https://creativecommons.org/licenses/by-nc/3.0/). Please read the license carefully before you use it.
13 |
14 | ### Contributors
15 | [Xuefeng Ding](https://github.com/mapledxf)
16 |
--------------------------------------------------------------------------------
/examples/android/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
2 |
--------------------------------------------------------------------------------
/examples/android/app/build.gradle:
--------------------------------------------------------------------------------
1 | apply plugin: 'com.android.application'
2 |
3 | android {
4 | compileSdkVersion 29
5 | buildToolsVersion "29.0.2"
6 | defaultConfig {
7 | applicationId "com.tensorspeech.tensorflowtts"
8 | minSdkVersion 21
9 | targetSdkVersion 29
10 | versionCode 1
11 | versionName "1.0"
12 | }
13 | buildTypes {
14 | release {
15 | minifyEnabled false
16 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
17 | }
18 | }
19 | aaptOptions {
20 | noCompress "tflite"
21 | }
22 | compileOptions {
23 | sourceCompatibility = '1.8'
24 | targetCompatibility = '1.8'
25 | }
26 | lintOptions {
27 | abortOnError false
28 | }
29 | }
30 |
31 | dependencies {
32 | implementation fileTree(dir: 'libs', include: ['*.jar'])
33 | implementation 'androidx.appcompat:appcompat:1.1.0'
34 | implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
35 |
36 | implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
37 | implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly'
38 | implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
39 | }
40 |
--------------------------------------------------------------------------------
/examples/android/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
22 |
--------------------------------------------------------------------------------
/examples/android/app/src/androidTest/java/com/tensorspeech/tensorflowtts/ExampleInstrumentedTest.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts;
2 |
3 | import android.content.Context;
4 |
5 | import androidx.test.platform.app.InstrumentationRegistry;
6 | import androidx.test.ext.junit.runners.AndroidJUnit4;
7 |
8 | import org.junit.Test;
9 | import org.junit.runner.RunWith;
10 |
11 | import static org.junit.Assert.*;
12 |
13 | /**
14 | * Instrumented test, which will execute on an Android device.
15 | *
16 | * @see Testing documentation
17 | */
18 | @RunWith(AndroidJUnit4.class)
19 | public class ExampleInstrumentedTest {
20 | @Test
21 | public void useAppContext() {
22 | // Context of the app under test.
23 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
24 |
25 | assertEquals("com.tensorspeech.tensorflowtts", appContext.getPackageName());
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/assets/fastspeech2_quant.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/assets/fastspeech2_quant.tflite
--------------------------------------------------------------------------------
/examples/android/app/src/main/assets/mbmelgan.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/assets/mbmelgan.tflite
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/MainActivity.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts;
2 |
3 | import android.os.Bundle;
4 | import android.text.TextUtils;
5 | import android.view.View;
6 | import android.widget.EditText;
7 | import android.widget.RadioGroup;
8 |
9 | import androidx.appcompat.app.AppCompatActivity;
10 |
11 | import com.tensorspeech.tensorflowtts.dispatcher.OnTtsStateListener;
12 | import com.tensorspeech.tensorflowtts.dispatcher.TtsStateDispatcher;
13 | import com.tensorspeech.tensorflowtts.tts.TtsManager;
14 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager;
15 |
16 | /**
17 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
18 | * Created 2020-07-20 17:25
19 | */
20 | public class MainActivity extends AppCompatActivity {
21 | private static final String DEFAULT_INPUT_TEXT = "Unless you work on a ship, it's unlikely that you use the word boatswain in everyday conversation, so it's understandably a tricky one. The word - which refers to a petty officer in charge of hull maintenance is not pronounced boats-wain Rather, it's bo-sun to reflect the salty pronunciation of sailors, as The Free Dictionary explains./Blue opinion poll conducted for the National Post.";
22 |
23 | private View speakBtn;
24 | private RadioGroup speedGroup;
25 |
26 | @Override
27 | protected void onCreate(Bundle savedInstanceState) {
28 | super.onCreate(savedInstanceState);
29 | setContentView(R.layout.activity_main);
30 |
31 | TtsManager.getInstance().init(this);
32 |
33 | TtsStateDispatcher.getInstance().addListener(new OnTtsStateListener() {
34 | @Override
35 | public void onTtsReady() {
36 | speakBtn.setEnabled(true);
37 | }
38 |
39 | @Override
40 | public void onTtsStart(String text) {
41 | }
42 |
43 | @Override
44 | public void onTtsStop() {
45 | }
46 | });
47 |
48 | EditText input = findViewById(R.id.input);
49 | input.setHint(DEFAULT_INPUT_TEXT);
50 |
51 | speedGroup = findViewById(R.id.speed_chooser);
52 | speedGroup.check(R.id.normal);
53 |
54 | speakBtn = findViewById(R.id.start);
55 | speakBtn.setEnabled(false);
56 | speakBtn.setOnClickListener(v ->
57 | ThreadPoolManager.getInstance().execute(() -> {
58 | float speed ;
59 | switch (speedGroup.getCheckedRadioButtonId()) {
60 | case R.id.fast:
61 | speed = 0.8F;
62 | break;
63 | case R.id.slow:
64 | speed = 1.2F;
65 | break;
66 | case R.id.normal:
67 | default:
68 | speed = 1.0F;
69 | break;
70 | }
71 |
72 | String inputText = input.getText().toString();
73 | if (TextUtils.isEmpty(inputText)) {
74 | inputText = DEFAULT_INPUT_TEXT;
75 | }
76 | TtsManager.getInstance().speak(inputText, speed, true);
77 | }));
78 |
79 | findViewById(R.id.stop).setOnClickListener(v ->
80 | TtsManager.getInstance().stopTts());
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/dispatcher/OnTtsStateListener.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.dispatcher;
2 |
3 | /**
4 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
5 | * Created 2020-07-28 14:25
6 | */
7 | public interface OnTtsStateListener {
8 | public void onTtsReady();
9 |
10 | public void onTtsStart(String text);
11 |
12 | public void onTtsStop();
13 | }
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/dispatcher/TtsStateDispatcher.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.dispatcher;
2 |
3 | import android.os.Handler;
4 | import android.os.Looper;
5 | import android.util.Log;
6 |
7 | import java.util.concurrent.CopyOnWriteArrayList;
8 |
9 | /**
10 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
11 | * Created 2020-07-28 14:25
12 | */
13 | public class TtsStateDispatcher {
14 | private static final String TAG = "TtsStateDispatcher";
15 | private static volatile TtsStateDispatcher instance;
16 | private static final Object INSTANCE_WRITE_LOCK = new Object();
17 |
18 | public static TtsStateDispatcher getInstance() {
19 | if (instance == null) {
20 | synchronized (INSTANCE_WRITE_LOCK) {
21 | if (instance == null) {
22 | instance = new TtsStateDispatcher();
23 | }
24 | }
25 | }
26 | return instance;
27 | }
28 |
29 | private final Handler handler = new Handler(Looper.getMainLooper());
30 |
31 | private CopyOnWriteArrayList mListeners = new CopyOnWriteArrayList<>();
32 |
33 | public void release() {
34 | Log.d(TAG, "release: ");
35 | mListeners.clear();
36 | }
37 |
38 | public void addListener(OnTtsStateListener listener) {
39 | if (mListeners.contains(listener)) {
40 | return;
41 | }
42 | Log.d(TAG, "addListener: " + listener.getClass());
43 | mListeners.add(listener);
44 | }
45 |
46 | public void removeListener(OnTtsStateListener listener) {
47 | if (mListeners.contains(listener)) {
48 | Log.d(TAG, "removeListener: " + listener.getClass());
49 | mListeners.remove(listener);
50 | }
51 | }
52 |
53 | public void onTtsStart(String text){
54 | Log.d(TAG, "onTtsStart: ");
55 | if (!mListeners.isEmpty()) {
56 | for (OnTtsStateListener listener : mListeners) {
57 | handler.post(() -> listener.onTtsStart(text));
58 | }
59 | }
60 | }
61 |
62 | public void onTtsStop(){
63 | Log.d(TAG, "onTtsStop: ");
64 | if (!mListeners.isEmpty()) {
65 | for (OnTtsStateListener listener : mListeners) {
66 | handler.post(listener::onTtsStop);
67 | }
68 | }
69 | }
70 |
71 | public void onTtsReady(){
72 | Log.d(TAG, "onTtsReady: ");
73 | if (!mListeners.isEmpty()) {
74 | for (OnTtsStateListener listener : mListeners) {
75 | handler.post(listener::onTtsReady);
76 | }
77 | }
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/AbstractModule.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.module;
2 |
3 | import org.tensorflow.lite.Interpreter;
4 |
5 | /**
6 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
7 | * Created 2020-07-20 17:25
8 | *
9 | */
10 | abstract class AbstractModule {
11 |
12 | Interpreter.Options getOption() {
13 | Interpreter.Options options = new Interpreter.Options();
14 | options.setNumThreads(5);
15 | return options;
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/FastSpeech2.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.module;
2 |
3 | import android.annotation.SuppressLint;
4 | import android.util.Log;
5 |
6 | import org.tensorflow.lite.DataType;
7 | import org.tensorflow.lite.Interpreter;
8 | import org.tensorflow.lite.Tensor;
9 | import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
10 |
11 | import java.io.File;
12 | import java.nio.FloatBuffer;
13 | import java.util.Arrays;
14 | import java.util.HashMap;
15 | import java.util.Map;
16 |
17 | /**
18 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
19 | * Created 2020-07-20 17:26
20 | *
21 | */
22 | public class FastSpeech2 extends AbstractModule {
23 | private static final String TAG = "FastSpeech2";
24 | private Interpreter mModule;
25 |
26 | public FastSpeech2(String modulePath) {
27 | try {
28 | mModule = new Interpreter(new File(modulePath), getOption());
29 | int input = mModule.getInputTensorCount();
30 | for (int i = 0; i < input; i++) {
31 | Tensor inputTensor = mModule.getInputTensor(i);
32 | Log.d(TAG, "input:" + i +
33 | " name:" + inputTensor.name() +
34 | " shape:" + Arrays.toString(inputTensor.shape()) +
35 | " dtype:" + inputTensor.dataType());
36 | }
37 |
38 | int output = mModule.getOutputTensorCount();
39 | for (int i = 0; i < output; i++) {
40 | Tensor outputTensor = mModule.getOutputTensor(i);
41 | Log.d(TAG, "output:" + i +
42 | " name:" + outputTensor.name() +
43 | " shape:" + Arrays.toString(outputTensor.shape()) +
44 | " dtype:" + outputTensor.dataType());
45 | }
46 | Log.d(TAG, "successfully init");
47 | } catch (Exception e) {
48 | e.printStackTrace();
49 | }
50 | }
51 |
52 | public TensorBuffer getMelSpectrogram(int[] inputIds, float speed) {
53 | Log.d(TAG, "input id length: " + inputIds.length);
54 | mModule.resizeInput(0, new int[]{1, inputIds.length});
55 | mModule.allocateTensors();
56 |
57 | @SuppressLint("UseSparseArrays")
58 | Map outputMap = new HashMap<>();
59 |
60 | FloatBuffer outputBuffer = FloatBuffer.allocate(350000);
61 | outputMap.put(0, outputBuffer);
62 |
63 | int[][] inputs = new int[1][inputIds.length];
64 | inputs[0] = inputIds;
65 |
66 | long time = System.currentTimeMillis();
67 | mModule.runForMultipleInputsOutputs(
68 | new Object[]{inputs, new int[1][1], new int[]{0}, new float[]{speed}, new float[]{1F}, new float[]{1F}},
69 | outputMap);
70 | Log.d(TAG, "time cost: " + (System.currentTimeMillis() - time));
71 |
72 | int size = mModule.getOutputTensor(0).shape()[2];
73 | int[] shape = {1, outputBuffer.position() / size, size};
74 | TensorBuffer spectrogram = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
75 | float[] outputArray = new float[outputBuffer.position()];
76 | outputBuffer.rewind();
77 | outputBuffer.get(outputArray);
78 | spectrogram.loadArray(outputArray);
79 |
80 | return spectrogram;
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/MBMelGan.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.module;
2 |
3 | import android.util.Log;
4 |
5 | import org.tensorflow.lite.Interpreter;
6 | import org.tensorflow.lite.Tensor;
7 | import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
8 |
9 | import java.io.File;
10 | import java.nio.FloatBuffer;
11 | import java.util.Arrays;
12 |
13 | /**
14 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
15 | * Created 2020-07-20 17:26
16 | *
17 | */
18 | public class MBMelGan extends AbstractModule {
19 | private static final String TAG = "MBMelGan";
20 | private Interpreter mModule;
21 |
22 | public MBMelGan(String modulePath) {
23 | try {
24 | mModule = new Interpreter(new File(modulePath), getOption());
25 | int input = mModule.getInputTensorCount();
26 | for (int i = 0; i < input; i++) {
27 | Tensor inputTensor = mModule.getInputTensor(i);
28 | Log.d(TAG, "input:" + i
29 | + " name:" + inputTensor.name()
30 | + " shape:" + Arrays.toString(inputTensor.shape()) +
31 | " dtype:" + inputTensor.dataType());
32 | }
33 |
34 | int output = mModule.getOutputTensorCount();
35 | for (int i = 0; i < output; i++) {
36 | Tensor outputTensor = mModule.getOutputTensor(i);
37 | Log.d(TAG, "output:" + i
38 | + " name:" + outputTensor.name()
39 | + " shape:" + Arrays.toString(outputTensor.shape())
40 | + " dtype:" + outputTensor.dataType());
41 | }
42 | Log.d(TAG, "successfully init");
43 | } catch (Exception e) {
44 | e.printStackTrace();
45 | }
46 | }
47 |
48 |
49 | public float[] getAudio(TensorBuffer input) {
50 | mModule.resizeInput(0, input.getShape());
51 | mModule.allocateTensors();
52 |
53 | FloatBuffer outputBuffer = FloatBuffer.allocate(350000);
54 |
55 | long time = System.currentTimeMillis();
56 | mModule.run(input.getBuffer(), outputBuffer);
57 | Log.d(TAG, "time cost: " + (System.currentTimeMillis() - time));
58 |
59 | float[] audioArray = new float[outputBuffer.position()];
60 | outputBuffer.rewind();
61 | outputBuffer.get(audioArray);
62 | return audioArray;
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/tts/TtsManager.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.tts;
2 |
3 | import android.content.Context;
4 | import android.util.Log;
5 |
6 | import com.tensorspeech.tensorflowtts.dispatcher.TtsStateDispatcher;
7 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager;
8 |
9 | import java.io.File;
10 | import java.io.FileOutputStream;
11 | import java.io.InputStream;
12 | import java.io.OutputStream;
13 |
14 | /**
15 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
16 | * Created 2020-07-28 14:25
17 | */
18 | public class TtsManager {
19 | private static final String TAG = "TtsManager";
20 |
21 | private static final Object INSTANCE_WRITE_LOCK = new Object();
22 |
23 | private static volatile TtsManager instance;
24 |
25 | public static TtsManager getInstance() {
26 | if (instance == null) {
27 | synchronized (INSTANCE_WRITE_LOCK) {
28 | if (instance == null) {
29 | instance = new TtsManager();
30 | }
31 | }
32 | }
33 | return instance;
34 | }
35 |
36 | private InputWorker mWorker;
37 |
38 | private final static String FASTSPEECH2_MODULE = "fastspeech2_quant.tflite";
39 | private final static String MELGAN_MODULE = "mbmelgan.tflite";
40 |
41 | public void init(Context context) {
42 | ThreadPoolManager.getInstance().getSingleExecutor("init").execute(() -> {
43 | try {
44 | String fastspeech = copyFile(context, FASTSPEECH2_MODULE);
45 | String vocoder = copyFile(context, MELGAN_MODULE);
46 | mWorker = new InputWorker(fastspeech, vocoder);
47 | } catch (Exception e) {
48 | Log.e(TAG, "mWorker init failed", e);
49 | }
50 |
51 | TtsStateDispatcher.getInstance().onTtsReady();
52 | });
53 | }
54 |
55 | private String copyFile(Context context, String strOutFileName) {
56 | Log.d(TAG, "start copy file " + strOutFileName);
57 | File file = context.getFilesDir();
58 |
59 | String tmpFile = file.getAbsolutePath() + "/" + strOutFileName;
60 | File f = new File(tmpFile);
61 | if (f.exists()) {
62 | Log.d(TAG, "file exists " + strOutFileName);
63 | return f.getAbsolutePath();
64 | }
65 |
66 | try (OutputStream myOutput = new FileOutputStream(f);
67 | InputStream myInput = context.getAssets().open(strOutFileName)) {
68 | byte[] buffer = new byte[1024];
69 | int length = myInput.read(buffer);
70 | while (length > 0) {
71 | myOutput.write(buffer, 0, length);
72 | length = myInput.read(buffer);
73 | }
74 | myOutput.flush();
75 | Log.d(TAG, "Copy task successful");
76 | } catch (Exception e) {
77 | Log.e(TAG, "copyFile: Failed to copy", e);
78 | } finally {
79 | Log.d(TAG, "end copy file " + strOutFileName);
80 | }
81 | return f.getAbsolutePath();
82 | }
83 |
84 | public void stopTts() {
85 | mWorker.interrupt();
86 | }
87 |
88 | public void speak(String inputText, float speed, boolean interrupt) {
89 | if (interrupt) {
90 | stopTts();
91 | }
92 |
93 | ThreadPoolManager.getInstance().execute(() ->
94 | mWorker.processInput(inputText, speed));
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/tts/TtsPlayer.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts.tts;
2 |
3 | import android.media.AudioAttributes;
4 | import android.media.AudioFormat;
5 | import android.media.AudioManager;
6 | import android.media.AudioTrack;
7 | import android.util.Log;
8 |
9 | import com.tensorspeech.tensorflowtts.utils.ThreadPoolManager;
10 |
11 | import java.util.concurrent.LinkedBlockingQueue;
12 |
13 | /**
14 | * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
15 | * Created 2020-07-20 18:22
16 | */
17 | class TtsPlayer {
18 | private static final String TAG = "TtsPlayer";
19 |
20 | private final AudioTrack mAudioTrack;
21 |
22 | private final static int FORMAT = AudioFormat.ENCODING_PCM_FLOAT;
23 | private final static int SAMPLERATE = 22050;
24 | private final static int CHANNEL = AudioFormat.CHANNEL_OUT_MONO;
25 | private final static int BUFFER_SIZE = AudioTrack.getMinBufferSize(SAMPLERATE, CHANNEL, FORMAT);
26 | private LinkedBlockingQueue mAudioQueue = new LinkedBlockingQueue<>();
27 | private AudioData mCurrentAudioData;
28 |
29 | TtsPlayer() {
30 | mAudioTrack = new AudioTrack(
31 | new AudioAttributes.Builder()
32 | .setUsage(AudioAttributes.USAGE_MEDIA)
33 | .setContentType(AudioAttributes.CONTENT_TYPE_MUSIC)
34 | .build(),
35 | new AudioFormat.Builder()
36 | .setSampleRate(22050)
37 | .setEncoding(FORMAT)
38 | .setChannelMask(CHANNEL)
39 | .build(),
40 | BUFFER_SIZE,
41 | AudioTrack.MODE_STREAM, AudioManager.AUDIO_SESSION_ID_GENERATE
42 | );
43 | mAudioTrack.play();
44 |
45 | ThreadPoolManager.getInstance().getSingleExecutor("audio").execute(() -> {
46 | //noinspection InfiniteLoopStatement
47 | while (true) {
48 | try {
49 | mCurrentAudioData = mAudioQueue.take();
50 | Log.d(TAG, "playing: " + mCurrentAudioData.text);
51 | int index = 0;
52 | while (index < mCurrentAudioData.audio.length && !mCurrentAudioData.isInterrupt) {
53 | int buffer = Math.min(BUFFER_SIZE, mCurrentAudioData.audio.length - index);
54 | mAudioTrack.write(mCurrentAudioData.audio, index, buffer, AudioTrack.WRITE_BLOCKING);
55 | index += BUFFER_SIZE;
56 | }
57 | } catch (Exception e) {
58 | Log.e(TAG, "Exception: ", e);
59 | }
60 | }
61 | });
62 | }
63 |
64 | void play(AudioData audioData) {
65 | Log.d(TAG, "add audio data to queue: " + audioData.text);
66 | mAudioQueue.offer(audioData);
67 | }
68 |
69 | void interrupt() {
70 | mAudioQueue.clear();
71 | if (mCurrentAudioData != null) {
72 | mCurrentAudioData.interrupt();
73 | }
74 | }
75 |
76 | static class AudioData {
77 | private String text;
78 | private float[] audio;
79 | private boolean isInterrupt;
80 |
81 | AudioData(String text, float[] audio) {
82 | this.text = text;
83 | this.audio = audio;
84 | }
85 |
86 | private void interrupt() {
87 | isInterrupt = true;
88 | }
89 | }
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
7 |
12 |
13 |
19 |
22 |
25 |
26 |
27 |
28 |
34 |
35 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/layout/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
17 |
18 |
23 |
24 |
28 |
29 |
34 |
35 |
40 |
41 |
46 |
47 |
52 |
53 |
54 |
55 |
56 |
61 |
62 |
68 |
69 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | #008577
4 | #00574B
5 | #D81B60
6 |
7 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
2 | TensorflowTTS
3 |
4 |
--------------------------------------------------------------------------------
/examples/android/app/src/main/res/values/styles.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/examples/android/app/src/test/java/com/tensorspeech/tensorflowtts/ExampleUnitTest.java:
--------------------------------------------------------------------------------
1 | package com.tensorspeech.tensorflowtts;
2 |
3 | import org.junit.Test;
4 |
5 | import static org.junit.Assert.*;
6 |
7 | /**
8 | * Example local unit test, which will execute on the development machine (host).
9 | *
10 | * @see Testing documentation
11 | */
12 | public class ExampleUnitTest {
13 | @Test
14 | public void addition_isCorrect() {
15 | assertEquals(4, 2 + 2);
16 | }
17 | }
--------------------------------------------------------------------------------
/examples/android/build.gradle:
--------------------------------------------------------------------------------
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 |
3 | buildscript {
4 | repositories {
5 | google()
6 | jcenter()
7 |
8 | }
9 | dependencies {
10 | classpath 'com.android.tools.build:gradle:3.5.2'
11 |
12 | // NOTE: Do not place your application dependencies here; they belong
13 | // in the individual module build.gradle files
14 | }
15 | }
16 |
17 | allprojects {
18 | repositories {
19 | google()
20 | jcenter()
21 |
22 | }
23 | }
24 |
25 | task clean(type: Delete) {
26 | delete rootProject.buildDir
27 | }
28 |
--------------------------------------------------------------------------------
/examples/android/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx1536m
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. More details, visit
12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
13 | # org.gradle.parallel=true
14 | # AndroidX package structure to make it clearer which packages are bundled with the
15 | # Android operating system, and which are packaged with your app's APK
16 | # https://developer.android.com/topic/libraries/support-library/androidx-rn
17 | android.useAndroidX=true
18 | # Automatically convert third-party libraries to use AndroidX
19 | android.enableJetifier=true
20 |
21 |
--------------------------------------------------------------------------------
/examples/android/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/android/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/examples/android/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Mon Jul 20 11:21:10 CST 2020
2 | distributionBase=GRADLE_USER_HOME
3 | distributionPath=wrapper/dists
4 | zipStoreBase=GRADLE_USER_HOME
5 | zipStorePath=wrapper/dists
6 | distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip
7 |
--------------------------------------------------------------------------------
/examples/android/gradlew.bat:
--------------------------------------------------------------------------------
1 | @if "%DEBUG%" == "" @echo off
2 | @rem ##########################################################################
3 | @rem
4 | @rem Gradle startup script for Windows
5 | @rem
6 | @rem ##########################################################################
7 |
8 | @rem Set local scope for the variables with windows NT shell
9 | if "%OS%"=="Windows_NT" setlocal
10 |
11 | set DIRNAME=%~dp0
12 | if "%DIRNAME%" == "" set DIRNAME=.
13 | set APP_BASE_NAME=%~n0
14 | set APP_HOME=%DIRNAME%
15 |
16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
17 | set DEFAULT_JVM_OPTS=
18 |
19 | @rem Find java.exe
20 | if defined JAVA_HOME goto findJavaFromJavaHome
21 |
22 | set JAVA_EXE=java.exe
23 | %JAVA_EXE% -version >NUL 2>&1
24 | if "%ERRORLEVEL%" == "0" goto init
25 |
26 | echo.
27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
28 | echo.
29 | echo Please set the JAVA_HOME variable in your environment to match the
30 | echo location of your Java installation.
31 |
32 | goto fail
33 |
34 | :findJavaFromJavaHome
35 | set JAVA_HOME=%JAVA_HOME:"=%
36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
37 |
38 | if exist "%JAVA_EXE%" goto init
39 |
40 | echo.
41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
42 | echo.
43 | echo Please set the JAVA_HOME variable in your environment to match the
44 | echo location of your Java installation.
45 |
46 | goto fail
47 |
48 | :init
49 | @rem Get command-line arguments, handling Windows variants
50 |
51 | if not "%OS%" == "Windows_NT" goto win9xME_args
52 |
53 | :win9xME_args
54 | @rem Slurp the command line arguments.
55 | set CMD_LINE_ARGS=
56 | set _SKIP=2
57 |
58 | :win9xME_args_slurp
59 | if "x%~1" == "x" goto execute
60 |
61 | set CMD_LINE_ARGS=%*
62 |
63 | :execute
64 | @rem Setup the command line
65 |
66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
67 |
68 | @rem Execute Gradle
69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
70 |
71 | :end
72 | @rem End local scope for the variables with windows NT shell
73 | if "%ERRORLEVEL%"=="0" goto mainEnd
74 |
75 | :fail
76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
77 | rem the _cmd.exe /c_ return code!
78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
79 | exit /b 1
80 |
81 | :mainEnd
82 | if "%OS%"=="Windows_NT" endlocal
83 |
84 | :omega
85 |
--------------------------------------------------------------------------------
/examples/android/settings.gradle:
--------------------------------------------------------------------------------
1 | include ':app'
2 | rootProject.name='TensorflowTTS'
3 |
--------------------------------------------------------------------------------
/examples/cpptflite/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | /build
3 | /models
4 | /lib
5 | lib.zip
6 | models.zip
7 | models_ljspeech.zip
--------------------------------------------------------------------------------
/examples/cpptflite/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.6)
2 | PROJECT(TfliteTTS)
3 |
4 | option(MAPPER "Processor select (supported BAKER or LJSPEECH)")
5 | if (${MAPPER} STREQUAL "LJSPEECH")
6 | add_definitions(-DLJSPEECH)
7 | elseif (${MAPPER} STREQUAL "BAKER")
8 | add_definitions(-DBAKER)
9 | else ()
10 | message(FATAL_ERROR "MAPPER is only supported BAKER or LJSPEECH")
11 | endif()
12 |
13 | message(STATUS "MAPPER is selected: "${MAPPER})
14 |
15 | include_directories(lib)
16 | include_directories(lib/flatbuffers/include)
17 | include_directories(src)
18 |
19 | aux_source_directory(src DIR_SRCS)
20 |
21 | SET(CMAKE_CXX_COMPILER "g++")
22 |
23 | SET(CMAKE_CXX_FLAGS "-O3 -DNDEBUG -Wl,--no-as-needed -ldl -pthread -fpermissive")
24 |
25 | add_executable(demo demo/main.cpp ${DIR_SRCS})
26 |
27 | find_library(tflite_LIB tensorflow-lite lib)
28 |
29 | target_link_libraries(demo ${tflite_LIB})
--------------------------------------------------------------------------------
/examples/cpptflite/demo/main.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "VoxCommon.h"
5 | #include "TTSFrontend.h"
6 | #include "TTSBackend.h"
7 |
8 | typedef struct
9 | {
10 | const char* mapperJson;
11 | unsigned int sampleRate;
12 | } Processor;
13 |
14 | int main(int argc, char* argv[])
15 | {
16 | if (argc != 3)
17 | {
18 | fprintf(stderr, "demo text wavfile\n");
19 | return 1;
20 | }
21 |
22 | const char* cmd = "python3 ../demo/text2ids.py";
23 |
24 | Processor proc;
25 | #if LJSPEECH
26 | proc.mapperJson = "../../../tensorflow_tts/processor/pretrained/ljspeech_mapper.json";
27 | proc.sampleRate = 22050;
28 | #elif BAKER
29 | proc.mapperJson = "../../../tensorflow_tts/processor/pretrained/baker_mapper.json";
30 | proc.sampleRate = 24000;
31 | #endif
32 |
33 | const char* melgenfile = "../models/fastspeech2_quan.tflite";
34 | const char* vocoderfile = "../models/mb_melgan.tflite";
35 |
36 | // Init
37 | TTSFrontend ttsfrontend(proc.mapperJson, cmd);
38 | TTSBackend ttsbackend(melgenfile, vocoderfile);
39 |
40 | // Process
41 | ttsfrontend.text2ids(argv[1]);
42 | std::vector phonesIds = ttsfrontend.getPhoneIds();
43 |
44 | ttsbackend.inference(phonesIds);
45 | MelGenData mel = ttsbackend.getMel();
46 | std::vector audio = ttsbackend.getAudio();
47 |
48 | std::cout << "********* Phones' ID *********" << std::endl;
49 |
50 | for (auto iter: phonesIds)
51 | {
52 | std::cout << iter << " ";
53 | }
54 | std::cout << std::endl;
55 |
56 | std::cout << "********* MEL SHAPE **********" << std::endl;
57 | for (auto index : mel.melShape)
58 | {
59 | std::cout << index << " ";
60 | }
61 | std::cout << std::endl;
62 |
63 | std::cout << "********* AUDIO LEN **********" << std::endl;
64 | std::cout << audio.size() << std::endl;
65 |
66 | VoxUtil::ExportWAV(argv[2], audio, proc.sampleRate);
67 | std::cout << "Wavfile: " << argv[2] << " creats." << std::endl;
68 |
69 | return 0;
70 | }
--------------------------------------------------------------------------------
/examples/cpptflite/demo/text2ids.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import re
3 |
4 | eng_pat = re.compile("[a-zA-Z]+")
5 |
6 | if __name__ == "__main__":
7 | argvs = sys.argv
8 |
9 | if (len(argvs) != 3):
10 | print("usage: python3 {} mapper.json text".format(argvs[0]))
11 | else:
12 | from tensorflow_tts.inference import AutoProcessor
13 | mapper_json = argvs[1]
14 | processor = AutoProcessor.from_pretrained(pretrained_path=mapper_json)
15 |
16 | input_text = argvs[2]
17 |
18 | if eng_pat.match(input_text):
19 | input_ids = processor.text_to_sequence(input_text)
20 | else:
21 | input_ids = processor.text_to_sequence(input_text, inference=True)
22 |
23 | print(" ".join(str(i) for i in input_ids))
--------------------------------------------------------------------------------
/examples/cpptflite/results/lj_ori_mel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/lj_ori_mel.png
--------------------------------------------------------------------------------
/examples/cpptflite/results/lj_tflite_mel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/lj_tflite_mel.png
--------------------------------------------------------------------------------
/examples/cpptflite/results/tflite_mel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/tflite_mel.png
--------------------------------------------------------------------------------
/examples/cpptflite/results/tflite_mel2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/cpptflite/results/tflite_mel2.png
--------------------------------------------------------------------------------
/examples/cpptflite/src/MelGenerateTF.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "MelGenerateTF.h"
3 |
4 | MelGenData MelGenerateTF::infer(const std::vector inputIds)
5 | {
6 |
7 | MelGenData output;
8 |
9 | int32_t idsLen = inputIds.size();
10 |
11 | std::vector> inputIndexsShape{ {1, idsLen}, {1}, {1}, {1}, {1} };
12 |
13 | int32_t shapeI = 0;
14 | for (auto index : inputIndexs)
15 | {
16 | interpreter->ResizeInputTensor(index, inputIndexsShape[shapeI]);
17 | shapeI++;
18 | }
19 |
20 | TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
21 |
22 | int32_t* input_ids_ptr = interpreter->typed_tensor(inputIndexs[0]);
23 | memcpy(input_ids_ptr, inputIds.data(), int_size * idsLen);
24 |
25 | int32_t* speaker_ids_ptr = interpreter->typed_tensor(inputIndexs[1]);
26 | memcpy(speaker_ids_ptr, _speakerId.data(), int_size);
27 |
28 | float* speed_ratios_ptr = interpreter->typed_tensor(inputIndexs[2]);
29 | memcpy(speed_ratios_ptr, _speedRatio.data(), float_size);
30 |
31 | float* speed_ratios2_ptr = interpreter->typed_tensor(inputIndexs[3]);
32 | memcpy(speed_ratios2_ptr, _f0Ratio.data(), float_size);
33 |
34 | float* speed_ratios3_ptr = interpreter->typed_tensor(inputIndexs[4]);
35 | memcpy(speed_ratios3_ptr, _enegyRatio.data(), float_size);
36 |
37 | TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
38 |
39 | TfLiteTensor* melGenTensor = interpreter->tensor(ouptIndex);
40 |
41 | for (int i=0; idims->size; i++)
42 | {
43 | output.melShape.push_back(melGenTensor->dims->data[i]);
44 | }
45 |
46 | output.bytes = melGenTensor->bytes;
47 |
48 | output.melData = interpreter->typed_tensor(ouptIndex);
49 |
50 | return output;
51 | }
--------------------------------------------------------------------------------
/examples/cpptflite/src/MelGenerateTF.h:
--------------------------------------------------------------------------------
1 | #ifndef MELGENERATETF_H
2 | #define MELGENERATETF_H
3 |
4 | #include "TfliteBase.h"
5 |
6 | class MelGenerateTF : public TfliteBase
7 | {
8 | public:
9 |
10 | MelGenerateTF(const char* modelFilename):TfliteBase(modelFilename),
11 | inputIndexs(interpreter->inputs()),
12 | ouptIndex(interpreter->outputs()[1]) {};
13 |
14 | MelGenData infer(const std::vector inputIds);
15 |
16 | private:
17 | std::vector _speakerId{0};
18 | std::vector _speedRatio{1.0};
19 | std::vector _f0Ratio{1.0};
20 | std::vector _enegyRatio{1.0};
21 |
22 | const std::vector inputIndexs;
23 | const int32_t ouptIndex;
24 |
25 | };
26 |
27 | #endif // MELGENERATETF_H
--------------------------------------------------------------------------------
/examples/cpptflite/src/TTSBackend.cpp:
--------------------------------------------------------------------------------
1 | #include "TTSBackend.h"
2 |
3 | void TTSBackend::inference(std::vector phonesIds)
4 | {
5 | _mel = MelGen.infer(phonesIds);
6 | _audio = Vocoder.infer(_mel);
7 | }
--------------------------------------------------------------------------------
/examples/cpptflite/src/TTSBackend.h:
--------------------------------------------------------------------------------
1 | #ifndef TTSBACKEND_H
2 | #define TTSBACKEND_H
3 |
4 | #include
5 | #include
6 | #include "MelGenerateTF.h"
7 | #include "VocoderTF.h"
8 |
9 | class TTSBackend
10 | {
11 | public:
12 | TTSBackend(const char* melgenfile, const char* vocoderfile):
13 | MelGen(melgenfile), Vocoder(vocoderfile)
14 | {
15 | std::cout << "TTSBackend Init" << std::endl;
16 | std::cout << melgenfile << std::endl;
17 | std::cout << vocoderfile << std::endl;
18 | };
19 |
20 | void inference(std::vector phonesIds);
21 |
22 | MelGenData getMel() const {return _mel;}
23 | std::vector getAudio() const {return _audio;}
24 |
25 | private:
26 | MelGenerateTF MelGen;
27 | VocoderTF Vocoder;
28 |
29 | MelGenData _mel;
30 | std::vector _audio;
31 | };
32 |
33 | #endif // TTSBACKEND_H
--------------------------------------------------------------------------------
/examples/cpptflite/src/TTSFrontend.cpp:
--------------------------------------------------------------------------------
1 | #include "TTSFrontend.h"
2 |
3 | void TTSFrontend::text2ids(const std::string &text)
4 | {
5 | _phonesIds = strSplit(getCmdResult(text));
6 | }
7 |
8 | std::string TTSFrontend::getCmdResult(const std::string &text)
9 | {
10 | char buf[1000] = {0};
11 | FILE *pf = NULL;
12 |
13 | if( (pf = popen((_strCmd + " " + _mapperJson + " \"" + text + "\"").c_str(), "r")) == NULL )
14 | {
15 | return "";
16 | }
17 |
18 | while(fgets(buf, sizeof(buf), pf))
19 | {
20 | continue;
21 | }
22 |
23 | std::string strResult(buf);
24 | pclose(pf);
25 |
26 | return strResult;
27 | }
28 |
29 | std::vector TTSFrontend::strSplit(const std::string &idStr)
30 | {
31 | std::vector idsVector;
32 |
33 | std::regex rgx ("\\s+");
34 | std::sregex_token_iterator iter(idStr.begin(), idStr.end(), rgx, -1);
35 | std::sregex_token_iterator end;
36 |
37 | while (iter != end) {
38 | idsVector.push_back(stoi(*iter));
39 | ++iter;
40 | }
41 |
42 | return idsVector;
43 | }
--------------------------------------------------------------------------------
/examples/cpptflite/src/TTSFrontend.h:
--------------------------------------------------------------------------------
1 | #ifndef TTSFRONTEND_H
2 | #define TTSFRONTEND_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | class TTSFrontend
11 | {
12 | public:
13 |
14 | /**
15 | * Converting text to phoneIDs.
16 | * A tmporary method using command to process text in this demo,
17 | * which should to be replaced by a pronunciation processing module.
18 | *@param strCmd Command to call the method of processor.text_to_sequence()
19 | */
20 | TTSFrontend(const std::string &mapperJson,
21 | const std::string &strCmd):
22 | _mapperJson(mapperJson),
23 | _strCmd(strCmd)
24 | {
25 | std::cout << "TTSFrontend Init" << std::endl;
26 | std::cout << _mapperJson << std::endl;
27 | std::cout << _strCmd << std::endl;
28 | };
29 |
30 | void text2ids(const std::string &text);
31 |
32 | std::vector getPhoneIds() const {return _phonesIds;}
33 | private:
34 |
35 | const std::string _mapperJson;
36 | const std::string _strCmd;
37 |
38 | std::vector _phonesIds;
39 |
40 | std::string getCmdResult(const std::string &text);
41 | std::vector strSplit(const std::string &idStr);
42 | };
43 |
44 | #endif // TTSFRONTEND_H
--------------------------------------------------------------------------------
/examples/cpptflite/src/TfliteBase.cpp:
--------------------------------------------------------------------------------
1 | #include "TfliteBase.h"
2 |
3 | TfliteBase::TfliteBase(const char* modelFilename)
4 | {
5 | interpreterBuild(modelFilename);
6 | }
7 |
8 | TfliteBase::~TfliteBase()
9 | {
10 | ;
11 | }
12 |
13 | void TfliteBase::interpreterBuild(const char* modelFilename)
14 | {
15 | model = tflite::FlatBufferModel::BuildFromFile(modelFilename);
16 |
17 | TFLITE_MINIMAL_CHECK(model != nullptr);
18 |
19 | tflite::InterpreterBuilder builder(*model, resolver);
20 |
21 | builder(&interpreter);
22 |
23 | TFLITE_MINIMAL_CHECK(interpreter != nullptr);
24 | }
25 |
--------------------------------------------------------------------------------
/examples/cpptflite/src/TfliteBase.h:
--------------------------------------------------------------------------------
1 | #ifndef TFLITEBASE_H
2 | #define TFLITEBASE_H
3 |
4 | #include "tensorflow/lite/interpreter.h"
5 | #include "tensorflow/lite/kernels/register.h"
6 | #include "tensorflow/lite/model.h"
7 | #include "tensorflow/lite/optional_debug_tools.h"
8 |
9 | #define TFLITE_MINIMAL_CHECK(x) \
10 | if (!(x)) { \
11 | fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
12 | exit(1); \
13 | }
14 |
15 | typedef struct
16 | {
17 | float *melData;
18 | std::vector melShape;
19 | int32_t bytes;
20 | } MelGenData;
21 |
22 | class TfliteBase
23 | {
24 | public:
25 | uint32_t int_size = sizeof(int32_t);
26 | uint32_t float_size = sizeof(float);
27 |
28 | std::unique_ptr interpreter;
29 |
30 | TfliteBase(const char* modelFilename);
31 | ~TfliteBase();
32 |
33 | private:
34 | std::unique_ptr model;
35 | tflite::ops::builtin::BuiltinOpResolver resolver;
36 |
37 | void interpreterBuild(const char* modelFilename);
38 | };
39 |
40 | #endif // TFLITEBASE_H
--------------------------------------------------------------------------------
/examples/cpptflite/src/VocoderTF.cpp:
--------------------------------------------------------------------------------
1 | #include "VocoderTF.h"
2 |
3 | std::vector VocoderTF::infer(const MelGenData mel)
4 | {
5 | std::vector audio;
6 |
7 | interpreter->ResizeInputTensor(inputIndex, mel.melShape);
8 | TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
9 |
10 | float* melDataPtr = interpreter->typed_input_tensor(inputIndex);
11 | memcpy(melDataPtr, mel.melData, mel.bytes);
12 |
13 | TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
14 |
15 | TfLiteTensor* audioTensor = interpreter->tensor(outputIndex);
16 |
17 | float* outputPtr = interpreter->typed_output_tensor(0);
18 |
19 | int32_t audio_len = audioTensor->bytes / float_size;
20 |
21 | for (int i=0; iinputs()[0]),
12 | outputIndex(interpreter->outputs()[0]) {};
13 |
14 | std::vector infer(const MelGenData mel);
15 |
16 | private:
17 |
18 | const int32_t inputIndex;
19 | const int32_t outputIndex;
20 | };
21 |
22 | #endif // VOCODERTF_H
--------------------------------------------------------------------------------
/examples/cpptflite/src/VoxCommon.cpp:
--------------------------------------------------------------------------------
1 | #include "VoxCommon.h"
2 |
3 | void VoxUtil::ExportWAV(const std::string & Filename, const std::vector& Data, unsigned SampleRate) {
4 | AudioFile::AudioBuffer Buffer;
5 | Buffer.resize(1);
6 |
7 |
8 | Buffer[0] = Data;
9 | size_t BufSz = Data.size();
10 |
11 |
12 | AudioFile File;
13 | File.setAudioBuffer(Buffer);
14 | File.setAudioBufferSize(1, (int)BufSz);
15 | File.setNumSamplesPerChannel((int)BufSz);
16 | File.setNumChannels(1);
17 | File.setBitDepth(32);
18 | File.setSampleRate(SampleRate);
19 |
20 | File.save(Filename, AudioFileFormat::Wave);
21 | }
22 |
--------------------------------------------------------------------------------
/examples/cpptflite/src/VoxCommon.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | /*
3 | VoxCommon.hpp : Defines common data structures and constants to be used with TensorVox
4 | */
5 | #include
6 | #include
7 | #include "AudioFile.h"
8 | // #include "ext/CppFlow/include/Tensor.h"
9 | // #include
10 |
11 | #define IF_RETURN(cond,ret) if (cond){return ret;}
12 | #define VX_IF_EXCEPT(cond,ex) if (cond){throw std::invalid_argument(ex);}
13 |
14 |
15 | template
16 | struct TFTensor {
17 | std::vector Data;
18 | std::vector Shape;
19 | size_t TotalSize;
20 | };
21 |
22 | namespace VoxUtil {
23 |
24 | void ExportWAV(const std::string& Filename, const std::vector& Data, unsigned SampleRate);
25 | }
26 |
--------------------------------------------------------------------------------
/examples/cppwin/.gitattributes:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Set default behavior to automatically normalize line endings.
3 | ###############################################################################
4 | * text=auto
5 |
6 | ###############################################################################
7 | # Set default behavior for command prompt diff.
8 | #
9 | # This is need for earlier builds of msysgit that does not have it on by
10 | # default for csharp files.
11 | # Note: This is only used by command line
12 | ###############################################################################
13 | #*.cs diff=csharp
14 |
15 | ###############################################################################
16 | # Set the merge driver for project and solution files
17 | #
18 | # Merging from the command prompt will add diff markers to the files if there
19 | # are conflicts (Merging from VS is not affected by the settings below, in VS
20 | # the diff markers are never inserted). Diff markers may cause the following
21 | # file extensions to fail to load in VS. An alternative would be to treat
22 | # these files as binary and thus will always conflict and require user
23 | # intervention with every merge. To do so, just uncomment the entries below
24 | ###############################################################################
25 | #*.sln merge=binary
26 | #*.csproj merge=binary
27 | #*.vbproj merge=binary
28 | #*.vcxproj merge=binary
29 | #*.vcproj merge=binary
30 | #*.dbproj merge=binary
31 | #*.fsproj merge=binary
32 | #*.lsproj merge=binary
33 | #*.wixproj merge=binary
34 | #*.modelproj merge=binary
35 | #*.sqlproj merge=binary
36 | #*.wwaproj merge=binary
37 |
38 | ###############################################################################
39 | # behavior for image files
40 | #
41 | # image files are treated as binary by default.
42 | ###############################################################################
43 | #*.jpg binary
44 | #*.png binary
45 | #*.gif binary
46 |
47 | ###############################################################################
48 | # diff behavior for common document formats
49 | #
50 | # Convert binary document formats to text before diffing them. This feature
51 | # is only available from the command line. Turn it on by uncommenting the
52 | # entries below.
53 | ###############################################################################
54 | #*.doc diff=astextplain
55 | #*.DOC diff=astextplain
56 | #*.docx diff=astextplain
57 | #*.DOCX diff=astextplain
58 | #*.dot diff=astextplain
59 | #*.DOT diff=astextplain
60 | #*.pdf diff=astextplain
61 | #*.PDF diff=astextplain
62 | #*.rtf diff=astextplain
63 | #*.RTF diff=astextplain
64 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference.pro:
--------------------------------------------------------------------------------
1 | TEMPLATE = app
2 | CONFIG += console c++14
3 | CONFIG -= app_bundle
4 | CONFIG -= qt
5 | TARGET = TFTTSCppInfer
6 |
7 | HEADERS += \
8 | TensorflowTTSCppInference/EnglishPhoneticProcessor.h \
9 | TensorflowTTSCppInference/FastSpeech2.h \
10 | TensorflowTTSCppInference/MultiBandMelGAN.h \
11 | TensorflowTTSCppInference/TextTokenizer.h \
12 | TensorflowTTSCppInference/Voice.h \
13 | TensorflowTTSCppInference/VoxCommon.hpp \
14 | TensorflowTTSCppInference/ext/AudioFile.hpp \
15 | TensorflowTTSCppInference/ext/CppFlow/include/Model.h \
16 | TensorflowTTSCppInference/ext/CppFlow/include/Tensor.h \
17 | TensorflowTTSCppInference/ext/ZCharScanner.h \
18 | TensorflowTTSCppInference/phonemizer.h \
19 | TensorflowTTSCppInference/tfg2p.h \
20 |
21 | SOURCES += \
22 | TensorflowTTSCppInference/EnglishPhoneticProcessor.cpp \
23 | TensorflowTTSCppInference/FastSpeech2.cpp \
24 | TensorflowTTSCppInference/MultiBandMelGAN.cpp \
25 | TensorflowTTSCppInference/TensorflowTTSCppInference.cpp \
26 | TensorflowTTSCppInference/TextTokenizer.cpp \
27 | TensorflowTTSCppInference/Voice.cpp \
28 | TensorflowTTSCppInference/VoxCommon.cpp \
29 | TensorflowTTSCppInference/ext/CppFlow/src/Model.cpp \
30 | TensorflowTTSCppInference/ext/CppFlow/src/Tensor.cpp \
31 | TensorflowTTSCppInference/phonemizer.cpp \
32 | TensorflowTTSCppInference/tfg2p.cpp \
33 | TensorflowTTSCppInference/ext/ZCharScanner.cpp \
34 |
35 | INCLUDEPATH += $$PWD/deps/include
36 | LIBS += -L$$PWD/deps/lib -ltensorflow
37 |
38 | # GCC shits itself on memcp in AudioFile.hpp (l-1186) unless we add this
39 | QMAKE_CXXFLAGS += -fpermissive
40 |
41 |
42 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.28307.136
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TensorflowTTSCppInference", "TensorflowTTSCppInference\TensorflowTTSCppInference.vcxproj", "{67C98279-9BA3-49F7-9FE4-2C0DF77A2875}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|x64 = Debug|x64
11 | Debug|x86 = Debug|x86
12 | Release|x64 = Release|x64
13 | Release|x86 = Release|x86
14 | EndGlobalSection
15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
16 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x64.ActiveCfg = Debug|x64
17 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x64.Build.0 = Debug|x64
18 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x86.ActiveCfg = Debug|Win32
19 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Debug|x86.Build.0 = Debug|Win32
20 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x64.ActiveCfg = Release|x64
21 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x64.Build.0 = Release|x64
22 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x86.ActiveCfg = Release|Win32
23 | {67C98279-9BA3-49F7-9FE4-2C0DF77A2875}.Release|x86.Build.0 = Release|Win32
24 | EndGlobalSection
25 | GlobalSection(SolutionProperties) = preSolution
26 | HideSolutionNode = FALSE
27 | EndGlobalSection
28 | GlobalSection(ExtensibilityGlobals) = postSolution
29 | SolutionGuid = {08E7CCCB-028D-4BFC-9CDC-E8957E50F8EA}
30 | EndGlobalSection
31 | EndGlobal
32 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/EnglishPhoneticProcessor.cpp:
--------------------------------------------------------------------------------
1 | #include "EnglishPhoneticProcessor.h"
2 | #include "VoxCommon.hpp"
3 |
4 | using namespace std;
5 |
6 | bool EnglishPhoneticProcessor::Initialize(Phonemizer* InPhn)
7 | {
8 |
9 |
10 | Phoner = InPhn;
11 | Tokenizer.SetAllowedChars(Phoner->GetGraphemeChars());
12 |
13 |
14 |
15 | return true;
16 | }
17 |
18 | std::string EnglishPhoneticProcessor::ProcessTextPhonetic(const std::string& InText, const std::vector &InPhonemes,ETTSLanguage::Enum InLanguage)
19 | {
20 | if (!Phoner)
21 | return "ERROR";
22 |
23 |
24 |
25 | vector Words = Tokenizer.Tokenize(InText,InLanguage);
26 |
27 | string Assemble = "";
28 | // Make a copy of the dict passed.
29 |
30 | for (size_t w = 0; w < Words.size();w++)
31 | {
32 | const string& Word = Words[w];
33 |
34 | if (Word.find("@") != std::string::npos){
35 | std::string AddPh = Word.substr(1); // Remove the @
36 | size_t OutId = 0;
37 | if (VoxUtil::FindInVec(AddPh,InPhonemes,OutId))
38 | {
39 | Assemble.append(InPhonemes[OutId]);
40 | Assemble.append(" ");
41 |
42 |
43 | }
44 |
45 | continue;
46 |
47 | }
48 |
49 |
50 |
51 |
52 | size_t OverrideIdx = 0;
53 |
54 |
55 |
56 | std::string Res = Phoner->ProcessWord(Word,0.001f);
57 |
58 | // Cache the word in the override dict so next time we don't have to research it
59 |
60 | Assemble.append(Res);
61 | Assemble.append(" ");
62 |
63 |
64 |
65 |
66 |
67 | }
68 |
69 |
70 | // Delete last space if there is
71 |
72 |
73 | if (Assemble[Assemble.size() - 1] == ' ')
74 | Assemble.pop_back();
75 |
76 |
77 | return Assemble;
78 | }
79 |
80 | EnglishPhoneticProcessor::EnglishPhoneticProcessor()
81 | {
82 | Phoner = nullptr;
83 | }
84 |
85 | EnglishPhoneticProcessor::EnglishPhoneticProcessor(Phonemizer *InPhn)
86 | {
87 | Initialize(InPhn);
88 |
89 | }
90 |
91 |
92 |
93 | EnglishPhoneticProcessor::~EnglishPhoneticProcessor()
94 | {
95 | if (Phoner)
96 | delete Phoner;
97 | }
98 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/EnglishPhoneticProcessor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "TextTokenizer.h"
3 | #include "phonemizer.h"
4 |
5 | class EnglishPhoneticProcessor
6 | {
7 | private:
8 | TextTokenizer Tokenizer;
9 | Phonemizer* Phoner;
10 |
11 | inline bool FileExists(const std::string& name) {
12 | std::ifstream f(name.c_str());
13 | return f.good();
14 | }
15 |
16 | public:
17 | bool Initialize(Phonemizer *InPhn);
18 | std::string ProcessTextPhonetic(const std::string& InText, const std::vector &InPhonemes,ETTSLanguage::Enum InLanguage);
19 | EnglishPhoneticProcessor();
20 | EnglishPhoneticProcessor(Phonemizer *InPhn);
21 | ~EnglishPhoneticProcessor();
22 | };
23 |
24 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/FastSpeech2.cpp:
--------------------------------------------------------------------------------
1 | #include "FastSpeech2.h"
2 | #include
3 |
4 |
5 | FastSpeech2::FastSpeech2()
6 | {
7 | FastSpeech = nullptr;
8 | }
9 |
10 | FastSpeech2::FastSpeech2(const std::string & SavedModelFolder)
11 | {
12 | Initialize(SavedModelFolder);
13 | }
14 |
15 |
16 | bool FastSpeech2::Initialize(const std::string & SavedModelFolder)
17 | {
18 | try {
19 | FastSpeech = new Model(SavedModelFolder);
20 | }
21 | catch (...) {
22 | FastSpeech = nullptr;
23 | return false;
24 |
25 | }
26 | return true;
27 | }
28 |
29 | TFTensor FastSpeech2::DoInference(const std::vector& InputIDs, int32_t SpeakerID, float Speed, float Energy, float F0, int32_t EmotionID)
30 | {
31 | if (!FastSpeech)
32 | throw std::invalid_argument("Tried to do inference on unloaded or invalid model!");
33 |
34 | // Convenience reference so that we don't have to constantly derefer pointers.
35 | Model& Mdl = *FastSpeech;
36 |
37 | // Define the tensors
38 | Tensor input_ids{ Mdl,"serving_default_input_ids" };
39 | Tensor energy_ratios{ Mdl,"serving_default_energy_ratios" };
40 | Tensor f0_ratios{ Mdl,"serving_default_f0_ratios" };
41 | Tensor speaker_ids{ Mdl,"serving_default_speaker_ids" };
42 | Tensor speed_ratios{ Mdl,"serving_default_speed_ratios" };
43 | Tensor* emotion_ids = nullptr;
44 |
45 | // This is a multi-emotion model
46 | if (EmotionID != -1)
47 | {
48 | emotion_ids = new Tensor{Mdl,"serving_default_emotion_ids"};
49 | emotion_ids->set_data(std::vector{EmotionID});
50 |
51 | }
52 |
53 |
54 | // This is the shape of the input IDs, our equivalent to tf.expand_dims.
55 | std::vector InputIDShape = { 1, (int64_t)InputIDs.size() };
56 |
57 | input_ids.set_data(InputIDs, InputIDShape);
58 | energy_ratios.set_data(std::vector{ Energy });
59 | f0_ratios.set_data(std::vector{F0});
60 | speaker_ids.set_data(std::vector{SpeakerID});
61 | speed_ratios.set_data(std::vector{Speed});
62 |
63 | // Define output tensor
64 | Tensor output{ Mdl,"StatefulPartitionedCall" };
65 |
66 |
67 | // Vector of input tensors
68 | std::vector inputs = { &input_ids,&speaker_ids,&speed_ratios,&f0_ratios,&energy_ratios };
69 |
70 | if (EmotionID != -1)
71 | inputs.push_back(emotion_ids);
72 |
73 |
74 | // Do inference
75 | FastSpeech->run(inputs, output);
76 |
77 | // Define output and return it
78 | TFTensor Output = VoxUtil::CopyTensor(output);
79 |
80 | // We allocated the emotion_ids tensor dynamically, delete it
81 | if (emotion_ids)
82 | delete emotion_ids;
83 |
84 | // We could just straight out define it in the return statement, but I like it more this way
85 |
86 | return Output;
87 | }
88 |
89 | FastSpeech2::~FastSpeech2()
90 | {
91 | if (FastSpeech)
92 | delete FastSpeech;
93 | }
94 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/FastSpeech2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ext/CppFlow/include/Model.h"
4 | #include "VoxCommon.hpp"
5 | class FastSpeech2
6 | {
7 | private:
8 | Model* FastSpeech;
9 |
10 | public:
11 | FastSpeech2();
12 | FastSpeech2(const std::string& SavedModelFolder);
13 |
14 | /*
15 | Initialize and load the model
16 |
17 | -> SavedModelFolder: Folder where the .pb, variables, and other characteristics of the exported SavedModel
18 | <- Returns: (bool)Success
19 | */
20 | bool Initialize(const std::string& SavedModelFolder);
21 |
22 | /*
23 | Do inference on a FastSpeech2 model.
24 |
25 | -> InputIDs: Input IDs of tokens for inference
26 | -> SpeakerID: ID of the speaker in the model to do inference on. If single speaker, always leave at 0. If multispeaker, refer to your model.
27 | -> Speed, Energy, F0: Parameters for FS2 inference. Leave at 1.f for defaults
28 |
29 | <- Returns: TFTensor with shape {1,,80} containing contents of mel spectrogram.
30 | */
31 | TFTensor DoInference(const std::vector& InputIDs, int32_t SpeakerID = 0, float Speed = 1.f, float Energy = 1.f, float F0 = 1.f,int32_t EmotionID = -1);
32 |
33 |
34 |
35 | ~FastSpeech2();
36 | };
37 |
38 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/MultiBandMelGAN.cpp:
--------------------------------------------------------------------------------
1 | #include "MultiBandMelGAN.h"
2 | #include
3 | #define IF_EXCEPT(cond,ex) if (cond){throw std::invalid_argument(ex);}
4 |
5 |
6 |
7 | bool MultiBandMelGAN::Initialize(const std::string & VocoderPath)
8 | {
9 | try {
10 | MelGAN = new Model(VocoderPath);
11 | }
12 | catch (...) {
13 | MelGAN = nullptr;
14 | return false;
15 |
16 | }
17 | return true;
18 |
19 |
20 | }
21 |
22 | TFTensor MultiBandMelGAN::DoInference(const TFTensor& InMel)
23 | {
24 | IF_EXCEPT(!MelGAN, "Tried to infer MB-MelGAN on uninitialized model!!!!")
25 |
26 | // Convenience reference so that we don't have to constantly derefer pointers.
27 | Model& Mdl = *MelGAN;
28 |
29 | Tensor input_mels{ Mdl,"serving_default_mels" };
30 | input_mels.set_data(InMel.Data, InMel.Shape);
31 |
32 | Tensor out_audio{ Mdl,"StatefulPartitionedCall" };
33 |
34 | MelGAN->run(input_mels, out_audio);
35 |
36 | TFTensor RetTensor = VoxUtil::CopyTensor(out_audio);
37 |
38 | return RetTensor;
39 |
40 |
41 | }
42 |
43 | MultiBandMelGAN::MultiBandMelGAN()
44 | {
45 | MelGAN = nullptr;
46 | }
47 |
48 |
49 | MultiBandMelGAN::~MultiBandMelGAN()
50 | {
51 | if (MelGAN)
52 | delete MelGAN;
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/MultiBandMelGAN.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ext/CppFlow/include/Model.h"
4 | #include "VoxCommon.hpp"
5 | class MultiBandMelGAN
6 | {
7 | private:
8 | Model* MelGAN;
9 |
10 |
11 | public:
12 | bool Initialize(const std::string& VocoderPath);
13 |
14 |
15 | // Do MultiBand MelGAN inference including PQMF
16 | // -> InMel: Mel spectrogram (shape [1, xx, 80])
17 | // <- Returns: Tensor data [4, xx, 1]
18 | TFTensor DoInference(const TFTensor& InMel);
19 |
20 | MultiBandMelGAN();
21 | ~MultiBandMelGAN();
22 | };
23 |
24 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/TextTokenizer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include "VoxCommon.hpp"
5 |
6 | class TextTokenizer
7 | {
8 | private:
9 | std::string AllowedChars;
10 | std::string IntToStr(int number);
11 |
12 | std::vector ExpandNumbers(const std::vector& SpaceTokens);
13 | public:
14 | TextTokenizer();
15 | ~TextTokenizer();
16 |
17 | std::vector Tokenize(const std::string& InTxt,ETTSLanguage::Enum Language = ETTSLanguage::English);
18 | void SetAllowedChars(const std::string &value);
19 | };
20 |
21 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/Voice.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "FastSpeech2.h"
4 | #include "MultiBandMelGAN.h"
5 | #include "EnglishPhoneticProcessor.h"
6 |
7 |
8 | class Voice
9 | {
10 | private:
11 | FastSpeech2 MelPredictor;
12 | MultiBandMelGAN Vocoder;
13 | EnglishPhoneticProcessor Processor;
14 | VoiceInfo VoxInfo;
15 |
16 |
17 |
18 | std::vector Phonemes;
19 | std::vector PhonemeIDs;
20 |
21 |
22 |
23 | std::vector PhonemesToID(const std::string& InTxt);
24 |
25 | std::vector Speakers;
26 | std::vector Emotions;
27 |
28 | void ReadPhonemes(const std::string& PhonemePath);
29 |
30 | void ReadSpeakers(const std::string& SpeakerPath);
31 |
32 | void ReadEmotions(const std::string& EmotionPath);
33 |
34 |
35 | void ReadModelInfo(const std::string& ModelInfoPath);
36 |
37 | std::vector GetLinedFile(const std::string& Path);
38 |
39 |
40 | std::string ModelInfo;
41 |
42 | public:
43 | /* Voice constructor, arguments obligatory.
44 | -> VoxPath: Path of folder where models are contained.
45 | -- Must be a folder without an ending slash with UNIX slashes, can be relative or absolute (eg: MyVoices/Karen)
46 | -- The folder must contain the following elements:
47 | --- melgen: Folder generated where a FastSpeech2 model was saved as SavedModel, with .pb, variables, etc
48 | --- vocoder: Folder where a Multi-Band MelGAN model was saved as SavedModel.
49 | --- info.json: Model information
50 | --- phonemes.txt: Tab delimited file containing PHONEME \t ID, for inputting to the FS2 model.
51 |
52 | --- If multispeaker, a lined .txt file called speakers.txt
53 | --- If multi-emotion, a lined .txt file called emotions.txt
54 |
55 | */
56 | Voice(const std::string& VoxPath, const std::string& inName,Phonemizer* InPhn);
57 |
58 | void AddPhonemizer(Phonemizer* InPhn);
59 |
60 |
61 | std::vector Vocalize(const std::string& Prompt, float Speed = 1.f, int32_t SpeakerID = 0, float Energy = 1.f, float F0 = 1.f,int32_t EmotionID = -1);
62 |
63 | std::string Name;
64 | inline const VoiceInfo& GetInfo(){return VoxInfo;}
65 |
66 | inline const std::vector& GetSpeakers(){return Speakers;}
67 | inline const std::vector& GetEmotions(){return Emotions;}
68 |
69 | inline const std::string& GetModelInfo(){return ModelInfo;}
70 |
71 | ~Voice();
72 | };
73 |
74 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/VoxCommon.cpp:
--------------------------------------------------------------------------------
1 | #include "VoxCommon.hpp"
2 | #include "ext/json.hpp"
3 | using namespace nlohmann;
4 |
5 | const std::vector Text2MelNames = {"FastSpeech2","Tacotron2"};
6 | const std::vector VocoderNames = {"Multi-Band MelGAN"};
7 | const std::vector RepoNames = {"TensorflowTTS"};
8 |
9 | const std::vector LanguageNames = {"English","Spanish"};
10 |
11 |
12 | void VoxUtil::ExportWAV(const std::string & Filename, const std::vector& Data, unsigned SampleRate) {
13 | AudioFile::AudioBuffer Buffer;
14 | Buffer.resize(1);
15 |
16 |
17 | Buffer[0] = Data;
18 | size_t BufSz = Data.size();
19 |
20 |
21 | AudioFile File;
22 | File.setAudioBuffer(Buffer);
23 | File.setAudioBufferSize(1, (int)BufSz);
24 | File.setNumSamplesPerChannel((int)BufSz);
25 | File.setNumChannels(1);
26 | File.setBitDepth(32);
27 | File.setSampleRate(SampleRate);
28 |
29 | File.save(Filename, AudioFileFormat::Wave);
30 |
31 |
32 |
33 | }
34 |
35 | VoiceInfo VoxUtil::ReadModelJSON(const std::string &InfoFilename)
36 | {
37 | const size_t MaxNoteSize = 80;
38 |
39 | std::ifstream JFile(InfoFilename);
40 | json JS;
41 |
42 | JFile >> JS;
43 |
44 |
45 | JFile.close();
46 |
47 | auto Arch = JS["architecture"];
48 |
49 | ArchitectureInfo CuArch;
50 | CuArch.Repo = Arch["repo"].get();
51 | CuArch.Text2Mel = Arch["text2mel"].get();
52 | CuArch.Vocoder = Arch["vocoder"].get();
53 |
54 | // Now fill the strings
55 | CuArch.s_Repo = RepoNames[CuArch.Repo];
56 | CuArch.s_Text2Mel = Text2MelNames[CuArch.Text2Mel];
57 | CuArch.s_Vocoder = VocoderNames[CuArch.Vocoder];
58 |
59 |
60 | uint32_t Lang = JS["language"].get();
61 | VoiceInfo Inf{JS["name"].get(),
62 | JS["author"].get(),
63 | JS["version"].get(),
64 | JS["description"].get(),
65 | CuArch,
66 | JS["note"].get(),
67 | JS["sarate"].get(),
68 | Lang,
69 | LanguageNames[Lang],
70 | " " + JS["pad"].get()}; // Add a space for separation since we directly append the value to the prompt
71 |
72 | if (Inf.Note.size() > MaxNoteSize)
73 | Inf.Note = Inf.Note.substr(0,MaxNoteSize);
74 |
75 | return Inf;
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/VoxCommon.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | /*
3 | VoxCommon.hpp : Defines common data structures and constants to be used with TensorVox
4 | */
5 | #include
6 | #include
7 | #include "ext/AudioFile.hpp"
8 | #include "ext/CppFlow/include/Tensor.h"
9 |
10 | #define IF_RETURN(cond,ret) if (cond){return ret;}
11 |
12 |
13 |
14 | template
15 | struct TFTensor {
16 | std::vector Data;
17 | std::vector Shape;
18 | size_t TotalSize;
19 |
20 | };
21 |
22 |
23 | namespace ETTSRepo {
24 | enum Enum{
25 | TensorflowTTS = 0,
26 | MozillaTTS // not implemented yet
27 | };
28 |
29 | }
30 | namespace EText2MelModel {
31 | enum Enum{
32 | FastSpeech2 = 0,
33 | Tacotron2 // not implemented yet
34 | };
35 |
36 | }
37 |
38 | namespace EVocoderModel{
39 | enum Enum{
40 | MultiBandMelGAN = 0
41 | };
42 | }
43 |
44 | namespace ETTSLanguage{
45 | enum Enum{
46 | English = 0,
47 | Spanish
48 | };
49 |
50 | }
51 |
52 |
53 |
54 | struct ArchitectureInfo{
55 | int Repo;
56 | int Text2Mel;
57 | int Vocoder;
58 |
59 | // String versions of the info, for displaying.
60 | // We want boilerplate int index to str conversion code to be low.
61 | std::string s_Repo;
62 | std::string s_Text2Mel;
63 | std::string s_Vocoder;
64 |
65 | };
66 | struct VoiceInfo{
67 | std::string Name;
68 | std::string Author;
69 | int32_t Version;
70 | std::string Description;
71 | ArchitectureInfo Architecture;
72 | std::string Note;
73 |
74 | uint32_t SampleRate;
75 |
76 | uint32_t Language;
77 | std::string s_Language;
78 |
79 | std::string EndPadding;
80 |
81 |
82 |
83 | };
84 |
85 | namespace VoxUtil {
86 |
87 | VoiceInfo ReadModelJSON(const std::string& InfoFilename);
88 |
89 |
90 | template
91 | TFTensor CopyTensor(Tensor& InTens)
92 | {
93 | std::vector Data = InTens.get_data();
94 | std::vector Shape = InTens.get_shape();
95 | size_t TotalSize = 1;
96 | for (const int64_t& Dim : Shape)
97 | TotalSize *= Dim;
98 |
99 | return TFTensor{Data, Shape, TotalSize};
100 |
101 |
102 | }
103 |
104 | template
105 | bool FindInVec(V In, const std::vector& Vec, size_t& OutIdx, size_t start = 0) {
106 | for (size_t xx = start;xx < Vec.size();xx++)
107 | {
108 | if (Vec[xx] == In) {
109 | OutIdx = xx;
110 | return true;
111 |
112 | }
113 |
114 | }
115 |
116 |
117 | return false;
118 |
119 | }
120 | template
121 | bool FindInVec2(V In, const std::vector& Vec, size_t& OutIdx, size_t start = 0) {
122 | for (size_t xx = start;xx < Vec.size();xx++)
123 | {
124 | if (Vec[xx] == In) {
125 | OutIdx = xx;
126 | return true;
127 |
128 | }
129 |
130 | }
131 |
132 |
133 | return false;
134 |
135 | }
136 |
137 | void ExportWAV(const std::string& Filename, const std::vector& Data, unsigned SampleRate);
138 | }
139 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/include/Model.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by sergio on 12/05/19.
3 | //
4 |
5 | #ifndef CPPFLOW_MODEL_H
6 | #define CPPFLOW_MODEL_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #pragma warning(push, 0)
16 | #include
17 | #include "Tensor.h"
18 | #pragma warning(pop)
19 | class Tensor;
20 |
21 | class Model {
22 | public:
23 | // Pass a path to the model file and optional Tensorflow config options. See examples/load_model/main.cpp.
24 | explicit Model(const std::string& model_filename, const std::vector& config_options = {});
25 |
26 | // Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy
27 | // the contents of the pointer (i guess dereferencing won't do a deep copy)
28 | Model(const Model &model) = delete;
29 | Model(Model &&model) = default;
30 | Model& operator=(const Model &model) = delete;
31 | Model& operator=(Model &&model) = default;
32 |
33 | ~Model();
34 |
35 | void init();
36 | void restore(const std::string& ckpt);
37 | void save(const std::string& ckpt);
38 | void restore_savedmodel(const std::string& savedmdl);
39 | std::vector get_operations() const;
40 |
41 | // Original Run
42 | void run(const std::vector& inputs, const std::vector& outputs);
43 |
44 | // Run with references
45 | void run(Tensor& input, const std::vector& outputs);
46 | void run(const std::vector& inputs, Tensor& output);
47 | void run(Tensor& input, Tensor& output);
48 |
49 | // Run with pointers
50 | void run(Tensor* input, const std::vector& outputs);
51 | void run(const std::vector& inputs, Tensor* output);
52 | void run(Tensor* input, Tensor* output);
53 |
54 | private:
55 | TF_Graph* graph;
56 | TF_Session* session;
57 | TF_Status* status;
58 |
59 | // Read a file from a string
60 | static TF_Buffer* read(const std::string&);
61 |
62 | bool status_check(bool throw_exc) const;
63 | void error_check(bool condition, const std::string &error) const;
64 |
65 | public:
66 | friend class Tensor;
67 | };
68 |
69 |
70 | #endif //CPPFLOW_MODEL_H
71 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/ext/CppFlow/include/Tensor.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by sergio on 13/05/19.
3 | //
4 |
5 | #ifndef CPPFLOW_TENSOR_H
6 | #define CPPFLOW_TENSOR_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | // Prevent warnings from Tensorflow C API headers
16 |
17 | #pragma warning(push, 0)
18 | #include
19 | #include "Model.h"
20 | #pragma warning(pop)
21 |
22 | class Model;
23 |
24 | class Tensor {
25 | public:
26 | Tensor(const Model& model, const std::string& operation);
27 |
28 | // Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy
29 | // the contents of the pointer (i guess dereferencing won't do a deep copy)
30 | Tensor(const Tensor &tensor) = delete;
31 | Tensor(Tensor &&tensor) = default;
32 | Tensor& operator=(const Tensor &tensor) = delete;
33 | Tensor& operator=(Tensor &&tensor) = default;
34 |
35 | ~Tensor();
36 |
37 | void clean();
38 |
39 | template
40 | void set_data(std::vector new_data);
41 |
42 | template
43 | void set_data(std::vector new_data, const std::vector& new_shape);
44 |
45 | void set_data(const std::string & new_data, Model & inmodel);
46 | template
47 | std::vector get_data();
48 |
49 | std::vector get_shape();
50 |
51 | private:
52 | TF_Tensor* val;
53 | TF_Output op;
54 | TF_DataType type;
55 | std::vector shape;
56 | std::unique_ptr> actual_shape;
57 | void* data;
58 | int flag;
59 |
60 | // Aux functions
61 | void error_check(bool condition, const std::string& error);
62 |
63 |
64 |
65 |
66 | template
67 | static TF_DataType deduce_type();
68 |
69 | void deduce_shape();
70 |
71 | public:
72 | friend class Model;
73 | };
74 |
75 | #endif //CPPFLOW_TENSOR_H
76 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/ext/ZCharScanner.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #define GBasicCharScanner ZStringDelimiter
4 |
5 | #include
6 | #include
7 |
8 | #define ZSDEL_USE_STD_STRING
9 | #ifndef ZSDEL_USE_STD_STRING
10 | #include "golem_string.h"
11 | #else
12 | #define GString std::string
13 | #endif
14 |
15 | typedef std::vector::const_iterator TokenIterator;
16 |
17 | // ZStringDelimiter
18 | // ==============
19 | // Simple class to delimit and split strings.
20 | // You can use operator[] to access them
21 | // Or you can use the itBegin() and itEnd() to get some iterators
22 | // =================
23 | class ZStringDelimiter
24 | {
25 | private:
26 | int key_search(const GString & s, const GString & key);
27 | void UpdateTokens();
28 | std::vector m_vTokens;
29 | std::vector m_vDelimiters;
30 |
31 | GString m_sString;
32 |
33 | void DelimStr(const GString& s, const GString& delimiter, const bool& removeEmptyEntries = false);
34 | void BarRange(const int& min, const int& max);
35 | void Bar(const int& pos);
36 | size_t tokenIndex;
37 | public:
38 | ZStringDelimiter();
39 | bool PgBar;
40 |
41 | #ifdef _AFX_ALL_WARNINGS
42 | CProgressCtrl* m_pBar;
43 | #endif
44 |
45 | ZStringDelimiter(const GString& in_iStr) {
46 | m_sString = in_iStr;
47 | PgBar = false;
48 |
49 | }
50 |
51 | bool GetFirstToken(GString& in_out);
52 | bool GetNextToken(GString& in_sOut);
53 |
54 | // std::String alts
55 |
56 | size_t szTokens() { return m_vTokens.size(); }
57 | GString operator[](const size_t& in_index);
58 |
59 | GString Reassemble(const GString & delim, const int & nelem = -1);
60 |
61 | // Override to reassemble provided tokens.
62 | GString Reassemble(const GString & delim, const std::vector& Strs,int nelem = -1);
63 |
64 | // Get a const reference to the tokens
65 | const std::vector& GetTokens() { return m_vTokens; }
66 |
67 | TokenIterator itBegin() { return m_vTokens.begin(); }
68 | TokenIterator itEnd() { return m_vTokens.end(); }
69 |
70 | void SetText(const GString& in_Txt) {
71 | m_sString = in_Txt;
72 | if (m_vDelimiters.size())
73 | UpdateTokens();
74 | }
75 | void AddDelimiter(const GString& in_Delim);
76 |
77 | ~ZStringDelimiter();
78 | };
79 |
80 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/phonemizer.h:
--------------------------------------------------------------------------------
1 | #ifndef PHONEMIZER_H
2 | #define PHONEMIZER_H
3 | #include "tfg2p.h"
4 | #include
5 | #include
6 | #include
7 |
8 | struct IdStr{
9 | int32_t ID;
10 | std::string STR;
11 | };
12 |
13 |
14 | struct StrStr{
15 | std::string Word;
16 | std::string Phn;
17 | };
18 |
19 |
20 | class Phonemizer
21 | {
22 | private:
23 | TFG2P G2pModel;
24 |
25 | std::vector CharId;
26 | std::vector PhnId;
27 |
28 |
29 |
30 |
31 |
32 |
33 | std::vector GetDelimitedFile(const std::string& InFname);
34 |
35 |
36 | // Sorry, can't use set, unordered_map or any other types. (I tried)
37 | std::vector Dictionary;
38 |
39 | void LoadDictionary(const std::string& InDictFn);
40 |
41 | std::string DictLookup(const std::string& InWord);
42 |
43 |
44 |
45 | std::string PhnLanguage;
46 | public:
47 | Phonemizer();
48 | /*
49 | * Initialize a phonemizer
50 | * Expects:
51 | * - Two files consisting in TOKEN \t ID:
52 | * -- char2id.txt: Translation from input character to ID the model can accept
53 | * -- phn2id.txt: Translation from output ID from the model to phoneme
54 | * - A model/ folder where a G2P-Tensorflow model was saved as SavedModel
55 | * - dict.txt: Phonetic dictionary. First it searches the word there and if it can't be found then it uses the model.
56 |
57 | */
58 | bool Initialize(const std::string InPath);
59 | std::string ProcessWord(const std::string& InWord, float Temperature = 0.1f);
60 | std::string GetPhnLanguage() const;
61 | void SetPhnLanguage(const std::string &value);
62 |
63 | std::string GetGraphemeChars();
64 |
65 | };
66 |
67 |
68 | bool operator<(const StrStr& right,const StrStr& left);
69 | #endif // PHONEMIZER_H
70 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/tfg2p.cpp:
--------------------------------------------------------------------------------
1 | #include "tfg2p.h"
2 | #include
3 | TFG2P::TFG2P()
4 | {
5 | G2P = nullptr;
6 |
7 | }
8 |
9 | TFG2P::TFG2P(const std::string &SavedModelFolder)
10 | {
11 | G2P = nullptr;
12 |
13 | Initialize(SavedModelFolder);
14 | }
15 |
16 | bool TFG2P::Initialize(const std::string &SavedModelFolder)
17 | {
18 | try {
19 |
20 | G2P = new Model(SavedModelFolder);
21 |
22 | }
23 | catch (...) {
24 | G2P = nullptr;
25 | return false;
26 |
27 | }
28 | return true;
29 | }
30 |
31 | TFTensor TFG2P::DoInference(const std::vector &InputIDs, float Temperature)
32 | {
33 | if (!G2P)
34 | throw std::invalid_argument("Tried to do inference on unloaded or invalid model!");
35 |
36 | // Convenience reference so that we don't have to constantly derefer pointers.
37 | Model& Mdl = *G2P;
38 |
39 |
40 | // Convenience reference so that we don't have to constantly derefer pointers.
41 |
42 | Tensor input_ids{ Mdl,"serving_default_input_ids" };
43 | Tensor input_len{Mdl,"serving_default_input_len"};
44 | Tensor input_temp{Mdl,"serving_default_input_temperature"};
45 |
46 | input_ids.set_data(InputIDs, std::vector{(int64_t)InputIDs.size()});
47 | input_len.set_data(std::vector{(int32_t)InputIDs.size()});
48 | input_temp.set_data(std::vector{Temperature});
49 |
50 |
51 |
52 | std::vector Inputs {&input_ids,&input_len,&input_temp};
53 | Tensor out_ids{ Mdl,"StatefulPartitionedCall" };
54 |
55 | Mdl.run(Inputs, out_ids);
56 |
57 | TFTensor RetTensor = VoxUtil::CopyTensor(out_ids);
58 |
59 | return RetTensor;
60 |
61 |
62 | }
63 |
64 | TFG2P::~TFG2P()
65 | {
66 | if (G2P)
67 | delete G2P;
68 |
69 | }
70 |
--------------------------------------------------------------------------------
/examples/cppwin/TensorflowTTSCppInference/tfg2p.h:
--------------------------------------------------------------------------------
1 | #ifndef TFG2P_H
2 | #define TFG2P_H
3 | #include "ext/CppFlow/include/Model.h"
4 | #include "VoxCommon.hpp"
5 |
6 |
7 | class TFG2P
8 | {
9 | private:
10 | Model* G2P;
11 |
12 | public:
13 | TFG2P();
14 | TFG2P(const std::string& SavedModelFolder);
15 |
16 | /*
17 | Initialize and load the model
18 |
19 | -> SavedModelFolder: Folder where the .pb, variables, and other characteristics of the exported SavedModel
20 | <- Returns: (bool)Success
21 | */
22 | bool Initialize(const std::string& SavedModelFolder);
23 |
24 | /*
25 | Do inference on a G2P-TF-RNN model.
26 |
27 | -> InputIDs: Input IDs of tokens for inference
28 | -> Temperature: Temperature of the RNN, values higher than 0.1 cause instability.
29 |
30 | <- Returns: TFTensor containing phoneme IDs
31 | */
32 | TFTensor DoInference(const std::vector& InputIDs, float Temperature = 0.1f);
33 |
34 | ~TFG2P();
35 |
36 | };
37 |
38 | #endif // TFG2P_H
39 |
--------------------------------------------------------------------------------
/examples/fastspeech/fig/fastspeech.v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/fastspeech/fig/fastspeech.v1.png
--------------------------------------------------------------------------------
/examples/fastspeech2_libritts/README.md:
--------------------------------------------------------------------------------
1 | # Fast speech 2 multi-speaker english lang based
2 |
3 | ## Prepare
4 | Everything is done from main repo folder so TensorflowTTS/
5 |
6 | 0. Optional* [Download](http://www.openslr.org/60/) and prepare libritts (helper to prepare libri in examples/fastspeech2_libritts/libri_experiment/prepare_libri.ipynb)
7 | - Dataset structure after finish this step:
8 | ```
9 | |- TensorFlowTTS/
10 | | |- LibriTTS/
11 | | |- |- train-clean-100/
12 | | |- |- SPEAKERS.txt
13 | | |- |- ...
14 | | |- libritts/
15 | | |- |- 200/
16 | | |- |- |- 200_124139_000001_000000.txt
17 | | |- |- |- 200_124139_000001_000000.wav
18 | | |- |- |- ...
19 | | |- |- 250/
20 | | |- |- ...
21 | | |- tensorflow_tts/
22 | | |- models/
23 | | |- ...
24 | ```
25 | 1. Extract Duration (use examples/mfa_extraction or pretrained tacotron2)
26 | 2. Optional* build docker
27 | - ```
28 | bash examples/fastspeech2_libritts/scripts/build.sh
29 | ```
30 | 3. Optional* run docker
31 | - ```
32 | bash examples/fastspeech2_libritts/scripts/interactive.sh
33 | ```
34 | 4. Preprocessing:
35 | - ```
36 | tensorflow-tts-preprocess --rootdir ./libritts \
37 | --outdir ./dump_libritts \
38 | --config preprocess/libritts_preprocess.yaml \
39 | --dataset libritts
40 | ```
41 |
42 | 5. Normalization:
43 | - ```
44 | tensorflow-tts-normalize --rootdir ./dump_libritts \
45 | --outdir ./dump_libritts \
46 | --config preprocess/libritts_preprocess.yaml \
47 | --dataset libritts
48 | ```
49 |
50 | 6. Change CharactorDurationF0EnergyMelDataset speaker mapper in fastspeech2_dataset to match your dataset (if you use libri with mfa_extraction you didnt need to change anything)
51 | 7. Change train_libri.sh to match your dataset and run:
52 | - ```
53 | bash examples/fastspeech2_libritts/scripts/train_libri.sh
54 | ```
55 | 8. Optional* If u have problems with tensor sizes mismatch check step 5 in `examples/mfa_extraction` directory
56 |
57 | ## Comments
58 |
59 | This version is using popular train.txt '|' split used in other repos. Training files should looks like this =>
60 |
61 | Wav Path | Text | Speaker Name
62 |
63 | Wav Path2 | Text | Speaker Name
64 |
65 |
--------------------------------------------------------------------------------
/examples/fastspeech2_libritts/scripts/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker build --rm -t tftts -f examples/fastspeech2_libritts/scripts/docker/Dockerfile .
3 |
--------------------------------------------------------------------------------
/examples/fastspeech2_libritts/scripts/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:2.2.0-gpu
2 | RUN apt-get update
3 | RUN apt-get install -y zsh tmux wget git libsndfile1
4 | ADD . /workspace/tts
5 | WORKDIR /workspace/tts
6 | RUN pip install .
7 |
8 |
--------------------------------------------------------------------------------
/examples/fastspeech2_libritts/scripts/interactive.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker run --gpus all --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -it --rm --ipc=host -p 8888:8888 -v $PWD:/workspace/tts/ tftts bash
3 |
--------------------------------------------------------------------------------
/examples/fastspeech2_libritts/scripts/train_libri.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python examples/fastspeech2_libritts/train_fastspeech2.py \
2 | --train-dir ./dump/train/ \
3 | --dev-dir ./dump/valid/ \
4 | --outdir ./examples/fastspeech2_libritts/outdir_libri/ \
5 | --config ./examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml \
6 | --use-norm 1 \
7 | --f0-stat ./dump/stats_f0.npy \
8 | --energy-stat ./dump/stats_energy.npy \
9 | --mixed_precision 1 \
10 | --dataset_config preprocess/libritts_preprocess.yaml \
11 | --dataset_stats dump/stats.npy
--------------------------------------------------------------------------------
/examples/ios/.gitignore:
--------------------------------------------------------------------------------
1 | Pods
2 | *.xcworkspace
3 | xcuserdata
4 |
--------------------------------------------------------------------------------
/examples/ios/Podfile:
--------------------------------------------------------------------------------
1 | platform :ios, '14.0'
2 |
3 | target 'TF_TTS_Demo' do
4 | pod 'TensorFlowLiteSwift'
5 | pod 'TensorFlowLiteSelectTfOps'
6 | end
7 |
--------------------------------------------------------------------------------
/examples/ios/Podfile.lock:
--------------------------------------------------------------------------------
1 | PODS:
2 | - TensorFlowLiteC (2.4.0):
3 | - TensorFlowLiteC/Core (= 2.4.0)
4 | - TensorFlowLiteC/Core (2.4.0)
5 | - TensorFlowLiteSelectTfOps (2.4.0)
6 | - TensorFlowLiteSwift (2.4.0):
7 | - TensorFlowLiteSwift/Core (= 2.4.0)
8 | - TensorFlowLiteSwift/Core (2.4.0):
9 | - TensorFlowLiteC (= 2.4.0)
10 |
11 | DEPENDENCIES:
12 | - TensorFlowLiteSelectTfOps
13 | - TensorFlowLiteSwift
14 |
15 | SPEC REPOS:
16 | trunk:
17 | - TensorFlowLiteC
18 | - TensorFlowLiteSelectTfOps
19 | - TensorFlowLiteSwift
20 |
21 | SPEC CHECKSUMS:
22 | TensorFlowLiteC: 09f8ac75a76caeadb19bcfa694e97454cc1ecf87
23 | TensorFlowLiteSelectTfOps: f8053d3ec72032887b832d2d060015d8b7031144
24 | TensorFlowLiteSwift: f062dc1178120100d825d7799fd9f115b4a37fee
25 |
26 | PODFILE CHECKSUM: 12da12fb22671b6cdc578320043f5d310043aded
27 |
28 | COCOAPODS: 1.10.1
29 |
--------------------------------------------------------------------------------
/examples/ios/README.md:
--------------------------------------------------------------------------------
1 | # iOS Demo
2 |
3 | This app demonstrates using FastSpeech2 and MB MelGAN models on iOS.
4 |
5 | ## How to build
6 |
7 | Download LJ Speech TFLite models from https://github.com/luan78zaoha/TTS_tflite_cpp/releases/tag/0.1.0 and unpack into TF_TTS_Demo directory containing Swift files.
8 |
9 | It uses [CocoaPods](https://cocoapods.org) to link with TensorFlowSwift.
10 |
11 | ```
12 | pod install
13 | open TF_TTS_Demo.xcworkspace
14 | ```
15 |
16 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "iphone",
5 | "scale" : "2x",
6 | "size" : "20x20"
7 | },
8 | {
9 | "idiom" : "iphone",
10 | "scale" : "3x",
11 | "size" : "20x20"
12 | },
13 | {
14 | "idiom" : "iphone",
15 | "scale" : "2x",
16 | "size" : "29x29"
17 | },
18 | {
19 | "idiom" : "iphone",
20 | "scale" : "3x",
21 | "size" : "29x29"
22 | },
23 | {
24 | "idiom" : "iphone",
25 | "scale" : "2x",
26 | "size" : "40x40"
27 | },
28 | {
29 | "idiom" : "iphone",
30 | "scale" : "3x",
31 | "size" : "40x40"
32 | },
33 | {
34 | "idiom" : "iphone",
35 | "scale" : "2x",
36 | "size" : "60x60"
37 | },
38 | {
39 | "idiom" : "iphone",
40 | "scale" : "3x",
41 | "size" : "60x60"
42 | },
43 | {
44 | "idiom" : "ipad",
45 | "scale" : "1x",
46 | "size" : "20x20"
47 | },
48 | {
49 | "idiom" : "ipad",
50 | "scale" : "2x",
51 | "size" : "20x20"
52 | },
53 | {
54 | "idiom" : "ipad",
55 | "scale" : "1x",
56 | "size" : "29x29"
57 | },
58 | {
59 | "idiom" : "ipad",
60 | "scale" : "2x",
61 | "size" : "29x29"
62 | },
63 | {
64 | "idiom" : "ipad",
65 | "scale" : "1x",
66 | "size" : "40x40"
67 | },
68 | {
69 | "idiom" : "ipad",
70 | "scale" : "2x",
71 | "size" : "40x40"
72 | },
73 | {
74 | "idiom" : "ipad",
75 | "scale" : "1x",
76 | "size" : "76x76"
77 | },
78 | {
79 | "idiom" : "ipad",
80 | "scale" : "2x",
81 | "size" : "76x76"
82 | },
83 | {
84 | "idiom" : "ipad",
85 | "scale" : "2x",
86 | "size" : "83.5x83.5"
87 | },
88 | {
89 | "idiom" : "ios-marketing",
90 | "scale" : "1x",
91 | "size" : "1024x1024"
92 | }
93 | ],
94 | "info" : {
95 | "author" : "xcode",
96 | "version" : 1
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/ContentView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ContentView.swift
3 | // TF TTS Demo
4 | //
5 | // Created by 안창범 on 2021/03/16.
6 | //
7 |
8 | import SwiftUI
9 |
10 | struct ContentView: View {
11 | @StateObject var tts = TTS()
12 |
13 | @State var text = "The Rhodes Must Fall campaigners said the announcement was hopeful, but warned they would remain cautious until the college had actually carried out the removal."
14 |
15 | var body: some View {
16 | VStack {
17 | TextEditor(text: $text)
18 | Button {
19 | tts.speak(string: text)
20 | } label: {
21 | Label("Speak", systemImage: "speaker.1")
22 | }
23 | }
24 | .padding()
25 | }
26 | }
27 |
28 | struct ContentView_Previews: PreviewProvider {
29 | static var previews: some View {
30 | ContentView()
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/FastSpeech2.swift:
--------------------------------------------------------------------------------
1 | //
2 | // FastSpeech2.swift
3 | // HelloTensorFlowTTS
4 | //
5 | // Created by 안창범 on 2021/03/09.
6 | //
7 |
8 | import Foundation
9 | import TensorFlowLite
10 |
11 | class FastSpeech2 {
12 | let interpreter: Interpreter
13 |
14 | var speakerId: Int32 = 0
15 |
16 | var f0Ratio: Float = 1
17 |
18 | var energyRatio: Float = 1
19 |
20 | init(url: URL) throws {
21 | var options = Interpreter.Options()
22 | options.threadCount = 5
23 | interpreter = try Interpreter(modelPath: url.path, options: options)
24 | }
25 |
26 | func getMelSpectrogram(inputIds: [Int32], speedRatio: Float) throws -> Tensor {
27 | try interpreter.resizeInput(at: 0, to: [1, inputIds.count])
28 | try interpreter.allocateTensors()
29 |
30 | let data = inputIds.withUnsafeBufferPointer(Data.init)
31 | try interpreter.copy(data, toInputAt: 0)
32 | try interpreter.copy(Data(bytes: &speakerId, count: 4), toInputAt: 1)
33 | var speedRatio = speedRatio
34 | try interpreter.copy(Data(bytes: &speedRatio, count: 4), toInputAt: 2)
35 | try interpreter.copy(Data(bytes: &f0Ratio, count: 4), toInputAt: 3)
36 | try interpreter.copy(Data(bytes: &energyRatio, count: 4), toInputAt: 4)
37 |
38 | let t0 = Date()
39 | try interpreter.invoke()
40 | print("fastspeech2: \(Date().timeIntervalSince(t0))s")
41 |
42 | return try interpreter.output(at: 1)
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/Info.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | CFBundleDevelopmentRegion
6 | $(DEVELOPMENT_LANGUAGE)
7 | CFBundleExecutable
8 | $(EXECUTABLE_NAME)
9 | CFBundleIdentifier
10 | $(PRODUCT_BUNDLE_IDENTIFIER)
11 | CFBundleInfoDictionaryVersion
12 | 6.0
13 | CFBundleName
14 | $(PRODUCT_NAME)
15 | CFBundlePackageType
16 | $(PRODUCT_BUNDLE_PACKAGE_TYPE)
17 | CFBundleShortVersionString
18 | 1.0
19 | CFBundleVersion
20 | 1
21 | LSRequiresIPhoneOS
22 |
23 | UIApplicationSceneManifest
24 |
25 | UIApplicationSupportsMultipleScenes
26 |
27 |
28 | UIApplicationSupportsIndirectInputEvents
29 |
30 | UILaunchScreen
31 |
32 | UIRequiredDeviceCapabilities
33 |
34 | armv7
35 |
36 | UISupportedInterfaceOrientations
37 |
38 | UIInterfaceOrientationPortrait
39 | UIInterfaceOrientationLandscapeLeft
40 | UIInterfaceOrientationLandscapeRight
41 |
42 | UISupportedInterfaceOrientations~ipad
43 |
44 | UIInterfaceOrientationPortrait
45 | UIInterfaceOrientationPortraitUpsideDown
46 | UIInterfaceOrientationLandscapeLeft
47 | UIInterfaceOrientationLandscapeRight
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/MBMelGAN.swift:
--------------------------------------------------------------------------------
1 | //
2 | // MBMelGAN.swift
3 | // HelloTensorFlowTTS
4 | //
5 | // Created by 안창범 on 2021/03/09.
6 | //
7 |
8 | import Foundation
9 | import TensorFlowLite
10 |
11 | class MBMelGan {
12 | let interpreter: Interpreter
13 |
14 | init(url: URL) throws {
15 | var options = Interpreter.Options()
16 | options.threadCount = 5
17 | interpreter = try Interpreter(modelPath: url.path, options: options)
18 | }
19 |
20 | func getAudio(input: Tensor) throws -> Data {
21 | try interpreter.resizeInput(at: 0, to: input.shape)
22 | try interpreter.allocateTensors()
23 |
24 | try interpreter.copy(input.data, toInputAt: 0)
25 |
26 | let t0 = Date()
27 | try interpreter.invoke()
28 | print("mbmelgan: \(Date().timeIntervalSince(t0))s")
29 |
30 | return try interpreter.output(at: 0).data
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/examples/ios/TF_TTS_Demo/TF_TTS_DemoApp.swift:
--------------------------------------------------------------------------------
1 | //
2 | // TF_TTS_DemoApp.swift
3 | // TF TTS Demo
4 | //
5 | // Created by 안창범 on 2021/03/16.
6 | //
7 |
8 | import SwiftUI
9 |
10 | @main
11 | struct TF_TTS_DemoApp: App {
12 | var body: some Scene {
13 | WindowGroup {
14 | ContentView()
15 | }
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/examples/melgan/fig/melgan.v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan/fig/melgan.v1.png
--------------------------------------------------------------------------------
/examples/melgan_stft/fig/melgan.stft.v1.eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan_stft/fig/melgan.stft.v1.eval.png
--------------------------------------------------------------------------------
/examples/melgan_stft/fig/melgan.stft.v1.train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/melgan_stft/fig/melgan.stft.v1.train.png
--------------------------------------------------------------------------------
/examples/mfa_extraction/README.md:
--------------------------------------------------------------------------------
1 | # MFA based extraction for FastSpeech
2 |
3 | ## Prepare
4 | Everything is done from main repo folder so TensorflowTTS/
5 |
6 | 0. Optional* Modify MFA scripts to work with your language (https://montreal-forced-aligner.readthedocs.io/en/latest/pretrained_models.html)
7 |
8 | 1. Download pretrained mfa, lexicon and run extract textgrids:
9 |
10 | - ```
11 | bash examples/mfa_extraction/scripts/prepare_mfa.sh
12 | ```
13 |
14 | - ```
15 | python examples/mfa_extraction/run_mfa.py \
16 | --corpus_directory ./libritts \
17 | --output_directory ./mfa/parsed \
18 | --jobs 8
19 | ```
20 |
21 | After this step, the TextGrids is allocated at `./mfa/parsed`.
22 |
23 | 2. Extract duration from textgrid files:
24 | - ```
25 | python examples/mfa_extraction/txt_grid_parser.py \
26 | --yaml_path examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml \
27 | --dataset_path ./libritts \
28 | --text_grid_path ./mfa/parsed \
29 | --output_durations_path ./libritts/durations \
30 | --sample_rate 24000
31 | ```
32 |
33 | - Dataset structure after finish this step:
34 | ```
35 | |- TensorFlowTTS/
36 | | |- LibriTTS/
37 | | |- |- train-clean-100/
38 | | |- |- SPEAKERS.txt
39 | | |- |- ...
40 | | |- dataset/
41 | | |- |- 200/
42 | | |- |- |- 200_124139_000001_000000.txt
43 | | |- |- |- 200_124139_000001_000000.wav
44 | | |- |- |- ...
45 | | |- |- 250/
46 | | |- |- ...
47 | | |- |- durations/
48 | | |- |- train.txt
49 | | |- tensorflow_tts/
50 | | |- models/
51 | | |- ...
52 | ```
53 | 3. Optional* add your own dataset parser based on tensorflow_tts/processor/experiment/example_dataset.py ( If base processor dataset didnt match yours )
54 |
55 | 4. Run preprocess and normalization (Step 4,5 in `examples/fastspeech2_libritts/README.MD`)
56 |
57 | 5. Run fix mismatch to fix few frames difference in audio and duration files:
58 |
59 | - ```
60 | python examples/mfa_extraction/fix_mismatch.py \
61 | --base_path ./dump \
62 | --trimmed_dur_path ./dataset/trimmed-durations \
63 | --dur_path ./dataset/durations
64 | ```
65 |
66 | ## Problems with MFA extraction
67 | Looks like MFA have problems with trimmed files it works better (in my experiments) with ~100ms of silence at start and end
68 |
69 | Short files can get a lot of false positive like only silence extraction (LibriTTS example) so i would get only samples >2s
70 |
--------------------------------------------------------------------------------
/examples/mfa_extraction/requirements.txt:
--------------------------------------------------------------------------------
1 | textgrid
2 | click
3 | g2p_en
--------------------------------------------------------------------------------
/examples/mfa_extraction/run_mfa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Runing mfa to extract textgrids."""
16 |
17 | from subprocess import call
18 | from pathlib import Path
19 |
20 | import click
21 | import os
22 |
23 |
24 | @click.command()
25 | @click.option("--mfa_path", default=os.path.join('mfa', 'montreal-forced-aligner', 'bin', 'mfa_align'))
26 | @click.option("--corpus_directory", default="libritts")
27 | @click.option("--lexicon", default=os.path.join('mfa', 'lexicon', 'librispeech-lexicon.txt'))
28 | @click.option("--acoustic_model_path", default=os.path.join('mfa', 'montreal-forced-aligner', 'pretrained_models', 'english.zip'))
29 | @click.option("--output_directory", default=os.path.join('mfa', 'parsed'))
30 | @click.option("--jobs", default="8")
31 | def run_mfa(
32 | mfa_path: str,
33 | corpus_directory: str,
34 | lexicon: str,
35 | acoustic_model_path: str,
36 | output_directory: str,
37 | jobs: str,
38 | ):
39 | Path(output_directory).mkdir(parents=True, exist_ok=True)
40 | call(
41 | [
42 | f".{os.path.sep}{mfa_path}",
43 | corpus_directory,
44 | lexicon,
45 | acoustic_model_path,
46 | output_directory,
47 | f"-j {jobs}"
48 | ]
49 | )
50 |
51 |
52 | if __name__ == "__main__":
53 | run_mfa()
54 |
--------------------------------------------------------------------------------
/examples/mfa_extraction/scripts/prepare_mfa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir mfa
3 | cd mfa
4 | wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.1.0-beta.2/montreal-forced-aligner_linux.tar.gz
5 | tar -zxvf montreal-forced-aligner_linux.tar.gz
6 | cd mfa
7 | mkdir lexicon
8 | cd lexicon
9 | wget http://www.openslr.org/resources/11/librispeech-lexicon.txt
--------------------------------------------------------------------------------
/examples/multiband_melgan/fig/eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan/fig/eval.png
--------------------------------------------------------------------------------
/examples/multiband_melgan/fig/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan/fig/train.png
--------------------------------------------------------------------------------
/examples/multiband_melgan_hf/fig/eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan_hf/fig/eval.png
--------------------------------------------------------------------------------
/examples/multiband_melgan_hf/fig/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/multiband_melgan_hf/fig/train.png
--------------------------------------------------------------------------------
/examples/tacotron2/fig/alignment.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/tacotron2/fig/alignment.gif
--------------------------------------------------------------------------------
/examples/tacotron2/fig/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/examples/tacotron2/fig/tensorboard.png
--------------------------------------------------------------------------------
/notebooks/multiband_melgan_inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import yaml\n",
10 | "import numpy as np\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "\n",
13 | "import tensorflow as tf\n",
14 | "\n",
15 | "from tensorflow_tts.inference import AutoConfig\n",
16 | "from tensorflow_tts.inference import TFAutoModel"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "mb_melgan = TFAutoModel.from_pretrained(\"tensorspeech/tts-mb_melgan-ljspeech-en\")"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "# Save to Pb"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "tf.saved_model.save(mb_melgan, \"./mb_melgan\", signatures=mb_melgan.inference)"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "# Load and Inference"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "mb_melgan = tf.saved_model.load(\"./mb_melgan\")"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "mels = np.load(\"../dump/valid/norm-feats/LJ001-0009-norm-feats.npy\")"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "audios = mb_melgan.inference(mels[None, ...])"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "plt.plot(audios[0, :, 0])"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": []
93 | }
94 | ],
95 | "metadata": {
96 | "kernelspec": {
97 | "display_name": "Python 3",
98 | "language": "python",
99 | "name": "python3"
100 | },
101 | "language_info": {
102 | "codemirror_mode": {
103 | "name": "ipython",
104 | "version": 3
105 | },
106 | "file_extension": ".py",
107 | "mimetype": "text/x-python",
108 | "name": "python",
109 | "nbconvert_exporter": "python",
110 | "pygments_lexer": "ipython3",
111 | "version": "3.7.7"
112 | }
113 | },
114 | "nbformat": 4,
115 | "nbformat_minor": 4
116 | }
117 |
--------------------------------------------------------------------------------
/preprocess/baker_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 24000 # Sampling rate.
5 | fft_size: 2048 # FFT size.
6 | hop_size: 300 # Hop size. (fixed value, don't change)
7 | win_length: 1200 # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/preprocess/jsut_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 24000 # Sampling rate.
5 | fft_size: 2048 # FFT size.
6 | hop_size: 300 # Hop size. (fixed value, don't change)
7 | win_length: 1200 # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/preprocess/kss_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 22050 # Sampling rate.
5 | fft_size: 1024 # FFT size.
6 | hop_size: 256 # Hop size. (fixed value, don't change)
7 | win_length: null # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 30 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/preprocess/libritts_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################base_preprocess
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 24000 # Sampling rate.
5 | fft_size: 1024 # FFT size.
6 | hop_size: 300 # Hop size. (fixed value, don't change)
7 | win_length: null # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 | trim_mfa: true
20 |
21 |
--------------------------------------------------------------------------------
/preprocess/ljspeech_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 22050 # Sampling rate.
5 | fft_size: 1024 # FFT size.
6 | hop_size: 256 # Hop size. (fixed value, don't change)
7 | win_length: null # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/preprocess/ljspeechu_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 44100 # Sampling rate.
5 | fft_size: 2048 # FFT size.
6 | hop_size: 512 # Hop size. (fixed value, don't change)
7 | win_length: 2048 # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 20 # Minimum freq in mel basis calculation.
12 | fmax: 11025 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: false # Whether to trim the start and end of silence
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 | trim_mfa: false
--------------------------------------------------------------------------------
/preprocess/synpaflex_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 22050 # Sampling rate.
5 | fft_size: 1024 # FFT size.
6 | hop_size: 256 # Hop size. (fixed value, don't change)
7 | win_length: null # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 20 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/preprocess/thorsten_preprocess.yaml:
--------------------------------------------------------------------------------
1 | ###########################################################
2 | # FEATURE EXTRACTION SETTING #
3 | ###########################################################
4 | sampling_rate: 22050 # Sampling rate.
5 | fft_size: 1024 # FFT size.
6 | hop_size: 256 # Hop size. (fixed value, don't change)
7 | win_length: null # Window length.
8 | # If set to null, it will be the same as fft_size.
9 | window: "hann" # Window function.
10 | num_mels: 80 # Number of mel basis.
11 | fmin: 80 # Minimum freq in mel basis calculation.
12 | fmax: 7600 # Maximum frequency in mel basis calculation.
13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform.
14 | trim_silence: true # Whether to trim the start and end of silence.
15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good.
16 | trim_frame_size: 2048 # Frame size in trimming.
17 | trim_hop_size: 512 # Hop size in trimming.
18 | format: "npy" # Feature file format. Only "npy" is supported.
19 |
20 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose --durations=0
6 | testpaths = test
7 |
8 | [flake8]
9 | ignore = H102,W504,H238,D104,H306,H405,D205
10 | # 120 is a workaround, 79 is good
11 | max-line-length = 120
12 |
--------------------------------------------------------------------------------
/tensorflow_tts/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0"
2 |
--------------------------------------------------------------------------------
/tensorflow_tts/bin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorSpeech/TensorFlowTTS/136877136355c82d7ba474ceb7a8f133bd84767e/tensorflow_tts/bin/__init__.py
--------------------------------------------------------------------------------
/tensorflow_tts/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.configs.base_config import BaseConfig
2 | from tensorflow_tts.configs.fastspeech import FastSpeechConfig
3 | from tensorflow_tts.configs.fastspeech2 import FastSpeech2Config
4 | from tensorflow_tts.configs.melgan import (
5 | MelGANDiscriminatorConfig,
6 | MelGANGeneratorConfig,
7 | )
8 | from tensorflow_tts.configs.mb_melgan import (
9 | MultiBandMelGANDiscriminatorConfig,
10 | MultiBandMelGANGeneratorConfig,
11 | )
12 | from tensorflow_tts.configs.hifigan import (
13 | HifiGANGeneratorConfig,
14 | HifiGANDiscriminatorConfig,
15 | )
16 | from tensorflow_tts.configs.tacotron2 import Tacotron2Config
17 | from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANGeneratorConfig
18 | from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANDiscriminatorConfig
19 |
--------------------------------------------------------------------------------
/tensorflow_tts/configs/base_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Base Config for all config."""
16 |
17 | import abc
18 | import yaml
19 | import os
20 |
21 | from tensorflow_tts.utils.utils import CONFIG_FILE_NAME
22 |
23 |
24 | class BaseConfig(abc.ABC):
25 | def set_config_params(self, config_params):
26 | self.config_params = config_params
27 |
28 | def save_pretrained(self, saved_path):
29 | """Save config to file"""
30 | os.makedirs(saved_path, exist_ok=True)
31 | with open(os.path.join(saved_path, CONFIG_FILE_NAME), "w") as file:
32 | yaml.dump(self.config_params, file, Dumper=yaml.Dumper)
33 |
--------------------------------------------------------------------------------
/tensorflow_tts/configs/fastspeech2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """FastSpeech2 Config object."""
16 |
17 |
18 | from tensorflow_tts.configs import FastSpeechConfig
19 |
20 |
21 | class FastSpeech2Config(FastSpeechConfig):
22 | """Initialize FastSpeech2 Config."""
23 |
24 | def __init__(
25 | self,
26 | variant_prediction_num_conv_layers=2,
27 | variant_kernel_size=9,
28 | variant_dropout_rate=0.5,
29 | variant_predictor_filter=256,
30 | variant_predictor_kernel_size=3,
31 | variant_predictor_dropout_rate=0.5,
32 | **kwargs
33 | ):
34 | super().__init__(**kwargs)
35 | self.variant_prediction_num_conv_layers = variant_prediction_num_conv_layers
36 | self.variant_predictor_kernel_size = variant_predictor_kernel_size
37 | self.variant_predictor_dropout_rate = variant_predictor_dropout_rate
38 | self.variant_predictor_filter = variant_predictor_filter
39 |
--------------------------------------------------------------------------------
/tensorflow_tts/configs/hifigan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorflowTTS Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """HifiGAN Config object."""
16 |
17 |
18 | from tensorflow_tts.configs import BaseConfig
19 |
20 |
21 | class HifiGANGeneratorConfig(BaseConfig):
22 | """Initialize HifiGAN Generator Config."""
23 |
24 | def __init__(
25 | self,
26 | out_channels=1,
27 | kernel_size=7,
28 | filters=128,
29 | use_bias=True,
30 | upsample_scales=[8, 8, 2, 2],
31 | stacks=3,
32 | stack_kernel_size=[3, 7, 11],
33 | stack_dilation_rate=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
34 | nonlinear_activation="LeakyReLU",
35 | nonlinear_activation_params={"alpha": 0.2},
36 | padding_type="REFLECT",
37 | use_final_nolinear_activation=True,
38 | is_weight_norm=True,
39 | initializer_seed=42,
40 | **kwargs
41 | ):
42 | """Init parameters for HifiGAN Generator model."""
43 | self.out_channels = out_channels
44 | self.kernel_size = kernel_size
45 | self.filters = filters
46 | self.use_bias = use_bias
47 | self.upsample_scales = upsample_scales
48 | self.stacks = stacks
49 | self.stack_kernel_size = stack_kernel_size
50 | self.stack_dilation_rate = stack_dilation_rate
51 | self.nonlinear_activation = nonlinear_activation
52 | self.nonlinear_activation_params = nonlinear_activation_params
53 | self.padding_type = padding_type
54 | self.use_final_nolinear_activation = use_final_nolinear_activation
55 | self.is_weight_norm = is_weight_norm
56 | self.initializer_seed = initializer_seed
57 |
58 |
59 | class HifiGANDiscriminatorConfig(object):
60 | """Initialize HifiGAN Discriminator Config."""
61 |
62 | def __init__(
63 | self,
64 | out_channels=1,
65 | period_scales=[2, 3, 5, 7, 11],
66 | n_layers=5,
67 | kernel_size=5,
68 | strides=3,
69 | filters=8,
70 | filter_scales=4,
71 | max_filters=1024,
72 | nonlinear_activation="LeakyReLU",
73 | nonlinear_activation_params={"alpha": 0.2},
74 | is_weight_norm=True,
75 | initializer_seed=42,
76 | **kwargs
77 | ):
78 | """Init parameters for MelGAN Discriminator model."""
79 | self.out_channels = out_channels
80 | self.period_scales = period_scales
81 | self.n_layers = n_layers
82 | self.kernel_size = kernel_size
83 | self.strides = strides
84 | self.filters = filters
85 | self.filter_scales = filter_scales
86 | self.max_filters = max_filters
87 | self.nonlinear_activation = nonlinear_activation
88 | self.nonlinear_activation_params = nonlinear_activation_params
89 | self.is_weight_norm = is_weight_norm
90 | self.initializer_seed = initializer_seed
91 |
--------------------------------------------------------------------------------
/tensorflow_tts/configs/mb_melgan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Multi-band MelGAN Config object."""
16 |
17 | from tensorflow_tts.configs import MelGANDiscriminatorConfig, MelGANGeneratorConfig
18 |
19 |
20 | class MultiBandMelGANGeneratorConfig(MelGANGeneratorConfig):
21 | """Initialize Multi-band MelGAN Generator Config."""
22 |
23 | def __init__(self, **kwargs):
24 | super().__init__(**kwargs)
25 | self.subbands = kwargs.pop("subbands", 4)
26 | self.taps = kwargs.pop("taps", 62)
27 | self.cutoff_ratio = kwargs.pop("cutoff_ratio", 0.142)
28 | self.beta = kwargs.pop("beta", 9.0)
29 |
30 |
31 | class MultiBandMelGANDiscriminatorConfig(MelGANDiscriminatorConfig):
32 | """Initialize Multi-band MelGAN Discriminator Config."""
33 |
34 | def __init__(self, **kwargs):
35 | super().__init__(**kwargs)
36 |
--------------------------------------------------------------------------------
/tensorflow_tts/configs/parallel_wavegan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ParallelWaveGAN Config object."""
16 |
17 |
18 | from tensorflow_tts.configs import BaseConfig
19 |
20 |
21 | class ParallelWaveGANGeneratorConfig(BaseConfig):
22 | """Initialize ParallelWaveGAN Generator Config."""
23 |
24 | def __init__(
25 | self,
26 | out_channels=1,
27 | kernel_size=3,
28 | n_layers=30,
29 | stacks=3,
30 | residual_channels=64,
31 | gate_channels=128,
32 | skip_channels=64,
33 | aux_channels=80,
34 | aux_context_window=2,
35 | dropout_rate=0.0,
36 | use_bias=True,
37 | use_causal_conv=False,
38 | upsample_conditional_features=True,
39 | upsample_params={"upsample_scales": [4, 4, 4, 4]},
40 | initializer_seed=42,
41 | **kwargs,
42 | ):
43 | """Init parameters for ParallelWaveGAN Generator model."""
44 | self.out_channels = out_channels
45 | self.kernel_size = kernel_size
46 | self.n_layers = n_layers
47 | self.stacks = stacks
48 | self.residual_channels = residual_channels
49 | self.gate_channels = gate_channels
50 | self.skip_channels = skip_channels
51 | self.aux_channels = aux_channels
52 | self.aux_context_window = aux_context_window
53 | self.dropout_rate = dropout_rate
54 | self.use_bias = use_bias
55 | self.use_causal_conv = use_causal_conv
56 | self.upsample_conditional_features = upsample_conditional_features
57 | self.upsample_params = upsample_params
58 | self.initializer_seed = initializer_seed
59 |
60 |
61 | class ParallelWaveGANDiscriminatorConfig(object):
62 | """Initialize ParallelWaveGAN Discriminator Config."""
63 |
64 | def __init__(
65 | self,
66 | out_channels=1,
67 | kernel_size=3,
68 | n_layers=10,
69 | conv_channels=64,
70 | use_bias=True,
71 | dilation_factor=1,
72 | nonlinear_activation="LeakyReLU",
73 | nonlinear_activation_params={"alpha": 0.2},
74 | initializer_seed=42,
75 | apply_sigmoid_at_last=False,
76 | **kwargs,
77 | ):
78 | "Init parameters for ParallelWaveGAN Discriminator model."
79 | self.out_channels = out_channels
80 | self.kernel_size = kernel_size
81 | self.n_layers = n_layers
82 | self.conv_channels = conv_channels
83 | self.use_bias = use_bias
84 | self.dilation_factor = dilation_factor
85 | self.nonlinear_activation = nonlinear_activation
86 | self.nonlinear_activation_params = nonlinear_activation_params
87 | self.initializer_seed = initializer_seed
88 | self.apply_sigmoid_at_last = apply_sigmoid_at_last
89 |
--------------------------------------------------------------------------------
/tensorflow_tts/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
2 | from tensorflow_tts.datasets.audio_dataset import AudioDataset
3 | from tensorflow_tts.datasets.mel_dataset import MelDataset
4 |
--------------------------------------------------------------------------------
/tensorflow_tts/datasets/abstract_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Abstract Dataset modules."""
16 |
17 | import abc
18 |
19 | import tensorflow as tf
20 |
21 |
22 | class AbstractDataset(metaclass=abc.ABCMeta):
23 | """Abstract Dataset module for Dataset Loader."""
24 |
25 | @abc.abstractmethod
26 | def get_args(self):
27 | """Return args for generator function."""
28 | pass
29 |
30 | @abc.abstractmethod
31 | def generator(self):
32 | """Generator function, should have args from get_args function."""
33 | pass
34 |
35 | @abc.abstractmethod
36 | def get_output_dtypes(self):
37 | """Return output dtypes for each element from generator."""
38 | pass
39 |
40 | @abc.abstractmethod
41 | def get_len_dataset(self):
42 | """Return number of samples on dataset."""
43 | pass
44 |
45 | def create(
46 | self,
47 | allow_cache=False,
48 | batch_size=1,
49 | is_shuffle=False,
50 | map_fn=None,
51 | reshuffle_each_iteration=True,
52 | ):
53 | """Create tf.dataset function."""
54 | output_types = self.get_output_dtypes()
55 | datasets = tf.data.Dataset.from_generator(
56 | self.generator, output_types=output_types, args=(self.get_args())
57 | )
58 |
59 | if allow_cache:
60 | datasets = datasets.cache()
61 |
62 | if is_shuffle:
63 | datasets = datasets.shuffle(
64 | self.get_len_dataset(),
65 | reshuffle_each_iteration=reshuffle_each_iteration,
66 | )
67 |
68 | if batch_size > 1 and map_fn is None:
69 | raise ValueError("map function must define when batch_size > 1.")
70 |
71 | if map_fn is not None:
72 | datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)
73 |
74 | datasets = datasets.batch(batch_size)
75 | datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
76 |
77 | return datasets
78 |
--------------------------------------------------------------------------------
/tensorflow_tts/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.inference.auto_model import TFAutoModel
2 | from tensorflow_tts.inference.auto_config import AutoConfig
3 | from tensorflow_tts.inference.auto_processor import AutoProcessor
4 |
--------------------------------------------------------------------------------
/tensorflow_tts/inference/auto_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tensorflow Auto Config modules."""
16 |
17 | import logging
18 | import yaml
19 | import os
20 | from collections import OrderedDict
21 |
22 | from tensorflow_tts.configs import (
23 | FastSpeechConfig,
24 | FastSpeech2Config,
25 | MelGANGeneratorConfig,
26 | MultiBandMelGANGeneratorConfig,
27 | HifiGANGeneratorConfig,
28 | Tacotron2Config,
29 | ParallelWaveGANGeneratorConfig,
30 | )
31 |
32 | from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME
33 | from tensorflow_tts import __version__ as VERSION
34 | from huggingface_hub import hf_hub_url, cached_download
35 |
36 | CONFIG_MAPPING = OrderedDict(
37 | [
38 | ("fastspeech", FastSpeechConfig),
39 | ("fastspeech2", FastSpeech2Config),
40 | ("multiband_melgan_generator", MultiBandMelGANGeneratorConfig),
41 | ("melgan_generator", MelGANGeneratorConfig),
42 | ("hifigan_generator", HifiGANGeneratorConfig),
43 | ("tacotron2", Tacotron2Config),
44 | ("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig),
45 | ]
46 | )
47 |
48 |
49 | class AutoConfig:
50 | def __init__(self):
51 | raise EnvironmentError(
52 | "AutoConfig is designed to be instantiated "
53 | "using the `AutoConfig.from_pretrained(pretrained_path)` method."
54 | )
55 |
56 | @classmethod
57 | def from_pretrained(cls, pretrained_path, **kwargs):
58 | # load weights from hf hub
59 | if not os.path.isfile(pretrained_path):
60 | # retrieve correct hub url
61 | download_url = hf_hub_url(
62 | repo_id=pretrained_path, filename=CONFIG_FILE_NAME
63 | )
64 |
65 | pretrained_path = str(
66 | cached_download(
67 | url=download_url,
68 | library_name=LIBRARY_NAME,
69 | library_version=VERSION,
70 | cache_dir=CACHE_DIRECTORY,
71 | )
72 | )
73 |
74 | with open(pretrained_path) as f:
75 | config = yaml.load(f, Loader=yaml.Loader)
76 |
77 | try:
78 | model_type = config["model_type"]
79 | config_class = CONFIG_MAPPING[model_type]
80 | config_class = config_class(**config[model_type + "_params"], **kwargs)
81 | config_class.set_config_params(config)
82 | return config_class
83 | except Exception:
84 | raise ValueError(
85 | "Unrecognized config in {}. "
86 | "Should have a `model_type` key in its config.yaml, or contain one of the following strings "
87 | "in its name: {}".format(
88 | pretrained_path, ", ".join(CONFIG_MAPPING.keys())
89 | )
90 | )
91 |
--------------------------------------------------------------------------------
/tensorflow_tts/inference/auto_processor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 The TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tensorflow Auto Processor modules."""
16 |
17 | import logging
18 | import json
19 | import os
20 | from collections import OrderedDict
21 |
22 | from tensorflow_tts.processor import (
23 | LJSpeechProcessor,
24 | KSSProcessor,
25 | BakerProcessor,
26 | LibriTTSProcessor,
27 | ThorstenProcessor,
28 | LJSpeechUltimateProcessor,
29 | SynpaflexProcessor,
30 | JSUTProcessor,
31 | )
32 |
33 | from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME
34 | from tensorflow_tts import __version__ as VERSION
35 | from huggingface_hub import hf_hub_url, cached_download
36 |
37 | CONFIG_MAPPING = OrderedDict(
38 | [
39 | ("LJSpeechProcessor", LJSpeechProcessor),
40 | ("KSSProcessor", KSSProcessor),
41 | ("BakerProcessor", BakerProcessor),
42 | ("LibriTTSProcessor", LibriTTSProcessor),
43 | ("ThorstenProcessor", ThorstenProcessor),
44 | ("LJSpeechUltimateProcessor", LJSpeechUltimateProcessor),
45 | ("SynpaflexProcessor", SynpaflexProcessor),
46 | ("JSUTProcessor", JSUTProcessor),
47 | ]
48 | )
49 |
50 |
51 | class AutoProcessor:
52 | def __init__(self):
53 | raise EnvironmentError(
54 | "AutoProcessor is designed to be instantiated "
55 | "using the `AutoProcessor.from_pretrained(pretrained_path)` method."
56 | )
57 |
58 | @classmethod
59 | def from_pretrained(cls, pretrained_path, **kwargs):
60 | # load weights from hf hub
61 | if not os.path.isfile(pretrained_path):
62 | # retrieve correct hub url
63 | download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME)
64 |
65 | pretrained_path = str(
66 | cached_download(
67 | url=download_url,
68 | library_name=LIBRARY_NAME,
69 | library_version=VERSION,
70 | cache_dir=CACHE_DIRECTORY,
71 | )
72 | )
73 | with open(pretrained_path, "r") as f:
74 | config = json.load(f)
75 |
76 | try:
77 | processor_name = config["processor_name"]
78 | processor_class = CONFIG_MAPPING[processor_name]
79 | processor_class = processor_class(
80 | data_dir=None, loaded_mapper_path=pretrained_path
81 | )
82 | return processor_class
83 | except Exception:
84 | raise ValueError(
85 | "Unrecognized processor in {}. "
86 | "Should have a `processor_name` key in its config.json, or contain one of the following strings "
87 | "in its name: {}".format(
88 | pretrained_path, ", ".join(CONFIG_MAPPING.keys())
89 | )
90 | )
91 |
--------------------------------------------------------------------------------
/tensorflow_tts/inference/savable_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tensorflow Savable Model modules."""
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from tensorflow_tts.models import (
21 | TFFastSpeech,
22 | TFFastSpeech2,
23 | TFMelGANGenerator,
24 | TFMBMelGANGenerator,
25 | TFHifiGANGenerator,
26 | TFTacotron2,
27 | TFParallelWaveGANGenerator,
28 | )
29 |
30 |
31 | class SavableTFTacotron2(TFTacotron2):
32 | def __init__(self, config, **kwargs):
33 | super().__init__(config, **kwargs)
34 |
35 | def call(self, inputs, training=False):
36 | input_ids, input_lengths, speaker_ids = inputs
37 | return super().inference(input_ids, input_lengths, speaker_ids)
38 |
39 | def _build(self):
40 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=tf.int32)
41 | input_lengths = tf.convert_to_tensor([9], dtype=tf.int32)
42 | speaker_ids = tf.convert_to_tensor([0], dtype=tf.int32)
43 | self([input_ids, input_lengths, speaker_ids])
44 |
45 |
46 | class SavableTFFastSpeech(TFFastSpeech):
47 | def __init__(self, config, **kwargs):
48 | super().__init__(config, **kwargs)
49 |
50 | def call(self, inputs, training=False):
51 | input_ids, speaker_ids, speed_ratios = inputs
52 | return super()._inference(input_ids, speaker_ids, speed_ratios)
53 |
54 | def _build(self):
55 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
56 | speaker_ids = tf.convert_to_tensor([0], tf.int32)
57 | speed_ratios = tf.convert_to_tensor([1.0], tf.float32)
58 | self([input_ids, speaker_ids, speed_ratios])
59 |
60 |
61 | class SavableTFFastSpeech2(TFFastSpeech2):
62 | def __init__(self, config, **kwargs):
63 | super().__init__(config, **kwargs)
64 |
65 | def call(self, inputs, training=False):
66 | input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios = inputs
67 | return super()._inference(
68 | input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios
69 | )
70 |
71 | def _build(self):
72 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
73 | speaker_ids = tf.convert_to_tensor([0], tf.int32)
74 | speed_ratios = tf.convert_to_tensor([1.0], tf.float32)
75 | f0_ratios = tf.convert_to_tensor([1.0], tf.float32)
76 | energy_ratios = tf.convert_to_tensor([1.0], tf.float32)
77 | self([input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios])
78 |
--------------------------------------------------------------------------------
/tensorflow_tts/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.losses.spectrogram import TFMelSpectrogram
2 | from tensorflow_tts.losses.stft import TFMultiResolutionSTFT
3 |
--------------------------------------------------------------------------------
/tensorflow_tts/losses/spectrogram.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Spectrogram-based loss modules."""
16 |
17 | import tensorflow as tf
18 |
19 |
20 | class TFMelSpectrogram(tf.keras.layers.Layer):
21 | """Mel Spectrogram loss."""
22 |
23 | def __init__(
24 | self,
25 | n_mels=80,
26 | f_min=80.0,
27 | f_max=7600,
28 | frame_length=1024,
29 | frame_step=256,
30 | fft_length=1024,
31 | sample_rate=16000,
32 | **kwargs
33 | ):
34 | """Initialize."""
35 | super().__init__(**kwargs)
36 | self.frame_length = frame_length
37 | self.frame_step = frame_step
38 | self.fft_length = fft_length
39 |
40 | self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
41 | n_mels, fft_length // 2 + 1, sample_rate, f_min, f_max
42 | )
43 |
44 | def _calculate_log_mels_spectrogram(self, signals):
45 | """Calculate forward propagation.
46 | Args:
47 | signals (Tensor): signal (B, T).
48 | Returns:
49 | Tensor: Mel spectrogram (B, T', 80)
50 | """
51 | stfts = tf.signal.stft(
52 | signals,
53 | frame_length=self.frame_length,
54 | frame_step=self.frame_step,
55 | fft_length=self.fft_length,
56 | )
57 | linear_spectrograms = tf.abs(stfts)
58 | mel_spectrograms = tf.tensordot(
59 | linear_spectrograms, self.linear_to_mel_weight_matrix, 1
60 | )
61 | mel_spectrograms.set_shape(
62 | linear_spectrograms.shape[:-1].concatenate(
63 | self.linear_to_mel_weight_matrix.shape[-1:]
64 | )
65 | )
66 | log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) # prevent nan.
67 | return log_mel_spectrograms
68 |
69 | def call(self, y, x):
70 | """Calculate forward propagation.
71 | Args:
72 | y (Tensor): Groundtruth signal (B, T).
73 | x (Tensor): Predicted signal (B, T).
74 | Returns:
75 | Tensor: Mean absolute Error Spectrogram Loss.
76 | """
77 | y_mels = self._calculate_log_mels_spectrogram(y)
78 | x_mels = self._calculate_log_mels_spectrogram(x)
79 | return tf.reduce_mean(
80 | tf.abs(y_mels - x_mels), axis=list(range(1, len(x_mels.shape)))
81 | )
82 |
--------------------------------------------------------------------------------
/tensorflow_tts/models/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.models.base_model import BaseModel
2 | from tensorflow_tts.models.fastspeech import TFFastSpeech
3 | from tensorflow_tts.models.fastspeech2 import TFFastSpeech2
4 | from tensorflow_tts.models.melgan import (
5 | TFMelGANDiscriminator,
6 | TFMelGANGenerator,
7 | TFMelGANMultiScaleDiscriminator,
8 | )
9 | from tensorflow_tts.models.mb_melgan import TFPQMF
10 | from tensorflow_tts.models.mb_melgan import TFMBMelGANGenerator
11 | from tensorflow_tts.models.hifigan import (
12 | TFHifiGANGenerator,
13 | TFHifiGANMultiPeriodDiscriminator,
14 | TFHifiGANPeriodDiscriminator
15 | )
16 | from tensorflow_tts.models.tacotron2 import TFTacotron2
17 | from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANGenerator
18 | from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANDiscriminator
19 |
--------------------------------------------------------------------------------
/tensorflow_tts/models/base_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Base Model for all model."""
16 |
17 | import tensorflow as tf
18 | import yaml
19 | import os
20 | import numpy as np
21 |
22 | from tensorflow_tts.utils.utils import MODEL_FILE_NAME, CONFIG_FILE_NAME
23 |
24 |
25 | class BaseModel(tf.keras.Model):
26 | def set_config(self, config):
27 | self.config = config
28 |
29 | def save_pretrained(self, saved_path):
30 | """Save config and weights to file"""
31 | os.makedirs(saved_path, exist_ok=True)
32 | self.config.save_pretrained(saved_path)
33 | self.save_weights(os.path.join(saved_path, MODEL_FILE_NAME))
34 |
--------------------------------------------------------------------------------
/tensorflow_tts/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.optimizers.adamweightdecay import AdamWeightDecay, WarmUp
2 | from tensorflow_tts.optimizers.gradient_accumulate import GradientAccumulator
3 |
--------------------------------------------------------------------------------
/tensorflow_tts/processor/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.processor.base_processor import BaseProcessor
2 |
3 | from tensorflow_tts.processor.ljspeech import LJSpeechProcessor
4 | from tensorflow_tts.processor.baker import BakerProcessor
5 | from tensorflow_tts.processor.kss import KSSProcessor
6 | from tensorflow_tts.processor.libritts import LibriTTSProcessor
7 | from tensorflow_tts.processor.thorsten import ThorstenProcessor
8 | from tensorflow_tts.processor.ljspeechu import LJSpeechUltimateProcessor
9 | from tensorflow_tts.processor.synpaflex import SynpaflexProcessor
10 | from tensorflow_tts.processor.jsut import JSUTProcessor
11 |
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/jsut_mapper.json:
--------------------------------------------------------------------------------
1 | {
2 | "symbol_to_id": {
3 | "pad": 0,
4 | "sil": 1,
5 | "N": 2,
6 | "a": 3,
7 | "b": 4,
8 | "by": 5,
9 | "ch": 6,
10 | "cl": 7,
11 | "d": 8,
12 | "dy": 9,
13 | "e": 10,
14 | "f": 11,
15 | "g": 12,
16 | "gy": 13,
17 | "h": 14,
18 | "hy": 15,
19 | "i": 16,
20 | "j": 17,
21 | "k": 18,
22 | "ky": 19,
23 | "m": 20,
24 | "my": 21,
25 | "n": 22,
26 | "ny": 23,
27 | "o": 24,
28 | "p": 25,
29 | "pau": 26,
30 | "py": 27,
31 | "r": 28,
32 | "ry": 29,
33 | "s": 30,
34 | "sh": 31,
35 | "t": 32,
36 | "ts": 33,
37 | "u": 34,
38 | "v": 35,
39 | "w": 36,
40 | "y": 37,
41 | "z": 38,
42 | "eos": 39
43 | },
44 | "id_to_symbol": {
45 | "0": "pad",
46 | "1": "sil",
47 | "2": "N",
48 | "3": "a",
49 | "4": "b",
50 | "5": "by",
51 | "6": "ch",
52 | "7": "cl",
53 | "8": "d",
54 | "9": "dy",
55 | "10": "e",
56 | "11": "f",
57 | "12": "g",
58 | "13": "gy",
59 | "14": "h",
60 | "15": "hy",
61 | "16": "i",
62 | "17": "j",
63 | "18": "k",
64 | "19": "ky",
65 | "20": "m",
66 | "21": "my",
67 | "22": "n",
68 | "23": "ny",
69 | "24": "o",
70 | "25": "p",
71 | "26": "pau",
72 | "27": "py",
73 | "28": "r",
74 | "29": "ry",
75 | "30": "s",
76 | "31": "sh",
77 | "32": "t",
78 | "33": "ts",
79 | "34": "u",
80 | "35": "v",
81 | "36": "w",
82 | "37": "y",
83 | "38": "z",
84 | "39": "eos"
85 | },
86 | "speakers_map": {
87 | "jsut": 0
88 | },
89 | "processor_name": "JSUTProcessor"
90 | }
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/kss_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 7, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "\u1100": 13, "\u1101": 14, "\u1102": 15, "\u1103": 16, "\u1104": 17, "\u1105": 18, "\u1106": 19, "\u1107": 20, "\u1108": 21, "\u1109": 22, "\u110a": 23, "\u110b": 24, "\u110c": 25, "\u110d": 26, "\u110e": 27, "\u110f": 28, "\u1110": 29, "\u1111": 30, "\u1112": 31, "\u1161": 32, "\u1162": 33, "\u1163": 34, "\u1164": 35, "\u1165": 36, "\u1166": 37, "\u1167": 38, "\u1168": 39, "\u1169": 40, "\u116a": 41, "\u116b": 42, "\u116c": 43, "\u116d": 44, "\u116e": 45, "\u116f": 46, "\u1170": 47, "\u1171": 48, "\u1172": 49, "\u1173": 50, "\u1174": 51, "\u1175": 52, "\u11a8": 53, "\u11a9": 54, "\u11aa": 55, "\u11ab": 56, "\u11ac": 57, "\u11ad": 58, "\u11ae": 59, "\u11af": 60, "\u11b0": 61, "\u11b1": 62, "\u11b2": 63, "\u11b3": 64, "\u11b4": 65, "\u11b5": 66, "\u11b6": 67, "\u11b7": 68, "\u11b8": 69, "\u11b9": 70, "\u11ba": 71, "\u11bb": 72, "\u11bc": 73, "\u11bd": 74, "\u11be": 75, "\u11bf": 76, "\u11c0": 77, "\u11c1": 78, "\u11c2": 79, "eos": 80}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "\u1100", "14": "\u1101", "15": "\u1102", "16": "\u1103", "17": "\u1104", "18": "\u1105", "19": "\u1106", "20": "\u1107", "21": "\u1108", "22": "\u1109", "23": "\u110a", "24": "\u110b", "25": "\u110c", "26": "\u110d", "27": "\u110e", "28": "\u110f", "29": "\u1110", "30": "\u1111", "31": "\u1112", "32": "\u1161", "33": "\u1162", "34": "\u1163", "35": "\u1164", "36": "\u1165", "37": "\u1166", "38": "\u1167", "39": "\u1168", "40": "\u1169", "41": "\u116a", "42": "\u116b", "43": "\u116c", "44": "\u116d", "45": "\u116e", "46": "\u116f", "47": "\u1170", "48": "\u1171", "49": "\u1172", "50": "\u1173", "51": "\u1174", "52": "\u1175", "53": "\u11a8", "54": "\u11a9", "55": "\u11aa", "56": "\u11ab", "57": "\u11ac", "58": "\u11ad", "59": "\u11ae", "60": "\u11af", "61": "\u11b0", "62": "\u11b1", "63": "\u11b2", "64": "\u11b3", "65": "\u11b4", "66": "\u11b5", "67": "\u11b6", "68": "\u11b7", "69": "\u11b8", "70": "\u11b9", "71": "\u11ba", "72": "\u11bb", "73": "\u11bc", "74": "\u11bd", "75": "\u11be", "76": "\u11bf", "77": "\u11c0", "78": "\u11c1", "79": "\u11c2", "80": "eos"}, "speakers_map": {"kss": 0}, "processor_name": "KSSProcessor"}
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/libritts_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"@": 0, "@": 1, "@": 2, "@": 3, "@AA0": 4, "@AA1": 5, "@AA2": 6, "@AE0": 7, "@AE1": 8, "@AE2": 9, "@AH0": 10, "@AH1": 11, "@AH2": 12, "@AO0": 13, "@AO1": 14, "@AO2": 15, "@AW0": 16, "@AW1": 17, "@AW2": 18, "@AY0": 19, "@AY1": 20, "@AY2": 21, "@B": 22, "@CH": 23, "@D": 24, "@DH": 25, "@EH0": 26, "@EH1": 27, "@EH2": 28, "@ER0": 29, "@ER1": 30, "@ER2": 31, "@EY0": 32, "@EY1": 33, "@EY2": 34, "@F": 35, "@G": 36, "@HH": 37, "@IH0": 38, "@IH1": 39, "@IH2": 40, "@IY0": 41, "@IY1": 42, "@IY2": 43, "@JH": 44, "@K": 45, "@L": 46, "@M": 47, "@N": 48, "@NG": 49, "@OW0": 50, "@OW1": 51, "@OW2": 52, "@OY0": 53, "@OY1": 54, "@OY2": 55, "@P": 56, "@R": 57, "@S": 58, "@SH": 59, "@T": 60, "@TH": 61, "@UH0": 62, "@UH1": 63, "@UH2": 64, "@UW": 65, "@UW0": 66, "@UW1": 67, "@UW2": 68, "@V": 69, "@W": 70, "@Y": 71, "@Z": 72, "@ZH": 73, "@SIL": 74, "@END": 75, "!": 76, "'": 77, "(": 78, ")": 79, ",": 80, ".": 81, ":": 82, ";": 83, "?": 84, " ": 85}, "id_to_symbol": {"0": "@", "1": "@", "2": "@", "3": "@", "4": "@AA0", "5": "@AA1", "6": "@AA2", "7": "@AE0", "8": "@AE1", "9": "@AE2", "10": "@AH0", "11": "@AH1", "12": "@AH2", "13": "@AO0", "14": "@AO1", "15": "@AO2", "16": "@AW0", "17": "@AW1", "18": "@AW2", "19": "@AY0", "20": "@AY1", "21": "@AY2", "22": "@B", "23": "@CH", "24": "@D", "25": "@DH", "26": "@EH0", "27": "@EH1", "28": "@EH2", "29": "@ER0", "30": "@ER1", "31": "@ER2", "32": "@EY0", "33": "@EY1", "34": "@EY2", "35": "@F", "36": "@G", "37": "@HH", "38": "@IH0", "39": "@IH1", "40": "@IH2", "41": "@IY0", "42": "@IY1", "43": "@IY2", "44": "@JH", "45": "@K", "46": "@L", "47": "@M", "48": "@N", "49": "@NG", "50": "@OW0", "51": "@OW1", "52": "@OW2", "53": "@OY0", "54": "@OY1", "55": "@OY2", "56": "@P", "57": "@R", "58": "@S", "59": "@SH", "60": "@T", "61": "@TH", "62": "@UH0", "63": "@UH1", "64": "@UH2", "65": "@UW", "66": "@UW0", "67": "@UW1", "68": "@UW2", "69": "@V", "70": "@W", "71": "@Y", "72": "@Z", "73": "@ZH", "74": "@SIL", "75": "@END", "76": "!", "77": "'", "78": "(", "79": ")", "80": ",", "81": ".", "82": ":", "83": ";", "84": "?", "85": " "}, "speakers_map": {"200": 0, "1841": 1, "3664": 2, "6454": 3, "8108": 4, "2416": 5, "4680": 6, "6147": 7, "412": 8, "2952": 9, "8838": 10, "2836": 11, "1263": 12, "5322": 13, "3830": 14, "7447": 15, "1116": 16, "8312": 17, "8123": 18, "250": 19}, "processor_name": "LibriTTSProcessor"}
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/ljspeech_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, ":": 8, ";": 9, "?": 10, " ": 11, "A": 12, "B": 13, "C": 14, "D": 15, "E": 16, "F": 17, "G": 18, "H": 19, "I": 20, "J": 21, "K": 22, "L": 23, "M": 24, "N": 25, "O": 26, "P": 27, "Q": 28, "R": 29, "S": 30, "T": 31, "U": 32, "V": 33, "W": 34, "X": 35, "Y": 36, "Z": 37, "a": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63, "@AA": 64, "@AA0": 65, "@AA1": 66, "@AA2": 67, "@AE": 68, "@AE0": 69, "@AE1": 70, "@AE2": 71, "@AH": 72, "@AH0": 73, "@AH1": 74, "@AH2": 75, "@AO": 76, "@AO0": 77, "@AO1": 78, "@AO2": 79, "@AW": 80, "@AW0": 81, "@AW1": 82, "@AW2": 83, "@AY": 84, "@AY0": 85, "@AY1": 86, "@AY2": 87, "@B": 88, "@CH": 89, "@D": 90, "@DH": 91, "@EH": 92, "@EH0": 93, "@EH1": 94, "@EH2": 95, "@ER": 96, "@ER0": 97, "@ER1": 98, "@ER2": 99, "@EY": 100, "@EY0": 101, "@EY1": 102, "@EY2": 103, "@F": 104, "@G": 105, "@HH": 106, "@IH": 107, "@IH0": 108, "@IH1": 109, "@IH2": 110, "@IY": 111, "@IY0": 112, "@IY1": 113, "@IY2": 114, "@JH": 115, "@K": 116, "@L": 117, "@M": 118, "@N": 119, "@NG": 120, "@OW": 121, "@OW0": 122, "@OW1": 123, "@OW2": 124, "@OY": 125, "@OY0": 126, "@OY1": 127, "@OY2": 128, "@P": 129, "@R": 130, "@S": 131, "@SH": 132, "@T": 133, "@TH": 134, "@UH": 135, "@UH0": 136, "@UH1": 137, "@UH2": 138, "@UW": 139, "@UW0": 140, "@UW1": 141, "@UW2": 142, "@V": 143, "@W": 144, "@Y": 145, "@Z": 146, "@ZH": 147, "eos": 148}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": " ", "12": "A", "13": "B", "14": "C", "15": "D", "16": "E", "17": "F", "18": "G", "19": "H", "20": "I", "21": "J", "22": "K", "23": "L", "24": "M", "25": "N", "26": "O", "27": "P", "28": "Q", "29": "R", "30": "S", "31": "T", "32": "U", "33": "V", "34": "W", "35": "X", "36": "Y", "37": "Z", "38": "a", "39": "b", "40": "c", "41": "d", "42": "e", "43": "f", "44": "g", "45": "h", "46": "i", "47": "j", "48": "k", "49": "l", "50": "m", "51": "n", "52": "o", "53": "p", "54": "q", "55": "r", "56": "s", "57": "t", "58": "u", "59": "v", "60": "w", "61": "x", "62": "y", "63": "z", "64": "@AA", "65": "@AA0", "66": "@AA1", "67": "@AA2", "68": "@AE", "69": "@AE0", "70": "@AE1", "71": "@AE2", "72": "@AH", "73": "@AH0", "74": "@AH1", "75": "@AH2", "76": "@AO", "77": "@AO0", "78": "@AO1", "79": "@AO2", "80": "@AW", "81": "@AW0", "82": "@AW1", "83": "@AW2", "84": "@AY", "85": "@AY0", "86": "@AY1", "87": "@AY2", "88": "@B", "89": "@CH", "90": "@D", "91": "@DH", "92": "@EH", "93": "@EH0", "94": "@EH1", "95": "@EH2", "96": "@ER", "97": "@ER0", "98": "@ER1", "99": "@ER2", "100": "@EY", "101": "@EY0", "102": "@EY1", "103": "@EY2", "104": "@F", "105": "@G", "106": "@HH", "107": "@IH", "108": "@IH0", "109": "@IH1", "110": "@IH2", "111": "@IY", "112": "@IY0", "113": "@IY1", "114": "@IY2", "115": "@JH", "116": "@K", "117": "@L", "118": "@M", "119": "@N", "120": "@NG", "121": "@OW", "122": "@OW0", "123": "@OW1", "124": "@OW2", "125": "@OY", "126": "@OY0", "127": "@OY1", "128": "@OY2", "129": "@P", "130": "@R", "131": "@S", "132": "@SH", "133": "@T", "134": "@TH", "135": "@UH", "136": "@UH0", "137": "@UH1", "138": "@UH2", "139": "@UW", "140": "@UW0", "141": "@UW1", "142": "@UW2", "143": "@V", "144": "@W", "145": "@Y", "146": "@Z", "147": "@ZH", "148": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechProcessor"}
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/ljspeechu_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, ":": 8, ";": 9, "?": 10, "@AA": 11, "@AA0": 12, "@AA1": 13, "@AA2": 14, "@AE": 15, "@AE0": 16, "@AE1": 17, "@AE2": 18, "@AH": 19, "@AH0": 20, "@AH1": 21, "@AH2": 22, "@AO": 23, "@AO0": 24, "@AO1": 25, "@AO2": 26, "@AW": 27, "@AW0": 28, "@AW1": 29, "@AW2": 30, "@AY": 31, "@AY0": 32, "@AY1": 33, "@AY2": 34, "@B": 35, "@CH": 36, "@D": 37, "@DH": 38, "@EH": 39, "@EH0": 40, "@EH1": 41, "@EH2": 42, "@ER": 43, "@ER0": 44, "@ER1": 45, "@ER2": 46, "@EY": 47, "@EY0": 48, "@EY1": 49, "@EY2": 50, "@F": 51, "@G": 52, "@HH": 53, "@IH": 54, "@IH0": 55, "@IH1": 56, "@IH2": 57, "@IY": 58, "@IY0": 59, "@IY1": 60, "@IY2": 61, "@JH": 62, "@K": 63, "@L": 64, "@M": 65, "@N": 66, "@NG": 67, "@OW": 68, "@OW0": 69, "@OW1": 70, "@OW2": 71, "@OY": 72, "@OY0": 73, "@OY1": 74, "@OY2": 75, "@P": 76, "@R": 77, "@S": 78, "@SH": 79, "@T": 80, "@TH": 81, "@UH": 82, "@UH0": 83, "@UH1": 84, "@UH2": 85, "@UW": 86, "@UW0": 87, "@UW1": 88, "@UW2": 89, "@V": 90, "@W": 91, "@Y": 92, "@Z": 93, "@ZH": 94, "eos": 95}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": "@AA", "12": "@AA0", "13": "@AA1", "14": "@AA2", "15": "@AE", "16": "@AE0", "17": "@AE1", "18": "@AE2", "19": "@AH", "20": "@AH0", "21": "@AH1", "22": "@AH2", "23": "@AO", "24": "@AO0", "25": "@AO1", "26": "@AO2", "27": "@AW", "28": "@AW0", "29": "@AW1", "30": "@AW2", "31": "@AY", "32": "@AY0", "33": "@AY1", "34": "@AY2", "35": "@B", "36": "@CH", "37": "@D", "38": "@DH", "39": "@EH", "40": "@EH0", "41": "@EH1", "42": "@EH2", "43": "@ER", "44": "@ER0", "45": "@ER1", "46": "@ER2", "47": "@EY", "48": "@EY0", "49": "@EY1", "50": "@EY2", "51": "@F", "52": "@G", "53": "@HH", "54": "@IH", "55": "@IH0", "56": "@IH1", "57": "@IH2", "58": "@IY", "59": "@IY0", "60": "@IY1", "61": "@IY2", "62": "@JH", "63": "@K", "64": "@L", "65": "@M", "66": "@N", "67": "@NG", "68": "@OW", "69": "@OW0", "70": "@OW1", "71": "@OW2", "72": "@OY", "73": "@OY0", "74": "@OY1", "75": "@OY2", "76": "@P", "77": "@R", "78": "@S", "79": "@SH", "80": "@T", "81": "@TH", "82": "@UH", "83": "@UH0", "84": "@UH1", "85": "@UH2", "86": "@UW", "87": "@UW0", "88": "@UW1", "89": "@UW2", "90": "@V", "91": "@W", "92": "@Y", "93": "@Z", "94": "@ZH", "95": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechUltimateProcessor"}
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/synpaflex_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "!": 1, "/": 2, "'": 3, "(": 4, ")": 5, ",": 6, "-": 7, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "A": 13, "B": 14, "C": 15, "D": 16, "E": 17, "F": 18, "G": 19, "H": 20, "I": 21, "J": 22, "K": 23, "L": 24, "M": 25, "N": 26, "O": 27, "P": 28, "Q": 29, "R": 30, "S": 31, "T": 32, "U": 33, "V": 34, "W": 35, "X": 36, "Y": 37, "Z": 38, "a": 39, "b": 40, "c": 41, "d": 42, "e": 43, "f": 44, "g": 45, "h": 46, "i": 47, "j": 48, "k": 49, "l": 50, "m": 51, "n": 52, "o": 53, "p": 54, "q": 55, "r": 56, "s": 57, "t": 58, "u": 59, "v": 60, "w": 61, "x": 62, "y": 63, "z": 64, "\u00e9": 65, "\u00e8": 66, "\u00e0": 67, "\u00f9": 68, "\u00e2": 69, "\u00ea": 70, "\u00ee": 71, "\u00f4": 72, "\u00fb": 73, "\u00e7": 74, "\u00e4": 75, "\u00eb": 76, "\u00ef": 77, "\u00f6": 78, "\u00fc": 79, "\u00ff": 80, "\u0153": 81, "\u00e6": 82, "eos": 83}, "id_to_symbol": {"0": "pad", "1": "!", "2": "/", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "A", "14": "B", "15": "C", "16": "D", "17": "E", "18": "F", "19": "G", "20": "H", "21": "I", "22": "J", "23": "K", "24": "L", "25": "M", "26": "N", "27": "O", "28": "P", "29": "Q", "30": "R", "31": "S", "32": "T", "33": "U", "34": "V", "35": "W", "36": "X", "37": "Y", "38": "Z", "39": "a", "40": "b", "41": "c", "42": "d", "43": "e", "44": "f", "45": "g", "46": "h", "47": "i", "48": "j", "49": "k", "50": "l", "51": "m", "52": "n", "53": "o", "54": "p", "55": "q", "56": "r", "57": "s", "58": "t", "59": "u", "60": "v", "61": "w", "62": "x", "63": "y", "64": "z", "65": "\u00e9", "66": "\u00e8", "67": "\u00e0", "68": "\u00f9", "69": "\u00e2", "70": "\u00ea", "71": "\u00ee", "72": "\u00f4", "73": "\u00fb", "74": "\u00e7", "75": "\u00e4", "76": "\u00eb", "77": "\u00ef", "78": "\u00f6", "79": "\u00fc", "80": "\u00ff", "81": "\u0153", "82": "\u00e6", "83": "eos"}, "speakers_map": {"synpaflex": 0}, "processor_name": "SynpaflexProcessor"}
2 |
--------------------------------------------------------------------------------
/tensorflow_tts/processor/pretrained/thorsten_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, "?": 8, " ": 9, "A": 10, "B": 11, "C": 12, "D": 13, "E": 14, "F": 15, "G": 16, "H": 17, "I": 18, "J": 19, "K": 20, "L": 21, "M": 22, "N": 23, "O": 24, "P": 25, "Q": 26, "R": 27, "S": 28, "T": 29, "U": 30, "V": 31, "W": 32, "X": 33, "Y": 34, "Z": 35, "a": 36, "b": 37, "c": 38, "d": 39, "e": 40, "f": 41, "g": 42, "h": 43, "i": 44, "j": 45, "k": 46, "l": 47, "m": 48, "n": 49, "o": 50, "p": 51, "q": 52, "r": 53, "s": 54, "t": 55, "u": 56, "v": 57, "w": 58, "x": 59, "y": 60, "z": 61, "eos": 62}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": "?", "9": " ", "10": "A", "11": "B", "12": "C", "13": "D", "14": "E", "15": "F", "16": "G", "17": "H", "18": "I", "19": "J", "20": "K", "21": "L", "22": "M", "23": "N", "24": "O", "25": "P", "26": "Q", "27": "R", "28": "S", "29": "T", "30": "U", "31": "V", "32": "W", "33": "X", "34": "Y", "35": "Z", "36": "a", "37": "b", "38": "c", "39": "d", "40": "e", "41": "f", "42": "g", "43": "h", "44": "i", "45": "j", "46": "k", "47": "l", "48": "m", "49": "n", "50": "o", "51": "p", "52": "q", "53": "r", "54": "s", "55": "t", "56": "u", "57": "v", "58": "w", "59": "x", "60": "y", "61": "z", "62": "eos"}, "speakers_map": {"thorsten": 0}, "processor_name": "ThorstenProcessor"}
--------------------------------------------------------------------------------
/tensorflow_tts/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.trainers.base_trainer import GanBasedTrainer, Seq2SeqBasedTrainer
2 |
--------------------------------------------------------------------------------
/tensorflow_tts/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from tensorflow_tts.utils.cleaners import (
2 | basic_cleaners,
3 | collapse_whitespace,
4 | convert_to_ascii,
5 | english_cleaners,
6 | expand_abbreviations,
7 | expand_numbers,
8 | lowercase,
9 | transliteration_cleaners,
10 | )
11 | from tensorflow_tts.utils.decoder import dynamic_decode
12 | from tensorflow_tts.utils.griffin_lim import TFGriffinLim, griffin_lim_lb
13 | from tensorflow_tts.utils.group_conv import GroupConv1D
14 | from tensorflow_tts.utils.number_norm import normalize_numbers
15 | from tensorflow_tts.utils.outliers import remove_outlier
16 | from tensorflow_tts.utils.strategy import (
17 | calculate_2d_loss,
18 | calculate_3d_loss,
19 | return_strategy,
20 | )
21 | from tensorflow_tts.utils.utils import find_files, MODEL_FILE_NAME, CONFIG_FILE_NAME, PROCESSOR_FILE_NAME, CACHE_DIRECTORY, LIBRARY_NAME
22 | from tensorflow_tts.utils.weight_norm import WeightNormalization
23 |
--------------------------------------------------------------------------------
/tensorflow_tts/utils/outliers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Outliers detection and remove."""
16 | import numpy as np
17 |
18 |
19 | def is_outlier(x, p25, p75):
20 | """Check if value is an outlier."""
21 | lower = p25 - 1.5 * (p75 - p25)
22 | upper = p75 + 1.5 * (p75 - p25)
23 | return x <= lower or x >= upper
24 |
25 |
26 | def remove_outlier(x, p_bottom: int = 25, p_top: int = 75):
27 | """Remove outlier from x."""
28 | p_bottom = np.percentile(x, p_bottom)
29 | p_top = np.percentile(x, p_top)
30 |
31 | indices_of_outliers = []
32 | for ind, value in enumerate(x):
33 | if is_outlier(value, p_bottom, p_top):
34 | indices_of_outliers.append(ind)
35 |
36 | x[indices_of_outliers] = 0.0
37 |
38 | # replace by mean f0.
39 | x[indices_of_outliers] = np.max(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/tensorflow_tts/utils/strategy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Strategy util functions"""
16 | import tensorflow as tf
17 |
18 |
19 | def return_strategy():
20 | physical_devices = tf.config.list_physical_devices("GPU")
21 | if len(physical_devices) == 0:
22 | return tf.distribute.OneDeviceStrategy(device="/cpu:0")
23 | elif len(physical_devices) == 1:
24 | return tf.distribute.OneDeviceStrategy(device="/gpu:0")
25 | else:
26 | return tf.distribute.MirroredStrategy()
27 |
28 |
29 | def calculate_3d_loss(y_gt, y_pred, loss_fn):
30 | """Calculate 3d loss, normally it's mel-spectrogram loss."""
31 | y_gt_T = tf.shape(y_gt)[1]
32 | y_pred_T = tf.shape(y_pred)[1]
33 |
34 | # there is a mismath length when training multiple GPU.
35 | # we need slice the longer tensor to make sure the loss
36 | # calculated correctly.
37 | if y_gt_T > y_pred_T:
38 | y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1])
39 | elif y_pred_T > y_gt_T:
40 | y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1])
41 |
42 | loss = loss_fn(y_gt, y_pred)
43 | if isinstance(loss, tuple) is False:
44 | loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
45 | else:
46 | loss = list(loss)
47 | for i in range(len(loss)):
48 | loss[i] = tf.reduce_mean(
49 | loss[i], list(range(1, len(loss[i].shape)))
50 | ) # shape = [B]
51 | return loss
52 |
53 |
54 | def calculate_2d_loss(y_gt, y_pred, loss_fn):
55 | """Calculate 2d loss, normally it's durrations/f0s/energys loss."""
56 | y_gt_T = tf.shape(y_gt)[1]
57 | y_pred_T = tf.shape(y_pred)[1]
58 |
59 | # there is a mismath length when training multiple GPU.
60 | # we need slice the longer tensor to make sure the loss
61 | # calculated correctly.
62 | if y_gt_T > y_pred_T:
63 | y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T])
64 | elif y_pred_T > y_gt_T:
65 | y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T])
66 |
67 | loss = loss_fn(y_gt, y_pred)
68 | if isinstance(loss, tuple) is False:
69 | loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
70 | else:
71 | loss = list(loss)
72 | for i in range(len(loss)):
73 | loss[i] = tf.reduce_mean(
74 | loss[i], list(range(1, len(loss[i].shape)))
75 | ) # shape = [B]
76 |
77 | return loss
78 |
--------------------------------------------------------------------------------
/tensorflow_tts/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2019 Tomoki Hayashi
4 | # MIT License (https://opensource.org/licenses/MIT)
5 | """Utility functions."""
6 |
7 | import fnmatch
8 | import os
9 | import re
10 | import tempfile
11 | from pathlib import Path
12 |
13 | import tensorflow as tf
14 |
15 | MODEL_FILE_NAME = "model.h5"
16 | CONFIG_FILE_NAME = "config.yml"
17 | PROCESSOR_FILE_NAME = "processor.json"
18 | LIBRARY_NAME = "tensorflow_tts"
19 | CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)
20 |
21 |
22 | def find_files(root_dir, query="*.wav", include_root_dir=True):
23 | """Find files recursively.
24 | Args:
25 | root_dir (str): Root root_dir to find.
26 | query (str): Query to find.
27 | include_root_dir (bool): If False, root_dir name is not included.
28 | Returns:
29 | list: List of found filenames.
30 | """
31 | files = []
32 | for root, _, filenames in os.walk(root_dir, followlinks=True):
33 | for filename in fnmatch.filter(filenames, query):
34 | files.append(os.path.join(root, filename))
35 | if not include_root_dir:
36 | files = [file_.replace(root_dir + "/", "") for file_ in files]
37 |
38 | return files
39 |
40 |
41 | def _path_requires_gfile(filepath):
42 | """Checks if the given path requires use of GFile API.
43 |
44 | Args:
45 | filepath (str): Path to check.
46 | Returns:
47 | bool: True if the given path needs GFile API to access, such as
48 | "s3://some/path" and "gs://some/path".
49 | """
50 | # If the filepath contains a protocol (e.g. "gs://"), it should be handled
51 | # using TensorFlow GFile API.
52 | return bool(re.match(r"^[a-z]+://", filepath))
53 |
54 |
55 | def save_weights(model, filepath):
56 | """Save model weights.
57 |
58 | Same as model.save_weights(filepath), but supports saving to S3 or GCS
59 | buckets using TensorFlow GFile API.
60 |
61 | Args:
62 | model (tf.keras.Model): Model to save.
63 | filepath (str): Path to save the model weights to.
64 | """
65 | if not _path_requires_gfile(filepath):
66 | model.save_weights(filepath)
67 | return
68 |
69 | # Save to a local temp file and copy to the desired path using GFile API.
70 | _, ext = os.path.splitext(filepath)
71 | with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
72 | model.save_weights(temp_file.name)
73 | # To preserve the original semantics, we need to overwrite the target
74 | # file.
75 | tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)
76 |
77 |
78 | def load_weights(model, filepath):
79 | """Load model weights.
80 |
81 | Same as model.load_weights(filepath), but supports loading from S3 or GCS
82 | buckets using TensorFlow GFile API.
83 |
84 | Args:
85 | model (tf.keras.Model): Model to load weights to.
86 | filepath (str): Path to the weights file.
87 | """
88 | if not _path_requires_gfile(filepath):
89 | model.load_weights(filepath)
90 | return
91 |
92 | # Make a local copy and load it.
93 | _, ext = os.path.splitext(filepath)
94 | with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
95 | # The target temp_file should be created above, so we need to overwrite.
96 | tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)
97 | model.load_weights(temp_file.name)
98 |
--------------------------------------------------------------------------------
/test/files/kss_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 7, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 8, ":": 9, ";": 10, "?": 11, " ": 12, "\u1100": 13, "\u1101": 14, "\u1102": 15, "\u1103": 16, "\u1104": 17, "\u1105": 18, "\u1106": 19, "\u1107": 20, "\u1108": 21, "\u1109": 22, "\u110a": 23, "\u110b": 24, "\u110c": 25, "\u110d": 26, "\u110e": 27, "\u110f": 28, "\u1110": 29, "\u1111": 30, "\u1112": 31, "\u1161": 32, "\u1162": 33, "\u1163": 34, "\u1164": 35, "\u1165": 36, "\u1166": 37, "\u1167": 38, "\u1168": 39, "\u1169": 40, "\u116a": 41, "\u116b": 42, "\u116c": 43, "\u116d": 44, "\u116e": 45, "\u116f": 46, "\u1170": 47, "\u1171": 48, "\u1172": 49, "\u1173": 50, "\u1174": 51, "\u1175": 52, "\u11a8": 53, "\u11a9": 54, "\u11aa": 55, "\u11ab": 56, "\u11ac": 57, "\u11ad": 58, "\u11ae": 59, "\u11af": 60, "\u11b0": 61, "\u11b1": 62, "\u11b2": 63, "\u11b3": 64, "\u11b4": 65, "\u11b5": 66, "\u11b6": 67, "\u11b7": 68, "\u11b8": 69, "\u11b9": 70, "\u11ba": 71, "\u11bb": 72, "\u11bc": 73, "\u11bd": 74, "\u11be": 75, "\u11bf": 76, "\u11c0": 77, "\u11c1": 78, "\u11c2": 79, "eos": 80}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": "-", "8": ".", "9": ":", "10": ";", "11": "?", "12": " ", "13": "\u1100", "14": "\u1101", "15": "\u1102", "16": "\u1103", "17": "\u1104", "18": "\u1105", "19": "\u1106", "20": "\u1107", "21": "\u1108", "22": "\u1109", "23": "\u110a", "24": "\u110b", "25": "\u110c", "26": "\u110d", "27": "\u110e", "28": "\u110f", "29": "\u1110", "30": "\u1111", "31": "\u1112", "32": "\u1161", "33": "\u1162", "34": "\u1163", "35": "\u1164", "36": "\u1165", "37": "\u1166", "38": "\u1167", "39": "\u1168", "40": "\u1169", "41": "\u116a", "42": "\u116b", "43": "\u116c", "44": "\u116d", "45": "\u116e", "46": "\u116f", "47": "\u1170", "48": "\u1171", "49": "\u1172", "50": "\u1173", "51": "\u1174", "52": "\u1175", "53": "\u11a8", "54": "\u11a9", "55": "\u11aa", "56": "\u11ab", "57": "\u11ac", "58": "\u11ad", "59": "\u11ae", "60": "\u11af", "61": "\u11b0", "62": "\u11b1", "63": "\u11b2", "64": "\u11b3", "65": "\u11b4", "66": "\u11b5", "67": "\u11b6", "68": "\u11b7", "69": "\u11b8", "70": "\u11b9", "71": "\u11ba", "72": "\u11bb", "73": "\u11bc", "74": "\u11bd", "75": "\u11be", "76": "\u11bf", "77": "\u11c0", "78": "\u11c1", "79": "\u11c2", "80": "eos"}, "speakers_map": {"kss": 0}, "processor_name": "KSSProcessor"}
--------------------------------------------------------------------------------
/test/files/libritts_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"@": 0, "@": 1, "@": 2, "@": 3, "@AA0": 4, "@AA1": 5, "@AA2": 6, "@AE0": 7, "@AE1": 8, "@AE2": 9, "@AH0": 10, "@AH1": 11, "@AH2": 12, "@AO0": 13, "@AO1": 14, "@AO2": 15, "@AW0": 16, "@AW1": 17, "@AW2": 18, "@AY0": 19, "@AY1": 20, "@AY2": 21, "@B": 22, "@CH": 23, "@D": 24, "@DH": 25, "@EH0": 26, "@EH1": 27, "@EH2": 28, "@ER0": 29, "@ER1": 30, "@ER2": 31, "@EY0": 32, "@EY1": 33, "@EY2": 34, "@F": 35, "@G": 36, "@HH": 37, "@IH0": 38, "@IH1": 39, "@IH2": 40, "@IY0": 41, "@IY1": 42, "@IY2": 43, "@JH": 44, "@K": 45, "@L": 46, "@M": 47, "@N": 48, "@NG": 49, "@OW0": 50, "@OW1": 51, "@OW2": 52, "@OY0": 53, "@OY1": 54, "@OY2": 55, "@P": 56, "@R": 57, "@S": 58, "@SH": 59, "@T": 60, "@TH": 61, "@UH0": 62, "@UH1": 63, "@UH2": 64, "@UW": 65, "@UW0": 66, "@UW1": 67, "@UW2": 68, "@V": 69, "@W": 70, "@Y": 71, "@Z": 72, "@ZH": 73, "@SIL": 74, "@END": 75, "!": 76, "'": 77, "(": 78, ")": 79, ",": 80, ".": 81, ":": 82, ";": 83, "?": 84, " ": 85}, "id_to_symbol": {"0": "@", "1": "@", "2": "@", "3": "@", "4": "@AA0", "5": "@AA1", "6": "@AA2", "7": "@AE0", "8": "@AE1", "9": "@AE2", "10": "@AH0", "11": "@AH1", "12": "@AH2", "13": "@AO0", "14": "@AO1", "15": "@AO2", "16": "@AW0", "17": "@AW1", "18": "@AW2", "19": "@AY0", "20": "@AY1", "21": "@AY2", "22": "@B", "23": "@CH", "24": "@D", "25": "@DH", "26": "@EH0", "27": "@EH1", "28": "@EH2", "29": "@ER0", "30": "@ER1", "31": "@ER2", "32": "@EY0", "33": "@EY1", "34": "@EY2", "35": "@F", "36": "@G", "37": "@HH", "38": "@IH0", "39": "@IH1", "40": "@IH2", "41": "@IY0", "42": "@IY1", "43": "@IY2", "44": "@JH", "45": "@K", "46": "@L", "47": "@M", "48": "@N", "49": "@NG", "50": "@OW0", "51": "@OW1", "52": "@OW2", "53": "@OY0", "54": "@OY1", "55": "@OY2", "56": "@P", "57": "@R", "58": "@S", "59": "@SH", "60": "@T", "61": "@TH", "62": "@UH0", "63": "@UH1", "64": "@UH2", "65": "@UW", "66": "@UW0", "67": "@UW1", "68": "@UW2", "69": "@V", "70": "@W", "71": "@Y", "72": "@Z", "73": "@ZH", "74": "@SIL", "75": "@END", "76": "!", "77": "'", "78": "(", "79": ")", "80": ",", "81": ".", "82": ":", "83": ";", "84": "?", "85": " "}, "speakers_map": {"200": 0, "1841": 1, "3664": 2, "6454": 3, "8108": 4, "2416": 5, "4680": 6, "6147": 7, "412": 8, "2952": 9, "8838": 10, "2836": 11, "1263": 12, "5322": 13, "3830": 14, "7447": 15, "1116": 16, "8312": 17, "8123": 18, "250": 19}, "processor_name": "LibriTTSProcessor"}
--------------------------------------------------------------------------------
/test/files/ljspeech_mapper.json:
--------------------------------------------------------------------------------
1 | {"symbol_to_id": {"pad": 0, "-": 1, "!": 2, "'": 3, "(": 4, ")": 5, ",": 6, ".": 7, ":": 8, ";": 9, "?": 10, " ": 11, "A": 12, "B": 13, "C": 14, "D": 15, "E": 16, "F": 17, "G": 18, "H": 19, "I": 20, "J": 21, "K": 22, "L": 23, "M": 24, "N": 25, "O": 26, "P": 27, "Q": 28, "R": 29, "S": 30, "T": 31, "U": 32, "V": 33, "W": 34, "X": 35, "Y": 36, "Z": 37, "a": 38, "b": 39, "c": 40, "d": 41, "e": 42, "f": 43, "g": 44, "h": 45, "i": 46, "j": 47, "k": 48, "l": 49, "m": 50, "n": 51, "o": 52, "p": 53, "q": 54, "r": 55, "s": 56, "t": 57, "u": 58, "v": 59, "w": 60, "x": 61, "y": 62, "z": 63, "@AA": 64, "@AA0": 65, "@AA1": 66, "@AA2": 67, "@AE": 68, "@AE0": 69, "@AE1": 70, "@AE2": 71, "@AH": 72, "@AH0": 73, "@AH1": 74, "@AH2": 75, "@AO": 76, "@AO0": 77, "@AO1": 78, "@AO2": 79, "@AW": 80, "@AW0": 81, "@AW1": 82, "@AW2": 83, "@AY": 84, "@AY0": 85, "@AY1": 86, "@AY2": 87, "@B": 88, "@CH": 89, "@D": 90, "@DH": 91, "@EH": 92, "@EH0": 93, "@EH1": 94, "@EH2": 95, "@ER": 96, "@ER0": 97, "@ER1": 98, "@ER2": 99, "@EY": 100, "@EY0": 101, "@EY1": 102, "@EY2": 103, "@F": 104, "@G": 105, "@HH": 106, "@IH": 107, "@IH0": 108, "@IH1": 109, "@IH2": 110, "@IY": 111, "@IY0": 112, "@IY1": 113, "@IY2": 114, "@JH": 115, "@K": 116, "@L": 117, "@M": 118, "@N": 119, "@NG": 120, "@OW": 121, "@OW0": 122, "@OW1": 123, "@OW2": 124, "@OY": 125, "@OY0": 126, "@OY1": 127, "@OY2": 128, "@P": 129, "@R": 130, "@S": 131, "@SH": 132, "@T": 133, "@TH": 134, "@UH": 135, "@UH0": 136, "@UH1": 137, "@UH2": 138, "@UW": 139, "@UW0": 140, "@UW1": 141, "@UW2": 142, "@V": 143, "@W": 144, "@Y": 145, "@Z": 146, "@ZH": 147, "eos": 148}, "id_to_symbol": {"0": "pad", "1": "-", "2": "!", "3": "'", "4": "(", "5": ")", "6": ",", "7": ".", "8": ":", "9": ";", "10": "?", "11": " ", "12": "A", "13": "B", "14": "C", "15": "D", "16": "E", "17": "F", "18": "G", "19": "H", "20": "I", "21": "J", "22": "K", "23": "L", "24": "M", "25": "N", "26": "O", "27": "P", "28": "Q", "29": "R", "30": "S", "31": "T", "32": "U", "33": "V", "34": "W", "35": "X", "36": "Y", "37": "Z", "38": "a", "39": "b", "40": "c", "41": "d", "42": "e", "43": "f", "44": "g", "45": "h", "46": "i", "47": "j", "48": "k", "49": "l", "50": "m", "51": "n", "52": "o", "53": "p", "54": "q", "55": "r", "56": "s", "57": "t", "58": "u", "59": "v", "60": "w", "61": "x", "62": "y", "63": "z", "64": "@AA", "65": "@AA0", "66": "@AA1", "67": "@AA2", "68": "@AE", "69": "@AE0", "70": "@AE1", "71": "@AE2", "72": "@AH", "73": "@AH0", "74": "@AH1", "75": "@AH2", "76": "@AO", "77": "@AO0", "78": "@AO1", "79": "@AO2", "80": "@AW", "81": "@AW0", "82": "@AW1", "83": "@AW2", "84": "@AY", "85": "@AY0", "86": "@AY1", "87": "@AY2", "88": "@B", "89": "@CH", "90": "@D", "91": "@DH", "92": "@EH", "93": "@EH0", "94": "@EH1", "95": "@EH2", "96": "@ER", "97": "@ER0", "98": "@ER1", "99": "@ER2", "100": "@EY", "101": "@EY0", "102": "@EY1", "103": "@EY2", "104": "@F", "105": "@G", "106": "@HH", "107": "@IH", "108": "@IH0", "109": "@IH1", "110": "@IH2", "111": "@IY", "112": "@IY0", "113": "@IY1", "114": "@IY2", "115": "@JH", "116": "@K", "117": "@L", "118": "@M", "119": "@N", "120": "@NG", "121": "@OW", "122": "@OW0", "123": "@OW1", "124": "@OW2", "125": "@OY", "126": "@OY0", "127": "@OY1", "128": "@OY2", "129": "@P", "130": "@R", "131": "@S", "132": "@SH", "133": "@T", "134": "@TH", "135": "@UH", "136": "@UH0", "137": "@UH1", "138": "@UH2", "139": "@UW", "140": "@UW0", "141": "@UW1", "142": "@UW2", "143": "@V", "144": "@W", "145": "@Y", "146": "@Z", "147": "@ZH", "148": "eos"}, "speakers_map": {"ljspeech": 0}, "processor_name": "LJSpeechProcessor"}
--------------------------------------------------------------------------------
/test/files/mapper.json:
--------------------------------------------------------------------------------
1 | {
2 | "speakers_map": {
3 | "test_one": 0,
4 | "test_two": 1
5 | },
6 | "symbol_to_id": {
7 | "a": 0,
8 | "b": 1,
9 | "@ph": 2
10 | },
11 | "id_to_symbol": {
12 | "0": "a",
13 | "1": "b",
14 | "2": "@ph"
15 | },
16 | "processor_name": "KSSProcessor"
17 | }
--------------------------------------------------------------------------------
/test/files/train.txt:
--------------------------------------------------------------------------------
1 | speaker1/libri1.wav|in fact its just a test.|One
2 | speaker2/libri2|in fact its just a speaker number one.|Two
3 |
--------------------------------------------------------------------------------
/test/test_auto.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 |
19 | import pytest
20 | import tensorflow as tf
21 |
22 | from tensorflow_tts.inference import AutoConfig
23 | from tensorflow_tts.inference import AutoProcessor
24 | from tensorflow_tts.inference import TFAutoModel
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
27 |
28 | logging.basicConfig(
29 | level=logging.DEBUG,
30 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
31 | )
32 |
33 |
34 | @pytest.mark.parametrize(
35 | "mapper_path",
36 | [
37 | "./test/files/baker_mapper.json",
38 | "./test/files/kss_mapper.json",
39 | "./test/files/libritts_mapper.json",
40 | "./test/files/ljspeech_mapper.json",
41 | ]
42 | )
43 | def test_auto_processor(mapper_path):
44 | processor = AutoProcessor.from_pretrained(pretrained_path=mapper_path)
45 | processor.save_pretrained("./test_saved")
46 | processor = AutoProcessor.from_pretrained("./test_saved/processor.json")
47 |
48 |
49 | @pytest.mark.parametrize(
50 | "config_path",
51 | [
52 | "./examples/fastspeech/conf/fastspeech.v1.yaml",
53 | "./examples/fastspeech/conf/fastspeech.v3.yaml",
54 | "./examples/fastspeech2/conf/fastspeech2.v1.yaml",
55 | "./examples/fastspeech2/conf/fastspeech2.v2.yaml",
56 | "./examples/fastspeech2/conf/fastspeech2.kss.v1.yaml",
57 | "./examples/fastspeech2/conf/fastspeech2.kss.v2.yaml",
58 | "./examples/melgan/conf/melgan.v1.yaml",
59 | "./examples/melgan_stft/conf/melgan_stft.v1.yaml",
60 | "./examples/multiband_melgan/conf/multiband_melgan.v1.yaml",
61 | "./examples/tacotron2/conf/tacotron2.v1.yaml",
62 | "./examples/tacotron2/conf/tacotron2.kss.v1.yaml",
63 | "./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml",
64 | "./examples/hifigan/conf/hifigan.v1.yaml",
65 | "./examples/hifigan/conf/hifigan.v2.yaml",
66 | ]
67 | )
68 | def test_auto_model(config_path):
69 | config = AutoConfig.from_pretrained(pretrained_path=config_path)
70 | model = TFAutoModel.from_pretrained(pretrained_path=None, config=config)
71 |
72 | # test save_pretrained
73 | config.save_pretrained("./test_saved")
74 | model.save_pretrained("./test_saved")
75 |
76 | # test from_pretrained
77 | config = AutoConfig.from_pretrained("./test_saved/config.yml")
78 | model = TFAutoModel.from_pretrained("./test_saved/model.h5", config=config)
79 |
--------------------------------------------------------------------------------
/test/test_fastspeech.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 |
19 | import pytest
20 | import tensorflow as tf
21 |
22 | from tensorflow_tts.configs import FastSpeechConfig
23 | from tensorflow_tts.models import TFFastSpeech
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
26 |
27 | logging.basicConfig(
28 | level=logging.DEBUG,
29 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
30 | )
31 |
32 |
33 | @pytest.mark.parametrize("new_size", [100, 200, 300])
34 | def test_fastspeech_resize_positional_embeddings(new_size):
35 | config = FastSpeechConfig()
36 | fastspeech = TFFastSpeech(config, name="fastspeech")
37 | fastspeech._build()
38 | fastspeech.save_weights("./test.h5")
39 | fastspeech.resize_positional_embeddings(new_size)
40 | fastspeech.load_weights("./test.h5", by_name=True, skip_mismatch=True)
41 |
42 |
43 | @pytest.mark.parametrize("num_hidden_layers,n_speakers", [(2, 1), (3, 2), (4, 3)])
44 | def test_fastspeech_trainable(num_hidden_layers, n_speakers):
45 | config = FastSpeechConfig(
46 | encoder_num_hidden_layers=num_hidden_layers,
47 | decoder_num_hidden_layers=num_hidden_layers + 1,
48 | n_speakers=n_speakers,
49 | )
50 |
51 | fastspeech = TFFastSpeech(config, name="fastspeech")
52 | optimizer = tf.keras.optimizers.Adam(lr=0.001)
53 |
54 | # fake inputs
55 | input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
56 | attention_mask = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
57 | speaker_ids = tf.convert_to_tensor([0], tf.int32)
58 | duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
59 |
60 | mel_gts = tf.random.uniform(shape=[1, 10, 80], dtype=tf.float32)
61 |
62 | @tf.function
63 | def one_step_training():
64 | with tf.GradientTape() as tape:
65 | mel_outputs_before, _, duration_outputs = fastspeech(
66 | input_ids, speaker_ids, duration_gts, training=True
67 | )
68 | duration_loss = tf.keras.losses.MeanSquaredError()(
69 | duration_gts, duration_outputs
70 | )
71 | mel_loss = tf.keras.losses.MeanSquaredError()(mel_gts, mel_outputs_before)
72 | loss = duration_loss + mel_loss
73 | gradients = tape.gradient(loss, fastspeech.trainable_variables)
74 | optimizer.apply_gradients(zip(gradients, fastspeech.trainable_variables))
75 |
76 | tf.print(loss)
77 |
78 | import time
79 |
80 | for i in range(2):
81 | if i == 1:
82 | start = time.time()
83 | one_step_training()
84 | print(time.time() - start)
85 |
--------------------------------------------------------------------------------
/test/test_mb_melgan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensorflow as tf
17 |
18 | import logging
19 | import os
20 |
21 | import numpy as np
22 | import pytest
23 |
24 | from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig
25 | from tensorflow_tts.models import TFPQMF, TFMelGANGenerator
26 |
27 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
28 |
29 | logging.basicConfig(
30 | level=logging.DEBUG,
31 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
32 | )
33 |
34 |
35 | def make_multi_band_melgan_generator_args(**kwargs):
36 | defaults = dict(
37 | out_channels=1,
38 | kernel_size=7,
39 | filters=512,
40 | use_bias=True,
41 | upsample_scales=[8, 8, 2, 2],
42 | stack_kernel_size=3,
43 | stacks=3,
44 | nonlinear_activation="LeakyReLU",
45 | nonlinear_activation_params={"alpha": 0.2},
46 | padding_type="REFLECT",
47 | subbands=4,
48 | tabs=62,
49 | cutoff_ratio=0.15,
50 | beta=9.0,
51 | )
52 | defaults.update(kwargs)
53 | return defaults
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "dict_g",
58 | [
59 | {"subbands": 4, "upsample_scales": [2, 4, 8], "stacks": 4, "out_channels": 4},
60 | {"subbands": 4, "upsample_scales": [4, 4, 4], "stacks": 5, "out_channels": 4},
61 | ],
62 | )
63 | def test_multi_band_melgan(dict_g):
64 | args_g = make_multi_band_melgan_generator_args(**dict_g)
65 | args_g = MultiBandMelGANGeneratorConfig(**args_g)
66 | generator = TFMelGANGenerator(args_g, name="multi_band_melgan")
67 | generator._build()
68 |
69 | pqmf = TFPQMF(args_g, name="pqmf")
70 |
71 | fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
72 | fake_y = tf.random.uniform(shape=[1, 100 * 256, 1], dtype=tf.float32)
73 | y_hat_subbands = generator(fake_mels)
74 |
75 | y_hat = pqmf.synthesis(y_hat_subbands)
76 | y_subbands = pqmf.analysis(fake_y)
77 |
78 | assert np.shape(y_subbands) == np.shape(y_hat_subbands)
79 | assert np.shape(fake_y) == np.shape(y_hat)
80 |
--------------------------------------------------------------------------------
/test/test_melgan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 |
19 | import pytest
20 | import tensorflow as tf
21 |
22 | from tensorflow_tts.configs import MelGANDiscriminatorConfig, MelGANGeneratorConfig
23 | from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
26 |
27 | logging.basicConfig(
28 | level=logging.DEBUG,
29 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
30 | )
31 |
32 |
33 | def make_melgan_generator_args(**kwargs):
34 | defaults = dict(
35 | out_channels=1,
36 | kernel_size=7,
37 | filters=512,
38 | use_bias=True,
39 | upsample_scales=[8, 8, 2, 2],
40 | stack_kernel_size=3,
41 | stacks=3,
42 | nonlinear_activation="LeakyReLU",
43 | nonlinear_activation_params={"alpha": 0.2},
44 | padding_type="REFLECT",
45 | )
46 | defaults.update(kwargs)
47 | return defaults
48 |
49 |
50 | def make_melgan_discriminator_args(**kwargs):
51 | defaults = dict(
52 | out_channels=1,
53 | scales=3,
54 | downsample_pooling="AveragePooling1D",
55 | downsample_pooling_params={"pool_size": 4, "strides": 2,},
56 | kernel_sizes=[5, 3],
57 | filters=16,
58 | max_downsample_filters=1024,
59 | use_bias=True,
60 | downsample_scales=[4, 4, 4, 4],
61 | nonlinear_activation="LeakyReLU",
62 | nonlinear_activation_params={"alpha": 0.2},
63 | padding_type="REFLECT",
64 | )
65 | defaults.update(kwargs)
66 | return defaults
67 |
68 |
69 | @pytest.mark.parametrize(
70 | "dict_g, dict_d, dict_loss",
71 | [
72 | ({}, {}, {}),
73 | ({"kernel_size": 3}, {}, {}),
74 | ({"filters": 1024}, {}, {}),
75 | ({"stack_kernel_size": 5}, {}, {}),
76 | ({"stack_kernel_size": 5, "stacks": 2}, {}, {}),
77 | ({"upsample_scales": [4, 4, 4, 4]}, {}, {}),
78 | ({"upsample_scales": [8, 8, 2, 2]}, {}, {}),
79 | ({"filters": 1024, "upsample_scales": [8, 8, 2, 2]}, {}, {}),
80 | ],
81 | )
82 | def test_melgan_trainable(dict_g, dict_d, dict_loss):
83 | batch_size = 4
84 | batch_length = 4096
85 | args_g = make_melgan_generator_args(**dict_g)
86 | args_d = make_melgan_discriminator_args(**dict_d)
87 |
88 | args_g = MelGANGeneratorConfig(**args_g)
89 | args_d = MelGANDiscriminatorConfig(**args_d)
90 |
91 | generator = TFMelGANGenerator(args_g)
92 | discriminator = TFMelGANMultiScaleDiscriminator(args_d)
93 |
--------------------------------------------------------------------------------
/test/test_melgan_layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 Minh Nguyen (@dathudeptrai)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 |
19 | import numpy as np
20 | import pytest
21 | import tensorflow as tf
22 |
23 | from tensorflow_tts.models.melgan import (
24 | TFConvTranspose1d,
25 | TFReflectionPad1d,
26 | TFResidualStack,
27 | )
28 |
29 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
30 |
31 | logging.basicConfig(
32 | level=logging.DEBUG,
33 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
34 | )
35 |
36 |
37 | @pytest.mark.parametrize("padding_size", [(3), (5)])
38 | def test_padding(padding_size):
39 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32)
40 | out = TFReflectionPad1d(padding_size=padding_size)(fake_input_1d)
41 | assert np.array_equal(
42 | tf.keras.backend.int_shape(out), [4, 8000 + 2 * padding_size, 256]
43 | )
44 |
45 |
46 | @pytest.mark.parametrize(
47 | "filters,kernel_size,strides,padding,is_weight_norm",
48 | [(512, 40, 8, "same", False), (768, 15, 8, "same", True)],
49 | )
50 | def test_convtranpose1d(filters, kernel_size, strides, padding, is_weight_norm):
51 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32)
52 | conv1d_transpose = TFConvTranspose1d(
53 | filters=filters,
54 | kernel_size=kernel_size,
55 | strides=strides,
56 | padding=padding,
57 | is_weight_norm=is_weight_norm,
58 | initializer_seed=42,
59 | )
60 | out = conv1d_transpose(fake_input_1d)
61 | assert np.array_equal(tf.keras.backend.int_shape(out), [4, 8000 * strides, filters])
62 |
63 |
64 | @pytest.mark.parametrize(
65 | "kernel_size,filters,dilation_rate,use_bias,nonlinear_activation,nonlinear_activation_params,is_weight_norm",
66 | [
67 | (3, 256, 1, True, "LeakyReLU", {"alpha": 0.3}, True),
68 | (3, 256, 3, True, "ReLU", {}, False),
69 | ],
70 | )
71 | def test_residualblock(
72 | kernel_size,
73 | filters,
74 | dilation_rate,
75 | use_bias,
76 | nonlinear_activation,
77 | nonlinear_activation_params,
78 | is_weight_norm,
79 | ):
80 | fake_input_1d = tf.random.normal(shape=[4, 8000, 256], dtype=tf.float32)
81 | residual_block = TFResidualStack(
82 | kernel_size=kernel_size,
83 | filters=filters,
84 | dilation_rate=dilation_rate,
85 | use_bias=use_bias,
86 | nonlinear_activation=nonlinear_activation,
87 | nonlinear_activation_params=nonlinear_activation_params,
88 | is_weight_norm=is_weight_norm,
89 | initializer_seed=42,
90 | )
91 | out = residual_block(fake_input_1d)
92 | assert np.array_equal(tf.keras.backend.int_shape(out), [4, 8000, filters])
93 |
--------------------------------------------------------------------------------
/test/test_parallel_wavegan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright 2020 TensorFlowTTS Team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 |
19 | import pytest
20 | import tensorflow as tf
21 |
22 | from tensorflow_tts.configs import (
23 | ParallelWaveGANGeneratorConfig,
24 | ParallelWaveGANDiscriminatorConfig,
25 | )
26 | from tensorflow_tts.models import (
27 | TFParallelWaveGANGenerator,
28 | TFParallelWaveGANDiscriminator,
29 | )
30 |
31 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
32 |
33 | logging.basicConfig(
34 | level=logging.DEBUG,
35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
36 | )
37 |
38 |
39 | def make_pwgan_generator_args(**kwargs):
40 | defaults = dict(
41 | out_channels=1,
42 | kernel_size=3,
43 | n_layers=30,
44 | stacks=3,
45 | residual_channels=64,
46 | gate_channels=128,
47 | skip_channels=64,
48 | aux_channels=80,
49 | aux_context_window=2,
50 | dropout_rate=0.0,
51 | use_bias=True,
52 | use_causal_conv=False,
53 | upsample_conditional_features=True,
54 | upsample_params={"upsample_scales": [4, 4, 4, 4]},
55 | initializer_seed=42,
56 | )
57 | defaults.update(kwargs)
58 | return defaults
59 |
60 |
61 | def make_pwgan_discriminator_args(**kwargs):
62 | defaults = dict(
63 | out_channels=1,
64 | kernel_size=3,
65 | n_layers=10,
66 | conv_channels=64,
67 | use_bias=True,
68 | dilation_factor=1,
69 | nonlinear_activation="LeakyReLU",
70 | nonlinear_activation_params={"alpha": 0.2},
71 | initializer_seed=42,
72 | apply_sigmoid_at_last=False,
73 | )
74 | defaults.update(kwargs)
75 | return defaults
76 |
77 |
78 | @pytest.mark.parametrize(
79 | "dict_g, dict_d",
80 | [
81 | ({}, {}),
82 | (
83 | {"kernel_size": 3, "aux_context_window": 5, "residual_channels": 128},
84 | {"dilation_factor": 2},
85 | ),
86 | ({"stacks": 4, "n_layers": 40}, {"conv_channels": 128}),
87 | ],
88 | )
89 | def test_melgan_trainable(dict_g, dict_d):
90 | random_c = tf.random.uniform(shape=[4, 32, 80], dtype=tf.float32)
91 |
92 | args_g = make_pwgan_generator_args(**dict_g)
93 | args_d = make_pwgan_discriminator_args(**dict_d)
94 |
95 | args_g = ParallelWaveGANGeneratorConfig(**args_g)
96 | args_d = ParallelWaveGANDiscriminatorConfig(**args_d)
97 |
98 | generator = TFParallelWaveGANGenerator(args_g)
99 | generator._build()
100 | discriminator = TFParallelWaveGANDiscriminator(args_d)
101 | discriminator._build()
102 |
103 | generated_audios = generator(random_c, training=True)
104 | discriminator(generated_audios)
105 |
106 | generator.summary()
107 | discriminator.summary()
108 |
--------------------------------------------------------------------------------