├── .clang-format
├── .flake8
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ └── lint.yml
├── .gitignore
├── CPPLINT.cfg
├── LICENSE
├── README.md
├── docs
├── Makefile
├── conf.py
└── index.rst
├── examples
├── aishell-3
│ ├── configs
│ │ ├── v1.json
│ │ ├── v2.json
│ │ └── v3.json
│ ├── local
│ │ ├── download_data.sh
│ │ └── prepare_data.py
│ ├── run.sh
│ ├── tools
│ └── vits
├── baker
│ ├── configs
│ │ ├── v1.json
│ │ ├── v2.json
│ │ ├── v3.json
│ │ ├── vits2_v1.json
│ │ ├── vits2_vocos_v1.json
│ │ └── vocos.json
│ ├── local
│ │ └── prepare_data.py
│ ├── run.sh
│ ├── tools
│ └── vits
├── chinese_prosody_polyphone
│ ├── README.md
│ ├── frontend
│ ├── lexicon
│ │ ├── pinyin_dict.txt
│ │ ├── polyphone.txt
│ │ └── prosody.txt
│ ├── run.sh
│ └── tools
├── ljspeech
│ ├── configs
│ │ ├── v1.json
│ │ ├── v2.json
│ │ └── v3.json
│ ├── local
│ │ ├── download_data.sh
│ │ └── prepare_data.py
│ ├── path.sh
│ ├── run.sh
│ ├── tools
│ └── vits
└── multilingual
│ ├── configs
│ ├── v1.json
│ ├── v2.json
│ └── v3.json
│ ├── run.sh
│ ├── tools
│ └── vits
├── requirements.txt
├── runtime
├── android
│ ├── .gitignore
│ ├── README.md
│ ├── app
│ │ ├── .gitignore
│ │ ├── build.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ ├── androidTest
│ │ │ └── java
│ │ │ │ └── cn
│ │ │ │ └── org
│ │ │ │ └── wenet
│ │ │ │ └── wetts
│ │ │ │ └── ExampleInstrumentedTest.java
│ │ │ ├── main
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── assets
│ │ │ │ └── .gitkeep
│ │ │ ├── cpp
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── cmake
│ │ │ │ ├── frontend
│ │ │ │ ├── model
│ │ │ │ ├── utils
│ │ │ │ └── wetts.cc
│ │ │ ├── java
│ │ │ │ └── cn
│ │ │ │ │ └── org
│ │ │ │ │ └── wenet
│ │ │ │ │ └── wetts
│ │ │ │ │ ├── MainActivity.java
│ │ │ │ │ └── Synthesis.java
│ │ │ └── res
│ │ │ │ ├── drawable-v24
│ │ │ │ └── ic_launcher_foreground.xml
│ │ │ │ ├── drawable
│ │ │ │ └── ic_launcher_background.xml
│ │ │ │ ├── layout
│ │ │ │ └── activity_main.xml
│ │ │ │ ├── mipmap-anydpi-v26
│ │ │ │ ├── ic_launcher.xml
│ │ │ │ └── ic_launcher_round.xml
│ │ │ │ ├── mipmap-hdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-mdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xxxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── values-night
│ │ │ │ └── themes.xml
│ │ │ │ ├── values
│ │ │ │ ├── attrs.xml
│ │ │ │ ├── colors.xml
│ │ │ │ ├── strings.xml
│ │ │ │ └── themes.xml
│ │ │ │ └── xml
│ │ │ │ ├── backup_rules.xml
│ │ │ │ └── data_extraction_rules.xml
│ │ │ └── test
│ │ │ └── java
│ │ │ └── cn
│ │ │ └── org
│ │ │ └── wenet
│ │ │ └── wetts
│ │ │ └── ExampleUnitTest.java
│ ├── build.gradle
│ ├── gradle.properties
│ ├── gradle
│ │ └── wrapper
│ │ │ ├── gradle-wrapper.jar
│ │ │ └── gradle-wrapper.properties
│ ├── gradlew
│ ├── gradlew.bat
│ └── settings.gradle
├── core
│ ├── bin
│ │ ├── CMakeLists.txt
│ │ ├── http_server_main.cc
│ │ └── tts_main.cc
│ ├── cmake
│ │ ├── boost.cmake
│ │ ├── gflags.cmake
│ │ ├── glog.cmake
│ │ ├── gtest.cmake
│ │ ├── jsoncpp.cmake
│ │ ├── onnxruntime.cmake
│ │ └── wetextprocessing.cmake
│ ├── frontend
│ │ ├── CMakeLists.txt
│ │ ├── g2p_en.cc
│ │ ├── g2p_en.h
│ │ ├── g2p_prosody.cc
│ │ ├── g2p_prosody.h
│ │ ├── lexicon.cc
│ │ ├── lexicon.h
│ │ └── wav.h
│ ├── http
│ │ ├── CMakeLists.txt
│ │ ├── http_server.cc
│ │ └── http_server.h
│ ├── model
│ │ ├── CMakeLists.txt
│ │ ├── onnx_model.cc
│ │ ├── onnx_model.h
│ │ ├── tts_model.cc
│ │ └── tts_model.h
│ ├── test
│ │ └── CMakeLists.txt
│ └── utils
│ │ ├── CMakeLists.txt
│ │ ├── fst.cc
│ │ ├── fst.h
│ │ ├── string.cc
│ │ ├── string.h
│ │ ├── timer.h
│ │ ├── utils.cc
│ │ └── utils.h
├── cpu_triton_stream
│ ├── .gitignore
│ ├── Dockerfile
│ ├── Makefile
│ ├── README.md
│ ├── client
│ │ ├── client.py
│ │ ├── stream_client.py
│ │ ├── text.scp
│ │ └── web_ui.py
│ ├── model_repo
│ │ ├── decoder
│ │ │ ├── 1
│ │ │ │ └── .gitkeep
│ │ │ └── config.pbtxt
│ │ ├── encoder
│ │ │ ├── 1
│ │ │ │ └── .gitkeep
│ │ │ └── config.pbtxt
│ │ ├── stream_tts
│ │ │ ├── 1
│ │ │ │ └── model.py
│ │ │ └── config.pbtxt
│ │ └── tts
│ │ │ ├── 1
│ │ │ └── model.py
│ │ │ └── config.pbtxt
│ ├── requirements-client.txt
│ └── requirements-web.txt
├── gpu_triton
│ ├── Dockerfile
│ ├── README.md
│ ├── client
│ │ ├── client.py
│ │ ├── generate_input.py
│ │ └── text.scp
│ └── model_repo
│ │ ├── generator
│ │ ├── 1
│ │ │ └── .gitkeep
│ │ └── config.pbtxt
│ │ └── tts
│ │ ├── 1
│ │ └── model.py
│ │ └── config.pbtxt
├── onnxruntime
│ ├── CMakeLists.txt
│ ├── bin
│ ├── cmake
│ ├── frontend
│ ├── http
│ ├── model
│ └── utils
└── web
│ ├── README.md
│ ├── app.py
│ └── requirements.txt
├── setup.cfg
├── setup.py
├── tools
├── cleaners.py
├── compute_spec_length.py
├── gen_pinyin_lexicon.py
└── parse_options.sh
└── wetts
├── __init__.py
├── cli
├── __init__.py
├── frontend.py
├── hub.py
├── model.py
└── tts.py
├── frontend
├── README.md
├── dataset.py
├── export_onnx.py
├── g2p_prosody.py
├── hanzi2pinyin.py
├── model.py
├── test_polyphone.py
├── test_prosody.py
├── train.py
└── utils.py
└── vits
├── data_utils.py
├── export_onnx.py
├── inference.py
├── inference_onnx.py
├── losses.py
├── model
├── attentions.py
├── decoders.py
├── discriminators.py
├── duration_predictors.py
├── encoders.py
├── flows.py
├── models.py
├── modules.py
└── normalization.py
├── train.py
└── utils
├── commons.py
├── mel_processing.py
├── monotonic_align.py
├── stft.py
├── task.py
└── transforms.py
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | Language: Cpp
3 | # BasedOnStyle: Google
4 | AccessModifierOffset: -1
5 | AlignAfterOpenBracket: Align
6 | AlignConsecutiveAssignments: false
7 | AlignConsecutiveDeclarations: false
8 | AlignEscapedNewlinesLeft: true
9 | AlignOperands: true
10 | AlignTrailingComments: true
11 | AllowAllParametersOfDeclarationOnNextLine: true
12 | AllowShortBlocksOnASingleLine: false
13 | AllowShortCaseLabelsOnASingleLine: false
14 | AllowShortFunctionsOnASingleLine: All
15 | AllowShortIfStatementsOnASingleLine: true
16 | AllowShortLoopsOnASingleLine: true
17 | AlwaysBreakAfterDefinitionReturnType: None
18 | AlwaysBreakAfterReturnType: None
19 | AlwaysBreakBeforeMultilineStrings: true
20 | AlwaysBreakTemplateDeclarations: true
21 | BinPackArguments: true
22 | BinPackParameters: true
23 | BraceWrapping:
24 | AfterClass: false
25 | AfterControlStatement: false
26 | AfterEnum: false
27 | AfterFunction: false
28 | AfterNamespace: false
29 | AfterObjCDeclaration: false
30 | AfterStruct: false
31 | AfterUnion: false
32 | BeforeCatch: false
33 | BeforeElse: false
34 | IndentBraces: false
35 | BreakBeforeBinaryOperators: None
36 | BreakBeforeBraces: Attach
37 | BreakBeforeTernaryOperators: true
38 | BreakConstructorInitializersBeforeComma: false
39 | BreakAfterJavaFieldAnnotations: false
40 | BreakStringLiterals: true
41 | ColumnLimit: 80
42 | CommentPragmas: '^ IWYU pragma:'
43 | ConstructorInitializerAllOnOneLineOrOnePerLine: true
44 | ConstructorInitializerIndentWidth: 4
45 | ContinuationIndentWidth: 4
46 | Cpp11BracedListStyle: true
47 | DisableFormat: false
48 | ExperimentalAutoDetectBinPacking: false
49 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
50 | IncludeCategories:
51 | - Regex: '^<.*\.h>'
52 | Priority: 1
53 | - Regex: '^<.*'
54 | Priority: 2
55 | - Regex: '.*'
56 | Priority: 3
57 | IncludeIsMainRegex: '([-_](test|unittest))?$'
58 | IndentCaseLabels: true
59 | IndentWidth: 2
60 | IndentWrappedFunctionNames: false
61 | JavaScriptQuotes: Leave
62 | JavaScriptWrapImports: true
63 | KeepEmptyLinesAtTheStartOfBlocks: false
64 | MacroBlockBegin: ''
65 | MacroBlockEnd: ''
66 | MaxEmptyLinesToKeep: 1
67 | NamespaceIndentation: None
68 | ObjCBlockIndentWidth: 2
69 | ObjCSpaceAfterProperty: false
70 | ObjCSpaceBeforeProtocolList: false
71 | PenaltyBreakBeforeFirstCallParameter: 1
72 | PenaltyBreakComment: 300
73 | PenaltyBreakFirstLessLess: 120
74 | PenaltyBreakString: 1000
75 | PenaltyExcessCharacter: 1000000
76 | PenaltyReturnTypeOnItsOwnLine: 200
77 | PointerAlignment: Left
78 | ReflowComments: true
79 | SortIncludes: true
80 | SpaceAfterCStyleCast: false
81 | SpaceBeforeAssignmentOperators: true
82 | SpaceBeforeParens: ControlStatements
83 | SpaceInEmptyParentheses: false
84 | SpacesBeforeTrailingComments: 2
85 | SpacesInAngles: false
86 | SpacesInContainerLiterals: true
87 | SpacesInCStyleCastParentheses: false
88 | SpacesInParentheses: false
89 | SpacesInSquareBrackets: false
90 | Standard: Auto
91 | TabWidth: 8
92 | UseTab: Never
93 | ...
94 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = B,C,E,F,P,T4,W,B9
3 | max-line-length = 80
4 | # C408 ignored because we like the dict keyword argument syntax
5 | # E501 is not flexible enough, we're using B950 instead
6 | ignore =
7 | E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying
9 | # to line this up with executable bit
10 | EXE001,
11 | # these ignores are from flake8-bugbear; please fix!
12 | B006,B007,B008,B905
13 | # these ignores are from flake8-comprehensions; please fix!
14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
15 | exclude =
16 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 |
9 | jobs:
10 | quick-checks:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Fetch WeTTS
14 | uses: actions/checkout@v1
15 | - name: Checkout PR tip
16 | run: |
17 | set -eux
18 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then
19 | # We are on a PR, so actions/checkout leaves us on a merge commit.
20 | # Check out the actual tip of the branch.
21 | git checkout ${{ github.event.pull_request.head.sha }}
22 | fi
23 | echo ::set-output name=commit_sha::$(git rev-parse HEAD)
24 | id: get_pr_tip
25 | - name: Ensure no tabs
26 | run: |
27 | (! git grep -I -l $'\t' -- . ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
28 | - name: Ensure no trailing whitespace
29 | run: |
30 | (! git grep -I -n $' $' -- . ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
31 |
32 | flake8-py3:
33 | runs-on: ubuntu-latest
34 | steps:
35 | - name: Setup Python
36 | uses: actions/setup-python@v1
37 | with:
38 | python-version: 3.9
39 | architecture: x64
40 | - name: Fetch WeTTS
41 | uses: actions/checkout@v1
42 | - name: Checkout PR tip
43 | run: |
44 | set -eux
45 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then
46 | # We are on a PR, so actions/checkout leaves us on a merge commit.
47 | # Check out the actual tip of the branch.
48 | git checkout ${{ github.event.pull_request.head.sha }}
49 | fi
50 | echo ::set-output name=commit_sha::$(git rev-parse HEAD)
51 | id: get_pr_tip
52 | - name: Run flake8
53 | run: |
54 | set -eux
55 | pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
56 | flake8 --version
57 | flake8
58 | if [ $? != 0 ]; then exit 1; fi
59 |
60 | cpplint:
61 | runs-on: ubuntu-latest
62 | steps:
63 | - name: Setup Python
64 | uses: actions/setup-python@v1
65 | with:
66 | python-version: 3.x
67 | architecture: x64
68 | - name: Fetch WeTTS
69 | uses: actions/checkout@v1
70 | - name: Checkout PR tip
71 | run: |
72 | set -eux
73 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then
74 | # We are on a PR, so actions/checkout leaves us on a merge commit.
75 | # Check out the actual tip of the branch.
76 | git checkout ${{ github.event.pull_request.head.sha }}
77 | fi
78 | echo ::set-output name=commit_sha::$(git rev-parse HEAD)
79 | id: get_pr_tip
80 | - name: Run cpplint
81 | run: |
82 | set -eux
83 | pip install cpplint
84 | cpplint --version
85 | cpplint --recursive .
86 | if [ $? != 0 ]; then exit 1; fi
87 |
88 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Visual Studio Code files
7 | .vscode
8 | .vs
9 |
10 | # PyCharm files
11 | .idea
12 |
13 | # Eclipse Project settings
14 | *.*project
15 | .settings
16 |
17 | # Sublime Text settings
18 | *.sublime-workspace
19 | *.sublime-project
20 |
21 | # Editor temporaries
22 | *.swn
23 | *.swo
24 | *.swp
25 | *.swm
26 | *~
27 |
28 | # IPython notebook checkpoints
29 | .ipynb_checkpoints
30 |
31 | # macOS dir files
32 | .DS_Store
33 |
34 | exp
35 | data
36 | raw_wav
37 | tensorboard
38 | **/*build*
39 | /BZNSYP
40 |
--------------------------------------------------------------------------------
/CPPLINT.cfg:
--------------------------------------------------------------------------------
1 | root=runtime/core
2 | filter=-build/c++11
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WeTTS
2 |
3 | Production First and Production Ready End-to-End Text-to-Speech Toolkit
4 |
5 | ## Install
6 |
7 | ### Install python package
8 | ``` sh
9 | pip install git+https://github.com/wenet-e2e/wetts.git
10 | ```
11 | **Command-line usage** (use `-h` for parameters):
12 |
13 | ``` sh
14 | wetts --text "今天天气怎么样" --wav output.wav
15 | ```
16 |
17 | **Python programming usage**:
18 |
19 | ``` python
20 | import wetts
21 |
22 | # TODO
23 | ```
24 |
25 | ### Install for development & deployment
26 |
27 | We suggest to install WeTTS with Anaconda or Miniconda.
28 |
29 | Clone this repo:
30 |
31 | ```sh
32 | git clone https://github.com/wenet-e2e/wetts.git
33 | ```
34 |
35 | Create the environment:
36 |
37 | ```bash
38 | conda create -n wetts python=3.8 -y
39 | conda activate wetts
40 | pip install -r requirements.txt
41 | ```
42 |
43 | ## Roadmap
44 |
45 | We mainly focus on end to end, production, and on-device TTS. We are going to use:
46 |
47 | * backend: end to end model, such as:
48 | * [VITS](https://arxiv.org/pdf/2106.06103.pdf)
49 | * frontend:
50 | * Text Normalization: [WeTextProcessing](https://github.com/wenet-e2e/WeTextProcessing)
51 | * Prosody & Polyphones: [Unified Mandarin TTS Front-end Based on Distilled BERT Model](https://arxiv.org/pdf/2012.15404.pdf)
52 |
53 | ## Dataset
54 |
55 | We plan to support a variaty of open source TTS datasets, include but not limited to:
56 |
57 | * [Baker](https://www.data-baker.com/data/index/TNtts), Chinese Standard Mandarin Speech corpus open sourced by Data Baker.
58 | * [AISHELL-3](https://openslr.org/93), a large-scale and high-fidelity multi-speaker Mandarin speech corpus.
59 | * [Opencpop](https://wenet.org.cn/opencpop), Mandarin singing voice synthesis (SVS) corpus open sourced by Netease Fuxi.
60 |
61 | ## Pretrained Models
62 |
63 | | Dataset | Language | Checkpoint Model | Runtime Model |
64 | | -------------- | -------- | ---------------- | ------------- |
65 | | Baker | CN | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_exp.tar.gz) | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_onnx.tar.gz) |
66 | | Multilingual | CN | [VITS](https://wenet.org.cn/downloads?models=wetts&version=multilingual_vits_v3_exp.tar.gz) | [VITS](https://wenet.org.cn/downloads?models=wetts&version=multilingual_vits_v3_onnx.tar.gz) |
67 |
68 | ## Runtime
69 |
70 | We plan to support a variaty of hardwares and platforms, including:
71 |
72 | * x86
73 | * Android
74 | * Raspberry Pi
75 | * Other on-device platforms
76 |
77 | ``` bash
78 | export GLOG_logtostderr=1
79 | export GLOG_v=2
80 |
81 | cd runtime/onnxruntime
82 | cmake -B build -DCMAKE_BUILD_TYPE=Release
83 | cmake --build build
84 | ./build/bin/tts_main \
85 | --frontend_flags baker_bert_onnx/frontend.flags \
86 | --vits_flags multilingual_vits_v3_onnx/vits.flags \
87 | --sname baker \
88 | --text "hello我是小明。" \
89 | --wav_path audio.wav
90 | ```
91 |
92 | ## Discussion & Communication
93 |
94 | For Chinese users, you can aslo scan the QR code on the left to follow our offical account of WeNet.
95 | We created a WeChat group for better discussion and quicker response.
96 | Please scan the personal QR code on the right, and the guy is responsible for inviting you to the chat group.
97 |
98 | |
|
|
99 | | ---- | ---- |
100 |
101 | Or you can directly discuss on [Github Issues](https://github.com/wenet-e2e/wetts/issues).
102 |
103 | ## Acknowledgement
104 |
105 | 1. We borrow a lot of code from [vits](https://github.com/jaywalnut310/vits) for VITS implementation.
106 | 2. We refer [PaddleSpeech](https://github.com/PaddlePaddle/PaddleSpeech) for `pinyin` lexicon generation.
107 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SPHINXPROJ = Wenet
9 | SOURCEDIR = .
10 | BUILDDIR = _build
11 |
12 | # Put it first so that "make" without argument is like "make help".
13 | help:
14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
15 |
16 | .PHONY: help Makefile
17 |
18 | # Catch-all target: route all unknown targets to Sphinx using the new
19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
20 | %: Makefile
21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Wenet'
21 | copyright = '2020, wenet-team'
22 | author = 'wenet-team'
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = [
31 | "nbsphinx",
32 | "sphinx.ext.autodoc",
33 | 'sphinx.ext.napoleon',
34 | 'sphinx.ext.viewcode',
35 | "sphinx.ext.mathjax",
36 | "sphinx.ext.todo",
37 | # "sphinxarg.ext",
38 | "sphinx_markdown_tables",
39 | 'recommonmark',
40 | 'sphinx_rtd_theme',
41 | ]
42 |
43 | # Add any paths that contain templates here, relative to this directory.
44 | templates_path = ['_templates']
45 |
46 |
47 | # The suffix(es) of source filenames.
48 | # You can specify multiple suffix as a list of string:
49 | source_suffix = {
50 | '.rst': 'restructuredtext',
51 | '.txt': 'markdown',
52 | '.md': 'markdown',
53 | }
54 |
55 | # List of patterns, relative to source directory, that match files and
56 | # directories to ignore when looking for source files.
57 | # This pattern also affects html_static_path and html_extra_path.
58 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
59 |
60 |
61 | # -- Options for HTML output -------------------------------------------------
62 |
63 | # The theme to use for HTML and HTML Help pages. See the documentation for
64 | # a list of builtin themes.
65 | # html_theme = 'alabaster'
66 | html_theme = "sphinx_rtd_theme"
67 |
68 | # Add any paths that contain custom static files (such as style sheets) here,
69 | # relative to this directory. They are copied after the builtin static files,
70 | # so a file named "default.css" will overwrite the builtin "default.css".
71 | html_static_path = ['_static']
72 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Wenet documentation master file, created by
2 | sphinx-quickstart on Thu Dec 3 11:43:53 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to WeTTS's documentation!
7 | =================================
8 |
9 | Production First and Production Ready End-to-End Text-to-Speech Toolkit
10 |
11 | .. toctree::
12 | :maxdepth: 1
13 | :caption: Tutorial:
14 |
15 |
16 | Indices and tables
17 | ==================
18 |
19 | * :ref:`genindex`
20 | * :ref:`modindex`
21 | * :ref:`search`
22 |
--------------------------------------------------------------------------------
/examples/aishell-3/configs/v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 44100,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 512,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/aishell-3/configs/v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 128,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/aishell-3/configs/v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 16000,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "2",
39 | "upsample_rates": [8,8,4],
40 | "upsample_kernel_sizes": [16,16,8],
41 | "upsample_initial_channel": 256,
42 | "resblock_kernel_sizes": [3,5,7],
43 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
44 | "n_layers_q": 3,
45 | "use_sdp": false,
46 | "use_spectral_norm": false,
47 | "gin_channels": 256
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/examples/aishell-3/local/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
3 |
4 | if [ $# -ne 2 ]; then
5 | echo "Usage: $0 "
6 | exit 0;
7 | fi
8 |
9 | url=$1
10 | dir=$2
11 |
12 | [ ! -d $dir ] && mkdir -p $dir
13 |
14 | # Download data
15 | if [ ! -f $dir/data_aishell3.tgz ]; then
16 | if ! which wget >/dev/null; then
17 | echo "$0: wget is not installed."
18 | exit 1;
19 | fi
20 | echo "$0: downloading data from $url. This may take some time, please wait"
21 |
22 | cd $dir
23 | if ! wget --no-check-certificate $url; then
24 | echo "$0: error executing wget $url"
25 | exit 1;
26 | fi
27 | fi
28 |
29 |
30 | cd $dir
31 | if ! tar -xvzf data_aishell3.tgz; then
32 | echo "$0: error un-tarring archive $dir/data_aishell3.tgz"
33 | exit 1;
34 | fi
35 |
--------------------------------------------------------------------------------
/examples/aishell-3/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | #!/user/bin/env python3
2 |
3 | # Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import sys
19 |
20 | if len(sys.argv) != 4:
21 | print("Usage: prepare_data.py lexicon in_data_dir out_data")
22 | sys.exit(-1)
23 |
24 | lexicon = {}
25 | with open(sys.argv[1], "r", encoding="utf8") as fin:
26 | for line in fin:
27 | arr = line.strip().split()
28 | lexicon[arr[0]] = arr[1:]
29 |
30 | train_set_label_file = os.path.join(sys.argv[2], "train", "label_train-set.txt")
31 | with open(train_set_label_file, encoding="utf8") as fin, open(
32 | sys.argv[3], "w", encoding="utf8"
33 | ) as fout:
34 | # skip the first five lines in label_train-set.txt
35 | lines = [x.strip() for x in fin.readlines()][5:]
36 | for line in lines:
37 | key, text, _ = line.split("|")
38 | speaker = key[:-4]
39 | wav_path = os.path.join(
40 | sys.argv[2], "train", "wav", speaker, "{}.wav".format(key)
41 | )
42 | phones = []
43 | for x in text.split():
44 | if x == "%" or x == "$":
45 | phones.append(x)
46 | elif x in lexicon:
47 | phones.extend(lexicon[x])
48 | else:
49 | print("{} OOV {}".format(key, x))
50 | sys.exit(-1)
51 | fout.write("{}|{}|sil {}\n".format(wav_path, speaker, " ".join(phones)))
52 |
--------------------------------------------------------------------------------
/examples/aishell-3/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2022 Jie Chen
4 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
5 |
6 | [ -f path.sh ] && . path.sh
7 |
8 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
9 |
10 | stage=0 # start from -1 if you need to download data
11 | stop_stage=3
12 |
13 | dataset_url=https://openslr.magicdatatech.com/resources/93/data_aishell3.tgz
14 | dataset_dir=. # path to dataset directory
15 |
16 | dir=exp/v1 # training dir
17 | config=configs/v1.json
18 |
19 | data=data
20 | test_audio=test_audio
21 |
22 | . tools/parse_options.sh || exit 1;
23 |
24 |
25 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
26 | # Download data
27 | local/download_data.sh $dataset_url $dataset_dir
28 | fi
29 |
30 |
31 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
32 | # Prepare data for training/validation
33 | mkdir -p $data
34 | python tools/gen_pinyin_lexicon.py \
35 | --with-zero-initial --with-tone --with-r \
36 | $data/lexicon.txt \
37 | $data/phones.list
38 | python local/prepare_data.py \
39 | $data/lexicon.txt \
40 | $dataset_dir/data_aishell3 \
41 | $data/all.txt
42 |
43 | # Compute spec length (optional, but recommended)
44 | python tools/compute_spec_length.py \
45 | $data/all.txt \
46 | $config \
47 | $data/all_spec_length.txt
48 | mv $data/all_spec_length.txt $data/all.txt
49 |
50 | cat $data/all.txt | awk -F '|' '{print $2}' | \
51 | sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt
52 | echo 'sil 0' > $data/phones.txt
53 | cat $data/all.txt | awk -F '|' '{print $3}' | \
54 | awk '{for (i=1;i<=NF;i++) print $i}' | sort | uniq | \
55 | grep -v 'sil' | awk '{print $0, NR}' >> $data/phones.txt
56 |
57 | # Split train/validation
58 | shuf --random-source=<(yes 777) $data/all.txt > $data/train.txt
59 | head -n 100 $data/train.txt > $data/val.txt
60 | sed -i '1,100d' $data/train.txt
61 | head -n 10 $data/train.txt > $data/test.txt
62 | sed -i '1,10d' $data/train.txt
63 | fi
64 |
65 |
66 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
67 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')
68 | torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
69 | vits/train.py -c $config -m $dir \
70 | --train_data $data/train.txt \
71 | --val_data $data/val.txt \
72 | --speaker_table $data/speaker.txt \
73 | --phone_table $data/phones.txt \
74 | --num_workers 8
75 | fi
76 |
77 |
78 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
79 | mkdir -p $test_audio
80 | python vits/inference.py --cfg $config \
81 | --speaker_table $data/speaker.txt \
82 | --phone_table $data/phones.txt \
83 | --checkpoint $dir/G_90000.pth \
84 | --test_file $data/test.txt \
85 | --outdir $test_audio
86 | fi
87 |
88 |
89 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
90 | mkdir -p $test_audio
91 | python vits/export_onnx.py --cfg $config \
92 | --speaker_table $data/speaker.txt \
93 | --phone_table $data/phones.txt \
94 | --checkpoint $dir/G_90000.pth \
95 | --onnx_model $dir/G_90000.onnx
96 |
97 | python vits/inference_onnx.py --cfg $config \
98 | --speaker_table $data/speaker.txt \
99 | --phone_table $data/phones.txt \
100 | --onnx_model $dir/G_90000.onnx \
101 | --test_file $data/test.txt \
102 | --outdir $test_audio
103 | fi
104 |
--------------------------------------------------------------------------------
/examples/aishell-3/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/examples/aishell-3/vits:
--------------------------------------------------------------------------------
1 | ../../wetts/vits
--------------------------------------------------------------------------------
/examples/baker/configs/v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 512,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256,
47 | "use_wd": true,
48 | "slm_model": "exp/slm/wavlm-base-plus",
49 | "slm_sr": 16000,
50 | "slm_hidden": 768,
51 | "slm_nlayers": 13,
52 | "slm_initial_channel": 64
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/examples/baker/configs/v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 128,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/baker/configs/v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 16000,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "2",
39 | "upsample_rates": [8,8,4],
40 | "upsample_kernel_sizes": [16,16,8],
41 | "upsample_initial_channel": 256,
42 | "resblock_kernel_sizes": [3,5,7],
43 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
44 | "n_layers_q": 3,
45 | "use_sdp": false,
46 | "use_spectral_norm": false,
47 | "gin_channels": 256
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/examples/baker/configs/vits2_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "use_mel_posterior_encoder": true,
21 | "max_wav_value": 32768.0,
22 | "sampling_rate": 22050,
23 | "filter_length": 1024,
24 | "hop_length": 256,
25 | "win_length": 1024,
26 | "n_mel_channels": 80,
27 | "mel_fmin": 0.0,
28 | "mel_fmax": null
29 | },
30 | "model": {
31 | "use_mel_posterior_encoder": true,
32 | "use_transformer_flows": true,
33 | "transformer_flow_type": "pre_conv",
34 | "use_spk_conditioned_encoder": false,
35 | "use_noise_scaled_mas": true,
36 | "use_duration_discriminator": true,
37 | "inter_channels": 192,
38 | "hidden_channels": 192,
39 | "filter_channels": 768,
40 | "n_heads": 2,
41 | "n_layers": 6,
42 | "kernel_size": 3,
43 | "p_dropout": 0.1,
44 | "resblock": "1",
45 | "resblock_kernel_sizes": [3,7,11],
46 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
47 | "upsample_rates": [8,8,2,2],
48 | "upsample_initial_channel": 512,
49 | "upsample_kernel_sizes": [16,16,4,4],
50 | "n_layers_q": 3,
51 | "use_sdp": true,
52 | "use_spectral_norm": false,
53 | "gin_channels": 256,
54 | "use_wd": true,
55 | "slm_model": "exp/slm/wavlm-base-plus",
56 | "slm_sr": 16000,
57 | "slm_hidden": 768,
58 | "slm_nlayers": 13,
59 | "slm_initial_channel": 64
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/examples/baker/configs/vits2_vocos_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "use_mel_posterior_encoder": true,
21 | "max_wav_value": 32768.0,
22 | "sampling_rate": 24000,
23 | "filter_length": 1024,
24 | "hop_length": 256,
25 | "win_length": 1024,
26 | "n_mel_channels": 100,
27 | "mel_fmin": 0.0,
28 | "mel_fmax": null
29 | },
30 | "model": {
31 | "vocoder_type": "vocos",
32 | "use_mrd_disc": true,
33 | "use_mel_posterior_encoder": true,
34 | "use_transformer_flows": true,
35 | "transformer_flow_type": "pre_conv",
36 | "use_spk_conditioned_encoder": false,
37 | "use_noise_scaled_mas": true,
38 | "use_duration_discriminator": true,
39 | "inter_channels": 192,
40 | "hidden_channels": 192,
41 | "filter_channels": 768,
42 | "n_heads": 2,
43 | "n_layers": 6,
44 | "kernel_size": 3,
45 | "p_dropout": 0.1,
46 | "vocos_channels": 512,
47 | "vocos_h_channels": 1536,
48 | "vocos_out_channels": 1026,
49 | "vocos_num_layers": 8,
50 | "vocos_istft_config": {
51 | "n_fft": 1024,
52 | "hop_length": 256,
53 | "win_length": 1024,
54 | "center": true
55 | },
56 | "resblock": "1",
57 | "resblock_kernel_sizes": [3,7,11],
58 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
59 | "upsample_rates": [8,8,2,2],
60 | "upsample_initial_channel": 512,
61 | "upsample_kernel_sizes": [16,16,4,4],
62 | "n_layers_q": 3,
63 | "use_sdp": true,
64 | "use_spectral_norm": false,
65 | "gin_channels": 256
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/examples/baker/configs/vocos.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [
9 | 0.8,
10 | 0.99
11 | ],
12 | "eps": 1e-9,
13 | "batch_size": 32,
14 | "fp16_run": true,
15 | "lr_decay": 0.999875,
16 | "segment_size": 8192,
17 | "init_lr_ratio": 1,
18 | "warmup_epochs": 0,
19 | "c_mel": 45,
20 | "c_kl": 1.0
21 | },
22 | "data": {
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 16000,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null
31 | },
32 | "model": {
33 | "use_mel_posterior_encoder": false,
34 | "vocoder_type": "vocos",
35 | "inter_channels": 192,
36 | "hidden_channels": 192,
37 | "filter_channels": 768,
38 | "n_heads": 2,
39 | "n_layers": 6,
40 | "kernel_size": 3,
41 | "p_dropout": 0.1,
42 | "vocos_channels": 512,
43 | "vocos_h_channels": 1536,
44 | "vocos_out_channels": 1026,
45 | "vocos_num_layers": 8,
46 | "vocos_istft_config": {
47 | "n_fft": 1024,
48 | "hop_length": 256,
49 | "win_length": 1024,
50 | "center": true
51 | },
52 | "resblock": "1",
53 | "resblock_kernel_sizes": [3,7,11],
54 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
55 | "upsample_rates": [8,8,2,2],
56 | "upsample_initial_channel": 512,
57 | "upsample_kernel_sizes": [16,16,4,4],
58 | "n_layers_q": 3,
59 | "use_sdp": false,
60 | "use_spectral_norm": false,
61 | "gin_channels": 256
62 | }
63 | }
--------------------------------------------------------------------------------
/examples/baker/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import re
3 |
4 | lexicon = {}
5 |
6 | with open(sys.argv[1], "r", encoding="utf8") as fin:
7 | for line in fin:
8 | arr = line.strip().split()
9 | lexicon[arr[0]] = arr[1:]
10 |
11 | with open(sys.argv[2], "r", encoding="utf8") as fin:
12 | lines = fin.readlines()
13 | for i in range(0, len(lines), 2):
14 | key = lines[i][:6]
15 | content = lines[i][7:].strip()
16 | content = re.sub("[。,、“”?:……!( )—;]", "", content)
17 | if "P" in content: # ignore utt 002365
18 | continue
19 | chars = []
20 | prosody = {}
21 |
22 | j = 0
23 | while j < len(content):
24 | if content[j] == "#":
25 | prosody[len(chars) - 1] = content[j : j + 2]
26 | j += 2
27 | else:
28 | chars.append(content[j])
29 | j += 1
30 | if key == "005107":
31 | lines[i + 1] = lines[i + 1].replace(" ng1", " en1")
32 | syllable = lines[i + 1].strip().split()
33 | s_index = 0
34 | phones = []
35 | for k, char in enumerate(chars):
36 | # 儿化音处理
37 | er_flag = False
38 | if char == "儿" and (
39 | s_index == len(syllable) or syllable[s_index][0:2] != "er"
40 | ):
41 | er_flag = True
42 | else:
43 | phones.extend(lexicon[syllable[s_index]])
44 | s_index += 1
45 | if k in prosody:
46 | if er_flag:
47 | phones[-1] = prosody[k]
48 | else:
49 | phones.append(prosody[k])
50 | else:
51 | phones.append("#0")
52 | print("{}/{}.wav|baker|sil {}\n".format(sys.argv[3], key, " ".join(phones)))
53 |
--------------------------------------------------------------------------------
/examples/baker/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
4 |
5 | [ -f path.sh ] && . path.sh
6 |
7 | export CUDA_VISIBLE_DEVICES="0,1,2,3" # specify your gpu id for training
8 |
9 | stage=0 # start from -1 if you need to download data
10 | stop_stage=3
11 |
12 | dir=exp/v3 # training dir
13 | config=configs/v3.json
14 |
15 | # Please download data from https://www.data-baker.com/data/index/TNtts, and
16 | # set `raw_data_dir` to your data.
17 | raw_data_dir=. # path to dataset directory
18 | data=data
19 | test_audio=test_audio
20 | ckpt_step=200000
21 |
22 | . tools/parse_options.sh || exit 1;
23 |
24 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
25 | # Prepare data for training/validation
26 | mkdir -p $data
27 | python tools/gen_pinyin_lexicon.py \
28 | --with-zero-initial --with-tone --with-r \
29 | $data/lexicon.txt \
30 | $data/phones.list
31 | python local/prepare_data.py \
32 | $data/lexicon.txt \
33 | $raw_data_dir/ProsodyLabeling/000001-010000.txt \
34 | $raw_data_dir/Wave > $data/all.txt
35 |
36 | cat $data/all.txt | awk -F '|' '{print $2}' | \
37 | sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt
38 | echo 'sil 0' > $data/phones.txt
39 | cat $data/all.txt | awk -F '|' '{print $3}' | \
40 | awk '{for (i=1;i<=NF;i++) print $i}' | sort | uniq | \
41 | grep -v 'sil' | awk '{print $0, NR}' >> $data/phones.txt
42 |
43 | # Split train/validation
44 | shuf --random-source=<(yes 777) $data/all.txt > $data/train.txt
45 | head -n 100 $data/train.txt > $data/val.txt
46 | sed -i '1,100d' $data/train.txt
47 | head -n 10 $data/train.txt > $data/test.txt
48 | sed -i '1,10d' $data/train.txt
49 | fi
50 |
51 |
52 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
53 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')
54 | torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
55 | vits/train.py -c $config -m $dir \
56 | --train_data $data/train.txt \
57 | --val_data $data/val.txt \
58 | --speaker_table $data/speaker.txt \
59 | --phone_table $data/phones.txt \
60 | --num_workers 8
61 | fi
62 |
63 |
64 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
65 | mkdir -p $test_audio
66 | python vits/inference.py --cfg $config \
67 | --speaker_table $data/speaker.txt \
68 | --phone_table $data/phones.txt \
69 | --checkpoint $dir/G_$ckpt_step.pth \
70 | --test_file $data/test.txt \
71 | --outdir $test_audio
72 | fi
73 |
74 |
75 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
76 | mkdir -p $test_audio
77 | python vits/export_onnx.py --cfg $config \
78 | --speaker_table $data/speaker.txt \
79 | --phone_table $data/phones.txt \
80 | --checkpoint $dir/G_$ckpt_step.pth \
81 | --onnx_model $dir/G_$ckpt_step.onnx
82 |
83 | python vits/inference_onnx.py --cfg $config \
84 | --speaker_table $data/speaker.txt \
85 | --phone_table $data/phones.txt \
86 | --onnx_model $dir/G_$ckpt_step.onnx \
87 | --test_file $data/test.txt \
88 | --outdir $test_audio
89 | fi
90 |
91 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
92 | mkdir -p $test_audio
93 | python vits/export_onnx.py --cfg $config \
94 | --streaming \
95 | --speaker_table $data/speaker.txt \
96 | --phone_table $data/phones.txt \
97 | --checkpoint $dir/G_$ckpt_step.pth \
98 | --onnx_model $dir/G_$ckpt_step.onnx
99 |
100 | python vits/inference_onnx.py --cfg $config \
101 | --streaming \
102 | --speaker_table $data/speaker.txt \
103 | --phone_table $data/phones.txt \
104 | --onnx_model $dir/G_$ckpt_step.onnx \
105 | --test_file $data/test.txt \
106 | --outdir $test_audio
107 | fi
108 |
--------------------------------------------------------------------------------
/examples/baker/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/examples/baker/vits:
--------------------------------------------------------------------------------
1 | ../../wetts/vits
--------------------------------------------------------------------------------
/examples/chinese_prosody_polyphone/README.md:
--------------------------------------------------------------------------------
1 | ## Model Method
2 |
3 | Please see [doc](../../wetts/frontend/README.md) for details.
4 |
5 | ## Data Description
6 |
7 | Here are the details of the prosody and polyphone data used in the recipe.
8 | The data are either collected from web or contributed by the community.
9 |
10 |
11 | ### Polyphone
12 |
13 | | corpus | number | source or contributors |
14 | |--------|--------|------------------------------------|
15 | | g2pM | 100000 | https://github.com/kakaobrain/g2pM |
16 | | | | |
17 |
18 | TODO(Binbin Zhang): Add more data
19 |
20 |
21 | ### Prosody
22 |
23 | | corpus | number | source or contributors |
24 | |---------|--------|---------------------------------------------|
25 | | biaobei | 10000 | https://www.data-baker.com/open_source.html |
26 | | | | |
27 |
28 | TODO(Binbin Zhang): Add more data
29 |
30 | ## Benchmark
31 |
32 | BERT-MLT is for polyphone and prosody joint training.
33 |
34 | ### Polyphone
35 |
36 | | system | ACC |
37 | |----------------|--------|
38 | | BERT-polyphone | 0.9778 |
39 | | BERT-MLT | 0.9797 |
40 |
41 |
42 | ### Prosody
43 |
44 | | system | PW-F1 | PPH-F1 | IPH-F1 |
45 | |---------------------------|--------|--------|--------|
46 | | BERT-prosody | 0.9308 | 0.8058 | 0.8596 |
47 | | BERT-MLT | 0.9334 | 0.8088 | 0.8559 |
48 | | BERT-prosody (exclude #4) | 0.9233 | 0.7074 | 0.6120 |
49 | | BERT-MLT (exclude #4) | 0.9261 | 0.7146 | 0.6140 |
50 |
--------------------------------------------------------------------------------
/examples/chinese_prosody_polyphone/frontend:
--------------------------------------------------------------------------------
1 | ../../wetts/frontend
--------------------------------------------------------------------------------
/examples/chinese_prosody_polyphone/lexicon/prosody.txt:
--------------------------------------------------------------------------------
1 | #0
2 | #1
3 | #2
4 | #3
5 | #4
6 |
--------------------------------------------------------------------------------
/examples/chinese_prosody_polyphone/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
3 |
4 | stage=0
5 | stop_stage=4
6 | url=https://wetts-1256283475.cos.ap-shanghai.myqcloud.com/data
7 |
8 | dir=exp
9 |
10 | . tools/parse_options.sh
11 |
12 |
13 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
14 | # Download prosody and polyphone
15 | mkdir -p data/download
16 | pushd data/download
17 | wget -c $url/polyphone.tar.gz && tar zxf polyphone.tar.gz
18 | wget -c $url/prosody.tar.gz && tar zxf prosody.tar.gz
19 | popd
20 | fi
21 |
22 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
23 | # Combine prosody data
24 | mkdir -p data/prosody
25 | cat data/download/prosody/biaobei/train.txt > data/prosody/train.txt
26 | cat data/download/prosody/biaobei/cv.txt > data/prosody/cv.txt
27 | # Combine polyphone data
28 | mkdir -p data/polyphone
29 | cat data/download/polyphone/g2pM/train.txt > data/polyphone/train.txt
30 | cat data/download/polyphone/g2pM/dev.txt > data/polyphone/cv.txt
31 | cat data/download/polyphone/g2pM/test.txt > data/polyphone/test.txt
32 | fi
33 |
34 |
35 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
36 | mkdir -p $dir
37 | python frontend/train.py \
38 | --gpu 2 \
39 | --lr 0.001 \
40 | --num_epochs 10 \
41 | --batch_size 32 \
42 | --log_interval 10 \
43 | --polyphone_weight 0.1 \
44 | --polyphone_dict lexicon/polyphone.txt \
45 | --train_polyphone_data data/polyphone/train.txt \
46 | --cv_polyphone_data data/polyphone/cv.txt \
47 | --prosody_dict lexicon/prosody.txt \
48 | --train_prosody_data data/prosody/train.txt \
49 | --cv_prosody_data data/prosody/cv.txt \
50 | --model_dir $dir
51 | fi
52 |
53 |
54 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55 | # Test polyphone, metric: accuracy
56 | python frontend/test_polyphone.py \
57 | --polyphone_dict lexicon/polyphone.txt \
58 | --prosody_dict lexicon/prosody.txt \
59 | --test_data data/polyphone/test.txt \
60 | --batch_size 32 \
61 | --checkpoint $dir/9.pt
62 |
63 | # Test prosody, metric: F1-score
64 | python frontend/test_prosody.py \
65 | --polyphone_dict lexicon/polyphone.txt \
66 | --prosody_dict lexicon/prosody.txt \
67 | --test_data data/prosody/cv.txt \
68 | --batch_size 32 \
69 | --checkpoint $dir/9.pt
70 | fi
71 |
72 |
73 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
74 | # export onnx model
75 | python frontend/export_onnx.py \
76 | --polyphone_dict lexicon/polyphone.txt \
77 | --prosody_dict lexicon/prosody.txt \
78 | --checkpoint $dir/9.pt \
79 | --onnx_model $dir/9.onnx
80 | fi
81 |
82 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
83 | # g2p
84 | # text: 八方财宝进
85 | # pinyin ['ba1', 'fang1', 'cai2', 'bao3', 'jin4']
86 | # prosody [0 1 0 0 4]
87 | python frontend/g2p_prosody.py \
88 | --text "八方财宝进" \
89 | --hanzi2pinyin_file lexicon/pinyin_dict.txt \
90 | --polyphone_file lexicon/polyphone.txt \
91 | --polyphone_prosody_model $dir/9.onnx
92 | fi
93 |
--------------------------------------------------------------------------------
/examples/chinese_prosody_polyphone/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/examples/ljspeech/configs/v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 512,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/ljspeech/configs/v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 128,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/ljspeech/configs/v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 16000,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "2",
39 | "upsample_rates": [8,8,4],
40 | "upsample_kernel_sizes": [16,16,8],
41 | "upsample_initial_channel": 256,
42 | "resblock_kernel_sizes": [3,5,7],
43 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
44 | "n_layers_q": 3,
45 | "use_sdp": false,
46 | "use_spectral_norm": false,
47 | "gin_channels": 256
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/examples/ljspeech/local/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
3 |
4 | if [ $# -ne 2 ]; then
5 | echo "Usage: $0 "
6 | exit 0;
7 | fi
8 |
9 | url=$1
10 | dir=$2
11 |
12 | [ ! -d $dir ] && mkdir -p $dir
13 |
14 | # Download data
15 | if [ ! -f $dir/LJSpeech-1.1.tar.bz2 ]; then
16 | if ! which wget >/dev/null; then
17 | echo "$0: wget is not installed."
18 | exit 1;
19 | fi
20 | echo "$0: downloading data from $url. This may take some time, please wait"
21 |
22 | cd $dir
23 | if ! wget --no-check-certificate $url; then
24 | echo "$0: error executing wget $url"
25 | exit 1;
26 | fi
27 | fi
28 |
29 |
30 | cd $dir
31 | if ! tar -xvf LJSpeech-1.1.tar.bz2; then
32 | echo "$0: error un-tarring archive $dir/LJSpeech-1.1.tar.bz2"
33 | exit 1;
34 | fi
35 |
--------------------------------------------------------------------------------
/examples/ljspeech/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | #!/user/bin/env python3
2 |
3 | # Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import csv
19 | import os
20 |
21 | from tools.cleaners import english_cleaners
22 |
23 |
24 | def get_args():
25 | parser = argparse.ArgumentParser(description="prepare data")
26 | parser.add_argument("--data_dir", required=True, help="input data dir")
27 | parser.add_argument("--output", required=True, help="output file")
28 | parser.add_argument("--use_prosody", default=True, help="whether use prosody")
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def main():
34 | args = get_args()
35 |
36 | metadata = os.path.join(args.data_dir, "metadata.csv")
37 | with open(metadata) as fin, open(args.output, "w", encoding="utf8") as fout:
38 | for row in csv.reader(fin, delimiter="|"):
39 | wav_path = os.path.join(args.data_dir, f"wavs/{row[0]}.wav")
40 | phones = english_cleaners(row[-1], args.use_prosody)
41 | fout.write("{}|ljspeech|sil {}\n".format(wav_path, " ".join(phones)))
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/examples/ljspeech/path.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
--------------------------------------------------------------------------------
/examples/ljspeech/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
4 |
5 | [ -f path.sh ] && . path.sh
6 |
7 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
8 |
9 | stage=0 # start from -1 if you need to download data
10 | stop_stage=3
11 |
12 | dataset_url=https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
13 | dataset_dir=. # path to dataset directory
14 |
15 | dir=exp/v3 # training dir
16 | config=configs/v3.json
17 |
18 | data=data
19 | test_audio=test_audio
20 |
21 | . tools/parse_options.sh || exit 1;
22 |
23 |
24 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
25 | # Download data
26 | local/download_data.sh $dataset_url $dataset_dir
27 | fi
28 |
29 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
30 | # Prepare data for training/validation
31 | mkdir -p $data
32 | python local/prepare_data.py \
33 | --data_dir $(realpath $dataset_dir)/LJSpeech-1.1 \
34 | --output $data/out.txt
35 | sed 's/#[0-9] //g' $data/out.txt > $data/all.txt
36 |
37 | cat $data/all.txt | awk -F '|' '{print $2}' | \
38 | sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt
39 | echo 'sil 0' > $data/phones.txt
40 | cat $data/all.txt | awk -F '|' '{print $3}' | \
41 | awk '{for (i=1;i<=NF;i++) print $i}' | sort | uniq | \
42 | grep -v 'sil' | awk '{print $0, NR}' >> $data/phones.txt
43 |
44 | # Split train/validation
45 | shuf --random-source=<(yes 777) $data/all.txt > $data/train.txt
46 | head -n 100 $data/train.txt > $data/val.txt
47 | sed -i '1,100d' $data/train.txt
48 | head -n 10 $data/train.txt > $data/test.txt
49 | sed -i '1,10d' $data/train.txt
50 | fi
51 |
52 |
53 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
54 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')
55 | torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
56 | vits/train.py -c $config -m $dir \
57 | --train_data $data/train.txt \
58 | --val_data $data/val.txt \
59 | --speaker_table $data/speaker.txt \
60 | --phone_table $data/phones.txt \
61 | --num_workers 8
62 | fi
63 |
64 |
65 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
66 | mkdir -p $test_audio
67 | python vits/inference.py --cfg $config \
68 | --speaker_table $data/speaker.txt \
69 | --phone_table $data/phones.txt \
70 | --checkpoint $dir/G_90000.pth \
71 | --test_file $data/test.txt \
72 | --outdir $test_audio
73 | fi
74 |
75 |
76 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
77 | mkdir -p $test_audio
78 | python vits/export_onnx.py --cfg $config \
79 | --speaker_table $data/speaker.txt \
80 | --phone_table $data/phones.txt \
81 | --checkpoint $dir/G_90000.pth \
82 | --onnx_model $dir/G_90000.onnx
83 |
84 | python vits/inference_onnx.py --cfg $config \
85 | --speaker_table $data/speaker.txt \
86 | --phone_table $data/phones.txt \
87 | --onnx_model $dir/G_90000.onnx \
88 | --test_file $data/test.txt \
89 | --outdir $test_audio
90 | fi
91 |
--------------------------------------------------------------------------------
/examples/ljspeech/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/examples/ljspeech/vits:
--------------------------------------------------------------------------------
1 | ../../wetts/vits
--------------------------------------------------------------------------------
/examples/multilingual/configs/v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 512,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/multilingual/configs/v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 22050,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "1",
39 | "upsample_rates": [8,8,2,2],
40 | "upsample_kernel_sizes": [16,16,4,4],
41 | "upsample_initial_channel": 128,
42 | "resblock_kernel_sizes": [3,7,11],
43 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
44 | "n_layers_q": 3,
45 | "use_spectral_norm": false,
46 | "gin_channels": 256
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/examples/multilingual/configs/v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "max_wav_value": 32768.0,
21 | "sampling_rate": 16000,
22 | "filter_length": 1024,
23 | "hop_length": 256,
24 | "win_length": 1024,
25 | "n_mel_channels": 80,
26 | "mel_fmin": 0.0,
27 | "mel_fmax": null
28 | },
29 | "model": {
30 | "use_mel_posterior_encoder": false,
31 | "inter_channels": 192,
32 | "hidden_channels": 192,
33 | "filter_channels": 768,
34 | "n_heads": 2,
35 | "n_layers": 6,
36 | "kernel_size": 3,
37 | "p_dropout": 0.1,
38 | "resblock": "2",
39 | "upsample_rates": [8,8,4],
40 | "upsample_kernel_sizes": [16,16,8],
41 | "upsample_initial_channel": 256,
42 | "resblock_kernel_sizes": [3,5,7],
43 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
44 | "n_layers_q": 3,
45 | "use_sdp": false,
46 | "use_spectral_norm": false,
47 | "gin_channels": 256
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/examples/multilingual/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright 2022 Binbin Zhang(binbzha@qq.com)
4 |
5 | [ -f path.sh ] && . path.sh
6 |
7 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
8 |
9 | stage=0 # start from -1 if you need to download data
10 | stop_stage=3
11 |
12 | dir=exp/v3 # training dir
13 | config=configs/v3.json
14 |
15 | data=data
16 | test_audio=test_audio
17 |
18 | . tools/parse_options.sh || exit 1;
19 |
20 |
21 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
22 | mkdir -p $data
23 | cat ../baker/$data/all.txt \
24 | ../ljspeech/$data/out.txt > $data/all.txt
25 |
26 | cat $data/all.txt | awk -F '|' '{print $2}' | \
27 | sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt
28 | echo 'sil 0' > $data/phones.txt
29 | cat $data/all.txt | awk -F '|' '{print $3}' | \
30 | awk '{for (i=1;i<=NF;i++) print $i}' | sort | uniq | \
31 | grep -v 'sil' | awk '{print $0, NR}' >> $data/phones.txt
32 |
33 | # Split train/validation
34 | shuf --random-source=<(yes 777) $data/all.txt > $data/train.txt
35 | head -n 100 $data/train.txt > $data/val.txt
36 | sed -i '1,100d' $data/train.txt
37 | head -n 10 $data/train.txt > $data/test.txt
38 | sed -i '1,10d' $data/train.txt
39 | fi
40 |
41 |
42 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
43 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')
44 | torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
45 | vits/train.py -c $config -m $dir \
46 | --train_data $data/train.txt \
47 | --val_data $data/val.txt \
48 | --speaker_table $data/speaker.txt \
49 | --phone_table $data/phones.txt \
50 | --num_workers 8
51 | fi
52 |
53 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
54 | mkdir -p $test_audio
55 | python vits/inference.py --cfg $config \
56 | --speaker_table $data/speaker.txt \
57 | --phone_table $data/phones.txt \
58 | --checkpoint $dir/G_90000.pth \
59 | --test_file $data/test.txt \
60 | --outdir $test_audio
61 | fi
62 |
63 |
64 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
65 | mkdir -p $test_audio
66 | python vits/export_onnx.py --cfg $config \
67 | --speaker_table $data/speaker.txt \
68 | --phone_table $data/phones.txt \
69 | --checkpoint $dir/G_90000.pth \
70 | --onnx_model $dir/G_90000.onnx
71 |
72 | python vits/inference_onnx.py --cfg $config \
73 | --speaker_table $data/speaker.txt \
74 | --phone_table $data/phones.txt \
75 | --onnx_model $dir/G_90000.onnx \
76 | --test_file $data/test.txt \
77 | --outdir $test_audio
78 | fi
79 |
--------------------------------------------------------------------------------
/examples/multilingual/tools:
--------------------------------------------------------------------------------
1 | ../../tools
--------------------------------------------------------------------------------
/examples/multilingual/vits:
--------------------------------------------------------------------------------
1 | ../../wetts/vits
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | g2p_en
2 | librosa
3 | nltk
4 | onnx
5 | onnxruntime
6 | scikit-learn
7 | scipy
8 | tensorboard
9 | torch
10 | torchvision
11 | tqdm
12 | transformers
13 | huggingface_hub
14 | soundfile
15 |
--------------------------------------------------------------------------------
/runtime/android/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | /local.properties
4 | /.idea/caches
5 | /.idea/libraries
6 | /.idea/modules.xml
7 | /.idea/workspace.xml
8 | /.idea/navEditor.xml
9 | /.idea/assetWizardSettings.xml
10 | .DS_Store
11 | /build
12 | /captures
13 | .externalNativeBuild
14 | .cxx
15 | local.properties
16 |
--------------------------------------------------------------------------------
/runtime/android/README.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | Most of AI engineers are not familiar with Android development, this is a simple ‘how to’.
4 |
5 | 1. Train your model with your data
6 |
7 | 2. Export pytorch model to onnx model
8 |
9 | 3. Convert onnx model for mobile deployment
10 |
11 | ```bash
12 | python -m onnxruntime.tools.convert_onnx_models_to_ort your-model.onnx
13 | ```
14 |
15 | you will get `your-model.ort` and `your-model.with_runtime_opt.ort`
16 |
17 | ``` bash
18 | $ tree app/src/main/assets
19 | app/src/main/assets
20 | ├── frontend
21 | │ ├── final.ort
22 | │ ├── frontend.flags
23 | │ ├── g2p_en
24 | │ │ ├── README.md
25 | │ │ ├── cmudict.dict
26 | │ │ ├── model.fst
27 | │ │ └── phones.sym
28 | │ ├── lexicon
29 | │ │ ├── lexicon.txt
30 | │ │ ├── pinyin_dict.txt
31 | │ │ ├── polyphone.txt
32 | │ │ ├── polyphone_phone.txt
33 | │ │ └── prosody.txt
34 | │ ├── tn
35 | │ │ ├── zh_tn_tagger.fst
36 | │ │ └── zh_tn_verbalizer.fst
37 | │ └── vocab.txt
38 | └── vits
39 | ├── final.ort
40 | ├── phones.txt
41 | ├── speaker.txt
42 | └── vits.flags
43 |
44 | $ head app/src/main/assets/frontend/frontend.flags
45 | --tagger=frontend/tn/zh_tn_tagger.fst
46 | --verbalizer=frontend/tn/zh_tn_verbalizer.fst
47 | --cmudict=frontend/g2p_en/cmudict.dict
48 | --g2p_en_model=frontend/g2p_en/model.fst
49 | --g2p_en_sym=frontend/g2p_en/phones.sym
50 | --char2pinyin=frontend/lexicon/pinyin_dict.txt
51 | --pinyin2id=frontend/lexicon/polyphone.txt
52 | --pinyin2phones=frontend/lexicon/lexicon.txt
53 | --vocab=frontend/vocab.txt
54 | --g2p_prosody_model=frontend/final.ort
55 |
56 | $ cat app/src/main/assets/vits/vits.flags
57 | --sampling_rate=16000
58 | --speaker2id=vits/speaker.txt
59 | --phone2id=vits/phones.txt
60 | --vits_model=vits/final.ort
61 | ```
62 |
63 | 4. Install Android Studio and open path of wetts/runtime/android and build
64 |
65 | 5. Install `app/build/outputs/apk/debug/app-debug.apk` to your phone and try it.
66 |
--------------------------------------------------------------------------------
/runtime/android/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
--------------------------------------------------------------------------------
/runtime/android/app/build.gradle:
--------------------------------------------------------------------------------
1 | plugins {
2 | id 'com.android.application'
3 | }
4 |
5 | android {
6 | compileSdk 32
7 |
8 | configurations {
9 | extractForNativeBuild
10 | }
11 |
12 | defaultConfig {
13 | applicationId "cn.org.wenet.wetts"
14 | minSdk 21
15 | targetSdk 32
16 | versionCode 1
17 | versionName "1.0"
18 |
19 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
20 | externalNativeBuild {
21 | cmake {
22 | targets "wetts"
23 | }
24 | }
25 | }
26 |
27 | buildTypes {
28 | release {
29 | minifyEnabled false
30 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
31 | }
32 | }
33 | externalNativeBuild {
34 | cmake {
35 | version "3.18.1"
36 | path "src/main/cpp/CMakeLists.txt"
37 | }
38 | }
39 | compileOptions {
40 | sourceCompatibility JavaVersion.VERSION_1_8
41 | targetCompatibility JavaVersion.VERSION_1_8
42 | }
43 | }
44 |
45 | dependencies {
46 | implementation 'androidx.appcompat:appcompat:1.3.0'
47 | implementation 'com.google.android.material:material:1.4.0'
48 | implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
49 | implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.15.1'
50 | extractForNativeBuild 'com.microsoft.onnxruntime:onnxruntime-android:1.15.1'
51 | implementation 'com.github.pengzhendong:wenet-openfst-android:1.0.2'
52 | extractForNativeBuild 'com.github.pengzhendong:wenet-openfst-android:1.0.2'
53 | testImplementation 'junit:junit:4.13.2'
54 | androidTestImplementation 'androidx.test.ext:junit:1.1.3'
55 | androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
56 | }
57 |
58 | task extractAARForNativeBuild {
59 | doLast {
60 | configurations.extractForNativeBuild.files.each {
61 | def file = it.absoluteFile
62 | copy {
63 | from zipTree(file)
64 | into "$buildDir/$file.name"
65 | include "headers/**"
66 | include "jni/**"
67 | }
68 | }
69 | }
70 | }
71 |
72 | tasks.whenTaskAdded { task ->
73 | if (task.name.contains('externalNativeBuild')) {
74 | task.dependsOn(extractAARForNativeBuild)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/runtime/android/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
--------------------------------------------------------------------------------
/runtime/android/app/src/androidTest/java/cn/org/wenet/wetts/ExampleInstrumentedTest.java:
--------------------------------------------------------------------------------
1 | package cn.org.wenet.wetts;
2 |
3 | import android.content.Context;
4 |
5 | import androidx.test.platform.app.InstrumentationRegistry;
6 | import androidx.test.ext.junit.runners.AndroidJUnit4;
7 |
8 | import org.junit.Test;
9 | import org.junit.runner.RunWith;
10 |
11 | import static org.junit.Assert.*;
12 |
13 | /**
14 | * Instrumented test, which will execute on an Android device.
15 | *
16 | * @see Testing documentation
17 | */
18 | @RunWith(AndroidJUnit4.class)
19 | public class ExampleInstrumentedTest {
20 | @Test
21 | public void useAppContext() {
22 | // Context of the app under test.
23 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
24 | assertEquals("cn.org.wenet.wetts", appContext.getPackageName());
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
16 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/assets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/assets/.gitkeep
--------------------------------------------------------------------------------
/runtime/android/app/src/main/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.4.1)
2 | project(wetts CXX)
3 | set(CMAKE_CXX_STANDARD 14)
4 | set(CMAKE_VERBOSE_MAKEFILE on)
5 |
6 | set(build_DIR ${CMAKE_SOURCE_DIR}/../../../build)
7 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
8 |
9 | file(GLOB ONNXRUNTIME_INCLUDE_DIRS ${build_DIR}/onnxruntime*.aar/headers)
10 | file(GLOB ONNXRUNTIME_LINK_DIRS ${build_DIR}/onnxruntime*.aar/jni/${ANDROID_ABI})
11 | link_directories(${ONNXRUNTIME_LINK_DIRS})
12 | include_directories(${ONNXRUNTIME_INCLUDE_DIRS})
13 |
14 | set(openfst_BINARY_DIR ${build_DIR}/wenet-openfst-android-1.0.2.aar/jni)
15 | link_directories(${openfst_BINARY_DIR}/${ANDROID_ABI})
16 | link_libraries(log gflags_nothreads glog fst)
17 | include_directories(${openfst_BINARY_DIR}/include)
18 |
19 | include(wetextprocessing)
20 | include_directories(${CMAKE_SOURCE_DIR})
21 |
22 | add_subdirectory(utils)
23 | add_subdirectory(frontend)
24 | add_subdirectory(model)
25 | add_dependencies(frontend wetextprocessing)
26 |
27 | add_library(wetts SHARED wetts.cc)
28 | target_link_libraries(wetts PUBLIC tts_model onnxruntime)
29 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/cpp/cmake:
--------------------------------------------------------------------------------
1 | ../../../../../core/cmake
--------------------------------------------------------------------------------
/runtime/android/app/src/main/cpp/frontend:
--------------------------------------------------------------------------------
1 | ../../../../../core/frontend
--------------------------------------------------------------------------------
/runtime/android/app/src/main/cpp/model:
--------------------------------------------------------------------------------
1 | ../../../../../core/model
--------------------------------------------------------------------------------
/runtime/android/app/src/main/cpp/utils:
--------------------------------------------------------------------------------
1 | ../../../../../core/utils
--------------------------------------------------------------------------------
/runtime/android/app/src/main/java/cn/org/wenet/wetts/Synthesis.java:
--------------------------------------------------------------------------------
1 | package cn.org.wenet.wetts;
2 |
3 | public class Synthesis {
4 |
5 | static {
6 | System.loadLibrary("wetts");
7 | }
8 |
9 | public static native void init(String modelDir);
10 | public static native void run(String text, String speaker);
11 | }
12 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
7 |
8 |
9 |
15 |
18 |
21 |
22 |
23 |
24 |
30 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/layout/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
10 |
11 |
22 |
23 |
32 |
33 |
42 |
43 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-hdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-hdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-mdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-mdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/values-night/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
16 |
17 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/values/attrs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | #FFBB86FC
4 | #FF6200EE
5 | #FF3700B3
6 | #FF03DAC5
7 | #FF018786
8 | #FF000000
9 | #FFFFFFFF
10 |
11 | #f16d7a
12 | #b7d28d
13 | #b8f1ed
14 | #b7d28d
15 | #b8f1ed
16 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
2 | wetts
3 |
4 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/values/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
16 |
17 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/xml/backup_rules.xml:
--------------------------------------------------------------------------------
1 |
8 |
9 |
13 |
--------------------------------------------------------------------------------
/runtime/android/app/src/main/res/xml/data_extraction_rules.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
12 |
13 |
19 |
--------------------------------------------------------------------------------
/runtime/android/app/src/test/java/cn/org/wenet/wetts/ExampleUnitTest.java:
--------------------------------------------------------------------------------
1 | package cn.org.wenet.wetts;
2 |
3 | import org.junit.Test;
4 |
5 | import static org.junit.Assert.*;
6 |
7 | /**
8 | * Example local unit test, which will execute on the development machine (host).
9 | *
10 | * @see Testing documentation
11 | */
12 | public class ExampleUnitTest {
13 | @Test
14 | public void addition_isCorrect() {
15 | assertEquals(4, 2 + 2);
16 | }
17 | }
--------------------------------------------------------------------------------
/runtime/android/build.gradle:
--------------------------------------------------------------------------------
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 | plugins {
3 | id 'com.android.application' version '7.2.2' apply false
4 | id 'com.android.library' version '7.2.2' apply false
5 | }
6 |
7 | task clean(type: Delete) {
8 | delete rootProject.buildDir
9 | }
--------------------------------------------------------------------------------
/runtime/android/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. More details, visit
12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
13 | # org.gradle.parallel=true
14 | # AndroidX package structure to make it clearer which packages are bundled with the
15 | # Android operating system, and which are packaged with your app"s APK
16 | # https://developer.android.com/topic/libraries/support-library/androidx-rn
17 | android.useAndroidX=true
18 | # Enables namespacing of each library's R class so that its R class includes only the
19 | # resources declared in the library itself and none from the library's dependencies,
20 | # thereby reducing the size of the R class for that library
21 | android.nonTransitiveRClass=true
--------------------------------------------------------------------------------
/runtime/android/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/android/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/runtime/android/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Sat Sep 03 16:13:01 CST 2022
2 | distributionBase=GRADLE_USER_HOME
3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip
4 | distributionPath=wrapper/dists
5 | zipStorePath=wrapper/dists
6 | zipStoreBase=GRADLE_USER_HOME
7 |
--------------------------------------------------------------------------------
/runtime/android/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 |
17 | @if "%DEBUG%" == "" @echo off
18 | @rem ##########################################################################
19 | @rem
20 | @rem Gradle startup script for Windows
21 | @rem
22 | @rem ##########################################################################
23 |
24 | @rem Set local scope for the variables with windows NT shell
25 | if "%OS%"=="Windows_NT" setlocal
26 |
27 | set DIRNAME=%~dp0
28 | if "%DIRNAME%" == "" set DIRNAME=.
29 | set APP_BASE_NAME=%~n0
30 | set APP_HOME=%DIRNAME%
31 |
32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
34 |
35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
37 |
38 | @rem Find java.exe
39 | if defined JAVA_HOME goto findJavaFromJavaHome
40 |
41 | set JAVA_EXE=java.exe
42 | %JAVA_EXE% -version >NUL 2>&1
43 | if "%ERRORLEVEL%" == "0" goto execute
44 |
45 | echo.
46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
47 | echo.
48 | echo Please set the JAVA_HOME variable in your environment to match the
49 | echo location of your Java installation.
50 |
51 | goto fail
52 |
53 | :findJavaFromJavaHome
54 | set JAVA_HOME=%JAVA_HOME:"=%
55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
56 |
57 | if exist "%JAVA_EXE%" goto execute
58 |
59 | echo.
60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
61 | echo.
62 | echo Please set the JAVA_HOME variable in your environment to match the
63 | echo location of your Java installation.
64 |
65 | goto fail
66 |
67 | :execute
68 | @rem Setup the command line
69 |
70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
71 |
72 |
73 | @rem Execute Gradle
74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
75 |
76 | :end
77 | @rem End local scope for the variables with windows NT shell
78 | if "%ERRORLEVEL%"=="0" goto mainEnd
79 |
80 | :fail
81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
82 | rem the _cmd.exe /c_ return code!
83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
84 | exit /b 1
85 |
86 | :mainEnd
87 | if "%OS%"=="Windows_NT" endlocal
88 |
89 | :omega
90 |
--------------------------------------------------------------------------------
/runtime/android/settings.gradle:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | gradlePluginPortal()
4 | google()
5 | mavenCentral()
6 | }
7 | }
8 | dependencyResolutionManagement {
9 | repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
10 | repositories {
11 | google()
12 | mavenCentral()
13 | maven { url 'https://jitpack.io' }
14 | }
15 | }
16 | rootProject.name = "wetts"
17 | include ':app'
18 |
--------------------------------------------------------------------------------
/runtime/core/bin/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_executable(tts_main tts_main.cc)
2 | target_link_libraries(tts_main PUBLIC gflags tts_model)
3 |
4 | if(BUILD_SERVER)
5 | add_executable(http_server_main http_server_main.cc)
6 | target_link_libraries(http_server_main PUBLIC gflags http_server tts_model jsoncpp_lib)
7 | endif()
8 |
--------------------------------------------------------------------------------
/runtime/core/bin/http_server_main.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "gflags/gflags.h"
16 | #include "glog/logging.h"
17 |
18 | #include "http/http_server.h"
19 | #include "processor/wetext_processor.h"
20 |
21 | #include "frontend/g2p_en.h"
22 | #include "frontend/g2p_prosody.h"
23 | #include "frontend/wav.h"
24 | #include "model/tts_model.h"
25 | #include "utils/string.h"
26 |
27 | // Flags
28 | DEFINE_string(frontend_flags, "", "frontend flags file");
29 | DEFINE_string(vits_flags, "", "vits flags file");
30 |
31 | // Text Normalization
32 | DEFINE_string(tagger, "", "tagger fst file");
33 | DEFINE_string(verbalizer, "", "verbalizer fst file");
34 |
35 | // Tokenizer
36 | DEFINE_string(vocab, "", "tokenizer vocab file");
37 |
38 | // G2P for English
39 | DEFINE_string(cmudict, "", "cmudict for english words");
40 | DEFINE_string(g2p_en_model, "", "english g2p fst model for oov");
41 | DEFINE_string(g2p_en_sym, "", "english g2p symbol table for oov");
42 |
43 | // G2P for Chinese
44 | DEFINE_string(char2pinyin, "", "chinese character to pinyin");
45 | DEFINE_string(pinyin2id, "", "pinyin to id");
46 | DEFINE_string(pinyin2phones, "", "pinyin to phones");
47 | DEFINE_string(g2p_prosody_model, "", "g2p prosody model file");
48 |
49 | // VITS
50 | DEFINE_string(speaker2id, "", "speaker to id");
51 | DEFINE_string(phone2id, "", "phone to id");
52 | DEFINE_string(vits_model, "", "e2e tts model file");
53 | DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm");
54 |
55 |
56 | // port
57 | DEFINE_int32(port, 10086, "http listening port");
58 |
59 | int main(int argc, char* argv[]) {
60 | gflags::ParseCommandLineFlags(&argc, &argv, false);
61 | google::InitGoogleLogging(argv[0]);
62 | gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false);
63 | gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false);
64 |
65 | auto tn = std::make_shared(FLAGS_tagger, FLAGS_verbalizer);
66 |
67 | bool has_en = !FLAGS_g2p_en_model.empty() && !FLAGS_g2p_en_sym.empty() &&
68 | !FLAGS_g2p_en_sym.empty();
69 | std::shared_ptr g2p_en =
70 | has_en ? std::make_shared(FLAGS_cmudict, FLAGS_g2p_en_model,
71 | FLAGS_g2p_en_sym)
72 | : nullptr;
73 |
74 | auto g2p_prosody = std::make_shared(
75 | FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
76 | FLAGS_pinyin2phones, g2p_en);
77 | auto model = std::make_shared(
78 | FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
79 | tn, g2p_prosody);
80 |
81 | wetts::HttpServer server(FLAGS_port, model);
82 | LOG(INFO) << "Listening at port " << FLAGS_port;
83 | server.Start();
84 | return 0;
85 | }
86 |
--------------------------------------------------------------------------------
/runtime/core/bin/tts_main.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "gflags/gflags.h"
16 | #include "glog/logging.h"
17 | #include "processor/wetext_processor.h"
18 |
19 | #include "frontend/g2p_en.h"
20 | #include "frontend/g2p_prosody.h"
21 | #include "frontend/wav.h"
22 | #include "model/tts_model.h"
23 | #include "utils/string.h"
24 |
25 | // Flags
26 | DEFINE_string(frontend_flags, "", "frontend flags file");
27 | DEFINE_string(vits_flags, "", "vits flags file");
28 |
29 | // Text Normalization
30 | DEFINE_string(tagger, "", "tagger fst file");
31 | DEFINE_string(verbalizer, "", "verbalizer fst file");
32 |
33 | // Tokenizer
34 | DEFINE_string(vocab, "", "tokenizer vocab file");
35 |
36 | // G2P for English
37 | DEFINE_string(cmudict, "", "cmudict for english words");
38 | DEFINE_string(g2p_en_model, "", "english g2p fst model for oov");
39 | DEFINE_string(g2p_en_sym, "", "english g2p symbol table for oov");
40 |
41 | // G2P for Chinese
42 | DEFINE_string(char2pinyin, "", "chinese character to pinyin");
43 | DEFINE_string(pinyin2id, "", "pinyin to id");
44 | DEFINE_string(pinyin2phones, "", "pinyin to phones");
45 | DEFINE_string(g2p_prosody_model, "", "g2p prosody model file");
46 |
47 | // VITS
48 | DEFINE_string(speaker2id, "", "speaker to id");
49 | DEFINE_string(phone2id, "", "phone to id");
50 | DEFINE_string(vits_model, "", "e2e tts model file");
51 | DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm");
52 |
53 | DEFINE_string(sname, "", "speaker name");
54 | DEFINE_string(text, "", "input text");
55 | DEFINE_string(wav_path, "", "output wave path");
56 |
57 | int main(int argc, char* argv[]) {
58 | gflags::ParseCommandLineFlags(&argc, &argv, false);
59 | google::InitGoogleLogging(argv[0]);
60 | gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false);
61 | gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false);
62 |
63 | auto tn = std::make_shared(FLAGS_tagger, FLAGS_verbalizer);
64 |
65 | bool has_en = !FLAGS_g2p_en_model.empty() && !FLAGS_g2p_en_sym.empty() &&
66 | !FLAGS_g2p_en_sym.empty();
67 | std::shared_ptr g2p_en =
68 | has_en ? std::make_shared(FLAGS_cmudict, FLAGS_g2p_en_model,
69 | FLAGS_g2p_en_sym)
70 | : nullptr;
71 |
72 | auto g2p_prosody = std::make_shared(
73 | FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
74 | FLAGS_pinyin2phones, g2p_en);
75 | auto model = std::make_shared(
76 | FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
77 | tn, g2p_prosody);
78 |
79 | std::vector audio;
80 | int sid = model->GetSid(FLAGS_sname);
81 | model->Synthesis(FLAGS_text, sid, &audio);
82 |
83 | wetts::WavWriter wav_writer(audio.data(), audio.size(), 1,
84 | FLAGS_sampling_rate, 16);
85 | wav_writer.Write(FLAGS_wav_path);
86 | return 0;
87 | }
88 |
--------------------------------------------------------------------------------
/runtime/core/cmake/boost.cmake:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(boost
2 | URL https://boostorg.jfrog.io/artifactory/main/beta/1.81.0.beta1/source/boost_1_81_0_b1.tar.gz
3 | URL_HASH SHA256=135f03965b50d05baae45f49e4b7f2f3c545ff956b4500342f8fb328b8207a90
4 | )
5 | FetchContent_MakeAvailable(boost)
6 | include_directories(${boost_SOURCE_DIR})
7 |
8 | if(MSVC)
9 | add_definitions(-DBOOST_ALL_DYN_LINK -DBOOST_ALL_NO_LIB)
10 | endif()
11 |
--------------------------------------------------------------------------------
/runtime/core/cmake/gflags.cmake:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(gflags
2 | URL https://github.com/gflags/gflags/archive/v2.2.2.zip
3 | URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
4 | )
5 | FetchContent_MakeAvailable(gflags)
6 | include_directories(${gflags_BINARY_DIR}/include)
7 |
--------------------------------------------------------------------------------
/runtime/core/cmake/glog.cmake:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(glog
2 | URL https://github.com/google/glog/archive/v0.4.0.zip
3 | URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
4 | )
5 | FetchContent_MakeAvailable(glog)
6 | include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR})
7 |
--------------------------------------------------------------------------------
/runtime/core/cmake/gtest.cmake:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(googletest
2 | URL https://github.com/google/googletest/archive/release-1.11.0.zip
3 | URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
4 | )
5 | if(MSVC)
6 | set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE)
7 | endif()
8 | FetchContent_MakeAvailable(googletest)
9 |
--------------------------------------------------------------------------------
/runtime/core/cmake/jsoncpp.cmake:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(jsoncpp
2 | URL https://github.com/open-source-parsers/jsoncpp/archive/refs/tags/1.9.3.zip
3 | URL_HASH SHA256=7853fe085ddd5da94b9795f4b520689c21f2753c4a8f7a5097410ee6136bf671
4 | )
5 | FetchContent_MakeAvailable(jsoncpp)
6 | include_directories(${jsoncpp_SOURCE_DIR}/include)
7 |
--------------------------------------------------------------------------------
/runtime/core/cmake/onnxruntime.cmake:
--------------------------------------------------------------------------------
1 | if(ONNX)
2 | if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
3 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-win-x64-1.13.1.zip")
4 | set(URL_HASH "SHA256=cd8318dc30352e0d615f809bd544bfd18b578289ec16621252b5db1994f09e43")
5 | elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
6 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
7 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-linux-aarch64-1.13.1.tgz")
8 | set(URL_HASH "SHA256=18e441585de69ef8aab263e2e96f0325729537ebfbd17cdcee78b2eabf0594d2")
9 | else()
10 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-linux-x64-1.13.1.tgz")
11 | set(URL_HASH "SHA256=2c7fdcfa8131b52167b1870747758cb24265952eba975318a67cc840c04ca73e")
12 | endif()
13 | elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
14 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
15 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-osx-arm64-1.13.1.tgz")
16 | set(URL_HASH "SHA256=10ce30925c789715f29424a7658b41c601dfbde5d58fe21cb53ad418cde3c215")
17 | else()
18 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.13.1/onnxruntime-osx-x86_64-1.13.1.tgz")
19 | set(URL_HASH "SHA256=32f3fff17b01db779e9e3cbe32f27adba40460e6202a79dfd1ac76b4f20588ef")
20 | endif()
21 | else()
22 | message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
23 | endif()
24 |
25 | FetchContent_Declare(onnxruntime
26 | URL ${ONNX_URL}
27 | URL_HASH ${URL_HASH}
28 | )
29 | FetchContent_MakeAvailable(onnxruntime)
30 | include_directories(${onnxruntime_SOURCE_DIR}/include)
31 | link_directories(${onnxruntime_SOURCE_DIR}/lib)
32 |
33 | if(MSVC)
34 | file(GLOB ONNX_DLLS "${onnxruntime_SOURCE_DIR}/lib/*.dll")
35 | file(COPY ${ONNX_DLLS} DESTINATION ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE})
36 | endif()
37 |
38 | add_definitions(-DUSE_ONNX)
39 | endif()
40 |
--------------------------------------------------------------------------------
/runtime/core/cmake/wetextprocessing.cmake:
--------------------------------------------------------------------------------
1 | if(NOT ANDROID)
2 | FetchContent_Declare(wetextprocessing
3 | URL https://github.com/wenet-e2e/WeTextProcessing/archive/refs/tags/0.1.3.tar.gz
4 | URL_HASH SHA256=2f1c81649b2f725a5825345356be9dccb9699965cf44c9f0e842f5c0d4b6ba61
5 | SOURCE_SUBDIR runtime
6 | )
7 | FetchContent_MakeAvailable(wetextprocessing)
8 | include_directories(${openfst_SOURCE_DIR}/src/include)
9 | include_directories(${wetextprocessing_SOURCE_DIR}/runtime)
10 | link_directories(${wetextprocessing_BINARY_DIR})
11 | else()
12 | include(ExternalProject)
13 | set(ANDROID_CMAKE_ARGS
14 | -DBUILD_TESTING=OFF
15 | -DBUILD_SHARED_LIBS=OFF
16 | -DCMAKE_BUILD_TYPE=Release
17 | -DCMAKE_MAKE_PROGRAM=${CMAKE_MAKE_PROGRAM}
18 | -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}
19 | -DANDROID_ABI=${ANDROID_ABI}
20 | -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL}
21 | -DCMAKE_CXX_FLAGS=-I${openfst_BINARY_DIR}/include
22 | -DCMAKE_EXE_LINKER_FLAGS=-L${openfst_BINARY_DIR}/${ANDROID_ABI}
23 | )
24 | ExternalProject_Add(wetextprocessing
25 | URL https://github.com/wenet-e2e/WeTextProcessing/archive/refs/tags/0.1.3.tar.gz
26 | URL_HASH SHA256=2f1c81649b2f725a5825345356be9dccb9699965cf44c9f0e842f5c0d4b6ba61
27 | SOURCE_SUBDIR runtime
28 | CMAKE_ARGS ${ANDROID_CMAKE_ARGS}
29 | INSTALL_COMMAND ""
30 | )
31 | ExternalProject_Get_Property(wetextprocessing SOURCE_DIR BINARY_DIR)
32 | include_directories(${SOURCE_DIR}/runtime)
33 | link_directories(${BINARY_DIR}/processor ${BINARY_DIR}/utils)
34 | link_libraries(wetext_utils)
35 | endif()
36 |
--------------------------------------------------------------------------------
/runtime/core/frontend/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_library(frontend STATIC
2 | lexicon.cc
3 | g2p_en.cc
4 | g2p_prosody.cc
5 | )
6 |
7 | target_link_libraries(frontend PUBLIC onnx_model)
8 |
--------------------------------------------------------------------------------
/runtime/core/frontend/g2p_en.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "frontend/g2p_en.h"
16 |
17 | #include "glog/logging.h"
18 |
19 | #include "utils/fst.h"
20 | #include "utils/string.h"
21 | #include "utils/utils.h"
22 |
23 | namespace wetts {
24 |
25 | G2pEn::G2pEn(const std::string& cmudict, const std::string& model,
26 | const std::string& sym) {
27 | ReadTableFile(cmudict, &cmudict_);
28 | model_.reset(fst::StdVectorFst::Read(model));
29 | sym_.reset(fst::SymbolTable::ReadText(sym));
30 | }
31 |
32 | void G2pEn::Convert(const std::string& grapheme,
33 | std::vector* phonemes) {
34 | if (cmudict_.count(grapheme) > 0) {
35 | *phonemes = cmudict_[grapheme];
36 | } else if (grapheme.size() < 4) {
37 | // Speak short oov letter by letter, such as `ASR` and `TTS`
38 | for (int i = 0; i < grapheme.size(); i++) {
39 | std::string token{grapheme[i]};
40 | std::vector& phones = cmudict_[token];
41 | phonemes->insert(phonemes->end(), phones.begin(), phones.end());
42 | if (i < grapheme.size() - 1) {
43 | // TODO(zhendong.peng): use prosody dict instead of hard code
44 | phonemes->emplace_back("#0");
45 | }
46 | }
47 | } else {
48 | std::vector graphemes;
49 | SplitStringToVector(grapheme, "-", true, &graphemes);
50 | for (int i = 0; i < graphemes.size(); ++i) {
51 | std::vector olabels;
52 | ShortestPath(graphemes[i], model_.get(), &olabels);
53 | for (auto olabel : olabels) {
54 | const auto& phoneme = sym_->Find(olabel);
55 | phonemes->emplace_back(phoneme);
56 | }
57 | if (i != graphemes.size() - 1) {
58 | phonemes->emplace_back("#0");
59 | }
60 | }
61 | }
62 | }
63 |
64 | } // namespace wetts
65 |
--------------------------------------------------------------------------------
/runtime/core/frontend/g2p_en.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef FRONTEND_G2P_EN_H_
16 | #define FRONTEND_G2P_EN_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "fst/fstlib.h"
24 |
25 | namespace wetts {
26 |
27 | class G2pEn {
28 | public:
29 | G2pEn(const std::string& cmudict, const std::string& model,
30 | const std::string& sym);
31 |
32 | void Convert(const std::string& grapheme, std::vector* phonemes);
33 |
34 | private:
35 | std::unordered_map> cmudict_;
36 | std::shared_ptr model_;
37 | std::shared_ptr sym_;
38 | };
39 |
40 | } // namespace wetts
41 |
42 | #endif // FRONTEND_G2P_EN_H_
43 |
--------------------------------------------------------------------------------
/runtime/core/frontend/g2p_prosody.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef FRONTEND_G2P_PROSODY_H_
16 | #define FRONTEND_G2P_PROSODY_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "onnxruntime_cxx_api.h" // NOLINT
24 |
25 | #include "frontend/g2p_en.h"
26 | #include "frontend/lexicon.h"
27 | #include "model/onnx_model.h"
28 |
29 | namespace wetts {
30 |
31 | // Unified G2P & Prosody model
32 | class G2pProsody : public OnnxModel {
33 | public:
34 | explicit G2pProsody(const std::string& g2p_prosody_model,
35 | const std::string& vocab, const std::string& char2pinyin,
36 | const std::string& pinyin2id,
37 | const std::string& pinyin2phones,
38 | std::shared_ptr g2p_en = nullptr);
39 | void Tokenize(const std::string& text, std::vector* tokens,
40 | std::vector* token_ids);
41 | void Compute(const std::string& str, std::vector* phonemes);
42 |
43 | private:
44 | const std::string CLS_ = "[CLS]";
45 | const std::string SEP_ = "[SEP]";
46 | const std::string UNK_ = "[UNK]";
47 | std::unordered_map vocab_;
48 | std::unordered_map phones_;
49 | std::shared_ptr g2p_en_;
50 | std::shared_ptr lexicon_;
51 | std::unordered_map> pinyin2phones_;
52 | };
53 |
54 | } // namespace wetts
55 |
56 | #endif // FRONTEND_G2P_PROSODY_H_
57 |
--------------------------------------------------------------------------------
/runtime/core/frontend/lexicon.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "frontend/lexicon.h"
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "glog/logging.h"
24 |
25 | #include "utils/string.h"
26 |
27 | namespace wetts {
28 |
29 | const char Lexicon::UNK[] = "";
30 |
31 | Lexicon::Lexicon(const std::string& lexicon_file) {
32 | std::ifstream is(lexicon_file);
33 | std::string line;
34 | while (getline(is, line)) {
35 | size_t pos = line.find(' ');
36 | CHECK(pos != std::string::npos);
37 | std::string word = line.substr(0, pos);
38 | std::string prons_str = line.substr(pos + 1);
39 | std::vector prons;
40 | SplitStringToVector(prons_str, ",", true, &prons);
41 | lexicon_[word] = std::move(prons);
42 | }
43 | unk_.emplace_back(UNK);
44 | }
45 |
46 | int Lexicon::NumProns(const std::string& word) {
47 | if (lexicon_.find(word) != lexicon_.end()) {
48 | return lexicon_[word].size();
49 | } else {
50 | return 0;
51 | }
52 | }
53 |
54 | const std::vector& Lexicon::Prons(const std::string& word) {
55 | if (lexicon_.find(word) != lexicon_.end()) {
56 | return lexicon_[word];
57 | } else {
58 | return unk_;
59 | }
60 | }
61 |
62 | } // namespace wetts
63 |
--------------------------------------------------------------------------------
/runtime/core/frontend/lexicon.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef FRONTEND_LEXICON_H_
16 | #define FRONTEND_LEXICON_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | namespace wetts {
24 |
25 | // Lexicon, the format is like
26 | // 今 jin1
27 | // 天 tian1
28 | // 好 hao3,hao4
29 |
30 | class Lexicon {
31 | public:
32 | static const char UNK[];
33 | explicit Lexicon(const std::string& lexicon_file);
34 | int NumProns(const std::string& word);
35 | const std::vector& Prons(const std::string& word);
36 |
37 | private:
38 | std::unordered_map> lexicon_;
39 | std::vector unk_;
40 | };
41 |
42 | } // namespace wetts
43 |
44 | #endif // FRONTEND_LEXICON_H_
45 |
--------------------------------------------------------------------------------
/runtime/core/http/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_library(http_server STATIC http_server.cc)
2 |
--------------------------------------------------------------------------------
/runtime/core/http/http_server.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef HTTP_HTTP_SERVER_H_
16 | #define HTTP_HTTP_SERVER_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "boost/asio/ip/tcp.hpp"
25 | #include "boost/beast/http.hpp"
26 |
27 | #include "model/tts_model.h"
28 |
29 | namespace wetts {
30 |
31 | namespace beast = boost::beast; // from
32 | namespace http = beast::http; // from
33 | namespace asio = boost::asio; // from
34 | using tcp = boost::asio::ip::tcp; // from
35 |
36 | class ConnectionHandler {
37 | public:
38 | ConnectionHandler(tcp::socket&& socket, std::shared_ptr tts_model)
39 | : socket_(std::move(socket)), tts_model_(std::move(tts_model)) {}
40 | void operator()();
41 | http::message_generator HandleRequest(const std::string& json_data);
42 |
43 | private:
44 | tcp::socket socket_;
45 | http::request request_;
46 | std::shared_ptr tts_model_;
47 | };
48 |
49 | class HttpServer {
50 | public:
51 | HttpServer(int port, std::shared_ptr tts_model)
52 | : port_(port), tts_model_(tts_model) {}
53 |
54 | void Start();
55 |
56 | private:
57 | int port_;
58 | // The io_context is required for all I/O
59 | asio::io_context ioc_{1};
60 | std::shared_ptr tts_model_;
61 | };
62 |
63 | } // namespace wetts
64 |
65 | #endif // HTTP_HTTP_SERVER_H_
66 |
--------------------------------------------------------------------------------
/runtime/core/model/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_library(onnx_model STATIC onnx_model.cc)
2 | target_link_libraries(onnx_model PUBLIC onnxruntime utils)
3 |
4 | add_library(tts_model STATIC tts_model.cc)
5 | target_link_libraries(tts_model PUBLIC wetext_processor frontend)
6 |
--------------------------------------------------------------------------------
/runtime/core/model/onnx_model.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 |
18 | #include "model/onnx_model.h"
19 |
20 | #include "glog/logging.h"
21 |
22 | #include "utils/string.h"
23 |
24 | namespace wetts {
25 |
26 | Ort::Env OnnxModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
27 | Ort::SessionOptions OnnxModel::session_options_ = Ort::SessionOptions();
28 |
29 | void OnnxModel::InitEngineThreads(int num_threads) {
30 | session_options_.SetIntraOpNumThreads(num_threads);
31 | session_options_.SetGraphOptimizationLevel(
32 | GraphOptimizationLevel::ORT_ENABLE_ALL);
33 | }
34 |
35 | OnnxModel::OnnxModel(const std::string& model_path) {
36 | InitEngineThreads(1);
37 | #ifdef _MSC_VER
38 | session_ = std::make_shared(env_, ToWString(model_path).c_str(),
39 | session_options_);
40 | #else
41 | session_ = std::make_shared(env_, model_path.c_str(),
42 | session_options_);
43 | #endif
44 | Ort::AllocatorWithDefaultOptions allocator;
45 | // Input info
46 | int num_nodes = session_->GetInputCount();
47 | input_node_names_.resize(num_nodes);
48 | for (int i = 0; i < num_nodes; ++i) {
49 | auto input_name = session_->GetInputNameAllocated(i, allocator);
50 | input_allocated_strings_.push_back(std::move(input_name));
51 | input_node_names_[i] = input_allocated_strings_[i].get();
52 | Ort::TypeInfo type_info = session_->GetInputTypeInfo(i);
53 | auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
54 | ONNXTensorElementDataType type = tensor_info.GetElementType();
55 | std::vector node_dims = tensor_info.GetShape();
56 | std::stringstream shape;
57 | for (auto j : node_dims) {
58 | shape << j;
59 | shape << " ";
60 | }
61 | LOG(INFO) << "Input " << i << " : name=" << input_node_names_[i]
62 | << " type=" << type << " dims=" << shape.str();
63 | }
64 |
65 | // Output info
66 | num_nodes = session_->GetOutputCount();
67 | output_node_names_.resize(num_nodes);
68 | for (int i = 0; i < num_nodes; ++i) {
69 | auto output_name = session_->GetOutputNameAllocated(i, allocator);
70 | output_allocated_strings_.push_back(std::move(output_name));
71 | output_node_names_[i] = output_allocated_strings_[i].get();
72 | Ort::TypeInfo type_info = session_->GetOutputTypeInfo(i);
73 | auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
74 | ONNXTensorElementDataType type = tensor_info.GetElementType();
75 | std::vector node_dims = tensor_info.GetShape();
76 | std::stringstream shape;
77 | for (auto j : node_dims) {
78 | shape << j;
79 | shape << " ";
80 | }
81 | LOG(INFO) << "Output " << i << " : name=" << output_node_names_[i]
82 | << " type=" << type << " dims=" << shape.str();
83 | }
84 | }
85 |
86 | } // namespace wetts
87 |
--------------------------------------------------------------------------------
/runtime/core/model/onnx_model.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef MODEL_ONNX_MODEL_H_
16 | #define MODEL_ONNX_MODEL_H_
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "onnxruntime_cxx_api.h" // NOLINT
23 |
24 | namespace wetts {
25 |
26 | class OnnxModel {
27 | public:
28 | static void InitEngineThreads(int num_threads = 1);
29 | explicit OnnxModel(const std::string& model_path);
30 |
31 | protected:
32 | static Ort::Env env_;
33 | static Ort::SessionOptions session_options_;
34 |
35 | std::shared_ptr session_ = nullptr;
36 | Ort::MemoryInfo memory_info_ =
37 | Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
38 |
39 | std::vector input_node_names_;
40 | std::vector output_node_names_;
41 | std::vector input_allocated_strings_;
42 | std::vector output_allocated_strings_;
43 | };
44 |
45 | } // namespace wetts
46 |
47 | #endif // MODEL_ONNX_MODEL_H_
48 |
--------------------------------------------------------------------------------
/runtime/core/model/tts_model.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef MODEL_TTS_MODEL_H_
16 | #define MODEL_TTS_MODEL_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "onnxruntime_cxx_api.h" // NOLINT
24 | #include "processor/wetext_processor.h"
25 |
26 | #include "frontend/g2p_prosody.h"
27 |
28 | namespace wetts {
29 |
30 | class TtsModel : public OnnxModel {
31 | public:
32 | explicit TtsModel(const std::string& model_path,
33 | const std::string& speaker2id,
34 | const std::string& phone2id,
35 | const int sampling_rate,
36 | std::shared_ptr processor,
37 | std::shared_ptr g2p_prosody);
38 | void Forward(const std::vector& phonemes, const int sid,
39 | std::vector* audio);
40 | void Synthesis(const std::string& text, const int sid,
41 | std::vector* audio);
42 | int GetSid(const std::string& name);
43 | int sampling_rate() const { return sampling_rate_; }
44 |
45 | private:
46 | int sampling_rate_;
47 | std::unordered_map phone2id_;
48 | std::unordered_map speaker2id_;
49 | std::shared_ptr tn_;
50 | std::shared_ptr g2p_prosody_;
51 | };
52 |
53 | } // namespace wetts
54 |
55 | #endif // MODEL_TTS_MODEL_H_
56 |
--------------------------------------------------------------------------------
/runtime/core/test/CMakeLists.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/core/test/CMakeLists.txt
--------------------------------------------------------------------------------
/runtime/core/utils/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_library(utils STATIC
2 | fst.cc
3 | string.cc
4 | utils.cc
5 | )
6 |
7 | target_link_libraries(utils PUBLIC glog fst)
8 |
--------------------------------------------------------------------------------
/runtime/core/utils/fst.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "utils/fst.h"
16 |
17 | #include "fst/rmepsilon.h"
18 |
19 | namespace wetts {
20 |
21 | StdVectorFst ShortestPath(const std::string& input, const StdVectorFst* fst) {
22 | StdVectorFst input_fst;
23 | static StringCompiler compiler(BYTE);
24 | compiler(input, &input_fst);
25 |
26 | StdVectorFst lattice;
27 | fst::Compose(input_fst, *fst, &lattice);
28 | StdVectorFst shortest_path;
29 | fst::ShortestPath(lattice, &shortest_path, 1, true);
30 | return shortest_path;
31 | }
32 |
33 | void ShortestPath(const std::string& input, const StdVectorFst* fst,
34 | std::string* output) {
35 | StdVectorFst lattice = ShortestPath(input, fst);
36 | static StringPrinter printer(BYTE);
37 | printer(lattice, output);
38 | }
39 |
40 | void ShortestPath(const std::string& input, const StdVectorFst* fst,
41 | std::vector* olabels) {
42 | StdVectorFst lattice = ShortestPath(input, fst);
43 | fst::Project(&lattice, PROJECT_OUTPUT);
44 | fst::RmEpsilon(&lattice);
45 | fst::TopSort(&lattice);
46 |
47 | for (StateIterator siter(lattice); !siter.Done(); siter.Next()) {
48 | ArcIterator aiter(lattice, siter.Value());
49 | if (!aiter.Done()) {
50 | olabels->emplace_back(aiter.Value().olabel);
51 | }
52 | }
53 | }
54 |
55 | } // namespace wetts
56 |
--------------------------------------------------------------------------------
/runtime/core/utils/fst.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef UTILS_FST_H_
16 | #define UTILS_FST_H_
17 |
18 | #include
19 | #include
20 |
21 | #include "fst/fstlib.h"
22 | #include "glog/logging.h"
23 |
24 | #include "utils/string.h"
25 |
26 | using fst::StdVectorFst;
27 | using fst::ProjectType::PROJECT_OUTPUT;
28 | using fst::StringTokenType::BYTE;
29 |
30 | using ArcIterator = fst::ArcIterator;
31 | using StringPrinter = fst::StringPrinter;
32 | using StateIterator = fst::StateIterator;
33 | using StringCompiler = fst::StringCompiler;
34 |
35 | namespace wetts {
36 |
37 | StdVectorFst ShortestPath(const std::string& input, const StdVectorFst* fst);
38 |
39 | void ShortestPath(const std::string& input, const StdVectorFst* fst,
40 | std::string* output);
41 |
42 | void ShortestPath(const std::string& input, const StdVectorFst* fst,
43 | std::vector* olabels);
44 |
45 | } // namespace wetts
46 |
47 | #endif // UTILS_FST_H_
48 |
--------------------------------------------------------------------------------
/runtime/core/utils/string.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
2 | // 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 |
16 | #ifndef UTILS_STRING_H_
17 | #define UTILS_STRING_H_
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | namespace wetts {
26 |
27 | // kSpaceSymbol in UTF-8 is: ▁
28 | const char kSpaceSymbol[] = "\xe2\x96\x81";
29 |
30 | const char WHITESPACE[] = " \n\r\t\f\v";
31 |
32 | // Split the string with space or tab.
33 | void SplitString(const std::string& str, std::vector* strs);
34 |
35 | void SplitStringToVector(const std::string& full, const char* delim,
36 | bool omit_empty_strings,
37 | std::vector* out);
38 |
39 | // NOTE(Xingchen Song): we add this function to make it possible to
40 | // support multilingual recipe in the future, in which characters of
41 | // different languages are all encoded in UTF-8 format.
42 | // UTF-8 REF: https://en.wikipedia.org/wiki/UTF-8#Encoding
43 | // Split the UTF-8 string into chars.
44 | void SplitUTF8StringToChars(const std::string& str,
45 | std::vector* chars);
46 |
47 | int UTF8StringLength(const std::string& str);
48 |
49 | // Check whether the UTF-8 char is alphabet or '.
50 | bool CheckEnglishChar(const std::string& ch);
51 |
52 | bool IsChineseChar(const std::string& ch);
53 |
54 | std::string AddSpaceForChineseChar(const std::string& str);
55 |
56 | // Check whether the UTF-8 word is only contains alphabet or '.
57 | bool CheckEnglishWord(const std::string& word);
58 |
59 | std::string JoinString(const std::string& c,
60 | const std::vector& strs);
61 |
62 | bool IsAlpha(const std::string& str);
63 |
64 | bool IsAlphaOrDigit(const std::string& str);
65 |
66 | // Replace ▁ with space, then remove head, tail and consecutive space.
67 | std::string ProcessBlank(const std::string& str, bool lowercase);
68 |
69 | std::string Ltrim(const std::string& str);
70 |
71 | std::string Rtrim(const std::string& str);
72 |
73 | std::string Trim(const std::string& str);
74 |
75 | std::string JoinPath(const std::string& left, const std::string& right);
76 |
77 | #ifdef _MSC_VER
78 | std::wstring ToWString(const std::string& str);
79 | #endif
80 |
81 | } // namespace wetts
82 |
83 | #endif // UTILS_STRING_H_
84 |
--------------------------------------------------------------------------------
/runtime/core/utils/timer.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef UTILS_TIMER_H_
16 | #define UTILS_TIMER_H_
17 |
18 | #include
19 |
20 | namespace wetts {
21 |
22 | class Timer {
23 | public:
24 | Timer() : time_start_(std::chrono::steady_clock::now()) {}
25 | void Reset() { time_start_ = std::chrono::steady_clock::now(); }
26 | // return int in milliseconds
27 | int Elapsed() const {
28 | auto time_now = std::chrono::steady_clock::now();
29 | return std::chrono::duration_cast(time_now -
30 | time_start_)
31 | .count();
32 | }
33 |
34 | private:
35 | std::chrono::time_point time_start_;
36 | };
37 | } // namespace wetts
38 |
39 | #endif // UTILS_TIMER_H_
40 |
--------------------------------------------------------------------------------
/runtime/core/utils/utils.cc:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "utils/utils.h"
16 |
17 | namespace wetts {
18 |
19 | void ReadTableFile(const std::string& file,
20 | std::unordered_map* map) {
21 | std::fstream infile(file);
22 | std::string left;
23 | int right;
24 | while (infile >> left >> right) {
25 | (*map)[left] = right;
26 | }
27 | }
28 |
29 | void ReadTableFile(const std::string& file,
30 | std::unordered_map* map) {
31 | std::ifstream infile(file);
32 | std::string line;
33 | while (getline(infile, line)) {
34 | int pos = line.find_first_of(" \t", 0);
35 | std::string key = line.substr(0, pos);
36 | std::string value = line.substr(pos + 1, line.size() - pos);
37 | (*map)[key] = value;
38 | }
39 | }
40 |
41 | void ReadTableFile(
42 | const std::string& file,
43 | std::unordered_map>* map) {
44 | std::ifstream infile(file);
45 | std::string line;
46 | while (getline(infile, line)) {
47 | std::vector strs;
48 | SplitString(line, &strs);
49 | CHECK_GE(strs.size(), 2);
50 | std::string key = strs[0];
51 | strs.erase(strs.begin());
52 | (*map)[key] = strs;
53 | }
54 | }
55 |
56 | } // namespace wetts
57 |
--------------------------------------------------------------------------------
/runtime/core/utils/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef UTILS_UTILS_H_
16 | #define UTILS_UTILS_H_
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "glog/logging.h"
25 |
26 | #include "utils/string.h"
27 |
28 | namespace wetts {
29 |
30 | void ReadTableFile(const std::string& file,
31 | std::unordered_map* map);
32 |
33 | void ReadTableFile(const std::string& file,
34 | std::unordered_map* map);
35 |
36 | void ReadTableFile(
37 | const std::string& file,
38 | std::unordered_map>* map);
39 |
40 | } // namespace wetts
41 |
42 | #endif // UTILS_UTILS_H_
43 |
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/.gitignore:
--------------------------------------------------------------------------------
1 | lexicon.txt
2 | base.json
3 | phones.txt
4 | *.onnx
5 | python3.10/
6 | python3.10.tar.gz
7 | test_audios/
8 |
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM jackiexiao/tritonserver:23.10-onnx-py-cpu
2 | # FROM nvcr.io/nvidia/tritonserver:23.10-py3
3 | # https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html
4 | # Here we use a custom-built image instead of the official build for minimal image size
5 |
6 | RUN pip3 config set global.index-url https://mirrors.cloud.tencent.com/pypi/simple && \
7 | pip3 install --no-cache-dir pynini==2.1.5 pypinyin WeTextProcessing
8 |
9 | # if you want to include your own model, uncomment the following line
10 | # COPY ./model_repo /models
11 |
12 | ENV PYTHONIOENCODING=UTF-8
13 | # 100MB cache
14 | CMD tritonserver --model-repository=/models --cache-config local,size=104857600
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/Makefile:
--------------------------------------------------------------------------------
1 | # gpu/cpu image
2 | triton_image=nvcr.io/nvidia/tritonserver:23.10-py3
3 | # cpu only image (smaller image size)
4 | triton_image=jackiexiao/tritonserver:23.10-onnx-py-cpu
5 |
6 | exp_name=vits2_vocos_v1
7 | model_dir=${shell pwd}/model_repo
8 | repo_dir=${shell dirname ${shell dirname ${shell pwd}}}
9 | ckpt_step=200000
10 |
11 |
12 | # Build the server docker image:
13 | build_docker:
14 | docker build . -f Dockerfile -t tts_server:latest
15 |
16 | .PHONY: cp_asset
17 | # Copy the asset to the model repo
18 | cp_asset:
19 | cp ${repo_dir}/examples/baker/exp/${exp_name}/encoder_G_${ckpt_step}.onnx ${model_dir}/encoder/1/encoder.onnx
20 | cp ${repo_dir}/examples/baker/exp/${exp_name}/decoder_G_${ckpt_step}.onnx ${model_dir}/decoder/1/decoder.onnx
21 | cp ${repo_dir}/examples/baker/data/lexicon.txt ${model_dir}/lexicon.txt
22 | cp ${repo_dir}/examples/baker/data/phones.txt ${model_dir}/phones.txt
23 | cp ${repo_dir}/examples/baker/configs/${exp_name}.json ${model_dir}/base.json
24 |
25 | .PHONY: start_server
26 | start_server:
27 | docker run \
28 | --rm \
29 | --cpus 2 \
30 | -p8000:8000 -p8001:8001 -p8002:8002 \
31 | --shm-size=1g \
32 | -v ${model_dir}:/models \
33 | --name tts_triton_server \
34 | tts_server:latest \
35 | bash -c "tritonserver --model-repository=/models --cache-config local,size=104857600"
36 |
37 | # streaming client
38 | .PHONY: stream_client
39 | stream_client:
40 | cd client/ && python stream_client.py --text text.scp --outdir test_audios
41 |
42 | # non streaming client
43 | .PHONY: client
44 | client:
45 | cd client/ && python client.py --text text.scp --outdir test_audios
46 |
47 | .PHONY: web_ui
48 | web_ui:
49 | cd client/ && streamlit run web_ui.py
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/README.md:
--------------------------------------------------------------------------------
1 | # Streaming TTS CPU Triton Server
2 |
3 | ## Quick Start
4 | Run with a pre-built demo
5 | - The VITS model in the Docker image is trained on the Baker dataset with configs/vits2_vocos_v1.json.
6 | - It only trains for 200,000 steps just for demonstration purposes.
7 |
8 | ```
9 | # server
10 | docker run -d -p8000:8000 -p8001:8001 jackiexiao/baker_tts_server:latest
11 |
12 | # stream client
13 | pip install -r requirements-client.txt
14 | cd client/ && python stream_tts_client.py --text text.scp --outdir test_audios
15 | ```
16 |
17 | You will get the following results:
18 | - Different CPUs may have varying performances. The following results are just for reference.
19 | - CPU(1 core): Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz
20 |
21 | ```
22 | cd client/ && python3 stream_client.py --text text.scp --outdir test_audios
23 |
24 | 2|今天天气不好我们家里躺平吧
25 | chunk_id=0, chunk_latency=0.21, chunk_duration=0.75s
26 | chunk_id=1, chunk_latency=0.08, chunk_duration=0.75s
27 | chunk_id=2, chunk_latency=0.04, chunk_duration=0.75s
28 | chunk_id=3, chunk_latency=0.08, chunk_duration=0.75s
29 | chunk_id=4, chunk_latency=0.08, chunk_duration=0.22s
30 | dur=3.21, rtf=0.15, first_latency=0.211
31 | ```
32 |
33 | ## Usage / Commands
34 |
35 | See Makefile for details. For example
36 | ```sh
37 | make build_docker
38 | make cp_asset
39 | make start_server
40 | make stream_client
41 | make client
42 | ```
43 |
44 | You need to train and export streaming model first, for example, go to wetts/examples/baker and run
45 | ```
46 | # train
47 | bash run.sh --stage 0 --stop_stage 1
48 | # export streaming model
49 | bash run.sh --stage 4 --stop_stage 4
50 | ```
51 |
52 | ## PS
53 | - I enable response_cache in model_repo/tts, if you want to disable it, you can comment out `response_cache` in model_repo/tts/config.pbtxt
54 | - CPU only triton server image: `jackiexiao/tritonserver:23.10-onnx-py-cpu` is built from source code of triton server, which is only 337.83 MB (COMPRESSED SIZE). See below for details.
55 |
56 | ```
57 | git clone https://github.com/triton-inference-server/server
58 |
59 | version=23.10
60 | git checkout r${version}
61 | python3 build.py \
62 | --enable-logging --enable-stats --enable-tracing --enable-metrics --enable-cpu-metrics \
63 | --cache=local --cache=redis \
64 | --endpoint=http --endpoint=grpc \
65 | --backend=ensemble \
66 | --backend=python \
67 | --backend=onnxruntime
68 |
69 | docker tag tritonserver:latest tritonserver:${version}-onnx-py-cpu
70 | ```
71 |
72 | ## Reference
73 | https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/client/text.scp:
--------------------------------------------------------------------------------
1 | wav1|今天天气不错我们一起去爬山
2 | wav2|今天天气不好我们家里躺平吧
3 |
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/decoder/1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/cpu_triton_stream/model_repo/decoder/1/.gitkeep
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/decoder/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "decoder"
16 | backend: "onnxruntime"
17 | default_model_filename: "decoder.onnx"
18 |
19 | max_batch_size: 32
20 |
21 | input [
22 | {
23 | name: "z"
24 | data_type: TYPE_FP32
25 | dims: [192, -1] # (model.inter_channels, -1)
26 | },
27 | {
28 | name: "g"
29 | data_type: TYPE_FP32
30 | dims: [256, 1] # (model.gin_channels, 1)
31 | }
32 | ]
33 | output [
34 | {
35 | name: "output"
36 | data_type: TYPE_FP32
37 | dims: [-1, -1]
38 | }
39 | ]
40 | dynamic_batching {
41 | }
42 | instance_group [
43 | {
44 | count: 1
45 | kind: KIND_CPU
46 | }
47 | ]
48 | optimization { execution_accelerators {
49 | cpu_execution_accelerator : [ {
50 | name : "openvino"
51 | } ]
52 | }}
53 |
54 | parameters [
55 | {
56 | key: "intra_op_thread_count"
57 | value: { string_value: "0" }
58 | },
59 | {
60 | key: "inter_op_thread_count"
61 | value: { string_value: "0" }
62 | }
63 | ]
64 |
65 | model_warmup [{
66 | name: "zero_value_warmup"
67 | batch_size: 1
68 | inputs[
69 | {
70 | key: "z"
71 | value: {
72 | data_type: TYPE_FP32
73 | dims: [192, 100]
74 | zero_data: true
75 | }
76 | },
77 | {
78 | key: "g"
79 | value: {
80 | data_type: TYPE_FP32
81 | dims: [256, 1]
82 | zero_data: true
83 | }
84 | }
85 | ]
86 | }]
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/encoder/1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/cpu_triton_stream/model_repo/encoder/1/.gitkeep
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/encoder/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "encoder"
16 | backend: "onnxruntime"
17 | default_model_filename: "encoder.onnx"
18 |
19 | max_batch_size: 32
20 |
21 | input [
22 | {
23 | name: "input"
24 | data_type: TYPE_INT64
25 | dims: [-1]
26 | },
27 | {
28 | name: "input_lengths"
29 | data_type: TYPE_INT64
30 | dims: [1]
31 | reshape: { shape: [ ] }
32 | },
33 | {
34 | name: "scales"
35 | data_type: TYPE_FP32
36 | dims: [3]
37 | },
38 | {
39 | name: "sid"
40 | data_type: TYPE_INT64
41 | dims: [1]
42 | reshape: { shape: [ ] }
43 | }
44 | ]
45 | output [
46 | {
47 | name: "z"
48 | data_type: TYPE_FP32
49 | dims: [192, -1] # (model.inter_channels, -1)
50 | },
51 | {
52 | name: "g"
53 | data_type: TYPE_FP32
54 | dims: [256, 1] # (model.gin_channels, 1)
55 | }
56 | ]
57 | dynamic_batching {
58 | }
59 | instance_group [
60 | {
61 | count: 1
62 | kind: KIND_CPU
63 | }
64 | ]
65 |
66 | optimization { execution_accelerators {
67 | cpu_execution_accelerator : [ {
68 | name : "openvino"
69 | } ]
70 | }}
71 |
72 | parameters [
73 | {
74 | key: "intra_op_thread_count"
75 | value: { string_value: "1" }
76 | },
77 | {
78 | key: "inter_op_thread_count"
79 | value: { string_value: "0" }
80 | }
81 | ]
82 |
83 | model_warmup [{
84 | name: "zero_value_warmup"
85 | batch_size: 1
86 | inputs[
87 | {
88 | key: "input"
89 | value: {
90 | data_type: TYPE_INT64
91 | dims: [20]
92 | zero_data: true
93 | }
94 | },
95 | {
96 | key: "input_lengths"
97 | value: {
98 | data_type: TYPE_INT64
99 | dims: [1]
100 | zero_data: true
101 | }
102 | },
103 | {
104 | key: "scales"
105 | value: {
106 | data_type: TYPE_FP32
107 | dims: [3]
108 | zero_data: true
109 | }
110 | },
111 | {
112 | key: "sid"
113 | value: {
114 | data_type: TYPE_INT64
115 | dims: [1]
116 | zero_data: true
117 | }
118 | }
119 | ]
120 | }]
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/stream_tts/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "stream_tts"
16 | backend: "python"
17 | max_batch_size: 0
18 |
19 | model_transaction_policy {
20 | decoupled: True
21 | }
22 |
23 | input [
24 | {
25 | name: "text"
26 | data_type: TYPE_STRING
27 | dims: [1]
28 | }
29 | ]
30 | output [
31 | {
32 | name: "wav"
33 | data_type: TYPE_INT16
34 | dims: [-1]
35 | }
36 | ]
37 |
38 | dynamic_batching {
39 | }
40 |
41 | instance_group [
42 | {
43 | count: 1
44 | kind: KIND_CPU
45 | }
46 | ]
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/model_repo/tts/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "tts"
16 | backend: "python"
17 | max_batch_size: 32
18 |
19 | input [
20 | {
21 | name: "text"
22 | data_type: TYPE_STRING
23 | dims: [1]
24 | reshape: { shape: [] }
25 | }
26 | ]
27 | output [
28 | {
29 | name: "wav"
30 | data_type: TYPE_INT16
31 | dims: [-1]
32 | }
33 | ]
34 | dynamic_batching {
35 | }
36 | instance_group [
37 | {
38 | count: 1
39 | kind: KIND_CPU
40 | }
41 | ]
42 | # see: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/response_cache.html
43 | response_cache {
44 | enable: True
45 | }
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/requirements-client.txt:
--------------------------------------------------------------------------------
1 | tritonclient[all]
2 | scipy
3 | soundfile
--------------------------------------------------------------------------------
/runtime/cpu_triton_stream/requirements-web.txt:
--------------------------------------------------------------------------------
1 | tritonclient[all]
2 | scipy
3 | soundfile
4 | pydub
5 | stqdm
6 | streamlit
--------------------------------------------------------------------------------
/runtime/gpu_triton/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/tritonserver:22.09-py3
2 | # https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html
3 | # Please choose previous tritonserver:xx.xx if you encounter cuda driver mismatch issue
4 |
5 | LABEL maintainer="NVIDIA"
6 | LABEL repository="tritonserver"
7 |
8 | RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
9 | RUN pip3 install pypinyin scipy WeTextProcessing grpcio-tools tritonclient
10 | WORKDIR /workspace
11 |
--------------------------------------------------------------------------------
/runtime/gpu_triton/README.md:
--------------------------------------------------------------------------------
1 | ### TTS Triton Server
2 |
3 |
4 | ```
5 |
6 | # Build the server docker image:
7 | docker build . -f Dockerfile -t tts_server:latest
8 | # start the docker server
9 | docker run --gpus all -v : --name tts_server --net host -it tts_server:latest
10 |
11 |
12 | # export to onnx
13 | python3 vits/export_onnx.py \
14 | --checkpoint logs/exp/base/G_0.pth \
15 | --cfg configs/base.json \
16 | --onnx_model ./logs/exp/base/generator.onnx \
17 | --providers CUDAExecutionProvider \
18 | --phone data/phones.txt
19 |
20 | # model repo preparation
21 | cp generator.onnx model_repo/generator/1/
22 | # please modify the hard coding path in model_repo/tts/config.pbtxt
23 |
24 | # start server (inside the container)
25 | CUDA_VISIBLE_DEVICES="0" tritonserver --model-repository model_repo
26 |
27 | # start client (inside the container)
28 | python3 client.py --text text.scp --outdir test_audios
29 |
30 | # test with triton perf_analyzer tool (inside the docker)
31 | python3 generate_input.py --text text.scp
32 | perf_analyzer -m tts -b 1 -a -p 20000 --concurrency-range 100:200:50 -i gRPC --input-data=./input.json -u localhost:8001
33 | ```
--------------------------------------------------------------------------------
/runtime/gpu_triton/client/generate_input.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import json
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | '--text',
9 | type=str,
10 | required=True,
11 | default=None,
12 | help='a text file'
13 | )
14 | FLAGS = parser.parse_args()
15 |
16 | data = {"data": []}
17 | with open(FLAGS.text, "r", encoding="utf-8")as f:
18 | for line in f:
19 | audio_name, audio_text = line.strip().split("|", 1)
20 | li = {"text": [audio_text.strip('\n')]}
21 | data["data"].append(li)
22 | json.dump(data, open("input.json", "w", encoding="utf-8"), ensure_ascii=False)
23 |
--------------------------------------------------------------------------------
/runtime/gpu_triton/client/text.scp:
--------------------------------------------------------------------------------
1 | wav1 | 今天天气不错
2 | wav2 | 我们一起去爬山
3 |
--------------------------------------------------------------------------------
/runtime/gpu_triton/model_repo/generator/1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/runtime/gpu_triton/model_repo/generator/1/.gitkeep
--------------------------------------------------------------------------------
/runtime/gpu_triton/model_repo/generator/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "generator"
16 | backend: "onnxruntime"
17 | default_model_filename: "generator.onnx"
18 |
19 | max_batch_size: 32
20 |
21 | input [
22 | {
23 | name: "input"
24 | data_type: TYPE_INT64
25 | dims: [-1]
26 | },
27 | {
28 | name: "input_lengths"
29 | data_type: TYPE_INT64
30 | dims: [1]
31 | reshape: { shape: [ ] }
32 | },
33 | {
34 | name: "scales"
35 | data_type: TYPE_FP32
36 | dims: [3]
37 | }
38 | ]
39 | output [
40 | {
41 | name: "output"
42 | data_type: TYPE_FP32
43 | dims: [-1,-1]
44 | }
45 | ]
46 | dynamic_batching {
47 | }
48 | instance_group [
49 | {
50 | count: 2
51 | kind: KIND_GPU
52 | }
53 | ]
54 |
--------------------------------------------------------------------------------
/runtime/gpu_triton/model_repo/tts/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: "tts"
16 | backend: "python"
17 | max_batch_size: 32
18 |
19 | parameters [
20 | {
21 | key: "config",
22 | value: { string_value: "/mnt/samsung-t7/yuekai/tts/wetts/examples/baker/configs/base.json"}
23 | },
24 | {
25 | key: "token_dict"
26 | value: { string_value: "/mnt/samsung-t7/yuekai/tts/wetts/examples/baker/data/phones.txt"}
27 | },
28 | {
29 | key: "pinyin_lexicon"
30 | value: { string_value: "/mnt/samsung-t7/yuekai/tts/wetts/examples/baker/data/lexicon.txt"}
31 | }
32 | ]
33 | input [
34 | {
35 | name: "text"
36 | data_type: TYPE_STRING
37 | dims: [1]
38 | reshape: { shape: [] }
39 | }
40 | ]
41 | output [
42 | {
43 | name: "wav"
44 | data_type: TYPE_INT16
45 | dims: [-1]
46 | }
47 | ]
48 | dynamic_batching {
49 | }
50 | instance_group [
51 | {
52 | count: 2
53 | kind: KIND_CPU
54 | }
55 | ]
--------------------------------------------------------------------------------
/runtime/onnxruntime/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
2 |
3 | project(wetts VERSION 0.1)
4 |
5 | set(CMAKE_VERBOSE_MAKEFILE OFF)
6 | option(BUILD_TESTING "whether to build unit test" OFF)
7 | option(BUILD_SERVER "whether to build server binaries" OFF)
8 | option(ONNX "whether to build with ONNX" ON)
9 |
10 | include(FetchContent)
11 | set(FETCHCONTENT_QUIET OFF)
12 | get_filename_component(fc_base "fc_base-${CMAKE_CXX_COMPILER_ID}" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
13 | set(FETCHCONTENT_BASE_DIR ${fc_base})
14 |
15 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
16 |
17 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread")
18 | include_directories(${CMAKE_CURRENT_SOURCE_DIR})
19 |
20 | include(glog)
21 | include(gflags)
22 | include(onnxruntime)
23 | include(wetextprocessing)
24 |
25 | add_subdirectory(utils)
26 | add_subdirectory(model)
27 | add_subdirectory(frontend)
28 |
29 | if(BUILD_SERVER)
30 | include(boost)
31 | include(jsoncpp)
32 | add_subdirectory(http)
33 | endif()
34 |
35 | add_subdirectory(bin)
36 |
--------------------------------------------------------------------------------
/runtime/onnxruntime/bin:
--------------------------------------------------------------------------------
1 | ../core/bin
--------------------------------------------------------------------------------
/runtime/onnxruntime/cmake:
--------------------------------------------------------------------------------
1 | ../core/cmake
--------------------------------------------------------------------------------
/runtime/onnxruntime/frontend:
--------------------------------------------------------------------------------
1 | ../core/frontend
--------------------------------------------------------------------------------
/runtime/onnxruntime/http:
--------------------------------------------------------------------------------
1 | ../core/http
--------------------------------------------------------------------------------
/runtime/onnxruntime/model:
--------------------------------------------------------------------------------
1 | ../core/model
--------------------------------------------------------------------------------
/runtime/onnxruntime/utils:
--------------------------------------------------------------------------------
1 | ../core/utils/
--------------------------------------------------------------------------------
/runtime/web/README.md:
--------------------------------------------------------------------------------
1 | ## WeTTS Web Demo
2 |
3 | * How to install? `pip install -r requirements.txt`
4 | * How to start? `python app.py`
5 |
--------------------------------------------------------------------------------
/runtime/web/app.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, Shengqiang Li (shengqiang.li96@gmail.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import gradio as gr
17 | from wetts.cli.model import load_model
18 |
19 |
20 | def main():
21 | title = "End-to-End Speech Synthesis in WeTTS | 基于 WeTTS 的端到端语音合成"
22 | description = "WeTTS Demo"
23 | inputs = [gr.Textbox(label="text")]
24 | phones = gr.Textbox(label="phones")
25 | audio = gr.Audio(label="audio")
26 | outputs = [phones, audio]
27 | gr.Interface(
28 | synthesis,
29 | title=title,
30 | description=description,
31 | inputs=inputs,
32 | outputs=outputs
33 | ).launch(server_name='0.0.0.0', share=True)
34 |
35 |
36 | def synthesis(text):
37 | model = load_model()
38 | phones, audio = model.synthesis(text)
39 | sampling_rate = 16000
40 | return ' '.join(phones), (sampling_rate, audio)
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
45 |
--------------------------------------------------------------------------------
/runtime/web/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/wenet-e2e/wetts.git
2 | requests
3 | gradio
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = wetts
3 | version = 0.0.0
4 | license = Apache Software License
5 | description = End to end speech synthesis toolkit
6 | long_description = file: README.md
7 | classifiers =
8 | License :: OSI Approved :: Apache Software License
9 | Operating System :: OS Independent
10 | Programming Language :: Python :: 3
11 |
12 | [options]
13 | packages = find:
14 | include_package_data = True
15 | python_requires = >= 3.8
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | requirements = [
4 | "tqdm",
5 | "scipy",
6 | "onnxruntime",
7 | ]
8 |
9 | setup(
10 | name="wetts",
11 | install_requires=requirements,
12 | packages=find_packages(),
13 | entry_points={"console_scripts": [
14 | "wetts = wetts.cli.tts:main",
15 | ]},
16 | )
17 |
--------------------------------------------------------------------------------
/tools/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 | from g2p_en import G2p
5 |
6 |
7 | g2p = G2p()
8 |
9 | # List of (regular expression, replacement) pairs for abbreviations:
10 | _abbreviations = [
11 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
12 | for x in [
13 | ("mrs", "misess"),
14 | ("mr", "mister"),
15 | ("dr", "doctor"),
16 | ("st", "saint"),
17 | ("co", "company"),
18 | ("jr", "junior"),
19 | ("maj", "major"),
20 | ("gen", "general"),
21 | ("drs", "doctors"),
22 | ("rev", "reverend"),
23 | ("lt", "lieutenant"),
24 | ("hon", "honorable"),
25 | ("sgt", "sergeant"),
26 | ("capt", "captain"),
27 | ("esq", "esquire"),
28 | ("ltd", "limited"),
29 | ("col", "colonel"),
30 | ("ft", "fort"),
31 | ]
32 | ]
33 |
34 | _prosodies = ["#0", "#1", "#2", "#3", "#4"]
35 |
36 |
37 | def expand_abbreviations(text):
38 | for regex, replacement in _abbreviations:
39 | text = re.sub(regex, replacement, text)
40 | return text
41 |
42 |
43 | def filter(phonemes, use_prosody):
44 | phones = []
45 | if not use_prosody:
46 | for phoneme in phonemes:
47 | is_symbol = re.match("^[-,!?.' ]+$", phoneme)
48 | if not is_symbol:
49 | phones.append(phoneme)
50 | return phones
51 |
52 | for phoneme in phonemes:
53 | if re.match("^[']+$", phoneme):
54 | continue
55 | elif re.match("^[- ]+$", phoneme):
56 | if len(phones) > 0 and "#" not in phones[-1]:
57 | phones.append(_prosodies[1])
58 | elif re.match("^[,!?.]+$", phoneme):
59 | if len(phones) > 0 and "#" in phones[-1]:
60 | phones[-1] = max(phones[-1], _prosodies[3])
61 | else:
62 | phones.append(_prosodies[3])
63 | else:
64 | phones.append(phoneme)
65 | if "#" in phones[-1]:
66 | phones[-1] = _prosodies[-1]
67 | else:
68 | phones.append(_prosodies[-1])
69 | return phones
70 |
71 |
72 | def english_cleaners(text, use_prosody):
73 | """Pipeline for English text, including abbreviation expansion."""
74 | text = text.lower()
75 | text = expand_abbreviations(text)
76 | phonemes = g2p(text)
77 | phonemes = filter(phonemes, use_prosody)
78 | return phonemes
79 |
--------------------------------------------------------------------------------
/tools/compute_spec_length.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # author: @lsrami
3 |
4 | import os
5 | import sys
6 | import json
7 | from tqdm import tqdm
8 | import soundfile as sf
9 | from concurrent.futures import ThreadPoolExecutor
10 |
11 |
12 | def load_filepaths_and_text(filename, split="|"):
13 | with open(filename, encoding="utf-8") as f:
14 | filepaths_and_text = [line.strip().split(split) for line in f]
15 | return filepaths_and_text
16 |
17 |
18 | def process_item(item):
19 | audiopath = item[0]
20 | src_sampling_rate = sf.info(audiopath).samplerate
21 | text = item[2]
22 | text = text.strip().split()
23 | if min_text_len <= len(text) and len(text) <= max_text_len:
24 | length = int(os.path.getsize(audiopath) * sampling_rate /
25 | src_sampling_rate) // (2 * hop_length)
26 | item.append(length)
27 | return item
28 | else:
29 | return None
30 |
31 |
32 | def main(in_file, out_file):
33 | """
34 | Filter text & store spec lengths
35 | """
36 |
37 | audiopaths_sid_text = load_filepaths_and_text(in_file, split="|")
38 |
39 | with ThreadPoolExecutor(max_workers=32) as executor:
40 | results = list(
41 | tqdm(
42 | executor.map(process_item, audiopaths_sid_text),
43 | total=len(audiopaths_sid_text),
44 | )
45 | )
46 |
47 | # Filter out None results
48 | results = [result for result in results if result is not None]
49 |
50 | with open(out_file, "w", encoding="utf-8") as f:
51 | for item in results:
52 | f.write("|".join([str(i) for i in item]) + "\n")
53 |
54 |
55 | if __name__ == "__main__":
56 | if len(sys.argv) != 4:
57 | print(f"Usage: {sys.argv[0]} ")
58 | sys.exit(1)
59 | in_file, config_file, out_file = sys.argv[1:4]
60 |
61 | with open(config_file, "r", encoding="utf8") as f:
62 | data = f.read()
63 | config = json.loads(data)
64 | hparams = config["data"]
65 |
66 | min_text_len = hparams.get("min_text_len", 1)
67 | max_text_len = hparams.get("max_text_len", 190)
68 | sampling_rate = hparams.get("sampling_rate", 22050)
69 | hop_length = hparams.get("hop_length", 256)
70 | print(min_text_len, max_text_len, sampling_rate, hop_length)
71 |
72 | main(in_file, out_file)
73 |
--------------------------------------------------------------------------------
/tools/parse_options.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4 | # Arnab Ghoshal, Karel Vesely
5 |
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15 | # MERCHANTABLITY OR NON-INFRINGEMENT.
16 | # See the Apache 2 License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # Parse command-line options.
21 | # To be sourced by another script (as in ". parse_options.sh").
22 | # Option format is: --option-name arg
23 | # and shell variable "option_name" gets set to value "arg."
24 | # The exception is --help, which takes no arguments, but prints the
25 | # $help_message variable (if defined).
26 |
27 |
28 | ###
29 | ### The --config file options have lower priority to command line
30 | ### options, so we need to import them first...
31 | ###
32 |
33 | # Now import all the configs specified by command-line, in left-to-right order
34 | for ((argpos=1; argpos<$#; argpos++)); do
35 | if [ "${!argpos}" == "--config" ]; then
36 | argpos_plus1=$((argpos+1))
37 | config=${!argpos_plus1}
38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39 | . $config # source the config file.
40 | fi
41 | done
42 |
43 |
44 | ###
45 | ### No we process the command line options
46 | ###
47 | while true; do
48 | [ -z "${1:-}" ] && break; # break if there are no arguments
49 | case "$1" in
50 | # If the enclosing script is called with --help option, print the help
51 | # message and exit. Scripts should put help messages in $help_message
52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53 | else printf "$help_message\n" 1>&2 ; fi;
54 | exit 0 ;;
55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56 | exit 1 ;;
57 | # If the first command-line argument begins with "--" (e.g. --foo-bar),
58 | # then work out the variable name as $name, which will equal "foo_bar".
59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60 | # Next we test whether the variable in question is undefned-- if so it's
61 | # an invalid option and we die. Note: $0 evaluates to the name of the
62 | # enclosing script.
63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64 | # is undefined. We then have to wrap this test inside "eval" because
65 | # foo_bar is itself inside a variable ($name).
66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67 |
68 | oldval="`eval echo \\$$name`";
69 | # Work out whether we seem to be expecting a Boolean argument.
70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71 | was_bool=true;
72 | else
73 | was_bool=false;
74 | fi
75 |
76 | # Set the variable to the right value-- the escaped quotes make it work if
77 | # the option had spaces, like --cmd "queue.pl -sync y"
78 | eval $name=\"$2\";
79 |
80 | # Check that Boolean-valued arguments are really Boolean.
81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83 | exit 1;
84 | fi
85 | shift 2;
86 | ;;
87 | *) break;
88 | esac
89 | done
90 |
91 |
92 | # Check for an empty argument to the --cmd option, which can easily occur as a
93 | # result of scripting errors.
94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95 |
96 |
97 | true; # so this script returns exit code 0.
98 |
--------------------------------------------------------------------------------
/wetts/__init__.py:
--------------------------------------------------------------------------------
1 | from wetts.cli.model import load_model # noqa
2 |
--------------------------------------------------------------------------------
/wetts/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenet-e2e/wetts/0abb5117171b305f9150feba5d9bb1c1796088b9/wetts/cli/__init__.py
--------------------------------------------------------------------------------
/wetts/cli/frontend.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import onnxruntime as ort
18 | import numpy as np
19 |
20 |
21 | class Frontend:
22 | def __init__(self, model_dir: str):
23 | self.session = ort.InferenceSession(
24 | os.path.join(model_dir, 'final.onnx'))
25 | self.token2id = self.read_list(os.path.join(model_dir, 'vocab.txt'))
26 | self.polyphone2id = self.read_list(
27 | os.path.join(model_dir, 'lexicon', 'polyphone.txt'))
28 | self.id2polyphone = {v: k for k, v in self.polyphone2id.items()}
29 | self.char2pinyins = self.read_char2pinyins(
30 | os.path.join(model_dir, 'lexicon', 'pinyin_dict.txt'))
31 | self.pinyin2phones = self.read_pinyin2phones(
32 | os.path.join(model_dir, 'lexicon', 'lexicon.txt'))
33 |
34 | def read_list(self, fname: str):
35 | table = {}
36 | with open(fname) as fin:
37 | for i, line in enumerate(fin):
38 | table[line.strip()] = i
39 | return table
40 |
41 | def read_char2pinyins(self, fname: str):
42 | table = {}
43 | with open(fname) as fin:
44 | for line in fin:
45 | arr = line.split()
46 | assert len(arr) == 2
47 | char, pinyins = arr[0], arr[1].split(',')
48 | table[char] = pinyins
49 | return table
50 |
51 | def read_pinyin2phones(self, fname: str):
52 | table = {}
53 | with open(fname) as fin:
54 | for line in fin:
55 | arr = line.split()
56 | assert len(arr) >= 2
57 | pinyin, phones = arr[0], arr[1:]
58 | table[pinyin] = phones
59 | return table
60 |
61 | def compute(self, text: str):
62 | # TODO(Binbin Zhang): Support English/Mix Code
63 | tokens = ['[CLS]'] + [str(x) for x in text] + ['[SEP]']
64 | token_ids = [self.token2id[x] for x in tokens]
65 | outputs = self.session.run(
66 | None, {'input': np.expand_dims(np.array(token_ids), axis=0)})
67 | pinyin_prob, prosody_prob = outputs[0][0], outputs[1][0]
68 | pinyins = []
69 | for i in range(1, len(tokens) - 1):
70 | x = tokens[i]
71 | arr = self.char2pinyins[x]
72 | if len(arr) > 1:
73 | poly_probs = [
74 | pinyin_prob[i][self.polyphone2id[p]] for p in arr
75 | ]
76 | max_idx = poly_probs.index(max(poly_probs))
77 | pinyins.append(arr[max_idx])
78 | else:
79 | pinyins.append(arr[0])
80 | prosodys = prosody_prob.argmax(axis=1).tolist()
81 | outputs = ['sil']
82 | for i in range(len(pinyins)):
83 | outputs.extend(self.pinyin2phones[pinyins[i]])
84 | outputs.append('#{}'.format(prosodys[i]))
85 | outputs[-1] = '#4'
86 | return outputs
87 |
--------------------------------------------------------------------------------
/wetts/cli/hub.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Mddct(hamddct@gmail.com)
2 | # 2023 Binbin Zhang(binbzha@qq.com)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import requests
18 | import sys
19 | import tarfile
20 | from pathlib import Path
21 | from urllib.request import urlretrieve
22 |
23 | import tqdm
24 |
25 |
26 | def download(url: str, dest: str):
27 | """ download from url to dest
28 | """
29 | assert os.path.exists(dest)
30 | print('Downloading {} to {}'.format(url, dest))
31 |
32 | def progress_hook(t):
33 | last_b = [0]
34 |
35 | def update_to(b=1, bsize=1, tsize=None):
36 | if tsize not in (None, -1):
37 | t.total = tsize
38 | displayed = t.update((b - last_b[0]) * bsize)
39 | last_b[0] = b
40 | return displayed
41 |
42 | return update_to
43 |
44 | # *.tar.gz
45 | name = url.split('?')[0].split('/')[-1]
46 | tar_path = os.path.join(dest, name)
47 | with tqdm.tqdm(unit='B',
48 | unit_scale=True,
49 | unit_divisor=1024,
50 | miniters=1,
51 | desc=(name)) as t:
52 | urlretrieve(url,
53 | filename=tar_path,
54 | reporthook=progress_hook(t),
55 | data=None)
56 | t.total = t.n
57 |
58 | with tarfile.open(tar_path) as f:
59 | for tarinfo in f:
60 | if "/" not in tarinfo.name or tarinfo.isdir():
61 | continue
62 | name = tarinfo.name[tarinfo.name.find('/') + 1:]
63 | save_path = os.path.join(dest, name)
64 | print('Extracting to {}'.format(save_path))
65 | dir_name = os.path.dirname(save_path)
66 | if not os.path.exists(dir_name):
67 | os.makedirs(dir_name)
68 | fileobj = f.extractfile(tarinfo)
69 | with open(save_path, "wb") as writer:
70 | writer.write(fileobj.read())
71 |
72 |
73 | class Hub(object):
74 | Assets = {
75 | "frontend": "baker_bert_onnx.tar.gz",
76 | "multilingual": "multilingual_vits_v3_onnx.tar.gz",
77 | }
78 |
79 | def __init__(self) -> None:
80 | pass
81 |
82 | @staticmethod
83 | def get_model(key: str) -> str:
84 | if key not in Hub.Assets.keys():
85 | print('ERROR: Unsupported key {} !!!'.format(key))
86 | sys.exit(1)
87 | model = Hub.Assets[key]
88 | model_dir = os.path.join(Path.home(), ".wetts", key)
89 | if not os.path.exists(model_dir):
90 | os.makedirs(model_dir)
91 | response = requests.get(
92 | "https://modelscope.cn/api/v1/datasets/wenet/wetts_pretrained_models/oss/tree" # noqa
93 | )
94 | model_info = next(data for data in response.json()["Data"]
95 | if data["Key"] == model)
96 | model_url = model_info['Url']
97 | download(model_url, model_dir)
98 | return model_dir
99 |
--------------------------------------------------------------------------------
/wetts/cli/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import numpy as np
18 | import onnxruntime as ort
19 |
20 | from wetts.cli.frontend import Frontend
21 | from wetts.cli.hub import Hub
22 |
23 |
24 | class Model:
25 | def __init__(self, backend_dir: str, front_dir: str):
26 | self.frontend = Frontend(front_dir)
27 | self.session = ort.InferenceSession(
28 | os.path.join(backend_dir, 'final.onnx'))
29 | self.phone2id = self.read_table(os.path.join(backend_dir,
30 | 'phones.txt'))
31 | self.speaker2id = self.read_table(
32 | os.path.join(backend_dir, 'speaker.txt'))
33 |
34 | def read_table(self, fname: str):
35 | table = {}
36 | with open(fname) as fin:
37 | for line in fin:
38 | arr = line.split()
39 | assert len(arr) == 2
40 | table[arr[0]] = int(arr[1])
41 | return table
42 |
43 | def synthesis(self, text: str, speaker: str = 'default'):
44 | phonemes = self.frontend.compute(text)
45 | phonemes_id = [self.phone2id[x] for x in phonemes]
46 | scales = [0.667, 1.0, 0.8]
47 | sid = self.speaker2id.get(speaker, 0)
48 | outputs = self.session.run(
49 | None, {
50 | 'input':
51 | np.expand_dims(np.array(phonemes_id), axis=0),
52 | 'input_lengths':
53 | np.array([len(phonemes)]),
54 | 'scales':
55 | np.expand_dims(np.array(scales, dtype=np.float32), axis=0),
56 | 'sid':
57 | np.array([sid]),
58 | })
59 | audio = outputs[0][0][0]
60 | audio = (audio * 32767).astype(np.int16)
61 | return phonemes, audio
62 |
63 |
64 | def load_model():
65 | front_dir = Hub.get_model('frontend')
66 | backend_dir = Hub.get_model('multilingual')
67 | model = Model(backend_dir, front_dir)
68 | return model
69 |
--------------------------------------------------------------------------------
/wetts/cli/tts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 |
17 | import scipy.io.wavfile as wavfile
18 |
19 | from wetts.cli.model import load_model
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser(description='')
24 | parser.add_argument('--text', help='text to synthesis')
25 | parser.add_argument('--wav', help='output wav file')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main():
31 | args = get_args()
32 | model = load_model()
33 | phones, audio = model.synthesis(args.text)
34 | wavfile.write(args.wav, 16000, audio)
35 | print('{} => {}'.format(args.text, ' '.join(phones)))
36 | print('Succeed, see {}'.format(args.wav))
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/wetts/frontend/README.md:
--------------------------------------------------------------------------------
1 | # WeTTS Frontend System
2 |
3 | ## Keynotes
4 |
5 | Motivated by [Unified Mandarin TTS Front-end Based on Distilled BERT Model](https://arxiv.org/pdf/2012.15404.pdf),
6 | we want to give a simple, production ready, and unified frontend solution in `wetts`.
7 |
8 |
9 | ## Roadmap
10 |
11 | - [x] Server prosody and polyphone based on BERT.
12 | - [ ] On-device prosody and polyphone solution.
13 | - [ ] Joint training with word break and POS to further improve performance(Optional).
14 | - [ ] Text normalization solution.
15 |
16 | ## Data Format
17 |
18 | ### Prosody
19 |
20 | The prosody format is like following, `#n` is prosody rank.
21 |
22 | ```
23 | 蔡少芬 #2 拍拖 #2 也不认啦 #4
24 | 瓦塔拉 #1 总统 #1 已 #1 下令 #3 坚决 #1 回应 #1 袭击者 #4
25 | ```
26 |
27 | ### Polyphone
28 |
29 | The polyphone is surrounded with `▁` in training corpus.
30 |
31 |
32 | ```
33 | 宋代出现了▁le5▁燕乐音阶的记载
34 | 爆发了▁le5▁占领华尔街示威活动
35 | ```
36 |
37 |
--------------------------------------------------------------------------------
/wetts/frontend/export_onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 |
18 | import torch
19 | import onnxruntime as ort
20 |
21 | from model import FrontendModel
22 | from utils import read_table
23 |
24 |
25 | def get_args():
26 | parser = argparse.ArgumentParser(description="export onnn model")
27 | parser.add_argument("--polyphone_dict", required=True, help="polyphone dict file")
28 | parser.add_argument("--prosody_dict", required=True, help="train data file")
29 | parser.add_argument("--checkpoint", required=True, help="checkpoint model")
30 | parser.add_argument("--onnx_model", required=True, help="onnx model path")
31 | args = parser.parse_args()
32 | return args
33 |
34 |
35 | def main():
36 | args = get_args()
37 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
38 | polyphone_dict = read_table(args.polyphone_dict)
39 | prosody_dict = read_table(args.prosody_dict)
40 | num_polyphones = len(polyphone_dict)
41 | num_prosody = len(prosody_dict)
42 |
43 | # Init model
44 | model = FrontendModel(num_polyphones, num_prosody)
45 | model.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
46 | model.forward = model.export_forward
47 | model.eval()
48 |
49 | dummy_input = torch.ones(1, 10, dtype=torch.int64)
50 | torch.onnx.export(
51 | model,
52 | dummy_input,
53 | args.onnx_model,
54 | input_names=["input"],
55 | output_names=["polyphone_output", "prosody_output"],
56 | dynamic_axes={
57 | "input": {1: "T"},
58 | "polyphone_output": {1: "T"},
59 | "prosody_output": {1: "T"},
60 | },
61 | opset_version=13,
62 | verbose=False,
63 | )
64 |
65 | # Verify onnx precision
66 | torch_output = model(dummy_input)
67 | ort_sess = ort.InferenceSession(args.onnx_model)
68 | onnx_output = ort_sess.run(None, {"input": dummy_input.numpy()})
69 | print(torch_output[1])
70 | print(onnx_output[1])
71 | if torch.allclose(
72 | torch_output[0], torch.tensor(onnx_output[0]), atol=1e-3
73 | ) and torch.allclose(torch_output[1], torch.tensor(onnx_output[1]), atol=1e-3):
74 | print("Export to onnx succeed!")
75 | else:
76 | print(
77 | """Export to onnx succeed, but pytorch/onnx have different
78 | outputs when given the same input, please check!!!"""
79 | )
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/wetts/frontend/g2p_prosody.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 |
18 | import onnxruntime as ort
19 | from transformers import AutoTokenizer
20 |
21 | from hanzi2pinyin import Hanzi2Pinyin
22 |
23 |
24 | tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
25 |
26 |
27 | def get_args():
28 | parser = argparse.ArgumentParser(description="training your network")
29 | parser.add_argument("--text", required=True, help="input text")
30 | parser.add_argument("--hanzi2pinyin_file", required=True, help="pinyin dict")
31 | parser.add_argument("--polyphone_file", required=True, help="polyphone dict")
32 | parser.add_argument(
33 | "--polyphone_prosody_model", required=True, help="checkpoint model"
34 | )
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | class Frontend(object):
40 | def __init__(
41 | self,
42 | hanzi2pinyin_file: str,
43 | polyphone_prosody_model: str,
44 | polyphone_file: str,
45 | ):
46 | self.hanzi2pinyin = Hanzi2Pinyin(hanzi2pinyin_file)
47 | self.ppm_sess = ort.InferenceSession(polyphone_prosody_model)
48 | self.polyphone_dict = []
49 | self.polyphone_character_dict = []
50 | with open(polyphone_file) as pp_f:
51 | for line in pp_f.readlines():
52 | self.polyphone_dict.append(line.strip())
53 |
54 | def g2p(self, x):
55 | # polyphone disambiguation & prosody prediction
56 | tokens = tokenizer(
57 | list(x),
58 | is_split_into_words=True,
59 | return_tensors="np",
60 | )["input_ids"]
61 | ort_inputs = {"input": tokens}
62 | ort_outs = self.ppm_sess.run(None, ort_inputs)
63 | prosody_pred = ort_outs[1].argmax(-1)[0][1:-1]
64 | pinyin = []
65 | if len(prons) > 1:
66 | polyphone_ids = []
67 | # The predicted probability for each pronunciation of the polyphone.
68 | preds = []
69 | for pron in prons:
70 | index = self.polyphone_phone_dict.index(pron)
71 | polyphone_ids.append(index)
72 | preds.append(ort_outs[0][0][i + 1][index])
73 | preds = np.array(preds)
74 | id = polyphone_ids[preds.argmax(-1)]
75 | pinyin.append(self.polyphone_phone_dict[id])
76 | else:
77 | pinyin.append(prons[0])
78 | return pinyin, prosody_pred
79 |
80 |
81 | def main():
82 | args = get_args()
83 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
84 |
85 | frontend = Frontend(
86 | args.hanzi2pinyin_file, args.polyphone_prosody_model, args.polyphone_file
87 | )
88 | pinyin, prosody = frontend.g2p(args.text)
89 | print("text: {} \npinyin {} \nprosody {}".format(args.text, pinyin, prosody))
90 |
91 |
92 | if __name__ == "__main__":
93 | main()
94 |
--------------------------------------------------------------------------------
/wetts/frontend/hanzi2pinyin.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Yongqiang Li (yongqiangli@alumni.hust.edu.cn)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class Hanzi2Pinyin:
17 | def __init__(self, dict_file: str):
18 | self.pinyin_dict = {}
19 | with open(dict_file) as f:
20 | for line in f.readlines():
21 | line = line.strip()
22 | self.pinyin_dict[line[0]] = line[2:].split(",")
23 |
24 | def get(self, x):
25 | assert x in self.pinyin_dict
26 | return self.pinyin_dict[x]
27 |
28 | def convert(self, x: str):
29 | pinyin = []
30 | for char in x:
31 | pinyin.append(self.pinyin_dict.get(char, "UNK"))
32 | return pinyin
33 |
34 |
35 | def main():
36 | hanzi2pinyin = Hanzi2Pinyin("local/pinyin_dict.txt")
37 | string = "汉字转拼音实验"
38 | pinyin = hanzi2pinyin.convert(string)
39 | print(string)
40 | print(pinyin)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/wetts/frontend/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from transformers import AutoModel
19 |
20 |
21 | class FrontendModel(nn.Module):
22 | def __init__(self, num_polyphones: int, num_prosody: int):
23 | super(FrontendModel, self).__init__()
24 | self.bert = AutoModel.from_pretrained("bert-base-chinese")
25 | for param in self.bert.parameters():
26 | param.requires_grad_(False)
27 | self.transform = nn.TransformerEncoderLayer(
28 | d_model=768, nhead=8, dim_feedforward=2048, batch_first=True
29 | )
30 | self.phone_classifier = nn.Linear(768, num_polyphones)
31 | self.prosody_classifier = nn.Linear(768, num_prosody)
32 |
33 | def _forward(self, x):
34 | mask = x["attention_mask"] == 0
35 | bert_output = self.bert(**x)
36 | x = self.transform(bert_output.last_hidden_state, src_key_padding_mask=mask)
37 | phone_pred = self.phone_classifier(x)
38 | prosody_pred = self.prosody_classifier(x)
39 | return phone_pred, prosody_pred
40 |
41 | def forward(self, x):
42 | return self._forward(x)
43 |
44 | def export_forward(self, x):
45 | assert x.size(0) == 1
46 | x = {
47 | "input_ids": x,
48 | "token_type_ids": torch.zeros(1, x.size(1), dtype=torch.int64),
49 | "attention_mask": torch.ones(1, x.size(1), dtype=torch.int64),
50 | }
51 | phone_logits, prosody_logits = self._forward(x)
52 | phone_pred = F.softmax(phone_logits, dim=-1)
53 | prosody_pred = F.softmax(prosody_logits, dim=-1)
54 | return phone_pred, prosody_pred
55 |
--------------------------------------------------------------------------------
/wetts/frontend/test_polyphone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 |
18 | import torch
19 | from torch.utils.data import DataLoader
20 | from tqdm import tqdm
21 |
22 | from dataset import FrontendDataset, collote_fn, IGNORE_ID
23 | from model import FrontendModel
24 | from utils import read_table
25 |
26 |
27 | def get_args():
28 | parser = argparse.ArgumentParser(description="training your network")
29 | parser.add_argument("--polyphone_dict", required=True, help="polyphone dict file")
30 | parser.add_argument("--prosody_dict", required=True, help="train data file")
31 | parser.add_argument("--test_data", required=True, help="test data file")
32 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
33 | parser.add_argument("--checkpoint", required=True, help="checkpoint model")
34 | args = parser.parse_args()
35 | return args
36 |
37 |
38 | def main():
39 | args = get_args()
40 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
41 | polyphone_dict = read_table(args.polyphone_dict)
42 | prosody_dict = read_table(args.prosody_dict)
43 | num_polyphones = len(polyphone_dict)
44 | num_prosody = len(prosody_dict)
45 |
46 | test_data = FrontendDataset(
47 | polyphone_file=args.test_data, polyphone_dict=polyphone_dict
48 | )
49 | test_dataloader = DataLoader(
50 | test_data, batch_size=args.batch_size, collate_fn=collote_fn
51 | )
52 | # Init model
53 | model = FrontendModel(num_polyphones, num_prosody)
54 | model.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
55 |
56 | model.eval()
57 | num_total = 0
58 | num_correct = 0
59 | with torch.no_grad():
60 | pbar = tqdm(total=len(test_dataloader))
61 | for _, (inputs, labels, _) in enumerate(test_dataloader):
62 | logits, _ = model(inputs)
63 | mask = labels != IGNORE_ID
64 | num_total += torch.sum(mask)
65 | pred = logits.argmax(-1)
66 | equal = (pred == labels) * mask
67 | num_correct += torch.sum(equal)
68 | pbar.update(1)
69 | pbar.close()
70 |
71 | print("Accuracy: {}".format(num_correct / num_total))
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
--------------------------------------------------------------------------------
/wetts/frontend/test_prosody.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 |
18 | from sklearn.metrics import f1_score
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from tqdm import tqdm
22 |
23 | from dataset import FrontendDataset, collote_fn, IGNORE_ID
24 | from model import FrontendModel
25 | from utils import read_table
26 |
27 |
28 | def get_args():
29 | parser = argparse.ArgumentParser(description="training your network")
30 | parser.add_argument("--polyphone_dict", required=True, help="polyphone dict file")
31 | parser.add_argument("--prosody_dict", required=True, help="train data file")
32 | parser.add_argument("--test_data", required=True, help="test data file")
33 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
34 | parser.add_argument("--checkpoint", required=True, help="checkpoint model")
35 | parser.add_argument(
36 | "--exclude_end",
37 | action="store_true",
38 | default=False,
39 | help="the prosody break at the end of sentence is not counted",
40 | )
41 | args = parser.parse_args()
42 | return args
43 |
44 |
45 | def main():
46 | args = get_args()
47 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
48 | polyphone_dict = read_table(args.polyphone_dict)
49 | prosody_dict = read_table(args.prosody_dict)
50 | num_polyphones = len(polyphone_dict)
51 | num_prosody = len(prosody_dict)
52 |
53 | test_data = FrontendDataset(prosody_file=args.test_data, prosody_dict=prosody_dict)
54 | test_dataloader = DataLoader(
55 | test_data, batch_size=args.batch_size, collate_fn=collote_fn
56 | )
57 | # Init model
58 | model = FrontendModel(num_polyphones, num_prosody)
59 | model.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
60 |
61 | model.eval()
62 | with torch.no_grad():
63 | pbar = tqdm(total=len(test_dataloader))
64 | pred = []
65 | label = []
66 | for _, (inputs, _, labels) in enumerate(test_dataloader):
67 | _, logits = model(inputs)
68 | mask = labels != IGNORE_ID
69 | lengths = torch.sum(mask, dim=1)
70 | for i in range(logits.size(0)):
71 | # Remove padding
72 | if args.exclude_end:
73 | pred.extend(logits[i][1 : lengths[i], :].argmax(-1).tolist())
74 | label.extend(labels[i][1 : lengths[i]].tolist())
75 | else:
76 | pred.extend(logits[i][1 : lengths[i] + 1, :].argmax(-1).tolist())
77 | label.extend(labels[i][1 : lengths[i] + 1].tolist())
78 | pbar.update(1)
79 | pw_f1_score = f1_score(
80 | [1 if x > 0 else 0 for x in label], [1 if x > 0 else 0 for x in pred]
81 | )
82 | pph_f1_score = f1_score(
83 | [1 if x > 1 else 0 for x in label], [1 if x > 1 else 0 for x in pred]
84 | )
85 | iph_f1_score = f1_score(
86 | [1 if x > 2 else 0 for x in label], [1 if x > 2 else 0 for x in pred]
87 | )
88 | print(
89 | "pw f1_score {} pph f1_score {} iph f1_score {}".format(
90 | pw_f1_score, pph_f1_score, iph_f1_score
91 | )
92 | )
93 | pbar.close()
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/wetts/frontend/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, Binbin Zhang (binbzha@qq.com)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def read_table(dict_file):
17 | table = {}
18 | lines = open(dict_file, "r", encoding="utf8").readlines()
19 | for idx, line in enumerate(lines):
20 | table[line.strip()] = idx
21 | return table
22 |
--------------------------------------------------------------------------------
/wetts/vits/model/encoders.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from model.attentions import Encoder
7 | from model.modules import WN
8 | from utils import commons
9 |
10 |
11 | class TextEncoder(nn.Module):
12 |
13 | def __init__(
14 | self,
15 | n_vocab,
16 | out_channels,
17 | hidden_channels,
18 | filter_channels,
19 | n_heads,
20 | n_layers,
21 | kernel_size,
22 | p_dropout,
23 | gin_channels=0,
24 | ):
25 | super().__init__()
26 | self.n_vocab = n_vocab
27 | self.out_channels = out_channels
28 | self.hidden_channels = hidden_channels
29 | self.filter_channels = filter_channels
30 | self.n_heads = n_heads
31 | self.n_layers = n_layers
32 | self.kernel_size = kernel_size
33 | self.p_dropout = p_dropout
34 | self.gin_channels = gin_channels
35 | self.emb = nn.Embedding(n_vocab, hidden_channels)
36 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
37 |
38 | self.encoder = Encoder(hidden_channels,
39 | filter_channels,
40 | n_heads,
41 | n_layers,
42 | kernel_size,
43 | p_dropout,
44 | gin_channels=self.gin_channels)
45 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
46 |
47 | def forward(self, x, x_lengths, g=None):
48 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
49 | x = torch.transpose(x, 1, -1) # [b, h, t]
50 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
51 | 1).to(x.dtype)
52 |
53 | x = self.encoder(x * x_mask, x_mask, g=g)
54 | stats = self.proj(x) * x_mask
55 |
56 | m, logs = torch.split(stats, self.out_channels, dim=1)
57 | return x, m, logs, x_mask
58 |
59 |
60 | class PosteriorEncoder(nn.Module):
61 |
62 | def __init__(
63 | self,
64 | in_channels,
65 | out_channels,
66 | hidden_channels,
67 | kernel_size,
68 | dilation_rate,
69 | n_layers,
70 | gin_channels,
71 | ):
72 | super().__init__()
73 | self.in_channels = in_channels
74 | self.out_channels = out_channels
75 | self.hidden_channels = hidden_channels
76 | self.kernel_size = kernel_size
77 | self.dilation_rate = dilation_rate
78 | self.n_layers = n_layers
79 | self.gin_channels = gin_channels
80 |
81 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
82 | self.enc = WN(
83 | hidden_channels,
84 | kernel_size,
85 | dilation_rate,
86 | n_layers,
87 | gin_channels=gin_channels,
88 | )
89 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90 |
91 | def forward(self, x, x_lengths, g=None):
92 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)),
93 | 1).to(x.dtype)
94 | x = self.pre(x) * x_mask
95 | x = self.enc(x, x_mask, g=g)
96 | stats = self.proj(x) * x_mask
97 | m, logs = torch.split(stats, self.out_channels, dim=1)
98 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
99 | return z, m, logs, x_mask
100 |
--------------------------------------------------------------------------------
/wetts/vits/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils import weight_norm, remove_weight_norm
4 |
5 | from utils import commons
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | class WN(nn.Module):
11 |
12 | def __init__(
13 | self,
14 | hidden_channels,
15 | kernel_size,
16 | dilation_rate,
17 | n_layers,
18 | gin_channels=0,
19 | p_dropout=0,
20 | ):
21 | super(WN, self).__init__()
22 | assert kernel_size % 2 == 1
23 | self.hidden_channels = hidden_channels
24 | self.kernel_size = (kernel_size, )
25 | self.dilation_rate = dilation_rate
26 | self.n_layers = n_layers
27 | self.gin_channels = gin_channels
28 | self.p_dropout = p_dropout
29 |
30 | self.in_layers = nn.ModuleList()
31 | self.res_skip_layers = nn.ModuleList()
32 | self.drop = nn.Dropout(p_dropout)
33 |
34 | cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
35 | self.cond_layer = weight_norm(cond_layer)
36 |
37 | for i in range(n_layers):
38 | dilation = dilation_rate**i
39 | padding = int((kernel_size * dilation - dilation) / 2)
40 | in_layer = nn.Conv1d(
41 | hidden_channels,
42 | 2 * hidden_channels,
43 | kernel_size,
44 | dilation=dilation,
45 | padding=padding,
46 | )
47 | in_layer = weight_norm(in_layer)
48 | self.in_layers.append(in_layer)
49 |
50 | # last one is not necessary
51 | if i < n_layers - 1:
52 | res_skip_channels = 2 * hidden_channels
53 | else:
54 | res_skip_channels = hidden_channels
55 |
56 | res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
57 | res_skip_layer = weight_norm(res_skip_layer)
58 | self.res_skip_layers.append(res_skip_layer)
59 |
60 | def forward(self, x, x_mask, g=None, **kwargs):
61 | output = torch.zeros_like(x)
62 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
63 |
64 | if g is not None:
65 | g = self.cond_layer(g)
66 |
67 | for i in range(self.n_layers):
68 | x_in = self.in_layers[i](x)
69 | if g is not None:
70 | cond_offset = i * 2 * self.hidden_channels
71 | g_l = g[:,
72 | cond_offset:cond_offset + 2 * self.hidden_channels, :]
73 | else:
74 | g_l = torch.zeros_like(x_in)
75 |
76 | acts = commons.fused_add_tanh_sigmoid_multiply(
77 | x_in, g_l, n_channels_tensor)
78 | acts = self.drop(acts)
79 |
80 | res_skip_acts = self.res_skip_layers[i](acts)
81 | if i < self.n_layers - 1:
82 | res_acts = res_skip_acts[:, :self.hidden_channels, :]
83 | x = (x + res_acts) * x_mask
84 | output = output + res_skip_acts[:, self.hidden_channels:, :]
85 | else:
86 | output = output + res_skip_acts
87 | return output * x_mask
88 |
89 | def remove_weight_norm(self):
90 | if self.gin_channels != 0:
91 | remove_weight_norm(self.cond_layer, "weight")
92 | for l in self.in_layers:
93 | remove_weight_norm(l, "weight")
94 | for l in self.res_skip_layers:
95 | remove_weight_norm(l, "weight")
96 |
97 |
98 | class Flip(nn.Module):
99 |
100 | def forward(self, x, *args, reverse=False, **kwargs):
101 | x = torch.flip(x, [1])
102 | if not reverse:
103 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
104 | return x, logdet
105 | else:
106 | return x
107 |
--------------------------------------------------------------------------------
/wetts/vits/model/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class LayerNorm(nn.Module):
7 |
8 | def __init__(self, channels, eps=1e-5):
9 | super().__init__()
10 | self.channels = channels
11 | self.eps = eps
12 |
13 | self.gamma = nn.Parameter(torch.ones(channels))
14 | self.beta = nn.Parameter(torch.zeros(channels))
15 |
16 | def forward(self, x):
17 | x = x.transpose(1, -1)
18 | x = F.layer_norm(x, (self.channels, ), self.gamma, self.beta, self.eps)
19 | return x.transpose(1, -1)
20 |
--------------------------------------------------------------------------------
/wetts/vits/utils/monotonic_align.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numba
3 | import numpy as np
4 |
5 |
6 | def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor):
7 | """numba optimized version.
8 | neg_cent: [b, t_t, t_s]
9 | mask: [b, t_t, t_s]
10 | """
11 | device = neg_cent.device
12 | dtype = neg_cent.dtype
13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14 | path = np.zeros(neg_cent.shape, dtype=np.int32)
15 |
16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
19 | return torch.from_numpy(path).to(device=device, dtype=dtype)
20 |
21 |
22 | @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1],
23 | numba.int32[::1], numba.int32[::1]),
24 | nopython=True,
25 | nogil=True)
26 | def maximum_path_jit(paths, values, t_ys, t_xs):
27 | b = paths.shape[0]
28 | max_neg_val = -1e9
29 | for i in range(int(b)):
30 | path = paths[i]
31 | value = values[i]
32 | t_y = t_ys[i]
33 | t_x = t_xs[i]
34 |
35 | v_prev = v_cur = 0.0
36 | index = t_x - 1
37 |
38 | for y in range(t_y):
39 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
40 | if x == y:
41 | v_cur = max_neg_val
42 | else:
43 | v_cur = value[y - 1, x]
44 | if x == 0:
45 | if y == 0:
46 | v_prev = 0.0
47 | else:
48 | v_prev = max_neg_val
49 | else:
50 | v_prev = value[y - 1, x - 1]
51 | value[y, x] += max(v_prev, v_cur)
52 |
53 | for y in range(t_y - 1, -1, -1):
54 | path[y, index] = 1
55 | if index != 0 and (index == y or value[y - 1, index]
56 | < value[y - 1, index - 1]):
57 | index = index - 1
58 |
--------------------------------------------------------------------------------